├── .gitignore
├── LICENSE
├── README.md
├── docs
├── algebraic-model.svg
├── unprojection.gif
├── video-preview.jpg
└── volumetric-model.svg
├── experiments
└── human36m
│ ├── eval
│ ├── human36m_alg.yaml
│ ├── human36m_ransac.yaml
│ └── human36m_vol_softmax.yaml
│ └── train
│ ├── human36m_alg.yaml
│ ├── human36m_alg_no_conf.yaml
│ └── human36m_vol_softmax.yaml
├── mvn
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── human36m.py
│ ├── human36m_preprocessing
│ │ ├── README.md
│ │ ├── action_to_bbox_filename.py
│ │ ├── action_to_una_dinosauria.py
│ │ ├── collect-bboxes.py
│ │ ├── generate-labels-npy-multiview.py
│ │ ├── undistort-h36m.py
│ │ └── view-dataset.py
│ └── utils.py
├── models
│ ├── __init__.py
│ ├── loss.py
│ ├── pose_resnet.py
│ ├── triangulation.py
│ └── v2v.py
└── utils
│ ├── __init__.py
│ ├── cfg.py
│ ├── img.py
│ ├── misc.py
│ ├── multiview.py
│ ├── op.py
│ ├── vis.py
│ └── volumetric.py
├── requirements.txt
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2019 Samsung AI Center Moscow, Karim Iskakov (k.iskakov@samsung.com)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://paperswithcode.com/sota/3d-human-pose-estimation-on-human36m?p=190505754)
2 |
3 | # Learnable Triangulation of Human Pose
4 | This repository is an official PyTorch implementation of the paper ["Learnable Triangulation of Human Pose"](https://arxiv.org/abs/1905.05754) (ICCV 2019, oral). Here we tackle the problem of 3D human pose estimation from multiple cameras. We present 2 novel methods — Algebraic and Volumetric learnable triangulation — that **outperform** previous state of the art.
5 |
6 | If you find a bug, have a question or know to improve the code - please open an issue!
7 |
8 | :arrow_forward: [ICCV 2019 talk](https://youtu.be/zem03fZWLrQ?t=3477)
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | # How to use
17 | This project doesn't have any special or difficult-to-install dependencies. All installation can be done with:
18 | ```bash
19 | pip install -r requirements.txt
20 | ```
21 |
22 | ## Data
23 | Sorry, only [Human3.6M](http://vision.imar.ro/human3.6m/description.php) dataset training/evaluation is available right now. We cannot add [CMU Panoptic](http://domedb.perception.cs.cmu.edu/), sorry for that.
24 |
25 | #### Human3.6M
26 | 1. Download and preprocess the dataset by following the instructions in [mvn/datasets/human36m_preprocessing/README.md](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/mvn/datasets/human36m_preprocessing/README.md).
27 | 2. Download pretrained backbone's weights from [here](https://disk.yandex.ru/d/hv-uH_7TY0ONpg) and place them here: `./data/pretrained/human36m/pose_resnet_4.5_pixels_human36m.pth` (ResNet-152 trained on COCO dataset and finetuned jointly on MPII and Human3.6M).
28 | 3. If you want to train Volumetric model, you need rough estimations of the pelvis' 3D positions both for train and val splits. In the paper we estimate them using the Algebraic model. You can use the [pretrained](#model-zoo) Algebraic model to produce predictions or just take [precalculated 3D skeletons](#model-zoo).
29 |
30 | ## Model zoo
31 | In this section we collect pretrained models and configs. All **pretrained weights** and **precalculated 3D skeletons** can be downloaded at once [from here](https://disk.yandex.ru/d/jbsyN8XVzD-Y7Q) and placed to `./data/pretrained`, so that eval configs can work out-of-the-box (without additional setting of paths). Alternatively, the table below provides separate links to those files.
32 |
33 | **Human3.6M:**
34 |
35 | | Model | Train config | Eval config | Weights | Precalculated results | MPJPE (relative to pelvis), mm |
36 | |----------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------:|-------------------------------:|
37 | | Algebraic | [train/human36m_alg.yaml](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/experiments/human36m/train/human36m_alg.yaml) | [eval/human36m_alg.yaml](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/experiments/human36m/eval/human36m_alg.yaml) | [link](https://disk.yandex.ru/d/3TJMKaa6iKaymw) | [train](https://disk.yandex.ru/d/2Gwk7JZ_QWpFvw), [val](https://disk.yandex.ru/d/ZsQ4GV5EX_Wsog) | 22.5 |
38 | | Volumetric (softmax) | [train/human36m_vol_softmax.yaml](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/experiments/human36m/train/human36m_vol_softmax.yaml) | [eval/human36m_vol_softmax.yaml](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/experiments/human36m/eval/human36m_vol_softmax.yaml) | [link](https://disk.yandex.ru/d/MvD3orcBc6wqRA) | — | **20.4** |
39 |
40 | ## Train
41 | Every experiment is defined by `.config` files. Configs with experiments from the paper can be found in the `./experiments` directory (see [model zoo](#model-zoo)).
42 |
43 | #### Single-GPU
44 | To train a Volumetric model with softmax aggregation using **1 GPU**, run:
45 | ```bash
46 | python3 train.py \
47 | --config experiments/human36m/train/human36m_vol_softmax.yaml \
48 | --logdir ./logs
49 | ```
50 |
51 | The training will start with the config file specified by `--config`, and logs (including tensorboard files) will be stored in `--logdir`.
52 |
53 | #### Multi-GPU (*in testing*)
54 | Multi-GPU training is implemented with PyTorch's [DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#distributeddataparallel). It can be used both for single-machine and multi-machine (cluster) training. To run the processes use the PyTorch [launch utility](https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py).
55 |
56 | To train a Volumetric model with softmax aggregation using **2 GPUs on single machine**, run:
57 | ```bash
58 | python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=2345 \
59 | train.py \
60 | --config experiments/human36m/train/human36m_vol_softmax.yaml \
61 | --logdir ./logs
62 | ```
63 |
64 | ## Tensorboard
65 | To watch your experiments' progress, run tensorboard:
66 | ```bash
67 | tensorboard --logdir ./logs
68 | ```
69 |
70 | ## Evaluation
71 | After training, you can evaluate the model. Inside the same config file, add path to the learned weights (they are dumped to `logs` dir during training):
72 | ```yaml
73 | model:
74 | init_weights: true
75 | checkpoint: {PATH_TO_WEIGHTS}
76 | ```
77 |
78 | Also, you can change other config parameters like `retain_every_n_frames_test`.
79 |
80 | Run:
81 | ```bash
82 | python3 train.py \
83 | --eval --eval_dataset val \
84 | --config experiments/human36m/eval/human36m_vol_softmax.yaml \
85 | --logdir ./logs
86 | ```
87 | Argument `--eval_dataset` can be `val` or `train`. Results can be seen in `logs` directory or in the tensorboard.
88 |
89 | # Results
90 | * We conduct experiments on two available large multi-view datasets: Human3.6M [\[2\]](#references) and CMU Panoptic [\[3\]](#references).
91 | * The main metric is **MPJPE** (Mean Per Joint Position Error) which is L2 distance averaged over all joints.
92 |
93 | ## Human3.6M
94 | * We significantly improved upon the previous state of the art (error is measured relative to pelvis, without alignment).
95 | * Our best model reaches **17.7 mm** error in absolute coordinates, which was unattainable before.
96 | * Our Volumetric model is able to estimate 3D human pose using **any number of cameras**, even using **only 1 camera**. In single-view setup, we get results comparable to current state of the art [\[6\]](#references) (49.9 mm vs. 49.6 mm).
97 |
98 |
99 | MPJPE relative to pelvis:
100 |
101 | | | MPJPE (averaged across all actions), mm |
102 | |----------------------------- |:--------: |
103 | | Multi-View Martinez [\[4\]](#references) | 57.0 |
104 | | Pavlakos et al. [\[8\]](#references) | 56.9 |
105 | | Tome et al. [\[4\]](#references) | 52.8 |
106 | | Kadkhodamohammadi & Padoy [\[5\]](#references) | 49.1 |
107 | | [Qiu et al.](https://github.com/microsoft/multiview-human-pose-estimation-pytorch) [\[9\]](#references) | 26.2 |
108 | | RANSAC (our implementation) | 27.4 |
109 | | **Ours, algebraic** | 22.4 |
110 | | **Ours, volumetric** | **20.5** |
111 |
112 |
113 | MPJPE absolute (scenes with invalid ground-truth annotations are excluded):
114 |
115 | | | MPJPE (averaged across all actions), mm |
116 | |----------------------------- |:--------: |
117 | | RANSAC (our implementation) | 22.8 |
118 | | **Ours, algebraic** | 19.2 |
119 | | **Ours, volumetric** | **17.7** |
120 |
121 |
122 | MPJPE relative to pelvis (single-view methods):
123 |
124 | | | MPJPE (averaged across all actions), mm |
125 | |----------------------------- |:-----------------------------------: |
126 | | Martinez et al. [\[7\]](#references) | 62.9 |
127 | | Sun et al. [\[6\]](#references) | **49.6** |
128 | | **Ours, volumetric single view** | **49.9** |
129 |
130 |
131 | ## CMU Panoptic
132 | * Our best model reaches **13.7 mm** error in absolute coordinates for 4 cameras
133 | * We managed to get much smoother and more accurate 3D pose annotations compared to dataset annotations (see [video demonstration](http://www.youtube.com/watch?v=z3f3aPSuhqg))
134 |
135 |
136 | MPJPE relative to pelvis [4 cameras]:
137 |
138 | | | MPJPE, mm |
139 | |----------------------------- |:--------: |
140 | | RANSAC (our implementation) | 39.5 |
141 | | **Ours, algebraic** | 21.3 |
142 | | **Ours, volumetric** | **13.7** |
143 |
144 | # Method overview
145 | We present 2 novel methods of learnable triangulation: Algebraic and Volumetric.
146 |
147 | ## Algebraic
148 | 
149 |
150 | Our first method is based on Algebraic triangulation. It is similar to the previous approaches, but differs in 2 critical aspects:
151 | 1. It is **fully differentiable**. To achieve this, we use soft-argmax aggregation and triangulate keypoints via a differentiable SVD.
152 | 2. The neural network additionally predicts **scalar confidences for each joint**, passed to the triangulation module, which successfully deals with outliers and occluded joints.
153 |
154 | For the most popular Human3.6M dataset, this method already dramatically reduces error by **2.2 times (!)**, compared to the previous art.
155 |
156 |
157 | ## Volumetric
158 | 
159 |
160 | In Volumetric triangulation model, intermediate 2D feature maps are densely unprojected to the volumetric cube and then processed with a 3D-convolutional neural network. Unprojection operation allows **dense aggregation from multiple views** and the 3D-convolutional neural network is able to model **implicit human pose prior**.
161 |
162 | Volumetric triangulation additionally improves accuracy, drastically reducing the previous state-of-the-art error by **2.4 times!** Even compared to the best parallelly developed [method](https://github.com/microsoft/multiview-human-pose-estimation-pytorch) by MSRA group, our method still offers significantly lower error of **21 mm**.
163 |
164 |
165 |
166 |
167 |
168 |
169 | # Cite us!
170 | ```bibtex
171 | @inproceedings{iskakov2019learnable,
172 | title={Learnable Triangulation of Human Pose},
173 | author={Iskakov, Karim and Burkov, Egor and Lempitsky, Victor and Malkov, Yury},
174 | booktitle = {International Conference on Computer Vision (ICCV)},
175 | year={2019}
176 | }
177 | ```
178 |
179 | # Contributors
180 | - [Karim Iskakov](https://github.com/karfly)
181 | - [Egor Burkov](https://github.com/shrubb)
182 | - [Victor Lempitsky](https://scholar.google.com/citations?user=gYYVokYAAAAJ&hl=ru)
183 | - [Yury Malkov](https://github.com/yurymalkov)
184 | - [Rasul Kerimov](https://github.com/rrkarim)
185 | - [Ivan Bulygin](https://github.com/blufzzz)
186 |
187 | # News
188 | - **26 Nov 2019:** Updataed [precalculated results](#model-zoo) (see [this issue](https://github.com/karfly/learnable-triangulation-pytorch/issues/37)).
189 | - **18 Oct 2019:** Pretrained models (algebraic and volumetric) for Human3.6M are released.
190 | - **8 Oct 2019:** Code is released!
191 |
192 | # References
193 | * [\[1\]](#references) R. Hartley and A. Zisserman. **Multiple view geometry in computer vision**.
194 | * [\[2\]](#references) C. Ionescu, D. Papava, V. Olaru, and C. Sminchisescu. **Human3.6m: Large scale datasets and predictive methods for 3d human sensing in natural environments**.
195 | * [\[3\]](#references) H. Joo, T. Simon, X. Li, H. Liu, L. Tan, L. Gui, S. Banerjee, T. S. Godisart, B. Nabbe, I. Matthews, T. Kanade,S. Nobuhara, and Y. Sheikh. **Panoptic studio: A massively multiview system for social interaction capture**.
196 | * [\[4\]](#references) D. Tome, M. Toso, L. Agapito, and C. Russell. **Rethinking Pose in 3D: Multi-stage Refinement and Recovery for Markerless Motion Capture**.
197 | * [\[5\]](#references) A. Kadkhodamohammadi and N. Padoy. **A generalizable approach for multi-view 3D human pose regression**.
198 | * [\[6\]](#references) X. Sun, B. Xiao, S. Liang, and Y. Wei. **Integral human pose regression**.
199 | * [\[7\]](#references) J. Martinez, R. Hossain, J. Romero, and J. J. Little. **A simple yet effective baseline for 3d human pose estimation**.
200 | * [\[8\]](#references) G. Pavlakos, X. Zhou, K. G. Derpanis, and K. Daniilidis. **Harvesting multiple views for marker-less 3D human pose annotations**.
201 | * [\[9\]](#references) H. Qiu, C. Wang, J. Wang, N. Wang and W. Zeng. (2019). **Cross View Fusion for 3D Human Pose Estimation**, [GitHub](https://github.com/microsoft/multiview-human-pose-estimation-pytorch)
202 |
--------------------------------------------------------------------------------
/docs/unprojection.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/karfly/learnable-triangulation-pytorch/8dcc4e97a407ba0474d2e5299a92bb32cfaeadfe/docs/unprojection.gif
--------------------------------------------------------------------------------
/docs/video-preview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/karfly/learnable-triangulation-pytorch/8dcc4e97a407ba0474d2e5299a92bb32cfaeadfe/docs/video-preview.jpg
--------------------------------------------------------------------------------
/experiments/human36m/eval/human36m_alg.yaml:
--------------------------------------------------------------------------------
1 | title: "human36m_alg"
2 | kind: "human36m"
3 | vis_freq: 1000
4 | vis_n_elements: 10
5 |
6 | image_shape: [384, 384]
7 |
8 | opt:
9 | criterion: "MSESmooth"
10 | mse_smooth_threshold: 400
11 |
12 | n_objects_per_epoch: 15000
13 | n_epochs: 9999
14 |
15 | batch_size: 8
16 | val_batch_size: 100
17 |
18 | lr: 0.00001
19 |
20 | scale_keypoints_3d: 0.1
21 |
22 | model:
23 | name: "alg"
24 |
25 | init_weights: true
26 | checkpoint: "./data/pretrained/human36m/human36m_alg_10-04-2019/checkpoints/0060/weights.pth"
27 |
28 | use_confidences: true
29 | heatmap_multiplier: 100.0
30 | heatmap_softmax: true
31 |
32 | backbone:
33 | name: "resnet152"
34 | style: "simple"
35 |
36 | init_weights: true
37 | checkpoint: "./data/pretrained/human36m/pose_resnet_4.5_pixels_human36m.pth"
38 |
39 | num_joints: 17
40 | num_layers: 152
41 |
42 | dataset:
43 | kind: "human36m"
44 |
45 | train:
46 | h36m_root: "./data/human36m/processed/"
47 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
48 | with_damaged_actions: true
49 | undistort_images: true
50 |
51 | scale_bbox: 1.0
52 |
53 | shuffle: true
54 | randomize_n_views: false
55 | min_n_views: null
56 | max_n_views: null
57 | num_workers: 8
58 |
59 | val:
60 | h36m_root: "./data/human36m/processed/"
61 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
62 | with_damaged_actions: true
63 | undistort_images: true
64 |
65 | scale_bbox: 1.0
66 |
67 | shuffle: false
68 | randomize_n_views: false
69 | min_n_views: null
70 | max_n_views: null
71 | num_workers: 8
72 |
73 | retain_every_n_frames_in_test: 1
74 |
--------------------------------------------------------------------------------
/experiments/human36m/eval/human36m_ransac.yaml:
--------------------------------------------------------------------------------
1 | title: "human36m_ransac"
2 | kind: "human36m"
3 | vis_freq: 1000
4 | vis_n_elements: 10
5 |
6 | image_shape: [384, 384]
7 |
8 | opt:
9 | criterion: "MSESmooth"
10 | mse_smooth_threshold: 400
11 |
12 | n_objects_per_epoch: 15000
13 | n_epochs: 9999
14 |
15 | batch_size: 8
16 | val_batch_size: 100
17 |
18 | lr: 0.00001
19 |
20 | scale_keypoints_3d: 0.1
21 |
22 | model:
23 | name: "ransac"
24 |
25 | init_weights: false
26 | checkpoint: ""
27 |
28 | direct_optimization: true
29 | heatmap_multiplier: 100.0
30 | heatmap_softmax: true
31 |
32 | backbone:
33 | name: "resnet152"
34 | style: "simple"
35 |
36 | init_weights: true
37 | checkpoint: "./data/pretrained/human36m/pose_resnet_4.5_pixels_human36m.pth"
38 |
39 | num_joints: 17
40 | num_layers: 152
41 |
42 | dataset:
43 | kind: "human36m"
44 |
45 | train:
46 | h36m_root: "./data/human36m/processed/"
47 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
48 | with_damaged_actions: true
49 | undistort_images: true
50 |
51 | scale_bbox: 1.0
52 |
53 | shuffle: true
54 | randomize_n_views: false
55 | min_n_views: null
56 | max_n_views: null
57 | num_workers: 8
58 |
59 | val:
60 | h36m_root: "./data/human36m/processed/"
61 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
62 | with_damaged_actions: true
63 | undistort_images: true
64 |
65 | scale_bbox: 1.0
66 |
67 | shuffle: false
68 | randomize_n_views: false
69 | min_n_views: null
70 | max_n_views: null
71 | num_workers: 8
72 |
73 | retain_every_n_frames_in_test: 1
74 |
--------------------------------------------------------------------------------
/experiments/human36m/eval/human36m_vol_softmax.yaml:
--------------------------------------------------------------------------------
1 | title: "human36m_vol_softmax"
2 | kind: "human36m"
3 | vis_freq: 1000
4 | vis_n_elements: 10
5 |
6 | image_shape: [384, 384]
7 |
8 | opt:
9 | criterion: "MAE"
10 |
11 | use_volumetric_ce_loss: true
12 | volumetric_ce_loss_weight: 0.01
13 |
14 | n_objects_per_epoch: 15000
15 | n_epochs: 9999
16 |
17 | batch_size: 5
18 | val_batch_size: 75
19 |
20 | lr: 0.0001
21 | process_features_lr: 0.001
22 | volume_net_lr: 0.001
23 |
24 | scale_keypoints_3d: 0.1
25 |
26 | model:
27 | name: "vol"
28 | kind: "mpii"
29 | volume_aggregation_method: "softmax"
30 |
31 | init_weights: true
32 | checkpoint: "./data/pretrained/human36m/human36m_vol_softmax_10-08-2019/checkpoints/0040/weights.pth"
33 |
34 | use_gt_pelvis: false
35 |
36 | cuboid_side: 2500.0
37 |
38 | volume_size: 64
39 | volume_multiplier: 1.0
40 | volume_softmax: true
41 |
42 | heatmap_softmax: true
43 | heatmap_multiplier: 100.0
44 |
45 | backbone:
46 | name: "resnet152"
47 | style: "simple"
48 |
49 | init_weights: true
50 | checkpoint: "./data/pretrained/human36m/pose_resnet_4.5_pixels_human36m.pth"
51 |
52 | num_joints: 17
53 | num_layers: 152
54 |
55 | dataset:
56 | kind: "human36m"
57 |
58 | train:
59 | h36m_root: "./data/human36m/processed/"
60 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
61 | pred_results_path: "./data/pretrained/human36m/human36m_alg_10-04-2019/checkpoints/0060/results/train.pkl"
62 |
63 | with_damaged_actions: true
64 | undistort_images: true
65 |
66 | scale_bbox: 1.0
67 |
68 | shuffle: true
69 | randomize_n_views: false
70 | min_n_views: null
71 | max_n_views: null
72 | num_workers: 5
73 |
74 | val:
75 | h36m_root: "./data/human36m/processed/"
76 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
77 | pred_results_path: "./data/pretrained/human36m/human36m_alg_10-04-2019/checkpoints/0060/results/val.pkl"
78 |
79 | with_damaged_actions: true
80 | undistort_images: true
81 |
82 | scale_bbox: 1.0
83 |
84 | shuffle: false
85 | randomize_n_views: false
86 | min_n_views: null
87 | max_n_views: null
88 | num_workers: 8
89 |
90 | retain_every_n_frames_in_test: 1
91 |
--------------------------------------------------------------------------------
/experiments/human36m/train/human36m_alg.yaml:
--------------------------------------------------------------------------------
1 | title: "human36m_alg"
2 | kind: "human36m"
3 | vis_freq: 1000
4 | vis_n_elements: 10
5 |
6 | image_shape: [384, 384]
7 |
8 | opt:
9 | criterion: "MSESmooth"
10 | mse_smooth_threshold: 400
11 |
12 | n_objects_per_epoch: 15000
13 | n_epochs: 9999
14 |
15 | batch_size: 8
16 | val_batch_size: 16
17 |
18 | lr: 0.00001
19 |
20 | scale_keypoints_3d: 0.1
21 |
22 | model:
23 | name: "alg"
24 |
25 | init_weights: false
26 | checkpoint: ""
27 |
28 | use_confidences: true
29 | heatmap_multiplier: 100.0
30 | heatmap_softmax: true
31 |
32 | backbone:
33 | name: "resnet152"
34 | style: "simple"
35 |
36 | init_weights: true
37 | checkpoint: "./data/pretrained/human36m/pose_resnet_4.5_pixels_human36m.pth"
38 |
39 | num_joints: 17
40 | num_layers: 152
41 |
42 | dataset:
43 | kind: "human36m"
44 |
45 | train:
46 | h36m_root: "./data/human36m/processed/"
47 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
48 | with_damaged_actions: true
49 | undistort_images: true
50 |
51 | scale_bbox: 1.0
52 |
53 | shuffle: true
54 | randomize_n_views: false
55 | min_n_views: null
56 | max_n_views: null
57 | num_workers: 8
58 |
59 | val:
60 | h36m_root: "./data/human36m/processed/"
61 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
62 | with_damaged_actions: true
63 | undistort_images: true
64 |
65 | scale_bbox: 1.0
66 |
67 | shuffle: false
68 | randomize_n_views: false
69 | min_n_views: null
70 | max_n_views: null
71 | num_workers: 8
72 |
73 | retain_every_n_frames_in_test: 1
74 |
--------------------------------------------------------------------------------
/experiments/human36m/train/human36m_alg_no_conf.yaml:
--------------------------------------------------------------------------------
1 | title: "human36m_alg"
2 | kind: "human36m"
3 | vis_freq: 1000
4 | vis_n_elements: 10
5 |
6 | image_shape: [384, 384]
7 |
8 | opt:
9 | criterion: "MSESmooth"
10 | mse_smooth_threshold: 400
11 |
12 | n_objects_per_epoch: 10000
13 | n_epochs: 9999
14 |
15 | batch_size: 8
16 | val_batch_size: 16
17 |
18 | lr: 0.00001
19 |
20 | scale_keypoints_3d: 0.1
21 |
22 | model:
23 | name: "alg"
24 |
25 | init_weights: false
26 | checkpoint: ""
27 |
28 | use_confidences: false
29 | heatmap_multiplier: 100.0
30 | heatmap_softmax: true
31 |
32 | backbone:
33 | name: "resnet152"
34 | style: "simple"
35 |
36 | init_weights: true
37 | checkpoint: "./data/pretrained/human36m/pose_resnet_4.5_pixels_human36m.pth"
38 |
39 | num_joints: 17
40 | num_layers: 152
41 |
42 | dataset:
43 | kind: "human36m"
44 |
45 | train:
46 | h36m_root: "./data/human36m/processed/"
47 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
48 | with_damaged_actions: true
49 | undistort_images: true
50 |
51 | scale_bbox: 1.0
52 |
53 | shuffle: true
54 | randomize_n_views: false
55 | min_n_views: null
56 | max_n_views: null
57 | num_workers: 8
58 |
59 | val:
60 | h36m_root: "./data/human36m/processed/"
61 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
62 | with_damaged_actions: true
63 | undistort_images: true
64 |
65 | scale_bbox: 1.0
66 |
67 | shuffle: false
68 | randomize_n_views: false
69 | min_n_views: null
70 | max_n_views: null
71 | num_workers: 8
72 |
73 | retain_every_n_frames_in_test: 1
74 |
--------------------------------------------------------------------------------
/experiments/human36m/train/human36m_vol_softmax.yaml:
--------------------------------------------------------------------------------
1 | title: "human36m_vol_softmax"
2 | kind: "human36m"
3 | vis_freq: 1000
4 | vis_n_elements: 10
5 |
6 | image_shape: [384, 384]
7 |
8 | opt:
9 | criterion: "MAE"
10 |
11 | use_volumetric_ce_loss: true
12 | volumetric_ce_loss_weight: 0.01
13 |
14 | n_objects_per_epoch: 15000
15 | n_epochs: 9999
16 |
17 | batch_size: 5
18 | val_batch_size: 10
19 |
20 | lr: 0.0001
21 | process_features_lr: 0.001
22 | volume_net_lr: 0.001
23 |
24 | scale_keypoints_3d: 0.1
25 |
26 | model:
27 | name: "vol"
28 | kind: "mpii"
29 | volume_aggregation_method: "softmax"
30 |
31 | init_weights: false
32 | checkpoint: ""
33 |
34 | use_gt_pelvis: false
35 |
36 | cuboid_side: 2500.0
37 |
38 | volume_size: 64
39 | volume_multiplier: 1.0
40 | volume_softmax: true
41 |
42 | heatmap_softmax: true
43 | heatmap_multiplier: 100.0
44 |
45 | backbone:
46 | name: "resnet152"
47 | style: "simple"
48 |
49 | init_weights: true
50 | checkpoint: "./data/pretrained/human36m/pose_resnet_4.5_pixels_human36m.pth"
51 |
52 | num_joints: 17
53 | num_layers: 152
54 |
55 | dataset:
56 | kind: "human36m"
57 |
58 | train:
59 | h36m_root: "./data/human36m/processed/"
60 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
61 | pred_results_path: "./data/pretrained/human36m/human36m_alg_10-04-2019/checkpoints/0060/results/train.pkl"
62 |
63 | with_damaged_actions: true
64 | undistort_images: true
65 |
66 | scale_bbox: 1.0
67 |
68 | shuffle: true
69 | randomize_n_views: false
70 | min_n_views: null
71 | max_n_views: null
72 | num_workers: 5
73 |
74 | val:
75 | h36m_root: "./data/human36m/processed/"
76 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
77 | pred_results_path: "./data/pretrained/human36m/human36m_alg_10-04-2019/checkpoints/0060/results/val.pkl"
78 |
79 | with_damaged_actions: true
80 | undistort_images: true
81 |
82 | scale_bbox: 1.0
83 |
84 | shuffle: false
85 | randomize_n_views: false
86 | min_n_views: null
87 | max_n_views: null
88 | num_workers: 10
89 |
90 | retain_every_n_frames_in_test: 1
91 |
--------------------------------------------------------------------------------
/mvn/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
--------------------------------------------------------------------------------
/mvn/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
--------------------------------------------------------------------------------
/mvn/datasets/human36m.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import defaultdict
3 | import pickle
4 |
5 | import numpy as np
6 | import cv2
7 |
8 | import torch
9 | from torch.utils.data import Dataset
10 |
11 | from mvn.utils.multiview import Camera
12 | from mvn.utils.img import get_square_bbox, resize_image, crop_image, normalize_image, scale_bbox
13 | from mvn.utils import volumetric
14 |
15 |
16 | class Human36MMultiViewDataset(Dataset):
17 | """
18 | Human3.6M for multiview tasks.
19 | """
20 | def __init__(self,
21 | h36m_root='/Vol1/dbstore/datasets/Human3.6M/processed/',
22 | labels_path='/Vol1/dbstore/datasets/Human3.6M/extra/human36m-multiview-labels-SSDbboxes.npy',
23 | pred_results_path=None,
24 | image_shape=(256, 256),
25 | train=False,
26 | test=False,
27 | retain_every_n_frames_in_test=1,
28 | with_damaged_actions=False,
29 | cuboid_side=2000.0,
30 | scale_bbox=1.5,
31 | norm_image=True,
32 | kind="mpii",
33 | undistort_images=False,
34 | ignore_cameras=[],
35 | crop=True
36 | ):
37 | """
38 | h36m_root:
39 | Path to 'processed/' directory in Human3.6M
40 | labels_path:
41 | Path to 'human36m-multiview-labels.npy' generated by 'generate-labels-npy-multiview.py'
42 | from https://github.sec.samsung.net/RRU8-VIOLET/human36m-preprocessing
43 | retain_every_n_frames_in_test:
44 | By default, there are 159 181 frames in training set and 26 634 in test (val) set.
45 | With this parameter, test set frames will be evenly skipped frames so that the
46 | test set size is `26634 // retain_every_n_frames_test`.
47 | Use a value of 13 to get 2049 frames in test set.
48 | with_damaged_actions:
49 | If `True`, will include 'S9/[Greeting-2,SittingDown-2,Waiting-1]' in test set.
50 | kind:
51 | Keypoint format, 'mpii' or 'human36m'
52 | ignore_cameras:
53 | A list with indices of cameras to exclude (0 to 3 inclusive)
54 | """
55 | assert train or test, '`Human36MMultiViewDataset` must be constructed with at least ' \
56 | 'one of `test=True` / `train=True`'
57 | assert kind in ("mpii", "human36m")
58 |
59 | self.h36m_root = h36m_root
60 | self.labels_path = labels_path
61 | self.image_shape = None if image_shape is None else tuple(image_shape)
62 | self.scale_bbox = scale_bbox
63 | self.norm_image = norm_image
64 | self.cuboid_side = cuboid_side
65 | self.kind = kind
66 | self.undistort_images = undistort_images
67 | self.ignore_cameras = ignore_cameras
68 | self.crop = crop
69 |
70 | self.labels = np.load(labels_path, allow_pickle=True).item()
71 |
72 | n_cameras = len(self.labels['camera_names'])
73 | assert all(camera_idx in range(n_cameras) for camera_idx in self.ignore_cameras)
74 |
75 | train_subjects = ['S1', 'S5', 'S6', 'S7', 'S8']
76 | test_subjects = ['S9', 'S11']
77 |
78 | train_subjects = list(self.labels['subject_names'].index(x) for x in train_subjects)
79 | test_subjects = list(self.labels['subject_names'].index(x) for x in test_subjects)
80 |
81 | indices = []
82 | if train:
83 | mask = np.isin(self.labels['table']['subject_idx'], train_subjects, assume_unique=True)
84 | indices.append(np.nonzero(mask)[0])
85 | if test:
86 | mask = np.isin(self.labels['table']['subject_idx'], test_subjects, assume_unique=True)
87 |
88 | if not with_damaged_actions:
89 | mask_S9 = self.labels['table']['subject_idx'] == self.labels['subject_names'].index('S9')
90 |
91 | damaged_actions = 'Greeting-2', 'SittingDown-2', 'Waiting-1'
92 | damaged_actions = [self.labels['action_names'].index(x) for x in damaged_actions]
93 | mask_damaged_actions = np.isin(self.labels['table']['action_idx'], damaged_actions)
94 |
95 | mask &= ~(mask_S9 & mask_damaged_actions)
96 |
97 | indices.append(np.nonzero(mask)[0][::retain_every_n_frames_in_test])
98 |
99 | self.labels['table'] = self.labels['table'][np.concatenate(indices)]
100 |
101 | self.num_keypoints = 16 if kind == "mpii" else 17
102 | assert self.labels['table']['keypoints'].shape[1] == 17, "Use a newer 'labels' file"
103 |
104 | self.keypoints_3d_pred = None
105 | if pred_results_path is not None:
106 | pred_results = np.load(pred_results_path, allow_pickle=True)
107 | keypoints_3d_pred = pred_results['keypoints_3d'][np.argsort(pred_results['indexes'])]
108 | self.keypoints_3d_pred = keypoints_3d_pred[::retain_every_n_frames_in_test]
109 | assert len(self.keypoints_3d_pred) == len(self), \
110 | f"[train={train}, test={test}] {labels_path} has {len(self)} samples, but '{pred_results_path}' " + \
111 | f"has {len(self.keypoints_3d_pred)}. Did you follow all preprocessing instructions carefully?"
112 |
113 | def __len__(self):
114 | return len(self.labels['table'])
115 |
116 | def __getitem__(self, idx):
117 | sample = defaultdict(list) # return value
118 | shot = self.labels['table'][idx]
119 |
120 | subject = self.labels['subject_names'][shot['subject_idx']]
121 | action = self.labels['action_names'][shot['action_idx']]
122 | frame_idx = shot['frame_idx']
123 |
124 | for camera_idx, camera_name in enumerate(self.labels['camera_names']):
125 | if camera_idx in self.ignore_cameras:
126 | continue
127 |
128 | # load bounding box
129 | bbox = shot['bbox_by_camera_tlbr'][camera_idx][[1,0,3,2]] # TLBR to LTRB
130 | bbox_height = bbox[2] - bbox[0]
131 | if bbox_height == 0:
132 | # convention: if the bbox is empty, then this view is missing
133 | continue
134 |
135 | # scale the bounding box
136 | bbox = scale_bbox(bbox, self.scale_bbox)
137 |
138 | # load image
139 | image_path = os.path.join(
140 | self.h36m_root, subject, action, 'imageSequence' + '-undistorted' * self.undistort_images,
141 | camera_name, 'img_%06d.jpg' % (frame_idx+1))
142 | assert os.path.isfile(image_path), '%s doesn\'t exist' % image_path
143 | image = cv2.imread(image_path)
144 |
145 | # load camera
146 | shot_camera = self.labels['cameras'][shot['subject_idx'], camera_idx]
147 | retval_camera = Camera(shot_camera['R'], shot_camera['t'], shot_camera['K'], shot_camera['dist'], camera_name)
148 |
149 | if self.crop:
150 | # crop image
151 | image = crop_image(image, bbox)
152 | retval_camera.update_after_crop(bbox)
153 |
154 | if self.image_shape is not None:
155 | # resize
156 | image_shape_before_resize = image.shape[:2]
157 | image = resize_image(image, self.image_shape)
158 | retval_camera.update_after_resize(image_shape_before_resize, self.image_shape)
159 |
160 | sample['image_shapes_before_resize'].append(image_shape_before_resize)
161 |
162 | if self.norm_image:
163 | image = normalize_image(image)
164 |
165 | sample['images'].append(image)
166 | sample['detections'].append(bbox + (1.0,)) # TODO add real confidences
167 | sample['cameras'].append(retval_camera)
168 | sample['proj_matrices'].append(retval_camera.projection)
169 |
170 | # 3D keypoints
171 | # add dummy confidences
172 | sample['keypoints_3d'] = np.pad(
173 | shot['keypoints'][:self.num_keypoints],
174 | ((0,0), (0,1)), 'constant', constant_values=1.0)
175 |
176 | # build cuboid
177 | # base_point = sample['keypoints_3d'][6, :3]
178 | # sides = np.array([self.cuboid_side, self.cuboid_side, self.cuboid_side])
179 | # position = base_point - sides / 2
180 | # sample['cuboids'] = volumetric.Cuboid3D(position, sides)
181 |
182 | # save sample's index
183 | sample['indexes'] = idx
184 |
185 | if self.keypoints_3d_pred is not None:
186 | sample['pred_keypoints_3d'] = self.keypoints_3d_pred[idx]
187 |
188 | sample.default_factory = None
189 | return sample
190 |
191 | def evaluate_using_per_pose_error(self, per_pose_error, split_by_subject):
192 | def evaluate_by_actions(self, per_pose_error, mask=None):
193 | if mask is None:
194 | mask = np.ones_like(per_pose_error, dtype=bool)
195 |
196 | action_scores = {
197 | 'Average': {'total_loss': per_pose_error[mask].sum(), 'frame_count': np.count_nonzero(mask)}
198 | }
199 |
200 | for action_idx in range(len(self.labels['action_names'])):
201 | action_mask = (self.labels['table']['action_idx'] == action_idx) & mask
202 | action_per_pose_error = per_pose_error[action_mask]
203 | action_scores[self.labels['action_names'][action_idx]] = {
204 | 'total_loss': action_per_pose_error.sum(), 'frame_count': len(action_per_pose_error)
205 | }
206 |
207 | action_names_without_trials = \
208 | [name[:-2] for name in self.labels['action_names'] if name.endswith('-1')]
209 |
210 | for action_name_without_trial in action_names_without_trials:
211 | combined_score = {'total_loss': 0.0, 'frame_count': 0}
212 |
213 | for trial in 1, 2:
214 | action_name = '%s-%d' % (action_name_without_trial, trial)
215 | combined_score['total_loss' ] += action_scores[action_name]['total_loss']
216 | combined_score['frame_count'] += action_scores[action_name]['frame_count']
217 | del action_scores[action_name]
218 |
219 | action_scores[action_name_without_trial] = combined_score
220 |
221 | for k, v in action_scores.items():
222 | action_scores[k] = float('nan') if v['frame_count'] == 0 else (v['total_loss'] / v['frame_count'])
223 |
224 | return action_scores
225 |
226 | subject_scores = {
227 | 'Average': evaluate_by_actions(self, per_pose_error)
228 | }
229 |
230 | for subject_idx in range(len(self.labels['subject_names'])):
231 | subject_mask = self.labels['table']['subject_idx'] == subject_idx
232 | subject_scores[self.labels['subject_names'][subject_idx]] = \
233 | evaluate_by_actions(self, per_pose_error, subject_mask)
234 |
235 | return subject_scores
236 |
237 | def evaluate(self, keypoints_3d_predicted, split_by_subject=False, transfer_cmu_to_human36m=False, transfer_human36m_to_human36m=False):
238 | keypoints_gt = self.labels['table']['keypoints'][:, :self.num_keypoints]
239 | if keypoints_3d_predicted.shape != keypoints_gt.shape:
240 | raise ValueError(
241 | '`keypoints_3d_predicted` shape should be %s, got %s' % \
242 | (keypoints_gt.shape, keypoints_3d_predicted.shape))
243 |
244 | if transfer_cmu_to_human36m or transfer_human36m_to_human36m:
245 | human36m_joints = [10, 11, 15, 14, 1, 4]
246 | if transfer_human36m_to_human36m:
247 | cmu_joints = [10, 11, 15, 14, 1, 4]
248 | else:
249 | cmu_joints = [10, 8, 9, 7, 14, 13]
250 |
251 | keypoints_gt = keypoints_gt[:, human36m_joints]
252 | keypoints_3d_predicted = keypoints_3d_predicted[:, cmu_joints]
253 |
254 | # mean error per 16/17 joints in mm, for each pose
255 | per_pose_error = np.sqrt(((keypoints_gt - keypoints_3d_predicted) ** 2).sum(2)).mean(1)
256 |
257 | # relative mean error per 16/17 joints in mm, for each pose
258 | if not (transfer_cmu_to_human36m or transfer_human36m_to_human36m):
259 | root_index = 6 if self.kind == "mpii" else 6
260 | else:
261 | root_index = 0
262 |
263 | keypoints_gt_relative = keypoints_gt - keypoints_gt[:, root_index:root_index + 1, :]
264 | keypoints_3d_predicted_relative = keypoints_3d_predicted - keypoints_3d_predicted[:, root_index:root_index + 1, :]
265 |
266 | per_pose_error_relative = np.sqrt(((keypoints_gt_relative - keypoints_3d_predicted_relative) ** 2).sum(2)).mean(1)
267 |
268 | result = {
269 | 'per_pose_error': self.evaluate_using_per_pose_error(per_pose_error, split_by_subject),
270 | 'per_pose_error_relative': self.evaluate_using_per_pose_error(per_pose_error_relative, split_by_subject)
271 | }
272 |
273 | return result['per_pose_error_relative']['Average']['Average'], result
274 |
--------------------------------------------------------------------------------
/mvn/datasets/human36m_preprocessing/README.md:
--------------------------------------------------------------------------------
1 | Human3.6M preprocessing scripts
2 | =======
3 |
4 | These scripts help preprocess Human3.6M dataset so that it can be used with `class Human36MMultiViewDataset`.
5 |
6 | Here is how we do it (brace yourselves):
7 |
8 | 0. Make sure you have a lot (around 200 GiB?) of free disk space. Otherwise, be prepared to always carefully delete intermediate files (e.g. after you extract movies, delete the archives).
9 |
10 | 1. Allocate a folder for the dataset. Make it accessible as `$THIS_REPOSITORY/data/human36m/`, i.e. either
11 |
12 | * store your data `$SOMEWHERE_ELSE` and make a soft symbolic link:
13 | ```bash
14 | mkdir $THIS_REPOSITORY/data
15 | ln -s $SOMEWHERE_ELSE $THIS_REPOSITORY/data/human36m
16 | ```
17 | * or just store the dataset along with the code at `$THIS_REPOSITORY/data/human36m/`.
18 |
19 | 1. Clone [this toolbox](https://github.com/anibali/h36m-fetch). Follow their manual to download, extract and unpack Human3.6M into image files. Move the result to `$THIS_REPOSITORY/data/human36m`.
20 |
21 | After that, you should have images unpacked as e.g. `$THIS_REPOSITORY/data/human36m/processed/S1/Phoning-1/imageSequence/54138969/img_000001.jpg`.
22 |
23 | 2. Additionally, if you want to use ground truth bounding boxes for training, download them as well (the website calls them *"Segments BBoxes MAT"*) and unpack them like so: `"$THIS_REPOSITORY/data/human36m/processed/S1/MySegmentsMat/ground_truth_bb/Phoning 1.58860488.mat"`.
24 |
25 | 3. Convert those bounding boxes into sane format. This will create `$THIS_REPOSITORY/data/human36m/extra/bboxes-Human36M-GT.npy`:
26 |
27 | ```bash
28 | cd $THIS_REPOSITORY/mvn/datasets/human36m_preprocessing
29 | # in our environment, this took around 6 minutes with 40 processes
30 | python3 collect-bboxes.py $THIS_REPOSITORY/data/human36m
31 | ```
32 |
33 | 4. Existing 3D keypoint positions and camera intrinsics are difficult to decipher, so initially we used the converted ones [from Julieta Martinez](https://github.com/una-dinosauria/3d-pose-baseline/):
34 |
35 | ```bash
36 | mkdir -p $THIS_REPOSITORY/data/human36m/extra/una-dinosauria-data
37 | cd $THIS_REPOSITORY/data/human36m/extra/una-dinosauria-data
38 | ```
39 |
40 | Download `h36m.zip` from [Google Drive](https://disk.yandex.ru/d/3gPRFzLSFpS27Q) and uzip it to current directory:
41 | ```bash
42 | unzip h36m.zip
43 | cd -
44 | ```
45 |
46 | 5. Wrap the 3D keypoint positions, bounding boxes and camera intrinsics together. This will create `$THIS_REPOSITORY/data/human36m/extra/human36m-multiview-labels-GTbboxes.npy`:
47 |
48 | ```bash
49 | python3 generate-labels-npy-multiview.py $THIS_REPOSITORY/data/human36m $THIS_REPOSITORY/data/human36m/extra/una-dinosauria-data/h36m $THIS_REPOSITORY/data/human36m/extra/bboxes-Human36M-GT.npy
50 | ```
51 |
52 | You should see only one warning saying `camera 54138969 isn't present in S11/Directions-2`. That's fine.
53 |
54 | Now you can train and evaluate models by setting these config values (already set by default in the example configs):
55 |
56 | ```yaml
57 | dataset:
58 | {train,val}:
59 | h36m_root: "data/human36m/processed/"
60 | labels_path: "data/human36m/extra/human36m-multiview-labels-GTbboxes.npy"
61 | ```
62 |
63 | 6. To use `undistort_images: true`, undistort the images beforehand. This will put undistorted images to e.g. `$THIS_REPOSITORY/data/human36m/processed/S1/Phoning-1/imageSequence-undistorted/54138969/img_000001.jpg`:
64 |
65 | ```bash
66 | # in our environment, this took around 90 minutes with 50 processes
67 | python3 undistort-h36m.py $THIS_REPOSITORY/data/human36m $THIS_REPOSITORY/data/human36m/extra/human36m-multiview-labels-GTbboxes.npy `
68 | ```
69 |
70 | *TODO: move undistortion to the dataloader. We can do it on the fly during training.*
71 |
72 | 7. Optionally, you can test if everything went well by viewing frames with skeletons and bounding boxes on a GUI machine:
73 |
74 | ```bash
75 | python3 view-dataset.py $THIS_REPOSITORY/data/human36m/processed $THIS_REPOSITORY/data/human36m/extra/human36m-multiview-labels-GTbboxes.npy
76 | ```
77 |
78 | You can test different settings by changing dataset constructor parameters in `view-dataset.py`.
79 |
--------------------------------------------------------------------------------
/mvn/datasets/human36m_preprocessing/action_to_bbox_filename.py:
--------------------------------------------------------------------------------
1 | action_to_bbox_filename = {
2 | 'S11': {
3 | 'Directions-2': 'Directions',
4 | 'Discussion-2': 'Discussion 2',
5 | 'Eating-2': 'Eating',
6 | 'Greeting-1': 'Greeting 2',
7 | 'Greeting-2': 'Greeting',
8 | 'Phoning-1': 'Phoning 3',
9 | 'Phoning-2': 'Phoning 2',
10 | 'Posing-2': 'Posing',
11 | 'Purchases-2': 'Purchases',
12 | 'Sitting-2': 'Sitting',
13 | 'SittingDown-1': 'SittingDown',
14 | 'SittingDown-2': 'SittingDown 1',
15 | 'Smoking-1': 'Smoking 2',
16 | 'Smoking-2': 'Smoking',
17 | 'TakingPhoto-1': 'Photo 1',
18 | 'TakingPhoto-2': 'Photo',
19 | 'Waiting-2': 'Waiting',
20 | 'Walking-2': 'Walking',
21 | 'WalkingDog-1': 'WalkDog 1',
22 | 'WalkingDog-2': 'WalkDog',
23 | 'WalkingTogether-1': 'WalkTogether 1',
24 | 'WalkingTogether-2': 'WalkTogether'
25 | },
26 | 'S9': {
27 | 'Directions-2': 'Directions',
28 | 'Discussion-2': 'Discussion 2',
29 | 'Eating-2': 'Eating',
30 | 'Greeting-2': 'Greeting',
31 | 'Phoning-1': 'Phoning 1',
32 | 'Phoning-2': 'Phoning',
33 | 'Posing-2': 'Posing',
34 | 'Purchases-2': 'Purchases',
35 | 'Sitting-2': 'Sitting',
36 | 'SittingDown-1': 'SittingDown',
37 | 'SittingDown-2': 'SittingDown 1',
38 | 'Smoking-2': 'Smoking',
39 | 'TakingPhoto-1': 'Photo 1',
40 | 'TakingPhoto-2': 'Photo',
41 | 'Waiting-2': 'Waiting',
42 | 'Walking-2': 'Walking',
43 | 'WalkingDog-1': 'WalkDog 1',
44 | 'WalkingDog-2': 'WalkDog',
45 | 'WalkingTogether-1': 'WalkTogether 1',
46 | 'WalkingTogether-2': 'WalkTogether'
47 | },
48 | 'S8': {
49 | 'Directions-2': 'Directions',
50 | 'Discussion-2': 'Discussion',
51 | 'Eating-2': 'Eating',
52 | 'Greeting-2': 'Greeting',
53 | 'Phoning-1': 'Phoning 1',
54 | 'Phoning-2': 'Phoning',
55 | 'Posing-2': 'Posing',
56 | 'Purchases-2': 'Purchases',
57 | 'Sitting-2': 'Sitting',
58 | 'SittingDown-1': 'SittingDown',
59 | 'SittingDown-2': 'SittingDown 1',
60 | 'Smoking-2': 'Smoking',
61 | 'TakingPhoto-1': 'Photo 1',
62 | 'TakingPhoto-2': 'Photo',
63 | 'Waiting-2': 'Waiting',
64 | 'Walking-2': 'Walking',
65 | 'WalkingDog-1': 'WalkDog 1',
66 | 'WalkingDog-2': 'WalkDog',
67 | 'WalkingTogether-1': 'WalkTogether 1',
68 | 'WalkingTogether-2': 'WalkTogether 2'
69 | },
70 | 'S7': {
71 | 'Directions-2': 'Directions',
72 | 'Discussion-2': 'Discussion',
73 | 'Eating-2': 'Eating',
74 | 'Greeting-2': 'Greeting',
75 | 'Phoning-1': 'Phoning 2',
76 | 'Phoning-2': 'Phoning',
77 | 'Posing-2': 'Posing',
78 | 'Purchases-2': 'Purchases',
79 | 'Sitting-2': 'Sitting',
80 | 'SittingDown-1': 'SittingDown',
81 | 'SittingDown-2': 'SittingDown 1',
82 | 'Smoking-2': 'Smoking',
83 | 'TakingPhoto-1': 'Photo',
84 | 'TakingPhoto-2': 'Photo 1',
85 | 'WalkingDog-1': 'WalkDog 1',
86 | 'WalkingDog-2': 'WalkDog',
87 | 'WalkingTogether-1': 'WalkTogether 1',
88 | 'WalkingTogether-2': 'WalkTogether'
89 | },
90 | 'S6': {
91 | 'Directions-2': 'Directions',
92 | 'Discussion-1': 'Discussion 1',
93 | 'Discussion-2': 'Discussion',
94 | 'Eating-2': 'Eating 2',
95 | 'Eating-1': 'Eating 1',
96 | 'Greeting-2': 'Greeting',
97 | 'Phoning-2': 'Phoning',
98 | 'Posing-1': 'Posing 2',
99 | 'Posing-2': 'Posing',
100 | 'Purchases-2': 'Purchases',
101 | 'SittingDown-1': 'SittingDown 1',
102 | 'SittingDown-2': 'SittingDown',
103 | 'Smoking-2': 'Smoking',
104 | 'TakingPhoto-1': 'Photo',
105 | 'TakingPhoto-2': 'Photo 1',
106 | 'Waiting-1': 'Waiting 3',
107 | 'Waiting-2': 'Waiting',
108 | 'Walking-2': 'Walking',
109 | 'WalkingDog-1': 'WalkDog 1',
110 | 'WalkingDog-2': 'WalkDog',
111 | 'WalkingTogether-1': 'WalkTogether 1',
112 | 'WalkingTogether-2': 'WalkTogether'
113 | },
114 | 'S1': {
115 | 'Discussion-2': 'Discussion',
116 | 'Directions-2': 'Directions',
117 | 'Eating-2': 'Eating',
118 | 'Eating-1': 'Eating 2',
119 | 'Greeting-2': 'Greeting',
120 | 'Phoning-2': 'Phoning',
121 | 'Posing-2': 'Posing',
122 | 'Purchases-2': 'Purchases',
123 | 'SittingDown-1': 'SittingDown 2',
124 | 'SittingDown-2': 'SittingDown',
125 | 'Smoking-2': 'Smoking',
126 | 'TakingPhoto-2': 'TakingPhoto',
127 | 'Waiting-2': 'Waiting',
128 | 'Walking-2': 'Walking',
129 | 'WalkingDog-2': 'WalkingDog',
130 | 'WalkingTogether-1': 'WalkTogether 1',
131 | 'WalkingTogether-2': 'WalkTogether'
132 | },
133 | 'S5': {
134 | 'Discussion-1': 'Discussion 2',
135 | 'Discussion-2': 'Discussion 3',
136 | 'Eating-2': 'Eating',
137 | 'Eating-1': 'Eating 1',
138 | 'Phoning-2': 'Phoning',
139 | 'Posing-2': 'Posing',
140 | 'Purchases-2': 'Purchases',
141 | 'Sitting-2': 'Sitting',
142 | 'SittingDown-1': 'SittingDown',
143 | 'SittingDown-2': 'SittingDown 1',
144 | 'Smoking-2': 'Smoking',
145 | 'TakingPhoto-1': 'Photo',
146 | 'TakingPhoto-2': 'Photo 2',
147 | 'Waiting-2': 'Waiting 2',
148 | 'Walking-2': 'Walking',
149 | 'WalkingDog-1': 'WalkDog 1',
150 | 'WalkingDog-2': 'WalkDog',
151 | 'WalkingTogether-1': 'WalkTogether 1',
152 | 'WalkingTogether-2': 'WalkTogether'
153 | },
154 | }
155 |
--------------------------------------------------------------------------------
/mvn/datasets/human36m_preprocessing/action_to_una_dinosauria.py:
--------------------------------------------------------------------------------
1 | action_to_una_dinosauria = {
2 | 'S11': {
3 | 'Directions-2': 'Directions',
4 | 'Discussion-2': 'Discussion 2',
5 | 'Eating-2': 'Eating',
6 | 'Greeting-1': 'Greeting 2',
7 | 'Greeting-2': 'Greeting',
8 | 'Phoning-1': 'Phoning 3',
9 | 'Phoning-2': 'Phoning 2',
10 | 'Posing-2': 'Posing',
11 | 'Purchases-2': 'Purchases',
12 | 'Sitting-2': 'Sitting',
13 | 'SittingDown-1': 'SittingDown',
14 | 'SittingDown-2': 'SittingDown 1',
15 | 'Smoking-1': 'Smoking 2',
16 | 'Smoking-2': 'Smoking',
17 | 'TakingPhoto-1': 'Photo 1',
18 | 'TakingPhoto-2': 'Photo',
19 | 'Waiting-2': 'Waiting',
20 | 'Walking-2': 'Walking',
21 | 'WalkingDog-1': 'WalkDog 1',
22 | 'WalkingDog-2': 'WalkDog',
23 | 'WalkingTogether-1': 'WalkTogether 1',
24 | 'WalkingTogether-2': 'WalkTogether'
25 | },
26 | 'S9': {
27 | 'Directions-2': 'Directions',
28 | 'Discussion-2': 'Discussion 2',
29 | 'Eating-2': 'Eating',
30 | 'Greeting-2': 'Greeting',
31 | 'Phoning-1': 'Phoning 1',
32 | 'Phoning-2': 'Phoning',
33 | 'Posing-2': 'Posing',
34 | 'Purchases-2': 'Purchases',
35 | 'Sitting-2': 'Sitting',
36 | 'SittingDown-1': 'SittingDown',
37 | 'SittingDown-2': 'SittingDown 1',
38 | 'Smoking-2': 'Smoking',
39 | 'TakingPhoto-1': 'Photo 1',
40 | 'TakingPhoto-2': 'Photo',
41 | 'Waiting-2': 'Waiting',
42 | 'Walking-2': 'Walking',
43 | 'WalkingDog-1': 'WalkDog 1',
44 | 'WalkingDog-2': 'WalkDog',
45 | 'WalkingTogether-1': 'WalkTogether 1',
46 | 'WalkingTogether-2': 'WalkTogether'
47 | },
48 | 'S8': {
49 | 'Directions-2': 'Directions',
50 | 'Discussion-2': 'Discussion',
51 | 'Eating-2': 'Eating',
52 | 'Greeting-2': 'Greeting',
53 | 'Phoning-1': 'Phoning 1',
54 | 'Phoning-2': 'Phoning',
55 | 'Posing-2': 'Posing',
56 | 'Purchases-2': 'Purchases',
57 | 'Sitting-2': 'Sitting',
58 | 'SittingDown-1': 'SittingDown',
59 | 'SittingDown-2': 'SittingDown 1',
60 | 'Smoking-2': 'Smoking',
61 | 'TakingPhoto-1': 'Photo 1',
62 | 'TakingPhoto-2': 'Photo',
63 | 'Waiting-2': 'Waiting',
64 | 'Walking-2': 'Walking',
65 | 'WalkingDog-1': 'WalkDog 1',
66 | 'WalkingDog-2': 'WalkDog',
67 | 'WalkingTogether-1': 'WalkTogether 1',
68 | 'WalkingTogether-2': 'WalkTogether 2'
69 | },
70 | 'S7': {
71 | 'Directions-2': 'Directions',
72 | 'Discussion-2': 'Discussion',
73 | 'Eating-2': 'Eating',
74 | 'Greeting-2': 'Greeting',
75 | 'Phoning-1': 'Phoning 2',
76 | 'Phoning-2': 'Phoning',
77 | 'Posing-2': 'Posing',
78 | 'Purchases-2': 'Purchases',
79 | 'Sitting-2': 'Sitting',
80 | 'SittingDown-1': 'SittingDown',
81 | 'SittingDown-2': 'SittingDown 1',
82 | 'Smoking-2': 'Smoking',
83 | 'TakingPhoto-1': 'Photo',
84 | 'TakingPhoto-2': 'Photo 1',
85 | 'WalkingDog-1': 'WalkDog 1',
86 | 'WalkingDog-2': 'WalkDog',
87 | 'WalkingTogether-1': 'WalkTogether 1',
88 | 'WalkingTogether-2': 'WalkTogether'
89 | },
90 | 'S6': {
91 | 'Directions-2': 'Directions',
92 | 'Discussion-1': 'Discussion 1',
93 | 'Discussion-2': 'Discussion',
94 | 'Eating-2': 'Eating 2',
95 | 'Eating-1': 'Eating 1',
96 | 'Greeting-2': 'Greeting',
97 | 'Phoning-2': 'Phoning',
98 | 'Posing-1': 'Posing 2',
99 | 'Posing-2': 'Posing',
100 | 'Purchases-2': 'Purchases',
101 | 'SittingDown-1': 'SittingDown 1',
102 | 'SittingDown-2': 'SittingDown',
103 | 'Smoking-2': 'Smoking',
104 | 'TakingPhoto-1': 'Photo',
105 | 'TakingPhoto-2': 'Photo 1',
106 | 'Waiting-1': 'Waiting 3',
107 | 'Waiting-2': 'Waiting',
108 | 'Walking-2': 'Walking',
109 | 'WalkingDog-1': 'WalkDog 1',
110 | 'WalkingDog-2': 'WalkDog',
111 | 'WalkingTogether-1': 'WalkTogether 1',
112 | 'WalkingTogether-2': 'WalkTogether'
113 | },
114 | 'S1': {
115 | 'Discussion-2': 'Discussion',
116 | 'Directions-2': 'Directions',
117 | 'Eating-2': 'Eating',
118 | 'Eating-1': 'Eating 2',
119 | 'Greeting-2': 'Greeting',
120 | 'Phoning-2': 'Phoning',
121 | 'Posing-2': 'Posing',
122 | 'Purchases-2': 'Purchases',
123 | 'SittingDown-1': 'SittingDown 2',
124 | 'SittingDown-2': 'SittingDown',
125 | 'Smoking-2': 'Smoking',
126 | 'TakingPhoto-1': 'Photo 1',
127 | 'TakingPhoto-2': 'Photo',
128 | 'Waiting-2': 'Waiting',
129 | 'Walking-2': 'Walking',
130 | 'WalkingDog-1': 'WalkDog 1',
131 | 'WalkingDog-2': 'WalkDog',
132 | 'WalkingTogether-1': 'WalkTogether 1',
133 | 'WalkingTogether-2': 'WalkTogether'
134 | },
135 | 'S5': {
136 | 'Discussion-1': 'Discussion 2',
137 | 'Discussion-2': 'Discussion 3',
138 | 'Eating-2': 'Eating',
139 | 'Eating-1': 'Eating 1',
140 | 'Phoning-2': 'Phoning',
141 | 'Posing-2': 'Posing',
142 | 'Purchases-2': 'Purchases',
143 | 'Sitting-2': 'Sitting',
144 | 'SittingDown-1': 'SittingDown',
145 | 'SittingDown-2': 'SittingDown 1',
146 | 'Smoking-2': 'Smoking',
147 | 'TakingPhoto-1': 'Photo',
148 | 'TakingPhoto-2': 'Photo 2',
149 | 'Waiting-2': 'Waiting 2',
150 | 'Walking-2': 'Walking',
151 | 'WalkingDog-1': 'WalkDog 1',
152 | 'WalkingDog-2': 'WalkDog',
153 | 'WalkingTogether-1': 'WalkTogether 1',
154 | 'WalkingTogether-2': 'WalkTogether'
155 | },
156 | }
157 |
--------------------------------------------------------------------------------
/mvn/datasets/human36m_preprocessing/collect-bboxes.py:
--------------------------------------------------------------------------------
1 | """
2 | Read bbox *.mat files from Human3.6M and convert them to a single *.npy file.
3 | Example of an original bbox file:
4 | /S1/MySegmentsMat/ground_truth_bb/WalkingDog 1.54138969.mat
5 |
6 | Usage:
7 | python3 collect-bboxes.py
8 | """
9 | import os, sys
10 | import numpy as np
11 | import h5py
12 |
13 | dataset_root = sys.argv[1]
14 | data_path = os.path.join(dataset_root, "processed")
15 | subjects = [x for x in os.listdir(data_path) if x.startswith('S')]
16 | assert len(subjects) == 7
17 |
18 | destination_dir = os.path.join(dataset_root, "extra")
19 | os.makedirs(destination_dir, exist_ok=True)
20 | destination_file_path = os.path.join(destination_dir, "bboxes-Human36M-GT.npy")
21 |
22 | # Some bbox files do not exist, can be misaligned, damaged etc.
23 | from action_to_bbox_filename import action_to_bbox_filename
24 |
25 | from collections import defaultdict
26 | nesteddict = lambda: defaultdict(nesteddict)
27 |
28 | bboxes_retval = nesteddict()
29 |
30 | def load_bboxes(data_path, subject, action, camera):
31 | print(subject, action, camera)
32 |
33 | def mask_to_bbox(mask):
34 | h_mask = mask.max(0)
35 | w_mask = mask.max(1)
36 |
37 | top = h_mask.argmax()
38 | bottom = len(h_mask) - h_mask[::-1].argmax()
39 |
40 | left = w_mask.argmax()
41 | right = len(w_mask) - w_mask[::-1].argmax()
42 |
43 | return top, left, bottom, right
44 |
45 | try:
46 | try:
47 | corrected_action = action_to_bbox_filename[subject][action]
48 | except KeyError:
49 | corrected_action = action.replace('-', ' ')
50 |
51 | # TODO use pathlib
52 | bboxes_path = os.path.join(
53 | data_path,
54 | subject,
55 | 'MySegmentsMat',
56 | 'ground_truth_bb',
57 | '%s.%s.mat' % (corrected_action, camera))
58 |
59 | with h5py.File(bboxes_path, 'r') as h5file:
60 | retval = np.empty((len(h5file['Masks']), 4), dtype=np.int32)
61 |
62 | for frame_idx, mask_reference in enumerate(h5file['Masks'][:,0]):
63 | bbox_mask = np.array(h5file[mask_reference])
64 | retval[frame_idx] = mask_to_bbox(bbox_mask)
65 |
66 | top, left, bottom, right = retval[frame_idx]
67 | if right-left < 2 or bottom-top < 2:
68 | raise Exception(str(bboxes_path) + ' $ ' + str(frame_idx))
69 | except Exception as ex:
70 | # reraise with path information
71 | raise Exception(str(ex) + '; %s %s %s' % (subject, action, camera))
72 |
73 | return retval, subject, action, camera
74 |
75 | # retval['S1']['Talking-1']['54534623'].shape = (n_frames, 4) # top, left, bottom, right
76 | def add_result_to_retval(args):
77 | bboxes, subject, action, camera = args
78 | bboxes_retval[subject][action][camera] = bboxes
79 |
80 | import multiprocessing
81 | num_processes = int(sys.argv[2])
82 | pool = multiprocessing.Pool(num_processes)
83 | async_errors = []
84 |
85 | for subject in subjects:
86 | subject_path = os.path.join(data_path, subject)
87 | actions = os.listdir(subject_path)
88 | try:
89 | actions.remove('MySegmentsMat') # folder with bbox *.mat files
90 | except ValueError:
91 | pass
92 |
93 | for action in actions:
94 | cameras = '54138969', '55011271', '58860488', '60457274'
95 |
96 | for camera in cameras:
97 | async_result = pool.apply_async(
98 | load_bboxes,
99 | args=(data_path, subject, action, camera),
100 | callback=add_result_to_retval)
101 | async_errors.append(async_result)
102 |
103 | pool.close()
104 | pool.join()
105 |
106 | # raise any exceptions from pool's processes
107 | for async_result in async_errors:
108 | async_result.get()
109 |
110 | def freeze_defaultdict(x):
111 | x.default_factory = None
112 | for value in x.values():
113 | if type(value) is defaultdict:
114 | freeze_defaultdict(value)
115 |
116 | # convert to normal dict
117 | freeze_defaultdict(bboxes_retval)
118 | np.save(destination_file_path, bboxes_retval)
119 |
--------------------------------------------------------------------------------
/mvn/datasets/human36m_preprocessing/generate-labels-npy-multiview.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate 'labels.npy' for multiview 'human36m.py'
3 | from https://github.sec.samsung.net/RRU8-VIOLET/multi-view-net/
4 |
5 | Usage: `python3 generate-labels-npy-multiview.py `
6 | """
7 | import os, sys
8 | import numpy as np
9 | import h5py
10 |
11 | # Change this line if you want to use Mask-RCNN or SSD bounding boxes instead of H36M's "ground truth".
12 | BBOXES_SOURCE = 'GT' # or 'MRCNN' or 'SSD'
13 |
14 | retval = {
15 | 'subject_names': ['S1', 'S5', 'S6', 'S7', 'S8', 'S9', 'S11'],
16 | 'camera_names': ['54138969', '55011271', '58860488', '60457274'],
17 | 'action_names': [
18 | 'Directions-1', 'Directions-2',
19 | 'Discussion-1', 'Discussion-2',
20 | 'Eating-1', 'Eating-2',
21 | 'Greeting-1', 'Greeting-2',
22 | 'Phoning-1', 'Phoning-2',
23 | 'Posing-1', 'Posing-2',
24 | 'Purchases-1', 'Purchases-2',
25 | 'Sitting-1', 'Sitting-2',
26 | 'SittingDown-1', 'SittingDown-2',
27 | 'Smoking-1', 'Smoking-2',
28 | 'TakingPhoto-1', 'TakingPhoto-2',
29 | 'Waiting-1', 'Waiting-2',
30 | 'Walking-1', 'Walking-2',
31 | 'WalkingDog-1', 'WalkingDog-2',
32 | 'WalkingTogether-1', 'WalkingTogether-2']
33 | }
34 | retval['cameras'] = np.empty(
35 | (len(retval['subject_names']), len(retval['camera_names'])),
36 | dtype=[
37 | ('R', np.float32, (3,3)),
38 | ('t', np.float32, (3,1)),
39 | ('K', np.float32, (3,3)),
40 | ('dist', np.float32, 5)
41 | ]
42 | )
43 |
44 | table_dtype = np.dtype([
45 | ('subject_idx', np.int8),
46 | ('action_idx', np.int8),
47 | ('frame_idx', np.int16),
48 | ('keypoints', np.float32, (17,3)), # roughly MPII format
49 | ('bbox_by_camera_tlbr', np.int16, (len(retval['camera_names']),4))
50 | ])
51 | retval['table'] = []
52 |
53 | h36m_root = sys.argv[1]
54 |
55 | destination_file_path = os.path.join(h36m_root, "extra", f"human36m-multiview-labels-{BBOXES_SOURCE}bboxes.npy")
56 |
57 | una_dinosauria_root = sys.argv[2]
58 | cameras_params = h5py.File(os.path.join(una_dinosauria_root, 'cameras.h5'), 'r')
59 |
60 | # Fill retval['cameras']
61 | for subject_idx, subject in enumerate(retval['subject_names']):
62 | for camera_idx, camera in enumerate(retval['camera_names']):
63 | assert len(cameras_params[subject.replace('S', 'subject')]) == 4
64 | camera_params = cameras_params[subject.replace('S', 'subject')]['camera%d' % (camera_idx+1)]
65 | camera_retval = retval['cameras'][subject_idx][camera_idx]
66 |
67 | def camera_array_to_name(array):
68 | return ''.join(chr(int(x[0])) for x in array)
69 | assert camera_array_to_name(camera_params['Name']) == camera
70 |
71 | camera_retval['R'] = np.array(camera_params['R']).T
72 | camera_retval['t'] = -camera_retval['R'] @ camera_params['T']
73 |
74 | camera_retval['K'] = 0
75 | camera_retval['K'][:2, 2] = camera_params['c'][:, 0]
76 | camera_retval['K'][0, 0] = camera_params['f'][0]
77 | camera_retval['K'][1, 1] = camera_params['f'][1]
78 | camera_retval['K'][2, 2] = 1.0
79 |
80 | camera_retval['dist'][:2] = camera_params['k'][:2, 0]
81 | camera_retval['dist'][2:4] = camera_params['p'][:, 0]
82 | camera_retval['dist'][4] = camera_params['k'][2, 0]
83 |
84 | # Fill bounding boxes
85 | bboxes = np.load(sys.argv[3], allow_pickle=True).item()
86 |
87 | def square_the_bbox(bbox):
88 | top, left, bottom, right = bbox
89 | width = right - left
90 | height = bottom - top
91 |
92 | if height < width:
93 | center = (top + bottom) * 0.5
94 | top = int(round(center - width * 0.5))
95 | bottom = top + width
96 | else:
97 | center = (left + right) * 0.5
98 | left = int(round(center - height * 0.5))
99 | right = left + height
100 |
101 | return top, left, bottom, right
102 |
103 | for subject in bboxes.keys():
104 | for action in bboxes[subject].keys():
105 | for camera, bbox_array in bboxes[subject][action].items():
106 | for frame_idx, bbox in enumerate(bbox_array):
107 | bbox[:] = square_the_bbox(bbox)
108 |
109 | if BBOXES_SOURCE is not 'GT':
110 | def replace_gt_bboxes_with_cnn(bboxes_gt, bboxes_detected_path, detections_file_list):
111 | """
112 | Replace ground truth bounding boxes with boxes from a CNN detector.
113 | """
114 | with open(bboxes_detected_path, 'r') as f:
115 | import json
116 | bboxes_detected = json.load(f)
117 |
118 | with open(detections_file_list, 'r') as f:
119 | for bbox, filename in zip(bboxes_detected, f):
120 | # parse filename
121 | filename = filename.strip()
122 | filename, frame_idx = filename[:-15], int(filename[-10:-4])-1
123 | filename, camera_name = filename[:-23], filename[-8:]
124 | slash_idx = filename.rfind('/')
125 | filename, action_name = filename[:slash_idx], filename[slash_idx+1:]
126 | subject_name = filename[filename.rfind('/')+1:]
127 |
128 | bbox, _ = bbox[:4], bbox[4] # throw confidence away
129 | bbox = square_the_bbox([bbox[1], bbox[0], bbox[3]+1, bbox[2]+1]) # LTRB to TLBR
130 | bboxes_gt[subject_name][action_name][camera_name][frame_idx] = bbox
131 |
132 | detections_paths = {
133 | 'MRCNN': {
134 | 'train': "/Vol1/dbstore/datasets/Human3.6M/extra/train_human36m_MRCNN.json",
135 | 'test': "/Vol1/dbstore/datasets/Human3.6M/extra/test_human36m_MRCNN.json"
136 | },
137 | 'SSD': {
138 | 'train': "/Vol1/dbstore/datasets/k.iskakov/share/ssd-detections-train-human36m.json",
139 | 'test': "/Vol1/dbstore/datasets/k.iskakov/share/ssd-detections-human36m.json"
140 | }
141 | }
142 |
143 | replace_gt_bboxes_with_cnn(
144 | bboxes,
145 | detections_paths[BBOXES_SOURCE]['train'],
146 | "/Vol1/dbstore/datasets/Human3.6M/train-images-list.txt")
147 |
148 | replace_gt_bboxes_with_cnn(
149 | bboxes,
150 | detections_paths[BBOXES_SOURCE]['test'],
151 | "/Vol1/dbstore/datasets/Human3.6M/test-images-list.txt")
152 |
153 | # fill retval['table']
154 | from action_to_una_dinosauria import action_to_una_dinosauria
155 |
156 | for subject_idx, subject in enumerate(retval['subject_names']):
157 | subject_path = os.path.join(h36m_root, "processed", subject)
158 | actions = os.listdir(subject_path)
159 | try:
160 | actions.remove('MySegmentsMat') # folder with bbox *.mat files
161 | except ValueError:
162 | pass
163 |
164 | for action_idx, action in enumerate(retval['action_names']):
165 | action_path = os.path.join(subject_path, action, 'imageSequence')
166 | if not os.path.isdir(action_path):
167 | raise FileNotFoundError(action_path)
168 |
169 | for camera_idx, camera in enumerate(retval['camera_names']):
170 | camera_path = os.path.join(action_path, camera)
171 | if os.path.isdir(camera_path):
172 | frame_idxs = sorted([int(name[4:-4])-1 for name in os.listdir(camera_path)])
173 | assert len(frame_idxs) > 15, 'Too few frames in %s' % camera_path # otherwise WTF
174 | break
175 | else:
176 | raise FileNotFoundError(action_path)
177 |
178 | # 16 joints in MPII order + "Neck/Nose"
179 | valid_joints = (3,2,1,6,7,8,0,12,13,15,27,26,25,17,18,19) + (14,)
180 | with h5py.File(os.path.join(una_dinosauria_root, subject, 'MyPoses', '3D_positions',
181 | '%s.h5' % action_to_una_dinosauria[subject].get(action, action.replace('-', ' '))), 'r') as poses_file:
182 | poses_world = np.array(poses_file['3D_positions']).T.reshape(-1, 32, 3)[frame_idxs][:, valid_joints]
183 |
184 | table_segment = np.empty(len(frame_idxs), dtype=table_dtype)
185 | table_segment['subject_idx'] = subject_idx
186 | table_segment['action_idx'] = action_idx
187 | table_segment['frame_idx'] = frame_idxs
188 | table_segment['keypoints'] = poses_world
189 | table_segment['bbox_by_camera_tlbr'] = 0 # let a (0,0,0,0) bbox mean that this view is missing
190 |
191 | for (camera_idx, camera) in enumerate(retval['camera_names']):
192 | camera_path = os.path.join(action_path, camera)
193 | if not os.path.isdir(camera_path):
194 | print('Warning: camera %s isn\'t present in %s/%s' % (camera, subject, action))
195 | continue
196 |
197 | for bbox, frame_idx in zip(table_segment['bbox_by_camera_tlbr'], frame_idxs):
198 | bbox[camera_idx] = bboxes[subject][action][camera][frame_idx]
199 |
200 | retval['table'].append(table_segment)
201 |
202 | retval['table'] = np.concatenate(retval['table'])
203 | assert retval['table'].ndim == 1
204 |
205 | print("Total frames in Human3.6Million:", len(retval['table']))
206 | np.save(destination_file_path, retval)
207 |
--------------------------------------------------------------------------------
/mvn/datasets/human36m_preprocessing/undistort-h36m.py:
--------------------------------------------------------------------------------
1 | """
2 | Undistort images in Human3.6M and save them alongside (in ".../imageSequence-undistorted/...").
3 |
4 | Usage: `python3 undistort-h36m.py `
5 | """
6 | import torch
7 | import numpy as np
8 | import cv2
9 | from tqdm import tqdm
10 |
11 | import os, sys
12 |
13 | sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../.."))
14 | from mvn.datasets.human36m import Human36MMultiViewDataset
15 |
16 | h36m_root = os.path.join(sys.argv[1], "processed")
17 | labels_multiview_npy_path = sys.argv[2]
18 | number_of_processes = int(sys.argv[3])
19 |
20 | dataset = Human36MMultiViewDataset(
21 | h36m_root,
22 | labels_multiview_npy_path,
23 | train=True, # include all possible data
24 | test=True,
25 | image_shape=None, # don't resize
26 | retain_every_n_frames_in_test=1, # yes actually ALL possible data
27 | with_damaged_actions=True, # I said ALL DATA
28 | kind="mpii",
29 | norm_image=False, # don't do unnecessary image processing
30 | crop=False) # don't crop
31 | print("Dataset length:", len(dataset))
32 |
33 | n_subjects = len(dataset.labels['subject_names'])
34 | n_cameras = len(dataset.labels['camera_names'])
35 |
36 | # First, prepare: compute distorted meshgrids
37 | print("Computing distorted meshgrids")
38 | meshgrids = np.empty((n_subjects, n_cameras), dtype=object)
39 |
40 | for sample_idx in tqdm(range(len(dataset))):
41 | subject_idx = dataset.labels['table']['subject_idx'][sample_idx]
42 |
43 | if not meshgrids[subject_idx].any():
44 | bboxes = dataset.labels['table']['bbox_by_camera_tlbr'][sample_idx]
45 |
46 | if (bboxes[:, 2] - bboxes[:, 0]).min() > 0: # if == 0, then some camera is missing
47 | sample = dataset[sample_idx]
48 | assert len(sample['images']) == n_cameras
49 |
50 | for camera_idx, (camera, image) in enumerate(zip(sample['cameras'], sample['images'])):
51 | h, w = image.shape[:2]
52 |
53 | fx, fy = camera.K[0, 0], camera.K[1, 1]
54 | cx, cy = camera.K[0, 2], camera.K[1, 2]
55 |
56 | grid_x = (np.arange(w, dtype=np.float32) - cx) / fx
57 | grid_y = (np.arange(h, dtype=np.float32) - cy) / fy
58 | meshgrid = np.stack(np.meshgrid(grid_x, grid_y), axis=2).reshape(-1, 2)
59 |
60 | # distort meshgrid points
61 | k = camera.dist[:3].copy(); k[2] = camera.dist[-1]
62 | p = camera.dist[2:4].copy()
63 |
64 | r2 = meshgrid[:, 0] ** 2 + meshgrid[:, 1] ** 2
65 | radial = meshgrid * (1 + k[0] * r2 + k[1] * r2**2 + k[2] * r2**3).reshape(-1, 1)
66 | tangential_1 = p.reshape(1, 2) * np.broadcast_to(meshgrid[:, 0:1] * meshgrid[:, 1:2], (len(meshgrid), 2))
67 | tangential_2 = p[::-1].reshape(1, 2) * (meshgrid**2 + np.broadcast_to(r2.reshape(-1, 1), (len(meshgrid), 2)))
68 |
69 | meshgrid = radial + tangential_1 + tangential_2
70 |
71 | # move back to screen coordinates
72 | meshgrid *= np.array([fx, fy]).reshape(1, 2)
73 | meshgrid += np.array([cx, cy]).reshape(1, 2)
74 |
75 | # cache (save) distortion maps
76 | meshgrids[subject_idx, camera_idx] = cv2.convertMaps(meshgrid.reshape((h, w, 2)), None, cv2.CV_16SC2)
77 |
78 | # Now the main part: undistort images
79 | def undistort_and_save(idx):
80 | sample = dataset[idx]
81 |
82 | shot = dataset.labels['table'][idx]
83 | subject_idx = shot['subject_idx']
84 | action_idx = shot['action_idx']
85 | frame_idx = shot['frame_idx']
86 |
87 | subject = dataset.labels['subject_names'][subject_idx]
88 | action = dataset.labels['action_names'][action_idx]
89 |
90 | available_cameras = list(range(len(dataset.labels['action_names'])))
91 | for camera_idx, bbox in enumerate(shot['bbox_by_camera_tlbr']):
92 | if bbox[2] == bbox[0]: # bbox is empty, which means that this camera is missing
93 | available_cameras.remove(camera_idx)
94 |
95 | for camera_idx, image in zip(available_cameras, sample['images']):
96 | camera_name = dataset.labels['camera_names'][camera_idx]
97 |
98 | output_image_folder = os.path.join(
99 | h36m_root, subject, action, 'imageSequence-undistorted', camera_name)
100 | output_image_path = os.path.join(output_image_folder, 'img_%06d.jpg' % (frame_idx+1))
101 | os.makedirs(output_image_folder, exist_ok=True)
102 |
103 | meshgrid_int16 = meshgrids[subject_idx, camera_idx]
104 | image_undistorted = cv2.remap(image, *meshgrid_int16, cv2.INTER_CUBIC)
105 |
106 | cv2.imwrite(output_image_path, image_undistorted)
107 |
108 | print(f"Undistorting images using {number_of_processes} parallel processes")
109 | cv2.setNumThreads(1)
110 | import multiprocessing
111 |
112 | pool = multiprocessing.Pool(number_of_processes)
113 | for _ in tqdm(pool.imap_unordered(
114 | undistort_and_save, range(len(dataset)), chunksize=10), total=len(dataset)):
115 | pass
116 |
117 | pool.close()
118 | pool.join()
119 |
--------------------------------------------------------------------------------
/mvn/datasets/human36m_preprocessing/view-dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | A GUI script for inspecting Human3.6M and its wrappers, namely:
3 | - `Human36MMultiViewDataset`
4 | - `human36m-multiview-labels-**bboxes.npy`
5 |
6 | Usage: `python3 view-dataset.py
7 | """
8 | import torch
9 | import numpy as np
10 | import cv2
11 |
12 | import os, sys
13 |
14 | h36m_root = sys.argv[1]
15 | labels_path = sys.argv[2]
16 |
17 | try: sample_idx = int(sys.argv[3])
18 | except: sample_idx = 0
19 |
20 | try: step = int(sys.argv[4])
21 | except: step = 10
22 |
23 | sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../.."))
24 | from mvn.datasets.human36m import Human36MMultiViewDataset
25 |
26 | dataset = Human36MMultiViewDataset(
27 | h36m_root,
28 | labels_path,
29 | train=True,
30 | test=True,
31 | image_shape=(512,512),
32 | retain_every_n_frames_in_test=1,
33 | with_damaged_actions=True,
34 | scale_bbox=1.0,
35 | kind='human36m',
36 | norm_image=False,
37 | undistort_images=True,
38 | ignore_cameras=[])
39 | print(len(dataset))
40 |
41 | prev_action = None
42 | patience = 0
43 |
44 | while True:
45 | sample = dataset[sample_idx]
46 |
47 | camera_idx = 0
48 | image = sample['images'][camera_idx]
49 | camera = sample['cameras'][camera_idx]
50 |
51 | display = image.copy()
52 |
53 | from mvn.utils.multiview import project_3d_points_to_image_plane_without_distortion as project
54 | keypoints_2d = project(camera.projection, sample['keypoints_3d'][:, :3])
55 |
56 | for i,(x,y) in enumerate(keypoints_2d):
57 | cv2.circle(display, (int(x), int(y)), 3, (0,0,255), -1)
58 | # cv2.putText(display, str(i), (int(x)+3, int(y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255))
59 |
60 | # Get window name
61 | sample_info = dataset.labels['table'][sample_idx]
62 | subject_name = dataset.labels['subject_names'][sample_info['subject_idx']]
63 | action_name = dataset.labels['action_names'][sample_info['action_idx']]
64 | camera_name = dataset.labels['camera_names'][camera_idx]
65 | frame_idx = sample_info['frame_idx']
66 |
67 | cv2.imshow('w', display)
68 | cv2.setWindowTitle('w', f"{subject_name}/{action_name}/{camera_name}/{frame_idx}")
69 | cv2.waitKey(0)
70 |
71 | action = dataset.labels['table'][sample_idx]['action_idx']
72 | if action != prev_action: # started a new action
73 | prev_action = action
74 | patience = 2000
75 | sample_idx += step
76 | elif patience == 0: # an action ended, jump to the start of new action
77 | while True:
78 | sample_idx += step
79 | action = dataset.labels['table'][sample_idx]['action_idx']
80 | if action != prev_action:
81 | break
82 | else: # in progess, just increment sample_idx
83 | patience -= 1
84 | sample_idx += step
85 |
--------------------------------------------------------------------------------
/mvn/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from mvn.utils.img import image_batch_to_torch
5 |
6 | def make_collate_fn(randomize_n_views=True, min_n_views=10, max_n_views=31):
7 |
8 | def collate_fn(items):
9 | items = list(filter(lambda x: x is not None, items))
10 | if len(items) == 0:
11 | print("All items in batch are None")
12 | return None
13 |
14 | batch = dict()
15 | total_n_views = min(len(item['images']) for item in items)
16 |
17 | indexes = np.arange(total_n_views)
18 | if randomize_n_views:
19 | n_views = np.random.randint(min_n_views, min(total_n_views, max_n_views) + 1)
20 | indexes = np.random.choice(np.arange(total_n_views), size=n_views, replace=False)
21 | else:
22 | indexes = np.arange(total_n_views)
23 |
24 | batch['images'] = np.stack([np.stack([item['images'][i] for item in items], axis=0) for i in indexes], axis=0).swapaxes(0, 1)
25 | batch['detections'] = np.array([[item['detections'][i] for item in items] for i in indexes]).swapaxes(0, 1)
26 | batch['cameras'] = [[item['cameras'][i] for item in items] for i in indexes]
27 |
28 | batch['keypoints_3d'] = [item['keypoints_3d'] for item in items]
29 | # batch['cuboids'] = [item['cuboids'] for item in items]
30 | batch['indexes'] = [item['indexes'] for item in items]
31 |
32 | try:
33 | batch['pred_keypoints_3d'] = np.array([item['pred_keypoints_3d'] for item in items])
34 | except:
35 | pass
36 |
37 | return batch
38 |
39 | return collate_fn
40 |
41 |
42 | def worker_init_fn(worker_id):
43 | np.random.seed(np.random.get_state()[1][0] + worker_id)
44 |
45 | def prepare_batch(batch, device, config, is_train=True):
46 | # images
47 | images_batch = []
48 | for image_batch in batch['images']:
49 | image_batch = image_batch_to_torch(image_batch)
50 | image_batch = image_batch.to(device)
51 | images_batch.append(image_batch)
52 |
53 | images_batch = torch.stack(images_batch, dim=0)
54 |
55 | # 3D keypoints
56 | keypoints_3d_batch_gt = torch.from_numpy(np.stack(batch['keypoints_3d'], axis=0)[:, :, :3]).float().to(device)
57 |
58 | # 3D keypoints validity
59 | keypoints_3d_validity_batch_gt = torch.from_numpy(np.stack(batch['keypoints_3d'], axis=0)[:, :, 3:]).float().to(device)
60 |
61 | # projection matricies
62 | proj_matricies_batch = torch.stack([torch.stack([torch.from_numpy(camera.projection) for camera in camera_batch], dim=0) for camera_batch in batch['cameras']], dim=0).transpose(1, 0) # shape (batch_size, n_views, 3, 4)
63 | proj_matricies_batch = proj_matricies_batch.float().to(device)
64 |
65 | return images_batch, keypoints_3d_batch_gt, keypoints_3d_validity_batch_gt, proj_matricies_batch
66 |
--------------------------------------------------------------------------------
/mvn/models/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
--------------------------------------------------------------------------------
/mvn/models/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class KeypointsMSELoss(nn.Module):
8 | def __init__(self):
9 | super().__init__()
10 |
11 | def forward(self, keypoints_pred, keypoints_gt, keypoints_binary_validity):
12 | dimension = keypoints_pred.shape[-1]
13 | loss = torch.sum((keypoints_gt - keypoints_pred) ** 2 * keypoints_binary_validity)
14 | loss = loss / (dimension * max(1, torch.sum(keypoints_binary_validity).item()))
15 | return loss
16 |
17 | class KeypointsMSESmoothLoss(nn.Module):
18 | def __init__(self, threshold=400):
19 | super().__init__()
20 |
21 | self.threshold = threshold
22 |
23 | def forward(self, keypoints_pred, keypoints_gt, keypoints_binary_validity):
24 | dimension = keypoints_pred.shape[-1]
25 | diff = (keypoints_gt - keypoints_pred) ** 2 * keypoints_binary_validity
26 | diff[diff > self.threshold] = torch.pow(diff[diff > self.threshold], 0.1) * (self.threshold ** 0.9)
27 | loss = torch.sum(diff) / (dimension * max(1, torch.sum(keypoints_binary_validity).item()))
28 | return loss
29 |
30 |
31 | class KeypointsMAELoss(nn.Module):
32 | def __init__(self):
33 | super().__init__()
34 |
35 | def forward(self, keypoints_pred, keypoints_gt, keypoints_binary_validity):
36 | dimension = keypoints_pred.shape[-1]
37 | loss = torch.sum(torch.abs(keypoints_gt - keypoints_pred) * keypoints_binary_validity)
38 | loss = loss / (dimension * max(1, torch.sum(keypoints_binary_validity).item()))
39 | return loss
40 |
41 |
42 | class KeypointsL2Loss(nn.Module):
43 | def __init__(self):
44 | super().__init__()
45 |
46 | def forward(self, keypoints_pred, keypoints_gt, keypoints_binary_validity):
47 | loss = torch.sum(torch.sqrt(torch.sum((keypoints_gt - keypoints_pred) ** 2 * keypoints_binary_validity, dim=2)))
48 | loss = loss / max(1, torch.sum(keypoints_binary_validity).item())
49 | return loss
50 |
51 |
52 | class VolumetricCELoss(nn.Module):
53 | def __init__(self):
54 | super().__init__()
55 |
56 | def forward(self, coord_volumes_batch, volumes_batch_pred, keypoints_gt, keypoints_binary_validity):
57 | loss = 0.0
58 | n_losses = 0
59 |
60 | batch_size = volumes_batch_pred.shape[0]
61 | for batch_i in range(batch_size):
62 | coord_volume = coord_volumes_batch[batch_i]
63 | keypoints_gt_i = keypoints_gt[batch_i]
64 |
65 | coord_volume_unsq = coord_volume.unsqueeze(0)
66 | keypoints_gt_i_unsq = keypoints_gt_i.unsqueeze(1).unsqueeze(1).unsqueeze(1)
67 |
68 | dists = torch.sqrt(((coord_volume_unsq - keypoints_gt_i_unsq) ** 2).sum(-1))
69 | dists = dists.view(dists.shape[0], -1)
70 |
71 | min_indexes = torch.argmin(dists, dim=-1).detach().cpu().numpy()
72 | min_indexes = np.stack(np.unravel_index(min_indexes, volumes_batch_pred.shape[-3:]), axis=1)
73 |
74 | for joint_i, index in enumerate(min_indexes):
75 | validity = keypoints_binary_validity[batch_i, joint_i]
76 | loss += validity[0] * (-torch.log(volumes_batch_pred[batch_i, joint_i, index[0], index[1], index[2]] + 1e-6))
77 | n_losses += 1
78 |
79 |
80 | return loss / n_losses
81 |
--------------------------------------------------------------------------------
/mvn/models/pose_resnet.py:
--------------------------------------------------------------------------------
1 | # Reference: https://github.com/microsoft/human-pose-estimation.pytorch
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import os
8 | import logging
9 |
10 | import torch
11 | import torch.nn as nn
12 | from collections import OrderedDict
13 |
14 |
15 | BN_MOMENTUM = 0.1
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | def conv3x3(in_planes, out_planes, stride=1):
20 | """3x3 convolution with padding"""
21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22 | padding=1, bias=False)
23 |
24 |
25 | class BasicBlock(nn.Module):
26 | expansion = 1
27 |
28 | def __init__(self, inplanes, planes, stride=1, downsample=None):
29 | super(BasicBlock, self).__init__()
30 | self.conv1 = conv3x3(inplanes, planes, stride)
31 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
32 | self.relu = nn.ReLU(inplace=True)
33 | self.conv2 = conv3x3(planes, planes)
34 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
35 | self.downsample = downsample
36 | self.stride = stride
37 |
38 | def forward(self, x):
39 | residual = x
40 |
41 | out = self.conv1(x)
42 | out = self.bn1(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv2(out)
46 | out = self.bn2(out)
47 |
48 | if self.downsample is not None:
49 | residual = self.downsample(x)
50 |
51 | out += residual
52 | out = self.relu(out)
53 |
54 | return out
55 |
56 |
57 | class Bottleneck(nn.Module):
58 | expansion = 4
59 |
60 | def __init__(self, inplanes, planes, stride=1, downsample=None):
61 | super(Bottleneck, self).__init__()
62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
63 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
65 | padding=1, bias=False)
66 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
67 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
68 | bias=False)
69 | self.bn3 = nn.BatchNorm2d(planes * self.expansion,
70 | momentum=BN_MOMENTUM)
71 | self.relu = nn.ReLU(inplace=True)
72 | self.downsample = downsample
73 | self.stride = stride
74 |
75 | def forward(self, x):
76 | residual = x
77 |
78 | out = self.conv1(x)
79 | out = self.bn1(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv2(out)
83 | out = self.bn2(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv3(out)
87 | out = self.bn3(out)
88 |
89 | if self.downsample is not None:
90 | residual = self.downsample(x)
91 |
92 | out += residual
93 | out = self.relu(out)
94 |
95 | return out
96 |
97 |
98 | class Bottleneck_CAFFE(nn.Module):
99 | expansion = 4
100 |
101 | def __init__(self, inplanes, planes, stride=1, downsample=None):
102 | super(Bottleneck_CAFFE, self).__init__()
103 | # add stride to conv1x1
104 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
105 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
106 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
107 | padding=1, bias=False)
108 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
109 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
110 | bias=False)
111 | self.bn3 = nn.BatchNorm2d(planes * self.expansion,
112 | momentum=BN_MOMENTUM)
113 | self.relu = nn.ReLU(inplace=True)
114 | self.downsample = downsample
115 | self.stride = stride
116 |
117 | def forward(self, x):
118 | residual = x
119 |
120 | out = self.conv1(x)
121 | out = self.bn1(out)
122 | out = self.relu(out)
123 |
124 | out = self.conv2(out)
125 | out = self.bn2(out)
126 | out = self.relu(out)
127 |
128 | out = self.conv3(out)
129 | out = self.bn3(out)
130 |
131 | if self.downsample is not None:
132 | residual = self.downsample(x)
133 |
134 | out += residual
135 | out = self.relu(out)
136 |
137 | return out
138 |
139 |
140 | class GlobalAveragePoolingHead(nn.Module):
141 | def __init__(self, in_channels, n_classes):
142 | super().__init__()
143 |
144 | self.features = nn.Sequential(
145 | nn.Conv2d(in_channels, 512, 3, stride=1, padding=1),
146 | nn.BatchNorm2d(512, momentum=BN_MOMENTUM),
147 | nn.MaxPool2d(2),
148 | nn.ReLU(inplace=True),
149 |
150 | nn.Conv2d(512, 256, 3, stride=1, padding=1),
151 | nn.BatchNorm2d(256, momentum=BN_MOMENTUM),
152 | nn.MaxPool2d(2),
153 | nn.ReLU(inplace=True),
154 | )
155 |
156 | self.head = nn.Sequential(
157 | nn.Linear(256, 512),
158 | nn.ReLU(inplace=True),
159 | nn.Linear(512, 256),
160 | nn.ReLU(inplace=True),
161 | nn.Linear(256, n_classes),
162 | nn.Sigmoid()
163 | )
164 |
165 | def forward(self, x):
166 | x = self.features(x)
167 |
168 | batch_size, n_channels = x.shape[:2]
169 | x = x.view((batch_size, n_channels, -1))
170 | x = x.mean(dim=-1)
171 |
172 | out = self.head(x)
173 |
174 | return out
175 |
176 |
177 | resnet_spec = {18: (BasicBlock, [2, 2, 2, 2]),
178 | 34: (BasicBlock, [3, 4, 6, 3]),
179 | 50: (Bottleneck, [3, 4, 6, 3]),
180 | 101: (Bottleneck, [3, 4, 23, 3]),
181 | 152: (Bottleneck, [3, 8, 36, 3])}
182 |
183 |
184 | class PoseResNet(nn.Module):
185 | def __init__(self, block, layers, num_joints,
186 | num_input_channels=3,
187 | deconv_with_bias=False,
188 | num_deconv_layers=3,
189 | num_deconv_filters=(256, 256, 256),
190 | num_deconv_kernels=(4, 4, 4),
191 | final_conv_kernel=1,
192 | alg_confidences=False,
193 | vol_confidences=False
194 | ):
195 | super().__init__()
196 |
197 | self.num_joints = num_joints
198 | self.num_input_channels = num_input_channels
199 | self.inplanes = 64
200 |
201 | self.deconv_with_bias = deconv_with_bias
202 | self.num_deconv_layers, self.num_deconv_filters, self.num_deconv_kernels = num_deconv_layers, num_deconv_filters, num_deconv_kernels
203 | self.final_conv_kernel = final_conv_kernel
204 |
205 | self.conv1 = nn.Conv2d(num_input_channels, 64, kernel_size=7, stride=2, padding=3,
206 | bias=False)
207 | self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
208 | self.relu = nn.ReLU(inplace=True)
209 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
210 | self.layer1 = self._make_layer(block, 64, layers[0])
211 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
212 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
213 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
214 |
215 | if alg_confidences:
216 | self.alg_confidences = GlobalAveragePoolingHead(512 * block.expansion, num_joints)
217 |
218 | if vol_confidences:
219 | self.vol_confidences = GlobalAveragePoolingHead(512 * block.expansion, 32)
220 |
221 | # used for deconv layers
222 | self.deconv_layers = self._make_deconv_layer(
223 | self.num_deconv_layers,
224 | self.num_deconv_filters,
225 | self.num_deconv_kernels,
226 | )
227 |
228 | self.final_layer = nn.Conv2d(
229 | in_channels=self.num_deconv_filters[-1],
230 | out_channels=self.num_joints,
231 | kernel_size=self.final_conv_kernel,
232 | stride=1,
233 | padding=1 if self.final_conv_kernel == 3 else 0
234 | )
235 |
236 | def _make_layer(self, block, planes, blocks, stride=1):
237 | downsample = None
238 | if stride != 1 or self.inplanes != planes * block.expansion:
239 | downsample = nn.Sequential(
240 | nn.Conv2d(self.inplanes, planes * block.expansion,
241 | kernel_size=1, stride=stride, bias=False),
242 | nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
243 | )
244 |
245 | layers = []
246 | layers.append(block(self.inplanes, planes, stride, downsample))
247 | self.inplanes = planes * block.expansion
248 | for i in range(1, blocks):
249 | layers.append(block(self.inplanes, planes))
250 |
251 | return nn.Sequential(*layers)
252 |
253 | def _get_deconv_cfg(self, deconv_kernel, index):
254 | if deconv_kernel == 4:
255 | padding = 1
256 | output_padding = 0
257 | elif deconv_kernel == 3:
258 | padding = 1
259 | output_padding = 1
260 | elif deconv_kernel == 2:
261 | padding = 0
262 | output_padding = 0
263 |
264 | return deconv_kernel, padding, output_padding
265 |
266 | def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
267 | assert num_layers == len(num_filters), \
268 | 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
269 | assert num_layers == len(num_kernels), \
270 | 'ERROR: num_deconv_layers is different len(num_deconv_filters)'
271 |
272 | layers = []
273 | for i in range(num_layers):
274 | kernel, padding, output_padding = \
275 | self._get_deconv_cfg(num_kernels[i], i)
276 |
277 | planes = num_filters[i]
278 | layers.append(
279 | nn.ConvTranspose2d(
280 | in_channels=self.inplanes,
281 | out_channels=planes,
282 | kernel_size=kernel,
283 | stride=2,
284 | padding=padding,
285 | output_padding=output_padding,
286 | bias=self.deconv_with_bias))
287 | layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
288 | layers.append(nn.ReLU(inplace=True))
289 | self.inplanes = planes
290 |
291 | return nn.Sequential(*layers)
292 |
293 | def forward(self, x):
294 | x = self.conv1(x)
295 | x = self.bn1(x)
296 | x = self.relu(x)
297 | x = self.maxpool(x)
298 |
299 | x = self.layer1(x)
300 | x = self.layer2(x)
301 | x = self.layer3(x)
302 | x = self.layer4(x)
303 |
304 | alg_confidences = None
305 | if hasattr(self, "alg_confidences"):
306 | alg_confidences = self.alg_confidences(x)
307 |
308 | vol_confidences = None
309 | if hasattr(self, "vol_confidences"):
310 | vol_confidences = self.vol_confidences(x)
311 |
312 | x = self.deconv_layers(x)
313 | features = x
314 |
315 | x = self.final_layer(x)
316 | heatmaps = x
317 |
318 | return heatmaps, features, alg_confidences, vol_confidences
319 |
320 |
321 | def get_pose_net(config, device='cuda:0'):
322 | block_class, layers = resnet_spec[config.num_layers]
323 | if config.style == 'caffe':
324 | block_class = Bottleneck_CAFFE
325 |
326 | model = PoseResNet(
327 | block_class, layers, config.num_joints,
328 | num_input_channels=3,
329 | deconv_with_bias=False,
330 | num_deconv_layers=3,
331 | num_deconv_filters=(256, 256, 256),
332 | num_deconv_kernels=(4, 4, 4),
333 | final_conv_kernel=1,
334 | alg_confidences=config.alg_confidences,
335 | vol_confidences=config.vol_confidences
336 | )
337 |
338 | if config.init_weights:
339 | print("Loading pretrained weights from: {}".format(config.checkpoint))
340 | model_state_dict = model.state_dict()
341 | pretrained_state_dict = torch.load(config.checkpoint, map_location=device)
342 |
343 | if 'state_dict' in pretrained_state_dict:
344 | pretrained_state_dict = pretrained_state_dict['state_dict']
345 |
346 | prefix = "module."
347 |
348 | new_pretrained_state_dict = {}
349 | for k, v in pretrained_state_dict.items():
350 | if k.replace(prefix, "") in model_state_dict and v.shape == model_state_dict[k.replace(prefix, "")].shape:
351 | new_pretrained_state_dict[k.replace(prefix, "")] = v
352 | elif k.replace(prefix, "") == "final_layer.weight": # TODO
353 | print("Reiniting final layer filters:", k)
354 |
355 | o = torch.zeros_like(model_state_dict[k.replace(prefix, "")][:, :, :, :])
356 | nn.init.xavier_uniform_(o)
357 | n_filters = min(o.shape[0], v.shape[0])
358 | o[:n_filters, :, :, :] = v[:n_filters, :, :, :]
359 |
360 | new_pretrained_state_dict[k.replace(prefix, "")] = o
361 | elif k.replace(prefix, "") == "final_layer.bias":
362 | print("Reiniting final layer biases:", k)
363 | o = torch.zeros_like(model_state_dict[k.replace(prefix, "")][:])
364 | nn.init.zeros_(o)
365 | n_filters = min(o.shape[0], v.shape[0])
366 | o[:n_filters] = v[:n_filters]
367 |
368 | new_pretrained_state_dict[k.replace(prefix, "")] = o
369 |
370 | not_inited_params = set(map(lambda x: x.replace(prefix, ""), pretrained_state_dict.keys())) - set(new_pretrained_state_dict.keys())
371 | if len(not_inited_params) > 0:
372 | print("Parameters [{}] were not inited".format(not_inited_params))
373 |
374 | model.load_state_dict(new_pretrained_state_dict, strict=False)
375 | print("Successfully loaded pretrained weights for backbone")
376 |
377 | return model
378 |
--------------------------------------------------------------------------------
/mvn/models/triangulation.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import numpy as np
3 | import pickle
4 | import random
5 |
6 | from scipy.optimize import least_squares
7 |
8 | import torch
9 | from torch import nn
10 |
11 | from mvn.utils import op, multiview, img, misc, volumetric
12 |
13 | from mvn.models import pose_resnet
14 | from mvn.models.v2v import V2VModel
15 |
16 |
17 | class RANSACTriangulationNet(nn.Module):
18 | def __init__(self, config, device='cuda:0'):
19 | super().__init__()
20 |
21 | config.model.backbone.alg_confidences = False
22 | config.model.backbone.vol_confidences = False
23 | self.backbone = pose_resnet.get_pose_net(config.model.backbone, device=device)
24 |
25 | self.direct_optimization = config.model.direct_optimization
26 |
27 | def forward(self, images, proj_matricies, batch):
28 | batch_size, n_views = images.shape[:2]
29 |
30 | # reshape n_views dimension to batch dimension
31 | images = images.view(-1, *images.shape[2:])
32 |
33 | # forward backbone and integrate
34 | heatmaps, _, _, _ = self.backbone(images)
35 |
36 | # reshape back
37 | images = images.view(batch_size, n_views, *images.shape[1:])
38 | heatmaps = heatmaps.view(batch_size, n_views, *heatmaps.shape[1:])
39 |
40 | # calcualte shapes
41 | image_shape = tuple(images.shape[3:])
42 | batch_size, n_views, n_joints, heatmap_shape = heatmaps.shape[0], heatmaps.shape[1], heatmaps.shape[2], tuple(heatmaps.shape[3:])
43 |
44 | # keypoints 2d
45 | _, max_indicies = torch.max(heatmaps.view(batch_size, n_views, n_joints, -1), dim=-1)
46 | keypoints_2d = torch.stack([max_indicies % heatmap_shape[1], max_indicies // heatmap_shape[1]], dim=-1).to(images.device)
47 |
48 | # upscale keypoints_2d, because image shape != heatmap shape
49 | keypoints_2d_transformed = torch.zeros_like(keypoints_2d)
50 | keypoints_2d_transformed[:, :, :, 0] = keypoints_2d[:, :, :, 0] * (image_shape[1] / heatmap_shape[1])
51 | keypoints_2d_transformed[:, :, :, 1] = keypoints_2d[:, :, :, 1] * (image_shape[0] / heatmap_shape[0])
52 | keypoints_2d = keypoints_2d_transformed
53 |
54 | # triangulate (cpu)
55 | keypoints_2d_np = keypoints_2d.detach().cpu().numpy()
56 | proj_matricies_np = proj_matricies.detach().cpu().numpy()
57 |
58 | keypoints_3d = np.zeros((batch_size, n_joints, 3))
59 | confidences = np.zeros((batch_size, n_views, n_joints)) # plug
60 | for batch_i in range(batch_size):
61 | for joint_i in range(n_joints):
62 | current_proj_matricies = proj_matricies_np[batch_i]
63 | points = keypoints_2d_np[batch_i, :, joint_i]
64 | keypoint_3d, _ = self.triangulate_ransac(current_proj_matricies, points, direct_optimization=self.direct_optimization)
65 | keypoints_3d[batch_i, joint_i] = keypoint_3d
66 |
67 | keypoints_3d = torch.from_numpy(keypoints_3d).type(torch.float).to(images.device)
68 | confidences = torch.from_numpy(confidences).type(torch.float).to(images.device)
69 |
70 | return keypoints_3d, keypoints_2d, heatmaps, confidences
71 |
72 | def triangulate_ransac(self, proj_matricies, points, n_iters=10, reprojection_error_epsilon=15, direct_optimization=True):
73 | assert len(proj_matricies) == len(points)
74 | assert len(points) >= 2
75 |
76 | proj_matricies = np.array(proj_matricies)
77 | points = np.array(points)
78 |
79 | n_views = len(points)
80 |
81 | # determine inliers
82 | view_set = set(range(n_views))
83 | inlier_set = set()
84 | for i in range(n_iters):
85 | sampled_views = sorted(random.sample(view_set, 2))
86 |
87 | keypoint_3d_in_base_camera = multiview.triangulate_point_from_multiple_views_linear(proj_matricies[sampled_views], points[sampled_views])
88 | reprojection_error_vector = multiview.calc_reprojection_error_matrix(np.array([keypoint_3d_in_base_camera]), points, proj_matricies)[0]
89 |
90 | new_inlier_set = set(sampled_views)
91 | for view in view_set:
92 | current_reprojection_error = reprojection_error_vector[view]
93 | if current_reprojection_error < reprojection_error_epsilon:
94 | new_inlier_set.add(view)
95 |
96 | if len(new_inlier_set) > len(inlier_set):
97 | inlier_set = new_inlier_set
98 |
99 | # triangulate using inlier_set
100 | if len(inlier_set) == 0:
101 | inlier_set = view_set.copy()
102 |
103 | inlier_list = np.array(sorted(inlier_set))
104 | inlier_proj_matricies = proj_matricies[inlier_list]
105 | inlier_points = points[inlier_list]
106 |
107 | keypoint_3d_in_base_camera = multiview.triangulate_point_from_multiple_views_linear(inlier_proj_matricies, inlier_points)
108 | reprojection_error_vector = multiview.calc_reprojection_error_matrix(np.array([keypoint_3d_in_base_camera]), inlier_points, inlier_proj_matricies)[0]
109 | reprojection_error_mean = np.mean(reprojection_error_vector)
110 |
111 | keypoint_3d_in_base_camera_before_direct_optimization = keypoint_3d_in_base_camera
112 | reprojection_error_before_direct_optimization = reprojection_error_mean
113 |
114 | # direct reprojection error minimization
115 | if direct_optimization:
116 | def residual_function(x):
117 | reprojection_error_vector = multiview.calc_reprojection_error_matrix(np.array([x]), inlier_points, inlier_proj_matricies)[0]
118 | residuals = reprojection_error_vector
119 | return residuals
120 |
121 | x_0 = np.array(keypoint_3d_in_base_camera)
122 | res = least_squares(residual_function, x_0, loss='huber', method='trf')
123 |
124 | keypoint_3d_in_base_camera = res.x
125 | reprojection_error_vector = multiview.calc_reprojection_error_matrix(np.array([keypoint_3d_in_base_camera]), inlier_points, inlier_proj_matricies)[0]
126 | reprojection_error_mean = np.mean(reprojection_error_vector)
127 |
128 | return keypoint_3d_in_base_camera, inlier_list
129 |
130 |
131 | class AlgebraicTriangulationNet(nn.Module):
132 | def __init__(self, config, device='cuda:0'):
133 | super().__init__()
134 |
135 | self.use_confidences = config.model.use_confidences
136 |
137 | config.model.backbone.alg_confidences = False
138 | config.model.backbone.vol_confidences = False
139 |
140 | if self.use_confidences:
141 | config.model.backbone.alg_confidences = True
142 |
143 | self.backbone = pose_resnet.get_pose_net(config.model.backbone, device=device)
144 |
145 | self.heatmap_softmax = config.model.heatmap_softmax
146 | self.heatmap_multiplier = config.model.heatmap_multiplier
147 |
148 |
149 | def forward(self, images, proj_matricies, batch):
150 | device = images.device
151 | batch_size, n_views = images.shape[:2]
152 |
153 | # reshape n_views dimension to batch dimension
154 | images = images.view(-1, *images.shape[2:])
155 |
156 | # forward backbone and integral
157 | if self.use_confidences:
158 | heatmaps, _, alg_confidences, _ = self.backbone(images)
159 | else:
160 | heatmaps, _, _, _ = self.backbone(images)
161 | alg_confidences = torch.ones(batch_size * n_views, heatmaps.shape[1]).type(torch.float).to(device)
162 |
163 | heatmaps_before_softmax = heatmaps.view(batch_size, n_views, *heatmaps.shape[1:])
164 | keypoints_2d, heatmaps = op.integrate_tensor_2d(heatmaps * self.heatmap_multiplier, self.heatmap_softmax)
165 |
166 | # reshape back
167 | images = images.view(batch_size, n_views, *images.shape[1:])
168 | heatmaps = heatmaps.view(batch_size, n_views, *heatmaps.shape[1:])
169 | keypoints_2d = keypoints_2d.view(batch_size, n_views, *keypoints_2d.shape[1:])
170 | alg_confidences = alg_confidences.view(batch_size, n_views, *alg_confidences.shape[1:])
171 |
172 | # norm confidences
173 | alg_confidences = alg_confidences / alg_confidences.sum(dim=1, keepdim=True)
174 | alg_confidences = alg_confidences + 1e-5 # for numerical stability
175 |
176 | # calcualte shapes
177 | image_shape = tuple(images.shape[3:])
178 | batch_size, n_views, n_joints, heatmap_shape = heatmaps.shape[0], heatmaps.shape[1], heatmaps.shape[2], tuple(heatmaps.shape[3:])
179 |
180 | # upscale keypoints_2d, because image shape != heatmap shape
181 | keypoints_2d_transformed = torch.zeros_like(keypoints_2d)
182 | keypoints_2d_transformed[:, :, :, 0] = keypoints_2d[:, :, :, 0] * (image_shape[1] / heatmap_shape[1])
183 | keypoints_2d_transformed[:, :, :, 1] = keypoints_2d[:, :, :, 1] * (image_shape[0] / heatmap_shape[0])
184 | keypoints_2d = keypoints_2d_transformed
185 |
186 | # triangulate
187 | try:
188 | keypoints_3d = multiview.triangulate_batch_of_points(
189 | proj_matricies, keypoints_2d,
190 | confidences_batch=alg_confidences
191 | )
192 | except RuntimeError as e:
193 | print("Error: ", e)
194 |
195 | print("confidences =", confidences_batch_pred)
196 | print("proj_matricies = ", proj_matricies)
197 | print("keypoints_2d_batch_pred =", keypoints_2d_batch_pred)
198 | exit()
199 |
200 | return keypoints_3d, keypoints_2d, heatmaps, alg_confidences
201 |
202 |
203 | class VolumetricTriangulationNet(nn.Module):
204 | def __init__(self, config, device='cuda:0'):
205 | super().__init__()
206 |
207 | self.num_joints = config.model.backbone.num_joints
208 | self.volume_aggregation_method = config.model.volume_aggregation_method
209 |
210 | # volume
211 | self.volume_softmax = config.model.volume_softmax
212 | self.volume_multiplier = config.model.volume_multiplier
213 | self.volume_size = config.model.volume_size
214 |
215 | self.cuboid_side = config.model.cuboid_side
216 |
217 | self.kind = config.model.kind
218 | self.use_gt_pelvis = config.model.use_gt_pelvis
219 |
220 | # heatmap
221 | self.heatmap_softmax = config.model.heatmap_softmax
222 | self.heatmap_multiplier = config.model.heatmap_multiplier
223 |
224 | # transfer
225 | self.transfer_cmu_to_human36m = config.model.transfer_cmu_to_human36m if hasattr(config.model, "transfer_cmu_to_human36m") else False
226 |
227 | # modules
228 | config.model.backbone.alg_confidences = False
229 | config.model.backbone.vol_confidences = False
230 | if self.volume_aggregation_method.startswith('conf'):
231 | config.model.backbone.vol_confidences = True
232 |
233 | self.backbone = pose_resnet.get_pose_net(config.model.backbone, device=device)
234 |
235 | for p in self.backbone.final_layer.parameters():
236 | p.requires_grad = False
237 |
238 | self.process_features = nn.Sequential(
239 | nn.Conv2d(256, 32, 1)
240 | )
241 |
242 | self.volume_net = V2VModel(32, self.num_joints)
243 |
244 |
245 | def forward(self, images, proj_matricies, batch):
246 | device = images.device
247 | batch_size, n_views = images.shape[:2]
248 |
249 | # reshape for backbone forward
250 | images = images.view(-1, *images.shape[2:])
251 |
252 | # forward backbone
253 | heatmaps, features, _, vol_confidences = self.backbone(images)
254 |
255 | # reshape back
256 | images = images.view(batch_size, n_views, *images.shape[1:])
257 | heatmaps = heatmaps.view(batch_size, n_views, *heatmaps.shape[1:])
258 | features = features.view(batch_size, n_views, *features.shape[1:])
259 |
260 | if vol_confidences is not None:
261 | vol_confidences = vol_confidences.view(batch_size, n_views, *vol_confidences.shape[1:])
262 |
263 | # calcualte shapes
264 | image_shape, heatmap_shape = tuple(images.shape[3:]), tuple(heatmaps.shape[3:])
265 | n_joints = heatmaps.shape[2]
266 |
267 | # norm vol confidences
268 | if self.volume_aggregation_method == 'conf_norm':
269 | vol_confidences = vol_confidences / vol_confidences.sum(dim=1, keepdim=True)
270 |
271 | # change camera intrinsics
272 | new_cameras = deepcopy(batch['cameras'])
273 | for view_i in range(n_views):
274 | for batch_i in range(batch_size):
275 | new_cameras[view_i][batch_i].update_after_resize(image_shape, heatmap_shape)
276 |
277 | proj_matricies = torch.stack([torch.stack([torch.from_numpy(camera.projection) for camera in camera_batch], dim=0) for camera_batch in new_cameras], dim=0).transpose(1, 0) # shape (batch_size, n_views, 3, 4)
278 | proj_matricies = proj_matricies.float().to(device)
279 |
280 | # build coord volumes
281 | cuboids = []
282 | base_points = torch.zeros(batch_size, 3, device=device)
283 | coord_volumes = torch.zeros(batch_size, self.volume_size, self.volume_size, self.volume_size, 3, device=device)
284 | for batch_i in range(batch_size):
285 | # if self.use_precalculated_pelvis:
286 | if self.use_gt_pelvis:
287 | keypoints_3d = batch['keypoints_3d'][batch_i]
288 | else:
289 | keypoints_3d = batch['pred_keypoints_3d'][batch_i]
290 |
291 | if self.kind == "coco":
292 | base_point = (keypoints_3d[11, :3] + keypoints_3d[12, :3]) / 2
293 | elif self.kind == "mpii":
294 | base_point = keypoints_3d[6, :3]
295 |
296 | base_points[batch_i] = torch.from_numpy(base_point).to(device)
297 |
298 | # build cuboid
299 | sides = np.array([self.cuboid_side, self.cuboid_side, self.cuboid_side])
300 | position = base_point - sides / 2
301 | cuboid = volumetric.Cuboid3D(position, sides)
302 |
303 | cuboids.append(cuboid)
304 |
305 | # build coord volume
306 | xxx, yyy, zzz = torch.meshgrid(torch.arange(self.volume_size, device=device), torch.arange(self.volume_size, device=device), torch.arange(self.volume_size, device=device))
307 | grid = torch.stack([xxx, yyy, zzz], dim=-1).type(torch.float)
308 | grid = grid.reshape((-1, 3))
309 |
310 | grid_coord = torch.zeros_like(grid)
311 | grid_coord[:, 0] = position[0] + (sides[0] / (self.volume_size - 1)) * grid[:, 0]
312 | grid_coord[:, 1] = position[1] + (sides[1] / (self.volume_size - 1)) * grid[:, 1]
313 | grid_coord[:, 2] = position[2] + (sides[2] / (self.volume_size - 1)) * grid[:, 2]
314 |
315 | coord_volume = grid_coord.reshape(self.volume_size, self.volume_size, self.volume_size, 3)
316 |
317 | # random rotation
318 | if self.training:
319 | theta = np.random.uniform(0.0, 2 * np.pi)
320 | else:
321 | theta = 0.0
322 |
323 | if self.kind == "coco":
324 | axis = [0, 1, 0] # y axis
325 | elif self.kind == "mpii":
326 | axis = [0, 0, 1] # z axis
327 |
328 | center = torch.from_numpy(base_point).type(torch.float).to(device)
329 |
330 | # rotate
331 | coord_volume = coord_volume - center
332 | coord_volume = volumetric.rotate_coord_volume(coord_volume, theta, axis)
333 | coord_volume = coord_volume + center
334 |
335 | # transfer
336 | if self.transfer_cmu_to_human36m: # different world coordinates
337 | coord_volume = coord_volume.permute(0, 2, 1, 3)
338 | inv_idx = torch.arange(coord_volume.shape[1] - 1, -1, -1).long().to(device)
339 | coord_volume = coord_volume.index_select(1, inv_idx)
340 |
341 | coord_volumes[batch_i] = coord_volume
342 |
343 | # process features before unprojecting
344 | features = features.view(-1, *features.shape[2:])
345 | features = self.process_features(features)
346 | features = features.view(batch_size, n_views, *features.shape[1:])
347 |
348 | # lift to volume
349 | volumes = op.unproject_heatmaps(features, proj_matricies, coord_volumes, volume_aggregation_method=self.volume_aggregation_method, vol_confidences=vol_confidences)
350 |
351 | # integral 3d
352 | volumes = self.volume_net(volumes)
353 | vol_keypoints_3d, volumes = op.integrate_tensor_3d_with_coordinates(volumes * self.volume_multiplier, coord_volumes, softmax=self.volume_softmax)
354 |
355 | return vol_keypoints_3d, features, volumes, vol_confidences, cuboids, coord_volumes, base_points
356 |
--------------------------------------------------------------------------------
/mvn/models/v2v.py:
--------------------------------------------------------------------------------
1 | # Reference: https://github.com/dragonbook/V2V-PoseNet-pytorch
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class Basic3DBlock(nn.Module):
8 | def __init__(self, in_planes, out_planes, kernel_size):
9 | super(Basic3DBlock, self).__init__()
10 | self.block = nn.Sequential(
11 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1, padding=((kernel_size-1)//2)),
12 | nn.BatchNorm3d(out_planes),
13 | nn.ReLU(True)
14 | )
15 |
16 | def forward(self, x):
17 | return self.block(x)
18 |
19 |
20 | class Res3DBlock(nn.Module):
21 | def __init__(self, in_planes, out_planes):
22 | super(Res3DBlock, self).__init__()
23 | self.res_branch = nn.Sequential(
24 | nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=1, padding=1),
25 | nn.BatchNorm3d(out_planes),
26 | nn.ReLU(True),
27 | nn.Conv3d(out_planes, out_planes, kernel_size=3, stride=1, padding=1),
28 | nn.BatchNorm3d(out_planes)
29 | )
30 |
31 | if in_planes == out_planes:
32 | self.skip_con = nn.Sequential()
33 | else:
34 | self.skip_con = nn.Sequential(
35 | nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=1, padding=0),
36 | nn.BatchNorm3d(out_planes)
37 | )
38 |
39 | def forward(self, x):
40 | res = self.res_branch(x)
41 | skip = self.skip_con(x)
42 | return F.relu(res + skip, True)
43 |
44 |
45 | class Pool3DBlock(nn.Module):
46 | def __init__(self, pool_size):
47 | super(Pool3DBlock, self).__init__()
48 | self.pool_size = pool_size
49 |
50 | def forward(self, x):
51 | return F.max_pool3d(x, kernel_size=self.pool_size, stride=self.pool_size)
52 |
53 |
54 | class Upsample3DBlock(nn.Module):
55 | def __init__(self, in_planes, out_planes, kernel_size, stride):
56 | super(Upsample3DBlock, self).__init__()
57 | assert(kernel_size == 2)
58 | assert(stride == 2)
59 | self.block = nn.Sequential(
60 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=0, output_padding=0),
61 | nn.BatchNorm3d(out_planes),
62 | nn.ReLU(True)
63 | )
64 |
65 | def forward(self, x):
66 | return self.block(x)
67 |
68 |
69 | class EncoderDecorder(nn.Module):
70 | def __init__(self):
71 | super().__init__()
72 |
73 | self.encoder_pool1 = Pool3DBlock(2)
74 | self.encoder_res1 = Res3DBlock(32, 64)
75 | self.encoder_pool2 = Pool3DBlock(2)
76 | self.encoder_res2 = Res3DBlock(64, 128)
77 | self.encoder_pool3 = Pool3DBlock(2)
78 | self.encoder_res3 = Res3DBlock(128, 128)
79 | self.encoder_pool4 = Pool3DBlock(2)
80 | self.encoder_res4 = Res3DBlock(128, 128)
81 | self.encoder_pool5 = Pool3DBlock(2)
82 | self.encoder_res5 = Res3DBlock(128, 128)
83 |
84 | self.mid_res = Res3DBlock(128, 128)
85 |
86 | self.decoder_res5 = Res3DBlock(128, 128)
87 | self.decoder_upsample5 = Upsample3DBlock(128, 128, 2, 2)
88 | self.decoder_res4 = Res3DBlock(128, 128)
89 | self.decoder_upsample4 = Upsample3DBlock(128, 128, 2, 2)
90 | self.decoder_res3 = Res3DBlock(128, 128)
91 | self.decoder_upsample3 = Upsample3DBlock(128, 128, 2, 2)
92 | self.decoder_res2 = Res3DBlock(128, 128)
93 | self.decoder_upsample2 = Upsample3DBlock(128, 64, 2, 2)
94 | self.decoder_res1 = Res3DBlock(64, 64)
95 | self.decoder_upsample1 = Upsample3DBlock(64, 32, 2, 2)
96 |
97 | self.skip_res1 = Res3DBlock(32, 32)
98 | self.skip_res2 = Res3DBlock(64, 64)
99 | self.skip_res3 = Res3DBlock(128, 128)
100 | self.skip_res4 = Res3DBlock(128, 128)
101 | self.skip_res5 = Res3DBlock(128, 128)
102 |
103 | def forward(self, x):
104 | skip_x1 = self.skip_res1(x)
105 | x = self.encoder_pool1(x)
106 | x = self.encoder_res1(x)
107 | skip_x2 = self.skip_res2(x)
108 | x = self.encoder_pool2(x)
109 | x = self.encoder_res2(x)
110 | skip_x3 = self.skip_res3(x)
111 | x = self.encoder_pool3(x)
112 | x = self.encoder_res3(x)
113 | skip_x4 = self.skip_res4(x)
114 | x = self.encoder_pool4(x)
115 | x = self.encoder_res4(x)
116 | skip_x5 = self.skip_res5(x)
117 | x = self.encoder_pool5(x)
118 | x = self.encoder_res5(x)
119 |
120 | x = self.mid_res(x)
121 |
122 | x = self.decoder_res5(x)
123 | x = self.decoder_upsample5(x)
124 | x = x + skip_x5
125 | x = self.decoder_res4(x)
126 | x = self.decoder_upsample4(x)
127 | x = x + skip_x4
128 | x = self.decoder_res3(x)
129 | x = self.decoder_upsample3(x)
130 | x = x + skip_x3
131 | x = self.decoder_res2(x)
132 | x = self.decoder_upsample2(x)
133 | x = x + skip_x2
134 | x = self.decoder_res1(x)
135 | x = self.decoder_upsample1(x)
136 | x = x + skip_x1
137 |
138 | return x
139 |
140 |
141 | class V2VModel(nn.Module):
142 | def __init__(self, input_channels, output_channels):
143 | super().__init__()
144 |
145 | self.front_layers = nn.Sequential(
146 | Basic3DBlock(input_channels, 16, 7),
147 | Res3DBlock(16, 32),
148 | Res3DBlock(32, 32),
149 | Res3DBlock(32, 32)
150 | )
151 |
152 | self.encoder_decoder = EncoderDecorder()
153 |
154 | self.back_layers = nn.Sequential(
155 | Res3DBlock(32, 32),
156 | Basic3DBlock(32, 32, 1),
157 | Basic3DBlock(32, 32, 1),
158 | )
159 |
160 | self.output_layer = nn.Conv3d(32, output_channels, kernel_size=1, stride=1, padding=0)
161 |
162 | self._initialize_weights()
163 |
164 | def forward(self, x):
165 | x = self.front_layers(x)
166 | x = self.encoder_decoder(x)
167 | x = self.back_layers(x)
168 | x = self.output_layer(x)
169 | return x
170 |
171 | def _initialize_weights(self):
172 | for m in self.modules():
173 | if isinstance(m, nn.Conv3d):
174 | nn.init.xavier_normal_(m.weight)
175 | # nn.init.normal_(m.weight, 0, 0.001)
176 | nn.init.constant_(m.bias, 0)
177 | elif isinstance(m, nn.ConvTranspose3d):
178 | nn.init.xavier_normal_(m.weight)
179 | # nn.init.normal_(m.weight, 0, 0.001)
180 | nn.init.constant_(m.bias, 0)
181 |
--------------------------------------------------------------------------------
/mvn/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
--------------------------------------------------------------------------------
/mvn/utils/cfg.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from easydict import EasyDict as edict
3 |
4 |
5 | def load_config(path):
6 | with open(path) as fin:
7 | config = edict(yaml.safe_load(fin))
8 |
9 | return config
10 |
--------------------------------------------------------------------------------
/mvn/utils/img.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | from PIL import Image
4 |
5 | import torch
6 |
7 | IMAGENET_MEAN, IMAGENET_STD = np.array([0.485, 0.456, 0.406]), np.array([0.229, 0.224, 0.225])
8 |
9 |
10 | def crop_image(image, bbox):
11 | """Crops area from image specified as bbox. Always returns area of size as bbox filling missing parts with zeros
12 | Args:
13 | image numpy array of shape (height, width, 3): input image
14 | bbox tuple of size 4: input bbox (left, upper, right, lower)
15 |
16 | Returns:
17 | cropped_image numpy array of shape (height, width, 3): resulting cropped image
18 |
19 | """
20 |
21 | image_pil = Image.fromarray(image)
22 | image_pil = image_pil.crop(bbox)
23 |
24 | return np.asarray(image_pil)
25 |
26 |
27 | def resize_image(image, shape):
28 | return cv2.resize(image, (shape[1], shape[0]), interpolation=cv2.INTER_AREA)
29 |
30 |
31 | def get_square_bbox(bbox):
32 | """Makes square bbox from any bbox by stretching of minimal length side
33 |
34 | Args:
35 | bbox tuple of size 4: input bbox (left, upper, right, lower)
36 |
37 | Returns:
38 | bbox: tuple of size 4: resulting square bbox (left, upper, right, lower)
39 | """
40 |
41 | left, upper, right, lower = bbox
42 | width, height = right - left, lower - upper
43 |
44 | if width > height:
45 | y_center = (upper + lower) // 2
46 | upper = y_center - width // 2
47 | lower = upper + width
48 | else:
49 | x_center = (left + right) // 2
50 | left = x_center - height // 2
51 | right = left + height
52 |
53 | return left, upper, right, lower
54 |
55 |
56 | def scale_bbox(bbox, scale):
57 | left, upper, right, lower = bbox
58 | width, height = right - left, lower - upper
59 |
60 | x_center, y_center = (right + left) // 2, (lower + upper) // 2
61 | new_width, new_height = int(scale * width), int(scale * height)
62 |
63 | new_left = x_center - new_width // 2
64 | new_right = new_left + new_width
65 |
66 | new_upper = y_center - new_height // 2
67 | new_lower = new_upper + new_height
68 |
69 | return new_left, new_upper, new_right, new_lower
70 |
71 |
72 | def to_numpy(tensor):
73 | if torch.is_tensor(tensor):
74 | return tensor.cpu().detach().numpy()
75 | elif type(tensor).__module__ != 'numpy':
76 | raise ValueError("Cannot convert {} to numpy array"
77 | .format(type(tensor)))
78 | return tensor
79 |
80 |
81 | def to_torch(ndarray):
82 | if type(ndarray).__module__ == 'numpy':
83 | return torch.from_numpy(ndarray)
84 | elif not torch.is_tensor(ndarray):
85 | raise ValueError("Cannot convert {} to torch tensor"
86 | .format(type(ndarray)))
87 | return ndarray
88 |
89 |
90 | def image_batch_to_numpy(image_batch):
91 | image_batch = to_numpy(image_batch)
92 | image_batch = np.transpose(image_batch, (0, 2, 3, 1)) # BxCxHxW -> BxHxWxC
93 | return image_batch
94 |
95 |
96 | def image_batch_to_torch(image_batch):
97 | image_batch = np.transpose(image_batch, (0, 3, 1, 2)) # BxHxWxC -> BxCxHxW
98 | image_batch = to_torch(image_batch).float()
99 | return image_batch
100 |
101 |
102 | def normalize_image(image):
103 | """Normalizes image using ImageNet mean and std
104 |
105 | Args:
106 | image numpy array of shape (h, w, 3): image
107 |
108 | Returns normalized_image numpy array of shape (h, w, 3): normalized image
109 | """
110 | return (image / 255.0 - IMAGENET_MEAN) / IMAGENET_STD
111 |
112 |
113 | def denormalize_image(image):
114 | """Reverse to normalize_image() function"""
115 | return np.clip(255.0 * (image * IMAGENET_STD + IMAGENET_MEAN), 0, 255)
116 |
--------------------------------------------------------------------------------
/mvn/utils/misc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import json
4 | import re
5 |
6 | import torch
7 |
8 |
9 | def config_to_str(config):
10 | return yaml.dump(yaml.safe_load(json.dumps(config))) # fuck yeah
11 |
12 |
13 | class AverageMeter(object):
14 | """Computes and stores the average and current value"""
15 | def __init__(self):
16 | self.reset()
17 |
18 | def reset(self):
19 | self.val = 0
20 | self.avg = 0
21 | self.sum = 0
22 | self.count = 0
23 |
24 | def update(self, val, n=1):
25 | self.val = val
26 | self.sum += val * n
27 | self.count += n
28 | self.avg = self.sum / self.count
29 |
30 |
31 | def calc_gradient_norm(named_parameters):
32 | total_norm = 0.0
33 | for name, p in named_parameters:
34 | # print(name)
35 | param_norm = p.grad.data.norm(2)
36 | total_norm += param_norm.item() ** 2
37 |
38 | total_norm = total_norm ** (1. / 2)
39 |
40 | return total_norm
41 |
--------------------------------------------------------------------------------
/mvn/utils/multiview.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class Camera:
6 | def __init__(self, R, t, K, dist=None, name=""):
7 | self.R = np.array(R).copy()
8 | assert self.R.shape == (3, 3)
9 |
10 | self.t = np.array(t).copy()
11 | assert self.t.size == 3
12 | self.t = self.t.reshape(3, 1)
13 |
14 | self.K = np.array(K).copy()
15 | assert self.K.shape == (3, 3)
16 |
17 | self.dist = dist
18 | if self.dist is not None:
19 | self.dist = np.array(self.dist).copy().flatten()
20 |
21 | self.name = name
22 |
23 | def update_after_crop(self, bbox):
24 | left, upper, right, lower = bbox
25 |
26 | cx, cy = self.K[0, 2], self.K[1, 2]
27 |
28 | new_cx = cx - left
29 | new_cy = cy - upper
30 |
31 | self.K[0, 2], self.K[1, 2] = new_cx, new_cy
32 |
33 | def update_after_resize(self, image_shape, new_image_shape):
34 | height, width = image_shape
35 | new_height, new_width = new_image_shape
36 |
37 | fx, fy, cx, cy = self.K[0, 0], self.K[1, 1], self.K[0, 2], self.K[1, 2]
38 |
39 | new_fx = fx * (new_width / width)
40 | new_fy = fy * (new_height / height)
41 | new_cx = cx * (new_width / width)
42 | new_cy = cy * (new_height / height)
43 |
44 | self.K[0, 0], self.K[1, 1], self.K[0, 2], self.K[1, 2] = new_fx, new_fy, new_cx, new_cy
45 |
46 | @property
47 | def projection(self):
48 | return self.K.dot(self.extrinsics)
49 |
50 | @property
51 | def extrinsics(self):
52 | return np.hstack([self.R, self.t])
53 |
54 |
55 | def euclidean_to_homogeneous(points):
56 | """Converts euclidean points to homogeneous
57 |
58 | Args:
59 | points numpy array or torch tensor of shape (N, M): N euclidean points of dimension M
60 |
61 | Returns:
62 | numpy array or torch tensor of shape (N, M + 1): homogeneous points
63 | """
64 | if isinstance(points, np.ndarray):
65 | return np.hstack([points, np.ones((len(points), 1))])
66 | elif torch.is_tensor(points):
67 | return torch.cat([points, torch.ones((points.shape[0], 1), dtype=points.dtype, device=points.device)], dim=1)
68 | else:
69 | raise TypeError("Works only with numpy arrays and PyTorch tensors.")
70 |
71 |
72 | def homogeneous_to_euclidean(points):
73 | """Converts homogeneous points to euclidean
74 |
75 | Args:
76 | points numpy array or torch tensor of shape (N, M + 1): N homogeneous points of dimension M
77 |
78 | Returns:
79 | numpy array or torch tensor of shape (N, M): euclidean points
80 | """
81 | if isinstance(points, np.ndarray):
82 | return (points.T[:-1] / points.T[-1]).T
83 | elif torch.is_tensor(points):
84 | return (points.transpose(1, 0)[:-1] / points.transpose(1, 0)[-1]).transpose(1, 0)
85 | else:
86 | raise TypeError("Works only with numpy arrays and PyTorch tensors.")
87 |
88 |
89 | def project_3d_points_to_image_plane_without_distortion(proj_matrix, points_3d, convert_back_to_euclidean=True):
90 | """Project 3D points to image plane not taking into account distortion
91 | Args:
92 | proj_matrix numpy array or torch tensor of shape (3, 4): projection matrix
93 | points_3d numpy array or torch tensor of shape (N, 3): 3D points
94 | convert_back_to_euclidean bool: if True, then resulting points will be converted to euclidean coordinates
95 | NOTE: division by zero can be here if z = 0
96 | Returns:
97 | numpy array or torch tensor of shape (N, 2): 3D points projected to image plane
98 | """
99 | if isinstance(proj_matrix, np.ndarray) and isinstance(points_3d, np.ndarray):
100 | result = euclidean_to_homogeneous(points_3d) @ proj_matrix.T
101 | if convert_back_to_euclidean:
102 | result = homogeneous_to_euclidean(result)
103 | return result
104 | elif torch.is_tensor(proj_matrix) and torch.is_tensor(points_3d):
105 | result = euclidean_to_homogeneous(points_3d) @ proj_matrix.t()
106 | if convert_back_to_euclidean:
107 | result = homogeneous_to_euclidean(result)
108 | return result
109 | else:
110 | raise TypeError("Works only with numpy arrays and PyTorch tensors.")
111 |
112 |
113 | def triangulate_point_from_multiple_views_linear(proj_matricies, points):
114 | """Triangulates one point from multiple (N) views using direct linear transformation (DLT).
115 | For more information look at "Multiple view geometry in computer vision",
116 | Richard Hartley and Andrew Zisserman, 12.2 (p. 312).
117 |
118 | Args:
119 | proj_matricies numpy array of shape (N, 3, 4): sequence of projection matricies (3x4)
120 | points numpy array of shape (N, 2): sequence of points' coordinates
121 |
122 | Returns:
123 | point_3d numpy array of shape (3,): triangulated point
124 | """
125 | assert len(proj_matricies) == len(points)
126 |
127 | n_views = len(proj_matricies)
128 | A = np.zeros((2 * n_views, 4))
129 | for j in range(len(proj_matricies)):
130 | A[j * 2 + 0] = points[j][0] * proj_matricies[j][2, :] - proj_matricies[j][0, :]
131 | A[j * 2 + 1] = points[j][1] * proj_matricies[j][2, :] - proj_matricies[j][1, :]
132 |
133 | u, s, vh = np.linalg.svd(A, full_matrices=False)
134 | point_3d_homo = vh[3, :]
135 |
136 | point_3d = homogeneous_to_euclidean(point_3d_homo)
137 |
138 | return point_3d
139 |
140 |
141 | def triangulate_point_from_multiple_views_linear_torch(proj_matricies, points, confidences=None):
142 | """Similar as triangulate_point_from_multiple_views_linear() but for PyTorch.
143 | For more information see its documentation.
144 | Args:
145 | proj_matricies torch tensor of shape (N, 3, 4): sequence of projection matricies (3x4)
146 | points torch tensor of of shape (N, 2): sequence of points' coordinates
147 | confidences None or torch tensor of shape (N,): confidences of points [0.0, 1.0].
148 | If None, all confidences are supposed to be 1.0
149 | Returns:
150 | point_3d numpy torch tensor of shape (3,): triangulated point
151 | """
152 | assert len(proj_matricies) == len(points)
153 |
154 | n_views = len(proj_matricies)
155 |
156 | if confidences is None:
157 | confidences = torch.ones(n_views, dtype=torch.float32, device=points.device)
158 |
159 | A = proj_matricies[:, 2:3].expand(n_views, 2, 4) * points.view(n_views, 2, 1)
160 | A -= proj_matricies[:, :2]
161 | A *= confidences.view(-1, 1, 1)
162 |
163 | u, s, vh = torch.svd(A.view(-1, 4))
164 |
165 | point_3d_homo = -vh[:, 3]
166 | point_3d = homogeneous_to_euclidean(point_3d_homo.unsqueeze(0))[0]
167 |
168 | return point_3d
169 |
170 |
171 | def triangulate_batch_of_points(proj_matricies_batch, points_batch, confidences_batch=None):
172 | batch_size, n_views, n_joints = points_batch.shape[:3]
173 | point_3d_batch = torch.zeros(batch_size, n_joints, 3, dtype=torch.float32, device=points_batch.device)
174 |
175 | for batch_i in range(batch_size):
176 | for joint_i in range(n_joints):
177 | points = points_batch[batch_i, :, joint_i, :]
178 |
179 | confidences = confidences_batch[batch_i, :, joint_i] if confidences_batch is not None else None
180 | point_3d = triangulate_point_from_multiple_views_linear_torch(proj_matricies_batch[batch_i], points, confidences=confidences)
181 | point_3d_batch[batch_i, joint_i] = point_3d
182 |
183 | return point_3d_batch
184 |
185 |
186 | def calc_reprojection_error_matrix(keypoints_3d, keypoints_2d_list, proj_matricies):
187 | reprojection_error_matrix = []
188 | for keypoints_2d, proj_matrix in zip(keypoints_2d_list, proj_matricies):
189 | keypoints_2d_projected = project_3d_points_to_image_plane_without_distortion(proj_matrix, keypoints_3d)
190 | reprojection_error = 1 / 2 * np.sqrt(np.sum((keypoints_2d - keypoints_2d_projected) ** 2, axis=1))
191 | reprojection_error_matrix.append(reprojection_error)
192 |
193 | return np.vstack(reprojection_error_matrix).T
194 |
--------------------------------------------------------------------------------
/mvn/utils/op.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from mvn.utils.img import to_numpy, to_torch
8 | from mvn.utils import multiview
9 |
10 |
11 | def integrate_tensor_2d(heatmaps, softmax=True):
12 | """Applies softmax to heatmaps and integrates them to get their's "center of masses"
13 |
14 | Args:
15 | heatmaps torch tensor of shape (batch_size, n_heatmaps, h, w): input heatmaps
16 |
17 | Returns:
18 | coordinates torch tensor of shape (batch_size, n_heatmaps, 2): coordinates of center of masses of all heatmaps
19 |
20 | """
21 | batch_size, n_heatmaps, h, w = heatmaps.shape
22 |
23 | heatmaps = heatmaps.reshape((batch_size, n_heatmaps, -1))
24 | if softmax:
25 | heatmaps = nn.functional.softmax(heatmaps, dim=2)
26 | else:
27 | heatmaps = nn.functional.relu(heatmaps)
28 |
29 | heatmaps = heatmaps.reshape((batch_size, n_heatmaps, h, w))
30 |
31 | mass_x = heatmaps.sum(dim=2)
32 | mass_y = heatmaps.sum(dim=3)
33 |
34 | mass_times_coord_x = mass_x * torch.arange(w).type(torch.float).to(mass_x.device)
35 | mass_times_coord_y = mass_y * torch.arange(h).type(torch.float).to(mass_y.device)
36 |
37 | x = mass_times_coord_x.sum(dim=2, keepdim=True)
38 | y = mass_times_coord_y.sum(dim=2, keepdim=True)
39 |
40 | if not softmax:
41 | x = x / mass_x.sum(dim=2, keepdim=True)
42 | y = y / mass_y.sum(dim=2, keepdim=True)
43 |
44 | coordinates = torch.cat((x, y), dim=2)
45 | coordinates = coordinates.reshape((batch_size, n_heatmaps, 2))
46 |
47 | return coordinates, heatmaps
48 |
49 |
50 | def integrate_tensor_3d(volumes, softmax=True):
51 | batch_size, n_volumes, x_size, y_size, z_size = volumes.shape
52 |
53 | volumes = volumes.reshape((batch_size, n_volumes, -1))
54 | if softmax:
55 | volumes = nn.functional.softmax(volumes, dim=2)
56 | else:
57 | volumes = nn.functional.relu(volumes)
58 |
59 | volumes = volumes.reshape((batch_size, n_volumes, x_size, y_size, z_size))
60 |
61 | mass_x = volumes.sum(dim=3).sum(dim=3)
62 | mass_y = volumes.sum(dim=2).sum(dim=3)
63 | mass_z = volumes.sum(dim=2).sum(dim=2)
64 |
65 | mass_times_coord_x = mass_x * torch.arange(x_size).type(torch.float).to(mass_x.device)
66 | mass_times_coord_y = mass_y * torch.arange(y_size).type(torch.float).to(mass_y.device)
67 | mass_times_coord_z = mass_z * torch.arange(z_size).type(torch.float).to(mass_z.device)
68 |
69 | x = mass_times_coord_x.sum(dim=2, keepdim=True)
70 | y = mass_times_coord_y.sum(dim=2, keepdim=True)
71 | z = mass_times_coord_z.sum(dim=2, keepdim=True)
72 |
73 | if not softmax:
74 | x = x / mass_x.sum(dim=2, keepdim=True)
75 | y = y / mass_y.sum(dim=2, keepdim=True)
76 | z = z / mass_z.sum(dim=2, keepdim=True)
77 |
78 | coordinates = torch.cat((x, y, z), dim=2)
79 | coordinates = coordinates.reshape((batch_size, n_volumes, 3))
80 |
81 | return coordinates, volumes
82 |
83 |
84 | def integrate_tensor_3d_with_coordinates(volumes, coord_volumes, softmax=True):
85 | batch_size, n_volumes, x_size, y_size, z_size = volumes.shape
86 |
87 | volumes = volumes.reshape((batch_size, n_volumes, -1))
88 | if softmax:
89 | volumes = nn.functional.softmax(volumes, dim=2)
90 | else:
91 | volumes = nn.functional.relu(volumes)
92 |
93 | volumes = volumes.reshape((batch_size, n_volumes, x_size, y_size, z_size))
94 | coordinates = torch.einsum("bnxyz, bxyzc -> bnc", volumes, coord_volumes)
95 |
96 | return coordinates, volumes
97 |
98 |
99 | def unproject_heatmaps(heatmaps, proj_matricies, coord_volumes, volume_aggregation_method='sum', vol_confidences=None):
100 | device = heatmaps.device
101 | batch_size, n_views, n_joints, heatmap_shape = heatmaps.shape[0], heatmaps.shape[1], heatmaps.shape[2], tuple(heatmaps.shape[3:])
102 | volume_shape = coord_volumes.shape[1:4]
103 |
104 | volume_batch = torch.zeros(batch_size, n_joints, *volume_shape, device=device)
105 |
106 | # TODO: speed up this this loop
107 | for batch_i in range(batch_size):
108 | coord_volume = coord_volumes[batch_i]
109 | grid_coord = coord_volume.reshape((-1, 3))
110 |
111 | volume_batch_to_aggregate = torch.zeros(n_views, n_joints, *volume_shape, device=device)
112 |
113 | for view_i in range(n_views):
114 | heatmap = heatmaps[batch_i, view_i]
115 | heatmap = heatmap.unsqueeze(0)
116 |
117 | grid_coord_proj = multiview.project_3d_points_to_image_plane_without_distortion(
118 | proj_matricies[batch_i, view_i], grid_coord, convert_back_to_euclidean=False
119 | )
120 |
121 | invalid_mask = grid_coord_proj[:, 2] <= 0.0 # depth must be larger than 0.0
122 |
123 | grid_coord_proj[grid_coord_proj[:, 2] == 0.0, 2] = 1.0 # not to divide by zero
124 | grid_coord_proj = multiview.homogeneous_to_euclidean(grid_coord_proj)
125 |
126 | # transform to [-1.0, 1.0] range
127 | grid_coord_proj_transformed = torch.zeros_like(grid_coord_proj)
128 | grid_coord_proj_transformed[:, 0] = 2 * (grid_coord_proj[:, 0] / heatmap_shape[0] - 0.5)
129 | grid_coord_proj_transformed[:, 1] = 2 * (grid_coord_proj[:, 1] / heatmap_shape[1] - 0.5)
130 | grid_coord_proj = grid_coord_proj_transformed
131 |
132 | # prepare to F.grid_sample
133 | grid_coord_proj = grid_coord_proj.unsqueeze(1).unsqueeze(0)
134 | try:
135 | current_volume = F.grid_sample(heatmap, grid_coord_proj, align_corners=True)
136 | except TypeError: # old PyTorch
137 | current_volume = F.grid_sample(heatmap, grid_coord_proj)
138 |
139 | # zero out non-valid points
140 | current_volume = current_volume.view(n_joints, -1)
141 | current_volume[:, invalid_mask] = 0.0
142 |
143 | # reshape back to volume
144 | current_volume = current_volume.view(n_joints, *volume_shape)
145 |
146 | # collect
147 | volume_batch_to_aggregate[view_i] = current_volume
148 |
149 | # agregate resulting volume
150 | if volume_aggregation_method.startswith('conf'):
151 | volume_batch[batch_i] = (volume_batch_to_aggregate * vol_confidences[batch_i].view(n_views, n_joints, 1, 1, 1)).sum(0)
152 | elif volume_aggregation_method == 'sum':
153 | volume_batch[batch_i] = volume_batch_to_aggregate.sum(0)
154 | elif volume_aggregation_method == 'max':
155 | volume_batch[batch_i] = volume_batch_to_aggregate.max(0)[0]
156 | elif volume_aggregation_method == 'softmax':
157 | volume_batch_to_aggregate_softmin = volume_batch_to_aggregate.clone()
158 | volume_batch_to_aggregate_softmin = volume_batch_to_aggregate_softmin.view(n_views, -1)
159 | volume_batch_to_aggregate_softmin = nn.functional.softmax(volume_batch_to_aggregate_softmin, dim=0)
160 | volume_batch_to_aggregate_softmin = volume_batch_to_aggregate_softmin.view(n_views, n_joints, *volume_shape)
161 |
162 | volume_batch[batch_i] = (volume_batch_to_aggregate * volume_batch_to_aggregate_softmin).sum(0)
163 | else:
164 | raise ValueError("Unknown volume_aggregation_method: {}".format(volume_aggregation_method))
165 |
166 | return volume_batch
167 |
168 |
169 | def gaussian_2d_pdf(coords, means, sigmas, normalize=True):
170 | normalization = 1.0
171 | if normalize:
172 | normalization = (2 * np.pi * sigmas[:, 0] * sigmas[:, 0])
173 |
174 | exp = torch.exp(-((coords[:, 0] - means[:, 0]) ** 2 / sigmas[:, 0] ** 2 + (coords[:, 1] - means[:, 1]) ** 2 / sigmas[:, 1] ** 2) / 2)
175 | return exp / normalization
176 |
177 |
178 | def render_points_as_2d_gaussians(points, sigmas, image_shape, normalize=True):
179 | device = points.device
180 | n_points = points.shape[0]
181 |
182 | yy, xx = torch.meshgrid(torch.arange(image_shape[0]).to(device), torch.arange(image_shape[1]).to(device))
183 | grid = torch.stack([xx, yy], dim=-1).type(torch.float32)
184 | grid = grid.unsqueeze(0).repeat(n_points, 1, 1, 1) # (n_points, h, w, 2)
185 | grid = grid.reshape((-1, 2))
186 |
187 | points = points.unsqueeze(1).unsqueeze(1).repeat(1, image_shape[0], image_shape[1], 1)
188 | points = points.reshape(-1, 2)
189 |
190 | sigmas = sigmas.unsqueeze(1).unsqueeze(1).repeat(1, image_shape[0], image_shape[1], 1)
191 | sigmas = sigmas.reshape(-1, 2)
192 |
193 | images = gaussian_2d_pdf(grid, points, sigmas, normalize=normalize)
194 | images = images.reshape(n_points, *image_shape)
195 |
196 | return images
197 |
--------------------------------------------------------------------------------
/mvn/utils/vis.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.ndimage
3 | import skimage.transform
4 | import cv2
5 |
6 | import torch
7 |
8 | import matplotlib
9 | matplotlib.use('Agg')
10 | from matplotlib import pylab as plt
11 | from mpl_toolkits.mplot3d import axes3d, Axes3D
12 |
13 |
14 | from mvn.utils.img import image_batch_to_numpy, to_numpy, denormalize_image, resize_image
15 | from mvn.utils.multiview import project_3d_points_to_image_plane_without_distortion
16 |
17 | CONNECTIVITY_DICT = {
18 | 'cmu': [(0, 2), (0, 9), (1, 0), (1, 17), (2, 12), (3, 0), (4, 3), (5, 4), (6, 2), (7, 6), (8, 7), (9, 10), (10, 11), (12, 13), (13, 14), (15, 1), (16, 15), (17, 18)],
19 | 'coco': [(0, 1), (0, 2), (1, 3), (2, 4), (5, 7), (7, 9), (6, 8), (8, 10), (11, 13), (13, 15), (12, 14), (14, 16), (5, 6), (5, 11), (6, 12), (11, 12)],
20 | "mpii": [(0, 1), (1, 2), (2, 6), (5, 4), (4, 3), (3, 6), (6, 7), (7, 8), (8, 9), (8, 12), (8, 13), (10, 11), (11, 12), (13, 14), (14, 15)],
21 | "human36m": [(0, 1), (1, 2), (2, 6), (5, 4), (4, 3), (3, 6), (6, 7), (7, 8), (8, 16), (9, 16), (8, 12), (11, 12), (10, 11), (8, 13), (13, 14), (14, 15)],
22 | "kth": [(0, 1), (1, 2), (5, 4), (4, 3), (6, 7), (7, 8), (11, 10), (10, 9), (2, 3), (3, 9), (2, 8), (9, 12), (8, 12), (12, 13)],
23 | }
24 |
25 | COLOR_DICT = {
26 | 'coco': [
27 | (102, 0, 153), (153, 0, 102), (51, 0, 153), (153, 0, 153), # head
28 | (51, 153, 0), (0, 153, 0), # left arm
29 | (153, 102, 0), (153, 153, 0), # right arm
30 | (0, 51, 153), (0, 0, 153), # left leg
31 | (0, 153, 102), (0, 153, 153), # right leg
32 | (153, 0, 0), (153, 0, 0), (153, 0, 0), (153, 0, 0) # body
33 | ],
34 |
35 | 'human36m': [
36 | (0, 153, 102), (0, 153, 153), (0, 153, 153), # right leg
37 | (0, 51, 153), (0, 0, 153), (0, 0, 153), # left leg
38 | (153, 0, 0), (153, 0, 0), # body
39 | (153, 0, 102), (153, 0, 102), # head
40 | (153, 153, 0), (153, 153, 0), (153, 102, 0), # right arm
41 | (0, 153, 0), (0, 153, 0), (51, 153, 0) # left arm
42 | ],
43 |
44 | 'kth': [
45 | (0, 153, 102), (0, 153, 153), # right leg
46 | (0, 51, 153), (0, 0, 153), # left leg
47 | (153, 102, 0), (153, 153, 0), # right arm
48 | (51, 153, 0), (0, 153, 0), # left arm
49 | (153, 0, 0), (153, 0, 0), (153, 0, 0), (153, 0, 0), (153, 0, 0), # body
50 | (102, 0, 153) # head
51 | ]
52 | }
53 |
54 | JOINT_NAMES_DICT = {
55 | 'coco': {
56 | 0: "nose",
57 | 1: "left_eye",
58 | 2: "right_eye",
59 | 3: "left_ear",
60 | 4: "right_ear",
61 | 5: "left_shoulder",
62 | 6: "right_shoulder",
63 | 7: "left_elbow",
64 | 8: "right_elbow",
65 | 9: "left_wrist",
66 | 10: "right_wrist",
67 | 11: "left_hip",
68 | 12: "right_hip",
69 | 13: "left_knee",
70 | 14: "right_knee",
71 | 15: "left_ankle",
72 | 16: "right_ankle"
73 | }
74 | }
75 |
76 |
77 | def fig_to_array(fig):
78 | fig.canvas.draw()
79 | fig_image = np.array(fig.canvas.renderer._renderer)
80 |
81 | return fig_image
82 |
83 |
84 | def visualize_batch(images_batch, heatmaps_batch, keypoints_2d_batch, proj_matricies_batch,
85 | keypoints_3d_batch_gt, keypoints_3d_batch_pred,
86 | kind="cmu",
87 | cuboids_batch=None,
88 | confidences_batch=None,
89 | batch_index=0, size=5,
90 | max_n_cols=10,
91 | pred_kind=None
92 | ):
93 | if pred_kind is None:
94 | pred_kind = kind
95 |
96 | n_views, n_joints = heatmaps_batch.shape[1], heatmaps_batch.shape[2]
97 |
98 | n_rows = 3
99 | n_rows = n_rows + 1 if keypoints_2d_batch is not None else n_rows
100 | n_rows = n_rows + 1 if cuboids_batch is not None else n_rows
101 | n_rows = n_rows + 1 if confidences_batch is not None else n_rows
102 |
103 | n_cols = min(n_views, max_n_cols)
104 | fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols * size, n_rows * size))
105 | axes = axes.reshape(n_rows, n_cols)
106 |
107 | image_shape = images_batch.shape[3:]
108 | heatmap_shape = heatmaps_batch.shape[3:]
109 |
110 | row_i = 0
111 |
112 | # images
113 | axes[row_i, 0].set_ylabel("image", size='large')
114 |
115 | images = image_batch_to_numpy(images_batch[batch_index])
116 | images = denormalize_image(images).astype(np.uint8)
117 | images = images[..., ::-1] # bgr -> rgb
118 |
119 | for view_i in range(n_cols):
120 | axes[row_i][view_i].imshow(images[view_i])
121 | row_i += 1
122 |
123 | # 2D keypoints (pred)
124 | if keypoints_2d_batch is not None:
125 | axes[row_i, 0].set_ylabel("2d keypoints (pred)", size='large')
126 |
127 | keypoints_2d = to_numpy(keypoints_2d_batch)[batch_index]
128 | for view_i in range(n_cols):
129 | axes[row_i][view_i].imshow(images[view_i])
130 | draw_2d_pose(keypoints_2d[view_i], axes[row_i][view_i], kind=kind)
131 | row_i += 1
132 |
133 | # 2D keypoints (gt projected)
134 | axes[row_i, 0].set_ylabel("2d keypoints (gt projected)", size='large')
135 |
136 | for view_i in range(n_cols):
137 | axes[row_i][view_i].imshow(images[view_i])
138 | keypoints_2d_gt_proj = project_3d_points_to_image_plane_without_distortion(proj_matricies_batch[batch_index, view_i].detach().cpu().numpy(), keypoints_3d_batch_gt[batch_index].detach().cpu().numpy())
139 | draw_2d_pose(keypoints_2d_gt_proj, axes[row_i][view_i], kind=kind)
140 | row_i += 1
141 |
142 | # 2D keypoints (pred projected)
143 | axes[row_i, 0].set_ylabel("2d keypoints (pred projected)", size='large')
144 |
145 | for view_i in range(n_cols):
146 | axes[row_i][view_i].imshow(images[view_i])
147 | keypoints_2d_pred_proj = project_3d_points_to_image_plane_without_distortion(proj_matricies_batch[batch_index, view_i].detach().cpu().numpy(), keypoints_3d_batch_pred[batch_index].detach().cpu().numpy())
148 | draw_2d_pose(keypoints_2d_pred_proj, axes[row_i][view_i], kind=pred_kind)
149 | row_i += 1
150 |
151 | # cuboids
152 | if cuboids_batch is not None:
153 | axes[row_i, 0].set_ylabel("cuboid", size='large')
154 |
155 | for view_i in range(n_cols):
156 | cuboid = cuboids_batch[batch_index]
157 | axes[row_i][view_i].imshow(cuboid.render(proj_matricies_batch[batch_index, view_i].detach().cpu().numpy(), images[view_i].copy()))
158 | row_i += 1
159 |
160 | # confidences
161 | if confidences_batch is not None:
162 | axes[row_i, 0].set_ylabel("confidences", size='large')
163 |
164 | for view_i in range(n_cols):
165 | confidences = to_numpy(confidences_batch[batch_index, view_i])
166 | xs = np.arange(len(confidences))
167 |
168 | axes[row_i, view_i].bar(xs, confidences, color='green')
169 | axes[row_i, view_i].set_xticks(xs)
170 | if torch.max(confidences_batch).item() <= 1.0:
171 | axes[row_i, view_i].set_ylim(0.0, 1.0)
172 |
173 | fig.tight_layout()
174 |
175 | fig_image = fig_to_array(fig)
176 |
177 | plt.close('all')
178 |
179 | return fig_image
180 |
181 |
182 | def visualize_heatmaps(images_batch, heatmaps_batch,
183 | kind="cmu",
184 | batch_index=0, size=5,
185 | max_n_rows=10, max_n_cols=10):
186 | n_views, n_joints = heatmaps_batch.shape[1], heatmaps_batch.shape[2]
187 | heatmap_shape = heatmaps_batch.shape[3:]
188 |
189 | n_cols, n_rows = min(n_joints + 1, max_n_cols), min(n_views, max_n_rows)
190 | fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols * size, n_rows * size))
191 | axes = axes.reshape(n_rows, n_cols)
192 |
193 | # images
194 | images = image_batch_to_numpy(images_batch[batch_index])
195 | images = denormalize_image(images).astype(np.uint8)
196 | images = images[..., ::-1] # bgr ->
197 |
198 | # heatmaps
199 | heatmaps = to_numpy(heatmaps_batch[batch_index])
200 |
201 | for row in range(n_rows):
202 | for col in range(n_cols):
203 | if col == 0:
204 | axes[row, col].set_ylabel(str(row), size='large')
205 | axes[row, col].imshow(images[row])
206 | else:
207 | if row == 0:
208 | joint_name = JOINT_NAMES_DICT[kind][col - 1] if kind in JOINT_NAMES_DICT else str(col - 1)
209 | axes[row, col].set_title(joint_name)
210 |
211 | axes[row, col].imshow(resize_image(images[row], heatmap_shape))
212 | axes[row, col].imshow(heatmaps[row, col - 1], alpha=0.5)
213 |
214 | fig.tight_layout()
215 |
216 | fig_image = fig_to_array(fig)
217 |
218 | plt.close('all')
219 |
220 | return fig_image
221 |
222 |
223 | def visualize_volumes(images_batch, volumes_batch, proj_matricies_batch,
224 | kind="cmu",
225 | cuboids_batch=None,
226 | batch_index=0, size=5,
227 | max_n_rows=10, max_n_cols=10):
228 | n_views, n_joints = volumes_batch.shape[1], volumes_batch.shape[2]
229 |
230 | n_cols, n_rows = min(n_joints + 1, max_n_cols), min(n_views, max_n_rows)
231 | fig = plt.figure(figsize=(n_cols * size, n_rows * size))
232 |
233 | # images
234 | images = image_batch_to_numpy(images_batch[batch_index])
235 | images = denormalize_image(images).astype(np.uint8)
236 | images = images[..., ::-1] # bgr ->
237 |
238 | # heatmaps
239 | volumes = to_numpy(volumes_batch[batch_index])
240 |
241 | for row in range(n_rows):
242 | for col in range(n_cols):
243 | if col == 0:
244 | ax = fig.add_subplot(n_rows, n_cols, row * n_cols + col + 1)
245 | ax.set_ylabel(str(row), size='large')
246 |
247 | cuboid = cuboids_batch[batch_index]
248 | ax.imshow(cuboid.render(proj_matricies_batch[batch_index, row].detach().cpu().numpy(), images[row].copy()))
249 | else:
250 | ax = fig.add_subplot(n_rows, n_cols, row * n_cols + col + 1, projection='3d')
251 |
252 | if row == 0:
253 | joint_name = JOINT_NAMES_DICT[kind][col - 1] if kind in JOINT_NAMES_DICT else str(col - 1)
254 | ax.set_title(joint_name)
255 |
256 | draw_voxels(volumes[col - 1], ax, norm=True)
257 |
258 | fig.tight_layout()
259 |
260 | fig_image = fig_to_array(fig)
261 |
262 | plt.close('all')
263 |
264 | return fig_image
265 |
266 |
267 | def draw_2d_pose(keypoints, ax, kind='cmu', keypoints_mask=None, point_size=2, line_width=1, radius=None, color=None):
268 | """
269 | Visualizes a 2d skeleton
270 |
271 | Args
272 | keypoints numpy array of shape (19, 2): pose to draw in CMU format.
273 | ax: matplotlib axis to draw on
274 | """
275 | connectivity = CONNECTIVITY_DICT[kind]
276 |
277 | color = 'blue' if color is None else color
278 |
279 | if keypoints_mask is None:
280 | keypoints_mask = [True] * len(keypoints)
281 |
282 | # points
283 | ax.scatter(keypoints[keypoints_mask][:, 0], keypoints[keypoints_mask][:, 1], c='red', s=point_size)
284 |
285 | # connections
286 | for (index_from, index_to) in connectivity:
287 | if keypoints_mask[index_from] and keypoints_mask[index_to]:
288 | xs, ys = [np.array([keypoints[index_from, j], keypoints[index_to, j]]) for j in range(2)]
289 | ax.plot(xs, ys, c=color, lw=line_width)
290 |
291 | if radius is not None:
292 | root_keypoint_index = 0
293 | xroot, yroot = keypoints[root_keypoint_index, 0], keypoints[root_keypoint_index, 1]
294 |
295 | ax.set_xlim([-radius + xroot, radius + xroot])
296 | ax.set_ylim([-radius + yroot, radius + yroot])
297 |
298 | ax.set_aspect('equal')
299 |
300 |
301 | def draw_2d_pose_cv2(keypoints, canvas, kind='cmu', keypoints_mask=None, point_size=2, point_color=(255, 255, 255), line_width=1, radius=None, color=None, anti_aliasing_scale=1):
302 | canvas = canvas.copy()
303 |
304 | shape = np.array(canvas.shape[:2])
305 | new_shape = shape * anti_aliasing_scale
306 | canvas = resize_image(canvas, tuple(new_shape))
307 |
308 | keypoints = keypoints * anti_aliasing_scale
309 | point_size = point_size * anti_aliasing_scale
310 | line_width = line_width * anti_aliasing_scale
311 |
312 | connectivity = CONNECTIVITY_DICT[kind]
313 |
314 | color = 'blue' if color is None else color
315 |
316 | if keypoints_mask is None:
317 | keypoints_mask = [True] * len(keypoints)
318 |
319 | # connections
320 | for i, (index_from, index_to) in enumerate(connectivity):
321 | if keypoints_mask[index_from] and keypoints_mask[index_to]:
322 | pt_from = tuple(np.array(keypoints[index_from, :]).astype(int))
323 | pt_to = tuple(np.array(keypoints[index_to, :]).astype(int))
324 |
325 | if kind in COLOR_DICT:
326 | color = COLOR_DICT[kind][i]
327 | else:
328 | color = (0, 0, 255)
329 |
330 | cv2.line(canvas, pt_from, pt_to, color=color, thickness=line_width)
331 |
332 | if kind == 'coco':
333 | mid_collarbone = (keypoints[5, :] + keypoints[6, :]) / 2
334 | nose = keypoints[0, :]
335 |
336 | pt_from = tuple(np.array(nose).astype(int))
337 | pt_to = tuple(np.array(mid_collarbone).astype(int))
338 |
339 | if kind in COLOR_DICT:
340 | color = (153, 0, 51)
341 | else:
342 | color = (0, 0, 255)
343 |
344 | cv2.line(canvas, pt_from, pt_to, color=color, thickness=line_width)
345 |
346 | # points
347 | for pt in keypoints[keypoints_mask]:
348 | cv2.circle(canvas, tuple(pt.astype(int)), point_size, color=point_color, thickness=-1)
349 |
350 | canvas = resize_image(canvas, tuple(shape))
351 |
352 | return canvas
353 |
354 |
355 | def draw_3d_pose(keypoints, ax, keypoints_mask=None, kind='cmu', radius=None, root=None, point_size=2, line_width=2, draw_connections=True):
356 | connectivity = CONNECTIVITY_DICT[kind]
357 |
358 | if keypoints_mask is None:
359 | keypoints_mask = [True] * len(keypoints)
360 |
361 | if draw_connections:
362 | # Make connection matrix
363 | for i, joint in enumerate(connectivity):
364 | if keypoints_mask[joint[0]] and keypoints_mask[joint[1]]:
365 | xs, ys, zs = [np.array([keypoints[joint[0], j], keypoints[joint[1], j]]) for j in range(3)]
366 |
367 | if kind in COLOR_DICT:
368 | color = COLOR_DICT[kind][i]
369 | else:
370 | color = (0, 0, 255)
371 |
372 | color = np.array(color) / 255
373 |
374 | ax.plot(xs, ys, zs, lw=line_width, c=color)
375 |
376 | if kind == 'coco':
377 | mid_collarbone = (keypoints[5, :] + keypoints[6, :]) / 2
378 | nose = keypoints[0, :]
379 |
380 | xs, ys, zs = [np.array([nose[j], mid_collarbone[j]]) for j in range(3)]
381 |
382 | if kind in COLOR_DICT:
383 | color = (153, 0, 51)
384 | else:
385 | color = (0, 0, 255)
386 |
387 | color = np.array(color) / 255
388 |
389 | ax.plot(xs, ys, zs, lw=line_width, c=color)
390 |
391 |
392 | ax.scatter(keypoints[keypoints_mask][:, 0], keypoints[keypoints_mask][:, 1], keypoints[keypoints_mask][:, 2],
393 | s=point_size, c=np.array([230, 145, 56])/255, edgecolors='black') # np.array([230, 145, 56])/255
394 |
395 | if radius is not None:
396 | if root is None:
397 | root = np.mean(keypoints, axis=0)
398 | xroot, yroot, zroot = root
399 | ax.set_xlim([-radius + xroot, radius + xroot])
400 | ax.set_ylim([-radius + yroot, radius + yroot])
401 | ax.set_zlim([-radius + zroot, radius + zroot])
402 |
403 | ax.set_aspect('equal')
404 |
405 |
406 | # Get rid of the panes
407 | background_color = np.array([252, 252, 252]) / 255
408 |
409 | ax.w_xaxis.set_pane_color(background_color)
410 | ax.w_yaxis.set_pane_color(background_color)
411 | ax.w_zaxis.set_pane_color(background_color)
412 |
413 | # Get rid of the ticks
414 | ax.set_xticklabels([])
415 | ax.set_yticklabels([])
416 | ax.set_zticklabels([])
417 |
418 |
419 | def draw_voxels(voxels, ax, shape=(8, 8, 8), norm=True, alpha=0.1):
420 | # resize for visualization
421 | zoom = np.array(shape) / np.array(voxels.shape)
422 | voxels = skimage.transform.resize(voxels, shape, mode='constant', anti_aliasing=True)
423 | voxels = voxels.transpose(2, 0, 1)
424 |
425 | if norm and voxels.max() - voxels.min() > 0:
426 | voxels = (voxels - voxels.min()) / (voxels.max() - voxels.min())
427 |
428 | filled = np.ones(voxels.shape)
429 |
430 | # facecolors
431 | cmap = plt.get_cmap("Blues")
432 |
433 | facecolors_a = cmap(voxels, alpha=alpha)
434 | facecolors_a = facecolors_a.reshape(-1, 4)
435 |
436 | facecolors_hex = np.array(list(map(lambda x: matplotlib.colors.to_hex(x, keep_alpha=True), facecolors_a)))
437 | facecolors_hex = facecolors_hex.reshape(*voxels.shape)
438 |
439 | # explode voxels to perform 3d alpha rendering (https://matplotlib.org/devdocs/gallery/mplot3d/voxels_numpy_logo.html)
440 | def explode(data):
441 | size = np.array(data.shape) * 2
442 | data_e = np.zeros(size - 1, dtype=data.dtype)
443 | data_e[::2, ::2, ::2] = data
444 | return data_e
445 |
446 | filled_2 = explode(filled)
447 | facecolors_2 = explode(facecolors_hex)
448 |
449 | # shrink the gaps
450 | x, y, z = np.indices(np.array(filled_2.shape) + 1).astype(float) // 2
451 | x[0::2, :, :] += 0.05
452 | y[:, 0::2, :] += 0.05
453 | z[:, :, 0::2] += 0.05
454 | x[1::2, :, :] += 0.95
455 | y[:, 1::2, :] += 0.95
456 | z[:, :, 1::2] += 0.95
457 |
458 | # draw voxels
459 | ax.voxels(x, y, z, filled_2, facecolors=facecolors_2)
460 |
461 | ax.set_xlabel("z"); ax.set_ylabel("x"); ax.set_zlabel("y")
462 | ax.invert_xaxis(); ax.invert_zaxis()
463 |
--------------------------------------------------------------------------------
/mvn/utils/volumetric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import torch
4 |
5 | from mvn.utils import multiview
6 |
7 |
8 | class Point3D:
9 | def __init__(self, point, size=3, color=(0, 0, 255)):
10 | self.point = point
11 | self.size = size
12 | self.color = color
13 |
14 | def render(self, proj_matrix, canvas):
15 | point_2d = multiview.project_3d_points_to_image_plane_without_distortion(
16 | proj_matrix, np.array([self.point])
17 | )[0]
18 |
19 | point_2d = tuple(map(int, point_2d))
20 | cv2.circle(canvas, point_2d, self.size, self.color, self.size)
21 |
22 | return canvas
23 |
24 |
25 | class Line3D:
26 | def __init__(self, start_point, end_point, size=2, color=(0, 0, 255)):
27 | self.start_point, self.end_point = start_point, end_point
28 | self.size = size
29 | self.color = color
30 |
31 | def render(self, proj_matrix, canvas):
32 | start_point_2d, end_point_2d = multiview.project_3d_points_to_image_plane_without_distortion(
33 | proj_matrix, np.array([self.start_point, self.end_point])
34 | )
35 |
36 | start_point_2d = tuple(map(int, start_point_2d))
37 | end_point_2d = tuple(map(int, end_point_2d))
38 |
39 | cv2.line(canvas, start_point_2d, end_point_2d, self.color, self.size)
40 |
41 | return canvas
42 |
43 |
44 | class Cuboid3D:
45 | def __init__(self, position, sides):
46 | self.position = position
47 | self.sides = sides
48 |
49 | def build(self):
50 | primitives = []
51 |
52 | line_color = (255, 255, 0)
53 |
54 | start = self.position + np.array([0, 0, 0])
55 | primitives.append(Line3D(start, start + np.array([self.sides[0], 0, 0]), color=(255, 0, 0)))
56 | primitives.append(Line3D(start, start + np.array([0, self.sides[1], 0]), color=(0, 255, 0)))
57 | primitives.append(Line3D(start, start + np.array([0, 0, self.sides[2]]), color=(0, 0, 255)))
58 |
59 | start = self.position + np.array([self.sides[0], 0, self.sides[2]])
60 | primitives.append(Line3D(start, start + np.array([-self.sides[0], 0, 0]), color=line_color))
61 | primitives.append(Line3D(start, start + np.array([0, self.sides[1], 0]), color=line_color))
62 | primitives.append(Line3D(start, start + np.array([0, 0, -self.sides[2]]), color=line_color))
63 |
64 | start = self.position + np.array([self.sides[0], self.sides[1], 0])
65 | primitives.append(Line3D(start, start + np.array([-self.sides[0], 0, 0]), color=line_color))
66 | primitives.append(Line3D(start, start + np.array([0, -self.sides[1], 0]), color=line_color))
67 | primitives.append(Line3D(start, start + np.array([0, 0, self.sides[2]]), color=line_color))
68 |
69 | start = self.position + np.array([0, self.sides[1], self.sides[2]])
70 | primitives.append(Line3D(start, start + np.array([self.sides[0], 0, 0]), color=line_color))
71 | primitives.append(Line3D(start, start + np.array([0, -self.sides[1], 0]), color=line_color))
72 | primitives.append(Line3D(start, start + np.array([0, 0, -self.sides[2]]), color=line_color))
73 |
74 | return primitives
75 |
76 | def render(self, proj_matrix, canvas):
77 | # TODO: support rotation
78 |
79 | primitives = self.build()
80 |
81 | for primitive in primitives:
82 | canvas = primitive.render(proj_matrix, canvas)
83 |
84 | return canvas
85 |
86 |
87 | def get_rotation_matrix(axis, theta):
88 | """Returns the rotation matrix associated with counterclockwise rotation about
89 | the given axis by theta radians.
90 | """
91 | axis = np.asarray(axis)
92 | axis = axis / np.sqrt(np.dot(axis, axis))
93 | a = np.cos(theta / 2.0)
94 | b, c, d = -axis * np.sin(theta / 2.0)
95 | aa, bb, cc, dd = a * a, b * b, c * c, d * d
96 | bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
97 | return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
98 | [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
99 | [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])
100 |
101 |
102 | def rotate_coord_volume(coord_volume, theta, axis):
103 | shape = coord_volume.shape
104 | device = coord_volume.device
105 |
106 | rot = get_rotation_matrix(axis, theta)
107 | rot = torch.from_numpy(rot).type(torch.float).to(device)
108 |
109 | coord_volume = coord_volume.view(-1, 3)
110 | coord_volume = rot.mm(coord_volume.t()).t()
111 |
112 | coord_volume = coord_volume.view(*shape)
113 |
114 | return coord_volume
115 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cycler==0.10.0
2 | decorator==4.4.0
3 | easydict==1.9
4 | imageio==2.5.0
5 | kiwisolver==1.1.0
6 | matplotlib==3.1.1
7 | networkx==2.3
8 | numpy==1.17.2
9 | opencv-python==4.1.1.26
10 | Pillow==6.2.0
11 | protobuf==3.10.0
12 | pyparsing==2.4.2
13 | python-dateutil==2.8.0
14 | PyWavelets==1.0.3
15 | PyYAML==5.1.2
16 | scikit-image==0.15.0
17 | scipy==1.3.1
18 | six==1.12.0
19 | tensorboardX==1.8
20 | torch==1.0.1
21 | torchvision==0.2.2
22 | h5py==2.10.0
23 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import argparse
4 | import time
5 | import json
6 | from datetime import datetime
7 | from collections import defaultdict
8 | from itertools import islice
9 | import pickle
10 | import copy
11 |
12 | import numpy as np
13 | import cv2
14 |
15 | import torch
16 | from torch import nn
17 | from torch import autograd
18 | import torch.nn.functional as F
19 | import torch.optim as optim
20 | from torch.utils.data import DataLoader
21 | from torch.nn.parallel import DistributedDataParallel
22 |
23 | from tensorboardX import SummaryWriter
24 |
25 | from mvn.models.triangulation import RANSACTriangulationNet, AlgebraicTriangulationNet, VolumetricTriangulationNet
26 | from mvn.models.loss import KeypointsMSELoss, KeypointsMSESmoothLoss, KeypointsMAELoss, KeypointsL2Loss, VolumetricCELoss
27 |
28 | from mvn.utils import img, multiview, op, vis, misc, cfg
29 | from mvn.datasets import human36m
30 | from mvn.datasets import utils as dataset_utils
31 |
32 |
33 | def parse_args():
34 | parser = argparse.ArgumentParser()
35 |
36 | parser.add_argument("--config", type=str, required=True, help="Path, where config file is stored")
37 | parser.add_argument('--eval', action='store_true', help="If set, then only evaluation will be done")
38 | parser.add_argument('--eval_dataset', type=str, default='val', help="Dataset split on which evaluate. Can be 'train' and 'val'")
39 |
40 | parser.add_argument("--local_rank", type=int, help="Local rank of the process on the node")
41 | parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
42 |
43 | parser.add_argument("--logdir", type=str, default="/Vol1/dbstore/datasets/k.iskakov/logs/multi-view-net-repr", help="Path, where logs will be stored")
44 |
45 | args = parser.parse_args()
46 | return args
47 |
48 |
49 | def setup_human36m_dataloaders(config, is_train, distributed_train):
50 | train_dataloader = None
51 | if is_train:
52 | # train
53 | train_dataset = human36m.Human36MMultiViewDataset(
54 | h36m_root=config.dataset.train.h36m_root,
55 | pred_results_path=config.dataset.train.pred_results_path if hasattr(config.dataset.train, "pred_results_path") else None,
56 | train=True,
57 | test=False,
58 | image_shape=config.image_shape if hasattr(config, "image_shape") else (256, 256),
59 | labels_path=config.dataset.train.labels_path,
60 | with_damaged_actions=config.dataset.train.with_damaged_actions,
61 | scale_bbox=config.dataset.train.scale_bbox,
62 | kind=config.kind,
63 | undistort_images=config.dataset.train.undistort_images,
64 | ignore_cameras=config.dataset.train.ignore_cameras if hasattr(config.dataset.train, "ignore_cameras") else [],
65 | crop=config.dataset.train.crop if hasattr(config.dataset.train, "crop") else True,
66 | )
67 |
68 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed_train else None
69 |
70 | train_dataloader = DataLoader(
71 | train_dataset,
72 | batch_size=config.opt.batch_size,
73 | shuffle=config.dataset.train.shuffle and (train_sampler is None), # debatable
74 | sampler=train_sampler,
75 | collate_fn=dataset_utils.make_collate_fn(randomize_n_views=config.dataset.train.randomize_n_views,
76 | min_n_views=config.dataset.train.min_n_views,
77 | max_n_views=config.dataset.train.max_n_views),
78 | num_workers=config.dataset.train.num_workers,
79 | worker_init_fn=dataset_utils.worker_init_fn,
80 | pin_memory=True
81 | )
82 |
83 | # val
84 | val_dataset = human36m.Human36MMultiViewDataset(
85 | h36m_root=config.dataset.val.h36m_root,
86 | pred_results_path=config.dataset.val.pred_results_path if hasattr(config.dataset.val, "pred_results_path") else None,
87 | train=False,
88 | test=True,
89 | image_shape=config.image_shape if hasattr(config, "image_shape") else (256, 256),
90 | labels_path=config.dataset.val.labels_path,
91 | with_damaged_actions=config.dataset.val.with_damaged_actions,
92 | retain_every_n_frames_in_test=config.dataset.val.retain_every_n_frames_in_test,
93 | scale_bbox=config.dataset.val.scale_bbox,
94 | kind=config.kind,
95 | undistort_images=config.dataset.val.undistort_images,
96 | ignore_cameras=config.dataset.val.ignore_cameras if hasattr(config.dataset.val, "ignore_cameras") else [],
97 | crop=config.dataset.val.crop if hasattr(config.dataset.val, "crop") else True,
98 | )
99 |
100 | val_dataloader = DataLoader(
101 | val_dataset,
102 | batch_size=config.opt.val_batch_size if hasattr(config.opt, "val_batch_size") else config.opt.batch_size,
103 | shuffle=config.dataset.val.shuffle,
104 | collate_fn=dataset_utils.make_collate_fn(randomize_n_views=config.dataset.val.randomize_n_views,
105 | min_n_views=config.dataset.val.min_n_views,
106 | max_n_views=config.dataset.val.max_n_views),
107 | num_workers=config.dataset.val.num_workers,
108 | worker_init_fn=dataset_utils.worker_init_fn,
109 | pin_memory=True
110 | )
111 |
112 | return train_dataloader, val_dataloader, train_sampler
113 |
114 |
115 | def setup_dataloaders(config, is_train=True, distributed_train=False):
116 | if config.dataset.kind == 'human36m':
117 | train_dataloader, val_dataloader, train_sampler = setup_human36m_dataloaders(config, is_train, distributed_train)
118 | else:
119 | raise NotImplementedError("Unknown dataset: {}".format(config.dataset.kind))
120 |
121 | return train_dataloader, val_dataloader, train_sampler
122 |
123 |
124 | def setup_experiment(config, model_name, is_train=True):
125 | prefix = "" if is_train else "eval_"
126 |
127 | if config.title:
128 | experiment_title = config.title + "_" + model_name
129 | else:
130 | experiment_title = model_name
131 |
132 | experiment_title = prefix + experiment_title
133 |
134 | experiment_name = '{}@{}'.format(experiment_title, datetime.now().strftime("%d.%m.%Y-%H:%M:%S"))
135 | print("Experiment name: {}".format(experiment_name))
136 |
137 | experiment_dir = os.path.join(args.logdir, experiment_name)
138 | os.makedirs(experiment_dir, exist_ok=True)
139 |
140 | checkpoints_dir = os.path.join(experiment_dir, "checkpoints")
141 | os.makedirs(checkpoints_dir, exist_ok=True)
142 |
143 | shutil.copy(args.config, os.path.join(experiment_dir, "config.yaml"))
144 |
145 | # tensorboard
146 | writer = SummaryWriter(os.path.join(experiment_dir, "tb"))
147 |
148 | # dump config to tensorboard
149 | writer.add_text(misc.config_to_str(config), "config", 0)
150 |
151 | return experiment_dir, writer
152 |
153 |
154 | def one_epoch(model, criterion, opt, config, dataloader, device, epoch, n_iters_total=0, is_train=True, caption='', master=False, experiment_dir=None, writer=None):
155 | name = "train" if is_train else "val"
156 | model_type = config.model.name
157 |
158 | if is_train:
159 | model.train()
160 | else:
161 | model.eval()
162 |
163 | metric_dict = defaultdict(list)
164 |
165 | results = defaultdict(list)
166 |
167 | # used to turn on/off gradients
168 | grad_context = torch.autograd.enable_grad if is_train else torch.no_grad
169 | with grad_context():
170 | end = time.time()
171 |
172 | iterator = enumerate(dataloader)
173 | if is_train and config.opt.n_iters_per_epoch is not None:
174 | iterator = islice(iterator, config.opt.n_iters_per_epoch)
175 |
176 | for iter_i, batch in iterator:
177 | with autograd.detect_anomaly():
178 | # measure data loading time
179 | data_time = time.time() - end
180 |
181 | if batch is None:
182 | print("Found None batch")
183 | continue
184 |
185 | images_batch, keypoints_3d_gt, keypoints_3d_validity_gt, proj_matricies_batch = dataset_utils.prepare_batch(batch, device, config)
186 |
187 | keypoints_2d_pred, cuboids_pred, base_points_pred = None, None, None
188 | if model_type == "alg" or model_type == "ransac":
189 | keypoints_3d_pred, keypoints_2d_pred, heatmaps_pred, confidences_pred = model(images_batch, proj_matricies_batch, batch)
190 | elif model_type == "vol":
191 | keypoints_3d_pred, heatmaps_pred, volumes_pred, confidences_pred, cuboids_pred, coord_volumes_pred, base_points_pred = model(images_batch, proj_matricies_batch, batch)
192 |
193 | batch_size, n_views, image_shape = images_batch.shape[0], images_batch.shape[1], tuple(images_batch.shape[3:])
194 | n_joints = keypoints_3d_pred.shape[1]
195 |
196 | keypoints_3d_binary_validity_gt = (keypoints_3d_validity_gt > 0.0).type(torch.float32)
197 |
198 | scale_keypoints_3d = config.opt.scale_keypoints_3d if hasattr(config.opt, "scale_keypoints_3d") else 1.0
199 |
200 | # 1-view case
201 | if n_views == 1:
202 | if config.kind == "human36m":
203 | base_joint = 6
204 | elif config.kind == "coco":
205 | base_joint = 11
206 |
207 | keypoints_3d_gt_transformed = keypoints_3d_gt.clone()
208 | keypoints_3d_gt_transformed[:, torch.arange(n_joints) != base_joint] -= keypoints_3d_gt_transformed[:, base_joint:base_joint + 1]
209 | keypoints_3d_gt = keypoints_3d_gt_transformed
210 |
211 | keypoints_3d_pred_transformed = keypoints_3d_pred.clone()
212 | keypoints_3d_pred_transformed[:, torch.arange(n_joints) != base_joint] -= keypoints_3d_pred_transformed[:, base_joint:base_joint + 1]
213 | keypoints_3d_pred = keypoints_3d_pred_transformed
214 |
215 | # calculate loss
216 | total_loss = 0.0
217 | loss = criterion(keypoints_3d_pred * scale_keypoints_3d, keypoints_3d_gt * scale_keypoints_3d, keypoints_3d_binary_validity_gt)
218 | total_loss += loss
219 | metric_dict[f'{config.opt.criterion}'].append(loss.item())
220 |
221 | # volumetric ce loss
222 | use_volumetric_ce_loss = config.opt.use_volumetric_ce_loss if hasattr(config.opt, "use_volumetric_ce_loss") else False
223 | if use_volumetric_ce_loss:
224 | volumetric_ce_criterion = VolumetricCELoss()
225 |
226 | loss = volumetric_ce_criterion(coord_volumes_pred, volumes_pred, keypoints_3d_gt, keypoints_3d_binary_validity_gt)
227 | metric_dict['volumetric_ce_loss'].append(loss.item())
228 |
229 | weight = config.opt.volumetric_ce_loss_weight if hasattr(config.opt, "volumetric_ce_loss_weight") else 1.0
230 | total_loss += weight * loss
231 |
232 | metric_dict['total_loss'].append(total_loss.item())
233 |
234 | if is_train:
235 | opt.zero_grad()
236 | total_loss.backward()
237 |
238 | if hasattr(config.opt, "grad_clip"):
239 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.opt.grad_clip / config.opt.lr)
240 |
241 | metric_dict['grad_norm_times_lr'].append(config.opt.lr * misc.calc_gradient_norm(filter(lambda x: x[1].requires_grad, model.named_parameters())))
242 |
243 | opt.step()
244 |
245 | # calculate metrics
246 | l2 = KeypointsL2Loss()(keypoints_3d_pred * scale_keypoints_3d, keypoints_3d_gt * scale_keypoints_3d, keypoints_3d_binary_validity_gt)
247 | metric_dict['l2'].append(l2.item())
248 |
249 | # base point l2
250 | if base_points_pred is not None:
251 | base_point_l2_list = []
252 | for batch_i in range(batch_size):
253 | base_point_pred = base_points_pred[batch_i]
254 |
255 | if config.model.kind == "coco":
256 | base_point_gt = (keypoints_3d_gt[batch_i, 11, :3] + keypoints_3d[batch_i, 12, :3]) / 2
257 | elif config.model.kind == "mpii":
258 | base_point_gt = keypoints_3d_gt[batch_i, 6, :3]
259 |
260 | base_point_l2_list.append(torch.sqrt(torch.sum((base_point_pred * scale_keypoints_3d - base_point_gt * scale_keypoints_3d) ** 2)).item())
261 |
262 | base_point_l2 = 0.0 if len(base_point_l2_list) == 0 else np.mean(base_point_l2_list)
263 | metric_dict['base_point_l2'].append(base_point_l2)
264 |
265 | # save answers for evalulation
266 | if not is_train:
267 | results['keypoints_3d'].append(keypoints_3d_pred.detach().cpu().numpy())
268 | results['indexes'].append(batch['indexes'])
269 |
270 | # plot visualization
271 | if master:
272 | if n_iters_total % config.vis_freq == 0:# or total_l2.item() > 500.0:
273 | vis_kind = config.kind
274 | if (config.transfer_cmu_to_human36m if hasattr(config, "transfer_cmu_to_human36m") else False):
275 | vis_kind = "coco"
276 |
277 | for batch_i in range(min(batch_size, config.vis_n_elements)):
278 | keypoints_vis = vis.visualize_batch(
279 | images_batch, heatmaps_pred, keypoints_2d_pred, proj_matricies_batch,
280 | keypoints_3d_gt, keypoints_3d_pred,
281 | kind=vis_kind,
282 | cuboids_batch=cuboids_pred,
283 | confidences_batch=confidences_pred,
284 | batch_index=batch_i, size=5,
285 | max_n_cols=10
286 | )
287 | writer.add_image(f"{name}/keypoints_vis/{batch_i}", keypoints_vis.transpose(2, 0, 1), global_step=n_iters_total)
288 |
289 | heatmaps_vis = vis.visualize_heatmaps(
290 | images_batch, heatmaps_pred,
291 | kind=vis_kind,
292 | batch_index=batch_i, size=5,
293 | max_n_rows=10, max_n_cols=10
294 | )
295 | writer.add_image(f"{name}/heatmaps/{batch_i}", heatmaps_vis.transpose(2, 0, 1), global_step=n_iters_total)
296 |
297 | if model_type == "vol":
298 | volumes_vis = vis.visualize_volumes(
299 | images_batch, volumes_pred, proj_matricies_batch,
300 | kind=vis_kind,
301 | cuboids_batch=cuboids_pred,
302 | batch_index=batch_i, size=5,
303 | max_n_rows=1, max_n_cols=16
304 | )
305 | writer.add_image(f"{name}/volumes/{batch_i}", volumes_vis.transpose(2, 0, 1), global_step=n_iters_total)
306 |
307 | # dump weights to tensoboard
308 | if n_iters_total % config.vis_freq == 0:
309 | for p_name, p in model.named_parameters():
310 | try:
311 | writer.add_histogram(p_name, p.clone().cpu().data.numpy(), n_iters_total)
312 | except ValueError as e:
313 | print(e)
314 | print(p_name, p)
315 | exit()
316 |
317 | # dump to tensorboard per-iter loss/metric stats
318 | if is_train:
319 | for title, value in metric_dict.items():
320 | writer.add_scalar(f"{name}/{title}", value[-1], n_iters_total)
321 |
322 | # measure elapsed time
323 | batch_time = time.time() - end
324 | end = time.time()
325 |
326 | # dump to tensorboard per-iter time stats
327 | writer.add_scalar(f"{name}/batch_time", batch_time, n_iters_total)
328 | writer.add_scalar(f"{name}/data_time", data_time, n_iters_total)
329 |
330 | # dump to tensorboard per-iter stats about sizes
331 | writer.add_scalar(f"{name}/batch_size", batch_size, n_iters_total)
332 | writer.add_scalar(f"{name}/n_views", n_views, n_iters_total)
333 |
334 | n_iters_total += 1
335 |
336 | # calculate evaluation metrics
337 | if master:
338 | if not is_train:
339 | results['keypoints_3d'] = np.concatenate(results['keypoints_3d'], axis=0)
340 | results['indexes'] = np.concatenate(results['indexes'])
341 |
342 | try:
343 | scalar_metric, full_metric = dataloader.dataset.evaluate(results['keypoints_3d'])
344 | except Exception as e:
345 | print("Failed to evaluate. Reason: ", e)
346 | scalar_metric, full_metric = 0.0, {}
347 |
348 | metric_dict['dataset_metric'].append(scalar_metric)
349 |
350 | checkpoint_dir = os.path.join(experiment_dir, "checkpoints", "{:04}".format(epoch))
351 | os.makedirs(checkpoint_dir, exist_ok=True)
352 |
353 | # dump results
354 | with open(os.path.join(checkpoint_dir, "results.pkl"), 'wb') as fout:
355 | pickle.dump(results, fout)
356 |
357 | # dump full metric
358 | with open(os.path.join(checkpoint_dir, "metric.json".format(epoch)), 'w') as fout:
359 | json.dump(full_metric, fout, indent=4, sort_keys=True)
360 |
361 | # dump to tensorboard per-epoch stats
362 | for title, value in metric_dict.items():
363 | writer.add_scalar(f"{name}/{title}_epoch", np.mean(value), epoch)
364 |
365 | return n_iters_total
366 |
367 |
368 | def init_distributed(args):
369 | if "WORLD_SIZE" not in os.environ or int(os.environ["WORLD_SIZE"]) < 1:
370 | return False
371 |
372 | torch.cuda.set_device(args.local_rank)
373 |
374 | assert os.environ["MASTER_PORT"], "set the MASTER_PORT variable or use pytorch launcher"
375 | assert os.environ["RANK"], "use pytorch launcher and explicityly state the rank of the process"
376 |
377 | torch.manual_seed(args.seed)
378 | torch.distributed.init_process_group(backend="nccl", init_method="env://")
379 |
380 | return True
381 |
382 |
383 | def main(args):
384 | print("Number of available GPUs: {}".format(torch.cuda.device_count()))
385 |
386 | is_distributed = init_distributed(args)
387 | master = True
388 | if is_distributed and os.environ["RANK"]:
389 | master = int(os.environ["RANK"]) == 0
390 |
391 | if is_distributed:
392 | device = torch.device(args.local_rank)
393 | else:
394 | device = torch.device(0)
395 |
396 | # config
397 | config = cfg.load_config(args.config)
398 | config.opt.n_iters_per_epoch = config.opt.n_objects_per_epoch // config.opt.batch_size
399 |
400 | model = {
401 | "ransac": RANSACTriangulationNet,
402 | "alg": AlgebraicTriangulationNet,
403 | "vol": VolumetricTriangulationNet
404 | }[config.model.name](config, device=device).to(device)
405 |
406 | if config.model.init_weights:
407 | state_dict = torch.load(config.model.checkpoint)
408 | for key in list(state_dict.keys()):
409 | new_key = key.replace("module.", "")
410 | state_dict[new_key] = state_dict.pop(key)
411 |
412 | model.load_state_dict(state_dict, strict=True)
413 | print("Successfully loaded pretrained weights for whole model")
414 |
415 | # criterion
416 | criterion_class = {
417 | "MSE": KeypointsMSELoss,
418 | "MSESmooth": KeypointsMSESmoothLoss,
419 | "MAE": KeypointsMAELoss
420 | }[config.opt.criterion]
421 |
422 | if config.opt.criterion == "MSESmooth":
423 | criterion = criterion_class(config.opt.mse_smooth_threshold)
424 | else:
425 | criterion = criterion_class()
426 |
427 | # optimizer
428 | opt = None
429 | if not args.eval:
430 | if config.model.name == "vol":
431 | opt = torch.optim.Adam(
432 | [{'params': model.backbone.parameters()},
433 | {'params': model.process_features.parameters(), 'lr': config.opt.process_features_lr if hasattr(config.opt, "process_features_lr") else config.opt.lr},
434 | {'params': model.volume_net.parameters(), 'lr': config.opt.volume_net_lr if hasattr(config.opt, "volume_net_lr") else config.opt.lr}
435 | ],
436 | lr=config.opt.lr
437 | )
438 | else:
439 | opt = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.opt.lr)
440 |
441 |
442 | # datasets
443 | print("Loading data...")
444 | train_dataloader, val_dataloader, train_sampler = setup_dataloaders(config, distributed_train=is_distributed)
445 |
446 | # experiment
447 | experiment_dir, writer = None, None
448 | if master:
449 | experiment_dir, writer = setup_experiment(config, type(model).__name__, is_train=not args.eval)
450 |
451 | # multi-gpu
452 | if is_distributed:
453 | model = DistributedDataParallel(model, device_ids=[device])
454 |
455 | if not args.eval:
456 | # train loop
457 | n_iters_total_train, n_iters_total_val = 0, 0
458 | for epoch in range(config.opt.n_epochs):
459 | if train_sampler is not None:
460 | train_sampler.set_epoch(epoch)
461 |
462 | n_iters_total_train = one_epoch(model, criterion, opt, config, train_dataloader, device, epoch, n_iters_total=n_iters_total_train, is_train=True, master=master, experiment_dir=experiment_dir, writer=writer)
463 | n_iters_total_val = one_epoch(model, criterion, opt, config, val_dataloader, device, epoch, n_iters_total=n_iters_total_val, is_train=False, master=master, experiment_dir=experiment_dir, writer=writer)
464 |
465 | if master:
466 | checkpoint_dir = os.path.join(experiment_dir, "checkpoints", "{:04}".format(epoch))
467 | os.makedirs(checkpoint_dir, exist_ok=True)
468 |
469 | torch.save(model.state_dict(), os.path.join(checkpoint_dir, "weights.pth"))
470 |
471 | print(f"{n_iters_total_train} iters done.")
472 | else:
473 | if args.eval_dataset == 'train':
474 | one_epoch(model, criterion, opt, config, train_dataloader, device, 0, n_iters_total=0, is_train=False, master=master, experiment_dir=experiment_dir, writer=writer)
475 | else:
476 | one_epoch(model, criterion, opt, config, val_dataloader, device, 0, n_iters_total=0, is_train=False, master=master, experiment_dir=experiment_dir, writer=writer)
477 |
478 | print("Done.")
479 |
480 | if __name__ == '__main__':
481 | args = parse_args()
482 | print("args: {}".format(args))
483 | main(args)
484 |
--------------------------------------------------------------------------------