├── .gitignore ├── .gitmodules ├── LICENSE ├── NOTICE.md ├── README.md ├── azure-pipelines.yml ├── docs ├── CODE_OF_CONDUCT.md ├── DEMO.md ├── DOWNLOAD.md ├── EXP.md ├── Fig1.gif ├── Fig2.gif ├── Fig3.gif ├── Fig4.gif ├── INSTALL.md ├── SECURITY.md ├── SUPPORT.md └── graphormer_overview.png ├── requirements.txt ├── samples ├── hand │ ├── freihand_sample1.jpg │ ├── freihand_sample1_graphormer_pred.jpg │ ├── freihand_sample2.jpg │ ├── freihand_sample2_graphormer_pred.jpg │ ├── freihand_sample3.jpg │ ├── freihand_sample3_graphormer_pred.jpg │ ├── internet_fig1.jpg │ ├── internet_fig1_graphormer_pred.jpg │ ├── internet_fig2.jpg │ ├── internet_fig2_graphormer_pred.jpg │ ├── internet_fig3.jpg │ ├── internet_fig3_graphormer_pred.jpg │ ├── internet_fig4.jpg │ └── internet_fig4_graphormer_pred.jpg └── human-body │ ├── 3dpw_test1.jpg │ ├── 3dpw_test1_graphormer_pred.jpg │ ├── 3dpw_test2.jpg │ ├── 3dpw_test2_graphormer_pred.jpg │ ├── 3dpw_test3.jpg │ ├── 3dpw_test3_graphormer_pred.jpg │ ├── 3dpw_test4.jpg │ ├── 3dpw_test4_graphormer_pred.jpg │ ├── 3dpw_test5.jpg │ ├── 3dpw_test5_graphormer_pred.jpg │ ├── 3dpw_test6.jpg │ ├── 3dpw_test6_graphormer_pred.jpg │ ├── 3dpw_test7.jpg │ └── 3dpw_test7_graphormer_pred.jpg ├── scripts ├── download_models.sh └── download_preds.sh ├── setup.py └── src ├── __init__.py ├── datasets ├── __init__.py ├── build.py ├── hand_mesh_tsv.py └── human_mesh_tsv.py ├── modeling ├── __init__.py ├── _gcnn.py ├── _mano.py ├── _smpl.py ├── bert │ ├── __init__.py │ ├── bert-base-uncased │ │ └── config.json │ ├── e2e_body_network.py │ ├── e2e_hand_network.py │ ├── file_utils.py │ ├── modeling_bert.py │ ├── modeling_graphormer.py │ └── modeling_utils.py ├── data │ ├── J_regressor_extra.npy │ ├── J_regressor_h36m_correct.npy │ ├── README.md │ ├── config.py │ ├── mano_195_adjmat_indices.pt │ ├── mano_195_adjmat_size.pt │ ├── mano_195_adjmat_values.pt │ ├── mano_downsampling.npz │ ├── mesh_downsampling.npz │ ├── smpl_431_adjmat_indices.pt │ ├── smpl_431_adjmat_size.pt │ ├── smpl_431_adjmat_values.pt │ └── smpl_431_faces.npy └── hrnet │ ├── config │ ├── __init__.py │ ├── default.py │ └── models.py │ ├── hrnet_cls_net.py │ └── hrnet_cls_net_gridfeat.py ├── tools ├── run_gphmer_bodymesh.py ├── run_gphmer_bodymesh_inference.py ├── run_gphmer_handmesh.py ├── run_gphmer_handmesh_inference.py └── run_hand_multiscale.py └── utils ├── __init__.py ├── comm.py ├── dataset_utils.py ├── geometric_layers.py ├── image_ops.py ├── logger.py ├── metric_logger.py ├── metric_pampjpe.py ├── miscellaneous.py ├── renderer.py ├── tsv_file.py └── tsv_file_ops.py /.gitignore: -------------------------------------------------------------------------------- 1 | # compilation and distribution 2 | __pycache__ 3 | _ext 4 | *.pyc 5 | *.so 6 | build/ 7 | dist/ 8 | *.egg-info/ 9 | 10 | # pytorch/python formats 11 | *.pth 12 | *.pkl 13 | 14 | # ipython/jupyter notebooks 15 | *.ipynb 16 | **/.ipynb_checkpoints/ 17 | 18 | # Editor temporaries 19 | *.swn 20 | *.swo 21 | *.swp 22 | *~ 23 | 24 | # Pycharm editor settings 25 | .idea 26 | 27 | # VS code 28 | .vscode 29 | 30 | # MacOS 31 | .DS_Store 32 | 33 | # project dirs 34 | /data 35 | /datasets 36 | models 37 | output 38 | exps 39 | predictions 40 | 41 | # hidden folder for aml configs 42 | .azureml 43 | .git -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "manopth"] 2 | path = manopth 3 | url = https://github.com/hassony2/manopth.git 4 | [submodule "transformers"] 5 | path = transformers 6 | url = https://github.com/huggingface/transformers.git 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MeshGraphormer ✨✨ 2 | 3 | 4 | This is our research code of [Mesh Graphormer](https://arxiv.org/abs/2104.00272). 5 | 6 | Mesh Graphormer is a new transformer-based method for human pose and mesh reconsruction from an input image. In this work, we study how to combine graph convolutions and self-attentions in a transformer to better model both local and global interactions. 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | ## Installation 15 | Check [INSTALL.md](docs/INSTALL.md) for installation instructions. 16 | 17 | 18 | ## Model Zoo and Download 19 | Please download our pre-trained models and other relevant files that are important to run our code. 20 | 21 | Check [DOWNLOAD.md](docs/DOWNLOAD.md) for details. 22 | 23 | ## Quick demo 24 | We provide demo codes to run end-to-end inference on the test images. 25 | 26 | Check [DEMO.md](docs/DEMO.md) for details. 27 | 28 | ## Experiments 29 | We provide python codes for training and evaluation. 30 | 31 | Check [EXP.md](docs/EXP.md) for details. 32 | 33 | 34 | ## License 35 | 36 | Our research code is released under the MIT license. See [LICENSE](LICENSE) for details. 37 | 38 | We use submodules from third parties, such as [huggingface/transformers](https://github.com/huggingface/transformers) and [hassony2/manopth](https://github.com/hassony2/manopth). Please see [NOTICE](NOTICE.md) for details. 39 | 40 | Our models have dependency with SMPL and MANO models. Please note that any use of SMPL models and MANO models are subject to **Software Copyright License for non-commercial scientific research purposes**. Please see [SMPL-Model License](https://smpl.is.tue.mpg.de/modellicense) and [MANO License](https://mano.is.tue.mpg.de/license) for details. 41 | 42 | 43 | ## Contributing 44 | 45 | We welcome contributions and suggestions. Please check [CONTRIBUTE](docs/CONTRIBUTE.md) and [CODE_OF_CONDUCT](docs/CODE_OF_CONDUCT.md) for details. 46 | 47 | 48 | ## Citations 49 | If you find our work useful in your research, please consider citing: 50 | 51 | ```bibtex 52 | @inproceedings{lin2021-mesh-graphormer, 53 | author = {Lin, Kevin and Wang, Lijuan and Liu, Zicheng}, 54 | title = {Mesh Graphormer}, 55 | booktitle = {ICCV}, 56 | year = {2021}, 57 | } 58 | ``` 59 | 60 | 61 | ## Acknowledgments 62 | 63 | Our implementation and experiments are built on top of open-source GitHub repositories. We thank all the authors who made their code public, which tremendously accelerates our project progress. If you find these works helpful, please consider citing them as well. 64 | 65 | [huggingface/transformers](https://github.com/huggingface/transformers) 66 | 67 | [HRNet/HRNet-Image-Classification](https://github.com/HRNet/HRNet-Image-Classification) 68 | 69 | [nkolot/GraphCMR](https://github.com/nkolot/GraphCMR) 70 | 71 | [akanazawa/hmr](https://github.com/akanazawa/hmr) 72 | 73 | [MandyMo/pytorch_HMR](https://github.com/MandyMo/pytorch_HMR) 74 | 75 | [hassony2/manopth](https://github.com/hassony2/manopth) 76 | 77 | [hongsukchoi/Pose2Mesh_RELEASE](https://github.com/hongsukchoi/Pose2Mesh_RELEASE) 78 | 79 | [mks0601/I2L-MeshNet_RELEASE](https://github.com/mks0601/I2L-MeshNet_RELEASE) 80 | 81 | [open-mmlab/mmdetection](https://github.com/open-mmlab/mmdetection) 82 | -------------------------------------------------------------------------------- /azure-pipelines.yml: -------------------------------------------------------------------------------- 1 | # Starter pipeline 2 | # Start with a minimal pipeline that you can customize to build and deploy your code. 3 | # Add steps that build, run tests, deploy, and more: 4 | # https://aka.ms/yaml 5 | 6 | trigger: 7 | - main 8 | 9 | pool: 10 | vmImage: ubuntu-latest 11 | strategy: 12 | matrix: 13 | Python36: 14 | python.version: '3.6' 15 | 16 | steps: 17 | - task: UsePythonVersion@0 18 | inputs: 19 | versionSpec: '$(python.version)' 20 | displayName: 'Use Python $(python.version)' 21 | 22 | - script: | 23 | python -m pip install --upgrade pip 24 | pip install -r requirements.txt 25 | displayName: 'Install dependencies' 26 | 27 | -------------------------------------------------------------------------------- /docs/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /docs/DEMO.md: -------------------------------------------------------------------------------- 1 | # Quick Demo 2 | We provide demo codes for end-to-end inference here. 3 | 4 | Our inference codes will iterate all images in a given folder, and generate the results. 5 | 6 | ## Important notes 7 | 8 | - **This demo doesn't perform human/hand detection**. Our model requires a centered target in the image. 9 | - As **Mesh Graphormer is a data-driven approach**, it may not perform well if the test samples are very different from the training data. We observe that our model does not work well if the target is out-of-the-view. Some examples can be found in our supplementary material (Sec. I Limitations). 10 | 11 | ## Human Body Reconstruction 12 | 13 | This demo runs 3D human mesh reconstruction from a single image. 14 | 15 | Our codes require the input images that are already **cropped with the person centered** in the image. The input images should have the size of `224x224`. To run the demo, please place your test images under `./samples/human-body`, and then run the following script. 16 | 17 | 18 | ```bash 19 | python ./src/tools/run_gphmer_bodymesh_inference.py 20 | --resume_checkpoint ./models/graphormer_release/graphormer_3dpw_state_dict.bin 21 | --image_file_or_path ./samples/human-body 22 | ``` 23 | After running, it will generate the results in the folder `./samples/human-body` 24 | 25 | 26 | 27 | ## Hand Reconstruction 28 | 29 | This demo runs 3D hand reconstruction from a single image. 30 | 31 | You may want to provide the images that are already **cropped with the right-hand centered** in the image. The input images should have the size of `224x224`. Please place the images under `./samples/hand`, and run the following script. 32 | 33 | ```bash 34 | python ./src/tools/run_gphmer_handmesh_inference.py 35 | --resume_checkpoint ./models/graphormer_release/graphormer_hand_state_dict.bin 36 | --image_file_or_path ./samples/hand 37 | ``` 38 | After running, it will outputs the results in the folder `./samples/hand` 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /docs/DOWNLOAD.md: -------------------------------------------------------------------------------- 1 | # Download 2 | 3 | ### Getting Started 4 | 5 | 1. Create folders that store pretrained models, datasets, and predictions. 6 | ```bash 7 | export REPO_DIR=$PWD 8 | mkdir -p $REPO_DIR/models # pre-trained models 9 | mkdir -p $REPO_DIR/datasets # datasets 10 | mkdir -p $REPO_DIR/predictions # prediction outputs 11 | ``` 12 | 13 | 2. Download pretrained models. 14 | 15 | Our pre-trained models can be downloaded with the following command. 16 | ```bash 17 | cd $REPO_DIR 18 | bash scripts/download_models.sh 19 | ``` 20 | The scripts will download three models that are trained for mesh reconstruction on Human3.6M, 3DPW, and FreiHAND, respectively. For your convenience, this script will also download HRNet pre-trained weights, which will be used in training. 21 | 22 | The resulting data structure should follow the hierarchy as below. 23 | ``` 24 | ${REPO_DIR} 25 | |-- models 26 | | |-- graphormer_release 27 | | | |-- graphormer_h36m_state_dict.bin 28 | | | |-- graphormer_3dpw_state_dict.bin 29 | | | |-- graphormer_hand_state_dict.bin 30 | | |-- hrnet 31 | | | |-- hrnetv2_w40_imagenet_pretrained.pth 32 | | | |-- hrnetv2_w64_imagenet_pretrained.pth 33 | | | |-- cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml 34 | | | |-- cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml 35 | |-- src 36 | |-- datasets 37 | |-- predictions 38 | |-- README.md 39 | |-- ... 40 | |-- ... 41 | ``` 42 | 43 | 3. Download SMPL and MANO models from their official websites 44 | 45 | To run our code smoothly, please visit the following websites to download SMPL and MANO models. 46 | 47 | - Download `basicModel_neutral_lbs_10_207_0_v1.0.0.pkl` from [SMPLify](http://smplify.is.tue.mpg.de/), and place it at `${REPO_DIR}/src/modeling/data`. 48 | - Download `MANO_RIGHT.pkl` from [MANO](https://mano.is.tue.mpg.de/), and place it at `${REPO_DIR}/src/modeling/data`. 49 | 50 | Please put the downloaded files under the `${REPO_DIR}/src/modeling/data` directory. The data structure should follow the hierarchy below. 51 | ``` 52 | ${REPO_DIR} 53 | |-- src 54 | | |-- modeling 55 | | | |-- data 56 | | | | |-- basicModel_neutral_lbs_10_207_0_v1.0.0.pkl 57 | | | | |-- MANO_RIGHT.pkl 58 | |-- models 59 | |-- datasets 60 | |-- predictions 61 | |-- README.md 62 | |-- ... 63 | |-- ... 64 | ``` 65 | Please check [/src/modeling/data/README.md](../src/modeling/data/README.md) for further details. 66 | 67 | 4. Download prediction files that were evaluated on FreiHAND Leaderboard. 68 | 69 | The prediction files can be downloaded with the following command. 70 | ```bash 71 | cd $REPO_DIR 72 | bash scripts/download_preds.sh 73 | ``` 74 | You could submit the prediction files to FreiHAND Leaderboard and reproduce our results. 75 | 76 | 5. Download datasets and pseudo labels for training. 77 | 78 | We use the same data from our previous project [METRO](https://github.com/microsoft/MeshTransformer) 79 | 80 | Please visit our previous project page to download datasets and annotations for experiments. Click [LINK](https://github.com/microsoft/MeshTransformer/blob/main/docs/DOWNLOAD.md) here. 81 | -------------------------------------------------------------------------------- /docs/EXP.md: -------------------------------------------------------------------------------- 1 | # Training and evaluation 2 | 3 | 4 | ## Table of contents 5 | * [3D hand experiment](#3D-hand-reconstruction-from-a-single-image) 6 | * [Training](#Training) 7 | * [Testing](#Testing) 8 | * [3D human body experiment](#Human-mesh-reconstruction-from-a-single-image) 9 | * [Training with mixed 2D+3D datasets](#Training-with-mixed-datasets) 10 | * [Evaluation on Human3.6M](#Evaluation-on-Human3.6M) 11 | * [Training with 3DPW](#Training-with-3DPW-dataset) 12 | * [Evaluation on 3DPW](#Evaluation-on-3DPW) 13 | 14 | 15 | ## 3D hand reconstruction from a single image 16 | 17 | ### Training 18 | 19 | We use the following script to train on FreiHAND dataset. 20 | 21 | ```bash 22 | python -m torch.distributed.launch --nproc_per_node=8 \ 23 | src/tools/run_gphmer_handmesh.py \ 24 | --train_yaml freihand/train.yaml \ 25 | --val_yaml freihand/test.yaml \ 26 | --arch hrnet-w64 \ 27 | --num_workers 4 \ 28 | --per_gpu_train_batch_size 32 \ 29 | --per_gpu_eval_batch_size 32 \ 30 | --num_hidden_layers 4 \ 31 | --num_attention_heads 4 \ 32 | --lr 1e-4 \ 33 | --num_train_epochs 200 \ 34 | --input_feat_dim 2051,512,128 \ 35 | --hidden_feat_dim 1024,256,64 36 | ``` 37 | 38 | 39 | Example training log can be found here [2021-03-06-graphormer_freihand_log.txt](https://datarelease.blob.core.windows.net/metro/models/2021-03-06-graphormer_freihand_log.txt) 40 | 41 | ### Testing 42 | 43 | After training, we use the final checkpoint (trained at 200 epoch) for testing. 44 | 45 | We use the following script to generate predictions. It will generate a prediction file called `ckpt200-sc10_rot0-pred.zip`. Afte that, please submit the prediction file to [FreiHAND Leaderboard](https://competitions.codalab.org/competitions/21238) to obtain the evlauation scores. 46 | 47 | 48 | In the following script, we perform prediction with test-time augmentation on FreiHAND experiments. We will generate a prediction file `ckpt200-multisc-pred.zip`. 49 | 50 | ```bash 51 | python src/tools/run_hand_multiscale.py \ 52 | --multiscale_inference \ 53 | --model_path models/graphormer_release/graphormer_hand_state_dict.bin \ 54 | ``` 55 | 56 | To reproduce our results, we have released our prediction file `ckpt200-multisc-pred.zip` (see `docs/DOWNLOAD.md`). You may want to submit it to the Leaderboard, and it should produce the following results. 57 | 58 | ```bash 59 | Evaluation 3D KP results: 60 | auc=0.000, mean_kp3d_avg=71.48 cm 61 | Evaluation 3D KP ALIGNED results: 62 | auc=0.883, mean_kp3d_avg=0.59 cm 63 | 64 | Evaluation 3D MESH results: 65 | auc=0.000, mean_kp3d_avg=71.47 cm 66 | Evaluation 3D MESH ALIGNED results: 67 | auc=0.880, mean_kp3d_avg=0.60 cm 68 | 69 | F-scores 70 | F@5.0mm = 0.000 F_aligned@5.0mm = 0.764 71 | F@15.0mm = 0.000 F_aligned@15.0mm = 0.986 72 | ``` 73 | 74 | Note that our method predicts relative coordinates (there is no global alignment). Therefore, only the aligned scores are meaningful in our case. 75 | 76 | 77 | ## Human mesh reconstruction from a single image 78 | 79 | 80 | ### Training with mixed datasets 81 | 82 | We conduct large-scale training on multiple 2D and 3D datasets, including Human3.6M, COCO, MUCO, UP3D, MPII. During training, it will evaluate the performance per epoch, and save the best checkpoints. 83 | 84 | ```bash 85 | python -m torch.distributed.launch --nproc_per_node=8 \ 86 | src/tools/run_gphmer_bodymesh.py \ 87 | --train_yaml Tax-H36m-coco40k-Muco-UP-Mpii/train.yaml \ 88 | --val_yaml human3.6m/valid.protocol2.yaml \ 89 | --arch hrnet-w64 \ 90 | --num_workers 4 \ 91 | --per_gpu_train_batch_size 25 \ 92 | --per_gpu_eval_batch_size 25 \ 93 | --num_hidden_layers 4 \ 94 | --num_attention_heads 4 \ 95 | --lr 1e-4 \ 96 | --num_train_epochs 200 \ 97 | --input_feat_dim 2051,512,128 \ 98 | --hidden_feat_dim 1024,256,64 99 | ``` 100 | 101 | Example training log can be found here [2021-02-25-graphormer_h36m_log](https://datarelease.blob.core.windows.net/metro/models/2021-02-25-graphormer_h36m_log.txt) 102 | 103 | ### Evaluation on Human3.6M 104 | 105 | In the following script, we evaluate our model `graphormer_h36m_state_dict.bin` on Human3.6M validation set. Check `docs/DOWNLOAD.md` for more details about downloading the model file. 106 | 107 | ```bash 108 | python -m torch.distributed.launch --nproc_per_node=8 \ 109 | src/tools/run_gphmer_bodymesh.py \ 110 | --val_yaml human3.6m/valid.protocol2.yaml \ 111 | --arch hrnet-w64 \ 112 | --num_workers 4 \ 113 | --per_gpu_eval_batch_size 25 \ 114 | --num_hidden_layers 4 \ 115 | --num_attention_heads 4 \ 116 | --input_feat_dim 2051,512,128 \ 117 | --hidden_feat_dim 1024,256,64 \ 118 | --run_eval_only \ 119 | --resume_checkpoint ./models/graphormer_release/graphormer_h36m_state_dict.bin 120 | ``` 121 | 122 | We show the example outputs of this script as below. 123 | ```bash 124 | 2021-09-19 13:18:14,416 Graphormer INFO: Using 8 GPUs 125 | 2021-09-19 13:18:18,712 Graphormer INFO: Update config parameter num_hidden_layers: 12 -> 4 126 | 2021-09-19 13:18:18,718 Graphormer INFO: Update config parameter hidden_size: 768 -> 1024 127 | 2021-09-19 13:18:18,725 Graphormer INFO: Update config parameter num_attention_heads: 12 -> 4 128 | 2021-09-19 13:18:18,731 Graphormer INFO: Update config parameter intermediate_size: 3072 -> 2048 129 | 2021-09-19 13:18:19,983 Graphormer INFO: Init model from scratch. 130 | 2021-09-19 13:18:19,990 Graphormer INFO: Update config parameter num_hidden_layers: 12 -> 4 131 | 2021-09-19 13:18:19,995 Graphormer INFO: Update config parameter hidden_size: 768 -> 256 132 | 2021-09-19 13:18:20,001 Graphormer INFO: Update config parameter num_attention_heads: 12 -> 4 133 | 2021-09-19 13:18:20,006 Graphormer INFO: Update config parameter intermediate_size: 3072 -> 512 134 | 2021-09-19 13:18:20,210 Graphormer INFO: Init model from scratch. 135 | 2021-09-19 13:18:20,217 Graphormer INFO: Add Graph Conv 136 | 2021-09-19 13:18:20,223 Graphormer INFO: Update config parameter num_hidden_layers: 12 -> 4 137 | 2021-09-19 13:18:20,228 Graphormer INFO: Update config parameter hidden_size: 768 -> 64 138 | 2021-09-19 13:18:20,233 Graphormer INFO: Update config parameter num_attention_heads: 12 -> 4 139 | 2021-09-19 13:18:20,239 Graphormer INFO: Update config parameter intermediate_size: 3072 -> 128 140 | 2021-09-19 13:18:20,295 Graphormer INFO: Init model from scratch. 141 | 2021-09-19 13:18:23,797 Graphormer INFO: => loading hrnet-v2-w64 model 142 | 2021-09-19 13:18:23,805 Graphormer INFO: Graphormer encoders total parameters: 83318598 143 | 2021-09-19 13:18:23,814 Graphormer INFO: Backbone total parameters: 128059944 144 | 2021-09-19 13:18:23,892 Graphormer INFO: Loading state dict from checkpoint _output/graphormer_release/graphormer_h36m_state_dict.bin 145 | 2021-09-19 13:19:26,299 Graphormer INFO: Validation epoch: 0 mPVE: 0.00, mPJPE: 51.20, PAmPJPE: 34.55 146 | ``` 147 | 148 | 149 | 150 | ### Training with 3DPW dataset 151 | 152 | We follow prior works that also use 3DPW training data. In order to make the training faster, we **fine-tune** our pre-trained model (`graphormer_h36m_state_dict.bin`) on 3DPW training set. 153 | 154 | We use the following script for fine-tuning. During fine-tuning, it will evaluate the performance per epoch, and save the best checkpoints. 155 | 156 | ```bash 157 | python -m torch.distributed.launch --nproc_per_node=8 \ 158 | src/tools/run_gphmer_bodymesh.py \ 159 | --train_yaml 3dpw/train.yaml \ 160 | --val_yaml 3dpw/test_has_gender.yaml \ 161 | --arch hrnet-w64 \ 162 | --num_workers 4 \ 163 | --per_gpu_train_batch_size 20 \ 164 | --per_gpu_eval_batch_size 20 \ 165 | --num_hidden_layers 4 \ 166 | --num_attention_heads 4 \ 167 | --lr 1e-4 \ 168 | --num_train_epochs 5 \ 169 | --input_feat_dim 2051,512,128 \ 170 | --hidden_feat_dim 1024,256,64 \ 171 | --resume_checkpoint {YOUR_PATH/state_dict.bin} \ 172 | ``` 173 | 174 | 175 | ### Evaluation on 3DPW 176 | In the following script, we evaluate our model `graphormer_3dpw_state_dict.bin` on 3DPW test set. Check `docs/DOWNLOAD.md` for more details about downloading the model file. 177 | 178 | 179 | ```bash 180 | python -m torch.distributed.launch --nproc_per_node=8 \ 181 | src/tools/run_gphmer_bodymesh.py \ 182 | --val_yaml 3dpw/test.yaml \ 183 | --arch hrnet-w64 \ 184 | --num_workers 4 \ 185 | --per_gpu_eval_batch_size 25 \ 186 | --num_hidden_layers 4 \ 187 | --num_attention_heads 4 \ 188 | --input_feat_dim 2051,512,128 \ 189 | --hidden_feat_dim 1024,256,64 \ 190 | --run_eval_only \ 191 | --resume_checkpoint ./models/graphormer_release/graphormer_3dpw_state_dict.bin 192 | ``` 193 | 194 | After evaluation, it should reproduce the results below 195 | ```bash 196 | 2021-09-20 00:54:46,178 Graphormer INFO: Using 8 GPUs 197 | 2021-09-20 00:54:50,339 Graphormer INFO: Update config parameter num_hidden_layers: 12 -> 4 198 | 2021-09-20 00:54:50,345 Graphormer INFO: Update config parameter hidden_size: 768 -> 1024 199 | 2021-09-20 00:54:50,351 Graphormer INFO: Update config parameter num_attention_heads: 12 -> 4 200 | 2021-09-20 00:54:50,357 Graphormer INFO: Update config parameter intermediate_size: 3072 -> 2048 201 | 2021-09-20 00:54:51,602 Graphormer INFO: Init model from scratch. 202 | 2021-09-20 00:54:51,613 Graphormer INFO: Update config parameter num_hidden_layers: 12 -> 4 203 | 2021-09-20 00:54:51,625 Graphormer INFO: Update config parameter hidden_size: 768 -> 256 204 | 2021-09-20 00:54:51,646 Graphormer INFO: Update config parameter num_attention_heads: 12 -> 4 205 | 2021-09-20 00:54:51,652 Graphormer INFO: Update config parameter intermediate_size: 3072 -> 512 206 | 2021-09-20 00:54:51,855 Graphormer INFO: Init model from scratch. 207 | 2021-09-20 00:54:51,862 Graphormer INFO: Add Graph Conv 208 | 2021-09-20 00:54:51,868 Graphormer INFO: Update config parameter num_hidden_layers: 12 -> 4 209 | 2021-09-20 00:54:51,873 Graphormer INFO: Update config parameter hidden_size: 768 -> 64 210 | 2021-09-20 00:54:51,880 Graphormer INFO: Update config parameter num_attention_heads: 12 -> 4 211 | 2021-09-20 00:54:51,885 Graphormer INFO: Update config parameter intermediate_size: 3072 -> 128 212 | 2021-09-20 00:54:51,948 Graphormer INFO: Init model from scratch. 213 | 2021-09-20 00:54:55,564 Graphormer INFO: => loading hrnet-v2-w64 model 214 | 2021-09-20 00:54:55,572 Graphormer INFO: Graphormer encoders total parameters: 83318598 215 | 2021-09-20 00:54:55,580 Graphormer INFO: Backbone total parameters: 128059944 216 | 2021-09-20 00:54:55,655 Graphormer INFO: Loading state dict from checkpoint _output/graphormer_release/graphormer_3dpw_state_dict.bin 217 | 2021-09-20 00:56:24,334 Graphormer INFO: Validation epoch: 0 mPVE: 87.57, mPJPE: 73.98, PAmPJPE: 45.35 218 | ``` 219 | 220 | -------------------------------------------------------------------------------- /docs/Fig1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/docs/Fig1.gif -------------------------------------------------------------------------------- /docs/Fig2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/docs/Fig2.gif -------------------------------------------------------------------------------- /docs/Fig3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/docs/Fig3.gif -------------------------------------------------------------------------------- /docs/Fig4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/docs/Fig4.gif -------------------------------------------------------------------------------- /docs/INSTALL.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | Our codebase is developed based on Ubuntu 16.04 and NVIDIA GPU cards. 4 | 5 | ### Requirements 6 | - Python 3.7 7 | - Pytorch 1.4 8 | - torchvision 0.5.0 9 | - cuda 10.1 10 | 11 | ### Setup with Conda 12 | 13 | We suggest to create a new conda environment and install all the relevant dependencies. 14 | 15 | ```bash 16 | # Create a new environment 17 | conda create --name gphmr python=3.7 18 | conda activate gphmr 19 | 20 | # Install Pytorch 21 | conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.1 -c pytorch 22 | 23 | export INSTALL_DIR=$PWD 24 | 25 | # Install apex 26 | cd $INSTALL_DIR 27 | git clone https://github.com/NVIDIA/apex.git 28 | cd apex 29 | python setup.py install --cuda_ext --cpp_ext 30 | 31 | # Install OpenDR 32 | pip install matplotlib 33 | pip install git+https://gitlab.eecs.umich.edu/ngv-python-modules/opendr.git 34 | 35 | # Install MeshGraphormer 36 | cd $INSTALL_DIR 37 | git clone --recursive https://github.com/microsoft/MeshGraphormer.git 38 | cd MeshGraphormer 39 | python setup.py build develop 40 | 41 | # Install requirements 42 | pip install -r requirements.txt 43 | 44 | # Install manopth 45 | cd $INSTALL_DIR 46 | cd MeshGraphormer 47 | pip install ./manopth/. 48 | 49 | unset INSTALL_DIR 50 | ``` 51 | 52 | 53 | -------------------------------------------------------------------------------- /docs/SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /docs/SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). 7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /docs/graphormer_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/docs/graphormer_overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | yacs==0.1.8 2 | cython 3 | opencv-python 4 | tqdm 5 | nltk 6 | numpy 7 | scipy==1.4.1 8 | chumpy 9 | boto3 10 | requests 11 | -------------------------------------------------------------------------------- /samples/hand/freihand_sample1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/freihand_sample1.jpg -------------------------------------------------------------------------------- /samples/hand/freihand_sample1_graphormer_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/freihand_sample1_graphormer_pred.jpg -------------------------------------------------------------------------------- /samples/hand/freihand_sample2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/freihand_sample2.jpg -------------------------------------------------------------------------------- /samples/hand/freihand_sample2_graphormer_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/freihand_sample2_graphormer_pred.jpg -------------------------------------------------------------------------------- /samples/hand/freihand_sample3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/freihand_sample3.jpg -------------------------------------------------------------------------------- /samples/hand/freihand_sample3_graphormer_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/freihand_sample3_graphormer_pred.jpg -------------------------------------------------------------------------------- /samples/hand/internet_fig1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/internet_fig1.jpg -------------------------------------------------------------------------------- /samples/hand/internet_fig1_graphormer_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/internet_fig1_graphormer_pred.jpg -------------------------------------------------------------------------------- /samples/hand/internet_fig2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/internet_fig2.jpg -------------------------------------------------------------------------------- /samples/hand/internet_fig2_graphormer_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/internet_fig2_graphormer_pred.jpg -------------------------------------------------------------------------------- /samples/hand/internet_fig3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/internet_fig3.jpg -------------------------------------------------------------------------------- /samples/hand/internet_fig3_graphormer_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/internet_fig3_graphormer_pred.jpg -------------------------------------------------------------------------------- /samples/hand/internet_fig4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/internet_fig4.jpg -------------------------------------------------------------------------------- /samples/hand/internet_fig4_graphormer_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/internet_fig4_graphormer_pred.jpg -------------------------------------------------------------------------------- /samples/human-body/3dpw_test1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test1.jpg -------------------------------------------------------------------------------- /samples/human-body/3dpw_test1_graphormer_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test1_graphormer_pred.jpg -------------------------------------------------------------------------------- /samples/human-body/3dpw_test2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test2.jpg -------------------------------------------------------------------------------- /samples/human-body/3dpw_test2_graphormer_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test2_graphormer_pred.jpg -------------------------------------------------------------------------------- /samples/human-body/3dpw_test3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test3.jpg -------------------------------------------------------------------------------- /samples/human-body/3dpw_test3_graphormer_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test3_graphormer_pred.jpg -------------------------------------------------------------------------------- /samples/human-body/3dpw_test4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test4.jpg -------------------------------------------------------------------------------- /samples/human-body/3dpw_test4_graphormer_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test4_graphormer_pred.jpg -------------------------------------------------------------------------------- /samples/human-body/3dpw_test5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test5.jpg -------------------------------------------------------------------------------- /samples/human-body/3dpw_test5_graphormer_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test5_graphormer_pred.jpg -------------------------------------------------------------------------------- /samples/human-body/3dpw_test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test6.jpg -------------------------------------------------------------------------------- /samples/human-body/3dpw_test6_graphormer_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test6_graphormer_pred.jpg -------------------------------------------------------------------------------- /samples/human-body/3dpw_test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test7.jpg -------------------------------------------------------------------------------- /samples/human-body/3dpw_test7_graphormer_pred.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test7_graphormer_pred.jpg -------------------------------------------------------------------------------- /scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | # -------------------------------- 2 | # Setup 3 | # -------------------------------- 4 | export REPO_DIR=$PWD 5 | if [ ! -d $REPO_DIR/models ] ; then 6 | mkdir -p $REPO_DIR/models 7 | fi 8 | BLOB='https://datarelease.blob.core.windows.net/metro' 9 | 10 | 11 | # -------------------------------- 12 | # Download our pre-trained models 13 | # -------------------------------- 14 | if [ ! -d $REPO_DIR/models/graphormer_release ] ; then 15 | mkdir -p $REPO_DIR/models/graphormer_release 16 | fi 17 | # (1) Mesh Graphormer for human mesh reconstruction (trained on H3.6M + COCO + MuCO + UP3D + MPII) 18 | wget -nc $BLOB/models/graphormer_h36m_state_dict.bin -O $REPO_DIR/models/graphormer_release/graphormer_h36m_state_dict.bin 19 | # (2) Mesh Graphormer for human mesh reconstruction (trained on H3.6M + COCO + MuCO + UP3D + MPII, then fine-tuned on 3DPW) 20 | wget -nc $BLOB/models/graphormer_3dpw_state_dict.bin -O $REPO_DIR/models/graphormer_release/graphormer_3dpw_state_dict.bin 21 | # (3) Mesh Graphormer for hand mesh reconstruction (trained on FreiHAND) 22 | wget -nc $BLOB/models/graphormer_hand_state_dict.bin -O $REPO_DIR/models/graphormer_release/graphormer_hand_state_dict.bin 23 | 24 | 25 | # -------------------------------- 26 | # Download the ImageNet pre-trained HRNet models 27 | # The weights are provided by https://github.com/HRNet/HRNet-Image-Classification 28 | # -------------------------------- 29 | if [ ! -d $REPO_DIR/models/hrnet ] ; then 30 | mkdir -p $REPO_DIR/models/hrnet 31 | fi 32 | wget -nc $BLOB/models/hrnetv2_w64_imagenet_pretrained.pth -O $REPO_DIR/models/hrnet/hrnetv2_w64_imagenet_pretrained.pth 33 | wget -nc $BLOB/models/hrnetv2_w40_imagenet_pretrained.pth -O $REPO_DIR/models/hrnet/hrnetv2_w40_imagenet_pretrained.pth 34 | wget -nc $BLOB/models/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml -O $REPO_DIR/models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml 35 | wget -nc $BLOB/models/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml -O $REPO_DIR/models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml 36 | 37 | 38 | -------------------------------------------------------------------------------- /scripts/download_preds.sh: -------------------------------------------------------------------------------- 1 | # -------------------------------- 2 | # Setup 3 | # -------------------------------- 4 | export REPO_DIR=$PWD 5 | if [ ! -d $REPO_DIR/models ] ; then 6 | mkdir -p $REPO_DIR/models 7 | fi 8 | BLOB='https://datarelease.blob.core.windows.net/metro' 9 | 10 | # -------------------------------- 11 | # Download our model predictions that can be submitted to FreiHAND Leaderboard 12 | # -------------------------------- 13 | if [ ! -d $REPO_DIR/predictions ] ; then 14 | mkdir -p $REPO_DIR/predictions 15 | fi 16 | # Our model + test-time augmentation. It achieves 5.9 PA-MPVPE on FreiHAND Leaderboard 17 | wget -nc $BLOB/graphormer-release-ckpt200-multisc-pred.zip -O $REPO_DIR/predictions/ckpt200-multisc-pred.zip 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | from __future__ import print_function 4 | import os 5 | import sys 6 | import re 7 | import os.path as op 8 | from setuptools import find_packages, setup 9 | 10 | # change directory to this module path 11 | try: 12 | this_file = __file__ 13 | except NameError: 14 | this_file = sys.argv[0] 15 | this_file = os.path.abspath(this_file) 16 | if op.dirname(this_file): 17 | os.chdir(op.dirname(this_file)) 18 | script_dir = os.getcwd() 19 | 20 | def readme(fname): 21 | """Read text out of a file in the same directory as setup.py. 22 | """ 23 | return open(op.join(script_dir, fname)).read() 24 | 25 | 26 | def find_version(fname): 27 | version_file = readme(fname) 28 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", 29 | version_file, re.M) 30 | if version_match: 31 | return version_match.group(1) 32 | raise RuntimeError("Unable to find version string.") 33 | 34 | 35 | setup( 36 | name="graphormer", 37 | version=find_version("src/__init__.py"), 38 | description="graphormer", 39 | long_description=readme('README.md'), 40 | packages=find_packages(), 41 | classifiers=[ 42 | 'Intended Audience :: Developers', 43 | "Programming Language :: Python", 44 | 'Topic :: Software Development', 45 | ] 46 | ) 47 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/datasets/build.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | """ 6 | 7 | 8 | import os.path as op 9 | import torch 10 | import logging 11 | import code 12 | from src.utils.comm import get_world_size 13 | from src.datasets.human_mesh_tsv import (MeshTSVDataset, MeshTSVYamlDataset) 14 | from src.datasets.hand_mesh_tsv import (HandMeshTSVDataset, HandMeshTSVYamlDataset) 15 | 16 | 17 | def build_dataset(yaml_file, args, is_train=True, scale_factor=1): 18 | print(yaml_file) 19 | if not op.isfile(yaml_file): 20 | yaml_file = op.join(args.data_dir, yaml_file) 21 | # code.interact(local=locals()) 22 | assert op.isfile(yaml_file) 23 | return MeshTSVYamlDataset(yaml_file, is_train, False, scale_factor) 24 | 25 | 26 | class IterationBasedBatchSampler(torch.utils.data.sampler.BatchSampler): 27 | """ 28 | Wraps a BatchSampler, resampling from it until 29 | a specified number of iterations have been sampled 30 | """ 31 | 32 | def __init__(self, batch_sampler, num_iterations, start_iter=0): 33 | self.batch_sampler = batch_sampler 34 | self.num_iterations = num_iterations 35 | self.start_iter = start_iter 36 | 37 | def __iter__(self): 38 | iteration = self.start_iter 39 | while iteration <= self.num_iterations: 40 | # if the underlying sampler has a set_epoch method, like 41 | # DistributedSampler, used for making each process see 42 | # a different split of the dataset, then set it 43 | if hasattr(self.batch_sampler.sampler, "set_epoch"): 44 | self.batch_sampler.sampler.set_epoch(iteration) 45 | for batch in self.batch_sampler: 46 | iteration += 1 47 | if iteration > self.num_iterations: 48 | break 49 | yield batch 50 | 51 | def __len__(self): 52 | return self.num_iterations 53 | 54 | 55 | def make_batch_data_sampler(sampler, images_per_gpu, num_iters=None, start_iter=0): 56 | batch_sampler = torch.utils.data.sampler.BatchSampler( 57 | sampler, images_per_gpu, drop_last=False 58 | ) 59 | if num_iters is not None and num_iters >= 0: 60 | batch_sampler = IterationBasedBatchSampler( 61 | batch_sampler, num_iters, start_iter 62 | ) 63 | return batch_sampler 64 | 65 | 66 | def make_data_sampler(dataset, shuffle, distributed): 67 | if distributed: 68 | return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) 69 | if shuffle: 70 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 71 | else: 72 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 73 | return sampler 74 | 75 | 76 | def make_data_loader(args, yaml_file, is_distributed=True, 77 | is_train=True, start_iter=0, scale_factor=1): 78 | 79 | dataset = build_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor) 80 | logger = logging.getLogger(__name__) 81 | if is_train==True: 82 | shuffle = True 83 | images_per_gpu = args.per_gpu_train_batch_size 84 | images_per_batch = images_per_gpu * get_world_size() 85 | iters_per_batch = len(dataset) // images_per_batch 86 | num_iters = iters_per_batch * args.num_train_epochs 87 | logger.info("Train with {} images per GPU.".format(images_per_gpu)) 88 | logger.info("Total batch size {}".format(images_per_batch)) 89 | logger.info("Total training steps {}".format(num_iters)) 90 | else: 91 | shuffle = False 92 | images_per_gpu = args.per_gpu_eval_batch_size 93 | num_iters = None 94 | start_iter = 0 95 | 96 | sampler = make_data_sampler(dataset, shuffle, is_distributed) 97 | batch_sampler = make_batch_data_sampler( 98 | sampler, images_per_gpu, num_iters, start_iter 99 | ) 100 | data_loader = torch.utils.data.DataLoader( 101 | dataset, num_workers=args.num_workers, batch_sampler=batch_sampler, 102 | pin_memory=True, 103 | ) 104 | return data_loader 105 | 106 | 107 | #============================================================================================== 108 | 109 | def build_hand_dataset(yaml_file, args, is_train=True, scale_factor=1): 110 | print(yaml_file) 111 | if not op.isfile(yaml_file): 112 | yaml_file = op.join(args.data_dir, yaml_file) 113 | # code.interact(local=locals()) 114 | assert op.isfile(yaml_file) 115 | return HandMeshTSVYamlDataset(args, yaml_file, is_train, False, scale_factor) 116 | 117 | 118 | def make_hand_data_loader(args, yaml_file, is_distributed=True, 119 | is_train=True, start_iter=0, scale_factor=1): 120 | 121 | dataset = build_hand_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor) 122 | logger = logging.getLogger(__name__) 123 | if is_train==True: 124 | shuffle = True 125 | images_per_gpu = args.per_gpu_train_batch_size 126 | images_per_batch = images_per_gpu * get_world_size() 127 | iters_per_batch = len(dataset) // images_per_batch 128 | num_iters = iters_per_batch * args.num_train_epochs 129 | logger.info("Train with {} images per GPU.".format(images_per_gpu)) 130 | logger.info("Total batch size {}".format(images_per_batch)) 131 | logger.info("Total training steps {}".format(num_iters)) 132 | else: 133 | shuffle = False 134 | images_per_gpu = args.per_gpu_eval_batch_size 135 | num_iters = None 136 | start_iter = 0 137 | 138 | sampler = make_data_sampler(dataset, shuffle, is_distributed) 139 | batch_sampler = make_batch_data_sampler( 140 | sampler, images_per_gpu, num_iters, start_iter 141 | ) 142 | data_loader = torch.utils.data.DataLoader( 143 | dataset, num_workers=args.num_workers, batch_sampler=batch_sampler, 144 | pin_memory=True, 145 | ) 146 | return data_loader 147 | 148 | -------------------------------------------------------------------------------- /src/datasets/hand_mesh_tsv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | """ 6 | 7 | 8 | import cv2 9 | import math 10 | import json 11 | from PIL import Image 12 | import os.path as op 13 | import numpy as np 14 | import code 15 | 16 | from src.utils.tsv_file import TSVFile, CompositeTSVFile 17 | from src.utils.tsv_file_ops import load_linelist_file, load_from_yaml_file, find_file_path_in_yaml 18 | from src.utils.image_ops import img_from_base64, crop, flip_img, flip_pose, flip_kp, transform, rot_aa 19 | import torch 20 | import torchvision.transforms as transforms 21 | 22 | 23 | class HandMeshTSVDataset(object): 24 | def __init__(self, args, img_file, label_file=None, hw_file=None, 25 | linelist_file=None, is_train=True, cv2_output=False, scale_factor=1): 26 | 27 | self.args = args 28 | self.img_file = img_file 29 | self.label_file = label_file 30 | self.hw_file = hw_file 31 | self.linelist_file = linelist_file 32 | self.img_tsv = self.get_tsv_file(img_file) 33 | self.label_tsv = None if label_file is None else self.get_tsv_file(label_file) 34 | self.hw_tsv = None if hw_file is None else self.get_tsv_file(hw_file) 35 | 36 | if self.is_composite: 37 | assert op.isfile(self.linelist_file) 38 | self.line_list = [i for i in range(self.hw_tsv.num_rows())] 39 | else: 40 | self.line_list = load_linelist_file(linelist_file) 41 | 42 | self.cv2_output = cv2_output 43 | self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], 44 | std=[0.229, 0.224, 0.225]) 45 | self.is_train = is_train 46 | self.scale_factor = 0.25 # rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor] 47 | self.noise_factor = 0.4 48 | self.rot_factor = 90 # Random rotation in the range [-rot_factor, rot_factor] 49 | self.img_res = 224 50 | self.image_keys = self.prepare_image_keys() 51 | self.joints_definition = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 52 | 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4') 53 | self.root_index = self.joints_definition.index('Wrist') 54 | 55 | def get_tsv_file(self, tsv_file): 56 | if tsv_file: 57 | if self.is_composite: 58 | return CompositeTSVFile(tsv_file, self.linelist_file, 59 | root=self.root) 60 | tsv_path = find_file_path_in_yaml(tsv_file, self.root) 61 | return TSVFile(tsv_path) 62 | 63 | def get_valid_tsv(self): 64 | # sorted by file size 65 | if self.hw_tsv: 66 | return self.hw_tsv 67 | if self.label_tsv: 68 | return self.label_tsv 69 | 70 | def prepare_image_keys(self): 71 | tsv = self.get_valid_tsv() 72 | return [tsv.get_key(i) for i in range(tsv.num_rows())] 73 | 74 | def prepare_image_key_to_index(self): 75 | tsv = self.get_valid_tsv() 76 | return {tsv.get_key(i) : i for i in range(tsv.num_rows())} 77 | 78 | 79 | def augm_params(self): 80 | """Get augmentation parameters.""" 81 | flip = 0 # flipping 82 | pn = np.ones(3) # per channel pixel-noise 83 | 84 | if self.args.multiscale_inference == False: 85 | rot = 0 # rotation 86 | sc = 1.0 # scaling 87 | elif self.args.multiscale_inference == True: 88 | rot = self.args.rot 89 | sc = self.args.sc 90 | 91 | if self.is_train: 92 | sc = 1.0 93 | # Each channel is multiplied with a number 94 | # in the area [1-opt.noiseFactor,1+opt.noiseFactor] 95 | pn = np.random.uniform(1-self.noise_factor, 1+self.noise_factor, 3) 96 | 97 | # The rotation is a number in the area [-2*rotFactor, 2*rotFactor] 98 | rot = min(2*self.rot_factor, 99 | max(-2*self.rot_factor, np.random.randn()*self.rot_factor)) 100 | 101 | # The scale is multiplied with a number 102 | # in the area [1-scaleFactor,1+scaleFactor] 103 | sc = min(1+self.scale_factor, 104 | max(1-self.scale_factor, np.random.randn()*self.scale_factor+1)) 105 | # but it is zero with probability 3/5 106 | if np.random.uniform() <= 0.6: 107 | rot = 0 108 | 109 | return flip, pn, rot, sc 110 | 111 | def rgb_processing(self, rgb_img, center, scale, rot, flip, pn): 112 | """Process rgb image and do augmentation.""" 113 | rgb_img = crop(rgb_img, center, scale, 114 | [self.img_res, self.img_res], rot=rot) 115 | # flip the image 116 | if flip: 117 | rgb_img = flip_img(rgb_img) 118 | # in the rgb image we add pixel noise in a channel-wise manner 119 | rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0])) 120 | rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1])) 121 | rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2])) 122 | # (3,224,224),float,[0,1] 123 | rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0 124 | return rgb_img 125 | 126 | def j2d_processing(self, kp, center, scale, r, f): 127 | """Process gt 2D keypoints and apply all augmentation transforms.""" 128 | nparts = kp.shape[0] 129 | for i in range(nparts): 130 | kp[i,0:2] = transform(kp[i,0:2]+1, center, scale, 131 | [self.img_res, self.img_res], rot=r) 132 | # convert to normalized coordinates 133 | kp[:,:-1] = 2.*kp[:,:-1]/self.img_res - 1. 134 | # flip the x coordinates 135 | if f: 136 | kp = flip_kp(kp) 137 | kp = kp.astype('float32') 138 | return kp 139 | 140 | 141 | def j3d_processing(self, S, r, f): 142 | """Process gt 3D keypoints and apply all augmentation transforms.""" 143 | # in-plane rotation 144 | rot_mat = np.eye(3) 145 | if not r == 0: 146 | rot_rad = -r * np.pi / 180 147 | sn,cs = np.sin(rot_rad), np.cos(rot_rad) 148 | rot_mat[0,:2] = [cs, -sn] 149 | rot_mat[1,:2] = [sn, cs] 150 | S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1]) 151 | # flip the x coordinates 152 | if f: 153 | S = flip_kp(S) 154 | S = S.astype('float32') 155 | return S 156 | 157 | def pose_processing(self, pose, r, f): 158 | """Process SMPL theta parameters and apply all augmentation transforms.""" 159 | # rotation or the pose parameters 160 | pose = pose.astype('float32') 161 | pose[:3] = rot_aa(pose[:3], r) 162 | # flip the pose parameters 163 | if f: 164 | pose = flip_pose(pose) 165 | # (72),float 166 | pose = pose.astype('float32') 167 | return pose 168 | 169 | def get_line_no(self, idx): 170 | return idx if self.line_list is None else self.line_list[idx] 171 | 172 | def get_image(self, idx): 173 | line_no = self.get_line_no(idx) 174 | row = self.img_tsv[line_no] 175 | # use -1 to support old format with multiple columns. 176 | cv2_im = img_from_base64(row[-1]) 177 | if self.cv2_output: 178 | return cv2_im.astype(np.float32, copy=True) 179 | cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB) 180 | return cv2_im 181 | 182 | def get_annotations(self, idx): 183 | line_no = self.get_line_no(idx) 184 | if self.label_tsv is not None: 185 | row = self.label_tsv[line_no] 186 | annotations = json.loads(row[1]) 187 | return annotations 188 | else: 189 | return [] 190 | 191 | def get_target_from_annotations(self, annotations, img_size, idx): 192 | # This function will be overwritten by each dataset to 193 | # decode the labels to specific formats for each task. 194 | return annotations 195 | 196 | def get_img_info(self, idx): 197 | if self.hw_tsv is not None: 198 | line_no = self.get_line_no(idx) 199 | row = self.hw_tsv[line_no] 200 | try: 201 | # json string format with "height" and "width" being the keys 202 | return json.loads(row[1])[0] 203 | except ValueError: 204 | # list of strings representing height and width in order 205 | hw_str = row[1].split(' ') 206 | hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])} 207 | return hw_dict 208 | 209 | def get_img_key(self, idx): 210 | line_no = self.get_line_no(idx) 211 | # based on the overhead of reading each row. 212 | if self.hw_tsv: 213 | return self.hw_tsv[line_no][0] 214 | elif self.label_tsv: 215 | return self.label_tsv[line_no][0] 216 | else: 217 | return self.img_tsv[line_no][0] 218 | 219 | def __len__(self): 220 | if self.line_list is None: 221 | return self.img_tsv.num_rows() 222 | else: 223 | return len(self.line_list) 224 | 225 | def __getitem__(self, idx): 226 | 227 | img = self.get_image(idx) 228 | img_key = self.get_img_key(idx) 229 | annotations = self.get_annotations(idx) 230 | 231 | annotations = annotations[0] 232 | center = annotations['center'] 233 | scale = annotations['scale'] 234 | has_2d_joints = annotations['has_2d_joints'] 235 | has_3d_joints = annotations['has_3d_joints'] 236 | joints_2d = np.asarray(annotations['2d_joints']) 237 | joints_3d = np.asarray(annotations['3d_joints']) 238 | 239 | if joints_2d.ndim==3: 240 | joints_2d = joints_2d[0] 241 | if joints_3d.ndim==3: 242 | joints_3d = joints_3d[0] 243 | 244 | # Get SMPL parameters, if available 245 | has_smpl = np.asarray(annotations['has_smpl']) 246 | pose = np.asarray(annotations['pose']) 247 | betas = np.asarray(annotations['betas']) 248 | 249 | # Get augmentation parameters 250 | flip,pn,rot,sc = self.augm_params() 251 | 252 | # Process image 253 | img = self.rgb_processing(img, center, sc*scale, rot, flip, pn) 254 | img = torch.from_numpy(img).float() 255 | # Store image before normalization to use it in visualization 256 | transfromed_img = self.normalize_img(img) 257 | 258 | # normalize 3d pose by aligning the wrist as the root (at origin) 259 | root_coord = joints_3d[self.root_index,:-1] 260 | joints_3d[:,:-1] = joints_3d[:,:-1] - root_coord[None,:] 261 | # 3d pose augmentation (random flip + rotation, consistent to image and SMPL) 262 | joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, flip) 263 | # 2d pose augmentation 264 | joints_2d_transformed = self.j2d_processing(joints_2d.copy(), center, sc*scale, rot, flip) 265 | 266 | ################################### 267 | # Masking percantage 268 | # We observe that 0% or 5% works better for 3D hand mesh 269 | # We think this is probably becasue 3D vertices are quite sparse in the down-sampled hand mesh 270 | mvm_percent = 0.0 # or 0.05 271 | ################################### 272 | 273 | mjm_mask = np.ones((21,1)) 274 | if self.is_train: 275 | num_joints = 21 276 | pb = np.random.random_sample() 277 | masked_num = int(pb * mvm_percent * num_joints) # at most x% of the joints could be masked 278 | indices = np.random.choice(np.arange(num_joints),replace=False,size=masked_num) 279 | mjm_mask[indices,:] = 0.0 280 | mjm_mask = torch.from_numpy(mjm_mask).float() 281 | 282 | mvm_mask = np.ones((195,1)) 283 | if self.is_train: 284 | num_vertices = 195 285 | pb = np.random.random_sample() 286 | masked_num = int(pb * mvm_percent * num_vertices) # at most x% of the vertices could be masked 287 | indices = np.random.choice(np.arange(num_vertices),replace=False,size=masked_num) 288 | mvm_mask[indices,:] = 0.0 289 | mvm_mask = torch.from_numpy(mvm_mask).float() 290 | 291 | meta_data = {} 292 | meta_data['ori_img'] = img 293 | meta_data['pose'] = torch.from_numpy(self.pose_processing(pose, rot, flip)).float() 294 | meta_data['betas'] = torch.from_numpy(betas).float() 295 | meta_data['joints_3d'] = torch.from_numpy(joints_3d_transformed).float() 296 | meta_data['has_3d_joints'] = has_3d_joints 297 | meta_data['has_smpl'] = has_smpl 298 | meta_data['mjm_mask'] = mjm_mask 299 | meta_data['mvm_mask'] = mvm_mask 300 | 301 | # Get 2D keypoints and apply augmentation transforms 302 | meta_data['has_2d_joints'] = has_2d_joints 303 | meta_data['joints_2d'] = torch.from_numpy(joints_2d_transformed).float() 304 | 305 | meta_data['scale'] = float(sc * scale) 306 | meta_data['center'] = np.asarray(center).astype(np.float32) 307 | 308 | return img_key, transfromed_img, meta_data 309 | 310 | 311 | class HandMeshTSVYamlDataset(HandMeshTSVDataset): 312 | """ TSVDataset taking a Yaml file for easy function call 313 | """ 314 | def __init__(self, args, yaml_file, is_train=True, cv2_output=False, scale_factor=1): 315 | self.cfg = load_from_yaml_file(yaml_file) 316 | self.is_composite = self.cfg.get('composite', False) 317 | self.root = op.dirname(yaml_file) 318 | 319 | if self.is_composite==False: 320 | img_file = find_file_path_in_yaml(self.cfg['img'], self.root) 321 | label_file = find_file_path_in_yaml(self.cfg.get('label', None), 322 | self.root) 323 | hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root) 324 | linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None), 325 | self.root) 326 | else: 327 | img_file = self.cfg['img'] 328 | hw_file = self.cfg['hw'] 329 | label_file = self.cfg.get('label', None) 330 | linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None), 331 | self.root) 332 | 333 | super(HandMeshTSVYamlDataset, self).__init__( 334 | args, img_file, label_file, hw_file, linelist_file, is_train, cv2_output=cv2_output, scale_factor=scale_factor) 335 | -------------------------------------------------------------------------------- /src/datasets/human_mesh_tsv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | """ 6 | 7 | import cv2 8 | import math 9 | import json 10 | from PIL import Image 11 | import os.path as op 12 | import numpy as np 13 | import code 14 | 15 | from src.utils.tsv_file import TSVFile, CompositeTSVFile 16 | from src.utils.tsv_file_ops import load_linelist_file, load_from_yaml_file, find_file_path_in_yaml 17 | from src.utils.image_ops import img_from_base64, crop, flip_img, flip_pose, flip_kp, transform, rot_aa 18 | import torch 19 | import torchvision.transforms as transforms 20 | 21 | 22 | class MeshTSVDataset(object): 23 | def __init__(self, img_file, label_file=None, hw_file=None, 24 | linelist_file=None, is_train=True, cv2_output=False, scale_factor=1): 25 | 26 | self.img_file = img_file 27 | self.label_file = label_file 28 | self.hw_file = hw_file 29 | self.linelist_file = linelist_file 30 | self.img_tsv = self.get_tsv_file(img_file) 31 | self.label_tsv = None if label_file is None else self.get_tsv_file(label_file) 32 | self.hw_tsv = None if hw_file is None else self.get_tsv_file(hw_file) 33 | 34 | if self.is_composite: 35 | assert op.isfile(self.linelist_file) 36 | self.line_list = [i for i in range(self.hw_tsv.num_rows())] 37 | else: 38 | self.line_list = load_linelist_file(linelist_file) 39 | 40 | self.cv2_output = cv2_output 41 | self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], 42 | std=[0.229, 0.224, 0.225]) 43 | self.is_train = is_train 44 | self.scale_factor = 0.25 # rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor] 45 | self.noise_factor = 0.4 46 | self.rot_factor = 30 # Random rotation in the range [-rot_factor, rot_factor] 47 | self.img_res = 224 48 | 49 | self.image_keys = self.prepare_image_keys() 50 | 51 | self.joints_definition = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder', 52 | 'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear') 53 | self.pelvis_index = self.joints_definition.index('Pelvis') 54 | 55 | def get_tsv_file(self, tsv_file): 56 | if tsv_file: 57 | if self.is_composite: 58 | return CompositeTSVFile(tsv_file, self.linelist_file, 59 | root=self.root) 60 | tsv_path = find_file_path_in_yaml(tsv_file, self.root) 61 | return TSVFile(tsv_path) 62 | 63 | def get_valid_tsv(self): 64 | # sorted by file size 65 | if self.hw_tsv: 66 | return self.hw_tsv 67 | if self.label_tsv: 68 | return self.label_tsv 69 | 70 | def prepare_image_keys(self): 71 | tsv = self.get_valid_tsv() 72 | return [tsv.get_key(i) for i in range(tsv.num_rows())] 73 | 74 | def prepare_image_key_to_index(self): 75 | tsv = self.get_valid_tsv() 76 | return {tsv.get_key(i) : i for i in range(tsv.num_rows())} 77 | 78 | 79 | def augm_params(self): 80 | """Get augmentation parameters.""" 81 | flip = 0 # flipping 82 | pn = np.ones(3) # per channel pixel-noise 83 | rot = 0 # rotation 84 | sc = 1 # scaling 85 | if self.is_train: 86 | # We flip with probability 1/2 87 | if np.random.uniform() <= 0.5: 88 | flip = 1 89 | 90 | # Each channel is multiplied with a number 91 | # in the area [1-opt.noiseFactor,1+opt.noiseFactor] 92 | pn = np.random.uniform(1-self.noise_factor, 1+self.noise_factor, 3) 93 | 94 | # The rotation is a number in the area [-2*rotFactor, 2*rotFactor] 95 | rot = min(2*self.rot_factor, 96 | max(-2*self.rot_factor, np.random.randn()*self.rot_factor)) 97 | 98 | # The scale is multiplied with a number 99 | # in the area [1-scaleFactor,1+scaleFactor] 100 | sc = min(1+self.scale_factor, 101 | max(1-self.scale_factor, np.random.randn()*self.scale_factor+1)) 102 | # but it is zero with probability 3/5 103 | if np.random.uniform() <= 0.6: 104 | rot = 0 105 | 106 | return flip, pn, rot, sc 107 | 108 | def rgb_processing(self, rgb_img, center, scale, rot, flip, pn): 109 | """Process rgb image and do augmentation.""" 110 | rgb_img = crop(rgb_img, center, scale, 111 | [self.img_res, self.img_res], rot=rot) 112 | # flip the image 113 | if flip: 114 | rgb_img = flip_img(rgb_img) 115 | # in the rgb image we add pixel noise in a channel-wise manner 116 | rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0])) 117 | rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1])) 118 | rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2])) 119 | # (3,224,224),float,[0,1] 120 | rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0 121 | return rgb_img 122 | 123 | def j2d_processing(self, kp, center, scale, r, f): 124 | """Process gt 2D keypoints and apply all augmentation transforms.""" 125 | nparts = kp.shape[0] 126 | for i in range(nparts): 127 | kp[i,0:2] = transform(kp[i,0:2]+1, center, scale, 128 | [self.img_res, self.img_res], rot=r) 129 | # convert to normalized coordinates 130 | kp[:,:-1] = 2.*kp[:,:-1]/self.img_res - 1. 131 | # flip the x coordinates 132 | if f: 133 | kp = flip_kp(kp) 134 | kp = kp.astype('float32') 135 | return kp 136 | 137 | def j3d_processing(self, S, r, f): 138 | """Process gt 3D keypoints and apply all augmentation transforms.""" 139 | # in-plane rotation 140 | rot_mat = np.eye(3) 141 | if not r == 0: 142 | rot_rad = -r * np.pi / 180 143 | sn,cs = np.sin(rot_rad), np.cos(rot_rad) 144 | rot_mat[0,:2] = [cs, -sn] 145 | rot_mat[1,:2] = [sn, cs] 146 | S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1]) 147 | # flip the x coordinates 148 | if f: 149 | S = flip_kp(S) 150 | S = S.astype('float32') 151 | return S 152 | 153 | def pose_processing(self, pose, r, f): 154 | """Process SMPL theta parameters and apply all augmentation transforms.""" 155 | # rotation or the pose parameters 156 | pose = pose.astype('float32') 157 | pose[:3] = rot_aa(pose[:3], r) 158 | # flip the pose parameters 159 | if f: 160 | pose = flip_pose(pose) 161 | # (72),float 162 | pose = pose.astype('float32') 163 | return pose 164 | 165 | def get_line_no(self, idx): 166 | return idx if self.line_list is None else self.line_list[idx] 167 | 168 | def get_image(self, idx): 169 | line_no = self.get_line_no(idx) 170 | row = self.img_tsv[line_no] 171 | # use -1 to support old format with multiple columns. 172 | cv2_im = img_from_base64(row[-1]) 173 | if self.cv2_output: 174 | return cv2_im.astype(np.float32, copy=True) 175 | cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB) 176 | 177 | return cv2_im 178 | 179 | def get_annotations(self, idx): 180 | line_no = self.get_line_no(idx) 181 | if self.label_tsv is not None: 182 | row = self.label_tsv[line_no] 183 | annotations = json.loads(row[1]) 184 | return annotations 185 | else: 186 | return [] 187 | 188 | def get_target_from_annotations(self, annotations, img_size, idx): 189 | # This function will be overwritten by each dataset to 190 | # decode the labels to specific formats for each task. 191 | return annotations 192 | 193 | 194 | def get_img_info(self, idx): 195 | if self.hw_tsv is not None: 196 | line_no = self.get_line_no(idx) 197 | row = self.hw_tsv[line_no] 198 | try: 199 | # json string format with "height" and "width" being the keys 200 | return json.loads(row[1])[0] 201 | except ValueError: 202 | # list of strings representing height and width in order 203 | hw_str = row[1].split(' ') 204 | hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])} 205 | return hw_dict 206 | 207 | def get_img_key(self, idx): 208 | line_no = self.get_line_no(idx) 209 | # based on the overhead of reading each row. 210 | if self.hw_tsv: 211 | return self.hw_tsv[line_no][0] 212 | elif self.label_tsv: 213 | return self.label_tsv[line_no][0] 214 | else: 215 | return self.img_tsv[line_no][0] 216 | 217 | def __len__(self): 218 | if self.line_list is None: 219 | return self.img_tsv.num_rows() 220 | else: 221 | return len(self.line_list) 222 | 223 | def __getitem__(self, idx): 224 | 225 | img = self.get_image(idx) 226 | img_key = self.get_img_key(idx) 227 | annotations = self.get_annotations(idx) 228 | 229 | annotations = annotations[0] 230 | center = annotations['center'] 231 | scale = annotations['scale'] 232 | has_2d_joints = annotations['has_2d_joints'] 233 | has_3d_joints = annotations['has_3d_joints'] 234 | joints_2d = np.asarray(annotations['2d_joints']) 235 | joints_3d = np.asarray(annotations['3d_joints']) 236 | 237 | if joints_2d.ndim==3: 238 | joints_2d = joints_2d[0] 239 | if joints_3d.ndim==3: 240 | joints_3d = joints_3d[0] 241 | 242 | # Get SMPL parameters, if available 243 | has_smpl = np.asarray(annotations['has_smpl']) 244 | pose = np.asarray(annotations['pose']) 245 | betas = np.asarray(annotations['betas']) 246 | 247 | try: 248 | gender = annotations['gender'] 249 | except KeyError: 250 | gender = 'none' 251 | 252 | # Get augmentation parameters 253 | flip,pn,rot,sc = self.augm_params() 254 | 255 | # Process image 256 | img = self.rgb_processing(img, center, sc*scale, rot, flip, pn) 257 | img = torch.from_numpy(img).float() 258 | # Store image before normalization to use it in visualization 259 | transfromed_img = self.normalize_img(img) 260 | 261 | # normalize 3d pose by aligning the pelvis as the root (at origin) 262 | root_pelvis = joints_3d[self.pelvis_index,:-1] 263 | joints_3d[:,:-1] = joints_3d[:,:-1] - root_pelvis[None,:] 264 | # 3d pose augmentation (random flip + rotation, consistent to image and SMPL) 265 | joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, flip) 266 | # 2d pose augmentation 267 | joints_2d_transformed = self.j2d_processing(joints_2d.copy(), center, sc*scale, rot, flip) 268 | 269 | ################################### 270 | # Masking percantage 271 | # We observe that 30% works better for human body mesh. Further details are reported in the paper. 272 | mvm_percent = 0.3 273 | ################################### 274 | 275 | mjm_mask = np.ones((14,1)) 276 | if self.is_train: 277 | num_joints = 14 278 | pb = np.random.random_sample() 279 | masked_num = int(pb * mvm_percent * num_joints) # at most x% of the joints could be masked 280 | indices = np.random.choice(np.arange(num_joints),replace=False,size=masked_num) 281 | mjm_mask[indices,:] = 0.0 282 | mjm_mask = torch.from_numpy(mjm_mask).float() 283 | 284 | mvm_mask = np.ones((431,1)) 285 | if self.is_train: 286 | num_vertices = 431 287 | pb = np.random.random_sample() 288 | masked_num = int(pb * mvm_percent * num_vertices) # at most x% of the vertices could be masked 289 | indices = np.random.choice(np.arange(num_vertices),replace=False,size=masked_num) 290 | mvm_mask[indices,:] = 0.0 291 | mvm_mask = torch.from_numpy(mvm_mask).float() 292 | 293 | meta_data = {} 294 | meta_data['ori_img'] = img 295 | meta_data['pose'] = torch.from_numpy(self.pose_processing(pose, rot, flip)).float() 296 | meta_data['betas'] = torch.from_numpy(betas).float() 297 | meta_data['joints_3d'] = torch.from_numpy(joints_3d_transformed).float() 298 | meta_data['has_3d_joints'] = has_3d_joints 299 | meta_data['has_smpl'] = has_smpl 300 | 301 | meta_data['mjm_mask'] = mjm_mask 302 | meta_data['mvm_mask'] = mvm_mask 303 | 304 | # Get 2D keypoints and apply augmentation transforms 305 | meta_data['has_2d_joints'] = has_2d_joints 306 | meta_data['joints_2d'] = torch.from_numpy(joints_2d_transformed).float() 307 | meta_data['scale'] = float(sc * scale) 308 | meta_data['center'] = np.asarray(center).astype(np.float32) 309 | meta_data['gender'] = gender 310 | return img_key, transfromed_img, meta_data 311 | 312 | 313 | 314 | class MeshTSVYamlDataset(MeshTSVDataset): 315 | """ TSVDataset taking a Yaml file for easy function call 316 | """ 317 | def __init__(self, yaml_file, is_train=True, cv2_output=False, scale_factor=1): 318 | self.cfg = load_from_yaml_file(yaml_file) 319 | self.is_composite = self.cfg.get('composite', False) 320 | self.root = op.dirname(yaml_file) 321 | 322 | if self.is_composite==False: 323 | img_file = find_file_path_in_yaml(self.cfg['img'], self.root) 324 | label_file = find_file_path_in_yaml(self.cfg.get('label', None), 325 | self.root) 326 | hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root) 327 | linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None), 328 | self.root) 329 | else: 330 | img_file = self.cfg['img'] 331 | hw_file = self.cfg['hw'] 332 | label_file = self.cfg.get('label', None) 333 | linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None), 334 | self.root) 335 | 336 | super(MeshTSVYamlDataset, self).__init__( 337 | img_file, label_file, hw_file, linelist_file, is_train, cv2_output=cv2_output, scale_factor=scale_factor) 338 | -------------------------------------------------------------------------------- /src/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/__init__.py -------------------------------------------------------------------------------- /src/modeling/_gcnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import scipy.sparse 6 | import math 7 | 8 | class SparseMM(torch.autograd.Function): 9 | """Redefine sparse @ dense matrix multiplication to enable backpropagation. 10 | The builtin matrix multiplication operation does not support backpropagation in some cases. 11 | """ 12 | @staticmethod 13 | def forward(ctx, sparse, dense): 14 | ctx.req_grad = dense.requires_grad 15 | ctx.save_for_backward(sparse) 16 | return torch.matmul(sparse, dense) 17 | 18 | @staticmethod 19 | def backward(ctx, grad_output): 20 | grad_input = None 21 | sparse, = ctx.saved_tensors 22 | if ctx.req_grad: 23 | grad_input = torch.matmul(sparse.t(), grad_output) 24 | return None, grad_input 25 | 26 | def spmm(sparse, dense): 27 | return SparseMM.apply(sparse, dense) 28 | 29 | 30 | def gelu(x): 31 | """Implementation of the gelu activation function. 32 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 33 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 34 | Also see https://arxiv.org/abs/1606.08415 35 | """ 36 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 37 | 38 | class BertLayerNorm(torch.nn.Module): 39 | def __init__(self, hidden_size, eps=1e-12): 40 | """Construct a layernorm module in the TF style (epsilon inside the square root). 41 | """ 42 | super(BertLayerNorm, self).__init__() 43 | self.weight = torch.nn.Parameter(torch.ones(hidden_size)) 44 | self.bias = torch.nn.Parameter(torch.zeros(hidden_size)) 45 | self.variance_epsilon = eps 46 | 47 | def forward(self, x): 48 | u = x.mean(-1, keepdim=True) 49 | s = (x - u).pow(2).mean(-1, keepdim=True) 50 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 51 | return self.weight * x + self.bias 52 | 53 | 54 | class GraphResBlock(torch.nn.Module): 55 | """ 56 | Graph Residual Block similar to the Bottleneck Residual Block in ResNet 57 | """ 58 | def __init__(self, in_channels, out_channels, mesh_type='body'): 59 | super(GraphResBlock, self).__init__() 60 | self.in_channels = in_channels 61 | self.out_channels = out_channels 62 | self.lin1 = GraphLinear(in_channels, out_channels // 2) 63 | self.conv = GraphConvolution(out_channels // 2, out_channels // 2, mesh_type) 64 | self.lin2 = GraphLinear(out_channels // 2, out_channels) 65 | self.skip_conv = GraphLinear(in_channels, out_channels) 66 | # print('Use BertLayerNorm in GraphResBlock') 67 | self.pre_norm = BertLayerNorm(in_channels) 68 | self.norm1 = BertLayerNorm(out_channels // 2) 69 | self.norm2 = BertLayerNorm(out_channels // 2) 70 | 71 | def forward(self, x): 72 | trans_y = F.relu(self.pre_norm(x)).transpose(1,2) 73 | y = self.lin1(trans_y).transpose(1,2) 74 | 75 | y = F.relu(self.norm1(y)) 76 | y = self.conv(y) 77 | 78 | trans_y = F.relu(self.norm2(y)).transpose(1,2) 79 | y = self.lin2(trans_y).transpose(1,2) 80 | 81 | z = x+y 82 | 83 | return z 84 | 85 | # class GraphResBlock(torch.nn.Module): 86 | # """ 87 | # Graph Residual Block similar to the Bottleneck Residual Block in ResNet 88 | # """ 89 | # def __init__(self, in_channels, out_channels, mesh_type='body'): 90 | # super(GraphResBlock, self).__init__() 91 | # self.in_channels = in_channels 92 | # self.out_channels = out_channels 93 | # self.conv = GraphConvolution(self.in_channels, self.out_channels, mesh_type) 94 | # print('Use BertLayerNorm and GeLU in GraphResBlock') 95 | # self.norm = BertLayerNorm(self.out_channels) 96 | # def forward(self, x): 97 | # y = self.conv(x) 98 | # y = self.norm(y) 99 | # y = gelu(y) 100 | # z = x+y 101 | # return z 102 | 103 | class GraphLinear(torch.nn.Module): 104 | """ 105 | Generalization of 1x1 convolutions on Graphs 106 | """ 107 | def __init__(self, in_channels, out_channels): 108 | super(GraphLinear, self).__init__() 109 | self.in_channels = in_channels 110 | self.out_channels = out_channels 111 | self.W = torch.nn.Parameter(torch.FloatTensor(out_channels, in_channels)) 112 | self.b = torch.nn.Parameter(torch.FloatTensor(out_channels)) 113 | self.reset_parameters() 114 | 115 | def reset_parameters(self): 116 | w_stdv = 1 / (self.in_channels * self.out_channels) 117 | self.W.data.uniform_(-w_stdv, w_stdv) 118 | self.b.data.uniform_(-w_stdv, w_stdv) 119 | 120 | def forward(self, x): 121 | return torch.matmul(self.W[None, :], x) + self.b[None, :, None] 122 | 123 | class GraphConvolution(torch.nn.Module): 124 | """Simple GCN layer, similar to https://arxiv.org/abs/1609.02907.""" 125 | def __init__(self, in_features, out_features, mesh='body', bias=True): 126 | super(GraphConvolution, self).__init__() 127 | device=torch.device('cuda') 128 | self.in_features = in_features 129 | self.out_features = out_features 130 | 131 | if mesh=='body': 132 | adj_indices = torch.load('./src/modeling/data/smpl_431_adjmat_indices.pt') 133 | adj_mat_value = torch.load('./src/modeling/data/smpl_431_adjmat_values.pt') 134 | adj_mat_size = torch.load('./src/modeling/data/smpl_431_adjmat_size.pt') 135 | elif mesh=='hand': 136 | adj_indices = torch.load('./src/modeling/data/mano_195_adjmat_indices.pt') 137 | adj_mat_value = torch.load('./src/modeling/data/mano_195_adjmat_values.pt') 138 | adj_mat_size = torch.load('./src/modeling/data/mano_195_adjmat_size.pt') 139 | 140 | self.adjmat = torch.sparse_coo_tensor(adj_indices, adj_mat_value, size=adj_mat_size).to(device) 141 | 142 | self.weight = torch.nn.Parameter(torch.FloatTensor(in_features, out_features)) 143 | if bias: 144 | self.bias = torch.nn.Parameter(torch.FloatTensor(out_features)) 145 | else: 146 | self.register_parameter('bias', None) 147 | self.reset_parameters() 148 | 149 | def reset_parameters(self): 150 | # stdv = 1. / math.sqrt(self.weight.size(1)) 151 | stdv = 6. / math.sqrt(self.weight.size(0) + self.weight.size(1)) 152 | self.weight.data.uniform_(-stdv, stdv) 153 | if self.bias is not None: 154 | self.bias.data.uniform_(-stdv, stdv) 155 | 156 | def forward(self, x): 157 | if x.ndimension() == 2: 158 | support = torch.matmul(x, self.weight) 159 | output = torch.matmul(self.adjmat, support) 160 | if self.bias is not None: 161 | output = output + self.bias 162 | return output 163 | else: 164 | output = [] 165 | for i in range(x.shape[0]): 166 | support = torch.matmul(x[i], self.weight) 167 | # output.append(torch.matmul(self.adjmat, support)) 168 | output.append(spmm(self.adjmat, support)) 169 | output = torch.stack(output, dim=0) 170 | if self.bias is not None: 171 | output = output + self.bias 172 | return output 173 | 174 | def __repr__(self): 175 | return self.__class__.__name__ + ' (' \ 176 | + str(self.in_features) + ' -> ' \ 177 | + str(self.out_features) + ')' -------------------------------------------------------------------------------- /src/modeling/_mano.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the MANO defination and mesh sampling operations for MANO mesh 3 | 4 | Adapted from opensource projects 5 | MANOPTH (https://github.com/hassony2/manopth) 6 | Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE) 7 | GraphCMR (https://github.com/nkolot/GraphCMR/) 8 | """ 9 | 10 | from __future__ import division 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import os.path as osp 15 | import json 16 | import code 17 | from manopth.manolayer import ManoLayer 18 | import scipy.sparse 19 | import src.modeling.data.config as cfg 20 | 21 | class MANO(nn.Module): 22 | def __init__(self): 23 | super(MANO, self).__init__() 24 | 25 | self.mano_dir = 'src/modeling/data' 26 | self.layer = self.get_layer() 27 | self.vertex_num = 778 28 | self.face = self.layer.th_faces.numpy() 29 | self.joint_regressor = self.layer.th_J_regressor.numpy() 30 | 31 | self.joint_num = 21 32 | self.joints_name = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4') 33 | self.skeleton = ( (0,1), (0,5), (0,9), (0,13), (0,17), (1,2), (2,3), (3,4), (5,6), (6,7), (7,8), (9,10), (10,11), (11,12), (13,14), (14,15), (15,16), (17,18), (18,19), (19,20) ) 34 | self.root_joint_idx = self.joints_name.index('Wrist') 35 | 36 | # add fingertips to joint_regressor 37 | self.fingertip_vertex_idx = [745, 317, 444, 556, 673] # mesh vertex idx (right hand) 38 | thumbtip_onehot = np.array([1 if i == 745 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) 39 | indextip_onehot = np.array([1 if i == 317 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) 40 | middletip_onehot = np.array([1 if i == 445 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) 41 | ringtip_onehot = np.array([1 if i == 556 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) 42 | pinkytip_onehot = np.array([1 if i == 673 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) 43 | self.joint_regressor = np.concatenate((self.joint_regressor, thumbtip_onehot, indextip_onehot, middletip_onehot, ringtip_onehot, pinkytip_onehot)) 44 | self.joint_regressor = self.joint_regressor[[0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20],:] 45 | joint_regressor_torch = torch.from_numpy(self.joint_regressor).float() 46 | self.register_buffer('joint_regressor_torch', joint_regressor_torch) 47 | 48 | def get_layer(self): 49 | return ManoLayer(mano_root=osp.join(self.mano_dir), flat_hand_mean=False, use_pca=False) # load right hand MANO model 50 | 51 | def get_3d_joints(self, vertices): 52 | """ 53 | This method is used to get the joint locations from the SMPL mesh 54 | Input: 55 | vertices: size = (B, 778, 3) 56 | Output: 57 | 3D joints: size = (B, 21, 3) 58 | """ 59 | joints = torch.einsum('bik,ji->bjk', [vertices, self.joint_regressor_torch]) 60 | return joints 61 | 62 | 63 | class SparseMM(torch.autograd.Function): 64 | """Redefine sparse @ dense matrix multiplication to enable backpropagation. 65 | The builtin matrix multiplication operation does not support backpropagation in some cases. 66 | """ 67 | @staticmethod 68 | def forward(ctx, sparse, dense): 69 | ctx.req_grad = dense.requires_grad 70 | ctx.save_for_backward(sparse) 71 | return torch.matmul(sparse, dense) 72 | 73 | @staticmethod 74 | def backward(ctx, grad_output): 75 | grad_input = None 76 | sparse, = ctx.saved_tensors 77 | if ctx.req_grad: 78 | grad_input = torch.matmul(sparse.t(), grad_output) 79 | return None, grad_input 80 | 81 | def spmm(sparse, dense): 82 | return SparseMM.apply(sparse, dense) 83 | 84 | 85 | def scipy_to_pytorch(A, U, D): 86 | """Convert scipy sparse matrices to pytorch sparse matrix.""" 87 | ptU = [] 88 | ptD = [] 89 | 90 | for i in range(len(U)): 91 | u = scipy.sparse.coo_matrix(U[i]) 92 | i = torch.LongTensor(np.array([u.row, u.col])) 93 | v = torch.FloatTensor(u.data) 94 | ptU.append(torch.sparse.FloatTensor(i, v, u.shape)) 95 | 96 | for i in range(len(D)): 97 | d = scipy.sparse.coo_matrix(D[i]) 98 | i = torch.LongTensor(np.array([d.row, d.col])) 99 | v = torch.FloatTensor(d.data) 100 | ptD.append(torch.sparse.FloatTensor(i, v, d.shape)) 101 | 102 | return ptU, ptD 103 | 104 | 105 | def adjmat_sparse(adjmat, nsize=1): 106 | """Create row-normalized sparse graph adjacency matrix.""" 107 | adjmat = scipy.sparse.csr_matrix(adjmat) 108 | if nsize > 1: 109 | orig_adjmat = adjmat.copy() 110 | for _ in range(1, nsize): 111 | adjmat = adjmat * orig_adjmat 112 | adjmat.data = np.ones_like(adjmat.data) 113 | for i in range(adjmat.shape[0]): 114 | adjmat[i,i] = 1 115 | num_neighbors = np.array(1 / adjmat.sum(axis=-1)) 116 | adjmat = adjmat.multiply(num_neighbors) 117 | adjmat = scipy.sparse.coo_matrix(adjmat) 118 | row = adjmat.row 119 | col = adjmat.col 120 | data = adjmat.data 121 | i = torch.LongTensor(np.array([row, col])) 122 | v = torch.from_numpy(data).float() 123 | adjmat = torch.sparse.FloatTensor(i, v, adjmat.shape) 124 | return adjmat 125 | 126 | def get_graph_params(filename, nsize=1): 127 | """Load and process graph adjacency matrix and upsampling/downsampling matrices.""" 128 | data = np.load(filename, encoding='latin1', allow_pickle=True) 129 | A = data['A'] 130 | U = data['U'] 131 | D = data['D'] 132 | U, D = scipy_to_pytorch(A, U, D) 133 | A = [adjmat_sparse(a, nsize=nsize) for a in A] 134 | return A, U, D 135 | 136 | 137 | class Mesh(object): 138 | """Mesh object that is used for handling certain graph operations.""" 139 | def __init__(self, filename=cfg.MANO_sampling_matrix, 140 | num_downsampling=1, nsize=1, device=torch.device('cuda')): 141 | self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize) 142 | # self._A = [a.to(device) for a in self._A] 143 | self._U = [u.to(device) for u in self._U] 144 | self._D = [d.to(device) for d in self._D] 145 | self.num_downsampling = num_downsampling 146 | 147 | def downsample(self, x, n1=0, n2=None): 148 | """Downsample mesh.""" 149 | if n2 is None: 150 | n2 = self.num_downsampling 151 | if x.ndimension() < 3: 152 | for i in range(n1, n2): 153 | x = spmm(self._D[i], x) 154 | elif x.ndimension() == 3: 155 | out = [] 156 | for i in range(x.shape[0]): 157 | y = x[i] 158 | for j in range(n1, n2): 159 | y = spmm(self._D[j], y) 160 | out.append(y) 161 | x = torch.stack(out, dim=0) 162 | return x 163 | 164 | def upsample(self, x, n1=1, n2=0): 165 | """Upsample mesh.""" 166 | if x.ndimension() < 3: 167 | for i in reversed(range(n2, n1)): 168 | x = spmm(self._U[i], x) 169 | elif x.ndimension() == 3: 170 | out = [] 171 | for i in range(x.shape[0]): 172 | y = x[i] 173 | for j in reversed(range(n2, n1)): 174 | y = spmm(self._U[j], y) 175 | out.append(y) 176 | x = torch.stack(out, dim=0) 177 | return x 178 | -------------------------------------------------------------------------------- /src/modeling/_smpl.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the definition of the SMPL model 3 | 4 | It is adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/) 5 | """ 6 | from __future__ import division 7 | 8 | import torch 9 | import torch.nn as nn 10 | import numpy as np 11 | import scipy.sparse 12 | try: 13 | import cPickle as pickle 14 | except ImportError: 15 | import pickle 16 | 17 | from src.utils.geometric_layers import rodrigues 18 | import src.modeling.data.config as cfg 19 | 20 | class SMPL(nn.Module): 21 | 22 | def __init__(self, gender='neutral'): 23 | super(SMPL, self).__init__() 24 | 25 | if gender=='m': 26 | model_file=cfg.SMPL_Male 27 | elif gender=='f': 28 | model_file=cfg.SMPL_Female 29 | else: 30 | model_file=cfg.SMPL_FILE 31 | 32 | smpl_model = pickle.load(open(model_file, 'rb'), encoding='latin1') 33 | J_regressor = smpl_model['J_regressor'].tocoo() 34 | row = J_regressor.row 35 | col = J_regressor.col 36 | data = J_regressor.data 37 | i = torch.LongTensor([row, col]) 38 | v = torch.FloatTensor(data) 39 | J_regressor_shape = [24, 6890] 40 | self.register_buffer('J_regressor', torch.sparse.FloatTensor(i, v, J_regressor_shape).to_dense()) 41 | self.register_buffer('weights', torch.FloatTensor(smpl_model['weights'])) 42 | self.register_buffer('posedirs', torch.FloatTensor(smpl_model['posedirs'])) 43 | self.register_buffer('v_template', torch.FloatTensor(smpl_model['v_template'])) 44 | self.register_buffer('shapedirs', torch.FloatTensor(np.array(smpl_model['shapedirs']))) 45 | self.register_buffer('faces', torch.from_numpy(smpl_model['f'].astype(np.int64))) 46 | self.register_buffer('kintree_table', torch.from_numpy(smpl_model['kintree_table'].astype(np.int64))) 47 | id_to_col = {self.kintree_table[1, i].item(): i for i in range(self.kintree_table.shape[1])} 48 | self.register_buffer('parent', torch.LongTensor([id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1])])) 49 | 50 | self.pose_shape = [24, 3] 51 | self.beta_shape = [10] 52 | self.translation_shape = [3] 53 | 54 | self.pose = torch.zeros(self.pose_shape) 55 | self.beta = torch.zeros(self.beta_shape) 56 | self.translation = torch.zeros(self.translation_shape) 57 | 58 | self.verts = None 59 | self.J = None 60 | self.R = None 61 | 62 | J_regressor_extra = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_TRAIN_EXTRA)).float() 63 | self.register_buffer('J_regressor_extra', J_regressor_extra) 64 | self.joints_idx = cfg.JOINTS_IDX 65 | 66 | J_regressor_h36m_correct = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_H36M_correct)).float() 67 | self.register_buffer('J_regressor_h36m_correct', J_regressor_h36m_correct) 68 | 69 | 70 | def forward(self, pose, beta): 71 | device = pose.device 72 | batch_size = pose.shape[0] 73 | v_template = self.v_template[None, :] 74 | shapedirs = self.shapedirs.view(-1,10)[None, :].expand(batch_size, -1, -1) 75 | beta = beta[:, :, None] 76 | v_shaped = torch.matmul(shapedirs, beta).view(-1, 6890, 3) + v_template 77 | # batched sparse matmul not supported in pytorch 78 | J = [] 79 | for i in range(batch_size): 80 | J.append(torch.matmul(self.J_regressor, v_shaped[i])) 81 | J = torch.stack(J, dim=0) 82 | # input it rotmat: (bs,24,3,3) 83 | if pose.ndimension() == 4: 84 | R = pose 85 | # input it rotmat: (bs,72) 86 | elif pose.ndimension() == 2: 87 | pose_cube = pose.view(-1, 3) # (batch_size * 24, 1, 3) 88 | R = rodrigues(pose_cube).view(batch_size, 24, 3, 3) 89 | R = R.view(batch_size, 24, 3, 3) 90 | I_cube = torch.eye(3)[None, None, :].to(device) 91 | # I_cube = torch.eye(3)[None, None, :].expand(theta.shape[0], R.shape[1]-1, -1, -1) 92 | lrotmin = (R[:,1:,:] - I_cube).view(batch_size, -1) 93 | posedirs = self.posedirs.view(-1,207)[None, :].expand(batch_size, -1, -1) 94 | v_posed = v_shaped + torch.matmul(posedirs, lrotmin[:, :, None]).view(-1, 6890, 3) 95 | J_ = J.clone() 96 | J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :] 97 | G_ = torch.cat([R, J_[:, :, :, None]], dim=-1) 98 | pad_row = torch.FloatTensor([0,0,0,1]).to(device).view(1,1,1,4).expand(batch_size, 24, -1, -1) 99 | G_ = torch.cat([G_, pad_row], dim=2) 100 | G = [G_[:, 0].clone()] 101 | for i in range(1, 24): 102 | G.append(torch.matmul(G[self.parent[i-1]], G_[:, i, :, :])) 103 | G = torch.stack(G, dim=1) 104 | 105 | rest = torch.cat([J, torch.zeros(batch_size, 24, 1).to(device)], dim=2).view(batch_size, 24, 4, 1) 106 | zeros = torch.zeros(batch_size, 24, 4, 3).to(device) 107 | rest = torch.cat([zeros, rest], dim=-1) 108 | rest = torch.matmul(G, rest) 109 | G = G - rest 110 | T = torch.matmul(self.weights, G.permute(1,0,2,3).contiguous().view(24,-1)).view(6890, batch_size, 4, 4).transpose(0,1) 111 | rest_shape_h = torch.cat([v_posed, torch.ones_like(v_posed)[:, :, [0]]], dim=-1) 112 | v = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0] 113 | return v 114 | 115 | def get_joints(self, vertices): 116 | """ 117 | This method is used to get the joint locations from the SMPL mesh 118 | Input: 119 | vertices: size = (B, 6890, 3) 120 | Output: 121 | 3D joints: size = (B, 38, 3) 122 | """ 123 | joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor]) 124 | joints_extra = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_extra]) 125 | joints = torch.cat((joints, joints_extra), dim=1) 126 | joints = joints[:, cfg.JOINTS_IDX] 127 | return joints 128 | 129 | def get_h36m_joints(self, vertices): 130 | """ 131 | This method is used to get the joint locations from the SMPL mesh 132 | Input: 133 | vertices: size = (B, 6890, 3) 134 | Output: 135 | 3D joints: size = (B, 24, 3) 136 | """ 137 | joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct]) 138 | return joints 139 | 140 | class SparseMM(torch.autograd.Function): 141 | """Redefine sparse @ dense matrix multiplication to enable backpropagation. 142 | The builtin matrix multiplication operation does not support backpropagation in some cases. 143 | """ 144 | @staticmethod 145 | def forward(ctx, sparse, dense): 146 | ctx.req_grad = dense.requires_grad 147 | ctx.save_for_backward(sparse) 148 | return torch.matmul(sparse, dense) 149 | 150 | @staticmethod 151 | def backward(ctx, grad_output): 152 | grad_input = None 153 | sparse, = ctx.saved_tensors 154 | if ctx.req_grad: 155 | grad_input = torch.matmul(sparse.t(), grad_output) 156 | return None, grad_input 157 | 158 | def spmm(sparse, dense): 159 | return SparseMM.apply(sparse, dense) 160 | 161 | 162 | def scipy_to_pytorch(A, U, D): 163 | """Convert scipy sparse matrices to pytorch sparse matrix.""" 164 | ptU = [] 165 | ptD = [] 166 | 167 | for i in range(len(U)): 168 | u = scipy.sparse.coo_matrix(U[i]) 169 | i = torch.LongTensor(np.array([u.row, u.col])) 170 | v = torch.FloatTensor(u.data) 171 | ptU.append(torch.sparse.FloatTensor(i, v, u.shape)) 172 | 173 | for i in range(len(D)): 174 | d = scipy.sparse.coo_matrix(D[i]) 175 | i = torch.LongTensor(np.array([d.row, d.col])) 176 | v = torch.FloatTensor(d.data) 177 | ptD.append(torch.sparse.FloatTensor(i, v, d.shape)) 178 | 179 | return ptU, ptD 180 | 181 | 182 | def adjmat_sparse(adjmat, nsize=1): 183 | """Create row-normalized sparse graph adjacency matrix.""" 184 | adjmat = scipy.sparse.csr_matrix(adjmat) 185 | if nsize > 1: 186 | orig_adjmat = adjmat.copy() 187 | for _ in range(1, nsize): 188 | adjmat = adjmat * orig_adjmat 189 | adjmat.data = np.ones_like(adjmat.data) 190 | for i in range(adjmat.shape[0]): 191 | adjmat[i,i] = 1 192 | num_neighbors = np.array(1 / adjmat.sum(axis=-1)) 193 | adjmat = adjmat.multiply(num_neighbors) 194 | adjmat = scipy.sparse.coo_matrix(adjmat) 195 | row = adjmat.row 196 | col = adjmat.col 197 | data = adjmat.data 198 | i = torch.LongTensor(np.array([row, col])) 199 | v = torch.from_numpy(data).float() 200 | adjmat = torch.sparse.FloatTensor(i, v, adjmat.shape) 201 | return adjmat 202 | 203 | def get_graph_params(filename, nsize=1): 204 | """Load and process graph adjacency matrix and upsampling/downsampling matrices.""" 205 | data = np.load(filename, encoding='latin1', allow_pickle=True) 206 | A = data['A'] 207 | U = data['U'] 208 | D = data['D'] 209 | U, D = scipy_to_pytorch(A, U, D) 210 | A = [adjmat_sparse(a, nsize=nsize) for a in A] 211 | return A, U, D 212 | 213 | 214 | class Mesh(object): 215 | """Mesh object that is used for handling certain graph operations.""" 216 | def __init__(self, filename=cfg.SMPL_sampling_matrix, 217 | num_downsampling=1, nsize=1, device=torch.device('cuda')): 218 | self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize) 219 | # self._A = [a.to(device) for a in self._A] 220 | self._U = [u.to(device) for u in self._U] 221 | self._D = [d.to(device) for d in self._D] 222 | self.num_downsampling = num_downsampling 223 | 224 | # load template vertices from SMPL and normalize them 225 | smpl = SMPL() 226 | ref_vertices = smpl.v_template 227 | center = 0.5*(ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None] 228 | ref_vertices -= center 229 | ref_vertices /= ref_vertices.abs().max().item() 230 | 231 | self._ref_vertices = ref_vertices.to(device) 232 | self.faces = smpl.faces.int().to(device) 233 | 234 | # @property 235 | # def adjmat(self): 236 | # """Return the graph adjacency matrix at the specified subsampling level.""" 237 | # return self._A[self.num_downsampling].float() 238 | 239 | @property 240 | def ref_vertices(self): 241 | """Return the template vertices at the specified subsampling level.""" 242 | ref_vertices = self._ref_vertices 243 | for i in range(self.num_downsampling): 244 | ref_vertices = torch.spmm(self._D[i], ref_vertices) 245 | return ref_vertices 246 | 247 | def downsample(self, x, n1=0, n2=None): 248 | """Downsample mesh.""" 249 | if n2 is None: 250 | n2 = self.num_downsampling 251 | if x.ndimension() < 3: 252 | for i in range(n1, n2): 253 | x = spmm(self._D[i], x) 254 | elif x.ndimension() == 3: 255 | out = [] 256 | for i in range(x.shape[0]): 257 | y = x[i] 258 | for j in range(n1, n2): 259 | y = spmm(self._D[j], y) 260 | out.append(y) 261 | x = torch.stack(out, dim=0) 262 | return x 263 | 264 | def upsample(self, x, n1=1, n2=0): 265 | """Upsample mesh.""" 266 | if x.ndimension() < 3: 267 | for i in reversed(range(n2, n1)): 268 | x = spmm(self._U[i], x) 269 | elif x.ndimension() == 3: 270 | out = [] 271 | for i in range(x.shape[0]): 272 | y = x[i] 273 | for j in reversed(range(n2, n1)): 274 | y = spmm(self._U[j], y) 275 | out.append(y) 276 | x = torch.stack(out, dim=0) 277 | return x 278 | -------------------------------------------------------------------------------- /src/modeling/bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.0" 2 | 3 | from .modeling_bert import (BertConfig, BertModel, 4 | load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, 5 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP) 6 | 7 | from .modeling_graphormer import Graphormer 8 | 9 | from .e2e_body_network import Graphormer_Body_Network 10 | 11 | from .e2e_hand_network import Graphormer_Hand_Network 12 | 13 | from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, 14 | PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) 15 | 16 | from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path) 17 | -------------------------------------------------------------------------------- /src/modeling/bert/bert-base-uncased/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "max_position_embeddings": 512, 12 | "num_attention_heads": 12, 13 | "num_hidden_layers": 12, 14 | "type_vocab_size": 2, 15 | "vocab_size": 30522 16 | } 17 | -------------------------------------------------------------------------------- /src/modeling/bert/e2e_body_network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | """ 6 | 7 | import torch 8 | import src.modeling.data.config as cfg 9 | 10 | class Graphormer_Body_Network(torch.nn.Module): 11 | ''' 12 | End-to-end Graphormer network for human pose and mesh reconstruction from a single image. 13 | ''' 14 | def __init__(self, args, config, backbone, trans_encoder, mesh_sampler): 15 | super(Graphormer_Body_Network, self).__init__() 16 | self.config = config 17 | self.config.device = args.device 18 | self.backbone = backbone 19 | self.trans_encoder = trans_encoder 20 | self.upsampling = torch.nn.Linear(431, 1723) 21 | self.upsampling2 = torch.nn.Linear(1723, 6890) 22 | self.cam_param_fc = torch.nn.Linear(3, 1) 23 | self.cam_param_fc2 = torch.nn.Linear(431, 250) 24 | self.cam_param_fc3 = torch.nn.Linear(250, 3) 25 | self.grid_feat_dim = torch.nn.Linear(1024, 2051) 26 | 27 | 28 | def forward(self, images, smpl, mesh_sampler, meta_masks=None, is_train=False): 29 | batch_size = images.size(0) 30 | # Generate T-pose template mesh 31 | template_pose = torch.zeros((1,72)) 32 | template_pose[:,0] = 3.1416 # Rectify "upside down" reference mesh in global coord 33 | template_pose = template_pose.cuda(self.config.device) 34 | template_betas = torch.zeros((1,10)).cuda(self.config.device) 35 | template_vertices = smpl(template_pose, template_betas) 36 | 37 | # template mesh simplification 38 | template_vertices_sub = mesh_sampler.downsample(template_vertices) 39 | template_vertices_sub2 = mesh_sampler.downsample(template_vertices_sub, n1=1, n2=2) 40 | 41 | # template mesh-to-joint regression 42 | template_3d_joints = smpl.get_h36m_joints(template_vertices) 43 | template_pelvis = template_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:] 44 | template_3d_joints = template_3d_joints[:,cfg.H36M_J17_TO_J14,:] 45 | num_joints = template_3d_joints.shape[1] 46 | 47 | # normalize 48 | template_3d_joints = template_3d_joints - template_pelvis[:, None, :] 49 | template_vertices_sub2 = template_vertices_sub2 - template_pelvis[:, None, :] 50 | 51 | # concatinate template joints and template vertices, and then duplicate to batch size 52 | ref_vertices = torch.cat([template_3d_joints, template_vertices_sub2],dim=1) 53 | ref_vertices = ref_vertices.expand(batch_size, -1, -1) 54 | 55 | # extract grid features and global image features using a CNN backbone 56 | image_feat, grid_feat = self.backbone(images) 57 | # concatinate image feat and 3d mesh template 58 | image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1) 59 | # process grid features 60 | grid_feat = torch.flatten(grid_feat, start_dim=2) 61 | grid_feat = grid_feat.transpose(1,2) 62 | grid_feat = self.grid_feat_dim(grid_feat) 63 | # concatinate image feat and template mesh to form the joint/vertex queries 64 | features = torch.cat([ref_vertices, image_feat], dim=2) 65 | # prepare input tokens including joint/vertex queries and grid features 66 | features = torch.cat([features, grid_feat],dim=1) 67 | 68 | if is_train==True: 69 | # apply mask vertex/joint modeling 70 | # meta_masks is a tensor of all the masks, randomly generated in dataloader 71 | # we pre-define a [MASK] token, which is a floating-value vector with 0.01s 72 | special_token = torch.ones_like(features[:,:-49,:]).cuda()*0.01 73 | features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks) 74 | 75 | # forward pass 76 | if self.config.output_attentions==True: 77 | features, hidden_states, att = self.trans_encoder(features) 78 | else: 79 | features = self.trans_encoder(features) 80 | 81 | pred_3d_joints = features[:,:num_joints,:] 82 | pred_vertices_sub2 = features[:,num_joints:-49,:] 83 | 84 | # learn camera parameters 85 | x = self.cam_param_fc(pred_vertices_sub2) 86 | x = x.transpose(1,2) 87 | x = self.cam_param_fc2(x) 88 | x = self.cam_param_fc3(x) 89 | cam_param = x.transpose(1,2) 90 | cam_param = cam_param.squeeze() 91 | 92 | temp_transpose = pred_vertices_sub2.transpose(1,2) 93 | pred_vertices_sub = self.upsampling(temp_transpose) 94 | pred_vertices_full = self.upsampling2(pred_vertices_sub) 95 | pred_vertices_sub = pred_vertices_sub.transpose(1,2) 96 | pred_vertices_full = pred_vertices_full.transpose(1,2) 97 | 98 | if self.config.output_attentions==True: 99 | return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full, hidden_states, att 100 | else: 101 | return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full -------------------------------------------------------------------------------- /src/modeling/bert/e2e_hand_network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | """ 6 | 7 | import torch 8 | import src.modeling.data.config as cfg 9 | 10 | class Graphormer_Hand_Network(torch.nn.Module): 11 | ''' 12 | End-to-end Graphormer network for hand pose and mesh reconstruction from a single image. 13 | ''' 14 | def __init__(self, args, config, backbone, trans_encoder): 15 | super(Graphormer_Hand_Network, self).__init__() 16 | self.config = config 17 | self.backbone = backbone 18 | self.trans_encoder = trans_encoder 19 | self.upsampling = torch.nn.Linear(195, 778) 20 | self.cam_param_fc = torch.nn.Linear(3, 1) 21 | self.cam_param_fc2 = torch.nn.Linear(195+21, 150) 22 | self.cam_param_fc3 = torch.nn.Linear(150, 3) 23 | self.grid_feat_dim = torch.nn.Linear(1024, 2051) 24 | 25 | def forward(self, images, mesh_model, mesh_sampler, meta_masks=None, is_train=False): 26 | batch_size = images.size(0) 27 | # Generate T-pose template mesh 28 | template_pose = torch.zeros((1,48)) 29 | template_pose = template_pose.cuda() 30 | template_betas = torch.zeros((1,10)).cuda() 31 | template_vertices, template_3d_joints = mesh_model.layer(template_pose, template_betas) 32 | template_vertices = template_vertices/1000.0 33 | template_3d_joints = template_3d_joints/1000.0 34 | 35 | template_vertices_sub = mesh_sampler.downsample(template_vertices) 36 | 37 | # normalize 38 | template_root = template_3d_joints[:,cfg.J_NAME.index('Wrist'),:] 39 | template_3d_joints = template_3d_joints - template_root[:, None, :] 40 | template_vertices = template_vertices - template_root[:, None, :] 41 | template_vertices_sub = template_vertices_sub - template_root[:, None, :] 42 | num_joints = template_3d_joints.shape[1] 43 | 44 | # concatinate template joints and template vertices, and then duplicate to batch size 45 | ref_vertices = torch.cat([template_3d_joints, template_vertices_sub],dim=1) 46 | ref_vertices = ref_vertices.expand(batch_size, -1, -1) 47 | 48 | # extract grid features and global image features using a CNN backbone 49 | image_feat, grid_feat = self.backbone(images) 50 | # concatinate image feat and mesh template 51 | image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1) 52 | # process grid features 53 | grid_feat = torch.flatten(grid_feat, start_dim=2) 54 | grid_feat = grid_feat.transpose(1,2) 55 | grid_feat = self.grid_feat_dim(grid_feat) 56 | # concatinate image feat and template mesh to form the joint/vertex queries 57 | features = torch.cat([ref_vertices, image_feat], dim=2) 58 | # prepare input tokens including joint/vertex queries and grid features 59 | features = torch.cat([features, grid_feat],dim=1) 60 | 61 | if is_train==True: 62 | # apply mask vertex/joint modeling 63 | # meta_masks is a tensor of all the masks, randomly generated in dataloader 64 | # we pre-define a [MASK] token, which is a floating-value vector with 0.01s 65 | special_token = torch.ones_like(features[:,:-49,:]).cuda()*0.01 66 | features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks) 67 | 68 | # forward pass 69 | if self.config.output_attentions==True: 70 | features, hidden_states, att = self.trans_encoder(features) 71 | else: 72 | features = self.trans_encoder(features) 73 | 74 | pred_3d_joints = features[:,:num_joints,:] 75 | pred_vertices_sub = features[:,num_joints:-49,:] 76 | 77 | # learn camera parameters 78 | x = self.cam_param_fc(features[:,:-49,:]) 79 | x = x.transpose(1,2) 80 | x = self.cam_param_fc2(x) 81 | x = self.cam_param_fc3(x) 82 | cam_param = x.transpose(1,2) 83 | cam_param = cam_param.squeeze() 84 | 85 | temp_transpose = pred_vertices_sub.transpose(1,2) 86 | pred_vertices = self.upsampling(temp_transpose) 87 | pred_vertices = pred_vertices.transpose(1,2) 88 | 89 | if self.config.output_attentions==True: 90 | return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att 91 | else: 92 | return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices -------------------------------------------------------------------------------- /src/modeling/bert/file_utils.py: -------------------------------------------------------------------------------- 1 | ../../../transformers/pytorch_transformers/file_utils.py -------------------------------------------------------------------------------- /src/modeling/bert/modeling_bert.py: -------------------------------------------------------------------------------- 1 | ../../../transformers/pytorch_transformers/modeling_bert.py -------------------------------------------------------------------------------- /src/modeling/bert/modeling_graphormer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | """ 6 | 7 | from __future__ import absolute_import, division, print_function, unicode_literals 8 | 9 | import logging 10 | import math 11 | import os 12 | import code 13 | import torch 14 | from torch import nn 15 | from .modeling_bert import BertPreTrainedModel, BertEmbeddings, BertPooler, BertIntermediate, BertOutput, BertSelfOutput 16 | import src.modeling.data.config as cfg 17 | from src.modeling._gcnn import GraphConvolution, GraphResBlock 18 | from .modeling_utils import prune_linear_layer 19 | LayerNormClass = torch.nn.LayerNorm 20 | BertLayerNorm = torch.nn.LayerNorm 21 | 22 | 23 | class BertSelfAttention(nn.Module): 24 | def __init__(self, config): 25 | super(BertSelfAttention, self).__init__() 26 | if config.hidden_size % config.num_attention_heads != 0: 27 | raise ValueError( 28 | "The hidden size (%d) is not a multiple of the number of attention " 29 | "heads (%d)" % (config.hidden_size, config.num_attention_heads)) 30 | self.output_attentions = config.output_attentions 31 | 32 | self.num_attention_heads = config.num_attention_heads 33 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 34 | self.all_head_size = self.num_attention_heads * self.attention_head_size 35 | 36 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 37 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 38 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 39 | 40 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 41 | 42 | def transpose_for_scores(self, x): 43 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 44 | x = x.view(*new_x_shape) 45 | return x.permute(0, 2, 1, 3) 46 | 47 | def forward(self, hidden_states, attention_mask, head_mask=None, 48 | history_state=None): 49 | if history_state is not None: 50 | x_states = torch.cat([history_state, hidden_states], dim=1) 51 | mixed_query_layer = self.query(hidden_states) 52 | mixed_key_layer = self.key(x_states) 53 | mixed_value_layer = self.value(x_states) 54 | else: 55 | mixed_query_layer = self.query(hidden_states) 56 | mixed_key_layer = self.key(hidden_states) 57 | mixed_value_layer = self.value(hidden_states) 58 | 59 | query_layer = self.transpose_for_scores(mixed_query_layer) 60 | key_layer = self.transpose_for_scores(mixed_key_layer) 61 | value_layer = self.transpose_for_scores(mixed_value_layer) 62 | 63 | # Take the dot product between "query" and "key" to get the raw attention scores. 64 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 65 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 66 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function) 67 | attention_scores = attention_scores + attention_mask 68 | 69 | # Normalize the attention scores to probabilities. 70 | attention_probs = nn.Softmax(dim=-1)(attention_scores) 71 | 72 | # This is actually dropping out entire tokens to attend to, which might 73 | # seem a bit unusual, but is taken from the original Transformer paper. 74 | attention_probs = self.dropout(attention_probs) 75 | 76 | # Mask heads if we want to 77 | if head_mask is not None: 78 | attention_probs = attention_probs * head_mask 79 | 80 | context_layer = torch.matmul(attention_probs, value_layer) 81 | 82 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 83 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 84 | context_layer = context_layer.view(*new_context_layer_shape) 85 | 86 | outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) 87 | return outputs 88 | 89 | class BertAttention(nn.Module): 90 | def __init__(self, config): 91 | super(BertAttention, self).__init__() 92 | self.self = BertSelfAttention(config) 93 | self.output = BertSelfOutput(config) 94 | 95 | def prune_heads(self, heads): 96 | if len(heads) == 0: 97 | return 98 | mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) 99 | for head in heads: 100 | mask[head] = 0 101 | mask = mask.view(-1).contiguous().eq(1) 102 | index = torch.arange(len(mask))[mask].long() 103 | # Prune linear layers 104 | self.self.query = prune_linear_layer(self.self.query, index) 105 | self.self.key = prune_linear_layer(self.self.key, index) 106 | self.self.value = prune_linear_layer(self.self.value, index) 107 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) 108 | # Update hyper params 109 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads) 110 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads 111 | 112 | def forward(self, input_tensor, attention_mask, head_mask=None, 113 | history_state=None): 114 | self_outputs = self.self(input_tensor, attention_mask, head_mask, 115 | history_state) 116 | attention_output = self.output(self_outputs[0], input_tensor) 117 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them 118 | return outputs 119 | 120 | 121 | class GraphormerLayer(nn.Module): 122 | def __init__(self, config): 123 | super(GraphormerLayer, self).__init__() 124 | self.attention = BertAttention(config) 125 | self.has_graph_conv = config.graph_conv 126 | self.mesh_type = config.mesh_type 127 | 128 | if self.has_graph_conv == True: 129 | self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type) 130 | 131 | self.intermediate = BertIntermediate(config) 132 | self.output = BertOutput(config) 133 | 134 | def MHA_GCN(self, hidden_states, attention_mask, head_mask=None, 135 | history_state=None): 136 | attention_outputs = self.attention(hidden_states, attention_mask, 137 | head_mask, history_state) 138 | attention_output = attention_outputs[0] 139 | 140 | if self.has_graph_conv==True: 141 | if self.mesh_type == 'body': 142 | joints = attention_output[:,0:14,:] 143 | vertices = attention_output[:,14:-49,:] 144 | img_tokens = attention_output[:,-49:,:] 145 | 146 | elif self.mesh_type == 'hand': 147 | joints = attention_output[:,0:21,:] 148 | vertices = attention_output[:,21:-49,:] 149 | img_tokens = attention_output[:,-49:,:] 150 | 151 | vertices = self.graph_conv(vertices) 152 | joints_vertices = torch.cat([joints,vertices,img_tokens],dim=1) 153 | else: 154 | joints_vertices = attention_output 155 | 156 | intermediate_output = self.intermediate(joints_vertices) 157 | layer_output = self.output(intermediate_output, joints_vertices) 158 | outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them 159 | return outputs 160 | 161 | def forward(self, hidden_states, attention_mask, head_mask=None, 162 | history_state=None): 163 | return self.MHA_GCN(hidden_states, attention_mask, head_mask,history_state) 164 | 165 | 166 | class GraphormerEncoder(nn.Module): 167 | def __init__(self, config): 168 | super(GraphormerEncoder, self).__init__() 169 | self.output_attentions = config.output_attentions 170 | self.output_hidden_states = config.output_hidden_states 171 | self.layer = nn.ModuleList([GraphormerLayer(config) for _ in range(config.num_hidden_layers)]) 172 | 173 | def forward(self, hidden_states, attention_mask, head_mask=None, 174 | encoder_history_states=None): 175 | all_hidden_states = () 176 | all_attentions = () 177 | for i, layer_module in enumerate(self.layer): 178 | if self.output_hidden_states: 179 | all_hidden_states = all_hidden_states + (hidden_states,) 180 | 181 | history_state = None if encoder_history_states is None else encoder_history_states[i] 182 | layer_outputs = layer_module( 183 | hidden_states, attention_mask, head_mask[i], 184 | history_state) 185 | hidden_states = layer_outputs[0] 186 | 187 | if self.output_attentions: 188 | all_attentions = all_attentions + (layer_outputs[1],) 189 | 190 | # Add last layer 191 | if self.output_hidden_states: 192 | all_hidden_states = all_hidden_states + (hidden_states,) 193 | 194 | outputs = (hidden_states,) 195 | if self.output_hidden_states: 196 | outputs = outputs + (all_hidden_states,) 197 | if self.output_attentions: 198 | outputs = outputs + (all_attentions,) 199 | 200 | return outputs # outputs, (hidden states), (attentions) 201 | 202 | class EncoderBlock(BertPreTrainedModel): 203 | def __init__(self, config): 204 | super(EncoderBlock, self).__init__(config) 205 | self.config = config 206 | self.embeddings = BertEmbeddings(config) 207 | self.encoder = GraphormerEncoder(config) 208 | self.pooler = BertPooler(config) 209 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 210 | self.img_dim = config.img_feature_dim 211 | 212 | try: 213 | self.use_img_layernorm = config.use_img_layernorm 214 | except: 215 | self.use_img_layernorm = None 216 | 217 | self.img_embedding = nn.Linear(self.img_dim, self.config.hidden_size, bias=True) 218 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 219 | if self.use_img_layernorm: 220 | self.LayerNorm = LayerNormClass(config.hidden_size, eps=config.img_layer_norm_eps) 221 | 222 | self.apply(self.init_weights) 223 | 224 | def _prune_heads(self, heads_to_prune): 225 | """ Prunes heads of the model. 226 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 227 | See base class PreTrainedModel 228 | """ 229 | for layer, heads in heads_to_prune.items(): 230 | self.encoder.layer[layer].attention.prune_heads(heads) 231 | 232 | def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None, 233 | position_ids=None, head_mask=None): 234 | 235 | batch_size = len(img_feats) 236 | seq_length = len(img_feats[0]) 237 | input_ids = torch.zeros([batch_size, seq_length],dtype=torch.long).cuda() 238 | 239 | if position_ids is None: 240 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) 241 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 242 | 243 | position_embeddings = self.position_embeddings(position_ids) 244 | 245 | if attention_mask is None: 246 | attention_mask = torch.ones_like(input_ids) 247 | 248 | if token_type_ids is None: 249 | token_type_ids = torch.zeros_like(input_ids) 250 | 251 | if attention_mask.dim() == 2: 252 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 253 | elif attention_mask.dim() == 3: 254 | extended_attention_mask = attention_mask.unsqueeze(1) 255 | else: 256 | raise NotImplementedError 257 | 258 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 259 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 260 | 261 | if head_mask is not None: 262 | if head_mask.dim() == 1: 263 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 264 | head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) 265 | elif head_mask.dim() == 2: 266 | head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer 267 | head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility 268 | else: 269 | head_mask = [None] * self.config.num_hidden_layers 270 | 271 | # Project input token features to have spcified hidden size 272 | img_embedding_output = self.img_embedding(img_feats) 273 | 274 | # We empirically observe that adding an additional learnable position embedding leads to more stable training 275 | embeddings = position_embeddings + img_embedding_output 276 | 277 | if self.use_img_layernorm: 278 | embeddings = self.LayerNorm(embeddings) 279 | embeddings = self.dropout(embeddings) 280 | 281 | encoder_outputs = self.encoder(embeddings, 282 | extended_attention_mask, head_mask=head_mask) 283 | sequence_output = encoder_outputs[0] 284 | 285 | outputs = (sequence_output,) 286 | if self.config.output_hidden_states: 287 | all_hidden_states = encoder_outputs[1] 288 | outputs = outputs + (all_hidden_states,) 289 | if self.config.output_attentions: 290 | all_attentions = encoder_outputs[-1] 291 | outputs = outputs + (all_attentions,) 292 | 293 | return outputs 294 | 295 | class Graphormer(BertPreTrainedModel): 296 | ''' 297 | The archtecture of a transformer encoder block we used in Graphormer 298 | ''' 299 | def __init__(self, config): 300 | super(Graphormer, self).__init__(config) 301 | self.config = config 302 | self.bert = EncoderBlock(config) 303 | self.cls_head = nn.Linear(config.hidden_size, self.config.output_feature_dim) 304 | self.residual = nn.Linear(config.img_feature_dim, self.config.output_feature_dim) 305 | self.apply(self.init_weights) 306 | 307 | def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None, 308 | next_sentence_label=None, position_ids=None, head_mask=None): 309 | ''' 310 | # self.bert has three outputs 311 | # predictions[0]: output tokens 312 | # predictions[1]: all_hidden_states, if enable "self.config.output_hidden_states" 313 | # predictions[2]: attentions, if enable "self.config.output_attentions" 314 | ''' 315 | predictions = self.bert(img_feats=img_feats, input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, 316 | attention_mask=attention_mask, head_mask=head_mask) 317 | 318 | # We use "self.cls_head" to perform dimensionality reduction. We don't use it for classification. 319 | pred_score = self.cls_head(predictions[0]) 320 | res_img_feats = self.residual(img_feats) 321 | pred_score = pred_score + res_img_feats 322 | 323 | if self.config.output_attentions and self.config.output_hidden_states: 324 | return pred_score, predictions[1], predictions[-1] 325 | else: 326 | return pred_score 327 | 328 | -------------------------------------------------------------------------------- /src/modeling/bert/modeling_utils.py: -------------------------------------------------------------------------------- 1 | ../../../transformers/pytorch_transformers/modeling_utils.py -------------------------------------------------------------------------------- /src/modeling/data/J_regressor_extra.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/J_regressor_extra.npy -------------------------------------------------------------------------------- /src/modeling/data/J_regressor_h36m_correct.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/J_regressor_h36m_correct.npy -------------------------------------------------------------------------------- /src/modeling/data/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Extra data 3 | Adapted from open source project [GraphCMR](https://github.com/nkolot/GraphCMR/) and [Pose2Mesh](https://github.com/hongsukchoi/Pose2Mesh_RELEASE) 4 | 5 | Our code requires additional data to run smoothly. 6 | 7 | ### J_regressor_extra.npy 8 | Joints regressor for joints or landmarks that are not included in the standard set of SMPL joints. 9 | 10 | ### J_regressor_h36m_correct.npy 11 | Joints regressor reflecting the Human3.6M joints. 12 | 13 | ### mesh_downsampling.npz 14 | Extra file with precomputed downsampling for the SMPL body mesh. 15 | 16 | ### mano_downsampling.npz 17 | Extra file with precomputed downsampling for the MANO hand mesh. 18 | 19 | ### basicModel_neutral_lbs_10_207_0_v1.0.0.pkl 20 | SMPL neutral model. Please visit the official website to download the file [http://smplify.is.tue.mpg.de/](http://smplify.is.tue.mpg.de/) 21 | 22 | ### basicModel_m_lbs_10_207_0_v1.0.0.pkl 23 | SMPL male model. Please visit the official website to download the file [https://smpl.is.tue.mpg.de/](https://smpl.is.tue.mpg.de/) 24 | 25 | ### basicModel_f_lbs_10_207_0_v1.0.0.pkl 26 | SMPL female model. Please visit the official website to download the file [https://smpl.is.tue.mpg.de/](https://smpl.is.tue.mpg.de/) 27 | 28 | ### MANO_RIGHT.pkl 29 | MANO hand model. Please visit the official website to download the file [https://mano.is.tue.mpg.de/](https://mano.is.tue.mpg.de/) 30 | 31 | -------------------------------------------------------------------------------- /src/modeling/data/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains definitions of useful data stuctures and the paths 3 | for the datasets and data files necessary to run the code. 4 | 5 | Adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/) and Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE) 6 | 7 | """ 8 | 9 | from os.path import join 10 | folder_path = 'src/modeling/' 11 | JOINT_REGRESSOR_TRAIN_EXTRA = folder_path + 'data/J_regressor_extra.npy' 12 | JOINT_REGRESSOR_H36M_correct = folder_path + 'data/J_regressor_h36m_correct.npy' 13 | SMPL_FILE = folder_path + 'data/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl' 14 | SMPL_Male = folder_path + 'data/basicModel_m_lbs_10_207_0_v1.0.0.pkl' 15 | SMPL_Female = folder_path + 'data/basicModel_f_lbs_10_207_0_v1.0.0.pkl' 16 | SMPL_sampling_matrix = folder_path + 'data/mesh_downsampling.npz' 17 | MANO_FILE = folder_path + 'data/MANO_RIGHT.pkl' 18 | MANO_sampling_matrix = folder_path + 'data/mano_downsampling.npz' 19 | 20 | JOINTS_IDX = [8, 5, 29, 30, 4, 7, 21, 19, 17, 16, 18, 20, 31, 32, 33, 34, 35, 36, 37, 24, 26, 25, 28, 27] 21 | 22 | 23 | """ 24 | We follow the body joint definition, loss functions, and evaluation metrics from 25 | open source project GraphCMR (https://github.com/nkolot/GraphCMR/) 26 | 27 | Each dataset uses different sets of joints. 28 | We use a superset of 24 joints such that we include all joints from every dataset. 29 | If a dataset doesn't provide annotations for a specific joint, we simply ignore it. 30 | The joints used here are: 31 | """ 32 | J24_NAME = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder', 33 | 'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear') 34 | H36M_J17_NAME = ( 'Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Torso', 'Neck', 'Nose', 'Head', 35 | 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist') 36 | J24_TO_J14 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18] 37 | H36M_J17_TO_J14 = [3, 2, 1, 4, 5, 6, 16, 15, 14, 11, 12, 13, 8, 10] 38 | 39 | """ 40 | We follow the hand joint definition and mesh topology from 41 | open source project Manopth (https://github.com/hassony2/manopth) 42 | 43 | The hand joints used here are: 44 | """ 45 | J_NAME = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 46 | 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4') 47 | ROOT_INDEX = 0 -------------------------------------------------------------------------------- /src/modeling/data/mano_195_adjmat_indices.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/mano_195_adjmat_indices.pt -------------------------------------------------------------------------------- /src/modeling/data/mano_195_adjmat_size.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/mano_195_adjmat_size.pt -------------------------------------------------------------------------------- /src/modeling/data/mano_195_adjmat_values.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/mano_195_adjmat_values.pt -------------------------------------------------------------------------------- /src/modeling/data/mano_downsampling.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/mano_downsampling.npz -------------------------------------------------------------------------------- /src/modeling/data/mesh_downsampling.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/mesh_downsampling.npz -------------------------------------------------------------------------------- /src/modeling/data/smpl_431_adjmat_indices.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/smpl_431_adjmat_indices.pt -------------------------------------------------------------------------------- /src/modeling/data/smpl_431_adjmat_size.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/smpl_431_adjmat_size.pt -------------------------------------------------------------------------------- /src/modeling/data/smpl_431_adjmat_values.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/smpl_431_adjmat_values.pt -------------------------------------------------------------------------------- /src/modeling/data/smpl_431_faces.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/smpl_431_faces.npy -------------------------------------------------------------------------------- /src/modeling/hrnet/config/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from .default import _C as config 8 | from .default import update_config 9 | from .models import MODEL_EXTRAS 10 | -------------------------------------------------------------------------------- /src/modeling/hrnet/config/default.py: -------------------------------------------------------------------------------- 1 | 2 | # ------------------------------------------------------------------------------ 3 | # Copyright (c) Microsoft 4 | # Licensed under the MIT License. 5 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 6 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn) 7 | # ------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | import os 14 | 15 | from yacs.config import CfgNode as CN 16 | 17 | 18 | _C = CN() 19 | 20 | _C.OUTPUT_DIR = '' 21 | _C.LOG_DIR = '' 22 | _C.DATA_DIR = '' 23 | _C.GPUS = (0,) 24 | _C.WORKERS = 4 25 | _C.PRINT_FREQ = 20 26 | _C.AUTO_RESUME = False 27 | _C.PIN_MEMORY = True 28 | _C.RANK = 0 29 | 30 | # Cudnn related params 31 | _C.CUDNN = CN() 32 | _C.CUDNN.BENCHMARK = True 33 | _C.CUDNN.DETERMINISTIC = False 34 | _C.CUDNN.ENABLED = True 35 | 36 | # common params for NETWORK 37 | _C.MODEL = CN() 38 | _C.MODEL.NAME = 'cls_hrnet' 39 | _C.MODEL.INIT_WEIGHTS = True 40 | _C.MODEL.PRETRAINED = '' 41 | _C.MODEL.NUM_JOINTS = 17 42 | _C.MODEL.NUM_CLASSES = 1000 43 | _C.MODEL.TAG_PER_JOINT = True 44 | _C.MODEL.TARGET_TYPE = 'gaussian' 45 | _C.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256 46 | _C.MODEL.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32 47 | _C.MODEL.SIGMA = 2 48 | _C.MODEL.EXTRA = CN(new_allowed=True) 49 | 50 | _C.LOSS = CN() 51 | _C.LOSS.USE_OHKM = False 52 | _C.LOSS.TOPK = 8 53 | _C.LOSS.USE_TARGET_WEIGHT = True 54 | _C.LOSS.USE_DIFFERENT_JOINTS_WEIGHT = False 55 | 56 | # DATASET related params 57 | _C.DATASET = CN() 58 | _C.DATASET.ROOT = '' 59 | _C.DATASET.DATASET = 'mpii' 60 | _C.DATASET.TRAIN_SET = 'train' 61 | _C.DATASET.TEST_SET = 'valid' 62 | _C.DATASET.DATA_FORMAT = 'jpg' 63 | _C.DATASET.HYBRID_JOINTS_TYPE = '' 64 | _C.DATASET.SELECT_DATA = False 65 | 66 | # training data augmentation 67 | _C.DATASET.FLIP = True 68 | _C.DATASET.SCALE_FACTOR = 0.25 69 | _C.DATASET.ROT_FACTOR = 30 70 | _C.DATASET.PROB_HALF_BODY = 0.0 71 | _C.DATASET.NUM_JOINTS_HALF_BODY = 8 72 | _C.DATASET.COLOR_RGB = False 73 | 74 | # train 75 | _C.TRAIN = CN() 76 | 77 | _C.TRAIN.LR_FACTOR = 0.1 78 | _C.TRAIN.LR_STEP = [90, 110] 79 | _C.TRAIN.LR = 0.001 80 | 81 | _C.TRAIN.OPTIMIZER = 'adam' 82 | _C.TRAIN.MOMENTUM = 0.9 83 | _C.TRAIN.WD = 0.0001 84 | _C.TRAIN.NESTEROV = False 85 | _C.TRAIN.GAMMA1 = 0.99 86 | _C.TRAIN.GAMMA2 = 0.0 87 | 88 | _C.TRAIN.BEGIN_EPOCH = 0 89 | _C.TRAIN.END_EPOCH = 140 90 | 91 | _C.TRAIN.RESUME = False 92 | _C.TRAIN.CHECKPOINT = '' 93 | 94 | _C.TRAIN.BATCH_SIZE_PER_GPU = 32 95 | _C.TRAIN.SHUFFLE = True 96 | 97 | # testing 98 | _C.TEST = CN() 99 | 100 | # size of images for each device 101 | _C.TEST.BATCH_SIZE_PER_GPU = 32 102 | # Test Model Epoch 103 | _C.TEST.FLIP_TEST = False 104 | _C.TEST.POST_PROCESS = False 105 | _C.TEST.SHIFT_HEATMAP = False 106 | 107 | _C.TEST.USE_GT_BBOX = False 108 | 109 | # nms 110 | _C.TEST.IMAGE_THRE = 0.1 111 | _C.TEST.NMS_THRE = 0.6 112 | _C.TEST.SOFT_NMS = False 113 | _C.TEST.OKS_THRE = 0.5 114 | _C.TEST.IN_VIS_THRE = 0.0 115 | _C.TEST.COCO_BBOX_FILE = '' 116 | _C.TEST.BBOX_THRE = 1.0 117 | _C.TEST.MODEL_FILE = '' 118 | 119 | # debug 120 | _C.DEBUG = CN() 121 | _C.DEBUG.DEBUG = False 122 | _C.DEBUG.SAVE_BATCH_IMAGES_GT = False 123 | _C.DEBUG.SAVE_BATCH_IMAGES_PRED = False 124 | _C.DEBUG.SAVE_HEATMAPS_GT = False 125 | _C.DEBUG.SAVE_HEATMAPS_PRED = False 126 | 127 | 128 | def update_config(cfg, config_file): 129 | cfg.defrost() 130 | cfg.merge_from_file(config_file) 131 | cfg.freeze() 132 | 133 | 134 | if __name__ == '__main__': 135 | import sys 136 | with open(sys.argv[1], 'w') as f: 137 | print(_C, file=f) 138 | 139 | -------------------------------------------------------------------------------- /src/modeling/hrnet/config/models.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Create by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn) 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | from yacs.config import CfgNode as CN 13 | 14 | # high_resoluton_net related params for classification 15 | POSE_HIGH_RESOLUTION_NET = CN() 16 | POSE_HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*'] 17 | POSE_HIGH_RESOLUTION_NET.STEM_INPLANES = 64 18 | POSE_HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1 19 | POSE_HIGH_RESOLUTION_NET.WITH_HEAD = True 20 | 21 | POSE_HIGH_RESOLUTION_NET.STAGE2 = CN() 22 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1 23 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2 24 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4] 25 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64] 26 | POSE_HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC' 27 | POSE_HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM' 28 | 29 | POSE_HIGH_RESOLUTION_NET.STAGE3 = CN() 30 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1 31 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3 32 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4] 33 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128] 34 | POSE_HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC' 35 | POSE_HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM' 36 | 37 | POSE_HIGH_RESOLUTION_NET.STAGE4 = CN() 38 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1 39 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4 40 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 41 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256] 42 | POSE_HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC' 43 | POSE_HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM' 44 | 45 | MODEL_EXTRAS = { 46 | 'cls_hrnet': POSE_HIGH_RESOLUTION_NET, 47 | } 48 | -------------------------------------------------------------------------------- /src/tools/run_gphmer_handmesh_inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | End-to-end inference codes for 6 | 3D hand mesh reconstruction from an image 7 | """ 8 | 9 | from __future__ import absolute_import, division, print_function 10 | import argparse 11 | import os 12 | import os.path as op 13 | import code 14 | import json 15 | import time 16 | import datetime 17 | import torch 18 | import torchvision.models as models 19 | from torchvision.utils import make_grid 20 | import gc 21 | import numpy as np 22 | import cv2 23 | from src.modeling.bert import BertConfig, Graphormer 24 | from src.modeling.bert import Graphormer_Hand_Network as Graphormer_Network 25 | from src.modeling._mano import MANO, Mesh 26 | from src.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat 27 | from src.modeling.hrnet.config import config as hrnet_config 28 | from src.modeling.hrnet.config import update_config as hrnet_update_config 29 | import src.modeling.data.config as cfg 30 | from src.datasets.build import make_hand_data_loader 31 | 32 | from src.utils.logger import setup_logger 33 | from src.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather 34 | from src.utils.miscellaneous import mkdir, set_seed 35 | from src.utils.metric_logger import AverageMeter 36 | from src.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text 37 | from src.utils.metric_pampjpe import reconstruction_error 38 | from src.utils.geometric_layers import orthographic_projection 39 | 40 | from PIL import Image 41 | from torchvision import transforms 42 | 43 | transform = transforms.Compose([ 44 | transforms.Resize(224), 45 | transforms.CenterCrop(224), 46 | transforms.ToTensor(), 47 | transforms.Normalize( 48 | mean=[0.485, 0.456, 0.406], 49 | std=[0.229, 0.224, 0.225])]) 50 | 51 | transform_visualize = transforms.Compose([ 52 | transforms.Resize(224), 53 | transforms.CenterCrop(224), 54 | transforms.ToTensor()]) 55 | 56 | def run_inference(args, image_list, Graphormer_model, mano, renderer, mesh_sampler): 57 | # switch to evaluate mode 58 | Graphormer_model.eval() 59 | mano.eval() 60 | with torch.no_grad(): 61 | for image_file in image_list: 62 | if 'pred' not in image_file: 63 | att_all = [] 64 | print(image_file) 65 | img = Image.open(image_file) 66 | img_tensor = transform(img) 67 | img_visual = transform_visualize(img) 68 | 69 | batch_imgs = torch.unsqueeze(img_tensor, 0).cuda() 70 | batch_visual_imgs = torch.unsqueeze(img_visual, 0).cuda() 71 | # forward-pass 72 | pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler) 73 | # obtain 3d joints from full mesh 74 | pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices) 75 | pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:] 76 | pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :] 77 | pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :] 78 | 79 | # save attantion 80 | att_max_value = att[-1] 81 | att_cpu = np.asarray(att_max_value.cpu().detach()) 82 | att_all.append(att_cpu) 83 | 84 | # obtain 3d joints, which are regressed from the full mesh 85 | pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices) 86 | # obtain 2d joints, which are projected from 3d joints of mesh 87 | pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous()) 88 | pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous()) 89 | 90 | 91 | visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0], 92 | pred_vertices[0].detach(), 93 | pred_camera.detach()) 94 | # visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0], 95 | # pred_vertices[0].detach(), 96 | # pred_vertices_sub[0].detach(), 97 | # pred_2d_coarse_vertices_from_mesh[0].detach(), 98 | # pred_2d_joints_from_mesh[0].detach(), 99 | # pred_camera.detach(), 100 | # att[-1][0].detach()) 101 | visual_imgs = visual_imgs_output.transpose(1,2,0) 102 | visual_imgs = np.asarray(visual_imgs) 103 | 104 | temp_fname = image_file[:-4] + '_graphormer_pred.jpg' 105 | print('save to ', temp_fname) 106 | cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) 107 | return 108 | 109 | def visualize_mesh( renderer, images, 110 | pred_vertices_full, 111 | pred_camera): 112 | img = images.cpu().numpy().transpose(1,2,0) 113 | # Get predict vertices for the particular example 114 | vertices_full = pred_vertices_full.cpu().numpy() 115 | cam = pred_camera.cpu().numpy() 116 | # Visualize only mesh reconstruction 117 | rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue') 118 | rend_img = rend_img.transpose(2,0,1) 119 | return rend_img 120 | 121 | def visualize_mesh_and_attention( renderer, images, 122 | pred_vertices_full, 123 | pred_vertices, 124 | pred_2d_vertices, 125 | pred_2d_joints, 126 | pred_camera, 127 | attention): 128 | img = images.cpu().numpy().transpose(1,2,0) 129 | # Get predict vertices for the particular example 130 | vertices_full = pred_vertices_full.cpu().numpy() 131 | vertices = pred_vertices.cpu().numpy() 132 | vertices_2d = pred_2d_vertices.cpu().numpy() 133 | joints_2d = pred_2d_joints.cpu().numpy() 134 | cam = pred_camera.cpu().numpy() 135 | att = attention.cpu().numpy() 136 | # Visualize reconstruction and attention 137 | rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue') 138 | rend_img = rend_img.transpose(2,0,1) 139 | return rend_img 140 | 141 | def parse_args(): 142 | parser = argparse.ArgumentParser() 143 | ######################################################### 144 | # Data related arguments 145 | ######################################################### 146 | parser.add_argument("--num_workers", default=4, type=int, 147 | help="Workers in dataloader.") 148 | parser.add_argument("--img_scale_factor", default=1, type=int, 149 | help="adjust image resolution.") 150 | parser.add_argument("--image_file_or_path", default='./samples/hand', type=str, 151 | help="test data") 152 | ######################################################### 153 | # Loading/saving checkpoints 154 | ######################################################### 155 | parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False, 156 | help="Path to pre-trained transformer model or model type.") 157 | parser.add_argument("--resume_checkpoint", default=None, type=str, required=False, 158 | help="Path to specific checkpoint for resume training.") 159 | parser.add_argument("--output_dir", default='output/', type=str, required=False, 160 | help="The output directory to save checkpoint and test results.") 161 | parser.add_argument("--config_name", default="", type=str, 162 | help="Pretrained config name or path if not the same as model_name.") 163 | parser.add_argument('-a', '--arch', default='hrnet-w64', 164 | help='CNN backbone architecture: hrnet-w64, hrnet, resnet50') 165 | ######################################################### 166 | # Model architectures 167 | ######################################################### 168 | parser.add_argument("--num_hidden_layers", default=4, type=int, required=False, 169 | help="Update model config if given") 170 | parser.add_argument("--hidden_size", default=-1, type=int, required=False, 171 | help="Update model config if given") 172 | parser.add_argument("--num_attention_heads", default=4, type=int, required=False, 173 | help="Update model config if given. Note that the division of " 174 | "hidden_size / num_attention_heads should be in integer.") 175 | parser.add_argument("--intermediate_size", default=-1, type=int, required=False, 176 | help="Update model config if given.") 177 | parser.add_argument("--input_feat_dim", default='2051,512,128', type=str, 178 | help="The Image Feature Dimension.") 179 | parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str, 180 | help="The Image Feature Dimension.") 181 | parser.add_argument("--which_gcn", default='0,0,1', type=str, 182 | help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv") 183 | parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand") 184 | 185 | ######################################################### 186 | # Others 187 | ######################################################### 188 | parser.add_argument("--run_eval_only", default=True, action='store_true',) 189 | parser.add_argument("--device", type=str, default='cuda', 190 | help="cuda or cpu") 191 | parser.add_argument('--seed', type=int, default=88, 192 | help="random seed for initialization.") 193 | args = parser.parse_args() 194 | return args 195 | 196 | def main(args): 197 | global logger 198 | # Setup CUDA, GPU & distributed training 199 | args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 200 | os.environ['OMP_NUM_THREADS'] = str(args.num_workers) 201 | print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) 202 | 203 | mkdir(args.output_dir) 204 | logger = setup_logger("Graphormer", args.output_dir, get_rank()) 205 | set_seed(args.seed, args.num_gpus) 206 | logger.info("Using {} GPUs".format(args.num_gpus)) 207 | 208 | # Mesh and MANO utils 209 | mano_model = MANO().to(args.device) 210 | mano_model.layer = mano_model.layer.cuda() 211 | mesh_sampler = Mesh() 212 | 213 | # Renderer for visualization 214 | renderer = Renderer(faces=mano_model.face) 215 | 216 | # Load pretrained model 217 | trans_encoder = [] 218 | 219 | input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] 220 | hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] 221 | output_feat_dim = input_feat_dim[1:] + [3] 222 | 223 | # which encoder block to have graph convs 224 | which_blk_graph = [int(item) for item in args.which_gcn.split(',')] 225 | 226 | if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: 227 | # if only run eval, load checkpoint 228 | logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) 229 | _model = torch.load(args.resume_checkpoint) 230 | 231 | else: 232 | # init three transformer-encoder blocks in a loop 233 | for i in range(len(output_feat_dim)): 234 | config_class, model_class = BertConfig, Graphormer 235 | config = config_class.from_pretrained(args.config_name if args.config_name \ 236 | else args.model_name_or_path) 237 | 238 | config.output_attentions = False 239 | config.img_feature_dim = input_feat_dim[i] 240 | config.output_feature_dim = output_feat_dim[i] 241 | args.hidden_size = hidden_feat_dim[i] 242 | args.intermediate_size = int(args.hidden_size*2) 243 | 244 | if which_blk_graph[i]==1: 245 | config.graph_conv = True 246 | logger.info("Add Graph Conv") 247 | else: 248 | config.graph_conv = False 249 | 250 | config.mesh_type = args.mesh_type 251 | 252 | # update model structure if specified in arguments 253 | update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] 254 | for idx, param in enumerate(update_params): 255 | arg_param = getattr(args, param) 256 | config_param = getattr(config, param) 257 | if arg_param > 0 and arg_param != config_param: 258 | logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) 259 | setattr(config, param, arg_param) 260 | 261 | # init a transformer encoder and append it to a list 262 | assert config.hidden_size % config.num_attention_heads == 0 263 | model = model_class(config=config) 264 | logger.info("Init model from scratch.") 265 | trans_encoder.append(model) 266 | 267 | # create backbone model 268 | if args.arch=='hrnet': 269 | hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' 270 | hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth' 271 | hrnet_update_config(hrnet_config, hrnet_yaml) 272 | backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) 273 | logger.info('=> loading hrnet-v2-w40 model') 274 | elif args.arch=='hrnet-w64': 275 | hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' 276 | hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth' 277 | hrnet_update_config(hrnet_config, hrnet_yaml) 278 | backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) 279 | logger.info('=> loading hrnet-v2-w64 model') 280 | else: 281 | print("=> using pre-trained model '{}'".format(args.arch)) 282 | backbone = models.__dict__[args.arch](pretrained=True) 283 | # remove the last fc layer 284 | backbone = torch.nn.Sequential(*list(backbone.children())[:-1]) 285 | 286 | trans_encoder = torch.nn.Sequential(*trans_encoder) 287 | total_params = sum(p.numel() for p in trans_encoder.parameters()) 288 | logger.info('Graphormer encoders total parameters: {}'.format(total_params)) 289 | backbone_total_params = sum(p.numel() for p in backbone.parameters()) 290 | logger.info('Backbone total parameters: {}'.format(backbone_total_params)) 291 | 292 | # build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder) 293 | _model = Graphormer_Network(args, config, backbone, trans_encoder) 294 | 295 | if args.resume_checkpoint!=None and args.resume_checkpoint!='None': 296 | # for fine-tuning or resume training or inference, load weights from checkpoint 297 | logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) 298 | # workaround approach to load sparse tensor in graph conv. 299 | state_dict = torch.load(args.resume_checkpoint) 300 | _model.load_state_dict(state_dict, strict=False) 301 | del state_dict 302 | gc.collect() 303 | torch.cuda.empty_cache() 304 | 305 | # update configs to enable attention outputs 306 | setattr(_model.trans_encoder[-1].config,'output_attentions', True) 307 | setattr(_model.trans_encoder[-1].config,'output_hidden_states', True) 308 | _model.trans_encoder[-1].bert.encoder.output_attentions = True 309 | _model.trans_encoder[-1].bert.encoder.output_hidden_states = True 310 | for iter_layer in range(4): 311 | _model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True 312 | for inter_block in range(3): 313 | setattr(_model.trans_encoder[-1].config,'device', args.device) 314 | 315 | _model.to(args.device) 316 | logger.info("Run inference") 317 | 318 | image_list = [] 319 | if not args.image_file_or_path: 320 | raise ValueError("image_file_or_path not specified") 321 | if op.isfile(args.image_file_or_path): 322 | image_list = [args.image_file_or_path] 323 | elif op.isdir(args.image_file_or_path): 324 | # should be a path with images only 325 | for filename in os.listdir(args.image_file_or_path): 326 | if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename: 327 | image_list.append(args.image_file_or_path+'/'+filename) 328 | else: 329 | raise ValueError("Cannot find images at {}".format(args.image_file_or_path)) 330 | 331 | run_inference(args, image_list, _model, mano_model, renderer, mesh_sampler) 332 | 333 | if __name__ == "__main__": 334 | args = parse_args() 335 | main(args) 336 | -------------------------------------------------------------------------------- /src/tools/run_hand_multiscale.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | import os 5 | import os.path as op 6 | import code 7 | import json 8 | import zipfile 9 | import torch 10 | import numpy as np 11 | from src.utils.metric_pampjpe import get_alignMesh 12 | 13 | 14 | def load_pred_json(filepath): 15 | archive = zipfile.ZipFile(filepath, 'r') 16 | jsondata = archive.read('pred.json') 17 | reference = json.loads(jsondata.decode("utf-8")) 18 | return reference[0], reference[1] 19 | 20 | 21 | def multiscale_fusion(output_dir): 22 | s = '10' 23 | filepath = output_dir+'ckpt200-sc10_rot0-pred.zip' 24 | ref_joints, ref_vertices = load_pred_json(filepath) 25 | ref_joints_array = np.asarray(ref_joints) 26 | ref_vertices_array = np.asarray(ref_vertices) 27 | 28 | rotations = [0.0] 29 | for i in range(1,10): 30 | rotations.append(i*10) 31 | rotations.append(i*-10) 32 | 33 | scale = [0.7,0.8,0.9,1.0,1.1] 34 | multiscale_joints = [] 35 | multiscale_vertices = [] 36 | 37 | counter = 0 38 | for s in scale: 39 | for r in rotations: 40 | setting = 'sc%02d_rot%s'%(int(s*10),str(int(r))) 41 | filepath = output_dir+'ckpt200-'+setting+'-pred.zip' 42 | joints, vertices = load_pred_json(filepath) 43 | joints_array = np.asarray(joints) 44 | vertices_array = np.asarray(vertices) 45 | 46 | pa_joint_error, pa_joint_array, _ = get_alignMesh(joints_array, ref_joints_array, reduction=None) 47 | pa_vertices_error, pa_vertices_array, _ = get_alignMesh(vertices_array, ref_vertices_array, reduction=None) 48 | print('--------------------------') 49 | print('scale:', s, 'rotate', r) 50 | print('PAMPJPE:', 1000*np.mean(pa_joint_error)) 51 | print('PAMPVPE:', 1000*np.mean(pa_vertices_error)) 52 | multiscale_joints.append(pa_joint_array) 53 | multiscale_vertices.append(pa_vertices_array) 54 | counter = counter + 1 55 | 56 | overall_joints_array = ref_joints_array.copy() 57 | overall_vertices_array = ref_vertices_array.copy() 58 | for i in range(counter): 59 | overall_joints_array += multiscale_joints[i] 60 | overall_vertices_array += multiscale_vertices[i] 61 | 62 | overall_joints_array /= (1+counter) 63 | overall_vertices_array /= (1+counter) 64 | pa_joint_error, pa_joint_array, _ = get_alignMesh(overall_joints_array, ref_joints_array, reduction=None) 65 | pa_vertices_error, pa_vertices_array, _ = get_alignMesh(overall_vertices_array, ref_vertices_array, reduction=None) 66 | print('--------------------------') 67 | print('overall:') 68 | print('PAMPJPE:', 1000*np.mean(pa_joint_error)) 69 | print('PAMPVPE:', 1000*np.mean(pa_vertices_error)) 70 | 71 | joint_output_save = overall_joints_array.tolist() 72 | mesh_output_save = overall_vertices_array.tolist() 73 | 74 | print('save results to pred.json') 75 | with open('pred.json', 'w') as f: 76 | json.dump([joint_output_save, mesh_output_save], f) 77 | 78 | 79 | filepath = output_dir+'ckpt200-multisc-pred.zip' 80 | resolved_submit_cmd = 'zip ' + filepath + ' ' + 'pred.json' 81 | print(resolved_submit_cmd) 82 | os.system(resolved_submit_cmd) 83 | resolved_submit_cmd = 'rm pred.json' 84 | print(resolved_submit_cmd) 85 | os.system(resolved_submit_cmd) 86 | 87 | 88 | def run_multiscale_inference(model_path, mode, output_dir): 89 | 90 | if mode==True: 91 | rotations = [0.0] 92 | for i in range(1,10): 93 | rotations.append(i*10) 94 | rotations.append(i*-10) 95 | scale = [0.7,0.8,0.9,1.0,1.1] 96 | else: 97 | rotations = [0.0] 98 | scale = [1.0] 99 | 100 | job_cmd = "python ./src/tools/run_gphmer_handmesh.py " \ 101 | "--val_yaml freihand_v3/test.yaml " \ 102 | "--resume_checkpoint %s " \ 103 | "--per_gpu_eval_batch_size 32 --run_eval_only --num_worker 2 " \ 104 | "--multiscale_inference " \ 105 | "--rot %f " \ 106 | "--sc %s " \ 107 | "--arch hrnet-w64 " \ 108 | "--num_hidden_layers 4 " \ 109 | "--num_attention_heads 4 " \ 110 | "--input_feat_dim 2051,512,128 " \ 111 | "--hidden_feat_dim 1024,256,64 " \ 112 | "--output_dir %s" 113 | 114 | for s in scale: 115 | for r in rotations: 116 | resolved_submit_cmd = job_cmd%(model_path, r, s, output_dir) 117 | print(resolved_submit_cmd) 118 | os.system(resolved_submit_cmd) 119 | 120 | def main(args): 121 | model_path = args.model_path 122 | mode = args.multiscale_inference 123 | output_dir = args.output_dir 124 | run_multiscale_inference(model_path, mode, output_dir) 125 | if mode==True: 126 | multiscale_fusion(output_dir) 127 | 128 | 129 | if __name__ == "__main__": 130 | parser = argparse.ArgumentParser(description="Evaluate a checkpoint in the folder") 131 | parser.add_argument("--model_path") 132 | parser.add_argument("--multiscale_inference", default=False, action='store_true',) 133 | parser.add_argument("--output_dir", default='output/', type=str, required=False, 134 | help="The output directory to save checkpoint and test results.") 135 | args = parser.parse_args() 136 | main(args) 137 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/comm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | This file contains primitives for multi-gpu communication. 6 | This is useful when doing distributed training. 7 | """ 8 | 9 | import pickle 10 | import time 11 | 12 | import torch 13 | import torch.distributed as dist 14 | 15 | 16 | def get_world_size(): 17 | if not dist.is_available(): 18 | return 1 19 | if not dist.is_initialized(): 20 | return 1 21 | return dist.get_world_size() 22 | 23 | 24 | def get_rank(): 25 | if not dist.is_available(): 26 | return 0 27 | if not dist.is_initialized(): 28 | return 0 29 | return dist.get_rank() 30 | 31 | 32 | def is_main_process(): 33 | return get_rank() == 0 34 | 35 | 36 | def synchronize(): 37 | """ 38 | Helper function to synchronize (barrier) among all processes when 39 | using distributed training 40 | """ 41 | if not dist.is_available(): 42 | return 43 | if not dist.is_initialized(): 44 | return 45 | world_size = dist.get_world_size() 46 | if world_size == 1: 47 | return 48 | dist.barrier() 49 | 50 | 51 | def gather_on_master(data): 52 | """Same as all_gather, but gathers data on master process only, using CPU. 53 | Thus, this does not work with NCCL backend unless they add CPU support. 54 | 55 | The memory consumption of this function is ~ 3x of data size. While in 56 | principal, it should be ~2x, it's not easy to force Python to release 57 | memory immediately and thus, peak memory usage could be up to 3x. 58 | """ 59 | world_size = get_world_size() 60 | if world_size == 1: 61 | return [data] 62 | 63 | # serialized to a Tensor 64 | buffer = pickle.dumps(data) 65 | # trying to optimize memory, but in fact, it's not guaranteed to be released 66 | del data 67 | storage = torch.ByteStorage.from_buffer(buffer) 68 | del buffer 69 | tensor = torch.ByteTensor(storage) 70 | 71 | # obtain Tensor size of each rank 72 | local_size = torch.LongTensor([tensor.numel()]) 73 | size_list = [torch.LongTensor([0]) for _ in range(world_size)] 74 | dist.all_gather(size_list, local_size) 75 | size_list = [int(size.item()) for size in size_list] 76 | max_size = max(size_list) 77 | 78 | if local_size != max_size: 79 | padding = torch.ByteTensor(size=(max_size - local_size,)) 80 | tensor = torch.cat((tensor, padding), dim=0) 81 | del padding 82 | 83 | if is_main_process(): 84 | tensor_list = [] 85 | for _ in size_list: 86 | tensor_list.append(torch.ByteTensor(size=(max_size,))) 87 | dist.gather(tensor, gather_list=tensor_list, dst=0) 88 | del tensor 89 | else: 90 | dist.gather(tensor, gather_list=[], dst=0) 91 | del tensor 92 | return 93 | 94 | data_list = [] 95 | for tensor in tensor_list: 96 | buffer = tensor.cpu().numpy().tobytes() 97 | del tensor 98 | data_list.append(pickle.loads(buffer)) 99 | del buffer 100 | 101 | return data_list 102 | 103 | 104 | def all_gather(data): 105 | """ 106 | Run all_gather on arbitrary picklable data (not necessarily tensors) 107 | Args: 108 | data: any picklable object 109 | Returns: 110 | list[data]: list of data gathered from each rank 111 | """ 112 | world_size = get_world_size() 113 | if world_size == 1: 114 | return [data] 115 | 116 | # serialized to a Tensor 117 | buffer = pickle.dumps(data) 118 | storage = torch.ByteStorage.from_buffer(buffer) 119 | tensor = torch.ByteTensor(storage).to("cuda") 120 | 121 | # obtain Tensor size of each rank 122 | local_size = torch.LongTensor([tensor.numel()]).to("cuda") 123 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] 124 | dist.all_gather(size_list, local_size) 125 | size_list = [int(size.item()) for size in size_list] 126 | max_size = max(size_list) 127 | 128 | # receiving Tensor from all ranks 129 | # we pad the tensor because torch all_gather does not support 130 | # gathering tensors of different shapes 131 | tensor_list = [] 132 | for _ in size_list: 133 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 134 | if local_size != max_size: 135 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 136 | tensor = torch.cat((tensor, padding), dim=0) 137 | dist.all_gather(tensor_list, tensor) 138 | 139 | data_list = [] 140 | for size, tensor in zip(size_list, tensor_list): 141 | buffer = tensor.cpu().numpy().tobytes()[:size] 142 | data_list.append(pickle.loads(buffer)) 143 | 144 | return data_list 145 | 146 | 147 | def reduce_dict(input_dict, average=True): 148 | """ 149 | Args: 150 | input_dict (dict): all the values will be reduced 151 | average (bool): whether to do average or sum 152 | Reduce the values in the dictionary from all processes so that process with rank 153 | 0 has the averaged results. Returns a dict with the same fields as 154 | input_dict, after reduction. 155 | """ 156 | world_size = get_world_size() 157 | if world_size < 2: 158 | return input_dict 159 | with torch.no_grad(): 160 | names = [] 161 | values = [] 162 | # sort the keys so that they are consistent across processes 163 | for k in sorted(input_dict.keys()): 164 | names.append(k) 165 | values.append(input_dict[k]) 166 | values = torch.stack(values, dim=0) 167 | dist.reduce(values, dst=0) 168 | if dist.get_rank() == 0 and average: 169 | # only main process gets accumulated, so only divide by 170 | # world_size in this case 171 | values /= world_size 172 | reduced_dict = {k: v for k, v in zip(names, values)} 173 | return reduced_dict 174 | -------------------------------------------------------------------------------- /src/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | """ 6 | 7 | 8 | import os 9 | import os.path as op 10 | import numpy as np 11 | import base64 12 | import cv2 13 | import yaml 14 | from collections import OrderedDict 15 | 16 | 17 | def img_from_base64(imagestring): 18 | try: 19 | jpgbytestring = base64.b64decode(imagestring) 20 | nparr = np.frombuffer(jpgbytestring, np.uint8) 21 | r = cv2.imdecode(nparr, cv2.IMREAD_COLOR) 22 | return r 23 | except: 24 | return None 25 | 26 | 27 | def load_labelmap(labelmap_file): 28 | label_dict = None 29 | if labelmap_file is not None and op.isfile(labelmap_file): 30 | label_dict = OrderedDict() 31 | with open(labelmap_file, 'r') as fp: 32 | for line in fp: 33 | label = line.strip().split('\t')[0] 34 | if label in label_dict: 35 | raise ValueError("Duplicate label " + label + " in labelmap.") 36 | else: 37 | label_dict[label] = len(label_dict) 38 | return label_dict 39 | 40 | 41 | def load_shuffle_file(shuf_file): 42 | shuf_list = None 43 | if shuf_file is not None: 44 | with open(shuf_file, 'r') as fp: 45 | shuf_list = [] 46 | for i in fp: 47 | shuf_list.append(int(i.strip())) 48 | return shuf_list 49 | 50 | 51 | def load_box_shuffle_file(shuf_file): 52 | if shuf_file is not None: 53 | with open(shuf_file, 'r') as fp: 54 | img_shuf_list = [] 55 | box_shuf_list = [] 56 | for i in fp: 57 | idx = [int(_) for _ in i.strip().split('\t')] 58 | img_shuf_list.append(idx[0]) 59 | box_shuf_list.append(idx[1]) 60 | return [img_shuf_list, box_shuf_list] 61 | return None 62 | 63 | 64 | def load_from_yaml_file(file_name): 65 | with open(file_name, 'r') as fp: 66 | return yaml.load(fp, Loader=yaml.CLoader) 67 | -------------------------------------------------------------------------------- /src/utils/geometric_layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Useful geometric operations, e.g. Orthographic projection and a differentiable Rodrigues formula 3 | 4 | Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR 5 | """ 6 | import torch 7 | 8 | def rodrigues(theta): 9 | """Convert axis-angle representation to rotation matrix. 10 | Args: 11 | theta: size = [B, 3] 12 | Returns: 13 | Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] 14 | """ 15 | l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1) 16 | angle = torch.unsqueeze(l1norm, -1) 17 | normalized = torch.div(theta, angle) 18 | angle = angle * 0.5 19 | v_cos = torch.cos(angle) 20 | v_sin = torch.sin(angle) 21 | quat = torch.cat([v_cos, v_sin * normalized], dim = 1) 22 | return quat2mat(quat) 23 | 24 | def quat2mat(quat): 25 | """Convert quaternion coefficients to rotation matrix. 26 | Args: 27 | quat: size = [B, 4] 4 <===>(w, x, y, z) 28 | Returns: 29 | Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] 30 | """ 31 | norm_quat = quat 32 | norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) 33 | w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3] 34 | 35 | B = quat.size(0) 36 | 37 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) 38 | wx, wy, wz = w*x, w*y, w*z 39 | xy, xz, yz = x*y, x*z, y*z 40 | 41 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, 42 | 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, 43 | 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) 44 | return rotMat 45 | 46 | def orthographic_projection(X, camera): 47 | """Perform orthographic projection of 3D points X using the camera parameters 48 | Args: 49 | X: size = [B, N, 3] 50 | camera: size = [B, 3] 51 | Returns: 52 | Projected 2D points -- size = [B, N, 2] 53 | """ 54 | camera = camera.view(-1, 1, 3) 55 | X_trans = X[:, :, :2] + camera[:, :, 1:] 56 | shape = X_trans.shape 57 | X_2d = (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape) 58 | return X_2d 59 | -------------------------------------------------------------------------------- /src/utils/image_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image processing tools 3 | 4 | Modified from open source projects: 5 | (https://github.com/nkolot/GraphCMR/) 6 | (https://github.com/open-mmlab/mmdetection) 7 | 8 | """ 9 | 10 | import numpy as np 11 | import base64 12 | import cv2 13 | import torch 14 | import scipy.misc 15 | 16 | def img_from_base64(imagestring): 17 | try: 18 | jpgbytestring = base64.b64decode(imagestring) 19 | nparr = np.frombuffer(jpgbytestring, np.uint8) 20 | r = cv2.imdecode(nparr, cv2.IMREAD_COLOR) 21 | return r 22 | except ValueError: 23 | return None 24 | 25 | def myimrotate(img, angle, center=None, scale=1.0, border_value=0, auto_bound=False): 26 | if center is not None and auto_bound: 27 | raise ValueError('`auto_bound` conflicts with `center`') 28 | h, w = img.shape[:2] 29 | if center is None: 30 | center = ((w - 1) * 0.5, (h - 1) * 0.5) 31 | assert isinstance(center, tuple) 32 | 33 | matrix = cv2.getRotationMatrix2D(center, angle, scale) 34 | if auto_bound: 35 | cos = np.abs(matrix[0, 0]) 36 | sin = np.abs(matrix[0, 1]) 37 | new_w = h * sin + w * cos 38 | new_h = h * cos + w * sin 39 | matrix[0, 2] += (new_w - w) * 0.5 40 | matrix[1, 2] += (new_h - h) * 0.5 41 | w = int(np.round(new_w)) 42 | h = int(np.round(new_h)) 43 | rotated = cv2.warpAffine(img, matrix, (w, h), borderValue=border_value) 44 | return rotated 45 | 46 | def myimresize(img, size, return_scale=False, interpolation='bilinear'): 47 | 48 | h, w = img.shape[:2] 49 | resized_img = cv2.resize( 50 | img, (size[0],size[1]), interpolation=cv2.INTER_LINEAR) 51 | if not return_scale: 52 | return resized_img 53 | else: 54 | w_scale = size[0] / w 55 | h_scale = size[1] / h 56 | return resized_img, w_scale, h_scale 57 | 58 | 59 | def get_transform(center, scale, res, rot=0): 60 | """Generate transformation matrix.""" 61 | h = 200 * scale 62 | t = np.zeros((3, 3)) 63 | t[0, 0] = float(res[1]) / h 64 | t[1, 1] = float(res[0]) / h 65 | t[0, 2] = res[1] * (-float(center[0]) / h + .5) 66 | t[1, 2] = res[0] * (-float(center[1]) / h + .5) 67 | t[2, 2] = 1 68 | if not rot == 0: 69 | rot = -rot # To match direction of rotation from cropping 70 | rot_mat = np.zeros((3,3)) 71 | rot_rad = rot * np.pi / 180 72 | sn,cs = np.sin(rot_rad), np.cos(rot_rad) 73 | rot_mat[0,:2] = [cs, -sn] 74 | rot_mat[1,:2] = [sn, cs] 75 | rot_mat[2,2] = 1 76 | # Need to rotate around center 77 | t_mat = np.eye(3) 78 | t_mat[0,2] = -res[1]/2 79 | t_mat[1,2] = -res[0]/2 80 | t_inv = t_mat.copy() 81 | t_inv[:2,2] *= -1 82 | t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t))) 83 | return t 84 | 85 | def transform(pt, center, scale, res, invert=0, rot=0): 86 | """Transform pixel location to different reference.""" 87 | t = get_transform(center, scale, res, rot=rot) 88 | if invert: 89 | # t = np.linalg.inv(t) 90 | t_torch = torch.from_numpy(t) 91 | t_torch = torch.inverse(t_torch) 92 | t = t_torch.numpy() 93 | new_pt = np.array([pt[0]-1, pt[1]-1, 1.]).T 94 | new_pt = np.dot(t, new_pt) 95 | return new_pt[:2].astype(int)+1 96 | 97 | def crop(img, center, scale, res, rot=0): 98 | """Crop image according to the supplied bounding box.""" 99 | # Upper left point 100 | ul = np.array(transform([1, 1], center, scale, res, invert=1))-1 101 | # Bottom right point 102 | br = np.array(transform([res[0]+1, 103 | res[1]+1], center, scale, res, invert=1))-1 104 | # Padding so that when rotated proper amount of context is included 105 | pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) 106 | if not rot == 0: 107 | ul -= pad 108 | br += pad 109 | new_shape = [br[1] - ul[1], br[0] - ul[0]] 110 | if len(img.shape) > 2: 111 | new_shape += [img.shape[2]] 112 | new_img = np.zeros(new_shape) 113 | 114 | # Range to fill new array 115 | new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] 116 | new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] 117 | # Range to sample from original image 118 | old_x = max(0, ul[0]), min(len(img[0]), br[0]) 119 | old_y = max(0, ul[1]), min(len(img), br[1]) 120 | 121 | new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], 122 | old_x[0]:old_x[1]] 123 | if not rot == 0: 124 | # Remove padding 125 | # new_img = scipy.misc.imrotate(new_img, rot) 126 | new_img = myimrotate(new_img, rot) 127 | new_img = new_img[pad:-pad, pad:-pad] 128 | 129 | # new_img = scipy.misc.imresize(new_img, res) 130 | new_img = myimresize(new_img, [res[0], res[1]]) 131 | return new_img 132 | 133 | def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True): 134 | """'Undo' the image cropping/resizing. 135 | This function is used when evaluating mask/part segmentation. 136 | """ 137 | res = img.shape[:2] 138 | # Upper left point 139 | ul = np.array(transform([1, 1], center, scale, res, invert=1))-1 140 | # Bottom right point 141 | br = np.array(transform([res[0]+1,res[1]+1], center, scale, res, invert=1))-1 142 | # size of cropped image 143 | crop_shape = [br[1] - ul[1], br[0] - ul[0]] 144 | 145 | new_shape = [br[1] - ul[1], br[0] - ul[0]] 146 | if len(img.shape) > 2: 147 | new_shape += [img.shape[2]] 148 | new_img = np.zeros(orig_shape, dtype=np.uint8) 149 | # Range to fill new array 150 | new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0] 151 | new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1] 152 | # Range to sample from original image 153 | old_x = max(0, ul[0]), min(orig_shape[1], br[0]) 154 | old_y = max(0, ul[1]), min(orig_shape[0], br[1]) 155 | # img = scipy.misc.imresize(img, crop_shape, interp='nearest') 156 | img = myimresize(img, [crop_shape[0],crop_shape[1]]) 157 | new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]] 158 | return new_img 159 | 160 | def rot_aa(aa, rot): 161 | """Rotate axis angle parameters.""" 162 | # pose parameters 163 | R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], 164 | [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], 165 | [0, 0, 1]]) 166 | # find the rotation of the body in camera frame 167 | per_rdg, _ = cv2.Rodrigues(aa) 168 | # apply the global rotation to the global orientation 169 | resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg)) 170 | aa = (resrot.T)[0] 171 | return aa 172 | 173 | def flip_img(img): 174 | """Flip rgb images or masks. 175 | channels come last, e.g. (256,256,3). 176 | """ 177 | img = np.fliplr(img) 178 | return img 179 | 180 | def flip_kp(kp): 181 | """Flip keypoints.""" 182 | flipped_parts = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, 20, 23, 22] 183 | kp = kp[flipped_parts] 184 | kp[:,0] = - kp[:,0] 185 | return kp 186 | 187 | def flip_pose(pose): 188 | """Flip pose. 189 | The flipping is based on SMPL parameters. 190 | """ 191 | flippedParts = [0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13, 192 | 14 ,18, 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33, 193 | 34, 35, 30, 31, 32, 36, 37, 38, 42, 43, 44, 39, 40, 41, 194 | 45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54, 55, 195 | 56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68] 196 | pose = pose[flippedParts] 197 | # we also negate the second and the third dimension of the axis-angle 198 | pose[1::3] = -pose[1::3] 199 | pose[2::3] = -pose[2::3] 200 | return pose 201 | 202 | def flip_aa(aa): 203 | """Flip axis-angle representation. 204 | We negate the second and the third dimension of the axis-angle. 205 | """ 206 | aa[1] = -aa[1] 207 | aa[2] = -aa[2] 208 | return aa -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | import sys 5 | from logging import StreamHandler, Handler, getLevelName 6 | 7 | 8 | # this class is a copy of logging.FileHandler except we end self.close() 9 | # at the end of each emit. While closing file and reopening file after each 10 | # write is not efficient, it allows us to see partial logs when writing to 11 | # fused Azure blobs, which is very convenient 12 | class FileHandler(StreamHandler): 13 | """ 14 | A handler class which writes formatted logging records to disk files. 15 | """ 16 | def __init__(self, filename, mode='a', encoding=None, delay=False): 17 | """ 18 | Open the specified file and use it as the stream for logging. 19 | """ 20 | # Issue #27493: add support for Path objects to be passed in 21 | filename = os.fspath(filename) 22 | #keep the absolute path, otherwise derived classes which use this 23 | #may come a cropper when the current directory changes 24 | self.baseFilename = os.path.abspath(filename) 25 | self.mode = mode 26 | self.encoding = encoding 27 | self.delay = delay 28 | if delay: 29 | #We don't open the stream, but we still need to call the 30 | #Handler constructor to set level, formatter, lock etc. 31 | Handler.__init__(self) 32 | self.stream = None 33 | else: 34 | StreamHandler.__init__(self, self._open()) 35 | 36 | def close(self): 37 | """ 38 | Closes the stream. 39 | """ 40 | self.acquire() 41 | try: 42 | try: 43 | if self.stream: 44 | try: 45 | self.flush() 46 | finally: 47 | stream = self.stream 48 | self.stream = None 49 | if hasattr(stream, "close"): 50 | stream.close() 51 | finally: 52 | # Issue #19523: call unconditionally to 53 | # prevent a handler leak when delay is set 54 | StreamHandler.close(self) 55 | finally: 56 | self.release() 57 | 58 | def _open(self): 59 | """ 60 | Open the current base file with the (original) mode and encoding. 61 | Return the resulting stream. 62 | """ 63 | return open(self.baseFilename, self.mode, encoding=self.encoding) 64 | 65 | def emit(self, record): 66 | """ 67 | Emit a record. 68 | 69 | If the stream was not opened because 'delay' was specified in the 70 | constructor, open it before calling the superclass's emit. 71 | """ 72 | if self.stream is None: 73 | self.stream = self._open() 74 | StreamHandler.emit(self, record) 75 | self.close() 76 | 77 | def __repr__(self): 78 | level = getLevelName(self.level) 79 | return '<%s %s (%s)>' % (self.__class__.__name__, self.baseFilename, level) 80 | 81 | 82 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt"): 83 | logger = logging.getLogger(name) 84 | logger.setLevel(logging.DEBUG) 85 | # don't log results for the non-master process 86 | if distributed_rank > 0: 87 | return logger 88 | ch = logging.StreamHandler(stream=sys.stdout) 89 | ch.setLevel(logging.DEBUG) 90 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 91 | ch.setFormatter(formatter) 92 | logger.addHandler(ch) 93 | 94 | if save_dir: 95 | fh = FileHandler(os.path.join(save_dir, filename)) 96 | fh.setLevel(logging.DEBUG) 97 | fh.setFormatter(formatter) 98 | logger.addHandler(fh) 99 | 100 | return logger 101 | -------------------------------------------------------------------------------- /src/utils/metric_logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Basic logger. It Computes and stores the average and current value 6 | """ 7 | 8 | class AverageMeter(object): 9 | 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | 25 | 26 | 27 | class EvalMetricsLogger(object): 28 | 29 | def __init__(self): 30 | self.reset() 31 | 32 | def reset(self): 33 | # define a upper-bound performance (worst case) 34 | # numbers are in unit millimeter 35 | self.PAmPJPE = 100.0/1000.0 36 | self.mPJPE = 100.0/1000.0 37 | self.mPVE = 100.0/1000.0 38 | 39 | self.epoch = 0 40 | 41 | def update(self, mPVE, mPJPE, PAmPJPE, epoch): 42 | self.PAmPJPE = PAmPJPE 43 | self.mPJPE = mPJPE 44 | self.mPVE = mPVE 45 | self.epoch = epoch 46 | -------------------------------------------------------------------------------- /src/utils/metric_pampjpe.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for compuing Procrustes alignment and reconstruction error 3 | 4 | Parts of the code are adapted from https://github.com/akanazawa/hmr 5 | 6 | """ 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | import numpy as np 11 | 12 | def compute_similarity_transform(S1, S2): 13 | """Computes a similarity transform (sR, t) that takes 14 | a set of 3D points S1 (3 x N) closest to a set of 3D points S2, 15 | where R is an 3x3 rotation matrix, t 3x1 translation, s scale. 16 | i.e. solves the orthogonal Procrutes problem. 17 | """ 18 | transposed = False 19 | if S1.shape[0] != 3 and S1.shape[0] != 2: 20 | S1 = S1.T 21 | S2 = S2.T 22 | transposed = True 23 | assert(S2.shape[1] == S1.shape[1]) 24 | 25 | # 1. Remove mean. 26 | mu1 = S1.mean(axis=1, keepdims=True) 27 | mu2 = S2.mean(axis=1, keepdims=True) 28 | X1 = S1 - mu1 29 | X2 = S2 - mu2 30 | 31 | # 2. Compute variance of X1 used for scale. 32 | var1 = np.sum(X1**2) 33 | 34 | # 3. The outer product of X1 and X2. 35 | K = X1.dot(X2.T) 36 | 37 | # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are 38 | # singular vectors of K. 39 | U, s, Vh = np.linalg.svd(K) 40 | V = Vh.T 41 | # Construct Z that fixes the orientation of R to get det(R)=1. 42 | Z = np.eye(U.shape[0]) 43 | Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T))) 44 | # Construct R. 45 | R = V.dot(Z.dot(U.T)) 46 | 47 | # 5. Recover scale. 48 | scale = np.trace(R.dot(K)) / var1 49 | 50 | # 6. Recover translation. 51 | t = mu2 - scale*(R.dot(mu1)) 52 | 53 | # 7. Error: 54 | S1_hat = scale*R.dot(S1) + t 55 | 56 | if transposed: 57 | S1_hat = S1_hat.T 58 | 59 | return S1_hat 60 | 61 | def compute_similarity_transform_batch(S1, S2): 62 | """Batched version of compute_similarity_transform.""" 63 | S1_hat = np.zeros_like(S1) 64 | for i in range(S1.shape[0]): 65 | S1_hat[i] = compute_similarity_transform(S1[i], S2[i]) 66 | return S1_hat 67 | 68 | def reconstruction_error(S1, S2, reduction='mean'): 69 | """Do Procrustes alignment and compute reconstruction error.""" 70 | S1_hat = compute_similarity_transform_batch(S1, S2) 71 | re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1) 72 | if reduction == 'mean': 73 | re = re.mean() 74 | elif reduction == 'sum': 75 | re = re.sum() 76 | return re 77 | 78 | 79 | def reconstruction_error_v2(S1, S2, J24_TO_J14, reduction='mean'): 80 | """Do Procrustes alignment and compute reconstruction error.""" 81 | S1_hat = compute_similarity_transform_batch(S1, S2) 82 | S1_hat = S1_hat[:,J24_TO_J14,:] 83 | S2 = S2[:,J24_TO_J14,:] 84 | re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1) 85 | if reduction == 'mean': 86 | re = re.mean() 87 | elif reduction == 'sum': 88 | re = re.sum() 89 | return re 90 | 91 | def get_alignMesh(S1, S2, reduction='mean'): 92 | """Do Procrustes alignment and compute reconstruction error.""" 93 | S1_hat = compute_similarity_transform_batch(S1, S2) 94 | re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1) 95 | if reduction == 'mean': 96 | re = re.mean() 97 | elif reduction == 'sum': 98 | re = re.sum() 99 | return re, S1_hat, S2 100 | -------------------------------------------------------------------------------- /src/utils/miscellaneous.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import errno 3 | import os 4 | import os.path as op 5 | import re 6 | import logging 7 | import numpy as np 8 | import torch 9 | import random 10 | import shutil 11 | from .comm import is_main_process 12 | import yaml 13 | 14 | 15 | def mkdir(path): 16 | # if it is the current folder, skip. 17 | # otherwise the original code will raise FileNotFoundError 18 | if path == '': 19 | return 20 | try: 21 | os.makedirs(path) 22 | except OSError as e: 23 | if e.errno != errno.EEXIST: 24 | raise 25 | 26 | 27 | def save_config(cfg, path): 28 | if is_main_process(): 29 | with open(path, 'w') as f: 30 | f.write(cfg.dump()) 31 | 32 | 33 | def config_iteration(output_dir, max_iter): 34 | save_file = os.path.join(output_dir, 'last_checkpoint') 35 | iteration = -1 36 | if os.path.exists(save_file): 37 | with open(save_file, 'r') as f: 38 | fname = f.read().strip() 39 | model_name = os.path.basename(fname) 40 | model_path = os.path.dirname(fname) 41 | if model_name.startswith('model_') and len(model_name) == 17: 42 | iteration = int(model_name[-11:-4]) 43 | elif model_name == "model_final": 44 | iteration = max_iter 45 | elif model_path.startswith('checkpoint-') and len(model_path) == 18: 46 | iteration = int(model_path.split('-')[-1]) 47 | return iteration 48 | 49 | 50 | def get_matching_parameters(model, regexp, none_on_empty=True): 51 | """Returns parameters matching regular expression""" 52 | if not regexp: 53 | if none_on_empty: 54 | return {} 55 | else: 56 | return dict(model.named_parameters()) 57 | compiled_pattern = re.compile(regexp) 58 | params = {} 59 | for weight_name, weight in model.named_parameters(): 60 | if compiled_pattern.match(weight_name): 61 | params[weight_name] = weight 62 | return params 63 | 64 | 65 | def freeze_weights(model, regexp): 66 | """Freeze weights based on regular expression.""" 67 | logger = logging.getLogger("maskrcnn_benchmark.trainer") 68 | for weight_name, weight in get_matching_parameters(model, regexp).items(): 69 | weight.requires_grad = False 70 | logger.info("Disabled training of {}".format(weight_name)) 71 | 72 | 73 | def unfreeze_weights(model, regexp, backbone_freeze_at=-1, 74 | is_distributed=False): 75 | """Unfreeze weights based on regular expression. 76 | This is helpful during training to unfreeze freezed weights after 77 | other unfreezed weights have been trained for some iterations. 78 | """ 79 | logger = logging.getLogger("maskrcnn_benchmark.trainer") 80 | for weight_name, weight in get_matching_parameters(model, regexp).items(): 81 | weight.requires_grad = True 82 | logger.info("Enabled training of {}".format(weight_name)) 83 | if backbone_freeze_at >= 0: 84 | logger.info("Freeze backbone at stage: {}".format(backbone_freeze_at)) 85 | if is_distributed: 86 | model.module.backbone.body._freeze_backbone(backbone_freeze_at) 87 | else: 88 | model.backbone.body._freeze_backbone(backbone_freeze_at) 89 | 90 | 91 | def delete_tsv_files(tsvs): 92 | for t in tsvs: 93 | if op.isfile(t): 94 | try_delete(t) 95 | line = op.splitext(t)[0] + '.lineidx' 96 | if op.isfile(line): 97 | try_delete(line) 98 | 99 | 100 | def concat_files(ins, out): 101 | mkdir(op.dirname(out)) 102 | out_tmp = out + '.tmp' 103 | with open(out_tmp, 'wb') as fp_out: 104 | for i, f in enumerate(ins): 105 | logging.info('concating {}/{} - {}'.format(i, len(ins), f)) 106 | with open(f, 'rb') as fp_in: 107 | shutil.copyfileobj(fp_in, fp_out, 1024*1024*10) 108 | os.rename(out_tmp, out) 109 | 110 | 111 | def concat_tsv_files(tsvs, out_tsv): 112 | concat_files(tsvs, out_tsv) 113 | sizes = [os.stat(t).st_size for t in tsvs] 114 | sizes = np.cumsum(sizes) 115 | all_idx = [] 116 | for i, t in enumerate(tsvs): 117 | for idx in load_list_file(op.splitext(t)[0] + '.lineidx'): 118 | if i == 0: 119 | all_idx.append(idx) 120 | else: 121 | all_idx.append(str(int(idx) + sizes[i - 1])) 122 | with open(op.splitext(out_tsv)[0] + '.lineidx', 'w') as f: 123 | f.write('\n'.join(all_idx)) 124 | 125 | 126 | def load_list_file(fname): 127 | with open(fname, 'r') as fp: 128 | lines = fp.readlines() 129 | result = [line.strip() for line in lines] 130 | if len(result) > 0 and result[-1] == '': 131 | result = result[:-1] 132 | return result 133 | 134 | 135 | def try_once(func): 136 | def func_wrapper(*args, **kwargs): 137 | try: 138 | return func(*args, **kwargs) 139 | except Exception as e: 140 | logging.info('ignore error \n{}'.format(str(e))) 141 | return func_wrapper 142 | 143 | 144 | @try_once 145 | def try_delete(f): 146 | os.remove(f) 147 | 148 | 149 | def set_seed(seed, n_gpu): 150 | random.seed(seed) 151 | np.random.seed(seed) 152 | torch.manual_seed(seed) 153 | if n_gpu > 0: 154 | torch.cuda.manual_seed_all(seed) 155 | 156 | 157 | def print_and_run_cmd(cmd): 158 | print(cmd) 159 | os.system(cmd) 160 | 161 | 162 | def write_to_yaml_file(context, file_name): 163 | with open(file_name, 'w') as fp: 164 | yaml.dump(context, fp, encoding='utf-8') 165 | 166 | 167 | def load_from_yaml_file(yaml_file): 168 | with open(yaml_file, 'r') as fp: 169 | return yaml.load(fp, Loader=yaml.CLoader) 170 | 171 | 172 | -------------------------------------------------------------------------------- /src/utils/tsv_file.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Definition of TSV class 6 | """ 7 | 8 | 9 | import logging 10 | import os 11 | import os.path as op 12 | 13 | 14 | def generate_lineidx(filein, idxout): 15 | idxout_tmp = idxout + '.tmp' 16 | with open(filein, 'r') as tsvin, open(idxout_tmp,'w') as tsvout: 17 | fsize = os.fstat(tsvin.fileno()).st_size 18 | fpos = 0 19 | while fpos!=fsize: 20 | tsvout.write(str(fpos)+"\n") 21 | tsvin.readline() 22 | fpos = tsvin.tell() 23 | os.rename(idxout_tmp, idxout) 24 | 25 | 26 | def read_to_character(fp, c): 27 | result = [] 28 | while True: 29 | s = fp.read(32) 30 | assert s != '' 31 | if c in s: 32 | result.append(s[: s.index(c)]) 33 | break 34 | else: 35 | result.append(s) 36 | return ''.join(result) 37 | 38 | 39 | class TSVFile(object): 40 | def __init__(self, tsv_file, generate_lineidx=False): 41 | self.tsv_file = tsv_file 42 | self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' 43 | self._fp = None 44 | self._lineidx = None 45 | # the process always keeps the process which opens the file. 46 | # If the pid is not equal to the currrent pid, we will re-open the file. 47 | self.pid = None 48 | # generate lineidx if not exist 49 | if not op.isfile(self.lineidx) and generate_lineidx: 50 | generate_lineidx(self.tsv_file, self.lineidx) 51 | 52 | def __del__(self): 53 | if self._fp: 54 | self._fp.close() 55 | 56 | def __str__(self): 57 | return "TSVFile(tsv_file='{}')".format(self.tsv_file) 58 | 59 | def __repr__(self): 60 | return str(self) 61 | 62 | def num_rows(self): 63 | self._ensure_lineidx_loaded() 64 | return len(self._lineidx) 65 | 66 | def seek(self, idx): 67 | self._ensure_tsv_opened() 68 | self._ensure_lineidx_loaded() 69 | try: 70 | pos = self._lineidx[idx] 71 | except: 72 | logging.info('{}-{}'.format(self.tsv_file, idx)) 73 | raise 74 | self._fp.seek(pos) 75 | return [s.strip() for s in self._fp.readline().split('\t')] 76 | 77 | def seek_first_column(self, idx): 78 | self._ensure_tsv_opened() 79 | self._ensure_lineidx_loaded() 80 | pos = self._lineidx[idx] 81 | self._fp.seek(pos) 82 | return read_to_character(self._fp, '\t') 83 | 84 | def get_key(self, idx): 85 | return self.seek_first_column(idx) 86 | 87 | def __getitem__(self, index): 88 | return self.seek(index) 89 | 90 | def __len__(self): 91 | return self.num_rows() 92 | 93 | def _ensure_lineidx_loaded(self): 94 | if self._lineidx is None: 95 | logging.info('loading lineidx: {}'.format(self.lineidx)) 96 | with open(self.lineidx, 'r') as fp: 97 | self._lineidx = [int(i.strip()) for i in fp.readlines()] 98 | 99 | def _ensure_tsv_opened(self): 100 | if self._fp is None: 101 | self._fp = open(self.tsv_file, 'r') 102 | self.pid = os.getpid() 103 | 104 | if self.pid != os.getpid(): 105 | logging.info('re-open {} because the process id changed'.format(self.tsv_file)) 106 | self._fp = open(self.tsv_file, 'r') 107 | self.pid = os.getpid() 108 | 109 | 110 | class CompositeTSVFile(): 111 | def __init__(self, file_list, seq_file, root='.'): 112 | if isinstance(file_list, str): 113 | self.file_list = load_list_file(file_list) 114 | else: 115 | assert isinstance(file_list, list) 116 | self.file_list = file_list 117 | 118 | self.seq_file = seq_file 119 | self.root = root 120 | self.initialized = False 121 | self.initialize() 122 | 123 | def get_key(self, index): 124 | idx_source, idx_row = self.seq[index] 125 | k = self.tsvs[idx_source].get_key(idx_row) 126 | return '_'.join([self.file_list[idx_source], k]) 127 | 128 | def num_rows(self): 129 | return len(self.seq) 130 | 131 | def __getitem__(self, index): 132 | idx_source, idx_row = self.seq[index] 133 | return self.tsvs[idx_source].seek(idx_row) 134 | 135 | def __len__(self): 136 | return len(self.seq) 137 | 138 | def initialize(self): 139 | ''' 140 | this function has to be called in init function if cache_policy is 141 | enabled. Thus, let's always call it in init funciton to make it simple. 142 | ''' 143 | if self.initialized: 144 | return 145 | self.seq = [] 146 | with open(self.seq_file, 'r') as fp: 147 | for line in fp: 148 | parts = line.strip().split('\t') 149 | self.seq.append([int(parts[0]), int(parts[1])]) 150 | self.tsvs = [TSVFile(op.join(self.root, f)) for f in self.file_list] 151 | self.initialized = True 152 | 153 | 154 | def load_list_file(fname): 155 | with open(fname, 'r') as fp: 156 | lines = fp.readlines() 157 | result = [line.strip() for line in lines] 158 | if len(result) > 0 and result[-1] == '': 159 | result = result[:-1] 160 | return result 161 | 162 | 163 | -------------------------------------------------------------------------------- /src/utils/tsv_file_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Microsoft Corporation. 3 | Licensed under the MIT license. 4 | 5 | Basic operations for TSV files 6 | """ 7 | 8 | 9 | import os 10 | import os.path as op 11 | import json 12 | import numpy as np 13 | import base64 14 | import cv2 15 | from tqdm import tqdm 16 | import yaml 17 | from src.utils.miscellaneous import mkdir 18 | from src.utils.tsv_file import TSVFile 19 | 20 | 21 | def img_from_base64(imagestring): 22 | try: 23 | jpgbytestring = base64.b64decode(imagestring) 24 | nparr = np.frombuffer(jpgbytestring, np.uint8) 25 | r = cv2.imdecode(nparr, cv2.IMREAD_COLOR) 26 | return r 27 | except ValueError: 28 | return None 29 | 30 | def load_linelist_file(linelist_file): 31 | if linelist_file is not None: 32 | line_list = [] 33 | with open(linelist_file, 'r') as fp: 34 | for i in fp: 35 | line_list.append(int(i.strip())) 36 | return line_list 37 | 38 | def tsv_writer(values, tsv_file, sep='\t'): 39 | mkdir(op.dirname(tsv_file)) 40 | lineidx_file = op.splitext(tsv_file)[0] + '.lineidx' 41 | idx = 0 42 | tsv_file_tmp = tsv_file + '.tmp' 43 | lineidx_file_tmp = lineidx_file + '.tmp' 44 | with open(tsv_file_tmp, 'w') as fp, open(lineidx_file_tmp, 'w') as fpidx: 45 | assert values is not None 46 | for value in values: 47 | assert value is not None 48 | value = [v if type(v)!=bytes else v.decode('utf-8') for v in value] 49 | v = '{0}\n'.format(sep.join(map(str, value))) 50 | fp.write(v) 51 | fpidx.write(str(idx) + '\n') 52 | idx = idx + len(v) 53 | os.rename(tsv_file_tmp, tsv_file) 54 | os.rename(lineidx_file_tmp, lineidx_file) 55 | 56 | def tsv_reader(tsv_file, sep='\t'): 57 | with open(tsv_file, 'r') as fp: 58 | for i, line in enumerate(fp): 59 | yield [x.strip() for x in line.split(sep)] 60 | 61 | def config_save_file(tsv_file, save_file=None, append_str='.new.tsv'): 62 | if save_file is not None: 63 | return save_file 64 | return op.splitext(tsv_file)[0] + append_str 65 | 66 | def get_line_list(linelist_file=None, num_rows=None): 67 | if linelist_file is not None: 68 | return load_linelist_file(linelist_file) 69 | 70 | if num_rows is not None: 71 | return [i for i in range(num_rows)] 72 | 73 | def generate_hw_file(img_file, save_file=None): 74 | rows = tsv_reader(img_file) 75 | def gen_rows(): 76 | for i, row in tqdm(enumerate(rows)): 77 | row1 = [row[0]] 78 | img = img_from_base64(row[-1]) 79 | height = img.shape[0] 80 | width = img.shape[1] 81 | row1.append(json.dumps([{"height":height, "width": width}])) 82 | yield row1 83 | 84 | save_file = config_save_file(img_file, save_file, '.hw.tsv') 85 | tsv_writer(gen_rows(), save_file) 86 | 87 | def generate_linelist_file(label_file, save_file=None, ignore_attrs=()): 88 | # generate a list of image that has labels 89 | # images with only ignore labels are not selected. 90 | line_list = [] 91 | rows = tsv_reader(label_file) 92 | for i, row in tqdm(enumerate(rows)): 93 | labels = json.loads(row[1]) 94 | if labels: 95 | if ignore_attrs and all([any([lab[attr] for attr in ignore_attrs if attr in lab]) \ 96 | for lab in labels]): 97 | continue 98 | line_list.append([i]) 99 | 100 | save_file = config_save_file(label_file, save_file, '.linelist.tsv') 101 | tsv_writer(line_list, save_file) 102 | 103 | def load_from_yaml_file(yaml_file): 104 | with open(yaml_file, 'r') as fp: 105 | return yaml.load(fp, Loader=yaml.CLoader) 106 | 107 | def find_file_path_in_yaml(fname, root): 108 | if fname is not None: 109 | if op.isfile(fname): 110 | return fname 111 | elif op.isfile(op.join(root, fname)): 112 | return op.join(root, fname) 113 | else: 114 | raise FileNotFoundError( 115 | errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname) 116 | ) 117 | --------------------------------------------------------------------------------