├── LICENSE ├── README.md ├── experiments ├── cls_hrnet_w18_sgd_lr5e-2_wd1e-4_bs32_x100.yaml ├── cls_hrnet_w18_small_v1_sgd_lr5e-2_wd1e-4_bs32_x100.yaml ├── cls_hrnet_w18_small_v2_sgd_lr5e-2_wd1e-4_bs32_x100.yaml ├── cls_hrnet_w30_sgd_lr5e-2_wd1e-4_bs32_x100.yaml ├── cls_hrnet_w32_sgd_lr5e-2_wd1e-4_bs32_x100.yaml ├── cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml ├── cls_hrnet_w44_sgd_lr5e-2_wd1e-4_bs32_x100.yaml ├── cls_hrnet_w48_sgd_lr5e-2_wd1e-4_bs32_x100.yaml └── cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml ├── figures ├── cls-head.png ├── cls-hrnet.png └── hrnet.png ├── lib ├── config │ ├── __init__.py │ ├── default.py │ └── models.py ├── core │ ├── evaluate.py │ └── function.py ├── models │ ├── __init__.py │ └── cls_hrnet.py └── utils │ ├── modelsummary.py │ └── utils.py ├── requirements.txt └── tools ├── _init_paths.py ├── train.py └── valid.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Microsoft Corporation 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # High-resolution networks (HRNets) for Image classification 2 | 3 | ## News 4 | 5 | - [2021/01/20] Add some stronger ImageNet pretrained models, e.g., the HRNet_W48_C_ssld_pretrained.pth achieved top-1 acc 83.6%. 6 | 7 | - [2020/03/13] Our paper is accepted by TPAMI: [Deep High-Resolution Representation Learning for Visual Recognition](https://arxiv.org/pdf/1908.07919.pdf). 8 | 9 | - Per request, we provide two small HRNet models. #parameters and GFLOPs are similar to ResNet18. The segmentation resutls using the two small models are also available at https://github.com/HRNet/HRNet-Semantic-Segmentation. 10 | 11 | - TensoFlow implemenation available at https://github.com/yuanyuanli85/tf-hrnet. Thanks [VictorLi](https://github.com/yuanyuanli85)! 12 | 13 | - ONNX export enabled after fixing issues. Thanks [Baowen Bao](https://github.com/BowenBao)! 14 | 15 | ## Introduction 16 | This is the official code of [high-resolution representations for ImageNet classification](https://arxiv.org/abs/1904.04514). 17 | We augment the HRNet with a classification head shown in the figure below. First, the four-resolution feature maps are fed into a bottleneck and the number of output channels are increased to 128, 256, 512, and 1024, respectively. Then, we downsample the high-resolution representations by a 2-strided 3x3 convolution outputting 256 channels and add them to the representations of the second-high-resolution representations. This process is repeated two times to get 1024 channels over the small resolution. Last, we transform 1024 channels to 2048 channels through a 1x1 convolution, followed by a global average pooling operation. The output 2048-dimensional representation is fed into the classifier. 18 | 19 | ![](figures/cls-hrnet.png) 20 | 21 | ## ImageNet pretrained models 22 | HRNetV2 ImageNet pretrained models are now available! 23 | 24 | | model |#Params | GFLOPs |top-1 error| top-5 error| Link | 25 | | :--: | :--: | :--: | :--: | :--: | :--: | 26 | | HRNet-W18-C-Small-v1 | 13.2M | 1.49 | 27.7% | 9.3% |[OneDrive](https://1drv.ms/u/s!Aus8VCZ_C_33gRv2PI1vjJyn2g7G?e=i8Rdzx)/[BaiduYun(Access Code:v3sw)](https://pan.baidu.com/s/1snP_gTz50pJp2g07anVIEA) 27 | | HRNet-W18-C-Small-v2 | 15.6M | 2.42 | 24.9% | 7.6% |[OneDrive](https://1drv.ms/u/s!Aus8VCZ_C_33gRmfdPR79WBS61Qn?e=HVZUi8)/[BaiduYun(Access Code:bnc9)](https://pan.baidu.com/s/1tbL45sOS4mXNGgyS4YCQww) 28 | | HRNet-W18-C | 21.3M | 3.99 | 23.2% | 6.6% |[OneDrive](https://1drv.ms/u/s!Aus8VCZ_C_33cMkPimlmClRvmpw)/[BaiduYun(Access Code:r5xn)](https://pan.baidu.com/s/1Px_g1E2BLVRkKC5t-b-R5Q)| 29 | | HRNet-W30-C | 37.7M | 7.55 | 21.8% | 5.8% |[OneDrive](https://1drv.ms/u/s!Aus8VCZ_C_33cQoACCEfrzcSaVI)/[BaiduYun(Access Code:ajc1)](https://pan.baidu.com/s/1yEz7hKaJT-H7eHLteAotbQ)| 30 | | HRNet-W32-C | 41.2M | 8.31 | 21.5% | 5.8% |[OneDrive](https://1drv.ms/u/s!Aus8VCZ_C_33dYBMemi9xOUFR0w)/[BaiduYun(Access Code:itc1)](https://pan.baidu.com/s/1xn92PSCg5KtXkKcnnLOycw)|| 31 | | HRNet-W40-C | 57.6M | 11.8 | 21.1% | 5.5% |[OneDrive](https://1drv.ms/u/s!Aus8VCZ_C_33ck0gvo5jfoWBOPo)/[BaiduYun(Access Code:i58x)](https://pan.baidu.com/s/1DD3WKxgLM1jawR87WdAtsw)| 32 | | HRNet-W44-C | 67.1M | 13.9 | 21.1% | 5.6% |[OneDrive](https://1drv.ms/u/s!Aus8VCZ_C_33czZQ0woUb980gRs)/[BaiduYun(Access Code:3imd)](https://pan.baidu.com/s/1F679dvz9iJ8aFAp6YKr9Rw)| 33 | | HRNet-W48-C | 77.5M | 16.1 | 20.7% | 5.5% |[OneDrive](https://1drv.ms/u/s!Aus8VCZ_C_33dKvqI6pBZlifgJk)/[BaiduYun(Access Code:68g2)](https://pan.baidu.com/s/13b8srQn8ARF9zHsaxvpRWA)| 34 | | HRNet-W64-C | 128.1M | 26.9 | 20.5% | 5.4% |[OneDrive](https://1drv.ms/u/s!Aus8VCZ_C_33gQbJsUPTIj3rQu99)/[BaiduYun(Access Code:6kw4)](https://pan.baidu.com/s/16ycW99VAYat3fHjgKpUXvQ)| 35 | 36 | Newly added checkpoints: 37 | 38 | | model |#Params | GFLOPs |top-1 error | Link | 39 | | :--: | :--: | :--: | :--: | :--: | 40 | | HRNet-W18-C (w/ CosineLR + CutMix + 300epochs) | 21.3M | 3.99 | 22.1% | [Link](https://github.com/HRNet/HRNet-Image-Classification/releases/download/PretrainedWeights/HRNet_W18_C_cosinelr_cutmix_300epoch.pth.tar) 41 | | HRNet-W48-C (w/ CosineLR + CutMix + 300epochs) | 77.5M | 16.1 | 18.9% | [Link](https://github.com/HRNet/HRNet-Image-Classification/releases/download/PretrainedWeights/HRNet_W48_C_cosinelr_cutmix_300epoch.pth.tar) 42 | | HRNet-W18-C-ssld (converted from PaddlePaddle) | 21.3M | 3.99 | 18.8% | [Link](https://github.com/HRNet/HRNet-Image-Classification/releases/download/PretrainedWeights/HRNet_W18_C_ssld_pretrained.pth) 43 | | HRNet-W48-C-ssld (converted from PaddlePaddle) | 77.5M | 16.1 | 16.4% | [Link](https://github.com/HRNet/HRNet-Image-Classification/releases/download/PretrainedWeights/HRNet_W48_C_ssld_pretrained.pth) 44 | 45 | In the above Table, the first 2 checkpoints are trained with CosineLR, CutMix data augmentation and for longer epochs, i.e., 300epochs. The other two checkpoints are converted 46 | from [PaddleClas](https://github.com/PaddlePaddle/PaddleClas). Please refer to [SSLD tutorial](https://github.com/PaddlePaddle/PaddleClas/blob/dygraph/docs/en/advanced_tutorials/distillation/distillation_en.md#ssld) for more details. 47 | 48 | ## Quick start 49 | ### Install 50 | 1. Install PyTorch=0.4.1 following the [official instructions](https://pytorch.org/) 51 | 2. git clone https://github.com/HRNet/HRNet-Image-Classification 52 | 3. Install dependencies: pip install -r requirements.txt 53 | 54 | ### Data preparation 55 | You can follow the Pytorch implementation: 56 | https://github.com/pytorch/examples/tree/master/imagenet 57 | 58 | The data should be under ./data/imagenet/images/. 59 | 60 | ### Train and test 61 | Please specify the configuration file. 62 | 63 | For example, train the HRNet-W18 on ImageNet with a batch size of 128 on 4 GPUs: 64 | ````bash 65 | python tools/train.py --cfg experiments/cls_hrnet_w18_sgd_lr5e-2_wd1e-4_bs32_x100.yaml 66 | ```` 67 | 68 | For example, test the HRNet-W18 on ImageNet on 4 GPUs: 69 | ````bash 70 | python tools/valid.py --cfg experiments/cls_hrnet_w18_sgd_lr5e-2_wd1e-4_bs32_x100.yaml --testModel hrnetv2_w18_imagenet_pretrained.pth 71 | ```` 72 | 73 | ## Other applications of HRNet 74 | * [Human pose estimation](https://github.com/leoxiaobin/deep-high-resolution-net.pytorch) 75 | * [Semantic segmentation](https://github.com/HRNet/HRNet-Semantic-Segmentation) 76 | * [Object detection](https://github.com/HRNet/HRNet-Object-Detection) 77 | * [Facial landmark detection](https://github.com/HRNet/HRNet-Facial-Landmark-Detection) 78 | 79 | ## Citation 80 | If you find this work or code is helpful in your research, please cite: 81 | ```` 82 | @inproceedings{SunXLW19, 83 | title={Deep High-Resolution Representation Learning for Human Pose Estimation}, 84 | author={Ke Sun and Bin Xiao and Dong Liu and Jingdong Wang}, 85 | booktitle={CVPR}, 86 | year={2019} 87 | } 88 | 89 | @article{WangSCJDZLMTWLX19, 90 | title={Deep High-Resolution Representation Learning for Visual Recognition}, 91 | author={Jingdong Wang and Ke Sun and Tianheng Cheng and 92 | Borui Jiang and Chaorui Deng and Yang Zhao and Dong Liu and Yadong Mu and 93 | Mingkui Tan and Xinggang Wang and Wenyu Liu and Bin Xiao}, 94 | journal = {TPAMI} 95 | year={2019} 96 | } 97 | ```` 98 | 99 | ## Reference 100 | [1] Deep High-Resolution Representation Learning for Visual Recognition. Jingdong Wang, Ke Sun, Tianheng Cheng, 101 | Borui Jiang, Chaorui Deng, Yang Zhao, Dong Liu, Yadong Mu, Mingkui Tan, Xinggang Wang, Wenyu Liu, Bin Xiao. Accepted by TPAMI. [download](https://arxiv.org/pdf/1908.07919.pdf) 102 | -------------------------------------------------------------------------------- /experiments/cls_hrnet_w18_sgd_lr5e-2_wd1e-4_bs32_x100.yaml: -------------------------------------------------------------------------------- 1 | GPUS: (0,1,2,3) 2 | LOG_DIR: 'log/' 3 | DATA_DIR: '' 4 | OUTPUT_DIR: 'output/' 5 | WORKERS: 4 6 | PRINT_FREQ: 1000 7 | 8 | MODEL: 9 | NAME: cls_hrnet 10 | IMAGE_SIZE: 11 | - 224 12 | - 224 13 | EXTRA: 14 | STAGE1: 15 | NUM_MODULES: 1 16 | NUM_RANCHES: 1 17 | BLOCK: BOTTLENECK 18 | NUM_BLOCKS: 19 | - 4 20 | NUM_CHANNELS: 21 | - 64 22 | FUSE_METHOD: SUM 23 | STAGE2: 24 | NUM_MODULES: 1 25 | NUM_BRANCHES: 2 26 | BLOCK: BASIC 27 | NUM_BLOCKS: 28 | - 4 29 | - 4 30 | NUM_CHANNELS: 31 | - 18 32 | - 36 33 | FUSE_METHOD: SUM 34 | STAGE3: 35 | NUM_MODULES: 4 36 | NUM_BRANCHES: 3 37 | BLOCK: BASIC 38 | NUM_BLOCKS: 39 | - 4 40 | - 4 41 | - 4 42 | NUM_CHANNELS: 43 | - 18 44 | - 36 45 | - 72 46 | FUSE_METHOD: SUM 47 | STAGE4: 48 | NUM_MODULES: 3 49 | NUM_BRANCHES: 4 50 | BLOCK: BASIC 51 | NUM_BLOCKS: 52 | - 4 53 | - 4 54 | - 4 55 | - 4 56 | NUM_CHANNELS: 57 | - 18 58 | - 36 59 | - 72 60 | - 144 61 | FUSE_METHOD: SUM 62 | CUDNN: 63 | BENCHMARK: true 64 | DETERMINISTIC: false 65 | ENABLED: true 66 | DATASET: 67 | DATASET: 'imagenet' 68 | DATA_FORMAT: 'jpg' 69 | ROOT: 'data/imagenet/' 70 | TEST_SET: 'val' 71 | TRAIN_SET: 'train' 72 | TEST: 73 | BATCH_SIZE_PER_GPU: 32 74 | MODEL_FILE: '' 75 | TRAIN: 76 | BATCH_SIZE_PER_GPU: 32 77 | BEGIN_EPOCH: 0 78 | END_EPOCH: 100 79 | RESUME: true 80 | LR_FACTOR: 0.1 81 | LR_STEP: 82 | - 30 83 | - 60 84 | - 90 85 | OPTIMIZER: sgd 86 | LR: 0.05 87 | WD: 0.0001 88 | MOMENTUM: 0.9 89 | NESTEROV: true 90 | SHUFFLE: true 91 | DEBUG: 92 | DEBUG: false 93 | -------------------------------------------------------------------------------- /experiments/cls_hrnet_w18_small_v1_sgd_lr5e-2_wd1e-4_bs32_x100.yaml: -------------------------------------------------------------------------------- 1 | GPUS: (0,1,2,3) 2 | LOG_DIR: 'log/' 3 | DATA_DIR: '' 4 | OUTPUT_DIR: 'output/' 5 | WORKERS: 4 6 | PRINT_FREQ: 1000 7 | 8 | MODEL: 9 | NAME: cls_hrnet 10 | IMAGE_SIZE: 11 | - 224 12 | - 224 13 | EXTRA: 14 | WITH_HEAD: true 15 | STAGE1: 16 | NUM_MODULES: 1 17 | NUM_RANCHES: 1 18 | BLOCK: BOTTLENECK 19 | NUM_BLOCKS: 20 | - 1 21 | NUM_CHANNELS: 22 | - 32 23 | FUSE_METHOD: SUM 24 | STAGE2: 25 | NUM_MODULES: 1 26 | NUM_BRANCHES: 2 27 | BLOCK: BASIC 28 | NUM_BLOCKS: 29 | - 2 30 | - 2 31 | NUM_CHANNELS: 32 | - 16 33 | - 32 34 | FUSE_METHOD: SUM 35 | STAGE3: 36 | NUM_MODULES: 1 37 | NUM_BRANCHES: 3 38 | BLOCK: BASIC 39 | NUM_BLOCKS: 40 | - 2 41 | - 2 42 | - 2 43 | NUM_CHANNELS: 44 | - 16 45 | - 32 46 | - 64 47 | FUSE_METHOD: SUM 48 | STAGE4: 49 | NUM_MODULES: 1 50 | NUM_BRANCHES: 4 51 | BLOCK: BASIC 52 | NUM_BLOCKS: 53 | - 2 54 | - 2 55 | - 2 56 | - 2 57 | NUM_CHANNELS: 58 | - 16 59 | - 32 60 | - 64 61 | - 128 62 | FUSE_METHOD: SUM 63 | CUDNN: 64 | BENCHMARK: true 65 | DETERMINISTIC: false 66 | ENABLED: true 67 | DATASET: 68 | DATASET: 'imagenet' 69 | DATA_FORMAT: 'jpg' 70 | ROOT: 'data/imagenet/' 71 | TEST_SET: 'val' 72 | TRAIN_SET: 'train' 73 | TEST: 74 | BATCH_SIZE_PER_GPU: 32 75 | MODEL_FILE: '' 76 | TRAIN: 77 | BATCH_SIZE_PER_GPU: 32 78 | BEGIN_EPOCH: 0 79 | END_EPOCH: 100 80 | RESUME: true 81 | LR_FACTOR: 0.1 82 | LR_STEP: 83 | - 30 84 | - 60 85 | - 90 86 | OPTIMIZER: sgd 87 | LR: 0.05 88 | WD: 0.0001 89 | MOMENTUM: 0.9 90 | NESTEROV: true 91 | SHUFFLE: true 92 | DEBUG: 93 | DEBUG: false 94 | -------------------------------------------------------------------------------- /experiments/cls_hrnet_w18_small_v2_sgd_lr5e-2_wd1e-4_bs32_x100.yaml: -------------------------------------------------------------------------------- 1 | GPUS: (0,1,2,3) 2 | LOG_DIR: 'log/' 3 | DATA_DIR: '' 4 | OUTPUT_DIR: 'output/' 5 | WORKERS: 4 6 | PRINT_FREQ: 1000 7 | 8 | MODEL: 9 | NAME: cls_hrnet 10 | IMAGE_SIZE: 11 | - 224 12 | - 224 13 | EXTRA: 14 | WITH_HEAD: true 15 | STAGE1: 16 | NUM_MODULES: 1 17 | NUM_RANCHES: 1 18 | BLOCK: BOTTLENECK 19 | NUM_BLOCKS: 20 | - 2 21 | NUM_CHANNELS: 22 | - 64 23 | FUSE_METHOD: SUM 24 | STAGE2: 25 | NUM_MODULES: 1 26 | NUM_BRANCHES: 2 27 | BLOCK: BASIC 28 | NUM_BLOCKS: 29 | - 2 30 | - 2 31 | NUM_CHANNELS: 32 | - 18 33 | - 36 34 | FUSE_METHOD: SUM 35 | STAGE3: 36 | NUM_MODULES: 3 37 | NUM_BRANCHES: 3 38 | BLOCK: BASIC 39 | NUM_BLOCKS: 40 | - 2 41 | - 2 42 | - 2 43 | NUM_CHANNELS: 44 | - 18 45 | - 36 46 | - 72 47 | FUSE_METHOD: SUM 48 | STAGE4: 49 | NUM_MODULES: 2 50 | NUM_BRANCHES: 4 51 | BLOCK: BASIC 52 | NUM_BLOCKS: 53 | - 2 54 | - 2 55 | - 2 56 | - 2 57 | NUM_CHANNELS: 58 | - 18 59 | - 36 60 | - 72 61 | - 144 62 | FUSE_METHOD: SUM 63 | CUDNN: 64 | BENCHMARK: true 65 | DETERMINISTIC: false 66 | ENABLED: true 67 | DATASET: 68 | DATASET: 'imagenet' 69 | DATA_FORMAT: 'zip' 70 | ROOT: 'data/imagenet/' 71 | TEST_SET: 'val' 72 | TRAIN_SET: 'train' 73 | TEST: 74 | BATCH_SIZE_PER_GPU: 32 75 | MODEL_FILE: '' 76 | TRAIN: 77 | BATCH_SIZE_PER_GPU: 32 78 | BEGIN_EPOCH: 0 79 | END_EPOCH: 100 80 | RESUME: true 81 | LR_FACTOR: 0.1 82 | LR_STEP: 83 | - 30 84 | - 60 85 | - 90 86 | OPTIMIZER: sgd 87 | LR: 0.05 88 | WD: 0.0001 89 | MOMENTUM: 0.9 90 | NESTEROV: true 91 | SHUFFLE: true 92 | DEBUG: 93 | DEBUG: false 94 | -------------------------------------------------------------------------------- /experiments/cls_hrnet_w30_sgd_lr5e-2_wd1e-4_bs32_x100.yaml: -------------------------------------------------------------------------------- 1 | GPUS: (0,1,2,3) 2 | LOG_DIR: 'log/' 3 | DATA_DIR: '' 4 | OUTPUT_DIR: 'output/' 5 | WORKERS: 4 6 | PRINT_FREQ: 1000 7 | 8 | MODEL: 9 | NAME: cls_hrnet 10 | IMAGE_SIZE: 11 | - 224 12 | - 224 13 | EXTRA: 14 | STAGE1: 15 | NUM_MODULES: 1 16 | NUM_RANCHES: 1 17 | BLOCK: BOTTLENECK 18 | NUM_BLOCKS: 19 | - 4 20 | NUM_CHANNELS: 21 | - 64 22 | FUSE_METHOD: SUM 23 | STAGE2: 24 | NUM_MODULES: 1 25 | NUM_BRANCHES: 2 26 | BLOCK: BASIC 27 | NUM_BLOCKS: 28 | - 4 29 | - 4 30 | NUM_CHANNELS: 31 | - 30 32 | - 60 33 | FUSE_METHOD: SUM 34 | STAGE3: 35 | NUM_MODULES: 4 36 | NUM_BRANCHES: 3 37 | BLOCK: BASIC 38 | NUM_BLOCKS: 39 | - 4 40 | - 4 41 | - 4 42 | NUM_CHANNELS: 43 | - 30 44 | - 60 45 | - 120 46 | FUSE_METHOD: SUM 47 | STAGE4: 48 | NUM_MODULES: 3 49 | NUM_BRANCHES: 4 50 | BLOCK: BASIC 51 | NUM_BLOCKS: 52 | - 4 53 | - 4 54 | - 4 55 | - 4 56 | NUM_CHANNELS: 57 | - 30 58 | - 60 59 | - 120 60 | - 240 61 | FUSE_METHOD: SUM 62 | CUDNN: 63 | BENCHMARK: true 64 | DETERMINISTIC: false 65 | ENABLED: true 66 | DATASET: 67 | DATASET: 'imagenet' 68 | DATA_FORMAT: 'jpg' 69 | ROOT: 'data/imagenet/' 70 | TEST_SET: 'val' 71 | TRAIN_SET: 'train' 72 | TEST: 73 | BATCH_SIZE_PER_GPU: 32 74 | MODEL_FILE: '' 75 | TRAIN: 76 | BATCH_SIZE_PER_GPU: 32 77 | BEGIN_EPOCH: 0 78 | END_EPOCH: 100 79 | RESUME: true 80 | LR_FACTOR: 0.1 81 | LR_STEP: 82 | - 30 83 | - 60 84 | - 90 85 | OPTIMIZER: sgd 86 | LR: 0.05 87 | WD: 0.0001 88 | MOMENTUM: 0.9 89 | NESTEROV: true 90 | SHUFFLE: true 91 | DEBUG: 92 | DEBUG: false 93 | -------------------------------------------------------------------------------- /experiments/cls_hrnet_w32_sgd_lr5e-2_wd1e-4_bs32_x100.yaml: -------------------------------------------------------------------------------- 1 | GPUS: (0,1,2,3) 2 | LOG_DIR: 'log/' 3 | DATA_DIR: '' 4 | OUTPUT_DIR: 'output/' 5 | WORKERS: 4 6 | PRINT_FREQ: 1000 7 | 8 | MODEL: 9 | NAME: cls_hrnet 10 | IMAGE_SIZE: 11 | - 224 12 | - 224 13 | EXTRA: 14 | STAGE1: 15 | NUM_MODULES: 1 16 | NUM_RANCHES: 1 17 | BLOCK: BOTTLENECK 18 | NUM_BLOCKS: 19 | - 4 20 | NUM_CHANNELS: 21 | - 64 22 | FUSE_METHOD: SUM 23 | STAGE2: 24 | NUM_MODULES: 1 25 | NUM_BRANCHES: 2 26 | BLOCK: BASIC 27 | NUM_BLOCKS: 28 | - 4 29 | - 4 30 | NUM_CHANNELS: 31 | - 32 32 | - 64 33 | FUSE_METHOD: SUM 34 | STAGE3: 35 | NUM_MODULES: 4 36 | NUM_BRANCHES: 3 37 | BLOCK: BASIC 38 | NUM_BLOCKS: 39 | - 4 40 | - 4 41 | - 4 42 | NUM_CHANNELS: 43 | - 32 44 | - 64 45 | - 128 46 | FUSE_METHOD: SUM 47 | STAGE4: 48 | NUM_MODULES: 3 49 | NUM_BRANCHES: 4 50 | BLOCK: BASIC 51 | NUM_BLOCKS: 52 | - 4 53 | - 4 54 | - 4 55 | - 4 56 | NUM_CHANNELS: 57 | - 32 58 | - 64 59 | - 128 60 | - 256 61 | FUSE_METHOD: SUM 62 | CUDNN: 63 | BENCHMARK: true 64 | DETERMINISTIC: false 65 | ENABLED: true 66 | DATASET: 67 | DATASET: 'imagenet' 68 | DATA_FORMAT: 'jpg' 69 | ROOT: 'data/imagenet/' 70 | TEST_SET: 'val' 71 | TRAIN_SET: 'train' 72 | TEST: 73 | BATCH_SIZE_PER_GPU: 32 74 | MODEL_FILE: '' 75 | TRAIN: 76 | BATCH_SIZE_PER_GPU: 32 77 | BEGIN_EPOCH: 0 78 | END_EPOCH: 100 79 | RESUME: true 80 | LR_FACTOR: 0.1 81 | LR_STEP: 82 | - 30 83 | - 60 84 | - 90 85 | OPTIMIZER: sgd 86 | LR: 0.05 87 | WD: 0.0001 88 | MOMENTUM: 0.9 89 | NESTEROV: true 90 | SHUFFLE: true 91 | DEBUG: 92 | DEBUG: false 93 | -------------------------------------------------------------------------------- /experiments/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml: -------------------------------------------------------------------------------- 1 | GPUS: (0,1,2,3) 2 | LOG_DIR: 'log/' 3 | DATA_DIR: '' 4 | OUTPUT_DIR: 'output/' 5 | WORKERS: 4 6 | PRINT_FREQ: 1000 7 | 8 | MODEL: 9 | NAME: cls_hrnet 10 | IMAGE_SIZE: 11 | - 224 12 | - 224 13 | EXTRA: 14 | STAGE1: 15 | NUM_MODULES: 1 16 | NUM_RANCHES: 1 17 | BLOCK: BOTTLENECK 18 | NUM_BLOCKS: 19 | - 4 20 | NUM_CHANNELS: 21 | - 64 22 | FUSE_METHOD: SUM 23 | STAGE2: 24 | NUM_MODULES: 1 25 | NUM_BRANCHES: 2 26 | BLOCK: BASIC 27 | NUM_BLOCKS: 28 | - 4 29 | - 4 30 | NUM_CHANNELS: 31 | - 40 32 | - 80 33 | FUSE_METHOD: SUM 34 | STAGE3: 35 | NUM_MODULES: 4 36 | NUM_BRANCHES: 3 37 | BLOCK: BASIC 38 | NUM_BLOCKS: 39 | - 4 40 | - 4 41 | - 4 42 | NUM_CHANNELS: 43 | - 40 44 | - 80 45 | - 160 46 | FUSE_METHOD: SUM 47 | STAGE4: 48 | NUM_MODULES: 3 49 | NUM_BRANCHES: 4 50 | BLOCK: BASIC 51 | NUM_BLOCKS: 52 | - 4 53 | - 4 54 | - 4 55 | - 4 56 | NUM_CHANNELS: 57 | - 40 58 | - 80 59 | - 160 60 | - 320 61 | FUSE_METHOD: SUM 62 | CUDNN: 63 | BENCHMARK: true 64 | DETERMINISTIC: false 65 | ENABLED: true 66 | DATASET: 67 | DATASET: 'imagenet' 68 | DATA_FORMAT: 'jpg' 69 | ROOT: 'data/imagenet/' 70 | TEST_SET: 'val' 71 | TRAIN_SET: 'train' 72 | TEST: 73 | BATCH_SIZE_PER_GPU: 32 74 | MODEL_FILE: '' 75 | TRAIN: 76 | BATCH_SIZE_PER_GPU: 32 77 | BEGIN_EPOCH: 0 78 | END_EPOCH: 100 79 | RESUME: true 80 | LR_FACTOR: 0.1 81 | LR_STEP: 82 | - 30 83 | - 60 84 | - 90 85 | OPTIMIZER: sgd 86 | LR: 0.05 87 | WD: 0.0001 88 | MOMENTUM: 0.9 89 | NESTEROV: true 90 | SHUFFLE: true 91 | DEBUG: 92 | DEBUG: false 93 | -------------------------------------------------------------------------------- /experiments/cls_hrnet_w44_sgd_lr5e-2_wd1e-4_bs32_x100.yaml: -------------------------------------------------------------------------------- 1 | GPUS: (0,1,2,3) 2 | LOG_DIR: 'log/' 3 | DATA_DIR: '' 4 | OUTPUT_DIR: 'output/' 5 | WORKERS: 4 6 | PRINT_FREQ: 1000 7 | 8 | MODEL: 9 | NAME: cls_hrnet 10 | IMAGE_SIZE: 11 | - 224 12 | - 224 13 | EXTRA: 14 | STAGE1: 15 | NUM_MODULES: 1 16 | NUM_RANCHES: 1 17 | BLOCK: BOTTLENECK 18 | NUM_BLOCKS: 19 | - 4 20 | NUM_CHANNELS: 21 | - 64 22 | FUSE_METHOD: SUM 23 | STAGE2: 24 | NUM_MODULES: 1 25 | NUM_BRANCHES: 2 26 | BLOCK: BASIC 27 | NUM_BLOCKS: 28 | - 4 29 | - 4 30 | NUM_CHANNELS: 31 | - 44 32 | - 88 33 | FUSE_METHOD: SUM 34 | STAGE3: 35 | NUM_MODULES: 4 36 | NUM_BRANCHES: 3 37 | BLOCK: BASIC 38 | NUM_BLOCKS: 39 | - 4 40 | - 4 41 | - 4 42 | NUM_CHANNELS: 43 | - 44 44 | - 88 45 | - 176 46 | FUSE_METHOD: SUM 47 | STAGE4: 48 | NUM_MODULES: 3 49 | NUM_BRANCHES: 4 50 | BLOCK: BASIC 51 | NUM_BLOCKS: 52 | - 4 53 | - 4 54 | - 4 55 | - 4 56 | NUM_CHANNELS: 57 | - 44 58 | - 88 59 | - 176 60 | - 352 61 | FUSE_METHOD: SUM 62 | CUDNN: 63 | BENCHMARK: true 64 | DETERMINISTIC: false 65 | ENABLED: true 66 | DATASET: 67 | DATASET: 'imagenet' 68 | DATA_FORMAT: 'jpg' 69 | ROOT: 'data/imagenet/' 70 | TEST_SET: 'val' 71 | TRAIN_SET: 'train' 72 | TEST: 73 | BATCH_SIZE_PER_GPU: 32 74 | MODEL_FILE: '' 75 | TRAIN: 76 | BATCH_SIZE_PER_GPU: 32 77 | BEGIN_EPOCH: 0 78 | END_EPOCH: 100 79 | RESUME: true 80 | LR_FACTOR: 0.1 81 | LR_STEP: 82 | - 30 83 | - 60 84 | - 90 85 | OPTIMIZER: sgd 86 | LR: 0.05 87 | WD: 0.0001 88 | MOMENTUM: 0.9 89 | NESTEROV: true 90 | SHUFFLE: true 91 | DEBUG: 92 | DEBUG: false 93 | -------------------------------------------------------------------------------- /experiments/cls_hrnet_w48_sgd_lr5e-2_wd1e-4_bs32_x100.yaml: -------------------------------------------------------------------------------- 1 | GPUS: (0,1,2,3) 2 | LOG_DIR: 'log/' 3 | DATA_DIR: '' 4 | OUTPUT_DIR: 'output/' 5 | WORKERS: 4 6 | PRINT_FREQ: 1000 7 | 8 | MODEL: 9 | NAME: cls_hrnet 10 | IMAGE_SIZE: 11 | - 224 12 | - 224 13 | EXTRA: 14 | STAGE1: 15 | NUM_MODULES: 1 16 | NUM_RANCHES: 1 17 | BLOCK: BOTTLENECK 18 | NUM_BLOCKS: 19 | - 4 20 | NUM_CHANNELS: 21 | - 64 22 | FUSE_METHOD: SUM 23 | STAGE2: 24 | NUM_MODULES: 1 25 | NUM_BRANCHES: 2 26 | BLOCK: BASIC 27 | NUM_BLOCKS: 28 | - 4 29 | - 4 30 | NUM_CHANNELS: 31 | - 48 32 | - 96 33 | FUSE_METHOD: SUM 34 | STAGE3: 35 | NUM_MODULES: 4 36 | NUM_BRANCHES: 3 37 | BLOCK: BASIC 38 | NUM_BLOCKS: 39 | - 4 40 | - 4 41 | - 4 42 | NUM_CHANNELS: 43 | - 48 44 | - 96 45 | - 192 46 | FUSE_METHOD: SUM 47 | STAGE4: 48 | NUM_MODULES: 3 49 | NUM_BRANCHES: 4 50 | BLOCK: BASIC 51 | NUM_BLOCKS: 52 | - 4 53 | - 4 54 | - 4 55 | - 4 56 | NUM_CHANNELS: 57 | - 48 58 | - 96 59 | - 192 60 | - 384 61 | FUSE_METHOD: SUM 62 | CUDNN: 63 | BENCHMARK: true 64 | DETERMINISTIC: false 65 | ENABLED: true 66 | DATASET: 67 | DATASET: 'imagenet' 68 | DATA_FORMAT: 'jpg' 69 | ROOT: 'data/imagenet/' 70 | TEST_SET: 'val' 71 | TRAIN_SET: 'train' 72 | TEST: 73 | BATCH_SIZE_PER_GPU: 32 74 | MODEL_FILE: '' 75 | TRAIN: 76 | BATCH_SIZE_PER_GPU: 32 77 | BEGIN_EPOCH: 0 78 | END_EPOCH: 100 79 | RESUME: true 80 | LR_FACTOR: 0.1 81 | LR_STEP: 82 | - 30 83 | - 60 84 | - 90 85 | OPTIMIZER: sgd 86 | LR: 0.05 87 | WD: 0.0001 88 | MOMENTUM: 0.9 89 | NESTEROV: true 90 | SHUFFLE: true 91 | DEBUG: 92 | DEBUG: false 93 | -------------------------------------------------------------------------------- /experiments/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml: -------------------------------------------------------------------------------- 1 | GPUS: (0,1,2,3) 2 | LOG_DIR: 'log/' 3 | DATA_DIR: '' 4 | OUTPUT_DIR: 'output/' 5 | WORKERS: 4 6 | PRINT_FREQ: 1000 7 | 8 | MODEL: 9 | NAME: cls_hrnet 10 | IMAGE_SIZE: 11 | - 224 12 | - 224 13 | EXTRA: 14 | STAGE1: 15 | NUM_MODULES: 1 16 | NUM_RANCHES: 1 17 | BLOCK: BOTTLENECK 18 | NUM_BLOCKS: 19 | - 4 20 | NUM_CHANNELS: 21 | - 64 22 | FUSE_METHOD: SUM 23 | STAGE2: 24 | NUM_MODULES: 1 25 | NUM_BRANCHES: 2 26 | BLOCK: BASIC 27 | NUM_BLOCKS: 28 | - 4 29 | - 4 30 | NUM_CHANNELS: 31 | - 64 32 | - 128 33 | FUSE_METHOD: SUM 34 | STAGE3: 35 | NUM_MODULES: 4 36 | NUM_BRANCHES: 3 37 | BLOCK: BASIC 38 | NUM_BLOCKS: 39 | - 4 40 | - 4 41 | - 4 42 | NUM_CHANNELS: 43 | - 64 44 | - 128 45 | - 256 46 | FUSE_METHOD: SUM 47 | STAGE4: 48 | NUM_MODULES: 3 49 | NUM_BRANCHES: 4 50 | BLOCK: BASIC 51 | NUM_BLOCKS: 52 | - 4 53 | - 4 54 | - 4 55 | - 4 56 | NUM_CHANNELS: 57 | - 64 58 | - 128 59 | - 256 60 | - 512 61 | FUSE_METHOD: SUM 62 | CUDNN: 63 | BENCHMARK: true 64 | DETERMINISTIC: false 65 | ENABLED: true 66 | DATASET: 67 | DATASET: 'imagenet' 68 | DATA_FORMAT: 'jpg' 69 | ROOT: 'data/imagenet/' 70 | TEST_SET: 'val' 71 | TRAIN_SET: 'train' 72 | TEST: 73 | BATCH_SIZE_PER_GPU: 32 74 | MODEL_FILE: '' 75 | TRAIN: 76 | BATCH_SIZE_PER_GPU: 32 77 | BEGIN_EPOCH: 0 78 | END_EPOCH: 100 79 | RESUME: true 80 | LR_FACTOR: 0.1 81 | LR_STEP: 82 | - 30 83 | - 60 84 | - 90 85 | OPTIMIZER: sgd 86 | LR: 0.05 87 | WD: 0.0001 88 | MOMENTUM: 0.9 89 | NESTEROV: true 90 | SHUFFLE: true 91 | DEBUG: 92 | DEBUG: false 93 | -------------------------------------------------------------------------------- /figures/cls-head.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HRNet/HRNet-Image-Classification/f760c988482cdb8a1f69b10b219d669721144582/figures/cls-head.png -------------------------------------------------------------------------------- /figures/cls-hrnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HRNet/HRNet-Image-Classification/f760c988482cdb8a1f69b10b219d669721144582/figures/cls-hrnet.png -------------------------------------------------------------------------------- /figures/hrnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HRNet/HRNet-Image-Classification/f760c988482cdb8a1f69b10b219d669721144582/figures/hrnet.png -------------------------------------------------------------------------------- /lib/config/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from .default import _C as config 8 | from .default import update_config 9 | from .models import MODEL_EXTRAS 10 | -------------------------------------------------------------------------------- /lib/config/default.py: -------------------------------------------------------------------------------- 1 | 2 | # ------------------------------------------------------------------------------ 3 | # Copyright (c) Microsoft 4 | # Licensed under the MIT License. 5 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 6 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn) 7 | # ------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | import os 14 | 15 | from yacs.config import CfgNode as CN 16 | 17 | 18 | _C = CN() 19 | 20 | _C.OUTPUT_DIR = '' 21 | _C.LOG_DIR = '' 22 | _C.DATA_DIR = '' 23 | _C.GPUS = (0,) 24 | _C.WORKERS = 4 25 | _C.PRINT_FREQ = 20 26 | _C.AUTO_RESUME = False 27 | _C.PIN_MEMORY = True 28 | _C.RANK = 0 29 | 30 | # Cudnn related params 31 | _C.CUDNN = CN() 32 | _C.CUDNN.BENCHMARK = True 33 | _C.CUDNN.DETERMINISTIC = False 34 | _C.CUDNN.ENABLED = True 35 | 36 | # common params for NETWORK 37 | _C.MODEL = CN() 38 | _C.MODEL.NAME = 'cls_hrnet' 39 | _C.MODEL.INIT_WEIGHTS = True 40 | _C.MODEL.PRETRAINED = '' 41 | _C.MODEL.NUM_JOINTS = 17 42 | _C.MODEL.NUM_CLASSES = 1000 43 | _C.MODEL.TAG_PER_JOINT = True 44 | _C.MODEL.TARGET_TYPE = 'gaussian' 45 | _C.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256 46 | _C.MODEL.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32 47 | _C.MODEL.SIGMA = 2 48 | _C.MODEL.EXTRA = CN(new_allowed=True) 49 | 50 | _C.LOSS = CN() 51 | _C.LOSS.USE_OHKM = False 52 | _C.LOSS.TOPK = 8 53 | _C.LOSS.USE_TARGET_WEIGHT = True 54 | _C.LOSS.USE_DIFFERENT_JOINTS_WEIGHT = False 55 | 56 | # DATASET related params 57 | _C.DATASET = CN() 58 | _C.DATASET.ROOT = '' 59 | _C.DATASET.DATASET = 'mpii' 60 | _C.DATASET.TRAIN_SET = 'train' 61 | _C.DATASET.TEST_SET = 'valid' 62 | _C.DATASET.DATA_FORMAT = 'jpg' 63 | _C.DATASET.HYBRID_JOINTS_TYPE = '' 64 | _C.DATASET.SELECT_DATA = False 65 | 66 | # training data augmentation 67 | _C.DATASET.FLIP = True 68 | _C.DATASET.SCALE_FACTOR = 0.25 69 | _C.DATASET.ROT_FACTOR = 30 70 | _C.DATASET.PROB_HALF_BODY = 0.0 71 | _C.DATASET.NUM_JOINTS_HALF_BODY = 8 72 | _C.DATASET.COLOR_RGB = False 73 | 74 | # train 75 | _C.TRAIN = CN() 76 | 77 | _C.TRAIN.LR_FACTOR = 0.1 78 | _C.TRAIN.LR_STEP = [90, 110] 79 | _C.TRAIN.LR = 0.001 80 | 81 | _C.TRAIN.OPTIMIZER = 'adam' 82 | _C.TRAIN.MOMENTUM = 0.9 83 | _C.TRAIN.WD = 0.0001 84 | _C.TRAIN.NESTEROV = False 85 | _C.TRAIN.GAMMA1 = 0.99 86 | _C.TRAIN.GAMMA2 = 0.0 87 | 88 | _C.TRAIN.BEGIN_EPOCH = 0 89 | _C.TRAIN.END_EPOCH = 140 90 | 91 | _C.TRAIN.RESUME = False 92 | _C.TRAIN.CHECKPOINT = '' 93 | 94 | _C.TRAIN.BATCH_SIZE_PER_GPU = 32 95 | _C.TRAIN.SHUFFLE = True 96 | 97 | # testing 98 | _C.TEST = CN() 99 | 100 | # size of images for each device 101 | _C.TEST.BATCH_SIZE_PER_GPU = 32 102 | # Test Model Epoch 103 | _C.TEST.FLIP_TEST = False 104 | _C.TEST.POST_PROCESS = False 105 | _C.TEST.SHIFT_HEATMAP = False 106 | 107 | _C.TEST.USE_GT_BBOX = False 108 | 109 | # nms 110 | _C.TEST.IMAGE_THRE = 0.1 111 | _C.TEST.NMS_THRE = 0.6 112 | _C.TEST.SOFT_NMS = False 113 | _C.TEST.OKS_THRE = 0.5 114 | _C.TEST.IN_VIS_THRE = 0.0 115 | _C.TEST.COCO_BBOX_FILE = '' 116 | _C.TEST.BBOX_THRE = 1.0 117 | _C.TEST.MODEL_FILE = '' 118 | 119 | # debug 120 | _C.DEBUG = CN() 121 | _C.DEBUG.DEBUG = False 122 | _C.DEBUG.SAVE_BATCH_IMAGES_GT = False 123 | _C.DEBUG.SAVE_BATCH_IMAGES_PRED = False 124 | _C.DEBUG.SAVE_HEATMAPS_GT = False 125 | _C.DEBUG.SAVE_HEATMAPS_PRED = False 126 | 127 | 128 | def update_config(cfg, args): 129 | cfg.defrost() 130 | cfg.merge_from_file(args.cfg) 131 | 132 | if args.modelDir: 133 | cfg.OUTPUT_DIR = args.modelDir 134 | 135 | if args.logDir: 136 | cfg.LOG_DIR = args.logDir 137 | 138 | if args.dataDir: 139 | cfg.DATA_DIR = args.dataDir 140 | 141 | if args.testModel: 142 | cfg.TEST.MODEL_FILE = args.testModel 143 | 144 | cfg.DATASET.ROOT = os.path.join( 145 | cfg.DATA_DIR, cfg.DATASET.DATASET, 'images') 146 | 147 | cfg.freeze() 148 | 149 | 150 | if __name__ == '__main__': 151 | import sys 152 | with open(sys.argv[1], 'w') as f: 153 | print(_C, file=f) 154 | 155 | -------------------------------------------------------------------------------- /lib/config/models.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Create by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn) 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | from yacs.config import CfgNode as CN 13 | 14 | # high_resoluton_net related params for classification 15 | POSE_HIGH_RESOLUTION_NET = CN() 16 | POSE_HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*'] 17 | POSE_HIGH_RESOLUTION_NET.STEM_INPLANES = 64 18 | POSE_HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1 19 | POSE_HIGH_RESOLUTION_NET.WITH_HEAD = True 20 | 21 | POSE_HIGH_RESOLUTION_NET.STAGE2 = CN() 22 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1 23 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2 24 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4] 25 | POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64] 26 | POSE_HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC' 27 | POSE_HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM' 28 | 29 | POSE_HIGH_RESOLUTION_NET.STAGE3 = CN() 30 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1 31 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3 32 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4] 33 | POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128] 34 | POSE_HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC' 35 | POSE_HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM' 36 | 37 | POSE_HIGH_RESOLUTION_NET.STAGE4 = CN() 38 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1 39 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4 40 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 41 | POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256] 42 | POSE_HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC' 43 | POSE_HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM' 44 | 45 | MODEL_EXTRAS = { 46 | 'cls_hrnet': POSE_HIGH_RESOLUTION_NET, 47 | } 48 | -------------------------------------------------------------------------------- /lib/core/evaluate.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import torch 12 | 13 | 14 | def accuracy(output, target, topk=(1,)): 15 | """Computes the precision@k for the specified values of k""" 16 | with torch.no_grad(): 17 | maxk = max(topk) 18 | batch_size = target.size(0) 19 | 20 | _, pred = output.topk(maxk, 1, True, True) 21 | pred = pred.t() 22 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 23 | 24 | res = [] 25 | for k in topk: 26 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 27 | res.append(correct_k.mul_(100.0 / batch_size)) 28 | return res 29 | -------------------------------------------------------------------------------- /lib/core/function.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import time 12 | import logging 13 | 14 | import torch 15 | 16 | from core.evaluate import accuracy 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def train(config, train_loader, model, criterion, optimizer, epoch, 23 | output_dir, tb_log_dir, writer_dict): 24 | batch_time = AverageMeter() 25 | data_time = AverageMeter() 26 | losses = AverageMeter() 27 | top1 = AverageMeter() 28 | top5 = AverageMeter() 29 | 30 | 31 | # switch to train mode 32 | model.train() 33 | 34 | end = time.time() 35 | for i, (input, target) in enumerate(train_loader): 36 | # measure data loading time 37 | data_time.update(time.time() - end) 38 | #target = target - 1 # Specific for imagenet 39 | 40 | # compute output 41 | output = model(input) 42 | target = target.cuda(non_blocking=True) 43 | 44 | loss = criterion(output, target) 45 | 46 | # compute gradient and do update step 47 | optimizer.zero_grad() 48 | loss.backward() 49 | optimizer.step() 50 | 51 | # measure accuracy and record loss 52 | losses.update(loss.item(), input.size(0)) 53 | 54 | prec1, prec5 = accuracy(output, target, (1, 5)) 55 | 56 | top1.update(prec1[0], input.size(0)) 57 | top5.update(prec5[0], input.size(0)) 58 | 59 | # measure elapsed time 60 | batch_time.update(time.time() - end) 61 | end = time.time() 62 | 63 | if i % config.PRINT_FREQ == 0: 64 | msg = 'Epoch: [{0}][{1}/{2}]\t' \ 65 | 'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \ 66 | 'Speed {speed:.1f} samples/s\t' \ 67 | 'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \ 68 | 'Loss {loss.val:.5f} ({loss.avg:.5f})\t' \ 69 | 'Accuracy@1 {top1.val:.3f} ({top1.avg:.3f})\t' \ 70 | 'Accuracy@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format( 71 | epoch, i, len(train_loader), batch_time=batch_time, 72 | speed=input.size(0)/batch_time.val, 73 | data_time=data_time, loss=losses, top1=top1, top5=top5) 74 | logger.info(msg) 75 | 76 | if writer_dict: 77 | writer = writer_dict['writer'] 78 | global_steps = writer_dict['train_global_steps'] 79 | writer.add_scalar('train_loss', losses.val, global_steps) 80 | writer.add_scalar('train_top1', top1.val, global_steps) 81 | writer_dict['train_global_steps'] = global_steps + 1 82 | 83 | 84 | def validate(config, val_loader, model, criterion, output_dir, tb_log_dir, 85 | writer_dict=None): 86 | batch_time = AverageMeter() 87 | losses = AverageMeter() 88 | top1 = AverageMeter() 89 | top5 = AverageMeter() 90 | 91 | # switch to evaluate mode 92 | model.eval() 93 | 94 | with torch.no_grad(): 95 | end = time.time() 96 | for i, (input, target) in enumerate(val_loader): 97 | # compute output 98 | output = model(input) 99 | 100 | target = target.cuda(non_blocking=True) 101 | 102 | loss = criterion(output, target) 103 | 104 | # measure accuracy and record loss 105 | losses.update(loss.item(), input.size(0)) 106 | prec1, prec5 = accuracy(output, target, (1, 5)) 107 | top1.update(prec1[0], input.size(0)) 108 | top5.update(prec5[0], input.size(0)) 109 | 110 | # measure elapsed time 111 | batch_time.update(time.time() - end) 112 | end = time.time() 113 | 114 | msg = 'Test: Time {batch_time.avg:.3f}\t' \ 115 | 'Loss {loss.avg:.4f}\t' \ 116 | 'Error@1 {error1:.3f}\t' \ 117 | 'Error@5 {error5:.3f}\t' \ 118 | 'Accuracy@1 {top1.avg:.3f}\t' \ 119 | 'Accuracy@5 {top5.avg:.3f}\t'.format( 120 | batch_time=batch_time, loss=losses, top1=top1, top5=top5, 121 | error1=100-top1.avg, error5=100-top5.avg) 122 | logger.info(msg) 123 | 124 | if writer_dict: 125 | writer = writer_dict['writer'] 126 | global_steps = writer_dict['valid_global_steps'] 127 | writer.add_scalar('valid_loss', losses.avg, global_steps) 128 | writer.add_scalar('valid_top1', top1.avg, global_steps) 129 | writer_dict['valid_global_steps'] = global_steps + 1 130 | 131 | return top1.avg 132 | 133 | 134 | class AverageMeter(object): 135 | """Computes and stores the average and current value""" 136 | def __init__(self): 137 | self.reset() 138 | 139 | def reset(self): 140 | self.val = 0 141 | self.avg = 0 142 | self.sum = 0 143 | self.count = 0 144 | 145 | def update(self, val, n=1): 146 | self.val = val 147 | self.sum += val * n 148 | self.count += n 149 | self.avg = self.sum / self.count 150 | -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import models.cls_hrnet 12 | -------------------------------------------------------------------------------- /lib/models/cls_hrnet.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn) 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os 13 | import logging 14 | import functools 15 | 16 | import numpy as np 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch._utils 21 | import torch.nn.functional as F 22 | 23 | BN_MOMENTUM = 0.1 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def conv3x3(in_planes, out_planes, stride=1): 28 | """3x3 convolution with padding""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=1, bias=False) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | 36 | def __init__(self, inplanes, planes, stride=1, downsample=None): 37 | super(BasicBlock, self).__init__() 38 | self.conv1 = conv3x3(inplanes, planes, stride) 39 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.conv2 = conv3x3(planes, planes) 42 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | 56 | if self.downsample is not None: 57 | residual = self.downsample(x) 58 | 59 | out += residual 60 | out = self.relu(out) 61 | 62 | return out 63 | 64 | 65 | class Bottleneck(nn.Module): 66 | expansion = 4 67 | 68 | def __init__(self, inplanes, planes, stride=1, downsample=None): 69 | super(Bottleneck, self).__init__() 70 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 71 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 72 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 73 | padding=1, bias=False) 74 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 75 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, 76 | bias=False) 77 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, 78 | momentum=BN_MOMENTUM) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | residual = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv3(out) 95 | out = self.bn3(out) 96 | 97 | if self.downsample is not None: 98 | residual = self.downsample(x) 99 | 100 | out += residual 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | 106 | class HighResolutionModule(nn.Module): 107 | def __init__(self, num_branches, blocks, num_blocks, num_inchannels, 108 | num_channels, fuse_method, multi_scale_output=True): 109 | super(HighResolutionModule, self).__init__() 110 | self._check_branches( 111 | num_branches, blocks, num_blocks, num_inchannels, num_channels) 112 | 113 | self.num_inchannels = num_inchannels 114 | self.fuse_method = fuse_method 115 | self.num_branches = num_branches 116 | 117 | self.multi_scale_output = multi_scale_output 118 | 119 | self.branches = self._make_branches( 120 | num_branches, blocks, num_blocks, num_channels) 121 | self.fuse_layers = self._make_fuse_layers() 122 | self.relu = nn.ReLU(False) 123 | 124 | def _check_branches(self, num_branches, blocks, num_blocks, 125 | num_inchannels, num_channels): 126 | if num_branches != len(num_blocks): 127 | error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( 128 | num_branches, len(num_blocks)) 129 | logger.error(error_msg) 130 | raise ValueError(error_msg) 131 | 132 | if num_branches != len(num_channels): 133 | error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( 134 | num_branches, len(num_channels)) 135 | logger.error(error_msg) 136 | raise ValueError(error_msg) 137 | 138 | if num_branches != len(num_inchannels): 139 | error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( 140 | num_branches, len(num_inchannels)) 141 | logger.error(error_msg) 142 | raise ValueError(error_msg) 143 | 144 | def _make_one_branch(self, branch_index, block, num_blocks, num_channels, 145 | stride=1): 146 | downsample = None 147 | if stride != 1 or \ 148 | self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: 149 | downsample = nn.Sequential( 150 | nn.Conv2d(self.num_inchannels[branch_index], 151 | num_channels[branch_index] * block.expansion, 152 | kernel_size=1, stride=stride, bias=False), 153 | nn.BatchNorm2d(num_channels[branch_index] * block.expansion, 154 | momentum=BN_MOMENTUM), 155 | ) 156 | 157 | layers = [] 158 | layers.append(block(self.num_inchannels[branch_index], 159 | num_channels[branch_index], stride, downsample)) 160 | self.num_inchannels[branch_index] = \ 161 | num_channels[branch_index] * block.expansion 162 | for i in range(1, num_blocks[branch_index]): 163 | layers.append(block(self.num_inchannels[branch_index], 164 | num_channels[branch_index])) 165 | 166 | return nn.Sequential(*layers) 167 | 168 | def _make_branches(self, num_branches, block, num_blocks, num_channels): 169 | branches = [] 170 | 171 | for i in range(num_branches): 172 | branches.append( 173 | self._make_one_branch(i, block, num_blocks, num_channels)) 174 | 175 | return nn.ModuleList(branches) 176 | 177 | def _make_fuse_layers(self): 178 | if self.num_branches == 1: 179 | return None 180 | 181 | num_branches = self.num_branches 182 | num_inchannels = self.num_inchannels 183 | fuse_layers = [] 184 | for i in range(num_branches if self.multi_scale_output else 1): 185 | fuse_layer = [] 186 | for j in range(num_branches): 187 | if j > i: 188 | fuse_layer.append(nn.Sequential( 189 | nn.Conv2d(num_inchannels[j], 190 | num_inchannels[i], 191 | 1, 192 | 1, 193 | 0, 194 | bias=False), 195 | nn.BatchNorm2d(num_inchannels[i], 196 | momentum=BN_MOMENTUM), 197 | nn.Upsample(scale_factor=2**(j-i), mode='nearest'))) 198 | elif j == i: 199 | fuse_layer.append(None) 200 | else: 201 | conv3x3s = [] 202 | for k in range(i-j): 203 | if k == i - j - 1: 204 | num_outchannels_conv3x3 = num_inchannels[i] 205 | conv3x3s.append(nn.Sequential( 206 | nn.Conv2d(num_inchannels[j], 207 | num_outchannels_conv3x3, 208 | 3, 2, 1, bias=False), 209 | nn.BatchNorm2d(num_outchannels_conv3x3, 210 | momentum=BN_MOMENTUM))) 211 | else: 212 | num_outchannels_conv3x3 = num_inchannels[j] 213 | conv3x3s.append(nn.Sequential( 214 | nn.Conv2d(num_inchannels[j], 215 | num_outchannels_conv3x3, 216 | 3, 2, 1, bias=False), 217 | nn.BatchNorm2d(num_outchannels_conv3x3, 218 | momentum=BN_MOMENTUM), 219 | nn.ReLU(False))) 220 | fuse_layer.append(nn.Sequential(*conv3x3s)) 221 | fuse_layers.append(nn.ModuleList(fuse_layer)) 222 | 223 | return nn.ModuleList(fuse_layers) 224 | 225 | def get_num_inchannels(self): 226 | return self.num_inchannels 227 | 228 | def forward(self, x): 229 | if self.num_branches == 1: 230 | return [self.branches[0](x[0])] 231 | 232 | for i in range(self.num_branches): 233 | x[i] = self.branches[i](x[i]) 234 | 235 | x_fuse = [] 236 | for i in range(len(self.fuse_layers)): 237 | y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) 238 | for j in range(1, self.num_branches): 239 | if i == j: 240 | y = y + x[j] 241 | else: 242 | y = y + self.fuse_layers[i][j](x[j]) 243 | x_fuse.append(self.relu(y)) 244 | 245 | return x_fuse 246 | 247 | 248 | blocks_dict = { 249 | 'BASIC': BasicBlock, 250 | 'BOTTLENECK': Bottleneck 251 | } 252 | 253 | 254 | class HighResolutionNet(nn.Module): 255 | 256 | def __init__(self, cfg, **kwargs): 257 | super(HighResolutionNet, self).__init__() 258 | 259 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, 260 | bias=False) 261 | self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 262 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, 263 | bias=False) 264 | self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) 265 | self.relu = nn.ReLU(inplace=True) 266 | 267 | self.stage1_cfg = cfg['MODEL']['EXTRA']['STAGE1'] 268 | num_channels = self.stage1_cfg['NUM_CHANNELS'][0] 269 | block = blocks_dict[self.stage1_cfg['BLOCK']] 270 | num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] 271 | self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) 272 | stage1_out_channel = block.expansion*num_channels 273 | 274 | self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2'] 275 | num_channels = self.stage2_cfg['NUM_CHANNELS'] 276 | block = blocks_dict[self.stage2_cfg['BLOCK']] 277 | num_channels = [ 278 | num_channels[i] * block.expansion for i in range(len(num_channels))] 279 | self.transition1 = self._make_transition_layer( 280 | [stage1_out_channel], num_channels) 281 | self.stage2, pre_stage_channels = self._make_stage( 282 | self.stage2_cfg, num_channels) 283 | 284 | self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3'] 285 | num_channels = self.stage3_cfg['NUM_CHANNELS'] 286 | block = blocks_dict[self.stage3_cfg['BLOCK']] 287 | num_channels = [ 288 | num_channels[i] * block.expansion for i in range(len(num_channels))] 289 | self.transition2 = self._make_transition_layer( 290 | pre_stage_channels, num_channels) 291 | self.stage3, pre_stage_channels = self._make_stage( 292 | self.stage3_cfg, num_channels) 293 | 294 | self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4'] 295 | num_channels = self.stage4_cfg['NUM_CHANNELS'] 296 | block = blocks_dict[self.stage4_cfg['BLOCK']] 297 | num_channels = [ 298 | num_channels[i] * block.expansion for i in range(len(num_channels))] 299 | self.transition3 = self._make_transition_layer( 300 | pre_stage_channels, num_channels) 301 | self.stage4, pre_stage_channels = self._make_stage( 302 | self.stage4_cfg, num_channels, multi_scale_output=True) 303 | 304 | # Classification Head 305 | self.incre_modules, self.downsamp_modules, \ 306 | self.final_layer = self._make_head(pre_stage_channels) 307 | 308 | self.classifier = nn.Linear(2048, 1000) 309 | 310 | def _make_head(self, pre_stage_channels): 311 | head_block = Bottleneck 312 | head_channels = [32, 64, 128, 256] 313 | 314 | # Increasing the #channels on each resolution 315 | # from C, 2C, 4C, 8C to 128, 256, 512, 1024 316 | incre_modules = [] 317 | for i, channels in enumerate(pre_stage_channels): 318 | incre_module = self._make_layer(head_block, 319 | channels, 320 | head_channels[i], 321 | 1, 322 | stride=1) 323 | incre_modules.append(incre_module) 324 | incre_modules = nn.ModuleList(incre_modules) 325 | 326 | # downsampling modules 327 | downsamp_modules = [] 328 | for i in range(len(pre_stage_channels)-1): 329 | in_channels = head_channels[i] * head_block.expansion 330 | out_channels = head_channels[i+1] * head_block.expansion 331 | 332 | downsamp_module = nn.Sequential( 333 | nn.Conv2d(in_channels=in_channels, 334 | out_channels=out_channels, 335 | kernel_size=3, 336 | stride=2, 337 | padding=1), 338 | nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM), 339 | nn.ReLU(inplace=True) 340 | ) 341 | 342 | downsamp_modules.append(downsamp_module) 343 | downsamp_modules = nn.ModuleList(downsamp_modules) 344 | 345 | final_layer = nn.Sequential( 346 | nn.Conv2d( 347 | in_channels=head_channels[3] * head_block.expansion, 348 | out_channels=2048, 349 | kernel_size=1, 350 | stride=1, 351 | padding=0 352 | ), 353 | nn.BatchNorm2d(2048, momentum=BN_MOMENTUM), 354 | nn.ReLU(inplace=True) 355 | ) 356 | 357 | return incre_modules, downsamp_modules, final_layer 358 | 359 | def _make_transition_layer( 360 | self, num_channels_pre_layer, num_channels_cur_layer): 361 | num_branches_cur = len(num_channels_cur_layer) 362 | num_branches_pre = len(num_channels_pre_layer) 363 | 364 | transition_layers = [] 365 | for i in range(num_branches_cur): 366 | if i < num_branches_pre: 367 | if num_channels_cur_layer[i] != num_channels_pre_layer[i]: 368 | transition_layers.append(nn.Sequential( 369 | nn.Conv2d(num_channels_pre_layer[i], 370 | num_channels_cur_layer[i], 371 | 3, 372 | 1, 373 | 1, 374 | bias=False), 375 | nn.BatchNorm2d( 376 | num_channels_cur_layer[i], momentum=BN_MOMENTUM), 377 | nn.ReLU(inplace=True))) 378 | else: 379 | transition_layers.append(None) 380 | else: 381 | conv3x3s = [] 382 | for j in range(i+1-num_branches_pre): 383 | inchannels = num_channels_pre_layer[-1] 384 | outchannels = num_channels_cur_layer[i] \ 385 | if j == i-num_branches_pre else inchannels 386 | conv3x3s.append(nn.Sequential( 387 | nn.Conv2d( 388 | inchannels, outchannels, 3, 2, 1, bias=False), 389 | nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM), 390 | nn.ReLU(inplace=True))) 391 | transition_layers.append(nn.Sequential(*conv3x3s)) 392 | 393 | return nn.ModuleList(transition_layers) 394 | 395 | def _make_layer(self, block, inplanes, planes, blocks, stride=1): 396 | downsample = None 397 | if stride != 1 or inplanes != planes * block.expansion: 398 | downsample = nn.Sequential( 399 | nn.Conv2d(inplanes, planes * block.expansion, 400 | kernel_size=1, stride=stride, bias=False), 401 | nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), 402 | ) 403 | 404 | layers = [] 405 | layers.append(block(inplanes, planes, stride, downsample)) 406 | inplanes = planes * block.expansion 407 | for i in range(1, blocks): 408 | layers.append(block(inplanes, planes)) 409 | 410 | return nn.Sequential(*layers) 411 | 412 | def _make_stage(self, layer_config, num_inchannels, 413 | multi_scale_output=True): 414 | num_modules = layer_config['NUM_MODULES'] 415 | num_branches = layer_config['NUM_BRANCHES'] 416 | num_blocks = layer_config['NUM_BLOCKS'] 417 | num_channels = layer_config['NUM_CHANNELS'] 418 | block = blocks_dict[layer_config['BLOCK']] 419 | fuse_method = layer_config['FUSE_METHOD'] 420 | 421 | modules = [] 422 | for i in range(num_modules): 423 | # multi_scale_output is only used last module 424 | if not multi_scale_output and i == num_modules - 1: 425 | reset_multi_scale_output = False 426 | else: 427 | reset_multi_scale_output = True 428 | 429 | modules.append( 430 | HighResolutionModule(num_branches, 431 | block, 432 | num_blocks, 433 | num_inchannels, 434 | num_channels, 435 | fuse_method, 436 | reset_multi_scale_output) 437 | ) 438 | num_inchannels = modules[-1].get_num_inchannels() 439 | 440 | return nn.Sequential(*modules), num_inchannels 441 | 442 | def forward(self, x): 443 | x = self.conv1(x) 444 | x = self.bn1(x) 445 | x = self.relu(x) 446 | x = self.conv2(x) 447 | x = self.bn2(x) 448 | x = self.relu(x) 449 | x = self.layer1(x) 450 | 451 | x_list = [] 452 | for i in range(self.stage2_cfg['NUM_BRANCHES']): 453 | if self.transition1[i] is not None: 454 | x_list.append(self.transition1[i](x)) 455 | else: 456 | x_list.append(x) 457 | y_list = self.stage2(x_list) 458 | 459 | x_list = [] 460 | for i in range(self.stage3_cfg['NUM_BRANCHES']): 461 | if self.transition2[i] is not None: 462 | x_list.append(self.transition2[i](y_list[-1])) 463 | else: 464 | x_list.append(y_list[i]) 465 | y_list = self.stage3(x_list) 466 | 467 | x_list = [] 468 | for i in range(self.stage4_cfg['NUM_BRANCHES']): 469 | if self.transition3[i] is not None: 470 | x_list.append(self.transition3[i](y_list[-1])) 471 | else: 472 | x_list.append(y_list[i]) 473 | y_list = self.stage4(x_list) 474 | 475 | # Classification Head 476 | y = self.incre_modules[0](y_list[0]) 477 | for i in range(len(self.downsamp_modules)): 478 | y = self.incre_modules[i+1](y_list[i+1]) + \ 479 | self.downsamp_modules[i](y) 480 | 481 | y = self.final_layer(y) 482 | 483 | if torch._C._get_tracing_state(): 484 | y = y.flatten(start_dim=2).mean(dim=2) 485 | else: 486 | y = F.avg_pool2d(y, kernel_size=y.size() 487 | [2:]).view(y.size(0), -1) 488 | 489 | y = self.classifier(y) 490 | 491 | return y 492 | 493 | def init_weights(self, pretrained='',): 494 | logger.info('=> init weights from normal distribution') 495 | for m in self.modules(): 496 | if isinstance(m, nn.Conv2d): 497 | nn.init.kaiming_normal_( 498 | m.weight, mode='fan_out', nonlinearity='relu') 499 | elif isinstance(m, nn.BatchNorm2d): 500 | nn.init.constant_(m.weight, 1) 501 | nn.init.constant_(m.bias, 0) 502 | if os.path.isfile(pretrained): 503 | pretrained_dict = torch.load(pretrained) 504 | logger.info('=> loading pretrained model {}'.format(pretrained)) 505 | model_dict = self.state_dict() 506 | pretrained_dict = {k: v for k, v in pretrained_dict.items() 507 | if k in model_dict.keys()} 508 | for k, _ in pretrained_dict.items(): 509 | logger.info( 510 | '=> loading {} pretrained model {}'.format(k, pretrained)) 511 | model_dict.update(pretrained_dict) 512 | self.load_state_dict(model_dict) 513 | 514 | 515 | def get_cls_net(config, **kwargs): 516 | model = HighResolutionNet(config, **kwargs) 517 | model.init_weights() 518 | return model 519 | -------------------------------------------------------------------------------- /lib/utils/modelsummary.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn) 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os 13 | import logging 14 | from collections import namedtuple 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | def get_model_summary(model, *input_tensors, item_length=26, verbose=False): 20 | """ 21 | :param model: 22 | :param input_tensors: 23 | :param item_length: 24 | :return: 25 | """ 26 | 27 | summary = [] 28 | 29 | ModuleDetails = namedtuple( 30 | "Layer", ["name", "input_size", "output_size", "num_parameters", "multiply_adds"]) 31 | hooks = [] 32 | layer_instances = {} 33 | 34 | def add_hooks(module): 35 | 36 | def hook(module, input, output): 37 | class_name = str(module.__class__.__name__) 38 | 39 | instance_index = 1 40 | if class_name not in layer_instances: 41 | layer_instances[class_name] = instance_index 42 | else: 43 | instance_index = layer_instances[class_name] + 1 44 | layer_instances[class_name] = instance_index 45 | 46 | layer_name = class_name + "_" + str(instance_index) 47 | 48 | params = 0 49 | 50 | if class_name.find("Conv") != -1 or class_name.find("BatchNorm") != -1 or \ 51 | class_name.find("Linear") != -1: 52 | for param_ in module.parameters(): 53 | params += param_.view(-1).size(0) 54 | 55 | flops = "Not Available" 56 | if class_name.find("Conv") != -1 and hasattr(module, "weight"): 57 | flops = ( 58 | torch.prod( 59 | torch.LongTensor(list(module.weight.data.size()))) * 60 | torch.prod( 61 | torch.LongTensor(list(output.size())[2:]))).item() 62 | elif isinstance(module, nn.Linear): 63 | flops = (torch.prod(torch.LongTensor(list(output.size()))) \ 64 | * input[0].size(1)).item() 65 | 66 | if isinstance(input[0], list): 67 | input = input[0] 68 | if isinstance(output, list): 69 | output = output[0] 70 | 71 | summary.append( 72 | ModuleDetails( 73 | name=layer_name, 74 | input_size=list(input[0].size()), 75 | output_size=list(output.size()), 76 | num_parameters=params, 77 | multiply_adds=flops) 78 | ) 79 | 80 | if not isinstance(module, nn.ModuleList) \ 81 | and not isinstance(module, nn.Sequential) \ 82 | and module != model: 83 | hooks.append(module.register_forward_hook(hook)) 84 | 85 | model.eval() 86 | model.apply(add_hooks) 87 | 88 | space_len = item_length 89 | 90 | model(*input_tensors) 91 | for hook in hooks: 92 | hook.remove() 93 | 94 | details = '' 95 | if verbose: 96 | details = "Model Summary" + \ 97 | os.linesep + \ 98 | "Name{}Input Size{}Output Size{}Parameters{}Multiply Adds (Flops){}".format( 99 | ' ' * (space_len - len("Name")), 100 | ' ' * (space_len - len("Input Size")), 101 | ' ' * (space_len - len("Output Size")), 102 | ' ' * (space_len - len("Parameters")), 103 | ' ' * (space_len - len("Multiply Adds (Flops)"))) \ 104 | + os.linesep + '-' * space_len * 5 + os.linesep 105 | 106 | params_sum = 0 107 | flops_sum = 0 108 | for layer in summary: 109 | params_sum += layer.num_parameters 110 | if layer.multiply_adds != "Not Available": 111 | flops_sum += layer.multiply_adds 112 | if verbose: 113 | details += "{}{}{}{}{}{}{}{}{}{}".format( 114 | layer.name, 115 | ' ' * (space_len - len(layer.name)), 116 | layer.input_size, 117 | ' ' * (space_len - len(str(layer.input_size))), 118 | layer.output_size, 119 | ' ' * (space_len - len(str(layer.output_size))), 120 | layer.num_parameters, 121 | ' ' * (space_len - len(str(layer.num_parameters))), 122 | layer.multiply_adds, 123 | ' ' * (space_len - len(str(layer.multiply_adds)))) \ 124 | + os.linesep + '-' * space_len * 5 + os.linesep 125 | 126 | details += os.linesep \ 127 | + "Total Parameters: {:,}".format(params_sum) \ 128 | + os.linesep + '-' * space_len * 5 + os.linesep 129 | details += "Total Multiply Adds (For Convolution and Linear Layers only): {:,} GFLOPs".format(flops_sum/(1024**3)) \ 130 | + os.linesep + '-' * space_len * 5 + os.linesep 131 | details += "Number of Layers" + os.linesep 132 | for layer in layer_instances: 133 | details += "{} : {} layers ".format(layer, layer_instances[layer]) 134 | 135 | return details -------------------------------------------------------------------------------- /lib/utils/utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn) 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os 13 | import logging 14 | import time 15 | from pathlib import Path 16 | 17 | import torch 18 | import torch.optim as optim 19 | 20 | def create_logger(cfg, cfg_name, phase='train'): 21 | root_output_dir = Path(cfg.OUTPUT_DIR) 22 | # set up logger 23 | if not root_output_dir.exists(): 24 | print('=> creating {}'.format(root_output_dir)) 25 | root_output_dir.mkdir() 26 | 27 | dataset = cfg.DATASET.DATASET 28 | model = cfg.MODEL.NAME 29 | cfg_name = os.path.basename(cfg_name).split('.')[0] 30 | 31 | final_output_dir = root_output_dir / dataset / cfg_name 32 | 33 | print('=> creating {}'.format(final_output_dir)) 34 | final_output_dir.mkdir(parents=True, exist_ok=True) 35 | 36 | time_str = time.strftime('%Y-%m-%d-%H-%M') 37 | log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase) 38 | final_log_file = final_output_dir / log_file 39 | head = '%(asctime)-15s %(message)s' 40 | logging.basicConfig(filename=str(final_log_file), 41 | format=head) 42 | logger = logging.getLogger() 43 | logger.setLevel(logging.INFO) 44 | console = logging.StreamHandler() 45 | logging.getLogger('').addHandler(console) 46 | 47 | tensorboard_log_dir = Path(cfg.LOG_DIR) / dataset / model / \ 48 | (cfg_name + '_' + time_str) 49 | print('=> creating {}'.format(tensorboard_log_dir)) 50 | tensorboard_log_dir.mkdir(parents=True, exist_ok=True) 51 | 52 | return logger, str(final_output_dir), str(tensorboard_log_dir) 53 | 54 | 55 | def get_optimizer(cfg, model): 56 | optimizer = None 57 | if cfg.TRAIN.OPTIMIZER == 'sgd': 58 | optimizer = optim.SGD( 59 | #model.parameters(), 60 | filter(lambda p: p.requires_grad, model.parameters()), 61 | lr=cfg.TRAIN.LR, 62 | momentum=cfg.TRAIN.MOMENTUM, 63 | weight_decay=cfg.TRAIN.WD, 64 | nesterov=cfg.TRAIN.NESTEROV 65 | ) 66 | elif cfg.TRAIN.OPTIMIZER == 'adam': 67 | optimizer = optim.Adam( 68 | #model.parameters(), 69 | filter(lambda p: p.requires_grad, model.parameters()), 70 | lr=cfg.TRAIN.LR 71 | ) 72 | elif cfg.TRAIN.OPTIMIZER == 'rmsprop': 73 | optimizer = optim.RMSprop( 74 | #model.parameters(), 75 | filter(lambda p: p.requires_grad, model.parameters()), 76 | lr=cfg.TRAIN.LR, 77 | momentum=cfg.TRAIN.MOMENTUM, 78 | weight_decay=cfg.TRAIN.WD, 79 | alpha=cfg.TRAIN.RMSPROP_ALPHA, 80 | centered=cfg.TRAIN.RMSPROP_CENTERED 81 | ) 82 | 83 | return optimizer 84 | 85 | 86 | def save_checkpoint(states, is_best, output_dir, 87 | filename='checkpoint.pth.tar'): 88 | torch.save(states, os.path.join(output_dir, filename)) 89 | if is_best and 'state_dict' in states: 90 | torch.save(states['state_dict'], 91 | os.path.join(output_dir, 'model_best.pth.tar')) 92 | 93 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | EasyDict==1.7 2 | opencv-python==3.4.1.15 3 | shapely==1.6.4 4 | Cython 5 | scipy 6 | pandas 7 | pyyaml 8 | json_tricks 9 | scikit-image 10 | yacs>=0.1.5 11 | tensorboardX>=1.6 12 | -------------------------------------------------------------------------------- /tools/_init_paths.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os.path as osp 12 | import sys 13 | 14 | 15 | def add_path(path): 16 | if path not in sys.path: 17 | sys.path.insert(0, path) 18 | 19 | 20 | this_dir = osp.dirname(__file__) 21 | 22 | lib_path = osp.join(this_dir, '..', 'lib') 23 | add_path(lib_path) 24 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn) 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import argparse 13 | import os 14 | import pprint 15 | import shutil 16 | import sys 17 | 18 | import torch 19 | import torch.nn.parallel 20 | import torch.backends.cudnn as cudnn 21 | import torch.optim 22 | import torch.utils.data 23 | import torch.utils.data.distributed 24 | import torchvision.datasets as datasets 25 | import torchvision.transforms as transforms 26 | from tensorboardX import SummaryWriter 27 | 28 | import _init_paths 29 | import models 30 | from config import config 31 | from config import update_config 32 | from core.function import train 33 | from core.function import validate 34 | from utils.modelsummary import get_model_summary 35 | from utils.utils import get_optimizer 36 | from utils.utils import save_checkpoint 37 | from utils.utils import create_logger 38 | 39 | 40 | def parse_args(): 41 | parser = argparse.ArgumentParser(description='Train classification network') 42 | 43 | parser.add_argument('--cfg', 44 | help='experiment configure file name', 45 | required=True, 46 | type=str) 47 | 48 | parser.add_argument('--modelDir', 49 | help='model directory', 50 | type=str, 51 | default='') 52 | parser.add_argument('--logDir', 53 | help='log directory', 54 | type=str, 55 | default='') 56 | parser.add_argument('--dataDir', 57 | help='data directory', 58 | type=str, 59 | default='') 60 | parser.add_argument('--testModel', 61 | help='testModel', 62 | type=str, 63 | default='') 64 | 65 | args = parser.parse_args() 66 | update_config(config, args) 67 | 68 | return args 69 | 70 | def main(): 71 | args = parse_args() 72 | 73 | logger, final_output_dir, tb_log_dir = create_logger( 74 | config, args.cfg, 'train') 75 | 76 | logger.info(pprint.pformat(args)) 77 | logger.info(pprint.pformat(config)) 78 | 79 | # cudnn related setting 80 | cudnn.benchmark = config.CUDNN.BENCHMARK 81 | torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC 82 | torch.backends.cudnn.enabled = config.CUDNN.ENABLED 83 | 84 | model = eval('models.'+config.MODEL.NAME+'.get_cls_net')( 85 | config) 86 | 87 | dump_input = torch.rand( 88 | (1, 3, config.MODEL.IMAGE_SIZE[1], config.MODEL.IMAGE_SIZE[0]) 89 | ) 90 | logger.info(get_model_summary(model, dump_input)) 91 | 92 | # copy model file 93 | this_dir = os.path.dirname(__file__) 94 | models_dst_dir = os.path.join(final_output_dir, 'models') 95 | if os.path.exists(models_dst_dir): 96 | shutil.rmtree(models_dst_dir) 97 | shutil.copytree(os.path.join(this_dir, '../lib/models'), models_dst_dir) 98 | 99 | writer_dict = { 100 | 'writer': SummaryWriter(log_dir=tb_log_dir), 101 | 'train_global_steps': 0, 102 | 'valid_global_steps': 0, 103 | } 104 | 105 | gpus = list(config.GPUS) 106 | model = torch.nn.DataParallel(model, device_ids=gpus).cuda() 107 | 108 | # define loss function (criterion) and optimizer 109 | criterion = torch.nn.CrossEntropyLoss().cuda() 110 | 111 | optimizer = get_optimizer(config, model) 112 | 113 | best_perf = 0.0 114 | best_model = False 115 | last_epoch = config.TRAIN.BEGIN_EPOCH 116 | if config.TRAIN.RESUME: 117 | model_state_file = os.path.join(final_output_dir, 118 | 'checkpoint.pth.tar') 119 | if os.path.isfile(model_state_file): 120 | checkpoint = torch.load(model_state_file) 121 | last_epoch = checkpoint['epoch'] 122 | best_perf = checkpoint['perf'] 123 | model.module.load_state_dict(checkpoint['state_dict']) 124 | optimizer.load_state_dict(checkpoint['optimizer']) 125 | logger.info("=> loaded checkpoint (epoch {})" 126 | .format(checkpoint['epoch'])) 127 | best_model = True 128 | 129 | if isinstance(config.TRAIN.LR_STEP, list): 130 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 131 | optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR, 132 | last_epoch-1 133 | ) 134 | else: 135 | lr_scheduler = torch.optim.lr_scheduler.StepLR( 136 | optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR, 137 | last_epoch-1 138 | ) 139 | 140 | # Data loading code 141 | traindir = os.path.join(config.DATASET.ROOT, config.DATASET.TRAIN_SET) 142 | valdir = os.path.join(config.DATASET.ROOT, config.DATASET.TEST_SET) 143 | 144 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 145 | std=[0.229, 0.224, 0.225]) 146 | 147 | train_dataset = datasets.ImageFolder( 148 | traindir, 149 | transforms.Compose([ 150 | transforms.RandomResizedCrop(config.MODEL.IMAGE_SIZE[0]), 151 | transforms.RandomHorizontalFlip(), 152 | transforms.ToTensor(), 153 | normalize, 154 | ]) 155 | ) 156 | train_loader = torch.utils.data.DataLoader( 157 | train_dataset, 158 | batch_size=config.TRAIN.BATCH_SIZE_PER_GPU*len(gpus), 159 | shuffle=True, 160 | num_workers=config.WORKERS, 161 | pin_memory=True 162 | ) 163 | 164 | valid_loader = torch.utils.data.DataLoader( 165 | datasets.ImageFolder(valdir, transforms.Compose([ 166 | transforms.Resize(int(config.MODEL.IMAGE_SIZE[0] / 0.875)), 167 | transforms.CenterCrop(config.MODEL.IMAGE_SIZE[0]), 168 | transforms.ToTensor(), 169 | normalize, 170 | ])), 171 | batch_size=config.TEST.BATCH_SIZE_PER_GPU*len(gpus), 172 | shuffle=False, 173 | num_workers=config.WORKERS, 174 | pin_memory=True 175 | ) 176 | 177 | for epoch in range(last_epoch, config.TRAIN.END_EPOCH): 178 | lr_scheduler.step() 179 | # train for one epoch 180 | train(config, train_loader, model, criterion, optimizer, epoch, 181 | final_output_dir, tb_log_dir, writer_dict) 182 | # evaluate on validation set 183 | perf_indicator = validate(config, valid_loader, model, criterion, 184 | final_output_dir, tb_log_dir, writer_dict) 185 | 186 | if perf_indicator > best_perf: 187 | best_perf = perf_indicator 188 | best_model = True 189 | else: 190 | best_model = False 191 | 192 | logger.info('=> saving checkpoint to {}'.format(final_output_dir)) 193 | save_checkpoint({ 194 | 'epoch': epoch + 1, 195 | 'model': config.MODEL.NAME, 196 | 'state_dict': model.module.state_dict(), 197 | 'perf': perf_indicator, 198 | 'optimizer': optimizer.state_dict(), 199 | }, best_model, final_output_dir, filename='checkpoint.pth.tar') 200 | 201 | final_model_state_file = os.path.join(final_output_dir, 202 | 'final_state.pth.tar') 203 | logger.info('saving final model state to {}'.format( 204 | final_model_state_file)) 205 | torch.save(model.module.state_dict(), final_model_state_file) 206 | writer_dict['writer'].close() 207 | 208 | 209 | if __name__ == '__main__': 210 | main() 211 | -------------------------------------------------------------------------------- /tools/valid.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn) 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import argparse 13 | import os 14 | import sys 15 | import shutil 16 | import pprint 17 | 18 | import torch 19 | import torch.nn.parallel 20 | import torch.backends.cudnn as cudnn 21 | import torch.optim 22 | import torch.utils.data 23 | import torch.utils.data.distributed 24 | import torchvision.datasets as datasets 25 | import torchvision.transforms as transforms 26 | 27 | import _init_paths 28 | import models 29 | from config import config 30 | from config import update_config 31 | from core.function import validate 32 | from utils.modelsummary import get_model_summary 33 | from utils.utils import create_logger 34 | 35 | 36 | def parse_args(): 37 | parser = argparse.ArgumentParser(description='Train keypoints network') 38 | 39 | parser.add_argument('--cfg', 40 | help='experiment configure file name', 41 | required=True, 42 | type=str) 43 | 44 | parser.add_argument('--modelDir', 45 | help='model directory', 46 | type=str, 47 | default='') 48 | parser.add_argument('--logDir', 49 | help='log directory', 50 | type=str, 51 | default='') 52 | parser.add_argument('--dataDir', 53 | help='data directory', 54 | type=str, 55 | default='') 56 | parser.add_argument('--testModel', 57 | help='testModel', 58 | type=str, 59 | default='') 60 | 61 | args = parser.parse_args() 62 | update_config(config, args) 63 | 64 | return args 65 | 66 | def main(): 67 | args = parse_args() 68 | 69 | logger, final_output_dir, tb_log_dir = create_logger( 70 | config, args.cfg, 'valid') 71 | 72 | logger.info(pprint.pformat(args)) 73 | logger.info(pprint.pformat(config)) 74 | 75 | # cudnn related setting 76 | cudnn.benchmark = config.CUDNN.BENCHMARK 77 | torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC 78 | torch.backends.cudnn.enabled = config.CUDNN.ENABLED 79 | 80 | model = eval('models.'+config.MODEL.NAME+'.get_cls_net')( 81 | config) 82 | 83 | dump_input = torch.rand( 84 | (1, 3, config.MODEL.IMAGE_SIZE[1], config.MODEL.IMAGE_SIZE[0]) 85 | ) 86 | logger.info(get_model_summary(model, dump_input)) 87 | 88 | if config.TEST.MODEL_FILE: 89 | logger.info('=> loading model from {}'.format(config.TEST.MODEL_FILE)) 90 | model.load_state_dict(torch.load(config.TEST.MODEL_FILE)) 91 | else: 92 | model_state_file = os.path.join(final_output_dir, 93 | 'final_state.pth.tar') 94 | logger.info('=> loading model from {}'.format(model_state_file)) 95 | model.load_state_dict(torch.load(model_state_file)) 96 | 97 | gpus = list(config.GPUS) 98 | model = torch.nn.DataParallel(model, device_ids=gpus).cuda() 99 | 100 | # define loss function (criterion) and optimizer 101 | criterion = torch.nn.CrossEntropyLoss().cuda() 102 | 103 | # Data loading code 104 | valdir = os.path.join(config.DATASET.ROOT, 105 | config.DATASET.TEST_SET) 106 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 107 | std=[0.229, 0.224, 0.225]) 108 | 109 | valid_loader = torch.utils.data.DataLoader( 110 | datasets.ImageFolder(valdir, transforms.Compose([ 111 | transforms.Resize(int(config.MODEL.IMAGE_SIZE[0] / 0.875)), 112 | transforms.CenterCrop(config.MODEL.IMAGE_SIZE[0]), 113 | transforms.ToTensor(), 114 | normalize, 115 | ])), 116 | batch_size=config.TEST.BATCH_SIZE_PER_GPU*len(gpus), 117 | shuffle=False, 118 | num_workers=config.WORKERS, 119 | pin_memory=True 120 | ) 121 | 122 | # evaluate on validation set 123 | validate(config, valid_loader, model, criterion, final_output_dir, 124 | tb_log_dir, None) 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | --------------------------------------------------------------------------------