├── LICENSE ├── README.md ├── assets ├── framework.jpg └── internet video.gif ├── environment.yml └── src ├── base_modules.py ├── dataset ├── __init__.py └── human36m.py ├── lifting.py ├── main.py ├── network ├── __init__.py ├── dgridconv.py ├── dgridconv_autogrids.py └── gridconv.py └── tool ├── __init__.py ├── argument.py ├── log.py └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | Copyright (c) 2023 OSVAI/GridConv 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 3D Human Pose Lifting with Grid Convolution 2 | 3 | --- 4 | By Yangyuxuan Kang, Yuyang Liu, Anbang Yao, Shandong Wang, and Enhua Wu. 5 | 6 | This repository is an official Pytorch implementation of "[3D Human Pose Lifting with Grid Convolution](http://arxiv.org/abs/2302.08760)", 7 | dubbed GridConv. The paper is published in AAAI 2023 as an **oral presentation**. 8 | 9 | GridConv is a new powerful representation learning paradigm to lift a 2D human pose 10 | to its 3D estimation, which relies on a learnable regular weave-like grid pose 11 | representation instead of the predominant irregular graph structures. 12 | 13 | ![](assets/framework.jpg) 14 | *Figure 1. Overview of grid lifting network regressing 3D human pose from 2D skeleton input.* 15 | 16 | Regarding the definition and implementation of SGT designs and grid convolution layers, please refer to our paper 17 | for thorough interpretations. 18 | 19 | ## Installation 20 | 21 | Our experiments are conducted on an GPU server with the Ubuntu 18.04 LTS system, Python 2.7, and PyTorch 1.4. 22 | 23 | ``` 24 | cd GridConv 25 | conda env create -f environment.yml 26 | conda activate gridconv 27 | ``` 28 | 29 | ## Dataset Setup 30 | 31 | --- 32 | ### Human3.6M 33 | - Get preprocessed `h36m.zip` ([Google Drive](https://drive.google.com/file/d/168_bVbJA0zMz37_IgYP18XECYT-_FuHM/view?usp=share_link)) 34 | - `mv h36m.zip ${GridConv_repo}/src/` 35 | - `unzip h36m.zip` 36 | 37 | 38 | ### Customized Dataset 39 | - Directory structure should look like: 40 | ``` 41 | ${GridConv_repo} 42 | ├──src 43 | ├── data 44 | ├── DATASET_NAME 45 | ├── train_custom_2d_unnorm.pth.tar 46 | ├── train_custom_3d_unnorm.pth.tar 47 | ├── test_custom_2d_unnorm.pth.tar 48 | ├── test_custom_3d_unnorm.pth.tar 49 | ``` 50 | - `*_2d_unnorm.pth.tar` are `dict`, whose keys are `(SUBJECT, ACTION, FILE_NAME)` 51 | and values are 2d positions with shape of `(N, 34)`. 52 | - `*_3d_unnorm.pth.tar` are `dict`, whose keys are `(SUBJECT, ACTION, FILE_NAME)` 53 | and values are `dict` of `{ 'pelvis':N*3, 'joints_3d':N*51, 'camera':[fx,fy,cx,cy] }`. 54 | 55 | 56 | ## Results and Models 57 | ![](assets/internet%20video.gif) 58 | *Figure 2. Qualitative results on Internet videos.* 59 | 60 | Grid lifting network with 2 residual blocks of D-GridConv, 256 latent channels, 5x5 grid size, 61 | trained on Human3.6M trainset for 100 epochs. 62 | 63 | Evaluation results of pretrained models on Human3.6M testset (S9, S11): 64 | 65 | | 2D Detections | SGT design | MPJPE | PA-MPJPE| Google Drive | 66 | |:------:|:------:|:----:|:----:|:----:| 67 | |GT|Handcrafted|37.15|28.32|[model](https://drive.google.com/file/d/1cH9ZhvRe-1dmzczwa2JnyvCmqK1YJcJW/view?usp=share_link)| 68 | |GT|Learnable|36.39|28.29|[model](https://drive.google.com/file/d/1q7YqGKl-i799nRw_oPeL07B5yy__hQP7/view?usp=share_link)| 69 | |HRNet|Handcrafted|47.93|37.85|[model](https://drive.google.com/file/d/14InSsbMeWInM1X5JYWxod0_h3ptXS8O3/view?usp=share_link)| 70 | |HRNet|Learnable|47.56|37.43|[model](https://drive.google.com/file/d/1O45DjCEcKE74c5Nw939Woie0o9lqll54/view?usp=share_link)| 71 | 72 | ## Evaluation of pretrained models 73 | 74 | --- 75 | Test on HRNet input using handcrafted SGT: 76 | ``` 77 | cd ./src 78 | python main.py --eval --input hrnet \ 79 | --load pretrained_model/hrnet_d-gridconv.pth.tar \ 80 | --lifting_model dgridconv --padding_mode c z 81 | ``` 82 | Test on HRNet input using learnable SGT: 83 | ``` 84 | python main.py --eval --input hrnet \ 85 | --load pretrained_model/hrnet_d-gridconv_autogrids.pth.tar \ 86 | --lifting_model dgridconv_autogrids --padding_mode c z 87 | ``` 88 | Test on ground truth input using handcrafted SGT: 89 | ``` 90 | python main.py --eval --input gt \ 91 | --load pretrained_model/gt_d-gridconv.pth.tar \ 92 | --lifting_model dgridconv --padding_mode c r 93 | ``` 94 | Test on ground truth input using learnable SGT: 95 | ``` 96 | python main.py --eval --input gt \ 97 | --load pretrained_model/gt_d-gridconv_autogrids.pth.tar \ 98 | --lifting_model dgridconv_autogrids --padding_mode c r 99 | ``` 100 | 101 | 102 | 103 | ## Training the model from scratch 104 | 105 | --- 106 | If you want to reproduce the results of our pretrained models, run the following commands. 107 | 108 | ``` 109 | python main.py --exp hrnet_dgridconv-autogrids_5x5 \ 110 | --input hrnet --lifting_model dgridconv_autogrids \ 111 | --grid_shape 5 5 --num_block 2 --hidsize 256 \ 112 | --padding_mode c z 113 | ``` 114 | Training on 1 1080Ti GPU typically costs about 20 minute per epoch. We train each model for 100 epochs 115 | with Adam optimizer. Several settings will influence the performance: 116 | - `--grid_shape H W`, we set grid pose as 5x5 size as default. 117 | When enabling learnable SGT, grid size can be set as arbitrary values and may have influence on the accuracy. 118 | - `--padding_mode c/z/r c/z/r`, we pad grid pose with 1x1 border before delivering into 119 | `nn.Conv2d`. c/z/r denote respectively `ciruclar / zeros / replicate` padding. We found 120 | `(c,r)` works better for GT input and `(c,z)` for HRNet input. 121 | 122 | See [src/tool/argument.py](src/tool/argument.py) for more details about argument setups. 123 | 124 | ## Citation 125 | 126 | --- 127 | If you find our work useful in your research, please consider citing: 128 | ``` 129 | @inproceedings{kang2023gridconv, 130 | title={3D Human Pose Lifting with Grid Convolution}, 131 | author={Yangyuxuan Kang and Yuyang Liu and Anbang Yao and Shandong Wang and Enhua Wu}, 132 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 133 | year={2023}, 134 | } 135 | ``` 136 | 137 | ## License 138 | 139 | --- 140 | GridConv is released under the Apache license. We encourage use 141 | for both research and commercial purposes, as long as proper attribution is given. 142 | 143 | ## Acknowledgement 144 | 145 | --- 146 | This repository is built based on [ERD_3DPose](https://github.com/kyang-06/ERD_3DPose), [3d_pose_baseline_pytorch](https://github.com/weigq/3d_pose_baseline_pytorch), 147 | and fine-tuned HRNet detection is fetched from [EvoSkeleton](https://github.com/Nicholasli1995/EvoSkeleton). 148 | We thank the authors for kindly releasing the codes. 149 | -------------------------------------------------------------------------------- /assets/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSVAI/GridConv/f151b118e9d455b8dc45155d3a494b64e95f335a/assets/framework.jpg -------------------------------------------------------------------------------- /assets/internet video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSVAI/GridConv/f151b118e9d455b8dc45155d3a494b64e95f335a/assets/internet video.gif -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gridconv 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - blas=1.0=mkl 8 | - configparser=4.0.2=py27_0 9 | - cudatoolkit=10.1.243=h036e899_8 10 | - hdf5=1.10.2=hba1933b_1 11 | - intel-openmp=2022.0.1=h06a4308_3633 12 | - ipykernel=4.10.0=py27_0 13 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 14 | - ipywidgets=7.6.0=pyhd3eb1b0_1 15 | - libopencv=3.4.2=hb342d67_1 16 | - mkl=2020.2=256 17 | - mkl-service=2.3.0=py27he904b0f_0 18 | - mkl_fft=1.0.15=py27ha843d7b_0 19 | - mkl_random=1.1.0=py27hd6b4f25_0 20 | - ninja=1.10.2=h5e70eb0_2 21 | - numpy=1.16.6=py27hbc911f0_0 22 | - numpy-base=1.16.6=py27hde5b4d6_0 23 | - opencv=3.4.2=py27h6fd60c2_1 24 | - pillow=6.2.1=py27h34e0f95_0 25 | - pip=19.3.1=py27_0 26 | - py-opencv=3.4.2=py27hb342d67_1 27 | - pycparser=2.20=py_2 28 | - pyparsing=2.4.7=pyhd3eb1b0_0 29 | - python=2.7.18=ha1903f6_2 30 | - python-dateutil=2.8.2=pyhd3eb1b0_0 31 | - pytorch=1.4.0=py2.7_cuda10.1.243_cudnn7.6.3_0 32 | - six=1.16.0=pyhd3eb1b0_1 33 | - torchaudio=0.4.0=py27 34 | - torchvision=0.5.0=py27_cu101 35 | - wheel=0.37.1=pyhd3eb1b0_0 36 | - widgetsnbextension=3.5.1=py27_0 37 | - zlib=1.2.11=h7f8727e_4 38 | - pip: 39 | - ipdb==0.13.9 40 | - ipython==5.10.0 41 | - kiwisolver==1.1.0 42 | - matplotlib==2.2.5 43 | - pickleshare==0.7.5 44 | - progress==1.6 45 | - protobuf==3.17.3 46 | - pyyaml==5.4.1 47 | - scikit-learn==0.20.4 48 | - scipy==1.2.3 49 | - sklearn==0.0 50 | - subprocess32==3.5.4 51 | - tqdm==4.63.0 52 | -------------------------------------------------------------------------------- /src/base_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | 6 | from network.gridconv import GridLiftingNetwork 7 | from network.dgridconv import DynamicGridLiftingNetwork 8 | from network.dgridconv_autogrids import AutoDynamicGridLiftingNetwork 9 | from dataset.human36m import Human36M 10 | 11 | def get_dataloader(opt, is_train=False, shuffle=False): 12 | if not is_train and opt.input != 'gt': 13 | exclude_drift_data = True 14 | else: 15 | exclude_drift_data = False 16 | actual_data_dir = os.path.join(opt.data_rootdir, opt.input) 17 | 18 | dataset = Human36M(data_path=actual_data_dir, is_train=is_train, exclude_drift_data=exclude_drift_data, prepare_grid=opt.prepare_grid) 19 | 20 | dataloader = torch.utils.data.DataLoader( 21 | dataset=dataset, 22 | batch_size=opt.batch if is_train else opt.test_batch, 23 | shuffle=shuffle, 24 | num_workers=0, 25 | pin_memory=True 26 | ) 27 | return dataloader 28 | 29 | def get_lifting_model(opt): 30 | if opt.lifting_model == 'gridconv': 31 | model = GridLiftingNetwork(hidden_size=opt.hidsize, 32 | num_block=opt.num_block) 33 | elif opt.lifting_model == 'dgridconv': 34 | model = DynamicGridLiftingNetwork(hidden_size=opt.hidsize, 35 | num_block=opt.num_block, 36 | grid_shape=opt.grid_shape, 37 | padding_mode=opt.padding_mode) 38 | elif opt.lifting_model == 'dgridconv_autogrids': 39 | model = AutoDynamicGridLiftingNetwork(hidden_size=opt.hidsize, 40 | num_block=opt.num_block, 41 | grid_shape=opt.grid_shape, 42 | padding_mode=opt.padding_mode, 43 | autosgt_prior=opt.autosgt_prior) 44 | else: 45 | raise Exception('Unexpected argument, %s' % opt.lifting_model) 46 | model = model.cuda() 47 | if opt.load: 48 | ckpt = torch.load(opt.load) 49 | model.load_state_dict(ckpt['state_dict']) 50 | return model 51 | 52 | def get_optimizer(model, opt): 53 | optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr) 54 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=opt.lr_gamma) 55 | 56 | return optimizer, scheduler 57 | 58 | 59 | def get_loss(opt): 60 | if opt.loss == 'l2': 61 | criterion = nn.MSELoss(reduction='mean').cuda() 62 | elif opt.loss == 'sqrtl2': 63 | criterion = lambda output, target: torch.mean(torch.norm(output - target, dim=-1)) 64 | elif opt.loss == 'l1': 65 | criterion = nn.L1Loss(reduction='mean').cuda() 66 | else: 67 | raise Exception('Unknown loss type %s' % opt.loss) 68 | 69 | return criterion 70 | -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSVAI/GridConv/f151b118e9d455b8dc45155d3a494b64e95f335a/src/dataset/__init__.py -------------------------------------------------------------------------------- /src/dataset/human36m.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | from tool import util 6 | import sys 7 | from tqdm import tqdm 8 | 9 | 10 | S9_drift_fname_list = ['Waiting 1.60457274', 'Greeting.60457274', 'Greeting.58860488', 'SittingDown 1.58860488', 11 | 'Waiting 1.54138969', 'SittingDown 1.54138969', 'Waiting 1.55011271', 'Greeting.54138969', 12 | 'Greeting.55011271', 'SittingDown 1.60457274', 'SittingDown 1.55011271', 'Waiting 1.58860488'] 13 | 14 | 15 | class Human36M(Dataset): 16 | def __init__(self, data_path, is_train, exclude_drift_data, prepare_grid): 17 | self.data_path = data_path 18 | self.exclude_drift_data = exclude_drift_data 19 | self.num_jts = 17 20 | self.inp, self.out = [], [] 21 | self.meta = {'info':[]} 22 | self.confidence_2d = [] 23 | self.subject_list = ['S1','S5','S6','S7','S8'] if is_train else ['S9', 'S11'] 24 | self.prepare_grid = prepare_grid 25 | 26 | data_2d = {} 27 | data_3d = {} 28 | self.phase = 'train' if is_train else 'test' 29 | 30 | for data_prefix in [self.phase]: 31 | data_2d_file = '%s_custom_2d_unnorm.pth.tar' % data_prefix 32 | data_3d_file = '%s_custom_3d_unnorm.pth.tar' % data_prefix 33 | cur_data_2d = torch.load(os.path.join(data_path, data_2d_file)) 34 | cur_data_3d = torch.load(os.path.join(data_path, data_3d_file)) 35 | data_2d.update(cur_data_2d) 36 | data_3d.update(cur_data_3d) 37 | 38 | ordered_key = sorted(data_2d.keys()) 39 | ordered_key = list(filter(lambda x: x[0] in self.subject_list, ordered_key)) 40 | sample_step = 1 41 | for key in tqdm(ordered_key): 42 | sub, act, fname = key 43 | fullact = fname.split('.')[0] 44 | num_f = data_2d[key].shape[0] 45 | if (sub == 'S11') and (fullact == 'Directions'): 46 | continue 47 | if self.exclude_drift_data and sub == 'S9' and fname in S9_drift_fname_list: 48 | continue 49 | for i in range(0, num_f, sample_step): 50 | p2d_ori = data_2d[key][i].reshape(self.num_jts, 2) 51 | p3d_ori = data_3d[key]['joint_3d'][i].reshape(self.num_jts, 3) 52 | 53 | p2d = (p2d_ori - 500) / 500. 54 | p3d = p3d_ori / 1000. 55 | self.inp.append(p2d) 56 | self.out.append(p3d) 57 | self.meta['info'].append({'subject':sub, 'action':fullact, 'camid':fname.split('.')[-1], 'frid':i}) 58 | 59 | 60 | 61 | def __getitem__(self, index): 62 | inputs = self.inp[index].copy() 63 | outputs = self.out[index].copy() 64 | 65 | if self.prepare_grid: 66 | inputs = util.semantic_grid_trans(np.expand_dims(inputs, axis=0)).squeeze(0) 67 | outputs = util.semantic_grid_trans(np.expand_dims(outputs, axis=0)).squeeze(0) 68 | 69 | inputs = torch.Tensor(inputs).float() 70 | outputs = torch.Tensor(outputs).float() 71 | 72 | meta = self.meta['info'][index] 73 | for key in self.meta: 74 | if key != 'info': 75 | meta[key] = self.meta[key] 76 | 77 | return inputs, outputs, meta 78 | 79 | def __len__(self): 80 | return len(self.inp) -------------------------------------------------------------------------------- /src/lifting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from progress.bar import Bar 4 | import numpy as np 5 | 6 | from tool import util 7 | 8 | def train(epoch, train_loader, lifting_model, criterion, optimizer, opt): 9 | losses = util.AverageMeter() 10 | 11 | start = time.time() 12 | batch_time = 0 13 | train_loader_len = len(train_loader) 14 | bar = Bar('>>>', fill='>', max=train_loader_len) 15 | 16 | lifting_model.train() 17 | 18 | for i, (inputs, targets, _) in enumerate(train_loader): 19 | batch_size = targets.shape[0] 20 | 21 | if hasattr(lifting_model, "net_update_temperature") and epoch < opt.temp_epoch: 22 | temperature = util.get_temperature(0, epoch, opt.temp_epoch, i, train_loader_len, 23 | method='linear', max_temp=opt.max_temp, increase=False) 24 | lifting_model.net_update_temperature(temperature) 25 | 26 | inputs_gpu = inputs.cuda() 27 | targets_gpu = targets.cuda() 28 | 29 | outputs_gpu = lifting_model(inputs_gpu) 30 | optimizer.zero_grad() 31 | loss = criterion(outputs_gpu, targets_gpu) 32 | losses.update(loss.item(), batch_size) 33 | loss.backward() 34 | 35 | optimizer.step() 36 | 37 | # update summary 38 | if (i + 1) % 100 == 0: 39 | batch_time = time.time() - start 40 | start = time.time() 41 | 42 | bar.suffix = '({batch}/{size}) | batch: {batchtime:.1}ms | Total: {ttl} | ETA: {eta:} | loss: {loss:.6f}' \ 43 | .format(batch=i + 1, 44 | size=len(train_loader), 45 | batchtime=batch_time * 10.0, 46 | ttl=bar.elapsed_td, 47 | eta=bar.eta_td, 48 | loss=losses.avg) 49 | bar.next() 50 | bar.finish() 51 | return losses.avg 52 | 53 | def evaluate(test_loader, lifting_model, criterion, opt): 54 | loss_test, outputs_array, targets_array, corresponding_info = inference(test_loader, lifting_model, criterion, opt) 55 | num_sample = len(outputs_array) 56 | 57 | outputs_array_by_action = util.rearrange_by_key(outputs_array, corresponding_info) 58 | targets_array_by_action = util.rearrange_by_key(targets_array, corresponding_info) 59 | 60 | err_ttl, err_act, err_dim = evaluate_actionwise(outputs_array_by_action, targets_array_by_action, opt.procrustes) 61 | 62 | print(">>> error mean of %d samples: %.3f <<<" % (num_sample, err_ttl)) 63 | print(">>> error by dim: x: %.3f, y:%.3f, z:%.3f <<<" % (tuple(err_dim))) 64 | return loss_test, err_ttl, err_dim 65 | 66 | def inference(test_loader, lifting_model, criterion, opt): 67 | print('Inferring...') 68 | losses = util.AverageMeter() 69 | lifting_model.eval() 70 | outputs_array = [] 71 | targets_array = [] 72 | corresponding_info = [] 73 | 74 | start = time.time() 75 | batch_time = 0 76 | bar = Bar('>>>', fill='>', max=len(test_loader)) 77 | 78 | with torch.no_grad(): 79 | for i, (inputs, targets, meta) in enumerate(test_loader): 80 | batch_size = targets.shape[0] 81 | inputs_gpu = inputs.cuda() 82 | info = meta 83 | info['fullaction'] = meta['action'] 84 | info['action'] = list(map(lambda x: x.split(' ')[0], meta['action'])) 85 | outputs_gpu = lifting_model(inputs_gpu) 86 | targets_gpu = targets.cuda() 87 | 88 | loss = criterion(outputs_gpu, targets_gpu) 89 | losses.update(loss.item(), batch_size) 90 | 91 | if opt.prepare_grid: 92 | outputs_pose = util.inverse_semantic_grid_trans(outputs_gpu.cpu().data.numpy()) 93 | targets_pose = util.inverse_semantic_grid_trans(targets.data.numpy()) 94 | else: 95 | outputs_pose = outputs_gpu.cpu().data.numpy() 96 | targets_pose = targets.data.numpy() 97 | 98 | outputs_array.append(outputs_pose) 99 | targets_array.append(targets_pose) 100 | info_list = util.dict2list(info) 101 | corresponding_info += info_list 102 | 103 | bar.suffix = '({batch}/{size}) | batch: {batchtime:.1}ms | Total: {ttl} | ETA: {eta:} | loss: {loss:.6f}' \ 104 | .format(batch=i + 1, 105 | size=len(test_loader), 106 | batchtime=batch_time * 10.0, 107 | ttl=bar.elapsed_td, 108 | eta=bar.eta_td, 109 | loss=losses.avg) 110 | bar.next() 111 | bar.finish() 112 | 113 | outputs_array = np.vstack(outputs_array) 114 | targets_array = np.vstack(targets_array) 115 | return losses.avg, outputs_array, targets_array, corresponding_info 116 | 117 | def evaluate_actionwise(outputs_array, targets_array, procrustes): 118 | err_ttl = util.AverageMeter() 119 | err_act = {} 120 | err_dim = [util.AverageMeter(), util.AverageMeter(), util.AverageMeter()] 121 | 122 | for act in sorted(outputs_array.keys()): 123 | num_sample = outputs_array[act].shape[0] 124 | predict = outputs_array[act] * 1000. 125 | gt = targets_array[act] * 1000. 126 | 127 | if procrustes: 128 | pred_procrustes = [] 129 | for i in range(num_sample): 130 | _, Z, T, b, c = util.get_procrustes_transformation(gt[i], predict[i], True) 131 | pred_procrustes.append((b * predict[i].dot(T)) + c) 132 | predict = np.array(pred_procrustes) 133 | 134 | err_act[act] = (((predict - gt) ** 2).sum(-1)**0.5).mean() 135 | err_ttl.update(err_act[act], 1) 136 | for dim_i in range(len(err_dim)): 137 | err = (np.abs(predict[:, :, dim_i] - gt[:, :, dim_i])).mean() 138 | err_dim[dim_i].update(err, 1) 139 | 140 | for dim_i in range(len(err_dim)): 141 | err_dim[dim_i] = err_dim[dim_i].avg 142 | 143 | return err_ttl.avg, err_act, err_dim -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import time 4 | import numpy as np 5 | 6 | from tool import log 7 | from tool.argument import Options 8 | from lifting import train, evaluate 9 | from base_modules import get_dataloader, get_lifting_model, get_loss, get_optimizer 10 | 11 | def main(opt): 12 | date = time.strftime("%y_%m_%d_%H_%M", time.localtime()) 13 | log.save_options(opt, opt.ckpt) 14 | 15 | lifting_model = get_lifting_model(opt) 16 | 17 | print(">>> Loading dataset...") 18 | if not opt.eval: 19 | train_loader = get_dataloader(opt, is_train=True, shuffle=True) 20 | test_loader = get_dataloader(opt, is_train=False, shuffle=False) 21 | 22 | criterion = get_loss(opt) 23 | 24 | if opt.eval: 25 | logger = log.Logger(os.path.join(opt.ckpt, 'inference-%s.txt' % date)) 26 | logger.set_names(['loss_test', 'err_test', 'err_x', 'err_y', 'err_z']) 27 | print(">>> Test lifting<<<") 28 | loss_test, err_test, err_dim = evaluate(test_loader, lifting_model, criterion, opt) 29 | logger.addmsg('lifting') 30 | logger.append([loss_test, err_test, err_dim[0], err_dim[1], err_dim[2]], 31 | ['float', 'float', 'float', 'float', 'float']) 32 | return 33 | 34 | logger = log.Logger(os.path.join(opt.ckpt, 'log-%s.txt' % date)) 35 | logger.set_names(['epoch', 'lr', 'loss_train', 'loss_test', 'err_test', 'err_x', 'err_y', 'err_z']) 36 | optimizer, scheduler = get_optimizer(lifting_model, opt) 37 | 38 | err_best = np.inf 39 | 40 | for epoch in range(opt.epoch): 41 | lr_now = scheduler.get_lr()[0] 42 | print('==========================') 43 | print('>>> epoch: {} | lr: {:.8f}'.format(epoch + 1, lr_now)) 44 | 45 | loss_train = train(epoch, train_loader, lifting_model, criterion, optimizer, opt) 46 | loss_test, err_test, err_dim = evaluate(test_loader, lifting_model, criterion, opt) 47 | 48 | if epoch % opt.lr_decay == 0: 49 | scheduler.step() 50 | 51 | logger.append([epoch+1, lr_now, loss_train, loss_test, err_test, err_dim[0], err_dim[1], err_dim[2]], 52 | ['int', 'float', 'float', 'float', 'float', 'float', 'float', 'float']) 53 | 54 | is_best = err_test < err_best 55 | err_best = min(err_test, err_best) 56 | 57 | # save ckpt 58 | stored_model_weight = lifting_model.state_dict() 59 | log.save_ckpt({'epoch': epoch+1, 60 | 'lr': lr_now, 61 | 'error': err_test, 62 | 'state_dict': stored_model_weight, 63 | 'optimizer': optimizer}, 64 | ckpt_path=opt.ckpt, 65 | is_best=is_best) 66 | 67 | 68 | if __name__ == '__main__': 69 | opt = Options().parse() 70 | main(opt) -------------------------------------------------------------------------------- /src/network/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSVAI/GridConv/f151b118e9d455b8dc45155d3a494b64e95f335a/src/network/__init__.py -------------------------------------------------------------------------------- /src/network/dgridconv.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import math 5 | import torch.nn.init as init 6 | 7 | class DynamicGridLiftingNetwork(nn.Module): 8 | def __init__(self, 9 | hidden_size=256, 10 | num_block=2, 11 | num_jts=17, 12 | out_num_jts=17, 13 | p_dropout=0.25, 14 | input_dim=2, 15 | output_dim=3, 16 | grid_shape=(5,5), 17 | temperature=30, 18 | padding_mode=('c','r')): 19 | super(DynamicGridLiftingNetwork, self).__init__() 20 | 21 | self.hidden_size = hidden_size 22 | self.num_stage = num_block 23 | self.num_jts = num_jts 24 | self.out_num_jts = num_jts 25 | self.out_dim = output_dim 26 | self.p_dropout = p_dropout 27 | 28 | self.input_size = num_jts * input_dim 29 | self.output_size = out_num_jts * output_dim 30 | 31 | conv = TwoBranchDGridConv 32 | 33 | self.w1 = conv(in_channels=2, out_channels=self.hidden_size, kernel_size=3, padding_mode=padding_mode, bias=True) 34 | self.batch_norm1 = nn.BatchNorm2d(self.hidden_size) 35 | self.dropout = nn.Dropout2d(p=self.p_dropout) 36 | 37 | self.atten_conv1 = DynamicAttention2D(in_planes=input_dim, out_planes=self.hidden_size, grid_shape=grid_shape, kernel_size=3, 38 | ratios=1/16., temperature=temperature, groups=1) 39 | self.linear_stages = [] 40 | for l in range(num_block): 41 | self.linear_stages.append(CNNBlock(self.hidden_size, grid_shape=grid_shape, padding_mode=padding_mode, p_dropout=self.p_dropout, temperature=temperature)) 42 | self.linear_stages = nn.ModuleList(self.linear_stages) 43 | 44 | self.w2 = conv(in_channels=self.hidden_size, out_channels=3, kernel_size=3, padding_mode=padding_mode, bias=False) 45 | 46 | self.atten_conv2 = DynamicAttention2D(in_planes=self.hidden_size, out_planes=output_dim, grid_shape=grid_shape, kernel_size=3, 47 | ratios=1/16., temperature=temperature, groups=1) 48 | self.relu = nn.ReLU(inplace=True) 49 | 50 | def net_update_temperature(self, temperature): 51 | for m in self.modules(): 52 | if hasattr(m, "update_temperature"): 53 | m.update_temperature(temperature) 54 | 55 | def forward(self, x): 56 | atten1 = self.atten_conv1(x) 57 | y = self.w1(x, atten1) 58 | y = self.batch_norm1(y) 59 | y = self.relu(y) 60 | y = self.dropout(y) 61 | 62 | for i in range(self.num_stage): 63 | y = self.linear_stages[i](y) 64 | 65 | atten2 = self.atten_conv2(y) 66 | y = self.w2(y, atten2) 67 | 68 | return y 69 | 70 | class CNNBlock(nn.Module): 71 | def __init__(self, hidden_size, grid_shape, padding_mode, p_dropout=0.25, biased=True, temperature=30): 72 | super(CNNBlock, self).__init__() 73 | self.hid_size = hidden_size 74 | 75 | self.relu = nn.ReLU(inplace=True) 76 | self.kernel_size = 3 77 | conv = TwoBranchDGridConv 78 | 79 | self.w1 = conv(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3, padding_mode=padding_mode, bias=biased) 80 | self.batch_norm1 = nn.BatchNorm2d(self.hid_size) 81 | 82 | self.w2 = conv(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3, padding_mode=padding_mode, bias=biased) 83 | self.batch_norm2 = nn.BatchNorm2d(self.hid_size) 84 | 85 | self.atten_conv1 = DynamicAttention2D(in_planes=hidden_size, out_planes=hidden_size, grid_shape=grid_shape, kernel_size=self.kernel_size, 86 | ratios=1 / 16., temperature=temperature, groups=1) 87 | self.atten_conv2 = DynamicAttention2D(in_planes=hidden_size, out_planes=hidden_size, grid_shape=grid_shape, kernel_size=self.kernel_size, 88 | ratios=1 / 16., temperature=temperature, groups=1) 89 | 90 | self.dropout = nn.Dropout2d(p=p_dropout) 91 | 92 | 93 | def forward(self, x): 94 | atten1 = self.atten_conv1(x) 95 | y = self.w1(x,atten1) 96 | y = self.batch_norm1(y) 97 | y = self.relu(y) 98 | y = self.dropout(y) 99 | 100 | atten2 = self.atten_conv2(y) 101 | y = self.w2(y,atten2) 102 | y = self.batch_norm2(y) 103 | y = self.relu(y) 104 | y = self.dropout(y) 105 | 106 | 107 | out = x + y 108 | 109 | return out 110 | 111 | class TwoBranchDGridConv(nn.Module): 112 | def __init__(self, in_channels, out_channels, kernel_size, padding_mode, bias=False): 113 | super(TwoBranchDGridConv, self).__init__() 114 | self.kernel_size = kernel_size 115 | self.in_chn = in_channels 116 | self.out_chn = out_channels 117 | self.branch1_weight = nn.Parameter(torch.zeros(out_channels, in_channels, self.kernel_size, self.kernel_size)) 118 | self.branch2_weight = nn.Parameter(torch.zeros(out_channels, in_channels, self.kernel_size, self.kernel_size)) 119 | self.padding_mode = padding_mode 120 | if bias: 121 | self.branch1_bias = nn.Parameter(torch.zeros(out_channels)) 122 | self.branch2_bias = nn.Parameter(torch.zeros(out_channels)) 123 | else: 124 | self.register_parameter('branch1_bias', None) 125 | self.register_parameter('branch2_bias', None) 126 | 127 | self.reset_parameters() 128 | 129 | def reset_parameters(self): 130 | for weight, bias in [[self.branch1_weight, self.branch1_bias], [self.branch2_weight, self.branch2_bias]]: 131 | init.kaiming_uniform_(weight, a=math.sqrt(5)) 132 | if bias is not None: 133 | fan_in, _ = init._calculate_fan_in_and_fan_out(weight) 134 | bound = 1 / math.sqrt(fan_in) 135 | init.uniform_(bias, -bound, bound) 136 | 137 | def unfolding_conv(self, x_pad, weight, bias, atten): 138 | kernel_size = self.branch1_weight.shape[-2] 139 | batch_size, cin, h_pad, w_pad = x_pad.shape # B*C*7*7 140 | h = h_pad - kernel_size + 1 141 | w = w_pad - kernel_size + 1 142 | x_unfold = F.unfold(x_pad, (kernel_size, kernel_size)) # B*(C*k*k)*num_block, num_block=5*5 143 | x_unfold_avg_weight = (x_unfold.reshape(batch_size, cin, kernel_size*kernel_size, h*w) * atten.reshape(batch_size, 1, kernel_size*kernel_size, h*w)).reshape(batch_size, cin*kernel_size*kernel_size, h*w) 144 | out = x_unfold_avg_weight.transpose(1,2).matmul(weight.view(weight.shape[0], -1).t()).transpose(1,2) # B*(C*k*K)*(5*5) 145 | out_fold = F.fold(out, (h, w), (1,1)) # B*Cout*5*5 146 | if bias is not None: 147 | out_fold = out_fold + bias.reshape(1, -1, 1, 1) 148 | 149 | return out_fold 150 | 151 | 152 | def forward(self, x, atten=None): 153 | pad_size = self.kernel_size // 2 154 | padding_kwargs = { 155 | 'c':dict(mode='circular'), 156 | 'z':dict(mode='constant', value=0), 157 | 'r':dict(mode='replicate') 158 | } 159 | x_branch1 = F.pad(x, [pad_size, pad_size, pad_size, pad_size], **padding_kwargs[self.padding_mode[0]]) 160 | x_branch2 = F.pad(x, [pad_size, pad_size, pad_size, pad_size], **padding_kwargs[self.padding_mode[1]]) 161 | 162 | y_branch1 = self.unfolding_conv(x_branch1, weight=self.branch1_weight, bias=self.branch1_bias, atten=atten) 163 | y_branch2 = self.unfolding_conv(x_branch2, weight=self.branch2_weight, bias=self.branch2_bias, atten=atten) 164 | 165 | out = y_branch1 + y_branch2 166 | 167 | return out 168 | 169 | class DynamicAttention2D(nn.Module): 170 | def __init__(self, in_planes, out_planes, grid_shape, kernel_size, ratios, temperature, init_weight=True, 171 | min_channel=16, groups=1): 172 | super(DynamicAttention2D, self).__init__() 173 | self.temperature = temperature 174 | self.kernel_size = kernel_size 175 | self.in_planes = in_planes 176 | self.out_planes = out_planes 177 | self.groups = groups 178 | self.attention_channel = max(int(in_planes * ratios), min_channel) 179 | self.out_spatial_size = grid_shape 180 | 181 | self.avgpool = nn.AdaptiveAvgPool2d(1) 182 | self.fc = nn.Conv2d(in_planes, self.attention_channel, 1, bias=False) 183 | self.bn = nn.BatchNorm2d(self.attention_channel) 184 | self.sigmoid = nn.Sigmoid() 185 | self.relu = nn.ReLU(inplace=True) 186 | 187 | self.position_fc = nn.Conv2d(self.attention_channel, self.kernel_size * self.kernel_size * grid_shape[0] * grid_shape[1], 1, bias=True) 188 | 189 | 190 | if init_weight: 191 | self._initialize_weights() 192 | 193 | self.forward_func = self.forward_vanila 194 | 195 | def _initialize_weights(self): 196 | for m in self.modules(): 197 | if isinstance(m, nn.Conv2d): 198 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 199 | if m.bias is not None: 200 | nn.init.constant_(m.bias, 0) 201 | if isinstance(m ,nn.BatchNorm2d): 202 | nn.init.constant_(m.weight, 1) 203 | nn.init.constant_(m.bias, 0) 204 | 205 | def update_temperature(self, temperature): 206 | self.temperature = temperature 207 | 208 | def forward_vanila(self, x): 209 | x = self.relu(self.bn(self.fc(self.avgpool(x)))) 210 | 211 | x = self.position_fc(x).view(x.size(0), self.kernel_size**2, self.out_spatial_size[0], self.out_spatial_size[1]) 212 | x = self.sigmoid(x/self.temperature) 213 | 214 | return x 215 | 216 | def forward(self, x): 217 | return self.forward_func(x) 218 | 219 | 220 | 221 | 222 | 223 | 224 | -------------------------------------------------------------------------------- /src/network/dgridconv_autogrids.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import torch.nn as nn 4 | import torch 5 | import torch.nn.functional as F 6 | import ipdb 7 | import math 8 | import torch.nn.init as init 9 | from collections import Counter 10 | import numpy as np 11 | 12 | ReLU = nn.ReLU 13 | 14 | class AutoDynamicGridLiftingNetwork(nn.Module): 15 | def __init__(self, 16 | hidden_size=256, 17 | num_block=2, 18 | num_jts=17, 19 | out_num_jts=17, 20 | p_dropout=0.25, 21 | input_dim=2, 22 | output_dim = 3, 23 | temperature=30, 24 | grid_shape=(5,5), 25 | padding_mode=('c','z'), 26 | autosgt_prior='standard'): 27 | super(AutoDynamicGridLiftingNetwork, self).__init__() 28 | 29 | self.linear_size = hidden_size 30 | self.num_stage = num_block 31 | self.num_jts = num_jts 32 | self.out_num_jts = num_jts 33 | self.inp_dim = input_dim 34 | self.out_dim = output_dim 35 | self.p_dropout = p_dropout 36 | 37 | self.input_size = num_jts * input_dim 38 | self.output_size = out_num_jts * output_dim 39 | 40 | conv3 = TwoBranchDGridConv 41 | 42 | self.w1 = conv3(in_channels=2, out_channels=self.linear_size, kernel_size=3, padding_mode=padding_mode, bias=True) 43 | self.batch_norm1 = nn.BatchNorm2d(self.linear_size) 44 | self.dropout = nn.Dropout2d(p=self.p_dropout) 45 | 46 | self.atten_conv1 = DynamicAttention2D(in_planes=input_dim, out_planes=self.linear_size, kernel_size=3, 47 | spatial_size=grid_shape, ratios=1/16., temperature=temperature, groups=1) 48 | self.linear_stages = [] 49 | for l in range(num_block): 50 | self.linear_stages.append(CNNBlock(self.linear_size, grid_shape=grid_shape, padding_mode=padding_mode, p_dropout=self.p_dropout, temperature=temperature)) 51 | self.linear_stages = nn.ModuleList(self.linear_stages) 52 | 53 | self.w2 = conv3(in_channels=self.linear_size, out_channels=3, kernel_size=3, padding_mode=padding_mode, bias=False) 54 | 55 | self.atten_conv2 = DynamicAttention2D(in_planes=self.linear_size, out_planes=output_dim, kernel_size=3, 56 | spatial_size=grid_shape, ratios=1/16., temperature=temperature, groups=1) 57 | self.relu = ReLU(inplace=True) 58 | 59 | self.grid_shape = list(grid_shape) 60 | self.sgt_layer = AutoSGT(num_jts=num_jts, grid_shape=grid_shape, autosgt_prior=autosgt_prior) 61 | 62 | def net_update_temperature(self, temperature): 63 | for m in self.modules(): 64 | if hasattr(m, "update_temperature"): 65 | m.update_temperature(temperature) 66 | 67 | def forward(self, x, gumbel_temp=1.0, use_gumbel_noise=False, is_training=False): 68 | batch_size = x.shape[0] 69 | sgt_trans_mat_hard = self.sgt_layer(gumbel_temp=gumbel_temp, use_gumbel_noise=use_gumbel_noise, is_training=is_training).repeat([batch_size, 1, 1]) 70 | 71 | x = torch.bmm(sgt_trans_mat_hard, x) 72 | x = x.reshape([batch_size] + list(self.grid_shape) + [self.inp_dim]).permute([0, 3, 1, 2]) # B*HW*C -> B*C*H*W 73 | 74 | atten1 = self.atten_conv1(x) 75 | y = self.w1(x, atten1) 76 | y = self.batch_norm1(y) 77 | y = self.relu(y) 78 | y = self.dropout(y) 79 | 80 | for i in range(self.num_stage): 81 | y = self.linear_stages[i](y) 82 | 83 | atten2 = self.atten_conv2(y) 84 | y = self.w2(y, atten2) 85 | 86 | y = y.permute([0, 2, 3, 1]).reshape(batch_size, np.prod(self.grid_shape), self.out_dim) # B*C*H*W -> B*HW*C 87 | sgt_trans_mat_inverse = sgt_trans_mat_hard.permute([0, 2, 1]) 88 | joint_reweight = sgt_trans_mat_inverse.sum(dim=-1, keepdim=True) + 1e-8 89 | y = torch.bmm(sgt_trans_mat_inverse, y) / joint_reweight 90 | 91 | return y 92 | 93 | class CNNBlock(nn.Module): 94 | def __init__(self, linear_size, grid_shape, padding_mode, p_dropout=0.25, biased=True, temperature=30): 95 | super(CNNBlock, self).__init__() 96 | self.l_size = linear_size 97 | 98 | self.relu = ReLU(inplace=True) 99 | self.kernel_size = 3 100 | conv3 = TwoBranchDGridConv 101 | 102 | self.w1 = conv3(in_channels=linear_size, out_channels=linear_size, kernel_size=3, padding_mode=padding_mode, bias=biased) 103 | self.batch_norm1 = nn.BatchNorm2d(self.l_size) 104 | 105 | self.w2 = conv3(in_channels=linear_size, out_channels=linear_size, kernel_size=3, padding_mode=padding_mode, bias=biased) 106 | self.batch_norm2 = nn.BatchNorm2d(self.l_size) 107 | 108 | self.atten_conv1 = DynamicAttention2D(in_planes=linear_size, out_planes=linear_size, kernel_size=self.kernel_size, 109 | spatial_size=grid_shape, ratios=1 / 16., temperature=temperature, groups=1) 110 | self.atten_conv2 = DynamicAttention2D(in_planes=linear_size, out_planes=linear_size, kernel_size=self.kernel_size, 111 | spatial_size=grid_shape, ratios=1 / 16., temperature=temperature, groups=1) 112 | 113 | self.dropout = nn.Dropout2d(p=p_dropout) 114 | 115 | def forward(self, x): 116 | atten1 = self.atten_conv1(x) 117 | y = self.w1(x,atten1) 118 | y = self.batch_norm1(y) 119 | y = self.relu(y) 120 | y = self.dropout(y) 121 | 122 | atten2 = self.atten_conv2(y) 123 | y = self.w2(y,atten2) 124 | y = self.batch_norm2(y) 125 | y = self.relu(y) 126 | y = self.dropout(y) 127 | 128 | out = x + y 129 | 130 | return out 131 | 132 | class TwoBranchDGridConv(nn.Module): 133 | def __init__(self, in_channels, out_channels, kernel_size, bias=False, padding_mode=None): 134 | super(TwoBranchDGridConv, self).__init__() 135 | self.kernel_size = kernel_size 136 | self.in_chn = in_channels 137 | self.out_chn = out_channels 138 | self.branch1_weight = nn.Parameter(torch.zeros(out_channels, in_channels, self.kernel_size, self.kernel_size)) 139 | self.branch2_weight = nn.Parameter(torch.zeros(out_channels, in_channels, self.kernel_size, self.kernel_size)) 140 | self.has_bias = bias 141 | self.padding_mode = padding_mode 142 | if bias: 143 | self.branch1_bias = nn.Parameter(torch.zeros(out_channels)) 144 | self.branch2_bias = nn.Parameter(torch.zeros(out_channels)) 145 | else: 146 | self.register_parameter('branch1_bias', None) 147 | self.register_parameter('branch2_bias', None) 148 | 149 | 150 | self.reset_parameters() 151 | 152 | def reset_parameters(self): 153 | for weight, bias in [[self.branch1_weight, self.branch1_bias], [self.branch2_weight, self.branch2_bias]]: 154 | init.kaiming_uniform_(weight, a=math.sqrt(5)) 155 | if bias is not None: 156 | fan_in, _ = init._calculate_fan_in_and_fan_out(weight) 157 | bound = 1 / math.sqrt(fan_in) 158 | init.uniform_(bias, -bound, bound) 159 | 160 | def unfolding_conv(self, x_pad, weight, bias, atten): 161 | kernel_size = self.branch1_weight.shape[-2] 162 | batch_size, cin, h_pad, w_pad = x_pad.shape 163 | h = h_pad - kernel_size + 1 164 | w = w_pad - kernel_size + 1 165 | x_unfold = F.unfold(x_pad, (kernel_size, kernel_size)) 166 | x_unfold_avg_weight = (x_unfold.reshape(batch_size, cin, kernel_size*kernel_size, h*w) * atten.reshape(batch_size, 1, kernel_size*kernel_size, h*w)).reshape(batch_size, cin*kernel_size*kernel_size, h*w) 167 | out = x_unfold_avg_weight.transpose(1,2).matmul(weight.view(weight.shape[0], -1).t()).transpose(1,2) # B*(C*k*K)*(5*5) 168 | out_fold = F.fold(out, (h, w), (1,1)) 169 | if bias is not None: 170 | out_fold = out_fold + bias.reshape(1, -1, 1, 1) 171 | 172 | return out_fold 173 | 174 | 175 | def forward(self, x, atten=None): 176 | padding_kwargs = { 177 | 'c':dict(mode='circular'), 178 | 'z':dict(mode='constant', value=0), 179 | 'r':dict(mode='replicate') 180 | } 181 | x_branch1 = F.pad(x, [1, 1, 1, 1], **padding_kwargs[self.padding_mode[0]]) 182 | x_branch2 = F.pad(x, [1, 1, 1, 1], **padding_kwargs[self.padding_mode[1]]) 183 | 184 | y_branch1 = self.unfolding_conv(x_branch1, weight=self.branch1_weight, bias=self.branch1_bias, atten=atten) 185 | y_branch2 = self.unfolding_conv(x_branch2, weight=self.branch2_weight, bias=self.branch2_bias, atten=atten) 186 | 187 | out = y_branch1 + y_branch2 188 | 189 | return out 190 | 191 | class DynamicAttention2D(nn.Module): 192 | def __init__(self, in_planes, out_planes, spatial_size, kernel_size, ratios, temperature, init_weight=True, 193 | min_channel=16, groups=1): 194 | super(DynamicAttention2D, self).__init__() 195 | self.temperature = temperature 196 | self.kernel_size = kernel_size 197 | self.spatial_size = spatial_size 198 | self.in_planes = in_planes 199 | self.out_planes = out_planes 200 | self.groups = groups 201 | self.attention_channel = max(int(in_planes * ratios), min_channel) 202 | 203 | self.avgpool = nn.AdaptiveAvgPool2d(1) 204 | self.fc = nn.Conv2d(in_planes, self.attention_channel, 1, bias=False) 205 | self.bn = nn.BatchNorm2d(self.attention_channel) 206 | self.sigmoid = nn.Sigmoid() 207 | self.relu = nn.ReLU(inplace=True) 208 | 209 | self.position_fc = nn.Conv2d(self.attention_channel, self.kernel_size * self.kernel_size * np.prod(spatial_size), 1, bias=True) 210 | 211 | 212 | if init_weight: 213 | self._initialize_weights() 214 | 215 | self.forward_func = self.forward_vanila 216 | 217 | def _initialize_weights(self): 218 | for m in self.modules(): 219 | if isinstance(m, nn.Conv2d): 220 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 221 | if m.bias is not None: 222 | nn.init.constant_(m.bias, 0) 223 | if isinstance(m ,nn.BatchNorm2d): 224 | nn.init.constant_(m.weight, 1) 225 | nn.init.constant_(m.bias, 0) 226 | 227 | def update_temperature(self, temperature): 228 | self.temperature = temperature 229 | 230 | 231 | def forward_vanila(self, x): 232 | x = self.relu(self.bn(self.fc(self.avgpool(x)))) 233 | 234 | x = self.position_fc(x).view(x.size(0), self.kernel_size**2, self.spatial_size[0], self.spatial_size[1]) 235 | x = self.sigmoid(x/self.temperature) 236 | 237 | return x 238 | 239 | def forward(self, x): 240 | return self.forward_func(x) 241 | 242 | 243 | class AutoSGT(nn.Module): 244 | def __init__(self, num_jts, grid_shape, autosgt_prior): 245 | super(AutoSGT, self).__init__() 246 | self.grid_shape = grid_shape 247 | self.J = num_jts 248 | self.HW = np.prod(grid_shape) 249 | 250 | self.register_parameter('sgt_trans_mat', torch.nn.Parameter(self.init_sgt_prior(autosgt_prior))) 251 | 252 | def forward(self, use_gumbel_noise, gumbel_temp, is_training=False): 253 | sgt_trans_mat = self.sgt_trans_mat 254 | if is_training: 255 | if use_gumbel_noise: 256 | sgt_trans_mat_hard = F.gumbel_softmax(sgt_trans_mat, tau=gumbel_temp, hard=False, dim=-1) 257 | else: 258 | dim = -1 259 | index = sgt_trans_mat.max(dim, keepdim=True)[1] 260 | y_hard = torch.zeros_like(sgt_trans_mat, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) 261 | sgt_trans_mat_hard = y_hard - sgt_trans_mat.detach() + sgt_trans_mat 262 | else: 263 | sgt_trans_mat_hard = F.one_hot(torch.argmax(sgt_trans_mat, -1)).float() 264 | 265 | return sgt_trans_mat_hard 266 | 267 | def init_sgt_prior(self, prior_type): 268 | assert self.J == 17 and self.HW == 25 269 | if prior_type == 'standard': 270 | prior_sgt_mat = torch.zeros(self.grid_shape + [self.J]) 271 | # row 0 272 | prior_sgt_mat[0, :, 7] = 1 273 | # row 1 274 | prior_sgt_mat[1, [0, -1], 0] = 1 275 | prior_sgt_mat[1, [1, 2, 3], 8] = 1 276 | # row 2 277 | prior_sgt_mat[2, 0, 4] = 1 278 | prior_sgt_mat[2, 1, 11] = 1 279 | prior_sgt_mat[2, 2, 9] = 1 280 | prior_sgt_mat[2, 3, 14] = 1 281 | prior_sgt_mat[2, 4, 1] = 1 282 | # row 3 283 | prior_sgt_mat[3, 0, 5] = 1 284 | prior_sgt_mat[3, 1, 12] = 1 285 | prior_sgt_mat[3, 2, 9] = 1 286 | prior_sgt_mat[3, 3, 15] = 1 287 | prior_sgt_mat[3, 4, 2] = 1 288 | # row 4 289 | prior_sgt_mat[4, 0, 6] = 1 290 | prior_sgt_mat[4, 1, 13] = 1 291 | prior_sgt_mat[4, 2, 10] = 1 292 | prior_sgt_mat[4, 3, 16] = 1 293 | prior_sgt_mat[4, 4, 3] = 1 294 | prior_sgt_mat = prior_sgt_mat.reshape(1, self.HW, self.J) 295 | elif prior_type == 'learnt_type1': 296 | prior_sgt_mat = torch.LongTensor([[7,4,7,1,0, 297 | 0,8,8,8,0, 298 | 4,11,9,14,1, 299 | 5,12,9,15,2, 300 | 6,13,10,16,3]]) 301 | prior_sgt_mat = F.one_hot(prior_sgt_mat, num_classes=self.J).float() # 1*self.HW*self.J 302 | elif prior_type == 'learnt_type2': 303 | prior_sgt_mat = torch.LongTensor([[0,15,7,1,0, 304 | 1,14,8,7,0, 305 | 4,0,9,13,1, 306 | 2,6,11,10,2, 307 | 5,12,14,16,3]]) 308 | prior_sgt_mat = F.one_hot(prior_sgt_mat, num_classes=self.J).float() # 1*self.HW*self.J 309 | elif prior_type == 'learnt_type3': 310 | prior_sgt_mat = torch.LongTensor([[9,7,7,10,7, 311 | 13,8,10,15,16, 312 | 9,12,7,14,1, 313 | 4,5,7,3,11, 314 | 7,6,9,2,14]]) 315 | prior_sgt_mat = F.one_hot(prior_sgt_mat, num_classes=self.J).float() # 1*self.HW*self.J 316 | elif prior_type == 'random_prob': 317 | prior_sgt_mat = torch.rand([self.HW, self.J]) 318 | prior_sgt_mat = F.softmax(prior_sgt_mat, dim=-1).unsqueeze(0) 319 | else: 320 | raise Exception() 321 | 322 | return prior_sgt_mat 323 | 324 | -------------------------------------------------------------------------------- /src/network/gridconv.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class GridLiftingNetwork(nn.Module): 4 | def __init__(self, 5 | hidden_size=256, 6 | num_block=2, 7 | num_jts=17, 8 | out_num_jts=17, 9 | p_dropout=0.25, 10 | input_dim=2, 11 | output_dim=3): 12 | super(GridLiftingNetwork, self).__init__() 13 | 14 | self.hidden_size = hidden_size 15 | self.num_stage = num_block 16 | self.num_jts = num_jts 17 | self.out_num_jts = num_jts 18 | self.out_dim = output_dim 19 | self.p_dropout = p_dropout 20 | 21 | self.input_size = num_jts * input_dim 22 | self.output_size = out_num_jts * output_dim 23 | 24 | conv = TwoBranchGridConv 25 | 26 | self.w1 = conv(in_channels=2, out_channels=self.hidden_size, kernel_size=3, bias=True) 27 | self.batch_norm1 = nn.BatchNorm2d(self.hidden_size) 28 | self.dropout = nn.Dropout2d(p=self.p_dropout) 29 | 30 | self.linear_stages = [] 31 | for l in range(num_block): 32 | self.linear_stages.append(CNNBlock(self.hidden_size, p_dropout=self.p_dropout)) 33 | self.linear_stages = nn.ModuleList(self.linear_stages) 34 | 35 | self.w2 = conv(in_channels=self.hidden_size, out_channels=3, kernel_size=3, bias=False) 36 | 37 | self.relu = nn.ReLU(inplace=True) 38 | 39 | 40 | def forward(self, x): 41 | y = self.w1(x) 42 | y = self.batch_norm1(y) 43 | y = self.relu(y) 44 | y = self.dropout(y) 45 | 46 | for i in range(self.num_stage): 47 | y = self.linear_stages[i](y) 48 | 49 | y = self.w2(y) 50 | 51 | return y 52 | 53 | class CNNBlock(nn.Module): 54 | def __init__(self, hidden_size, p_dropout=0.25, biased=True): 55 | super(CNNBlock, self).__init__() 56 | self.hid_size = hidden_size 57 | 58 | self.relu = nn.ReLU(inplace=True) 59 | 60 | conv = TwoBranchGridConv 61 | 62 | self.w1 = conv(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3, bias=biased) 63 | self.batch_norm1 = nn.BatchNorm2d(self.hid_size) 64 | 65 | self.w2 = conv(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3, bias=biased) 66 | self.batch_norm2 = nn.BatchNorm2d(self.hid_size) 67 | 68 | self.dropout = nn.Dropout2d(p=p_dropout) 69 | 70 | 71 | def forward(self, x): 72 | y = self.w1(x) 73 | y = self.batch_norm1(y) 74 | y = self.relu(y) 75 | y = self.dropout(y) 76 | 77 | y = self.w2(y) 78 | y = self.batch_norm2(y) 79 | y = self.relu(y) 80 | y = self.dropout(y) 81 | 82 | out = x + y 83 | 84 | return out 85 | 86 | class TwoBranchGridConv(nn.Module): 87 | def __init__(self, in_channels, out_channels, kernel_size, bias=False): 88 | super(TwoBranchGridConv, self).__init__() 89 | self.donut_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 90 | padding=[kernel_size // 2 * 2, kernel_size // 2 * 2], padding_mode='circular', bias=bias) 91 | self.tablet_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 92 | padding=[kernel_size // 2, kernel_size // 2], padding_mode='zero', bias=bias) 93 | 94 | def forward(self, x): 95 | y_cir = self.donut_conv(x) 96 | y_rep = self.tablet_conv(x) 97 | out = y_cir + y_rep 98 | 99 | return out 100 | 101 | 102 | -------------------------------------------------------------------------------- /src/tool/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OSVAI/GridConv/f151b118e9d455b8dc45155d3a494b64e95f335a/src/tool/__init__.py -------------------------------------------------------------------------------- /src/tool/argument.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pprint import pprint 3 | import os 4 | 5 | class Options: 6 | def __init__(self): 7 | self.parser = argparse.ArgumentParser() 8 | self.opt = None 9 | 10 | def _initial(self): 11 | # =============================================================== 12 | # General options 13 | # =============================================================== 14 | self.parser.add_argument('--data_rootdir', type=str, default='./data/') 15 | self.parser.add_argument('--input', type=str, default='gt', help='choises:{gt,cpn,sh}') 16 | 17 | self.parser.add_argument('--eval', dest='eval', action='store_true') 18 | self.parser.set_defaults(eval=False) 19 | self.parser.add_argument('--exp', type=str, default='temporary', help='name of experiment') 20 | self.parser.add_argument('--ckpt', type=str, default='checkpoint') 21 | self.parser.add_argument('--procrustes', dest='procrustes', action='store_true', 22 | help='use procrustes analysis at testing') 23 | 24 | self.parser.add_argument('--lr', type=float, default=1.0e-3) 25 | self.parser.add_argument('--lr_decay', type=int, default=10, help='milestone epoch for lr decay') 26 | self.parser.add_argument('--lr_gamma', type=float, default=0.96, help='decay weight') 27 | self.parser.add_argument('--epoch', type=int, default=200) 28 | self.parser.add_argument('--dropout', type=float, default=0.25, help='dropout probability') 29 | self.parser.add_argument('--batch', type=int, default=200) 30 | self.parser.add_argument('--test_batch', type=int, default=1000) 31 | self.parser.add_argument('--loss', type=str, default='l2') 32 | 33 | self.parser.add_argument('--max_temp', type=int, default=30) 34 | self.parser.add_argument('--temp_epoch', type=int, default=10) 35 | 36 | # =============================================================== 37 | # Model options 38 | # =============================================================== 39 | self.parser.add_argument('--lifting_model', type=str, default='gridconv', help='choices: {gridconv, dgridconv, dgridconv_autogrids}') 40 | self.parser.add_argument('--load', type=str, default=None) 41 | self.parser.add_argument('--hidsize', type=int, default=256, help='number of hidden node in nn.linear layer') 42 | self.parser.add_argument('--num_block', type=int, default=2, help='number of residual blocks') 43 | self.parser.add_argument('--padding_mode', type=str, nargs='+', default=['c','r']) 44 | self.parser.add_argument('--grid_shape', type=int, nargs='+', default=[5, 5]) 45 | self.parser.add_argument('--autosgt_prior', type=str, default='standard') 46 | 47 | 48 | def _print(self): 49 | print("\n==================Options=================") 50 | pprint(vars(self.opt), indent=4) 51 | print("==========================================\n") 52 | 53 | def parse(self): 54 | self._initial() 55 | self.opt = self.parser.parse_args() 56 | ckpt = os.path.join(self.opt.ckpt, self.opt.exp) 57 | if not os.path.isdir(ckpt): 58 | os.makedirs(ckpt) 59 | self.opt.ckpt = ckpt 60 | self.opt.prepare_grid = self.opt.lifting_model in ['gridconv', 'dgridconv'] 61 | self._print() 62 | 63 | return self.opt 64 | -------------------------------------------------------------------------------- /src/tool/log.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import absolute_import 4 | 5 | import json 6 | import os 7 | import torch 8 | import time 9 | 10 | 11 | class Logger(object): 12 | def __init__(self, fpath, title=None, resume=False): 13 | self.file = None 14 | self.resume = resume 15 | self.title = '' if not title else title 16 | if fpath is not None: 17 | if resume: 18 | self.file = open(fpath, 'r') 19 | name = self.file.readline() 20 | self.names = name.rstrip().split('\t') 21 | self.numbers = {} 22 | for _, name in enumerate(self.names): 23 | self.numbers[name] = [] 24 | 25 | for numbers in self.file: 26 | numbers = numbers.rstrip().split('\t') 27 | for i in range(0, len(numbers)): 28 | self.numbers[self.names[i]].append(numbers[i]) 29 | self.file.close() 30 | self.file = open(fpath, 'a') 31 | else: 32 | self.file = open(fpath, 'w') 33 | 34 | def set_names(self, names): 35 | if self.resume: 36 | pass 37 | self.numbers = {} 38 | self.names = names 39 | for _, name in enumerate(self.names): 40 | self.file.write(name) 41 | self.file.write('\t') 42 | self.numbers[name] = [] 43 | self.file.write('\n') 44 | self.file.flush() 45 | 46 | def append(self, member, mem_type): 47 | assert len(self.names) == len(member), '# of data does not match title' 48 | for index, mem in enumerate(member): 49 | if mem_type[index] == 'int': 50 | self.file.write("{}".format(mem)) 51 | else: 52 | self.file.write("{0:.4f}".format(mem)) 53 | self.file.write('\t') 54 | self.numbers[self.names[index]].append(mem) 55 | self.file.write('\n') 56 | self.file.flush() 57 | 58 | def addmsg(self, message): 59 | self.file.write(message + '\n') 60 | self.file.flush() 61 | 62 | def close(self): 63 | if self.file: 64 | self.file.close() 65 | 66 | 67 | def save_options(opt, path): 68 | date = time.strftime("%y_%m_%d_%H_%M", time.localtime()) 69 | file_path = os.path.join(path, 'opt-%s.json' % date) 70 | with open(file_path, 'w') as f: 71 | f.write(json.dumps(vars(opt), sort_keys=True, indent=4)) 72 | 73 | 74 | def save_ckpt(state, ckpt_path, is_best=True): 75 | if is_best: 76 | file_path = os.path.join(ckpt_path, 'ckpt_best.pth.tar') 77 | torch.save(state, file_path) 78 | else: 79 | file_path = os.path.join(ckpt_path, 'ckpt_last.pth.tar') 80 | torch.save(state, file_path) 81 | -------------------------------------------------------------------------------- /src/tool/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | def dict2list(d): 5 | # convert dict with list items to list with dict items. 6 | # eg. DL = {'a': [0, 1], 'b': [2, 3]}, LD=[{'a': 0, 'b': 2}, {'a': 1, 'b': 3}] 7 | # DL to LD 8 | l = [dict(zip(d,t)) for t in zip(*d.values())] 9 | return l 10 | 11 | class AverageMeter(object): 12 | def __init__(self): 13 | self.val = 0 14 | self.avg = 0 15 | self.sum = 0 16 | self.count = 0 17 | 18 | def update(self, val, n=1): 19 | self.val = val 20 | self.sum += val * n 21 | self.count += n 22 | self.avg = self.sum / self.count if self.count != 0 else 0 23 | 24 | def rearrange_by_key(array, guide_info, guide_key='action'): 25 | arr_rearrange = {} 26 | for i in range(len(array)): 27 | key = guide_info[i][guide_key] 28 | if key not in arr_rearrange: 29 | arr_rearrange[key] = [] 30 | arr_rearrange[key].append(array[i]) 31 | for key in arr_rearrange: 32 | arr_rearrange[key] = np.array(arr_rearrange[key]) 33 | return arr_rearrange 34 | 35 | def semantic_grid_trans(src_graph_pose): 36 | assert len(src_graph_pose.shape) == 3 # B*J*C 37 | batch_size, _, C = src_graph_pose.shape 38 | grid_pose = np.zeros([batch_size, 5, 5, C]) 39 | grid_pose[:, 0] = src_graph_pose[:, [7, 7, 7, 7, 7]] 40 | grid_pose[:, 1] = src_graph_pose[:, [0, 8, 8, 8, 0]] 41 | grid_pose[:, 2] = src_graph_pose[:, [1, 14, 0, 11, 4]] 42 | grid_pose[:, 2, 2] = src_graph_pose[:, [8, 9]].mean(1) # midpoint of neck and nose 43 | 44 | grid_pose[:, 3] = src_graph_pose[:, [2, 15, 9, 12, 5]] 45 | grid_pose[:, 4] = src_graph_pose[:, [3, 16, 10, 13, 6]] 46 | 47 | grid_pose = grid_pose.transpose([0, 3, 1, 2]) # B*C*5*5 48 | 49 | return grid_pose 50 | 51 | def inverse_semantic_grid_trans(src_grid_pose): 52 | batch_size, C = src_grid_pose.shape[:2] 53 | src_grid_pose = src_grid_pose.transpose([0, 2, 3, 1]) # B*5*5*C 54 | 55 | graph_pose = np.zeros([batch_size, 17, C]) 56 | graph_pose[:, 7] = src_grid_pose[:, 0].mean(axis=1) 57 | graph_pose[:, 0] = src_grid_pose[:, 1, [0, 4]].mean(axis=1) 58 | graph_pose[:, 8] = src_grid_pose[:, 1, [1, 2, 3]].mean(axis=1) 59 | graph_pose[:, [1, 14, 11, 4]] = src_grid_pose[:, 2, [0, 1, 3, 4]] 60 | graph_pose[:, [2, 15, 9, 12, 5]] = src_grid_pose[:, 3] 61 | graph_pose[:, [3, 16, 10, 13, 6]] = src_grid_pose[:, 4] 62 | 63 | 64 | return graph_pose 65 | 66 | def get_temperature(start_epoch, cur_epoch, total_epoch, cur_iter, total_iter, method, max_temp, pow_x=10, 67 | increase=False): 68 | 69 | if cur_epoch >= total_epoch: 70 | return 1 71 | 72 | ratio = ((cur_epoch - start_epoch) + (cur_iter / total_iter)) / (total_epoch - start_epoch) 73 | 74 | if not increase: 75 | ratio = 1.0 - ratio 76 | 77 | if method == 'linear': 78 | return 1 + ratio * (max_temp-1) 79 | elif method == 'exp': 80 | return math.exp(ratio * max_temp) 81 | elif method == 'pow': 82 | return math.pow(pow_x, ratio * max_temp) 83 | else: 84 | raise ValueError("Invalid choice for temperature") 85 | 86 | def get_procrustes_transformation(X, Y, compute_optimal_scale=False): 87 | muX = X.mean(0) 88 | muY = Y.mean(0) 89 | 90 | X0 = X - muX 91 | Y0 = Y - muY 92 | 93 | ssX = (X0 ** 2.).sum() 94 | ssY = (Y0 ** 2.).sum() 95 | 96 | # centred Frobenius norm 97 | normX = np.sqrt(ssX) 98 | normY = np.sqrt(ssY) 99 | 100 | # scale to equal (unit) norm 101 | X0 = X0 / normX 102 | Y0 = Y0 / normY 103 | 104 | # optimum rotation matrix of Y 105 | A = np.dot(X0.T, Y0) 106 | U, s, Vt = np.linalg.svd(A, full_matrices=False) 107 | V = Vt.T 108 | T = np.dot(V, U.T) 109 | 110 | # Make sure we have a rotation 111 | detT = np.linalg.det(T) 112 | V[:, -1] *= np.sign(detT) 113 | s[-1] *= np.sign(detT) 114 | T = np.dot(V, U.T) 115 | 116 | traceTA = s.sum() 117 | 118 | if compute_optimal_scale: # Compute optimum scaling of Y. 119 | b = traceTA * normX / normY 120 | d = 1 - traceTA ** 2 121 | Z = normX * traceTA * np.dot(Y0, T) + muX 122 | else: # If no scaling allowed 123 | b = 1 124 | d = 1 + ssY / ssX - 2 * traceTA * normY / normX 125 | Z = normY * np.dot(Y0, T) + muX 126 | 127 | c = muX - b * np.dot(muY, T) 128 | 129 | return d, Z, T, b, c --------------------------------------------------------------------------------