├── LICENSE
├── README.md
├── datasets
├── msra_center
│ ├── center_test_0_refined.txt
│ ├── center_test_1_refined.txt
│ ├── center_test_2_refined.txt
│ ├── center_test_3_refined.txt
│ ├── center_test_4_refined.txt
│ ├── center_test_5_refined.txt
│ ├── center_test_6_refined.txt
│ ├── center_test_7_refined.txt
│ ├── center_test_8_refined.txt
│ ├── center_train_0_refined.txt
│ ├── center_train_1_refined.txt
│ ├── center_train_2_refined.txt
│ ├── center_train_3_refined.txt
│ ├── center_train_4_refined.txt
│ ├── center_train_5_refined.txt
│ ├── center_train_6_refined.txt
│ ├── center_train_7_refined.txt
│ └── center_train_8_refined.txt
└── msra_hand.py
├── experiments
└── msra-subject3
│ ├── gen_gt.py
│ ├── main.py
│ ├── show_acc.py
│ ├── test_res.txt
│ └── test_s3_gt.txt
├── figs
├── Challenge_result.png
├── Paper_result_hand_graph.png
├── Paper_result_hand_msra.png
├── Paper_result_hand_table.png
├── Paper_result_human_table.png
├── V2V-PoseNet.png
├── integral_pose_msra_s3_joint_acc.png
├── integral_pose_msra_s3_joint_mean_error.png
├── mean_error_compare.png
├── msra_s3_joint_acc.png
├── msra_s3_joint_mean_error.png
└── result
│ ├── Paper_result_HANDS2017.png
│ ├── Paper_result_ICVL.png
│ ├── Paper_result_ITOP_front.png
│ ├── Paper_result_ITOP_top.png
│ ├── Paper_result_MSRA.png
│ └── Paper_result_NYU.png
├── integral-pose
├── accuracy.py
├── compare_acc.py
├── gen_gt.py
├── loss.py
├── main.py
├── model.py
├── msra_hand.py
├── plot.py
├── progressbar.py
├── sampler.py
├── show_acc.py
├── solver.py
├── test_res.txt
├── test_s3_gt.txt
├── v2v_model.py
├── v2v_util.py
└── v2vposenet-loss_test_res.txt
├── lib
├── accuracy.py
├── progressbar.py
├── sampler.py
└── solver.py
├── src
├── v2v_model.py
└── v2v_util.py
└── vis
└── plot.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Gyeongsik Moon
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # V2V-PoseNet-pytorch
2 | This is a pytorch implementation of V2V-PoseNet([V2V-PoseNet: Voxel-to-Voxel Prediction Network for Accurate 3D Hand and Human Pose Estimation from a Single Depth Map](https://arxiv.org/abs/1711.07399)), which is largely based on the author's [torch7 implementation](https://github.com/mks0601/V2V-PoseNet_RELEASE).
3 |
4 | This repository provides
5 | * V2V-PoseNet core modules(model, voxelization, ..)
6 | * An experiment demo on MSRA hand pose dataset, result in ~11mm mean error.
7 | * *Additional [Integral Pose Loss](https://arxiv.org/abs/1711.08229) (or [PoseFix Loss](https://arxiv.org/abs/1812.03595)) implementation*, result in ~10mm mean error on the same demo.
8 |
9 | ## Requirements
10 | * pytorch 0.4.1 or pytorch 1.0
11 | * python 3.6
12 | * numpy
13 |
14 | ### **Warning on pytorch0.4.1 cudnn**:
15 | May need to **disable cudnn for batchnorm**, or just only use cuda instead. With cudnn for batchnorm and in float precision, the model cannot train well. My simple experiments show that:
16 |
17 | ```
18 | cudnn+float: NOT work(e.g. the loss decreases much slower, and result in a higher loss)
19 | cudnn+float+(disable batchnorm's cudnn): work(e.g. the loss decreases faster, and result in a lower loss)
20 | cudnn+double: work, but the speed is slow
21 | cuda+(float/double): work, but uses much more memroy
22 | ```
23 |
24 | There is a similar issue pointed out by https://github.com/Microsoft/human-pose-estimation.pytorch. As suggested, disable cudnn for batchnorm:
25 |
26 | ```
27 | PYTORCH=/path/to/pytorch
28 | for pytorch v0.4.0
29 | sed -i "1194s/torch\.backends\.cudnn\.enabled/False/g" ${PYTORCH}/torch/nn/functional.py
30 | for pytorch v0.4.1
31 | sed -i "1254s/torch\.backends\.cudnn\.enabled/False/g" ${PYTORCH}/torch/nn/functional.py
32 | ```
33 |
34 | ## MSRA hand dataset demo
35 | ### Usage
36 | - Clone this repo:
37 | ```
38 | git clone https://github.com/dragonbook/V2V-PoseNet-pytorch.git
39 | cd V2V-PoseNet-pytorch
40 | ```
41 |
42 | - Download [MSRA hand dataset](https://jimmysuen.github.io/) and extract to directory path/to/msra-hand.
43 |
44 | - Download [estimated centers](https://cv.snu.ac.kr/research/V2V-PoseNet/MSRA/center/center.tar.gz) of MSRA hand dataset which required by V2V-PoseNet and provided by the [author's implementation](https://github.com/mks0601/V2V-PoseNet_RELEASE). Extract them to the directory path/to/msra-hand-center.
45 | ```
46 | Note, this repository contains a copy of the msra hand centers under ./datasets/msra_center.
47 | ```
48 |
49 | - Configure data_dir=path/to/msra-hand and center_dir=path/to/msra-hand-center in ./experiments/msra-subject3/main.py. And Run following command to perform training and testing. It will train the dataset for few epochs and evaluate on the test dataset. The test result will be saved as test_res.txt and the fit result on training data will be saved as fit_res.txt
50 | ```
51 | PYTHONPATH=./ python ./experiments/msra-subject3/main.py
52 | ```
53 |
54 | - Configure data_dir=path/to/msra-hand and center_dir=path/to/msra-hand-center in ./experiments/msra-subject3/gen_gt.py. Run it to generate ground truth labels as train_s3_gt.txt and test_s3_gt.txt
55 |
56 | - Configure pred_file=path/to/test_s3_gt.txt and gt_file=path/to/test_res.txt in ./experiments/msra-subject3/show_acc.py. Run it to plot accuracy and error.
57 |
58 | - The following figures show that the simple experiment can result in about 11mm mean error.
59 |
60 | 
61 |
62 | 
63 |
64 |
65 | ## Additional [IntegralPose](https://arxiv.org/abs/1711.08229)/[PoseFix](https://arxiv.org/abs/1812.03595) style loss implementation
66 | Replaced V2V-PoseNet's loss with PoseFix's loss(one-hot heatmap loss + L1 coord loss), and it's independently implemented under ./integral-pose directory. Also, configure data_dir and center_dir in ./integral-pose/main.py, and start training. The result shows about 10mm mean error.
67 |
68 | 
69 |
70 | 
71 |
72 | 
73 |
74 | # Below is from author's README for reference
75 | # V2V-PoseNet: Voxel-to-Voxel Prediction Network for Accurate 3D Hand and Human Pose Estimation from a Single Depth Map
76 |
77 | # Introduction
78 |
79 | This is our project repository for the paper, **V2V-PoseNet: Voxel-to-Voxel Prediction Network for Accurate 3D Hand and Human Pose Estimation from a Single Depth Map ([CVPR 2018](http://cvpr2018.thecvf.com))**.
80 |
81 | We, **Team SNU CVLAB**, (Gyeongsik Moon, Juyong Chang, and Kyoung Mu Lee of [**Computer Vision Lab, Seoul National University**](https://cv.snu.ac.kr/)) are **winners** of [**HANDS2017 Challenge**](http://icvl.ee.ic.ac.uk/hands17/challenge/) on frame-based 3D hand pose estimation.
82 |
83 |
84 |
85 | Please refer to our paper for details.
86 |
87 | If you find our work useful in your research or publication, please cite our work:
88 |
89 | [1] Moon, Gyeongsik, Ju Yong Chang, and Kyoung Mu Lee. **"V2V-PoseNet: Voxel-to-Voxel Prediction Network for Accurate 3D Hand and Human Pose Estimation from a Single Depth Map."** CVPR 2018. [[arXiv](https://arxiv.org/abs/1711.07399)]
90 |
91 | ```
92 | @InProceedings{Moon_2018_CVPR_V2V-PoseNet,
93 | author = {Moon, Gyeongsik and Chang, Juyong and Lee, Kyoung Mu},
94 | title = {V2V-PoseNet: Voxel-to-Voxel Prediction Network for Accurate 3D Hand and Human Pose Estimation from a Single Depth Map},
95 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
96 | year = {2018}
97 | }
98 | ```
99 |
100 | In this repository, we provide
101 | * Our model architecture description (V2V-PoseNet)
102 | * HANDS2017 frame-based 3D hand pose estimation Challenge Results
103 | * Comparison with the previous state-of-the-art methods
104 | * Training code
105 | * Datasets we used (ICVL, NYU, MSRA, ITOP)
106 | * Trained models and estimated results
107 | * 3D hand and human pose estimation examples
108 |
109 |
110 | ## Model Architecture
111 |
112 | 
113 |
114 | ## HANDS2017 frame-based 3D hand pose estimation Challenge Results
115 |
116 | 
117 |
118 |
119 | ## Comparison with the previous state-of-the-art methods
120 |
121 | 
122 |
123 | 
124 |
125 | 
126 |
127 | # About our code
128 | ## Dependencies
129 | * [Torch7](http://torch.ch)
130 | * [CUDA](https://developer.nvidia.com/cuda-downloads)
131 | * [cuDNN](https://developer.nvidia.com/cudnn)
132 |
133 | Our code is tested under Ubuntu 14.04 and 16.04 environment with Titan X GPUs (12GB VRAM).
134 |
135 | ## Code
136 | Clone this repository into any place you want. You may follow the example below.
137 | ```bash
138 | makeReposit = [/the/directory/as/you/wish]
139 | mkdir -p $makeReposit/; cd $makeReposit/
140 | git clone https://github.com/mks0601/V2V-PoseNet_RELEASE.git
141 | ```
142 | * `src` folder contains lua script files for data loader, trainer, tester and other utilities.
143 | * `data` folder contains data converter which converts image files to the binary files.
144 |
145 | To train our model, please run the following command in the `src` directory:
146 |
147 | ```bash
148 | th rum_me.lua
149 | ```
150 |
151 | * There are some optional configurations you can adjust in the config.lua.
152 | * You have to convert the `.png` images of the ICVL and NYU dataset to the `.bin` files by running the code from `data` folder.
153 | * The directory where you have to put the dataset files and computed centers of each frame is defined in `src/data/dataset_name/data.lua`
154 | * Visualization code is finally uploaded! You have to prepare 'result_pixel.txt' for each dataset. Each row of the result file has to contain the pixel coordinates of x, y and depth of all joints (i.e, x1 y1 z1 x2 y2 z2 ...). Then run pixel2world script and run draw_DB.m
155 |
156 | # Dataset
157 | We trained and tested our model on the four 3D hand pose estimation and one 3D human pose estimation datasets.
158 |
159 | * ICVL Hand Poseture Dataset [[link](https://labicvl.github.io/hand.html)] [[paper](http://www.iis.ee.ic.ac.uk/dtang/cvpr_14.pdf)]
160 | * NYU Hand Pose Dataset [[link](https://cims.nyu.edu/~tompson/NYU_Hand_Pose_Dataset.htm)] [[paper](https://cims.nyu.edu/~tompson/others/TOG_2014_paper_PREPRINT.pdf)]
161 | * MSRA Hand Pose Dataset [[link](https://jimmysuen.github.io/)] [[paper](https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Sun_Cascaded_Hand_Pose_2015_CVPR_paper.pdf)]
162 | * HANDS2017 Challenge Dataset [[link](http://icvl.ee.ic.ac.uk/hands17/challenge/)] [[paper](https://arxiv.org/abs/1712.03917)] [[challenge benchmark paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Yuan_Depth-Based_3D_Hand_CVPR_2018_paper.pdf)]
163 | * ITOP Human Pose Dataset [[link](https://www.albert.cm/projects/viewpoint_3d_pose/)] [[paper](https://arxiv.org/abs/1603.07076)]
164 |
165 |
166 | # Results
167 | Here we provide the precomputed centers, estimated 3D coordinates and pre-trained models.
168 |
169 | The precomputed centers are obtained by training the hand center estimation network from [DeepPrior++ ](https://arxiv.org/pdf/1708.08325.pdf). Each line represents 3D world coordinate of each frame.
170 | In case of ICVL, NYU, MSRA dataset, if depth map is not exist or not contain hand, that frame is considered as invalid.
171 | In case of ITOP dataset, if 'valid' variable of a certain frame is false, that frame is considered as invalid.
172 | All test images are considered as valid.
173 |
174 | The 3D coordinates estimated on the ICVL, NYU and MSRA datasets are pixel coordinates and the 3D coordinates estimated on the ITOP datasets are world coordinates. The estimated results are from ensembled model. You can make the results from a single model by downloading the pre-trained model and testing it.
175 |
176 | * ICVL Hand Poseture Dataset [[center_trainset](https://cv.snu.ac.kr/research/V2V-PoseNet/ICVL/center/center_train_refined.txt)] [[center_testset](https://cv.snu.ac.kr/research/V2V-PoseNet/ICVL/center/center_test_refined.txt)] [[estimation](https://cv.snu.ac.kr/research/V2V-PoseNet/ICVL/coordinate/result.txt)] [[models](https://cv.snu.ac.kr/research/V2V-PoseNet/ICVL/model/model.tar.gz)]
177 | * NYU Hand Pose Dataset [[center_trainset](https://cv.snu.ac.kr/research/V2V-PoseNet/NYU/center/center_train_refined.txt)] [[center_testset](https://cv.snu.ac.kr/research/V2V-PoseNet/NYU/center/center_test_refined.txt)] [[estimation](https://cv.snu.ac.kr/research/V2V-PoseNet/NYU/coordinate/result.txt)] [[models](https://cv.snu.ac.kr/research/V2V-PoseNet/NYU/model/model.tar.gz)]
178 | * MSRA Hand Pose Dataset [[center](https://cv.snu.ac.kr/research/V2V-PoseNet/MSRA/center/center.tar.gz)] [[estimation](https://cv.snu.ac.kr/research/V2V-PoseNet/MSRA/coordinate/result.txt)] [[models](https://cv.snu.ac.kr/research/V2V-PoseNet/MSRA/model/model.tar.gz)]
179 | * ITOP Human Pose Dataset (front-view) [[center_trainset](https://cv.snu.ac.kr/research/V2V-PoseNet/ITOP_front/center/center_train.txt)] [[center_testset](https://cv.snu.ac.kr/research/V2V-PoseNet/ITOP_front/center/center_test.txt)] [[estimation](https://cv.snu.ac.kr/research/V2V-PoseNet/ITOP_front/coordinate/result.txt)] [[models](https://cv.snu.ac.kr/research/V2V-PoseNet/ITOP_front/model/model.tar.gz)]
180 | * ITOP Human Pose Dataset (top-view) [[center_trainset](https://cv.snu.ac.kr/research/V2V-PoseNet/ITOP_top/center/center_train.txt)] [[center_testset](https://cv.snu.ac.kr/research/V2V-PoseNet/ITOP_top/center/center_test.txt)] [[estimation](https://cv.snu.ac.kr/research/V2V-PoseNet/ITOP_top/coordinate/result.txt)] [[models](https://cv.snu.ac.kr/research/V2V-PoseNet/ITOP_top/model/model.tar.gz)]
181 |
182 | We used [awesome-hand-pose-estimation ](https://github.com/xinghaochen/awesome-hand-pose-estimation) to evaluate the accuracy of the V2V-PoseNet on the ICVL, NYU and MSRA dataset.
183 |
184 | Belows are qualitative results.
185 | 
186 | 
187 | 
188 | 
189 | 
190 | 
191 |
--------------------------------------------------------------------------------
/datasets/msra_hand.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import sys
4 | import struct
5 | from torch.utils.data import Dataset
6 |
7 |
8 | def pixel2world(x, y, z, img_width, img_height, fx, fy):
9 | w_x = (x - img_width / 2) * z / fx
10 | w_y = (img_height / 2 - y) * z / fy
11 | w_z = z
12 | return w_x, w_y, w_z
13 |
14 |
15 | def world2pixel(x, y, z, img_width, img_height, fx, fy):
16 | p_x = x * fx / z + img_width / 2
17 | p_y = img_height / 2 - y * fy / z
18 | return p_x, p_y
19 |
20 |
21 | def depthmap2points(image, fx, fy):
22 | h, w = image.shape
23 | x, y = np.meshgrid(np.arange(w) + 1, np.arange(h) + 1)
24 | points = np.zeros((h, w, 3), dtype=np.float32)
25 | points[:,:,0], points[:,:,1], points[:,:,2] = pixel2world(x, y, image, w, h, fx, fy)
26 | return points
27 |
28 |
29 | def points2pixels(points, img_width, img_height, fx, fy):
30 | pixels = np.zeros((points.shape[0], 2))
31 | pixels[:, 0], pixels[:, 1] = \
32 | world2pixel(points[:,0], points[:, 1], points[:, 2], img_width, img_height, fx, fy)
33 | return pixels
34 |
35 |
36 | def load_depthmap(filename, img_width, img_height, max_depth):
37 | with open(filename, mode='rb') as f:
38 | data = f.read()
39 | _, _, left, top, right, bottom = struct.unpack('I'*6, data[:6*4])
40 | num_pixel = (right - left) * (bottom - top)
41 | cropped_image = struct.unpack('f'*num_pixel, data[6*4:])
42 |
43 | cropped_image = np.asarray(cropped_image).reshape(bottom-top, -1)
44 | depth_image = np.zeros((img_height, img_width), dtype=np.float32)
45 | depth_image[top:bottom, left:right] = cropped_image
46 | depth_image[depth_image == 0] = max_depth
47 |
48 | return depth_image
49 |
50 |
51 | class MARAHandDataset(Dataset):
52 | def __init__(self, root, center_dir, mode, test_subject_id, transform=None):
53 | self.img_width = 320
54 | self.img_height = 240
55 | self.min_depth = 100
56 | self.max_depth = 700
57 | self.fx = 241.42
58 | self.fy = 241.42
59 | self.joint_num = 21
60 | self.world_dim = 3
61 | self.folder_list = ['1','2','3','4','5','6','7','8','9','I','IP','L','MP','RP','T','TIP','Y']
62 | self.subject_num = 9
63 |
64 | self.root = root
65 | self.center_dir = center_dir
66 | self.mode = mode
67 | self.test_subject_id = test_subject_id
68 | self.transform = transform
69 |
70 | if not self.mode in ['train', 'test']: raise ValueError('Invalid mode')
71 | assert self.test_subject_id >= 0 and self.test_subject_id < self.subject_num
72 |
73 | if not self._check_exists(): raise RuntimeError('Invalid MSRA hand dataset')
74 |
75 | self._load()
76 |
77 | def __getitem__(self, index):
78 | depthmap = load_depthmap(self.names[index], self.img_width, self.img_height, self.max_depth)
79 | points = depthmap2points(depthmap, self.fx, self.fy)
80 | points = points.reshape((-1, 3))
81 |
82 | sample = {
83 | 'name': self.names[index],
84 | 'points': points,
85 | 'joints': self.joints_world[index],
86 | 'refpoint': self.ref_pts[index]
87 | }
88 |
89 | if self.transform: sample = self.transform(sample)
90 |
91 | return sample
92 |
93 | def __len__(self):
94 | return self.num_samples
95 |
96 | def _load(self):
97 | self._compute_dataset_size()
98 |
99 | self.num_samples = self.train_size if self.mode == 'train' else self.test_size
100 | self.joints_world = np.zeros((self.num_samples, self.joint_num, self.world_dim))
101 | self.ref_pts = np.zeros((self.num_samples, self.world_dim))
102 | self.names = []
103 |
104 | # Collect reference center points strings
105 | if self.mode == 'train': ref_pt_file = 'center_train_' + str(self.test_subject_id) + '_refined.txt'
106 | else: ref_pt_file = 'center_test_' + str(self.test_subject_id) + '_refined.txt'
107 |
108 | with open(os.path.join(self.center_dir, ref_pt_file)) as f:
109 | ref_pt_str = [l.rstrip() for l in f]
110 |
111 | #
112 | file_id = 0
113 | frame_id = 0
114 |
115 | for mid in range(self.subject_num):
116 | if self.mode == 'train': model_chk = (mid != self.test_subject_id)
117 | elif self.mode == 'test': model_chk = (mid == self.test_subject_id)
118 | else: raise RuntimeError('unsupported mode {}'.format(self.mode))
119 |
120 | if model_chk:
121 | for fd in self.folder_list:
122 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt')
123 |
124 | lines = []
125 | with open(annot_file) as f:
126 | lines = [line.rstrip() for line in f]
127 |
128 | # skip first line
129 | for i in range(1, len(lines)):
130 | # referece point
131 | splitted = ref_pt_str[file_id].split()
132 | if splitted[0] == 'invalid':
133 | print('Warning: found invalid reference frame')
134 | file_id += 1
135 | continue
136 | else:
137 | self.ref_pts[frame_id, 0] = float(splitted[0])
138 | self.ref_pts[frame_id, 1] = float(splitted[1])
139 | self.ref_pts[frame_id, 2] = float(splitted[2])
140 |
141 | # joint point
142 | splitted = lines[i].split()
143 | for jid in range(self.joint_num):
144 | self.joints_world[frame_id, jid, 0] = float(splitted[jid * self.world_dim])
145 | self.joints_world[frame_id, jid, 1] = float(splitted[jid * self.world_dim + 1])
146 | self.joints_world[frame_id, jid, 2] = -float(splitted[jid * self.world_dim + 2])
147 |
148 | filename = os.path.join(self.root, 'P'+str(mid), fd, '{:0>6d}'.format(i-1) + '_depth.bin')
149 | self.names.append(filename)
150 |
151 | frame_id += 1
152 | file_id += 1
153 |
154 | def _compute_dataset_size(self):
155 | self.train_size, self.test_size = 0, 0
156 |
157 | for mid in range(self.subject_num):
158 | num = 0
159 | for fd in self.folder_list:
160 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt')
161 | with open(annot_file) as f:
162 | num = int(f.readline().rstrip())
163 | if mid == self.test_subject_id: self.test_size += num
164 | else: self.train_size += num
165 |
166 | def _check_exists(self):
167 | # Check basic data
168 | for mid in range(self.subject_num):
169 | for fd in self.folder_list:
170 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt')
171 | if not os.path.exists(annot_file):
172 | print('Error: annotation file {} does not exist'.format(annot_file))
173 | return False
174 |
175 | # Check precomputed centers by v2v-hand model's author
176 | for subject_id in range(self.subject_num):
177 | center_train = os.path.join(self.center_dir, 'center_train_' + str(subject_id) + '_refined.txt')
178 | center_test = os.path.join(self.center_dir, 'center_test_' + str(subject_id) + '_refined.txt')
179 | if not os.path.exists(center_train) or not os.path.exists(center_test):
180 | print('Error: precomputed center files do not exist')
181 | return False
182 |
183 | return True
184 |
--------------------------------------------------------------------------------
/experiments/msra-subject3/gen_gt.py:
--------------------------------------------------------------------------------
1 | ##%%
2 | import os
3 | import numpy as np
4 | import sys
5 | import struct
6 | from torch.utils.data import Dataset
7 |
8 |
9 | def pixel2world(x, y, z, img_width, img_height, fx, fy):
10 | w_x = (x - img_width / 2) * z / fx
11 | w_y = (img_height / 2 - y) * z / fy
12 | w_z = z
13 | return w_x, w_y, w_z
14 |
15 |
16 | def world2pixel(x, y, z, img_width, img_height, fx, fy):
17 | p_x = x * fx / z + img_width / 2
18 | p_y = img_height / 2 - y * fy / z
19 | return p_x, p_y
20 |
21 |
22 | def depthmap2points(image, fx, fy):
23 | h, w = image.shape
24 | x, y = np.meshgrid(np.arange(w) + 1, np.arange(h) + 1)
25 | points = np.zeros((h, w, 3), dtype=np.float32)
26 | points[:,:,0], points[:,:,1], points[:,:,2] = pixel2world(x, y, image, w, h, fx, fy)
27 | return points
28 |
29 |
30 | def points2pixels(points, img_width, img_height, fx, fy):
31 | pixels = np.zeros((points.shape[0], 2))
32 | pixels[:, 0], pixels[:, 1] = \
33 | world2pixel(points[:,0], points[:, 1], points[:, 2], img_width, img_height, fx, fy)
34 | return pixels
35 |
36 |
37 | def load_depthmap(filename, img_width, img_height, max_depth):
38 | with open(filename, mode='rb') as f:
39 | data = f.read()
40 | _, _, left, top, right, bottom = struct.unpack('I'*6, data[:6*4])
41 | num_pixel = (right - left) * (bottom - top)
42 | cropped_image = struct.unpack('f'*num_pixel, data[6*4:])
43 |
44 | cropped_image = np.asarray(cropped_image).reshape(bottom-top, -1)
45 | depth_image = np.zeros((img_height, img_width), dtype=np.float32)
46 | depth_image[top:bottom, left:right] = cropped_image
47 | depth_image[depth_image == 0] = max_depth
48 |
49 | return depth_image
50 |
51 |
52 | class MARAHandDataset(Dataset):
53 | def __init__(self, root, center_dir, mode, test_subject_id, transform=None):
54 | self.img_width = 320
55 | self.img_height = 240
56 | self.min_depth = 100
57 | self.max_depth = 700
58 | self.fx = 241.42
59 | self.fy = 241.42
60 | self.joint_num = 21
61 | self.world_dim = 3
62 | self.folder_list = ['1','2','3','4','5','6','7','8','9','I','IP','L','MP','RP','T','TIP','Y']
63 | self.subject_num = 9
64 |
65 | self.root = root
66 | self.center_dir = center_dir
67 | self.mode = mode
68 | self.test_subject_id = test_subject_id
69 | self.transform = transform
70 |
71 | if not self.mode in ['train', 'test']: raise ValueError('Invalid mode')
72 | assert self.test_subject_id >= 0 and self.test_subject_id < self.subject_num
73 |
74 | if not self._check_exists(): raise RuntimeError('Invalid MSRA hand dataset')
75 |
76 | self._load()
77 |
78 | def get_data(self):
79 | return self.names, self.joints_world, self.ref_pts
80 |
81 | def __getitem__(self, index):
82 | depthmap = load_depthmap(self.names[index], self.img_width, self.img_height, self.max_depth)
83 | points = depthmap2points(depthmap, self.fx, self.fy)
84 | points = points.reshape((-1, 3))
85 |
86 | sample = {
87 | 'name': self.names[index],
88 | 'points': points,
89 | 'joints': self.joints_world[index],
90 | 'refpoint': self.ref_pts[index]
91 | }
92 |
93 | if self.transform: sample = self.transform(sample)
94 |
95 | return sample
96 |
97 | def __len__(self):
98 | return self.num_samples
99 |
100 | def _load(self):
101 | self._compute_dataset_size()
102 |
103 | self.num_samples = self.train_size if self.mode == 'train' else self.test_size
104 | self.joints_world = np.zeros((self.num_samples, self.joint_num, self.world_dim))
105 | self.ref_pts = np.zeros((self.num_samples, self.world_dim))
106 | self.names = []
107 |
108 | # Collect reference center points strings
109 | if self.mode == 'train': ref_pt_file = 'center_train_' + str(self.test_subject_id) + '_refined.txt'
110 | else: ref_pt_file = 'center_test_' + str(self.test_subject_id) + '_refined.txt'
111 |
112 | with open(os.path.join(self.center_dir, ref_pt_file)) as f:
113 | ref_pt_str = [l.rstrip() for l in f]
114 |
115 | #
116 | file_id = 0
117 | frame_id = 0
118 |
119 | for mid in range(self.subject_num):
120 | if self.mode == 'train': model_chk = (mid != self.test_subject_id)
121 | elif self.mode == 'test': model_chk = (mid == self.test_subject_id)
122 | else: raise RuntimeError('unsupported mode {}'.format(self.mode))
123 |
124 | if model_chk:
125 | for fd in self.folder_list:
126 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt')
127 |
128 | lines = []
129 | with open(annot_file) as f:
130 | lines = [line.rstrip() for line in f]
131 |
132 | # skip first line
133 | for i in range(1, len(lines)):
134 | # referece point
135 | splitted = ref_pt_str[file_id].split()
136 | if splitted[0] == 'invalid':
137 | print('Warning: found invalid reference frame')
138 | file_id += 1
139 | continue
140 | else:
141 | self.ref_pts[frame_id, 0] = float(splitted[0])
142 | self.ref_pts[frame_id, 1] = float(splitted[1])
143 | self.ref_pts[frame_id, 2] = float(splitted[2])
144 |
145 | # joint point
146 | splitted = lines[i].split()
147 | for jid in range(self.joint_num):
148 | self.joints_world[frame_id, jid, 0] = float(splitted[jid * self.world_dim])
149 | self.joints_world[frame_id, jid, 1] = float(splitted[jid * self.world_dim + 1])
150 | self.joints_world[frame_id, jid, 2] = -float(splitted[jid * self.world_dim + 2])
151 |
152 | filename = os.path.join(self.root, 'P'+str(mid), fd, '{:0>6d}'.format(i-1) + '_depth.bin')
153 | self.names.append(filename)
154 |
155 | frame_id += 1
156 | file_id += 1
157 |
158 | def _compute_dataset_size(self):
159 | self.train_size, self.test_size = 0, 0
160 |
161 | for mid in range(self.subject_num):
162 | num = 0
163 | for fd in self.folder_list:
164 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt')
165 | with open(annot_file) as f:
166 | num = int(f.readline().rstrip())
167 | if mid == self.test_subject_id: self.test_size += num
168 | else: self.train_size += num
169 |
170 | def _check_exists(self):
171 | # Check basic data
172 | for mid in range(self.subject_num):
173 | for fd in self.folder_list:
174 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt')
175 | if not os.path.exists(annot_file):
176 | print('Error: annotation file {} does not exist'.format(annot_file))
177 | return False
178 |
179 | # Check precomputed centers by v2v-hand model's author
180 | for subject_id in range(self.subject_num):
181 | center_train = os.path.join(self.center_dir, 'center_train_' + str(subject_id) + '_refined.txt')
182 | center_test = os.path.join(self.center_dir, 'center_test_' + str(subject_id) + '_refined.txt')
183 | if not os.path.exists(center_train) or not os.path.exists(center_test):
184 | print('Error: precomputed center files do not exist')
185 | return False
186 |
187 | return True
188 |
189 |
190 | ##%%
191 | # Generate train_subject3_gt.txt and test_subject3_gt.txt
192 | data_dir = r'/home/maiqi/yalong/dataset/cvpr15_MSRAHandGestureDB'
193 | center_dir = r'/home/maiqi/yalong/project/KeyPoint/Code/V2V-PoseNet-Rlease-Codes/V2V-PoseNet_RELEASE-hand/data-result/MSRA-result/center'
194 | test_subject_id = 3
195 |
196 |
197 | ##%%
198 | def save_keypoints(filename, keypoints):
199 | # Reshape one sample keypoints into one line
200 | keypoints = keypoints.reshape(keypoints.shape[0], -1)
201 | np.savetxt(filename, keypoints, fmt='%0.4f')
202 |
203 |
204 | ##%%
205 | train_dataset = MARAHandDataset(root=data_dir, center_dir=center_dir, mode='train', test_subject_id=test_subject_id)
206 | names, joints_world, ref_pts = train_dataset.get_data()
207 | print('save train reslt ..')
208 | save_keypoints('./train_s3_gt.txt', joints_world)
209 | print('done ..')
210 |
211 |
212 | ##%%
213 | test_dataset = MARAHandDataset(root=data_dir, center_dir=center_dir, mode='test', test_subject_id=test_subject_id)
214 | names, joints_world, ref_pts = test_dataset.get_data()
215 | print('save test reslt ..')
216 | save_keypoints('./test_s3_gt.txt', joints_world)
217 | print('done ..')
218 |
--------------------------------------------------------------------------------
/experiments/msra-subject3/main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | import torch.backends.cudnn as cudnn
6 | import argparse
7 | import os
8 |
9 | from lib.solver import train_epoch, val_epoch, test_epoch
10 | from lib.sampler import ChunkSampler
11 | from src.v2v_model import V2VModel
12 | from src.v2v_util import V2VVoxelization
13 | from datasets.msra_hand import MARAHandDataset
14 |
15 |
16 | #######################################################################################
17 | # Note,
18 | # Run in project root direcotry(ROOT_DIR) with:
19 | # PYTHONPATH=./ python experiments/msra-subject3/main.py
20 | #
21 | # This script will train model on MSRA hand datasets, save checkpoints to ROOT_DIR/checkpoint,
22 | # and save test results(test_res.txt) and fit results(fit_res.txt) to ROOT_DIR.
23 | #
24 |
25 |
26 | #######################################################################################
27 | ## Some helpers
28 | def parse_args():
29 | parser = argparse.ArgumentParser(description='PyTorch Hand Keypoints Estimation Training')
30 | #parser.add_argument('--resume', 'r', action='store_true', help='resume from checkpoint')
31 | parser.add_argument('--resume', '-r', default=-1, type=int, help='resume after epoch')
32 | args = parser.parse_args()
33 | return args
34 |
35 |
36 | #######################################################################################
37 | ## Configurations
38 | print('Warning: disable cudnn for batchnorm first, or just use only cuda instead!')
39 |
40 | # When we need to resume training, enable randomness to avoid seeing the determinstic
41 | # (agumented) samples many times.
42 | # np.random.seed(1)
43 | # torch.manual_seed(1)
44 | # torch.cuda.manual_seed(1)
45 |
46 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
47 | dtype = torch.float
48 |
49 | #
50 | args = parse_args()
51 | resume_train = args.resume >= 0
52 | resume_after_epoch = args.resume
53 |
54 | save_checkpoint = True
55 | checkpoint_per_epochs = 1
56 | checkpoint_dir = r'./checkpoint'
57 |
58 | start_epoch = 0
59 | epochs_num = 15
60 |
61 | batch_size = 12
62 |
63 |
64 | #######################################################################################
65 | ## Data, transform, dataset and loader
66 | # Data
67 | print('==> Preparing data ..')
68 | data_dir = r'/home/maiqi/yalong/dataset/cvpr15_MSRAHandGestureDB'
69 | center_dir = r'/home/maiqi/yalong/project/KeyPoint/Code/V2V-PoseNet-Rlease-Codes/V2V-PoseNet_RELEASE-hand/data-result/MSRA-result/center'
70 | keypoints_num = 21
71 | test_subject_id = 3
72 | cubic_size = 200
73 |
74 |
75 | # Transform
76 | voxelization_train = V2VVoxelization(cubic_size=200, augmentation=True)
77 | voxelization_val = V2VVoxelization(cubic_size=200, augmentation=False)
78 |
79 |
80 | def transform_train(sample):
81 | points, keypoints, refpoint = sample['points'], sample['joints'], sample['refpoint']
82 | assert(keypoints.shape[0] == keypoints_num)
83 | input, heatmap = voxelization_train({'points': points, 'keypoints': keypoints, 'refpoint': refpoint})
84 | return (torch.from_numpy(input), torch.from_numpy(heatmap))
85 |
86 |
87 | def transform_val(sample):
88 | points, keypoints, refpoint = sample['points'], sample['joints'], sample['refpoint']
89 | assert(keypoints.shape[0] == keypoints_num)
90 | input, heatmap = voxelization_val({'points': points, 'keypoints': keypoints, 'refpoint': refpoint})
91 | return (torch.from_numpy(input), torch.from_numpy(heatmap))
92 |
93 |
94 | # Dataset and loader
95 | train_set = MARAHandDataset(data_dir, center_dir, 'train', test_subject_id, transform_train)
96 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=6)
97 | #train_num = 1
98 | #train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=False, num_workers=6,sampler=ChunkSampler(train_num, 0))
99 |
100 | # No separate validation dataset, just use test dataset instead
101 | val_set = MARAHandDataset(data_dir, center_dir, 'test', test_subject_id, transform_val)
102 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=6)
103 |
104 |
105 | #######################################################################################
106 | ## Model, criterion and optimizer
107 | print('==> Constructing model ..')
108 | net = V2VModel(input_channels=1, output_channels=keypoints_num)
109 |
110 | net = net.to(device, dtype)
111 | if device == torch.device('cuda'):
112 | torch.backends.cudnn.enabled = True
113 | cudnn.benchmark = True
114 | print('cudnn.enabled: ', torch.backends.cudnn.enabled)
115 |
116 | criterion = nn.MSELoss()
117 |
118 | optimizer = optim.Adam(net.parameters())
119 | #optimizer = optim.RMSprop(net.parameters(), lr=2.5e-4)
120 |
121 |
122 | #######################################################################################
123 | ## Resume
124 | if resume_train:
125 | # Load checkpoint
126 | epoch = resume_after_epoch
127 | checkpoint_file = os.path.join(checkpoint_dir, 'epoch'+str(epoch)+'.pth')
128 |
129 | print('==> Resuming from checkpoint after epoch {} ..'.format(epoch))
130 | assert os.path.isdir(checkpoint_dir), 'Error: no checkpoint directory found!'
131 | assert os.path.isfile(checkpoint_file), 'Error: no checkpoint file of epoch {}'.format(epoch)
132 |
133 | checkpoint = torch.load(os.path.join(checkpoint_dir, 'epoch'+str(epoch)+'.pth'))
134 | net.load_state_dict(checkpoint['model_state_dict'])
135 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
136 | start_epoch = checkpoint['epoch'] + 1
137 |
138 |
139 | #######################################################################################
140 | ## Train and Validate
141 | print('==> Training ..')
142 | for epoch in range(start_epoch, start_epoch + epochs_num):
143 | print('Epoch: {}'.format(epoch))
144 | train_epoch(net, criterion, optimizer, train_loader, device=device, dtype=dtype)
145 | val_epoch(net, criterion, val_loader, device=device, dtype=dtype)
146 |
147 | if save_checkpoint and epoch % checkpoint_per_epochs == 0:
148 | if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir)
149 | checkpoint_file = os.path.join(checkpoint_dir, 'epoch'+str(epoch)+'.pth')
150 | checkpoint = {
151 | 'model_state_dict': net.state_dict(),
152 | 'optimizer_state_dict': optimizer.state_dict(),
153 | 'epoch': epoch
154 | }
155 | torch.save(checkpoint, checkpoint_file)
156 |
157 |
158 | #######################################################################################
159 | ## Test
160 | print('==> Testing ..')
161 | voxelize_input = voxelization_train.voxelize
162 | evaluate_keypoints = voxelization_train.evaluate
163 |
164 |
165 | def transform_test(sample):
166 | points, refpoint = sample['points'], sample['refpoint']
167 | input = voxelize_input(points, refpoint)
168 | return torch.from_numpy(input), torch.from_numpy(refpoint.reshape((1, -1)))
169 |
170 |
171 | def transform_output(heatmaps, refpoints):
172 | keypoints = evaluate_keypoints(heatmaps, refpoints)
173 | return keypoints
174 |
175 |
176 | class BatchResultCollector():
177 | def __init__(self, samples_num, transform_output):
178 | self.samples_num = samples_num
179 | self.transform_output = transform_output
180 | self.keypoints = None
181 | self.idx = 0
182 |
183 | def __call__(self, data_batch):
184 | inputs_batch, outputs_batch, extra_batch = data_batch
185 | outputs_batch = outputs_batch.cpu().numpy()
186 | refpoints_batch = extra_batch.cpu().numpy()
187 |
188 | keypoints_batch = self.transform_output(outputs_batch, refpoints_batch)
189 |
190 | if self.keypoints is None:
191 | # Initialize keypoints until dimensions awailable now
192 | self.keypoints = np.zeros((self.samples_num, *keypoints_batch.shape[1:]))
193 |
194 | batch_size = keypoints_batch.shape[0]
195 | self.keypoints[self.idx:self.idx+batch_size] = keypoints_batch
196 | self.idx += batch_size
197 |
198 | def get_result(self):
199 | return self.keypoints
200 |
201 |
202 | print('Test on test dataset ..')
203 | def save_keypoints(filename, keypoints):
204 | # Reshape one sample keypoints into one line
205 | keypoints = keypoints.reshape(keypoints.shape[0], -1)
206 | np.savetxt(filename, keypoints, fmt='%0.4f')
207 |
208 |
209 | test_set = MARAHandDataset(data_dir, center_dir, 'test', test_subject_id, transform_test)
210 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=6)
211 | test_res_collector = BatchResultCollector(len(test_set), transform_output)
212 |
213 | test_epoch(net, test_loader, test_res_collector, device, dtype)
214 | keypoints_test = test_res_collector.get_result()
215 | save_keypoints('./test_res.txt', keypoints_test)
216 |
217 |
218 | print('Fit on train dataset ..')
219 | fit_set = MARAHandDataset(data_dir, center_dir, 'train', test_subject_id, transform_test)
220 | fit_loader = torch.utils.data.DataLoader(fit_set, batch_size=batch_size, shuffle=False, num_workers=6)
221 | fit_res_collector = BatchResultCollector(len(fit_set), transform_output)
222 |
223 | test_epoch(net, fit_loader, fit_res_collector, device, dtype)
224 | keypoints_fit = fit_res_collector.get_result()
225 | save_keypoints('./fit_res.txt', keypoints_fit)
226 |
227 | print('All done ..')
228 |
--------------------------------------------------------------------------------
/experiments/msra-subject3/show_acc.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('../../')
3 |
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from lib.accuracy import *
7 | from vis.plot import *
8 |
9 |
10 | gt_file = r'./test_s3_gt.txt'
11 | pred_file = r'./test_res.txt'
12 |
13 |
14 | gt = np.loadtxt(gt_file)
15 | gt = gt.reshape(gt.shape[0], -1, 3)
16 |
17 | pred = np.loadtxt(pred_file)
18 | pred = pred.reshape(pred.shape[0], -1, 3)
19 |
20 | print('gt: ', gt.shape)
21 | print('pred: ', pred.shape)
22 |
23 |
24 | keypoints_num = 21
25 | names = ['joint'+str(i+1) for i in range(keypoints_num)]
26 |
27 |
28 | dist, acc = compute_dist_acc_wrapper(pred, gt, max_dist=100, num=100)
29 |
30 | fig, ax = plt.subplots()
31 | plot_acc(ax, dist, acc, names)
32 | fig.savefig('msra_s3_joint_acc.png')
33 | plt.show()
34 |
35 |
36 | mean_err = compute_mean_err(pred, gt)
37 | fig, ax = plt.subplots()
38 | plot_mean_err(ax, mean_err, names)
39 | fig.savefig('msra_s3_joint_acc.png')
40 | plt.show()
41 |
42 |
43 | print('mean_err: {}'.format(mean_err))
44 | mean_err_all = compute_mean_err(pred.reshape((-1, 1, 3)), gt.reshape((-1, 1,3)))
45 | print('mean_err_all: ', mean_err_all)
46 |
--------------------------------------------------------------------------------
/figs/Challenge_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/Challenge_result.png
--------------------------------------------------------------------------------
/figs/Paper_result_hand_graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/Paper_result_hand_graph.png
--------------------------------------------------------------------------------
/figs/Paper_result_hand_msra.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/Paper_result_hand_msra.png
--------------------------------------------------------------------------------
/figs/Paper_result_hand_table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/Paper_result_hand_table.png
--------------------------------------------------------------------------------
/figs/Paper_result_human_table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/Paper_result_human_table.png
--------------------------------------------------------------------------------
/figs/V2V-PoseNet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/V2V-PoseNet.png
--------------------------------------------------------------------------------
/figs/integral_pose_msra_s3_joint_acc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/integral_pose_msra_s3_joint_acc.png
--------------------------------------------------------------------------------
/figs/integral_pose_msra_s3_joint_mean_error.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/integral_pose_msra_s3_joint_mean_error.png
--------------------------------------------------------------------------------
/figs/mean_error_compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/mean_error_compare.png
--------------------------------------------------------------------------------
/figs/msra_s3_joint_acc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/msra_s3_joint_acc.png
--------------------------------------------------------------------------------
/figs/msra_s3_joint_mean_error.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/msra_s3_joint_mean_error.png
--------------------------------------------------------------------------------
/figs/result/Paper_result_HANDS2017.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/result/Paper_result_HANDS2017.png
--------------------------------------------------------------------------------
/figs/result/Paper_result_ICVL.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/result/Paper_result_ICVL.png
--------------------------------------------------------------------------------
/figs/result/Paper_result_ITOP_front.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/result/Paper_result_ITOP_front.png
--------------------------------------------------------------------------------
/figs/result/Paper_result_ITOP_top.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/result/Paper_result_ITOP_top.png
--------------------------------------------------------------------------------
/figs/result/Paper_result_MSRA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/result/Paper_result_MSRA.png
--------------------------------------------------------------------------------
/figs/result/Paper_result_NYU.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonbook/V2V-PoseNet-pytorch/90045b61c45f18dc20b410e2de14bd22be55fe0e/figs/result/Paper_result_NYU.png
--------------------------------------------------------------------------------
/integral-pose/accuracy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def compute_dist_acc_wrapper(pred, gt, max_dist=10, num=100):
5 | '''
6 | pred: (N, K, 3)
7 | gt: (N, K, 3)
8 |
9 | return dist: (K, )
10 | return acc: (K, num)
11 | '''
12 | assert(pred.shape == gt.shape)
13 | assert(len(pred.shape) == 3)
14 |
15 | dist = np.linspace(0, max_dist, num)
16 | return dist, compute_dist_acc(pred, gt, dist)
17 |
18 |
19 | def compute_dist_acc(pred, gt, dist):
20 | '''
21 | pred: (N, K, 3)
22 | gt: (N, K, 3)
23 | dist: (M, )
24 |
25 | return acc: (K, M)
26 | '''
27 | assert(pred.shape == gt.shape)
28 | assert(len(pred.shape) == 3)
29 |
30 | N, K = pred.shape[0], pred.shape[1]
31 | err_dist = np.sqrt(np.sum((pred - gt)**2, axis=2)) # (N, K)
32 |
33 | acc = np.zeros((K, dist.shape[0]))
34 |
35 | for i, d in enumerate(dist):
36 | acc_d = (err_dist < d).sum(axis=0) / N
37 | acc[:,i] = acc_d
38 |
39 | return acc
40 |
41 |
42 | def compute_mean_err(pred, gt):
43 | '''
44 | pred: (N, K, 3)
45 | gt: (N, K, 3)
46 |
47 | mean_err: (K,)
48 | '''
49 | N, K = pred.shape[0], pred.shape[1]
50 | err_dist = np.sqrt(np.sum((pred - gt)**2, axis=2)) # (N, K)
51 | return np.mean(err_dist, axis=0)
52 |
53 |
54 | def compute_dist_err(pred, gt):
55 | '''
56 | pred: (N, K, 3)
57 | return: (N, K)
58 | '''
59 | return np.sqrt(np.sum((pred - gt)**2, axis=2))
60 |
--------------------------------------------------------------------------------
/integral-pose/compare_acc.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | from accuracy import *
5 | from plot import *
6 |
7 |
8 | gt_file = r'./test_s3_gt.txt'
9 | pred_file = r'./v2vposenet-loss_test_res.txt' # copied from ./experiments/msra-subject3/test_res.txt
10 | pred_file1 = r'./test_res.txt' # one-hot heatmap loss + L1 coord loss
11 |
12 |
13 | gt = np.loadtxt(gt_file)
14 | gt = gt.reshape(gt.shape[0], -1, 3)
15 |
16 | pred = np.loadtxt(pred_file)
17 | pred = pred.reshape(pred.shape[0], -1, 3)
18 |
19 | pred1 = np.loadtxt(pred_file1)
20 | pred1 = pred1.reshape(pred1.shape[0], -1, 3)
21 |
22 | print('gt: ', gt.shape)
23 | print('pred: ', pred.shape)
24 | print('pred1: ', pred1.shape)
25 |
26 |
27 | names = ['kp'+str(i+1) for i in range(gt.shape[1]) ]
28 |
29 |
30 | ##
31 | dist, acc = compute_dist_acc_wrapper(pred, gt, 100, 100)
32 |
33 | fig, ax = plt.subplots()
34 | plot_acc(ax, dist, acc, names)
35 | plt.show()
36 |
37 |
38 | ##
39 | _, acc1 = compute_dist_acc_wrapper(pred1, gt, 100, 100)
40 |
41 | fig, ax = plt.subplots()
42 | plot_acc(ax, dist, acc1, names)
43 | plt.show()
44 |
45 | ##
46 | mean_err = compute_mean_err(pred, gt)
47 | mean_err1 = compute_mean_err(pred1, gt)
48 |
49 | fig, ax = plt.subplots()
50 | _pos = np.arange(len(names))
51 | ax.bar(_pos, mean_err, width=0.1, label='gaussian-heatmap-loss(V2VPoseNet-loss)')
52 | ax.bar(_pos+0.1, mean_err1, width=0.1, label='one_hot-heamap-loss + L1 coord loss')
53 | ax.set_xticks(_pos)
54 | ax.set_xticklabels(names)
55 | ax.set_xlabel('keypoints categories')
56 | ax.set_ylabel('distance mean error (mm)')
57 | ax.legend(loc='upper right')
58 |
59 | plt.show()
60 |
61 |
62 | all_mean_err = np.mean(mean_err)
63 | all_mean_err1 = np.mean(mean_err1)
64 | print('all_mean_error: ', all_mean_err)
65 | print('all_mean_error1: ', all_mean_err1)
66 |
67 |
68 | ## histogram
69 | dist_err = compute_dist_err(pred, gt)
70 | dist_err1 = compute_dist_err(pred1, gt)
71 |
72 | fig, ax = plt.subplots()
73 | bins = np.linspace(0, 10, 200)
74 | ax.hist([dist_err[:].ravel(), dist_err1[:].ravel()], bins=bins, label=['gaussian-heatmap-loss(V2VPoseNet-loss)', 'one_hot-heamap-loss + L1 coord loss'])
75 | ax.set_xlabel('distance error (mm)')
76 | ax.set_ylabel('keypoints number')
77 | plt.legend(loc='upper right')
78 |
79 | plt.show()
--------------------------------------------------------------------------------
/integral-pose/gen_gt.py:
--------------------------------------------------------------------------------
1 | ##%%
2 | import os
3 | import numpy as np
4 | import sys
5 | import struct
6 | from torch.utils.data import Dataset
7 |
8 |
9 | def pixel2world(x, y, z, img_width, img_height, fx, fy):
10 | w_x = (x - img_width / 2) * z / fx
11 | w_y = (img_height / 2 - y) * z / fy
12 | w_z = z
13 | return w_x, w_y, w_z
14 |
15 |
16 | def world2pixel(x, y, z, img_width, img_height, fx, fy):
17 | p_x = x * fx / z + img_width / 2
18 | p_y = img_height / 2 - y * fy / z
19 | return p_x, p_y
20 |
21 |
22 | def depthmap2points(image, fx, fy):
23 | h, w = image.shape
24 | x, y = np.meshgrid(np.arange(w) + 1, np.arange(h) + 1)
25 | points = np.zeros((h, w, 3), dtype=np.float32)
26 | points[:,:,0], points[:,:,1], points[:,:,2] = pixel2world(x, y, image, w, h, fx, fy)
27 | return points
28 |
29 |
30 | def points2pixels(points, img_width, img_height, fx, fy):
31 | pixels = np.zeros((points.shape[0], 2))
32 | pixels[:, 0], pixels[:, 1] = \
33 | world2pixel(points[:,0], points[:, 1], points[:, 2], img_width, img_height, fx, fy)
34 | return pixels
35 |
36 |
37 | def load_depthmap(filename, img_width, img_height, max_depth):
38 | with open(filename, mode='rb') as f:
39 | data = f.read()
40 | _, _, left, top, right, bottom = struct.unpack('I'*6, data[:6*4])
41 | num_pixel = (right - left) * (bottom - top)
42 | cropped_image = struct.unpack('f'*num_pixel, data[6*4:])
43 |
44 | cropped_image = np.asarray(cropped_image).reshape(bottom-top, -1)
45 | depth_image = np.zeros((img_height, img_width), dtype=np.float32)
46 | depth_image[top:bottom, left:right] = cropped_image
47 | depth_image[depth_image == 0] = max_depth
48 |
49 | return depth_image
50 |
51 |
52 | class MARAHandDataset(Dataset):
53 | def __init__(self, root, center_dir, mode, test_subject_id, transform=None):
54 | self.img_width = 320
55 | self.img_height = 240
56 | self.min_depth = 100
57 | self.max_depth = 700
58 | self.fx = 241.42
59 | self.fy = 241.42
60 | self.joint_num = 21
61 | self.world_dim = 3
62 | self.folder_list = ['1','2','3','4','5','6','7','8','9','I','IP','L','MP','RP','T','TIP','Y']
63 | self.subject_num = 9
64 |
65 | self.root = root
66 | self.center_dir = center_dir
67 | self.mode = mode
68 | self.test_subject_id = test_subject_id
69 | self.transform = transform
70 |
71 | if not self.mode in ['train', 'test']: raise ValueError('Invalid mode')
72 | assert self.test_subject_id >= 0 and self.test_subject_id < self.subject_num
73 |
74 | if not self._check_exists(): raise RuntimeError('Invalid MSRA hand dataset')
75 |
76 | self._load()
77 |
78 | def get_data(self):
79 | return self.names, self.joints_world, self.ref_pts
80 |
81 | def __getitem__(self, index):
82 | depthmap = load_depthmap(self.names[index], self.img_width, self.img_height, self.max_depth)
83 | points = depthmap2points(depthmap, self.fx, self.fy)
84 | points = points.reshape((-1, 3))
85 |
86 | sample = {
87 | 'name': self.names[index],
88 | 'points': points,
89 | 'joints': self.joints_world[index],
90 | 'refpoint': self.ref_pts[index]
91 | }
92 |
93 | if self.transform: sample = self.transform(sample)
94 |
95 | return sample
96 |
97 | def __len__(self):
98 | return self.num_samples
99 |
100 | def _load(self):
101 | self._compute_dataset_size()
102 |
103 | self.num_samples = self.train_size if self.mode == 'train' else self.test_size
104 | self.joints_world = np.zeros((self.num_samples, self.joint_num, self.world_dim))
105 | self.ref_pts = np.zeros((self.num_samples, self.world_dim))
106 | self.names = []
107 |
108 | # Collect reference center points strings
109 | if self.mode == 'train': ref_pt_file = 'center_train_' + str(self.test_subject_id) + '_refined.txt'
110 | else: ref_pt_file = 'center_test_' + str(self.test_subject_id) + '_refined.txt'
111 |
112 | with open(os.path.join(self.center_dir, ref_pt_file)) as f:
113 | ref_pt_str = [l.rstrip() for l in f]
114 |
115 | #
116 | file_id = 0
117 | frame_id = 0
118 |
119 | for mid in range(self.subject_num):
120 | if self.mode == 'train': model_chk = (mid != self.test_subject_id)
121 | elif self.mode == 'test': model_chk = (mid == self.test_subject_id)
122 | else: raise RuntimeError('unsupported mode {}'.format(self.mode))
123 |
124 | if model_chk:
125 | for fd in self.folder_list:
126 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt')
127 |
128 | lines = []
129 | with open(annot_file) as f:
130 | lines = [line.rstrip() for line in f]
131 |
132 | # skip first line
133 | for i in range(1, len(lines)):
134 | # referece point
135 | splitted = ref_pt_str[file_id].split()
136 | if splitted[0] == 'invalid':
137 | print('Warning: found invalid reference frame')
138 | file_id += 1
139 | continue
140 | else:
141 | self.ref_pts[frame_id, 0] = float(splitted[0])
142 | self.ref_pts[frame_id, 1] = float(splitted[1])
143 | self.ref_pts[frame_id, 2] = float(splitted[2])
144 |
145 | # joint point
146 | splitted = lines[i].split()
147 | for jid in range(self.joint_num):
148 | self.joints_world[frame_id, jid, 0] = float(splitted[jid * self.world_dim])
149 | self.joints_world[frame_id, jid, 1] = float(splitted[jid * self.world_dim + 1])
150 | self.joints_world[frame_id, jid, 2] = -float(splitted[jid * self.world_dim + 2])
151 |
152 | filename = os.path.join(self.root, 'P'+str(mid), fd, '{:0>6d}'.format(i-1) + '_depth.bin')
153 | self.names.append(filename)
154 |
155 | frame_id += 1
156 | file_id += 1
157 |
158 | def _compute_dataset_size(self):
159 | self.train_size, self.test_size = 0, 0
160 |
161 | for mid in range(self.subject_num):
162 | num = 0
163 | for fd in self.folder_list:
164 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt')
165 | with open(annot_file) as f:
166 | num = int(f.readline().rstrip())
167 | if mid == self.test_subject_id: self.test_size += num
168 | else: self.train_size += num
169 |
170 | def _check_exists(self):
171 | # Check basic data
172 | for mid in range(self.subject_num):
173 | for fd in self.folder_list:
174 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt')
175 | if not os.path.exists(annot_file):
176 | print('Error: annotation file {} does not exist'.format(annot_file))
177 | return False
178 |
179 | # Check precomputed centers by v2v-hand model's author
180 | for subject_id in range(self.subject_num):
181 | center_train = os.path.join(self.center_dir, 'center_train_' + str(subject_id) + '_refined.txt')
182 | center_test = os.path.join(self.center_dir, 'center_test_' + str(subject_id) + '_refined.txt')
183 | if not os.path.exists(center_train) or not os.path.exists(center_test):
184 | print('Error: precomputed center files do not exist')
185 | return False
186 |
187 | return True
188 |
189 |
190 | ##%%
191 | # Generate train_subject3_gt.txt and test_subject3_gt.txt
192 | data_dir = r'/home/maiqi/yalong/dataset/cvpr15_MSRAHandGestureDB'
193 | center_dir = r'/home/maiqi/yalong/project/KeyPoint/Code/V2V-PoseNet-Rlease-Codes/V2V-PoseNet_RELEASE-hand/data-result/MSRA-result/center'
194 | test_subject_id = 3
195 |
196 |
197 | ##%%
198 | def save_keypoints(filename, keypoints):
199 | # Reshape one sample keypoints into one line
200 | keypoints = keypoints.reshape(keypoints.shape[0], -1)
201 | np.savetxt(filename, keypoints, fmt='%0.4f')
202 |
203 |
204 | ##%%
205 | train_dataset = MARAHandDataset(root=data_dir, center_dir=center_dir, mode='train', test_subject_id=test_subject_id)
206 | names, joints_world, ref_pts = train_dataset.get_data()
207 | print('save train reslt ..')
208 | save_keypoints('./train_s3_gt.txt', joints_world)
209 | print('done ..')
210 |
211 |
212 | ##%%
213 | test_dataset = MARAHandDataset(root=data_dir, center_dir=center_dir, mode='test', test_subject_id=test_subject_id)
214 | names, joints_world, ref_pts = test_dataset.get_data()
215 | print('save test reslt ..')
216 | save_keypoints('./test_s3_gt.txt', joints_world)
217 | print('done ..')
218 |
--------------------------------------------------------------------------------
/integral-pose/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class SoftmaxCrossEntropyWithLogits(nn.Module):
7 | '''
8 | Similar to tensorflow's tf.nn.softmax_cross_entropy_with_logits
9 | ref: https://gist.github.com/tejaskhot/cf3d087ce4708c422e68b3b747494b9f
10 |
11 | The 'input' is unnormalized scores.
12 | The 'target' is a probability distribution.
13 |
14 | Shape:
15 | Input: (N, C), batch size N, with C classes
16 | Target: (N, C), batch size N, with C classes
17 | '''
18 | def __init__(self):
19 | super(SoftmaxCrossEntropyWithLogits, self).__init__()
20 |
21 | def forward(self, input, target):
22 | loss = torch.sum(-target * F.log_softmax(input, -1), -1)
23 | mean_loss = torch.mean(loss)
24 | return mean_loss
25 |
26 |
27 | class MixedLoss(nn.Module):
28 | '''
29 | ref: https://github.com/mks0601/PoseFix_RELEASE/blob/master/main/model.py
30 |
31 | input: {
32 | 'heatmap': (N, C, X, Y, Z), unnormalized
33 | 'coord': (N, C, 3)
34 | }
35 |
36 | target: {
37 | 'heatmap': (N, C, X, Y, Z), normalized
38 | 'coord': (N, C, 3)
39 | }
40 |
41 | '''
42 | def __init__(self, heatmap_weight=0.5):
43 | # def __init__(self, heatmap_weight=0.05):
44 | super(MixedLoss, self).__init__()
45 | self.w1 = heatmap_weight
46 | self.w2 = 1 - self.w1
47 | self.cross_entropy_loss = SoftmaxCrossEntropyWithLogits()
48 |
49 | def forward(self, input, target):
50 | pred_heatmap, pred_coord = input['heatmap'], input['coord']
51 | gt_heatmap, gt_coord = target['heatmap'], target['coord']
52 |
53 | # Heatmap loss
54 | N, C = pred_heatmap.shape[0:2]
55 | pred_heatmap = pred_heatmap.view(N*C, -1)
56 | gt_heatmap = gt_heatmap.view(N*C, -1)
57 |
58 | # Note, averaged over N*C
59 | hm_loss = self.cross_entropy_loss(pred_heatmap, gt_heatmap)
60 |
61 | # Coord L1 loss
62 | l1_loss = torch.mean(torch.abs(pred_coord - gt_coord))
63 |
64 | return self.w1 * hm_loss + self.w2 * l1_loss
65 |
--------------------------------------------------------------------------------
/integral-pose/main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | import torch.backends.cudnn as cudnn
6 | import argparse
7 | import os
8 |
9 | from solver import train_epoch, val_epoch, test_epoch
10 | from sampler import ChunkSampler
11 | from model import Model
12 | from loss import MixedLoss
13 | from v2v_util import V2VVoxelization
14 | from msra_hand import MARAHandDataset
15 |
16 |
17 | #######################################################################################
18 | # Note,
19 | # Run in project root direcotry(ROOT_DIR) with:
20 | # PYTHONPATH=./ python experiments/msra-subject3/main.py
21 | #
22 | # This script will train model on MSRA hand datasets, save checkpoints to ROOT_DIR/checkpoint,
23 | # and save test results(test_res.txt) and fit results(fit_res.txt) to ROOT_DIR.
24 | #
25 |
26 |
27 | #######################################################################################
28 | ## Some helpers
29 | def parse_args():
30 | parser = argparse.ArgumentParser(description='PyTorch Hand Keypoints Estimation Training')
31 | #parser.add_argument('--resume', 'r', action='store_true', help='resume from checkpoint')
32 | parser.add_argument('--resume', '-r', default=-1, type=int, help='resume after epoch')
33 | args = parser.parse_args()
34 | return args
35 |
36 |
37 | #######################################################################################
38 | ## Configurations
39 | print('Warning: disable cudnn for batchnorm first, or just use only cuda instead!')
40 |
41 | # When we need to resume training, enable randomness to avoid seeing the determinstic
42 | # (agumented) samples many times.
43 | # np.random.seed(1)
44 | # torch.manual_seed(1)
45 | # torch.cuda.manual_seed(1)
46 |
47 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
48 | dtype = torch.float
49 |
50 | #
51 | args = parse_args()
52 | resume_train = args.resume >= 0
53 | resume_after_epoch = args.resume
54 |
55 | save_checkpoint = True
56 | checkpoint_per_epochs = 1
57 | checkpoint_dir = r'./checkpoint'
58 |
59 | start_epoch = 0
60 | epochs_num = 15
61 |
62 | batch_size = 12
63 |
64 |
65 | #######################################################################################
66 | ## Data, transform, dataset and loader
67 | # Data
68 | print('==> Preparing data ..')
69 | data_dir = r'/home/maiqi/yalong/dataset/cvpr15_MSRAHandGestureDB'
70 | center_dir = r'/home/maiqi/yalong/project/KeyPoint/Code/V2V-PoseNet-Rlease-Codes/V2V-PoseNet_RELEASE-hand/data-result/MSRA-result/center'
71 | keypoints_num = 21
72 | test_subject_id = 3
73 | cubic_size = 200
74 |
75 |
76 | # Transform
77 | voxelization_train = V2VVoxelization(cubic_size=200, augmentation=True)
78 | voxelization_val = V2VVoxelization(cubic_size=200, augmentation=False)
79 |
80 |
81 | def transform_train(sample):
82 | points, keypoints, refpoint = sample['points'], sample['joints'], sample['refpoint']
83 | assert(keypoints.shape[0] == keypoints_num)
84 |
85 | input, heatmap, coord = voxelization_train({'points': points, 'keypoints': keypoints, 'refpoint': refpoint})
86 |
87 | sample = {
88 | 'input': input,
89 | 'target': {
90 | 'heatmap': heatmap,
91 | 'coord': coord
92 | },
93 | 'extra': {
94 | 'refpoint': refpoint.reshape((1, -1))
95 | }
96 | }
97 |
98 | return sample
99 |
100 |
101 |
102 | def transform_val(sample):
103 | points, keypoints, refpoint = sample['points'], sample['joints'], sample['refpoint']
104 | assert(keypoints.shape[0] == keypoints_num)
105 |
106 | input, heatmap, coord = voxelization_val({'points': points, 'keypoints': keypoints, 'refpoint': refpoint})
107 |
108 | sample = {
109 | 'input': input,
110 | 'target': {
111 | 'heatmap': heatmap,
112 | 'coord': coord
113 | },
114 | 'extra': {
115 | 'refpoint': refpoint.reshape((1, -1))
116 | }
117 | }
118 |
119 | return sample
120 |
121 |
122 | # Dataset and loader
123 | train_set = MARAHandDataset(data_dir, center_dir, 'train', test_subject_id, transform_train)
124 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=6)
125 | # train_num = 24
126 | # train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=False, num_workers=6,sampler=ChunkSampler(train_num, 0))
127 |
128 | # No separate validation dataset, just use test dataset instead
129 | val_set = MARAHandDataset(data_dir, center_dir, 'test', test_subject_id, transform_val)
130 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=6)
131 | # val_num = 24
132 | # val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=6, sampler=ChunkSampler(val_num))
133 |
134 |
135 | #######################################################################################
136 | ## Model, criterion and optimizer
137 | print('==> Constructing model ..')
138 | output_res = 44
139 | net = Model(in_channels=1, out_channels=keypoints_num, output_res=output_res)
140 |
141 | net = net.to(device, dtype)
142 | if device == torch.device('cuda'):
143 | torch.backends.cudnn.enabled = True
144 | cudnn.benchmark = True
145 | print('cudnn.enabled: ', torch.backends.cudnn.enabled)
146 |
147 |
148 | criterion = MixedLoss()
149 | optimizer = optim.Adam(net.parameters())
150 |
151 |
152 | #######################################################################################
153 | ## Resume
154 | if resume_train:
155 | # Load checkpoint
156 | epoch = resume_after_epoch
157 | checkpoint_file = os.path.join(checkpoint_dir, 'epoch'+str(epoch)+'.pth')
158 |
159 | print('==> Resuming from checkpoint after epoch {} ..'.format(epoch))
160 | assert os.path.isdir(checkpoint_dir), 'Error: no checkpoint directory found!'
161 | assert os.path.isfile(checkpoint_file), 'Error: no checkpoint file of epoch {}'.format(epoch)
162 |
163 | checkpoint = torch.load(os.path.join(checkpoint_dir, 'epoch'+str(epoch)+'.pth'))
164 | net.load_state_dict(checkpoint['model_state_dict'])
165 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
166 | start_epoch = checkpoint['epoch'] + 1
167 |
168 |
169 | #######################################################################################
170 | ## Train and Validate
171 | print('==> Training ..')
172 | for epoch in range(start_epoch, start_epoch + epochs_num):
173 | print('Epoch: {}'.format(epoch))
174 | train_epoch(net, criterion, optimizer, train_loader, device=device, dtype=dtype)
175 | val_epoch(net, criterion, val_loader, device=device, dtype=dtype)
176 |
177 | if save_checkpoint and epoch % checkpoint_per_epochs == 0:
178 | if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir)
179 | checkpoint_file = os.path.join(checkpoint_dir, 'epoch'+str(epoch)+'.pth')
180 | checkpoint = {
181 | 'model_state_dict': net.state_dict(),
182 | 'optimizer_state_dict': optimizer.state_dict(),
183 | 'epoch': epoch
184 | }
185 | torch.save(checkpoint, checkpoint_file)
186 |
187 |
188 | #######################################################################################
189 | ## Test
190 | print('==> Testing ..')
191 |
192 | def transform_test(sample):
193 | points, keypoints, refpoint = sample['points'], sample['joints'], sample['refpoint']
194 | assert(keypoints.shape[0] == keypoints_num)
195 |
196 | input, heatmap, coord = voxelization_val({'points': points, 'keypoints': keypoints, 'refpoint': refpoint})
197 |
198 | sample = {
199 | 'input': input,
200 | 'target': {
201 | 'heatmap': heatmap,
202 | 'coord': coord
203 | },
204 | 'extra': {
205 | 'refpoint': refpoint.reshape((1, -1))
206 | }
207 | }
208 |
209 | return sample
210 |
211 |
212 | def transform_coord(coords, refpoints):
213 | keypoints = voxelization_val.warp2continuous_raw(coords, refpoints)
214 | return keypoints
215 |
216 | transform_output = transform_coord
217 |
218 |
219 | class BatchResultCollector():
220 | def __init__(self, samples_num, transform_output):
221 | self.samples_num = samples_num
222 | self.transform_output = transform_output
223 | self.keypoints = None
224 | self.idx = 0
225 |
226 | def __call__(self, data_batch):
227 | outputs_batch = data_batch['output']['coord']
228 | refpoints_batch = data_batch['extra']['refpoint']
229 |
230 | outputs_batch = outputs_batch.cpu().numpy()
231 | refpoints_batch = refpoints_batch.numpy()
232 |
233 | keypoints_batch = self.transform_output(outputs_batch, refpoints_batch)
234 |
235 | if self.keypoints is None:
236 | # Initialize keypoints until dimensions awailable now
237 | self.keypoints = np.zeros((self.samples_num, *keypoints_batch.shape[1:]))
238 |
239 | batch_size = keypoints_batch.shape[0]
240 | self.keypoints[self.idx:self.idx+batch_size] = keypoints_batch
241 | self.idx += batch_size
242 |
243 | def get_result(self):
244 | return self.keypoints
245 |
246 |
247 | print('Test on test dataset ..')
248 | def save_keypoints(filename, keypoints):
249 | # Reshape one sample keypoints into one line
250 | keypoints = keypoints.reshape(keypoints.shape[0], -1)
251 | np.savetxt(filename, keypoints, fmt='%0.4f')
252 |
253 |
254 | test_set = MARAHandDataset(data_dir, center_dir, 'test', test_subject_id, transform_test)
255 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=6)
256 | test_res_collector = BatchResultCollector(len(test_set), transform_output)
257 |
258 | test_epoch(net, test_loader, test_res_collector, device, dtype)
259 | keypoints_test = test_res_collector.get_result()
260 | save_keypoints('./test_res.txt', keypoints_test)
261 |
262 |
263 | print('Fit on train dataset ..')
264 | fit_set = MARAHandDataset(data_dir, center_dir, 'train', test_subject_id, transform_test)
265 | fit_loader = torch.utils.data.DataLoader(fit_set, batch_size=batch_size, shuffle=False, num_workers=6)
266 | fit_res_collector = BatchResultCollector(len(fit_set), transform_output)
267 |
268 | test_epoch(net, fit_loader, fit_res_collector, device, dtype)
269 | keypoints_fit = fit_res_collector.get_result()
270 | save_keypoints('./fit_res.txt', keypoints_fit)
271 |
272 | print('All done ..')
273 |
--------------------------------------------------------------------------------
/integral-pose/model.py:
--------------------------------------------------------------------------------
1 | # from collections import namedtuple
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from v2v_model import V2VModel
6 | import numpy as np
7 |
8 |
9 | class VolumetricSoftmax(nn.Module):
10 | '''
11 | TODO: soft-argmax: norm coord to [-1, 1], instead of [0, N]
12 |
13 | ref: https://gist.github.com/jeasinema/1cba9b40451236ba2cfb507687e08834
14 | '''
15 |
16 | def __init__(self, channel, sizes):
17 | super(VolumetricSoftmax, self).__init__()
18 | self.channel = channel
19 | self.xsize, self.ysize, self.zsize = sizes[0], sizes[1], sizes[2]
20 | self.volume_size = self.xsize * self.ysize * self.zsize
21 |
22 | # TODO: optimize, compute x, y, z together
23 | # pos = np.meshgrid(np.arange(self.xsize), np.arange(self.ysize), np.arange(self.zsize), indexing='ij')
24 | # pos = np.array(pos).reshape((3, -1))
25 | # pos = torch.from_numpy(pos)
26 | # self.register_buffer('pos', pos)
27 |
28 | pos_x, pos_y, pos_z = np.meshgrid(np.arange(self.xsize), np.arange(self.ysize), np.arange(self.zsize), indexing='ij')
29 |
30 | pos_x = torch.from_numpy(pos_x.reshape((-1))).float()
31 | pos_y = torch.from_numpy(pos_y.reshape((-1))).float()
32 | pos_z = torch.from_numpy(pos_z.reshape((-1))).float()
33 |
34 | self.register_buffer('pos_x', pos_x)
35 | self.register_buffer('pos_y', pos_y)
36 | self.register_buffer('pos_z', pos_z)
37 |
38 | def forward(self, x):
39 | # x: (N, C, X, Y, Z)
40 | x = x.view(-1, self.volume_size)
41 | p = F.softmax(x, dim=1)
42 |
43 | #print('self.pos_x: {}, device: {}, dtype: {}'.format(type(self.pos_x), self.pos_x.device, self.pos_x.dtype))
44 | #print('p: {}, device: {}, dtype: {}'.format(type(p), p.device, p.dtype))
45 |
46 | expected_x = torch.sum(self.pos_x * p, dim=1, keepdim=True)
47 | expected_y = torch.sum(self.pos_y * p, dim=1, keepdim=True)
48 | expected_z = torch.sum(self.pos_z * p, dim=1, keepdim=True)
49 |
50 | expected_xyz = torch.cat([expected_x, expected_y, expected_z], 1)
51 | out = expected_xyz.view(-1, self.channel, 3)
52 |
53 | return out
54 |
55 |
56 |
57 | class Model(nn.Module):
58 | def __init__(self, in_channels, out_channels, output_res=44):
59 | super(Model, self).__init__()
60 | self.output_res = output_res
61 | self.basic_model = V2VModel(in_channels, out_channels)
62 | self.spatial_softmax = VolumetricSoftmax(out_channels, (self.output_res, self.output_res, self.output_res))
63 |
64 | def forward(self, x):
65 | heatmap = self.basic_model(x)
66 | coord = self.spatial_softmax(heatmap)
67 |
68 | #print('model heatmap: {}'.format(heatmap.dtype))
69 |
70 | output = {
71 | 'heatmap': heatmap,
72 | 'coord': coord
73 | }
74 |
75 | return output
76 |
--------------------------------------------------------------------------------
/integral-pose/msra_hand.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import sys
4 | import struct
5 | from torch.utils.data import Dataset
6 |
7 |
8 | def pixel2world(x, y, z, img_width, img_height, fx, fy):
9 | w_x = (x - img_width / 2) * z / fx
10 | w_y = (img_height / 2 - y) * z / fy
11 | w_z = z
12 | return w_x, w_y, w_z
13 |
14 |
15 | def world2pixel(x, y, z, img_width, img_height, fx, fy):
16 | p_x = x * fx / z + img_width / 2
17 | p_y = img_height / 2 - y * fy / z
18 | return p_x, p_y
19 |
20 |
21 | def depthmap2points(image, fx, fy):
22 | h, w = image.shape
23 | x, y = np.meshgrid(np.arange(w) + 1, np.arange(h) + 1)
24 | points = np.zeros((h, w, 3), dtype=np.float32)
25 | points[:,:,0], points[:,:,1], points[:,:,2] = pixel2world(x, y, image, w, h, fx, fy)
26 | return points
27 |
28 |
29 | def points2pixels(points, img_width, img_height, fx, fy):
30 | pixels = np.zeros((points.shape[0], 2))
31 | pixels[:, 0], pixels[:, 1] = \
32 | world2pixel(points[:,0], points[:, 1], points[:, 2], img_width, img_height, fx, fy)
33 | return pixels
34 |
35 |
36 | def load_depthmap(filename, img_width, img_height, max_depth):
37 | with open(filename, mode='rb') as f:
38 | data = f.read()
39 | _, _, left, top, right, bottom = struct.unpack('I'*6, data[:6*4])
40 | num_pixel = (right - left) * (bottom - top)
41 | cropped_image = struct.unpack('f'*num_pixel, data[6*4:])
42 |
43 | cropped_image = np.asarray(cropped_image).reshape(bottom-top, -1)
44 | depth_image = np.zeros((img_height, img_width), dtype=np.float32)
45 | depth_image[top:bottom, left:right] = cropped_image
46 | depth_image[depth_image == 0] = max_depth
47 |
48 | return depth_image
49 |
50 |
51 | class MARAHandDataset(Dataset):
52 | def __init__(self, root, center_dir, mode, test_subject_id, transform=None):
53 | self.img_width = 320
54 | self.img_height = 240
55 | self.min_depth = 100
56 | self.max_depth = 700
57 | self.fx = 241.42
58 | self.fy = 241.42
59 | self.joint_num = 21
60 | self.world_dim = 3
61 | self.folder_list = ['1','2','3','4','5','6','7','8','9','I','IP','L','MP','RP','T','TIP','Y']
62 | self.subject_num = 9
63 |
64 | self.root = root
65 | self.center_dir = center_dir
66 | self.mode = mode
67 | self.test_subject_id = test_subject_id
68 | self.transform = transform
69 |
70 | if not self.mode in ['train', 'test']: raise ValueError('Invalid mode')
71 | assert self.test_subject_id >= 0 and self.test_subject_id < self.subject_num
72 |
73 | if not self._check_exists(): raise RuntimeError('Invalid MSRA hand dataset')
74 |
75 | self._load()
76 |
77 | def __getitem__(self, index):
78 | depthmap = load_depthmap(self.names[index], self.img_width, self.img_height, self.max_depth)
79 | points = depthmap2points(depthmap, self.fx, self.fy)
80 | points = points.reshape((-1, 3))
81 |
82 | sample = {
83 | 'name': self.names[index],
84 | 'points': points,
85 | 'joints': self.joints_world[index],
86 | 'refpoint': self.ref_pts[index]
87 | }
88 |
89 | if self.transform: sample = self.transform(sample)
90 |
91 | return sample
92 |
93 | def __len__(self):
94 | return self.num_samples
95 |
96 | def _load(self):
97 | self._compute_dataset_size()
98 |
99 | self.num_samples = self.train_size if self.mode == 'train' else self.test_size
100 | self.joints_world = np.zeros((self.num_samples, self.joint_num, self.world_dim))
101 | self.ref_pts = np.zeros((self.num_samples, self.world_dim))
102 | self.names = []
103 |
104 | # Collect reference center points strings
105 | if self.mode == 'train': ref_pt_file = 'center_train_' + str(self.test_subject_id) + '_refined.txt'
106 | else: ref_pt_file = 'center_test_' + str(self.test_subject_id) + '_refined.txt'
107 |
108 | with open(os.path.join(self.center_dir, ref_pt_file)) as f:
109 | ref_pt_str = [l.rstrip() for l in f]
110 |
111 | #
112 | file_id = 0
113 | frame_id = 0
114 |
115 | for mid in range(self.subject_num):
116 | if self.mode == 'train': model_chk = (mid != self.test_subject_id)
117 | elif self.mode == 'test': model_chk = (mid == self.test_subject_id)
118 | else: raise RuntimeError('unsupported mode {}'.format(self.mode))
119 |
120 | if model_chk:
121 | for fd in self.folder_list:
122 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt')
123 |
124 | lines = []
125 | with open(annot_file) as f:
126 | lines = [line.rstrip() for line in f]
127 |
128 | # skip first line
129 | for i in range(1, len(lines)):
130 | # referece point
131 | splitted = ref_pt_str[file_id].split()
132 | if splitted[0] == 'invalid':
133 | print('Warning: found invalid reference frame')
134 | file_id += 1
135 | continue
136 | else:
137 | self.ref_pts[frame_id, 0] = float(splitted[0])
138 | self.ref_pts[frame_id, 1] = float(splitted[1])
139 | self.ref_pts[frame_id, 2] = float(splitted[2])
140 |
141 | # joint point
142 | splitted = lines[i].split()
143 | for jid in range(self.joint_num):
144 | self.joints_world[frame_id, jid, 0] = float(splitted[jid * self.world_dim])
145 | self.joints_world[frame_id, jid, 1] = float(splitted[jid * self.world_dim + 1])
146 | self.joints_world[frame_id, jid, 2] = -float(splitted[jid * self.world_dim + 2])
147 |
148 | filename = os.path.join(self.root, 'P'+str(mid), fd, '{:0>6d}'.format(i-1) + '_depth.bin')
149 | self.names.append(filename)
150 |
151 | frame_id += 1
152 | file_id += 1
153 |
154 | def _compute_dataset_size(self):
155 | self.train_size, self.test_size = 0, 0
156 |
157 | for mid in range(self.subject_num):
158 | num = 0
159 | for fd in self.folder_list:
160 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt')
161 | with open(annot_file) as f:
162 | num = int(f.readline().rstrip())
163 | if mid == self.test_subject_id: self.test_size += num
164 | else: self.train_size += num
165 |
166 | def _check_exists(self):
167 | # Check basic data
168 | for mid in range(self.subject_num):
169 | for fd in self.folder_list:
170 | annot_file = os.path.join(self.root, 'P'+str(mid), fd, 'joint.txt')
171 | if not os.path.exists(annot_file):
172 | print('Error: annotation file {} does not exist'.format(annot_file))
173 | return False
174 |
175 | # Check precomputed centers by v2v-hand model's author
176 | for subject_id in range(self.subject_num):
177 | center_train = os.path.join(self.center_dir, 'center_train_' + str(subject_id) + '_refined.txt')
178 | center_test = os.path.join(self.center_dir, 'center_test_' + str(subject_id) + '_refined.txt')
179 | if not os.path.exists(center_train) or not os.path.exists(center_test):
180 | print('Error: precomputed center files do not exist')
181 | return False
182 |
183 | return True
184 |
--------------------------------------------------------------------------------
/integral-pose/plot.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 |
4 | def plot_acc(ax, dist, acc, name):
5 | '''
6 | acc: (K, num)
7 | dist: (K, )
8 | name: (K, )
9 | '''
10 | assert(acc.shape[0] == len(name))
11 |
12 | for i in range(len(name)):
13 | ax.plot(dist, acc[i], label=name[i])
14 |
15 | ax.legend()
16 |
17 | ax.set_xlabel('Maximum allowed distance to GT (mm)')
18 | ax.set_ylabel('Fraction of samples within distance')
19 |
20 |
21 | def plot_mean_err(ax, mean_err, name):
22 | '''
23 | mean_err: (K, )
24 | name: (K, )
25 | '''
26 | ax.bar(name, mean_err)
27 |
--------------------------------------------------------------------------------
/integral-pose/progressbar.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 |
5 |
6 | _, term_width = os.popen('stty size', 'r').read().split()
7 | term_width = int(term_width)
8 | TOTAL_BAR_LENGTH = 65.
9 | last_time = time.time()
10 | begin_time = last_time
11 |
12 | def progress_bar(current, total, msg=None):
13 | global last_time, begin_time
14 | if current == 0:
15 | begin_time = time.time() # Reset for new bar.
16 |
17 | cur_len = int(TOTAL_BAR_LENGTH*current/total)
18 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
19 |
20 | sys.stdout.write(' [')
21 | for i in range(cur_len):
22 | sys.stdout.write('=')
23 | sys.stdout.write('>')
24 | for i in range(rest_len):
25 | sys.stdout.write('.')
26 | sys.stdout.write(']')
27 |
28 | cur_time = time.time()
29 | step_time = cur_time - last_time
30 | last_time = cur_time
31 | tot_time = cur_time - begin_time
32 |
33 | L = []
34 | L.append(' Step: %s' % format_time(step_time))
35 | L.append(' | Tot: %s' % format_time(tot_time))
36 | if msg:
37 | L.append(' | ' + msg)
38 |
39 | msg = ''.join(L)
40 | sys.stdout.write(msg)
41 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
42 | sys.stdout.write(' ')
43 |
44 | # Go back to the center of the bar.
45 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
46 | sys.stdout.write('\b')
47 | sys.stdout.write(' %d/%d ' % (current+1, total))
48 |
49 | if current < total-1:
50 | sys.stdout.write('\r')
51 | else:
52 | sys.stdout.write('\n')
53 | sys.stdout.flush()
54 |
55 |
56 | def format_time(seconds):
57 | days = int(seconds / 3600/24)
58 | seconds = seconds - days*3600*24
59 | hours = int(seconds / 3600)
60 | seconds = seconds - hours*3600
61 | minutes = int(seconds / 60)
62 | seconds = seconds - minutes*60
63 | secondsf = int(seconds)
64 | seconds = seconds - secondsf
65 | millis = int(seconds*1000)
66 |
67 | f = ''
68 | i = 1
69 | if days > 0:
70 | f += str(days) + 'D'
71 | i += 1
72 | if hours > 0 and i <= 2:
73 | f += str(hours) + 'h'
74 | i += 1
75 | if minutes > 0 and i <= 2:
76 | f += str(minutes) + 'm'
77 | i += 1
78 | if secondsf > 0 and i <= 2:
79 | f += str(secondsf) + 's'
80 | i += 1
81 | if millis > 0 and i <= 2:
82 | f += str(millis) + 'ms'
83 | i += 1
84 | if f == '':
85 | f = '0ms'
86 | return f
87 |
--------------------------------------------------------------------------------
/integral-pose/sampler.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import sampler
2 | from torch.utils.data import Dataset
3 |
4 |
5 | class ChunkSampler(sampler.Sampler):
6 | """Samples elements sequentially from some offset.
7 | Arguments:
8 | num_samples: # of desired datapoints
9 | start: offset where we should start selecting from
10 | """
11 | def __init__(self, num_samples, start=0):
12 | self.num_samples = num_samples
13 | self.start = start
14 |
15 | def __iter__(self):
16 | return iter(range(self.start, self.start + self.num_samples))
17 |
18 | def __len__(self):
19 | return self.num_samples
20 |
21 |
22 | class ChunkDataset(Dataset):
23 | '''
24 | A warpper of common datasets
25 | '''
26 | def __init__(self, data_set, num_samples, start=0):
27 | self.data_set = data_set
28 | self.num_samples = num_samples
29 | self.start = start
30 | assert(self.start + self.num_samples <= len(data_set))
31 |
32 | def __getitem__(self, index):
33 | return self.data_set[index]
34 |
35 | def __len__(self):
36 | return self.num_samples
37 |
--------------------------------------------------------------------------------
/integral-pose/show_acc.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | from accuracy import *
5 | from plot import *
6 |
7 |
8 | gt_file = r'./test_s3_gt.txt'
9 | pred_file = r'./test_res.txt'
10 |
11 |
12 | gt = np.loadtxt(gt_file)
13 | gt = gt.reshape(gt.shape[0], -1, 3)
14 |
15 | pred = np.loadtxt(pred_file)
16 | pred = pred.reshape(pred.shape[0], -1, 3)
17 |
18 | print('gt: ', gt.shape)
19 | print('pred: ', pred.shape)
20 |
21 |
22 | keypoints_num = 21
23 | names = ['joint'+str(i+1) for i in range(keypoints_num)]
24 |
25 |
26 | dist, acc = compute_dist_acc_wrapper(pred, gt, max_dist=100, num=100)
27 |
28 | fig, ax = plt.subplots()
29 | plot_acc(ax, dist, acc, names)
30 | fig.savefig('msra_s3_joint_acc.png')
31 | plt.show()
32 |
33 |
34 | mean_err = compute_mean_err(pred, gt)
35 | fig, ax = plt.subplots()
36 | plot_mean_err(ax, mean_err, names)
37 | fig.savefig('msra_s3_joint_acc.png')
38 | plt.show()
39 |
40 |
41 | print('mean_err: {}'.format(mean_err))
42 | mean_err_all = compute_mean_err(pred.reshape((-1, 1, 3)), gt.reshape((-1, 1,3)))
43 | print('mean_err_all: ', mean_err_all)
44 |
--------------------------------------------------------------------------------
/integral-pose/solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | from torch.utils.data import DataLoader
4 | import os
5 | from progressbar import progress_bar
6 |
7 |
8 | def train_epoch(model, criterion, optimizer, train_loader, device=torch.device('cuda'), dtype=torch.float, collector=None):
9 | model.train()
10 | train_loss = 0
11 |
12 | for batch_idx, batch_data in enumerate(train_loader):
13 | input, target, extra = batch_data['input'], batch_data['target'], batch_data['extra']
14 |
15 | input = input.to(device, dtype)
16 |
17 | if isinstance(target, torch.Tensor):
18 | target = target.to(device, dtype)
19 | elif isinstance(target, dict):
20 | for k in target:
21 | if isinstance(target[k], torch.Tensor):
22 | target[k] = target[k].to(device, dtype)
23 |
24 | #print('solver target[heatmap]: ', target['heatmap'].dtype)
25 |
26 | optimizer.zero_grad()
27 | output = model(input)
28 | loss = criterion(output, target)
29 | loss.backward()
30 | optimizer.step()
31 |
32 | if collector is None:
33 | train_loss += loss.item()
34 | progress_bar(batch_idx, len(train_loader), 'Loss: {0:.4e}'.format(train_loss/(batch_idx+1)))
35 | #print('loss: {0: .4e}'.format(train_loss/(batch_idx+1)))
36 | else:
37 | model.eval()
38 | with torch.no_grad():
39 | extra['batch_idx'], extra['loader_len'], extra['batch_avg_loss'] = batch_idx, len(train_loader), loss.item()
40 | collector({'model': model, 'input': input, 'target': target, 'output': output, 'extra': extra})
41 |
42 | # Keep train mode
43 | model.train()
44 |
45 |
46 | def val_epoch(model, criterion, val_loader, device=torch.device('cuda'), dtype=torch.float, collector=None):
47 | model.eval()
48 | val_loss = 0
49 |
50 | with torch.no_grad():
51 | for batch_idx, batch_data in enumerate(val_loader):
52 | input, target, extra = batch_data['input'], batch_data['target'], batch_data['extra']
53 |
54 | input = input.to(device, dtype)
55 |
56 | if isinstance(target, torch.Tensor):
57 | target = target.to(device, dtype)
58 | elif isinstance(target, dict):
59 | for k in target:
60 | if isinstance(target[k], torch.Tensor):
61 | target[k] = target[k].to(device, dtype)
62 |
63 | output = model(input)
64 | loss = criterion(output, target)
65 |
66 | if collector is None:
67 | val_loss += loss.item()
68 | progress_bar(batch_idx, len(val_loader), 'Loss: {0:.4e}'.format(val_loss/(batch_idx+1)))
69 | #print('loss: {0: .4e}'.format(val_loss/(batch_idx+1)))
70 | else:
71 | extra['batch_idx'], extra['loader_len'], extra['batch_avg_loss'] = batch_idx, len(val_loader), loss.item()
72 | collector({'model': model, 'input': input, 'target': target, 'output': output, 'extra': extra})
73 |
74 | # Keep eval mode
75 | model.eval()
76 |
77 |
78 | def test_epoch(model, test_loader, collector, device=torch.device('cuda'), dtype=torch.float):
79 | model.eval()
80 |
81 | with torch.no_grad():
82 | for batch_idx, batch_data in enumerate(test_loader):
83 | input, target, extra = batch_data['input'], batch_data['target'], batch_data['extra']
84 | output = model(input.to(device, dtype))
85 |
86 | extra['batch_idx'], extra['loader_len'] = batch_idx, len(test_loader)
87 | collector({'model': model, 'input': input, 'target': target, 'output': output, 'extra': extra})
88 |
89 | # Keep eval mode
90 | model.eval()
91 |
92 |
93 | class Solver():
94 | def __init__(self, train_set, model, criterion, optimizer, device=torch.device('cuda'), dtype=torch.float, **kwargs):
95 | self.train_set = train_set
96 | self.model = model
97 | self.criterion = criterion
98 | self.optimizer = optimizer
99 | self.device = device
100 | self.dtype = dtype
101 |
102 | self.batch_size = kwargs.pop('batch_size', 1)
103 | self.num_epochs = kwargs.pop('num_epochs', 1)
104 | self.num_workers = kwargs.pop('num_workers', 6)
105 |
106 | self.val_set = kwargs.pop('val_set', None)
107 |
108 | # Result collectors
109 | self.train_collector = kwargs.pop('train_collector', None)
110 | self.val_collector = kwargs.pop('val_collector', None)
111 |
112 | # Save check point and resume
113 | self.checkpoint_config = kwargs.pop('checkpoint_config', None)
114 | if self.checkpoint_config is not None:
115 | self.save_checkpoint = self.checkpoint_config['save_checkpoint']
116 | self.checkpoint_dir = self.checkpoint_config['checkpoint_dir']
117 | self.checkpoint_per_epochs = self.checkpoint_config['checkpoint_per_epochs']
118 |
119 | self.resume_training = self.checkpoint_config['resume_training']
120 | self.resume_after_epoch = self.checkpoint_config['resume_after_epoch']
121 | else:
122 | self.save_checkpoint = False
123 | self.resume_training = False
124 |
125 | self._init()
126 |
127 | def _init(self):
128 | self.train_loader = DataLoader(self.train_set, self.batch_size, shuffle=True, num_workers=self.num_workers)
129 |
130 | if self.val_set is not None:
131 | self.val_loader = DataLoader(self.val_set, self.batch_size, shuffle=False, num_workers=self.num_workers)
132 |
133 | self.start_epoch = 0
134 |
135 | if self.resume_training:
136 | self._load_checkpoint(self.resume_after_epoch)
137 |
138 | def _train_epoch(self):
139 | train_epoch(self.model, self.criterion, self.optimizer, self.train_loader,
140 | self.device, self.dtype,
141 | self.train_collector)
142 |
143 | def _val_epoch(self):
144 | val_epoch(self.model, self.criterion, self.val_loader,
145 | self.device, self.dtype,
146 | self.val_collector)
147 |
148 | def _save_checkpoint(self, epoch):
149 | if not os.path.exists(self.checkpoint_dir): os.mkdir(self.checkpoint_dir)
150 | checkpoint_file = os.path.join(self.checkpoint_dir, 'epoch'+str(epoch)+'.pth')
151 |
152 | checkpoint = {
153 | 'model_state_dict': self.model.state_dict(),
154 | 'optimizer_state_dict': self.optimizer.state_dict(),
155 | 'epoch': epoch
156 | }
157 |
158 | torch.save(checkpoint, checkpoint_file)
159 |
160 | def _load_checkpoint(self, epoch):
161 | checkpoint_file = os.path.join(self.checkpoint_dir, 'epoch'+str(epoch)+'.pth')
162 |
163 | print('==> Resuming from checkpoint after epoch {} ..'.format(epoch))
164 | assert os.path.isdir(self.checkpoint_dir), 'Error: no checkpoint directory found!'
165 | assert os.path.isfile(checkpoint_file), 'Error: no checkpoint file of epoch {}'.format(epoch)
166 |
167 | checkpoint = torch.load(os.path.join(self.checkpoint_dir, 'epoch'+str(epoch)+'.pth'))
168 | self.model.load_state_dict(checkpoint['model_state_dict'])
169 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
170 | self.start_epoch = checkpoint['epoch'] + 1
171 |
172 | def train(self):
173 | for epoch in range(self.start_epoch, self.start_epoch + self.num_epochs):
174 | print('Epoch {}: '.format(epoch))
175 | self._train_epoch()
176 |
177 | if self.val_set is not None:
178 | self._val_epoch()
179 |
180 | if self.save_checkpoint and epoch % self.checkpoint_per_epochs == 0:
181 | self._save_checkpoint(epoch)
182 |
--------------------------------------------------------------------------------
/integral-pose/v2v_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class Basic3DBlock(nn.Module):
6 | def __init__(self, in_planes, out_planes, kernel_size):
7 | super(Basic3DBlock, self).__init__()
8 | self.block = nn.Sequential(
9 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1, padding=((kernel_size-1)//2)),
10 | nn.BatchNorm3d(out_planes),
11 | nn.ReLU(True)
12 | )
13 |
14 | def forward(self, x):
15 | return self.block(x)
16 |
17 |
18 | class Res3DBlock(nn.Module):
19 | def __init__(self, in_planes, out_planes):
20 | super(Res3DBlock, self).__init__()
21 | self.res_branch = nn.Sequential(
22 | nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=1, padding=1),
23 | nn.BatchNorm3d(out_planes),
24 | nn.ReLU(True),
25 | nn.Conv3d(out_planes, out_planes, kernel_size=3, stride=1, padding=1),
26 | nn.BatchNorm3d(out_planes)
27 | )
28 |
29 | if in_planes == out_planes:
30 | self.skip_con = nn.Sequential()
31 | else:
32 | self.skip_con = nn.Sequential(
33 | nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=1, padding=0),
34 | nn.BatchNorm3d(out_planes)
35 | )
36 |
37 | def forward(self, x):
38 | res = self.res_branch(x)
39 | skip = self.skip_con(x)
40 | return F.relu(res + skip, True)
41 |
42 |
43 | class Pool3DBlock(nn.Module):
44 | def __init__(self, pool_size):
45 | super(Pool3DBlock, self).__init__()
46 | self.pool_size = pool_size
47 |
48 | def forward(self, x):
49 | return F.max_pool3d(x, kernel_size=self.pool_size, stride=self.pool_size)
50 |
51 |
52 | class Upsample3DBlock(nn.Module):
53 | def __init__(self, in_planes, out_planes, kernel_size, stride):
54 | super(Upsample3DBlock, self).__init__()
55 | assert(kernel_size == 2)
56 | assert(stride == 2)
57 | self.block = nn.Sequential(
58 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=0, output_padding=0),
59 | nn.BatchNorm3d(out_planes),
60 | nn.ReLU(True)
61 | )
62 |
63 | def forward(self, x):
64 | return self.block(x)
65 |
66 |
67 | class EncoderDecorder(nn.Module):
68 | def __init__(self):
69 | super(EncoderDecorder, self).__init__()
70 |
71 | self.encoder_pool1 = Pool3DBlock(2)
72 | self.encoder_res1 = Res3DBlock(32, 64)
73 | self.encoder_pool2 = Pool3DBlock(2)
74 | self.encoder_res2 = Res3DBlock(64, 128)
75 |
76 | self.mid_res = Res3DBlock(128, 128)
77 |
78 | self.decoder_res2 = Res3DBlock(128, 128)
79 | self.decoder_upsample2 = Upsample3DBlock(128, 64, 2, 2)
80 | self.decoder_res1 = Res3DBlock(64, 64)
81 | self.decoder_upsample1 = Upsample3DBlock(64, 32, 2, 2)
82 |
83 | self.skip_res1 = Res3DBlock(32, 32)
84 | self.skip_res2 = Res3DBlock(64, 64)
85 |
86 | def forward(self, x):
87 | skip_x1 = self.skip_res1(x)
88 | x = self.encoder_pool1(x)
89 | x = self.encoder_res1(x)
90 | skip_x2 = self.skip_res2(x)
91 | x = self.encoder_pool2(x)
92 | x = self.encoder_res2(x)
93 |
94 | x = self.mid_res(x)
95 |
96 | x = self.decoder_res2(x)
97 | x = self.decoder_upsample2(x)
98 | x = x + skip_x2
99 | x = self.decoder_res1(x)
100 | x = self.decoder_upsample1(x)
101 | x = x + skip_x1
102 |
103 | return x
104 |
105 |
106 | class V2VModel(nn.Module):
107 | def __init__(self, input_channels, output_channels):
108 | super(V2VModel, self).__init__()
109 |
110 | self.front_layers = nn.Sequential(
111 | Basic3DBlock(input_channels, 16, 7),
112 | Pool3DBlock(2),
113 | Res3DBlock(16, 32),
114 | Res3DBlock(32, 32),
115 | Res3DBlock(32, 32)
116 | )
117 |
118 | self.encoder_decoder = EncoderDecorder()
119 |
120 | self.back_layers = nn.Sequential(
121 | Res3DBlock(32, 32),
122 | Basic3DBlock(32, 32, 1),
123 | Basic3DBlock(32, 32, 1),
124 | )
125 |
126 | self.output_layer = nn.Conv3d(32, output_channels, kernel_size=1, stride=1, padding=0)
127 |
128 | self._initialize_weights()
129 |
130 | def forward(self, x):
131 | x = self.front_layers(x)
132 | x = self.encoder_decoder(x)
133 | x = self.back_layers(x)
134 | x = self.output_layer(x)
135 | return x
136 |
137 | def _initialize_weights(self):
138 | for m in self.modules():
139 | if isinstance(m, nn.Conv3d):
140 | nn.init.normal_(m.weight, 0, 0.001)
141 | nn.init.constant_(m.bias, 0)
142 | elif isinstance(m, nn.ConvTranspose3d):
143 | nn.init.normal_(m.weight, 0, 0.001)
144 | nn.init.constant_(m.bias, 0)
145 |
--------------------------------------------------------------------------------
/integral-pose/v2v_util.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import numpy as np
3 |
4 |
5 | def discretize(coord, cropped_size):
6 | '''[-1, 1] -> [0, cropped_size]'''
7 | min_normalized = -1
8 | max_normalized = 1
9 | scale = (max_normalized - min_normalized) / cropped_size
10 | return (coord - min_normalized) / scale
11 |
12 |
13 | def warp2continuous(coord, refpoint, cubic_size, cropped_size):
14 | '''
15 | Map coordinates in set [0, 1, .., cropped_size-1] to original range [-cubic_size/2+refpoint, cubic_size/2 + refpoint]
16 | '''
17 | min_normalized = -1
18 | max_normalized = 1
19 |
20 | scale = (max_normalized - min_normalized) / cropped_size
21 | coord = coord * scale + min_normalized # -> [-1, 1]
22 |
23 | coord = coord * cubic_size / 2 + refpoint
24 |
25 | return coord
26 |
27 |
28 | def scattering(coord, cropped_size):
29 | # coord: [0, cropped_size]
30 | # Assign range[0, 1) -> 0, [1, 2) -> 1, .. [cropped_size-1, cropped_size) -> cropped_size-1
31 | # That is, around center 0.5 -> 0, around center 1.5 -> 1 .. around center cropped_size-0.5 -> cropped_size-1
32 | coord = coord.astype(np.int32)
33 |
34 | mask = (coord[:, 0] >= 0) & (coord[:, 0] < cropped_size) & \
35 | (coord[:, 1] >= 0) & (coord[:, 1] < cropped_size) & \
36 | (coord[:, 2] >= 0) & (coord[:, 2] < cropped_size)
37 |
38 | coord = coord[mask, :]
39 |
40 | cubic = np.zeros((cropped_size, cropped_size, cropped_size))
41 |
42 | # Note, directly map point coordinate (x, y, z) to index (i, j, k), instead of (k, j, i)
43 | # Need to be consistent with heatmap generating and coordinates extration from heatmap
44 | cubic[coord[:, 0], coord[:, 1], coord[:, 2]] = 1
45 |
46 | return cubic
47 |
48 |
49 | def extract_coord_from_output(output, center=True):
50 | '''
51 | output: shape (batch, jointNum, volumeSize, volumeSize, volumeSize)
52 | center: if True, add 0.5, default is true
53 | return: shape (batch, jointNum, 3)
54 | '''
55 | assert(len(output.shape) >= 3)
56 | vsize = output.shape[-3:]
57 |
58 | output_rs = output.reshape(-1, np.prod(vsize))
59 | max_index = np.unravel_index(np.argmax(output_rs, axis=1), vsize)
60 | max_index = np.array(max_index).T
61 |
62 | xyz_output = max_index.reshape([*output.shape[:-3], 3])
63 |
64 | # Note discrete coord can represents real range [coord, coord+1), see function scattering()
65 | # So, move coord to range center for better fittness
66 | if center: xyz_output = xyz_output + 0.5
67 |
68 | return xyz_output
69 |
70 |
71 | def generate_coord(points, refpoint, new_size, angle, trans, sizes):
72 | cubic_size, cropped_size, original_size = sizes
73 |
74 | # points shape: (n, 3)
75 | coord = points
76 |
77 | # note, will consider points within range [refpoint-cubic_size/2, refpoint+cubic_size/2] as candidates
78 |
79 | # normalize
80 | coord = (coord - refpoint) / (cubic_size/2) # -> [-1, 1]
81 |
82 | # discretize
83 | coord = discretize(coord, cropped_size) # -> [0, cropped_size]
84 |
85 | # move cropped center to (virtual larger, [0, original_size]) original volume center
86 | # that is, treat current data as cropped from center of original volume, and now we put back it
87 | coord += (original_size / 2 - cropped_size / 2)
88 |
89 | # resize around original center with scale new_size/100
90 | resize_scale = 100 / new_size
91 | if new_size < 100:
92 | coord = coord * resize_scale + original_size/2 * (1 - resize_scale)
93 | elif new_size > 100:
94 | coord = coord * resize_scale - original_size/2 * (resize_scale - 1)
95 | else:
96 | # new_size = 100 if it is in test mode
97 | pass
98 |
99 | # rotation
100 | if angle != 0:
101 | original_coord = coord.copy()
102 | original_coord[:,0] -= original_size/2
103 | original_coord[:,1] -= original_size/2
104 | coord[:,0] = original_coord[:,0]*np.cos(angle) - original_coord[:,1]*np.sin(angle)
105 | coord[:,1] = original_coord[:,0]*np.sin(angle) + original_coord[:,1]*np.cos(angle)
106 | coord[:,0] += original_size/2
107 | coord[:,1] += original_size/2
108 |
109 | # translation
110 | # Note, if trans = (original_size/2 - cropped_size/2 + 1), the following translation will
111 | # cancel the above translation(after discretion). It will be set it when in test mode.
112 | # TODO: Can only achieve translation [-4, 4]?
113 | coord -= trans
114 |
115 | return coord
116 |
117 |
118 | def generate_cubic_input(points, refpoint, new_size, angle, trans, sizes):
119 | _, cropped_size, _ = sizes
120 | coord = generate_coord(points, refpoint, new_size, angle, trans, sizes)
121 |
122 | # scattering
123 | cubic = scattering(coord, cropped_size)
124 |
125 | return cubic
126 |
127 |
128 | # def generate_heatmap_gt(keypoints, refpoint, new_size, angle, trans, sizes, d3outputs, pool_factor, std):
129 | # _, cropped_size, _ = sizes
130 | # d3output_x, d3output_y, d3output_z = d3outputs
131 |
132 | # coord = generate_coord(keypoints, refpoint, new_size, angle, trans, sizes) # [0, cropped_size]
133 | # coord /= pool_factor # [0, cropped_size/pool_factor]
134 |
135 | # # heatmap generation
136 | # output_size = int(cropped_size / pool_factor)
137 | # heatmap = np.zeros((keypoints.shape[0], output_size, output_size, output_size))
138 |
139 | # # use center of cell
140 | # center_offset = 0.5
141 |
142 | # for i in range(coord.shape[0]):
143 | # xi, yi, zi= coord[i]
144 | # heatmap[i] = np.exp(-(np.power((d3output_x+center_offset-xi)/std, 2)/2 + \
145 | # np.power((d3output_y+center_offset-yi)/std, 2)/2 + \
146 | # np.power((d3output_z+center_offset-zi)/std, 2)/2)) # +0.5, move coordinate to range center
147 |
148 | # return heatmap
149 |
150 |
151 | #---
152 | def containing_box_coord(point):
153 | '''
154 | point: (K, 3)
155 | return: (K, 8, 3), eight box vertices coords
156 | '''
157 | # if np.any(point >= 43):
158 | # res = point[point >= 43]
159 | # print('invalid point: {}'.format(res))
160 | # exit()
161 |
162 |
163 | box_grid = np.meshgrid([0, 1], [0, 1], [0, 1], indexing='ij')
164 | box_grid = np.array(box_grid).reshape((3, 8)).transpose()
165 |
166 | floor = np.floor(point)
167 | box_coord = floor.reshape((-1, 1, 3)) + box_grid
168 |
169 | return box_coord
170 |
171 |
172 | def box_coord_prob(point, box_coord):
173 | '''
174 | point: (K, 3)
175 | box_coord: (K, 8, 3)
176 | return: (K, 8)
177 | '''
178 | diff = box_coord - point.reshape((-1, 1, 3))
179 | weight = np.maximum(0, 1 - np.abs(diff))
180 | prob = weight[:,:,0] * weight[:,:,1] * weight[:,:,2]
181 | norm = np.sum(prob, axis=1, keepdims=True)
182 | norm[norm <= 0] = 1.0 # avoid zero
183 | prob = prob / norm
184 |
185 | return prob
186 |
187 |
188 | def onehot_heatmap_impl(coord, output_size):
189 | '''
190 | coord: (K, 3)
191 | return: (output_size, output_size, output_size)
192 | '''
193 | coord = np.array(coord)
194 |
195 | box_coord = containing_box_coord(coord) # (K, 8, 3)
196 | box_prob = box_coord_prob(coord, box_coord) # (K, 8)
197 | box_coord = box_coord.astype(np.int32)
198 |
199 | # Generate K heatmaps
200 | heatmap = np.zeros((coord.shape[0], output_size, output_size, output_size))
201 | for i in range(coord.shape[0]):
202 | heatmap[i][box_coord[i,:,0], box_coord[i,:,1], box_coord[i,:,2]] = box_prob[i]
203 |
204 | return heatmap
205 |
206 |
207 | def generate_onehot_heatmap_gt(keypoints, refpoint, new_size, angle, trans, sizes, d3outputs, pool_factor, std):
208 | _, cropped_size, _ = sizes
209 | d3output_x, d3output_y, d3output_z = d3outputs
210 |
211 | coord = generate_coord(keypoints, refpoint, new_size, angle, trans, sizes) # [0, cropped_size]
212 | coord /= pool_factor # [0, cropped_size/pool_factor]
213 |
214 | # heatmap generation
215 | output_size = int(cropped_size / pool_factor)
216 |
217 | # Warning, clip joints into [0, 42], make sure the containing box coord indices will not exceed 43(44-1)
218 | # TODO: check here
219 | target_output_size = 44
220 | coord[coord >= target_output_size-2] = target_output_size - 2
221 | coord[coord < 0] = 0
222 |
223 | return onehot_heatmap_impl(coord, output_size)
224 |
225 |
226 | def generate_volume_coord_gt(keypoints, refpoint, new_size, angle, trans, sizes, pool_factor):
227 | coord = generate_coord(keypoints, refpoint, new_size, angle, trans, sizes) # [0, cropped_size]
228 | coord /= pool_factor
229 |
230 | # Warning, clip joints into [0, 42], make sure the containing box coord indices will not exceed 43(44-1)
231 | # TODO: check here
232 | target_output_size = 44
233 | coord[coord >= target_output_size-2] = target_output_size - 2
234 | coord[coord < 0] = 0
235 |
236 | return coord
237 |
238 |
239 | class V2VVoxelization(object):
240 | def __init__(self, cubic_size, augmentation=True):
241 | self.cubic_size = cubic_size
242 | self.cropped_size, self.original_size = 88, 96
243 | self.sizes = (self.cubic_size, self.cropped_size, self.original_size)
244 | self.pool_factor = 2
245 | self.std = 1.7
246 | self.augmentation = augmentation
247 |
248 | output_size = int(self.cropped_size / self.pool_factor)
249 | # Note, range(size) and indexing = 'ij'
250 | self.d3outputs = np.meshgrid(np.arange(output_size), np.arange(output_size), np.arange(output_size), indexing='ij')
251 |
252 | def __call__(self, sample):
253 | points, keypoints, refpoint = sample['points'], sample['keypoints'], sample['refpoint']
254 |
255 | ## Augmentations
256 | # Resize
257 | new_size = np.random.rand() * 40 + 80
258 |
259 | # Rotation
260 | angle = np.random.rand() * 80/180*np.pi - 40/180*np.pi
261 |
262 | # Translation
263 | trans = np.random.rand(3) * (self.original_size - self.cropped_size)
264 |
265 | if not self.augmentation:
266 | new_size = 100
267 | angle = 0
268 | trans = self.original_size/2 - self.cropped_size/2
269 |
270 | # Add noise and random selection
271 | add_noise = False
272 | random_selection = False
273 |
274 | if self.augmentation and add_noise:
275 | # noise, [-0.5, 0.5]
276 | scale = 0.5
277 | noise = (np.random.rand(*points.shape) * 2 - 1) * scale
278 | points += noise
279 |
280 | if self.augmentation and random_selection:
281 | threshold = np.random.rand(1)[0] * 0.5 # <= 0.5
282 | prob = np.random.rand(points.shape[0])
283 | mask = prob > threshold
284 | points = points[mask, :]
285 |
286 | input = generate_cubic_input(points, refpoint, new_size, angle, trans, self.sizes)
287 | # heatmap = generate_heatmap_gt(keypoints, refpoint, new_size, angle, trans, self.sizes, self.d3outputs, self.pool_factor, self.std)
288 | # Use One-hot heatmap
289 | heatmap = generate_onehot_heatmap_gt(keypoints, refpoint, new_size, angle, trans, self.sizes, self.d3outputs, self.pool_factor, self.std)
290 | keypoints_volume_coords = generate_volume_coord_gt(keypoints, refpoint, new_size, angle, trans, self.sizes, self.pool_factor)
291 |
292 | # one channel
293 | input = input.reshape((1, *input.shape))
294 |
295 | return input, heatmap, keypoints_volume_coords
296 |
297 | # def voxelize(self, points, refpoint):
298 | # new_size, angle, trans = 100, 0, self.original_size/2 - self.cropped_size/2
299 | # input = generate_cubic_input(points, refpoint, new_size, angle, trans, self.sizes)
300 | # return input.reshape((1, *input.shape))
301 |
302 | # def generate_heatmap(self, keypoints, refpoint):
303 | # new_size, angle, trans = 100, 0, self.original_size/2 - self.cropped_size/2
304 | # heatmap = generate_heatmap_gt(keypoints, refpoint, new_size, angle, trans, self.sizes, self.d3outputs, self.pool_factor, self.std)
305 | # return heatmap
306 |
307 | # def evaluate(self, heatmaps, refpoints):
308 | # coords = extract_coord_from_output(heatmaps, center=True)
309 | # coords *= self.pool_factor
310 | # keypoints = warp2continuous(coords, refpoints, self.cubic_size, self.cropped_size)
311 | # return keypoints
312 |
313 | # def warp2continuous(self, coords, refpoints):
314 | # print('Warning: added 0.5 on input coord')
315 | # coords += 0.5 # move to grid cell center
316 | # coords *= self.pool_factor
317 | # keypoints = warp2continuous(coords, refpoints, self.cubic_size, self.cropped_size)
318 | # return keypoints
319 |
320 | # def generate_coord_raw(self, points, refpoint):
321 | # new_size, angle, trans = 100, 0, self.original_size/2 - self.cropped_size/2
322 | # coord = generate_coord(points, refpoint, new_size, angle, trans, self.sizes)
323 | # return coord
324 |
325 | def warp2continuous_raw(self, coords, refpoints):
326 | # Do not add 0.5, since coords have float precison
327 | coords = coords * self.pool_factor
328 | keypoints = warp2continuous(coords, refpoints, self.cubic_size, self.cropped_size)
329 | return keypoints
330 |
--------------------------------------------------------------------------------
/lib/accuracy.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def compute_dist_acc_wrapper(pred, gt, max_dist=10, num=100):
5 | '''
6 | pred: (N, K, 3)
7 | gt: (N, K, 3)
8 |
9 | return dist: (K, )
10 | return acc: (K, num)
11 | '''
12 | assert(pred.shape == gt.shape)
13 | assert(len(pred.shape) == 3)
14 |
15 | dist = np.linspace(0, max_dist, num)
16 | return dist, compute_dist_acc(pred, gt, dist)
17 |
18 |
19 | def compute_dist_acc(pred, gt, dist):
20 | '''
21 | pred: (N, K, 3)
22 | gt: (N, K, 3)
23 | dist: (M, )
24 |
25 | return acc: (K, M)
26 | '''
27 | assert(pred.shape == gt.shape)
28 | assert(len(pred.shape) == 3)
29 |
30 | N, K = pred.shape[0], pred.shape[1]
31 | err_dist = np.sqrt(np.sum((pred - gt)**2, axis=2)) # (N, K)
32 |
33 | acc = np.zeros((K, dist.shape[0]))
34 |
35 | for i, d in enumerate(dist):
36 | acc_d = (err_dist < d).sum(axis=0) / N
37 | acc[:,i] = acc_d
38 |
39 | return acc
40 |
41 |
42 | def compute_mean_err(pred, gt):
43 | '''
44 | pred: (N, K, 3)
45 | gt: (N, K, 3)
46 |
47 | mean_err: (K,)
48 | '''
49 | N, K = pred.shape[0], pred.shape[1]
50 | err_dist = np.sqrt(np.sum((pred - gt)**2, axis=2)) # (N, K)
51 | return np.mean(err_dist, axis=0)
52 |
--------------------------------------------------------------------------------
/lib/progressbar.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 |
5 |
6 | _, term_width = os.popen('stty size', 'r').read().split()
7 | term_width = int(term_width)
8 | TOTAL_BAR_LENGTH = 65.
9 | last_time = time.time()
10 | begin_time = last_time
11 |
12 | def progress_bar(current, total, msg=None):
13 | global last_time, begin_time
14 | if current == 0:
15 | begin_time = time.time() # Reset for new bar.
16 |
17 | cur_len = int(TOTAL_BAR_LENGTH*current/total)
18 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
19 |
20 | sys.stdout.write(' [')
21 | for i in range(cur_len):
22 | sys.stdout.write('=')
23 | sys.stdout.write('>')
24 | for i in range(rest_len):
25 | sys.stdout.write('.')
26 | sys.stdout.write(']')
27 |
28 | cur_time = time.time()
29 | step_time = cur_time - last_time
30 | last_time = cur_time
31 | tot_time = cur_time - begin_time
32 |
33 | L = []
34 | L.append(' Step: %s' % format_time(step_time))
35 | L.append(' | Tot: %s' % format_time(tot_time))
36 | if msg:
37 | L.append(' | ' + msg)
38 |
39 | msg = ''.join(L)
40 | sys.stdout.write(msg)
41 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
42 | sys.stdout.write(' ')
43 |
44 | # Go back to the center of the bar.
45 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
46 | sys.stdout.write('\b')
47 | sys.stdout.write(' %d/%d ' % (current+1, total))
48 |
49 | if current < total-1:
50 | sys.stdout.write('\r')
51 | else:
52 | sys.stdout.write('\n')
53 | sys.stdout.flush()
54 |
55 |
56 | def format_time(seconds):
57 | days = int(seconds / 3600/24)
58 | seconds = seconds - days*3600*24
59 | hours = int(seconds / 3600)
60 | seconds = seconds - hours*3600
61 | minutes = int(seconds / 60)
62 | seconds = seconds - minutes*60
63 | secondsf = int(seconds)
64 | seconds = seconds - secondsf
65 | millis = int(seconds*1000)
66 |
67 | f = ''
68 | i = 1
69 | if days > 0:
70 | f += str(days) + 'D'
71 | i += 1
72 | if hours > 0 and i <= 2:
73 | f += str(hours) + 'h'
74 | i += 1
75 | if minutes > 0 and i <= 2:
76 | f += str(minutes) + 'm'
77 | i += 1
78 | if secondsf > 0 and i <= 2:
79 | f += str(secondsf) + 's'
80 | i += 1
81 | if millis > 0 and i <= 2:
82 | f += str(millis) + 'ms'
83 | i += 1
84 | if f == '':
85 | f = '0ms'
86 | return f
87 |
--------------------------------------------------------------------------------
/lib/sampler.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import sampler
2 |
3 |
4 | class ChunkSampler(sampler.Sampler):
5 | """Samples elements sequentially from some offset.
6 | Arguments:
7 | num_samples: # of desired datapoints
8 | start: offset where we should start selecting from
9 | """
10 | def __init__(self, num_samples, start=0):
11 | self.num_samples = num_samples
12 | self.start = start
13 |
14 | def __iter__(self):
15 | return iter(range(self.start, self.start + self.num_samples))
16 |
17 | def __len__(self):
18 | return self.num_samples
19 |
--------------------------------------------------------------------------------
/lib/solver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | from lib.progressbar import progress_bar
4 |
5 |
6 | def train_epoch(model, criterion, optimizer, train_loader, device=torch.device('cuda'), dtype=torch.float):
7 | model.train()
8 | train_loss = 0
9 |
10 | for batch_idx, (inputs, targets) in enumerate(train_loader):
11 | inputs, targets = inputs.to(device, dtype), targets.to(device, dtype)
12 | optimizer.zero_grad()
13 | outputs = model(inputs)
14 | loss = criterion(outputs, targets)
15 | loss.backward()
16 | optimizer.step()
17 |
18 | train_loss += loss.item()
19 | progress_bar(batch_idx, len(train_loader), 'Loss: {0:.4e}'.format(train_loss/(batch_idx+1)))
20 | #print('loss: {0: .4e}'.format(train_loss/(batch_idx+1)))
21 |
22 |
23 | def val_epoch(model, criterion, val_loader, device=torch.device('cuda'), dtype=torch.float):
24 | model.eval()
25 | val_loss = 0
26 |
27 | with torch.no_grad():
28 | for batch_idx, (inputs, targets) in enumerate(val_loader):
29 | inputs, targets = inputs.to(device, dtype), targets.to(device, dtype)
30 | outputs = model(inputs)
31 | loss = criterion(outputs, targets)
32 |
33 | val_loss += loss.item()
34 | progress_bar(batch_idx, len(val_loader), 'Loss: {0:.4e}'.format(val_loss/(batch_idx+1)))
35 | #print('loss: {0: .4e}'.format(val_loss/(batch_idx+1)))
36 |
37 |
38 | def test_epoch(model, test_loader, result_collector, device=torch.device('cuda'), dtype=torch.float):
39 | model.eval()
40 |
41 | with torch.no_grad():
42 | for batch_idx, (inputs, extra) in enumerate(test_loader):
43 | outputs = model(inputs.to(device, dtype))
44 | result_collector((inputs, outputs, extra))
45 |
46 | progress_bar(batch_idx, len(test_loader))
47 |
--------------------------------------------------------------------------------
/src/v2v_model.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class Basic3DBlock(nn.Module):
6 | def __init__(self, in_planes, out_planes, kernel_size):
7 | super(Basic3DBlock, self).__init__()
8 | self.block = nn.Sequential(
9 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1, padding=((kernel_size-1)//2)),
10 | nn.BatchNorm3d(out_planes),
11 | nn.ReLU(True)
12 | )
13 |
14 | def forward(self, x):
15 | return self.block(x)
16 |
17 |
18 | class Res3DBlock(nn.Module):
19 | def __init__(self, in_planes, out_planes):
20 | super(Res3DBlock, self).__init__()
21 | self.res_branch = nn.Sequential(
22 | nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=1, padding=1),
23 | nn.BatchNorm3d(out_planes),
24 | nn.ReLU(True),
25 | nn.Conv3d(out_planes, out_planes, kernel_size=3, stride=1, padding=1),
26 | nn.BatchNorm3d(out_planes)
27 | )
28 |
29 | if in_planes == out_planes:
30 | self.skip_con = nn.Sequential()
31 | else:
32 | self.skip_con = nn.Sequential(
33 | nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=1, padding=0),
34 | nn.BatchNorm3d(out_planes)
35 | )
36 |
37 | def forward(self, x):
38 | res = self.res_branch(x)
39 | skip = self.skip_con(x)
40 | return F.relu(res + skip, True)
41 |
42 |
43 | class Pool3DBlock(nn.Module):
44 | def __init__(self, pool_size):
45 | super(Pool3DBlock, self).__init__()
46 | self.pool_size = pool_size
47 |
48 | def forward(self, x):
49 | return F.max_pool3d(x, kernel_size=self.pool_size, stride=self.pool_size)
50 |
51 |
52 | class Upsample3DBlock(nn.Module):
53 | def __init__(self, in_planes, out_planes, kernel_size, stride):
54 | super(Upsample3DBlock, self).__init__()
55 | assert(kernel_size == 2)
56 | assert(stride == 2)
57 | self.block = nn.Sequential(
58 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=0, output_padding=0),
59 | nn.BatchNorm3d(out_planes),
60 | nn.ReLU(True)
61 | )
62 |
63 | def forward(self, x):
64 | return self.block(x)
65 |
66 |
67 | class EncoderDecorder(nn.Module):
68 | def __init__(self):
69 | super(EncoderDecorder, self).__init__()
70 |
71 | self.encoder_pool1 = Pool3DBlock(2)
72 | self.encoder_res1 = Res3DBlock(32, 64)
73 | self.encoder_pool2 = Pool3DBlock(2)
74 | self.encoder_res2 = Res3DBlock(64, 128)
75 |
76 | self.mid_res = Res3DBlock(128, 128)
77 |
78 | self.decoder_res2 = Res3DBlock(128, 128)
79 | self.decoder_upsample2 = Upsample3DBlock(128, 64, 2, 2)
80 | self.decoder_res1 = Res3DBlock(64, 64)
81 | self.decoder_upsample1 = Upsample3DBlock(64, 32, 2, 2)
82 |
83 | self.skip_res1 = Res3DBlock(32, 32)
84 | self.skip_res2 = Res3DBlock(64, 64)
85 |
86 | def forward(self, x):
87 | skip_x1 = self.skip_res1(x)
88 | x = self.encoder_pool1(x)
89 | x = self.encoder_res1(x)
90 | skip_x2 = self.skip_res2(x)
91 | x = self.encoder_pool2(x)
92 | x = self.encoder_res2(x)
93 |
94 | x = self.mid_res(x)
95 |
96 | x = self.decoder_res2(x)
97 | x = self.decoder_upsample2(x)
98 | x = x + skip_x2
99 | x = self.decoder_res1(x)
100 | x = self.decoder_upsample1(x)
101 | x = x + skip_x1
102 |
103 | return x
104 |
105 |
106 | class V2VModel(nn.Module):
107 | def __init__(self, input_channels, output_channels):
108 | super(V2VModel, self).__init__()
109 |
110 | self.front_layers = nn.Sequential(
111 | Basic3DBlock(input_channels, 16, 7),
112 | Pool3DBlock(2),
113 | Res3DBlock(16, 32),
114 | Res3DBlock(32, 32),
115 | Res3DBlock(32, 32)
116 | )
117 |
118 | self.encoder_decoder = EncoderDecorder()
119 |
120 | self.back_layers = nn.Sequential(
121 | Res3DBlock(32, 32),
122 | Basic3DBlock(32, 32, 1),
123 | Basic3DBlock(32, 32, 1),
124 | )
125 |
126 | self.output_layer = nn.Conv3d(32, output_channels, kernel_size=1, stride=1, padding=0)
127 |
128 | self._initialize_weights()
129 |
130 | def forward(self, x):
131 | x = self.front_layers(x)
132 | x = self.encoder_decoder(x)
133 | x = self.back_layers(x)
134 | x = self.output_layer(x)
135 | return x
136 |
137 | def _initialize_weights(self):
138 | for m in self.modules():
139 | if isinstance(m, nn.Conv3d):
140 | nn.init.normal_(m.weight, 0, 0.001)
141 | nn.init.constant_(m.bias, 0)
142 | elif isinstance(m, nn.ConvTranspose3d):
143 | nn.init.normal_(m.weight, 0, 0.001)
144 | nn.init.constant_(m.bias, 0)
145 |
--------------------------------------------------------------------------------
/src/v2v_util.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import numpy as np
3 |
4 |
5 | def discretize(coord, cropped_size):
6 | '''[-1, 1] -> [0, cropped_size]'''
7 | min_normalized = -1
8 | max_normalized = 1
9 | scale = (max_normalized - min_normalized) / cropped_size
10 | return (coord - min_normalized) / scale
11 |
12 |
13 | def warp2continuous(coord, refpoint, cubic_size, cropped_size):
14 | '''
15 | Map coordinates in set [0, 1, .., cropped_size-1] to original range [-cubic_size/2+refpoint, cubic_size/2 + refpoint]
16 | '''
17 | min_normalized = -1
18 | max_normalized = 1
19 |
20 | scale = (max_normalized - min_normalized) / cropped_size
21 | coord = coord * scale + min_normalized # -> [-1, 1]
22 |
23 | coord = coord * cubic_size / 2 + refpoint
24 |
25 | return coord
26 |
27 |
28 | def scattering(coord, cropped_size):
29 | # coord: [0, cropped_size]
30 | # Assign range[0, 1) -> 0, [1, 2) -> 1, .. [cropped_size-1, cropped_size) -> cropped_size-1
31 | # That is, around center 0.5 -> 0, around center 1.5 -> 1 .. around center cropped_size-0.5 -> cropped_size-1
32 | coord = coord.astype(np.int32)
33 |
34 | mask = (coord[:, 0] >= 0) & (coord[:, 0] < cropped_size) & \
35 | (coord[:, 1] >= 0) & (coord[:, 1] < cropped_size) & \
36 | (coord[:, 2] >= 0) & (coord[:, 2] < cropped_size)
37 |
38 | coord = coord[mask, :]
39 |
40 | cubic = np.zeros((cropped_size, cropped_size, cropped_size))
41 |
42 | # Note, directly map point coordinate (x, y, z) to index (i, j, k), instead of (k, j, i)
43 | # Need to be consistent with heatmap generating and coordinates extration from heatmap
44 | cubic[coord[:, 0], coord[:, 1], coord[:, 2]] = 1
45 |
46 | return cubic
47 |
48 |
49 | def extract_coord_from_output(output, center=True):
50 | '''
51 | output: shape (batch, jointNum, volumeSize, volumeSize, volumeSize)
52 | center: if True, add 0.5, default is true
53 | return: shape (batch, jointNum, 3)
54 | '''
55 | assert(len(output.shape) >= 3)
56 | vsize = output.shape[-3:]
57 |
58 | output_rs = output.reshape(-1, np.prod(vsize))
59 | max_index = np.unravel_index(np.argmax(output_rs, axis=1), vsize)
60 | max_index = np.array(max_index).T
61 |
62 | xyz_output = max_index.reshape([*output.shape[:-3], 3])
63 |
64 | # Note discrete coord can represents real range [coord, coord+1), see function scattering()
65 | # So, move coord to range center for better fittness
66 | if center: xyz_output = xyz_output + 0.5
67 |
68 | return xyz_output
69 |
70 |
71 | def generate_coord(points, refpoint, new_size, angle, trans, sizes):
72 | cubic_size, cropped_size, original_size = sizes
73 |
74 | # points shape: (n, 3)
75 | coord = points
76 |
77 | # note, will consider points within range [refpoint-cubic_size/2, refpoint+cubic_size/2] as candidates
78 |
79 | # normalize
80 | coord = (coord - refpoint) / (cubic_size/2) # -> [-1, 1]
81 |
82 | # discretize
83 | coord = discretize(coord, cropped_size) # -> [0, cropped_size]
84 | coord += (original_size / 2 - cropped_size / 2) # move center to original volume
85 |
86 | # resize around original volume center
87 | resize_scale = new_size / 100
88 | if new_size < 100:
89 | coord = coord * resize_scale + original_size/2 * (1 - resize_scale)
90 | elif new_size > 100:
91 | coord = coord * resize_scale - original_size/2 * (resize_scale - 1)
92 | else:
93 | # new_size = 100 if it is in test mode
94 | pass
95 |
96 | # rotation
97 | if angle != 0:
98 | original_coord = coord.copy()
99 | original_coord[:,0] -= original_size / 2
100 | original_coord[:,1] -= original_size / 2
101 | coord[:,0] = original_coord[:,0]*np.cos(angle) - original_coord[:,1]*np.sin(angle)
102 | coord[:,1] = original_coord[:,0]*np.sin(angle) + original_coord[:,1]*np.cos(angle)
103 | coord[:,0] += original_size / 2
104 | coord[:,1] += original_size / 2
105 |
106 | # translation
107 | # Note, if trans = (original_size/2 - cropped_size/2), the following translation will
108 | # cancel the above translation(after discretion). It will be set it when in test mode.
109 | coord -= trans
110 |
111 | return coord
112 |
113 |
114 | def generate_cubic_input(points, refpoint, new_size, angle, trans, sizes):
115 | _, cropped_size, _ = sizes
116 | coord = generate_coord(points, refpoint, new_size, angle, trans, sizes)
117 |
118 | # scattering
119 | cubic = scattering(coord, cropped_size)
120 |
121 | return cubic
122 |
123 |
124 | def generate_heatmap_gt(keypoints, refpoint, new_size, angle, trans, sizes, d3outputs, pool_factor, std):
125 | _, cropped_size, _ = sizes
126 | d3output_x, d3output_y, d3output_z = d3outputs
127 |
128 | coord = generate_coord(keypoints, refpoint, new_size, angle, trans, sizes) # [0, cropped_size]
129 | coord /= pool_factor # [0, cropped_size/pool_factor]
130 |
131 | # heatmap generation
132 | output_size = int(cropped_size / pool_factor)
133 | heatmap = np.zeros((keypoints.shape[0], output_size, output_size, output_size))
134 |
135 | # use center of cell
136 | center_offset = 0.5
137 |
138 | for i in range(coord.shape[0]):
139 | xi, yi, zi= coord[i]
140 | heatmap[i] = np.exp(-(np.power((d3output_x+center_offset-xi)/std, 2)/2 + \
141 | np.power((d3output_y+center_offset-yi)/std, 2)/2 + \
142 | np.power((d3output_z+center_offset-zi)/std, 2)/2))
143 |
144 | return heatmap
145 |
146 |
147 | class V2VVoxelization(object):
148 | def __init__(self, cubic_size, augmentation=True):
149 | self.cubic_size = cubic_size
150 | self.cropped_size, self.original_size = 88, 96
151 | self.sizes = (self.cubic_size, self.cropped_size, self.original_size)
152 | self.pool_factor = 2
153 | self.std = 1.7
154 | self.augmentation = augmentation
155 |
156 | output_size = int(self.cropped_size / self.pool_factor)
157 | # Note, range(size) and indexing = 'ij'
158 | self.d3outputs = np.meshgrid(np.arange(output_size), np.arange(output_size), np.arange(output_size), indexing='ij')
159 |
160 | def __call__(self, sample):
161 | points, keypoints, refpoint = sample['points'], sample['keypoints'], sample['refpoint']
162 |
163 | ## Augmentations
164 | # Resize
165 | new_size = np.random.rand() * 40 + 80
166 |
167 | # Rotation
168 | angle = np.random.rand() * 80/180*np.pi - 40/180*np.pi
169 |
170 | # Translation
171 | trans = np.random.rand(3) * (self.original_size-self.cropped_size)
172 |
173 | if not self.augmentation:
174 | new_size = 100
175 | angle = 0
176 | trans = self.original_size/2 - self.cropped_size/2
177 |
178 | input = generate_cubic_input(points, refpoint, new_size, angle, trans, self.sizes)
179 | heatmap = generate_heatmap_gt(keypoints, refpoint, new_size, angle, trans, self.sizes, self.d3outputs, self.pool_factor, self.std)
180 |
181 | return input.reshape((1, *input.shape)), heatmap
182 |
183 | def voxelize(self, points, refpoint):
184 | new_size, angle, trans = 100, 0, self.original_size/2 - self.cropped_size/2
185 | input = generate_cubic_input(points, refpoint, new_size, angle, trans, self.sizes)
186 | return input.reshape((1, *input.shape))
187 |
188 | def generate_heatmap(self, keypoints, refpoint):
189 | new_size, angle, trans = 100, 0, self.original_size/2 - self.cropped_size/2
190 | heatmap = generate_heatmap_gt(keypoints, refpoint, new_size, angle, trans, self.sizes, self.d3outputs, self.pool_factor, self.std)
191 | return heatmap
192 |
193 | def evaluate(self, heatmaps, refpoints):
194 | coords = extract_coord_from_output(heatmaps)
195 | coords *= self.pool_factor
196 | keypoints = warp2continuous(coords, refpoints, self.cubic_size, self.cropped_size)
197 | return keypoints
198 |
--------------------------------------------------------------------------------
/vis/plot.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 |
3 |
4 | def plot_acc(ax, dist, acc, name):
5 | '''
6 | acc: (K, num)
7 | dist: (K, )
8 | name: (K, )
9 | '''
10 | assert(acc.shape[0] == len(name))
11 |
12 | for i in range(len(name)):
13 | ax.plot(dist, acc[i], label=name[i])
14 |
15 | ax.legend()
16 |
17 | ax.set_xlabel('Maximum allowed distance to GT (mm)')
18 | ax.set_ylabel('Fraction of samples within distance')
19 |
20 |
21 | def plot_mean_err(ax, mean_err, name):
22 | '''
23 | mean_err: (K, )
24 | name: (K, )
25 | '''
26 | ax.bar(name, mean_err)
27 |
--------------------------------------------------------------------------------