├── .gitignore ├── LICENSE ├── README.md ├── docs ├── algebraic-model.svg ├── unprojection.gif ├── video-preview.jpg └── volumetric-model.svg ├── experiments └── human36m │ ├── eval │ ├── human36m_alg.yaml │ ├── human36m_ransac.yaml │ └── human36m_vol_softmax.yaml │ └── train │ ├── human36m_alg.yaml │ ├── human36m_alg_no_conf.yaml │ └── human36m_vol_softmax.yaml ├── mvn ├── __init__.py ├── datasets │ ├── __init__.py │ ├── human36m.py │ ├── human36m_preprocessing │ │ ├── README.md │ │ ├── action_to_bbox_filename.py │ │ ├── action_to_una_dinosauria.py │ │ ├── collect-bboxes.py │ │ ├── generate-labels-npy-multiview.py │ │ ├── undistort-h36m.py │ │ └── view-dataset.py │ └── utils.py ├── models │ ├── __init__.py │ ├── loss.py │ ├── pose_resnet.py │ ├── triangulation.py │ └── v2v.py └── utils │ ├── __init__.py │ ├── cfg.py │ ├── img.py │ ├── misc.py │ ├── multiview.py │ ├── op.py │ ├── vis.py │ └── volumetric.py ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2019 Samsung AI Center Moscow, Karim Iskakov (k.iskakov@samsung.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/190505754/3d-human-pose-estimation-on-human36m)](https://paperswithcode.com/sota/3d-human-pose-estimation-on-human36m?p=190505754) 2 | 3 | # Learnable Triangulation of Human Pose 4 | This repository is an official PyTorch implementation of the paper ["Learnable Triangulation of Human Pose"](https://arxiv.org/abs/1905.05754) (ICCV 2019, oral). Here we tackle the problem of 3D human pose estimation from multiple cameras. We present 2 novel methods — Algebraic and Volumetric learnable triangulation — that **outperform** previous state of the art. 5 | 6 | If you find a bug, have a question or know to improve the code - please open an issue! 7 | 8 | :arrow_forward: [ICCV 2019 talk](https://youtu.be/zem03fZWLrQ?t=3477) 9 | 10 |

11 | 12 | 13 | 14 |

15 | 16 | # How to use 17 | This project doesn't have any special or difficult-to-install dependencies. All installation can be done with: 18 | ```bash 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ## Data 23 | Sorry, only [Human3.6M](http://vision.imar.ro/human3.6m/description.php) dataset training/evaluation is available right now. We cannot add [CMU Panoptic](http://domedb.perception.cs.cmu.edu/), sorry for that. 24 | 25 | #### Human3.6M 26 | 1. Download and preprocess the dataset by following the instructions in [mvn/datasets/human36m_preprocessing/README.md](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/mvn/datasets/human36m_preprocessing/README.md). 27 | 2. Download pretrained backbone's weights from [here](https://disk.yandex.ru/d/hv-uH_7TY0ONpg) and place them here: `./data/pretrained/human36m/pose_resnet_4.5_pixels_human36m.pth` (ResNet-152 trained on COCO dataset and finetuned jointly on MPII and Human3.6M). 28 | 3. If you want to train Volumetric model, you need rough estimations of the pelvis' 3D positions both for train and val splits. In the paper we estimate them using the Algebraic model. You can use the [pretrained](#model-zoo) Algebraic model to produce predictions or just take [precalculated 3D skeletons](#model-zoo). 29 | 30 | ## Model zoo 31 | In this section we collect pretrained models and configs. All **pretrained weights** and **precalculated 3D skeletons** can be downloaded at once [from here](https://disk.yandex.ru/d/jbsyN8XVzD-Y7Q) and placed to `./data/pretrained`, so that eval configs can work out-of-the-box (without additional setting of paths). Alternatively, the table below provides separate links to those files. 32 | 33 | **Human3.6M:** 34 | 35 | | Model | Train config | Eval config | Weights | Precalculated results | MPJPE (relative to pelvis), mm | 36 | |----------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------------------------------------:|-------------------------------:| 37 | | Algebraic | [train/human36m_alg.yaml](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/experiments/human36m/train/human36m_alg.yaml) | [eval/human36m_alg.yaml](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/experiments/human36m/eval/human36m_alg.yaml) | [link](https://disk.yandex.ru/d/3TJMKaa6iKaymw) | [train](https://disk.yandex.ru/d/2Gwk7JZ_QWpFvw), [val](https://disk.yandex.ru/d/ZsQ4GV5EX_Wsog) | 22.5 | 38 | | Volumetric (softmax) | [train/human36m_vol_softmax.yaml](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/experiments/human36m/train/human36m_vol_softmax.yaml) | [eval/human36m_vol_softmax.yaml](https://github.com/karfly/learnable-triangulation-pytorch/blob/master/experiments/human36m/eval/human36m_vol_softmax.yaml) | [link](https://disk.yandex.ru/d/MvD3orcBc6wqRA) | — | **20.4** | 39 | 40 | ## Train 41 | Every experiment is defined by `.config` files. Configs with experiments from the paper can be found in the `./experiments` directory (see [model zoo](#model-zoo)). 42 | 43 | #### Single-GPU 44 | To train a Volumetric model with softmax aggregation using **1 GPU**, run: 45 | ```bash 46 | python3 train.py \ 47 | --config experiments/human36m/train/human36m_vol_softmax.yaml \ 48 | --logdir ./logs 49 | ``` 50 | 51 | The training will start with the config file specified by `--config`, and logs (including tensorboard files) will be stored in `--logdir`. 52 | 53 | #### Multi-GPU (*in testing*) 54 | Multi-GPU training is implemented with PyTorch's [DistributedDataParallel](https://pytorch.org/docs/stable/nn.html#distributeddataparallel). It can be used both for single-machine and multi-machine (cluster) training. To run the processes use the PyTorch [launch utility](https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py). 55 | 56 | To train a Volumetric model with softmax aggregation using **2 GPUs on single machine**, run: 57 | ```bash 58 | python3 -m torch.distributed.launch --nproc_per_node=2 --master_port=2345 \ 59 | train.py \ 60 | --config experiments/human36m/train/human36m_vol_softmax.yaml \ 61 | --logdir ./logs 62 | ``` 63 | 64 | ## Tensorboard 65 | To watch your experiments' progress, run tensorboard: 66 | ```bash 67 | tensorboard --logdir ./logs 68 | ``` 69 | 70 | ## Evaluation 71 | After training, you can evaluate the model. Inside the same config file, add path to the learned weights (they are dumped to `logs` dir during training): 72 | ```yaml 73 | model: 74 | init_weights: true 75 | checkpoint: {PATH_TO_WEIGHTS} 76 | ``` 77 | 78 | Also, you can change other config parameters like `retain_every_n_frames_test`. 79 | 80 | Run: 81 | ```bash 82 | python3 train.py \ 83 | --eval --eval_dataset val \ 84 | --config experiments/human36m/eval/human36m_vol_softmax.yaml \ 85 | --logdir ./logs 86 | ``` 87 | Argument `--eval_dataset` can be `val` or `train`. Results can be seen in `logs` directory or in the tensorboard. 88 | 89 | # Results 90 | * We conduct experiments on two available large multi-view datasets: Human3.6M [\[2\]](#references) and CMU Panoptic [\[3\]](#references). 91 | * The main metric is **MPJPE** (Mean Per Joint Position Error) which is L2 distance averaged over all joints. 92 | 93 | ## Human3.6M 94 | * We significantly improved upon the previous state of the art (error is measured relative to pelvis, without alignment). 95 | * Our best model reaches **17.7 mm** error in absolute coordinates, which was unattainable before. 96 | * Our Volumetric model is able to estimate 3D human pose using **any number of cameras**, even using **only 1 camera**. In single-view setup, we get results comparable to current state of the art [\[6\]](#references) (49.9 mm vs. 49.6 mm). 97 | 98 |
99 | MPJPE relative to pelvis: 100 | 101 | | | MPJPE (averaged across all actions), mm | 102 | |----------------------------- |:--------: | 103 | | Multi-View Martinez [\[4\]](#references) | 57.0 | 104 | | Pavlakos et al. [\[8\]](#references) | 56.9 | 105 | | Tome et al. [\[4\]](#references) | 52.8 | 106 | | Kadkhodamohammadi & Padoy [\[5\]](#references) | 49.1 | 107 | | [Qiu et al.](https://github.com/microsoft/multiview-human-pose-estimation-pytorch) [\[9\]](#references) | 26.2 | 108 | | RANSAC (our implementation) | 27.4 | 109 | | **Ours, algebraic** | 22.4 | 110 | | **Ours, volumetric** | **20.5** | 111 | 112 |
113 | MPJPE absolute (scenes with invalid ground-truth annotations are excluded): 114 | 115 | | | MPJPE (averaged across all actions), mm | 116 | |----------------------------- |:--------: | 117 | | RANSAC (our implementation) | 22.8 | 118 | | **Ours, algebraic** | 19.2 | 119 | | **Ours, volumetric** | **17.7** | 120 | 121 |
122 | MPJPE relative to pelvis (single-view methods): 123 | 124 | | | MPJPE (averaged across all actions), mm | 125 | |----------------------------- |:-----------------------------------: | 126 | | Martinez et al. [\[7\]](#references) | 62.9 | 127 | | Sun et al. [\[6\]](#references) | **49.6** | 128 | | **Ours, volumetric single view** | **49.9** | 129 | 130 | 131 | ## CMU Panoptic 132 | * Our best model reaches **13.7 mm** error in absolute coordinates for 4 cameras 133 | * We managed to get much smoother and more accurate 3D pose annotations compared to dataset annotations (see [video demonstration](http://www.youtube.com/watch?v=z3f3aPSuhqg)) 134 | 135 |
136 | MPJPE relative to pelvis [4 cameras]: 137 | 138 | | | MPJPE, mm | 139 | |----------------------------- |:--------: | 140 | | RANSAC (our implementation) | 39.5 | 141 | | **Ours, algebraic** | 21.3 | 142 | | **Ours, volumetric** | **13.7** | 143 | 144 | # Method overview 145 | We present 2 novel methods of learnable triangulation: Algebraic and Volumetric. 146 | 147 | ## Algebraic 148 | ![algebraic-model](docs/algebraic-model.svg) 149 | 150 | Our first method is based on Algebraic triangulation. It is similar to the previous approaches, but differs in 2 critical aspects: 151 | 1. It is **fully differentiable**. To achieve this, we use soft-argmax aggregation and triangulate keypoints via a differentiable SVD. 152 | 2. The neural network additionally predicts **scalar confidences for each joint**, passed to the triangulation module, which successfully deals with outliers and occluded joints. 153 | 154 | For the most popular Human3.6M dataset, this method already dramatically reduces error by **2.2 times (!)**, compared to the previous art. 155 | 156 | 157 | ## Volumetric 158 | ![volumetric-model](docs/volumetric-model.svg) 159 | 160 | In Volumetric triangulation model, intermediate 2D feature maps are densely unprojected to the volumetric cube and then processed with a 3D-convolutional neural network. Unprojection operation allows **dense aggregation from multiple views** and the 3D-convolutional neural network is able to model **implicit human pose prior**. 161 | 162 | Volumetric triangulation additionally improves accuracy, drastically reducing the previous state-of-the-art error by **2.4 times!** Even compared to the best parallelly developed [method](https://github.com/microsoft/multiview-human-pose-estimation-pytorch) by MSRA group, our method still offers significantly lower error of **21 mm**. 163 | 164 |

165 | 166 |

167 | 168 | 169 | # Cite us! 170 | ```bibtex 171 | @inproceedings{iskakov2019learnable, 172 | title={Learnable Triangulation of Human Pose}, 173 | author={Iskakov, Karim and Burkov, Egor and Lempitsky, Victor and Malkov, Yury}, 174 | booktitle = {International Conference on Computer Vision (ICCV)}, 175 | year={2019} 176 | } 177 | ``` 178 | 179 | # Contributors 180 | - [Karim Iskakov](https://github.com/karfly) 181 | - [Egor Burkov](https://github.com/shrubb) 182 | - [Victor Lempitsky](https://scholar.google.com/citations?user=gYYVokYAAAAJ&hl=ru) 183 | - [Yury Malkov](https://github.com/yurymalkov) 184 | - [Rasul Kerimov](https://github.com/rrkarim) 185 | - [Ivan Bulygin](https://github.com/blufzzz) 186 | 187 | # News 188 | - **26 Nov 2019:** Updataed [precalculated results](#model-zoo) (see [this issue](https://github.com/karfly/learnable-triangulation-pytorch/issues/37)). 189 | - **18 Oct 2019:** Pretrained models (algebraic and volumetric) for Human3.6M are released. 190 | - **8 Oct 2019:** Code is released! 191 | 192 | # References 193 | * [\[1\]](#references) R. Hartley and A. Zisserman. **Multiple view geometry in computer vision**. 194 | * [\[2\]](#references) C. Ionescu, D. Papava, V. Olaru, and C. Sminchisescu. **Human3.6m: Large scale datasets and predictive methods for 3d human sensing in natural environments**. 195 | * [\[3\]](#references) H. Joo, T. Simon, X. Li, H. Liu, L. Tan, L. Gui, S. Banerjee, T. S. Godisart, B. Nabbe, I. Matthews, T. Kanade,S. Nobuhara, and Y. Sheikh. **Panoptic studio: A massively multiview system for social interaction capture**. 196 | * [\[4\]](#references) D. Tome, M. Toso, L. Agapito, and C. Russell. **Rethinking Pose in 3D: Multi-stage Refinement and Recovery for Markerless Motion Capture**. 197 | * [\[5\]](#references) A. Kadkhodamohammadi and N. Padoy. **A generalizable approach for multi-view 3D human pose regression**. 198 | * [\[6\]](#references) X. Sun, B. Xiao, S. Liang, and Y. Wei. **Integral human pose regression**. 199 | * [\[7\]](#references) J. Martinez, R. Hossain, J. Romero, and J. J. Little. **A simple yet effective baseline for 3d human pose estimation**. 200 | * [\[8\]](#references) G. Pavlakos, X. Zhou, K. G. Derpanis, and K. Daniilidis. **Harvesting multiple views for marker-less 3D human pose annotations**. 201 | * [\[9\]](#references) H. Qiu, C. Wang, J. Wang, N. Wang and W. Zeng. (2019). **Cross View Fusion for 3D Human Pose Estimation**, [GitHub](https://github.com/microsoft/multiview-human-pose-estimation-pytorch) 202 | -------------------------------------------------------------------------------- /docs/unprojection.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/karfly/learnable-triangulation-pytorch/8dcc4e97a407ba0474d2e5299a92bb32cfaeadfe/docs/unprojection.gif -------------------------------------------------------------------------------- /docs/video-preview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/karfly/learnable-triangulation-pytorch/8dcc4e97a407ba0474d2e5299a92bb32cfaeadfe/docs/video-preview.jpg -------------------------------------------------------------------------------- /experiments/human36m/eval/human36m_alg.yaml: -------------------------------------------------------------------------------- 1 | title: "human36m_alg" 2 | kind: "human36m" 3 | vis_freq: 1000 4 | vis_n_elements: 10 5 | 6 | image_shape: [384, 384] 7 | 8 | opt: 9 | criterion: "MSESmooth" 10 | mse_smooth_threshold: 400 11 | 12 | n_objects_per_epoch: 15000 13 | n_epochs: 9999 14 | 15 | batch_size: 8 16 | val_batch_size: 100 17 | 18 | lr: 0.00001 19 | 20 | scale_keypoints_3d: 0.1 21 | 22 | model: 23 | name: "alg" 24 | 25 | init_weights: true 26 | checkpoint: "./data/pretrained/human36m/human36m_alg_10-04-2019/checkpoints/0060/weights.pth" 27 | 28 | use_confidences: true 29 | heatmap_multiplier: 100.0 30 | heatmap_softmax: true 31 | 32 | backbone: 33 | name: "resnet152" 34 | style: "simple" 35 | 36 | init_weights: true 37 | checkpoint: "./data/pretrained/human36m/pose_resnet_4.5_pixels_human36m.pth" 38 | 39 | num_joints: 17 40 | num_layers: 152 41 | 42 | dataset: 43 | kind: "human36m" 44 | 45 | train: 46 | h36m_root: "./data/human36m/processed/" 47 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy" 48 | with_damaged_actions: true 49 | undistort_images: true 50 | 51 | scale_bbox: 1.0 52 | 53 | shuffle: true 54 | randomize_n_views: false 55 | min_n_views: null 56 | max_n_views: null 57 | num_workers: 8 58 | 59 | val: 60 | h36m_root: "./data/human36m/processed/" 61 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy" 62 | with_damaged_actions: true 63 | undistort_images: true 64 | 65 | scale_bbox: 1.0 66 | 67 | shuffle: false 68 | randomize_n_views: false 69 | min_n_views: null 70 | max_n_views: null 71 | num_workers: 8 72 | 73 | retain_every_n_frames_in_test: 1 74 | -------------------------------------------------------------------------------- /experiments/human36m/eval/human36m_ransac.yaml: -------------------------------------------------------------------------------- 1 | title: "human36m_ransac" 2 | kind: "human36m" 3 | vis_freq: 1000 4 | vis_n_elements: 10 5 | 6 | image_shape: [384, 384] 7 | 8 | opt: 9 | criterion: "MSESmooth" 10 | mse_smooth_threshold: 400 11 | 12 | n_objects_per_epoch: 15000 13 | n_epochs: 9999 14 | 15 | batch_size: 8 16 | val_batch_size: 100 17 | 18 | lr: 0.00001 19 | 20 | scale_keypoints_3d: 0.1 21 | 22 | model: 23 | name: "ransac" 24 | 25 | init_weights: false 26 | checkpoint: "" 27 | 28 | direct_optimization: true 29 | heatmap_multiplier: 100.0 30 | heatmap_softmax: true 31 | 32 | backbone: 33 | name: "resnet152" 34 | style: "simple" 35 | 36 | init_weights: true 37 | checkpoint: "./data/pretrained/human36m/pose_resnet_4.5_pixels_human36m.pth" 38 | 39 | num_joints: 17 40 | num_layers: 152 41 | 42 | dataset: 43 | kind: "human36m" 44 | 45 | train: 46 | h36m_root: "./data/human36m/processed/" 47 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy" 48 | with_damaged_actions: true 49 | undistort_images: true 50 | 51 | scale_bbox: 1.0 52 | 53 | shuffle: true 54 | randomize_n_views: false 55 | min_n_views: null 56 | max_n_views: null 57 | num_workers: 8 58 | 59 | val: 60 | h36m_root: "./data/human36m/processed/" 61 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy" 62 | with_damaged_actions: true 63 | undistort_images: true 64 | 65 | scale_bbox: 1.0 66 | 67 | shuffle: false 68 | randomize_n_views: false 69 | min_n_views: null 70 | max_n_views: null 71 | num_workers: 8 72 | 73 | retain_every_n_frames_in_test: 1 74 | -------------------------------------------------------------------------------- /experiments/human36m/eval/human36m_vol_softmax.yaml: -------------------------------------------------------------------------------- 1 | title: "human36m_vol_softmax" 2 | kind: "human36m" 3 | vis_freq: 1000 4 | vis_n_elements: 10 5 | 6 | image_shape: [384, 384] 7 | 8 | opt: 9 | criterion: "MAE" 10 | 11 | use_volumetric_ce_loss: true 12 | volumetric_ce_loss_weight: 0.01 13 | 14 | n_objects_per_epoch: 15000 15 | n_epochs: 9999 16 | 17 | batch_size: 5 18 | val_batch_size: 75 19 | 20 | lr: 0.0001 21 | process_features_lr: 0.001 22 | volume_net_lr: 0.001 23 | 24 | scale_keypoints_3d: 0.1 25 | 26 | model: 27 | name: "vol" 28 | kind: "mpii" 29 | volume_aggregation_method: "softmax" 30 | 31 | init_weights: true 32 | checkpoint: "./data/pretrained/human36m/human36m_vol_softmax_10-08-2019/checkpoints/0040/weights.pth" 33 | 34 | use_gt_pelvis: false 35 | 36 | cuboid_side: 2500.0 37 | 38 | volume_size: 64 39 | volume_multiplier: 1.0 40 | volume_softmax: true 41 | 42 | heatmap_softmax: true 43 | heatmap_multiplier: 100.0 44 | 45 | backbone: 46 | name: "resnet152" 47 | style: "simple" 48 | 49 | init_weights: true 50 | checkpoint: "./data/pretrained/human36m/pose_resnet_4.5_pixels_human36m.pth" 51 | 52 | num_joints: 17 53 | num_layers: 152 54 | 55 | dataset: 56 | kind: "human36m" 57 | 58 | train: 59 | h36m_root: "./data/human36m/processed/" 60 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy" 61 | pred_results_path: "./data/pretrained/human36m/human36m_alg_10-04-2019/checkpoints/0060/results/train.pkl" 62 | 63 | with_damaged_actions: true 64 | undistort_images: true 65 | 66 | scale_bbox: 1.0 67 | 68 | shuffle: true 69 | randomize_n_views: false 70 | min_n_views: null 71 | max_n_views: null 72 | num_workers: 5 73 | 74 | val: 75 | h36m_root: "./data/human36m/processed/" 76 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy" 77 | pred_results_path: "./data/pretrained/human36m/human36m_alg_10-04-2019/checkpoints/0060/results/val.pkl" 78 | 79 | with_damaged_actions: true 80 | undistort_images: true 81 | 82 | scale_bbox: 1.0 83 | 84 | shuffle: false 85 | randomize_n_views: false 86 | min_n_views: null 87 | max_n_views: null 88 | num_workers: 8 89 | 90 | retain_every_n_frames_in_test: 1 91 | -------------------------------------------------------------------------------- /experiments/human36m/train/human36m_alg.yaml: -------------------------------------------------------------------------------- 1 | title: "human36m_alg" 2 | kind: "human36m" 3 | vis_freq: 1000 4 | vis_n_elements: 10 5 | 6 | image_shape: [384, 384] 7 | 8 | opt: 9 | criterion: "MSESmooth" 10 | mse_smooth_threshold: 400 11 | 12 | n_objects_per_epoch: 15000 13 | n_epochs: 9999 14 | 15 | batch_size: 8 16 | val_batch_size: 16 17 | 18 | lr: 0.00001 19 | 20 | scale_keypoints_3d: 0.1 21 | 22 | model: 23 | name: "alg" 24 | 25 | init_weights: false 26 | checkpoint: "" 27 | 28 | use_confidences: true 29 | heatmap_multiplier: 100.0 30 | heatmap_softmax: true 31 | 32 | backbone: 33 | name: "resnet152" 34 | style: "simple" 35 | 36 | init_weights: true 37 | checkpoint: "./data/pretrained/human36m/pose_resnet_4.5_pixels_human36m.pth" 38 | 39 | num_joints: 17 40 | num_layers: 152 41 | 42 | dataset: 43 | kind: "human36m" 44 | 45 | train: 46 | h36m_root: "./data/human36m/processed/" 47 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy" 48 | with_damaged_actions: true 49 | undistort_images: true 50 | 51 | scale_bbox: 1.0 52 | 53 | shuffle: true 54 | randomize_n_views: false 55 | min_n_views: null 56 | max_n_views: null 57 | num_workers: 8 58 | 59 | val: 60 | h36m_root: "./data/human36m/processed/" 61 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy" 62 | with_damaged_actions: true 63 | undistort_images: true 64 | 65 | scale_bbox: 1.0 66 | 67 | shuffle: false 68 | randomize_n_views: false 69 | min_n_views: null 70 | max_n_views: null 71 | num_workers: 8 72 | 73 | retain_every_n_frames_in_test: 1 74 | -------------------------------------------------------------------------------- /experiments/human36m/train/human36m_alg_no_conf.yaml: -------------------------------------------------------------------------------- 1 | title: "human36m_alg" 2 | kind: "human36m" 3 | vis_freq: 1000 4 | vis_n_elements: 10 5 | 6 | image_shape: [384, 384] 7 | 8 | opt: 9 | criterion: "MSESmooth" 10 | mse_smooth_threshold: 400 11 | 12 | n_objects_per_epoch: 10000 13 | n_epochs: 9999 14 | 15 | batch_size: 8 16 | val_batch_size: 16 17 | 18 | lr: 0.00001 19 | 20 | scale_keypoints_3d: 0.1 21 | 22 | model: 23 | name: "alg" 24 | 25 | init_weights: false 26 | checkpoint: "" 27 | 28 | use_confidences: false 29 | heatmap_multiplier: 100.0 30 | heatmap_softmax: true 31 | 32 | backbone: 33 | name: "resnet152" 34 | style: "simple" 35 | 36 | init_weights: true 37 | checkpoint: "./data/pretrained/human36m/pose_resnet_4.5_pixels_human36m.pth" 38 | 39 | num_joints: 17 40 | num_layers: 152 41 | 42 | dataset: 43 | kind: "human36m" 44 | 45 | train: 46 | h36m_root: "./data/human36m/processed/" 47 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy" 48 | with_damaged_actions: true 49 | undistort_images: true 50 | 51 | scale_bbox: 1.0 52 | 53 | shuffle: true 54 | randomize_n_views: false 55 | min_n_views: null 56 | max_n_views: null 57 | num_workers: 8 58 | 59 | val: 60 | h36m_root: "./data/human36m/processed/" 61 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy" 62 | with_damaged_actions: true 63 | undistort_images: true 64 | 65 | scale_bbox: 1.0 66 | 67 | shuffle: false 68 | randomize_n_views: false 69 | min_n_views: null 70 | max_n_views: null 71 | num_workers: 8 72 | 73 | retain_every_n_frames_in_test: 1 74 | -------------------------------------------------------------------------------- /experiments/human36m/train/human36m_vol_softmax.yaml: -------------------------------------------------------------------------------- 1 | title: "human36m_vol_softmax" 2 | kind: "human36m" 3 | vis_freq: 1000 4 | vis_n_elements: 10 5 | 6 | image_shape: [384, 384] 7 | 8 | opt: 9 | criterion: "MAE" 10 | 11 | use_volumetric_ce_loss: true 12 | volumetric_ce_loss_weight: 0.01 13 | 14 | n_objects_per_epoch: 15000 15 | n_epochs: 9999 16 | 17 | batch_size: 5 18 | val_batch_size: 10 19 | 20 | lr: 0.0001 21 | process_features_lr: 0.001 22 | volume_net_lr: 0.001 23 | 24 | scale_keypoints_3d: 0.1 25 | 26 | model: 27 | name: "vol" 28 | kind: "mpii" 29 | volume_aggregation_method: "softmax" 30 | 31 | init_weights: false 32 | checkpoint: "" 33 | 34 | use_gt_pelvis: false 35 | 36 | cuboid_side: 2500.0 37 | 38 | volume_size: 64 39 | volume_multiplier: 1.0 40 | volume_softmax: true 41 | 42 | heatmap_softmax: true 43 | heatmap_multiplier: 100.0 44 | 45 | backbone: 46 | name: "resnet152" 47 | style: "simple" 48 | 49 | init_weights: true 50 | checkpoint: "./data/pretrained/human36m/pose_resnet_4.5_pixels_human36m.pth" 51 | 52 | num_joints: 17 53 | num_layers: 152 54 | 55 | dataset: 56 | kind: "human36m" 57 | 58 | train: 59 | h36m_root: "./data/human36m/processed/" 60 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy" 61 | pred_results_path: "./data/pretrained/human36m/human36m_alg_10-04-2019/checkpoints/0060/results/train.pkl" 62 | 63 | with_damaged_actions: true 64 | undistort_images: true 65 | 66 | scale_bbox: 1.0 67 | 68 | shuffle: true 69 | randomize_n_views: false 70 | min_n_views: null 71 | max_n_views: null 72 | num_workers: 5 73 | 74 | val: 75 | h36m_root: "./data/human36m/processed/" 76 | labels_path: "./data/human36m/extra/human36m-multiview-labels-GTbboxes.npy" 77 | pred_results_path: "./data/pretrained/human36m/human36m_alg_10-04-2019/checkpoints/0060/results/val.pkl" 78 | 79 | with_damaged_actions: true 80 | undistort_images: true 81 | 82 | scale_bbox: 1.0 83 | 84 | shuffle: false 85 | randomize_n_views: false 86 | min_n_views: null 87 | max_n_views: null 88 | num_workers: 10 89 | 90 | retain_every_n_frames_in_test: 1 91 | -------------------------------------------------------------------------------- /mvn/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | -------------------------------------------------------------------------------- /mvn/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | -------------------------------------------------------------------------------- /mvn/datasets/human36m.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | import pickle 4 | 5 | import numpy as np 6 | import cv2 7 | 8 | import torch 9 | from torch.utils.data import Dataset 10 | 11 | from mvn.utils.multiview import Camera 12 | from mvn.utils.img import get_square_bbox, resize_image, crop_image, normalize_image, scale_bbox 13 | from mvn.utils import volumetric 14 | 15 | 16 | class Human36MMultiViewDataset(Dataset): 17 | """ 18 | Human3.6M for multiview tasks. 19 | """ 20 | def __init__(self, 21 | h36m_root='/Vol1/dbstore/datasets/Human3.6M/processed/', 22 | labels_path='/Vol1/dbstore/datasets/Human3.6M/extra/human36m-multiview-labels-SSDbboxes.npy', 23 | pred_results_path=None, 24 | image_shape=(256, 256), 25 | train=False, 26 | test=False, 27 | retain_every_n_frames_in_test=1, 28 | with_damaged_actions=False, 29 | cuboid_side=2000.0, 30 | scale_bbox=1.5, 31 | norm_image=True, 32 | kind="mpii", 33 | undistort_images=False, 34 | ignore_cameras=[], 35 | crop=True 36 | ): 37 | """ 38 | h36m_root: 39 | Path to 'processed/' directory in Human3.6M 40 | labels_path: 41 | Path to 'human36m-multiview-labels.npy' generated by 'generate-labels-npy-multiview.py' 42 | from https://github.sec.samsung.net/RRU8-VIOLET/human36m-preprocessing 43 | retain_every_n_frames_in_test: 44 | By default, there are 159 181 frames in training set and 26 634 in test (val) set. 45 | With this parameter, test set frames will be evenly skipped frames so that the 46 | test set size is `26634 // retain_every_n_frames_test`. 47 | Use a value of 13 to get 2049 frames in test set. 48 | with_damaged_actions: 49 | If `True`, will include 'S9/[Greeting-2,SittingDown-2,Waiting-1]' in test set. 50 | kind: 51 | Keypoint format, 'mpii' or 'human36m' 52 | ignore_cameras: 53 | A list with indices of cameras to exclude (0 to 3 inclusive) 54 | """ 55 | assert train or test, '`Human36MMultiViewDataset` must be constructed with at least ' \ 56 | 'one of `test=True` / `train=True`' 57 | assert kind in ("mpii", "human36m") 58 | 59 | self.h36m_root = h36m_root 60 | self.labels_path = labels_path 61 | self.image_shape = None if image_shape is None else tuple(image_shape) 62 | self.scale_bbox = scale_bbox 63 | self.norm_image = norm_image 64 | self.cuboid_side = cuboid_side 65 | self.kind = kind 66 | self.undistort_images = undistort_images 67 | self.ignore_cameras = ignore_cameras 68 | self.crop = crop 69 | 70 | self.labels = np.load(labels_path, allow_pickle=True).item() 71 | 72 | n_cameras = len(self.labels['camera_names']) 73 | assert all(camera_idx in range(n_cameras) for camera_idx in self.ignore_cameras) 74 | 75 | train_subjects = ['S1', 'S5', 'S6', 'S7', 'S8'] 76 | test_subjects = ['S9', 'S11'] 77 | 78 | train_subjects = list(self.labels['subject_names'].index(x) for x in train_subjects) 79 | test_subjects = list(self.labels['subject_names'].index(x) for x in test_subjects) 80 | 81 | indices = [] 82 | if train: 83 | mask = np.isin(self.labels['table']['subject_idx'], train_subjects, assume_unique=True) 84 | indices.append(np.nonzero(mask)[0]) 85 | if test: 86 | mask = np.isin(self.labels['table']['subject_idx'], test_subjects, assume_unique=True) 87 | 88 | if not with_damaged_actions: 89 | mask_S9 = self.labels['table']['subject_idx'] == self.labels['subject_names'].index('S9') 90 | 91 | damaged_actions = 'Greeting-2', 'SittingDown-2', 'Waiting-1' 92 | damaged_actions = [self.labels['action_names'].index(x) for x in damaged_actions] 93 | mask_damaged_actions = np.isin(self.labels['table']['action_idx'], damaged_actions) 94 | 95 | mask &= ~(mask_S9 & mask_damaged_actions) 96 | 97 | indices.append(np.nonzero(mask)[0][::retain_every_n_frames_in_test]) 98 | 99 | self.labels['table'] = self.labels['table'][np.concatenate(indices)] 100 | 101 | self.num_keypoints = 16 if kind == "mpii" else 17 102 | assert self.labels['table']['keypoints'].shape[1] == 17, "Use a newer 'labels' file" 103 | 104 | self.keypoints_3d_pred = None 105 | if pred_results_path is not None: 106 | pred_results = np.load(pred_results_path, allow_pickle=True) 107 | keypoints_3d_pred = pred_results['keypoints_3d'][np.argsort(pred_results['indexes'])] 108 | self.keypoints_3d_pred = keypoints_3d_pred[::retain_every_n_frames_in_test] 109 | assert len(self.keypoints_3d_pred) == len(self), \ 110 | f"[train={train}, test={test}] {labels_path} has {len(self)} samples, but '{pred_results_path}' " + \ 111 | f"has {len(self.keypoints_3d_pred)}. Did you follow all preprocessing instructions carefully?" 112 | 113 | def __len__(self): 114 | return len(self.labels['table']) 115 | 116 | def __getitem__(self, idx): 117 | sample = defaultdict(list) # return value 118 | shot = self.labels['table'][idx] 119 | 120 | subject = self.labels['subject_names'][shot['subject_idx']] 121 | action = self.labels['action_names'][shot['action_idx']] 122 | frame_idx = shot['frame_idx'] 123 | 124 | for camera_idx, camera_name in enumerate(self.labels['camera_names']): 125 | if camera_idx in self.ignore_cameras: 126 | continue 127 | 128 | # load bounding box 129 | bbox = shot['bbox_by_camera_tlbr'][camera_idx][[1,0,3,2]] # TLBR to LTRB 130 | bbox_height = bbox[2] - bbox[0] 131 | if bbox_height == 0: 132 | # convention: if the bbox is empty, then this view is missing 133 | continue 134 | 135 | # scale the bounding box 136 | bbox = scale_bbox(bbox, self.scale_bbox) 137 | 138 | # load image 139 | image_path = os.path.join( 140 | self.h36m_root, subject, action, 'imageSequence' + '-undistorted' * self.undistort_images, 141 | camera_name, 'img_%06d.jpg' % (frame_idx+1)) 142 | assert os.path.isfile(image_path), '%s doesn\'t exist' % image_path 143 | image = cv2.imread(image_path) 144 | 145 | # load camera 146 | shot_camera = self.labels['cameras'][shot['subject_idx'], camera_idx] 147 | retval_camera = Camera(shot_camera['R'], shot_camera['t'], shot_camera['K'], shot_camera['dist'], camera_name) 148 | 149 | if self.crop: 150 | # crop image 151 | image = crop_image(image, bbox) 152 | retval_camera.update_after_crop(bbox) 153 | 154 | if self.image_shape is not None: 155 | # resize 156 | image_shape_before_resize = image.shape[:2] 157 | image = resize_image(image, self.image_shape) 158 | retval_camera.update_after_resize(image_shape_before_resize, self.image_shape) 159 | 160 | sample['image_shapes_before_resize'].append(image_shape_before_resize) 161 | 162 | if self.norm_image: 163 | image = normalize_image(image) 164 | 165 | sample['images'].append(image) 166 | sample['detections'].append(bbox + (1.0,)) # TODO add real confidences 167 | sample['cameras'].append(retval_camera) 168 | sample['proj_matrices'].append(retval_camera.projection) 169 | 170 | # 3D keypoints 171 | # add dummy confidences 172 | sample['keypoints_3d'] = np.pad( 173 | shot['keypoints'][:self.num_keypoints], 174 | ((0,0), (0,1)), 'constant', constant_values=1.0) 175 | 176 | # build cuboid 177 | # base_point = sample['keypoints_3d'][6, :3] 178 | # sides = np.array([self.cuboid_side, self.cuboid_side, self.cuboid_side]) 179 | # position = base_point - sides / 2 180 | # sample['cuboids'] = volumetric.Cuboid3D(position, sides) 181 | 182 | # save sample's index 183 | sample['indexes'] = idx 184 | 185 | if self.keypoints_3d_pred is not None: 186 | sample['pred_keypoints_3d'] = self.keypoints_3d_pred[idx] 187 | 188 | sample.default_factory = None 189 | return sample 190 | 191 | def evaluate_using_per_pose_error(self, per_pose_error, split_by_subject): 192 | def evaluate_by_actions(self, per_pose_error, mask=None): 193 | if mask is None: 194 | mask = np.ones_like(per_pose_error, dtype=bool) 195 | 196 | action_scores = { 197 | 'Average': {'total_loss': per_pose_error[mask].sum(), 'frame_count': np.count_nonzero(mask)} 198 | } 199 | 200 | for action_idx in range(len(self.labels['action_names'])): 201 | action_mask = (self.labels['table']['action_idx'] == action_idx) & mask 202 | action_per_pose_error = per_pose_error[action_mask] 203 | action_scores[self.labels['action_names'][action_idx]] = { 204 | 'total_loss': action_per_pose_error.sum(), 'frame_count': len(action_per_pose_error) 205 | } 206 | 207 | action_names_without_trials = \ 208 | [name[:-2] for name in self.labels['action_names'] if name.endswith('-1')] 209 | 210 | for action_name_without_trial in action_names_without_trials: 211 | combined_score = {'total_loss': 0.0, 'frame_count': 0} 212 | 213 | for trial in 1, 2: 214 | action_name = '%s-%d' % (action_name_without_trial, trial) 215 | combined_score['total_loss' ] += action_scores[action_name]['total_loss'] 216 | combined_score['frame_count'] += action_scores[action_name]['frame_count'] 217 | del action_scores[action_name] 218 | 219 | action_scores[action_name_without_trial] = combined_score 220 | 221 | for k, v in action_scores.items(): 222 | action_scores[k] = float('nan') if v['frame_count'] == 0 else (v['total_loss'] / v['frame_count']) 223 | 224 | return action_scores 225 | 226 | subject_scores = { 227 | 'Average': evaluate_by_actions(self, per_pose_error) 228 | } 229 | 230 | for subject_idx in range(len(self.labels['subject_names'])): 231 | subject_mask = self.labels['table']['subject_idx'] == subject_idx 232 | subject_scores[self.labels['subject_names'][subject_idx]] = \ 233 | evaluate_by_actions(self, per_pose_error, subject_mask) 234 | 235 | return subject_scores 236 | 237 | def evaluate(self, keypoints_3d_predicted, split_by_subject=False, transfer_cmu_to_human36m=False, transfer_human36m_to_human36m=False): 238 | keypoints_gt = self.labels['table']['keypoints'][:, :self.num_keypoints] 239 | if keypoints_3d_predicted.shape != keypoints_gt.shape: 240 | raise ValueError( 241 | '`keypoints_3d_predicted` shape should be %s, got %s' % \ 242 | (keypoints_gt.shape, keypoints_3d_predicted.shape)) 243 | 244 | if transfer_cmu_to_human36m or transfer_human36m_to_human36m: 245 | human36m_joints = [10, 11, 15, 14, 1, 4] 246 | if transfer_human36m_to_human36m: 247 | cmu_joints = [10, 11, 15, 14, 1, 4] 248 | else: 249 | cmu_joints = [10, 8, 9, 7, 14, 13] 250 | 251 | keypoints_gt = keypoints_gt[:, human36m_joints] 252 | keypoints_3d_predicted = keypoints_3d_predicted[:, cmu_joints] 253 | 254 | # mean error per 16/17 joints in mm, for each pose 255 | per_pose_error = np.sqrt(((keypoints_gt - keypoints_3d_predicted) ** 2).sum(2)).mean(1) 256 | 257 | # relative mean error per 16/17 joints in mm, for each pose 258 | if not (transfer_cmu_to_human36m or transfer_human36m_to_human36m): 259 | root_index = 6 if self.kind == "mpii" else 6 260 | else: 261 | root_index = 0 262 | 263 | keypoints_gt_relative = keypoints_gt - keypoints_gt[:, root_index:root_index + 1, :] 264 | keypoints_3d_predicted_relative = keypoints_3d_predicted - keypoints_3d_predicted[:, root_index:root_index + 1, :] 265 | 266 | per_pose_error_relative = np.sqrt(((keypoints_gt_relative - keypoints_3d_predicted_relative) ** 2).sum(2)).mean(1) 267 | 268 | result = { 269 | 'per_pose_error': self.evaluate_using_per_pose_error(per_pose_error, split_by_subject), 270 | 'per_pose_error_relative': self.evaluate_using_per_pose_error(per_pose_error_relative, split_by_subject) 271 | } 272 | 273 | return result['per_pose_error_relative']['Average']['Average'], result 274 | -------------------------------------------------------------------------------- /mvn/datasets/human36m_preprocessing/README.md: -------------------------------------------------------------------------------- 1 | Human3.6M preprocessing scripts 2 | ======= 3 | 4 | These scripts help preprocess Human3.6M dataset so that it can be used with `class Human36MMultiViewDataset`. 5 | 6 | Here is how we do it (brace yourselves): 7 | 8 | 0. Make sure you have a lot (around 200 GiB?) of free disk space. Otherwise, be prepared to always carefully delete intermediate files (e.g. after you extract movies, delete the archives). 9 | 10 | 1. Allocate a folder for the dataset. Make it accessible as `$THIS_REPOSITORY/data/human36m/`, i.e. either 11 | 12 | * store your data `$SOMEWHERE_ELSE` and make a soft symbolic link: 13 | ```bash 14 | mkdir $THIS_REPOSITORY/data 15 | ln -s $SOMEWHERE_ELSE $THIS_REPOSITORY/data/human36m 16 | ``` 17 | * or just store the dataset along with the code at `$THIS_REPOSITORY/data/human36m/`. 18 | 19 | 1. Clone [this toolbox](https://github.com/anibali/h36m-fetch). Follow their manual to download, extract and unpack Human3.6M into image files. Move the result to `$THIS_REPOSITORY/data/human36m`. 20 | 21 | After that, you should have images unpacked as e.g. `$THIS_REPOSITORY/data/human36m/processed/S1/Phoning-1/imageSequence/54138969/img_000001.jpg`. 22 | 23 | 2. Additionally, if you want to use ground truth bounding boxes for training, download them as well (the website calls them *"Segments BBoxes MAT"*) and unpack them like so: `"$THIS_REPOSITORY/data/human36m/processed/S1/MySegmentsMat/ground_truth_bb/Phoning 1.58860488.mat"`. 24 | 25 | 3. Convert those bounding boxes into sane format. This will create `$THIS_REPOSITORY/data/human36m/extra/bboxes-Human36M-GT.npy`: 26 | 27 | ```bash 28 | cd $THIS_REPOSITORY/mvn/datasets/human36m_preprocessing 29 | # in our environment, this took around 6 minutes with 40 processes 30 | python3 collect-bboxes.py $THIS_REPOSITORY/data/human36m 31 | ``` 32 | 33 | 4. Existing 3D keypoint positions and camera intrinsics are difficult to decipher, so initially we used the converted ones [from Julieta Martinez](https://github.com/una-dinosauria/3d-pose-baseline/): 34 | 35 | ```bash 36 | mkdir -p $THIS_REPOSITORY/data/human36m/extra/una-dinosauria-data 37 | cd $THIS_REPOSITORY/data/human36m/extra/una-dinosauria-data 38 | ``` 39 | 40 | Download `h36m.zip` from [Google Drive](https://disk.yandex.ru/d/3gPRFzLSFpS27Q) and uzip it to current directory: 41 | ```bash 42 | unzip h36m.zip 43 | cd - 44 | ``` 45 | 46 | 5. Wrap the 3D keypoint positions, bounding boxes and camera intrinsics together. This will create `$THIS_REPOSITORY/data/human36m/extra/human36m-multiview-labels-GTbboxes.npy`: 47 | 48 | ```bash 49 | python3 generate-labels-npy-multiview.py $THIS_REPOSITORY/data/human36m $THIS_REPOSITORY/data/human36m/extra/una-dinosauria-data/h36m $THIS_REPOSITORY/data/human36m/extra/bboxes-Human36M-GT.npy 50 | ``` 51 | 52 | You should see only one warning saying `camera 54138969 isn't present in S11/Directions-2`. That's fine. 53 | 54 | Now you can train and evaluate models by setting these config values (already set by default in the example configs): 55 | 56 | ```yaml 57 | dataset: 58 | {train,val}: 59 | h36m_root: "data/human36m/processed/" 60 | labels_path: "data/human36m/extra/human36m-multiview-labels-GTbboxes.npy" 61 | ``` 62 | 63 | 6. To use `undistort_images: true`, undistort the images beforehand. This will put undistorted images to e.g. `$THIS_REPOSITORY/data/human36m/processed/S1/Phoning-1/imageSequence-undistorted/54138969/img_000001.jpg`: 64 | 65 | ```bash 66 | # in our environment, this took around 90 minutes with 50 processes 67 | python3 undistort-h36m.py $THIS_REPOSITORY/data/human36m $THIS_REPOSITORY/data/human36m/extra/human36m-multiview-labels-GTbboxes.npy ` 68 | ``` 69 | 70 | *TODO: move undistortion to the dataloader. We can do it on the fly during training.* 71 | 72 | 7. Optionally, you can test if everything went well by viewing frames with skeletons and bounding boxes on a GUI machine: 73 | 74 | ```bash 75 | python3 view-dataset.py $THIS_REPOSITORY/data/human36m/processed $THIS_REPOSITORY/data/human36m/extra/human36m-multiview-labels-GTbboxes.npy 76 | ``` 77 | 78 | You can test different settings by changing dataset constructor parameters in `view-dataset.py`. 79 | -------------------------------------------------------------------------------- /mvn/datasets/human36m_preprocessing/action_to_bbox_filename.py: -------------------------------------------------------------------------------- 1 | action_to_bbox_filename = { 2 | 'S11': { 3 | 'Directions-2': 'Directions', 4 | 'Discussion-2': 'Discussion 2', 5 | 'Eating-2': 'Eating', 6 | 'Greeting-1': 'Greeting 2', 7 | 'Greeting-2': 'Greeting', 8 | 'Phoning-1': 'Phoning 3', 9 | 'Phoning-2': 'Phoning 2', 10 | 'Posing-2': 'Posing', 11 | 'Purchases-2': 'Purchases', 12 | 'Sitting-2': 'Sitting', 13 | 'SittingDown-1': 'SittingDown', 14 | 'SittingDown-2': 'SittingDown 1', 15 | 'Smoking-1': 'Smoking 2', 16 | 'Smoking-2': 'Smoking', 17 | 'TakingPhoto-1': 'Photo 1', 18 | 'TakingPhoto-2': 'Photo', 19 | 'Waiting-2': 'Waiting', 20 | 'Walking-2': 'Walking', 21 | 'WalkingDog-1': 'WalkDog 1', 22 | 'WalkingDog-2': 'WalkDog', 23 | 'WalkingTogether-1': 'WalkTogether 1', 24 | 'WalkingTogether-2': 'WalkTogether' 25 | }, 26 | 'S9': { 27 | 'Directions-2': 'Directions', 28 | 'Discussion-2': 'Discussion 2', 29 | 'Eating-2': 'Eating', 30 | 'Greeting-2': 'Greeting', 31 | 'Phoning-1': 'Phoning 1', 32 | 'Phoning-2': 'Phoning', 33 | 'Posing-2': 'Posing', 34 | 'Purchases-2': 'Purchases', 35 | 'Sitting-2': 'Sitting', 36 | 'SittingDown-1': 'SittingDown', 37 | 'SittingDown-2': 'SittingDown 1', 38 | 'Smoking-2': 'Smoking', 39 | 'TakingPhoto-1': 'Photo 1', 40 | 'TakingPhoto-2': 'Photo', 41 | 'Waiting-2': 'Waiting', 42 | 'Walking-2': 'Walking', 43 | 'WalkingDog-1': 'WalkDog 1', 44 | 'WalkingDog-2': 'WalkDog', 45 | 'WalkingTogether-1': 'WalkTogether 1', 46 | 'WalkingTogether-2': 'WalkTogether' 47 | }, 48 | 'S8': { 49 | 'Directions-2': 'Directions', 50 | 'Discussion-2': 'Discussion', 51 | 'Eating-2': 'Eating', 52 | 'Greeting-2': 'Greeting', 53 | 'Phoning-1': 'Phoning 1', 54 | 'Phoning-2': 'Phoning', 55 | 'Posing-2': 'Posing', 56 | 'Purchases-2': 'Purchases', 57 | 'Sitting-2': 'Sitting', 58 | 'SittingDown-1': 'SittingDown', 59 | 'SittingDown-2': 'SittingDown 1', 60 | 'Smoking-2': 'Smoking', 61 | 'TakingPhoto-1': 'Photo 1', 62 | 'TakingPhoto-2': 'Photo', 63 | 'Waiting-2': 'Waiting', 64 | 'Walking-2': 'Walking', 65 | 'WalkingDog-1': 'WalkDog 1', 66 | 'WalkingDog-2': 'WalkDog', 67 | 'WalkingTogether-1': 'WalkTogether 1', 68 | 'WalkingTogether-2': 'WalkTogether 2' 69 | }, 70 | 'S7': { 71 | 'Directions-2': 'Directions', 72 | 'Discussion-2': 'Discussion', 73 | 'Eating-2': 'Eating', 74 | 'Greeting-2': 'Greeting', 75 | 'Phoning-1': 'Phoning 2', 76 | 'Phoning-2': 'Phoning', 77 | 'Posing-2': 'Posing', 78 | 'Purchases-2': 'Purchases', 79 | 'Sitting-2': 'Sitting', 80 | 'SittingDown-1': 'SittingDown', 81 | 'SittingDown-2': 'SittingDown 1', 82 | 'Smoking-2': 'Smoking', 83 | 'TakingPhoto-1': 'Photo', 84 | 'TakingPhoto-2': 'Photo 1', 85 | 'WalkingDog-1': 'WalkDog 1', 86 | 'WalkingDog-2': 'WalkDog', 87 | 'WalkingTogether-1': 'WalkTogether 1', 88 | 'WalkingTogether-2': 'WalkTogether' 89 | }, 90 | 'S6': { 91 | 'Directions-2': 'Directions', 92 | 'Discussion-1': 'Discussion 1', 93 | 'Discussion-2': 'Discussion', 94 | 'Eating-2': 'Eating 2', 95 | 'Eating-1': 'Eating 1', 96 | 'Greeting-2': 'Greeting', 97 | 'Phoning-2': 'Phoning', 98 | 'Posing-1': 'Posing 2', 99 | 'Posing-2': 'Posing', 100 | 'Purchases-2': 'Purchases', 101 | 'SittingDown-1': 'SittingDown 1', 102 | 'SittingDown-2': 'SittingDown', 103 | 'Smoking-2': 'Smoking', 104 | 'TakingPhoto-1': 'Photo', 105 | 'TakingPhoto-2': 'Photo 1', 106 | 'Waiting-1': 'Waiting 3', 107 | 'Waiting-2': 'Waiting', 108 | 'Walking-2': 'Walking', 109 | 'WalkingDog-1': 'WalkDog 1', 110 | 'WalkingDog-2': 'WalkDog', 111 | 'WalkingTogether-1': 'WalkTogether 1', 112 | 'WalkingTogether-2': 'WalkTogether' 113 | }, 114 | 'S1': { 115 | 'Discussion-2': 'Discussion', 116 | 'Directions-2': 'Directions', 117 | 'Eating-2': 'Eating', 118 | 'Eating-1': 'Eating 2', 119 | 'Greeting-2': 'Greeting', 120 | 'Phoning-2': 'Phoning', 121 | 'Posing-2': 'Posing', 122 | 'Purchases-2': 'Purchases', 123 | 'SittingDown-1': 'SittingDown 2', 124 | 'SittingDown-2': 'SittingDown', 125 | 'Smoking-2': 'Smoking', 126 | 'TakingPhoto-2': 'TakingPhoto', 127 | 'Waiting-2': 'Waiting', 128 | 'Walking-2': 'Walking', 129 | 'WalkingDog-2': 'WalkingDog', 130 | 'WalkingTogether-1': 'WalkTogether 1', 131 | 'WalkingTogether-2': 'WalkTogether' 132 | }, 133 | 'S5': { 134 | 'Discussion-1': 'Discussion 2', 135 | 'Discussion-2': 'Discussion 3', 136 | 'Eating-2': 'Eating', 137 | 'Eating-1': 'Eating 1', 138 | 'Phoning-2': 'Phoning', 139 | 'Posing-2': 'Posing', 140 | 'Purchases-2': 'Purchases', 141 | 'Sitting-2': 'Sitting', 142 | 'SittingDown-1': 'SittingDown', 143 | 'SittingDown-2': 'SittingDown 1', 144 | 'Smoking-2': 'Smoking', 145 | 'TakingPhoto-1': 'Photo', 146 | 'TakingPhoto-2': 'Photo 2', 147 | 'Waiting-2': 'Waiting 2', 148 | 'Walking-2': 'Walking', 149 | 'WalkingDog-1': 'WalkDog 1', 150 | 'WalkingDog-2': 'WalkDog', 151 | 'WalkingTogether-1': 'WalkTogether 1', 152 | 'WalkingTogether-2': 'WalkTogether' 153 | }, 154 | } 155 | -------------------------------------------------------------------------------- /mvn/datasets/human36m_preprocessing/action_to_una_dinosauria.py: -------------------------------------------------------------------------------- 1 | action_to_una_dinosauria = { 2 | 'S11': { 3 | 'Directions-2': 'Directions', 4 | 'Discussion-2': 'Discussion 2', 5 | 'Eating-2': 'Eating', 6 | 'Greeting-1': 'Greeting 2', 7 | 'Greeting-2': 'Greeting', 8 | 'Phoning-1': 'Phoning 3', 9 | 'Phoning-2': 'Phoning 2', 10 | 'Posing-2': 'Posing', 11 | 'Purchases-2': 'Purchases', 12 | 'Sitting-2': 'Sitting', 13 | 'SittingDown-1': 'SittingDown', 14 | 'SittingDown-2': 'SittingDown 1', 15 | 'Smoking-1': 'Smoking 2', 16 | 'Smoking-2': 'Smoking', 17 | 'TakingPhoto-1': 'Photo 1', 18 | 'TakingPhoto-2': 'Photo', 19 | 'Waiting-2': 'Waiting', 20 | 'Walking-2': 'Walking', 21 | 'WalkingDog-1': 'WalkDog 1', 22 | 'WalkingDog-2': 'WalkDog', 23 | 'WalkingTogether-1': 'WalkTogether 1', 24 | 'WalkingTogether-2': 'WalkTogether' 25 | }, 26 | 'S9': { 27 | 'Directions-2': 'Directions', 28 | 'Discussion-2': 'Discussion 2', 29 | 'Eating-2': 'Eating', 30 | 'Greeting-2': 'Greeting', 31 | 'Phoning-1': 'Phoning 1', 32 | 'Phoning-2': 'Phoning', 33 | 'Posing-2': 'Posing', 34 | 'Purchases-2': 'Purchases', 35 | 'Sitting-2': 'Sitting', 36 | 'SittingDown-1': 'SittingDown', 37 | 'SittingDown-2': 'SittingDown 1', 38 | 'Smoking-2': 'Smoking', 39 | 'TakingPhoto-1': 'Photo 1', 40 | 'TakingPhoto-2': 'Photo', 41 | 'Waiting-2': 'Waiting', 42 | 'Walking-2': 'Walking', 43 | 'WalkingDog-1': 'WalkDog 1', 44 | 'WalkingDog-2': 'WalkDog', 45 | 'WalkingTogether-1': 'WalkTogether 1', 46 | 'WalkingTogether-2': 'WalkTogether' 47 | }, 48 | 'S8': { 49 | 'Directions-2': 'Directions', 50 | 'Discussion-2': 'Discussion', 51 | 'Eating-2': 'Eating', 52 | 'Greeting-2': 'Greeting', 53 | 'Phoning-1': 'Phoning 1', 54 | 'Phoning-2': 'Phoning', 55 | 'Posing-2': 'Posing', 56 | 'Purchases-2': 'Purchases', 57 | 'Sitting-2': 'Sitting', 58 | 'SittingDown-1': 'SittingDown', 59 | 'SittingDown-2': 'SittingDown 1', 60 | 'Smoking-2': 'Smoking', 61 | 'TakingPhoto-1': 'Photo 1', 62 | 'TakingPhoto-2': 'Photo', 63 | 'Waiting-2': 'Waiting', 64 | 'Walking-2': 'Walking', 65 | 'WalkingDog-1': 'WalkDog 1', 66 | 'WalkingDog-2': 'WalkDog', 67 | 'WalkingTogether-1': 'WalkTogether 1', 68 | 'WalkingTogether-2': 'WalkTogether 2' 69 | }, 70 | 'S7': { 71 | 'Directions-2': 'Directions', 72 | 'Discussion-2': 'Discussion', 73 | 'Eating-2': 'Eating', 74 | 'Greeting-2': 'Greeting', 75 | 'Phoning-1': 'Phoning 2', 76 | 'Phoning-2': 'Phoning', 77 | 'Posing-2': 'Posing', 78 | 'Purchases-2': 'Purchases', 79 | 'Sitting-2': 'Sitting', 80 | 'SittingDown-1': 'SittingDown', 81 | 'SittingDown-2': 'SittingDown 1', 82 | 'Smoking-2': 'Smoking', 83 | 'TakingPhoto-1': 'Photo', 84 | 'TakingPhoto-2': 'Photo 1', 85 | 'WalkingDog-1': 'WalkDog 1', 86 | 'WalkingDog-2': 'WalkDog', 87 | 'WalkingTogether-1': 'WalkTogether 1', 88 | 'WalkingTogether-2': 'WalkTogether' 89 | }, 90 | 'S6': { 91 | 'Directions-2': 'Directions', 92 | 'Discussion-1': 'Discussion 1', 93 | 'Discussion-2': 'Discussion', 94 | 'Eating-2': 'Eating 2', 95 | 'Eating-1': 'Eating 1', 96 | 'Greeting-2': 'Greeting', 97 | 'Phoning-2': 'Phoning', 98 | 'Posing-1': 'Posing 2', 99 | 'Posing-2': 'Posing', 100 | 'Purchases-2': 'Purchases', 101 | 'SittingDown-1': 'SittingDown 1', 102 | 'SittingDown-2': 'SittingDown', 103 | 'Smoking-2': 'Smoking', 104 | 'TakingPhoto-1': 'Photo', 105 | 'TakingPhoto-2': 'Photo 1', 106 | 'Waiting-1': 'Waiting 3', 107 | 'Waiting-2': 'Waiting', 108 | 'Walking-2': 'Walking', 109 | 'WalkingDog-1': 'WalkDog 1', 110 | 'WalkingDog-2': 'WalkDog', 111 | 'WalkingTogether-1': 'WalkTogether 1', 112 | 'WalkingTogether-2': 'WalkTogether' 113 | }, 114 | 'S1': { 115 | 'Discussion-2': 'Discussion', 116 | 'Directions-2': 'Directions', 117 | 'Eating-2': 'Eating', 118 | 'Eating-1': 'Eating 2', 119 | 'Greeting-2': 'Greeting', 120 | 'Phoning-2': 'Phoning', 121 | 'Posing-2': 'Posing', 122 | 'Purchases-2': 'Purchases', 123 | 'SittingDown-1': 'SittingDown 2', 124 | 'SittingDown-2': 'SittingDown', 125 | 'Smoking-2': 'Smoking', 126 | 'TakingPhoto-1': 'Photo 1', 127 | 'TakingPhoto-2': 'Photo', 128 | 'Waiting-2': 'Waiting', 129 | 'Walking-2': 'Walking', 130 | 'WalkingDog-1': 'WalkDog 1', 131 | 'WalkingDog-2': 'WalkDog', 132 | 'WalkingTogether-1': 'WalkTogether 1', 133 | 'WalkingTogether-2': 'WalkTogether' 134 | }, 135 | 'S5': { 136 | 'Discussion-1': 'Discussion 2', 137 | 'Discussion-2': 'Discussion 3', 138 | 'Eating-2': 'Eating', 139 | 'Eating-1': 'Eating 1', 140 | 'Phoning-2': 'Phoning', 141 | 'Posing-2': 'Posing', 142 | 'Purchases-2': 'Purchases', 143 | 'Sitting-2': 'Sitting', 144 | 'SittingDown-1': 'SittingDown', 145 | 'SittingDown-2': 'SittingDown 1', 146 | 'Smoking-2': 'Smoking', 147 | 'TakingPhoto-1': 'Photo', 148 | 'TakingPhoto-2': 'Photo 2', 149 | 'Waiting-2': 'Waiting 2', 150 | 'Walking-2': 'Walking', 151 | 'WalkingDog-1': 'WalkDog 1', 152 | 'WalkingDog-2': 'WalkDog', 153 | 'WalkingTogether-1': 'WalkTogether 1', 154 | 'WalkingTogether-2': 'WalkTogether' 155 | }, 156 | } 157 | -------------------------------------------------------------------------------- /mvn/datasets/human36m_preprocessing/collect-bboxes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Read bbox *.mat files from Human3.6M and convert them to a single *.npy file. 3 | Example of an original bbox file: 4 | /S1/MySegmentsMat/ground_truth_bb/WalkingDog 1.54138969.mat 5 | 6 | Usage: 7 | python3 collect-bboxes.py 8 | """ 9 | import os, sys 10 | import numpy as np 11 | import h5py 12 | 13 | dataset_root = sys.argv[1] 14 | data_path = os.path.join(dataset_root, "processed") 15 | subjects = [x for x in os.listdir(data_path) if x.startswith('S')] 16 | assert len(subjects) == 7 17 | 18 | destination_dir = os.path.join(dataset_root, "extra") 19 | os.makedirs(destination_dir, exist_ok=True) 20 | destination_file_path = os.path.join(destination_dir, "bboxes-Human36M-GT.npy") 21 | 22 | # Some bbox files do not exist, can be misaligned, damaged etc. 23 | from action_to_bbox_filename import action_to_bbox_filename 24 | 25 | from collections import defaultdict 26 | nesteddict = lambda: defaultdict(nesteddict) 27 | 28 | bboxes_retval = nesteddict() 29 | 30 | def load_bboxes(data_path, subject, action, camera): 31 | print(subject, action, camera) 32 | 33 | def mask_to_bbox(mask): 34 | h_mask = mask.max(0) 35 | w_mask = mask.max(1) 36 | 37 | top = h_mask.argmax() 38 | bottom = len(h_mask) - h_mask[::-1].argmax() 39 | 40 | left = w_mask.argmax() 41 | right = len(w_mask) - w_mask[::-1].argmax() 42 | 43 | return top, left, bottom, right 44 | 45 | try: 46 | try: 47 | corrected_action = action_to_bbox_filename[subject][action] 48 | except KeyError: 49 | corrected_action = action.replace('-', ' ') 50 | 51 | # TODO use pathlib 52 | bboxes_path = os.path.join( 53 | data_path, 54 | subject, 55 | 'MySegmentsMat', 56 | 'ground_truth_bb', 57 | '%s.%s.mat' % (corrected_action, camera)) 58 | 59 | with h5py.File(bboxes_path, 'r') as h5file: 60 | retval = np.empty((len(h5file['Masks']), 4), dtype=np.int32) 61 | 62 | for frame_idx, mask_reference in enumerate(h5file['Masks'][:,0]): 63 | bbox_mask = np.array(h5file[mask_reference]) 64 | retval[frame_idx] = mask_to_bbox(bbox_mask) 65 | 66 | top, left, bottom, right = retval[frame_idx] 67 | if right-left < 2 or bottom-top < 2: 68 | raise Exception(str(bboxes_path) + ' $ ' + str(frame_idx)) 69 | except Exception as ex: 70 | # reraise with path information 71 | raise Exception(str(ex) + '; %s %s %s' % (subject, action, camera)) 72 | 73 | return retval, subject, action, camera 74 | 75 | # retval['S1']['Talking-1']['54534623'].shape = (n_frames, 4) # top, left, bottom, right 76 | def add_result_to_retval(args): 77 | bboxes, subject, action, camera = args 78 | bboxes_retval[subject][action][camera] = bboxes 79 | 80 | import multiprocessing 81 | num_processes = int(sys.argv[2]) 82 | pool = multiprocessing.Pool(num_processes) 83 | async_errors = [] 84 | 85 | for subject in subjects: 86 | subject_path = os.path.join(data_path, subject) 87 | actions = os.listdir(subject_path) 88 | try: 89 | actions.remove('MySegmentsMat') # folder with bbox *.mat files 90 | except ValueError: 91 | pass 92 | 93 | for action in actions: 94 | cameras = '54138969', '55011271', '58860488', '60457274' 95 | 96 | for camera in cameras: 97 | async_result = pool.apply_async( 98 | load_bboxes, 99 | args=(data_path, subject, action, camera), 100 | callback=add_result_to_retval) 101 | async_errors.append(async_result) 102 | 103 | pool.close() 104 | pool.join() 105 | 106 | # raise any exceptions from pool's processes 107 | for async_result in async_errors: 108 | async_result.get() 109 | 110 | def freeze_defaultdict(x): 111 | x.default_factory = None 112 | for value in x.values(): 113 | if type(value) is defaultdict: 114 | freeze_defaultdict(value) 115 | 116 | # convert to normal dict 117 | freeze_defaultdict(bboxes_retval) 118 | np.save(destination_file_path, bboxes_retval) 119 | -------------------------------------------------------------------------------- /mvn/datasets/human36m_preprocessing/generate-labels-npy-multiview.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate 'labels.npy' for multiview 'human36m.py' 3 | from https://github.sec.samsung.net/RRU8-VIOLET/multi-view-net/ 4 | 5 | Usage: `python3 generate-labels-npy-multiview.py ` 6 | """ 7 | import os, sys 8 | import numpy as np 9 | import h5py 10 | 11 | # Change this line if you want to use Mask-RCNN or SSD bounding boxes instead of H36M's "ground truth". 12 | BBOXES_SOURCE = 'GT' # or 'MRCNN' or 'SSD' 13 | 14 | retval = { 15 | 'subject_names': ['S1', 'S5', 'S6', 'S7', 'S8', 'S9', 'S11'], 16 | 'camera_names': ['54138969', '55011271', '58860488', '60457274'], 17 | 'action_names': [ 18 | 'Directions-1', 'Directions-2', 19 | 'Discussion-1', 'Discussion-2', 20 | 'Eating-1', 'Eating-2', 21 | 'Greeting-1', 'Greeting-2', 22 | 'Phoning-1', 'Phoning-2', 23 | 'Posing-1', 'Posing-2', 24 | 'Purchases-1', 'Purchases-2', 25 | 'Sitting-1', 'Sitting-2', 26 | 'SittingDown-1', 'SittingDown-2', 27 | 'Smoking-1', 'Smoking-2', 28 | 'TakingPhoto-1', 'TakingPhoto-2', 29 | 'Waiting-1', 'Waiting-2', 30 | 'Walking-1', 'Walking-2', 31 | 'WalkingDog-1', 'WalkingDog-2', 32 | 'WalkingTogether-1', 'WalkingTogether-2'] 33 | } 34 | retval['cameras'] = np.empty( 35 | (len(retval['subject_names']), len(retval['camera_names'])), 36 | dtype=[ 37 | ('R', np.float32, (3,3)), 38 | ('t', np.float32, (3,1)), 39 | ('K', np.float32, (3,3)), 40 | ('dist', np.float32, 5) 41 | ] 42 | ) 43 | 44 | table_dtype = np.dtype([ 45 | ('subject_idx', np.int8), 46 | ('action_idx', np.int8), 47 | ('frame_idx', np.int16), 48 | ('keypoints', np.float32, (17,3)), # roughly MPII format 49 | ('bbox_by_camera_tlbr', np.int16, (len(retval['camera_names']),4)) 50 | ]) 51 | retval['table'] = [] 52 | 53 | h36m_root = sys.argv[1] 54 | 55 | destination_file_path = os.path.join(h36m_root, "extra", f"human36m-multiview-labels-{BBOXES_SOURCE}bboxes.npy") 56 | 57 | una_dinosauria_root = sys.argv[2] 58 | cameras_params = h5py.File(os.path.join(una_dinosauria_root, 'cameras.h5'), 'r') 59 | 60 | # Fill retval['cameras'] 61 | for subject_idx, subject in enumerate(retval['subject_names']): 62 | for camera_idx, camera in enumerate(retval['camera_names']): 63 | assert len(cameras_params[subject.replace('S', 'subject')]) == 4 64 | camera_params = cameras_params[subject.replace('S', 'subject')]['camera%d' % (camera_idx+1)] 65 | camera_retval = retval['cameras'][subject_idx][camera_idx] 66 | 67 | def camera_array_to_name(array): 68 | return ''.join(chr(int(x[0])) for x in array) 69 | assert camera_array_to_name(camera_params['Name']) == camera 70 | 71 | camera_retval['R'] = np.array(camera_params['R']).T 72 | camera_retval['t'] = -camera_retval['R'] @ camera_params['T'] 73 | 74 | camera_retval['K'] = 0 75 | camera_retval['K'][:2, 2] = camera_params['c'][:, 0] 76 | camera_retval['K'][0, 0] = camera_params['f'][0] 77 | camera_retval['K'][1, 1] = camera_params['f'][1] 78 | camera_retval['K'][2, 2] = 1.0 79 | 80 | camera_retval['dist'][:2] = camera_params['k'][:2, 0] 81 | camera_retval['dist'][2:4] = camera_params['p'][:, 0] 82 | camera_retval['dist'][4] = camera_params['k'][2, 0] 83 | 84 | # Fill bounding boxes 85 | bboxes = np.load(sys.argv[3], allow_pickle=True).item() 86 | 87 | def square_the_bbox(bbox): 88 | top, left, bottom, right = bbox 89 | width = right - left 90 | height = bottom - top 91 | 92 | if height < width: 93 | center = (top + bottom) * 0.5 94 | top = int(round(center - width * 0.5)) 95 | bottom = top + width 96 | else: 97 | center = (left + right) * 0.5 98 | left = int(round(center - height * 0.5)) 99 | right = left + height 100 | 101 | return top, left, bottom, right 102 | 103 | for subject in bboxes.keys(): 104 | for action in bboxes[subject].keys(): 105 | for camera, bbox_array in bboxes[subject][action].items(): 106 | for frame_idx, bbox in enumerate(bbox_array): 107 | bbox[:] = square_the_bbox(bbox) 108 | 109 | if BBOXES_SOURCE is not 'GT': 110 | def replace_gt_bboxes_with_cnn(bboxes_gt, bboxes_detected_path, detections_file_list): 111 | """ 112 | Replace ground truth bounding boxes with boxes from a CNN detector. 113 | """ 114 | with open(bboxes_detected_path, 'r') as f: 115 | import json 116 | bboxes_detected = json.load(f) 117 | 118 | with open(detections_file_list, 'r') as f: 119 | for bbox, filename in zip(bboxes_detected, f): 120 | # parse filename 121 | filename = filename.strip() 122 | filename, frame_idx = filename[:-15], int(filename[-10:-4])-1 123 | filename, camera_name = filename[:-23], filename[-8:] 124 | slash_idx = filename.rfind('/') 125 | filename, action_name = filename[:slash_idx], filename[slash_idx+1:] 126 | subject_name = filename[filename.rfind('/')+1:] 127 | 128 | bbox, _ = bbox[:4], bbox[4] # throw confidence away 129 | bbox = square_the_bbox([bbox[1], bbox[0], bbox[3]+1, bbox[2]+1]) # LTRB to TLBR 130 | bboxes_gt[subject_name][action_name][camera_name][frame_idx] = bbox 131 | 132 | detections_paths = { 133 | 'MRCNN': { 134 | 'train': "/Vol1/dbstore/datasets/Human3.6M/extra/train_human36m_MRCNN.json", 135 | 'test': "/Vol1/dbstore/datasets/Human3.6M/extra/test_human36m_MRCNN.json" 136 | }, 137 | 'SSD': { 138 | 'train': "/Vol1/dbstore/datasets/k.iskakov/share/ssd-detections-train-human36m.json", 139 | 'test': "/Vol1/dbstore/datasets/k.iskakov/share/ssd-detections-human36m.json" 140 | } 141 | } 142 | 143 | replace_gt_bboxes_with_cnn( 144 | bboxes, 145 | detections_paths[BBOXES_SOURCE]['train'], 146 | "/Vol1/dbstore/datasets/Human3.6M/train-images-list.txt") 147 | 148 | replace_gt_bboxes_with_cnn( 149 | bboxes, 150 | detections_paths[BBOXES_SOURCE]['test'], 151 | "/Vol1/dbstore/datasets/Human3.6M/test-images-list.txt") 152 | 153 | # fill retval['table'] 154 | from action_to_una_dinosauria import action_to_una_dinosauria 155 | 156 | for subject_idx, subject in enumerate(retval['subject_names']): 157 | subject_path = os.path.join(h36m_root, "processed", subject) 158 | actions = os.listdir(subject_path) 159 | try: 160 | actions.remove('MySegmentsMat') # folder with bbox *.mat files 161 | except ValueError: 162 | pass 163 | 164 | for action_idx, action in enumerate(retval['action_names']): 165 | action_path = os.path.join(subject_path, action, 'imageSequence') 166 | if not os.path.isdir(action_path): 167 | raise FileNotFoundError(action_path) 168 | 169 | for camera_idx, camera in enumerate(retval['camera_names']): 170 | camera_path = os.path.join(action_path, camera) 171 | if os.path.isdir(camera_path): 172 | frame_idxs = sorted([int(name[4:-4])-1 for name in os.listdir(camera_path)]) 173 | assert len(frame_idxs) > 15, 'Too few frames in %s' % camera_path # otherwise WTF 174 | break 175 | else: 176 | raise FileNotFoundError(action_path) 177 | 178 | # 16 joints in MPII order + "Neck/Nose" 179 | valid_joints = (3,2,1,6,7,8,0,12,13,15,27,26,25,17,18,19) + (14,) 180 | with h5py.File(os.path.join(una_dinosauria_root, subject, 'MyPoses', '3D_positions', 181 | '%s.h5' % action_to_una_dinosauria[subject].get(action, action.replace('-', ' '))), 'r') as poses_file: 182 | poses_world = np.array(poses_file['3D_positions']).T.reshape(-1, 32, 3)[frame_idxs][:, valid_joints] 183 | 184 | table_segment = np.empty(len(frame_idxs), dtype=table_dtype) 185 | table_segment['subject_idx'] = subject_idx 186 | table_segment['action_idx'] = action_idx 187 | table_segment['frame_idx'] = frame_idxs 188 | table_segment['keypoints'] = poses_world 189 | table_segment['bbox_by_camera_tlbr'] = 0 # let a (0,0,0,0) bbox mean that this view is missing 190 | 191 | for (camera_idx, camera) in enumerate(retval['camera_names']): 192 | camera_path = os.path.join(action_path, camera) 193 | if not os.path.isdir(camera_path): 194 | print('Warning: camera %s isn\'t present in %s/%s' % (camera, subject, action)) 195 | continue 196 | 197 | for bbox, frame_idx in zip(table_segment['bbox_by_camera_tlbr'], frame_idxs): 198 | bbox[camera_idx] = bboxes[subject][action][camera][frame_idx] 199 | 200 | retval['table'].append(table_segment) 201 | 202 | retval['table'] = np.concatenate(retval['table']) 203 | assert retval['table'].ndim == 1 204 | 205 | print("Total frames in Human3.6Million:", len(retval['table'])) 206 | np.save(destination_file_path, retval) 207 | -------------------------------------------------------------------------------- /mvn/datasets/human36m_preprocessing/undistort-h36m.py: -------------------------------------------------------------------------------- 1 | """ 2 | Undistort images in Human3.6M and save them alongside (in ".../imageSequence-undistorted/..."). 3 | 4 | Usage: `python3 undistort-h36m.py ` 5 | """ 6 | import torch 7 | import numpy as np 8 | import cv2 9 | from tqdm import tqdm 10 | 11 | import os, sys 12 | 13 | sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../..")) 14 | from mvn.datasets.human36m import Human36MMultiViewDataset 15 | 16 | h36m_root = os.path.join(sys.argv[1], "processed") 17 | labels_multiview_npy_path = sys.argv[2] 18 | number_of_processes = int(sys.argv[3]) 19 | 20 | dataset = Human36MMultiViewDataset( 21 | h36m_root, 22 | labels_multiview_npy_path, 23 | train=True, # include all possible data 24 | test=True, 25 | image_shape=None, # don't resize 26 | retain_every_n_frames_in_test=1, # yes actually ALL possible data 27 | with_damaged_actions=True, # I said ALL DATA 28 | kind="mpii", 29 | norm_image=False, # don't do unnecessary image processing 30 | crop=False) # don't crop 31 | print("Dataset length:", len(dataset)) 32 | 33 | n_subjects = len(dataset.labels['subject_names']) 34 | n_cameras = len(dataset.labels['camera_names']) 35 | 36 | # First, prepare: compute distorted meshgrids 37 | print("Computing distorted meshgrids") 38 | meshgrids = np.empty((n_subjects, n_cameras), dtype=object) 39 | 40 | for sample_idx in tqdm(range(len(dataset))): 41 | subject_idx = dataset.labels['table']['subject_idx'][sample_idx] 42 | 43 | if not meshgrids[subject_idx].any(): 44 | bboxes = dataset.labels['table']['bbox_by_camera_tlbr'][sample_idx] 45 | 46 | if (bboxes[:, 2] - bboxes[:, 0]).min() > 0: # if == 0, then some camera is missing 47 | sample = dataset[sample_idx] 48 | assert len(sample['images']) == n_cameras 49 | 50 | for camera_idx, (camera, image) in enumerate(zip(sample['cameras'], sample['images'])): 51 | h, w = image.shape[:2] 52 | 53 | fx, fy = camera.K[0, 0], camera.K[1, 1] 54 | cx, cy = camera.K[0, 2], camera.K[1, 2] 55 | 56 | grid_x = (np.arange(w, dtype=np.float32) - cx) / fx 57 | grid_y = (np.arange(h, dtype=np.float32) - cy) / fy 58 | meshgrid = np.stack(np.meshgrid(grid_x, grid_y), axis=2).reshape(-1, 2) 59 | 60 | # distort meshgrid points 61 | k = camera.dist[:3].copy(); k[2] = camera.dist[-1] 62 | p = camera.dist[2:4].copy() 63 | 64 | r2 = meshgrid[:, 0] ** 2 + meshgrid[:, 1] ** 2 65 | radial = meshgrid * (1 + k[0] * r2 + k[1] * r2**2 + k[2] * r2**3).reshape(-1, 1) 66 | tangential_1 = p.reshape(1, 2) * np.broadcast_to(meshgrid[:, 0:1] * meshgrid[:, 1:2], (len(meshgrid), 2)) 67 | tangential_2 = p[::-1].reshape(1, 2) * (meshgrid**2 + np.broadcast_to(r2.reshape(-1, 1), (len(meshgrid), 2))) 68 | 69 | meshgrid = radial + tangential_1 + tangential_2 70 | 71 | # move back to screen coordinates 72 | meshgrid *= np.array([fx, fy]).reshape(1, 2) 73 | meshgrid += np.array([cx, cy]).reshape(1, 2) 74 | 75 | # cache (save) distortion maps 76 | meshgrids[subject_idx, camera_idx] = cv2.convertMaps(meshgrid.reshape((h, w, 2)), None, cv2.CV_16SC2) 77 | 78 | # Now the main part: undistort images 79 | def undistort_and_save(idx): 80 | sample = dataset[idx] 81 | 82 | shot = dataset.labels['table'][idx] 83 | subject_idx = shot['subject_idx'] 84 | action_idx = shot['action_idx'] 85 | frame_idx = shot['frame_idx'] 86 | 87 | subject = dataset.labels['subject_names'][subject_idx] 88 | action = dataset.labels['action_names'][action_idx] 89 | 90 | available_cameras = list(range(len(dataset.labels['action_names']))) 91 | for camera_idx, bbox in enumerate(shot['bbox_by_camera_tlbr']): 92 | if bbox[2] == bbox[0]: # bbox is empty, which means that this camera is missing 93 | available_cameras.remove(camera_idx) 94 | 95 | for camera_idx, image in zip(available_cameras, sample['images']): 96 | camera_name = dataset.labels['camera_names'][camera_idx] 97 | 98 | output_image_folder = os.path.join( 99 | h36m_root, subject, action, 'imageSequence-undistorted', camera_name) 100 | output_image_path = os.path.join(output_image_folder, 'img_%06d.jpg' % (frame_idx+1)) 101 | os.makedirs(output_image_folder, exist_ok=True) 102 | 103 | meshgrid_int16 = meshgrids[subject_idx, camera_idx] 104 | image_undistorted = cv2.remap(image, *meshgrid_int16, cv2.INTER_CUBIC) 105 | 106 | cv2.imwrite(output_image_path, image_undistorted) 107 | 108 | print(f"Undistorting images using {number_of_processes} parallel processes") 109 | cv2.setNumThreads(1) 110 | import multiprocessing 111 | 112 | pool = multiprocessing.Pool(number_of_processes) 113 | for _ in tqdm(pool.imap_unordered( 114 | undistort_and_save, range(len(dataset)), chunksize=10), total=len(dataset)): 115 | pass 116 | 117 | pool.close() 118 | pool.join() 119 | -------------------------------------------------------------------------------- /mvn/datasets/human36m_preprocessing/view-dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | A GUI script for inspecting Human3.6M and its wrappers, namely: 3 | - `Human36MMultiViewDataset` 4 | - `human36m-multiview-labels-**bboxes.npy` 5 | 6 | Usage: `python3 view-dataset.py 7 | """ 8 | import torch 9 | import numpy as np 10 | import cv2 11 | 12 | import os, sys 13 | 14 | h36m_root = sys.argv[1] 15 | labels_path = sys.argv[2] 16 | 17 | try: sample_idx = int(sys.argv[3]) 18 | except: sample_idx = 0 19 | 20 | try: step = int(sys.argv[4]) 21 | except: step = 10 22 | 23 | sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../..")) 24 | from mvn.datasets.human36m import Human36MMultiViewDataset 25 | 26 | dataset = Human36MMultiViewDataset( 27 | h36m_root, 28 | labels_path, 29 | train=True, 30 | test=True, 31 | image_shape=(512,512), 32 | retain_every_n_frames_in_test=1, 33 | with_damaged_actions=True, 34 | scale_bbox=1.0, 35 | kind='human36m', 36 | norm_image=False, 37 | undistort_images=True, 38 | ignore_cameras=[]) 39 | print(len(dataset)) 40 | 41 | prev_action = None 42 | patience = 0 43 | 44 | while True: 45 | sample = dataset[sample_idx] 46 | 47 | camera_idx = 0 48 | image = sample['images'][camera_idx] 49 | camera = sample['cameras'][camera_idx] 50 | 51 | display = image.copy() 52 | 53 | from mvn.utils.multiview import project_3d_points_to_image_plane_without_distortion as project 54 | keypoints_2d = project(camera.projection, sample['keypoints_3d'][:, :3]) 55 | 56 | for i,(x,y) in enumerate(keypoints_2d): 57 | cv2.circle(display, (int(x), int(y)), 3, (0,0,255), -1) 58 | # cv2.putText(display, str(i), (int(x)+3, int(y)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255)) 59 | 60 | # Get window name 61 | sample_info = dataset.labels['table'][sample_idx] 62 | subject_name = dataset.labels['subject_names'][sample_info['subject_idx']] 63 | action_name = dataset.labels['action_names'][sample_info['action_idx']] 64 | camera_name = dataset.labels['camera_names'][camera_idx] 65 | frame_idx = sample_info['frame_idx'] 66 | 67 | cv2.imshow('w', display) 68 | cv2.setWindowTitle('w', f"{subject_name}/{action_name}/{camera_name}/{frame_idx}") 69 | cv2.waitKey(0) 70 | 71 | action = dataset.labels['table'][sample_idx]['action_idx'] 72 | if action != prev_action: # started a new action 73 | prev_action = action 74 | patience = 2000 75 | sample_idx += step 76 | elif patience == 0: # an action ended, jump to the start of new action 77 | while True: 78 | sample_idx += step 79 | action = dataset.labels['table'][sample_idx]['action_idx'] 80 | if action != prev_action: 81 | break 82 | else: # in progess, just increment sample_idx 83 | patience -= 1 84 | sample_idx += step 85 | -------------------------------------------------------------------------------- /mvn/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from mvn.utils.img import image_batch_to_torch 5 | 6 | def make_collate_fn(randomize_n_views=True, min_n_views=10, max_n_views=31): 7 | 8 | def collate_fn(items): 9 | items = list(filter(lambda x: x is not None, items)) 10 | if len(items) == 0: 11 | print("All items in batch are None") 12 | return None 13 | 14 | batch = dict() 15 | total_n_views = min(len(item['images']) for item in items) 16 | 17 | indexes = np.arange(total_n_views) 18 | if randomize_n_views: 19 | n_views = np.random.randint(min_n_views, min(total_n_views, max_n_views) + 1) 20 | indexes = np.random.choice(np.arange(total_n_views), size=n_views, replace=False) 21 | else: 22 | indexes = np.arange(total_n_views) 23 | 24 | batch['images'] = np.stack([np.stack([item['images'][i] for item in items], axis=0) for i in indexes], axis=0).swapaxes(0, 1) 25 | batch['detections'] = np.array([[item['detections'][i] for item in items] for i in indexes]).swapaxes(0, 1) 26 | batch['cameras'] = [[item['cameras'][i] for item in items] for i in indexes] 27 | 28 | batch['keypoints_3d'] = [item['keypoints_3d'] for item in items] 29 | # batch['cuboids'] = [item['cuboids'] for item in items] 30 | batch['indexes'] = [item['indexes'] for item in items] 31 | 32 | try: 33 | batch['pred_keypoints_3d'] = np.array([item['pred_keypoints_3d'] for item in items]) 34 | except: 35 | pass 36 | 37 | return batch 38 | 39 | return collate_fn 40 | 41 | 42 | def worker_init_fn(worker_id): 43 | np.random.seed(np.random.get_state()[1][0] + worker_id) 44 | 45 | def prepare_batch(batch, device, config, is_train=True): 46 | # images 47 | images_batch = [] 48 | for image_batch in batch['images']: 49 | image_batch = image_batch_to_torch(image_batch) 50 | image_batch = image_batch.to(device) 51 | images_batch.append(image_batch) 52 | 53 | images_batch = torch.stack(images_batch, dim=0) 54 | 55 | # 3D keypoints 56 | keypoints_3d_batch_gt = torch.from_numpy(np.stack(batch['keypoints_3d'], axis=0)[:, :, :3]).float().to(device) 57 | 58 | # 3D keypoints validity 59 | keypoints_3d_validity_batch_gt = torch.from_numpy(np.stack(batch['keypoints_3d'], axis=0)[:, :, 3:]).float().to(device) 60 | 61 | # projection matricies 62 | proj_matricies_batch = torch.stack([torch.stack([torch.from_numpy(camera.projection) for camera in camera_batch], dim=0) for camera_batch in batch['cameras']], dim=0).transpose(1, 0) # shape (batch_size, n_views, 3, 4) 63 | proj_matricies_batch = proj_matricies_batch.float().to(device) 64 | 65 | return images_batch, keypoints_3d_batch_gt, keypoints_3d_validity_batch_gt, proj_matricies_batch 66 | -------------------------------------------------------------------------------- /mvn/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | -------------------------------------------------------------------------------- /mvn/models/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class KeypointsMSELoss(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def forward(self, keypoints_pred, keypoints_gt, keypoints_binary_validity): 12 | dimension = keypoints_pred.shape[-1] 13 | loss = torch.sum((keypoints_gt - keypoints_pred) ** 2 * keypoints_binary_validity) 14 | loss = loss / (dimension * max(1, torch.sum(keypoints_binary_validity).item())) 15 | return loss 16 | 17 | class KeypointsMSESmoothLoss(nn.Module): 18 | def __init__(self, threshold=400): 19 | super().__init__() 20 | 21 | self.threshold = threshold 22 | 23 | def forward(self, keypoints_pred, keypoints_gt, keypoints_binary_validity): 24 | dimension = keypoints_pred.shape[-1] 25 | diff = (keypoints_gt - keypoints_pred) ** 2 * keypoints_binary_validity 26 | diff[diff > self.threshold] = torch.pow(diff[diff > self.threshold], 0.1) * (self.threshold ** 0.9) 27 | loss = torch.sum(diff) / (dimension * max(1, torch.sum(keypoints_binary_validity).item())) 28 | return loss 29 | 30 | 31 | class KeypointsMAELoss(nn.Module): 32 | def __init__(self): 33 | super().__init__() 34 | 35 | def forward(self, keypoints_pred, keypoints_gt, keypoints_binary_validity): 36 | dimension = keypoints_pred.shape[-1] 37 | loss = torch.sum(torch.abs(keypoints_gt - keypoints_pred) * keypoints_binary_validity) 38 | loss = loss / (dimension * max(1, torch.sum(keypoints_binary_validity).item())) 39 | return loss 40 | 41 | 42 | class KeypointsL2Loss(nn.Module): 43 | def __init__(self): 44 | super().__init__() 45 | 46 | def forward(self, keypoints_pred, keypoints_gt, keypoints_binary_validity): 47 | loss = torch.sum(torch.sqrt(torch.sum((keypoints_gt - keypoints_pred) ** 2 * keypoints_binary_validity, dim=2))) 48 | loss = loss / max(1, torch.sum(keypoints_binary_validity).item()) 49 | return loss 50 | 51 | 52 | class VolumetricCELoss(nn.Module): 53 | def __init__(self): 54 | super().__init__() 55 | 56 | def forward(self, coord_volumes_batch, volumes_batch_pred, keypoints_gt, keypoints_binary_validity): 57 | loss = 0.0 58 | n_losses = 0 59 | 60 | batch_size = volumes_batch_pred.shape[0] 61 | for batch_i in range(batch_size): 62 | coord_volume = coord_volumes_batch[batch_i] 63 | keypoints_gt_i = keypoints_gt[batch_i] 64 | 65 | coord_volume_unsq = coord_volume.unsqueeze(0) 66 | keypoints_gt_i_unsq = keypoints_gt_i.unsqueeze(1).unsqueeze(1).unsqueeze(1) 67 | 68 | dists = torch.sqrt(((coord_volume_unsq - keypoints_gt_i_unsq) ** 2).sum(-1)) 69 | dists = dists.view(dists.shape[0], -1) 70 | 71 | min_indexes = torch.argmin(dists, dim=-1).detach().cpu().numpy() 72 | min_indexes = np.stack(np.unravel_index(min_indexes, volumes_batch_pred.shape[-3:]), axis=1) 73 | 74 | for joint_i, index in enumerate(min_indexes): 75 | validity = keypoints_binary_validity[batch_i, joint_i] 76 | loss += validity[0] * (-torch.log(volumes_batch_pred[batch_i, joint_i, index[0], index[1], index[2]] + 1e-6)) 77 | n_losses += 1 78 | 79 | 80 | return loss / n_losses 81 | -------------------------------------------------------------------------------- /mvn/models/pose_resnet.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/microsoft/human-pose-estimation.pytorch 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import os 8 | import logging 9 | 10 | import torch 11 | import torch.nn as nn 12 | from collections import OrderedDict 13 | 14 | 15 | BN_MOMENTUM = 0.1 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 65 | padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 67 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, 68 | bias=False) 69 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, 70 | momentum=BN_MOMENTUM) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class Bottleneck_CAFFE(nn.Module): 99 | expansion = 4 100 | 101 | def __init__(self, inplanes, planes, stride=1, downsample=None): 102 | super(Bottleneck_CAFFE, self).__init__() 103 | # add stride to conv1x1 104 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) 105 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 106 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, 107 | padding=1, bias=False) 108 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 109 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, 110 | bias=False) 111 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, 112 | momentum=BN_MOMENTUM) 113 | self.relu = nn.ReLU(inplace=True) 114 | self.downsample = downsample 115 | self.stride = stride 116 | 117 | def forward(self, x): 118 | residual = x 119 | 120 | out = self.conv1(x) 121 | out = self.bn1(out) 122 | out = self.relu(out) 123 | 124 | out = self.conv2(out) 125 | out = self.bn2(out) 126 | out = self.relu(out) 127 | 128 | out = self.conv3(out) 129 | out = self.bn3(out) 130 | 131 | if self.downsample is not None: 132 | residual = self.downsample(x) 133 | 134 | out += residual 135 | out = self.relu(out) 136 | 137 | return out 138 | 139 | 140 | class GlobalAveragePoolingHead(nn.Module): 141 | def __init__(self, in_channels, n_classes): 142 | super().__init__() 143 | 144 | self.features = nn.Sequential( 145 | nn.Conv2d(in_channels, 512, 3, stride=1, padding=1), 146 | nn.BatchNorm2d(512, momentum=BN_MOMENTUM), 147 | nn.MaxPool2d(2), 148 | nn.ReLU(inplace=True), 149 | 150 | nn.Conv2d(512, 256, 3, stride=1, padding=1), 151 | nn.BatchNorm2d(256, momentum=BN_MOMENTUM), 152 | nn.MaxPool2d(2), 153 | nn.ReLU(inplace=True), 154 | ) 155 | 156 | self.head = nn.Sequential( 157 | nn.Linear(256, 512), 158 | nn.ReLU(inplace=True), 159 | nn.Linear(512, 256), 160 | nn.ReLU(inplace=True), 161 | nn.Linear(256, n_classes), 162 | nn.Sigmoid() 163 | ) 164 | 165 | def forward(self, x): 166 | x = self.features(x) 167 | 168 | batch_size, n_channels = x.shape[:2] 169 | x = x.view((batch_size, n_channels, -1)) 170 | x = x.mean(dim=-1) 171 | 172 | out = self.head(x) 173 | 174 | return out 175 | 176 | 177 | resnet_spec = {18: (BasicBlock, [2, 2, 2, 2]), 178 | 34: (BasicBlock, [3, 4, 6, 3]), 179 | 50: (Bottleneck, [3, 4, 6, 3]), 180 | 101: (Bottleneck, [3, 4, 23, 3]), 181 | 152: (Bottleneck, [3, 8, 36, 3])} 182 | 183 | 184 | class PoseResNet(nn.Module): 185 | def __init__(self, block, layers, num_joints, 186 | num_input_channels=3, 187 | deconv_with_bias=False, 188 | num_deconv_layers=3, 189 | num_deconv_filters=(256, 256, 256), 190 | num_deconv_kernels=(4, 4, 4), 191 | final_conv_kernel=1, 192 | alg_confidences=False, 193 | vol_confidences=False 194 | ): 195 | super().__init__() 196 | 197 | self.num_joints = num_joints 198 | self.num_input_channels = num_input_channels 199 | self.inplanes = 64 200 | 201 | self.deconv_with_bias = deconv_with_bias 202 | self.num_deconv_layers, self.num_deconv_filters, self.num_deconv_kernels = num_deconv_layers, num_deconv_filters, num_deconv_kernels 203 | self.final_conv_kernel = final_conv_kernel 204 | 205 | self.conv1 = nn.Conv2d(num_input_channels, 64, kernel_size=7, stride=2, padding=3, 206 | bias=False) 207 | self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 208 | self.relu = nn.ReLU(inplace=True) 209 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 210 | self.layer1 = self._make_layer(block, 64, layers[0]) 211 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 212 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 213 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 214 | 215 | if alg_confidences: 216 | self.alg_confidences = GlobalAveragePoolingHead(512 * block.expansion, num_joints) 217 | 218 | if vol_confidences: 219 | self.vol_confidences = GlobalAveragePoolingHead(512 * block.expansion, 32) 220 | 221 | # used for deconv layers 222 | self.deconv_layers = self._make_deconv_layer( 223 | self.num_deconv_layers, 224 | self.num_deconv_filters, 225 | self.num_deconv_kernels, 226 | ) 227 | 228 | self.final_layer = nn.Conv2d( 229 | in_channels=self.num_deconv_filters[-1], 230 | out_channels=self.num_joints, 231 | kernel_size=self.final_conv_kernel, 232 | stride=1, 233 | padding=1 if self.final_conv_kernel == 3 else 0 234 | ) 235 | 236 | def _make_layer(self, block, planes, blocks, stride=1): 237 | downsample = None 238 | if stride != 1 or self.inplanes != planes * block.expansion: 239 | downsample = nn.Sequential( 240 | nn.Conv2d(self.inplanes, planes * block.expansion, 241 | kernel_size=1, stride=stride, bias=False), 242 | nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), 243 | ) 244 | 245 | layers = [] 246 | layers.append(block(self.inplanes, planes, stride, downsample)) 247 | self.inplanes = planes * block.expansion 248 | for i in range(1, blocks): 249 | layers.append(block(self.inplanes, planes)) 250 | 251 | return nn.Sequential(*layers) 252 | 253 | def _get_deconv_cfg(self, deconv_kernel, index): 254 | if deconv_kernel == 4: 255 | padding = 1 256 | output_padding = 0 257 | elif deconv_kernel == 3: 258 | padding = 1 259 | output_padding = 1 260 | elif deconv_kernel == 2: 261 | padding = 0 262 | output_padding = 0 263 | 264 | return deconv_kernel, padding, output_padding 265 | 266 | def _make_deconv_layer(self, num_layers, num_filters, num_kernels): 267 | assert num_layers == len(num_filters), \ 268 | 'ERROR: num_deconv_layers is different len(num_deconv_filters)' 269 | assert num_layers == len(num_kernels), \ 270 | 'ERROR: num_deconv_layers is different len(num_deconv_filters)' 271 | 272 | layers = [] 273 | for i in range(num_layers): 274 | kernel, padding, output_padding = \ 275 | self._get_deconv_cfg(num_kernels[i], i) 276 | 277 | planes = num_filters[i] 278 | layers.append( 279 | nn.ConvTranspose2d( 280 | in_channels=self.inplanes, 281 | out_channels=planes, 282 | kernel_size=kernel, 283 | stride=2, 284 | padding=padding, 285 | output_padding=output_padding, 286 | bias=self.deconv_with_bias)) 287 | layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)) 288 | layers.append(nn.ReLU(inplace=True)) 289 | self.inplanes = planes 290 | 291 | return nn.Sequential(*layers) 292 | 293 | def forward(self, x): 294 | x = self.conv1(x) 295 | x = self.bn1(x) 296 | x = self.relu(x) 297 | x = self.maxpool(x) 298 | 299 | x = self.layer1(x) 300 | x = self.layer2(x) 301 | x = self.layer3(x) 302 | x = self.layer4(x) 303 | 304 | alg_confidences = None 305 | if hasattr(self, "alg_confidences"): 306 | alg_confidences = self.alg_confidences(x) 307 | 308 | vol_confidences = None 309 | if hasattr(self, "vol_confidences"): 310 | vol_confidences = self.vol_confidences(x) 311 | 312 | x = self.deconv_layers(x) 313 | features = x 314 | 315 | x = self.final_layer(x) 316 | heatmaps = x 317 | 318 | return heatmaps, features, alg_confidences, vol_confidences 319 | 320 | 321 | def get_pose_net(config, device='cuda:0'): 322 | block_class, layers = resnet_spec[config.num_layers] 323 | if config.style == 'caffe': 324 | block_class = Bottleneck_CAFFE 325 | 326 | model = PoseResNet( 327 | block_class, layers, config.num_joints, 328 | num_input_channels=3, 329 | deconv_with_bias=False, 330 | num_deconv_layers=3, 331 | num_deconv_filters=(256, 256, 256), 332 | num_deconv_kernels=(4, 4, 4), 333 | final_conv_kernel=1, 334 | alg_confidences=config.alg_confidences, 335 | vol_confidences=config.vol_confidences 336 | ) 337 | 338 | if config.init_weights: 339 | print("Loading pretrained weights from: {}".format(config.checkpoint)) 340 | model_state_dict = model.state_dict() 341 | pretrained_state_dict = torch.load(config.checkpoint, map_location=device) 342 | 343 | if 'state_dict' in pretrained_state_dict: 344 | pretrained_state_dict = pretrained_state_dict['state_dict'] 345 | 346 | prefix = "module." 347 | 348 | new_pretrained_state_dict = {} 349 | for k, v in pretrained_state_dict.items(): 350 | if k.replace(prefix, "") in model_state_dict and v.shape == model_state_dict[k.replace(prefix, "")].shape: 351 | new_pretrained_state_dict[k.replace(prefix, "")] = v 352 | elif k.replace(prefix, "") == "final_layer.weight": # TODO 353 | print("Reiniting final layer filters:", k) 354 | 355 | o = torch.zeros_like(model_state_dict[k.replace(prefix, "")][:, :, :, :]) 356 | nn.init.xavier_uniform_(o) 357 | n_filters = min(o.shape[0], v.shape[0]) 358 | o[:n_filters, :, :, :] = v[:n_filters, :, :, :] 359 | 360 | new_pretrained_state_dict[k.replace(prefix, "")] = o 361 | elif k.replace(prefix, "") == "final_layer.bias": 362 | print("Reiniting final layer biases:", k) 363 | o = torch.zeros_like(model_state_dict[k.replace(prefix, "")][:]) 364 | nn.init.zeros_(o) 365 | n_filters = min(o.shape[0], v.shape[0]) 366 | o[:n_filters] = v[:n_filters] 367 | 368 | new_pretrained_state_dict[k.replace(prefix, "")] = o 369 | 370 | not_inited_params = set(map(lambda x: x.replace(prefix, ""), pretrained_state_dict.keys())) - set(new_pretrained_state_dict.keys()) 371 | if len(not_inited_params) > 0: 372 | print("Parameters [{}] were not inited".format(not_inited_params)) 373 | 374 | model.load_state_dict(new_pretrained_state_dict, strict=False) 375 | print("Successfully loaded pretrained weights for backbone") 376 | 377 | return model 378 | -------------------------------------------------------------------------------- /mvn/models/triangulation.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import numpy as np 3 | import pickle 4 | import random 5 | 6 | from scipy.optimize import least_squares 7 | 8 | import torch 9 | from torch import nn 10 | 11 | from mvn.utils import op, multiview, img, misc, volumetric 12 | 13 | from mvn.models import pose_resnet 14 | from mvn.models.v2v import V2VModel 15 | 16 | 17 | class RANSACTriangulationNet(nn.Module): 18 | def __init__(self, config, device='cuda:0'): 19 | super().__init__() 20 | 21 | config.model.backbone.alg_confidences = False 22 | config.model.backbone.vol_confidences = False 23 | self.backbone = pose_resnet.get_pose_net(config.model.backbone, device=device) 24 | 25 | self.direct_optimization = config.model.direct_optimization 26 | 27 | def forward(self, images, proj_matricies, batch): 28 | batch_size, n_views = images.shape[:2] 29 | 30 | # reshape n_views dimension to batch dimension 31 | images = images.view(-1, *images.shape[2:]) 32 | 33 | # forward backbone and integrate 34 | heatmaps, _, _, _ = self.backbone(images) 35 | 36 | # reshape back 37 | images = images.view(batch_size, n_views, *images.shape[1:]) 38 | heatmaps = heatmaps.view(batch_size, n_views, *heatmaps.shape[1:]) 39 | 40 | # calcualte shapes 41 | image_shape = tuple(images.shape[3:]) 42 | batch_size, n_views, n_joints, heatmap_shape = heatmaps.shape[0], heatmaps.shape[1], heatmaps.shape[2], tuple(heatmaps.shape[3:]) 43 | 44 | # keypoints 2d 45 | _, max_indicies = torch.max(heatmaps.view(batch_size, n_views, n_joints, -1), dim=-1) 46 | keypoints_2d = torch.stack([max_indicies % heatmap_shape[1], max_indicies // heatmap_shape[1]], dim=-1).to(images.device) 47 | 48 | # upscale keypoints_2d, because image shape != heatmap shape 49 | keypoints_2d_transformed = torch.zeros_like(keypoints_2d) 50 | keypoints_2d_transformed[:, :, :, 0] = keypoints_2d[:, :, :, 0] * (image_shape[1] / heatmap_shape[1]) 51 | keypoints_2d_transformed[:, :, :, 1] = keypoints_2d[:, :, :, 1] * (image_shape[0] / heatmap_shape[0]) 52 | keypoints_2d = keypoints_2d_transformed 53 | 54 | # triangulate (cpu) 55 | keypoints_2d_np = keypoints_2d.detach().cpu().numpy() 56 | proj_matricies_np = proj_matricies.detach().cpu().numpy() 57 | 58 | keypoints_3d = np.zeros((batch_size, n_joints, 3)) 59 | confidences = np.zeros((batch_size, n_views, n_joints)) # plug 60 | for batch_i in range(batch_size): 61 | for joint_i in range(n_joints): 62 | current_proj_matricies = proj_matricies_np[batch_i] 63 | points = keypoints_2d_np[batch_i, :, joint_i] 64 | keypoint_3d, _ = self.triangulate_ransac(current_proj_matricies, points, direct_optimization=self.direct_optimization) 65 | keypoints_3d[batch_i, joint_i] = keypoint_3d 66 | 67 | keypoints_3d = torch.from_numpy(keypoints_3d).type(torch.float).to(images.device) 68 | confidences = torch.from_numpy(confidences).type(torch.float).to(images.device) 69 | 70 | return keypoints_3d, keypoints_2d, heatmaps, confidences 71 | 72 | def triangulate_ransac(self, proj_matricies, points, n_iters=10, reprojection_error_epsilon=15, direct_optimization=True): 73 | assert len(proj_matricies) == len(points) 74 | assert len(points) >= 2 75 | 76 | proj_matricies = np.array(proj_matricies) 77 | points = np.array(points) 78 | 79 | n_views = len(points) 80 | 81 | # determine inliers 82 | view_set = set(range(n_views)) 83 | inlier_set = set() 84 | for i in range(n_iters): 85 | sampled_views = sorted(random.sample(view_set, 2)) 86 | 87 | keypoint_3d_in_base_camera = multiview.triangulate_point_from_multiple_views_linear(proj_matricies[sampled_views], points[sampled_views]) 88 | reprojection_error_vector = multiview.calc_reprojection_error_matrix(np.array([keypoint_3d_in_base_camera]), points, proj_matricies)[0] 89 | 90 | new_inlier_set = set(sampled_views) 91 | for view in view_set: 92 | current_reprojection_error = reprojection_error_vector[view] 93 | if current_reprojection_error < reprojection_error_epsilon: 94 | new_inlier_set.add(view) 95 | 96 | if len(new_inlier_set) > len(inlier_set): 97 | inlier_set = new_inlier_set 98 | 99 | # triangulate using inlier_set 100 | if len(inlier_set) == 0: 101 | inlier_set = view_set.copy() 102 | 103 | inlier_list = np.array(sorted(inlier_set)) 104 | inlier_proj_matricies = proj_matricies[inlier_list] 105 | inlier_points = points[inlier_list] 106 | 107 | keypoint_3d_in_base_camera = multiview.triangulate_point_from_multiple_views_linear(inlier_proj_matricies, inlier_points) 108 | reprojection_error_vector = multiview.calc_reprojection_error_matrix(np.array([keypoint_3d_in_base_camera]), inlier_points, inlier_proj_matricies)[0] 109 | reprojection_error_mean = np.mean(reprojection_error_vector) 110 | 111 | keypoint_3d_in_base_camera_before_direct_optimization = keypoint_3d_in_base_camera 112 | reprojection_error_before_direct_optimization = reprojection_error_mean 113 | 114 | # direct reprojection error minimization 115 | if direct_optimization: 116 | def residual_function(x): 117 | reprojection_error_vector = multiview.calc_reprojection_error_matrix(np.array([x]), inlier_points, inlier_proj_matricies)[0] 118 | residuals = reprojection_error_vector 119 | return residuals 120 | 121 | x_0 = np.array(keypoint_3d_in_base_camera) 122 | res = least_squares(residual_function, x_0, loss='huber', method='trf') 123 | 124 | keypoint_3d_in_base_camera = res.x 125 | reprojection_error_vector = multiview.calc_reprojection_error_matrix(np.array([keypoint_3d_in_base_camera]), inlier_points, inlier_proj_matricies)[0] 126 | reprojection_error_mean = np.mean(reprojection_error_vector) 127 | 128 | return keypoint_3d_in_base_camera, inlier_list 129 | 130 | 131 | class AlgebraicTriangulationNet(nn.Module): 132 | def __init__(self, config, device='cuda:0'): 133 | super().__init__() 134 | 135 | self.use_confidences = config.model.use_confidences 136 | 137 | config.model.backbone.alg_confidences = False 138 | config.model.backbone.vol_confidences = False 139 | 140 | if self.use_confidences: 141 | config.model.backbone.alg_confidences = True 142 | 143 | self.backbone = pose_resnet.get_pose_net(config.model.backbone, device=device) 144 | 145 | self.heatmap_softmax = config.model.heatmap_softmax 146 | self.heatmap_multiplier = config.model.heatmap_multiplier 147 | 148 | 149 | def forward(self, images, proj_matricies, batch): 150 | device = images.device 151 | batch_size, n_views = images.shape[:2] 152 | 153 | # reshape n_views dimension to batch dimension 154 | images = images.view(-1, *images.shape[2:]) 155 | 156 | # forward backbone and integral 157 | if self.use_confidences: 158 | heatmaps, _, alg_confidences, _ = self.backbone(images) 159 | else: 160 | heatmaps, _, _, _ = self.backbone(images) 161 | alg_confidences = torch.ones(batch_size * n_views, heatmaps.shape[1]).type(torch.float).to(device) 162 | 163 | heatmaps_before_softmax = heatmaps.view(batch_size, n_views, *heatmaps.shape[1:]) 164 | keypoints_2d, heatmaps = op.integrate_tensor_2d(heatmaps * self.heatmap_multiplier, self.heatmap_softmax) 165 | 166 | # reshape back 167 | images = images.view(batch_size, n_views, *images.shape[1:]) 168 | heatmaps = heatmaps.view(batch_size, n_views, *heatmaps.shape[1:]) 169 | keypoints_2d = keypoints_2d.view(batch_size, n_views, *keypoints_2d.shape[1:]) 170 | alg_confidences = alg_confidences.view(batch_size, n_views, *alg_confidences.shape[1:]) 171 | 172 | # norm confidences 173 | alg_confidences = alg_confidences / alg_confidences.sum(dim=1, keepdim=True) 174 | alg_confidences = alg_confidences + 1e-5 # for numerical stability 175 | 176 | # calcualte shapes 177 | image_shape = tuple(images.shape[3:]) 178 | batch_size, n_views, n_joints, heatmap_shape = heatmaps.shape[0], heatmaps.shape[1], heatmaps.shape[2], tuple(heatmaps.shape[3:]) 179 | 180 | # upscale keypoints_2d, because image shape != heatmap shape 181 | keypoints_2d_transformed = torch.zeros_like(keypoints_2d) 182 | keypoints_2d_transformed[:, :, :, 0] = keypoints_2d[:, :, :, 0] * (image_shape[1] / heatmap_shape[1]) 183 | keypoints_2d_transformed[:, :, :, 1] = keypoints_2d[:, :, :, 1] * (image_shape[0] / heatmap_shape[0]) 184 | keypoints_2d = keypoints_2d_transformed 185 | 186 | # triangulate 187 | try: 188 | keypoints_3d = multiview.triangulate_batch_of_points( 189 | proj_matricies, keypoints_2d, 190 | confidences_batch=alg_confidences 191 | ) 192 | except RuntimeError as e: 193 | print("Error: ", e) 194 | 195 | print("confidences =", confidences_batch_pred) 196 | print("proj_matricies = ", proj_matricies) 197 | print("keypoints_2d_batch_pred =", keypoints_2d_batch_pred) 198 | exit() 199 | 200 | return keypoints_3d, keypoints_2d, heatmaps, alg_confidences 201 | 202 | 203 | class VolumetricTriangulationNet(nn.Module): 204 | def __init__(self, config, device='cuda:0'): 205 | super().__init__() 206 | 207 | self.num_joints = config.model.backbone.num_joints 208 | self.volume_aggregation_method = config.model.volume_aggregation_method 209 | 210 | # volume 211 | self.volume_softmax = config.model.volume_softmax 212 | self.volume_multiplier = config.model.volume_multiplier 213 | self.volume_size = config.model.volume_size 214 | 215 | self.cuboid_side = config.model.cuboid_side 216 | 217 | self.kind = config.model.kind 218 | self.use_gt_pelvis = config.model.use_gt_pelvis 219 | 220 | # heatmap 221 | self.heatmap_softmax = config.model.heatmap_softmax 222 | self.heatmap_multiplier = config.model.heatmap_multiplier 223 | 224 | # transfer 225 | self.transfer_cmu_to_human36m = config.model.transfer_cmu_to_human36m if hasattr(config.model, "transfer_cmu_to_human36m") else False 226 | 227 | # modules 228 | config.model.backbone.alg_confidences = False 229 | config.model.backbone.vol_confidences = False 230 | if self.volume_aggregation_method.startswith('conf'): 231 | config.model.backbone.vol_confidences = True 232 | 233 | self.backbone = pose_resnet.get_pose_net(config.model.backbone, device=device) 234 | 235 | for p in self.backbone.final_layer.parameters(): 236 | p.requires_grad = False 237 | 238 | self.process_features = nn.Sequential( 239 | nn.Conv2d(256, 32, 1) 240 | ) 241 | 242 | self.volume_net = V2VModel(32, self.num_joints) 243 | 244 | 245 | def forward(self, images, proj_matricies, batch): 246 | device = images.device 247 | batch_size, n_views = images.shape[:2] 248 | 249 | # reshape for backbone forward 250 | images = images.view(-1, *images.shape[2:]) 251 | 252 | # forward backbone 253 | heatmaps, features, _, vol_confidences = self.backbone(images) 254 | 255 | # reshape back 256 | images = images.view(batch_size, n_views, *images.shape[1:]) 257 | heatmaps = heatmaps.view(batch_size, n_views, *heatmaps.shape[1:]) 258 | features = features.view(batch_size, n_views, *features.shape[1:]) 259 | 260 | if vol_confidences is not None: 261 | vol_confidences = vol_confidences.view(batch_size, n_views, *vol_confidences.shape[1:]) 262 | 263 | # calcualte shapes 264 | image_shape, heatmap_shape = tuple(images.shape[3:]), tuple(heatmaps.shape[3:]) 265 | n_joints = heatmaps.shape[2] 266 | 267 | # norm vol confidences 268 | if self.volume_aggregation_method == 'conf_norm': 269 | vol_confidences = vol_confidences / vol_confidences.sum(dim=1, keepdim=True) 270 | 271 | # change camera intrinsics 272 | new_cameras = deepcopy(batch['cameras']) 273 | for view_i in range(n_views): 274 | for batch_i in range(batch_size): 275 | new_cameras[view_i][batch_i].update_after_resize(image_shape, heatmap_shape) 276 | 277 | proj_matricies = torch.stack([torch.stack([torch.from_numpy(camera.projection) for camera in camera_batch], dim=0) for camera_batch in new_cameras], dim=0).transpose(1, 0) # shape (batch_size, n_views, 3, 4) 278 | proj_matricies = proj_matricies.float().to(device) 279 | 280 | # build coord volumes 281 | cuboids = [] 282 | base_points = torch.zeros(batch_size, 3, device=device) 283 | coord_volumes = torch.zeros(batch_size, self.volume_size, self.volume_size, self.volume_size, 3, device=device) 284 | for batch_i in range(batch_size): 285 | # if self.use_precalculated_pelvis: 286 | if self.use_gt_pelvis: 287 | keypoints_3d = batch['keypoints_3d'][batch_i] 288 | else: 289 | keypoints_3d = batch['pred_keypoints_3d'][batch_i] 290 | 291 | if self.kind == "coco": 292 | base_point = (keypoints_3d[11, :3] + keypoints_3d[12, :3]) / 2 293 | elif self.kind == "mpii": 294 | base_point = keypoints_3d[6, :3] 295 | 296 | base_points[batch_i] = torch.from_numpy(base_point).to(device) 297 | 298 | # build cuboid 299 | sides = np.array([self.cuboid_side, self.cuboid_side, self.cuboid_side]) 300 | position = base_point - sides / 2 301 | cuboid = volumetric.Cuboid3D(position, sides) 302 | 303 | cuboids.append(cuboid) 304 | 305 | # build coord volume 306 | xxx, yyy, zzz = torch.meshgrid(torch.arange(self.volume_size, device=device), torch.arange(self.volume_size, device=device), torch.arange(self.volume_size, device=device)) 307 | grid = torch.stack([xxx, yyy, zzz], dim=-1).type(torch.float) 308 | grid = grid.reshape((-1, 3)) 309 | 310 | grid_coord = torch.zeros_like(grid) 311 | grid_coord[:, 0] = position[0] + (sides[0] / (self.volume_size - 1)) * grid[:, 0] 312 | grid_coord[:, 1] = position[1] + (sides[1] / (self.volume_size - 1)) * grid[:, 1] 313 | grid_coord[:, 2] = position[2] + (sides[2] / (self.volume_size - 1)) * grid[:, 2] 314 | 315 | coord_volume = grid_coord.reshape(self.volume_size, self.volume_size, self.volume_size, 3) 316 | 317 | # random rotation 318 | if self.training: 319 | theta = np.random.uniform(0.0, 2 * np.pi) 320 | else: 321 | theta = 0.0 322 | 323 | if self.kind == "coco": 324 | axis = [0, 1, 0] # y axis 325 | elif self.kind == "mpii": 326 | axis = [0, 0, 1] # z axis 327 | 328 | center = torch.from_numpy(base_point).type(torch.float).to(device) 329 | 330 | # rotate 331 | coord_volume = coord_volume - center 332 | coord_volume = volumetric.rotate_coord_volume(coord_volume, theta, axis) 333 | coord_volume = coord_volume + center 334 | 335 | # transfer 336 | if self.transfer_cmu_to_human36m: # different world coordinates 337 | coord_volume = coord_volume.permute(0, 2, 1, 3) 338 | inv_idx = torch.arange(coord_volume.shape[1] - 1, -1, -1).long().to(device) 339 | coord_volume = coord_volume.index_select(1, inv_idx) 340 | 341 | coord_volumes[batch_i] = coord_volume 342 | 343 | # process features before unprojecting 344 | features = features.view(-1, *features.shape[2:]) 345 | features = self.process_features(features) 346 | features = features.view(batch_size, n_views, *features.shape[1:]) 347 | 348 | # lift to volume 349 | volumes = op.unproject_heatmaps(features, proj_matricies, coord_volumes, volume_aggregation_method=self.volume_aggregation_method, vol_confidences=vol_confidences) 350 | 351 | # integral 3d 352 | volumes = self.volume_net(volumes) 353 | vol_keypoints_3d, volumes = op.integrate_tensor_3d_with_coordinates(volumes * self.volume_multiplier, coord_volumes, softmax=self.volume_softmax) 354 | 355 | return vol_keypoints_3d, features, volumes, vol_confidences, cuboids, coord_volumes, base_points 356 | -------------------------------------------------------------------------------- /mvn/models/v2v.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/dragonbook/V2V-PoseNet-pytorch 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Basic3DBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, kernel_size): 9 | super(Basic3DBlock, self).__init__() 10 | self.block = nn.Sequential( 11 | nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=1, padding=((kernel_size-1)//2)), 12 | nn.BatchNorm3d(out_planes), 13 | nn.ReLU(True) 14 | ) 15 | 16 | def forward(self, x): 17 | return self.block(x) 18 | 19 | 20 | class Res3DBlock(nn.Module): 21 | def __init__(self, in_planes, out_planes): 22 | super(Res3DBlock, self).__init__() 23 | self.res_branch = nn.Sequential( 24 | nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=1, padding=1), 25 | nn.BatchNorm3d(out_planes), 26 | nn.ReLU(True), 27 | nn.Conv3d(out_planes, out_planes, kernel_size=3, stride=1, padding=1), 28 | nn.BatchNorm3d(out_planes) 29 | ) 30 | 31 | if in_planes == out_planes: 32 | self.skip_con = nn.Sequential() 33 | else: 34 | self.skip_con = nn.Sequential( 35 | nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=1, padding=0), 36 | nn.BatchNorm3d(out_planes) 37 | ) 38 | 39 | def forward(self, x): 40 | res = self.res_branch(x) 41 | skip = self.skip_con(x) 42 | return F.relu(res + skip, True) 43 | 44 | 45 | class Pool3DBlock(nn.Module): 46 | def __init__(self, pool_size): 47 | super(Pool3DBlock, self).__init__() 48 | self.pool_size = pool_size 49 | 50 | def forward(self, x): 51 | return F.max_pool3d(x, kernel_size=self.pool_size, stride=self.pool_size) 52 | 53 | 54 | class Upsample3DBlock(nn.Module): 55 | def __init__(self, in_planes, out_planes, kernel_size, stride): 56 | super(Upsample3DBlock, self).__init__() 57 | assert(kernel_size == 2) 58 | assert(stride == 2) 59 | self.block = nn.Sequential( 60 | nn.ConvTranspose3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=0, output_padding=0), 61 | nn.BatchNorm3d(out_planes), 62 | nn.ReLU(True) 63 | ) 64 | 65 | def forward(self, x): 66 | return self.block(x) 67 | 68 | 69 | class EncoderDecorder(nn.Module): 70 | def __init__(self): 71 | super().__init__() 72 | 73 | self.encoder_pool1 = Pool3DBlock(2) 74 | self.encoder_res1 = Res3DBlock(32, 64) 75 | self.encoder_pool2 = Pool3DBlock(2) 76 | self.encoder_res2 = Res3DBlock(64, 128) 77 | self.encoder_pool3 = Pool3DBlock(2) 78 | self.encoder_res3 = Res3DBlock(128, 128) 79 | self.encoder_pool4 = Pool3DBlock(2) 80 | self.encoder_res4 = Res3DBlock(128, 128) 81 | self.encoder_pool5 = Pool3DBlock(2) 82 | self.encoder_res5 = Res3DBlock(128, 128) 83 | 84 | self.mid_res = Res3DBlock(128, 128) 85 | 86 | self.decoder_res5 = Res3DBlock(128, 128) 87 | self.decoder_upsample5 = Upsample3DBlock(128, 128, 2, 2) 88 | self.decoder_res4 = Res3DBlock(128, 128) 89 | self.decoder_upsample4 = Upsample3DBlock(128, 128, 2, 2) 90 | self.decoder_res3 = Res3DBlock(128, 128) 91 | self.decoder_upsample3 = Upsample3DBlock(128, 128, 2, 2) 92 | self.decoder_res2 = Res3DBlock(128, 128) 93 | self.decoder_upsample2 = Upsample3DBlock(128, 64, 2, 2) 94 | self.decoder_res1 = Res3DBlock(64, 64) 95 | self.decoder_upsample1 = Upsample3DBlock(64, 32, 2, 2) 96 | 97 | self.skip_res1 = Res3DBlock(32, 32) 98 | self.skip_res2 = Res3DBlock(64, 64) 99 | self.skip_res3 = Res3DBlock(128, 128) 100 | self.skip_res4 = Res3DBlock(128, 128) 101 | self.skip_res5 = Res3DBlock(128, 128) 102 | 103 | def forward(self, x): 104 | skip_x1 = self.skip_res1(x) 105 | x = self.encoder_pool1(x) 106 | x = self.encoder_res1(x) 107 | skip_x2 = self.skip_res2(x) 108 | x = self.encoder_pool2(x) 109 | x = self.encoder_res2(x) 110 | skip_x3 = self.skip_res3(x) 111 | x = self.encoder_pool3(x) 112 | x = self.encoder_res3(x) 113 | skip_x4 = self.skip_res4(x) 114 | x = self.encoder_pool4(x) 115 | x = self.encoder_res4(x) 116 | skip_x5 = self.skip_res5(x) 117 | x = self.encoder_pool5(x) 118 | x = self.encoder_res5(x) 119 | 120 | x = self.mid_res(x) 121 | 122 | x = self.decoder_res5(x) 123 | x = self.decoder_upsample5(x) 124 | x = x + skip_x5 125 | x = self.decoder_res4(x) 126 | x = self.decoder_upsample4(x) 127 | x = x + skip_x4 128 | x = self.decoder_res3(x) 129 | x = self.decoder_upsample3(x) 130 | x = x + skip_x3 131 | x = self.decoder_res2(x) 132 | x = self.decoder_upsample2(x) 133 | x = x + skip_x2 134 | x = self.decoder_res1(x) 135 | x = self.decoder_upsample1(x) 136 | x = x + skip_x1 137 | 138 | return x 139 | 140 | 141 | class V2VModel(nn.Module): 142 | def __init__(self, input_channels, output_channels): 143 | super().__init__() 144 | 145 | self.front_layers = nn.Sequential( 146 | Basic3DBlock(input_channels, 16, 7), 147 | Res3DBlock(16, 32), 148 | Res3DBlock(32, 32), 149 | Res3DBlock(32, 32) 150 | ) 151 | 152 | self.encoder_decoder = EncoderDecorder() 153 | 154 | self.back_layers = nn.Sequential( 155 | Res3DBlock(32, 32), 156 | Basic3DBlock(32, 32, 1), 157 | Basic3DBlock(32, 32, 1), 158 | ) 159 | 160 | self.output_layer = nn.Conv3d(32, output_channels, kernel_size=1, stride=1, padding=0) 161 | 162 | self._initialize_weights() 163 | 164 | def forward(self, x): 165 | x = self.front_layers(x) 166 | x = self.encoder_decoder(x) 167 | x = self.back_layers(x) 168 | x = self.output_layer(x) 169 | return x 170 | 171 | def _initialize_weights(self): 172 | for m in self.modules(): 173 | if isinstance(m, nn.Conv3d): 174 | nn.init.xavier_normal_(m.weight) 175 | # nn.init.normal_(m.weight, 0, 0.001) 176 | nn.init.constant_(m.bias, 0) 177 | elif isinstance(m, nn.ConvTranspose3d): 178 | nn.init.xavier_normal_(m.weight) 179 | # nn.init.normal_(m.weight, 0, 0.001) 180 | nn.init.constant_(m.bias, 0) 181 | -------------------------------------------------------------------------------- /mvn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | -------------------------------------------------------------------------------- /mvn/utils/cfg.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from easydict import EasyDict as edict 3 | 4 | 5 | def load_config(path): 6 | with open(path) as fin: 7 | config = edict(yaml.safe_load(fin)) 8 | 9 | return config 10 | -------------------------------------------------------------------------------- /mvn/utils/img.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from PIL import Image 4 | 5 | import torch 6 | 7 | IMAGENET_MEAN, IMAGENET_STD = np.array([0.485, 0.456, 0.406]), np.array([0.229, 0.224, 0.225]) 8 | 9 | 10 | def crop_image(image, bbox): 11 | """Crops area from image specified as bbox. Always returns area of size as bbox filling missing parts with zeros 12 | Args: 13 | image numpy array of shape (height, width, 3): input image 14 | bbox tuple of size 4: input bbox (left, upper, right, lower) 15 | 16 | Returns: 17 | cropped_image numpy array of shape (height, width, 3): resulting cropped image 18 | 19 | """ 20 | 21 | image_pil = Image.fromarray(image) 22 | image_pil = image_pil.crop(bbox) 23 | 24 | return np.asarray(image_pil) 25 | 26 | 27 | def resize_image(image, shape): 28 | return cv2.resize(image, (shape[1], shape[0]), interpolation=cv2.INTER_AREA) 29 | 30 | 31 | def get_square_bbox(bbox): 32 | """Makes square bbox from any bbox by stretching of minimal length side 33 | 34 | Args: 35 | bbox tuple of size 4: input bbox (left, upper, right, lower) 36 | 37 | Returns: 38 | bbox: tuple of size 4: resulting square bbox (left, upper, right, lower) 39 | """ 40 | 41 | left, upper, right, lower = bbox 42 | width, height = right - left, lower - upper 43 | 44 | if width > height: 45 | y_center = (upper + lower) // 2 46 | upper = y_center - width // 2 47 | lower = upper + width 48 | else: 49 | x_center = (left + right) // 2 50 | left = x_center - height // 2 51 | right = left + height 52 | 53 | return left, upper, right, lower 54 | 55 | 56 | def scale_bbox(bbox, scale): 57 | left, upper, right, lower = bbox 58 | width, height = right - left, lower - upper 59 | 60 | x_center, y_center = (right + left) // 2, (lower + upper) // 2 61 | new_width, new_height = int(scale * width), int(scale * height) 62 | 63 | new_left = x_center - new_width // 2 64 | new_right = new_left + new_width 65 | 66 | new_upper = y_center - new_height // 2 67 | new_lower = new_upper + new_height 68 | 69 | return new_left, new_upper, new_right, new_lower 70 | 71 | 72 | def to_numpy(tensor): 73 | if torch.is_tensor(tensor): 74 | return tensor.cpu().detach().numpy() 75 | elif type(tensor).__module__ != 'numpy': 76 | raise ValueError("Cannot convert {} to numpy array" 77 | .format(type(tensor))) 78 | return tensor 79 | 80 | 81 | def to_torch(ndarray): 82 | if type(ndarray).__module__ == 'numpy': 83 | return torch.from_numpy(ndarray) 84 | elif not torch.is_tensor(ndarray): 85 | raise ValueError("Cannot convert {} to torch tensor" 86 | .format(type(ndarray))) 87 | return ndarray 88 | 89 | 90 | def image_batch_to_numpy(image_batch): 91 | image_batch = to_numpy(image_batch) 92 | image_batch = np.transpose(image_batch, (0, 2, 3, 1)) # BxCxHxW -> BxHxWxC 93 | return image_batch 94 | 95 | 96 | def image_batch_to_torch(image_batch): 97 | image_batch = np.transpose(image_batch, (0, 3, 1, 2)) # BxHxWxC -> BxCxHxW 98 | image_batch = to_torch(image_batch).float() 99 | return image_batch 100 | 101 | 102 | def normalize_image(image): 103 | """Normalizes image using ImageNet mean and std 104 | 105 | Args: 106 | image numpy array of shape (h, w, 3): image 107 | 108 | Returns normalized_image numpy array of shape (h, w, 3): normalized image 109 | """ 110 | return (image / 255.0 - IMAGENET_MEAN) / IMAGENET_STD 111 | 112 | 113 | def denormalize_image(image): 114 | """Reverse to normalize_image() function""" 115 | return np.clip(255.0 * (image * IMAGENET_STD + IMAGENET_MEAN), 0, 255) 116 | -------------------------------------------------------------------------------- /mvn/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import json 4 | import re 5 | 6 | import torch 7 | 8 | 9 | def config_to_str(config): 10 | return yaml.dump(yaml.safe_load(json.dumps(config))) # fuck yeah 11 | 12 | 13 | class AverageMeter(object): 14 | """Computes and stores the average and current value""" 15 | def __init__(self): 16 | self.reset() 17 | 18 | def reset(self): 19 | self.val = 0 20 | self.avg = 0 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.val = val 26 | self.sum += val * n 27 | self.count += n 28 | self.avg = self.sum / self.count 29 | 30 | 31 | def calc_gradient_norm(named_parameters): 32 | total_norm = 0.0 33 | for name, p in named_parameters: 34 | # print(name) 35 | param_norm = p.grad.data.norm(2) 36 | total_norm += param_norm.item() ** 2 37 | 38 | total_norm = total_norm ** (1. / 2) 39 | 40 | return total_norm 41 | -------------------------------------------------------------------------------- /mvn/utils/multiview.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class Camera: 6 | def __init__(self, R, t, K, dist=None, name=""): 7 | self.R = np.array(R).copy() 8 | assert self.R.shape == (3, 3) 9 | 10 | self.t = np.array(t).copy() 11 | assert self.t.size == 3 12 | self.t = self.t.reshape(3, 1) 13 | 14 | self.K = np.array(K).copy() 15 | assert self.K.shape == (3, 3) 16 | 17 | self.dist = dist 18 | if self.dist is not None: 19 | self.dist = np.array(self.dist).copy().flatten() 20 | 21 | self.name = name 22 | 23 | def update_after_crop(self, bbox): 24 | left, upper, right, lower = bbox 25 | 26 | cx, cy = self.K[0, 2], self.K[1, 2] 27 | 28 | new_cx = cx - left 29 | new_cy = cy - upper 30 | 31 | self.K[0, 2], self.K[1, 2] = new_cx, new_cy 32 | 33 | def update_after_resize(self, image_shape, new_image_shape): 34 | height, width = image_shape 35 | new_height, new_width = new_image_shape 36 | 37 | fx, fy, cx, cy = self.K[0, 0], self.K[1, 1], self.K[0, 2], self.K[1, 2] 38 | 39 | new_fx = fx * (new_width / width) 40 | new_fy = fy * (new_height / height) 41 | new_cx = cx * (new_width / width) 42 | new_cy = cy * (new_height / height) 43 | 44 | self.K[0, 0], self.K[1, 1], self.K[0, 2], self.K[1, 2] = new_fx, new_fy, new_cx, new_cy 45 | 46 | @property 47 | def projection(self): 48 | return self.K.dot(self.extrinsics) 49 | 50 | @property 51 | def extrinsics(self): 52 | return np.hstack([self.R, self.t]) 53 | 54 | 55 | def euclidean_to_homogeneous(points): 56 | """Converts euclidean points to homogeneous 57 | 58 | Args: 59 | points numpy array or torch tensor of shape (N, M): N euclidean points of dimension M 60 | 61 | Returns: 62 | numpy array or torch tensor of shape (N, M + 1): homogeneous points 63 | """ 64 | if isinstance(points, np.ndarray): 65 | return np.hstack([points, np.ones((len(points), 1))]) 66 | elif torch.is_tensor(points): 67 | return torch.cat([points, torch.ones((points.shape[0], 1), dtype=points.dtype, device=points.device)], dim=1) 68 | else: 69 | raise TypeError("Works only with numpy arrays and PyTorch tensors.") 70 | 71 | 72 | def homogeneous_to_euclidean(points): 73 | """Converts homogeneous points to euclidean 74 | 75 | Args: 76 | points numpy array or torch tensor of shape (N, M + 1): N homogeneous points of dimension M 77 | 78 | Returns: 79 | numpy array or torch tensor of shape (N, M): euclidean points 80 | """ 81 | if isinstance(points, np.ndarray): 82 | return (points.T[:-1] / points.T[-1]).T 83 | elif torch.is_tensor(points): 84 | return (points.transpose(1, 0)[:-1] / points.transpose(1, 0)[-1]).transpose(1, 0) 85 | else: 86 | raise TypeError("Works only with numpy arrays and PyTorch tensors.") 87 | 88 | 89 | def project_3d_points_to_image_plane_without_distortion(proj_matrix, points_3d, convert_back_to_euclidean=True): 90 | """Project 3D points to image plane not taking into account distortion 91 | Args: 92 | proj_matrix numpy array or torch tensor of shape (3, 4): projection matrix 93 | points_3d numpy array or torch tensor of shape (N, 3): 3D points 94 | convert_back_to_euclidean bool: if True, then resulting points will be converted to euclidean coordinates 95 | NOTE: division by zero can be here if z = 0 96 | Returns: 97 | numpy array or torch tensor of shape (N, 2): 3D points projected to image plane 98 | """ 99 | if isinstance(proj_matrix, np.ndarray) and isinstance(points_3d, np.ndarray): 100 | result = euclidean_to_homogeneous(points_3d) @ proj_matrix.T 101 | if convert_back_to_euclidean: 102 | result = homogeneous_to_euclidean(result) 103 | return result 104 | elif torch.is_tensor(proj_matrix) and torch.is_tensor(points_3d): 105 | result = euclidean_to_homogeneous(points_3d) @ proj_matrix.t() 106 | if convert_back_to_euclidean: 107 | result = homogeneous_to_euclidean(result) 108 | return result 109 | else: 110 | raise TypeError("Works only with numpy arrays and PyTorch tensors.") 111 | 112 | 113 | def triangulate_point_from_multiple_views_linear(proj_matricies, points): 114 | """Triangulates one point from multiple (N) views using direct linear transformation (DLT). 115 | For more information look at "Multiple view geometry in computer vision", 116 | Richard Hartley and Andrew Zisserman, 12.2 (p. 312). 117 | 118 | Args: 119 | proj_matricies numpy array of shape (N, 3, 4): sequence of projection matricies (3x4) 120 | points numpy array of shape (N, 2): sequence of points' coordinates 121 | 122 | Returns: 123 | point_3d numpy array of shape (3,): triangulated point 124 | """ 125 | assert len(proj_matricies) == len(points) 126 | 127 | n_views = len(proj_matricies) 128 | A = np.zeros((2 * n_views, 4)) 129 | for j in range(len(proj_matricies)): 130 | A[j * 2 + 0] = points[j][0] * proj_matricies[j][2, :] - proj_matricies[j][0, :] 131 | A[j * 2 + 1] = points[j][1] * proj_matricies[j][2, :] - proj_matricies[j][1, :] 132 | 133 | u, s, vh = np.linalg.svd(A, full_matrices=False) 134 | point_3d_homo = vh[3, :] 135 | 136 | point_3d = homogeneous_to_euclidean(point_3d_homo) 137 | 138 | return point_3d 139 | 140 | 141 | def triangulate_point_from_multiple_views_linear_torch(proj_matricies, points, confidences=None): 142 | """Similar as triangulate_point_from_multiple_views_linear() but for PyTorch. 143 | For more information see its documentation. 144 | Args: 145 | proj_matricies torch tensor of shape (N, 3, 4): sequence of projection matricies (3x4) 146 | points torch tensor of of shape (N, 2): sequence of points' coordinates 147 | confidences None or torch tensor of shape (N,): confidences of points [0.0, 1.0]. 148 | If None, all confidences are supposed to be 1.0 149 | Returns: 150 | point_3d numpy torch tensor of shape (3,): triangulated point 151 | """ 152 | assert len(proj_matricies) == len(points) 153 | 154 | n_views = len(proj_matricies) 155 | 156 | if confidences is None: 157 | confidences = torch.ones(n_views, dtype=torch.float32, device=points.device) 158 | 159 | A = proj_matricies[:, 2:3].expand(n_views, 2, 4) * points.view(n_views, 2, 1) 160 | A -= proj_matricies[:, :2] 161 | A *= confidences.view(-1, 1, 1) 162 | 163 | u, s, vh = torch.svd(A.view(-1, 4)) 164 | 165 | point_3d_homo = -vh[:, 3] 166 | point_3d = homogeneous_to_euclidean(point_3d_homo.unsqueeze(0))[0] 167 | 168 | return point_3d 169 | 170 | 171 | def triangulate_batch_of_points(proj_matricies_batch, points_batch, confidences_batch=None): 172 | batch_size, n_views, n_joints = points_batch.shape[:3] 173 | point_3d_batch = torch.zeros(batch_size, n_joints, 3, dtype=torch.float32, device=points_batch.device) 174 | 175 | for batch_i in range(batch_size): 176 | for joint_i in range(n_joints): 177 | points = points_batch[batch_i, :, joint_i, :] 178 | 179 | confidences = confidences_batch[batch_i, :, joint_i] if confidences_batch is not None else None 180 | point_3d = triangulate_point_from_multiple_views_linear_torch(proj_matricies_batch[batch_i], points, confidences=confidences) 181 | point_3d_batch[batch_i, joint_i] = point_3d 182 | 183 | return point_3d_batch 184 | 185 | 186 | def calc_reprojection_error_matrix(keypoints_3d, keypoints_2d_list, proj_matricies): 187 | reprojection_error_matrix = [] 188 | for keypoints_2d, proj_matrix in zip(keypoints_2d_list, proj_matricies): 189 | keypoints_2d_projected = project_3d_points_to_image_plane_without_distortion(proj_matrix, keypoints_3d) 190 | reprojection_error = 1 / 2 * np.sqrt(np.sum((keypoints_2d - keypoints_2d_projected) ** 2, axis=1)) 191 | reprojection_error_matrix.append(reprojection_error) 192 | 193 | return np.vstack(reprojection_error_matrix).T 194 | -------------------------------------------------------------------------------- /mvn/utils/op.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from mvn.utils.img import to_numpy, to_torch 8 | from mvn.utils import multiview 9 | 10 | 11 | def integrate_tensor_2d(heatmaps, softmax=True): 12 | """Applies softmax to heatmaps and integrates them to get their's "center of masses" 13 | 14 | Args: 15 | heatmaps torch tensor of shape (batch_size, n_heatmaps, h, w): input heatmaps 16 | 17 | Returns: 18 | coordinates torch tensor of shape (batch_size, n_heatmaps, 2): coordinates of center of masses of all heatmaps 19 | 20 | """ 21 | batch_size, n_heatmaps, h, w = heatmaps.shape 22 | 23 | heatmaps = heatmaps.reshape((batch_size, n_heatmaps, -1)) 24 | if softmax: 25 | heatmaps = nn.functional.softmax(heatmaps, dim=2) 26 | else: 27 | heatmaps = nn.functional.relu(heatmaps) 28 | 29 | heatmaps = heatmaps.reshape((batch_size, n_heatmaps, h, w)) 30 | 31 | mass_x = heatmaps.sum(dim=2) 32 | mass_y = heatmaps.sum(dim=3) 33 | 34 | mass_times_coord_x = mass_x * torch.arange(w).type(torch.float).to(mass_x.device) 35 | mass_times_coord_y = mass_y * torch.arange(h).type(torch.float).to(mass_y.device) 36 | 37 | x = mass_times_coord_x.sum(dim=2, keepdim=True) 38 | y = mass_times_coord_y.sum(dim=2, keepdim=True) 39 | 40 | if not softmax: 41 | x = x / mass_x.sum(dim=2, keepdim=True) 42 | y = y / mass_y.sum(dim=2, keepdim=True) 43 | 44 | coordinates = torch.cat((x, y), dim=2) 45 | coordinates = coordinates.reshape((batch_size, n_heatmaps, 2)) 46 | 47 | return coordinates, heatmaps 48 | 49 | 50 | def integrate_tensor_3d(volumes, softmax=True): 51 | batch_size, n_volumes, x_size, y_size, z_size = volumes.shape 52 | 53 | volumes = volumes.reshape((batch_size, n_volumes, -1)) 54 | if softmax: 55 | volumes = nn.functional.softmax(volumes, dim=2) 56 | else: 57 | volumes = nn.functional.relu(volumes) 58 | 59 | volumes = volumes.reshape((batch_size, n_volumes, x_size, y_size, z_size)) 60 | 61 | mass_x = volumes.sum(dim=3).sum(dim=3) 62 | mass_y = volumes.sum(dim=2).sum(dim=3) 63 | mass_z = volumes.sum(dim=2).sum(dim=2) 64 | 65 | mass_times_coord_x = mass_x * torch.arange(x_size).type(torch.float).to(mass_x.device) 66 | mass_times_coord_y = mass_y * torch.arange(y_size).type(torch.float).to(mass_y.device) 67 | mass_times_coord_z = mass_z * torch.arange(z_size).type(torch.float).to(mass_z.device) 68 | 69 | x = mass_times_coord_x.sum(dim=2, keepdim=True) 70 | y = mass_times_coord_y.sum(dim=2, keepdim=True) 71 | z = mass_times_coord_z.sum(dim=2, keepdim=True) 72 | 73 | if not softmax: 74 | x = x / mass_x.sum(dim=2, keepdim=True) 75 | y = y / mass_y.sum(dim=2, keepdim=True) 76 | z = z / mass_z.sum(dim=2, keepdim=True) 77 | 78 | coordinates = torch.cat((x, y, z), dim=2) 79 | coordinates = coordinates.reshape((batch_size, n_volumes, 3)) 80 | 81 | return coordinates, volumes 82 | 83 | 84 | def integrate_tensor_3d_with_coordinates(volumes, coord_volumes, softmax=True): 85 | batch_size, n_volumes, x_size, y_size, z_size = volumes.shape 86 | 87 | volumes = volumes.reshape((batch_size, n_volumes, -1)) 88 | if softmax: 89 | volumes = nn.functional.softmax(volumes, dim=2) 90 | else: 91 | volumes = nn.functional.relu(volumes) 92 | 93 | volumes = volumes.reshape((batch_size, n_volumes, x_size, y_size, z_size)) 94 | coordinates = torch.einsum("bnxyz, bxyzc -> bnc", volumes, coord_volumes) 95 | 96 | return coordinates, volumes 97 | 98 | 99 | def unproject_heatmaps(heatmaps, proj_matricies, coord_volumes, volume_aggregation_method='sum', vol_confidences=None): 100 | device = heatmaps.device 101 | batch_size, n_views, n_joints, heatmap_shape = heatmaps.shape[0], heatmaps.shape[1], heatmaps.shape[2], tuple(heatmaps.shape[3:]) 102 | volume_shape = coord_volumes.shape[1:4] 103 | 104 | volume_batch = torch.zeros(batch_size, n_joints, *volume_shape, device=device) 105 | 106 | # TODO: speed up this this loop 107 | for batch_i in range(batch_size): 108 | coord_volume = coord_volumes[batch_i] 109 | grid_coord = coord_volume.reshape((-1, 3)) 110 | 111 | volume_batch_to_aggregate = torch.zeros(n_views, n_joints, *volume_shape, device=device) 112 | 113 | for view_i in range(n_views): 114 | heatmap = heatmaps[batch_i, view_i] 115 | heatmap = heatmap.unsqueeze(0) 116 | 117 | grid_coord_proj = multiview.project_3d_points_to_image_plane_without_distortion( 118 | proj_matricies[batch_i, view_i], grid_coord, convert_back_to_euclidean=False 119 | ) 120 | 121 | invalid_mask = grid_coord_proj[:, 2] <= 0.0 # depth must be larger than 0.0 122 | 123 | grid_coord_proj[grid_coord_proj[:, 2] == 0.0, 2] = 1.0 # not to divide by zero 124 | grid_coord_proj = multiview.homogeneous_to_euclidean(grid_coord_proj) 125 | 126 | # transform to [-1.0, 1.0] range 127 | grid_coord_proj_transformed = torch.zeros_like(grid_coord_proj) 128 | grid_coord_proj_transformed[:, 0] = 2 * (grid_coord_proj[:, 0] / heatmap_shape[0] - 0.5) 129 | grid_coord_proj_transformed[:, 1] = 2 * (grid_coord_proj[:, 1] / heatmap_shape[1] - 0.5) 130 | grid_coord_proj = grid_coord_proj_transformed 131 | 132 | # prepare to F.grid_sample 133 | grid_coord_proj = grid_coord_proj.unsqueeze(1).unsqueeze(0) 134 | try: 135 | current_volume = F.grid_sample(heatmap, grid_coord_proj, align_corners=True) 136 | except TypeError: # old PyTorch 137 | current_volume = F.grid_sample(heatmap, grid_coord_proj) 138 | 139 | # zero out non-valid points 140 | current_volume = current_volume.view(n_joints, -1) 141 | current_volume[:, invalid_mask] = 0.0 142 | 143 | # reshape back to volume 144 | current_volume = current_volume.view(n_joints, *volume_shape) 145 | 146 | # collect 147 | volume_batch_to_aggregate[view_i] = current_volume 148 | 149 | # agregate resulting volume 150 | if volume_aggregation_method.startswith('conf'): 151 | volume_batch[batch_i] = (volume_batch_to_aggregate * vol_confidences[batch_i].view(n_views, n_joints, 1, 1, 1)).sum(0) 152 | elif volume_aggregation_method == 'sum': 153 | volume_batch[batch_i] = volume_batch_to_aggregate.sum(0) 154 | elif volume_aggregation_method == 'max': 155 | volume_batch[batch_i] = volume_batch_to_aggregate.max(0)[0] 156 | elif volume_aggregation_method == 'softmax': 157 | volume_batch_to_aggregate_softmin = volume_batch_to_aggregate.clone() 158 | volume_batch_to_aggregate_softmin = volume_batch_to_aggregate_softmin.view(n_views, -1) 159 | volume_batch_to_aggregate_softmin = nn.functional.softmax(volume_batch_to_aggregate_softmin, dim=0) 160 | volume_batch_to_aggregate_softmin = volume_batch_to_aggregate_softmin.view(n_views, n_joints, *volume_shape) 161 | 162 | volume_batch[batch_i] = (volume_batch_to_aggregate * volume_batch_to_aggregate_softmin).sum(0) 163 | else: 164 | raise ValueError("Unknown volume_aggregation_method: {}".format(volume_aggregation_method)) 165 | 166 | return volume_batch 167 | 168 | 169 | def gaussian_2d_pdf(coords, means, sigmas, normalize=True): 170 | normalization = 1.0 171 | if normalize: 172 | normalization = (2 * np.pi * sigmas[:, 0] * sigmas[:, 0]) 173 | 174 | exp = torch.exp(-((coords[:, 0] - means[:, 0]) ** 2 / sigmas[:, 0] ** 2 + (coords[:, 1] - means[:, 1]) ** 2 / sigmas[:, 1] ** 2) / 2) 175 | return exp / normalization 176 | 177 | 178 | def render_points_as_2d_gaussians(points, sigmas, image_shape, normalize=True): 179 | device = points.device 180 | n_points = points.shape[0] 181 | 182 | yy, xx = torch.meshgrid(torch.arange(image_shape[0]).to(device), torch.arange(image_shape[1]).to(device)) 183 | grid = torch.stack([xx, yy], dim=-1).type(torch.float32) 184 | grid = grid.unsqueeze(0).repeat(n_points, 1, 1, 1) # (n_points, h, w, 2) 185 | grid = grid.reshape((-1, 2)) 186 | 187 | points = points.unsqueeze(1).unsqueeze(1).repeat(1, image_shape[0], image_shape[1], 1) 188 | points = points.reshape(-1, 2) 189 | 190 | sigmas = sigmas.unsqueeze(1).unsqueeze(1).repeat(1, image_shape[0], image_shape[1], 1) 191 | sigmas = sigmas.reshape(-1, 2) 192 | 193 | images = gaussian_2d_pdf(grid, points, sigmas, normalize=normalize) 194 | images = images.reshape(n_points, *image_shape) 195 | 196 | return images 197 | -------------------------------------------------------------------------------- /mvn/utils/vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.ndimage 3 | import skimage.transform 4 | import cv2 5 | 6 | import torch 7 | 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | from matplotlib import pylab as plt 11 | from mpl_toolkits.mplot3d import axes3d, Axes3D 12 | 13 | 14 | from mvn.utils.img import image_batch_to_numpy, to_numpy, denormalize_image, resize_image 15 | from mvn.utils.multiview import project_3d_points_to_image_plane_without_distortion 16 | 17 | CONNECTIVITY_DICT = { 18 | 'cmu': [(0, 2), (0, 9), (1, 0), (1, 17), (2, 12), (3, 0), (4, 3), (5, 4), (6, 2), (7, 6), (8, 7), (9, 10), (10, 11), (12, 13), (13, 14), (15, 1), (16, 15), (17, 18)], 19 | 'coco': [(0, 1), (0, 2), (1, 3), (2, 4), (5, 7), (7, 9), (6, 8), (8, 10), (11, 13), (13, 15), (12, 14), (14, 16), (5, 6), (5, 11), (6, 12), (11, 12)], 20 | "mpii": [(0, 1), (1, 2), (2, 6), (5, 4), (4, 3), (3, 6), (6, 7), (7, 8), (8, 9), (8, 12), (8, 13), (10, 11), (11, 12), (13, 14), (14, 15)], 21 | "human36m": [(0, 1), (1, 2), (2, 6), (5, 4), (4, 3), (3, 6), (6, 7), (7, 8), (8, 16), (9, 16), (8, 12), (11, 12), (10, 11), (8, 13), (13, 14), (14, 15)], 22 | "kth": [(0, 1), (1, 2), (5, 4), (4, 3), (6, 7), (7, 8), (11, 10), (10, 9), (2, 3), (3, 9), (2, 8), (9, 12), (8, 12), (12, 13)], 23 | } 24 | 25 | COLOR_DICT = { 26 | 'coco': [ 27 | (102, 0, 153), (153, 0, 102), (51, 0, 153), (153, 0, 153), # head 28 | (51, 153, 0), (0, 153, 0), # left arm 29 | (153, 102, 0), (153, 153, 0), # right arm 30 | (0, 51, 153), (0, 0, 153), # left leg 31 | (0, 153, 102), (0, 153, 153), # right leg 32 | (153, 0, 0), (153, 0, 0), (153, 0, 0), (153, 0, 0) # body 33 | ], 34 | 35 | 'human36m': [ 36 | (0, 153, 102), (0, 153, 153), (0, 153, 153), # right leg 37 | (0, 51, 153), (0, 0, 153), (0, 0, 153), # left leg 38 | (153, 0, 0), (153, 0, 0), # body 39 | (153, 0, 102), (153, 0, 102), # head 40 | (153, 153, 0), (153, 153, 0), (153, 102, 0), # right arm 41 | (0, 153, 0), (0, 153, 0), (51, 153, 0) # left arm 42 | ], 43 | 44 | 'kth': [ 45 | (0, 153, 102), (0, 153, 153), # right leg 46 | (0, 51, 153), (0, 0, 153), # left leg 47 | (153, 102, 0), (153, 153, 0), # right arm 48 | (51, 153, 0), (0, 153, 0), # left arm 49 | (153, 0, 0), (153, 0, 0), (153, 0, 0), (153, 0, 0), (153, 0, 0), # body 50 | (102, 0, 153) # head 51 | ] 52 | } 53 | 54 | JOINT_NAMES_DICT = { 55 | 'coco': { 56 | 0: "nose", 57 | 1: "left_eye", 58 | 2: "right_eye", 59 | 3: "left_ear", 60 | 4: "right_ear", 61 | 5: "left_shoulder", 62 | 6: "right_shoulder", 63 | 7: "left_elbow", 64 | 8: "right_elbow", 65 | 9: "left_wrist", 66 | 10: "right_wrist", 67 | 11: "left_hip", 68 | 12: "right_hip", 69 | 13: "left_knee", 70 | 14: "right_knee", 71 | 15: "left_ankle", 72 | 16: "right_ankle" 73 | } 74 | } 75 | 76 | 77 | def fig_to_array(fig): 78 | fig.canvas.draw() 79 | fig_image = np.array(fig.canvas.renderer._renderer) 80 | 81 | return fig_image 82 | 83 | 84 | def visualize_batch(images_batch, heatmaps_batch, keypoints_2d_batch, proj_matricies_batch, 85 | keypoints_3d_batch_gt, keypoints_3d_batch_pred, 86 | kind="cmu", 87 | cuboids_batch=None, 88 | confidences_batch=None, 89 | batch_index=0, size=5, 90 | max_n_cols=10, 91 | pred_kind=None 92 | ): 93 | if pred_kind is None: 94 | pred_kind = kind 95 | 96 | n_views, n_joints = heatmaps_batch.shape[1], heatmaps_batch.shape[2] 97 | 98 | n_rows = 3 99 | n_rows = n_rows + 1 if keypoints_2d_batch is not None else n_rows 100 | n_rows = n_rows + 1 if cuboids_batch is not None else n_rows 101 | n_rows = n_rows + 1 if confidences_batch is not None else n_rows 102 | 103 | n_cols = min(n_views, max_n_cols) 104 | fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols * size, n_rows * size)) 105 | axes = axes.reshape(n_rows, n_cols) 106 | 107 | image_shape = images_batch.shape[3:] 108 | heatmap_shape = heatmaps_batch.shape[3:] 109 | 110 | row_i = 0 111 | 112 | # images 113 | axes[row_i, 0].set_ylabel("image", size='large') 114 | 115 | images = image_batch_to_numpy(images_batch[batch_index]) 116 | images = denormalize_image(images).astype(np.uint8) 117 | images = images[..., ::-1] # bgr -> rgb 118 | 119 | for view_i in range(n_cols): 120 | axes[row_i][view_i].imshow(images[view_i]) 121 | row_i += 1 122 | 123 | # 2D keypoints (pred) 124 | if keypoints_2d_batch is not None: 125 | axes[row_i, 0].set_ylabel("2d keypoints (pred)", size='large') 126 | 127 | keypoints_2d = to_numpy(keypoints_2d_batch)[batch_index] 128 | for view_i in range(n_cols): 129 | axes[row_i][view_i].imshow(images[view_i]) 130 | draw_2d_pose(keypoints_2d[view_i], axes[row_i][view_i], kind=kind) 131 | row_i += 1 132 | 133 | # 2D keypoints (gt projected) 134 | axes[row_i, 0].set_ylabel("2d keypoints (gt projected)", size='large') 135 | 136 | for view_i in range(n_cols): 137 | axes[row_i][view_i].imshow(images[view_i]) 138 | keypoints_2d_gt_proj = project_3d_points_to_image_plane_without_distortion(proj_matricies_batch[batch_index, view_i].detach().cpu().numpy(), keypoints_3d_batch_gt[batch_index].detach().cpu().numpy()) 139 | draw_2d_pose(keypoints_2d_gt_proj, axes[row_i][view_i], kind=kind) 140 | row_i += 1 141 | 142 | # 2D keypoints (pred projected) 143 | axes[row_i, 0].set_ylabel("2d keypoints (pred projected)", size='large') 144 | 145 | for view_i in range(n_cols): 146 | axes[row_i][view_i].imshow(images[view_i]) 147 | keypoints_2d_pred_proj = project_3d_points_to_image_plane_without_distortion(proj_matricies_batch[batch_index, view_i].detach().cpu().numpy(), keypoints_3d_batch_pred[batch_index].detach().cpu().numpy()) 148 | draw_2d_pose(keypoints_2d_pred_proj, axes[row_i][view_i], kind=pred_kind) 149 | row_i += 1 150 | 151 | # cuboids 152 | if cuboids_batch is not None: 153 | axes[row_i, 0].set_ylabel("cuboid", size='large') 154 | 155 | for view_i in range(n_cols): 156 | cuboid = cuboids_batch[batch_index] 157 | axes[row_i][view_i].imshow(cuboid.render(proj_matricies_batch[batch_index, view_i].detach().cpu().numpy(), images[view_i].copy())) 158 | row_i += 1 159 | 160 | # confidences 161 | if confidences_batch is not None: 162 | axes[row_i, 0].set_ylabel("confidences", size='large') 163 | 164 | for view_i in range(n_cols): 165 | confidences = to_numpy(confidences_batch[batch_index, view_i]) 166 | xs = np.arange(len(confidences)) 167 | 168 | axes[row_i, view_i].bar(xs, confidences, color='green') 169 | axes[row_i, view_i].set_xticks(xs) 170 | if torch.max(confidences_batch).item() <= 1.0: 171 | axes[row_i, view_i].set_ylim(0.0, 1.0) 172 | 173 | fig.tight_layout() 174 | 175 | fig_image = fig_to_array(fig) 176 | 177 | plt.close('all') 178 | 179 | return fig_image 180 | 181 | 182 | def visualize_heatmaps(images_batch, heatmaps_batch, 183 | kind="cmu", 184 | batch_index=0, size=5, 185 | max_n_rows=10, max_n_cols=10): 186 | n_views, n_joints = heatmaps_batch.shape[1], heatmaps_batch.shape[2] 187 | heatmap_shape = heatmaps_batch.shape[3:] 188 | 189 | n_cols, n_rows = min(n_joints + 1, max_n_cols), min(n_views, max_n_rows) 190 | fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols * size, n_rows * size)) 191 | axes = axes.reshape(n_rows, n_cols) 192 | 193 | # images 194 | images = image_batch_to_numpy(images_batch[batch_index]) 195 | images = denormalize_image(images).astype(np.uint8) 196 | images = images[..., ::-1] # bgr -> 197 | 198 | # heatmaps 199 | heatmaps = to_numpy(heatmaps_batch[batch_index]) 200 | 201 | for row in range(n_rows): 202 | for col in range(n_cols): 203 | if col == 0: 204 | axes[row, col].set_ylabel(str(row), size='large') 205 | axes[row, col].imshow(images[row]) 206 | else: 207 | if row == 0: 208 | joint_name = JOINT_NAMES_DICT[kind][col - 1] if kind in JOINT_NAMES_DICT else str(col - 1) 209 | axes[row, col].set_title(joint_name) 210 | 211 | axes[row, col].imshow(resize_image(images[row], heatmap_shape)) 212 | axes[row, col].imshow(heatmaps[row, col - 1], alpha=0.5) 213 | 214 | fig.tight_layout() 215 | 216 | fig_image = fig_to_array(fig) 217 | 218 | plt.close('all') 219 | 220 | return fig_image 221 | 222 | 223 | def visualize_volumes(images_batch, volumes_batch, proj_matricies_batch, 224 | kind="cmu", 225 | cuboids_batch=None, 226 | batch_index=0, size=5, 227 | max_n_rows=10, max_n_cols=10): 228 | n_views, n_joints = volumes_batch.shape[1], volumes_batch.shape[2] 229 | 230 | n_cols, n_rows = min(n_joints + 1, max_n_cols), min(n_views, max_n_rows) 231 | fig = plt.figure(figsize=(n_cols * size, n_rows * size)) 232 | 233 | # images 234 | images = image_batch_to_numpy(images_batch[batch_index]) 235 | images = denormalize_image(images).astype(np.uint8) 236 | images = images[..., ::-1] # bgr -> 237 | 238 | # heatmaps 239 | volumes = to_numpy(volumes_batch[batch_index]) 240 | 241 | for row in range(n_rows): 242 | for col in range(n_cols): 243 | if col == 0: 244 | ax = fig.add_subplot(n_rows, n_cols, row * n_cols + col + 1) 245 | ax.set_ylabel(str(row), size='large') 246 | 247 | cuboid = cuboids_batch[batch_index] 248 | ax.imshow(cuboid.render(proj_matricies_batch[batch_index, row].detach().cpu().numpy(), images[row].copy())) 249 | else: 250 | ax = fig.add_subplot(n_rows, n_cols, row * n_cols + col + 1, projection='3d') 251 | 252 | if row == 0: 253 | joint_name = JOINT_NAMES_DICT[kind][col - 1] if kind in JOINT_NAMES_DICT else str(col - 1) 254 | ax.set_title(joint_name) 255 | 256 | draw_voxels(volumes[col - 1], ax, norm=True) 257 | 258 | fig.tight_layout() 259 | 260 | fig_image = fig_to_array(fig) 261 | 262 | plt.close('all') 263 | 264 | return fig_image 265 | 266 | 267 | def draw_2d_pose(keypoints, ax, kind='cmu', keypoints_mask=None, point_size=2, line_width=1, radius=None, color=None): 268 | """ 269 | Visualizes a 2d skeleton 270 | 271 | Args 272 | keypoints numpy array of shape (19, 2): pose to draw in CMU format. 273 | ax: matplotlib axis to draw on 274 | """ 275 | connectivity = CONNECTIVITY_DICT[kind] 276 | 277 | color = 'blue' if color is None else color 278 | 279 | if keypoints_mask is None: 280 | keypoints_mask = [True] * len(keypoints) 281 | 282 | # points 283 | ax.scatter(keypoints[keypoints_mask][:, 0], keypoints[keypoints_mask][:, 1], c='red', s=point_size) 284 | 285 | # connections 286 | for (index_from, index_to) in connectivity: 287 | if keypoints_mask[index_from] and keypoints_mask[index_to]: 288 | xs, ys = [np.array([keypoints[index_from, j], keypoints[index_to, j]]) for j in range(2)] 289 | ax.plot(xs, ys, c=color, lw=line_width) 290 | 291 | if radius is not None: 292 | root_keypoint_index = 0 293 | xroot, yroot = keypoints[root_keypoint_index, 0], keypoints[root_keypoint_index, 1] 294 | 295 | ax.set_xlim([-radius + xroot, radius + xroot]) 296 | ax.set_ylim([-radius + yroot, radius + yroot]) 297 | 298 | ax.set_aspect('equal') 299 | 300 | 301 | def draw_2d_pose_cv2(keypoints, canvas, kind='cmu', keypoints_mask=None, point_size=2, point_color=(255, 255, 255), line_width=1, radius=None, color=None, anti_aliasing_scale=1): 302 | canvas = canvas.copy() 303 | 304 | shape = np.array(canvas.shape[:2]) 305 | new_shape = shape * anti_aliasing_scale 306 | canvas = resize_image(canvas, tuple(new_shape)) 307 | 308 | keypoints = keypoints * anti_aliasing_scale 309 | point_size = point_size * anti_aliasing_scale 310 | line_width = line_width * anti_aliasing_scale 311 | 312 | connectivity = CONNECTIVITY_DICT[kind] 313 | 314 | color = 'blue' if color is None else color 315 | 316 | if keypoints_mask is None: 317 | keypoints_mask = [True] * len(keypoints) 318 | 319 | # connections 320 | for i, (index_from, index_to) in enumerate(connectivity): 321 | if keypoints_mask[index_from] and keypoints_mask[index_to]: 322 | pt_from = tuple(np.array(keypoints[index_from, :]).astype(int)) 323 | pt_to = tuple(np.array(keypoints[index_to, :]).astype(int)) 324 | 325 | if kind in COLOR_DICT: 326 | color = COLOR_DICT[kind][i] 327 | else: 328 | color = (0, 0, 255) 329 | 330 | cv2.line(canvas, pt_from, pt_to, color=color, thickness=line_width) 331 | 332 | if kind == 'coco': 333 | mid_collarbone = (keypoints[5, :] + keypoints[6, :]) / 2 334 | nose = keypoints[0, :] 335 | 336 | pt_from = tuple(np.array(nose).astype(int)) 337 | pt_to = tuple(np.array(mid_collarbone).astype(int)) 338 | 339 | if kind in COLOR_DICT: 340 | color = (153, 0, 51) 341 | else: 342 | color = (0, 0, 255) 343 | 344 | cv2.line(canvas, pt_from, pt_to, color=color, thickness=line_width) 345 | 346 | # points 347 | for pt in keypoints[keypoints_mask]: 348 | cv2.circle(canvas, tuple(pt.astype(int)), point_size, color=point_color, thickness=-1) 349 | 350 | canvas = resize_image(canvas, tuple(shape)) 351 | 352 | return canvas 353 | 354 | 355 | def draw_3d_pose(keypoints, ax, keypoints_mask=None, kind='cmu', radius=None, root=None, point_size=2, line_width=2, draw_connections=True): 356 | connectivity = CONNECTIVITY_DICT[kind] 357 | 358 | if keypoints_mask is None: 359 | keypoints_mask = [True] * len(keypoints) 360 | 361 | if draw_connections: 362 | # Make connection matrix 363 | for i, joint in enumerate(connectivity): 364 | if keypoints_mask[joint[0]] and keypoints_mask[joint[1]]: 365 | xs, ys, zs = [np.array([keypoints[joint[0], j], keypoints[joint[1], j]]) for j in range(3)] 366 | 367 | if kind in COLOR_DICT: 368 | color = COLOR_DICT[kind][i] 369 | else: 370 | color = (0, 0, 255) 371 | 372 | color = np.array(color) / 255 373 | 374 | ax.plot(xs, ys, zs, lw=line_width, c=color) 375 | 376 | if kind == 'coco': 377 | mid_collarbone = (keypoints[5, :] + keypoints[6, :]) / 2 378 | nose = keypoints[0, :] 379 | 380 | xs, ys, zs = [np.array([nose[j], mid_collarbone[j]]) for j in range(3)] 381 | 382 | if kind in COLOR_DICT: 383 | color = (153, 0, 51) 384 | else: 385 | color = (0, 0, 255) 386 | 387 | color = np.array(color) / 255 388 | 389 | ax.plot(xs, ys, zs, lw=line_width, c=color) 390 | 391 | 392 | ax.scatter(keypoints[keypoints_mask][:, 0], keypoints[keypoints_mask][:, 1], keypoints[keypoints_mask][:, 2], 393 | s=point_size, c=np.array([230, 145, 56])/255, edgecolors='black') # np.array([230, 145, 56])/255 394 | 395 | if radius is not None: 396 | if root is None: 397 | root = np.mean(keypoints, axis=0) 398 | xroot, yroot, zroot = root 399 | ax.set_xlim([-radius + xroot, radius + xroot]) 400 | ax.set_ylim([-radius + yroot, radius + yroot]) 401 | ax.set_zlim([-radius + zroot, radius + zroot]) 402 | 403 | ax.set_aspect('equal') 404 | 405 | 406 | # Get rid of the panes 407 | background_color = np.array([252, 252, 252]) / 255 408 | 409 | ax.w_xaxis.set_pane_color(background_color) 410 | ax.w_yaxis.set_pane_color(background_color) 411 | ax.w_zaxis.set_pane_color(background_color) 412 | 413 | # Get rid of the ticks 414 | ax.set_xticklabels([]) 415 | ax.set_yticklabels([]) 416 | ax.set_zticklabels([]) 417 | 418 | 419 | def draw_voxels(voxels, ax, shape=(8, 8, 8), norm=True, alpha=0.1): 420 | # resize for visualization 421 | zoom = np.array(shape) / np.array(voxels.shape) 422 | voxels = skimage.transform.resize(voxels, shape, mode='constant', anti_aliasing=True) 423 | voxels = voxels.transpose(2, 0, 1) 424 | 425 | if norm and voxels.max() - voxels.min() > 0: 426 | voxels = (voxels - voxels.min()) / (voxels.max() - voxels.min()) 427 | 428 | filled = np.ones(voxels.shape) 429 | 430 | # facecolors 431 | cmap = plt.get_cmap("Blues") 432 | 433 | facecolors_a = cmap(voxels, alpha=alpha) 434 | facecolors_a = facecolors_a.reshape(-1, 4) 435 | 436 | facecolors_hex = np.array(list(map(lambda x: matplotlib.colors.to_hex(x, keep_alpha=True), facecolors_a))) 437 | facecolors_hex = facecolors_hex.reshape(*voxels.shape) 438 | 439 | # explode voxels to perform 3d alpha rendering (https://matplotlib.org/devdocs/gallery/mplot3d/voxels_numpy_logo.html) 440 | def explode(data): 441 | size = np.array(data.shape) * 2 442 | data_e = np.zeros(size - 1, dtype=data.dtype) 443 | data_e[::2, ::2, ::2] = data 444 | return data_e 445 | 446 | filled_2 = explode(filled) 447 | facecolors_2 = explode(facecolors_hex) 448 | 449 | # shrink the gaps 450 | x, y, z = np.indices(np.array(filled_2.shape) + 1).astype(float) // 2 451 | x[0::2, :, :] += 0.05 452 | y[:, 0::2, :] += 0.05 453 | z[:, :, 0::2] += 0.05 454 | x[1::2, :, :] += 0.95 455 | y[:, 1::2, :] += 0.95 456 | z[:, :, 1::2] += 0.95 457 | 458 | # draw voxels 459 | ax.voxels(x, y, z, filled_2, facecolors=facecolors_2) 460 | 461 | ax.set_xlabel("z"); ax.set_ylabel("x"); ax.set_zlabel("y") 462 | ax.invert_xaxis(); ax.invert_zaxis() 463 | -------------------------------------------------------------------------------- /mvn/utils/volumetric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | 5 | from mvn.utils import multiview 6 | 7 | 8 | class Point3D: 9 | def __init__(self, point, size=3, color=(0, 0, 255)): 10 | self.point = point 11 | self.size = size 12 | self.color = color 13 | 14 | def render(self, proj_matrix, canvas): 15 | point_2d = multiview.project_3d_points_to_image_plane_without_distortion( 16 | proj_matrix, np.array([self.point]) 17 | )[0] 18 | 19 | point_2d = tuple(map(int, point_2d)) 20 | cv2.circle(canvas, point_2d, self.size, self.color, self.size) 21 | 22 | return canvas 23 | 24 | 25 | class Line3D: 26 | def __init__(self, start_point, end_point, size=2, color=(0, 0, 255)): 27 | self.start_point, self.end_point = start_point, end_point 28 | self.size = size 29 | self.color = color 30 | 31 | def render(self, proj_matrix, canvas): 32 | start_point_2d, end_point_2d = multiview.project_3d_points_to_image_plane_without_distortion( 33 | proj_matrix, np.array([self.start_point, self.end_point]) 34 | ) 35 | 36 | start_point_2d = tuple(map(int, start_point_2d)) 37 | end_point_2d = tuple(map(int, end_point_2d)) 38 | 39 | cv2.line(canvas, start_point_2d, end_point_2d, self.color, self.size) 40 | 41 | return canvas 42 | 43 | 44 | class Cuboid3D: 45 | def __init__(self, position, sides): 46 | self.position = position 47 | self.sides = sides 48 | 49 | def build(self): 50 | primitives = [] 51 | 52 | line_color = (255, 255, 0) 53 | 54 | start = self.position + np.array([0, 0, 0]) 55 | primitives.append(Line3D(start, start + np.array([self.sides[0], 0, 0]), color=(255, 0, 0))) 56 | primitives.append(Line3D(start, start + np.array([0, self.sides[1], 0]), color=(0, 255, 0))) 57 | primitives.append(Line3D(start, start + np.array([0, 0, self.sides[2]]), color=(0, 0, 255))) 58 | 59 | start = self.position + np.array([self.sides[0], 0, self.sides[2]]) 60 | primitives.append(Line3D(start, start + np.array([-self.sides[0], 0, 0]), color=line_color)) 61 | primitives.append(Line3D(start, start + np.array([0, self.sides[1], 0]), color=line_color)) 62 | primitives.append(Line3D(start, start + np.array([0, 0, -self.sides[2]]), color=line_color)) 63 | 64 | start = self.position + np.array([self.sides[0], self.sides[1], 0]) 65 | primitives.append(Line3D(start, start + np.array([-self.sides[0], 0, 0]), color=line_color)) 66 | primitives.append(Line3D(start, start + np.array([0, -self.sides[1], 0]), color=line_color)) 67 | primitives.append(Line3D(start, start + np.array([0, 0, self.sides[2]]), color=line_color)) 68 | 69 | start = self.position + np.array([0, self.sides[1], self.sides[2]]) 70 | primitives.append(Line3D(start, start + np.array([self.sides[0], 0, 0]), color=line_color)) 71 | primitives.append(Line3D(start, start + np.array([0, -self.sides[1], 0]), color=line_color)) 72 | primitives.append(Line3D(start, start + np.array([0, 0, -self.sides[2]]), color=line_color)) 73 | 74 | return primitives 75 | 76 | def render(self, proj_matrix, canvas): 77 | # TODO: support rotation 78 | 79 | primitives = self.build() 80 | 81 | for primitive in primitives: 82 | canvas = primitive.render(proj_matrix, canvas) 83 | 84 | return canvas 85 | 86 | 87 | def get_rotation_matrix(axis, theta): 88 | """Returns the rotation matrix associated with counterclockwise rotation about 89 | the given axis by theta radians. 90 | """ 91 | axis = np.asarray(axis) 92 | axis = axis / np.sqrt(np.dot(axis, axis)) 93 | a = np.cos(theta / 2.0) 94 | b, c, d = -axis * np.sin(theta / 2.0) 95 | aa, bb, cc, dd = a * a, b * b, c * c, d * d 96 | bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d 97 | return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], 98 | [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], 99 | [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) 100 | 101 | 102 | def rotate_coord_volume(coord_volume, theta, axis): 103 | shape = coord_volume.shape 104 | device = coord_volume.device 105 | 106 | rot = get_rotation_matrix(axis, theta) 107 | rot = torch.from_numpy(rot).type(torch.float).to(device) 108 | 109 | coord_volume = coord_volume.view(-1, 3) 110 | coord_volume = rot.mm(coord_volume.t()).t() 111 | 112 | coord_volume = coord_volume.view(*shape) 113 | 114 | return coord_volume 115 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | decorator==4.4.0 3 | easydict==1.9 4 | imageio==2.5.0 5 | kiwisolver==1.1.0 6 | matplotlib==3.1.1 7 | networkx==2.3 8 | numpy==1.17.2 9 | opencv-python==4.1.1.26 10 | Pillow==6.2.0 11 | protobuf==3.10.0 12 | pyparsing==2.4.2 13 | python-dateutil==2.8.0 14 | PyWavelets==1.0.3 15 | PyYAML==5.1.2 16 | scikit-image==0.15.0 17 | scipy==1.3.1 18 | six==1.12.0 19 | tensorboardX==1.8 20 | torch==1.0.1 21 | torchvision==0.2.2 22 | h5py==2.10.0 23 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import argparse 4 | import time 5 | import json 6 | from datetime import datetime 7 | from collections import defaultdict 8 | from itertools import islice 9 | import pickle 10 | import copy 11 | 12 | import numpy as np 13 | import cv2 14 | 15 | import torch 16 | from torch import nn 17 | from torch import autograd 18 | import torch.nn.functional as F 19 | import torch.optim as optim 20 | from torch.utils.data import DataLoader 21 | from torch.nn.parallel import DistributedDataParallel 22 | 23 | from tensorboardX import SummaryWriter 24 | 25 | from mvn.models.triangulation import RANSACTriangulationNet, AlgebraicTriangulationNet, VolumetricTriangulationNet 26 | from mvn.models.loss import KeypointsMSELoss, KeypointsMSESmoothLoss, KeypointsMAELoss, KeypointsL2Loss, VolumetricCELoss 27 | 28 | from mvn.utils import img, multiview, op, vis, misc, cfg 29 | from mvn.datasets import human36m 30 | from mvn.datasets import utils as dataset_utils 31 | 32 | 33 | def parse_args(): 34 | parser = argparse.ArgumentParser() 35 | 36 | parser.add_argument("--config", type=str, required=True, help="Path, where config file is stored") 37 | parser.add_argument('--eval', action='store_true', help="If set, then only evaluation will be done") 38 | parser.add_argument('--eval_dataset', type=str, default='val', help="Dataset split on which evaluate. Can be 'train' and 'val'") 39 | 40 | parser.add_argument("--local_rank", type=int, help="Local rank of the process on the node") 41 | parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") 42 | 43 | parser.add_argument("--logdir", type=str, default="/Vol1/dbstore/datasets/k.iskakov/logs/multi-view-net-repr", help="Path, where logs will be stored") 44 | 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | def setup_human36m_dataloaders(config, is_train, distributed_train): 50 | train_dataloader = None 51 | if is_train: 52 | # train 53 | train_dataset = human36m.Human36MMultiViewDataset( 54 | h36m_root=config.dataset.train.h36m_root, 55 | pred_results_path=config.dataset.train.pred_results_path if hasattr(config.dataset.train, "pred_results_path") else None, 56 | train=True, 57 | test=False, 58 | image_shape=config.image_shape if hasattr(config, "image_shape") else (256, 256), 59 | labels_path=config.dataset.train.labels_path, 60 | with_damaged_actions=config.dataset.train.with_damaged_actions, 61 | scale_bbox=config.dataset.train.scale_bbox, 62 | kind=config.kind, 63 | undistort_images=config.dataset.train.undistort_images, 64 | ignore_cameras=config.dataset.train.ignore_cameras if hasattr(config.dataset.train, "ignore_cameras") else [], 65 | crop=config.dataset.train.crop if hasattr(config.dataset.train, "crop") else True, 66 | ) 67 | 68 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed_train else None 69 | 70 | train_dataloader = DataLoader( 71 | train_dataset, 72 | batch_size=config.opt.batch_size, 73 | shuffle=config.dataset.train.shuffle and (train_sampler is None), # debatable 74 | sampler=train_sampler, 75 | collate_fn=dataset_utils.make_collate_fn(randomize_n_views=config.dataset.train.randomize_n_views, 76 | min_n_views=config.dataset.train.min_n_views, 77 | max_n_views=config.dataset.train.max_n_views), 78 | num_workers=config.dataset.train.num_workers, 79 | worker_init_fn=dataset_utils.worker_init_fn, 80 | pin_memory=True 81 | ) 82 | 83 | # val 84 | val_dataset = human36m.Human36MMultiViewDataset( 85 | h36m_root=config.dataset.val.h36m_root, 86 | pred_results_path=config.dataset.val.pred_results_path if hasattr(config.dataset.val, "pred_results_path") else None, 87 | train=False, 88 | test=True, 89 | image_shape=config.image_shape if hasattr(config, "image_shape") else (256, 256), 90 | labels_path=config.dataset.val.labels_path, 91 | with_damaged_actions=config.dataset.val.with_damaged_actions, 92 | retain_every_n_frames_in_test=config.dataset.val.retain_every_n_frames_in_test, 93 | scale_bbox=config.dataset.val.scale_bbox, 94 | kind=config.kind, 95 | undistort_images=config.dataset.val.undistort_images, 96 | ignore_cameras=config.dataset.val.ignore_cameras if hasattr(config.dataset.val, "ignore_cameras") else [], 97 | crop=config.dataset.val.crop if hasattr(config.dataset.val, "crop") else True, 98 | ) 99 | 100 | val_dataloader = DataLoader( 101 | val_dataset, 102 | batch_size=config.opt.val_batch_size if hasattr(config.opt, "val_batch_size") else config.opt.batch_size, 103 | shuffle=config.dataset.val.shuffle, 104 | collate_fn=dataset_utils.make_collate_fn(randomize_n_views=config.dataset.val.randomize_n_views, 105 | min_n_views=config.dataset.val.min_n_views, 106 | max_n_views=config.dataset.val.max_n_views), 107 | num_workers=config.dataset.val.num_workers, 108 | worker_init_fn=dataset_utils.worker_init_fn, 109 | pin_memory=True 110 | ) 111 | 112 | return train_dataloader, val_dataloader, train_sampler 113 | 114 | 115 | def setup_dataloaders(config, is_train=True, distributed_train=False): 116 | if config.dataset.kind == 'human36m': 117 | train_dataloader, val_dataloader, train_sampler = setup_human36m_dataloaders(config, is_train, distributed_train) 118 | else: 119 | raise NotImplementedError("Unknown dataset: {}".format(config.dataset.kind)) 120 | 121 | return train_dataloader, val_dataloader, train_sampler 122 | 123 | 124 | def setup_experiment(config, model_name, is_train=True): 125 | prefix = "" if is_train else "eval_" 126 | 127 | if config.title: 128 | experiment_title = config.title + "_" + model_name 129 | else: 130 | experiment_title = model_name 131 | 132 | experiment_title = prefix + experiment_title 133 | 134 | experiment_name = '{}@{}'.format(experiment_title, datetime.now().strftime("%d.%m.%Y-%H:%M:%S")) 135 | print("Experiment name: {}".format(experiment_name)) 136 | 137 | experiment_dir = os.path.join(args.logdir, experiment_name) 138 | os.makedirs(experiment_dir, exist_ok=True) 139 | 140 | checkpoints_dir = os.path.join(experiment_dir, "checkpoints") 141 | os.makedirs(checkpoints_dir, exist_ok=True) 142 | 143 | shutil.copy(args.config, os.path.join(experiment_dir, "config.yaml")) 144 | 145 | # tensorboard 146 | writer = SummaryWriter(os.path.join(experiment_dir, "tb")) 147 | 148 | # dump config to tensorboard 149 | writer.add_text(misc.config_to_str(config), "config", 0) 150 | 151 | return experiment_dir, writer 152 | 153 | 154 | def one_epoch(model, criterion, opt, config, dataloader, device, epoch, n_iters_total=0, is_train=True, caption='', master=False, experiment_dir=None, writer=None): 155 | name = "train" if is_train else "val" 156 | model_type = config.model.name 157 | 158 | if is_train: 159 | model.train() 160 | else: 161 | model.eval() 162 | 163 | metric_dict = defaultdict(list) 164 | 165 | results = defaultdict(list) 166 | 167 | # used to turn on/off gradients 168 | grad_context = torch.autograd.enable_grad if is_train else torch.no_grad 169 | with grad_context(): 170 | end = time.time() 171 | 172 | iterator = enumerate(dataloader) 173 | if is_train and config.opt.n_iters_per_epoch is not None: 174 | iterator = islice(iterator, config.opt.n_iters_per_epoch) 175 | 176 | for iter_i, batch in iterator: 177 | with autograd.detect_anomaly(): 178 | # measure data loading time 179 | data_time = time.time() - end 180 | 181 | if batch is None: 182 | print("Found None batch") 183 | continue 184 | 185 | images_batch, keypoints_3d_gt, keypoints_3d_validity_gt, proj_matricies_batch = dataset_utils.prepare_batch(batch, device, config) 186 | 187 | keypoints_2d_pred, cuboids_pred, base_points_pred = None, None, None 188 | if model_type == "alg" or model_type == "ransac": 189 | keypoints_3d_pred, keypoints_2d_pred, heatmaps_pred, confidences_pred = model(images_batch, proj_matricies_batch, batch) 190 | elif model_type == "vol": 191 | keypoints_3d_pred, heatmaps_pred, volumes_pred, confidences_pred, cuboids_pred, coord_volumes_pred, base_points_pred = model(images_batch, proj_matricies_batch, batch) 192 | 193 | batch_size, n_views, image_shape = images_batch.shape[0], images_batch.shape[1], tuple(images_batch.shape[3:]) 194 | n_joints = keypoints_3d_pred.shape[1] 195 | 196 | keypoints_3d_binary_validity_gt = (keypoints_3d_validity_gt > 0.0).type(torch.float32) 197 | 198 | scale_keypoints_3d = config.opt.scale_keypoints_3d if hasattr(config.opt, "scale_keypoints_3d") else 1.0 199 | 200 | # 1-view case 201 | if n_views == 1: 202 | if config.kind == "human36m": 203 | base_joint = 6 204 | elif config.kind == "coco": 205 | base_joint = 11 206 | 207 | keypoints_3d_gt_transformed = keypoints_3d_gt.clone() 208 | keypoints_3d_gt_transformed[:, torch.arange(n_joints) != base_joint] -= keypoints_3d_gt_transformed[:, base_joint:base_joint + 1] 209 | keypoints_3d_gt = keypoints_3d_gt_transformed 210 | 211 | keypoints_3d_pred_transformed = keypoints_3d_pred.clone() 212 | keypoints_3d_pred_transformed[:, torch.arange(n_joints) != base_joint] -= keypoints_3d_pred_transformed[:, base_joint:base_joint + 1] 213 | keypoints_3d_pred = keypoints_3d_pred_transformed 214 | 215 | # calculate loss 216 | total_loss = 0.0 217 | loss = criterion(keypoints_3d_pred * scale_keypoints_3d, keypoints_3d_gt * scale_keypoints_3d, keypoints_3d_binary_validity_gt) 218 | total_loss += loss 219 | metric_dict[f'{config.opt.criterion}'].append(loss.item()) 220 | 221 | # volumetric ce loss 222 | use_volumetric_ce_loss = config.opt.use_volumetric_ce_loss if hasattr(config.opt, "use_volumetric_ce_loss") else False 223 | if use_volumetric_ce_loss: 224 | volumetric_ce_criterion = VolumetricCELoss() 225 | 226 | loss = volumetric_ce_criterion(coord_volumes_pred, volumes_pred, keypoints_3d_gt, keypoints_3d_binary_validity_gt) 227 | metric_dict['volumetric_ce_loss'].append(loss.item()) 228 | 229 | weight = config.opt.volumetric_ce_loss_weight if hasattr(config.opt, "volumetric_ce_loss_weight") else 1.0 230 | total_loss += weight * loss 231 | 232 | metric_dict['total_loss'].append(total_loss.item()) 233 | 234 | if is_train: 235 | opt.zero_grad() 236 | total_loss.backward() 237 | 238 | if hasattr(config.opt, "grad_clip"): 239 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.opt.grad_clip / config.opt.lr) 240 | 241 | metric_dict['grad_norm_times_lr'].append(config.opt.lr * misc.calc_gradient_norm(filter(lambda x: x[1].requires_grad, model.named_parameters()))) 242 | 243 | opt.step() 244 | 245 | # calculate metrics 246 | l2 = KeypointsL2Loss()(keypoints_3d_pred * scale_keypoints_3d, keypoints_3d_gt * scale_keypoints_3d, keypoints_3d_binary_validity_gt) 247 | metric_dict['l2'].append(l2.item()) 248 | 249 | # base point l2 250 | if base_points_pred is not None: 251 | base_point_l2_list = [] 252 | for batch_i in range(batch_size): 253 | base_point_pred = base_points_pred[batch_i] 254 | 255 | if config.model.kind == "coco": 256 | base_point_gt = (keypoints_3d_gt[batch_i, 11, :3] + keypoints_3d[batch_i, 12, :3]) / 2 257 | elif config.model.kind == "mpii": 258 | base_point_gt = keypoints_3d_gt[batch_i, 6, :3] 259 | 260 | base_point_l2_list.append(torch.sqrt(torch.sum((base_point_pred * scale_keypoints_3d - base_point_gt * scale_keypoints_3d) ** 2)).item()) 261 | 262 | base_point_l2 = 0.0 if len(base_point_l2_list) == 0 else np.mean(base_point_l2_list) 263 | metric_dict['base_point_l2'].append(base_point_l2) 264 | 265 | # save answers for evalulation 266 | if not is_train: 267 | results['keypoints_3d'].append(keypoints_3d_pred.detach().cpu().numpy()) 268 | results['indexes'].append(batch['indexes']) 269 | 270 | # plot visualization 271 | if master: 272 | if n_iters_total % config.vis_freq == 0:# or total_l2.item() > 500.0: 273 | vis_kind = config.kind 274 | if (config.transfer_cmu_to_human36m if hasattr(config, "transfer_cmu_to_human36m") else False): 275 | vis_kind = "coco" 276 | 277 | for batch_i in range(min(batch_size, config.vis_n_elements)): 278 | keypoints_vis = vis.visualize_batch( 279 | images_batch, heatmaps_pred, keypoints_2d_pred, proj_matricies_batch, 280 | keypoints_3d_gt, keypoints_3d_pred, 281 | kind=vis_kind, 282 | cuboids_batch=cuboids_pred, 283 | confidences_batch=confidences_pred, 284 | batch_index=batch_i, size=5, 285 | max_n_cols=10 286 | ) 287 | writer.add_image(f"{name}/keypoints_vis/{batch_i}", keypoints_vis.transpose(2, 0, 1), global_step=n_iters_total) 288 | 289 | heatmaps_vis = vis.visualize_heatmaps( 290 | images_batch, heatmaps_pred, 291 | kind=vis_kind, 292 | batch_index=batch_i, size=5, 293 | max_n_rows=10, max_n_cols=10 294 | ) 295 | writer.add_image(f"{name}/heatmaps/{batch_i}", heatmaps_vis.transpose(2, 0, 1), global_step=n_iters_total) 296 | 297 | if model_type == "vol": 298 | volumes_vis = vis.visualize_volumes( 299 | images_batch, volumes_pred, proj_matricies_batch, 300 | kind=vis_kind, 301 | cuboids_batch=cuboids_pred, 302 | batch_index=batch_i, size=5, 303 | max_n_rows=1, max_n_cols=16 304 | ) 305 | writer.add_image(f"{name}/volumes/{batch_i}", volumes_vis.transpose(2, 0, 1), global_step=n_iters_total) 306 | 307 | # dump weights to tensoboard 308 | if n_iters_total % config.vis_freq == 0: 309 | for p_name, p in model.named_parameters(): 310 | try: 311 | writer.add_histogram(p_name, p.clone().cpu().data.numpy(), n_iters_total) 312 | except ValueError as e: 313 | print(e) 314 | print(p_name, p) 315 | exit() 316 | 317 | # dump to tensorboard per-iter loss/metric stats 318 | if is_train: 319 | for title, value in metric_dict.items(): 320 | writer.add_scalar(f"{name}/{title}", value[-1], n_iters_total) 321 | 322 | # measure elapsed time 323 | batch_time = time.time() - end 324 | end = time.time() 325 | 326 | # dump to tensorboard per-iter time stats 327 | writer.add_scalar(f"{name}/batch_time", batch_time, n_iters_total) 328 | writer.add_scalar(f"{name}/data_time", data_time, n_iters_total) 329 | 330 | # dump to tensorboard per-iter stats about sizes 331 | writer.add_scalar(f"{name}/batch_size", batch_size, n_iters_total) 332 | writer.add_scalar(f"{name}/n_views", n_views, n_iters_total) 333 | 334 | n_iters_total += 1 335 | 336 | # calculate evaluation metrics 337 | if master: 338 | if not is_train: 339 | results['keypoints_3d'] = np.concatenate(results['keypoints_3d'], axis=0) 340 | results['indexes'] = np.concatenate(results['indexes']) 341 | 342 | try: 343 | scalar_metric, full_metric = dataloader.dataset.evaluate(results['keypoints_3d']) 344 | except Exception as e: 345 | print("Failed to evaluate. Reason: ", e) 346 | scalar_metric, full_metric = 0.0, {} 347 | 348 | metric_dict['dataset_metric'].append(scalar_metric) 349 | 350 | checkpoint_dir = os.path.join(experiment_dir, "checkpoints", "{:04}".format(epoch)) 351 | os.makedirs(checkpoint_dir, exist_ok=True) 352 | 353 | # dump results 354 | with open(os.path.join(checkpoint_dir, "results.pkl"), 'wb') as fout: 355 | pickle.dump(results, fout) 356 | 357 | # dump full metric 358 | with open(os.path.join(checkpoint_dir, "metric.json".format(epoch)), 'w') as fout: 359 | json.dump(full_metric, fout, indent=4, sort_keys=True) 360 | 361 | # dump to tensorboard per-epoch stats 362 | for title, value in metric_dict.items(): 363 | writer.add_scalar(f"{name}/{title}_epoch", np.mean(value), epoch) 364 | 365 | return n_iters_total 366 | 367 | 368 | def init_distributed(args): 369 | if "WORLD_SIZE" not in os.environ or int(os.environ["WORLD_SIZE"]) < 1: 370 | return False 371 | 372 | torch.cuda.set_device(args.local_rank) 373 | 374 | assert os.environ["MASTER_PORT"], "set the MASTER_PORT variable or use pytorch launcher" 375 | assert os.environ["RANK"], "use pytorch launcher and explicityly state the rank of the process" 376 | 377 | torch.manual_seed(args.seed) 378 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 379 | 380 | return True 381 | 382 | 383 | def main(args): 384 | print("Number of available GPUs: {}".format(torch.cuda.device_count())) 385 | 386 | is_distributed = init_distributed(args) 387 | master = True 388 | if is_distributed and os.environ["RANK"]: 389 | master = int(os.environ["RANK"]) == 0 390 | 391 | if is_distributed: 392 | device = torch.device(args.local_rank) 393 | else: 394 | device = torch.device(0) 395 | 396 | # config 397 | config = cfg.load_config(args.config) 398 | config.opt.n_iters_per_epoch = config.opt.n_objects_per_epoch // config.opt.batch_size 399 | 400 | model = { 401 | "ransac": RANSACTriangulationNet, 402 | "alg": AlgebraicTriangulationNet, 403 | "vol": VolumetricTriangulationNet 404 | }[config.model.name](config, device=device).to(device) 405 | 406 | if config.model.init_weights: 407 | state_dict = torch.load(config.model.checkpoint) 408 | for key in list(state_dict.keys()): 409 | new_key = key.replace("module.", "") 410 | state_dict[new_key] = state_dict.pop(key) 411 | 412 | model.load_state_dict(state_dict, strict=True) 413 | print("Successfully loaded pretrained weights for whole model") 414 | 415 | # criterion 416 | criterion_class = { 417 | "MSE": KeypointsMSELoss, 418 | "MSESmooth": KeypointsMSESmoothLoss, 419 | "MAE": KeypointsMAELoss 420 | }[config.opt.criterion] 421 | 422 | if config.opt.criterion == "MSESmooth": 423 | criterion = criterion_class(config.opt.mse_smooth_threshold) 424 | else: 425 | criterion = criterion_class() 426 | 427 | # optimizer 428 | opt = None 429 | if not args.eval: 430 | if config.model.name == "vol": 431 | opt = torch.optim.Adam( 432 | [{'params': model.backbone.parameters()}, 433 | {'params': model.process_features.parameters(), 'lr': config.opt.process_features_lr if hasattr(config.opt, "process_features_lr") else config.opt.lr}, 434 | {'params': model.volume_net.parameters(), 'lr': config.opt.volume_net_lr if hasattr(config.opt, "volume_net_lr") else config.opt.lr} 435 | ], 436 | lr=config.opt.lr 437 | ) 438 | else: 439 | opt = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.opt.lr) 440 | 441 | 442 | # datasets 443 | print("Loading data...") 444 | train_dataloader, val_dataloader, train_sampler = setup_dataloaders(config, distributed_train=is_distributed) 445 | 446 | # experiment 447 | experiment_dir, writer = None, None 448 | if master: 449 | experiment_dir, writer = setup_experiment(config, type(model).__name__, is_train=not args.eval) 450 | 451 | # multi-gpu 452 | if is_distributed: 453 | model = DistributedDataParallel(model, device_ids=[device]) 454 | 455 | if not args.eval: 456 | # train loop 457 | n_iters_total_train, n_iters_total_val = 0, 0 458 | for epoch in range(config.opt.n_epochs): 459 | if train_sampler is not None: 460 | train_sampler.set_epoch(epoch) 461 | 462 | n_iters_total_train = one_epoch(model, criterion, opt, config, train_dataloader, device, epoch, n_iters_total=n_iters_total_train, is_train=True, master=master, experiment_dir=experiment_dir, writer=writer) 463 | n_iters_total_val = one_epoch(model, criterion, opt, config, val_dataloader, device, epoch, n_iters_total=n_iters_total_val, is_train=False, master=master, experiment_dir=experiment_dir, writer=writer) 464 | 465 | if master: 466 | checkpoint_dir = os.path.join(experiment_dir, "checkpoints", "{:04}".format(epoch)) 467 | os.makedirs(checkpoint_dir, exist_ok=True) 468 | 469 | torch.save(model.state_dict(), os.path.join(checkpoint_dir, "weights.pth")) 470 | 471 | print(f"{n_iters_total_train} iters done.") 472 | else: 473 | if args.eval_dataset == 'train': 474 | one_epoch(model, criterion, opt, config, train_dataloader, device, 0, n_iters_total=0, is_train=False, master=master, experiment_dir=experiment_dir, writer=writer) 475 | else: 476 | one_epoch(model, criterion, opt, config, val_dataloader, device, 0, n_iters_total=0, is_train=False, master=master, experiment_dir=experiment_dir, writer=writer) 477 | 478 | print("Done.") 479 | 480 | if __name__ == '__main__': 481 | args = parse_args() 482 | print("args: {}".format(args)) 483 | main(args) 484 | --------------------------------------------------------------------------------