├── img └── framework.jpg ├── SAR ├── __init__.py ├── data.py └── model.py ├── LICENSE ├── README.md ├── .gitignore └── configs ├── coco-res50.py ├── wholebody-hr48.py └── coco-hr48.py /img/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kennethwdk/SAR/HEAD/img/framework.jpg -------------------------------------------------------------------------------- /SAR/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import MSRAHeatmapCoord 2 | from .model import SARHead 3 | 4 | __all__ = ['MSRAHeatmapCoord', 'SARHead'] 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Dongkai Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spatial-Aware Regression for Keypoint Localization 2 | 3 | [[`Paper`](https://openaccess.thecvf.com/content/CVPR2024/papers/Wang_Spatial-Aware_Regression_for_Keypoint_Localization_CVPR_2024_paper.pdf)] 4 | 5 | > [Spatial-Aware Regression for Keypoint Localization](https://openaccess.thecvf.com/content/CVPR2024/html/Wang_Spatial-Aware_Regression_for_Keypoint_Localization_CVPR_2024_paper.html) 6 | > Dongkai Wang, Shiliang Zhang 7 | > CVPR 2024 *Highlight* 8 | 9 | ![overview](./img/framework.jpg) 10 | 11 | ## Installation 12 | 13 | ### 1. Clone code 14 | ```shell 15 | git clone https://github.com/kennethwdk/SAR 16 | cd ./SAR 17 | ``` 18 | ### 2. Create a conda environment for this repo 19 | ```shell 20 | conda create -n sar python=3.10 21 | conda activate sar 22 | ``` 23 | ### 3. Install PyTorch following official instruction (other version may not work) 24 | ```shell 25 | conda install pytorch==2.0.1 torchvision==0.15.2 pytorch-cuda=11.7 -c pytorch -c nvidia 26 | ``` 27 | ### 4. Install other dependency python packages (do not change package version) 28 | ```shell 29 | pip install -U openmim 30 | mim install mmengine 31 | mim install "mmcv>=2.0.1" 32 | mim install "mmpose==1.3.1" 33 | ``` 34 | ### 5. Prepare dataset 35 | Download [COCO ](https://cocodataset.org/#home) and [COCO WholeBody](https://github.com/jin-s13/COCO-WholeBody) from website and put the zip file under the directory following below structure, (xxx.json) denotes their original name. 36 | 37 | ``` 38 | ./data 39 | |── coco 40 | │ └── annotations 41 | | | └──person_keypoints_train2017.json 42 | | | └──person_keypoints_val2017.json 43 | | | └──coco_wholebody_train_v1.0.json 44 | | | └──coco_wholebody_val_v1.0.json 45 | | └── images 46 | | | └──train2017 47 | | | | └──000000000009.jpg 48 | | | └──val2017 49 | | | | └──000000000139.jpg 50 | ``` 51 | ## Usage 52 | 53 | ### 1. Download trained model 54 | 55 | ```shell 56 | git lfs install 57 | git clone https://huggingface.co/d0ntcare/SAR 58 | mv SAR weights 59 | ``` 60 | 61 | ### 2. Evaluate Model 62 | 63 | ```shell 64 | # evaluate on coco val set 65 | export PYTHONPATH=`pwd`:$PYTHONPATH 66 | CUDA_VISIBLE_DEVICES=0 mim test mmpose configs/coco-res50.py --checkpoint weights/coco-res50/best_coco_AP_epoch_210.pth 67 | ``` 68 | 69 | ### 3. Train Model 70 | 71 | ```shell 72 | # train on coco 73 | export PYTHONPATH=`pwd`:$PYTHONPATH 74 | CUDA_VISIBLE_DEVICES=0,1 mim train mmpose configs/coco-res50.py --launcher pytorch --gpus 2 75 | ``` 76 | 77 | 78 | ## Citations 79 | If you find this code useful for your research, please cite our paper: 80 | 81 | ``` 82 | @InProceedings{Wang_2024_CVPR, 83 | author = {Wang, Dongkai and Zhang, Shiliang}, 84 | title = {Spatial-Aware Regression for Keypoint Localization}, 85 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 86 | month = {June}, 87 | year = {2024}, 88 | pages = {624-633} 89 | } 90 | ``` 91 | ## Contact me 92 | If you have any questions about this code or paper, feel free to contact me at 93 | dongkai.wang@pku.edu.cn. 94 | 95 | ## Acknowledgement 96 | The code is built on [mmpose](https://github.com/open-mmlab/mmpose). -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | .dist_test/ 7 | data 8 | weights/ 9 | work_dirs_old/ 10 | moniter_gpu.py 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/#use-with-ide 116 | .pdm.toml 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | -------------------------------------------------------------------------------- /configs/coco-res50.py: -------------------------------------------------------------------------------- 1 | _base_ = 'mmpose::_base_/default_runtime.py' 2 | 3 | custom_imports = dict(imports='SAR') 4 | 5 | # runtime 6 | train_cfg = dict(max_epochs=210, val_interval=70) 7 | 8 | # optimizer 9 | optim_wrapper = dict(optimizer=dict( 10 | type='Adam', 11 | lr=5e-4, 12 | )) 13 | 14 | # learning policy 15 | param_scheduler = [ 16 | dict( 17 | type='LinearLR', begin=0, end=500, start_factor=0.001, 18 | by_epoch=False), # warm-up 19 | dict( 20 | type='MultiStepLR', 21 | begin=0, 22 | end=210, 23 | milestones=[170, 200], 24 | gamma=0.1, 25 | by_epoch=True) 26 | ] 27 | 28 | # automatically scaling LR based on the actual training batch size 29 | auto_scale_lr = dict(base_batch_size=512) 30 | 31 | # hooks 32 | default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater')) 33 | 34 | # codec settings 35 | codec = dict( 36 | type='MSRAHeatmapCoord', input_size=(288, 384), heatmap_size=(72, 96), sigma=3) 37 | 38 | # model settings 39 | model = dict( 40 | type='TopdownPoseEstimator', 41 | data_preprocessor=dict( 42 | type='PoseDataPreprocessor', 43 | mean=[123.675, 116.28, 103.53], 44 | std=[58.395, 57.12, 57.375], 45 | bgr_to_rgb=True), 46 | backbone=dict( 47 | type='ResNet', 48 | depth=50, 49 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50'), 50 | ), 51 | head=dict( 52 | type='SARHead', 53 | in_channels=2048, 54 | out_channels=17, 55 | with_heatmap=True, 56 | codec=codec), 57 | test_cfg=dict( 58 | flip_test=True, 59 | flip_mode='heatmap', 60 | shift_heatmap=True, 61 | )) 62 | 63 | # base dataset settings 64 | dataset_type = 'CocoDataset' 65 | data_mode = 'topdown' 66 | data_root = 'data/coco/' 67 | 68 | # pipelines 69 | train_pipeline = [ 70 | dict(type='LoadImage'), 71 | dict(type='GetBBoxCenterScale'), 72 | dict(type='RandomFlip', direction='horizontal'), 73 | dict(type='RandomHalfBody'), 74 | dict(type='RandomBBoxTransform'), 75 | dict(type='TopdownAffine', input_size=codec['input_size']), 76 | dict(type='GenerateTarget', encoder=codec), 77 | dict(type='PackPoseInputs') 78 | ] 79 | val_pipeline = [ 80 | dict(type='LoadImage'), 81 | dict(type='GetBBoxCenterScale'), 82 | dict(type='TopdownAffine', input_size=codec['input_size']), 83 | dict(type='PackPoseInputs') 84 | ] 85 | 86 | # data loaders 87 | train_dataloader = dict( 88 | batch_size=96, 89 | num_workers=4, 90 | persistent_workers=True, 91 | sampler=dict(type='DefaultSampler', shuffle=True), 92 | dataset=dict( 93 | type=dataset_type, 94 | data_root=data_root, 95 | data_mode=data_mode, 96 | ann_file='annotations/person_keypoints_train2017.json', 97 | data_prefix=dict(img='train2017/'), 98 | pipeline=train_pipeline, 99 | )) 100 | val_dataloader = dict( 101 | batch_size=256, 102 | num_workers=2, 103 | persistent_workers=True, 104 | drop_last=False, 105 | sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), 106 | dataset=dict( 107 | type=dataset_type, 108 | data_root=data_root, 109 | data_mode=data_mode, 110 | ann_file='annotations/person_keypoints_val2017.json', 111 | bbox_file='data/coco/person_detection_results/' 112 | 'COCO_val2017_detections_AP_H_56_person.json', 113 | data_prefix=dict(img='val2017/'), 114 | test_mode=True, 115 | pipeline=val_pipeline, 116 | )) 117 | test_dataloader = val_dataloader 118 | 119 | # evaluators 120 | val_evaluator = dict( 121 | type='CocoMetric', 122 | ann_file=data_root + 'annotations/person_keypoints_val2017.json') 123 | test_evaluator = val_evaluator 124 | 125 | # test_dataloader = dict( 126 | # batch_size=256, 127 | # num_workers=2, 128 | # persistent_workers=True, 129 | # drop_last=False, 130 | # sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), 131 | # dataset=dict( 132 | # type=dataset_type, 133 | # data_root=data_root, 134 | # data_mode=data_mode, 135 | # ann_file='annotations/coco_test.json', 136 | # bbox_file='data/coco/person_detection_results/COCO_test-dev2017_detections_AP_H_609_person.json', 137 | # data_prefix=dict(img='test2017/'), 138 | # test_mode=True, 139 | # pipeline=val_pipeline, 140 | # )) 141 | # test_evaluator = dict( 142 | # type='CocoMetric', 143 | # ann_file=data_root + 'annotations/coco_test.json', 144 | # format_only=True, 145 | # outfile_prefix='SAR-Res50') 146 | 147 | -------------------------------------------------------------------------------- /configs/wholebody-hr48.py: -------------------------------------------------------------------------------- 1 | # Directly inherit the entire recipe you want to use. 2 | _base_ = ['mmpose::_base_/default_runtime.py', 3 | 'mmpose::_base_/datasets/coco_wholebody.py'] 4 | 5 | custom_imports = dict(imports='SAR') 6 | 7 | # runtime 8 | train_cfg = dict(max_epochs=210, val_interval=35) 9 | 10 | # optimizer 11 | optim_wrapper = dict(optimizer=dict( 12 | type='Adam', 13 | lr=5e-4, 14 | )) 15 | 16 | # learning policy 17 | param_scheduler = [ 18 | dict( 19 | type='LinearLR', begin=0, end=500, start_factor=0.001, 20 | by_epoch=False), # warm-up 21 | dict( 22 | type='MultiStepLR', 23 | begin=0, 24 | end=210, 25 | milestones=[170, 200], 26 | gamma=0.1, 27 | by_epoch=True) 28 | ] 29 | 30 | # automatically scaling LR based on the actual training batch size 31 | auto_scale_lr = dict(base_batch_size=512) 32 | 33 | # hooks 34 | default_hooks = dict( 35 | checkpoint=dict(save_best='coco-wholebody/AP', rule='greater')) 36 | 37 | # codec settings 38 | codec = [dict(type='MSRAHeatmapCoord', input_size=(288, 384), heatmap_size=(72, 96), sigma=3)] 39 | 40 | # model settings 41 | model = dict( 42 | type='TopdownPoseEstimator', 43 | data_preprocessor=dict( 44 | type='PoseDataPreprocessor', 45 | mean=[123.675, 116.28, 103.53], 46 | std=[58.395, 57.12, 57.375], 47 | bgr_to_rgb=True), 48 | backbone=dict( 49 | type='HRNet', 50 | in_channels=3, 51 | extra=dict( 52 | stage1=dict( 53 | num_modules=1, 54 | num_branches=1, 55 | block='BOTTLENECK', 56 | num_blocks=(4, ), 57 | num_channels=(64, )), 58 | stage2=dict( 59 | num_modules=1, 60 | num_branches=2, 61 | block='BASIC', 62 | num_blocks=(4, 4), 63 | num_channels=(48, 96)), 64 | stage3=dict( 65 | num_modules=4, 66 | num_branches=3, 67 | block='BASIC', 68 | num_blocks=(4, 4, 4), 69 | num_channels=(48, 96, 192)), 70 | stage4=dict( 71 | num_modules=3, 72 | num_branches=4, 73 | block='BASIC', 74 | num_blocks=(4, 4, 4, 4), 75 | num_channels=(48, 96, 192, 384))), 76 | init_cfg=dict( 77 | type='Pretrained', 78 | checkpoint='https://download.openmmlab.com/mmpose/' 79 | 'pretrain_models/hrnet_w48-8ef0771d.pth'), 80 | ), 81 | head=dict( 82 | type='SARHead', 83 | in_channels=48, 84 | out_channels=133, 85 | deconv_out_channels=None, 86 | with_heatmap=True, 87 | codec=codec[0]), 88 | test_cfg=dict( 89 | flip_test=True, 90 | flip_mode='heatmap', 91 | shift_heatmap=True, 92 | )) 93 | 94 | # base dataset settings 95 | dataset_type = 'CocoWholeBodyDataset' 96 | data_mode = 'topdown' 97 | data_root = 'data/coco/' 98 | 99 | # pipelines 100 | train_pipeline = [ 101 | dict(type='LoadImage'), 102 | dict(type='GetBBoxCenterScale'), 103 | dict(type='RandomFlip', direction='horizontal'), 104 | dict(type='RandomHalfBody'), 105 | dict(type='RandomBBoxTransform'), 106 | dict(type='TopdownAffine', input_size=codec[0]['input_size']), 107 | dict(type='GenerateTarget', encoder=codec), 108 | dict(type='PackPoseInputs') 109 | ] 110 | val_pipeline = [ 111 | dict(type='LoadImage'), 112 | dict(type='GetBBoxCenterScale'), 113 | dict(type='TopdownAffine', input_size=codec[0]['input_size']), 114 | dict(type='PackPoseInputs') 115 | ] 116 | 117 | # data loaders 118 | train_dataloader = dict( 119 | batch_size=36, 120 | num_workers=2, 121 | persistent_workers=True, 122 | sampler=dict(type='DefaultSampler', shuffle=True), 123 | dataset=dict( 124 | type=dataset_type, 125 | data_root=data_root, 126 | data_mode=data_mode, 127 | ann_file='annotations/coco_wholebody_train_v1.0.json', 128 | data_prefix=dict(img='train2017/'), 129 | pipeline=train_pipeline, 130 | )) 131 | val_dataloader = dict( 132 | batch_size=96, 133 | num_workers=2, 134 | persistent_workers=True, 135 | drop_last=False, 136 | sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), 137 | dataset=dict( 138 | type=dataset_type, 139 | data_root=data_root, 140 | data_mode=data_mode, 141 | ann_file='annotations/coco_wholebody_val_v1.0.json', 142 | data_prefix=dict(img='val2017/'), 143 | test_mode=True, 144 | bbox_file='data/coco/person_detection_results/' 145 | 'COCO_val2017_detections_AP_H_56_person.json', 146 | pipeline=val_pipeline, 147 | )) 148 | test_dataloader = val_dataloader 149 | 150 | val_evaluator = dict( 151 | type='CocoWholeBodyMetric', 152 | ann_file=data_root + 'annotations/coco_wholebody_val_v1.0.json') 153 | test_evaluator = val_evaluator 154 | 155 | -------------------------------------------------------------------------------- /configs/coco-hr48.py: -------------------------------------------------------------------------------- 1 | _base_ = 'mmpose::_base_/default_runtime.py' 2 | 3 | custom_imports = dict(imports='SAR') 4 | 5 | # runtime 6 | train_cfg = dict(max_epochs=210, val_interval=70) 7 | 8 | # optimizer 9 | optim_wrapper = dict(optimizer=dict( 10 | type='Adam', 11 | lr=5e-4, 12 | )) 13 | 14 | # learning policy 15 | param_scheduler = [ 16 | dict( 17 | type='LinearLR', begin=0, end=500, start_factor=0.001, 18 | by_epoch=False), # warm-up 19 | dict( 20 | type='MultiStepLR', 21 | begin=0, 22 | end=210, 23 | milestones=[170, 200], 24 | gamma=0.1, 25 | by_epoch=True) 26 | ] 27 | 28 | # automatically scaling LR based on the actual training batch size 29 | auto_scale_lr = dict(base_batch_size=512) 30 | 31 | # hooks 32 | default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater')) 33 | 34 | # codec settings 35 | codec = dict( 36 | type='MSRAHeatmapCoord', input_size=(288, 384), heatmap_size=(72, 96), sigma=3) 37 | 38 | # model settings 39 | model = dict( 40 | type='TopdownPoseEstimator', 41 | data_preprocessor=dict( 42 | type='PoseDataPreprocessor', 43 | mean=[123.675, 116.28, 103.53], 44 | std=[58.395, 57.12, 57.375], 45 | bgr_to_rgb=True), 46 | backbone=dict( 47 | type='HRNet', 48 | in_channels=3, 49 | extra=dict( 50 | stage1=dict( 51 | num_modules=1, 52 | num_branches=1, 53 | block='BOTTLENECK', 54 | num_blocks=(4, ), 55 | num_channels=(64, )), 56 | stage2=dict( 57 | num_modules=1, 58 | num_branches=2, 59 | block='BASIC', 60 | num_blocks=(4, 4), 61 | num_channels=(48, 96)), 62 | stage3=dict( 63 | num_modules=4, 64 | num_branches=3, 65 | block='BASIC', 66 | num_blocks=(4, 4, 4), 67 | num_channels=(48, 96, 192)), 68 | stage4=dict( 69 | num_modules=3, 70 | num_branches=4, 71 | block='BASIC', 72 | num_blocks=(4, 4, 4, 4), 73 | num_channels=(48, 96, 192, 384))), 74 | init_cfg=dict( 75 | type='Pretrained', 76 | checkpoint='https://download.openmmlab.com/mmpose/' 77 | 'pretrain_models/hrnet_w48-8ef0771d.pth'), 78 | ), 79 | head=dict( 80 | type='SARHead', 81 | in_channels=48, 82 | out_channels=17, 83 | with_heatmap=True, 84 | deconv_out_channels=None, 85 | codec=codec), 86 | test_cfg=dict( 87 | flip_test=True, 88 | flip_mode='heatmap', 89 | shift_heatmap=True, 90 | )) 91 | 92 | # base dataset settings 93 | dataset_type = 'CocoDataset' 94 | data_mode = 'topdown' 95 | data_root = 'data/coco/' 96 | 97 | # pipelines 98 | train_pipeline = [ 99 | dict(type='LoadImage'), 100 | dict(type='GetBBoxCenterScale'), 101 | dict(type='RandomFlip', direction='horizontal'), 102 | dict(type='RandomHalfBody'), 103 | dict(type='RandomBBoxTransform'), 104 | dict(type='TopdownAffine', input_size=codec['input_size']), 105 | dict(type='GenerateTarget', encoder=codec), 106 | dict(type='PackPoseInputs') 107 | ] 108 | val_pipeline = [ 109 | dict(type='LoadImage'), 110 | dict(type='GetBBoxCenterScale'), 111 | dict(type='TopdownAffine', input_size=codec['input_size']), 112 | dict(type='PackPoseInputs') 113 | ] 114 | 115 | # data loaders 116 | train_dataloader = dict( 117 | batch_size=36, 118 | num_workers=2, 119 | persistent_workers=True, 120 | sampler=dict(type='DefaultSampler', shuffle=True), 121 | dataset=dict( 122 | type=dataset_type, 123 | data_root=data_root, 124 | data_mode=data_mode, 125 | ann_file='annotations/person_keypoints_train2017.json', 126 | data_prefix=dict(img='train2017/'), 127 | pipeline=train_pipeline, 128 | )) 129 | val_dataloader = dict( 130 | batch_size=256, 131 | num_workers=2, 132 | persistent_workers=True, 133 | drop_last=False, 134 | sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), 135 | dataset=dict( 136 | type=dataset_type, 137 | data_root=data_root, 138 | data_mode=data_mode, 139 | ann_file='annotations/person_keypoints_val2017.json', 140 | bbox_file='data/coco/person_detection_results/' 141 | 'COCO_val2017_detections_AP_H_56_person.json', 142 | data_prefix=dict(img='val2017/'), 143 | test_mode=True, 144 | pipeline=val_pipeline, 145 | )) 146 | test_dataloader = val_dataloader 147 | 148 | # evaluators 149 | val_evaluator = dict( 150 | type='CocoMetric', 151 | ann_file=data_root + 'annotations/person_keypoints_val2017.json') 152 | test_evaluator = val_evaluator 153 | 154 | # test_dataloader = dict( 155 | # batch_size=256, 156 | # num_workers=2, 157 | # persistent_workers=True, 158 | # drop_last=False, 159 | # sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), 160 | # dataset=dict( 161 | # type=dataset_type, 162 | # data_root=data_root, 163 | # data_mode=data_mode, 164 | # ann_file='annotations/coco_test.json', 165 | # bbox_file='data/coco/person_detection_results/COCO_test-dev2017_detections_AP_H_609_person.json', 166 | # data_prefix=dict(img='test2017/'), 167 | # test_mode=True, 168 | # pipeline=val_pipeline, 169 | # )) 170 | 171 | # test_evaluator = dict( 172 | # type='CocoMetric', 173 | # ann_file=data_root + 'annotations/coco_test.json', 174 | # format_only=True, 175 | # outfile_prefix='SAR-HR48') -------------------------------------------------------------------------------- /SAR/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Optional, Tuple 3 | 4 | import numpy as np 5 | 6 | from mmpose.registry import KEYPOINT_CODECS 7 | from mmpose.codecs.base import BaseKeypointCodec 8 | from mmpose.codecs.utils.gaussian_heatmap import (generate_gaussian_heatmaps, 9 | generate_unbiased_gaussian_heatmaps) 10 | from mmpose.codecs.utils.post_processing import get_heatmap_maximum 11 | from mmpose.codecs.utils.refinement import refine_keypoints, refine_keypoints_dark 12 | 13 | @KEYPOINT_CODECS.register_module() 14 | class MSRAHeatmapCoord(BaseKeypointCodec): 15 | """Represent keypoints as heatmaps via "MSRA" approach. See the paper: 16 | `Simple Baselines for Human Pose Estimation and Tracking`_ by Xiao et al 17 | (2018) for details. 18 | 19 | Note: 20 | 21 | - instance number: N 22 | - keypoint number: K 23 | - keypoint dimension: D 24 | - image size: [w, h] 25 | - heatmap size: [W, H] 26 | 27 | Encoded: 28 | 29 | - heatmaps (np.ndarray): The generated heatmap in shape (K, H, W) 30 | where [W, H] is the `heatmap_size` 31 | - keypoint_weights (np.ndarray): The target weights in shape (N, K) 32 | 33 | Args: 34 | input_size (tuple): Image size in [w, h] 35 | heatmap_size (tuple): Heatmap size in [W, H] 36 | sigma (float): The sigma value of the Gaussian heatmap 37 | unbiased (bool): Whether use unbiased method (DarkPose) in ``'msra'`` 38 | encoding. See `Dark Pose`_ for details. Defaults to ``False`` 39 | blur_kernel_size (int): The Gaussian blur kernel size of the heatmap 40 | modulation in DarkPose. The kernel size and sigma should follow 41 | the expirical formula :math:`sigma = 0.3*((ks-1)*0.5-1)+0.8`. 42 | Defaults to 11 43 | 44 | .. _`Simple Baselines for Human Pose Estimation and Tracking`: 45 | https://arxiv.org/abs/1804.06208 46 | .. _`Dark Pose`: https://arxiv.org/abs/1910.06278 47 | """ 48 | 49 | def __init__(self, 50 | input_size: Tuple[int, int], 51 | heatmap_size: Tuple[int, int], 52 | sigma: float, 53 | unbiased: bool = False, 54 | blur_kernel_size: int = 11) -> None: 55 | super().__init__() 56 | self.input_size = input_size 57 | self.heatmap_size = heatmap_size 58 | self.sigma = sigma 59 | self.unbiased = unbiased 60 | 61 | # The Gaussian blur kernel size of the heatmap modulation 62 | # in DarkPose and the sigma value follows the expirical 63 | # formula :math:`sigma = 0.3*((ks-1)*0.5-1)+0.8` 64 | # which gives: 65 | # sigma~=3 if ks=17 66 | # sigma=2 if ks=11; 67 | # sigma~=1.5 if ks=7; 68 | # sigma~=1 if ks=3; 69 | self.blur_kernel_size = blur_kernel_size 70 | self.scale_factor = (np.array(input_size) / 71 | heatmap_size).astype(np.float32) 72 | 73 | def encode(self, 74 | keypoints: np.ndarray, 75 | keypoints_visible: Optional[np.ndarray] = None) -> dict: 76 | """Encode keypoints into heatmaps. Note that the original keypoint 77 | coordinates should be in the input image space. 78 | 79 | Args: 80 | keypoints (np.ndarray): Keypoint coordinates in shape (N, K, D) 81 | keypoints_visible (np.ndarray): Keypoint visibilities in shape 82 | (N, K) 83 | 84 | Returns: 85 | dict: 86 | - heatmaps (np.ndarray): The generated heatmap in shape 87 | (K, H, W) where [W, H] is the `heatmap_size` 88 | - keypoint_weights (np.ndarray): The target weights in shape 89 | (N, K) 90 | """ 91 | 92 | assert keypoints.shape[0] == 1, ( 93 | f'{self.__class__.__name__} only support single-instance ' 94 | 'keypoint encoding') 95 | 96 | if keypoints_visible is None: 97 | keypoints_visible = np.ones(keypoints.shape[:2], dtype=np.float32) 98 | 99 | if self.unbiased: 100 | heatmaps, keypoint_weights = generate_unbiased_gaussian_heatmaps( 101 | heatmap_size=self.heatmap_size, 102 | keypoints=keypoints / self.scale_factor, 103 | keypoints_visible=keypoints_visible, 104 | sigma=self.sigma) 105 | else: 106 | heatmaps, keypoint_weights = generate_gaussian_heatmaps( 107 | heatmap_size=self.heatmap_size, 108 | keypoints=keypoints / self.scale_factor, 109 | keypoints_visible=keypoints_visible, 110 | sigma=self.sigma) 111 | 112 | encoded = dict(heatmaps=heatmaps, keypoint_labels=keypoints / self.scale_factor, keypoint_weights=keypoint_weights) 113 | 114 | return encoded 115 | 116 | def decode(self, encoded: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 117 | """Decode keypoint coordinates from heatmaps. The decoded keypoint 118 | coordinates are in the input image space. 119 | 120 | Args: 121 | encoded (np.ndarray): Heatmaps in shape (K, H, W) 122 | 123 | Returns: 124 | tuple: 125 | - keypoints (np.ndarray): Decoded keypoint coordinates in shape 126 | (N, K, D) 127 | - scores (np.ndarray): The keypoint scores in shape (N, K). It 128 | usually represents the confidence of the keypoint prediction 129 | """ 130 | heatmaps = encoded.copy() 131 | K, H, W = heatmaps.shape 132 | 133 | keypoints, scores = get_heatmap_maximum(heatmaps) 134 | 135 | # Unsqueeze the instance dimension for single-instance results 136 | keypoints, scores = keypoints[None], scores[None] 137 | 138 | if self.unbiased: 139 | # Alleviate biased coordinate 140 | keypoints = refine_keypoints_dark( 141 | keypoints, heatmaps, blur_kernel_size=self.blur_kernel_size) 142 | 143 | else: 144 | keypoints = refine_keypoints(keypoints, heatmaps) 145 | 146 | # Restore the keypoint scale 147 | keypoints = keypoints * self.scale_factor 148 | 149 | return keypoints, scores -------------------------------------------------------------------------------- /SAR/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple, Union 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from mmcv.cnn import build_conv_layer, build_upsample_layer 6 | from mmengine.structures import InstanceData 7 | from torch import Tensor, nn 8 | import numpy as np 9 | 10 | from mmpose.registry import KEYPOINT_CODECS, MODELS 11 | from mmpose.utils.typing import (ConfigType, Features, OptConfigType, 12 | OptSampleList, Predictions) 13 | 14 | from mmpose.models.heads import BaseHead 15 | 16 | OptIntSeq = Optional[Sequence[int]] 17 | 18 | 19 | @MODELS.register_module() 20 | class SARHead(BaseHead): 21 | _version = 2 22 | 23 | def __init__(self, 24 | in_channels: Union[int, Sequence[int]], 25 | out_channels: int, 26 | deconv_out_channels: OptIntSeq = (256, 256, 256), 27 | deconv_kernel_sizes: OptIntSeq = (4, 4, 4), 28 | with_heatmap = True, 29 | codec: OptConfigType = None, 30 | init_cfg: OptConfigType = None): 31 | 32 | if init_cfg is None: 33 | init_cfg = self.default_init_cfg 34 | import math 35 | prior_prob = 0.01 36 | bias_value = -math.log((1 - prior_prob) / prior_prob) 37 | override_dict = [dict( 38 | type='Normal', 39 | layer=['Conv2d'], 40 | std=0.001, 41 | override=dict( 42 | name='logit_conv', 43 | type='Normal', 44 | std=0.001, 45 | bias=bias_value))] 46 | if with_heatmap: 47 | override_dict += [dict( 48 | type='Normal', 49 | layer=['Conv2d'], 50 | std=0.001, 51 | override=dict( 52 | name='heatmap_conv', 53 | type='Normal', 54 | std=0.001, 55 | bias=bias_value))] 56 | init_cfg = init_cfg + override_dict 57 | super().__init__(init_cfg) 58 | 59 | self.in_channels = in_channels 60 | self.out_channels = out_channels 61 | self.codec = codec 62 | 63 | if deconv_out_channels: 64 | if deconv_kernel_sizes is None or len(deconv_out_channels) != len( 65 | deconv_kernel_sizes): 66 | raise ValueError( 67 | '"deconv_out_channels" and "deconv_kernel_sizes" should ' 68 | 'be integer sequences with the same length. Got ' 69 | f'mismatched lengths {deconv_out_channels} and ' 70 | f'{deconv_kernel_sizes}') 71 | 72 | self.deconv_layers = self._make_deconv_layers( 73 | in_channels=in_channels, 74 | layer_out_channels=deconv_out_channels, 75 | layer_kernel_sizes=deconv_kernel_sizes, 76 | ) 77 | in_channels = deconv_out_channels[-1] 78 | else: 79 | self.deconv_layers = nn.Identity() 80 | 81 | self.with_heatmap = with_heatmap 82 | if self.with_heatmap: 83 | heatmap_cfg = dict( 84 | type='Conv2d', 85 | in_channels=in_channels, 86 | out_channels=out_channels, 87 | kernel_size=1) 88 | self.heatmap_conv = build_conv_layer(heatmap_cfg) 89 | 90 | logit_cfg = dict( 91 | type='Conv2d', 92 | in_channels=in_channels, 93 | out_channels=out_channels, 94 | kernel_size=1) 95 | offset_cfg = dict( 96 | type='Conv2d', 97 | in_channels=in_channels, 98 | out_channels=out_channels*2, 99 | kernel_size=1) 100 | self.logit_conv = build_conv_layer(logit_cfg) 101 | self.offset_conv = build_conv_layer(offset_cfg) 102 | 103 | self.heatmap_loss = FocalLoss() 104 | self.num_keypoints = out_channels 105 | 106 | # Register the hook to automatically convert old version state dicts 107 | self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) 108 | 109 | def _make_deconv_layers(self, in_channels: int, 110 | layer_out_channels: Sequence[int], 111 | layer_kernel_sizes: Sequence[int]) -> nn.Module: 112 | """Create deconvolutional layers by given parameters.""" 113 | 114 | layers = [] 115 | for out_channels, kernel_size in zip(layer_out_channels, 116 | layer_kernel_sizes): 117 | if kernel_size == 4: 118 | padding = 1 119 | output_padding = 0 120 | elif kernel_size == 3: 121 | padding = 1 122 | output_padding = 1 123 | elif kernel_size == 2: 124 | padding = 0 125 | output_padding = 0 126 | else: 127 | raise ValueError(f'Unsupported kernel size {kernel_size} for' 128 | 'deconvlutional layers in ' 129 | f'{self.__class__.__name__}') 130 | cfg = dict( 131 | type='deconv', 132 | in_channels=in_channels, 133 | out_channels=out_channels, 134 | kernel_size=kernel_size, 135 | stride=2, 136 | padding=padding, 137 | output_padding=output_padding, 138 | bias=False) 139 | layers.append(build_upsample_layer(cfg)) 140 | layers.append(nn.BatchNorm2d(num_features=out_channels)) 141 | layers.append(nn.ReLU(inplace=True)) 142 | in_channels = out_channels 143 | 144 | return nn.Sequential(*layers) 145 | 146 | @property 147 | def default_init_cfg(self): 148 | init_cfg = [ 149 | dict( 150 | type='Normal', layer=['Conv2d', 'ConvTranspose2d'], std=0.001), 151 | dict(type='Constant', layer='BatchNorm2d', val=1) 152 | ] 153 | return init_cfg 154 | 155 | def _sigmoid(self, x): 156 | y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4) 157 | return y 158 | 159 | @torch.no_grad() 160 | def locations(self, features): 161 | h, w = features.size()[-2:] 162 | device = features.device 163 | shifts_x = torch.arange(0, w, dtype=torch.float32, device=device) 164 | shifts_y = torch.arange(0, h, dtype=torch.float32, device=device) 165 | shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) 166 | shift_x = shift_x.reshape(-1) / w 167 | shift_y = shift_y.reshape(-1) / h 168 | locations = torch.stack((shift_x, shift_y), dim=1) 169 | locations = locations.reshape(h, w, 2).permute(2, 0, 1) 170 | return locations 171 | 172 | def forward(self, feats: Tuple[Tensor]) -> Tensor: 173 | """Forward the network. The input is multi scale feature maps and the 174 | output is the heatmap. 175 | 176 | Args: 177 | feats (Tuple[Tensor]): Multi scale feature maps. 178 | 179 | Returns: 180 | Tensor: output heatmap. 181 | """ 182 | x = feats[-1] 183 | 184 | x = self.deconv_layers(x) 185 | bs, c, h, w = x.size() 186 | logit = self.logit_conv(x).sigmoid() 187 | offset = self.offset_conv(x).reshape(bs, self.num_keypoints, 2, h, w) 188 | location = self.locations(offset)[None, None] 189 | keypoint = location - offset 190 | 191 | ret = [logit, keypoint] 192 | if self.with_heatmap: 193 | heatmap = self.heatmap_conv(x) 194 | ret.append(heatmap) 195 | 196 | return ret 197 | 198 | def predict(self, 199 | feats: Features, 200 | batch_data_samples: OptSampleList, 201 | test_cfg: ConfigType = {}) -> Predictions: 202 | 203 | assert test_cfg.get('flip_test', False) 204 | if test_cfg.get('flip_test', False): 205 | # assert False, 'flip test is not support!' 206 | 207 | # TTA: flip test -> feats = [orig, flipped] 208 | assert isinstance(feats, list) and len(feats) == 2 209 | flip_indices = batch_data_samples[0].metainfo['flip_indices'] 210 | input_size = batch_data_samples[0].metainfo['input_size'] 211 | 212 | _feats, _feats_flip = feats 213 | batch_rets = self.forward(_feats) 214 | flip_batch_rets = self.forward(_feats_flip) 215 | preds, vals = self.flip_decode(batch_rets, flip_batch_rets, flip_indices, input_size) 216 | else: 217 | batch_rets = self.forward(feats) 218 | preds, vals = self.decode(batch_rets) 219 | 220 | preds = [ 221 | InstanceData(keypoints=keypoints[None, :, :], keypoint_scores=scores[None, :]) 222 | for keypoints, scores in zip(preds, vals) 223 | ] 224 | 225 | return preds 226 | 227 | def flip_decode(self, batch_rets, flip_batch_rets, flip_indices, input_size): 228 | logit, keypoint = batch_rets[:2] 229 | bs, k, h, w = logit.shape 230 | logits = logit.reshape(bs*k, h*w) 231 | logits = logits / logits.sum(dim=1, keepdim=True) 232 | keypoints = keypoint.reshape(bs*k, 2, h*w).permute(0, 2, 1) 233 | maxvals, maxinds = logits.max(dim=1) 234 | coords = keypoints[torch.arange(bs*k, dtype=torch.long).to(keypoints.device), maxinds] 235 | 236 | logit, keypoint = flip_batch_rets[:2] 237 | bs, k, h, w = logit.shape 238 | logits = logit.reshape(bs*k, h*w) 239 | logits = logits / logits.sum(dim=1, keepdim=True) 240 | keypoints = keypoint.reshape(bs*k, 2, h*w).permute(0, 2, 1) 241 | maxvals_flip, maxinds = logits.max(dim=1) 242 | coords_flip = keypoints[torch.arange(bs*k, dtype=torch.long).to(keypoints.device), maxinds] 243 | 244 | coords_flip[:, 0] = 1 - coords_flip[:, 0] - 1.0 / (w * 4) 245 | coords_flip = coords_flip.reshape(bs, k, 2)[:, flip_indices, :].reshape(bs*k, 2) 246 | maxvals_flip = maxvals_flip.reshape(bs, k)[:, flip_indices].reshape(-1) 247 | coords = (coords + coords_flip) / 2.0 248 | coords[..., 0] *= w 249 | coords[..., 1] *= h 250 | maxvals = (maxvals + maxvals_flip) / 2.0 251 | 252 | # hmvals 253 | # heatmap = batch_rets[2] 254 | # heatmap_flip = flip_batch_rets[2] 255 | # heatmap_flip = heatmap_flip.flip(3)[:, flip_indices, :, :] 256 | # heatmap = (heatmap + heatmap_flip) / 2.0 257 | # bs, k, h, w = heatmap.shape 258 | # heatmap = heatmap.reshape(bs*k, 1, h, w).sigmoid() 259 | # coord_inds = torch.stack(( 260 | # coords[:, 0] / (w - 1) * 2 - 1, 261 | # coords[:, 1] / (h - 1) * 2 - 1, 262 | # ), dim=-1) 263 | # coord_inds = coord_inds[:, None, None, :] 264 | # keypoint_scores = torch.nn.functional.grid_sample( 265 | # heatmap, coord_inds, 266 | # padding_mode='border').reshape(bs*k, -1) 267 | # maxvals = keypoint_scores 268 | 269 | preds = coords.reshape(bs, k, 2).cpu().numpy() 270 | maxvals = maxvals.reshape(bs, k).cpu().numpy() 271 | if self.codec.get('type', 'MSRAHeatmap') == 'UDPHeatmap': 272 | preds = preds / [w-1, h-1] 273 | preds = preds * self.codec['input_size'] 274 | else: 275 | stride = self.codec['input_size'][0] / self.codec['heatmap_size'][0] 276 | preds = preds * stride 277 | 278 | return preds, maxvals 279 | 280 | def decode(self, batch_rets): 281 | logit, keypoint = batch_rets[:2] 282 | bs, k, h, w = logit.shape 283 | logits = logit.reshape(bs*k, h*w) 284 | logits = logits / logits.sum(dim=1, keepdim=True) 285 | keypoints = keypoint.reshape(bs*k, 2, h*w).permute(0, 2, 1) 286 | maxvals, maxinds = logits.max(dim=1) 287 | coords = keypoints[torch.arange(bs*k, dtype=torch.long).to(keypoints.device), maxinds] 288 | coords[..., 0] *= w 289 | coords[..., 1] *= h 290 | 291 | # hmvals 292 | # heatmap = batch_rets[2] 293 | # bs, k, h, w = heatmap.shape 294 | # heatmap = heatmap.reshape(bs*k, 1, h, w).sigmoid() 295 | # coord_inds = torch.stack(( 296 | # coords[:, 0] / (w - 1) * 2 - 1, 297 | # coords[:, 1] / (h - 1) * 2 - 1, 298 | # ), dim=-1) 299 | # coord_inds = coord_inds[:, None, None, :] 300 | # keypoint_scores = torch.nn.functional.grid_sample( 301 | # heatmap, coord_inds, 302 | # padding_mode='border').reshape(bs*k, -1) 303 | # maxvals = keypoint_scores 304 | 305 | preds = coords.reshape(bs, k, 2).cpu().numpy() 306 | maxvals = maxvals.reshape(bs, k).cpu().numpy() 307 | if self.codec.get('type', 'MSRAHeatmap') == 'UDPHeatmap': 308 | preds = preds / [w-1, h-1] 309 | preds = preds * self.codec['input_size'] 310 | else: 311 | stride = self.codec['input_size'][0] / self.codec['heatmap_size'][0] 312 | preds = preds * stride 313 | 314 | return preds, maxvals 315 | 316 | def loss(self, 317 | feats: Tuple[Tensor], 318 | batch_data_samples: OptSampleList, 319 | train_cfg: ConfigType = {}) -> dict: 320 | """Calculate losses from a batch of inputs and data samples. 321 | 322 | Args: 323 | feats (Tuple[Tensor]): The multi-stage features 324 | batch_data_samples (List[:obj:`PoseDataSample`]): The batch 325 | data samples 326 | train_cfg (dict): The runtime config for training process. 327 | Defaults to {} 328 | 329 | Returns: 330 | dict: A dictionary of losses. 331 | """ 332 | pred_fields = self.forward(feats) 333 | gt_heatmaps = torch.stack( 334 | [d.gt_fields.heatmaps for d in batch_data_samples]) 335 | gt_weights = torch.cat([ 336 | d.gt_instance_labels.keypoint_weights for d in batch_data_samples 337 | ]) 338 | gt_keypoints = torch.cat([ 339 | d.gt_instance_labels.keypoint_labels for d in batch_data_samples 340 | ]) 341 | 342 | logit, keypoint = pred_fields[:2] 343 | if self.with_heatmap: heatmap = pred_fields[2] 344 | 345 | bs, k, h, w = logit.size() 346 | assert k == keypoint.size(1) and keypoint.size(2) == 2 347 | 348 | valid_mask = gt_weights.reshape(bs*k) > 0 349 | # get cls score 350 | cls_score = logit.reshape(bs*k, h*w)[valid_mask] 351 | # get reg score 352 | pred_keypoints = keypoint.reshape(bs*k, 2, h*w).permute(0, 2, 1) 353 | gt_keypoints = gt_keypoints.reshape(bs*k, 2)[:, None, :] 354 | hh, ww = gt_heatmaps.shape[2:] 355 | gt_keypoints[..., 0] /= ww 356 | gt_keypoints[..., 1] /= hh 357 | 358 | dist_mat = torch.abs(pred_keypoints - gt_keypoints) 359 | dist_mat = dist_mat * 16.0 360 | reg_score = torch.exp(-dist_mat.sum(dim=2))[valid_mask] 361 | 362 | norm_cls_score = cls_score / cls_score.sum(dim=1, keepdim=True) 363 | normcls2reg_loss = torch.sum(norm_cls_score * reg_score, dim=1) 364 | normcls2reg_loss = -torch.log(normcls2reg_loss + 1e-6) 365 | loss = normcls2reg_loss.mean() 366 | 367 | if self.with_heatmap: 368 | bs, k, h, w = gt_heatmaps.size() 369 | heatmap = heatmap.reshape(bs*k, h, w)[valid_mask] 370 | gt_heatmap = gt_heatmaps.reshape(bs*k, h, w)[valid_mask] 371 | pos_label = gt_heatmap > 0 372 | num_pos = torch.sum(gt_heatmap > 0.7).item() 373 | heatmap_loss = gfl_loss(heatmap, gt_heatmap, pos_label) / num_pos 374 | loss = loss + heatmap_loss 375 | 376 | # calculate losses 377 | losses = dict() 378 | 379 | losses.update(loss_kpt=loss) 380 | 381 | return losses 382 | 383 | def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args, 384 | **kwargs): 385 | """A hook function to convert old-version state dict of 386 | :class:`DeepposeRegressionHead` (before MMPose v1.0.0) to a 387 | compatible format of :class:`RegressionHead`. 388 | 389 | The hook will be automatically registered during initialization. 390 | """ 391 | version = local_meta.get('version', None) 392 | if version and version >= self._version: 393 | return 394 | 395 | # convert old-version state dict 396 | keys = list(state_dict.keys()) 397 | for _k in keys: 398 | if not _k.startswith(prefix): 399 | continue 400 | v = state_dict.pop(_k) 401 | k = _k[len(prefix):] 402 | # In old version, "final_layer" includes both intermediate 403 | # conv layers (new "conv_layers") and final conv layers (new 404 | # "final_layer"). 405 | # 406 | # If there is no intermediate conv layer, old "final_layer" will 407 | # have keys like "final_layer.xxx", which should be still 408 | # named "final_layer.xxx"; 409 | # 410 | # If there are intermediate conv layers, old "final_layer" will 411 | # have keys like "final_layer.n.xxx", where the weights of the last 412 | # one should be renamed "final_layer.xxx", and others should be 413 | # renamed "conv_layers.n.xxx" 414 | k_parts = k.split('.') 415 | if k_parts[0] == 'final_layer': 416 | if len(k_parts) == 3: 417 | assert isinstance(self.conv_layers, nn.Sequential) 418 | idx = int(k_parts[1]) 419 | if idx < len(self.conv_layers): 420 | # final_layer.n.xxx -> conv_layers.n.xxx 421 | k_new = 'conv_layers.' + '.'.join(k_parts[1:]) 422 | else: 423 | # final_layer.n.xxx -> final_layer.xxx 424 | k_new = 'final_layer.' + k_parts[2] 425 | else: 426 | # final_layer.xxx remains final_layer.xxx 427 | k_new = k 428 | else: 429 | k_new = k 430 | 431 | state_dict[prefix + k_new] = v 432 | 433 | class FocalLoss(nn.Module): 434 | def __init__(self): 435 | super(FocalLoss, self).__init__() 436 | self.alpha = 2 437 | self.beta = 4 438 | 439 | def forward(self, pred, gt, mask=None): 440 | pos_inds = gt.eq(1).float() 441 | neg_inds = gt.lt(1).float() 442 | 443 | if mask is not None: 444 | pos_inds = pos_inds * mask 445 | neg_inds = neg_inds * mask 446 | 447 | neg_weights = torch.pow(1 - gt, self.beta) 448 | 449 | loss = 0 450 | 451 | pos_loss = torch.log(pred) * torch.pow(1 - pred, self.alpha) * pos_inds 452 | neg_loss = torch.log(1 - pred) * torch.pow(pred, self.alpha) * neg_weights * neg_inds 453 | 454 | num_pos = pos_inds.float().sum() 455 | pos_loss = pos_loss.sum() 456 | neg_loss = neg_loss.sum() 457 | 458 | if num_pos == 0: 459 | loss = loss - neg_loss 460 | else: 461 | loss = loss - (pos_loss + neg_loss) / num_pos 462 | return loss 463 | 464 | def gfl_loss( 465 | inputs: torch.Tensor, 466 | targets: torch.Tensor, 467 | pos_label: torch.Tensor, 468 | gamma: float = 2, 469 | ): 470 | score = targets 471 | pred = inputs 472 | # negatives are supervised by 0 quality score 473 | pred_sigmoid = pred.sigmoid() 474 | scale_factor = pred_sigmoid 475 | zerolabel = scale_factor.new_zeros(pred.shape) 476 | loss = F.binary_cross_entropy_with_logits( 477 | pred, zerolabel, reduction='none') * scale_factor.pow(gamma) 478 | 479 | # FG cat_id: [0, num_classes -1], BG cat_id: num_classes 480 | # positives are supervised by bbox quality (IoU) score 481 | scale_factor = score[pos_label] - pred_sigmoid[pos_label] 482 | loss[pos_label] = F.binary_cross_entropy_with_logits( 483 | pred[pos_label], score[pos_label], 484 | reduction='none') * scale_factor.abs().pow(gamma) 485 | 486 | loss = loss.sum() 487 | return loss --------------------------------------------------------------------------------