├── .gitignore
├── .gitmodules
├── LICENSE
├── NOTICE.md
├── README.md
├── azure-pipelines.yml
├── docs
├── CODE_OF_CONDUCT.md
├── DEMO.md
├── DOWNLOAD.md
├── EXP.md
├── Fig1.gif
├── Fig2.gif
├── Fig3.gif
├── Fig4.gif
├── INSTALL.md
├── SECURITY.md
├── SUPPORT.md
└── graphormer_overview.png
├── requirements.txt
├── samples
├── hand
│ ├── freihand_sample1.jpg
│ ├── freihand_sample1_graphormer_pred.jpg
│ ├── freihand_sample2.jpg
│ ├── freihand_sample2_graphormer_pred.jpg
│ ├── freihand_sample3.jpg
│ ├── freihand_sample3_graphormer_pred.jpg
│ ├── internet_fig1.jpg
│ ├── internet_fig1_graphormer_pred.jpg
│ ├── internet_fig2.jpg
│ ├── internet_fig2_graphormer_pred.jpg
│ ├── internet_fig3.jpg
│ ├── internet_fig3_graphormer_pred.jpg
│ ├── internet_fig4.jpg
│ └── internet_fig4_graphormer_pred.jpg
└── human-body
│ ├── 3dpw_test1.jpg
│ ├── 3dpw_test1_graphormer_pred.jpg
│ ├── 3dpw_test2.jpg
│ ├── 3dpw_test2_graphormer_pred.jpg
│ ├── 3dpw_test3.jpg
│ ├── 3dpw_test3_graphormer_pred.jpg
│ ├── 3dpw_test4.jpg
│ ├── 3dpw_test4_graphormer_pred.jpg
│ ├── 3dpw_test5.jpg
│ ├── 3dpw_test5_graphormer_pred.jpg
│ ├── 3dpw_test6.jpg
│ ├── 3dpw_test6_graphormer_pred.jpg
│ ├── 3dpw_test7.jpg
│ └── 3dpw_test7_graphormer_pred.jpg
├── scripts
├── download_models.sh
└── download_preds.sh
├── setup.py
└── src
├── __init__.py
├── datasets
├── __init__.py
├── build.py
├── hand_mesh_tsv.py
└── human_mesh_tsv.py
├── modeling
├── __init__.py
├── _gcnn.py
├── _mano.py
├── _smpl.py
├── bert
│ ├── __init__.py
│ ├── bert-base-uncased
│ │ └── config.json
│ ├── e2e_body_network.py
│ ├── e2e_hand_network.py
│ ├── file_utils.py
│ ├── modeling_bert.py
│ ├── modeling_graphormer.py
│ └── modeling_utils.py
├── data
│ ├── J_regressor_extra.npy
│ ├── J_regressor_h36m_correct.npy
│ ├── README.md
│ ├── config.py
│ ├── mano_195_adjmat_indices.pt
│ ├── mano_195_adjmat_size.pt
│ ├── mano_195_adjmat_values.pt
│ ├── mano_downsampling.npz
│ ├── mesh_downsampling.npz
│ ├── smpl_431_adjmat_indices.pt
│ ├── smpl_431_adjmat_size.pt
│ ├── smpl_431_adjmat_values.pt
│ └── smpl_431_faces.npy
└── hrnet
│ ├── config
│ ├── __init__.py
│ ├── default.py
│ └── models.py
│ ├── hrnet_cls_net.py
│ └── hrnet_cls_net_gridfeat.py
├── tools
├── run_gphmer_bodymesh.py
├── run_gphmer_bodymesh_inference.py
├── run_gphmer_handmesh.py
├── run_gphmer_handmesh_inference.py
└── run_hand_multiscale.py
└── utils
├── __init__.py
├── comm.py
├── dataset_utils.py
├── geometric_layers.py
├── image_ops.py
├── logger.py
├── metric_logger.py
├── metric_pampjpe.py
├── miscellaneous.py
├── renderer.py
├── tsv_file.py
└── tsv_file_ops.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # compilation and distribution
2 | __pycache__
3 | _ext
4 | *.pyc
5 | *.so
6 | build/
7 | dist/
8 | *.egg-info/
9 |
10 | # pytorch/python formats
11 | *.pth
12 | *.pkl
13 |
14 | # ipython/jupyter notebooks
15 | *.ipynb
16 | **/.ipynb_checkpoints/
17 |
18 | # Editor temporaries
19 | *.swn
20 | *.swo
21 | *.swp
22 | *~
23 |
24 | # Pycharm editor settings
25 | .idea
26 |
27 | # VS code
28 | .vscode
29 |
30 | # MacOS
31 | .DS_Store
32 |
33 | # project dirs
34 | /data
35 | /datasets
36 | models
37 | output
38 | exps
39 | predictions
40 |
41 | # hidden folder for aml configs
42 | .azureml
43 | .git
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "manopth"]
2 | path = manopth
3 | url = https://github.com/hassony2/manopth.git
4 | [submodule "transformers"]
5 | path = transformers
6 | url = https://github.com/huggingface/transformers.git
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MeshGraphormer ✨✨
2 |
3 |
4 | This is our research code of [Mesh Graphormer](https://arxiv.org/abs/2104.00272).
5 |
6 | Mesh Graphormer is a new transformer-based method for human pose and mesh reconsruction from an input image. In this work, we study how to combine graph convolutions and self-attentions in a transformer to better model both local and global interactions.
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | ## Installation
15 | Check [INSTALL.md](docs/INSTALL.md) for installation instructions.
16 |
17 |
18 | ## Model Zoo and Download
19 | Please download our pre-trained models and other relevant files that are important to run our code.
20 |
21 | Check [DOWNLOAD.md](docs/DOWNLOAD.md) for details.
22 |
23 | ## Quick demo
24 | We provide demo codes to run end-to-end inference on the test images.
25 |
26 | Check [DEMO.md](docs/DEMO.md) for details.
27 |
28 | ## Experiments
29 | We provide python codes for training and evaluation.
30 |
31 | Check [EXP.md](docs/EXP.md) for details.
32 |
33 |
34 | ## License
35 |
36 | Our research code is released under the MIT license. See [LICENSE](LICENSE) for details.
37 |
38 | We use submodules from third parties, such as [huggingface/transformers](https://github.com/huggingface/transformers) and [hassony2/manopth](https://github.com/hassony2/manopth). Please see [NOTICE](NOTICE.md) for details.
39 |
40 | Our models have dependency with SMPL and MANO models. Please note that any use of SMPL models and MANO models are subject to **Software Copyright License for non-commercial scientific research purposes**. Please see [SMPL-Model License](https://smpl.is.tue.mpg.de/modellicense) and [MANO License](https://mano.is.tue.mpg.de/license) for details.
41 |
42 |
43 | ## Contributing
44 |
45 | We welcome contributions and suggestions. Please check [CONTRIBUTE](docs/CONTRIBUTE.md) and [CODE_OF_CONDUCT](docs/CODE_OF_CONDUCT.md) for details.
46 |
47 |
48 | ## Citations
49 | If you find our work useful in your research, please consider citing:
50 |
51 | ```bibtex
52 | @inproceedings{lin2021-mesh-graphormer,
53 | author = {Lin, Kevin and Wang, Lijuan and Liu, Zicheng},
54 | title = {Mesh Graphormer},
55 | booktitle = {ICCV},
56 | year = {2021},
57 | }
58 | ```
59 |
60 |
61 | ## Acknowledgments
62 |
63 | Our implementation and experiments are built on top of open-source GitHub repositories. We thank all the authors who made their code public, which tremendously accelerates our project progress. If you find these works helpful, please consider citing them as well.
64 |
65 | [huggingface/transformers](https://github.com/huggingface/transformers)
66 |
67 | [HRNet/HRNet-Image-Classification](https://github.com/HRNet/HRNet-Image-Classification)
68 |
69 | [nkolot/GraphCMR](https://github.com/nkolot/GraphCMR)
70 |
71 | [akanazawa/hmr](https://github.com/akanazawa/hmr)
72 |
73 | [MandyMo/pytorch_HMR](https://github.com/MandyMo/pytorch_HMR)
74 |
75 | [hassony2/manopth](https://github.com/hassony2/manopth)
76 |
77 | [hongsukchoi/Pose2Mesh_RELEASE](https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
78 |
79 | [mks0601/I2L-MeshNet_RELEASE](https://github.com/mks0601/I2L-MeshNet_RELEASE)
80 |
81 | [open-mmlab/mmdetection](https://github.com/open-mmlab/mmdetection)
82 |
--------------------------------------------------------------------------------
/azure-pipelines.yml:
--------------------------------------------------------------------------------
1 | # Starter pipeline
2 | # Start with a minimal pipeline that you can customize to build and deploy your code.
3 | # Add steps that build, run tests, deploy, and more:
4 | # https://aka.ms/yaml
5 |
6 | trigger:
7 | - main
8 |
9 | pool:
10 | vmImage: ubuntu-latest
11 | strategy:
12 | matrix:
13 | Python36:
14 | python.version: '3.6'
15 |
16 | steps:
17 | - task: UsePythonVersion@0
18 | inputs:
19 | versionSpec: '$(python.version)'
20 | displayName: 'Use Python $(python.version)'
21 |
22 | - script: |
23 | python -m pip install --upgrade pip
24 | pip install -r requirements.txt
25 | displayName: 'Install dependencies'
26 |
27 |
--------------------------------------------------------------------------------
/docs/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/docs/DEMO.md:
--------------------------------------------------------------------------------
1 | # Quick Demo
2 | We provide demo codes for end-to-end inference here.
3 |
4 | Our inference codes will iterate all images in a given folder, and generate the results.
5 |
6 | ## Important notes
7 |
8 | - **This demo doesn't perform human/hand detection**. Our model requires a centered target in the image.
9 | - As **Mesh Graphormer is a data-driven approach**, it may not perform well if the test samples are very different from the training data. We observe that our model does not work well if the target is out-of-the-view. Some examples can be found in our supplementary material (Sec. I Limitations).
10 |
11 | ## Human Body Reconstruction
12 |
13 | This demo runs 3D human mesh reconstruction from a single image.
14 |
15 | Our codes require the input images that are already **cropped with the person centered** in the image. The input images should have the size of `224x224`. To run the demo, please place your test images under `./samples/human-body`, and then run the following script.
16 |
17 |
18 | ```bash
19 | python ./src/tools/run_gphmer_bodymesh_inference.py
20 | --resume_checkpoint ./models/graphormer_release/graphormer_3dpw_state_dict.bin
21 | --image_file_or_path ./samples/human-body
22 | ```
23 | After running, it will generate the results in the folder `./samples/human-body`
24 |
25 |
26 |
27 | ## Hand Reconstruction
28 |
29 | This demo runs 3D hand reconstruction from a single image.
30 |
31 | You may want to provide the images that are already **cropped with the right-hand centered** in the image. The input images should have the size of `224x224`. Please place the images under `./samples/hand`, and run the following script.
32 |
33 | ```bash
34 | python ./src/tools/run_gphmer_handmesh_inference.py
35 | --resume_checkpoint ./models/graphormer_release/graphormer_hand_state_dict.bin
36 | --image_file_or_path ./samples/hand
37 | ```
38 | After running, it will outputs the results in the folder `./samples/hand`
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
--------------------------------------------------------------------------------
/docs/DOWNLOAD.md:
--------------------------------------------------------------------------------
1 | # Download
2 |
3 | ### Getting Started
4 |
5 | 1. Create folders that store pretrained models, datasets, and predictions.
6 | ```bash
7 | export REPO_DIR=$PWD
8 | mkdir -p $REPO_DIR/models # pre-trained models
9 | mkdir -p $REPO_DIR/datasets # datasets
10 | mkdir -p $REPO_DIR/predictions # prediction outputs
11 | ```
12 |
13 | 2. Download pretrained models.
14 |
15 | Our pre-trained models can be downloaded with the following command.
16 | ```bash
17 | cd $REPO_DIR
18 | bash scripts/download_models.sh
19 | ```
20 | The scripts will download three models that are trained for mesh reconstruction on Human3.6M, 3DPW, and FreiHAND, respectively. For your convenience, this script will also download HRNet pre-trained weights, which will be used in training.
21 |
22 | The resulting data structure should follow the hierarchy as below.
23 | ```
24 | ${REPO_DIR}
25 | |-- models
26 | | |-- graphormer_release
27 | | | |-- graphormer_h36m_state_dict.bin
28 | | | |-- graphormer_3dpw_state_dict.bin
29 | | | |-- graphormer_hand_state_dict.bin
30 | | |-- hrnet
31 | | | |-- hrnetv2_w40_imagenet_pretrained.pth
32 | | | |-- hrnetv2_w64_imagenet_pretrained.pth
33 | | | |-- cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml
34 | | | |-- cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml
35 | |-- src
36 | |-- datasets
37 | |-- predictions
38 | |-- README.md
39 | |-- ...
40 | |-- ...
41 | ```
42 |
43 | 3. Download SMPL and MANO models from their official websites
44 |
45 | To run our code smoothly, please visit the following websites to download SMPL and MANO models.
46 |
47 | - Download `basicModel_neutral_lbs_10_207_0_v1.0.0.pkl` from [SMPLify](http://smplify.is.tue.mpg.de/), and place it at `${REPO_DIR}/src/modeling/data`.
48 | - Download `MANO_RIGHT.pkl` from [MANO](https://mano.is.tue.mpg.de/), and place it at `${REPO_DIR}/src/modeling/data`.
49 |
50 | Please put the downloaded files under the `${REPO_DIR}/src/modeling/data` directory. The data structure should follow the hierarchy below.
51 | ```
52 | ${REPO_DIR}
53 | |-- src
54 | | |-- modeling
55 | | | |-- data
56 | | | | |-- basicModel_neutral_lbs_10_207_0_v1.0.0.pkl
57 | | | | |-- MANO_RIGHT.pkl
58 | |-- models
59 | |-- datasets
60 | |-- predictions
61 | |-- README.md
62 | |-- ...
63 | |-- ...
64 | ```
65 | Please check [/src/modeling/data/README.md](../src/modeling/data/README.md) for further details.
66 |
67 | 4. Download prediction files that were evaluated on FreiHAND Leaderboard.
68 |
69 | The prediction files can be downloaded with the following command.
70 | ```bash
71 | cd $REPO_DIR
72 | bash scripts/download_preds.sh
73 | ```
74 | You could submit the prediction files to FreiHAND Leaderboard and reproduce our results.
75 |
76 | 5. Download datasets and pseudo labels for training.
77 |
78 | We use the same data from our previous project [METRO](https://github.com/microsoft/MeshTransformer)
79 |
80 | Please visit our previous project page to download datasets and annotations for experiments. Click [LINK](https://github.com/microsoft/MeshTransformer/blob/main/docs/DOWNLOAD.md) here.
81 |
--------------------------------------------------------------------------------
/docs/EXP.md:
--------------------------------------------------------------------------------
1 | # Training and evaluation
2 |
3 |
4 | ## Table of contents
5 | * [3D hand experiment](#3D-hand-reconstruction-from-a-single-image)
6 | * [Training](#Training)
7 | * [Testing](#Testing)
8 | * [3D human body experiment](#Human-mesh-reconstruction-from-a-single-image)
9 | * [Training with mixed 2D+3D datasets](#Training-with-mixed-datasets)
10 | * [Evaluation on Human3.6M](#Evaluation-on-Human3.6M)
11 | * [Training with 3DPW](#Training-with-3DPW-dataset)
12 | * [Evaluation on 3DPW](#Evaluation-on-3DPW)
13 |
14 |
15 | ## 3D hand reconstruction from a single image
16 |
17 | ### Training
18 |
19 | We use the following script to train on FreiHAND dataset.
20 |
21 | ```bash
22 | python -m torch.distributed.launch --nproc_per_node=8 \
23 | src/tools/run_gphmer_handmesh.py \
24 | --train_yaml freihand/train.yaml \
25 | --val_yaml freihand/test.yaml \
26 | --arch hrnet-w64 \
27 | --num_workers 4 \
28 | --per_gpu_train_batch_size 32 \
29 | --per_gpu_eval_batch_size 32 \
30 | --num_hidden_layers 4 \
31 | --num_attention_heads 4 \
32 | --lr 1e-4 \
33 | --num_train_epochs 200 \
34 | --input_feat_dim 2051,512,128 \
35 | --hidden_feat_dim 1024,256,64
36 | ```
37 |
38 |
39 | Example training log can be found here [2021-03-06-graphormer_freihand_log.txt](https://datarelease.blob.core.windows.net/metro/models/2021-03-06-graphormer_freihand_log.txt)
40 |
41 | ### Testing
42 |
43 | After training, we use the final checkpoint (trained at 200 epoch) for testing.
44 |
45 | We use the following script to generate predictions. It will generate a prediction file called `ckpt200-sc10_rot0-pred.zip`. Afte that, please submit the prediction file to [FreiHAND Leaderboard](https://competitions.codalab.org/competitions/21238) to obtain the evlauation scores.
46 |
47 |
48 | In the following script, we perform prediction with test-time augmentation on FreiHAND experiments. We will generate a prediction file `ckpt200-multisc-pred.zip`.
49 |
50 | ```bash
51 | python src/tools/run_hand_multiscale.py \
52 | --multiscale_inference \
53 | --model_path models/graphormer_release/graphormer_hand_state_dict.bin \
54 | ```
55 |
56 | To reproduce our results, we have released our prediction file `ckpt200-multisc-pred.zip` (see `docs/DOWNLOAD.md`). You may want to submit it to the Leaderboard, and it should produce the following results.
57 |
58 | ```bash
59 | Evaluation 3D KP results:
60 | auc=0.000, mean_kp3d_avg=71.48 cm
61 | Evaluation 3D KP ALIGNED results:
62 | auc=0.883, mean_kp3d_avg=0.59 cm
63 |
64 | Evaluation 3D MESH results:
65 | auc=0.000, mean_kp3d_avg=71.47 cm
66 | Evaluation 3D MESH ALIGNED results:
67 | auc=0.880, mean_kp3d_avg=0.60 cm
68 |
69 | F-scores
70 | F@5.0mm = 0.000 F_aligned@5.0mm = 0.764
71 | F@15.0mm = 0.000 F_aligned@15.0mm = 0.986
72 | ```
73 |
74 | Note that our method predicts relative coordinates (there is no global alignment). Therefore, only the aligned scores are meaningful in our case.
75 |
76 |
77 | ## Human mesh reconstruction from a single image
78 |
79 |
80 | ### Training with mixed datasets
81 |
82 | We conduct large-scale training on multiple 2D and 3D datasets, including Human3.6M, COCO, MUCO, UP3D, MPII. During training, it will evaluate the performance per epoch, and save the best checkpoints.
83 |
84 | ```bash
85 | python -m torch.distributed.launch --nproc_per_node=8 \
86 | src/tools/run_gphmer_bodymesh.py \
87 | --train_yaml Tax-H36m-coco40k-Muco-UP-Mpii/train.yaml \
88 | --val_yaml human3.6m/valid.protocol2.yaml \
89 | --arch hrnet-w64 \
90 | --num_workers 4 \
91 | --per_gpu_train_batch_size 25 \
92 | --per_gpu_eval_batch_size 25 \
93 | --num_hidden_layers 4 \
94 | --num_attention_heads 4 \
95 | --lr 1e-4 \
96 | --num_train_epochs 200 \
97 | --input_feat_dim 2051,512,128 \
98 | --hidden_feat_dim 1024,256,64
99 | ```
100 |
101 | Example training log can be found here [2021-02-25-graphormer_h36m_log](https://datarelease.blob.core.windows.net/metro/models/2021-02-25-graphormer_h36m_log.txt)
102 |
103 | ### Evaluation on Human3.6M
104 |
105 | In the following script, we evaluate our model `graphormer_h36m_state_dict.bin` on Human3.6M validation set. Check `docs/DOWNLOAD.md` for more details about downloading the model file.
106 |
107 | ```bash
108 | python -m torch.distributed.launch --nproc_per_node=8 \
109 | src/tools/run_gphmer_bodymesh.py \
110 | --val_yaml human3.6m/valid.protocol2.yaml \
111 | --arch hrnet-w64 \
112 | --num_workers 4 \
113 | --per_gpu_eval_batch_size 25 \
114 | --num_hidden_layers 4 \
115 | --num_attention_heads 4 \
116 | --input_feat_dim 2051,512,128 \
117 | --hidden_feat_dim 1024,256,64 \
118 | --run_eval_only \
119 | --resume_checkpoint ./models/graphormer_release/graphormer_h36m_state_dict.bin
120 | ```
121 |
122 | We show the example outputs of this script as below.
123 | ```bash
124 | 2021-09-19 13:18:14,416 Graphormer INFO: Using 8 GPUs
125 | 2021-09-19 13:18:18,712 Graphormer INFO: Update config parameter num_hidden_layers: 12 -> 4
126 | 2021-09-19 13:18:18,718 Graphormer INFO: Update config parameter hidden_size: 768 -> 1024
127 | 2021-09-19 13:18:18,725 Graphormer INFO: Update config parameter num_attention_heads: 12 -> 4
128 | 2021-09-19 13:18:18,731 Graphormer INFO: Update config parameter intermediate_size: 3072 -> 2048
129 | 2021-09-19 13:18:19,983 Graphormer INFO: Init model from scratch.
130 | 2021-09-19 13:18:19,990 Graphormer INFO: Update config parameter num_hidden_layers: 12 -> 4
131 | 2021-09-19 13:18:19,995 Graphormer INFO: Update config parameter hidden_size: 768 -> 256
132 | 2021-09-19 13:18:20,001 Graphormer INFO: Update config parameter num_attention_heads: 12 -> 4
133 | 2021-09-19 13:18:20,006 Graphormer INFO: Update config parameter intermediate_size: 3072 -> 512
134 | 2021-09-19 13:18:20,210 Graphormer INFO: Init model from scratch.
135 | 2021-09-19 13:18:20,217 Graphormer INFO: Add Graph Conv
136 | 2021-09-19 13:18:20,223 Graphormer INFO: Update config parameter num_hidden_layers: 12 -> 4
137 | 2021-09-19 13:18:20,228 Graphormer INFO: Update config parameter hidden_size: 768 -> 64
138 | 2021-09-19 13:18:20,233 Graphormer INFO: Update config parameter num_attention_heads: 12 -> 4
139 | 2021-09-19 13:18:20,239 Graphormer INFO: Update config parameter intermediate_size: 3072 -> 128
140 | 2021-09-19 13:18:20,295 Graphormer INFO: Init model from scratch.
141 | 2021-09-19 13:18:23,797 Graphormer INFO: => loading hrnet-v2-w64 model
142 | 2021-09-19 13:18:23,805 Graphormer INFO: Graphormer encoders total parameters: 83318598
143 | 2021-09-19 13:18:23,814 Graphormer INFO: Backbone total parameters: 128059944
144 | 2021-09-19 13:18:23,892 Graphormer INFO: Loading state dict from checkpoint _output/graphormer_release/graphormer_h36m_state_dict.bin
145 | 2021-09-19 13:19:26,299 Graphormer INFO: Validation epoch: 0 mPVE: 0.00, mPJPE: 51.20, PAmPJPE: 34.55
146 | ```
147 |
148 |
149 |
150 | ### Training with 3DPW dataset
151 |
152 | We follow prior works that also use 3DPW training data. In order to make the training faster, we **fine-tune** our pre-trained model (`graphormer_h36m_state_dict.bin`) on 3DPW training set.
153 |
154 | We use the following script for fine-tuning. During fine-tuning, it will evaluate the performance per epoch, and save the best checkpoints.
155 |
156 | ```bash
157 | python -m torch.distributed.launch --nproc_per_node=8 \
158 | src/tools/run_gphmer_bodymesh.py \
159 | --train_yaml 3dpw/train.yaml \
160 | --val_yaml 3dpw/test_has_gender.yaml \
161 | --arch hrnet-w64 \
162 | --num_workers 4 \
163 | --per_gpu_train_batch_size 20 \
164 | --per_gpu_eval_batch_size 20 \
165 | --num_hidden_layers 4 \
166 | --num_attention_heads 4 \
167 | --lr 1e-4 \
168 | --num_train_epochs 5 \
169 | --input_feat_dim 2051,512,128 \
170 | --hidden_feat_dim 1024,256,64 \
171 | --resume_checkpoint {YOUR_PATH/state_dict.bin} \
172 | ```
173 |
174 |
175 | ### Evaluation on 3DPW
176 | In the following script, we evaluate our model `graphormer_3dpw_state_dict.bin` on 3DPW test set. Check `docs/DOWNLOAD.md` for more details about downloading the model file.
177 |
178 |
179 | ```bash
180 | python -m torch.distributed.launch --nproc_per_node=8 \
181 | src/tools/run_gphmer_bodymesh.py \
182 | --val_yaml 3dpw/test.yaml \
183 | --arch hrnet-w64 \
184 | --num_workers 4 \
185 | --per_gpu_eval_batch_size 25 \
186 | --num_hidden_layers 4 \
187 | --num_attention_heads 4 \
188 | --input_feat_dim 2051,512,128 \
189 | --hidden_feat_dim 1024,256,64 \
190 | --run_eval_only \
191 | --resume_checkpoint ./models/graphormer_release/graphormer_3dpw_state_dict.bin
192 | ```
193 |
194 | After evaluation, it should reproduce the results below
195 | ```bash
196 | 2021-09-20 00:54:46,178 Graphormer INFO: Using 8 GPUs
197 | 2021-09-20 00:54:50,339 Graphormer INFO: Update config parameter num_hidden_layers: 12 -> 4
198 | 2021-09-20 00:54:50,345 Graphormer INFO: Update config parameter hidden_size: 768 -> 1024
199 | 2021-09-20 00:54:50,351 Graphormer INFO: Update config parameter num_attention_heads: 12 -> 4
200 | 2021-09-20 00:54:50,357 Graphormer INFO: Update config parameter intermediate_size: 3072 -> 2048
201 | 2021-09-20 00:54:51,602 Graphormer INFO: Init model from scratch.
202 | 2021-09-20 00:54:51,613 Graphormer INFO: Update config parameter num_hidden_layers: 12 -> 4
203 | 2021-09-20 00:54:51,625 Graphormer INFO: Update config parameter hidden_size: 768 -> 256
204 | 2021-09-20 00:54:51,646 Graphormer INFO: Update config parameter num_attention_heads: 12 -> 4
205 | 2021-09-20 00:54:51,652 Graphormer INFO: Update config parameter intermediate_size: 3072 -> 512
206 | 2021-09-20 00:54:51,855 Graphormer INFO: Init model from scratch.
207 | 2021-09-20 00:54:51,862 Graphormer INFO: Add Graph Conv
208 | 2021-09-20 00:54:51,868 Graphormer INFO: Update config parameter num_hidden_layers: 12 -> 4
209 | 2021-09-20 00:54:51,873 Graphormer INFO: Update config parameter hidden_size: 768 -> 64
210 | 2021-09-20 00:54:51,880 Graphormer INFO: Update config parameter num_attention_heads: 12 -> 4
211 | 2021-09-20 00:54:51,885 Graphormer INFO: Update config parameter intermediate_size: 3072 -> 128
212 | 2021-09-20 00:54:51,948 Graphormer INFO: Init model from scratch.
213 | 2021-09-20 00:54:55,564 Graphormer INFO: => loading hrnet-v2-w64 model
214 | 2021-09-20 00:54:55,572 Graphormer INFO: Graphormer encoders total parameters: 83318598
215 | 2021-09-20 00:54:55,580 Graphormer INFO: Backbone total parameters: 128059944
216 | 2021-09-20 00:54:55,655 Graphormer INFO: Loading state dict from checkpoint _output/graphormer_release/graphormer_3dpw_state_dict.bin
217 | 2021-09-20 00:56:24,334 Graphormer INFO: Validation epoch: 0 mPVE: 87.57, mPJPE: 73.98, PAmPJPE: 45.35
218 | ```
219 |
220 |
--------------------------------------------------------------------------------
/docs/Fig1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/docs/Fig1.gif
--------------------------------------------------------------------------------
/docs/Fig2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/docs/Fig2.gif
--------------------------------------------------------------------------------
/docs/Fig3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/docs/Fig3.gif
--------------------------------------------------------------------------------
/docs/Fig4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/docs/Fig4.gif
--------------------------------------------------------------------------------
/docs/INSTALL.md:
--------------------------------------------------------------------------------
1 | ## Installation
2 |
3 | Our codebase is developed based on Ubuntu 16.04 and NVIDIA GPU cards.
4 |
5 | ### Requirements
6 | - Python 3.7
7 | - Pytorch 1.4
8 | - torchvision 0.5.0
9 | - cuda 10.1
10 |
11 | ### Setup with Conda
12 |
13 | We suggest to create a new conda environment and install all the relevant dependencies.
14 |
15 | ```bash
16 | # Create a new environment
17 | conda create --name gphmr python=3.7
18 | conda activate gphmr
19 |
20 | # Install Pytorch
21 | conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.1 -c pytorch
22 |
23 | export INSTALL_DIR=$PWD
24 |
25 | # Install apex
26 | cd $INSTALL_DIR
27 | git clone https://github.com/NVIDIA/apex.git
28 | cd apex
29 | python setup.py install --cuda_ext --cpp_ext
30 |
31 | # Install OpenDR
32 | pip install matplotlib
33 | pip install git+https://gitlab.eecs.umich.edu/ngv-python-modules/opendr.git
34 |
35 | # Install MeshGraphormer
36 | cd $INSTALL_DIR
37 | git clone --recursive https://github.com/microsoft/MeshGraphormer.git
38 | cd MeshGraphormer
39 | python setup.py build develop
40 |
41 | # Install requirements
42 | pip install -r requirements.txt
43 |
44 | # Install manopth
45 | cd $INSTALL_DIR
46 | cd MeshGraphormer
47 | pip install ./manopth/.
48 |
49 | unset INSTALL_DIR
50 | ```
51 |
52 |
53 |
--------------------------------------------------------------------------------
/docs/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
40 |
41 |
--------------------------------------------------------------------------------
/docs/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # TODO: The maintainer of this repo has not yet edited this file
2 |
3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4 |
5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).
7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide.
8 |
9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 |
11 | # Support
12 |
13 | ## How to file issues and get help
14 |
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
17 | feature request as a new Issue.
18 |
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 |
23 | ## Microsoft Support Policy
24 |
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 |
--------------------------------------------------------------------------------
/docs/graphormer_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/docs/graphormer_overview.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | yacs==0.1.8
2 | cython
3 | opencv-python
4 | tqdm
5 | nltk
6 | numpy
7 | scipy==1.4.1
8 | chumpy
9 | boto3
10 | requests
11 |
--------------------------------------------------------------------------------
/samples/hand/freihand_sample1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/freihand_sample1.jpg
--------------------------------------------------------------------------------
/samples/hand/freihand_sample1_graphormer_pred.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/freihand_sample1_graphormer_pred.jpg
--------------------------------------------------------------------------------
/samples/hand/freihand_sample2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/freihand_sample2.jpg
--------------------------------------------------------------------------------
/samples/hand/freihand_sample2_graphormer_pred.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/freihand_sample2_graphormer_pred.jpg
--------------------------------------------------------------------------------
/samples/hand/freihand_sample3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/freihand_sample3.jpg
--------------------------------------------------------------------------------
/samples/hand/freihand_sample3_graphormer_pred.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/freihand_sample3_graphormer_pred.jpg
--------------------------------------------------------------------------------
/samples/hand/internet_fig1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/internet_fig1.jpg
--------------------------------------------------------------------------------
/samples/hand/internet_fig1_graphormer_pred.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/internet_fig1_graphormer_pred.jpg
--------------------------------------------------------------------------------
/samples/hand/internet_fig2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/internet_fig2.jpg
--------------------------------------------------------------------------------
/samples/hand/internet_fig2_graphormer_pred.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/internet_fig2_graphormer_pred.jpg
--------------------------------------------------------------------------------
/samples/hand/internet_fig3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/internet_fig3.jpg
--------------------------------------------------------------------------------
/samples/hand/internet_fig3_graphormer_pred.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/internet_fig3_graphormer_pred.jpg
--------------------------------------------------------------------------------
/samples/hand/internet_fig4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/internet_fig4.jpg
--------------------------------------------------------------------------------
/samples/hand/internet_fig4_graphormer_pred.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/hand/internet_fig4_graphormer_pred.jpg
--------------------------------------------------------------------------------
/samples/human-body/3dpw_test1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test1.jpg
--------------------------------------------------------------------------------
/samples/human-body/3dpw_test1_graphormer_pred.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test1_graphormer_pred.jpg
--------------------------------------------------------------------------------
/samples/human-body/3dpw_test2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test2.jpg
--------------------------------------------------------------------------------
/samples/human-body/3dpw_test2_graphormer_pred.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test2_graphormer_pred.jpg
--------------------------------------------------------------------------------
/samples/human-body/3dpw_test3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test3.jpg
--------------------------------------------------------------------------------
/samples/human-body/3dpw_test3_graphormer_pred.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test3_graphormer_pred.jpg
--------------------------------------------------------------------------------
/samples/human-body/3dpw_test4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test4.jpg
--------------------------------------------------------------------------------
/samples/human-body/3dpw_test4_graphormer_pred.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test4_graphormer_pred.jpg
--------------------------------------------------------------------------------
/samples/human-body/3dpw_test5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test5.jpg
--------------------------------------------------------------------------------
/samples/human-body/3dpw_test5_graphormer_pred.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test5_graphormer_pred.jpg
--------------------------------------------------------------------------------
/samples/human-body/3dpw_test6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test6.jpg
--------------------------------------------------------------------------------
/samples/human-body/3dpw_test6_graphormer_pred.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test6_graphormer_pred.jpg
--------------------------------------------------------------------------------
/samples/human-body/3dpw_test7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test7.jpg
--------------------------------------------------------------------------------
/samples/human-body/3dpw_test7_graphormer_pred.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/samples/human-body/3dpw_test7_graphormer_pred.jpg
--------------------------------------------------------------------------------
/scripts/download_models.sh:
--------------------------------------------------------------------------------
1 | # --------------------------------
2 | # Setup
3 | # --------------------------------
4 | export REPO_DIR=$PWD
5 | if [ ! -d $REPO_DIR/models ] ; then
6 | mkdir -p $REPO_DIR/models
7 | fi
8 | BLOB='https://datarelease.blob.core.windows.net/metro'
9 |
10 |
11 | # --------------------------------
12 | # Download our pre-trained models
13 | # --------------------------------
14 | if [ ! -d $REPO_DIR/models/graphormer_release ] ; then
15 | mkdir -p $REPO_DIR/models/graphormer_release
16 | fi
17 | # (1) Mesh Graphormer for human mesh reconstruction (trained on H3.6M + COCO + MuCO + UP3D + MPII)
18 | wget -nc $BLOB/models/graphormer_h36m_state_dict.bin -O $REPO_DIR/models/graphormer_release/graphormer_h36m_state_dict.bin
19 | # (2) Mesh Graphormer for human mesh reconstruction (trained on H3.6M + COCO + MuCO + UP3D + MPII, then fine-tuned on 3DPW)
20 | wget -nc $BLOB/models/graphormer_3dpw_state_dict.bin -O $REPO_DIR/models/graphormer_release/graphormer_3dpw_state_dict.bin
21 | # (3) Mesh Graphormer for hand mesh reconstruction (trained on FreiHAND)
22 | wget -nc $BLOB/models/graphormer_hand_state_dict.bin -O $REPO_DIR/models/graphormer_release/graphormer_hand_state_dict.bin
23 |
24 |
25 | # --------------------------------
26 | # Download the ImageNet pre-trained HRNet models
27 | # The weights are provided by https://github.com/HRNet/HRNet-Image-Classification
28 | # --------------------------------
29 | if [ ! -d $REPO_DIR/models/hrnet ] ; then
30 | mkdir -p $REPO_DIR/models/hrnet
31 | fi
32 | wget -nc $BLOB/models/hrnetv2_w64_imagenet_pretrained.pth -O $REPO_DIR/models/hrnet/hrnetv2_w64_imagenet_pretrained.pth
33 | wget -nc $BLOB/models/hrnetv2_w40_imagenet_pretrained.pth -O $REPO_DIR/models/hrnet/hrnetv2_w40_imagenet_pretrained.pth
34 | wget -nc $BLOB/models/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml -O $REPO_DIR/models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml
35 | wget -nc $BLOB/models/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml -O $REPO_DIR/models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml
36 |
37 |
38 |
--------------------------------------------------------------------------------
/scripts/download_preds.sh:
--------------------------------------------------------------------------------
1 | # --------------------------------
2 | # Setup
3 | # --------------------------------
4 | export REPO_DIR=$PWD
5 | if [ ! -d $REPO_DIR/models ] ; then
6 | mkdir -p $REPO_DIR/models
7 | fi
8 | BLOB='https://datarelease.blob.core.windows.net/metro'
9 |
10 | # --------------------------------
11 | # Download our model predictions that can be submitted to FreiHAND Leaderboard
12 | # --------------------------------
13 | if [ ! -d $REPO_DIR/predictions ] ; then
14 | mkdir -p $REPO_DIR/predictions
15 | fi
16 | # Our model + test-time augmentation. It achieves 5.9 PA-MPVPE on FreiHAND Leaderboard
17 | wget -nc $BLOB/graphormer-release-ckpt200-multisc-pred.zip -O $REPO_DIR/predictions/ckpt200-multisc-pred.zip
18 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 |
3 | from __future__ import print_function
4 | import os
5 | import sys
6 | import re
7 | import os.path as op
8 | from setuptools import find_packages, setup
9 |
10 | # change directory to this module path
11 | try:
12 | this_file = __file__
13 | except NameError:
14 | this_file = sys.argv[0]
15 | this_file = os.path.abspath(this_file)
16 | if op.dirname(this_file):
17 | os.chdir(op.dirname(this_file))
18 | script_dir = os.getcwd()
19 |
20 | def readme(fname):
21 | """Read text out of a file in the same directory as setup.py.
22 | """
23 | return open(op.join(script_dir, fname)).read()
24 |
25 |
26 | def find_version(fname):
27 | version_file = readme(fname)
28 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]",
29 | version_file, re.M)
30 | if version_match:
31 | return version_match.group(1)
32 | raise RuntimeError("Unable to find version string.")
33 |
34 |
35 | setup(
36 | name="graphormer",
37 | version=find_version("src/__init__.py"),
38 | description="graphormer",
39 | long_description=readme('README.md'),
40 | packages=find_packages(),
41 | classifiers=[
42 | 'Intended Audience :: Developers',
43 | "Programming Language :: Python",
44 | 'Topic :: Software Development',
45 | ]
46 | )
47 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.1.0"
2 |
--------------------------------------------------------------------------------
/src/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/datasets/build.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Microsoft Corporation.
3 | Licensed under the MIT license.
4 |
5 | """
6 |
7 |
8 | import os.path as op
9 | import torch
10 | import logging
11 | import code
12 | from src.utils.comm import get_world_size
13 | from src.datasets.human_mesh_tsv import (MeshTSVDataset, MeshTSVYamlDataset)
14 | from src.datasets.hand_mesh_tsv import (HandMeshTSVDataset, HandMeshTSVYamlDataset)
15 |
16 |
17 | def build_dataset(yaml_file, args, is_train=True, scale_factor=1):
18 | print(yaml_file)
19 | if not op.isfile(yaml_file):
20 | yaml_file = op.join(args.data_dir, yaml_file)
21 | # code.interact(local=locals())
22 | assert op.isfile(yaml_file)
23 | return MeshTSVYamlDataset(yaml_file, is_train, False, scale_factor)
24 |
25 |
26 | class IterationBasedBatchSampler(torch.utils.data.sampler.BatchSampler):
27 | """
28 | Wraps a BatchSampler, resampling from it until
29 | a specified number of iterations have been sampled
30 | """
31 |
32 | def __init__(self, batch_sampler, num_iterations, start_iter=0):
33 | self.batch_sampler = batch_sampler
34 | self.num_iterations = num_iterations
35 | self.start_iter = start_iter
36 |
37 | def __iter__(self):
38 | iteration = self.start_iter
39 | while iteration <= self.num_iterations:
40 | # if the underlying sampler has a set_epoch method, like
41 | # DistributedSampler, used for making each process see
42 | # a different split of the dataset, then set it
43 | if hasattr(self.batch_sampler.sampler, "set_epoch"):
44 | self.batch_sampler.sampler.set_epoch(iteration)
45 | for batch in self.batch_sampler:
46 | iteration += 1
47 | if iteration > self.num_iterations:
48 | break
49 | yield batch
50 |
51 | def __len__(self):
52 | return self.num_iterations
53 |
54 |
55 | def make_batch_data_sampler(sampler, images_per_gpu, num_iters=None, start_iter=0):
56 | batch_sampler = torch.utils.data.sampler.BatchSampler(
57 | sampler, images_per_gpu, drop_last=False
58 | )
59 | if num_iters is not None and num_iters >= 0:
60 | batch_sampler = IterationBasedBatchSampler(
61 | batch_sampler, num_iters, start_iter
62 | )
63 | return batch_sampler
64 |
65 |
66 | def make_data_sampler(dataset, shuffle, distributed):
67 | if distributed:
68 | return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
69 | if shuffle:
70 | sampler = torch.utils.data.sampler.RandomSampler(dataset)
71 | else:
72 | sampler = torch.utils.data.sampler.SequentialSampler(dataset)
73 | return sampler
74 |
75 |
76 | def make_data_loader(args, yaml_file, is_distributed=True,
77 | is_train=True, start_iter=0, scale_factor=1):
78 |
79 | dataset = build_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor)
80 | logger = logging.getLogger(__name__)
81 | if is_train==True:
82 | shuffle = True
83 | images_per_gpu = args.per_gpu_train_batch_size
84 | images_per_batch = images_per_gpu * get_world_size()
85 | iters_per_batch = len(dataset) // images_per_batch
86 | num_iters = iters_per_batch * args.num_train_epochs
87 | logger.info("Train with {} images per GPU.".format(images_per_gpu))
88 | logger.info("Total batch size {}".format(images_per_batch))
89 | logger.info("Total training steps {}".format(num_iters))
90 | else:
91 | shuffle = False
92 | images_per_gpu = args.per_gpu_eval_batch_size
93 | num_iters = None
94 | start_iter = 0
95 |
96 | sampler = make_data_sampler(dataset, shuffle, is_distributed)
97 | batch_sampler = make_batch_data_sampler(
98 | sampler, images_per_gpu, num_iters, start_iter
99 | )
100 | data_loader = torch.utils.data.DataLoader(
101 | dataset, num_workers=args.num_workers, batch_sampler=batch_sampler,
102 | pin_memory=True,
103 | )
104 | return data_loader
105 |
106 |
107 | #==============================================================================================
108 |
109 | def build_hand_dataset(yaml_file, args, is_train=True, scale_factor=1):
110 | print(yaml_file)
111 | if not op.isfile(yaml_file):
112 | yaml_file = op.join(args.data_dir, yaml_file)
113 | # code.interact(local=locals())
114 | assert op.isfile(yaml_file)
115 | return HandMeshTSVYamlDataset(args, yaml_file, is_train, False, scale_factor)
116 |
117 |
118 | def make_hand_data_loader(args, yaml_file, is_distributed=True,
119 | is_train=True, start_iter=0, scale_factor=1):
120 |
121 | dataset = build_hand_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor)
122 | logger = logging.getLogger(__name__)
123 | if is_train==True:
124 | shuffle = True
125 | images_per_gpu = args.per_gpu_train_batch_size
126 | images_per_batch = images_per_gpu * get_world_size()
127 | iters_per_batch = len(dataset) // images_per_batch
128 | num_iters = iters_per_batch * args.num_train_epochs
129 | logger.info("Train with {} images per GPU.".format(images_per_gpu))
130 | logger.info("Total batch size {}".format(images_per_batch))
131 | logger.info("Total training steps {}".format(num_iters))
132 | else:
133 | shuffle = False
134 | images_per_gpu = args.per_gpu_eval_batch_size
135 | num_iters = None
136 | start_iter = 0
137 |
138 | sampler = make_data_sampler(dataset, shuffle, is_distributed)
139 | batch_sampler = make_batch_data_sampler(
140 | sampler, images_per_gpu, num_iters, start_iter
141 | )
142 | data_loader = torch.utils.data.DataLoader(
143 | dataset, num_workers=args.num_workers, batch_sampler=batch_sampler,
144 | pin_memory=True,
145 | )
146 | return data_loader
147 |
148 |
--------------------------------------------------------------------------------
/src/datasets/hand_mesh_tsv.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Microsoft Corporation.
3 | Licensed under the MIT license.
4 |
5 | """
6 |
7 |
8 | import cv2
9 | import math
10 | import json
11 | from PIL import Image
12 | import os.path as op
13 | import numpy as np
14 | import code
15 |
16 | from src.utils.tsv_file import TSVFile, CompositeTSVFile
17 | from src.utils.tsv_file_ops import load_linelist_file, load_from_yaml_file, find_file_path_in_yaml
18 | from src.utils.image_ops import img_from_base64, crop, flip_img, flip_pose, flip_kp, transform, rot_aa
19 | import torch
20 | import torchvision.transforms as transforms
21 |
22 |
23 | class HandMeshTSVDataset(object):
24 | def __init__(self, args, img_file, label_file=None, hw_file=None,
25 | linelist_file=None, is_train=True, cv2_output=False, scale_factor=1):
26 |
27 | self.args = args
28 | self.img_file = img_file
29 | self.label_file = label_file
30 | self.hw_file = hw_file
31 | self.linelist_file = linelist_file
32 | self.img_tsv = self.get_tsv_file(img_file)
33 | self.label_tsv = None if label_file is None else self.get_tsv_file(label_file)
34 | self.hw_tsv = None if hw_file is None else self.get_tsv_file(hw_file)
35 |
36 | if self.is_composite:
37 | assert op.isfile(self.linelist_file)
38 | self.line_list = [i for i in range(self.hw_tsv.num_rows())]
39 | else:
40 | self.line_list = load_linelist_file(linelist_file)
41 |
42 | self.cv2_output = cv2_output
43 | self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406],
44 | std=[0.229, 0.224, 0.225])
45 | self.is_train = is_train
46 | self.scale_factor = 0.25 # rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor]
47 | self.noise_factor = 0.4
48 | self.rot_factor = 90 # Random rotation in the range [-rot_factor, rot_factor]
49 | self.img_res = 224
50 | self.image_keys = self.prepare_image_keys()
51 | self.joints_definition = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1',
52 | 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
53 | self.root_index = self.joints_definition.index('Wrist')
54 |
55 | def get_tsv_file(self, tsv_file):
56 | if tsv_file:
57 | if self.is_composite:
58 | return CompositeTSVFile(tsv_file, self.linelist_file,
59 | root=self.root)
60 | tsv_path = find_file_path_in_yaml(tsv_file, self.root)
61 | return TSVFile(tsv_path)
62 |
63 | def get_valid_tsv(self):
64 | # sorted by file size
65 | if self.hw_tsv:
66 | return self.hw_tsv
67 | if self.label_tsv:
68 | return self.label_tsv
69 |
70 | def prepare_image_keys(self):
71 | tsv = self.get_valid_tsv()
72 | return [tsv.get_key(i) for i in range(tsv.num_rows())]
73 |
74 | def prepare_image_key_to_index(self):
75 | tsv = self.get_valid_tsv()
76 | return {tsv.get_key(i) : i for i in range(tsv.num_rows())}
77 |
78 |
79 | def augm_params(self):
80 | """Get augmentation parameters."""
81 | flip = 0 # flipping
82 | pn = np.ones(3) # per channel pixel-noise
83 |
84 | if self.args.multiscale_inference == False:
85 | rot = 0 # rotation
86 | sc = 1.0 # scaling
87 | elif self.args.multiscale_inference == True:
88 | rot = self.args.rot
89 | sc = self.args.sc
90 |
91 | if self.is_train:
92 | sc = 1.0
93 | # Each channel is multiplied with a number
94 | # in the area [1-opt.noiseFactor,1+opt.noiseFactor]
95 | pn = np.random.uniform(1-self.noise_factor, 1+self.noise_factor, 3)
96 |
97 | # The rotation is a number in the area [-2*rotFactor, 2*rotFactor]
98 | rot = min(2*self.rot_factor,
99 | max(-2*self.rot_factor, np.random.randn()*self.rot_factor))
100 |
101 | # The scale is multiplied with a number
102 | # in the area [1-scaleFactor,1+scaleFactor]
103 | sc = min(1+self.scale_factor,
104 | max(1-self.scale_factor, np.random.randn()*self.scale_factor+1))
105 | # but it is zero with probability 3/5
106 | if np.random.uniform() <= 0.6:
107 | rot = 0
108 |
109 | return flip, pn, rot, sc
110 |
111 | def rgb_processing(self, rgb_img, center, scale, rot, flip, pn):
112 | """Process rgb image and do augmentation."""
113 | rgb_img = crop(rgb_img, center, scale,
114 | [self.img_res, self.img_res], rot=rot)
115 | # flip the image
116 | if flip:
117 | rgb_img = flip_img(rgb_img)
118 | # in the rgb image we add pixel noise in a channel-wise manner
119 | rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0]))
120 | rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1]))
121 | rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2]))
122 | # (3,224,224),float,[0,1]
123 | rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0
124 | return rgb_img
125 |
126 | def j2d_processing(self, kp, center, scale, r, f):
127 | """Process gt 2D keypoints and apply all augmentation transforms."""
128 | nparts = kp.shape[0]
129 | for i in range(nparts):
130 | kp[i,0:2] = transform(kp[i,0:2]+1, center, scale,
131 | [self.img_res, self.img_res], rot=r)
132 | # convert to normalized coordinates
133 | kp[:,:-1] = 2.*kp[:,:-1]/self.img_res - 1.
134 | # flip the x coordinates
135 | if f:
136 | kp = flip_kp(kp)
137 | kp = kp.astype('float32')
138 | return kp
139 |
140 |
141 | def j3d_processing(self, S, r, f):
142 | """Process gt 3D keypoints and apply all augmentation transforms."""
143 | # in-plane rotation
144 | rot_mat = np.eye(3)
145 | if not r == 0:
146 | rot_rad = -r * np.pi / 180
147 | sn,cs = np.sin(rot_rad), np.cos(rot_rad)
148 | rot_mat[0,:2] = [cs, -sn]
149 | rot_mat[1,:2] = [sn, cs]
150 | S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1])
151 | # flip the x coordinates
152 | if f:
153 | S = flip_kp(S)
154 | S = S.astype('float32')
155 | return S
156 |
157 | def pose_processing(self, pose, r, f):
158 | """Process SMPL theta parameters and apply all augmentation transforms."""
159 | # rotation or the pose parameters
160 | pose = pose.astype('float32')
161 | pose[:3] = rot_aa(pose[:3], r)
162 | # flip the pose parameters
163 | if f:
164 | pose = flip_pose(pose)
165 | # (72),float
166 | pose = pose.astype('float32')
167 | return pose
168 |
169 | def get_line_no(self, idx):
170 | return idx if self.line_list is None else self.line_list[idx]
171 |
172 | def get_image(self, idx):
173 | line_no = self.get_line_no(idx)
174 | row = self.img_tsv[line_no]
175 | # use -1 to support old format with multiple columns.
176 | cv2_im = img_from_base64(row[-1])
177 | if self.cv2_output:
178 | return cv2_im.astype(np.float32, copy=True)
179 | cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB)
180 | return cv2_im
181 |
182 | def get_annotations(self, idx):
183 | line_no = self.get_line_no(idx)
184 | if self.label_tsv is not None:
185 | row = self.label_tsv[line_no]
186 | annotations = json.loads(row[1])
187 | return annotations
188 | else:
189 | return []
190 |
191 | def get_target_from_annotations(self, annotations, img_size, idx):
192 | # This function will be overwritten by each dataset to
193 | # decode the labels to specific formats for each task.
194 | return annotations
195 |
196 | def get_img_info(self, idx):
197 | if self.hw_tsv is not None:
198 | line_no = self.get_line_no(idx)
199 | row = self.hw_tsv[line_no]
200 | try:
201 | # json string format with "height" and "width" being the keys
202 | return json.loads(row[1])[0]
203 | except ValueError:
204 | # list of strings representing height and width in order
205 | hw_str = row[1].split(' ')
206 | hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])}
207 | return hw_dict
208 |
209 | def get_img_key(self, idx):
210 | line_no = self.get_line_no(idx)
211 | # based on the overhead of reading each row.
212 | if self.hw_tsv:
213 | return self.hw_tsv[line_no][0]
214 | elif self.label_tsv:
215 | return self.label_tsv[line_no][0]
216 | else:
217 | return self.img_tsv[line_no][0]
218 |
219 | def __len__(self):
220 | if self.line_list is None:
221 | return self.img_tsv.num_rows()
222 | else:
223 | return len(self.line_list)
224 |
225 | def __getitem__(self, idx):
226 |
227 | img = self.get_image(idx)
228 | img_key = self.get_img_key(idx)
229 | annotations = self.get_annotations(idx)
230 |
231 | annotations = annotations[0]
232 | center = annotations['center']
233 | scale = annotations['scale']
234 | has_2d_joints = annotations['has_2d_joints']
235 | has_3d_joints = annotations['has_3d_joints']
236 | joints_2d = np.asarray(annotations['2d_joints'])
237 | joints_3d = np.asarray(annotations['3d_joints'])
238 |
239 | if joints_2d.ndim==3:
240 | joints_2d = joints_2d[0]
241 | if joints_3d.ndim==3:
242 | joints_3d = joints_3d[0]
243 |
244 | # Get SMPL parameters, if available
245 | has_smpl = np.asarray(annotations['has_smpl'])
246 | pose = np.asarray(annotations['pose'])
247 | betas = np.asarray(annotations['betas'])
248 |
249 | # Get augmentation parameters
250 | flip,pn,rot,sc = self.augm_params()
251 |
252 | # Process image
253 | img = self.rgb_processing(img, center, sc*scale, rot, flip, pn)
254 | img = torch.from_numpy(img).float()
255 | # Store image before normalization to use it in visualization
256 | transfromed_img = self.normalize_img(img)
257 |
258 | # normalize 3d pose by aligning the wrist as the root (at origin)
259 | root_coord = joints_3d[self.root_index,:-1]
260 | joints_3d[:,:-1] = joints_3d[:,:-1] - root_coord[None,:]
261 | # 3d pose augmentation (random flip + rotation, consistent to image and SMPL)
262 | joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, flip)
263 | # 2d pose augmentation
264 | joints_2d_transformed = self.j2d_processing(joints_2d.copy(), center, sc*scale, rot, flip)
265 |
266 | ###################################
267 | # Masking percantage
268 | # We observe that 0% or 5% works better for 3D hand mesh
269 | # We think this is probably becasue 3D vertices are quite sparse in the down-sampled hand mesh
270 | mvm_percent = 0.0 # or 0.05
271 | ###################################
272 |
273 | mjm_mask = np.ones((21,1))
274 | if self.is_train:
275 | num_joints = 21
276 | pb = np.random.random_sample()
277 | masked_num = int(pb * mvm_percent * num_joints) # at most x% of the joints could be masked
278 | indices = np.random.choice(np.arange(num_joints),replace=False,size=masked_num)
279 | mjm_mask[indices,:] = 0.0
280 | mjm_mask = torch.from_numpy(mjm_mask).float()
281 |
282 | mvm_mask = np.ones((195,1))
283 | if self.is_train:
284 | num_vertices = 195
285 | pb = np.random.random_sample()
286 | masked_num = int(pb * mvm_percent * num_vertices) # at most x% of the vertices could be masked
287 | indices = np.random.choice(np.arange(num_vertices),replace=False,size=masked_num)
288 | mvm_mask[indices,:] = 0.0
289 | mvm_mask = torch.from_numpy(mvm_mask).float()
290 |
291 | meta_data = {}
292 | meta_data['ori_img'] = img
293 | meta_data['pose'] = torch.from_numpy(self.pose_processing(pose, rot, flip)).float()
294 | meta_data['betas'] = torch.from_numpy(betas).float()
295 | meta_data['joints_3d'] = torch.from_numpy(joints_3d_transformed).float()
296 | meta_data['has_3d_joints'] = has_3d_joints
297 | meta_data['has_smpl'] = has_smpl
298 | meta_data['mjm_mask'] = mjm_mask
299 | meta_data['mvm_mask'] = mvm_mask
300 |
301 | # Get 2D keypoints and apply augmentation transforms
302 | meta_data['has_2d_joints'] = has_2d_joints
303 | meta_data['joints_2d'] = torch.from_numpy(joints_2d_transformed).float()
304 |
305 | meta_data['scale'] = float(sc * scale)
306 | meta_data['center'] = np.asarray(center).astype(np.float32)
307 |
308 | return img_key, transfromed_img, meta_data
309 |
310 |
311 | class HandMeshTSVYamlDataset(HandMeshTSVDataset):
312 | """ TSVDataset taking a Yaml file for easy function call
313 | """
314 | def __init__(self, args, yaml_file, is_train=True, cv2_output=False, scale_factor=1):
315 | self.cfg = load_from_yaml_file(yaml_file)
316 | self.is_composite = self.cfg.get('composite', False)
317 | self.root = op.dirname(yaml_file)
318 |
319 | if self.is_composite==False:
320 | img_file = find_file_path_in_yaml(self.cfg['img'], self.root)
321 | label_file = find_file_path_in_yaml(self.cfg.get('label', None),
322 | self.root)
323 | hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root)
324 | linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
325 | self.root)
326 | else:
327 | img_file = self.cfg['img']
328 | hw_file = self.cfg['hw']
329 | label_file = self.cfg.get('label', None)
330 | linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
331 | self.root)
332 |
333 | super(HandMeshTSVYamlDataset, self).__init__(
334 | args, img_file, label_file, hw_file, linelist_file, is_train, cv2_output=cv2_output, scale_factor=scale_factor)
335 |
--------------------------------------------------------------------------------
/src/datasets/human_mesh_tsv.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Microsoft Corporation.
3 | Licensed under the MIT license.
4 |
5 | """
6 |
7 | import cv2
8 | import math
9 | import json
10 | from PIL import Image
11 | import os.path as op
12 | import numpy as np
13 | import code
14 |
15 | from src.utils.tsv_file import TSVFile, CompositeTSVFile
16 | from src.utils.tsv_file_ops import load_linelist_file, load_from_yaml_file, find_file_path_in_yaml
17 | from src.utils.image_ops import img_from_base64, crop, flip_img, flip_pose, flip_kp, transform, rot_aa
18 | import torch
19 | import torchvision.transforms as transforms
20 |
21 |
22 | class MeshTSVDataset(object):
23 | def __init__(self, img_file, label_file=None, hw_file=None,
24 | linelist_file=None, is_train=True, cv2_output=False, scale_factor=1):
25 |
26 | self.img_file = img_file
27 | self.label_file = label_file
28 | self.hw_file = hw_file
29 | self.linelist_file = linelist_file
30 | self.img_tsv = self.get_tsv_file(img_file)
31 | self.label_tsv = None if label_file is None else self.get_tsv_file(label_file)
32 | self.hw_tsv = None if hw_file is None else self.get_tsv_file(hw_file)
33 |
34 | if self.is_composite:
35 | assert op.isfile(self.linelist_file)
36 | self.line_list = [i for i in range(self.hw_tsv.num_rows())]
37 | else:
38 | self.line_list = load_linelist_file(linelist_file)
39 |
40 | self.cv2_output = cv2_output
41 | self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406],
42 | std=[0.229, 0.224, 0.225])
43 | self.is_train = is_train
44 | self.scale_factor = 0.25 # rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor]
45 | self.noise_factor = 0.4
46 | self.rot_factor = 30 # Random rotation in the range [-rot_factor, rot_factor]
47 | self.img_res = 224
48 |
49 | self.image_keys = self.prepare_image_keys()
50 |
51 | self.joints_definition = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder',
52 | 'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear')
53 | self.pelvis_index = self.joints_definition.index('Pelvis')
54 |
55 | def get_tsv_file(self, tsv_file):
56 | if tsv_file:
57 | if self.is_composite:
58 | return CompositeTSVFile(tsv_file, self.linelist_file,
59 | root=self.root)
60 | tsv_path = find_file_path_in_yaml(tsv_file, self.root)
61 | return TSVFile(tsv_path)
62 |
63 | def get_valid_tsv(self):
64 | # sorted by file size
65 | if self.hw_tsv:
66 | return self.hw_tsv
67 | if self.label_tsv:
68 | return self.label_tsv
69 |
70 | def prepare_image_keys(self):
71 | tsv = self.get_valid_tsv()
72 | return [tsv.get_key(i) for i in range(tsv.num_rows())]
73 |
74 | def prepare_image_key_to_index(self):
75 | tsv = self.get_valid_tsv()
76 | return {tsv.get_key(i) : i for i in range(tsv.num_rows())}
77 |
78 |
79 | def augm_params(self):
80 | """Get augmentation parameters."""
81 | flip = 0 # flipping
82 | pn = np.ones(3) # per channel pixel-noise
83 | rot = 0 # rotation
84 | sc = 1 # scaling
85 | if self.is_train:
86 | # We flip with probability 1/2
87 | if np.random.uniform() <= 0.5:
88 | flip = 1
89 |
90 | # Each channel is multiplied with a number
91 | # in the area [1-opt.noiseFactor,1+opt.noiseFactor]
92 | pn = np.random.uniform(1-self.noise_factor, 1+self.noise_factor, 3)
93 |
94 | # The rotation is a number in the area [-2*rotFactor, 2*rotFactor]
95 | rot = min(2*self.rot_factor,
96 | max(-2*self.rot_factor, np.random.randn()*self.rot_factor))
97 |
98 | # The scale is multiplied with a number
99 | # in the area [1-scaleFactor,1+scaleFactor]
100 | sc = min(1+self.scale_factor,
101 | max(1-self.scale_factor, np.random.randn()*self.scale_factor+1))
102 | # but it is zero with probability 3/5
103 | if np.random.uniform() <= 0.6:
104 | rot = 0
105 |
106 | return flip, pn, rot, sc
107 |
108 | def rgb_processing(self, rgb_img, center, scale, rot, flip, pn):
109 | """Process rgb image and do augmentation."""
110 | rgb_img = crop(rgb_img, center, scale,
111 | [self.img_res, self.img_res], rot=rot)
112 | # flip the image
113 | if flip:
114 | rgb_img = flip_img(rgb_img)
115 | # in the rgb image we add pixel noise in a channel-wise manner
116 | rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0]))
117 | rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1]))
118 | rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2]))
119 | # (3,224,224),float,[0,1]
120 | rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0
121 | return rgb_img
122 |
123 | def j2d_processing(self, kp, center, scale, r, f):
124 | """Process gt 2D keypoints and apply all augmentation transforms."""
125 | nparts = kp.shape[0]
126 | for i in range(nparts):
127 | kp[i,0:2] = transform(kp[i,0:2]+1, center, scale,
128 | [self.img_res, self.img_res], rot=r)
129 | # convert to normalized coordinates
130 | kp[:,:-1] = 2.*kp[:,:-1]/self.img_res - 1.
131 | # flip the x coordinates
132 | if f:
133 | kp = flip_kp(kp)
134 | kp = kp.astype('float32')
135 | return kp
136 |
137 | def j3d_processing(self, S, r, f):
138 | """Process gt 3D keypoints and apply all augmentation transforms."""
139 | # in-plane rotation
140 | rot_mat = np.eye(3)
141 | if not r == 0:
142 | rot_rad = -r * np.pi / 180
143 | sn,cs = np.sin(rot_rad), np.cos(rot_rad)
144 | rot_mat[0,:2] = [cs, -sn]
145 | rot_mat[1,:2] = [sn, cs]
146 | S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1])
147 | # flip the x coordinates
148 | if f:
149 | S = flip_kp(S)
150 | S = S.astype('float32')
151 | return S
152 |
153 | def pose_processing(self, pose, r, f):
154 | """Process SMPL theta parameters and apply all augmentation transforms."""
155 | # rotation or the pose parameters
156 | pose = pose.astype('float32')
157 | pose[:3] = rot_aa(pose[:3], r)
158 | # flip the pose parameters
159 | if f:
160 | pose = flip_pose(pose)
161 | # (72),float
162 | pose = pose.astype('float32')
163 | return pose
164 |
165 | def get_line_no(self, idx):
166 | return idx if self.line_list is None else self.line_list[idx]
167 |
168 | def get_image(self, idx):
169 | line_no = self.get_line_no(idx)
170 | row = self.img_tsv[line_no]
171 | # use -1 to support old format with multiple columns.
172 | cv2_im = img_from_base64(row[-1])
173 | if self.cv2_output:
174 | return cv2_im.astype(np.float32, copy=True)
175 | cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB)
176 |
177 | return cv2_im
178 |
179 | def get_annotations(self, idx):
180 | line_no = self.get_line_no(idx)
181 | if self.label_tsv is not None:
182 | row = self.label_tsv[line_no]
183 | annotations = json.loads(row[1])
184 | return annotations
185 | else:
186 | return []
187 |
188 | def get_target_from_annotations(self, annotations, img_size, idx):
189 | # This function will be overwritten by each dataset to
190 | # decode the labels to specific formats for each task.
191 | return annotations
192 |
193 |
194 | def get_img_info(self, idx):
195 | if self.hw_tsv is not None:
196 | line_no = self.get_line_no(idx)
197 | row = self.hw_tsv[line_no]
198 | try:
199 | # json string format with "height" and "width" being the keys
200 | return json.loads(row[1])[0]
201 | except ValueError:
202 | # list of strings representing height and width in order
203 | hw_str = row[1].split(' ')
204 | hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])}
205 | return hw_dict
206 |
207 | def get_img_key(self, idx):
208 | line_no = self.get_line_no(idx)
209 | # based on the overhead of reading each row.
210 | if self.hw_tsv:
211 | return self.hw_tsv[line_no][0]
212 | elif self.label_tsv:
213 | return self.label_tsv[line_no][0]
214 | else:
215 | return self.img_tsv[line_no][0]
216 |
217 | def __len__(self):
218 | if self.line_list is None:
219 | return self.img_tsv.num_rows()
220 | else:
221 | return len(self.line_list)
222 |
223 | def __getitem__(self, idx):
224 |
225 | img = self.get_image(idx)
226 | img_key = self.get_img_key(idx)
227 | annotations = self.get_annotations(idx)
228 |
229 | annotations = annotations[0]
230 | center = annotations['center']
231 | scale = annotations['scale']
232 | has_2d_joints = annotations['has_2d_joints']
233 | has_3d_joints = annotations['has_3d_joints']
234 | joints_2d = np.asarray(annotations['2d_joints'])
235 | joints_3d = np.asarray(annotations['3d_joints'])
236 |
237 | if joints_2d.ndim==3:
238 | joints_2d = joints_2d[0]
239 | if joints_3d.ndim==3:
240 | joints_3d = joints_3d[0]
241 |
242 | # Get SMPL parameters, if available
243 | has_smpl = np.asarray(annotations['has_smpl'])
244 | pose = np.asarray(annotations['pose'])
245 | betas = np.asarray(annotations['betas'])
246 |
247 | try:
248 | gender = annotations['gender']
249 | except KeyError:
250 | gender = 'none'
251 |
252 | # Get augmentation parameters
253 | flip,pn,rot,sc = self.augm_params()
254 |
255 | # Process image
256 | img = self.rgb_processing(img, center, sc*scale, rot, flip, pn)
257 | img = torch.from_numpy(img).float()
258 | # Store image before normalization to use it in visualization
259 | transfromed_img = self.normalize_img(img)
260 |
261 | # normalize 3d pose by aligning the pelvis as the root (at origin)
262 | root_pelvis = joints_3d[self.pelvis_index,:-1]
263 | joints_3d[:,:-1] = joints_3d[:,:-1] - root_pelvis[None,:]
264 | # 3d pose augmentation (random flip + rotation, consistent to image and SMPL)
265 | joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, flip)
266 | # 2d pose augmentation
267 | joints_2d_transformed = self.j2d_processing(joints_2d.copy(), center, sc*scale, rot, flip)
268 |
269 | ###################################
270 | # Masking percantage
271 | # We observe that 30% works better for human body mesh. Further details are reported in the paper.
272 | mvm_percent = 0.3
273 | ###################################
274 |
275 | mjm_mask = np.ones((14,1))
276 | if self.is_train:
277 | num_joints = 14
278 | pb = np.random.random_sample()
279 | masked_num = int(pb * mvm_percent * num_joints) # at most x% of the joints could be masked
280 | indices = np.random.choice(np.arange(num_joints),replace=False,size=masked_num)
281 | mjm_mask[indices,:] = 0.0
282 | mjm_mask = torch.from_numpy(mjm_mask).float()
283 |
284 | mvm_mask = np.ones((431,1))
285 | if self.is_train:
286 | num_vertices = 431
287 | pb = np.random.random_sample()
288 | masked_num = int(pb * mvm_percent * num_vertices) # at most x% of the vertices could be masked
289 | indices = np.random.choice(np.arange(num_vertices),replace=False,size=masked_num)
290 | mvm_mask[indices,:] = 0.0
291 | mvm_mask = torch.from_numpy(mvm_mask).float()
292 |
293 | meta_data = {}
294 | meta_data['ori_img'] = img
295 | meta_data['pose'] = torch.from_numpy(self.pose_processing(pose, rot, flip)).float()
296 | meta_data['betas'] = torch.from_numpy(betas).float()
297 | meta_data['joints_3d'] = torch.from_numpy(joints_3d_transformed).float()
298 | meta_data['has_3d_joints'] = has_3d_joints
299 | meta_data['has_smpl'] = has_smpl
300 |
301 | meta_data['mjm_mask'] = mjm_mask
302 | meta_data['mvm_mask'] = mvm_mask
303 |
304 | # Get 2D keypoints and apply augmentation transforms
305 | meta_data['has_2d_joints'] = has_2d_joints
306 | meta_data['joints_2d'] = torch.from_numpy(joints_2d_transformed).float()
307 | meta_data['scale'] = float(sc * scale)
308 | meta_data['center'] = np.asarray(center).astype(np.float32)
309 | meta_data['gender'] = gender
310 | return img_key, transfromed_img, meta_data
311 |
312 |
313 |
314 | class MeshTSVYamlDataset(MeshTSVDataset):
315 | """ TSVDataset taking a Yaml file for easy function call
316 | """
317 | def __init__(self, yaml_file, is_train=True, cv2_output=False, scale_factor=1):
318 | self.cfg = load_from_yaml_file(yaml_file)
319 | self.is_composite = self.cfg.get('composite', False)
320 | self.root = op.dirname(yaml_file)
321 |
322 | if self.is_composite==False:
323 | img_file = find_file_path_in_yaml(self.cfg['img'], self.root)
324 | label_file = find_file_path_in_yaml(self.cfg.get('label', None),
325 | self.root)
326 | hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root)
327 | linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
328 | self.root)
329 | else:
330 | img_file = self.cfg['img']
331 | hw_file = self.cfg['hw']
332 | label_file = self.cfg.get('label', None)
333 | linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None),
334 | self.root)
335 |
336 | super(MeshTSVYamlDataset, self).__init__(
337 | img_file, label_file, hw_file, linelist_file, is_train, cv2_output=cv2_output, scale_factor=scale_factor)
338 |
--------------------------------------------------------------------------------
/src/modeling/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/__init__.py
--------------------------------------------------------------------------------
/src/modeling/_gcnn.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import scipy.sparse
6 | import math
7 |
8 | class SparseMM(torch.autograd.Function):
9 | """Redefine sparse @ dense matrix multiplication to enable backpropagation.
10 | The builtin matrix multiplication operation does not support backpropagation in some cases.
11 | """
12 | @staticmethod
13 | def forward(ctx, sparse, dense):
14 | ctx.req_grad = dense.requires_grad
15 | ctx.save_for_backward(sparse)
16 | return torch.matmul(sparse, dense)
17 |
18 | @staticmethod
19 | def backward(ctx, grad_output):
20 | grad_input = None
21 | sparse, = ctx.saved_tensors
22 | if ctx.req_grad:
23 | grad_input = torch.matmul(sparse.t(), grad_output)
24 | return None, grad_input
25 |
26 | def spmm(sparse, dense):
27 | return SparseMM.apply(sparse, dense)
28 |
29 |
30 | def gelu(x):
31 | """Implementation of the gelu activation function.
32 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
33 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
34 | Also see https://arxiv.org/abs/1606.08415
35 | """
36 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
37 |
38 | class BertLayerNorm(torch.nn.Module):
39 | def __init__(self, hidden_size, eps=1e-12):
40 | """Construct a layernorm module in the TF style (epsilon inside the square root).
41 | """
42 | super(BertLayerNorm, self).__init__()
43 | self.weight = torch.nn.Parameter(torch.ones(hidden_size))
44 | self.bias = torch.nn.Parameter(torch.zeros(hidden_size))
45 | self.variance_epsilon = eps
46 |
47 | def forward(self, x):
48 | u = x.mean(-1, keepdim=True)
49 | s = (x - u).pow(2).mean(-1, keepdim=True)
50 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
51 | return self.weight * x + self.bias
52 |
53 |
54 | class GraphResBlock(torch.nn.Module):
55 | """
56 | Graph Residual Block similar to the Bottleneck Residual Block in ResNet
57 | """
58 | def __init__(self, in_channels, out_channels, mesh_type='body'):
59 | super(GraphResBlock, self).__init__()
60 | self.in_channels = in_channels
61 | self.out_channels = out_channels
62 | self.lin1 = GraphLinear(in_channels, out_channels // 2)
63 | self.conv = GraphConvolution(out_channels // 2, out_channels // 2, mesh_type)
64 | self.lin2 = GraphLinear(out_channels // 2, out_channels)
65 | self.skip_conv = GraphLinear(in_channels, out_channels)
66 | # print('Use BertLayerNorm in GraphResBlock')
67 | self.pre_norm = BertLayerNorm(in_channels)
68 | self.norm1 = BertLayerNorm(out_channels // 2)
69 | self.norm2 = BertLayerNorm(out_channels // 2)
70 |
71 | def forward(self, x):
72 | trans_y = F.relu(self.pre_norm(x)).transpose(1,2)
73 | y = self.lin1(trans_y).transpose(1,2)
74 |
75 | y = F.relu(self.norm1(y))
76 | y = self.conv(y)
77 |
78 | trans_y = F.relu(self.norm2(y)).transpose(1,2)
79 | y = self.lin2(trans_y).transpose(1,2)
80 |
81 | z = x+y
82 |
83 | return z
84 |
85 | # class GraphResBlock(torch.nn.Module):
86 | # """
87 | # Graph Residual Block similar to the Bottleneck Residual Block in ResNet
88 | # """
89 | # def __init__(self, in_channels, out_channels, mesh_type='body'):
90 | # super(GraphResBlock, self).__init__()
91 | # self.in_channels = in_channels
92 | # self.out_channels = out_channels
93 | # self.conv = GraphConvolution(self.in_channels, self.out_channels, mesh_type)
94 | # print('Use BertLayerNorm and GeLU in GraphResBlock')
95 | # self.norm = BertLayerNorm(self.out_channels)
96 | # def forward(self, x):
97 | # y = self.conv(x)
98 | # y = self.norm(y)
99 | # y = gelu(y)
100 | # z = x+y
101 | # return z
102 |
103 | class GraphLinear(torch.nn.Module):
104 | """
105 | Generalization of 1x1 convolutions on Graphs
106 | """
107 | def __init__(self, in_channels, out_channels):
108 | super(GraphLinear, self).__init__()
109 | self.in_channels = in_channels
110 | self.out_channels = out_channels
111 | self.W = torch.nn.Parameter(torch.FloatTensor(out_channels, in_channels))
112 | self.b = torch.nn.Parameter(torch.FloatTensor(out_channels))
113 | self.reset_parameters()
114 |
115 | def reset_parameters(self):
116 | w_stdv = 1 / (self.in_channels * self.out_channels)
117 | self.W.data.uniform_(-w_stdv, w_stdv)
118 | self.b.data.uniform_(-w_stdv, w_stdv)
119 |
120 | def forward(self, x):
121 | return torch.matmul(self.W[None, :], x) + self.b[None, :, None]
122 |
123 | class GraphConvolution(torch.nn.Module):
124 | """Simple GCN layer, similar to https://arxiv.org/abs/1609.02907."""
125 | def __init__(self, in_features, out_features, mesh='body', bias=True):
126 | super(GraphConvolution, self).__init__()
127 | device=torch.device('cuda')
128 | self.in_features = in_features
129 | self.out_features = out_features
130 |
131 | if mesh=='body':
132 | adj_indices = torch.load('./src/modeling/data/smpl_431_adjmat_indices.pt')
133 | adj_mat_value = torch.load('./src/modeling/data/smpl_431_adjmat_values.pt')
134 | adj_mat_size = torch.load('./src/modeling/data/smpl_431_adjmat_size.pt')
135 | elif mesh=='hand':
136 | adj_indices = torch.load('./src/modeling/data/mano_195_adjmat_indices.pt')
137 | adj_mat_value = torch.load('./src/modeling/data/mano_195_adjmat_values.pt')
138 | adj_mat_size = torch.load('./src/modeling/data/mano_195_adjmat_size.pt')
139 |
140 | self.adjmat = torch.sparse_coo_tensor(adj_indices, adj_mat_value, size=adj_mat_size).to(device)
141 |
142 | self.weight = torch.nn.Parameter(torch.FloatTensor(in_features, out_features))
143 | if bias:
144 | self.bias = torch.nn.Parameter(torch.FloatTensor(out_features))
145 | else:
146 | self.register_parameter('bias', None)
147 | self.reset_parameters()
148 |
149 | def reset_parameters(self):
150 | # stdv = 1. / math.sqrt(self.weight.size(1))
151 | stdv = 6. / math.sqrt(self.weight.size(0) + self.weight.size(1))
152 | self.weight.data.uniform_(-stdv, stdv)
153 | if self.bias is not None:
154 | self.bias.data.uniform_(-stdv, stdv)
155 |
156 | def forward(self, x):
157 | if x.ndimension() == 2:
158 | support = torch.matmul(x, self.weight)
159 | output = torch.matmul(self.adjmat, support)
160 | if self.bias is not None:
161 | output = output + self.bias
162 | return output
163 | else:
164 | output = []
165 | for i in range(x.shape[0]):
166 | support = torch.matmul(x[i], self.weight)
167 | # output.append(torch.matmul(self.adjmat, support))
168 | output.append(spmm(self.adjmat, support))
169 | output = torch.stack(output, dim=0)
170 | if self.bias is not None:
171 | output = output + self.bias
172 | return output
173 |
174 | def __repr__(self):
175 | return self.__class__.__name__ + ' (' \
176 | + str(self.in_features) + ' -> ' \
177 | + str(self.out_features) + ')'
--------------------------------------------------------------------------------
/src/modeling/_mano.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains the MANO defination and mesh sampling operations for MANO mesh
3 |
4 | Adapted from opensource projects
5 | MANOPTH (https://github.com/hassony2/manopth)
6 | Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
7 | GraphCMR (https://github.com/nkolot/GraphCMR/)
8 | """
9 |
10 | from __future__ import division
11 | import numpy as np
12 | import torch
13 | import torch.nn as nn
14 | import os.path as osp
15 | import json
16 | import code
17 | from manopth.manolayer import ManoLayer
18 | import scipy.sparse
19 | import src.modeling.data.config as cfg
20 |
21 | class MANO(nn.Module):
22 | def __init__(self):
23 | super(MANO, self).__init__()
24 |
25 | self.mano_dir = 'src/modeling/data'
26 | self.layer = self.get_layer()
27 | self.vertex_num = 778
28 | self.face = self.layer.th_faces.numpy()
29 | self.joint_regressor = self.layer.th_J_regressor.numpy()
30 |
31 | self.joint_num = 21
32 | self.joints_name = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
33 | self.skeleton = ( (0,1), (0,5), (0,9), (0,13), (0,17), (1,2), (2,3), (3,4), (5,6), (6,7), (7,8), (9,10), (10,11), (11,12), (13,14), (14,15), (15,16), (17,18), (18,19), (19,20) )
34 | self.root_joint_idx = self.joints_name.index('Wrist')
35 |
36 | # add fingertips to joint_regressor
37 | self.fingertip_vertex_idx = [745, 317, 444, 556, 673] # mesh vertex idx (right hand)
38 | thumbtip_onehot = np.array([1 if i == 745 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
39 | indextip_onehot = np.array([1 if i == 317 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
40 | middletip_onehot = np.array([1 if i == 445 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
41 | ringtip_onehot = np.array([1 if i == 556 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
42 | pinkytip_onehot = np.array([1 if i == 673 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1)
43 | self.joint_regressor = np.concatenate((self.joint_regressor, thumbtip_onehot, indextip_onehot, middletip_onehot, ringtip_onehot, pinkytip_onehot))
44 | self.joint_regressor = self.joint_regressor[[0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20],:]
45 | joint_regressor_torch = torch.from_numpy(self.joint_regressor).float()
46 | self.register_buffer('joint_regressor_torch', joint_regressor_torch)
47 |
48 | def get_layer(self):
49 | return ManoLayer(mano_root=osp.join(self.mano_dir), flat_hand_mean=False, use_pca=False) # load right hand MANO model
50 |
51 | def get_3d_joints(self, vertices):
52 | """
53 | This method is used to get the joint locations from the SMPL mesh
54 | Input:
55 | vertices: size = (B, 778, 3)
56 | Output:
57 | 3D joints: size = (B, 21, 3)
58 | """
59 | joints = torch.einsum('bik,ji->bjk', [vertices, self.joint_regressor_torch])
60 | return joints
61 |
62 |
63 | class SparseMM(torch.autograd.Function):
64 | """Redefine sparse @ dense matrix multiplication to enable backpropagation.
65 | The builtin matrix multiplication operation does not support backpropagation in some cases.
66 | """
67 | @staticmethod
68 | def forward(ctx, sparse, dense):
69 | ctx.req_grad = dense.requires_grad
70 | ctx.save_for_backward(sparse)
71 | return torch.matmul(sparse, dense)
72 |
73 | @staticmethod
74 | def backward(ctx, grad_output):
75 | grad_input = None
76 | sparse, = ctx.saved_tensors
77 | if ctx.req_grad:
78 | grad_input = torch.matmul(sparse.t(), grad_output)
79 | return None, grad_input
80 |
81 | def spmm(sparse, dense):
82 | return SparseMM.apply(sparse, dense)
83 |
84 |
85 | def scipy_to_pytorch(A, U, D):
86 | """Convert scipy sparse matrices to pytorch sparse matrix."""
87 | ptU = []
88 | ptD = []
89 |
90 | for i in range(len(U)):
91 | u = scipy.sparse.coo_matrix(U[i])
92 | i = torch.LongTensor(np.array([u.row, u.col]))
93 | v = torch.FloatTensor(u.data)
94 | ptU.append(torch.sparse.FloatTensor(i, v, u.shape))
95 |
96 | for i in range(len(D)):
97 | d = scipy.sparse.coo_matrix(D[i])
98 | i = torch.LongTensor(np.array([d.row, d.col]))
99 | v = torch.FloatTensor(d.data)
100 | ptD.append(torch.sparse.FloatTensor(i, v, d.shape))
101 |
102 | return ptU, ptD
103 |
104 |
105 | def adjmat_sparse(adjmat, nsize=1):
106 | """Create row-normalized sparse graph adjacency matrix."""
107 | adjmat = scipy.sparse.csr_matrix(adjmat)
108 | if nsize > 1:
109 | orig_adjmat = adjmat.copy()
110 | for _ in range(1, nsize):
111 | adjmat = adjmat * orig_adjmat
112 | adjmat.data = np.ones_like(adjmat.data)
113 | for i in range(adjmat.shape[0]):
114 | adjmat[i,i] = 1
115 | num_neighbors = np.array(1 / adjmat.sum(axis=-1))
116 | adjmat = adjmat.multiply(num_neighbors)
117 | adjmat = scipy.sparse.coo_matrix(adjmat)
118 | row = adjmat.row
119 | col = adjmat.col
120 | data = adjmat.data
121 | i = torch.LongTensor(np.array([row, col]))
122 | v = torch.from_numpy(data).float()
123 | adjmat = torch.sparse.FloatTensor(i, v, adjmat.shape)
124 | return adjmat
125 |
126 | def get_graph_params(filename, nsize=1):
127 | """Load and process graph adjacency matrix and upsampling/downsampling matrices."""
128 | data = np.load(filename, encoding='latin1', allow_pickle=True)
129 | A = data['A']
130 | U = data['U']
131 | D = data['D']
132 | U, D = scipy_to_pytorch(A, U, D)
133 | A = [adjmat_sparse(a, nsize=nsize) for a in A]
134 | return A, U, D
135 |
136 |
137 | class Mesh(object):
138 | """Mesh object that is used for handling certain graph operations."""
139 | def __init__(self, filename=cfg.MANO_sampling_matrix,
140 | num_downsampling=1, nsize=1, device=torch.device('cuda')):
141 | self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
142 | # self._A = [a.to(device) for a in self._A]
143 | self._U = [u.to(device) for u in self._U]
144 | self._D = [d.to(device) for d in self._D]
145 | self.num_downsampling = num_downsampling
146 |
147 | def downsample(self, x, n1=0, n2=None):
148 | """Downsample mesh."""
149 | if n2 is None:
150 | n2 = self.num_downsampling
151 | if x.ndimension() < 3:
152 | for i in range(n1, n2):
153 | x = spmm(self._D[i], x)
154 | elif x.ndimension() == 3:
155 | out = []
156 | for i in range(x.shape[0]):
157 | y = x[i]
158 | for j in range(n1, n2):
159 | y = spmm(self._D[j], y)
160 | out.append(y)
161 | x = torch.stack(out, dim=0)
162 | return x
163 |
164 | def upsample(self, x, n1=1, n2=0):
165 | """Upsample mesh."""
166 | if x.ndimension() < 3:
167 | for i in reversed(range(n2, n1)):
168 | x = spmm(self._U[i], x)
169 | elif x.ndimension() == 3:
170 | out = []
171 | for i in range(x.shape[0]):
172 | y = x[i]
173 | for j in reversed(range(n2, n1)):
174 | y = spmm(self._U[j], y)
175 | out.append(y)
176 | x = torch.stack(out, dim=0)
177 | return x
178 |
--------------------------------------------------------------------------------
/src/modeling/_smpl.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains the definition of the SMPL model
3 |
4 | It is adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/)
5 | """
6 | from __future__ import division
7 |
8 | import torch
9 | import torch.nn as nn
10 | import numpy as np
11 | import scipy.sparse
12 | try:
13 | import cPickle as pickle
14 | except ImportError:
15 | import pickle
16 |
17 | from src.utils.geometric_layers import rodrigues
18 | import src.modeling.data.config as cfg
19 |
20 | class SMPL(nn.Module):
21 |
22 | def __init__(self, gender='neutral'):
23 | super(SMPL, self).__init__()
24 |
25 | if gender=='m':
26 | model_file=cfg.SMPL_Male
27 | elif gender=='f':
28 | model_file=cfg.SMPL_Female
29 | else:
30 | model_file=cfg.SMPL_FILE
31 |
32 | smpl_model = pickle.load(open(model_file, 'rb'), encoding='latin1')
33 | J_regressor = smpl_model['J_regressor'].tocoo()
34 | row = J_regressor.row
35 | col = J_regressor.col
36 | data = J_regressor.data
37 | i = torch.LongTensor([row, col])
38 | v = torch.FloatTensor(data)
39 | J_regressor_shape = [24, 6890]
40 | self.register_buffer('J_regressor', torch.sparse.FloatTensor(i, v, J_regressor_shape).to_dense())
41 | self.register_buffer('weights', torch.FloatTensor(smpl_model['weights']))
42 | self.register_buffer('posedirs', torch.FloatTensor(smpl_model['posedirs']))
43 | self.register_buffer('v_template', torch.FloatTensor(smpl_model['v_template']))
44 | self.register_buffer('shapedirs', torch.FloatTensor(np.array(smpl_model['shapedirs'])))
45 | self.register_buffer('faces', torch.from_numpy(smpl_model['f'].astype(np.int64)))
46 | self.register_buffer('kintree_table', torch.from_numpy(smpl_model['kintree_table'].astype(np.int64)))
47 | id_to_col = {self.kintree_table[1, i].item(): i for i in range(self.kintree_table.shape[1])}
48 | self.register_buffer('parent', torch.LongTensor([id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1])]))
49 |
50 | self.pose_shape = [24, 3]
51 | self.beta_shape = [10]
52 | self.translation_shape = [3]
53 |
54 | self.pose = torch.zeros(self.pose_shape)
55 | self.beta = torch.zeros(self.beta_shape)
56 | self.translation = torch.zeros(self.translation_shape)
57 |
58 | self.verts = None
59 | self.J = None
60 | self.R = None
61 |
62 | J_regressor_extra = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_TRAIN_EXTRA)).float()
63 | self.register_buffer('J_regressor_extra', J_regressor_extra)
64 | self.joints_idx = cfg.JOINTS_IDX
65 |
66 | J_regressor_h36m_correct = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_H36M_correct)).float()
67 | self.register_buffer('J_regressor_h36m_correct', J_regressor_h36m_correct)
68 |
69 |
70 | def forward(self, pose, beta):
71 | device = pose.device
72 | batch_size = pose.shape[0]
73 | v_template = self.v_template[None, :]
74 | shapedirs = self.shapedirs.view(-1,10)[None, :].expand(batch_size, -1, -1)
75 | beta = beta[:, :, None]
76 | v_shaped = torch.matmul(shapedirs, beta).view(-1, 6890, 3) + v_template
77 | # batched sparse matmul not supported in pytorch
78 | J = []
79 | for i in range(batch_size):
80 | J.append(torch.matmul(self.J_regressor, v_shaped[i]))
81 | J = torch.stack(J, dim=0)
82 | # input it rotmat: (bs,24,3,3)
83 | if pose.ndimension() == 4:
84 | R = pose
85 | # input it rotmat: (bs,72)
86 | elif pose.ndimension() == 2:
87 | pose_cube = pose.view(-1, 3) # (batch_size * 24, 1, 3)
88 | R = rodrigues(pose_cube).view(batch_size, 24, 3, 3)
89 | R = R.view(batch_size, 24, 3, 3)
90 | I_cube = torch.eye(3)[None, None, :].to(device)
91 | # I_cube = torch.eye(3)[None, None, :].expand(theta.shape[0], R.shape[1]-1, -1, -1)
92 | lrotmin = (R[:,1:,:] - I_cube).view(batch_size, -1)
93 | posedirs = self.posedirs.view(-1,207)[None, :].expand(batch_size, -1, -1)
94 | v_posed = v_shaped + torch.matmul(posedirs, lrotmin[:, :, None]).view(-1, 6890, 3)
95 | J_ = J.clone()
96 | J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :]
97 | G_ = torch.cat([R, J_[:, :, :, None]], dim=-1)
98 | pad_row = torch.FloatTensor([0,0,0,1]).to(device).view(1,1,1,4).expand(batch_size, 24, -1, -1)
99 | G_ = torch.cat([G_, pad_row], dim=2)
100 | G = [G_[:, 0].clone()]
101 | for i in range(1, 24):
102 | G.append(torch.matmul(G[self.parent[i-1]], G_[:, i, :, :]))
103 | G = torch.stack(G, dim=1)
104 |
105 | rest = torch.cat([J, torch.zeros(batch_size, 24, 1).to(device)], dim=2).view(batch_size, 24, 4, 1)
106 | zeros = torch.zeros(batch_size, 24, 4, 3).to(device)
107 | rest = torch.cat([zeros, rest], dim=-1)
108 | rest = torch.matmul(G, rest)
109 | G = G - rest
110 | T = torch.matmul(self.weights, G.permute(1,0,2,3).contiguous().view(24,-1)).view(6890, batch_size, 4, 4).transpose(0,1)
111 | rest_shape_h = torch.cat([v_posed, torch.ones_like(v_posed)[:, :, [0]]], dim=-1)
112 | v = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0]
113 | return v
114 |
115 | def get_joints(self, vertices):
116 | """
117 | This method is used to get the joint locations from the SMPL mesh
118 | Input:
119 | vertices: size = (B, 6890, 3)
120 | Output:
121 | 3D joints: size = (B, 38, 3)
122 | """
123 | joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor])
124 | joints_extra = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_extra])
125 | joints = torch.cat((joints, joints_extra), dim=1)
126 | joints = joints[:, cfg.JOINTS_IDX]
127 | return joints
128 |
129 | def get_h36m_joints(self, vertices):
130 | """
131 | This method is used to get the joint locations from the SMPL mesh
132 | Input:
133 | vertices: size = (B, 6890, 3)
134 | Output:
135 | 3D joints: size = (B, 24, 3)
136 | """
137 | joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct])
138 | return joints
139 |
140 | class SparseMM(torch.autograd.Function):
141 | """Redefine sparse @ dense matrix multiplication to enable backpropagation.
142 | The builtin matrix multiplication operation does not support backpropagation in some cases.
143 | """
144 | @staticmethod
145 | def forward(ctx, sparse, dense):
146 | ctx.req_grad = dense.requires_grad
147 | ctx.save_for_backward(sparse)
148 | return torch.matmul(sparse, dense)
149 |
150 | @staticmethod
151 | def backward(ctx, grad_output):
152 | grad_input = None
153 | sparse, = ctx.saved_tensors
154 | if ctx.req_grad:
155 | grad_input = torch.matmul(sparse.t(), grad_output)
156 | return None, grad_input
157 |
158 | def spmm(sparse, dense):
159 | return SparseMM.apply(sparse, dense)
160 |
161 |
162 | def scipy_to_pytorch(A, U, D):
163 | """Convert scipy sparse matrices to pytorch sparse matrix."""
164 | ptU = []
165 | ptD = []
166 |
167 | for i in range(len(U)):
168 | u = scipy.sparse.coo_matrix(U[i])
169 | i = torch.LongTensor(np.array([u.row, u.col]))
170 | v = torch.FloatTensor(u.data)
171 | ptU.append(torch.sparse.FloatTensor(i, v, u.shape))
172 |
173 | for i in range(len(D)):
174 | d = scipy.sparse.coo_matrix(D[i])
175 | i = torch.LongTensor(np.array([d.row, d.col]))
176 | v = torch.FloatTensor(d.data)
177 | ptD.append(torch.sparse.FloatTensor(i, v, d.shape))
178 |
179 | return ptU, ptD
180 |
181 |
182 | def adjmat_sparse(adjmat, nsize=1):
183 | """Create row-normalized sparse graph adjacency matrix."""
184 | adjmat = scipy.sparse.csr_matrix(adjmat)
185 | if nsize > 1:
186 | orig_adjmat = adjmat.copy()
187 | for _ in range(1, nsize):
188 | adjmat = adjmat * orig_adjmat
189 | adjmat.data = np.ones_like(adjmat.data)
190 | for i in range(adjmat.shape[0]):
191 | adjmat[i,i] = 1
192 | num_neighbors = np.array(1 / adjmat.sum(axis=-1))
193 | adjmat = adjmat.multiply(num_neighbors)
194 | adjmat = scipy.sparse.coo_matrix(adjmat)
195 | row = adjmat.row
196 | col = adjmat.col
197 | data = adjmat.data
198 | i = torch.LongTensor(np.array([row, col]))
199 | v = torch.from_numpy(data).float()
200 | adjmat = torch.sparse.FloatTensor(i, v, adjmat.shape)
201 | return adjmat
202 |
203 | def get_graph_params(filename, nsize=1):
204 | """Load and process graph adjacency matrix and upsampling/downsampling matrices."""
205 | data = np.load(filename, encoding='latin1', allow_pickle=True)
206 | A = data['A']
207 | U = data['U']
208 | D = data['D']
209 | U, D = scipy_to_pytorch(A, U, D)
210 | A = [adjmat_sparse(a, nsize=nsize) for a in A]
211 | return A, U, D
212 |
213 |
214 | class Mesh(object):
215 | """Mesh object that is used for handling certain graph operations."""
216 | def __init__(self, filename=cfg.SMPL_sampling_matrix,
217 | num_downsampling=1, nsize=1, device=torch.device('cuda')):
218 | self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
219 | # self._A = [a.to(device) for a in self._A]
220 | self._U = [u.to(device) for u in self._U]
221 | self._D = [d.to(device) for d in self._D]
222 | self.num_downsampling = num_downsampling
223 |
224 | # load template vertices from SMPL and normalize them
225 | smpl = SMPL()
226 | ref_vertices = smpl.v_template
227 | center = 0.5*(ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None]
228 | ref_vertices -= center
229 | ref_vertices /= ref_vertices.abs().max().item()
230 |
231 | self._ref_vertices = ref_vertices.to(device)
232 | self.faces = smpl.faces.int().to(device)
233 |
234 | # @property
235 | # def adjmat(self):
236 | # """Return the graph adjacency matrix at the specified subsampling level."""
237 | # return self._A[self.num_downsampling].float()
238 |
239 | @property
240 | def ref_vertices(self):
241 | """Return the template vertices at the specified subsampling level."""
242 | ref_vertices = self._ref_vertices
243 | for i in range(self.num_downsampling):
244 | ref_vertices = torch.spmm(self._D[i], ref_vertices)
245 | return ref_vertices
246 |
247 | def downsample(self, x, n1=0, n2=None):
248 | """Downsample mesh."""
249 | if n2 is None:
250 | n2 = self.num_downsampling
251 | if x.ndimension() < 3:
252 | for i in range(n1, n2):
253 | x = spmm(self._D[i], x)
254 | elif x.ndimension() == 3:
255 | out = []
256 | for i in range(x.shape[0]):
257 | y = x[i]
258 | for j in range(n1, n2):
259 | y = spmm(self._D[j], y)
260 | out.append(y)
261 | x = torch.stack(out, dim=0)
262 | return x
263 |
264 | def upsample(self, x, n1=1, n2=0):
265 | """Upsample mesh."""
266 | if x.ndimension() < 3:
267 | for i in reversed(range(n2, n1)):
268 | x = spmm(self._U[i], x)
269 | elif x.ndimension() == 3:
270 | out = []
271 | for i in range(x.shape[0]):
272 | y = x[i]
273 | for j in reversed(range(n2, n1)):
274 | y = spmm(self._U[j], y)
275 | out.append(y)
276 | x = torch.stack(out, dim=0)
277 | return x
278 |
--------------------------------------------------------------------------------
/src/modeling/bert/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.0.0"
2 |
3 | from .modeling_bert import (BertConfig, BertModel,
4 | load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
5 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP)
6 |
7 | from .modeling_graphormer import Graphormer
8 |
9 | from .e2e_body_network import Graphormer_Body_Network
10 |
11 | from .e2e_hand_network import Graphormer_Hand_Network
12 |
13 | from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME,
14 | PretrainedConfig, PreTrainedModel, prune_layer, Conv1D)
15 |
16 | from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE, cached_path)
17 |
--------------------------------------------------------------------------------
/src/modeling/bert/bert-base-uncased/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "BertForMaskedLM"
4 | ],
5 | "attention_probs_dropout_prob": 0.1,
6 | "hidden_act": "gelu",
7 | "hidden_dropout_prob": 0.1,
8 | "hidden_size": 768,
9 | "initializer_range": 0.02,
10 | "intermediate_size": 3072,
11 | "max_position_embeddings": 512,
12 | "num_attention_heads": 12,
13 | "num_hidden_layers": 12,
14 | "type_vocab_size": 2,
15 | "vocab_size": 30522
16 | }
17 |
--------------------------------------------------------------------------------
/src/modeling/bert/e2e_body_network.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Microsoft Corporation.
3 | Licensed under the MIT license.
4 |
5 | """
6 |
7 | import torch
8 | import src.modeling.data.config as cfg
9 |
10 | class Graphormer_Body_Network(torch.nn.Module):
11 | '''
12 | End-to-end Graphormer network for human pose and mesh reconstruction from a single image.
13 | '''
14 | def __init__(self, args, config, backbone, trans_encoder, mesh_sampler):
15 | super(Graphormer_Body_Network, self).__init__()
16 | self.config = config
17 | self.config.device = args.device
18 | self.backbone = backbone
19 | self.trans_encoder = trans_encoder
20 | self.upsampling = torch.nn.Linear(431, 1723)
21 | self.upsampling2 = torch.nn.Linear(1723, 6890)
22 | self.cam_param_fc = torch.nn.Linear(3, 1)
23 | self.cam_param_fc2 = torch.nn.Linear(431, 250)
24 | self.cam_param_fc3 = torch.nn.Linear(250, 3)
25 | self.grid_feat_dim = torch.nn.Linear(1024, 2051)
26 |
27 |
28 | def forward(self, images, smpl, mesh_sampler, meta_masks=None, is_train=False):
29 | batch_size = images.size(0)
30 | # Generate T-pose template mesh
31 | template_pose = torch.zeros((1,72))
32 | template_pose[:,0] = 3.1416 # Rectify "upside down" reference mesh in global coord
33 | template_pose = template_pose.cuda(self.config.device)
34 | template_betas = torch.zeros((1,10)).cuda(self.config.device)
35 | template_vertices = smpl(template_pose, template_betas)
36 |
37 | # template mesh simplification
38 | template_vertices_sub = mesh_sampler.downsample(template_vertices)
39 | template_vertices_sub2 = mesh_sampler.downsample(template_vertices_sub, n1=1, n2=2)
40 |
41 | # template mesh-to-joint regression
42 | template_3d_joints = smpl.get_h36m_joints(template_vertices)
43 | template_pelvis = template_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:]
44 | template_3d_joints = template_3d_joints[:,cfg.H36M_J17_TO_J14,:]
45 | num_joints = template_3d_joints.shape[1]
46 |
47 | # normalize
48 | template_3d_joints = template_3d_joints - template_pelvis[:, None, :]
49 | template_vertices_sub2 = template_vertices_sub2 - template_pelvis[:, None, :]
50 |
51 | # concatinate template joints and template vertices, and then duplicate to batch size
52 | ref_vertices = torch.cat([template_3d_joints, template_vertices_sub2],dim=1)
53 | ref_vertices = ref_vertices.expand(batch_size, -1, -1)
54 |
55 | # extract grid features and global image features using a CNN backbone
56 | image_feat, grid_feat = self.backbone(images)
57 | # concatinate image feat and 3d mesh template
58 | image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1)
59 | # process grid features
60 | grid_feat = torch.flatten(grid_feat, start_dim=2)
61 | grid_feat = grid_feat.transpose(1,2)
62 | grid_feat = self.grid_feat_dim(grid_feat)
63 | # concatinate image feat and template mesh to form the joint/vertex queries
64 | features = torch.cat([ref_vertices, image_feat], dim=2)
65 | # prepare input tokens including joint/vertex queries and grid features
66 | features = torch.cat([features, grid_feat],dim=1)
67 |
68 | if is_train==True:
69 | # apply mask vertex/joint modeling
70 | # meta_masks is a tensor of all the masks, randomly generated in dataloader
71 | # we pre-define a [MASK] token, which is a floating-value vector with 0.01s
72 | special_token = torch.ones_like(features[:,:-49,:]).cuda()*0.01
73 | features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks)
74 |
75 | # forward pass
76 | if self.config.output_attentions==True:
77 | features, hidden_states, att = self.trans_encoder(features)
78 | else:
79 | features = self.trans_encoder(features)
80 |
81 | pred_3d_joints = features[:,:num_joints,:]
82 | pred_vertices_sub2 = features[:,num_joints:-49,:]
83 |
84 | # learn camera parameters
85 | x = self.cam_param_fc(pred_vertices_sub2)
86 | x = x.transpose(1,2)
87 | x = self.cam_param_fc2(x)
88 | x = self.cam_param_fc3(x)
89 | cam_param = x.transpose(1,2)
90 | cam_param = cam_param.squeeze()
91 |
92 | temp_transpose = pred_vertices_sub2.transpose(1,2)
93 | pred_vertices_sub = self.upsampling(temp_transpose)
94 | pred_vertices_full = self.upsampling2(pred_vertices_sub)
95 | pred_vertices_sub = pred_vertices_sub.transpose(1,2)
96 | pred_vertices_full = pred_vertices_full.transpose(1,2)
97 |
98 | if self.config.output_attentions==True:
99 | return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full, hidden_states, att
100 | else:
101 | return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full
--------------------------------------------------------------------------------
/src/modeling/bert/e2e_hand_network.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Microsoft Corporation.
3 | Licensed under the MIT license.
4 |
5 | """
6 |
7 | import torch
8 | import src.modeling.data.config as cfg
9 |
10 | class Graphormer_Hand_Network(torch.nn.Module):
11 | '''
12 | End-to-end Graphormer network for hand pose and mesh reconstruction from a single image.
13 | '''
14 | def __init__(self, args, config, backbone, trans_encoder):
15 | super(Graphormer_Hand_Network, self).__init__()
16 | self.config = config
17 | self.backbone = backbone
18 | self.trans_encoder = trans_encoder
19 | self.upsampling = torch.nn.Linear(195, 778)
20 | self.cam_param_fc = torch.nn.Linear(3, 1)
21 | self.cam_param_fc2 = torch.nn.Linear(195+21, 150)
22 | self.cam_param_fc3 = torch.nn.Linear(150, 3)
23 | self.grid_feat_dim = torch.nn.Linear(1024, 2051)
24 |
25 | def forward(self, images, mesh_model, mesh_sampler, meta_masks=None, is_train=False):
26 | batch_size = images.size(0)
27 | # Generate T-pose template mesh
28 | template_pose = torch.zeros((1,48))
29 | template_pose = template_pose.cuda()
30 | template_betas = torch.zeros((1,10)).cuda()
31 | template_vertices, template_3d_joints = mesh_model.layer(template_pose, template_betas)
32 | template_vertices = template_vertices/1000.0
33 | template_3d_joints = template_3d_joints/1000.0
34 |
35 | template_vertices_sub = mesh_sampler.downsample(template_vertices)
36 |
37 | # normalize
38 | template_root = template_3d_joints[:,cfg.J_NAME.index('Wrist'),:]
39 | template_3d_joints = template_3d_joints - template_root[:, None, :]
40 | template_vertices = template_vertices - template_root[:, None, :]
41 | template_vertices_sub = template_vertices_sub - template_root[:, None, :]
42 | num_joints = template_3d_joints.shape[1]
43 |
44 | # concatinate template joints and template vertices, and then duplicate to batch size
45 | ref_vertices = torch.cat([template_3d_joints, template_vertices_sub],dim=1)
46 | ref_vertices = ref_vertices.expand(batch_size, -1, -1)
47 |
48 | # extract grid features and global image features using a CNN backbone
49 | image_feat, grid_feat = self.backbone(images)
50 | # concatinate image feat and mesh template
51 | image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1)
52 | # process grid features
53 | grid_feat = torch.flatten(grid_feat, start_dim=2)
54 | grid_feat = grid_feat.transpose(1,2)
55 | grid_feat = self.grid_feat_dim(grid_feat)
56 | # concatinate image feat and template mesh to form the joint/vertex queries
57 | features = torch.cat([ref_vertices, image_feat], dim=2)
58 | # prepare input tokens including joint/vertex queries and grid features
59 | features = torch.cat([features, grid_feat],dim=1)
60 |
61 | if is_train==True:
62 | # apply mask vertex/joint modeling
63 | # meta_masks is a tensor of all the masks, randomly generated in dataloader
64 | # we pre-define a [MASK] token, which is a floating-value vector with 0.01s
65 | special_token = torch.ones_like(features[:,:-49,:]).cuda()*0.01
66 | features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks)
67 |
68 | # forward pass
69 | if self.config.output_attentions==True:
70 | features, hidden_states, att = self.trans_encoder(features)
71 | else:
72 | features = self.trans_encoder(features)
73 |
74 | pred_3d_joints = features[:,:num_joints,:]
75 | pred_vertices_sub = features[:,num_joints:-49,:]
76 |
77 | # learn camera parameters
78 | x = self.cam_param_fc(features[:,:-49,:])
79 | x = x.transpose(1,2)
80 | x = self.cam_param_fc2(x)
81 | x = self.cam_param_fc3(x)
82 | cam_param = x.transpose(1,2)
83 | cam_param = cam_param.squeeze()
84 |
85 | temp_transpose = pred_vertices_sub.transpose(1,2)
86 | pred_vertices = self.upsampling(temp_transpose)
87 | pred_vertices = pred_vertices.transpose(1,2)
88 |
89 | if self.config.output_attentions==True:
90 | return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att
91 | else:
92 | return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices
--------------------------------------------------------------------------------
/src/modeling/bert/file_utils.py:
--------------------------------------------------------------------------------
1 | ../../../transformers/pytorch_transformers/file_utils.py
--------------------------------------------------------------------------------
/src/modeling/bert/modeling_bert.py:
--------------------------------------------------------------------------------
1 | ../../../transformers/pytorch_transformers/modeling_bert.py
--------------------------------------------------------------------------------
/src/modeling/bert/modeling_graphormer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Microsoft Corporation.
3 | Licensed under the MIT license.
4 |
5 | """
6 |
7 | from __future__ import absolute_import, division, print_function, unicode_literals
8 |
9 | import logging
10 | import math
11 | import os
12 | import code
13 | import torch
14 | from torch import nn
15 | from .modeling_bert import BertPreTrainedModel, BertEmbeddings, BertPooler, BertIntermediate, BertOutput, BertSelfOutput
16 | import src.modeling.data.config as cfg
17 | from src.modeling._gcnn import GraphConvolution, GraphResBlock
18 | from .modeling_utils import prune_linear_layer
19 | LayerNormClass = torch.nn.LayerNorm
20 | BertLayerNorm = torch.nn.LayerNorm
21 |
22 |
23 | class BertSelfAttention(nn.Module):
24 | def __init__(self, config):
25 | super(BertSelfAttention, self).__init__()
26 | if config.hidden_size % config.num_attention_heads != 0:
27 | raise ValueError(
28 | "The hidden size (%d) is not a multiple of the number of attention "
29 | "heads (%d)" % (config.hidden_size, config.num_attention_heads))
30 | self.output_attentions = config.output_attentions
31 |
32 | self.num_attention_heads = config.num_attention_heads
33 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
34 | self.all_head_size = self.num_attention_heads * self.attention_head_size
35 |
36 | self.query = nn.Linear(config.hidden_size, self.all_head_size)
37 | self.key = nn.Linear(config.hidden_size, self.all_head_size)
38 | self.value = nn.Linear(config.hidden_size, self.all_head_size)
39 |
40 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
41 |
42 | def transpose_for_scores(self, x):
43 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
44 | x = x.view(*new_x_shape)
45 | return x.permute(0, 2, 1, 3)
46 |
47 | def forward(self, hidden_states, attention_mask, head_mask=None,
48 | history_state=None):
49 | if history_state is not None:
50 | x_states = torch.cat([history_state, hidden_states], dim=1)
51 | mixed_query_layer = self.query(hidden_states)
52 | mixed_key_layer = self.key(x_states)
53 | mixed_value_layer = self.value(x_states)
54 | else:
55 | mixed_query_layer = self.query(hidden_states)
56 | mixed_key_layer = self.key(hidden_states)
57 | mixed_value_layer = self.value(hidden_states)
58 |
59 | query_layer = self.transpose_for_scores(mixed_query_layer)
60 | key_layer = self.transpose_for_scores(mixed_key_layer)
61 | value_layer = self.transpose_for_scores(mixed_value_layer)
62 |
63 | # Take the dot product between "query" and "key" to get the raw attention scores.
64 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
65 | attention_scores = attention_scores / math.sqrt(self.attention_head_size)
66 | # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
67 | attention_scores = attention_scores + attention_mask
68 |
69 | # Normalize the attention scores to probabilities.
70 | attention_probs = nn.Softmax(dim=-1)(attention_scores)
71 |
72 | # This is actually dropping out entire tokens to attend to, which might
73 | # seem a bit unusual, but is taken from the original Transformer paper.
74 | attention_probs = self.dropout(attention_probs)
75 |
76 | # Mask heads if we want to
77 | if head_mask is not None:
78 | attention_probs = attention_probs * head_mask
79 |
80 | context_layer = torch.matmul(attention_probs, value_layer)
81 |
82 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
83 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
84 | context_layer = context_layer.view(*new_context_layer_shape)
85 |
86 | outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
87 | return outputs
88 |
89 | class BertAttention(nn.Module):
90 | def __init__(self, config):
91 | super(BertAttention, self).__init__()
92 | self.self = BertSelfAttention(config)
93 | self.output = BertSelfOutput(config)
94 |
95 | def prune_heads(self, heads):
96 | if len(heads) == 0:
97 | return
98 | mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
99 | for head in heads:
100 | mask[head] = 0
101 | mask = mask.view(-1).contiguous().eq(1)
102 | index = torch.arange(len(mask))[mask].long()
103 | # Prune linear layers
104 | self.self.query = prune_linear_layer(self.self.query, index)
105 | self.self.key = prune_linear_layer(self.self.key, index)
106 | self.self.value = prune_linear_layer(self.self.value, index)
107 | self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
108 | # Update hyper params
109 | self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
110 | self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
111 |
112 | def forward(self, input_tensor, attention_mask, head_mask=None,
113 | history_state=None):
114 | self_outputs = self.self(input_tensor, attention_mask, head_mask,
115 | history_state)
116 | attention_output = self.output(self_outputs[0], input_tensor)
117 | outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
118 | return outputs
119 |
120 |
121 | class GraphormerLayer(nn.Module):
122 | def __init__(self, config):
123 | super(GraphormerLayer, self).__init__()
124 | self.attention = BertAttention(config)
125 | self.has_graph_conv = config.graph_conv
126 | self.mesh_type = config.mesh_type
127 |
128 | if self.has_graph_conv == True:
129 | self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type)
130 |
131 | self.intermediate = BertIntermediate(config)
132 | self.output = BertOutput(config)
133 |
134 | def MHA_GCN(self, hidden_states, attention_mask, head_mask=None,
135 | history_state=None):
136 | attention_outputs = self.attention(hidden_states, attention_mask,
137 | head_mask, history_state)
138 | attention_output = attention_outputs[0]
139 |
140 | if self.has_graph_conv==True:
141 | if self.mesh_type == 'body':
142 | joints = attention_output[:,0:14,:]
143 | vertices = attention_output[:,14:-49,:]
144 | img_tokens = attention_output[:,-49:,:]
145 |
146 | elif self.mesh_type == 'hand':
147 | joints = attention_output[:,0:21,:]
148 | vertices = attention_output[:,21:-49,:]
149 | img_tokens = attention_output[:,-49:,:]
150 |
151 | vertices = self.graph_conv(vertices)
152 | joints_vertices = torch.cat([joints,vertices,img_tokens],dim=1)
153 | else:
154 | joints_vertices = attention_output
155 |
156 | intermediate_output = self.intermediate(joints_vertices)
157 | layer_output = self.output(intermediate_output, joints_vertices)
158 | outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
159 | return outputs
160 |
161 | def forward(self, hidden_states, attention_mask, head_mask=None,
162 | history_state=None):
163 | return self.MHA_GCN(hidden_states, attention_mask, head_mask,history_state)
164 |
165 |
166 | class GraphormerEncoder(nn.Module):
167 | def __init__(self, config):
168 | super(GraphormerEncoder, self).__init__()
169 | self.output_attentions = config.output_attentions
170 | self.output_hidden_states = config.output_hidden_states
171 | self.layer = nn.ModuleList([GraphormerLayer(config) for _ in range(config.num_hidden_layers)])
172 |
173 | def forward(self, hidden_states, attention_mask, head_mask=None,
174 | encoder_history_states=None):
175 | all_hidden_states = ()
176 | all_attentions = ()
177 | for i, layer_module in enumerate(self.layer):
178 | if self.output_hidden_states:
179 | all_hidden_states = all_hidden_states + (hidden_states,)
180 |
181 | history_state = None if encoder_history_states is None else encoder_history_states[i]
182 | layer_outputs = layer_module(
183 | hidden_states, attention_mask, head_mask[i],
184 | history_state)
185 | hidden_states = layer_outputs[0]
186 |
187 | if self.output_attentions:
188 | all_attentions = all_attentions + (layer_outputs[1],)
189 |
190 | # Add last layer
191 | if self.output_hidden_states:
192 | all_hidden_states = all_hidden_states + (hidden_states,)
193 |
194 | outputs = (hidden_states,)
195 | if self.output_hidden_states:
196 | outputs = outputs + (all_hidden_states,)
197 | if self.output_attentions:
198 | outputs = outputs + (all_attentions,)
199 |
200 | return outputs # outputs, (hidden states), (attentions)
201 |
202 | class EncoderBlock(BertPreTrainedModel):
203 | def __init__(self, config):
204 | super(EncoderBlock, self).__init__(config)
205 | self.config = config
206 | self.embeddings = BertEmbeddings(config)
207 | self.encoder = GraphormerEncoder(config)
208 | self.pooler = BertPooler(config)
209 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
210 | self.img_dim = config.img_feature_dim
211 |
212 | try:
213 | self.use_img_layernorm = config.use_img_layernorm
214 | except:
215 | self.use_img_layernorm = None
216 |
217 | self.img_embedding = nn.Linear(self.img_dim, self.config.hidden_size, bias=True)
218 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
219 | if self.use_img_layernorm:
220 | self.LayerNorm = LayerNormClass(config.hidden_size, eps=config.img_layer_norm_eps)
221 |
222 | self.apply(self.init_weights)
223 |
224 | def _prune_heads(self, heads_to_prune):
225 | """ Prunes heads of the model.
226 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
227 | See base class PreTrainedModel
228 | """
229 | for layer, heads in heads_to_prune.items():
230 | self.encoder.layer[layer].attention.prune_heads(heads)
231 |
232 | def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None,
233 | position_ids=None, head_mask=None):
234 |
235 | batch_size = len(img_feats)
236 | seq_length = len(img_feats[0])
237 | input_ids = torch.zeros([batch_size, seq_length],dtype=torch.long).cuda()
238 |
239 | if position_ids is None:
240 | position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
241 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
242 |
243 | position_embeddings = self.position_embeddings(position_ids)
244 |
245 | if attention_mask is None:
246 | attention_mask = torch.ones_like(input_ids)
247 |
248 | if token_type_ids is None:
249 | token_type_ids = torch.zeros_like(input_ids)
250 |
251 | if attention_mask.dim() == 2:
252 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
253 | elif attention_mask.dim() == 3:
254 | extended_attention_mask = attention_mask.unsqueeze(1)
255 | else:
256 | raise NotImplementedError
257 |
258 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
259 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
260 |
261 | if head_mask is not None:
262 | if head_mask.dim() == 1:
263 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
264 | head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
265 | elif head_mask.dim() == 2:
266 | head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
267 | head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
268 | else:
269 | head_mask = [None] * self.config.num_hidden_layers
270 |
271 | # Project input token features to have spcified hidden size
272 | img_embedding_output = self.img_embedding(img_feats)
273 |
274 | # We empirically observe that adding an additional learnable position embedding leads to more stable training
275 | embeddings = position_embeddings + img_embedding_output
276 |
277 | if self.use_img_layernorm:
278 | embeddings = self.LayerNorm(embeddings)
279 | embeddings = self.dropout(embeddings)
280 |
281 | encoder_outputs = self.encoder(embeddings,
282 | extended_attention_mask, head_mask=head_mask)
283 | sequence_output = encoder_outputs[0]
284 |
285 | outputs = (sequence_output,)
286 | if self.config.output_hidden_states:
287 | all_hidden_states = encoder_outputs[1]
288 | outputs = outputs + (all_hidden_states,)
289 | if self.config.output_attentions:
290 | all_attentions = encoder_outputs[-1]
291 | outputs = outputs + (all_attentions,)
292 |
293 | return outputs
294 |
295 | class Graphormer(BertPreTrainedModel):
296 | '''
297 | The archtecture of a transformer encoder block we used in Graphormer
298 | '''
299 | def __init__(self, config):
300 | super(Graphormer, self).__init__(config)
301 | self.config = config
302 | self.bert = EncoderBlock(config)
303 | self.cls_head = nn.Linear(config.hidden_size, self.config.output_feature_dim)
304 | self.residual = nn.Linear(config.img_feature_dim, self.config.output_feature_dim)
305 | self.apply(self.init_weights)
306 |
307 | def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None,
308 | next_sentence_label=None, position_ids=None, head_mask=None):
309 | '''
310 | # self.bert has three outputs
311 | # predictions[0]: output tokens
312 | # predictions[1]: all_hidden_states, if enable "self.config.output_hidden_states"
313 | # predictions[2]: attentions, if enable "self.config.output_attentions"
314 | '''
315 | predictions = self.bert(img_feats=img_feats, input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
316 | attention_mask=attention_mask, head_mask=head_mask)
317 |
318 | # We use "self.cls_head" to perform dimensionality reduction. We don't use it for classification.
319 | pred_score = self.cls_head(predictions[0])
320 | res_img_feats = self.residual(img_feats)
321 | pred_score = pred_score + res_img_feats
322 |
323 | if self.config.output_attentions and self.config.output_hidden_states:
324 | return pred_score, predictions[1], predictions[-1]
325 | else:
326 | return pred_score
327 |
328 |
--------------------------------------------------------------------------------
/src/modeling/bert/modeling_utils.py:
--------------------------------------------------------------------------------
1 | ../../../transformers/pytorch_transformers/modeling_utils.py
--------------------------------------------------------------------------------
/src/modeling/data/J_regressor_extra.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/J_regressor_extra.npy
--------------------------------------------------------------------------------
/src/modeling/data/J_regressor_h36m_correct.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/J_regressor_h36m_correct.npy
--------------------------------------------------------------------------------
/src/modeling/data/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Extra data
3 | Adapted from open source project [GraphCMR](https://github.com/nkolot/GraphCMR/) and [Pose2Mesh](https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
4 |
5 | Our code requires additional data to run smoothly.
6 |
7 | ### J_regressor_extra.npy
8 | Joints regressor for joints or landmarks that are not included in the standard set of SMPL joints.
9 |
10 | ### J_regressor_h36m_correct.npy
11 | Joints regressor reflecting the Human3.6M joints.
12 |
13 | ### mesh_downsampling.npz
14 | Extra file with precomputed downsampling for the SMPL body mesh.
15 |
16 | ### mano_downsampling.npz
17 | Extra file with precomputed downsampling for the MANO hand mesh.
18 |
19 | ### basicModel_neutral_lbs_10_207_0_v1.0.0.pkl
20 | SMPL neutral model. Please visit the official website to download the file [http://smplify.is.tue.mpg.de/](http://smplify.is.tue.mpg.de/)
21 |
22 | ### basicModel_m_lbs_10_207_0_v1.0.0.pkl
23 | SMPL male model. Please visit the official website to download the file [https://smpl.is.tue.mpg.de/](https://smpl.is.tue.mpg.de/)
24 |
25 | ### basicModel_f_lbs_10_207_0_v1.0.0.pkl
26 | SMPL female model. Please visit the official website to download the file [https://smpl.is.tue.mpg.de/](https://smpl.is.tue.mpg.de/)
27 |
28 | ### MANO_RIGHT.pkl
29 | MANO hand model. Please visit the official website to download the file [https://mano.is.tue.mpg.de/](https://mano.is.tue.mpg.de/)
30 |
31 |
--------------------------------------------------------------------------------
/src/modeling/data/config.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains definitions of useful data stuctures and the paths
3 | for the datasets and data files necessary to run the code.
4 |
5 | Adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/) and Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE)
6 |
7 | """
8 |
9 | from os.path import join
10 | folder_path = 'src/modeling/'
11 | JOINT_REGRESSOR_TRAIN_EXTRA = folder_path + 'data/J_regressor_extra.npy'
12 | JOINT_REGRESSOR_H36M_correct = folder_path + 'data/J_regressor_h36m_correct.npy'
13 | SMPL_FILE = folder_path + 'data/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl'
14 | SMPL_Male = folder_path + 'data/basicModel_m_lbs_10_207_0_v1.0.0.pkl'
15 | SMPL_Female = folder_path + 'data/basicModel_f_lbs_10_207_0_v1.0.0.pkl'
16 | SMPL_sampling_matrix = folder_path + 'data/mesh_downsampling.npz'
17 | MANO_FILE = folder_path + 'data/MANO_RIGHT.pkl'
18 | MANO_sampling_matrix = folder_path + 'data/mano_downsampling.npz'
19 |
20 | JOINTS_IDX = [8, 5, 29, 30, 4, 7, 21, 19, 17, 16, 18, 20, 31, 32, 33, 34, 35, 36, 37, 24, 26, 25, 28, 27]
21 |
22 |
23 | """
24 | We follow the body joint definition, loss functions, and evaluation metrics from
25 | open source project GraphCMR (https://github.com/nkolot/GraphCMR/)
26 |
27 | Each dataset uses different sets of joints.
28 | We use a superset of 24 joints such that we include all joints from every dataset.
29 | If a dataset doesn't provide annotations for a specific joint, we simply ignore it.
30 | The joints used here are:
31 | """
32 | J24_NAME = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder',
33 | 'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear')
34 | H36M_J17_NAME = ( 'Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Torso', 'Neck', 'Nose', 'Head',
35 | 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist')
36 | J24_TO_J14 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18]
37 | H36M_J17_TO_J14 = [3, 2, 1, 4, 5, 6, 16, 15, 14, 11, 12, 13, 8, 10]
38 |
39 | """
40 | We follow the hand joint definition and mesh topology from
41 | open source project Manopth (https://github.com/hassony2/manopth)
42 |
43 | The hand joints used here are:
44 | """
45 | J_NAME = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1',
46 | 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4')
47 | ROOT_INDEX = 0
--------------------------------------------------------------------------------
/src/modeling/data/mano_195_adjmat_indices.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/mano_195_adjmat_indices.pt
--------------------------------------------------------------------------------
/src/modeling/data/mano_195_adjmat_size.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/mano_195_adjmat_size.pt
--------------------------------------------------------------------------------
/src/modeling/data/mano_195_adjmat_values.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/mano_195_adjmat_values.pt
--------------------------------------------------------------------------------
/src/modeling/data/mano_downsampling.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/mano_downsampling.npz
--------------------------------------------------------------------------------
/src/modeling/data/mesh_downsampling.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/mesh_downsampling.npz
--------------------------------------------------------------------------------
/src/modeling/data/smpl_431_adjmat_indices.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/smpl_431_adjmat_indices.pt
--------------------------------------------------------------------------------
/src/modeling/data/smpl_431_adjmat_size.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/smpl_431_adjmat_size.pt
--------------------------------------------------------------------------------
/src/modeling/data/smpl_431_adjmat_values.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/smpl_431_adjmat_values.pt
--------------------------------------------------------------------------------
/src/modeling/data/smpl_431_faces.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/modeling/data/smpl_431_faces.npy
--------------------------------------------------------------------------------
/src/modeling/hrnet/config/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com)
5 | # ------------------------------------------------------------------------------
6 |
7 | from .default import _C as config
8 | from .default import update_config
9 | from .models import MODEL_EXTRAS
10 |
--------------------------------------------------------------------------------
/src/modeling/hrnet/config/default.py:
--------------------------------------------------------------------------------
1 |
2 | # ------------------------------------------------------------------------------
3 | # Copyright (c) Microsoft
4 | # Licensed under the MIT License.
5 | # Written by Bin Xiao (Bin.Xiao@microsoft.com)
6 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn)
7 | # ------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import division
11 | from __future__ import print_function
12 |
13 | import os
14 |
15 | from yacs.config import CfgNode as CN
16 |
17 |
18 | _C = CN()
19 |
20 | _C.OUTPUT_DIR = ''
21 | _C.LOG_DIR = ''
22 | _C.DATA_DIR = ''
23 | _C.GPUS = (0,)
24 | _C.WORKERS = 4
25 | _C.PRINT_FREQ = 20
26 | _C.AUTO_RESUME = False
27 | _C.PIN_MEMORY = True
28 | _C.RANK = 0
29 |
30 | # Cudnn related params
31 | _C.CUDNN = CN()
32 | _C.CUDNN.BENCHMARK = True
33 | _C.CUDNN.DETERMINISTIC = False
34 | _C.CUDNN.ENABLED = True
35 |
36 | # common params for NETWORK
37 | _C.MODEL = CN()
38 | _C.MODEL.NAME = 'cls_hrnet'
39 | _C.MODEL.INIT_WEIGHTS = True
40 | _C.MODEL.PRETRAINED = ''
41 | _C.MODEL.NUM_JOINTS = 17
42 | _C.MODEL.NUM_CLASSES = 1000
43 | _C.MODEL.TAG_PER_JOINT = True
44 | _C.MODEL.TARGET_TYPE = 'gaussian'
45 | _C.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256
46 | _C.MODEL.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32
47 | _C.MODEL.SIGMA = 2
48 | _C.MODEL.EXTRA = CN(new_allowed=True)
49 |
50 | _C.LOSS = CN()
51 | _C.LOSS.USE_OHKM = False
52 | _C.LOSS.TOPK = 8
53 | _C.LOSS.USE_TARGET_WEIGHT = True
54 | _C.LOSS.USE_DIFFERENT_JOINTS_WEIGHT = False
55 |
56 | # DATASET related params
57 | _C.DATASET = CN()
58 | _C.DATASET.ROOT = ''
59 | _C.DATASET.DATASET = 'mpii'
60 | _C.DATASET.TRAIN_SET = 'train'
61 | _C.DATASET.TEST_SET = 'valid'
62 | _C.DATASET.DATA_FORMAT = 'jpg'
63 | _C.DATASET.HYBRID_JOINTS_TYPE = ''
64 | _C.DATASET.SELECT_DATA = False
65 |
66 | # training data augmentation
67 | _C.DATASET.FLIP = True
68 | _C.DATASET.SCALE_FACTOR = 0.25
69 | _C.DATASET.ROT_FACTOR = 30
70 | _C.DATASET.PROB_HALF_BODY = 0.0
71 | _C.DATASET.NUM_JOINTS_HALF_BODY = 8
72 | _C.DATASET.COLOR_RGB = False
73 |
74 | # train
75 | _C.TRAIN = CN()
76 |
77 | _C.TRAIN.LR_FACTOR = 0.1
78 | _C.TRAIN.LR_STEP = [90, 110]
79 | _C.TRAIN.LR = 0.001
80 |
81 | _C.TRAIN.OPTIMIZER = 'adam'
82 | _C.TRAIN.MOMENTUM = 0.9
83 | _C.TRAIN.WD = 0.0001
84 | _C.TRAIN.NESTEROV = False
85 | _C.TRAIN.GAMMA1 = 0.99
86 | _C.TRAIN.GAMMA2 = 0.0
87 |
88 | _C.TRAIN.BEGIN_EPOCH = 0
89 | _C.TRAIN.END_EPOCH = 140
90 |
91 | _C.TRAIN.RESUME = False
92 | _C.TRAIN.CHECKPOINT = ''
93 |
94 | _C.TRAIN.BATCH_SIZE_PER_GPU = 32
95 | _C.TRAIN.SHUFFLE = True
96 |
97 | # testing
98 | _C.TEST = CN()
99 |
100 | # size of images for each device
101 | _C.TEST.BATCH_SIZE_PER_GPU = 32
102 | # Test Model Epoch
103 | _C.TEST.FLIP_TEST = False
104 | _C.TEST.POST_PROCESS = False
105 | _C.TEST.SHIFT_HEATMAP = False
106 |
107 | _C.TEST.USE_GT_BBOX = False
108 |
109 | # nms
110 | _C.TEST.IMAGE_THRE = 0.1
111 | _C.TEST.NMS_THRE = 0.6
112 | _C.TEST.SOFT_NMS = False
113 | _C.TEST.OKS_THRE = 0.5
114 | _C.TEST.IN_VIS_THRE = 0.0
115 | _C.TEST.COCO_BBOX_FILE = ''
116 | _C.TEST.BBOX_THRE = 1.0
117 | _C.TEST.MODEL_FILE = ''
118 |
119 | # debug
120 | _C.DEBUG = CN()
121 | _C.DEBUG.DEBUG = False
122 | _C.DEBUG.SAVE_BATCH_IMAGES_GT = False
123 | _C.DEBUG.SAVE_BATCH_IMAGES_PRED = False
124 | _C.DEBUG.SAVE_HEATMAPS_GT = False
125 | _C.DEBUG.SAVE_HEATMAPS_PRED = False
126 |
127 |
128 | def update_config(cfg, config_file):
129 | cfg.defrost()
130 | cfg.merge_from_file(config_file)
131 | cfg.freeze()
132 |
133 |
134 | if __name__ == '__main__':
135 | import sys
136 | with open(sys.argv[1], 'w') as f:
137 | print(_C, file=f)
138 |
139 |
--------------------------------------------------------------------------------
/src/modeling/hrnet/config/models.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Create by Bin Xiao (Bin.Xiao@microsoft.com)
5 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn)
6 | # ------------------------------------------------------------------------------
7 |
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | from yacs.config import CfgNode as CN
13 |
14 | # high_resoluton_net related params for classification
15 | POSE_HIGH_RESOLUTION_NET = CN()
16 | POSE_HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*']
17 | POSE_HIGH_RESOLUTION_NET.STEM_INPLANES = 64
18 | POSE_HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1
19 | POSE_HIGH_RESOLUTION_NET.WITH_HEAD = True
20 |
21 | POSE_HIGH_RESOLUTION_NET.STAGE2 = CN()
22 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1
23 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2
24 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4]
25 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64]
26 | POSE_HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC'
27 | POSE_HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM'
28 |
29 | POSE_HIGH_RESOLUTION_NET.STAGE3 = CN()
30 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1
31 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3
32 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4]
33 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128]
34 | POSE_HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC'
35 | POSE_HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM'
36 |
37 | POSE_HIGH_RESOLUTION_NET.STAGE4 = CN()
38 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1
39 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4
40 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
41 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
42 | POSE_HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC'
43 | POSE_HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM'
44 |
45 | MODEL_EXTRAS = {
46 | 'cls_hrnet': POSE_HIGH_RESOLUTION_NET,
47 | }
48 |
--------------------------------------------------------------------------------
/src/tools/run_gphmer_handmesh_inference.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Microsoft Corporation.
3 | Licensed under the MIT license.
4 |
5 | End-to-end inference codes for
6 | 3D hand mesh reconstruction from an image
7 | """
8 |
9 | from __future__ import absolute_import, division, print_function
10 | import argparse
11 | import os
12 | import os.path as op
13 | import code
14 | import json
15 | import time
16 | import datetime
17 | import torch
18 | import torchvision.models as models
19 | from torchvision.utils import make_grid
20 | import gc
21 | import numpy as np
22 | import cv2
23 | from src.modeling.bert import BertConfig, Graphormer
24 | from src.modeling.bert import Graphormer_Hand_Network as Graphormer_Network
25 | from src.modeling._mano import MANO, Mesh
26 | from src.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat
27 | from src.modeling.hrnet.config import config as hrnet_config
28 | from src.modeling.hrnet.config import update_config as hrnet_update_config
29 | import src.modeling.data.config as cfg
30 | from src.datasets.build import make_hand_data_loader
31 |
32 | from src.utils.logger import setup_logger
33 | from src.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather
34 | from src.utils.miscellaneous import mkdir, set_seed
35 | from src.utils.metric_logger import AverageMeter
36 | from src.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text
37 | from src.utils.metric_pampjpe import reconstruction_error
38 | from src.utils.geometric_layers import orthographic_projection
39 |
40 | from PIL import Image
41 | from torchvision import transforms
42 |
43 | transform = transforms.Compose([
44 | transforms.Resize(224),
45 | transforms.CenterCrop(224),
46 | transforms.ToTensor(),
47 | transforms.Normalize(
48 | mean=[0.485, 0.456, 0.406],
49 | std=[0.229, 0.224, 0.225])])
50 |
51 | transform_visualize = transforms.Compose([
52 | transforms.Resize(224),
53 | transforms.CenterCrop(224),
54 | transforms.ToTensor()])
55 |
56 | def run_inference(args, image_list, Graphormer_model, mano, renderer, mesh_sampler):
57 | # switch to evaluate mode
58 | Graphormer_model.eval()
59 | mano.eval()
60 | with torch.no_grad():
61 | for image_file in image_list:
62 | if 'pred' not in image_file:
63 | att_all = []
64 | print(image_file)
65 | img = Image.open(image_file)
66 | img_tensor = transform(img)
67 | img_visual = transform_visualize(img)
68 |
69 | batch_imgs = torch.unsqueeze(img_tensor, 0).cuda()
70 | batch_visual_imgs = torch.unsqueeze(img_visual, 0).cuda()
71 | # forward-pass
72 | pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler)
73 | # obtain 3d joints from full mesh
74 | pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices)
75 | pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:]
76 | pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :]
77 | pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :]
78 |
79 | # save attantion
80 | att_max_value = att[-1]
81 | att_cpu = np.asarray(att_max_value.cpu().detach())
82 | att_all.append(att_cpu)
83 |
84 | # obtain 3d joints, which are regressed from the full mesh
85 | pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices)
86 | # obtain 2d joints, which are projected from 3d joints of mesh
87 | pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous())
88 | pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous())
89 |
90 |
91 | visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0],
92 | pred_vertices[0].detach(),
93 | pred_camera.detach())
94 | # visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0],
95 | # pred_vertices[0].detach(),
96 | # pred_vertices_sub[0].detach(),
97 | # pred_2d_coarse_vertices_from_mesh[0].detach(),
98 | # pred_2d_joints_from_mesh[0].detach(),
99 | # pred_camera.detach(),
100 | # att[-1][0].detach())
101 | visual_imgs = visual_imgs_output.transpose(1,2,0)
102 | visual_imgs = np.asarray(visual_imgs)
103 |
104 | temp_fname = image_file[:-4] + '_graphormer_pred.jpg'
105 | print('save to ', temp_fname)
106 | cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))
107 | return
108 |
109 | def visualize_mesh( renderer, images,
110 | pred_vertices_full,
111 | pred_camera):
112 | img = images.cpu().numpy().transpose(1,2,0)
113 | # Get predict vertices for the particular example
114 | vertices_full = pred_vertices_full.cpu().numpy()
115 | cam = pred_camera.cpu().numpy()
116 | # Visualize only mesh reconstruction
117 | rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue')
118 | rend_img = rend_img.transpose(2,0,1)
119 | return rend_img
120 |
121 | def visualize_mesh_and_attention( renderer, images,
122 | pred_vertices_full,
123 | pred_vertices,
124 | pred_2d_vertices,
125 | pred_2d_joints,
126 | pred_camera,
127 | attention):
128 | img = images.cpu().numpy().transpose(1,2,0)
129 | # Get predict vertices for the particular example
130 | vertices_full = pred_vertices_full.cpu().numpy()
131 | vertices = pred_vertices.cpu().numpy()
132 | vertices_2d = pred_2d_vertices.cpu().numpy()
133 | joints_2d = pred_2d_joints.cpu().numpy()
134 | cam = pred_camera.cpu().numpy()
135 | att = attention.cpu().numpy()
136 | # Visualize reconstruction and attention
137 | rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue')
138 | rend_img = rend_img.transpose(2,0,1)
139 | return rend_img
140 |
141 | def parse_args():
142 | parser = argparse.ArgumentParser()
143 | #########################################################
144 | # Data related arguments
145 | #########################################################
146 | parser.add_argument("--num_workers", default=4, type=int,
147 | help="Workers in dataloader.")
148 | parser.add_argument("--img_scale_factor", default=1, type=int,
149 | help="adjust image resolution.")
150 | parser.add_argument("--image_file_or_path", default='./samples/hand', type=str,
151 | help="test data")
152 | #########################################################
153 | # Loading/saving checkpoints
154 | #########################################################
155 | parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False,
156 | help="Path to pre-trained transformer model or model type.")
157 | parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
158 | help="Path to specific checkpoint for resume training.")
159 | parser.add_argument("--output_dir", default='output/', type=str, required=False,
160 | help="The output directory to save checkpoint and test results.")
161 | parser.add_argument("--config_name", default="", type=str,
162 | help="Pretrained config name or path if not the same as model_name.")
163 | parser.add_argument('-a', '--arch', default='hrnet-w64',
164 | help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
165 | #########################################################
166 | # Model architectures
167 | #########################################################
168 | parser.add_argument("--num_hidden_layers", default=4, type=int, required=False,
169 | help="Update model config if given")
170 | parser.add_argument("--hidden_size", default=-1, type=int, required=False,
171 | help="Update model config if given")
172 | parser.add_argument("--num_attention_heads", default=4, type=int, required=False,
173 | help="Update model config if given. Note that the division of "
174 | "hidden_size / num_attention_heads should be in integer.")
175 | parser.add_argument("--intermediate_size", default=-1, type=int, required=False,
176 | help="Update model config if given.")
177 | parser.add_argument("--input_feat_dim", default='2051,512,128', type=str,
178 | help="The Image Feature Dimension.")
179 | parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str,
180 | help="The Image Feature Dimension.")
181 | parser.add_argument("--which_gcn", default='0,0,1', type=str,
182 | help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv")
183 | parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand")
184 |
185 | #########################################################
186 | # Others
187 | #########################################################
188 | parser.add_argument("--run_eval_only", default=True, action='store_true',)
189 | parser.add_argument("--device", type=str, default='cuda',
190 | help="cuda or cpu")
191 | parser.add_argument('--seed', type=int, default=88,
192 | help="random seed for initialization.")
193 | args = parser.parse_args()
194 | return args
195 |
196 | def main(args):
197 | global logger
198 | # Setup CUDA, GPU & distributed training
199 | args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1
200 | os.environ['OMP_NUM_THREADS'] = str(args.num_workers)
201 | print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS']))
202 |
203 | mkdir(args.output_dir)
204 | logger = setup_logger("Graphormer", args.output_dir, get_rank())
205 | set_seed(args.seed, args.num_gpus)
206 | logger.info("Using {} GPUs".format(args.num_gpus))
207 |
208 | # Mesh and MANO utils
209 | mano_model = MANO().to(args.device)
210 | mano_model.layer = mano_model.layer.cuda()
211 | mesh_sampler = Mesh()
212 |
213 | # Renderer for visualization
214 | renderer = Renderer(faces=mano_model.face)
215 |
216 | # Load pretrained model
217 | trans_encoder = []
218 |
219 | input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')]
220 | hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')]
221 | output_feat_dim = input_feat_dim[1:] + [3]
222 |
223 | # which encoder block to have graph convs
224 | which_blk_graph = [int(item) for item in args.which_gcn.split(',')]
225 |
226 | if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint:
227 | # if only run eval, load checkpoint
228 | logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint))
229 | _model = torch.load(args.resume_checkpoint)
230 |
231 | else:
232 | # init three transformer-encoder blocks in a loop
233 | for i in range(len(output_feat_dim)):
234 | config_class, model_class = BertConfig, Graphormer
235 | config = config_class.from_pretrained(args.config_name if args.config_name \
236 | else args.model_name_or_path)
237 |
238 | config.output_attentions = False
239 | config.img_feature_dim = input_feat_dim[i]
240 | config.output_feature_dim = output_feat_dim[i]
241 | args.hidden_size = hidden_feat_dim[i]
242 | args.intermediate_size = int(args.hidden_size*2)
243 |
244 | if which_blk_graph[i]==1:
245 | config.graph_conv = True
246 | logger.info("Add Graph Conv")
247 | else:
248 | config.graph_conv = False
249 |
250 | config.mesh_type = args.mesh_type
251 |
252 | # update model structure if specified in arguments
253 | update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']
254 | for idx, param in enumerate(update_params):
255 | arg_param = getattr(args, param)
256 | config_param = getattr(config, param)
257 | if arg_param > 0 and arg_param != config_param:
258 | logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
259 | setattr(config, param, arg_param)
260 |
261 | # init a transformer encoder and append it to a list
262 | assert config.hidden_size % config.num_attention_heads == 0
263 | model = model_class(config=config)
264 | logger.info("Init model from scratch.")
265 | trans_encoder.append(model)
266 |
267 | # create backbone model
268 | if args.arch=='hrnet':
269 | hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
270 | hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth'
271 | hrnet_update_config(hrnet_config, hrnet_yaml)
272 | backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
273 | logger.info('=> loading hrnet-v2-w40 model')
274 | elif args.arch=='hrnet-w64':
275 | hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml'
276 | hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth'
277 | hrnet_update_config(hrnet_config, hrnet_yaml)
278 | backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint)
279 | logger.info('=> loading hrnet-v2-w64 model')
280 | else:
281 | print("=> using pre-trained model '{}'".format(args.arch))
282 | backbone = models.__dict__[args.arch](pretrained=True)
283 | # remove the last fc layer
284 | backbone = torch.nn.Sequential(*list(backbone.children())[:-1])
285 |
286 | trans_encoder = torch.nn.Sequential(*trans_encoder)
287 | total_params = sum(p.numel() for p in trans_encoder.parameters())
288 | logger.info('Graphormer encoders total parameters: {}'.format(total_params))
289 | backbone_total_params = sum(p.numel() for p in backbone.parameters())
290 | logger.info('Backbone total parameters: {}'.format(backbone_total_params))
291 |
292 | # build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder)
293 | _model = Graphormer_Network(args, config, backbone, trans_encoder)
294 |
295 | if args.resume_checkpoint!=None and args.resume_checkpoint!='None':
296 | # for fine-tuning or resume training or inference, load weights from checkpoint
297 | logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint))
298 | # workaround approach to load sparse tensor in graph conv.
299 | state_dict = torch.load(args.resume_checkpoint)
300 | _model.load_state_dict(state_dict, strict=False)
301 | del state_dict
302 | gc.collect()
303 | torch.cuda.empty_cache()
304 |
305 | # update configs to enable attention outputs
306 | setattr(_model.trans_encoder[-1].config,'output_attentions', True)
307 | setattr(_model.trans_encoder[-1].config,'output_hidden_states', True)
308 | _model.trans_encoder[-1].bert.encoder.output_attentions = True
309 | _model.trans_encoder[-1].bert.encoder.output_hidden_states = True
310 | for iter_layer in range(4):
311 | _model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True
312 | for inter_block in range(3):
313 | setattr(_model.trans_encoder[-1].config,'device', args.device)
314 |
315 | _model.to(args.device)
316 | logger.info("Run inference")
317 |
318 | image_list = []
319 | if not args.image_file_or_path:
320 | raise ValueError("image_file_or_path not specified")
321 | if op.isfile(args.image_file_or_path):
322 | image_list = [args.image_file_or_path]
323 | elif op.isdir(args.image_file_or_path):
324 | # should be a path with images only
325 | for filename in os.listdir(args.image_file_or_path):
326 | if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename:
327 | image_list.append(args.image_file_or_path+'/'+filename)
328 | else:
329 | raise ValueError("Cannot find images at {}".format(args.image_file_or_path))
330 |
331 | run_inference(args, image_list, _model, mano_model, renderer, mesh_sampler)
332 |
333 | if __name__ == "__main__":
334 | args = parse_args()
335 | main(args)
336 |
--------------------------------------------------------------------------------
/src/tools/run_hand_multiscale.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import argparse
4 | import os
5 | import os.path as op
6 | import code
7 | import json
8 | import zipfile
9 | import torch
10 | import numpy as np
11 | from src.utils.metric_pampjpe import get_alignMesh
12 |
13 |
14 | def load_pred_json(filepath):
15 | archive = zipfile.ZipFile(filepath, 'r')
16 | jsondata = archive.read('pred.json')
17 | reference = json.loads(jsondata.decode("utf-8"))
18 | return reference[0], reference[1]
19 |
20 |
21 | def multiscale_fusion(output_dir):
22 | s = '10'
23 | filepath = output_dir+'ckpt200-sc10_rot0-pred.zip'
24 | ref_joints, ref_vertices = load_pred_json(filepath)
25 | ref_joints_array = np.asarray(ref_joints)
26 | ref_vertices_array = np.asarray(ref_vertices)
27 |
28 | rotations = [0.0]
29 | for i in range(1,10):
30 | rotations.append(i*10)
31 | rotations.append(i*-10)
32 |
33 | scale = [0.7,0.8,0.9,1.0,1.1]
34 | multiscale_joints = []
35 | multiscale_vertices = []
36 |
37 | counter = 0
38 | for s in scale:
39 | for r in rotations:
40 | setting = 'sc%02d_rot%s'%(int(s*10),str(int(r)))
41 | filepath = output_dir+'ckpt200-'+setting+'-pred.zip'
42 | joints, vertices = load_pred_json(filepath)
43 | joints_array = np.asarray(joints)
44 | vertices_array = np.asarray(vertices)
45 |
46 | pa_joint_error, pa_joint_array, _ = get_alignMesh(joints_array, ref_joints_array, reduction=None)
47 | pa_vertices_error, pa_vertices_array, _ = get_alignMesh(vertices_array, ref_vertices_array, reduction=None)
48 | print('--------------------------')
49 | print('scale:', s, 'rotate', r)
50 | print('PAMPJPE:', 1000*np.mean(pa_joint_error))
51 | print('PAMPVPE:', 1000*np.mean(pa_vertices_error))
52 | multiscale_joints.append(pa_joint_array)
53 | multiscale_vertices.append(pa_vertices_array)
54 | counter = counter + 1
55 |
56 | overall_joints_array = ref_joints_array.copy()
57 | overall_vertices_array = ref_vertices_array.copy()
58 | for i in range(counter):
59 | overall_joints_array += multiscale_joints[i]
60 | overall_vertices_array += multiscale_vertices[i]
61 |
62 | overall_joints_array /= (1+counter)
63 | overall_vertices_array /= (1+counter)
64 | pa_joint_error, pa_joint_array, _ = get_alignMesh(overall_joints_array, ref_joints_array, reduction=None)
65 | pa_vertices_error, pa_vertices_array, _ = get_alignMesh(overall_vertices_array, ref_vertices_array, reduction=None)
66 | print('--------------------------')
67 | print('overall:')
68 | print('PAMPJPE:', 1000*np.mean(pa_joint_error))
69 | print('PAMPVPE:', 1000*np.mean(pa_vertices_error))
70 |
71 | joint_output_save = overall_joints_array.tolist()
72 | mesh_output_save = overall_vertices_array.tolist()
73 |
74 | print('save results to pred.json')
75 | with open('pred.json', 'w') as f:
76 | json.dump([joint_output_save, mesh_output_save], f)
77 |
78 |
79 | filepath = output_dir+'ckpt200-multisc-pred.zip'
80 | resolved_submit_cmd = 'zip ' + filepath + ' ' + 'pred.json'
81 | print(resolved_submit_cmd)
82 | os.system(resolved_submit_cmd)
83 | resolved_submit_cmd = 'rm pred.json'
84 | print(resolved_submit_cmd)
85 | os.system(resolved_submit_cmd)
86 |
87 |
88 | def run_multiscale_inference(model_path, mode, output_dir):
89 |
90 | if mode==True:
91 | rotations = [0.0]
92 | for i in range(1,10):
93 | rotations.append(i*10)
94 | rotations.append(i*-10)
95 | scale = [0.7,0.8,0.9,1.0,1.1]
96 | else:
97 | rotations = [0.0]
98 | scale = [1.0]
99 |
100 | job_cmd = "python ./src/tools/run_gphmer_handmesh.py " \
101 | "--val_yaml freihand_v3/test.yaml " \
102 | "--resume_checkpoint %s " \
103 | "--per_gpu_eval_batch_size 32 --run_eval_only --num_worker 2 " \
104 | "--multiscale_inference " \
105 | "--rot %f " \
106 | "--sc %s " \
107 | "--arch hrnet-w64 " \
108 | "--num_hidden_layers 4 " \
109 | "--num_attention_heads 4 " \
110 | "--input_feat_dim 2051,512,128 " \
111 | "--hidden_feat_dim 1024,256,64 " \
112 | "--output_dir %s"
113 |
114 | for s in scale:
115 | for r in rotations:
116 | resolved_submit_cmd = job_cmd%(model_path, r, s, output_dir)
117 | print(resolved_submit_cmd)
118 | os.system(resolved_submit_cmd)
119 |
120 | def main(args):
121 | model_path = args.model_path
122 | mode = args.multiscale_inference
123 | output_dir = args.output_dir
124 | run_multiscale_inference(model_path, mode, output_dir)
125 | if mode==True:
126 | multiscale_fusion(output_dir)
127 |
128 |
129 | if __name__ == "__main__":
130 | parser = argparse.ArgumentParser(description="Evaluate a checkpoint in the folder")
131 | parser.add_argument("--model_path")
132 | parser.add_argument("--multiscale_inference", default=False, action='store_true',)
133 | parser.add_argument("--output_dir", default='output/', type=str, required=False,
134 | help="The output directory to save checkpoint and test results.")
135 | args = parser.parse_args()
136 | main(args)
137 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/MeshGraphormer/27f7cdb33d2ce9e77969352e71a4e84eb5d9a522/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/comm.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Microsoft Corporation.
3 | Licensed under the MIT license.
4 |
5 | This file contains primitives for multi-gpu communication.
6 | This is useful when doing distributed training.
7 | """
8 |
9 | import pickle
10 | import time
11 |
12 | import torch
13 | import torch.distributed as dist
14 |
15 |
16 | def get_world_size():
17 | if not dist.is_available():
18 | return 1
19 | if not dist.is_initialized():
20 | return 1
21 | return dist.get_world_size()
22 |
23 |
24 | def get_rank():
25 | if not dist.is_available():
26 | return 0
27 | if not dist.is_initialized():
28 | return 0
29 | return dist.get_rank()
30 |
31 |
32 | def is_main_process():
33 | return get_rank() == 0
34 |
35 |
36 | def synchronize():
37 | """
38 | Helper function to synchronize (barrier) among all processes when
39 | using distributed training
40 | """
41 | if not dist.is_available():
42 | return
43 | if not dist.is_initialized():
44 | return
45 | world_size = dist.get_world_size()
46 | if world_size == 1:
47 | return
48 | dist.barrier()
49 |
50 |
51 | def gather_on_master(data):
52 | """Same as all_gather, but gathers data on master process only, using CPU.
53 | Thus, this does not work with NCCL backend unless they add CPU support.
54 |
55 | The memory consumption of this function is ~ 3x of data size. While in
56 | principal, it should be ~2x, it's not easy to force Python to release
57 | memory immediately and thus, peak memory usage could be up to 3x.
58 | """
59 | world_size = get_world_size()
60 | if world_size == 1:
61 | return [data]
62 |
63 | # serialized to a Tensor
64 | buffer = pickle.dumps(data)
65 | # trying to optimize memory, but in fact, it's not guaranteed to be released
66 | del data
67 | storage = torch.ByteStorage.from_buffer(buffer)
68 | del buffer
69 | tensor = torch.ByteTensor(storage)
70 |
71 | # obtain Tensor size of each rank
72 | local_size = torch.LongTensor([tensor.numel()])
73 | size_list = [torch.LongTensor([0]) for _ in range(world_size)]
74 | dist.all_gather(size_list, local_size)
75 | size_list = [int(size.item()) for size in size_list]
76 | max_size = max(size_list)
77 |
78 | if local_size != max_size:
79 | padding = torch.ByteTensor(size=(max_size - local_size,))
80 | tensor = torch.cat((tensor, padding), dim=0)
81 | del padding
82 |
83 | if is_main_process():
84 | tensor_list = []
85 | for _ in size_list:
86 | tensor_list.append(torch.ByteTensor(size=(max_size,)))
87 | dist.gather(tensor, gather_list=tensor_list, dst=0)
88 | del tensor
89 | else:
90 | dist.gather(tensor, gather_list=[], dst=0)
91 | del tensor
92 | return
93 |
94 | data_list = []
95 | for tensor in tensor_list:
96 | buffer = tensor.cpu().numpy().tobytes()
97 | del tensor
98 | data_list.append(pickle.loads(buffer))
99 | del buffer
100 |
101 | return data_list
102 |
103 |
104 | def all_gather(data):
105 | """
106 | Run all_gather on arbitrary picklable data (not necessarily tensors)
107 | Args:
108 | data: any picklable object
109 | Returns:
110 | list[data]: list of data gathered from each rank
111 | """
112 | world_size = get_world_size()
113 | if world_size == 1:
114 | return [data]
115 |
116 | # serialized to a Tensor
117 | buffer = pickle.dumps(data)
118 | storage = torch.ByteStorage.from_buffer(buffer)
119 | tensor = torch.ByteTensor(storage).to("cuda")
120 |
121 | # obtain Tensor size of each rank
122 | local_size = torch.LongTensor([tensor.numel()]).to("cuda")
123 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
124 | dist.all_gather(size_list, local_size)
125 | size_list = [int(size.item()) for size in size_list]
126 | max_size = max(size_list)
127 |
128 | # receiving Tensor from all ranks
129 | # we pad the tensor because torch all_gather does not support
130 | # gathering tensors of different shapes
131 | tensor_list = []
132 | for _ in size_list:
133 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
134 | if local_size != max_size:
135 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
136 | tensor = torch.cat((tensor, padding), dim=0)
137 | dist.all_gather(tensor_list, tensor)
138 |
139 | data_list = []
140 | for size, tensor in zip(size_list, tensor_list):
141 | buffer = tensor.cpu().numpy().tobytes()[:size]
142 | data_list.append(pickle.loads(buffer))
143 |
144 | return data_list
145 |
146 |
147 | def reduce_dict(input_dict, average=True):
148 | """
149 | Args:
150 | input_dict (dict): all the values will be reduced
151 | average (bool): whether to do average or sum
152 | Reduce the values in the dictionary from all processes so that process with rank
153 | 0 has the averaged results. Returns a dict with the same fields as
154 | input_dict, after reduction.
155 | """
156 | world_size = get_world_size()
157 | if world_size < 2:
158 | return input_dict
159 | with torch.no_grad():
160 | names = []
161 | values = []
162 | # sort the keys so that they are consistent across processes
163 | for k in sorted(input_dict.keys()):
164 | names.append(k)
165 | values.append(input_dict[k])
166 | values = torch.stack(values, dim=0)
167 | dist.reduce(values, dst=0)
168 | if dist.get_rank() == 0 and average:
169 | # only main process gets accumulated, so only divide by
170 | # world_size in this case
171 | values /= world_size
172 | reduced_dict = {k: v for k, v in zip(names, values)}
173 | return reduced_dict
174 |
--------------------------------------------------------------------------------
/src/utils/dataset_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Microsoft Corporation.
3 | Licensed under the MIT license.
4 |
5 | """
6 |
7 |
8 | import os
9 | import os.path as op
10 | import numpy as np
11 | import base64
12 | import cv2
13 | import yaml
14 | from collections import OrderedDict
15 |
16 |
17 | def img_from_base64(imagestring):
18 | try:
19 | jpgbytestring = base64.b64decode(imagestring)
20 | nparr = np.frombuffer(jpgbytestring, np.uint8)
21 | r = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
22 | return r
23 | except:
24 | return None
25 |
26 |
27 | def load_labelmap(labelmap_file):
28 | label_dict = None
29 | if labelmap_file is not None and op.isfile(labelmap_file):
30 | label_dict = OrderedDict()
31 | with open(labelmap_file, 'r') as fp:
32 | for line in fp:
33 | label = line.strip().split('\t')[0]
34 | if label in label_dict:
35 | raise ValueError("Duplicate label " + label + " in labelmap.")
36 | else:
37 | label_dict[label] = len(label_dict)
38 | return label_dict
39 |
40 |
41 | def load_shuffle_file(shuf_file):
42 | shuf_list = None
43 | if shuf_file is not None:
44 | with open(shuf_file, 'r') as fp:
45 | shuf_list = []
46 | for i in fp:
47 | shuf_list.append(int(i.strip()))
48 | return shuf_list
49 |
50 |
51 | def load_box_shuffle_file(shuf_file):
52 | if shuf_file is not None:
53 | with open(shuf_file, 'r') as fp:
54 | img_shuf_list = []
55 | box_shuf_list = []
56 | for i in fp:
57 | idx = [int(_) for _ in i.strip().split('\t')]
58 | img_shuf_list.append(idx[0])
59 | box_shuf_list.append(idx[1])
60 | return [img_shuf_list, box_shuf_list]
61 | return None
62 |
63 |
64 | def load_from_yaml_file(file_name):
65 | with open(file_name, 'r') as fp:
66 | return yaml.load(fp, Loader=yaml.CLoader)
67 |
--------------------------------------------------------------------------------
/src/utils/geometric_layers.py:
--------------------------------------------------------------------------------
1 | """
2 | Useful geometric operations, e.g. Orthographic projection and a differentiable Rodrigues formula
3 |
4 | Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR
5 | """
6 | import torch
7 |
8 | def rodrigues(theta):
9 | """Convert axis-angle representation to rotation matrix.
10 | Args:
11 | theta: size = [B, 3]
12 | Returns:
13 | Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
14 | """
15 | l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1)
16 | angle = torch.unsqueeze(l1norm, -1)
17 | normalized = torch.div(theta, angle)
18 | angle = angle * 0.5
19 | v_cos = torch.cos(angle)
20 | v_sin = torch.sin(angle)
21 | quat = torch.cat([v_cos, v_sin * normalized], dim = 1)
22 | return quat2mat(quat)
23 |
24 | def quat2mat(quat):
25 | """Convert quaternion coefficients to rotation matrix.
26 | Args:
27 | quat: size = [B, 4] 4 <===>(w, x, y, z)
28 | Returns:
29 | Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
30 | """
31 | norm_quat = quat
32 | norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
33 | w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
34 |
35 | B = quat.size(0)
36 |
37 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
38 | wx, wy, wz = w*x, w*y, w*z
39 | xy, xz, yz = x*y, x*z, y*z
40 |
41 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
42 | 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
43 | 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
44 | return rotMat
45 |
46 | def orthographic_projection(X, camera):
47 | """Perform orthographic projection of 3D points X using the camera parameters
48 | Args:
49 | X: size = [B, N, 3]
50 | camera: size = [B, 3]
51 | Returns:
52 | Projected 2D points -- size = [B, N, 2]
53 | """
54 | camera = camera.view(-1, 1, 3)
55 | X_trans = X[:, :, :2] + camera[:, :, 1:]
56 | shape = X_trans.shape
57 | X_2d = (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape)
58 | return X_2d
59 |
--------------------------------------------------------------------------------
/src/utils/image_ops.py:
--------------------------------------------------------------------------------
1 | """
2 | Image processing tools
3 |
4 | Modified from open source projects:
5 | (https://github.com/nkolot/GraphCMR/)
6 | (https://github.com/open-mmlab/mmdetection)
7 |
8 | """
9 |
10 | import numpy as np
11 | import base64
12 | import cv2
13 | import torch
14 | import scipy.misc
15 |
16 | def img_from_base64(imagestring):
17 | try:
18 | jpgbytestring = base64.b64decode(imagestring)
19 | nparr = np.frombuffer(jpgbytestring, np.uint8)
20 | r = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
21 | return r
22 | except ValueError:
23 | return None
24 |
25 | def myimrotate(img, angle, center=None, scale=1.0, border_value=0, auto_bound=False):
26 | if center is not None and auto_bound:
27 | raise ValueError('`auto_bound` conflicts with `center`')
28 | h, w = img.shape[:2]
29 | if center is None:
30 | center = ((w - 1) * 0.5, (h - 1) * 0.5)
31 | assert isinstance(center, tuple)
32 |
33 | matrix = cv2.getRotationMatrix2D(center, angle, scale)
34 | if auto_bound:
35 | cos = np.abs(matrix[0, 0])
36 | sin = np.abs(matrix[0, 1])
37 | new_w = h * sin + w * cos
38 | new_h = h * cos + w * sin
39 | matrix[0, 2] += (new_w - w) * 0.5
40 | matrix[1, 2] += (new_h - h) * 0.5
41 | w = int(np.round(new_w))
42 | h = int(np.round(new_h))
43 | rotated = cv2.warpAffine(img, matrix, (w, h), borderValue=border_value)
44 | return rotated
45 |
46 | def myimresize(img, size, return_scale=False, interpolation='bilinear'):
47 |
48 | h, w = img.shape[:2]
49 | resized_img = cv2.resize(
50 | img, (size[0],size[1]), interpolation=cv2.INTER_LINEAR)
51 | if not return_scale:
52 | return resized_img
53 | else:
54 | w_scale = size[0] / w
55 | h_scale = size[1] / h
56 | return resized_img, w_scale, h_scale
57 |
58 |
59 | def get_transform(center, scale, res, rot=0):
60 | """Generate transformation matrix."""
61 | h = 200 * scale
62 | t = np.zeros((3, 3))
63 | t[0, 0] = float(res[1]) / h
64 | t[1, 1] = float(res[0]) / h
65 | t[0, 2] = res[1] * (-float(center[0]) / h + .5)
66 | t[1, 2] = res[0] * (-float(center[1]) / h + .5)
67 | t[2, 2] = 1
68 | if not rot == 0:
69 | rot = -rot # To match direction of rotation from cropping
70 | rot_mat = np.zeros((3,3))
71 | rot_rad = rot * np.pi / 180
72 | sn,cs = np.sin(rot_rad), np.cos(rot_rad)
73 | rot_mat[0,:2] = [cs, -sn]
74 | rot_mat[1,:2] = [sn, cs]
75 | rot_mat[2,2] = 1
76 | # Need to rotate around center
77 | t_mat = np.eye(3)
78 | t_mat[0,2] = -res[1]/2
79 | t_mat[1,2] = -res[0]/2
80 | t_inv = t_mat.copy()
81 | t_inv[:2,2] *= -1
82 | t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t)))
83 | return t
84 |
85 | def transform(pt, center, scale, res, invert=0, rot=0):
86 | """Transform pixel location to different reference."""
87 | t = get_transform(center, scale, res, rot=rot)
88 | if invert:
89 | # t = np.linalg.inv(t)
90 | t_torch = torch.from_numpy(t)
91 | t_torch = torch.inverse(t_torch)
92 | t = t_torch.numpy()
93 | new_pt = np.array([pt[0]-1, pt[1]-1, 1.]).T
94 | new_pt = np.dot(t, new_pt)
95 | return new_pt[:2].astype(int)+1
96 |
97 | def crop(img, center, scale, res, rot=0):
98 | """Crop image according to the supplied bounding box."""
99 | # Upper left point
100 | ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
101 | # Bottom right point
102 | br = np.array(transform([res[0]+1,
103 | res[1]+1], center, scale, res, invert=1))-1
104 | # Padding so that when rotated proper amount of context is included
105 | pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
106 | if not rot == 0:
107 | ul -= pad
108 | br += pad
109 | new_shape = [br[1] - ul[1], br[0] - ul[0]]
110 | if len(img.shape) > 2:
111 | new_shape += [img.shape[2]]
112 | new_img = np.zeros(new_shape)
113 |
114 | # Range to fill new array
115 | new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
116 | new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
117 | # Range to sample from original image
118 | old_x = max(0, ul[0]), min(len(img[0]), br[0])
119 | old_y = max(0, ul[1]), min(len(img), br[1])
120 |
121 | new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
122 | old_x[0]:old_x[1]]
123 | if not rot == 0:
124 | # Remove padding
125 | # new_img = scipy.misc.imrotate(new_img, rot)
126 | new_img = myimrotate(new_img, rot)
127 | new_img = new_img[pad:-pad, pad:-pad]
128 |
129 | # new_img = scipy.misc.imresize(new_img, res)
130 | new_img = myimresize(new_img, [res[0], res[1]])
131 | return new_img
132 |
133 | def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
134 | """'Undo' the image cropping/resizing.
135 | This function is used when evaluating mask/part segmentation.
136 | """
137 | res = img.shape[:2]
138 | # Upper left point
139 | ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
140 | # Bottom right point
141 | br = np.array(transform([res[0]+1,res[1]+1], center, scale, res, invert=1))-1
142 | # size of cropped image
143 | crop_shape = [br[1] - ul[1], br[0] - ul[0]]
144 |
145 | new_shape = [br[1] - ul[1], br[0] - ul[0]]
146 | if len(img.shape) > 2:
147 | new_shape += [img.shape[2]]
148 | new_img = np.zeros(orig_shape, dtype=np.uint8)
149 | # Range to fill new array
150 | new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
151 | new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
152 | # Range to sample from original image
153 | old_x = max(0, ul[0]), min(orig_shape[1], br[0])
154 | old_y = max(0, ul[1]), min(orig_shape[0], br[1])
155 | # img = scipy.misc.imresize(img, crop_shape, interp='nearest')
156 | img = myimresize(img, [crop_shape[0],crop_shape[1]])
157 | new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
158 | return new_img
159 |
160 | def rot_aa(aa, rot):
161 | """Rotate axis angle parameters."""
162 | # pose parameters
163 | R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
164 | [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
165 | [0, 0, 1]])
166 | # find the rotation of the body in camera frame
167 | per_rdg, _ = cv2.Rodrigues(aa)
168 | # apply the global rotation to the global orientation
169 | resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg))
170 | aa = (resrot.T)[0]
171 | return aa
172 |
173 | def flip_img(img):
174 | """Flip rgb images or masks.
175 | channels come last, e.g. (256,256,3).
176 | """
177 | img = np.fliplr(img)
178 | return img
179 |
180 | def flip_kp(kp):
181 | """Flip keypoints."""
182 | flipped_parts = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, 20, 23, 22]
183 | kp = kp[flipped_parts]
184 | kp[:,0] = - kp[:,0]
185 | return kp
186 |
187 | def flip_pose(pose):
188 | """Flip pose.
189 | The flipping is based on SMPL parameters.
190 | """
191 | flippedParts = [0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13,
192 | 14 ,18, 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33,
193 | 34, 35, 30, 31, 32, 36, 37, 38, 42, 43, 44, 39, 40, 41,
194 | 45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54, 55,
195 | 56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68]
196 | pose = pose[flippedParts]
197 | # we also negate the second and the third dimension of the axis-angle
198 | pose[1::3] = -pose[1::3]
199 | pose[2::3] = -pose[2::3]
200 | return pose
201 |
202 | def flip_aa(aa):
203 | """Flip axis-angle representation.
204 | We negate the second and the third dimension of the axis-angle.
205 | """
206 | aa[1] = -aa[1]
207 | aa[2] = -aa[2]
208 | return aa
--------------------------------------------------------------------------------
/src/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | import logging
3 | import os
4 | import sys
5 | from logging import StreamHandler, Handler, getLevelName
6 |
7 |
8 | # this class is a copy of logging.FileHandler except we end self.close()
9 | # at the end of each emit. While closing file and reopening file after each
10 | # write is not efficient, it allows us to see partial logs when writing to
11 | # fused Azure blobs, which is very convenient
12 | class FileHandler(StreamHandler):
13 | """
14 | A handler class which writes formatted logging records to disk files.
15 | """
16 | def __init__(self, filename, mode='a', encoding=None, delay=False):
17 | """
18 | Open the specified file and use it as the stream for logging.
19 | """
20 | # Issue #27493: add support for Path objects to be passed in
21 | filename = os.fspath(filename)
22 | #keep the absolute path, otherwise derived classes which use this
23 | #may come a cropper when the current directory changes
24 | self.baseFilename = os.path.abspath(filename)
25 | self.mode = mode
26 | self.encoding = encoding
27 | self.delay = delay
28 | if delay:
29 | #We don't open the stream, but we still need to call the
30 | #Handler constructor to set level, formatter, lock etc.
31 | Handler.__init__(self)
32 | self.stream = None
33 | else:
34 | StreamHandler.__init__(self, self._open())
35 |
36 | def close(self):
37 | """
38 | Closes the stream.
39 | """
40 | self.acquire()
41 | try:
42 | try:
43 | if self.stream:
44 | try:
45 | self.flush()
46 | finally:
47 | stream = self.stream
48 | self.stream = None
49 | if hasattr(stream, "close"):
50 | stream.close()
51 | finally:
52 | # Issue #19523: call unconditionally to
53 | # prevent a handler leak when delay is set
54 | StreamHandler.close(self)
55 | finally:
56 | self.release()
57 |
58 | def _open(self):
59 | """
60 | Open the current base file with the (original) mode and encoding.
61 | Return the resulting stream.
62 | """
63 | return open(self.baseFilename, self.mode, encoding=self.encoding)
64 |
65 | def emit(self, record):
66 | """
67 | Emit a record.
68 |
69 | If the stream was not opened because 'delay' was specified in the
70 | constructor, open it before calling the superclass's emit.
71 | """
72 | if self.stream is None:
73 | self.stream = self._open()
74 | StreamHandler.emit(self, record)
75 | self.close()
76 |
77 | def __repr__(self):
78 | level = getLevelName(self.level)
79 | return '<%s %s (%s)>' % (self.__class__.__name__, self.baseFilename, level)
80 |
81 |
82 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt"):
83 | logger = logging.getLogger(name)
84 | logger.setLevel(logging.DEBUG)
85 | # don't log results for the non-master process
86 | if distributed_rank > 0:
87 | return logger
88 | ch = logging.StreamHandler(stream=sys.stdout)
89 | ch.setLevel(logging.DEBUG)
90 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
91 | ch.setFormatter(formatter)
92 | logger.addHandler(ch)
93 |
94 | if save_dir:
95 | fh = FileHandler(os.path.join(save_dir, filename))
96 | fh.setLevel(logging.DEBUG)
97 | fh.setFormatter(formatter)
98 | logger.addHandler(fh)
99 |
100 | return logger
101 |
--------------------------------------------------------------------------------
/src/utils/metric_logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Microsoft Corporation.
3 | Licensed under the MIT license.
4 |
5 | Basic logger. It Computes and stores the average and current value
6 | """
7 |
8 | class AverageMeter(object):
9 |
10 | def __init__(self):
11 | self.reset()
12 |
13 | def reset(self):
14 | self.val = 0
15 | self.avg = 0
16 | self.sum = 0
17 | self.count = 0
18 |
19 | def update(self, val, n=1):
20 | self.val = val
21 | self.sum += val * n
22 | self.count += n
23 | self.avg = self.sum / self.count
24 |
25 |
26 |
27 | class EvalMetricsLogger(object):
28 |
29 | def __init__(self):
30 | self.reset()
31 |
32 | def reset(self):
33 | # define a upper-bound performance (worst case)
34 | # numbers are in unit millimeter
35 | self.PAmPJPE = 100.0/1000.0
36 | self.mPJPE = 100.0/1000.0
37 | self.mPVE = 100.0/1000.0
38 |
39 | self.epoch = 0
40 |
41 | def update(self, mPVE, mPJPE, PAmPJPE, epoch):
42 | self.PAmPJPE = PAmPJPE
43 | self.mPJPE = mPJPE
44 | self.mPVE = mPVE
45 | self.epoch = epoch
46 |
--------------------------------------------------------------------------------
/src/utils/metric_pampjpe.py:
--------------------------------------------------------------------------------
1 | """
2 | Functions for compuing Procrustes alignment and reconstruction error
3 |
4 | Parts of the code are adapted from https://github.com/akanazawa/hmr
5 |
6 | """
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | from __future__ import print_function
10 | import numpy as np
11 |
12 | def compute_similarity_transform(S1, S2):
13 | """Computes a similarity transform (sR, t) that takes
14 | a set of 3D points S1 (3 x N) closest to a set of 3D points S2,
15 | where R is an 3x3 rotation matrix, t 3x1 translation, s scale.
16 | i.e. solves the orthogonal Procrutes problem.
17 | """
18 | transposed = False
19 | if S1.shape[0] != 3 and S1.shape[0] != 2:
20 | S1 = S1.T
21 | S2 = S2.T
22 | transposed = True
23 | assert(S2.shape[1] == S1.shape[1])
24 |
25 | # 1. Remove mean.
26 | mu1 = S1.mean(axis=1, keepdims=True)
27 | mu2 = S2.mean(axis=1, keepdims=True)
28 | X1 = S1 - mu1
29 | X2 = S2 - mu2
30 |
31 | # 2. Compute variance of X1 used for scale.
32 | var1 = np.sum(X1**2)
33 |
34 | # 3. The outer product of X1 and X2.
35 | K = X1.dot(X2.T)
36 |
37 | # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
38 | # singular vectors of K.
39 | U, s, Vh = np.linalg.svd(K)
40 | V = Vh.T
41 | # Construct Z that fixes the orientation of R to get det(R)=1.
42 | Z = np.eye(U.shape[0])
43 | Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T)))
44 | # Construct R.
45 | R = V.dot(Z.dot(U.T))
46 |
47 | # 5. Recover scale.
48 | scale = np.trace(R.dot(K)) / var1
49 |
50 | # 6. Recover translation.
51 | t = mu2 - scale*(R.dot(mu1))
52 |
53 | # 7. Error:
54 | S1_hat = scale*R.dot(S1) + t
55 |
56 | if transposed:
57 | S1_hat = S1_hat.T
58 |
59 | return S1_hat
60 |
61 | def compute_similarity_transform_batch(S1, S2):
62 | """Batched version of compute_similarity_transform."""
63 | S1_hat = np.zeros_like(S1)
64 | for i in range(S1.shape[0]):
65 | S1_hat[i] = compute_similarity_transform(S1[i], S2[i])
66 | return S1_hat
67 |
68 | def reconstruction_error(S1, S2, reduction='mean'):
69 | """Do Procrustes alignment and compute reconstruction error."""
70 | S1_hat = compute_similarity_transform_batch(S1, S2)
71 | re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
72 | if reduction == 'mean':
73 | re = re.mean()
74 | elif reduction == 'sum':
75 | re = re.sum()
76 | return re
77 |
78 |
79 | def reconstruction_error_v2(S1, S2, J24_TO_J14, reduction='mean'):
80 | """Do Procrustes alignment and compute reconstruction error."""
81 | S1_hat = compute_similarity_transform_batch(S1, S2)
82 | S1_hat = S1_hat[:,J24_TO_J14,:]
83 | S2 = S2[:,J24_TO_J14,:]
84 | re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
85 | if reduction == 'mean':
86 | re = re.mean()
87 | elif reduction == 'sum':
88 | re = re.sum()
89 | return re
90 |
91 | def get_alignMesh(S1, S2, reduction='mean'):
92 | """Do Procrustes alignment and compute reconstruction error."""
93 | S1_hat = compute_similarity_transform_batch(S1, S2)
94 | re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1)
95 | if reduction == 'mean':
96 | re = re.mean()
97 | elif reduction == 'sum':
98 | re = re.sum()
99 | return re, S1_hat, S2
100 |
--------------------------------------------------------------------------------
/src/utils/miscellaneous.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | import errno
3 | import os
4 | import os.path as op
5 | import re
6 | import logging
7 | import numpy as np
8 | import torch
9 | import random
10 | import shutil
11 | from .comm import is_main_process
12 | import yaml
13 |
14 |
15 | def mkdir(path):
16 | # if it is the current folder, skip.
17 | # otherwise the original code will raise FileNotFoundError
18 | if path == '':
19 | return
20 | try:
21 | os.makedirs(path)
22 | except OSError as e:
23 | if e.errno != errno.EEXIST:
24 | raise
25 |
26 |
27 | def save_config(cfg, path):
28 | if is_main_process():
29 | with open(path, 'w') as f:
30 | f.write(cfg.dump())
31 |
32 |
33 | def config_iteration(output_dir, max_iter):
34 | save_file = os.path.join(output_dir, 'last_checkpoint')
35 | iteration = -1
36 | if os.path.exists(save_file):
37 | with open(save_file, 'r') as f:
38 | fname = f.read().strip()
39 | model_name = os.path.basename(fname)
40 | model_path = os.path.dirname(fname)
41 | if model_name.startswith('model_') and len(model_name) == 17:
42 | iteration = int(model_name[-11:-4])
43 | elif model_name == "model_final":
44 | iteration = max_iter
45 | elif model_path.startswith('checkpoint-') and len(model_path) == 18:
46 | iteration = int(model_path.split('-')[-1])
47 | return iteration
48 |
49 |
50 | def get_matching_parameters(model, regexp, none_on_empty=True):
51 | """Returns parameters matching regular expression"""
52 | if not regexp:
53 | if none_on_empty:
54 | return {}
55 | else:
56 | return dict(model.named_parameters())
57 | compiled_pattern = re.compile(regexp)
58 | params = {}
59 | for weight_name, weight in model.named_parameters():
60 | if compiled_pattern.match(weight_name):
61 | params[weight_name] = weight
62 | return params
63 |
64 |
65 | def freeze_weights(model, regexp):
66 | """Freeze weights based on regular expression."""
67 | logger = logging.getLogger("maskrcnn_benchmark.trainer")
68 | for weight_name, weight in get_matching_parameters(model, regexp).items():
69 | weight.requires_grad = False
70 | logger.info("Disabled training of {}".format(weight_name))
71 |
72 |
73 | def unfreeze_weights(model, regexp, backbone_freeze_at=-1,
74 | is_distributed=False):
75 | """Unfreeze weights based on regular expression.
76 | This is helpful during training to unfreeze freezed weights after
77 | other unfreezed weights have been trained for some iterations.
78 | """
79 | logger = logging.getLogger("maskrcnn_benchmark.trainer")
80 | for weight_name, weight in get_matching_parameters(model, regexp).items():
81 | weight.requires_grad = True
82 | logger.info("Enabled training of {}".format(weight_name))
83 | if backbone_freeze_at >= 0:
84 | logger.info("Freeze backbone at stage: {}".format(backbone_freeze_at))
85 | if is_distributed:
86 | model.module.backbone.body._freeze_backbone(backbone_freeze_at)
87 | else:
88 | model.backbone.body._freeze_backbone(backbone_freeze_at)
89 |
90 |
91 | def delete_tsv_files(tsvs):
92 | for t in tsvs:
93 | if op.isfile(t):
94 | try_delete(t)
95 | line = op.splitext(t)[0] + '.lineidx'
96 | if op.isfile(line):
97 | try_delete(line)
98 |
99 |
100 | def concat_files(ins, out):
101 | mkdir(op.dirname(out))
102 | out_tmp = out + '.tmp'
103 | with open(out_tmp, 'wb') as fp_out:
104 | for i, f in enumerate(ins):
105 | logging.info('concating {}/{} - {}'.format(i, len(ins), f))
106 | with open(f, 'rb') as fp_in:
107 | shutil.copyfileobj(fp_in, fp_out, 1024*1024*10)
108 | os.rename(out_tmp, out)
109 |
110 |
111 | def concat_tsv_files(tsvs, out_tsv):
112 | concat_files(tsvs, out_tsv)
113 | sizes = [os.stat(t).st_size for t in tsvs]
114 | sizes = np.cumsum(sizes)
115 | all_idx = []
116 | for i, t in enumerate(tsvs):
117 | for idx in load_list_file(op.splitext(t)[0] + '.lineidx'):
118 | if i == 0:
119 | all_idx.append(idx)
120 | else:
121 | all_idx.append(str(int(idx) + sizes[i - 1]))
122 | with open(op.splitext(out_tsv)[0] + '.lineidx', 'w') as f:
123 | f.write('\n'.join(all_idx))
124 |
125 |
126 | def load_list_file(fname):
127 | with open(fname, 'r') as fp:
128 | lines = fp.readlines()
129 | result = [line.strip() for line in lines]
130 | if len(result) > 0 and result[-1] == '':
131 | result = result[:-1]
132 | return result
133 |
134 |
135 | def try_once(func):
136 | def func_wrapper(*args, **kwargs):
137 | try:
138 | return func(*args, **kwargs)
139 | except Exception as e:
140 | logging.info('ignore error \n{}'.format(str(e)))
141 | return func_wrapper
142 |
143 |
144 | @try_once
145 | def try_delete(f):
146 | os.remove(f)
147 |
148 |
149 | def set_seed(seed, n_gpu):
150 | random.seed(seed)
151 | np.random.seed(seed)
152 | torch.manual_seed(seed)
153 | if n_gpu > 0:
154 | torch.cuda.manual_seed_all(seed)
155 |
156 |
157 | def print_and_run_cmd(cmd):
158 | print(cmd)
159 | os.system(cmd)
160 |
161 |
162 | def write_to_yaml_file(context, file_name):
163 | with open(file_name, 'w') as fp:
164 | yaml.dump(context, fp, encoding='utf-8')
165 |
166 |
167 | def load_from_yaml_file(yaml_file):
168 | with open(yaml_file, 'r') as fp:
169 | return yaml.load(fp, Loader=yaml.CLoader)
170 |
171 |
172 |
--------------------------------------------------------------------------------
/src/utils/tsv_file.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Microsoft Corporation.
3 | Licensed under the MIT license.
4 |
5 | Definition of TSV class
6 | """
7 |
8 |
9 | import logging
10 | import os
11 | import os.path as op
12 |
13 |
14 | def generate_lineidx(filein, idxout):
15 | idxout_tmp = idxout + '.tmp'
16 | with open(filein, 'r') as tsvin, open(idxout_tmp,'w') as tsvout:
17 | fsize = os.fstat(tsvin.fileno()).st_size
18 | fpos = 0
19 | while fpos!=fsize:
20 | tsvout.write(str(fpos)+"\n")
21 | tsvin.readline()
22 | fpos = tsvin.tell()
23 | os.rename(idxout_tmp, idxout)
24 |
25 |
26 | def read_to_character(fp, c):
27 | result = []
28 | while True:
29 | s = fp.read(32)
30 | assert s != ''
31 | if c in s:
32 | result.append(s[: s.index(c)])
33 | break
34 | else:
35 | result.append(s)
36 | return ''.join(result)
37 |
38 |
39 | class TSVFile(object):
40 | def __init__(self, tsv_file, generate_lineidx=False):
41 | self.tsv_file = tsv_file
42 | self.lineidx = op.splitext(tsv_file)[0] + '.lineidx'
43 | self._fp = None
44 | self._lineidx = None
45 | # the process always keeps the process which opens the file.
46 | # If the pid is not equal to the currrent pid, we will re-open the file.
47 | self.pid = None
48 | # generate lineidx if not exist
49 | if not op.isfile(self.lineidx) and generate_lineidx:
50 | generate_lineidx(self.tsv_file, self.lineidx)
51 |
52 | def __del__(self):
53 | if self._fp:
54 | self._fp.close()
55 |
56 | def __str__(self):
57 | return "TSVFile(tsv_file='{}')".format(self.tsv_file)
58 |
59 | def __repr__(self):
60 | return str(self)
61 |
62 | def num_rows(self):
63 | self._ensure_lineidx_loaded()
64 | return len(self._lineidx)
65 |
66 | def seek(self, idx):
67 | self._ensure_tsv_opened()
68 | self._ensure_lineidx_loaded()
69 | try:
70 | pos = self._lineidx[idx]
71 | except:
72 | logging.info('{}-{}'.format(self.tsv_file, idx))
73 | raise
74 | self._fp.seek(pos)
75 | return [s.strip() for s in self._fp.readline().split('\t')]
76 |
77 | def seek_first_column(self, idx):
78 | self._ensure_tsv_opened()
79 | self._ensure_lineidx_loaded()
80 | pos = self._lineidx[idx]
81 | self._fp.seek(pos)
82 | return read_to_character(self._fp, '\t')
83 |
84 | def get_key(self, idx):
85 | return self.seek_first_column(idx)
86 |
87 | def __getitem__(self, index):
88 | return self.seek(index)
89 |
90 | def __len__(self):
91 | return self.num_rows()
92 |
93 | def _ensure_lineidx_loaded(self):
94 | if self._lineidx is None:
95 | logging.info('loading lineidx: {}'.format(self.lineidx))
96 | with open(self.lineidx, 'r') as fp:
97 | self._lineidx = [int(i.strip()) for i in fp.readlines()]
98 |
99 | def _ensure_tsv_opened(self):
100 | if self._fp is None:
101 | self._fp = open(self.tsv_file, 'r')
102 | self.pid = os.getpid()
103 |
104 | if self.pid != os.getpid():
105 | logging.info('re-open {} because the process id changed'.format(self.tsv_file))
106 | self._fp = open(self.tsv_file, 'r')
107 | self.pid = os.getpid()
108 |
109 |
110 | class CompositeTSVFile():
111 | def __init__(self, file_list, seq_file, root='.'):
112 | if isinstance(file_list, str):
113 | self.file_list = load_list_file(file_list)
114 | else:
115 | assert isinstance(file_list, list)
116 | self.file_list = file_list
117 |
118 | self.seq_file = seq_file
119 | self.root = root
120 | self.initialized = False
121 | self.initialize()
122 |
123 | def get_key(self, index):
124 | idx_source, idx_row = self.seq[index]
125 | k = self.tsvs[idx_source].get_key(idx_row)
126 | return '_'.join([self.file_list[idx_source], k])
127 |
128 | def num_rows(self):
129 | return len(self.seq)
130 |
131 | def __getitem__(self, index):
132 | idx_source, idx_row = self.seq[index]
133 | return self.tsvs[idx_source].seek(idx_row)
134 |
135 | def __len__(self):
136 | return len(self.seq)
137 |
138 | def initialize(self):
139 | '''
140 | this function has to be called in init function if cache_policy is
141 | enabled. Thus, let's always call it in init funciton to make it simple.
142 | '''
143 | if self.initialized:
144 | return
145 | self.seq = []
146 | with open(self.seq_file, 'r') as fp:
147 | for line in fp:
148 | parts = line.strip().split('\t')
149 | self.seq.append([int(parts[0]), int(parts[1])])
150 | self.tsvs = [TSVFile(op.join(self.root, f)) for f in self.file_list]
151 | self.initialized = True
152 |
153 |
154 | def load_list_file(fname):
155 | with open(fname, 'r') as fp:
156 | lines = fp.readlines()
157 | result = [line.strip() for line in lines]
158 | if len(result) > 0 and result[-1] == '':
159 | result = result[:-1]
160 | return result
161 |
162 |
163 |
--------------------------------------------------------------------------------
/src/utils/tsv_file_ops.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Microsoft Corporation.
3 | Licensed under the MIT license.
4 |
5 | Basic operations for TSV files
6 | """
7 |
8 |
9 | import os
10 | import os.path as op
11 | import json
12 | import numpy as np
13 | import base64
14 | import cv2
15 | from tqdm import tqdm
16 | import yaml
17 | from src.utils.miscellaneous import mkdir
18 | from src.utils.tsv_file import TSVFile
19 |
20 |
21 | def img_from_base64(imagestring):
22 | try:
23 | jpgbytestring = base64.b64decode(imagestring)
24 | nparr = np.frombuffer(jpgbytestring, np.uint8)
25 | r = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
26 | return r
27 | except ValueError:
28 | return None
29 |
30 | def load_linelist_file(linelist_file):
31 | if linelist_file is not None:
32 | line_list = []
33 | with open(linelist_file, 'r') as fp:
34 | for i in fp:
35 | line_list.append(int(i.strip()))
36 | return line_list
37 |
38 | def tsv_writer(values, tsv_file, sep='\t'):
39 | mkdir(op.dirname(tsv_file))
40 | lineidx_file = op.splitext(tsv_file)[0] + '.lineidx'
41 | idx = 0
42 | tsv_file_tmp = tsv_file + '.tmp'
43 | lineidx_file_tmp = lineidx_file + '.tmp'
44 | with open(tsv_file_tmp, 'w') as fp, open(lineidx_file_tmp, 'w') as fpidx:
45 | assert values is not None
46 | for value in values:
47 | assert value is not None
48 | value = [v if type(v)!=bytes else v.decode('utf-8') for v in value]
49 | v = '{0}\n'.format(sep.join(map(str, value)))
50 | fp.write(v)
51 | fpidx.write(str(idx) + '\n')
52 | idx = idx + len(v)
53 | os.rename(tsv_file_tmp, tsv_file)
54 | os.rename(lineidx_file_tmp, lineidx_file)
55 |
56 | def tsv_reader(tsv_file, sep='\t'):
57 | with open(tsv_file, 'r') as fp:
58 | for i, line in enumerate(fp):
59 | yield [x.strip() for x in line.split(sep)]
60 |
61 | def config_save_file(tsv_file, save_file=None, append_str='.new.tsv'):
62 | if save_file is not None:
63 | return save_file
64 | return op.splitext(tsv_file)[0] + append_str
65 |
66 | def get_line_list(linelist_file=None, num_rows=None):
67 | if linelist_file is not None:
68 | return load_linelist_file(linelist_file)
69 |
70 | if num_rows is not None:
71 | return [i for i in range(num_rows)]
72 |
73 | def generate_hw_file(img_file, save_file=None):
74 | rows = tsv_reader(img_file)
75 | def gen_rows():
76 | for i, row in tqdm(enumerate(rows)):
77 | row1 = [row[0]]
78 | img = img_from_base64(row[-1])
79 | height = img.shape[0]
80 | width = img.shape[1]
81 | row1.append(json.dumps([{"height":height, "width": width}]))
82 | yield row1
83 |
84 | save_file = config_save_file(img_file, save_file, '.hw.tsv')
85 | tsv_writer(gen_rows(), save_file)
86 |
87 | def generate_linelist_file(label_file, save_file=None, ignore_attrs=()):
88 | # generate a list of image that has labels
89 | # images with only ignore labels are not selected.
90 | line_list = []
91 | rows = tsv_reader(label_file)
92 | for i, row in tqdm(enumerate(rows)):
93 | labels = json.loads(row[1])
94 | if labels:
95 | if ignore_attrs and all([any([lab[attr] for attr in ignore_attrs if attr in lab]) \
96 | for lab in labels]):
97 | continue
98 | line_list.append([i])
99 |
100 | save_file = config_save_file(label_file, save_file, '.linelist.tsv')
101 | tsv_writer(line_list, save_file)
102 |
103 | def load_from_yaml_file(yaml_file):
104 | with open(yaml_file, 'r') as fp:
105 | return yaml.load(fp, Loader=yaml.CLoader)
106 |
107 | def find_file_path_in_yaml(fname, root):
108 | if fname is not None:
109 | if op.isfile(fname):
110 | return fname
111 | elif op.isfile(op.join(root, fname)):
112 | return op.join(root, fname)
113 | else:
114 | raise FileNotFoundError(
115 | errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname)
116 | )
117 |
--------------------------------------------------------------------------------