├── LICENSE ├── README.md ├── datasets ├── msra_center │ ├── center_test_0_refined.txt │ ├── center_test_1_refined.txt │ ├── center_test_2_refined.txt │ ├── center_test_3_refined.txt │ ├── center_test_4_refined.txt │ ├── center_test_5_refined.txt │ ├── center_test_6_refined.txt │ ├── center_test_7_refined.txt │ ├── center_test_8_refined.txt │ ├── center_train_0_refined.txt │ ├── center_train_1_refined.txt │ ├── center_train_2_refined.txt │ ├── center_train_3_refined.txt │ ├── center_train_4_refined.txt │ ├── center_train_5_refined.txt │ ├── center_train_6_refined.txt │ ├── center_train_7_refined.txt │ └── center_train_8_refined.txt └── msra_hand.py ├── experiments └── msra-subject3 │ ├── gen_gt.py │ ├── main.py │ ├── show_acc.py │ ├── test_res.txt │ └── test_s3_gt.txt ├── figs ├── Challenge_result.png ├── Paper_result_hand_graph.png ├── Paper_result_hand_msra.png ├── Paper_result_hand_table.png ├── Paper_result_human_table.png ├── V2V-PoseNet.png ├── integral_pose_msra_s3_joint_acc.png ├── integral_pose_msra_s3_joint_mean_error.png ├── mean_error_compare.png ├── msra_s3_joint_acc.png ├── msra_s3_joint_mean_error.png └── result │ ├── Paper_result_HANDS2017.png │ ├── Paper_result_ICVL.png │ ├── Paper_result_ITOP_front.png │ ├── Paper_result_ITOP_top.png │ ├── Paper_result_MSRA.png │ └── Paper_result_NYU.png ├── integral-pose ├── accuracy.py ├── compare_acc.py ├── gen_gt.py ├── loss.py ├── main.py ├── model.py ├── msra_hand.py ├── plot.py ├── progressbar.py ├── sampler.py ├── show_acc.py ├── solver.py ├── test_res.txt ├── test_s3_gt.txt ├── v2v_model.py ├── v2v_util.py └── v2vposenet-loss_test_res.txt ├── lib ├── accuracy.py ├── progressbar.py ├── sampler.py └── solver.py ├── src ├── v2v_model.py └── v2v_util.py └── vis └── plot.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Gyeongsik Moon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # V2V-PoseNet-pytorch 2 | This is a pytorch implementation of V2V-PoseNet([V2V-PoseNet: Voxel-to-Voxel Prediction Network for Accurate 3D Hand and Human Pose Estimation from a Single Depth Map](https://arxiv.org/abs/1711.07399)), which is largely based on the author's [torch7 implementation](https://github.com/mks0601/V2V-PoseNet_RELEASE). 3 | 4 | This repository provides 5 | * V2V-PoseNet core modules(model, voxelization, ..) 6 | * An experiment demo on MSRA hand pose dataset, result in ~11mm mean error. 7 | * *Additional [Integral Pose Loss](https://arxiv.org/abs/1711.08229) (or [PoseFix Loss](https://arxiv.org/abs/1812.03595)) implementation*, result in ~10mm mean error on the same demo. 8 | 9 | ## Requirements 10 | * pytorch 0.4.1 or pytorch 1.0 11 | * python 3.6 12 | * numpy 13 | 14 | ### **Warning on pytorch0.4.1 cudnn**: 15 | May need to **disable cudnn for batchnorm**, or just only use cuda instead. With cudnn for batchnorm and in float precision, the model cannot train well. My simple experiments show that: 16 | 17 | ``` 18 | cudnn+float: NOT work(e.g. the loss decreases much slower, and result in a higher loss) 19 | cudnn+float+(disable batchnorm's cudnn): work(e.g. the loss decreases faster, and result in a lower loss) 20 | cudnn+double: work, but the speed is slow 21 | cuda+(float/double): work, but uses much more memroy 22 | ``` 23 | 24 | There is a similar issue pointed out by https://github.com/Microsoft/human-pose-estimation.pytorch. As suggested, disable cudnn for batchnorm: 25 | 26 | ``` 27 | PYTORCH=/path/to/pytorch 28 | for pytorch v0.4.0 29 | sed -i "1194s/torch\.backends\.cudnn\.enabled/False/g" ${PYTORCH}/torch/nn/functional.py 30 | for pytorch v0.4.1 31 | sed -i "1254s/torch\.backends\.cudnn\.enabled/False/g" ${PYTORCH}/torch/nn/functional.py 32 | ``` 33 | 34 | ## MSRA hand dataset demo 35 | ### Usage 36 | - Clone this repo: 37 | ``` 38 | git clone https://github.com/dragonbook/V2V-PoseNet-pytorch.git 39 | cd V2V-PoseNet-pytorch 40 | ``` 41 | 42 | - Download [MSRA hand dataset](https://jimmysuen.github.io/) and extract to directory path/to/msra-hand. 43 | 44 | - Download [estimated centers](https://cv.snu.ac.kr/research/V2V-PoseNet/MSRA/center/center.tar.gz) of MSRA hand dataset which required by V2V-PoseNet and provided by the [author's implementation](https://github.com/mks0601/V2V-PoseNet_RELEASE). Extract them to the directory path/to/msra-hand-center. 45 | ``` 46 | Note, this repository contains a copy of the msra hand centers under ./datasets/msra_center. 47 | ``` 48 | 49 | - Configure data_dir=path/to/msra-hand and center_dir=path/to/msra-hand-center in ./experiments/msra-subject3/main.py. And Run following command to perform training and testing. It will train the dataset for few epochs and evaluate on the test dataset. The test result will be saved as test_res.txt and the fit result on training data will be saved as fit_res.txt 50 | ``` 51 | PYTHONPATH=./ python ./experiments/msra-subject3/main.py 52 | ``` 53 | 54 | - Configure data_dir=path/to/msra-hand and center_dir=path/to/msra-hand-center in ./experiments/msra-subject3/gen_gt.py. Run it to generate ground truth labels as train_s3_gt.txt and test_s3_gt.txt 55 | 56 | - Configure pred_file=path/to/test_s3_gt.txt and gt_file=path/to/test_res.txt in ./experiments/msra-subject3/show_acc.py. Run it to plot accuracy and error. 57 | 58 | - The following figures show that the simple experiment can result in about 11mm mean error. 59 | 60 | ![msra_s3_acc](/figs/msra_s3_joint_acc.png) 61 | 62 | ![msra_s3_mean_error](/figs/msra_s3_joint_mean_error.png) 63 | 64 | 65 | ## Additional [IntegralPose](https://arxiv.org/abs/1711.08229)/[PoseFix](https://arxiv.org/abs/1812.03595) style loss implementation 66 | Replaced V2V-PoseNet's loss with PoseFix's loss(one-hot heatmap loss + L1 coord loss), and it's independently implemented under ./integral-pose directory. Also, configure data_dir and center_dir in ./integral-pose/main.py, and start training. The result shows about 10mm mean error. 67 | 68 | ![integral_loss_s3_acc](/figs/integral_pose_msra_s3_joint_acc.png) 69 | 70 | ![integral_loss_mean_error](/figs/integral_pose_msra_s3_joint_mean_error.png) 71 | 72 | ![compare_mean_error](/figs/mean_error_compare.png) 73 | 74 | # Below is from author's README for reference 75 | # V2V-PoseNet: Voxel-to-Voxel Prediction Network for Accurate 3D Hand and Human Pose Estimation from a Single Depth Map 76 | 77 | # Introduction 78 | 79 | This is our project repository for the paper, **V2V-PoseNet: Voxel-to-Voxel Prediction Network for Accurate 3D Hand and Human Pose Estimation from a Single Depth Map ([CVPR 2018](http://cvpr2018.thecvf.com))**. 80 | 81 | We, **Team SNU CVLAB**, (Gyeongsik Moon, Juyong Chang, and Kyoung Mu Lee of [**Computer Vision Lab, Seoul National University**](https://cv.snu.ac.kr/)) are **winners** of [**HANDS2017 Challenge**](http://icvl.ee.ic.ac.uk/hands17/challenge/) on frame-based 3D hand pose estimation. 82 | 83 | 84 | 85 | Please refer to our paper for details. 86 | 87 | If you find our work useful in your research or publication, please cite our work: 88 | 89 | [1] Moon, Gyeongsik, Ju Yong Chang, and Kyoung Mu Lee. **"V2V-PoseNet: Voxel-to-Voxel Prediction Network for Accurate 3D Hand and Human Pose Estimation from a Single Depth Map."** CVPR 2018. [[arXiv](https://arxiv.org/abs/1711.07399)] 90 | 91 | ``` 92 | @InProceedings{Moon_2018_CVPR_V2V-PoseNet, 93 | author = {Moon, Gyeongsik and Chang, Juyong and Lee, Kyoung Mu}, 94 | title = {V2V-PoseNet: Voxel-to-Voxel Prediction Network for Accurate 3D Hand and Human Pose Estimation from a Single Depth Map}, 95 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 96 | year = {2018} 97 | } 98 | ``` 99 | 100 | In this repository, we provide 101 | * Our model architecture description (V2V-PoseNet) 102 | * HANDS2017 frame-based 3D hand pose estimation Challenge Results 103 | * Comparison with the previous state-of-the-art methods 104 | * Training code 105 | * Datasets we used (ICVL, NYU, MSRA, ITOP) 106 | * Trained models and estimated results 107 | * 3D hand and human pose estimation examples 108 | 109 | 110 | ## Model Architecture 111 | 112 | ![V2V-PoseNet](/figs/V2V-PoseNet.png) 113 | 114 | ## HANDS2017 frame-based 3D hand pose estimation Challenge Results 115 | 116 | ![Challenge_result](/figs/Challenge_result.png) 117 | 118 | 119 | ## Comparison with the previous state-of-the-art methods 120 | 121 | ![Paper_result_hand_graph](/figs/Paper_result_hand_graph.png) 122 | 123 | ![Paper_result_hand_table](/figs/Paper_result_hand_table.png) 124 | 125 | ![Paper_result_human_table](/figs/Paper_result_human_table.png) 126 | 127 | # About our code 128 | ## Dependencies 129 | * [Torch7](http://torch.ch) 130 | * [CUDA](https://developer.nvidia.com/cuda-downloads) 131 | * [cuDNN](https://developer.nvidia.com/cudnn) 132 | 133 | Our code is tested under Ubuntu 14.04 and 16.04 environment with Titan X GPUs (12GB VRAM). 134 | 135 | ## Code 136 | Clone this repository into any place you want. You may follow the example below. 137 | ```bash 138 | makeReposit = [/the/directory/as/you/wish] 139 | mkdir -p $makeReposit/; cd $makeReposit/ 140 | git clone https://github.com/mks0601/V2V-PoseNet_RELEASE.git 141 | ``` 142 | * `src` folder contains lua script files for data loader, trainer, tester and other utilities. 143 | * `data` folder contains data converter which converts image files to the binary files. 144 | 145 | To train our model, please run the following command in the `src` directory: 146 | 147 | ```bash 148 | th rum_me.lua 149 | ``` 150 | 151 | * There are some optional configurations you can adjust in the config.lua. 152 | * You have to convert the `.png` images of the ICVL and NYU dataset to the `.bin` files by running the code from `data` folder. 153 | * The directory where you have to put the dataset files and computed centers of each frame is defined in `src/data/dataset_name/data.lua` 154 | * Visualization code is finally uploaded! You have to prepare 'result_pixel.txt' for each dataset. Each row of the result file has to contain the pixel coordinates of x, y and depth of all joints (i.e, x1 y1 z1 x2 y2 z2 ...). Then run pixel2world script and run draw_DB.m 155 | 156 | # Dataset 157 | We trained and tested our model on the four 3D hand pose estimation and one 3D human pose estimation datasets. 158 | 159 | * ICVL Hand Poseture Dataset [[link](https://labicvl.github.io/hand.html)] [[paper](http://www.iis.ee.ic.ac.uk/dtang/cvpr_14.pdf)] 160 | * NYU Hand Pose Dataset [[link](https://cims.nyu.edu/~tompson/NYU_Hand_Pose_Dataset.htm)] [[paper](https://cims.nyu.edu/~tompson/others/TOG_2014_paper_PREPRINT.pdf)] 161 | * MSRA Hand Pose Dataset [[link](https://jimmysuen.github.io/)] [[paper](https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Sun_Cascaded_Hand_Pose_2015_CVPR_paper.pdf)] 162 | * HANDS2017 Challenge Dataset [[link](http://icvl.ee.ic.ac.uk/hands17/challenge/)] [[paper](https://arxiv.org/abs/1712.03917)] [[challenge benchmark paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Yuan_Depth-Based_3D_Hand_CVPR_2018_paper.pdf)] 163 | * ITOP Human Pose Dataset [[link](https://www.albert.cm/projects/viewpoint_3d_pose/)] [[paper](https://arxiv.org/abs/1603.07076)] 164 | 165 | 166 | # Results 167 | Here we provide the precomputed centers, estimated 3D coordinates and pre-trained models. 168 | 169 | The precomputed centers are obtained by training the hand center estimation network from [DeepPrior++ ](https://arxiv.org/pdf/1708.08325.pdf). Each line represents 3D world coordinate of each frame. 170 | In case of ICVL, NYU, MSRA dataset, if depth map is not exist or not contain hand, that frame is considered as invalid. 171 | In case of ITOP dataset, if 'valid' variable of a certain frame is false, that frame is considered as invalid. 172 | All test images are considered as valid. 173 | 174 | The 3D coordinates estimated on the ICVL, NYU and MSRA datasets are pixel coordinates and the 3D coordinates estimated on the ITOP datasets are world coordinates. The estimated results are from ensembled model. You can make the results from a single model by downloading the pre-trained model and testing it. 175 | 176 | * ICVL Hand Poseture Dataset [[center_trainset](https://cv.snu.ac.kr/research/V2V-PoseNet/ICVL/center/center_train_refined.txt)] [[center_testset](https://cv.snu.ac.kr/research/V2V-PoseNet/ICVL/center/center_test_refined.txt)] [[estimation](https://cv.snu.ac.kr/research/V2V-PoseNet/ICVL/coordinate/result.txt)] [[models](https://cv.snu.ac.kr/research/V2V-PoseNet/ICVL/model/model.tar.gz)] 177 | * NYU Hand Pose Dataset [[center_trainset](https://cv.snu.ac.kr/research/V2V-PoseNet/NYU/center/center_train_refined.txt)] [[center_testset](https://cv.snu.ac.kr/research/V2V-PoseNet/NYU/center/center_test_refined.txt)] [[estimation](https://cv.snu.ac.kr/research/V2V-PoseNet/NYU/coordinate/result.txt)] [[models](https://cv.snu.ac.kr/research/V2V-PoseNet/NYU/model/model.tar.gz)] 178 | * MSRA Hand Pose Dataset [[center](https://cv.snu.ac.kr/research/V2V-PoseNet/MSRA/center/center.tar.gz)] [[estimation](https://cv.snu.ac.kr/research/V2V-PoseNet/MSRA/coordinate/result.txt)] [[models](https://cv.snu.ac.kr/research/V2V-PoseNet/MSRA/model/model.tar.gz)] 179 | * ITOP Human Pose Dataset (front-view) [[center_trainset](https://cv.snu.ac.kr/research/V2V-PoseNet/ITOP_front/center/center_train.txt)] [[center_testset](https://cv.snu.ac.kr/research/V2V-PoseNet/ITOP_front/center/center_test.txt)] [[estimation](https://cv.snu.ac.kr/research/V2V-PoseNet/ITOP_front/coordinate/result.txt)] [[models](https://cv.snu.ac.kr/research/V2V-PoseNet/ITOP_front/model/model.tar.gz)] 180 | * ITOP Human Pose Dataset (top-view) [[center_trainset](https://cv.snu.ac.kr/research/V2V-PoseNet/ITOP_top/center/center_train.txt)] [[center_testset](https://cv.snu.ac.kr/research/V2V-PoseNet/ITOP_top/center/center_test.txt)] [[estimation](https://cv.snu.ac.kr/research/V2V-PoseNet/ITOP_top/coordinate/result.txt)] [[models](https://cv.snu.ac.kr/research/V2V-PoseNet/ITOP_top/model/model.tar.gz)] 181 | 182 | We used [awesome-hand-pose-estimation ](https://github.com/xinghaochen/awesome-hand-pose-estimation) to evaluate the accuracy of the V2V-PoseNet on the ICVL, NYU and MSRA dataset. 183 | 184 | Belows are qualitative results. 185 | ![result_1](/figs/result/Paper_result_ICVL.png) 186 | ![result_2](/figs/result/Paper_result_NYU.png) 187 | ![result_3](/figs/result/Paper_result_MSRA.png) 188 | ![result_4](/figs/result/Paper_result_HANDS2017.png) 189 | ![result_5](/figs/result/Paper_result_ITOP_front.png) 190 | ![result_6](/figs/result/Paper_result_ITOP_top.png) 191 | -------------------------------------------------------------------------------- /datasets/msra_hand.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import sys 4 | import struct 5 | from torch.utils.data import Dataset 6 | 7 | 8 | def pixel2world(x, y, z, img_width, img_height, fx, fy): 9 | w_x = (x - img_width / 2) * z / fx 10 | w_y = (img_height / 2 - y) * z / fy 11 | w_z = z 12 | return w_x, w_y, w_z 13 | 14 | 15 | def world2pixel(x, y, z, img_width, img_height, fx, fy): 16 | p_x = x * fx / z + img_width / 2 17 | p_y = img_height / 2 - y * fy / z 18 | return p_x, p_y 19 | 20 | 21 | def depthmap2points(image, fx, fy): 22 | h, w = image.shape 23 | x, y = np.meshgrid(np.arange(w) + 1, np.arange(h) + 1) 24 | points = np.zeros((h, w, 3), dtype=np.float32) 25 | points[:,:,0], points[:,:,1], points[:,:,2] = pixel2world(x, y, image, w, h, fx, fy) 26 | return points 27 | 28 | 29 | def points2pixels(points, img_width, img_height, fx, fy): 30 | pixels = np.zeros((points.shape[0], 2)) 31 | pixels[:, 0], pixels[:, 1] = \ 32 | world2pixel(points[:,0], points[:, 1], points[:, 2], img_width, img_height, fx, fy) 33 | return pixels 34 | 35 | 36 | def load_depthmap(filename, img_width, img_height, max_depth): 37 | with open(filename, mode='rb') as f: 38 | data = f.read() 39 | _, _, left, top, right, bottom = struct.unpack('I'*6, data[:6*4]) 40 | num_pixel = (right - left) * (bottom - top) 41 | cropped_image = struct.unpack('f'*num_pixel, data[6*4:]) 42 | 43 | cropped_image = np.asarray(cropped_image).reshape(bottom-top, -1) 44 | depth_image = np.zeros((img_height, img_width), dtype=np.float32) 45 | depth_image[top:bottom, left:right] = cropped_image 46 | depth_image[depth_image == 0] = max_depth 47 | 48 | return depth_image 49 | 50 | 51 | class MARAHandDataset(Dataset): 52 | def __init__(self, root, center_dir, mode, test_subject_id, transform=None): 53 | self.img_width = 320 54 | self.img_height = 240 55 | self.min_depth = 100 56 | self.max_depth = 700 57 | self.fx = 241.42 58 | self.fy = 241.42 59 | self.joint_num = 21 60 | self.world_dim = 3 61 | self.folder_list = ['1','2','3','4','5','6','7','8','9','I','IP','L','MP','RP','T','TIP','Y'] 62 | self.subject_num = 9 63 | 64 | self.root = root 65 | self.center_dir = center_dir 66 | self.mode = mode 67 | self.test_subject_id = test_subject_id 68 | self.transform = transform 69 | 70 | if not self.mode in ['train', 'test']: raise ValueError('Invalid mode') 71 | assert self.test_subject_id >= 0 and self.test_subject_id < self.subject_num 72 | 73 | if not self._check_exists(): raise RuntimeError('Invalid MSRA hand dataset') 74 | 75 | self._load() 76 | 77 | def __getitem__(self, index): 78 | depthmap = load_depthmap(self.names[index], self.img_width, self.img_height, self.max_depth) 79 | points = depthmap2points(depthmap, self.fx, self.fy) 80 | points = points.reshape((-1, 3)) 81 | 82 | sample = { 83 | 'name': self.names[index], 84 | 'points': points, 85 | 'joints': self.joints_world[index], 86 | 'refpoint': self.ref_pts[index] 87 | } 88 | 89 | if self.transform: sample = self.transform(sample) 90 | 91 | return sample 92 | 93 | def __len__(self): 94 | return self.num_samples 95 | 96 | def _load(self): 97 | self._compute_dataset_size() 98 | 99 | self.num_samples = self.train_size if self.mode == 'train' else self.test_size 100 | self.joints_world = np.zeros((self.num_samples, self.joint_num, self.world_dim)) 101 | self.ref_pts = np.zeros((self.num_samples, self.world_dim)) 102 | self.names = [] 103 | 104 | # Collect reference center points strings 105 | if self.mode == 'train': ref_pt_file = 'center_train_' + str(self.test_subject_id) + '_refined.txt' 106 | else: ref_pt_file = 'center_test_' + str(self.test_subject_id) + '_refined.txt' 107 | 108 | with open(os.path.join(self.center_dir, ref_pt_file)) as f: 109 | ref_pt_str = [l.rstrip() for l in f] 110 | 111 | # 112 | file_id = 0 113 | frame_id = 0 114 | 115 | for mid in range(self.subject_num): 116 | if self.mode == 'train': model_chk = (mid != self.test_subject_id) 117 | elif self.mode == 'test': model_chk = (mid == self.test_subject_id) 118 | else: raise RuntimeError('unsupported mode {}'.format(self.mode)) 119 | 120 | if model_chk: 121 | for fd in self.folder_list: 122 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt') 123 | 124 | lines = [] 125 | with open(annot_file) as f: 126 | lines = [line.rstrip() for line in f] 127 | 128 | # skip first line 129 | for i in range(1, len(lines)): 130 | # referece point 131 | splitted = ref_pt_str[file_id].split() 132 | if splitted[0] == 'invalid': 133 | print('Warning: found invalid reference frame') 134 | file_id += 1 135 | continue 136 | else: 137 | self.ref_pts[frame_id, 0] = float(splitted[0]) 138 | self.ref_pts[frame_id, 1] = float(splitted[1]) 139 | self.ref_pts[frame_id, 2] = float(splitted[2]) 140 | 141 | # joint point 142 | splitted = lines[i].split() 143 | for jid in range(self.joint_num): 144 | self.joints_world[frame_id, jid, 0] = float(splitted[jid * self.world_dim]) 145 | self.joints_world[frame_id, jid, 1] = float(splitted[jid * self.world_dim + 1]) 146 | self.joints_world[frame_id, jid, 2] = -float(splitted[jid * self.world_dim + 2]) 147 | 148 | filename = os.path.join(self.root, 'P'+str(mid), fd, '{:0>6d}'.format(i-1) + '_depth.bin') 149 | self.names.append(filename) 150 | 151 | frame_id += 1 152 | file_id += 1 153 | 154 | def _compute_dataset_size(self): 155 | self.train_size, self.test_size = 0, 0 156 | 157 | for mid in range(self.subject_num): 158 | num = 0 159 | for fd in self.folder_list: 160 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt') 161 | with open(annot_file) as f: 162 | num = int(f.readline().rstrip()) 163 | if mid == self.test_subject_id: self.test_size += num 164 | else: self.train_size += num 165 | 166 | def _check_exists(self): 167 | # Check basic data 168 | for mid in range(self.subject_num): 169 | for fd in self.folder_list: 170 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt') 171 | if not os.path.exists(annot_file): 172 | print('Error: annotation file {} does not exist'.format(annot_file)) 173 | return False 174 | 175 | # Check precomputed centers by v2v-hand model's author 176 | for subject_id in range(self.subject_num): 177 | center_train = os.path.join(self.center_dir, 'center_train_' + str(subject_id) + '_refined.txt') 178 | center_test = os.path.join(self.center_dir, 'center_test_' + str(subject_id) + '_refined.txt') 179 | if not os.path.exists(center_train) or not os.path.exists(center_test): 180 | print('Error: precomputed center files do not exist') 181 | return False 182 | 183 | return True 184 | -------------------------------------------------------------------------------- /experiments/msra-subject3/gen_gt.py: -------------------------------------------------------------------------------- 1 | ##%% 2 | import os 3 | import numpy as np 4 | import sys 5 | import struct 6 | from torch.utils.data import Dataset 7 | 8 | 9 | def pixel2world(x, y, z, img_width, img_height, fx, fy): 10 | w_x = (x - img_width / 2) * z / fx 11 | w_y = (img_height / 2 - y) * z / fy 12 | w_z = z 13 | return w_x, w_y, w_z 14 | 15 | 16 | def world2pixel(x, y, z, img_width, img_height, fx, fy): 17 | p_x = x * fx / z + img_width / 2 18 | p_y = img_height / 2 - y * fy / z 19 | return p_x, p_y 20 | 21 | 22 | def depthmap2points(image, fx, fy): 23 | h, w = image.shape 24 | x, y = np.meshgrid(np.arange(w) + 1, np.arange(h) + 1) 25 | points = np.zeros((h, w, 3), dtype=np.float32) 26 | points[:,:,0], points[:,:,1], points[:,:,2] = pixel2world(x, y, image, w, h, fx, fy) 27 | return points 28 | 29 | 30 | def points2pixels(points, img_width, img_height, fx, fy): 31 | pixels = np.zeros((points.shape[0], 2)) 32 | pixels[:, 0], pixels[:, 1] = \ 33 | world2pixel(points[:,0], points[:, 1], points[:, 2], img_width, img_height, fx, fy) 34 | return pixels 35 | 36 | 37 | def load_depthmap(filename, img_width, img_height, max_depth): 38 | with open(filename, mode='rb') as f: 39 | data = f.read() 40 | _, _, left, top, right, bottom = struct.unpack('I'*6, data[:6*4]) 41 | num_pixel = (right - left) * (bottom - top) 42 | cropped_image = struct.unpack('f'*num_pixel, data[6*4:]) 43 | 44 | cropped_image = np.asarray(cropped_image).reshape(bottom-top, -1) 45 | depth_image = np.zeros((img_height, img_width), dtype=np.float32) 46 | depth_image[top:bottom, left:right] = cropped_image 47 | depth_image[depth_image == 0] = max_depth 48 | 49 | return depth_image 50 | 51 | 52 | class MARAHandDataset(Dataset): 53 | def __init__(self, root, center_dir, mode, test_subject_id, transform=None): 54 | self.img_width = 320 55 | self.img_height = 240 56 | self.min_depth = 100 57 | self.max_depth = 700 58 | self.fx = 241.42 59 | self.fy = 241.42 60 | self.joint_num = 21 61 | self.world_dim = 3 62 | self.folder_list = ['1','2','3','4','5','6','7','8','9','I','IP','L','MP','RP','T','TIP','Y'] 63 | self.subject_num = 9 64 | 65 | self.root = root 66 | self.center_dir = center_dir 67 | self.mode = mode 68 | self.test_subject_id = test_subject_id 69 | self.transform = transform 70 | 71 | if not self.mode in ['train', 'test']: raise ValueError('Invalid mode') 72 | assert self.test_subject_id >= 0 and self.test_subject_id < self.subject_num 73 | 74 | if not self._check_exists(): raise RuntimeError('Invalid MSRA hand dataset') 75 | 76 | self._load() 77 | 78 | def get_data(self): 79 | return self.names, self.joints_world, self.ref_pts 80 | 81 | def __getitem__(self, index): 82 | depthmap = load_depthmap(self.names[index], self.img_width, self.img_height, self.max_depth) 83 | points = depthmap2points(depthmap, self.fx, self.fy) 84 | points = points.reshape((-1, 3)) 85 | 86 | sample = { 87 | 'name': self.names[index], 88 | 'points': points, 89 | 'joints': self.joints_world[index], 90 | 'refpoint': self.ref_pts[index] 91 | } 92 | 93 | if self.transform: sample = self.transform(sample) 94 | 95 | return sample 96 | 97 | def __len__(self): 98 | return self.num_samples 99 | 100 | def _load(self): 101 | self._compute_dataset_size() 102 | 103 | self.num_samples = self.train_size if self.mode == 'train' else self.test_size 104 | self.joints_world = np.zeros((self.num_samples, self.joint_num, self.world_dim)) 105 | self.ref_pts = np.zeros((self.num_samples, self.world_dim)) 106 | self.names = [] 107 | 108 | # Collect reference center points strings 109 | if self.mode == 'train': ref_pt_file = 'center_train_' + str(self.test_subject_id) + '_refined.txt' 110 | else: ref_pt_file = 'center_test_' + str(self.test_subject_id) + '_refined.txt' 111 | 112 | with open(os.path.join(self.center_dir, ref_pt_file)) as f: 113 | ref_pt_str = [l.rstrip() for l in f] 114 | 115 | # 116 | file_id = 0 117 | frame_id = 0 118 | 119 | for mid in range(self.subject_num): 120 | if self.mode == 'train': model_chk = (mid != self.test_subject_id) 121 | elif self.mode == 'test': model_chk = (mid == self.test_subject_id) 122 | else: raise RuntimeError('unsupported mode {}'.format(self.mode)) 123 | 124 | if model_chk: 125 | for fd in self.folder_list: 126 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt') 127 | 128 | lines = [] 129 | with open(annot_file) as f: 130 | lines = [line.rstrip() for line in f] 131 | 132 | # skip first line 133 | for i in range(1, len(lines)): 134 | # referece point 135 | splitted = ref_pt_str[file_id].split() 136 | if splitted[0] == 'invalid': 137 | print('Warning: found invalid reference frame') 138 | file_id += 1 139 | continue 140 | else: 141 | self.ref_pts[frame_id, 0] = float(splitted[0]) 142 | self.ref_pts[frame_id, 1] = float(splitted[1]) 143 | self.ref_pts[frame_id, 2] = float(splitted[2]) 144 | 145 | # joint point 146 | splitted = lines[i].split() 147 | for jid in range(self.joint_num): 148 | self.joints_world[frame_id, jid, 0] = float(splitted[jid * self.world_dim]) 149 | self.joints_world[frame_id, jid, 1] = float(splitted[jid * self.world_dim + 1]) 150 | self.joints_world[frame_id, jid, 2] = -float(splitted[jid * self.world_dim + 2]) 151 | 152 | filename = os.path.join(self.root, 'P'+str(mid), fd, '{:0>6d}'.format(i-1) + '_depth.bin') 153 | self.names.append(filename) 154 | 155 | frame_id += 1 156 | file_id += 1 157 | 158 | def _compute_dataset_size(self): 159 | self.train_size, self.test_size = 0, 0 160 | 161 | for mid in range(self.subject_num): 162 | num = 0 163 | for fd in self.folder_list: 164 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt') 165 | with open(annot_file) as f: 166 | num = int(f.readline().rstrip()) 167 | if mid == self.test_subject_id: self.test_size += num 168 | else: self.train_size += num 169 | 170 | def _check_exists(self): 171 | # Check basic data 172 | for mid in range(self.subject_num): 173 | for fd in self.folder_list: 174 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt') 175 | if not os.path.exists(annot_file): 176 | print('Error: annotation file {} does not exist'.format(annot_file)) 177 | return False 178 | 179 | # Check precomputed centers by v2v-hand model's author 180 | for subject_id in range(self.subject_num): 181 | center_train = os.path.join(self.center_dir, 'center_train_' + str(subject_id) + '_refined.txt') 182 | center_test = os.path.join(self.center_dir, 'center_test_' + str(subject_id) + '_refined.txt') 183 | if not os.path.exists(center_train) or not os.path.exists(center_test): 184 | print('Error: precomputed center files do not exist') 185 | return False 186 | 187 | return True 188 | 189 | 190 | ##%% 191 | # Generate train_subject3_gt.txt and test_subject3_gt.txt 192 | data_dir = r'/home/maiqi/yalong/dataset/cvpr15_MSRAHandGestureDB' 193 | center_dir = r'/home/maiqi/yalong/project/KeyPoint/Code/V2V-PoseNet-Rlease-Codes/V2V-PoseNet_RELEASE-hand/data-result/MSRA-result/center' 194 | test_subject_id = 3 195 | 196 | 197 | ##%% 198 | def save_keypoints(filename, keypoints): 199 | # Reshape one sample keypoints into one line 200 | keypoints = keypoints.reshape(keypoints.shape[0], -1) 201 | np.savetxt(filename, keypoints, fmt='%0.4f') 202 | 203 | 204 | ##%% 205 | train_dataset = MARAHandDataset(root=data_dir, center_dir=center_dir, mode='train', test_subject_id=test_subject_id) 206 | names, joints_world, ref_pts = train_dataset.get_data() 207 | print('save train reslt ..') 208 | save_keypoints('./train_s3_gt.txt', joints_world) 209 | print('done ..') 210 | 211 | 212 | ##%% 213 | test_dataset = MARAHandDataset(root=data_dir, center_dir=center_dir, mode='test', test_subject_id=test_subject_id) 214 | names, joints_world, ref_pts = test_dataset.get_data() 215 | print('save test reslt ..') 216 | save_keypoints('./test_s3_gt.txt', joints_world) 217 | print('done ..') 218 | -------------------------------------------------------------------------------- /experiments/msra-subject3/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.backends.cudnn as cudnn 6 | import argparse 7 | import os 8 | 9 | from lib.solver import train_epoch, val_epoch, test_epoch 10 | from lib.sampler import ChunkSampler 11 | from src.v2v_model import V2VModel 12 | from src.v2v_util import V2VVoxelization 13 | from datasets.msra_hand import MARAHandDataset 14 | 15 | 16 | ####################################################################################### 17 | # Note, 18 | # Run in project root direcotry(ROOT_DIR) with: 19 | # PYTHONPATH=./ python experiments/msra-subject3/main.py 20 | # 21 | # This script will train model on MSRA hand datasets, save checkpoints to ROOT_DIR/checkpoint, 22 | # and save test results(test_res.txt) and fit results(fit_res.txt) to ROOT_DIR. 23 | # 24 | 25 | 26 | ####################################################################################### 27 | ## Some helpers 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description='PyTorch Hand Keypoints Estimation Training') 30 | #parser.add_argument('--resume', 'r', action='store_true', help='resume from checkpoint') 31 | parser.add_argument('--resume', '-r', default=-1, type=int, help='resume after epoch') 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | ####################################################################################### 37 | ## Configurations 38 | print('Warning: disable cudnn for batchnorm first, or just use only cuda instead!') 39 | 40 | # When we need to resume training, enable randomness to avoid seeing the determinstic 41 | # (agumented) samples many times. 42 | # np.random.seed(1) 43 | # torch.manual_seed(1) 44 | # torch.cuda.manual_seed(1) 45 | 46 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 47 | dtype = torch.float 48 | 49 | # 50 | args = parse_args() 51 | resume_train = args.resume >= 0 52 | resume_after_epoch = args.resume 53 | 54 | save_checkpoint = True 55 | checkpoint_per_epochs = 1 56 | checkpoint_dir = r'./checkpoint' 57 | 58 | start_epoch = 0 59 | epochs_num = 15 60 | 61 | batch_size = 12 62 | 63 | 64 | ####################################################################################### 65 | ## Data, transform, dataset and loader 66 | # Data 67 | print('==> Preparing data ..') 68 | data_dir = r'/home/maiqi/yalong/dataset/cvpr15_MSRAHandGestureDB' 69 | center_dir = r'/home/maiqi/yalong/project/KeyPoint/Code/V2V-PoseNet-Rlease-Codes/V2V-PoseNet_RELEASE-hand/data-result/MSRA-result/center' 70 | keypoints_num = 21 71 | test_subject_id = 3 72 | cubic_size = 200 73 | 74 | 75 | # Transform 76 | voxelization_train = V2VVoxelization(cubic_size=200, augmentation=True) 77 | voxelization_val = V2VVoxelization(cubic_size=200, augmentation=False) 78 | 79 | 80 | def transform_train(sample): 81 | points, keypoints, refpoint = sample['points'], sample['joints'], sample['refpoint'] 82 | assert(keypoints.shape[0] == keypoints_num) 83 | input, heatmap = voxelization_train({'points': points, 'keypoints': keypoints, 'refpoint': refpoint}) 84 | return (torch.from_numpy(input), torch.from_numpy(heatmap)) 85 | 86 | 87 | def transform_val(sample): 88 | points, keypoints, refpoint = sample['points'], sample['joints'], sample['refpoint'] 89 | assert(keypoints.shape[0] == keypoints_num) 90 | input, heatmap = voxelization_val({'points': points, 'keypoints': keypoints, 'refpoint': refpoint}) 91 | return (torch.from_numpy(input), torch.from_numpy(heatmap)) 92 | 93 | 94 | # Dataset and loader 95 | train_set = MARAHandDataset(data_dir, center_dir, 'train', test_subject_id, transform_train) 96 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=6) 97 | #train_num = 1 98 | #train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=False, num_workers=6,sampler=ChunkSampler(train_num, 0)) 99 | 100 | # No separate validation dataset, just use test dataset instead 101 | val_set = MARAHandDataset(data_dir, center_dir, 'test', test_subject_id, transform_val) 102 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=6) 103 | 104 | 105 | ####################################################################################### 106 | ## Model, criterion and optimizer 107 | print('==> Constructing model ..') 108 | net = V2VModel(input_channels=1, output_channels=keypoints_num) 109 | 110 | net = net.to(device, dtype) 111 | if device == torch.device('cuda'): 112 | torch.backends.cudnn.enabled = True 113 | cudnn.benchmark = True 114 | print('cudnn.enabled: ', torch.backends.cudnn.enabled) 115 | 116 | criterion = nn.MSELoss() 117 | 118 | optimizer = optim.Adam(net.parameters()) 119 | #optimizer = optim.RMSprop(net.parameters(), lr=2.5e-4) 120 | 121 | 122 | ####################################################################################### 123 | ## Resume 124 | if resume_train: 125 | # Load checkpoint 126 | epoch = resume_after_epoch 127 | checkpoint_file = os.path.join(checkpoint_dir, 'epoch'+str(epoch)+'.pth') 128 | 129 | print('==> Resuming from checkpoint after epoch {} ..'.format(epoch)) 130 | assert os.path.isdir(checkpoint_dir), 'Error: no checkpoint directory found!' 131 | assert os.path.isfile(checkpoint_file), 'Error: no checkpoint file of epoch {}'.format(epoch) 132 | 133 | checkpoint = torch.load(os.path.join(checkpoint_dir, 'epoch'+str(epoch)+'.pth')) 134 | net.load_state_dict(checkpoint['model_state_dict']) 135 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 136 | start_epoch = checkpoint['epoch'] + 1 137 | 138 | 139 | ####################################################################################### 140 | ## Train and Validate 141 | print('==> Training ..') 142 | for epoch in range(start_epoch, start_epoch + epochs_num): 143 | print('Epoch: {}'.format(epoch)) 144 | train_epoch(net, criterion, optimizer, train_loader, device=device, dtype=dtype) 145 | val_epoch(net, criterion, val_loader, device=device, dtype=dtype) 146 | 147 | if save_checkpoint and epoch % checkpoint_per_epochs == 0: 148 | if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) 149 | checkpoint_file = os.path.join(checkpoint_dir, 'epoch'+str(epoch)+'.pth') 150 | checkpoint = { 151 | 'model_state_dict': net.state_dict(), 152 | 'optimizer_state_dict': optimizer.state_dict(), 153 | 'epoch': epoch 154 | } 155 | torch.save(checkpoint, checkpoint_file) 156 | 157 | 158 | ####################################################################################### 159 | ## Test 160 | print('==> Testing ..') 161 | voxelize_input = voxelization_train.voxelize 162 | evaluate_keypoints = voxelization_train.evaluate 163 | 164 | 165 | def transform_test(sample): 166 | points, refpoint = sample['points'], sample['refpoint'] 167 | input = voxelize_input(points, refpoint) 168 | return torch.from_numpy(input), torch.from_numpy(refpoint.reshape((1, -1))) 169 | 170 | 171 | def transform_output(heatmaps, refpoints): 172 | keypoints = evaluate_keypoints(heatmaps, refpoints) 173 | return keypoints 174 | 175 | 176 | class BatchResultCollector(): 177 | def __init__(self, samples_num, transform_output): 178 | self.samples_num = samples_num 179 | self.transform_output = transform_output 180 | self.keypoints = None 181 | self.idx = 0 182 | 183 | def __call__(self, data_batch): 184 | inputs_batch, outputs_batch, extra_batch = data_batch 185 | outputs_batch = outputs_batch.cpu().numpy() 186 | refpoints_batch = extra_batch.cpu().numpy() 187 | 188 | keypoints_batch = self.transform_output(outputs_batch, refpoints_batch) 189 | 190 | if self.keypoints is None: 191 | # Initialize keypoints until dimensions awailable now 192 | self.keypoints = np.zeros((self.samples_num, *keypoints_batch.shape[1:])) 193 | 194 | batch_size = keypoints_batch.shape[0] 195 | self.keypoints[self.idx:self.idx+batch_size] = keypoints_batch 196 | self.idx += batch_size 197 | 198 | def get_result(self): 199 | return self.keypoints 200 | 201 | 202 | print('Test on test dataset ..') 203 | def save_keypoints(filename, keypoints): 204 | # Reshape one sample keypoints into one line 205 | keypoints = keypoints.reshape(keypoints.shape[0], -1) 206 | np.savetxt(filename, keypoints, fmt='%0.4f') 207 | 208 | 209 | test_set = MARAHandDataset(data_dir, center_dir, 'test', test_subject_id, transform_test) 210 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=6) 211 | test_res_collector = BatchResultCollector(len(test_set), transform_output) 212 | 213 | test_epoch(net, test_loader, test_res_collector, device, dtype) 214 | keypoints_test = test_res_collector.get_result() 215 | save_keypoints('./test_res.txt', keypoints_test) 216 | 217 | 218 | print('Fit on train dataset ..') 219 | fit_set = MARAHandDataset(data_dir, center_dir, 'train', test_subject_id, transform_test) 220 | fit_loader = torch.utils.data.DataLoader(fit_set, batch_size=batch_size, shuffle=False, num_workers=6) 221 | fit_res_collector = BatchResultCollector(len(fit_set), transform_output) 222 | 223 | test_epoch(net, fit_loader, fit_res_collector, device, dtype) 224 | keypoints_fit = fit_res_collector.get_result() 225 | save_keypoints('./fit_res.txt', keypoints_fit) 226 | 227 | print('All done ..') 228 | -------------------------------------------------------------------------------- /experiments/msra-subject3/show_acc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../../') 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from lib.accuracy import * 7 | from vis.plot import * 8 | 9 | 10 | gt_file = r'./test_s3_gt.txt' 11 | pred_file = r'./test_res.txt' 12 | 13 | 14 | gt = np.loadtxt(gt_file) 15 | gt = gt.reshape(gt.shape[0], -1, 3) 16 | 17 | pred = np.loadtxt(pred_file) 18 | pred = pred.reshape(pred.shape[0], -1, 3) 19 | 20 | print('gt: ', gt.shape) 21 | print('pred: ', pred.shape) 22 | 23 | 24 | keypoints_num = 21 25 | names = ['joint'+str(i+1) for i in range(keypoints_num)] 26 | 27 | 28 | dist, acc = compute_dist_acc_wrapper(pred, gt, max_dist=100, num=100) 29 | 30 | fig, ax = plt.subplots() 31 | plot_acc(ax, dist, acc, names) 32 | fig.savefig('msra_s3_joint_acc.png') 33 | plt.show() 34 | 35 | 36 | mean_err = compute_mean_err(pred, gt) 37 | fig, ax = plt.subplots() 38 | plot_mean_err(ax, mean_err, names) 39 | fig.savefig('msra_s3_joint_acc.png') 40 | plt.show() 41 | 42 | 43 | print('mean_err: {}'.format(mean_err)) 44 | mean_err_all = compute_mean_err(pred.reshape((-1, 1, 3)), gt.reshape((-1, 1,3))) 45 | print('mean_err_all: ', mean_err_all) 46 | -------------------------------------------------------------------------------- /figs/Challenge_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/Challenge_result.png -------------------------------------------------------------------------------- /figs/Paper_result_hand_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/Paper_result_hand_graph.png -------------------------------------------------------------------------------- /figs/Paper_result_hand_msra.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/Paper_result_hand_msra.png -------------------------------------------------------------------------------- /figs/Paper_result_hand_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/Paper_result_hand_table.png -------------------------------------------------------------------------------- /figs/Paper_result_human_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/Paper_result_human_table.png -------------------------------------------------------------------------------- /figs/V2V-PoseNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/V2V-PoseNet.png -------------------------------------------------------------------------------- /figs/integral_pose_msra_s3_joint_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/integral_pose_msra_s3_joint_acc.png -------------------------------------------------------------------------------- /figs/integral_pose_msra_s3_joint_mean_error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/integral_pose_msra_s3_joint_mean_error.png -------------------------------------------------------------------------------- /figs/mean_error_compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/mean_error_compare.png -------------------------------------------------------------------------------- /figs/msra_s3_joint_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/msra_s3_joint_acc.png -------------------------------------------------------------------------------- /figs/msra_s3_joint_mean_error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/msra_s3_joint_mean_error.png -------------------------------------------------------------------------------- /figs/result/Paper_result_HANDS2017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/result/Paper_result_HANDS2017.png -------------------------------------------------------------------------------- /figs/result/Paper_result_ICVL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/result/Paper_result_ICVL.png -------------------------------------------------------------------------------- /figs/result/Paper_result_ITOP_front.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/result/Paper_result_ITOP_front.png -------------------------------------------------------------------------------- /figs/result/Paper_result_ITOP_top.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/result/Paper_result_ITOP_top.png -------------------------------------------------------------------------------- /figs/result/Paper_result_MSRA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/result/Paper_result_MSRA.png -------------------------------------------------------------------------------- /figs/result/Paper_result_NYU.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/result/Paper_result_NYU.png -------------------------------------------------------------------------------- /integral-pose/accuracy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_dist_acc_wrapper(pred, gt, max_dist=10, num=100): 5 | ''' 6 | pred: (N, K, 3) 7 | gt: (N, K, 3) 8 | 9 | return dist: (K, ) 10 | return acc: (K, num) 11 | ''' 12 | assert(pred.shape == gt.shape) 13 | assert(len(pred.shape) == 3) 14 | 15 | dist = np.linspace(0, max_dist, num) 16 | return dist, compute_dist_acc(pred, gt, dist) 17 | 18 | 19 | def compute_dist_acc(pred, gt, dist): 20 | ''' 21 | pred: (N, K, 3) 22 | gt: (N, K, 3) 23 | dist: (M, ) 24 | 25 | return acc: (K, M) 26 | ''' 27 | assert(pred.shape == gt.shape) 28 | assert(len(pred.shape) == 3) 29 | 30 | N, K = pred.shape[0], pred.shape[1] 31 | err_dist = np.sqrt(np.sum((pred - gt)**2, axis=2)) # (N, K) 32 | 33 | acc = np.zeros((K, dist.shape[0])) 34 | 35 | for i, d in enumerate(dist): 36 | acc_d = (err_dist < d).sum(axis=0) / N 37 | acc[:,i] = acc_d 38 | 39 | return acc 40 | 41 | 42 | def compute_mean_err(pred, gt): 43 | ''' 44 | pred: (N, K, 3) 45 | gt: (N, K, 3) 46 | 47 | mean_err: (K,) 48 | ''' 49 | N, K = pred.shape[0], pred.shape[1] 50 | err_dist = np.sqrt(np.sum((pred - gt)**2, axis=2)) # (N, K) 51 | return np.mean(err_dist, axis=0) 52 | 53 | 54 | def compute_dist_err(pred, gt): 55 | ''' 56 | pred: (N, K, 3) 57 | return: (N, K) 58 | ''' 59 | return np.sqrt(np.sum((pred - gt)**2, axis=2)) 60 | -------------------------------------------------------------------------------- /integral-pose/compare_acc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from accuracy import * 5 | from plot import * 6 | 7 | 8 | gt_file = r'./test_s3_gt.txt' 9 | pred_file = r'./v2vposenet-loss_test_res.txt' # copied from ./experiments/msra-subject3/test_res.txt 10 | pred_file1 = r'./test_res.txt' # one-hot heatmap loss + L1 coord loss 11 | 12 | 13 | gt = np.loadtxt(gt_file) 14 | gt = gt.reshape(gt.shape[0], -1, 3) 15 | 16 | pred = np.loadtxt(pred_file) 17 | pred = pred.reshape(pred.shape[0], -1, 3) 18 | 19 | pred1 = np.loadtxt(pred_file1) 20 | pred1 = pred1.reshape(pred1.shape[0], -1, 3) 21 | 22 | print('gt: ', gt.shape) 23 | print('pred: ', pred.shape) 24 | print('pred1: ', pred1.shape) 25 | 26 | 27 | names = ['kp'+str(i+1) for i in range(gt.shape[1]) ] 28 | 29 | 30 | ## 31 | dist, acc = compute_dist_acc_wrapper(pred, gt, 100, 100) 32 | 33 | fig, ax = plt.subplots() 34 | plot_acc(ax, dist, acc, names) 35 | plt.show() 36 | 37 | 38 | ## 39 | _, acc1 = compute_dist_acc_wrapper(pred1, gt, 100, 100) 40 | 41 | fig, ax = plt.subplots() 42 | plot_acc(ax, dist, acc1, names) 43 | plt.show() 44 | 45 | ## 46 | mean_err = compute_mean_err(pred, gt) 47 | mean_err1 = compute_mean_err(pred1, gt) 48 | 49 | fig, ax = plt.subplots() 50 | _pos = np.arange(len(names)) 51 | ax.bar(_pos, mean_err, width=0.1, label='gaussian-heatmap-loss(V2VPoseNet-loss)') 52 | ax.bar(_pos+0.1, mean_err1, width=0.1, label='one_hot-heamap-loss + L1 coord loss') 53 | ax.set_xticks(_pos) 54 | ax.set_xticklabels(names) 55 | ax.set_xlabel('keypoints categories') 56 | ax.set_ylabel('distance mean error (mm)') 57 | ax.legend(loc='upper right') 58 | 59 | plt.show() 60 | 61 | 62 | all_mean_err = np.mean(mean_err) 63 | all_mean_err1 = np.mean(mean_err1) 64 | print('all_mean_error: ', all_mean_err) 65 | print('all_mean_error1: ', all_mean_err1) 66 | 67 | 68 | ## histogram 69 | dist_err = compute_dist_err(pred, gt) 70 | dist_err1 = compute_dist_err(pred1, gt) 71 | 72 | fig, ax = plt.subplots() 73 | bins = np.linspace(0, 10, 200) 74 | ax.hist([dist_err[:].ravel(), dist_err1[:].ravel()], bins=bins, label=['gaussian-heatmap-loss(V2VPoseNet-loss)', 'one_hot-heamap-loss + L1 coord loss']) 75 | ax.set_xlabel('distance error (mm)') 76 | ax.set_ylabel('keypoints number') 77 | plt.legend(loc='upper right') 78 | 79 | plt.show() -------------------------------------------------------------------------------- /integral-pose/gen_gt.py: -------------------------------------------------------------------------------- 1 | ##%% 2 | import os 3 | import numpy as np 4 | import sys 5 | import struct 6 | from torch.utils.data import Dataset 7 | 8 | 9 | def pixel2world(x, y, z, img_width, img_height, fx, fy): 10 | w_x = (x - img_width / 2) * z / fx 11 | w_y = (img_height / 2 - y) * z / fy 12 | w_z = z 13 | return w_x, w_y, w_z 14 | 15 | 16 | def world2pixel(x, y, z, img_width, img_height, fx, fy): 17 | p_x = x * fx / z + img_width / 2 18 | p_y = img_height / 2 - y * fy / z 19 | return p_x, p_y 20 | 21 | 22 | def depthmap2points(image, fx, fy): 23 | h, w = image.shape 24 | x, y = np.meshgrid(np.arange(w) + 1, np.arange(h) + 1) 25 | points = np.zeros((h, w, 3), dtype=np.float32) 26 | points[:,:,0], points[:,:,1], points[:,:,2] = pixel2world(x, y, image, w, h, fx, fy) 27 | return points 28 | 29 | 30 | def points2pixels(points, img_width, img_height, fx, fy): 31 | pixels = np.zeros((points.shape[0], 2)) 32 | pixels[:, 0], pixels[:, 1] = \ 33 | world2pixel(points[:,0], points[:, 1], points[:, 2], img_width, img_height, fx, fy) 34 | return pixels 35 | 36 | 37 | def load_depthmap(filename, img_width, img_height, max_depth): 38 | with open(filename, mode='rb') as f: 39 | data = f.read() 40 | _, _, left, top, right, bottom = struct.unpack('I'*6, data[:6*4]) 41 | num_pixel = (right - left) * (bottom - top) 42 | cropped_image = struct.unpack('f'*num_pixel, data[6*4:]) 43 | 44 | cropped_image = np.asarray(cropped_image).reshape(bottom-top, -1) 45 | depth_image = np.zeros((img_height, img_width), dtype=np.float32) 46 | depth_image[top:bottom, left:right] = cropped_image 47 | depth_image[depth_image == 0] = max_depth 48 | 49 | return depth_image 50 | 51 | 52 | class MARAHandDataset(Dataset): 53 | def __init__(self, root, center_dir, mode, test_subject_id, transform=None): 54 | self.img_width = 320 55 | self.img_height = 240 56 | self.min_depth = 100 57 | self.max_depth = 700 58 | self.fx = 241.42 59 | self.fy = 241.42 60 | self.joint_num = 21 61 | self.world_dim = 3 62 | self.folder_list = ['1','2','3','4','5','6','7','8','9','I','IP','L','MP','RP','T','TIP','Y'] 63 | self.subject_num = 9 64 | 65 | self.root = root 66 | self.center_dir = center_dir 67 | self.mode = mode 68 | self.test_subject_id = test_subject_id 69 | self.transform = transform 70 | 71 | if not self.mode in ['train', 'test']: raise ValueError('Invalid mode') 72 | assert self.test_subject_id >= 0 and self.test_subject_id < self.subject_num 73 | 74 | if not self._check_exists(): raise RuntimeError('Invalid MSRA hand dataset') 75 | 76 | self._load() 77 | 78 | def get_data(self): 79 | return self.names, self.joints_world, self.ref_pts 80 | 81 | def __getitem__(self, index): 82 | depthmap = load_depthmap(self.names[index], self.img_width, self.img_height, self.max_depth) 83 | points = depthmap2points(depthmap, self.fx, self.fy) 84 | points = points.reshape((-1, 3)) 85 | 86 | sample = { 87 | 'name': self.names[index], 88 | 'points': points, 89 | 'joints': self.joints_world[index], 90 | 'refpoint': self.ref_pts[index] 91 | } 92 | 93 | if self.transform: sample = self.transform(sample) 94 | 95 | return sample 96 | 97 | def __len__(self): 98 | return self.num_samples 99 | 100 | def _load(self): 101 | self._compute_dataset_size() 102 | 103 | self.num_samples = self.train_size if self.mode == 'train' else self.test_size 104 | self.joints_world = np.zeros((self.num_samples, self.joint_num, self.world_dim)) 105 | self.ref_pts = np.zeros((self.num_samples, self.world_dim)) 106 | self.names = [] 107 | 108 | # Collect reference center points strings 109 | if self.mode == 'train': ref_pt_file = 'center_train_' + str(self.test_subject_id) + '_refined.txt' 110 | else: ref_pt_file = 'center_test_' + str(self.test_subject_id) + '_refined.txt' 111 | 112 | with open(os.path.join(self.center_dir, ref_pt_file)) as f: 113 | ref_pt_str = [l.rstrip() for l in f] 114 | 115 | # 116 | file_id = 0 117 | frame_id = 0 118 | 119 | for mid in range(self.subject_num): 120 | if self.mode == 'train': model_chk = (mid != self.test_subject_id) 121 | elif self.mode == 'test': model_chk = (mid == self.test_subject_id) 122 | else: raise RuntimeError('unsupported mode {}'.format(self.mode)) 123 | 124 | if model_chk: 125 | for fd in self.folder_list: 126 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt') 127 | 128 | lines = [] 129 | with open(annot_file) as f: 130 | lines = [line.rstrip() for line in f] 131 | 132 | # skip first line 133 | for i in range(1, len(lines)): 134 | # referece point 135 | splitted = ref_pt_str[file_id].split() 136 | if splitted[0] == 'invalid': 137 | print('Warning: found invalid reference frame') 138 | file_id += 1 139 | continue 140 | else: 141 | self.ref_pts[frame_id, 0] = float(splitted[0]) 142 | self.ref_pts[frame_id, 1] = float(splitted[1]) 143 | self.ref_pts[frame_id, 2] = float(splitted[2]) 144 | 145 | # joint point 146 | splitted = lines[i].split() 147 | for jid in range(self.joint_num): 148 | self.joints_world[frame_id, jid, 0] = float(splitted[jid * self.world_dim]) 149 | self.joints_world[frame_id, jid, 1] = float(splitted[jid * self.world_dim + 1]) 150 | self.joints_world[frame_id, jid, 2] = -float(splitted[jid * self.world_dim + 2]) 151 | 152 | filename = os.path.join(self.root, 'P'+str(mid), fd, '{:0>6d}'.format(i-1) + '_depth.bin') 153 | self.names.append(filename) 154 | 155 | frame_id += 1 156 | file_id += 1 157 | 158 | def _compute_dataset_size(self): 159 | self.train_size, self.test_size = 0, 0 160 | 161 | for mid in range(self.subject_num): 162 | num = 0 163 | for fd in self.folder_list: 164 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt') 165 | with open(annot_file) as f: 166 | num = int(f.readline().rstrip()) 167 | if mid == self.test_subject_id: self.test_size += num 168 | else: self.train_size += num 169 | 170 | def _check_exists(self): 171 | # Check basic data 172 | for mid in range(self.subject_num): 173 | for fd in self.folder_list: 174 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt') 175 | if not os.path.exists(annot_file): 176 | print('Error: annotation file {} does not exist'.format(annot_file)) 177 | return False 178 | 179 | # Check precomputed centers by v2v-hand model's author 180 | for subject_id in range(self.subject_num): 181 | center_train = os.path.join(self.center_dir, 'center_train_' + str(subject_id) + '_refined.txt') 182 | center_test = os.path.join(self.center_dir, 'center_test_' + str(subject_id) + '_refined.txt') 183 | if not os.path.exists(center_train) or not os.path.exists(center_test): 184 | print('Error: precomputed center files do not exist') 185 | return False 186 | 187 | return True 188 | 189 | 190 | ##%% 191 | # Generate train_subject3_gt.txt and test_subject3_gt.txt 192 | data_dir = r'/home/maiqi/yalong/dataset/cvpr15_MSRAHandGestureDB' 193 | center_dir = r'/home/maiqi/yalong/project/KeyPoint/Code/V2V-PoseNet-Rlease-Codes/V2V-PoseNet_RELEASE-hand/data-result/MSRA-result/center' 194 | test_subject_id = 3 195 | 196 | 197 | ##%% 198 | def save_keypoints(filename, keypoints): 199 | # Reshape one sample keypoints into one line 200 | keypoints = keypoints.reshape(keypoints.shape[0], -1) 201 | np.savetxt(filename, keypoints, fmt='%0.4f') 202 | 203 | 204 | ##%% 205 | train_dataset = MARAHandDataset(root=data_dir, center_dir=center_dir, mode='train', test_subject_id=test_subject_id) 206 | names, joints_world, ref_pts = train_dataset.get_data() 207 | print('save train reslt ..') 208 | save_keypoints('./train_s3_gt.txt', joints_world) 209 | print('done ..') 210 | 211 | 212 | ##%% 213 | test_dataset = MARAHandDataset(root=data_dir, center_dir=center_dir, mode='test', test_subject_id=test_subject_id) 214 | names, joints_world, ref_pts = test_dataset.get_data() 215 | print('save test reslt ..') 216 | save_keypoints('./test_s3_gt.txt', joints_world) 217 | print('done ..') 218 | -------------------------------------------------------------------------------- /integral-pose/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SoftmaxCrossEntropyWithLogits(nn.Module): 7 | ''' 8 | Similar to tensorflow's tf.nn.softmax_cross_entropy_with_logits 9 | ref: https://gist.github.com/tejaskhot/cf3d087ce4708c422e68b3b747494b9f 10 | 11 | The 'input' is unnormalized scores. 12 | The 'target' is a probability distribution. 13 | 14 | Shape: 15 | Input: (N, C), batch size N, with C classes 16 | Target: (N, C), batch size N, with C classes 17 | ''' 18 | def __init__(self): 19 | super(SoftmaxCrossEntropyWithLogits, self).__init__() 20 | 21 | def forward(self, input, target): 22 | loss = torch.sum(-target * F.log_softmax(input, -1), -1) 23 | mean_loss = torch.mean(loss) 24 | return mean_loss 25 | 26 | 27 | class MixedLoss(nn.Module): 28 | ''' 29 | ref: https://github.com/mks0601/PoseFix_RELEASE/blob/master/main/model.py 30 | 31 | input: { 32 | 'heatmap': (N, C, X, Y, Z), unnormalized 33 | 'coord': (N, C, 3) 34 | } 35 | 36 | target: { 37 | 'heatmap': (N, C, X, Y, Z), normalized 38 | 'coord': (N, C, 3) 39 | } 40 | 41 | ''' 42 | def __init__(self, heatmap_weight=0.5): 43 | # def __init__(self, heatmap_weight=0.05): 44 | super(MixedLoss, self).__init__() 45 | self.w1 = heatmap_weight 46 | self.w2 = 1 - self.w1 47 | self.cross_entropy_loss = SoftmaxCrossEntropyWithLogits() 48 | 49 | def forward(self, input, target): 50 | pred_heatmap, pred_coord = input['heatmap'], input['coord'] 51 | gt_heatmap, gt_coord = target['heatmap'], target['coord'] 52 | 53 | # Heatmap loss 54 | N, C = pred_heatmap.shape[0:2] 55 | pred_heatmap = pred_heatmap.view(N*C, -1) 56 | gt_heatmap = gt_heatmap.view(N*C, -1) 57 | 58 | # Note, averaged over N*C 59 | hm_loss = self.cross_entropy_loss(pred_heatmap, gt_heatmap) 60 | 61 | # Coord L1 loss 62 | l1_loss = torch.mean(torch.abs(pred_coord - gt_coord)) 63 | 64 | return self.w1 * hm_loss + self.w2 * l1_loss 65 | -------------------------------------------------------------------------------- /integral-pose/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.backends.cudnn as cudnn 6 | import argparse 7 | import os 8 | 9 | from solver import train_epoch, val_epoch, test_epoch 10 | from sampler import ChunkSampler 11 | from model import Model 12 | from loss import MixedLoss 13 | from v2v_util import V2VVoxelization 14 | from msra_hand import MARAHandDataset 15 | 16 | 17 | ####################################################################################### 18 | # Note, 19 | # Run in project root direcotry(ROOT_DIR) with: 20 | # PYTHONPATH=./ python experiments/msra-subject3/main.py 21 | # 22 | # This script will train model on MSRA hand datasets, save checkpoints to ROOT_DIR/checkpoint, 23 | # and save test results(test_res.txt) and fit results(fit_res.txt) to ROOT_DIR. 24 | # 25 | 26 | 27 | ####################################################################################### 28 | ## Some helpers 29 | def parse_args(): 30 | parser = argparse.ArgumentParser(description='PyTorch Hand Keypoints Estimation Training') 31 | #parser.add_argument('--resume', 'r', action='store_true', help='resume from checkpoint') 32 | parser.add_argument('--resume', '-r', default=-1, type=int, help='resume after epoch') 33 | args = parser.parse_args() 34 | return args 35 | 36 | 37 | ####################################################################################### 38 | ## Configurations 39 | print('Warning: disable cudnn for batchnorm first, or just use only cuda instead!') 40 | 41 | # When we need to resume training, enable randomness to avoid seeing the determinstic 42 | # (agumented) samples many times. 43 | # np.random.seed(1) 44 | # torch.manual_seed(1) 45 | # torch.cuda.manual_seed(1) 46 | 47 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 48 | dtype = torch.float 49 | 50 | # 51 | args = parse_args() 52 | resume_train = args.resume >= 0 53 | resume_after_epoch = args.resume 54 | 55 | save_checkpoint = True 56 | checkpoint_per_epochs = 1 57 | checkpoint_dir = r'./checkpoint' 58 | 59 | start_epoch = 0 60 | epochs_num = 15 61 | 62 | batch_size = 12 63 | 64 | 65 | ####################################################################################### 66 | ## Data, transform, dataset and loader 67 | # Data 68 | print('==> Preparing data ..') 69 | data_dir = r'/home/maiqi/yalong/dataset/cvpr15_MSRAHandGestureDB' 70 | center_dir = r'/home/maiqi/yalong/project/KeyPoint/Code/V2V-PoseNet-Rlease-Codes/V2V-PoseNet_RELEASE-hand/data-result/MSRA-result/center' 71 | keypoints_num = 21 72 | test_subject_id = 3 73 | cubic_size = 200 74 | 75 | 76 | # Transform 77 | voxelization_train = V2VVoxelization(cubic_size=200, augmentation=True) 78 | voxelization_val = V2VVoxelization(cubic_size=200, augmentation=False) 79 | 80 | 81 | def transform_train(sample): 82 | points, keypoints, refpoint = sample['points'], sample['joints'], sample['refpoint'] 83 | assert(keypoints.shape[0] == keypoints_num) 84 | 85 | input, heatmap, coord = voxelization_train({'points': points, 'keypoints': keypoints, 'refpoint': refpoint}) 86 | 87 | sample = { 88 | 'input': input, 89 | 'target': { 90 | 'heatmap': heatmap, 91 | 'coord': coord 92 | }, 93 | 'extra': { 94 | 'refpoint': refpoint.reshape((1, -1)) 95 | } 96 | } 97 | 98 | return sample 99 | 100 | 101 | 102 | def transform_val(sample): 103 | points, keypoints, refpoint = sample['points'], sample['joints'], sample['refpoint'] 104 | assert(keypoints.shape[0] == keypoints_num) 105 | 106 | input, heatmap, coord = voxelization_val({'points': points, 'keypoints': keypoints, 'refpoint': refpoint}) 107 | 108 | sample = { 109 | 'input': input, 110 | 'target': { 111 | 'heatmap': heatmap, 112 | 'coord': coord 113 | }, 114 | 'extra': { 115 | 'refpoint': refpoint.reshape((1, -1)) 116 | } 117 | } 118 | 119 | return sample 120 | 121 | 122 | # Dataset and loader 123 | train_set = MARAHandDataset(data_dir, center_dir, 'train', test_subject_id, transform_train) 124 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=6) 125 | # train_num = 24 126 | # train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=False, num_workers=6,sampler=ChunkSampler(train_num, 0)) 127 | 128 | # No separate validation dataset, just use test dataset instead 129 | val_set = MARAHandDataset(data_dir, center_dir, 'test', test_subject_id, transform_val) 130 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=6) 131 | # val_num = 24 132 | # val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=6, sampler=ChunkSampler(val_num)) 133 | 134 | 135 | ####################################################################################### 136 | ## Model, criterion and optimizer 137 | print('==> Constructing model ..') 138 | output_res = 44 139 | net = Model(in_channels=1, out_channels=keypoints_num, output_res=output_res) 140 | 141 | net = net.to(device, dtype) 142 | if device == torch.device('cuda'): 143 | torch.backends.cudnn.enabled = True 144 | cudnn.benchmark = True 145 | print('cudnn.enabled: ', torch.backends.cudnn.enabled) 146 | 147 | 148 | criterion = MixedLoss() 149 | optimizer = optim.Adam(net.parameters()) 150 | 151 | 152 | ####################################################################################### 153 | ## Resume 154 | if resume_train: 155 | # Load checkpoint 156 | epoch = resume_after_epoch 157 | checkpoint_file = os.path.join(checkpoint_dir, 'epoch'+str(epoch)+'.pth') 158 | 159 | print('==> Resuming from checkpoint after epoch {} ..'.format(epoch)) 160 | assert os.path.isdir(checkpoint_dir), 'Error: no checkpoint directory found!' 161 | assert os.path.isfile(checkpoint_file), 'Error: no checkpoint file of epoch {}'.format(epoch) 162 | 163 | checkpoint = torch.load(os.path.join(checkpoint_dir, 'epoch'+str(epoch)+'.pth')) 164 | net.load_state_dict(checkpoint['model_state_dict']) 165 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 166 | start_epoch = checkpoint['epoch'] + 1 167 | 168 | 169 | ####################################################################################### 170 | ## Train and Validate 171 | print('==> Training ..') 172 | for epoch in range(start_epoch, start_epoch + epochs_num): 173 | print('Epoch: {}'.format(epoch)) 174 | train_epoch(net, criterion, optimizer, train_loader, device=device, dtype=dtype) 175 | val_epoch(net, criterion, val_loader, device=device, dtype=dtype) 176 | 177 | if save_checkpoint and epoch % checkpoint_per_epochs == 0: 178 | if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) 179 | checkpoint_file = os.path.join(checkpoint_dir, 'epoch'+str(epoch)+'.pth') 180 | checkpoint = { 181 | 'model_state_dict': net.state_dict(), 182 | 'optimizer_state_dict': optimizer.state_dict(), 183 | 'epoch': epoch 184 | } 185 | torch.save(checkpoint, checkpoint_file) 186 | 187 | 188 | ####################################################################################### 189 | ## Test 190 | print('==> Testing ..') 191 | 192 | def transform_test(sample): 193 | points, keypoints, refpoint = sample['points'], sample['joints'], sample['refpoint'] 194 | assert(keypoints.shape[0] == keypoints_num) 195 | 196 | input, heatmap, coord = voxelization_val({'points': points, 'keypoints': keypoints, 'refpoint': refpoint}) 197 | 198 | sample = { 199 | 'input': input, 200 | 'target': { 201 | 'heatmap': heatmap, 202 | 'coord': coord 203 | }, 204 | 'extra': { 205 | 'refpoint': refpoint.reshape((1, -1)) 206 | } 207 | } 208 | 209 | return sample 210 | 211 | 212 | def transform_coord(coords, refpoints): 213 | keypoints = voxelization_val.warp2continuous_raw(coords, refpoints) 214 | return keypoints 215 | 216 | transform_output = transform_coord 217 | 218 | 219 | class BatchResultCollector(): 220 | def __init__(self, samples_num, transform_output): 221 | self.samples_num = samples_num 222 | self.transform_output = transform_output 223 | self.keypoints = None 224 | self.idx = 0 225 | 226 | def __call__(self, data_batch): 227 | outputs_batch = data_batch['output']['coord'] 228 | refpoints_batch = data_batch['extra']['refpoint'] 229 | 230 | outputs_batch = outputs_batch.cpu().numpy() 231 | refpoints_batch = refpoints_batch.numpy() 232 | 233 | keypoints_batch = self.transform_output(outputs_batch, refpoints_batch) 234 | 235 | if self.keypoints is None: 236 | # Initialize keypoints until dimensions awailable now 237 | self.keypoints = np.zeros((self.samples_num, *keypoints_batch.shape[1:])) 238 | 239 | batch_size = keypoints_batch.shape[0] 240 | self.keypoints[self.idx:self.idx+batch_size] = keypoints_batch 241 | self.idx += batch_size 242 | 243 | def get_result(self): 244 | return self.keypoints 245 | 246 | 247 | print('Test on test dataset ..') 248 | def save_keypoints(filename, keypoints): 249 | # Reshape one sample keypoints into one line 250 | keypoints = keypoints.reshape(keypoints.shape[0], -1) 251 | np.savetxt(filename, keypoints, fmt='%0.4f') 252 | 253 | 254 | test_set = MARAHandDataset(data_dir, center_dir, 'test', test_subject_id, transform_test) 255 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=6) 256 | test_res_collector = BatchResultCollector(len(test_set), transform_output) 257 | 258 | test_epoch(net, test_loader, test_res_collector, device, dtype) 259 | keypoints_test = test_res_collector.get_result() 260 | save_keypoints('./test_res.txt', keypoints_test) 261 | 262 | 263 | print('Fit on train dataset ..') 264 | fit_set = MARAHandDataset(data_dir, center_dir, 'train', test_subject_id, transform_test) 265 | fit_loader = torch.utils.data.DataLoader(fit_set, batch_size=batch_size, shuffle=False, num_workers=6) 266 | fit_res_collector = BatchResultCollector(len(fit_set), transform_output) 267 | 268 | test_epoch(net, fit_loader, fit_res_collector, device, dtype) 269 | keypoints_fit = fit_res_collector.get_result() 270 | save_keypoints('./fit_res.txt', keypoints_fit) 271 | 272 | print('All done ..') 273 | -------------------------------------------------------------------------------- /integral-pose/model.py: -------------------------------------------------------------------------------- 1 | # from collections import namedtuple 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from v2v_model import V2VModel 6 | import numpy as np 7 | 8 | 9 | class VolumetricSoftmax(nn.Module): 10 | ''' 11 | TODO: soft-argmax: norm coord to [-1, 1], instead of [0, N] 12 | 13 | ref: https://gist.github.com/jeasinema/1cba9b40451236ba2cfb507687e08834 14 | ''' 15 | 16 | def __init__(self, channel, sizes): 17 | super(VolumetricSoftmax, self).__init__() 18 | self.channel = channel 19 | self.xsize, self.ysize, self.zsize = sizes[0], sizes[1], sizes[2] 20 | self.volume_size = self.xsize * self.ysize * self.zsize 21 | 22 | # TODO: optimize, compute x, y, z together 23 | # pos = np.meshgrid(np.arange(self.xsize), np.arange(self.ysize), np.arange(self.zsize), indexing='ij') 24 | # pos = np.array(pos).reshape((3, -1)) 25 | # pos = torch.from_numpy(pos) 26 | # self.register_buffer('pos', pos) 27 | 28 | pos_x, pos_y, pos_z = np.meshgrid(np.arange(self.xsize), np.arange(self.ysize), np.arange(self.zsize), indexing='ij') 29 | 30 | pos_x = torch.from_numpy(pos_x.reshape((-1))).float() 31 | pos_y = torch.from_numpy(pos_y.reshape((-1))).float() 32 | pos_z = torch.from_numpy(pos_z.reshape((-1))).float() 33 | 34 | self.register_buffer('pos_x', pos_x) 35 | self.register_buffer('pos_y', pos_y) 36 | self.register_buffer('pos_z', pos_z) 37 | 38 | def forward(self, x): 39 | # x: (N, C, X, Y, Z) 40 | x = x.view(-1, self.volume_size) 41 | p = F.softmax(x, dim=1) 42 | 43 | #print('self.pos_x: {}, device: {}, dtype: {}'.format(type(self.pos_x), self.pos_x.device, self.pos_x.dtype)) 44 | #print('p: {}, device: {}, dtype: {}'.format(type(p), p.device, p.dtype)) 45 | 46 | expected_x = torch.sum(self.pos_x * p, dim=1, keepdim=True) 47 | expected_y = torch.sum(self.pos_y * p, dim=1, keepdim=True) 48 | expected_z = torch.sum(self.pos_z * p, dim=1, keepdim=True) 49 | 50 | expected_xyz = torch.cat([expected_x, expected_y, expected_z], 1) 51 | out = expected_xyz.view(-1, self.channel, 3) 52 | 53 | return out 54 | 55 | 56 | 57 | class Model(nn.Module): 58 | def __init__(self, in_channels, out_channels, output_res=44): 59 | super(Model, self).__init__() 60 | self.output_res = output_res 61 | self.basic_model = V2VModel(in_channels, out_channels) 62 | self.spatial_softmax = VolumetricSoftmax(out_channels, (self.output_res, self.output_res, self.output_res)) 63 | 64 | def forward(self, x): 65 | heatmap = self.basic_model(x) 66 | coord = self.spatial_softmax(heatmap) 67 | 68 | #print('model heatmap: {}'.format(heatmap.dtype)) 69 | 70 | output = { 71 | 'heatmap': heatmap, 72 | 'coord': coord 73 | } 74 | 75 | return output 76 | -------------------------------------------------------------------------------- /integral-pose/msra_hand.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import sys 4 | import struct 5 | from torch.utils.data import Dataset 6 | 7 | 8 | def pixel2world(x, y, z, img_width, img_height, fx, fy): 9 | w_x = (x - img_width / 2) * z / fx 10 | w_y = (img_height / 2 - y) * z / fy 11 | w_z = z 12 | return w_x, w_y, w_z 13 | 14 | 15 | def world2pixel(x, y, z, img_width, img_height, fx, fy): 16 | p_x = x * fx / z + img_width / 2 17 | p_y = img_height / 2 - y * fy / z 18 | return p_x, p_y 19 | 20 | 21 | def depthmap2points(image, fx, fy): 22 | h, w = image.shape 23 | x, y = np.meshgrid(np.arange(w) + 1, np.arange(h) + 1) 24 | points = np.zeros((h, w, 3), dtype=np.float32) 25 | points[:,:,0], points[:,:,1], points[:,:,2] = pixel2world(x, y, image, w, h, fx, fy) 26 | return points 27 | 28 | 29 | def points2pixels(points, img_width, img_height, fx, fy): 30 | pixels = np.zeros((points.shape[0], 2)) 31 | pixels[:, 0], pixels[:, 1] = \ 32 | world2pixel(points[:,0], points[:, 1], points[:, 2], img_width, img_height, fx, fy) 33 | return pixels 34 | 35 | 36 | def load_depthmap(filename, img_width, img_height, max_depth): 37 | with open(filename, mode='rb') as f: 38 | data = f.read() 39 | _, _, left, top, right, bottom = struct.unpack('I'*6, data[:6*4]) 40 | num_pixel = (right - left) * (bottom - top) 41 | cropped_image = struct.unpack('f'*num_pixel, data[6*4:]) 42 | 43 | cropped_image = np.asarray(cropped_image).reshape(bottom-top, -1) 44 | depth_image = np.zeros((img_height, img_width), dtype=np.float32) 45 | depth_image[top:bottom, left:right] = cropped_image 46 | depth_image[depth_image == 0] = max_depth 47 | 48 | return depth_image 49 | 50 | 51 | class MARAHandDataset(Dataset): 52 | def __init__(self, root, center_dir, mode, test_subject_id, transform=None): 53 | self.img_width = 320 54 | self.img_height = 240 55 | self.min_depth = 100 56 | self.max_depth = 700 57 | self.fx = 241.42 58 | self.fy = 241.42 59 | self.joint_num = 21 60 | self.world_dim = 3 61 | self.folder_list = ['1','2','3','4','5','6','7','8','9','I','IP','L','MP','RP','T','TIP','Y'] 62 | self.subject_num = 9 63 | 64 | self.root = root 65 | self.center_dir = center_dir 66 | self.mode = mode 67 | self.test_subject_id = test_subject_id 68 | self.transform = transform 69 | 70 | if not self.mode in ['train', 'test']: raise ValueError('Invalid mode') 71 | assert self.test_subject_id >= 0 and self.test_subject_id < self.subject_num 72 | 73 | if not self._check_exists(): raise RuntimeError('Invalid MSRA hand dataset') 74 | 75 | self._load() 76 | 77 | def __getitem__(self, index): 78 | depthmap = load_depthmap(self.names[index], self.img_width, self.img_height, self.max_depth) 79 | points = depthmap2points(depthmap, self.fx, self.fy) 80 | points = points.reshape((-1, 3)) 81 | 82 | sample = { 83 | 'name': self.names[index], 84 | 'points': points, 85 | 'joints': self.joints_world[index], 86 | 'refpoint': self.ref_pts[index] 87 | } 88 | 89 | if self.transform: sample = self.transform(sample) 90 | 91 | return sample 92 | 93 | def __len__(self): 94 | return self.num_samples 95 | 96 | def _load(self): 97 | self._compute_dataset_size() 98 | 99 | self.num_samples = self.train_size if self.mode == 'train' else self.test_size 100 | self.joints_world = np.zeros((self.num_samples, self.joint_num, self.world_dim)) 101 | self.ref_pts = np.zeros((self.num_samples, self.world_dim)) 102 | self.names = [] 103 | 104 | # Collect reference center points strings 105 | if self.mode == 'train': ref_pt_file = 'center_train_' + str(self.test_subject_id) + '_refined.txt' 106 | else: ref_pt_file = 'center_test_' + str(self.test_subject_id) + '_refined.txt' 107 | 108 | with open(os.path.join(self.center_dir, ref_pt_file)) as f: 109 | ref_pt_str = [l.rstrip() for l in f] 110 | 111 | # 112 | file_id = 0 113 | frame_id = 0 114 | 115 | for mid in range(self.subject_num): 116 | if self.mode == 'train': model_chk = (mid != self.test_subject_id) 117 | elif self.mode == 'test': model_chk = (mid == self.test_subject_id) 118 | else: raise RuntimeError('unsupported mode {}'.format(self.mode)) 119 | 120 | if model_chk: 121 | for fd in self.folder_list: 122 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt') 123 | 124 | lines = [] 125 | with open(annot_file) as f: 126 | lines = [line.rstrip() for line in f] 127 | 128 | # skip first line 129 | for i in range(1, len(lines)): 130 | # referece point 131 | splitted = ref_pt_str[file_id].split() 132 | if splitted[0] == 'invalid': 133 | print('Warning: found invalid reference frame') 134 | file_id += 1 135 | continue 136 | else: 137 | self.ref_pts[frame_id, 0] = float(splitted[0]) 138 | self.ref_pts[frame_id, 1] = float(splitted[1]) 139 | self.ref_pts[frame_id, 2] = float(splitted[2]) 140 | 141 | # joint point 142 | splitted = lines[i].split() 143 | for jid in range(self.joint_num): 144 | self.joints_world[frame_id, jid, 0] = float(splitted[jid * self.world_dim]) 145 | self.joints_world[frame_id, jid, 1] = float(splitted[jid * self.world_dim + 1]) 146 | self.joints_world[frame_id, jid, 2] = -float(splitted[jid * self.world_dim + 2]) 147 | 148 | filename = os.path.join(self.root, 'P'+str(mid), fd, '{:0>6d}'.format(i-1) + '_depth.bin') 149 | self.names.append(filename) 150 | 151 | frame_id += 1 152 | file_id += 1 153 | 154 | def _compute_dataset_size(self): 155 | self.train_size, self.test_size = 0, 0 156 | 157 | for mid in range(self.subject_num): 158 | num = 0 159 | for fd in self.folder_list: 160 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt') 161 | with open(annot_file) as f: 162 | num = int(f.readline().rstrip()) 163 | if mid == self.test_subject_id: self.test_size += num 164 | else: self.train_size += num 165 | 166 | def _check_exists(self): 167 | # Check basic data 168 | for mid in range(self.subject_num): 169 | for fd in self.folder_list: 170 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt') 171 | if not os.path.exists(annot_file): 172 | print('Error: annotation file {} does not exist'.format(annot_file)) 173 | return False 174 | 175 | # Check precomputed centers by v2v-hand model's author 176 | for subject_id in range(self.subject_num): 177 | center_train = os.path.join(self.center_dir, 'center_train_' + str(subject_id) + '_refined.txt') 178 | center_test = os.path.join(self.center_dir, 'center_test_' + str(subject_id) + '_refined.txt') 179 | if not os.path.exists(center_train) or not os.path.exists(center_test): 180 | print('Error: precomputed center files do not exist') 181 | return False 182 | 183 | return True 184 | -------------------------------------------------------------------------------- /integral-pose/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | def plot_acc(ax, dist, acc, name): 5 | ''' 6 | acc: (K, num) 7 | dist: (K, ) 8 | name: (K, ) 9 | ''' 10 | assert(acc.shape[0] == len(name)) 11 | 12 | for i in range(len(name)): 13 | ax.plot(dist, acc[i], label=name[i]) 14 | 15 | ax.legend() 16 | 17 | ax.set_xlabel('Maximum allowed distance to GT (mm)') 18 | ax.set_ylabel('Fraction of samples within distance') 19 | 20 | 21 | def plot_mean_err(ax, mean_err, name): 22 | ''' 23 | mean_err: (K, ) 24 | name: (K, ) 25 | ''' 26 | ax.bar(name, mean_err) 27 | -------------------------------------------------------------------------------- /integral-pose/progressbar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | 6 | _, term_width = os.popen('stty size', 'r').read().split() 7 | term_width = int(term_width) 8 | TOTAL_BAR_LENGTH = 65. 9 | last_time = time.time() 10 | begin_time = last_time 11 | 12 | def progress_bar(current, total, msg=None): 13 | global last_time, begin_time 14 | if current == 0: 15 | begin_time = time.time() # Reset for new bar. 16 | 17 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 18 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 19 | 20 | sys.stdout.write(' [') 21 | for i in range(cur_len): 22 | sys.stdout.write('=') 23 | sys.stdout.write('>') 24 | for i in range(rest_len): 25 | sys.stdout.write('.') 26 | sys.stdout.write(']') 27 | 28 | cur_time = time.time() 29 | step_time = cur_time - last_time 30 | last_time = cur_time 31 | tot_time = cur_time - begin_time 32 | 33 | L = [] 34 | L.append(' Step: %s' % format_time(step_time)) 35 | L.append(' | Tot: %s' % format_time(tot_time)) 36 | if msg: 37 | L.append(' | ' + msg) 38 | 39 | msg = ''.join(L) 40 | sys.stdout.write(msg) 41 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 42 | sys.stdout.write(' ') 43 | 44 | # Go back to the center of the bar. 45 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 46 | sys.stdout.write('\b') 47 | sys.stdout.write(' %d/%d ' % (current+1, total)) 48 | 49 | if current < total-1: 50 | sys.stdout.write('\r') 51 | else: 52 | sys.stdout.write('\n') 53 | sys.stdout.flush() 54 | 55 | 56 | def format_time(seconds): 57 | days = int(seconds / 3600/24) 58 | seconds = seconds - days*3600*24 59 | hours = int(seconds / 3600) 60 | seconds = seconds - hours*3600 61 | minutes = int(seconds / 60) 62 | seconds = seconds - minutes*60 63 | secondsf = int(seconds) 64 | seconds = seconds - secondsf 65 | millis = int(seconds*1000) 66 | 67 | f = '' 68 | i = 1 69 | if days > 0: 70 | f += str(days) + 'D' 71 | i += 1 72 | if hours > 0 and i <= 2: 73 | f += str(hours) + 'h' 74 | i += 1 75 | if minutes > 0 and i <= 2: 76 | f += str(minutes) + 'm' 77 | i += 1 78 | if secondsf > 0 and i <= 2: 79 | f += str(secondsf) + 's' 80 | i += 1 81 | if millis > 0 and i <= 2: 82 | f += str(millis) + 'ms' 83 | i += 1 84 | if f == '': 85 | f = '0ms' 86 | return f 87 | -------------------------------------------------------------------------------- /integral-pose/sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import sampler 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class ChunkSampler(sampler.Sampler): 6 | """Samples elements sequentially from some offset. 7 | Arguments: 8 | num_samples: # of desired datapoints 9 | start: offset where we should start selecting from 10 | """ 11 | def __init__(self, num_samples, start=0): 12 | self.num_samples = num_samples 13 | self.start = start 14 | 15 | def __iter__(self): 16 | return iter(range(self.start, self.start + self.num_samples)) 17 | 18 | def __len__(self): 19 | return self.num_samples 20 | 21 | 22 | class ChunkDataset(Dataset): 23 | ''' 24 | A warpper of common datasets 25 | ''' 26 | def __init__(self, data_set, num_samples, start=0): 27 | self.data_set = data_set 28 | self.num_samples = num_samples 29 | self.start = start 30 | assert(self.start + self.num_samples <= len(data_set)) 31 | 32 | def __getitem__(self, index): 33 | return self.data_set[index] 34 | 35 | def __len__(self): 36 | return self.num_samples 37 | -------------------------------------------------------------------------------- /integral-pose/show_acc.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from accuracy import * 5 | from plot import * 6 | 7 | 8 | gt_file = r'./test_s3_gt.txt' 9 | pred_file = r'./test_res.txt' 10 | 11 | 12 | gt = np.loadtxt(gt_file) 13 | gt = gt.reshape(gt.shape[0], -1, 3) 14 | 15 | pred = np.loadtxt(pred_file) 16 | pred = pred.reshape(pred.shape[0], -1, 3) 17 | 18 | print('gt: ', gt.shape) 19 | print('pred: ', pred.shape) 20 | 21 | 22 | keypoints_num = 21 23 | names = ['joint'+str(i+1) for i in range(keypoints_num)] 24 | 25 | 26 | dist, acc = compute_dist_acc_wrapper(pred, gt, max_dist=100, num=100) 27 | 28 | fig, ax = plt.subplots() 29 | plot_acc(ax, dist, acc, names) 30 | fig.savefig('msra_s3_joint_acc.png') 31 | plt.show() 32 | 33 | 34 | mean_err = compute_mean_err(pred, gt) 35 | fig, ax = plt.subplots() 36 | plot_mean_err(ax, mean_err, names) 37 | fig.savefig('msra_s3_joint_acc.png') 38 | plt.show() 39 | 40 | 41 | print('mean_err: {}'.format(mean_err)) 42 | mean_err_all = compute_mean_err(pred.reshape((-1, 1, 3)), gt.reshape((-1, 1,3))) 43 | print('mean_err_all: ', mean_err_all) 44 | -------------------------------------------------------------------------------- /integral-pose/solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch.utils.data import DataLoader 4 | import os 5 | from progressbar import progress_bar 6 | 7 | 8 | def train_epoch(model, criterion, optimizer, train_loader, device=torch.device('cuda'), dtype=torch.float, collector=None): 9 | model.train() 10 | train_loss = 0 11 | 12 | for batch_idx, batch_data in enumerate(train_loader): 13 | input, target, extra = batch_data['input'], batch_data['target'], batch_data['extra'] 14 | 15 | input = input.to(device, dtype) 16 | 17 | if isinstance(target, torch.Tensor): 18 | target = target.to(device, dtype) 19 | elif isinstance(target, dict): 20 | for k in target: 21 | if isinstance(target[k], torch.Tensor): 22 | target[k] = target[k].to(device, dtype) 23 | 24 | #print('solver target[heatmap]: ', target['heatmap'].dtype) 25 | 26 | optimizer.zero_grad() 27 | output = model(input) 28 | loss = criterion(output, target) 29 | loss.backward() 30 | optimizer.step() 31 | 32 | if collector is None: 33 | train_loss += loss.item() 34 | progress_bar(batch_idx, len(train_loader), 'Loss: {0:.4e}'.format(train_loss/(batch_idx+1))) 35 | #print('loss: {0: .4e}'.format(train_loss/(batch_idx+1))) 36 | else: 37 | model.eval() 38 | with torch.no_grad(): 39 | extra['batch_idx'], extra['loader_len'], extra['batch_avg_loss'] = batch_idx, len(train_loader), loss.item() 40 | collector({'model': model, 'input': input, 'target': target, 'output': output, 'extra': extra}) 41 | 42 | # Keep train mode 43 | model.train() 44 | 45 | 46 | def val_epoch(model, criterion, val_loader, device=torch.device('cuda'), dtype=torch.float, collector=None): 47 | model.eval() 48 | val_loss = 0 49 | 50 | with torch.no_grad(): 51 | for batch_idx, batch_data in enumerate(val_loader): 52 | input, target, extra = batch_data['input'], batch_data['target'], batch_data['extra'] 53 | 54 | input = input.to(device, dtype) 55 | 56 | if isinstance(target, torch.Tensor): 57 | target = target.to(device, dtype) 58 | elif isinstance(target, dict): 59 | for k in target: 60 | if isinstance(target[k], torch.Tensor): 61 | target[k] = target[k].to(device, dtype) 62 | 63 | output = model(input) 64 | loss = criterion(output, target) 65 | 66 | if collector is None: 67 | val_loss += loss.item() 68 | progress_bar(batch_idx, len(val_loader), 'Loss: {0:.4e}'.format(val_loss/(batch_idx+1))) 69 | #print('loss: {0: .4e}'.format(val_loss/(batch_idx+1))) 70 | else: 71 | extra['batch_idx'], extra['loader_len'], extra['batch_avg_loss'] = batch_idx, len(val_loader), loss.item() 72 | collector({'model': model, 'input': input, 'target': target, 'output': output, 'extra': extra}) 73 | 74 | # Keep eval mode 75 | model.eval() 76 | 77 | 78 | def test_epoch(model, test_loader, collector, device=torch.device('cuda'), dtype=torch.float): 79 | model.eval() 80 | 81 | with torch.no_grad(): 82 | for batch_idx, batch_data in enumerate(test_loader): 83 | input, target, extra = batch_data['input'], batch_data['target'], batch_data['extra'] 84 | output = model(input.to(device, dtype)) 85 | 86 | extra['batch_idx'], extra['loader_len'] = batch_idx, len(test_loader) 87 | collector({'model': model, 'input': input, 'target': target, 'output': output, 'extra': extra}) 88 | 89 | # Keep eval mode 90 | model.eval() 91 | 92 | 93 | class Solver(): 94 | def __init__(self, train_set, model, criterion, optimizer, device=torch.device('cuda'), dtype=torch.float, **kwargs): 95 | self.train_set = train_set 96 | self.model = model 97 | self.criterion = criterion 98 | self.optimizer = optimizer 99 | self.device = device 100 | self.dtype = dtype 101 | 102 | self.batch_size = kwargs.pop('batch_size', 1) 103 | self.num_epochs = kwargs.pop('num_epochs', 1) 104 | self.num_workers = kwargs.pop('num_workers', 6) 105 | 106 | self.val_set = kwargs.pop('val_set', None) 107 | 108 | # Result collectors 109 | self.train_collector = kwargs.pop('train_collector', None) 110 | self.val_collector = kwargs.pop('val_collector', None) 111 | 112 | # Save check point and resume 113 | self.checkpoint_config = kwargs.pop('checkpoint_config', None) 114 | if self.checkpoint_config is not None: 115 | self.save_checkpoint = self.checkpoint_config['save_checkpoint'] 116 | self.checkpoint_dir = self.checkpoint_config['checkpoint_dir'] 117 | self.checkpoint_per_epochs = self.checkpoint_config['checkpoint_per_epochs'] 118 | 119 | self.resume_training = self.checkpoint_config['resume_training'] 120 | self.resume_after_epoch = self.checkpoint_config['resume_after_epoch'] 121 | else: 122 | self.save_checkpoint = False 123 | self.resume_training = False 124 | 125 | self._init() 126 | 127 | def _init(self): 128 | self.train_loader = DataLoader(self.train_set, self.batch_size, shuffle=True, num_workers=self.num_workers) 129 | 130 | if self.val_set is not None: 131 | self.val_loader = DataLoader(self.val_set, self.batch_size, shuffle=False, num_workers=self.num_workers) 132 | 133 | self.start_epoch = 0 134 | 135 | if self.resume_training: 136 | self._load_checkpoint(self.resume_after_epoch) 137 | 138 | def _train_epoch(self): 139 | train_epoch(self.model, self.criterion, self.optimizer, self.train_loader, 140 | self.device, self.dtype, 141 | self.train_collector) 142 | 143 | def _val_epoch(self): 144 | val_epoch(self.model, self.criterion, self.val_loader, 145 | self.device, self.dtype, 146 | self.val_collector) 147 | 148 | def _save_checkpoint(self, epoch): 149 | if not os.path.exists(self.checkpoint_dir): os.mkdir(self.checkpoint_dir) 150 | checkpoint_file = os.path.join(self.checkpoint_dir, 'epoch'+str(epoch)+'.pth') 151 | 152 | checkpoint = { 153 | 'model_state_dict': self.model.state_dict(), 154 | 'optimizer_state_dict': self.optimizer.state_dict(), 155 | 'epoch': epoch 156 | } 157 | 158 | torch.save(checkpoint, checkpoint_file) 159 | 160 | def _load_checkpoint(self, epoch): 161 | checkpoint_file = os.path.join(self.checkpoint_dir, 'epoch'+str(epoch)+'.pth') 162 | 163 | print('==> Resuming from checkpoint after epoch {} ..'.format(epoch)) 164 | assert os.path.isdir(self.checkpoint_dir), 'Error: no checkpoint directory found!' 165 | assert os.path.isfile(checkpoint_file), 'Error: no checkpoint file of epoch {}'.format(epoch) 166 | 167 | checkpoint = torch.load(os.path.join(self.checkpoint_dir, 'epoch'+str(epoch)+'.pth')) 168 | self.model.load_state_dict(checkpoint['model_state_dict']) 169 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 170 | self.start_epoch = checkpoint['epoch'] + 1 171 | 172 | def train(self): 173 | for epoch in range(self.start_epoch, self.start_epoch + self.num_epochs): 174 | print('Epoch {}: '.format(epoch)) 175 | self._train_epoch() 176 | 177 | if self.val_set is not None: 178 | self._val_epoch() 179 | 180 | if self.save_checkpoint and epoch % self.checkpoint_per_epochs == 0: 181 | self._save_checkpoint(epoch) 182 | -------------------------------------------------------------------------------- /integral-pose/v2v_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class Basic3DBlock(nn.Module): 6 | def __init__(self, in_planes, out_planes, kernel_size): 7 | super(Basic3DBlock, self).__init__() 8 | self.block = nn.Sequential( 9 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1, padding=((kernel_size-1)//2)), 10 | nn.BatchNorm3d(out_planes), 11 | nn.ReLU(True) 12 | ) 13 | 14 | def forward(self, x): 15 | return self.block(x) 16 | 17 | 18 | class Res3DBlock(nn.Module): 19 | def __init__(self, in_planes, out_planes): 20 | super(Res3DBlock, self).__init__() 21 | self.res_branch = nn.Sequential( 22 | nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=1, padding=1), 23 | nn.BatchNorm3d(out_planes), 24 | nn.ReLU(True), 25 | nn.Conv3d(out_planes, out_planes, kernel_size=3, stride=1, padding=1), 26 | nn.BatchNorm3d(out_planes) 27 | ) 28 | 29 | if in_planes == out_planes: 30 | self.skip_con = nn.Sequential() 31 | else: 32 | self.skip_con = nn.Sequential( 33 | nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=1, padding=0), 34 | nn.BatchNorm3d(out_planes) 35 | ) 36 | 37 | def forward(self, x): 38 | res = self.res_branch(x) 39 | skip = self.skip_con(x) 40 | return F.relu(res + skip, True) 41 | 42 | 43 | class Pool3DBlock(nn.Module): 44 | def __init__(self, pool_size): 45 | super(Pool3DBlock, self).__init__() 46 | self.pool_size = pool_size 47 | 48 | def forward(self, x): 49 | return F.max_pool3d(x, kernel_size=self.pool_size, stride=self.pool_size) 50 | 51 | 52 | class Upsample3DBlock(nn.Module): 53 | def __init__(self, in_planes, out_planes, kernel_size, stride): 54 | super(Upsample3DBlock, self).__init__() 55 | assert(kernel_size == 2) 56 | assert(stride == 2) 57 | self.block = nn.Sequential( 58 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=0, output_padding=0), 59 | nn.BatchNorm3d(out_planes), 60 | nn.ReLU(True) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.block(x) 65 | 66 | 67 | class EncoderDecorder(nn.Module): 68 | def __init__(self): 69 | super(EncoderDecorder, self).__init__() 70 | 71 | self.encoder_pool1 = Pool3DBlock(2) 72 | self.encoder_res1 = Res3DBlock(32, 64) 73 | self.encoder_pool2 = Pool3DBlock(2) 74 | self.encoder_res2 = Res3DBlock(64, 128) 75 | 76 | self.mid_res = Res3DBlock(128, 128) 77 | 78 | self.decoder_res2 = Res3DBlock(128, 128) 79 | self.decoder_upsample2 = Upsample3DBlock(128, 64, 2, 2) 80 | self.decoder_res1 = Res3DBlock(64, 64) 81 | self.decoder_upsample1 = Upsample3DBlock(64, 32, 2, 2) 82 | 83 | self.skip_res1 = Res3DBlock(32, 32) 84 | self.skip_res2 = Res3DBlock(64, 64) 85 | 86 | def forward(self, x): 87 | skip_x1 = self.skip_res1(x) 88 | x = self.encoder_pool1(x) 89 | x = self.encoder_res1(x) 90 | skip_x2 = self.skip_res2(x) 91 | x = self.encoder_pool2(x) 92 | x = self.encoder_res2(x) 93 | 94 | x = self.mid_res(x) 95 | 96 | x = self.decoder_res2(x) 97 | x = self.decoder_upsample2(x) 98 | x = x + skip_x2 99 | x = self.decoder_res1(x) 100 | x = self.decoder_upsample1(x) 101 | x = x + skip_x1 102 | 103 | return x 104 | 105 | 106 | class V2VModel(nn.Module): 107 | def __init__(self, input_channels, output_channels): 108 | super(V2VModel, self).__init__() 109 | 110 | self.front_layers = nn.Sequential( 111 | Basic3DBlock(input_channels, 16, 7), 112 | Pool3DBlock(2), 113 | Res3DBlock(16, 32), 114 | Res3DBlock(32, 32), 115 | Res3DBlock(32, 32) 116 | ) 117 | 118 | self.encoder_decoder = EncoderDecorder() 119 | 120 | self.back_layers = nn.Sequential( 121 | Res3DBlock(32, 32), 122 | Basic3DBlock(32, 32, 1), 123 | Basic3DBlock(32, 32, 1), 124 | ) 125 | 126 | self.output_layer = nn.Conv3d(32, output_channels, kernel_size=1, stride=1, padding=0) 127 | 128 | self._initialize_weights() 129 | 130 | def forward(self, x): 131 | x = self.front_layers(x) 132 | x = self.encoder_decoder(x) 133 | x = self.back_layers(x) 134 | x = self.output_layer(x) 135 | return x 136 | 137 | def _initialize_weights(self): 138 | for m in self.modules(): 139 | if isinstance(m, nn.Conv3d): 140 | nn.init.normal_(m.weight, 0, 0.001) 141 | nn.init.constant_(m.bias, 0) 142 | elif isinstance(m, nn.ConvTranspose3d): 143 | nn.init.normal_(m.weight, 0, 0.001) 144 | nn.init.constant_(m.bias, 0) 145 | -------------------------------------------------------------------------------- /integral-pose/v2v_util.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | 4 | 5 | def discretize(coord, cropped_size): 6 | '''[-1, 1] -> [0, cropped_size]''' 7 | min_normalized = -1 8 | max_normalized = 1 9 | scale = (max_normalized - min_normalized) / cropped_size 10 | return (coord - min_normalized) / scale 11 | 12 | 13 | def warp2continuous(coord, refpoint, cubic_size, cropped_size): 14 | ''' 15 | Map coordinates in set [0, 1, .., cropped_size-1] to original range [-cubic_size/2+refpoint, cubic_size/2 + refpoint] 16 | ''' 17 | min_normalized = -1 18 | max_normalized = 1 19 | 20 | scale = (max_normalized - min_normalized) / cropped_size 21 | coord = coord * scale + min_normalized # -> [-1, 1] 22 | 23 | coord = coord * cubic_size / 2 + refpoint 24 | 25 | return coord 26 | 27 | 28 | def scattering(coord, cropped_size): 29 | # coord: [0, cropped_size] 30 | # Assign range[0, 1) -> 0, [1, 2) -> 1, .. [cropped_size-1, cropped_size) -> cropped_size-1 31 | # That is, around center 0.5 -> 0, around center 1.5 -> 1 .. around center cropped_size-0.5 -> cropped_size-1 32 | coord = coord.astype(np.int32) 33 | 34 | mask = (coord[:, 0] >= 0) & (coord[:, 0] < cropped_size) & \ 35 | (coord[:, 1] >= 0) & (coord[:, 1] < cropped_size) & \ 36 | (coord[:, 2] >= 0) & (coord[:, 2] < cropped_size) 37 | 38 | coord = coord[mask, :] 39 | 40 | cubic = np.zeros((cropped_size, cropped_size, cropped_size)) 41 | 42 | # Note, directly map point coordinate (x, y, z) to index (i, j, k), instead of (k, j, i) 43 | # Need to be consistent with heatmap generating and coordinates extration from heatmap 44 | cubic[coord[:, 0], coord[:, 1], coord[:, 2]] = 1 45 | 46 | return cubic 47 | 48 | 49 | def extract_coord_from_output(output, center=True): 50 | ''' 51 | output: shape (batch, jointNum, volumeSize, volumeSize, volumeSize) 52 | center: if True, add 0.5, default is true 53 | return: shape (batch, jointNum, 3) 54 | ''' 55 | assert(len(output.shape) >= 3) 56 | vsize = output.shape[-3:] 57 | 58 | output_rs = output.reshape(-1, np.prod(vsize)) 59 | max_index = np.unravel_index(np.argmax(output_rs, axis=1), vsize) 60 | max_index = np.array(max_index).T 61 | 62 | xyz_output = max_index.reshape([*output.shape[:-3], 3]) 63 | 64 | # Note discrete coord can represents real range [coord, coord+1), see function scattering() 65 | # So, move coord to range center for better fittness 66 | if center: xyz_output = xyz_output + 0.5 67 | 68 | return xyz_output 69 | 70 | 71 | def generate_coord(points, refpoint, new_size, angle, trans, sizes): 72 | cubic_size, cropped_size, original_size = sizes 73 | 74 | # points shape: (n, 3) 75 | coord = points 76 | 77 | # note, will consider points within range [refpoint-cubic_size/2, refpoint+cubic_size/2] as candidates 78 | 79 | # normalize 80 | coord = (coord - refpoint) / (cubic_size/2) # -> [-1, 1] 81 | 82 | # discretize 83 | coord = discretize(coord, cropped_size) # -> [0, cropped_size] 84 | 85 | # move cropped center to (virtual larger, [0, original_size]) original volume center 86 | # that is, treat current data as cropped from center of original volume, and now we put back it 87 | coord += (original_size / 2 - cropped_size / 2) 88 | 89 | # resize around original center with scale new_size/100 90 | resize_scale = 100 / new_size 91 | if new_size < 100: 92 | coord = coord * resize_scale + original_size/2 * (1 - resize_scale) 93 | elif new_size > 100: 94 | coord = coord * resize_scale - original_size/2 * (resize_scale - 1) 95 | else: 96 | # new_size = 100 if it is in test mode 97 | pass 98 | 99 | # rotation 100 | if angle != 0: 101 | original_coord = coord.copy() 102 | original_coord[:,0] -= original_size/2 103 | original_coord[:,1] -= original_size/2 104 | coord[:,0] = original_coord[:,0]*np.cos(angle) - original_coord[:,1]*np.sin(angle) 105 | coord[:,1] = original_coord[:,0]*np.sin(angle) + original_coord[:,1]*np.cos(angle) 106 | coord[:,0] += original_size/2 107 | coord[:,1] += original_size/2 108 | 109 | # translation 110 | # Note, if trans = (original_size/2 - cropped_size/2 + 1), the following translation will 111 | # cancel the above translation(after discretion). It will be set it when in test mode. 112 | # TODO: Can only achieve translation [-4, 4]? 113 | coord -= trans 114 | 115 | return coord 116 | 117 | 118 | def generate_cubic_input(points, refpoint, new_size, angle, trans, sizes): 119 | _, cropped_size, _ = sizes 120 | coord = generate_coord(points, refpoint, new_size, angle, trans, sizes) 121 | 122 | # scattering 123 | cubic = scattering(coord, cropped_size) 124 | 125 | return cubic 126 | 127 | 128 | # def generate_heatmap_gt(keypoints, refpoint, new_size, angle, trans, sizes, d3outputs, pool_factor, std): 129 | # _, cropped_size, _ = sizes 130 | # d3output_x, d3output_y, d3output_z = d3outputs 131 | 132 | # coord = generate_coord(keypoints, refpoint, new_size, angle, trans, sizes) # [0, cropped_size] 133 | # coord /= pool_factor # [0, cropped_size/pool_factor] 134 | 135 | # # heatmap generation 136 | # output_size = int(cropped_size / pool_factor) 137 | # heatmap = np.zeros((keypoints.shape[0], output_size, output_size, output_size)) 138 | 139 | # # use center of cell 140 | # center_offset = 0.5 141 | 142 | # for i in range(coord.shape[0]): 143 | # xi, yi, zi= coord[i] 144 | # heatmap[i] = np.exp(-(np.power((d3output_x+center_offset-xi)/std, 2)/2 + \ 145 | # np.power((d3output_y+center_offset-yi)/std, 2)/2 + \ 146 | # np.power((d3output_z+center_offset-zi)/std, 2)/2)) # +0.5, move coordinate to range center 147 | 148 | # return heatmap 149 | 150 | 151 | #--- 152 | def containing_box_coord(point): 153 | ''' 154 | point: (K, 3) 155 | return: (K, 8, 3), eight box vertices coords 156 | ''' 157 | # if np.any(point >= 43): 158 | # res = point[point >= 43] 159 | # print('invalid point: {}'.format(res)) 160 | # exit() 161 | 162 | 163 | box_grid = np.meshgrid([0, 1], [0, 1], [0, 1], indexing='ij') 164 | box_grid = np.array(box_grid).reshape((3, 8)).transpose() 165 | 166 | floor = np.floor(point) 167 | box_coord = floor.reshape((-1, 1, 3)) + box_grid 168 | 169 | return box_coord 170 | 171 | 172 | def box_coord_prob(point, box_coord): 173 | ''' 174 | point: (K, 3) 175 | box_coord: (K, 8, 3) 176 | return: (K, 8) 177 | ''' 178 | diff = box_coord - point.reshape((-1, 1, 3)) 179 | weight = np.maximum(0, 1 - np.abs(diff)) 180 | prob = weight[:,:,0] * weight[:,:,1] * weight[:,:,2] 181 | norm = np.sum(prob, axis=1, keepdims=True) 182 | norm[norm <= 0] = 1.0 # avoid zero 183 | prob = prob / norm 184 | 185 | return prob 186 | 187 | 188 | def onehot_heatmap_impl(coord, output_size): 189 | ''' 190 | coord: (K, 3) 191 | return: (output_size, output_size, output_size) 192 | ''' 193 | coord = np.array(coord) 194 | 195 | box_coord = containing_box_coord(coord) # (K, 8, 3) 196 | box_prob = box_coord_prob(coord, box_coord) # (K, 8) 197 | box_coord = box_coord.astype(np.int32) 198 | 199 | # Generate K heatmaps 200 | heatmap = np.zeros((coord.shape[0], output_size, output_size, output_size)) 201 | for i in range(coord.shape[0]): 202 | heatmap[i][box_coord[i,:,0], box_coord[i,:,1], box_coord[i,:,2]] = box_prob[i] 203 | 204 | return heatmap 205 | 206 | 207 | def generate_onehot_heatmap_gt(keypoints, refpoint, new_size, angle, trans, sizes, d3outputs, pool_factor, std): 208 | _, cropped_size, _ = sizes 209 | d3output_x, d3output_y, d3output_z = d3outputs 210 | 211 | coord = generate_coord(keypoints, refpoint, new_size, angle, trans, sizes) # [0, cropped_size] 212 | coord /= pool_factor # [0, cropped_size/pool_factor] 213 | 214 | # heatmap generation 215 | output_size = int(cropped_size / pool_factor) 216 | 217 | # Warning, clip joints into [0, 42], make sure the containing box coord indices will not exceed 43(44-1) 218 | # TODO: check here 219 | target_output_size = 44 220 | coord[coord >= target_output_size-2] = target_output_size - 2 221 | coord[coord < 0] = 0 222 | 223 | return onehot_heatmap_impl(coord, output_size) 224 | 225 | 226 | def generate_volume_coord_gt(keypoints, refpoint, new_size, angle, trans, sizes, pool_factor): 227 | coord = generate_coord(keypoints, refpoint, new_size, angle, trans, sizes) # [0, cropped_size] 228 | coord /= pool_factor 229 | 230 | # Warning, clip joints into [0, 42], make sure the containing box coord indices will not exceed 43(44-1) 231 | # TODO: check here 232 | target_output_size = 44 233 | coord[coord >= target_output_size-2] = target_output_size - 2 234 | coord[coord < 0] = 0 235 | 236 | return coord 237 | 238 | 239 | class V2VVoxelization(object): 240 | def __init__(self, cubic_size, augmentation=True): 241 | self.cubic_size = cubic_size 242 | self.cropped_size, self.original_size = 88, 96 243 | self.sizes = (self.cubic_size, self.cropped_size, self.original_size) 244 | self.pool_factor = 2 245 | self.std = 1.7 246 | self.augmentation = augmentation 247 | 248 | output_size = int(self.cropped_size / self.pool_factor) 249 | # Note, range(size) and indexing = 'ij' 250 | self.d3outputs = np.meshgrid(np.arange(output_size), np.arange(output_size), np.arange(output_size), indexing='ij') 251 | 252 | def __call__(self, sample): 253 | points, keypoints, refpoint = sample['points'], sample['keypoints'], sample['refpoint'] 254 | 255 | ## Augmentations 256 | # Resize 257 | new_size = np.random.rand() * 40 + 80 258 | 259 | # Rotation 260 | angle = np.random.rand() * 80/180*np.pi - 40/180*np.pi 261 | 262 | # Translation 263 | trans = np.random.rand(3) * (self.original_size - self.cropped_size) 264 | 265 | if not self.augmentation: 266 | new_size = 100 267 | angle = 0 268 | trans = self.original_size/2 - self.cropped_size/2 269 | 270 | # Add noise and random selection 271 | add_noise = False 272 | random_selection = False 273 | 274 | if self.augmentation and add_noise: 275 | # noise, [-0.5, 0.5] 276 | scale = 0.5 277 | noise = (np.random.rand(*points.shape) * 2 - 1) * scale 278 | points += noise 279 | 280 | if self.augmentation and random_selection: 281 | threshold = np.random.rand(1)[0] * 0.5 # <= 0.5 282 | prob = np.random.rand(points.shape[0]) 283 | mask = prob > threshold 284 | points = points[mask, :] 285 | 286 | input = generate_cubic_input(points, refpoint, new_size, angle, trans, self.sizes) 287 | # heatmap = generate_heatmap_gt(keypoints, refpoint, new_size, angle, trans, self.sizes, self.d3outputs, self.pool_factor, self.std) 288 | # Use One-hot heatmap 289 | heatmap = generate_onehot_heatmap_gt(keypoints, refpoint, new_size, angle, trans, self.sizes, self.d3outputs, self.pool_factor, self.std) 290 | keypoints_volume_coords = generate_volume_coord_gt(keypoints, refpoint, new_size, angle, trans, self.sizes, self.pool_factor) 291 | 292 | # one channel 293 | input = input.reshape((1, *input.shape)) 294 | 295 | return input, heatmap, keypoints_volume_coords 296 | 297 | # def voxelize(self, points, refpoint): 298 | # new_size, angle, trans = 100, 0, self.original_size/2 - self.cropped_size/2 299 | # input = generate_cubic_input(points, refpoint, new_size, angle, trans, self.sizes) 300 | # return input.reshape((1, *input.shape)) 301 | 302 | # def generate_heatmap(self, keypoints, refpoint): 303 | # new_size, angle, trans = 100, 0, self.original_size/2 - self.cropped_size/2 304 | # heatmap = generate_heatmap_gt(keypoints, refpoint, new_size, angle, trans, self.sizes, self.d3outputs, self.pool_factor, self.std) 305 | # return heatmap 306 | 307 | # def evaluate(self, heatmaps, refpoints): 308 | # coords = extract_coord_from_output(heatmaps, center=True) 309 | # coords *= self.pool_factor 310 | # keypoints = warp2continuous(coords, refpoints, self.cubic_size, self.cropped_size) 311 | # return keypoints 312 | 313 | # def warp2continuous(self, coords, refpoints): 314 | # print('Warning: added 0.5 on input coord') 315 | # coords += 0.5 # move to grid cell center 316 | # coords *= self.pool_factor 317 | # keypoints = warp2continuous(coords, refpoints, self.cubic_size, self.cropped_size) 318 | # return keypoints 319 | 320 | # def generate_coord_raw(self, points, refpoint): 321 | # new_size, angle, trans = 100, 0, self.original_size/2 - self.cropped_size/2 322 | # coord = generate_coord(points, refpoint, new_size, angle, trans, self.sizes) 323 | # return coord 324 | 325 | def warp2continuous_raw(self, coords, refpoints): 326 | # Do not add 0.5, since coords have float precison 327 | coords = coords * self.pool_factor 328 | keypoints = warp2continuous(coords, refpoints, self.cubic_size, self.cropped_size) 329 | return keypoints 330 | -------------------------------------------------------------------------------- /lib/accuracy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_dist_acc_wrapper(pred, gt, max_dist=10, num=100): 5 | ''' 6 | pred: (N, K, 3) 7 | gt: (N, K, 3) 8 | 9 | return dist: (K, ) 10 | return acc: (K, num) 11 | ''' 12 | assert(pred.shape == gt.shape) 13 | assert(len(pred.shape) == 3) 14 | 15 | dist = np.linspace(0, max_dist, num) 16 | return dist, compute_dist_acc(pred, gt, dist) 17 | 18 | 19 | def compute_dist_acc(pred, gt, dist): 20 | ''' 21 | pred: (N, K, 3) 22 | gt: (N, K, 3) 23 | dist: (M, ) 24 | 25 | return acc: (K, M) 26 | ''' 27 | assert(pred.shape == gt.shape) 28 | assert(len(pred.shape) == 3) 29 | 30 | N, K = pred.shape[0], pred.shape[1] 31 | err_dist = np.sqrt(np.sum((pred - gt)**2, axis=2)) # (N, K) 32 | 33 | acc = np.zeros((K, dist.shape[0])) 34 | 35 | for i, d in enumerate(dist): 36 | acc_d = (err_dist < d).sum(axis=0) / N 37 | acc[:,i] = acc_d 38 | 39 | return acc 40 | 41 | 42 | def compute_mean_err(pred, gt): 43 | ''' 44 | pred: (N, K, 3) 45 | gt: (N, K, 3) 46 | 47 | mean_err: (K,) 48 | ''' 49 | N, K = pred.shape[0], pred.shape[1] 50 | err_dist = np.sqrt(np.sum((pred - gt)**2, axis=2)) # (N, K) 51 | return np.mean(err_dist, axis=0) 52 | -------------------------------------------------------------------------------- /lib/progressbar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | 6 | _, term_width = os.popen('stty size', 'r').read().split() 7 | term_width = int(term_width) 8 | TOTAL_BAR_LENGTH = 65. 9 | last_time = time.time() 10 | begin_time = last_time 11 | 12 | def progress_bar(current, total, msg=None): 13 | global last_time, begin_time 14 | if current == 0: 15 | begin_time = time.time() # Reset for new bar. 16 | 17 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 18 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 19 | 20 | sys.stdout.write(' [') 21 | for i in range(cur_len): 22 | sys.stdout.write('=') 23 | sys.stdout.write('>') 24 | for i in range(rest_len): 25 | sys.stdout.write('.') 26 | sys.stdout.write(']') 27 | 28 | cur_time = time.time() 29 | step_time = cur_time - last_time 30 | last_time = cur_time 31 | tot_time = cur_time - begin_time 32 | 33 | L = [] 34 | L.append(' Step: %s' % format_time(step_time)) 35 | L.append(' | Tot: %s' % format_time(tot_time)) 36 | if msg: 37 | L.append(' | ' + msg) 38 | 39 | msg = ''.join(L) 40 | sys.stdout.write(msg) 41 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 42 | sys.stdout.write(' ') 43 | 44 | # Go back to the center of the bar. 45 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 46 | sys.stdout.write('\b') 47 | sys.stdout.write(' %d/%d ' % (current+1, total)) 48 | 49 | if current < total-1: 50 | sys.stdout.write('\r') 51 | else: 52 | sys.stdout.write('\n') 53 | sys.stdout.flush() 54 | 55 | 56 | def format_time(seconds): 57 | days = int(seconds / 3600/24) 58 | seconds = seconds - days*3600*24 59 | hours = int(seconds / 3600) 60 | seconds = seconds - hours*3600 61 | minutes = int(seconds / 60) 62 | seconds = seconds - minutes*60 63 | secondsf = int(seconds) 64 | seconds = seconds - secondsf 65 | millis = int(seconds*1000) 66 | 67 | f = '' 68 | i = 1 69 | if days > 0: 70 | f += str(days) + 'D' 71 | i += 1 72 | if hours > 0 and i <= 2: 73 | f += str(hours) + 'h' 74 | i += 1 75 | if minutes > 0 and i <= 2: 76 | f += str(minutes) + 'm' 77 | i += 1 78 | if secondsf > 0 and i <= 2: 79 | f += str(secondsf) + 's' 80 | i += 1 81 | if millis > 0 and i <= 2: 82 | f += str(millis) + 'ms' 83 | i += 1 84 | if f == '': 85 | f = '0ms' 86 | return f 87 | -------------------------------------------------------------------------------- /lib/sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import sampler 2 | 3 | 4 | class ChunkSampler(sampler.Sampler): 5 | """Samples elements sequentially from some offset. 6 | Arguments: 7 | num_samples: # of desired datapoints 8 | start: offset where we should start selecting from 9 | """ 10 | def __init__(self, num_samples, start=0): 11 | self.num_samples = num_samples 12 | self.start = start 13 | 14 | def __iter__(self): 15 | return iter(range(self.start, self.start + self.num_samples)) 16 | 17 | def __len__(self): 18 | return self.num_samples 19 | -------------------------------------------------------------------------------- /lib/solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from lib.progressbar import progress_bar 4 | 5 | 6 | def train_epoch(model, criterion, optimizer, train_loader, device=torch.device('cuda'), dtype=torch.float): 7 | model.train() 8 | train_loss = 0 9 | 10 | for batch_idx, (inputs, targets) in enumerate(train_loader): 11 | inputs, targets = inputs.to(device, dtype), targets.to(device, dtype) 12 | optimizer.zero_grad() 13 | outputs = model(inputs) 14 | loss = criterion(outputs, targets) 15 | loss.backward() 16 | optimizer.step() 17 | 18 | train_loss += loss.item() 19 | progress_bar(batch_idx, len(train_loader), 'Loss: {0:.4e}'.format(train_loss/(batch_idx+1))) 20 | #print('loss: {0: .4e}'.format(train_loss/(batch_idx+1))) 21 | 22 | 23 | def val_epoch(model, criterion, val_loader, device=torch.device('cuda'), dtype=torch.float): 24 | model.eval() 25 | val_loss = 0 26 | 27 | with torch.no_grad(): 28 | for batch_idx, (inputs, targets) in enumerate(val_loader): 29 | inputs, targets = inputs.to(device, dtype), targets.to(device, dtype) 30 | outputs = model(inputs) 31 | loss = criterion(outputs, targets) 32 | 33 | val_loss += loss.item() 34 | progress_bar(batch_idx, len(val_loader), 'Loss: {0:.4e}'.format(val_loss/(batch_idx+1))) 35 | #print('loss: {0: .4e}'.format(val_loss/(batch_idx+1))) 36 | 37 | 38 | def test_epoch(model, test_loader, result_collector, device=torch.device('cuda'), dtype=torch.float): 39 | model.eval() 40 | 41 | with torch.no_grad(): 42 | for batch_idx, (inputs, extra) in enumerate(test_loader): 43 | outputs = model(inputs.to(device, dtype)) 44 | result_collector((inputs, outputs, extra)) 45 | 46 | progress_bar(batch_idx, len(test_loader)) 47 | -------------------------------------------------------------------------------- /src/v2v_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class Basic3DBlock(nn.Module): 6 | def __init__(self, in_planes, out_planes, kernel_size): 7 | super(Basic3DBlock, self).__init__() 8 | self.block = nn.Sequential( 9 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1, padding=((kernel_size-1)//2)), 10 | nn.BatchNorm3d(out_planes), 11 | nn.ReLU(True) 12 | ) 13 | 14 | def forward(self, x): 15 | return self.block(x) 16 | 17 | 18 | class Res3DBlock(nn.Module): 19 | def __init__(self, in_planes, out_planes): 20 | super(Res3DBlock, self).__init__() 21 | self.res_branch = nn.Sequential( 22 | nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=1, padding=1), 23 | nn.BatchNorm3d(out_planes), 24 | nn.ReLU(True), 25 | nn.Conv3d(out_planes, out_planes, kernel_size=3, stride=1, padding=1), 26 | nn.BatchNorm3d(out_planes) 27 | ) 28 | 29 | if in_planes == out_planes: 30 | self.skip_con = nn.Sequential() 31 | else: 32 | self.skip_con = nn.Sequential( 33 | nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=1, padding=0), 34 | nn.BatchNorm3d(out_planes) 35 | ) 36 | 37 | def forward(self, x): 38 | res = self.res_branch(x) 39 | skip = self.skip_con(x) 40 | return F.relu(res + skip, True) 41 | 42 | 43 | class Pool3DBlock(nn.Module): 44 | def __init__(self, pool_size): 45 | super(Pool3DBlock, self).__init__() 46 | self.pool_size = pool_size 47 | 48 | def forward(self, x): 49 | return F.max_pool3d(x, kernel_size=self.pool_size, stride=self.pool_size) 50 | 51 | 52 | class Upsample3DBlock(nn.Module): 53 | def __init__(self, in_planes, out_planes, kernel_size, stride): 54 | super(Upsample3DBlock, self).__init__() 55 | assert(kernel_size == 2) 56 | assert(stride == 2) 57 | self.block = nn.Sequential( 58 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=0, output_padding=0), 59 | nn.BatchNorm3d(out_planes), 60 | nn.ReLU(True) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.block(x) 65 | 66 | 67 | class EncoderDecorder(nn.Module): 68 | def __init__(self): 69 | super(EncoderDecorder, self).__init__() 70 | 71 | self.encoder_pool1 = Pool3DBlock(2) 72 | self.encoder_res1 = Res3DBlock(32, 64) 73 | self.encoder_pool2 = Pool3DBlock(2) 74 | self.encoder_res2 = Res3DBlock(64, 128) 75 | 76 | self.mid_res = Res3DBlock(128, 128) 77 | 78 | self.decoder_res2 = Res3DBlock(128, 128) 79 | self.decoder_upsample2 = Upsample3DBlock(128, 64, 2, 2) 80 | self.decoder_res1 = Res3DBlock(64, 64) 81 | self.decoder_upsample1 = Upsample3DBlock(64, 32, 2, 2) 82 | 83 | self.skip_res1 = Res3DBlock(32, 32) 84 | self.skip_res2 = Res3DBlock(64, 64) 85 | 86 | def forward(self, x): 87 | skip_x1 = self.skip_res1(x) 88 | x = self.encoder_pool1(x) 89 | x = self.encoder_res1(x) 90 | skip_x2 = self.skip_res2(x) 91 | x = self.encoder_pool2(x) 92 | x = self.encoder_res2(x) 93 | 94 | x = self.mid_res(x) 95 | 96 | x = self.decoder_res2(x) 97 | x = self.decoder_upsample2(x) 98 | x = x + skip_x2 99 | x = self.decoder_res1(x) 100 | x = self.decoder_upsample1(x) 101 | x = x + skip_x1 102 | 103 | return x 104 | 105 | 106 | class V2VModel(nn.Module): 107 | def __init__(self, input_channels, output_channels): 108 | super(V2VModel, self).__init__() 109 | 110 | self.front_layers = nn.Sequential( 111 | Basic3DBlock(input_channels, 16, 7), 112 | Pool3DBlock(2), 113 | Res3DBlock(16, 32), 114 | Res3DBlock(32, 32), 115 | Res3DBlock(32, 32) 116 | ) 117 | 118 | self.encoder_decoder = EncoderDecorder() 119 | 120 | self.back_layers = nn.Sequential( 121 | Res3DBlock(32, 32), 122 | Basic3DBlock(32, 32, 1), 123 | Basic3DBlock(32, 32, 1), 124 | ) 125 | 126 | self.output_layer = nn.Conv3d(32, output_channels, kernel_size=1, stride=1, padding=0) 127 | 128 | self._initialize_weights() 129 | 130 | def forward(self, x): 131 | x = self.front_layers(x) 132 | x = self.encoder_decoder(x) 133 | x = self.back_layers(x) 134 | x = self.output_layer(x) 135 | return x 136 | 137 | def _initialize_weights(self): 138 | for m in self.modules(): 139 | if isinstance(m, nn.Conv3d): 140 | nn.init.normal_(m.weight, 0, 0.001) 141 | nn.init.constant_(m.bias, 0) 142 | elif isinstance(m, nn.ConvTranspose3d): 143 | nn.init.normal_(m.weight, 0, 0.001) 144 | nn.init.constant_(m.bias, 0) 145 | -------------------------------------------------------------------------------- /src/v2v_util.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | 4 | 5 | def discretize(coord, cropped_size): 6 | '''[-1, 1] -> [0, cropped_size]''' 7 | min_normalized = -1 8 | max_normalized = 1 9 | scale = (max_normalized - min_normalized) / cropped_size 10 | return (coord - min_normalized) / scale 11 | 12 | 13 | def warp2continuous(coord, refpoint, cubic_size, cropped_size): 14 | ''' 15 | Map coordinates in set [0, 1, .., cropped_size-1] to original range [-cubic_size/2+refpoint, cubic_size/2 + refpoint] 16 | ''' 17 | min_normalized = -1 18 | max_normalized = 1 19 | 20 | scale = (max_normalized - min_normalized) / cropped_size 21 | coord = coord * scale + min_normalized # -> [-1, 1] 22 | 23 | coord = coord * cubic_size / 2 + refpoint 24 | 25 | return coord 26 | 27 | 28 | def scattering(coord, cropped_size): 29 | # coord: [0, cropped_size] 30 | # Assign range[0, 1) -> 0, [1, 2) -> 1, .. [cropped_size-1, cropped_size) -> cropped_size-1 31 | # That is, around center 0.5 -> 0, around center 1.5 -> 1 .. around center cropped_size-0.5 -> cropped_size-1 32 | coord = coord.astype(np.int32) 33 | 34 | mask = (coord[:, 0] >= 0) & (coord[:, 0] < cropped_size) & \ 35 | (coord[:, 1] >= 0) & (coord[:, 1] < cropped_size) & \ 36 | (coord[:, 2] >= 0) & (coord[:, 2] < cropped_size) 37 | 38 | coord = coord[mask, :] 39 | 40 | cubic = np.zeros((cropped_size, cropped_size, cropped_size)) 41 | 42 | # Note, directly map point coordinate (x, y, z) to index (i, j, k), instead of (k, j, i) 43 | # Need to be consistent with heatmap generating and coordinates extration from heatmap 44 | cubic[coord[:, 0], coord[:, 1], coord[:, 2]] = 1 45 | 46 | return cubic 47 | 48 | 49 | def extract_coord_from_output(output, center=True): 50 | ''' 51 | output: shape (batch, jointNum, volumeSize, volumeSize, volumeSize) 52 | center: if True, add 0.5, default is true 53 | return: shape (batch, jointNum, 3) 54 | ''' 55 | assert(len(output.shape) >= 3) 56 | vsize = output.shape[-3:] 57 | 58 | output_rs = output.reshape(-1, np.prod(vsize)) 59 | max_index = np.unravel_index(np.argmax(output_rs, axis=1), vsize) 60 | max_index = np.array(max_index).T 61 | 62 | xyz_output = max_index.reshape([*output.shape[:-3], 3]) 63 | 64 | # Note discrete coord can represents real range [coord, coord+1), see function scattering() 65 | # So, move coord to range center for better fittness 66 | if center: xyz_output = xyz_output + 0.5 67 | 68 | return xyz_output 69 | 70 | 71 | def generate_coord(points, refpoint, new_size, angle, trans, sizes): 72 | cubic_size, cropped_size, original_size = sizes 73 | 74 | # points shape: (n, 3) 75 | coord = points 76 | 77 | # note, will consider points within range [refpoint-cubic_size/2, refpoint+cubic_size/2] as candidates 78 | 79 | # normalize 80 | coord = (coord - refpoint) / (cubic_size/2) # -> [-1, 1] 81 | 82 | # discretize 83 | coord = discretize(coord, cropped_size) # -> [0, cropped_size] 84 | coord += (original_size / 2 - cropped_size / 2) # move center to original volume 85 | 86 | # resize around original volume center 87 | resize_scale = new_size / 100 88 | if new_size < 100: 89 | coord = coord * resize_scale + original_size/2 * (1 - resize_scale) 90 | elif new_size > 100: 91 | coord = coord * resize_scale - original_size/2 * (resize_scale - 1) 92 | else: 93 | # new_size = 100 if it is in test mode 94 | pass 95 | 96 | # rotation 97 | if angle != 0: 98 | original_coord = coord.copy() 99 | original_coord[:,0] -= original_size / 2 100 | original_coord[:,1] -= original_size / 2 101 | coord[:,0] = original_coord[:,0]*np.cos(angle) - original_coord[:,1]*np.sin(angle) 102 | coord[:,1] = original_coord[:,0]*np.sin(angle) + original_coord[:,1]*np.cos(angle) 103 | coord[:,0] += original_size / 2 104 | coord[:,1] += original_size / 2 105 | 106 | # translation 107 | # Note, if trans = (original_size/2 - cropped_size/2), the following translation will 108 | # cancel the above translation(after discretion). It will be set it when in test mode. 109 | coord -= trans 110 | 111 | return coord 112 | 113 | 114 | def generate_cubic_input(points, refpoint, new_size, angle, trans, sizes): 115 | _, cropped_size, _ = sizes 116 | coord = generate_coord(points, refpoint, new_size, angle, trans, sizes) 117 | 118 | # scattering 119 | cubic = scattering(coord, cropped_size) 120 | 121 | return cubic 122 | 123 | 124 | def generate_heatmap_gt(keypoints, refpoint, new_size, angle, trans, sizes, d3outputs, pool_factor, std): 125 | _, cropped_size, _ = sizes 126 | d3output_x, d3output_y, d3output_z = d3outputs 127 | 128 | coord = generate_coord(keypoints, refpoint, new_size, angle, trans, sizes) # [0, cropped_size] 129 | coord /= pool_factor # [0, cropped_size/pool_factor] 130 | 131 | # heatmap generation 132 | output_size = int(cropped_size / pool_factor) 133 | heatmap = np.zeros((keypoints.shape[0], output_size, output_size, output_size)) 134 | 135 | # use center of cell 136 | center_offset = 0.5 137 | 138 | for i in range(coord.shape[0]): 139 | xi, yi, zi= coord[i] 140 | heatmap[i] = np.exp(-(np.power((d3output_x+center_offset-xi)/std, 2)/2 + \ 141 | np.power((d3output_y+center_offset-yi)/std, 2)/2 + \ 142 | np.power((d3output_z+center_offset-zi)/std, 2)/2)) 143 | 144 | return heatmap 145 | 146 | 147 | class V2VVoxelization(object): 148 | def __init__(self, cubic_size, augmentation=True): 149 | self.cubic_size = cubic_size 150 | self.cropped_size, self.original_size = 88, 96 151 | self.sizes = (self.cubic_size, self.cropped_size, self.original_size) 152 | self.pool_factor = 2 153 | self.std = 1.7 154 | self.augmentation = augmentation 155 | 156 | output_size = int(self.cropped_size / self.pool_factor) 157 | # Note, range(size) and indexing = 'ij' 158 | self.d3outputs = np.meshgrid(np.arange(output_size), np.arange(output_size), np.arange(output_size), indexing='ij') 159 | 160 | def __call__(self, sample): 161 | points, keypoints, refpoint = sample['points'], sample['keypoints'], sample['refpoint'] 162 | 163 | ## Augmentations 164 | # Resize 165 | new_size = np.random.rand() * 40 + 80 166 | 167 | # Rotation 168 | angle = np.random.rand() * 80/180*np.pi - 40/180*np.pi 169 | 170 | # Translation 171 | trans = np.random.rand(3) * (self.original_size-self.cropped_size) 172 | 173 | if not self.augmentation: 174 | new_size = 100 175 | angle = 0 176 | trans = self.original_size/2 - self.cropped_size/2 177 | 178 | input = generate_cubic_input(points, refpoint, new_size, angle, trans, self.sizes) 179 | heatmap = generate_heatmap_gt(keypoints, refpoint, new_size, angle, trans, self.sizes, self.d3outputs, self.pool_factor, self.std) 180 | 181 | return input.reshape((1, *input.shape)), heatmap 182 | 183 | def voxelize(self, points, refpoint): 184 | new_size, angle, trans = 100, 0, self.original_size/2 - self.cropped_size/2 185 | input = generate_cubic_input(points, refpoint, new_size, angle, trans, self.sizes) 186 | return input.reshape((1, *input.shape)) 187 | 188 | def generate_heatmap(self, keypoints, refpoint): 189 | new_size, angle, trans = 100, 0, self.original_size/2 - self.cropped_size/2 190 | heatmap = generate_heatmap_gt(keypoints, refpoint, new_size, angle, trans, self.sizes, self.d3outputs, self.pool_factor, self.std) 191 | return heatmap 192 | 193 | def evaluate(self, heatmaps, refpoints): 194 | coords = extract_coord_from_output(heatmaps) 195 | coords *= self.pool_factor 196 | keypoints = warp2continuous(coords, refpoints, self.cubic_size, self.cropped_size) 197 | return keypoints 198 | -------------------------------------------------------------------------------- /vis/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | def plot_acc(ax, dist, acc, name): 5 | ''' 6 | acc: (K, num) 7 | dist: (K, ) 8 | name: (K, ) 9 | ''' 10 | assert(acc.shape[0] == len(name)) 11 | 12 | for i in range(len(name)): 13 | ax.plot(dist, acc[i], label=name[i]) 14 | 15 | ax.legend() 16 | 17 | ax.set_xlabel('Maximum allowed distance to GT (mm)') 18 | ax.set_ylabel('Fraction of samples within distance') 19 | 20 | 21 | def plot_mean_err(ax, mean_err, name): 22 | ''' 23 | mean_err: (K, ) 24 | name: (K, ) 25 | ''' 26 | ax.bar(name, mean_err) 27 | --------------------------------------------------------------------------------