├── .gitignore ├── LICENSE ├── README.md ├── bake.png ├── imagenet ├── LICENSE ├── README.md ├── configs │ ├── effnet │ │ ├── BAKE-EN-B0_dds_8gpu.yaml │ │ └── EN-B0_dds_8gpu.yaml │ ├── mobilenet │ │ ├── BAKE-M-V2-W1_dds_4gpu.yaml │ │ └── M-V2-W1_dds_4gpu.yaml │ ├── resnest │ │ ├── BAKE-S-101_dds_8gpu.yaml │ │ ├── BAKE-S-50_dds_8gpu.yaml │ │ ├── S-101_dds_8gpu.yaml │ │ └── S-50_dds_8gpu.yaml │ ├── resnet │ │ ├── BAKE-R-101-1x64d_dds_8gpu.yaml │ │ ├── BAKE-R-152-1x64d_dds_8gpu.yaml │ │ ├── BAKE-R-50-1x64d_dds_8gpu.yaml │ │ ├── R-101-1x64d_dds_8gpu.yaml │ │ ├── R-152-1x64d_dds_8gpu.yaml │ │ └── R-50-1x64d_dds_8gpu.yaml │ └── resnext │ │ ├── BAKE-X-101-32x4d_dds_8gpu.yaml │ │ ├── BAKE-X-152-32x4d_dds_8gpu.yaml │ │ ├── X-101-32x4d_dds_8gpu.yaml │ │ └── X-152-32x4d_dds_8gpu.yaml ├── pycls │ ├── __init__.py │ ├── core │ │ ├── __init__.py │ │ ├── benchmark.py │ │ ├── builders.py │ │ ├── checkpoint.py │ │ ├── config.py │ │ ├── distributed.py │ │ ├── io.py │ │ ├── logging.py │ │ ├── meters.py │ │ ├── net.py │ │ ├── optimizer.py │ │ ├── plotting.py │ │ ├── timer.py │ │ └── trainer.py │ ├── datasets │ │ ├── __init__.py │ │ ├── imagenet.py │ │ ├── loader.py │ │ ├── sampler.py │ │ └── transforms.py │ └── models │ │ ├── __init__.py │ │ ├── anynet.py │ │ ├── effnet.py │ │ ├── mobilenetv2.py │ │ ├── resnest.py │ │ └── resnet.py ├── requirements.txt ├── setup.py └── tools │ ├── dist_test.sh │ ├── dist_train.sh │ ├── slurm_test.sh │ ├── slurm_train.sh │ ├── test_net.py │ └── train_net.py └── small_scale ├── README.md ├── data ├── README.md └── tools │ ├── cub.py │ ├── mit67.py │ ├── stanford_dogs.py │ └── tinyimagenet.py ├── datasets.py ├── models ├── __init__.py ├── densenet.py ├── densenet3.py └── resnet.py ├── scripts ├── train_bake.sh ├── train_baseline.sh └── val.sh ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Shared objects 7 | *.so 8 | 9 | # Distribution / packaging 10 | build/ 11 | *.egg-info/ 12 | *.egg 13 | 14 | # Temporary files 15 | *.swn 16 | *.swo 17 | *.swp 18 | 19 | # PyCharm 20 | .idea/ 21 | 22 | # Mac 23 | .DS_STORE 24 | 25 | # Data symlinks 26 | pycls/datasets/data/ 27 | 28 | # Other 29 | logs/ 30 | scratch* 31 | 32 | .un~ 33 | 34 | arun_log/ 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Yixiao Ge 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-distillation with Batch Knowledge Ensembling Improves ImageNet Classification 2 | 3 | 4 |
5 | 6 |
7 | 8 | 9 | ## Updates 10 | 11 | [2021-06-08] The implementation of BAKE on small-scale datasets has been added, please refer to [small_scale](small_scale/). 12 | [2021-06-09] The implementation of BAKE on ImageNet has been added, please refer to [imagenet](imagenet/). 13 | 14 | 15 | ## Citation 16 | 17 | If you find **BAKE** helpful in your research, please consider citing: 18 | 19 | ``` 20 | @misc{ge2020bake, 21 | title={Self-distillation with Batch Knowledge Ensembling Improves ImageNet Classification}, 22 | author={Yixiao Ge and Ching Lam Choi and Xiao Zhang and Peipei Zhao and Feng Zhu and Rui Zhao and Hongsheng Li}, 23 | year={2021}, 24 | archivePrefix={arXiv}, 25 | primaryClass={cs.CV} 26 | } 27 | ``` 28 | -------------------------------------------------------------------------------- /bake.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxgeee/BAKE/07c4f668ea19311d5b50121026e73d2f035d5765/bake.png -------------------------------------------------------------------------------- /imagenet/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. Modified by Yixiao Ge. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /imagenet/README.md: -------------------------------------------------------------------------------- 1 | # BAKE on ImageNet 2 | 3 | PyTorch implementation of [Self-distillation with Batch Knowledge Ensembling Improves ImageNet Classification](https://arxiv.org/abs/2104.13298) on ImageNet. 4 | 5 | 6 | ## Installation 7 | 8 | - Install Python dependencies: 9 | 10 | ``` 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | - Set up Python modules: 15 | 16 | ``` 17 | python setup.py develop --user 18 | ``` 19 | 20 | ## Dataset 21 | 22 | - Download ImageNet with an expected structure: 23 | 24 | ``` 25 | imagenet 26 | |_ train 27 | | |_ n01440764 28 | | |_ ... 29 | | |_ n15075141 30 | |_ val 31 | | |_ n01440764 32 | | |_ ... 33 | | |_ n15075141 34 | |_ ... 35 | ``` 36 | 37 | - Create a directory containing symlinks: 38 | 39 | ``` 40 | mkdir -p pycls/datasets/data 41 | ``` 42 | 43 | - Symlink ImageNet: 44 | 45 | ``` 46 | ln -s /path/imagenet pycls/datasets/data/imagenet 47 | ``` 48 | 49 | 50 | ## BAKE Training 51 | 52 | ``` 53 | sh tools/dist_train.sh 54 | ``` 55 | 56 | For example, 57 | ``` 58 | sh tools/dist_train.sh logs/resnet50_bake configs/resnet/BAKE-R-50-1x64d_dds_8gpu.yaml 59 | ``` 60 | 61 | **Note**: you could use `tools/slurm_train.sh` for distributed training with multiple machines. 62 | 63 | 64 | ## Validation 65 | 66 | ``` 67 | sh tools/dist_test.sh 68 | ``` 69 | 70 | For example, 71 | ``` 72 | sh tools/dist_test.sh configs/resnet/BAKE-R-50-1x64d_dds_8gpu.yaml logs/resnet50_bake/checkpoints/model_epoch_0100.pyth 73 | ``` 74 | 75 | ## Results 76 | 77 | |architecture|ImageNet top-1 acc.|config|download| 78 | |---|:--:|:--:|:--:| 79 | |ResNet-50|78.0|[config](configs/resnet/BAKE-R-50-1x64d_dds_8gpu.yaml)|[model](https://drive.google.com/file/d/1RJeUxXLHnSc6m3iiaEslVN2hx3Vsbldc/view?usp=sharing)| 80 | |ResNet-101|79.3|[config](configs/resnet/BAKE-R-101-1x64d_dds_8gpu.yaml)|| 81 | |ResNet-152|79.6|[config](configs/resnet/BAKE-R-152-1x64d_dds_8gpu.yaml)|| 82 | |ResNeSt-50|79.4|[config](configs/resnest/BAKE-S-50_dds_8gpu.yaml)|| 83 | |ResNeSt-101|80.4|[config](configs/resnest/BAKE-S-101_dds_8gpu.yaml)|| 84 | |ResNeXt-101(32x4d)|79.3|[config](configs/resnext/BAKE-X-101-32x4d_dds_8gpu.yaml)|| 85 | |ResNeXt-152(32x4d)|79.7|[config](configs/resnext/BAKE-X-152-32x4d_dds_8gpu.yaml)|| 86 | |MobileNet-V2|72.0|[config](configs/mobilenet/BAKE-M-V2-W1_dds_4gpu.yaml)|| 87 | |EfficientNet-B0|76.2|[config](configs/effnet/BAKE-EN-B0_dds_8gpu.yaml)|| 88 | 89 | ## Thanks 90 | The code is modified from [pycls](https://github.com/facebookresearch/pycls). 91 | -------------------------------------------------------------------------------- /imagenet/configs/effnet/BAKE-EN-B0_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: effnet 3 | NUM_CLASSES: 1000 4 | EN: 5 | STEM_W: 32 6 | STRIDES: [1, 2, 2, 2, 1, 2, 1] 7 | DEPTHS: [1, 2, 2, 3, 3, 4, 1] 8 | WIDTHS: [16, 24, 40, 80, 112, 192, 320] 9 | EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6] 10 | KERNELS: [3, 3, 5, 3, 5, 5, 3] 11 | HEAD_W: 1280 12 | OPTIM: 13 | LR_POLICY: cos 14 | BASE_LR: 0.8 15 | MAX_EPOCH: 100 16 | MOMENTUM: 0.9 17 | WEIGHT_DECAY: 1e-5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 512 22 | INTRA_IMGS: 1 23 | SYNC_BN: True 24 | USE_BAKE: True 25 | OMEGA: 0.5 26 | TEMP: 4.0 27 | LAMBDA: 1.0 28 | GLOBAL_KE: True 29 | TEST: 30 | DATASET: imagenet 31 | IM_SIZE: 256 32 | BATCH_SIZE: 400 33 | -------------------------------------------------------------------------------- /imagenet/configs/effnet/EN-B0_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: effnet 3 | NUM_CLASSES: 1000 4 | EN: 5 | STEM_W: 32 6 | STRIDES: [1, 2, 2, 2, 1, 2, 1] 7 | DEPTHS: [1, 2, 2, 3, 3, 4, 1] 8 | WIDTHS: [16, 24, 40, 80, 112, 192, 320] 9 | EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6] 10 | KERNELS: [3, 3, 5, 3, 5, 5, 3] 11 | HEAD_W: 1280 12 | OPTIM: 13 | LR_POLICY: cos 14 | BASE_LR: 0.4 15 | MAX_EPOCH: 100 16 | MOMENTUM: 0.9 17 | WEIGHT_DECAY: 1e-5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 256 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 200 26 | -------------------------------------------------------------------------------- /imagenet/configs/mobilenet/BAKE-M-V2-W1_dds_4gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: mobilenetv2 3 | NUM_CLASSES: 1000 4 | OPTIM: 5 | LR_POLICY: cos 6 | BASE_LR: 0.1 7 | MAX_EPOCH: 100 8 | MOMENTUM: 0.9 9 | WEIGHT_DECAY: 4e-5 10 | TRAIN: 11 | DATASET: imagenet 12 | IM_SIZE: 224 13 | BATCH_SIZE: 512 14 | INTRA_IMGS: 1 15 | SYNC_BN: True 16 | USE_BAKE: True 17 | OMEGA: 0.5 18 | TEMP: 4.0 19 | LAMBDA: 1.0 20 | GLOBAL_KE: True 21 | TEST: 22 | DATASET: imagenet 23 | IM_SIZE: 256 24 | BATCH_SIZE: 400 25 | -------------------------------------------------------------------------------- /imagenet/configs/mobilenet/M-V2-W1_dds_4gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: mobilenetv2 3 | NUM_CLASSES: 1000 4 | OPTIM: 5 | LR_POLICY: cos 6 | BASE_LR: 0.05 7 | MAX_EPOCH: 100 8 | MOMENTUM: 0.9 9 | WEIGHT_DECAY: 4e-5 10 | TRAIN: 11 | DATASET: imagenet 12 | IM_SIZE: 224 13 | BATCH_SIZE: 256 14 | TEST: 15 | DATASET: imagenet 16 | IM_SIZE: 256 17 | BATCH_SIZE: 200 18 | -------------------------------------------------------------------------------- /imagenet/configs/resnest/BAKE-S-101_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: resnest 3 | NUM_CLASSES: 1000 4 | RESNEST: 5 | STEM_W: 64 6 | DEPTHS: [3, 4, 23, 3] 7 | OPTIM: 8 | LR_POLICY: cos 9 | BASE_LR: 0.4 10 | MAX_EPOCH: 100 11 | MOMENTUM: 0.9 12 | WEIGHT_DECAY: 5e-5 13 | TRAIN: 14 | DATASET: imagenet 15 | IM_SIZE: 224 16 | BATCH_SIZE: 512 17 | INTRA_IMGS: 1 18 | SYNC_BN: True 19 | USE_BAKE: True 20 | OMEGA: 0.5 21 | TEMP: 4.0 22 | LAMBDA: 1.0 23 | GLOBAL_KE: True 24 | TEST: 25 | DATASET: imagenet 26 | IM_SIZE: 256 27 | BATCH_SIZE: 400 28 | -------------------------------------------------------------------------------- /imagenet/configs/resnest/BAKE-S-50_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: resnest 3 | NUM_CLASSES: 1000 4 | RESNEST: 5 | STEM_W: 32 6 | DEPTHS: [3, 4, 6, 3] 7 | OPTIM: 8 | LR_POLICY: cos 9 | BASE_LR: 0.4 10 | MAX_EPOCH: 100 11 | MOMENTUM: 0.9 12 | WEIGHT_DECAY: 5e-5 13 | TRAIN: 14 | DATASET: imagenet 15 | IM_SIZE: 224 16 | BATCH_SIZE: 512 17 | INTRA_IMGS: 1 18 | SYNC_BN: True 19 | USE_BAKE: True 20 | OMEGA: 0.5 21 | TEMP: 4.0 22 | LAMBDA: 1.0 23 | GLOBAL_KE: True 24 | TEST: 25 | DATASET: imagenet 26 | IM_SIZE: 256 27 | BATCH_SIZE: 400 28 | -------------------------------------------------------------------------------- /imagenet/configs/resnest/S-101_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: resnest 3 | NUM_CLASSES: 1000 4 | RESNEST: 5 | STEM_W: 64 6 | DEPTHS: [3, 4, 23, 3] 7 | OPTIM: 8 | LR_POLICY: cos 9 | BASE_LR: 0.2 10 | MAX_EPOCH: 100 11 | MOMENTUM: 0.9 12 | WEIGHT_DECAY: 5e-5 13 | TRAIN: 14 | DATASET: imagenet 15 | IM_SIZE: 224 16 | BATCH_SIZE: 256 17 | TEST: 18 | DATASET: imagenet 19 | IM_SIZE: 256 20 | BATCH_SIZE: 200 21 | -------------------------------------------------------------------------------- /imagenet/configs/resnest/S-50_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: resnest 3 | NUM_CLASSES: 1000 4 | RESNEST: 5 | STEM_W: 32 6 | DEPTHS: [3, 4, 6, 3] 7 | OPTIM: 8 | LR_POLICY: cos 9 | BASE_LR: 0.2 10 | MAX_EPOCH: 100 11 | MOMENTUM: 0.9 12 | WEIGHT_DECAY: 5e-5 13 | TRAIN: 14 | DATASET: imagenet 15 | IM_SIZE: 224 16 | BATCH_SIZE: 256 17 | TEST: 18 | DATASET: imagenet 19 | IM_SIZE: 256 20 | BATCH_SIZE: 200 21 | -------------------------------------------------------------------------------- /imagenet/configs/resnet/BAKE-R-101-1x64d_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: anynet 3 | NUM_CLASSES: 1000 4 | ANYNET: 5 | STEM_TYPE: res_stem_in 6 | STEM_W: 64 7 | BLOCK_TYPE: res_bottleneck_block 8 | STRIDES: [1, 2, 2, 2] 9 | DEPTHS: [3, 4, 23, 3] 10 | WIDTHS: [256, 512, 1024, 2048] 11 | BOT_MULS: [0.25, 0.25, 0.25, 0.25] 12 | GROUP_WS: [64, 128, 256, 512] 13 | OPTIM: 14 | LR_POLICY: cos 15 | BASE_LR: 0.4 16 | MAX_EPOCH: 100 17 | MOMENTUM: 0.9 18 | WEIGHT_DECAY: 5e-5 19 | TRAIN: 20 | DATASET: imagenet 21 | IM_SIZE: 224 22 | BATCH_SIZE: 512 23 | INTRA_IMGS: 1 24 | SYNC_BN: True 25 | USE_BAKE: True 26 | OMEGA: 0.5 27 | TEMP: 4.0 28 | LAMBDA: 1.0 29 | GLOBAL_KE: True 30 | TEST: 31 | DATASET: imagenet 32 | IM_SIZE: 256 33 | BATCH_SIZE: 400 34 | -------------------------------------------------------------------------------- /imagenet/configs/resnet/BAKE-R-152-1x64d_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: anynet 3 | NUM_CLASSES: 1000 4 | ANYNET: 5 | STEM_TYPE: res_stem_in 6 | STEM_W: 64 7 | BLOCK_TYPE: res_bottleneck_block 8 | STRIDES: [1, 2, 2, 2] 9 | DEPTHS: [3, 8, 36, 3] 10 | WIDTHS: [256, 512, 1024, 2048] 11 | BOT_MULS: [0.25, 0.25, 0.25, 0.25] 12 | GROUP_WS: [64, 128, 256, 512] 13 | OPTIM: 14 | LR_POLICY: cos 15 | BASE_LR: 0.4 16 | MAX_EPOCH: 100 17 | MOMENTUM: 0.9 18 | WEIGHT_DECAY: 5e-5 19 | TRAIN: 20 | DATASET: imagenet 21 | IM_SIZE: 224 22 | BATCH_SIZE: 512 23 | INTRA_IMGS: 1 24 | SYNC_BN: True 25 | USE_BAKE: True 26 | OMEGA: 0.5 27 | TEMP: 4.0 28 | LAMBDA: 1.0 29 | GLOBAL_KE: True 30 | TEST: 31 | DATASET: imagenet 32 | IM_SIZE: 256 33 | BATCH_SIZE: 400 34 | -------------------------------------------------------------------------------- /imagenet/configs/resnet/BAKE-R-50-1x64d_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: anynet 3 | NUM_CLASSES: 1000 4 | ANYNET: 5 | STEM_TYPE: res_stem_in 6 | STEM_W: 64 7 | BLOCK_TYPE: res_bottleneck_block 8 | STRIDES: [1, 2, 2, 2] 9 | DEPTHS: [3, 4, 6, 3] 10 | WIDTHS: [256, 512, 1024, 2048] 11 | BOT_MULS: [0.25, 0.25, 0.25, 0.25] 12 | GROUP_WS: [64, 128, 256, 512] 13 | OPTIM: 14 | LR_POLICY: cos 15 | BASE_LR: 0.4 16 | MAX_EPOCH: 100 17 | MOMENTUM: 0.9 18 | WEIGHT_DECAY: 5e-5 19 | TRAIN: 20 | DATASET: imagenet 21 | IM_SIZE: 224 22 | BATCH_SIZE: 512 23 | INTRA_IMGS: 1 24 | SYNC_BN: True 25 | USE_BAKE: True 26 | OMEGA: 0.5 27 | TEMP: 4.0 28 | LAMBDA: 1.0 29 | GLOBAL_KE: True 30 | TEST: 31 | DATASET: imagenet 32 | IM_SIZE: 256 33 | BATCH_SIZE: 400 34 | -------------------------------------------------------------------------------- /imagenet/configs/resnet/R-101-1x64d_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: anynet 3 | NUM_CLASSES: 1000 4 | ANYNET: 5 | STEM_TYPE: res_stem_in 6 | STEM_W: 64 7 | BLOCK_TYPE: res_bottleneck_block 8 | STRIDES: [1, 2, 2, 2] 9 | DEPTHS: [3, 4, 23, 3] 10 | WIDTHS: [256, 512, 1024, 2048] 11 | BOT_MULS: [0.25, 0.25, 0.25, 0.25] 12 | GROUP_WS: [64, 128, 256, 512] 13 | OPTIM: 14 | LR_POLICY: cos 15 | BASE_LR: 0.2 16 | MAX_EPOCH: 100 17 | MOMENTUM: 0.9 18 | WEIGHT_DECAY: 5e-5 19 | TRAIN: 20 | DATASET: imagenet 21 | IM_SIZE: 224 22 | BATCH_SIZE: 256 23 | TEST: 24 | DATASET: imagenet 25 | IM_SIZE: 256 26 | BATCH_SIZE: 200 27 | -------------------------------------------------------------------------------- /imagenet/configs/resnet/R-152-1x64d_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: anynet 3 | NUM_CLASSES: 1000 4 | ANYNET: 5 | STEM_TYPE: res_stem_in 6 | STEM_W: 64 7 | BLOCK_TYPE: res_bottleneck_block 8 | STRIDES: [1, 2, 2, 2] 9 | DEPTHS: [3, 8, 36, 3] 10 | WIDTHS: [256, 512, 1024, 2048] 11 | BOT_MULS: [0.25, 0.25, 0.25, 0.25] 12 | GROUP_WS: [64, 128, 256, 512] 13 | OPTIM: 14 | LR_POLICY: cos 15 | BASE_LR: 0.2 16 | MAX_EPOCH: 100 17 | MOMENTUM: 0.9 18 | WEIGHT_DECAY: 5e-5 19 | TRAIN: 20 | DATASET: imagenet 21 | IM_SIZE: 224 22 | BATCH_SIZE: 256 23 | TEST: 24 | DATASET: imagenet 25 | IM_SIZE: 256 26 | BATCH_SIZE: 200 27 | -------------------------------------------------------------------------------- /imagenet/configs/resnet/R-50-1x64d_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: anynet 3 | NUM_CLASSES: 1000 4 | ANYNET: 5 | STEM_TYPE: res_stem_in 6 | STEM_W: 64 7 | BLOCK_TYPE: res_bottleneck_block 8 | STRIDES: [1, 2, 2, 2] 9 | DEPTHS: [3, 4, 6, 3] 10 | WIDTHS: [256, 512, 1024, 2048] 11 | BOT_MULS: [0.25, 0.25, 0.25, 0.25] 12 | GROUP_WS: [64, 128, 256, 512] 13 | OPTIM: 14 | LR_POLICY: cos 15 | BASE_LR: 0.2 16 | MAX_EPOCH: 100 17 | MOMENTUM: 0.9 18 | WEIGHT_DECAY: 5e-5 19 | TRAIN: 20 | DATASET: imagenet 21 | IM_SIZE: 224 22 | BATCH_SIZE: 256 23 | TEST: 24 | DATASET: imagenet 25 | IM_SIZE: 256 26 | BATCH_SIZE: 200 27 | -------------------------------------------------------------------------------- /imagenet/configs/resnext/BAKE-X-101-32x4d_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: anynet 3 | NUM_CLASSES: 1000 4 | ANYNET: 5 | STEM_TYPE: res_stem_in 6 | STEM_W: 64 7 | BLOCK_TYPE: res_bottleneck_block 8 | STRIDES: [1, 2, 2, 2] 9 | DEPTHS: [3, 4, 23, 3] 10 | WIDTHS: [256, 512, 1024, 2048] 11 | BOT_MULS: [0.5, 0.5, 0.5, 0.5] 12 | GROUP_WS: [4, 8, 16, 32] 13 | OPTIM: 14 | LR_POLICY: cos 15 | BASE_LR: 0.4 16 | MAX_EPOCH: 100 17 | MOMENTUM: 0.9 18 | WEIGHT_DECAY: 5e-5 19 | TRAIN: 20 | DATASET: imagenet 21 | IM_SIZE: 224 22 | BATCH_SIZE: 512 23 | INTRA_IMGS: 1 24 | SYNC_BN: True 25 | USE_BAKE: True 26 | OMEGA: 0.5 27 | TEMP: 4.0 28 | LAMBDA: 1.0 29 | GLOBAL_KE: True 30 | TEST: 31 | DATASET: imagenet 32 | IM_SIZE: 256 33 | BATCH_SIZE: 400 34 | -------------------------------------------------------------------------------- /imagenet/configs/resnext/BAKE-X-152-32x4d_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: anynet 3 | NUM_CLASSES: 1000 4 | ANYNET: 5 | STEM_TYPE: res_stem_in 6 | STEM_W: 64 7 | BLOCK_TYPE: res_bottleneck_block 8 | STRIDES: [1, 2, 2, 2] 9 | DEPTHS: [3, 8, 36, 3] 10 | WIDTHS: [256, 512, 1024, 2048] 11 | BOT_MULS: [0.5, 0.5, 0.5, 0.5] 12 | GROUP_WS: [4, 8, 16, 32] 13 | OPTIM: 14 | LR_POLICY: cos 15 | BASE_LR: 0.4 16 | MAX_EPOCH: 100 17 | MOMENTUM: 0.9 18 | WEIGHT_DECAY: 5e-5 19 | TRAIN: 20 | DATASET: imagenet 21 | IM_SIZE: 224 22 | BATCH_SIZE: 512 23 | INTRA_IMGS: 1 24 | SYNC_BN: True 25 | USE_BAKE: True 26 | OMEGA: 0.5 27 | TEMP: 4.0 28 | LAMBDA: 1.0 29 | GLOBAL_KE: True 30 | TEST: 31 | DATASET: imagenet 32 | IM_SIZE: 256 33 | BATCH_SIZE: 400 34 | -------------------------------------------------------------------------------- /imagenet/configs/resnext/X-101-32x4d_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: anynet 3 | NUM_CLASSES: 1000 4 | ANYNET: 5 | STEM_TYPE: res_stem_in 6 | STEM_W: 64 7 | BLOCK_TYPE: res_bottleneck_block 8 | STRIDES: [1, 2, 2, 2] 9 | DEPTHS: [3, 4, 23, 3] 10 | WIDTHS: [256, 512, 1024, 2048] 11 | BOT_MULS: [0.5, 0.5, 0.5, 0.5] 12 | GROUP_WS: [4, 8, 16, 32] 13 | OPTIM: 14 | LR_POLICY: cos 15 | BASE_LR: 0.2 16 | MAX_EPOCH: 100 17 | MOMENTUM: 0.9 18 | WEIGHT_DECAY: 5e-5 19 | TRAIN: 20 | DATASET: imagenet 21 | IM_SIZE: 224 22 | BATCH_SIZE: 256 23 | TEST: 24 | DATASET: imagenet 25 | IM_SIZE: 256 26 | BATCH_SIZE: 200 27 | -------------------------------------------------------------------------------- /imagenet/configs/resnext/X-152-32x4d_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: anynet 3 | NUM_CLASSES: 1000 4 | ANYNET: 5 | STEM_TYPE: res_stem_in 6 | STEM_W: 64 7 | BLOCK_TYPE: res_bottleneck_block 8 | STRIDES: [1, 2, 2, 2] 9 | DEPTHS: [3, 8, 36, 3] 10 | WIDTHS: [256, 512, 1024, 2048] 11 | BOT_MULS: [0.5, 0.5, 0.5, 0.5] 12 | GROUP_WS: [4, 8, 16, 32] 13 | OPTIM: 14 | LR_POLICY: cos 15 | BASE_LR: 0.2 16 | MAX_EPOCH: 100 17 | MOMENTUM: 0.9 18 | WEIGHT_DECAY: 5e-5 19 | TRAIN: 20 | DATASET: imagenet 21 | IM_SIZE: 224 22 | BATCH_SIZE: 256 23 | TEST: 24 | DATASET: imagenet 25 | IM_SIZE: 256 26 | BATCH_SIZE: 200 27 | -------------------------------------------------------------------------------- /imagenet/pycls/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxgeee/BAKE/07c4f668ea19311d5b50121026e73d2f035d5765/imagenet/pycls/__init__.py -------------------------------------------------------------------------------- /imagenet/pycls/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxgeee/BAKE/07c4f668ea19311d5b50121026e73d2f035d5765/imagenet/pycls/core/__init__.py -------------------------------------------------------------------------------- /imagenet/pycls/core/benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Benchmarking functions.""" 9 | 10 | import pycls.core.logging as logging 11 | import pycls.datasets.loader as loader 12 | import torch 13 | from pycls.core.config import cfg 14 | from pycls.core.timer import Timer 15 | 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | 20 | @torch.no_grad() 21 | def compute_time_eval(model): 22 | """Computes precise model forward test time using dummy data.""" 23 | # Use eval mode 24 | model.eval() 25 | # Generate a dummy mini-batch and copy data to GPU 26 | im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS) 27 | inputs = torch.zeros(batch_size, 3, im_size, im_size).cuda(non_blocking=False) 28 | # Compute precise forward pass time 29 | timer = Timer() 30 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER 31 | for cur_iter in range(total_iter): 32 | # Reset the timers after the warmup phase 33 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER: 34 | timer.reset() 35 | # Forward 36 | timer.tic() 37 | model(inputs) 38 | torch.cuda.synchronize() 39 | timer.toc() 40 | return timer.average_time 41 | 42 | 43 | def compute_time_train(model, loss_fun): 44 | """Computes precise model forward + backward time using dummy data.""" 45 | # Use train mode 46 | model.train() 47 | # Generate a dummy mini-batch and copy data to GPU 48 | im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS) 49 | inputs = torch.rand(batch_size, 3, im_size, im_size).cuda(non_blocking=False) 50 | labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False) 51 | # Cache BatchNorm2D running stats 52 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] 53 | bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns] 54 | # Compute precise forward backward pass time 55 | fw_timer, bw_timer = Timer(), Timer() 56 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER 57 | for cur_iter in range(total_iter): 58 | # Reset the timers after the warmup phase 59 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER: 60 | fw_timer.reset() 61 | bw_timer.reset() 62 | # Forward 63 | fw_timer.tic() 64 | _, preds = model(inputs) 65 | loss = loss_fun(preds, labels) 66 | torch.cuda.synchronize() 67 | fw_timer.toc() 68 | # Backward 69 | bw_timer.tic() 70 | loss.backward() 71 | torch.cuda.synchronize() 72 | bw_timer.toc() 73 | # Restore BatchNorm2D running stats 74 | for bn, (mean, var) in zip(bns, bn_stats): 75 | bn.running_mean, bn.running_var = mean, var 76 | return fw_timer.average_time, bw_timer.average_time 77 | 78 | 79 | def compute_time_loader(data_loader): 80 | """Computes loader time.""" 81 | timer = Timer() 82 | loader.shuffle(data_loader, 0) 83 | data_loader_iterator = iter(data_loader) 84 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER 85 | total_iter = min(total_iter, len(data_loader)) 86 | for cur_iter in range(total_iter): 87 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER: 88 | timer.reset() 89 | timer.tic() 90 | next(data_loader_iterator) 91 | timer.toc() 92 | return timer.average_time 93 | 94 | 95 | def compute_time_full(model, loss_fun, train_loader, test_loader): 96 | """Times model and data loader.""" 97 | logger.info("Computing model and loader timings...") 98 | # Compute timings 99 | test_fw_time = compute_time_eval(model) 100 | train_fw_time, train_bw_time = compute_time_train(model, loss_fun) 101 | train_fw_bw_time = train_fw_time + train_bw_time 102 | train_loader_time = compute_time_loader(train_loader) 103 | # Output iter timing 104 | iter_times = { 105 | "test_fw_time": test_fw_time, 106 | "train_fw_time": train_fw_time, 107 | "train_bw_time": train_bw_time, 108 | "train_fw_bw_time": train_fw_bw_time, 109 | "train_loader_time": train_loader_time, 110 | } 111 | logger.info(logging.dump_log_data(iter_times, "iter_times")) 112 | # Output epoch timing 113 | epoch_times = { 114 | "test_fw_time": test_fw_time * len(test_loader), 115 | "train_fw_time": train_fw_time * len(train_loader), 116 | "train_bw_time": train_bw_time * len(train_loader), 117 | "train_fw_bw_time": train_fw_bw_time * len(train_loader), 118 | "train_loader_time": train_loader_time * len(train_loader), 119 | } 120 | logger.info(logging.dump_log_data(epoch_times, "epoch_times")) 121 | # Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1) 122 | overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time 123 | logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100)) 124 | -------------------------------------------------------------------------------- /imagenet/pycls/core/builders.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Model and loss construction functions.""" 9 | 10 | import torch 11 | from pycls.core.config import cfg 12 | from pycls.models.anynet import AnyNet 13 | from pycls.models.effnet import EffNet 14 | from pycls.models.resnet import ResNet 15 | from pycls.models.mobilenetv2 import MobileNetV2 16 | from pycls.models.resnest import ResNeSt 17 | 18 | 19 | # Supported models 20 | _models = { 21 | "anynet": AnyNet, 22 | "effnet": EffNet, 23 | "resnet": ResNet, 24 | "mobilenetv2": MobileNetV2, 25 | "resnest": ResNeSt, 26 | } 27 | 28 | # Supported loss functions 29 | _loss_funs = {"cross_entropy": torch.nn.CrossEntropyLoss} 30 | 31 | 32 | def get_model(): 33 | """Gets the model class specified in the config.""" 34 | err_str = "Model type '{}' not supported" 35 | assert cfg.MODEL.TYPE in _models.keys(), err_str.format(cfg.MODEL.TYPE) 36 | return _models[cfg.MODEL.TYPE] 37 | 38 | 39 | def get_loss_fun(): 40 | """Gets the loss function class specified in the config.""" 41 | err_str = "Loss function type '{}' not supported" 42 | assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS) 43 | return _loss_funs[cfg.MODEL.LOSS_FUN] 44 | 45 | 46 | def build_model(): 47 | """Builds the model.""" 48 | return get_model()() 49 | 50 | 51 | def build_loss_fun(): 52 | """Build the loss function.""" 53 | return get_loss_fun()() 54 | 55 | 56 | def register_model(name, ctor): 57 | """Registers a model dynamically.""" 58 | _models[name] = ctor 59 | 60 | 61 | def register_loss_fun(name, ctor): 62 | """Registers a loss function dynamically.""" 63 | _loss_funs[name] = ctor 64 | -------------------------------------------------------------------------------- /imagenet/pycls/core/checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions that handle saving and loading of checkpoints.""" 9 | 10 | import os 11 | 12 | import pycls.core.distributed as dist 13 | import torch 14 | from pycls.core.config import cfg 15 | 16 | 17 | # Common prefix for checkpoint file names 18 | _NAME_PREFIX = "model_epoch_" 19 | # Checkpoints directory name 20 | _DIR_NAME = "checkpoints" 21 | 22 | 23 | def get_checkpoint_dir(): 24 | """Retrieves the location for storing checkpoints.""" 25 | return os.path.join(cfg.OUT_DIR, _DIR_NAME) 26 | 27 | 28 | def get_checkpoint(epoch): 29 | """Retrieves the path to a checkpoint file.""" 30 | name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch) 31 | return os.path.join(get_checkpoint_dir(), name) 32 | 33 | 34 | def get_last_checkpoint(): 35 | """Retrieves the most recent checkpoint (highest epoch number).""" 36 | checkpoint_dir = get_checkpoint_dir() 37 | # Checkpoint file names are in lexicographic order 38 | checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f] 39 | last_checkpoint_name = sorted(checkpoints)[-1] 40 | return os.path.join(checkpoint_dir, last_checkpoint_name) 41 | 42 | 43 | def has_checkpoint(): 44 | """Determines if there are checkpoints available.""" 45 | checkpoint_dir = get_checkpoint_dir() 46 | if not os.path.exists(checkpoint_dir): 47 | return False 48 | return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir)) 49 | 50 | 51 | def save_checkpoint(model, optimizer, epoch): 52 | """Saves a checkpoint.""" 53 | # Save checkpoints only from the master process 54 | if not dist.is_master_proc(): 55 | return 56 | # Ensure that the checkpoint dir exists 57 | os.makedirs(get_checkpoint_dir(), exist_ok=True) 58 | # Omit the DDP wrapper in the multi-gpu setting 59 | sd = model.module.state_dict() 60 | # if cfg.NUM_GPUS > 1 else model.state_dict() 61 | # Record the state 62 | checkpoint = { 63 | "epoch": epoch, 64 | "model_state": sd, 65 | "optimizer_state": optimizer.state_dict(), 66 | "cfg": cfg.dump(), 67 | } 68 | # Write the checkpoint 69 | checkpoint_file = get_checkpoint(epoch + 1) 70 | torch.save(checkpoint, checkpoint_file) 71 | return checkpoint_file 72 | 73 | 74 | def load_checkpoint(checkpoint_file, model, optimizer=None): 75 | """Loads the checkpoint from the given file.""" 76 | err_str = "Checkpoint '{}' not found" 77 | assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file) 78 | # Load the checkpoint on CPU to avoid GPU mem spike 79 | checkpoint = torch.load(checkpoint_file, map_location="cpu") 80 | # Account for the DDP wrapper in the multi-gpu setting 81 | ms = model.module 82 | # if cfg.NUM_GPUS > 1 else model 83 | ms.load_state_dict(checkpoint["model_state"]) 84 | # Load the optimizer state (commonly not done when fine-tuning) 85 | if optimizer: 86 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 87 | return checkpoint["epoch"] 88 | -------------------------------------------------------------------------------- /imagenet/pycls/core/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Configuration file (powered by YACS).""" 9 | 10 | import argparse 11 | import os 12 | import sys 13 | 14 | from pycls.core.io import cache_url 15 | from yacs.config import CfgNode as CfgNode 16 | 17 | 18 | # Global config object 19 | _C = CfgNode() 20 | 21 | # Example usage: 22 | # from core.config import cfg 23 | cfg = _C 24 | 25 | 26 | # ------------------------------------------------------------------------------------ # 27 | # Model options 28 | # ------------------------------------------------------------------------------------ # 29 | _C.MODEL = CfgNode() 30 | 31 | # Model type 32 | _C.MODEL.TYPE = "" 33 | 34 | # Number of weight layers 35 | _C.MODEL.DEPTH = 0 36 | 37 | # Number of classes 38 | _C.MODEL.NUM_CLASSES = 10 39 | 40 | # Loss function (see pycls/models/loss.py for options) 41 | _C.MODEL.LOSS_FUN = "cross_entropy" 42 | 43 | 44 | # ------------------------------------------------------------------------------------ # 45 | # ResNet options 46 | # ------------------------------------------------------------------------------------ # 47 | _C.RESNET = CfgNode() 48 | 49 | # Transformation function (see pycls/models/resnet.py for options) 50 | _C.RESNET.TRANS_FUN = "basic_transform" 51 | 52 | # Number of groups to use (1 -> ResNet; > 1 -> ResNeXt) 53 | _C.RESNET.NUM_GROUPS = 1 54 | 55 | # Width of each group (64 -> ResNet; 4 -> ResNeXt) 56 | _C.RESNET.WIDTH_PER_GROUP = 64 57 | 58 | # Apply stride to 1x1 conv (True -> MSRA; False -> fb.torch) 59 | _C.RESNET.STRIDE_1X1 = True 60 | 61 | 62 | # ------------------------------------------------------------------------------------ # 63 | # AnyNet options 64 | # ------------------------------------------------------------------------------------ # 65 | _C.ANYNET = CfgNode() 66 | 67 | # Stem type 68 | _C.ANYNET.STEM_TYPE = "simple_stem_in" 69 | 70 | # Stem width 71 | _C.ANYNET.STEM_W = 32 72 | 73 | # Block type 74 | _C.ANYNET.BLOCK_TYPE = "res_bottleneck_block" 75 | 76 | # Depth for each stage (number of blocks in the stage) 77 | _C.ANYNET.DEPTHS = [] 78 | 79 | # Width for each stage (width of each block in the stage) 80 | _C.ANYNET.WIDTHS = [] 81 | 82 | # Strides for each stage (applies to the first block of each stage) 83 | _C.ANYNET.STRIDES = [] 84 | 85 | # Bottleneck multipliers for each stage (applies to bottleneck block) 86 | _C.ANYNET.BOT_MULS = [] 87 | 88 | # Group widths for each stage (applies to bottleneck block) 89 | _C.ANYNET.GROUP_WS = [] 90 | 91 | # Whether SE is enabled for res_bottleneck_block 92 | _C.ANYNET.SE_ON = False 93 | 94 | # SE ratio 95 | _C.ANYNET.SE_R = 0.25 96 | 97 | 98 | # ------------------------------------------------------------------------------------ # 99 | # RegNet options 100 | # ------------------------------------------------------------------------------------ # 101 | _C.REGNET = CfgNode() 102 | 103 | # Stem type 104 | _C.REGNET.STEM_TYPE = "simple_stem_in" 105 | 106 | # Stem width 107 | _C.REGNET.STEM_W = 32 108 | 109 | # Block type 110 | _C.REGNET.BLOCK_TYPE = "res_bottleneck_block" 111 | 112 | # Stride of each stage 113 | _C.REGNET.STRIDE = 2 114 | 115 | # Squeeze-and-Excitation (RegNetY) 116 | _C.REGNET.SE_ON = False 117 | _C.REGNET.SE_R = 0.25 118 | 119 | # Depth 120 | _C.REGNET.DEPTH = 10 121 | 122 | # Initial width 123 | _C.REGNET.W0 = 32 124 | 125 | # Slope 126 | _C.REGNET.WA = 5.0 127 | 128 | # Quantization 129 | _C.REGNET.WM = 2.5 130 | 131 | # Group width 132 | _C.REGNET.GROUP_W = 16 133 | 134 | # Bottleneck multiplier (bm = 1 / b from the paper) 135 | _C.REGNET.BOT_MUL = 1.0 136 | 137 | 138 | # ------------------------------------------------------------------------------------ # 139 | # ResNeSt options 140 | # ------------------------------------------------------------------------------------ # 141 | _C.RESNEST = CfgNode() 142 | 143 | # Depth for each stage (number of blocks in the stage) 144 | _C.RESNEST.DEPTHS = [] 145 | 146 | # Width for each stage (width of each block in the stage) 147 | _C.RESNEST.STEM_W = 32 148 | 149 | 150 | # ------------------------------------------------------------------------------------ # 151 | # EfficientNet options 152 | # ------------------------------------------------------------------------------------ # 153 | _C.EN = CfgNode() 154 | 155 | # Stem width 156 | _C.EN.STEM_W = 32 157 | 158 | # Depth for each stage (number of blocks in the stage) 159 | _C.EN.DEPTHS = [] 160 | 161 | # Width for each stage (width of each block in the stage) 162 | _C.EN.WIDTHS = [] 163 | 164 | # Expansion ratios for MBConv blocks in each stage 165 | _C.EN.EXP_RATIOS = [] 166 | 167 | # Squeeze-and-Excitation (SE) ratio 168 | _C.EN.SE_R = 0.25 169 | 170 | # Strides for each stage (applies to the first block of each stage) 171 | _C.EN.STRIDES = [] 172 | 173 | # Kernel sizes for each stage 174 | _C.EN.KERNELS = [] 175 | 176 | # Head width 177 | _C.EN.HEAD_W = 1280 178 | 179 | # Drop connect ratio 180 | _C.EN.DC_RATIO = 0.0 181 | 182 | # Dropout ratio 183 | _C.EN.DROPOUT_RATIO = 0.0 184 | 185 | 186 | # ------------------------------------------------------------------------------------ # 187 | # Batch norm options 188 | # ------------------------------------------------------------------------------------ # 189 | _C.BN = CfgNode() 190 | 191 | # BN epsilon 192 | _C.BN.EPS = 1e-5 193 | 194 | # BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2) 195 | _C.BN.MOM = 0.1 196 | 197 | # Precise BN stats 198 | _C.BN.USE_PRECISE_STATS = True 199 | _C.BN.NUM_SAMPLES_PRECISE = 8192 200 | 201 | # Initialize the gamma of the final BN of each block to zero 202 | _C.BN.ZERO_INIT_FINAL_GAMMA = False 203 | 204 | # Use a different weight decay for BN layers 205 | _C.BN.USE_CUSTOM_WEIGHT_DECAY = False 206 | _C.BN.CUSTOM_WEIGHT_DECAY = 0.0 207 | 208 | 209 | # ------------------------------------------------------------------------------------ # 210 | # Optimizer options 211 | # ------------------------------------------------------------------------------------ # 212 | _C.OPTIM = CfgNode() 213 | 214 | # Base learning rate 215 | _C.OPTIM.BASE_LR = 0.1 216 | 217 | # Learning rate policy select from {'cos', 'exp', 'steps'} 218 | _C.OPTIM.LR_POLICY = "cos" 219 | 220 | # Exponential decay factor 221 | _C.OPTIM.GAMMA = 0.1 222 | 223 | # Steps for 'steps' policy (in epochs) 224 | _C.OPTIM.STEPS = [] 225 | 226 | # Learning rate multiplier for 'steps' policy 227 | _C.OPTIM.LR_MULT = 0.1 228 | 229 | # Maximal number of epochs 230 | _C.OPTIM.MAX_EPOCH = 200 231 | 232 | # Momentum 233 | _C.OPTIM.MOMENTUM = 0.9 234 | 235 | # Momentum dampening 236 | _C.OPTIM.DAMPENING = 0.0 237 | 238 | # Nesterov momentum 239 | _C.OPTIM.NESTEROV = True 240 | 241 | # L2 regularization 242 | _C.OPTIM.WEIGHT_DECAY = 5e-4 243 | 244 | # Start the warm up from OPTIM.BASE_LR * OPTIM.WARMUP_FACTOR 245 | _C.OPTIM.WARMUP_FACTOR = 0.1 246 | 247 | # Gradually warm up the OPTIM.BASE_LR over this number of epochs 248 | _C.OPTIM.WARMUP_EPOCHS = 5 249 | 250 | 251 | # ------------------------------------------------------------------------------------ # 252 | # Training options 253 | # ------------------------------------------------------------------------------------ # 254 | _C.TRAIN = CfgNode() 255 | 256 | # Dataset and split 257 | _C.TRAIN.DATASET = "" 258 | _C.TRAIN.SPLIT = "train" 259 | 260 | # Total mini-batch size 261 | _C.TRAIN.BATCH_SIZE = 128 262 | _C.TRAIN.INTRA_IMGS = 0 263 | 264 | # Image size 265 | _C.TRAIN.IM_SIZE = 224 266 | 267 | # Evaluate model on test data every eval period epochs 268 | _C.TRAIN.EVAL_PERIOD = 1 269 | 270 | # Save model checkpoint every checkpoint period epochs 271 | _C.TRAIN.CHECKPOINT_PERIOD = 1 272 | 273 | # Resume training from the latest checkpoint in the output directory 274 | _C.TRAIN.AUTO_RESUME = True 275 | 276 | # Weights to start training from 277 | _C.TRAIN.WEIGHTS = "" 278 | 279 | _C.TRAIN.SYNC_BN = False 280 | _C.TRAIN.SHUFFLE_BN = False 281 | 282 | _C.TRAIN.USE_BAKE = False 283 | _C.TRAIN.OMEGA = 0.5 284 | _C.TRAIN.TEMP = 4. 285 | _C.TRAIN.LAMBDA = 0. 286 | _C.TRAIN.GLOBAL_KE = True 287 | 288 | _C.TRAIN.IS_CUTMIX = False 289 | _C.TRAIN.CUTMIX = CfgNode() 290 | _C.TRAIN.CUTMIX.PROB = 1.0 291 | _C.TRAIN.CUTMIX.BETA = 1.0 292 | _C.TRAIN.CUTMIX.SMOOTHING = 0.0 293 | _C.TRAIN.CUTMIX.OFF_EPOCH = 5 294 | 295 | _C.TRAIN.JITTER = False 296 | 297 | # ------------------------------------------------------------------------------------ # 298 | # Testing options 299 | # ------------------------------------------------------------------------------------ # 300 | _C.TEST = CfgNode() 301 | 302 | # Dataset and split 303 | _C.TEST.DATASET = "" 304 | _C.TEST.SPLIT = "val" 305 | 306 | # Total mini-batch size 307 | _C.TEST.BATCH_SIZE = 200 308 | 309 | # Image size 310 | _C.TEST.IM_SIZE = 256 311 | 312 | # Weights to use for testing 313 | _C.TEST.WEIGHTS = "" 314 | 315 | 316 | # ------------------------------------------------------------------------------------ # 317 | # Common train/test data loader options 318 | # ------------------------------------------------------------------------------------ # 319 | _C.DATA_LOADER = CfgNode() 320 | 321 | # Number of data loader workers per process 322 | _C.DATA_LOADER.NUM_WORKERS = 6 323 | 324 | # Load data to pinned host memory 325 | _C.DATA_LOADER.PIN_MEMORY = True 326 | 327 | 328 | # ------------------------------------------------------------------------------------ # 329 | # Memory options 330 | # ------------------------------------------------------------------------------------ # 331 | _C.MEM = CfgNode() 332 | 333 | # Perform ReLU inplace 334 | _C.MEM.RELU_INPLACE = True 335 | 336 | 337 | # ------------------------------------------------------------------------------------ # 338 | # CUDNN options 339 | # ------------------------------------------------------------------------------------ # 340 | _C.CUDNN = CfgNode() 341 | 342 | # Perform benchmarking to select the fastest CUDNN algorithms to use 343 | # Note that this may increase the memory usage and will likely not result 344 | # in overall speedups when variable size inputs are used (e.g. COCO training) 345 | _C.CUDNN.BENCHMARK = True 346 | 347 | 348 | # ------------------------------------------------------------------------------------ # 349 | # Precise timing options 350 | # ------------------------------------------------------------------------------------ # 351 | _C.PREC_TIME = CfgNode() 352 | 353 | # Number of iterations to warm up the caches 354 | _C.PREC_TIME.WARMUP_ITER = 3 355 | 356 | # Number of iterations to compute avg time 357 | _C.PREC_TIME.NUM_ITER = 30 358 | 359 | 360 | # ------------------------------------------------------------------------------------ # 361 | # Misc options 362 | # ------------------------------------------------------------------------------------ # 363 | 364 | _C.LAUNCHER = "slurm" 365 | _C.PORT = 8080 366 | _C.RANK = 0 367 | _C.WORLD_SIZE = 1 368 | _C.NGPUS_PER_NODE = 1 369 | _C.GPU = 0 370 | 371 | # Number of GPUs to use (applies to both training and testing) 372 | _C.NUM_GPUS = 1 373 | 374 | # Output directory 375 | _C.OUT_DIR = "." 376 | 377 | # Config destination (in OUT_DIR) 378 | _C.CFG_DEST = "config.yaml" 379 | 380 | # Note that non-determinism may still be present due to non-deterministic 381 | # operator implementations in GPU operator libraries 382 | _C.RNG_SEED = 1 383 | 384 | # Log destination ('stdout' or 'file') 385 | _C.LOG_DEST = "stdout" 386 | 387 | # Log period in iters 388 | _C.LOG_PERIOD = 10 389 | 390 | # Distributed backend 391 | _C.DIST_BACKEND = "nccl" 392 | 393 | # Hostname and port range for multi-process groups (actual port selected randomly) 394 | _C.HOST = "localhost" 395 | _C.PORT_RANGE = [10000, 65000] 396 | 397 | # Models weights referred to by URL are downloaded to this local cache 398 | _C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache" 399 | 400 | 401 | # ------------------------------------------------------------------------------------ # 402 | # Deprecated keys 403 | # ------------------------------------------------------------------------------------ # 404 | 405 | _C.register_deprecated_key("PREC_TIME.BATCH_SIZE") 406 | _C.register_deprecated_key("PREC_TIME.ENABLED") 407 | # _C.register_deprecated_key("PORT") 408 | 409 | 410 | def assert_and_infer_cfg(cache_urls=True): 411 | """Checks config values invariants.""" 412 | err_str = "The first lr step must start at 0" 413 | assert not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0, err_str 414 | data_splits = ["train", "val", "test"] 415 | err_str = "Data split '{}' not supported" 416 | assert _C.TRAIN.SPLIT in data_splits, err_str.format(_C.TRAIN.SPLIT) 417 | assert _C.TEST.SPLIT in data_splits, err_str.format(_C.TEST.SPLIT) 418 | err_str = "Mini-batch size should be a multiple of NUM_GPUS." 419 | assert _C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0, err_str 420 | assert _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0, err_str 421 | err_str = "Log destination '{}' not supported" 422 | assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST) 423 | if cache_urls: 424 | cache_cfg_urls() 425 | 426 | 427 | def cache_cfg_urls(): 428 | """Download URLs in config, cache them, and rewrite cfg to use cached file.""" 429 | _C.TRAIN.WEIGHTS = cache_url(_C.TRAIN.WEIGHTS, _C.DOWNLOAD_CACHE) 430 | _C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE) 431 | 432 | 433 | def dump_cfg(): 434 | """Dumps the config to the output directory.""" 435 | cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST) 436 | with open(cfg_file, "w") as f: 437 | _C.dump(stream=f) 438 | 439 | 440 | def load_cfg(out_dir, cfg_dest="config.yaml"): 441 | """Loads config from specified output directory.""" 442 | cfg_file = os.path.join(out_dir, cfg_dest) 443 | _C.merge_from_file(cfg_file) 444 | 445 | 446 | def load_cfg_fom_args(description="Config file options."): 447 | """Load config from command line arguments and set any specified options.""" 448 | parser = argparse.ArgumentParser(description=description) 449 | help_s = "Config file location" 450 | parser.add_argument("--cfg", dest="cfg_file", help=help_s, required=True, type=str) 451 | help_s = "See pycls/core/config.py for all options" 452 | parser.add_argument("opts", help=help_s, default=None, nargs=argparse.REMAINDER) 453 | if len(sys.argv) == 1: 454 | parser.print_help() 455 | sys.exit(1) 456 | args = parser.parse_args() 457 | _C.merge_from_file(args.cfg_file) 458 | _C.merge_from_list(args.opts) 459 | -------------------------------------------------------------------------------- /imagenet/pycls/core/distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Distributed helpers.""" 9 | 10 | import multiprocessing 11 | import os 12 | import subprocess 13 | import random 14 | import signal 15 | import threading 16 | import traceback 17 | 18 | import torch 19 | import torch.distributed as dist 20 | import torch.multiprocessing as mp 21 | 22 | from pycls.core.config import cfg 23 | 24 | 25 | def is_master_proc(): 26 | """Determines if the current process is the master process. 27 | 28 | Master process is responsible for logging, writing and loading checkpoints. In 29 | the multi GPU setting, we assign the master role to the rank 0 process. When 30 | training using a single GPU, there is a single process which is considered master. 31 | """ 32 | return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0 33 | 34 | 35 | def init_process_group(proc_rank, world_size, port): 36 | """Initializes the default process group.""" 37 | # Set the GPU to use 38 | torch.cuda.set_device(proc_rank) 39 | # Initialize the process group 40 | torch.distributed.init_process_group( 41 | backend=cfg.DIST_BACKEND, 42 | init_method="tcp://{}:{}".format(cfg.HOST, port), 43 | world_size=world_size, 44 | rank=proc_rank, 45 | ) 46 | 47 | 48 | def destroy_process_group(): 49 | """Destroys the default process group.""" 50 | torch.distributed.destroy_process_group() 51 | 52 | 53 | def scaled_all_reduce(tensors): 54 | """Performs the scaled all_reduce operation on the provided tensors. 55 | 56 | The input tensors are modified in-place. Currently supports only the sum 57 | reduction operator. The reduced values are scaled by the inverse size of the 58 | process group (equivalent to cfg.NUM_GPUS). 59 | """ 60 | # There is no need for reduction in the single-proc case 61 | if cfg.NUM_GPUS == 1: 62 | return tensors 63 | # Queue the reductions 64 | reductions = [] 65 | for tensor in tensors: 66 | reduction = torch.distributed.all_reduce(tensor, async_op=True) 67 | reductions.append(reduction) 68 | # Wait for reductions to finish 69 | for reduction in reductions: 70 | reduction.wait() 71 | # Scale the results 72 | for tensor in tensors: 73 | tensor.mul_(1.0 / cfg.NUM_GPUS) 74 | return tensors 75 | 76 | 77 | class ChildException(Exception): 78 | """Wraps an exception from a child process.""" 79 | 80 | def __init__(self, child_trace): 81 | super(ChildException, self).__init__(child_trace) 82 | 83 | 84 | class ErrorHandler(object): 85 | """Multiprocessing error handler (based on fairseq's). 86 | 87 | Listens for errors in child processes and propagates the tracebacks to the parent. 88 | """ 89 | 90 | def __init__(self, error_queue): 91 | # Shared error queue 92 | self.error_queue = error_queue 93 | # Children processes sharing the error queue 94 | self.children_pids = [] 95 | # Start a thread listening to errors 96 | self.error_listener = threading.Thread(target=self.listen, daemon=True) 97 | self.error_listener.start() 98 | # Register the signal handler 99 | signal.signal(signal.SIGUSR1, self.signal_handler) 100 | 101 | def add_child(self, pid): 102 | """Registers a child process.""" 103 | self.children_pids.append(pid) 104 | 105 | def listen(self): 106 | """Listens for errors in the error queue.""" 107 | # Wait until there is an error in the queue 108 | child_trace = self.error_queue.get() 109 | # Put the error back for the signal handler 110 | self.error_queue.put(child_trace) 111 | # Invoke the signal handler 112 | os.kill(os.getpid(), signal.SIGUSR1) 113 | 114 | def signal_handler(self, _sig_num, _stack_frame): 115 | """Signal handler.""" 116 | # Kill children processes 117 | for pid in self.children_pids: 118 | os.kill(pid, signal.SIGINT) 119 | # Propagate the error from the child process 120 | raise ChildException(self.error_queue.get()) 121 | 122 | 123 | def run(proc_rank, world_size, port, error_queue, fun, fun_args, fun_kwargs): 124 | """Runs a function from a child process.""" 125 | try: 126 | # Initialize the process group 127 | init_process_group(proc_rank, world_size, port) 128 | # Run the function 129 | fun(*fun_args, **fun_kwargs) 130 | except KeyboardInterrupt: 131 | # Killed by the parent process 132 | pass 133 | except Exception: 134 | # Propagate exception to the parent process 135 | error_queue.put(traceback.format_exc()) 136 | finally: 137 | # Destroy the process group 138 | destroy_process_group() 139 | 140 | def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None): 141 | init_dist() 142 | cfg.freeze() 143 | fun_kwargs = fun_kwargs if fun_kwargs else {} 144 | fun(*fun_args, **fun_kwargs) 145 | # return 146 | 147 | def init_dist(backend="nccl"): 148 | if mp.get_start_method(allow_none=True) is None: 149 | mp.set_start_method("spawn") 150 | 151 | if cfg.LAUNCHER == "pytorch": 152 | init_dist_pytorch(backend) 153 | elif cfg.LAUNCHER == "slurm": 154 | init_dist_slurm(backend) 155 | else: 156 | raise ValueError("Invalid launcher type: {}".format(cfg.LAUNCHER)) 157 | 158 | 159 | def init_dist_pytorch(backend="nccl"): 160 | cfg.RANK = int(os.environ["LOCAL_RANK"]) 161 | if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(): 162 | cfg.NGPUS_PER_NODE = len(os.environ['CUDA_VISIBLE_DEVICES'].split(',')) 163 | else: 164 | cfg.NGPUS_PER_NODE = torch.cuda.device_count() 165 | assert cfg.NGPUS_PER_NODE>0, "CUDA is not supported" 166 | cfg.GPU = cfg.RANK 167 | torch.cuda.set_device(cfg.GPU) 168 | dist.init_process_group(backend=backend) 169 | cfg.NUM_GPUS = dist.get_world_size() 170 | cfg.WORLD_SIZE = cfg.NUM_GPUS 171 | 172 | def init_dist_slurm(backend="nccl"): 173 | cfg.RANK = int(os.environ["SLURM_PROCID"]) 174 | cfg.WORLD_SIZE = int(os.environ["SLURM_NTASKS"]) 175 | node_list = os.environ["SLURM_NODELIST"] 176 | if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(): 177 | cfg.NGPUS_PER_NODE = len(os.environ['CUDA_VISIBLE_DEVICES'].split(',')) 178 | else: 179 | cfg.NGPUS_PER_NODE = torch.cuda.device_count() 180 | assert cfg.NGPUS_PER_NODE>0, "CUDA is not supported" 181 | cfg.GPU = cfg.RANK % cfg.NGPUS_PER_NODE 182 | torch.cuda.set_device(cfg.GPU) 183 | addr = subprocess.getoutput( 184 | "scontrol show hostname {} | head -n1".format(node_list) 185 | ) 186 | os.environ["MASTER_PORT"] = str(cfg.PORT) 187 | os.environ["MASTER_ADDR"] = addr 188 | os.environ["WORLD_SIZE"] = str(cfg.WORLD_SIZE) 189 | os.environ["RANK"] = str(cfg.RANK) 190 | dist.init_process_group(backend=backend) 191 | cfg.NUM_GPUS = dist.get_world_size() 192 | -------------------------------------------------------------------------------- /imagenet/pycls/core/io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """IO utilities (adapted from Detectron)""" 9 | 10 | import logging 11 | import os 12 | import re 13 | import sys 14 | from urllib import request as urlrequest 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | _PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls" 20 | 21 | 22 | def cache_url(url_or_file, cache_dir): 23 | """Download the file specified by the URL to the cache_dir and return the path to 24 | the cached file. If the argument is not a URL, simply return it as is. 25 | """ 26 | is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None 27 | if not is_url: 28 | return url_or_file 29 | url = url_or_file 30 | err_str = "pycls only automatically caches URLs in the pycls S3 bucket: {}" 31 | assert url.startswith(_PYCLS_BASE_URL), err_str.format(_PYCLS_BASE_URL) 32 | cache_file_path = url.replace(_PYCLS_BASE_URL, cache_dir) 33 | if os.path.exists(cache_file_path): 34 | return cache_file_path 35 | cache_file_dir = os.path.dirname(cache_file_path) 36 | if not os.path.exists(cache_file_dir): 37 | os.makedirs(cache_file_dir) 38 | logger.info("Downloading remote file {} to {}".format(url, cache_file_path)) 39 | download_url(url, cache_file_path) 40 | return cache_file_path 41 | 42 | 43 | def _progress_bar(count, total): 44 | """Report download progress. Credit: 45 | https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113 46 | """ 47 | bar_len = 60 48 | filled_len = int(round(bar_len * count / float(total))) 49 | percents = round(100.0 * count / float(total), 1) 50 | bar = "=" * filled_len + "-" * (bar_len - filled_len) 51 | sys.stdout.write( 52 | " [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024) 53 | ) 54 | sys.stdout.flush() 55 | if count >= total: 56 | sys.stdout.write("\n") 57 | 58 | 59 | def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar): 60 | """Download url and write it to dst_file_path. Credit: 61 | https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook 62 | """ 63 | req = urlrequest.Request(url) 64 | response = urlrequest.urlopen(req) 65 | total_size = response.info().get("Content-Length").strip() 66 | total_size = int(total_size) 67 | bytes_so_far = 0 68 | with open(dst_file_path, "wb") as f: 69 | while 1: 70 | chunk = response.read(chunk_size) 71 | bytes_so_far += len(chunk) 72 | if not chunk: 73 | break 74 | if progress_hook: 75 | progress_hook(bytes_so_far, total_size) 76 | f.write(chunk) 77 | return bytes_so_far 78 | -------------------------------------------------------------------------------- /imagenet/pycls/core/logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Logging.""" 9 | 10 | import builtins 11 | import decimal 12 | import logging 13 | import os 14 | import sys 15 | 16 | import pycls.core.distributed as dist 17 | import simplejson 18 | from pycls.core.config import cfg 19 | 20 | 21 | # Show filename and line number in logs 22 | _FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s" 23 | 24 | # Log file name (for cfg.LOG_DEST = 'file') 25 | _LOG_FILE = "stdout.log" 26 | 27 | # Data output with dump_log_data(data, data_type) will be tagged w/ this 28 | _TAG = "json_stats: " 29 | 30 | # Data output with dump_log_data(data, data_type) will have data[_TYPE]=data_type 31 | _TYPE = "_type" 32 | 33 | 34 | def _suppress_print(): 35 | """Suppresses printing from the current process.""" 36 | 37 | def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False): 38 | pass 39 | 40 | builtins.print = ignore 41 | 42 | 43 | def setup_logging(): 44 | """Sets up the logging.""" 45 | # Enable logging only for the master process 46 | if dist.is_master_proc(): 47 | # Clear the root logger to prevent any existing logging config 48 | # (e.g. set by another module) from messing with our setup 49 | logging.root.handlers = [] 50 | # Construct logging configuration 51 | logging_config = {"level": logging.INFO, "format": _FORMAT} 52 | # Log either to stdout or to a file 53 | if cfg.LOG_DEST == "stdout": 54 | logging_config["stream"] = sys.stdout 55 | else: 56 | logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE) 57 | # Configure logging 58 | logging.basicConfig(**logging_config) 59 | else: 60 | _suppress_print() 61 | 62 | 63 | def get_logger(name): 64 | """Retrieves the logger.""" 65 | return logging.getLogger(name) 66 | 67 | 68 | def dump_log_data(data, data_type, prec=4): 69 | """Covert data (a dictionary) into tagged json string for logging.""" 70 | data[_TYPE] = data_type 71 | data = float_to_decimal(data, prec) 72 | data_json = simplejson.dumps(data, sort_keys=True, use_decimal=True) 73 | return "{:s}{:s}".format(_TAG, data_json) 74 | 75 | 76 | def float_to_decimal(data, prec=4): 77 | """Convert floats to decimals which allows for fixed width json.""" 78 | if isinstance(data, dict): 79 | return {k: float_to_decimal(v, prec) for k, v in data.items()} 80 | if isinstance(data, float): 81 | return decimal.Decimal(("{:." + str(prec) + "f}").format(data)) 82 | else: 83 | return data 84 | 85 | 86 | def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE): 87 | """Get all log files in directory containing subdirs of trained models.""" 88 | names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n] 89 | files = [os.path.join(log_dir, n, log_file) for n in names] 90 | f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)] 91 | files, names = zip(*f_n_ps) if f_n_ps else ([], []) 92 | return files, names 93 | 94 | 95 | def load_log_data(log_file, data_types_to_skip=()): 96 | """Loads log data into a dictionary of the form data[data_type][metric][index].""" 97 | # Load log_file 98 | assert os.path.exists(log_file), "Log file not found: {}".format(log_file) 99 | with open(log_file, "r") as f: 100 | lines = f.readlines() 101 | # Extract and parse lines that start with _TAG and have a type specified 102 | lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l] 103 | lines = [simplejson.loads(l) for l in lines] 104 | lines = [l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip] 105 | # Generate data structure accessed by data[data_type][index][metric] 106 | data_types = [l[_TYPE] for l in lines] 107 | data = {t: [] for t in data_types} 108 | for t, line in zip(data_types, lines): 109 | del line[_TYPE] 110 | data[t].append(line) 111 | # Generate data structure accessed by data[data_type][metric][index] 112 | for t in data: 113 | metrics = sorted(data[t][0].keys()) 114 | err_str = "Inconsistent metrics in log for _type={}: {}".format(t, metrics) 115 | assert all(sorted(d.keys()) == metrics for d in data[t]), err_str 116 | data[t] = {m: [d[m] for d in data[t]] for m in metrics} 117 | return data 118 | 119 | 120 | def sort_log_data(data): 121 | """Sort each data[data_type][metric] by epoch or keep only first instance.""" 122 | for t in data: 123 | if "epoch" in data[t]: 124 | assert "epoch_ind" not in data[t] and "epoch_max" not in data[t] 125 | data[t]["epoch_ind"] = [int(e.split("/")[0]) for e in data[t]["epoch"]] 126 | data[t]["epoch_max"] = [int(e.split("/")[1]) for e in data[t]["epoch"]] 127 | epoch = data[t]["epoch_ind"] 128 | if "iter" in data[t]: 129 | assert "iter_ind" not in data[t] and "iter_max" not in data[t] 130 | data[t]["iter_ind"] = [int(i.split("/")[0]) for i in data[t]["iter"]] 131 | data[t]["iter_max"] = [int(i.split("/")[1]) for i in data[t]["iter"]] 132 | itr = zip(epoch, data[t]["iter_ind"], data[t]["iter_max"]) 133 | epoch = [e + (i_ind - 1) / i_max for e, i_ind, i_max in itr] 134 | for m in data[t]: 135 | data[t][m] = [v for _, v in sorted(zip(epoch, data[t][m]))] 136 | else: 137 | data[t] = {m: d[0] for m, d in data[t].items()} 138 | return data 139 | -------------------------------------------------------------------------------- /imagenet/pycls/core/meters.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Meters.""" 9 | 10 | from collections import deque 11 | 12 | import numpy as np 13 | import pycls.core.logging as logging 14 | import torch 15 | from pycls.core.config import cfg 16 | from pycls.core.timer import Timer 17 | 18 | 19 | logger = logging.get_logger(__name__) 20 | 21 | 22 | def time_string(seconds): 23 | """Converts time in seconds to a fixed-width string format.""" 24 | days, rem = divmod(int(seconds), 24 * 3600) 25 | hrs, rem = divmod(rem, 3600) 26 | mins, secs = divmod(rem, 60) 27 | return "{0:02},{1:02}:{2:02}:{3:02}".format(days, hrs, mins, secs) 28 | 29 | 30 | def topk_errors(preds, labels, ks): 31 | """Computes the top-k error for each k.""" 32 | err_str = "Batch dim of predictions and labels must match" 33 | assert preds.size(0) == labels.size(0), err_str 34 | # Find the top max_k predictions for each sample 35 | _top_max_k_vals, top_max_k_inds = torch.topk( 36 | preds, max(ks), dim=1, largest=True, sorted=True 37 | ) 38 | # (batch_size, max_k) -> (max_k, batch_size) 39 | top_max_k_inds = top_max_k_inds.t() 40 | # (batch_size, ) -> (max_k, batch_size) 41 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) 42 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct 43 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) 44 | # Compute the number of topk correct predictions for each k 45 | topks_correct = [top_max_k_correct[:k, :].contiguous().view(-1).float().sum() for k in ks] 46 | return [(1.0 - x / preds.size(0)) * 100.0 for x in topks_correct] 47 | 48 | 49 | def gpu_mem_usage(): 50 | """Computes the GPU memory usage for the current device (MB).""" 51 | mem_usage_bytes = torch.cuda.max_memory_allocated() 52 | return mem_usage_bytes / 1024 / 1024 53 | 54 | 55 | class ScalarMeter(object): 56 | """Measures a scalar value (adapted from Detectron).""" 57 | 58 | def __init__(self, window_size): 59 | self.deque = deque(maxlen=window_size) 60 | self.total = 0.0 61 | self.count = 0 62 | 63 | def reset(self): 64 | self.deque.clear() 65 | self.total = 0.0 66 | self.count = 0 67 | 68 | def add_value(self, value): 69 | self.deque.append(value) 70 | self.count += 1 71 | self.total += value 72 | 73 | def get_win_median(self): 74 | return np.median(self.deque) 75 | 76 | def get_win_avg(self): 77 | return np.mean(self.deque) 78 | 79 | def get_global_avg(self): 80 | return self.total / self.count 81 | 82 | 83 | class TrainMeter(object): 84 | """Measures training stats.""" 85 | 86 | def __init__(self, epoch_iters): 87 | self.epoch_iters = epoch_iters 88 | self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters 89 | self.iter_timer = Timer() 90 | self.loss = ScalarMeter(cfg.LOG_PERIOD) 91 | self.loss_total = 0.0 92 | self.lr = None 93 | # Current minibatch errors (smoothed over a window) 94 | self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) 95 | self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) 96 | # Number of misclassified examples 97 | self.num_top1_mis = 0 98 | self.num_top5_mis = 0 99 | self.num_samples = 0 100 | 101 | def reset(self, timer=False): 102 | if timer: 103 | self.iter_timer.reset() 104 | self.loss.reset() 105 | self.loss_total = 0.0 106 | self.lr = None 107 | self.mb_top1_err.reset() 108 | self.mb_top5_err.reset() 109 | self.num_top1_mis = 0 110 | self.num_top5_mis = 0 111 | self.num_samples = 0 112 | 113 | def iter_tic(self): 114 | self.iter_timer.tic() 115 | 116 | def iter_toc(self): 117 | self.iter_timer.toc() 118 | 119 | def update_stats(self, top1_err, top5_err, loss, lr, mb_size): 120 | # Current minibatch stats 121 | self.mb_top1_err.add_value(top1_err) 122 | self.mb_top5_err.add_value(top5_err) 123 | self.loss.add_value(loss) 124 | self.lr = lr 125 | # Aggregate stats 126 | self.num_top1_mis += top1_err * mb_size 127 | self.num_top5_mis += top5_err * mb_size 128 | self.loss_total += loss * mb_size 129 | self.num_samples += mb_size 130 | 131 | def get_iter_stats(self, cur_epoch, cur_iter): 132 | cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1 133 | eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total) 134 | mem_usage = gpu_mem_usage() 135 | stats = { 136 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 137 | "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters), 138 | "time_avg": self.iter_timer.average_time, 139 | "time_diff": self.iter_timer.diff, 140 | "eta": time_string(eta_sec), 141 | "top1_err": self.mb_top1_err.get_win_median(), 142 | "top5_err": self.mb_top5_err.get_win_median(), 143 | "loss": self.loss.get_win_median(), 144 | "lr": self.lr, 145 | "mem": int(np.ceil(mem_usage)), 146 | } 147 | return stats 148 | 149 | def log_iter_stats(self, cur_epoch, cur_iter): 150 | if (cur_iter + 1) % cfg.LOG_PERIOD != 0: 151 | return 152 | stats = self.get_iter_stats(cur_epoch, cur_iter) 153 | logger.info(logging.dump_log_data(stats, "train_iter")) 154 | 155 | def get_epoch_stats(self, cur_epoch): 156 | cur_iter_total = (cur_epoch + 1) * self.epoch_iters 157 | eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total) 158 | mem_usage = gpu_mem_usage() 159 | top1_err = self.num_top1_mis / self.num_samples 160 | top5_err = self.num_top5_mis / self.num_samples 161 | avg_loss = self.loss_total / self.num_samples 162 | stats = { 163 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 164 | "time_avg": self.iter_timer.average_time, 165 | "eta": time_string(eta_sec), 166 | "top1_err": top1_err, 167 | "top5_err": top5_err, 168 | "loss": avg_loss, 169 | "lr": self.lr, 170 | "mem": int(np.ceil(mem_usage)), 171 | } 172 | return stats 173 | 174 | def log_epoch_stats(self, cur_epoch): 175 | stats = self.get_epoch_stats(cur_epoch) 176 | logger.info(logging.dump_log_data(stats, "train_epoch")) 177 | 178 | 179 | class TestMeter(object): 180 | """Measures testing stats.""" 181 | 182 | def __init__(self, max_iter): 183 | self.max_iter = max_iter 184 | self.iter_timer = Timer() 185 | # Current minibatch errors (smoothed over a window) 186 | self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) 187 | self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) 188 | # Min errors (over the full test set) 189 | self.min_top1_err = 100.0 190 | self.min_top5_err = 100.0 191 | # Number of misclassified examples 192 | self.num_top1_mis = 0 193 | self.num_top5_mis = 0 194 | self.num_samples = 0 195 | 196 | def reset(self, min_errs=False): 197 | if min_errs: 198 | self.min_top1_err = 100.0 199 | self.min_top5_err = 100.0 200 | self.iter_timer.reset() 201 | self.mb_top1_err.reset() 202 | self.mb_top5_err.reset() 203 | self.num_top1_mis = 0 204 | self.num_top5_mis = 0 205 | self.num_samples = 0 206 | 207 | def iter_tic(self): 208 | self.iter_timer.tic() 209 | 210 | def iter_toc(self): 211 | self.iter_timer.toc() 212 | 213 | def update_stats(self, top1_err, top5_err, mb_size): 214 | self.mb_top1_err.add_value(top1_err) 215 | self.mb_top5_err.add_value(top5_err) 216 | self.num_top1_mis += top1_err * mb_size 217 | self.num_top5_mis += top5_err * mb_size 218 | self.num_samples += mb_size 219 | 220 | def get_iter_stats(self, cur_epoch, cur_iter): 221 | mem_usage = gpu_mem_usage() 222 | iter_stats = { 223 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 224 | "iter": "{}/{}".format(cur_iter + 1, self.max_iter), 225 | "time_avg": self.iter_timer.average_time, 226 | "time_diff": self.iter_timer.diff, 227 | "top1_err": self.mb_top1_err.get_win_median(), 228 | "top5_err": self.mb_top5_err.get_win_median(), 229 | "mem": int(np.ceil(mem_usage)), 230 | } 231 | return iter_stats 232 | 233 | def log_iter_stats(self, cur_epoch, cur_iter): 234 | if (cur_iter + 1) % cfg.LOG_PERIOD != 0: 235 | return 236 | stats = self.get_iter_stats(cur_epoch, cur_iter) 237 | logger.info(logging.dump_log_data(stats, "test_iter")) 238 | 239 | def get_epoch_stats(self, cur_epoch): 240 | top1_err = self.num_top1_mis / max(1, self.num_samples) 241 | top5_err = self.num_top5_mis / max(1, self.num_samples) 242 | self.min_top1_err = min(self.min_top1_err, top1_err) 243 | self.min_top5_err = min(self.min_top5_err, top5_err) 244 | mem_usage = gpu_mem_usage() 245 | stats = { 246 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 247 | "time_avg": self.iter_timer.average_time, 248 | "top1_err": top1_err, 249 | "top5_err": top5_err, 250 | "min_top1_err": self.min_top1_err, 251 | "min_top5_err": self.min_top5_err, 252 | "mem": int(np.ceil(mem_usage)), 253 | } 254 | return stats 255 | 256 | def log_epoch_stats(self, cur_epoch): 257 | stats = self.get_epoch_stats(cur_epoch) 258 | logger.info(logging.dump_log_data(stats, "test_epoch")) 259 | -------------------------------------------------------------------------------- /imagenet/pycls/core/net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Functions for manipulating networks.""" 9 | 10 | import itertools 11 | import math 12 | 13 | import pycls.core.distributed as dist 14 | import torch 15 | import torch.nn as nn 16 | from pycls.core.config import cfg 17 | 18 | 19 | def init_weights(m): 20 | """Performs ResNet-style weight initialization.""" 21 | if isinstance(m, nn.Conv2d): 22 | # Note that there is no bias due to BN 23 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 24 | m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out)) 25 | elif isinstance(m, nn.BatchNorm2d): 26 | zero_init_gamma = cfg.BN.ZERO_INIT_FINAL_GAMMA 27 | zero_init_gamma = hasattr(m, "final_bn") and m.final_bn and zero_init_gamma 28 | m.weight.data.fill_(0.0 if zero_init_gamma else 1.0) 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.Linear): 31 | m.weight.data.normal_(mean=0.0, std=0.01) 32 | m.bias.data.zero_() 33 | 34 | 35 | @torch.no_grad() 36 | def compute_precise_bn_stats(model, loader): 37 | """Computes precise BN stats on training data.""" 38 | # Compute the number of minibatches to use 39 | num_iter = int(cfg.BN.NUM_SAMPLES_PRECISE / loader.batch_size / cfg.NUM_GPUS) 40 | num_iter = min(num_iter, len(loader)) 41 | # Retrieve the BN layers 42 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] 43 | # Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch)) 44 | running_means = [torch.zeros_like(bn.running_mean) for bn in bns] 45 | running_vars = [torch.zeros_like(bn.running_var) for bn in bns] 46 | # Remember momentum values 47 | momentums = [bn.momentum for bn in bns] 48 | # Set momentum to 1.0 to compute BN stats that only reflect the current batch 49 | for bn in bns: 50 | bn.momentum = 1.0 51 | # Average the BN stats for each BN layer over the batches 52 | for inputs, _labels in itertools.islice(loader, num_iter): 53 | model(inputs.cuda()) 54 | for i, bn in enumerate(bns): 55 | running_means[i] += bn.running_mean / num_iter 56 | running_vars[i] += bn.running_var / num_iter 57 | # Sync BN stats across GPUs (no reduction if 1 GPU used) 58 | running_means = dist.scaled_all_reduce(running_means) 59 | running_vars = dist.scaled_all_reduce(running_vars) 60 | # Set BN stats and restore original momentum values 61 | for i, bn in enumerate(bns): 62 | bn.running_mean = running_means[i] 63 | bn.running_var = running_vars[i] 64 | bn.momentum = momentums[i] 65 | 66 | 67 | def reset_bn_stats(model): 68 | """Resets running BN stats.""" 69 | for m in model.modules(): 70 | if isinstance(m, torch.nn.BatchNorm2d): 71 | m.reset_running_stats() 72 | 73 | 74 | def complexity_conv2d(cx, w_in, w_out, k, stride, padding, groups=1, bias=False): 75 | """Accumulates complexity of Conv2D into cx = (h, w, flops, params, acts).""" 76 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] 77 | h = (h + 2 * padding - k) // stride + 1 78 | w = (w + 2 * padding - k) // stride + 1 79 | flops += k * k * w_in * w_out * h * w // groups 80 | params += k * k * w_in * w_out // groups 81 | flops += w_out if bias else 0 82 | params += w_out if bias else 0 83 | acts += w_out * h * w 84 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} 85 | 86 | 87 | def complexity_batchnorm2d(cx, w_in): 88 | """Accumulates complexity of BatchNorm2D into cx = (h, w, flops, params, acts).""" 89 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] 90 | params += 2 * w_in 91 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} 92 | 93 | 94 | def complexity_maxpool2d(cx, k, stride, padding): 95 | """Accumulates complexity of MaxPool2d into cx = (h, w, flops, params, acts).""" 96 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] 97 | h = (h + 2 * padding - k) // stride + 1 98 | w = (w + 2 * padding - k) // stride + 1 99 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} 100 | 101 | 102 | def complexity(model): 103 | """Compute model complexity (model can be model instance or model class).""" 104 | size = cfg.TRAIN.IM_SIZE 105 | cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0} 106 | cx = model.complexity(cx) 107 | return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]} 108 | 109 | 110 | def drop_connect(x, drop_ratio): 111 | """Drop connect (adapted from DARTS).""" 112 | keep_ratio = 1.0 - drop_ratio 113 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device) 114 | mask.bernoulli_(keep_ratio) 115 | x.div_(keep_ratio) 116 | x.mul_(mask) 117 | return x 118 | 119 | 120 | def get_flat_weights(model): 121 | """Gets all model weights as a single flat vector.""" 122 | return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0) 123 | 124 | 125 | def set_flat_weights(model, flat_weights): 126 | """Sets all model weights from a single flat vector.""" 127 | k = 0 128 | for p in model.parameters(): 129 | n = p.data.numel() 130 | p.data.copy_(flat_weights[k : (k + n)].view_as(p.data)) 131 | k += n 132 | assert k == flat_weights.numel() 133 | -------------------------------------------------------------------------------- /imagenet/pycls/core/optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Optimizer.""" 9 | 10 | import numpy as np 11 | import torch 12 | from pycls.core.config import cfg 13 | 14 | 15 | def construct_optimizer(model): 16 | """Constructs the optimizer. 17 | 18 | Note that the momentum update in PyTorch differs from the one in Caffe2. 19 | In particular, 20 | 21 | Caffe2: 22 | V := mu * V + lr * g 23 | p := p - V 24 | 25 | PyTorch: 26 | V := mu * V + g 27 | p := p - lr * V 28 | 29 | where V is the velocity, mu is the momentum factor, lr is the learning rate, 30 | g is the gradient and p are the parameters. 31 | 32 | Since V is defined independently of the learning rate in PyTorch, 33 | when the learning rate is changed there is no need to perform the 34 | momentum correction by scaling V (unlike in the Caffe2 case). 35 | """ 36 | if cfg.BN.USE_CUSTOM_WEIGHT_DECAY: 37 | # Apply different weight decay to Batchnorm and non-batchnorm parameters. 38 | p_bn = [p for n, p in model.named_parameters() if "bn" in n] 39 | p_non_bn = [p for n, p in model.named_parameters() if "bn" not in n] 40 | optim_params = [ 41 | {"params": p_bn, "weight_decay": cfg.BN.CUSTOM_WEIGHT_DECAY}, 42 | {"params": p_non_bn, "weight_decay": cfg.OPTIM.WEIGHT_DECAY}, 43 | ] 44 | else: 45 | optim_params = model.parameters() 46 | 47 | return torch.optim.SGD( 48 | optim_params, 49 | lr=cfg.OPTIM.BASE_LR, 50 | momentum=cfg.OPTIM.MOMENTUM, 51 | weight_decay=cfg.OPTIM.WEIGHT_DECAY, 52 | dampening=cfg.OPTIM.DAMPENING, 53 | nesterov=cfg.OPTIM.NESTEROV, 54 | ) 55 | 56 | 57 | def lr_fun_steps(cur_epoch): 58 | """Steps schedule (cfg.OPTIM.LR_POLICY = 'steps').""" 59 | ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1] 60 | return cfg.OPTIM.BASE_LR * (cfg.OPTIM.LR_MULT ** ind) 61 | 62 | 63 | def lr_fun_exp(cur_epoch): 64 | """Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp').""" 65 | return cfg.OPTIM.BASE_LR * (cfg.OPTIM.GAMMA ** cur_epoch) 66 | 67 | 68 | def lr_fun_cos(cur_epoch): 69 | """Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos').""" 70 | base_lr, max_epoch = cfg.OPTIM.BASE_LR, cfg.OPTIM.MAX_EPOCH 71 | return 0.5 * base_lr * (1.0 + np.cos(np.pi * cur_epoch / max_epoch)) 72 | 73 | 74 | def get_lr_fun(): 75 | """Retrieves the specified lr policy function""" 76 | lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY 77 | if lr_fun not in globals(): 78 | raise NotImplementedError("Unknown LR policy:" + cfg.OPTIM.LR_POLICY) 79 | return globals()[lr_fun] 80 | 81 | 82 | def get_epoch_lr(cur_epoch): 83 | """Retrieves the lr for the given epoch according to the policy.""" 84 | lr = get_lr_fun()(cur_epoch) 85 | # Linear warmup 86 | if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS: 87 | alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS 88 | warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha 89 | lr *= warmup_factor 90 | return lr 91 | 92 | 93 | def set_lr(optimizer, new_lr): 94 | """Sets the optimizer lr to the specified value.""" 95 | for param_group in optimizer.param_groups: 96 | param_group["lr"] = new_lr 97 | -------------------------------------------------------------------------------- /imagenet/pycls/core/plotting.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Plotting functions.""" 9 | 10 | import colorlover as cl 11 | import matplotlib.pyplot as plt 12 | import plotly.graph_objs as go 13 | import plotly.offline as offline 14 | import pycls.core.logging as logging 15 | 16 | 17 | def get_plot_colors(max_colors, color_format="pyplot"): 18 | """Generate colors for plotting.""" 19 | colors = cl.scales["11"]["qual"]["Paired"] 20 | if max_colors > len(colors): 21 | colors = cl.to_rgb(cl.interp(colors, max_colors)) 22 | if color_format == "pyplot": 23 | return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)] 24 | return colors 25 | 26 | 27 | def prepare_plot_data(log_files, names, metric="top1_err"): 28 | """Load logs and extract data for plotting error curves.""" 29 | plot_data = [] 30 | for file, name in zip(log_files, names): 31 | d, data = {}, logging.sort_log_data(logging.load_log_data(file)) 32 | for phase in ["train", "test"]: 33 | x = data[phase + "_epoch"]["epoch_ind"] 34 | y = data[phase + "_epoch"][metric] 35 | d["x_" + phase], d["y_" + phase] = x, y 36 | d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name 37 | plot_data.append(d) 38 | assert len(plot_data) > 0, "No data to plot" 39 | return plot_data 40 | 41 | 42 | def plot_error_curves_plotly(log_files, names, filename, metric="top1_err"): 43 | """Plot error curves using plotly and save to file.""" 44 | plot_data = prepare_plot_data(log_files, names, metric) 45 | colors = get_plot_colors(len(plot_data), "plotly") 46 | # Prepare data for plots (3 sets, train duplicated w and w/o legend) 47 | data = [] 48 | for i, d in enumerate(plot_data): 49 | s = str(i) 50 | line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5} 51 | line_test = {"color": colors[i], "dash": "solid", "width": 1.5} 52 | data.append( 53 | go.Scatter( 54 | x=d["x_train"], 55 | y=d["y_train"], 56 | mode="lines", 57 | name=d["train_label"], 58 | line=line_train, 59 | legendgroup=s, 60 | visible=True, 61 | showlegend=False, 62 | ) 63 | ) 64 | data.append( 65 | go.Scatter( 66 | x=d["x_test"], 67 | y=d["y_test"], 68 | mode="lines", 69 | name=d["test_label"], 70 | line=line_test, 71 | legendgroup=s, 72 | visible=True, 73 | showlegend=True, 74 | ) 75 | ) 76 | data.append( 77 | go.Scatter( 78 | x=d["x_train"], 79 | y=d["y_train"], 80 | mode="lines", 81 | name=d["train_label"], 82 | line=line_train, 83 | legendgroup=s, 84 | visible=False, 85 | showlegend=True, 86 | ) 87 | ) 88 | # Prepare layout w ability to toggle 'all', 'train', 'test' 89 | titlefont = {"size": 18, "color": "#7f7f7f"} 90 | vis = [[True, True, False], [False, False, True], [False, True, False]] 91 | buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis]) 92 | buttons = [{"label": b, "args": v, "method": "update"} for b, v in buttons] 93 | layout = go.Layout( 94 | title=metric + " vs. epoch
[dash=train, solid=test]", 95 | xaxis={"title": "epoch", "titlefont": titlefont}, 96 | yaxis={"title": metric, "titlefont": titlefont}, 97 | showlegend=True, 98 | hoverlabel={"namelength": -1}, 99 | updatemenus=[ 100 | { 101 | "buttons": buttons, 102 | "direction": "down", 103 | "showactive": True, 104 | "x": 1.02, 105 | "xanchor": "left", 106 | "y": 1.08, 107 | "yanchor": "top", 108 | } 109 | ], 110 | ) 111 | # Create plotly plot 112 | offline.plot({"data": data, "layout": layout}, filename=filename) 113 | 114 | 115 | def plot_error_curves_pyplot(log_files, names, filename=None, metric="top1_err"): 116 | """Plot error curves using matplotlib.pyplot and save to file.""" 117 | plot_data = prepare_plot_data(log_files, names, metric) 118 | colors = get_plot_colors(len(names)) 119 | for ind, d in enumerate(plot_data): 120 | c, lbl = colors[ind], d["test_label"] 121 | plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8) 122 | plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl) 123 | plt.title(metric + " vs. epoch\n[dash=train, solid=test]", fontsize=14) 124 | plt.xlabel("epoch", fontsize=14) 125 | plt.ylabel(metric, fontsize=14) 126 | plt.grid(alpha=0.4) 127 | plt.legend() 128 | if filename: 129 | plt.savefig(filename) 130 | plt.clf() 131 | else: 132 | plt.show() 133 | -------------------------------------------------------------------------------- /imagenet/pycls/core/timer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Timer.""" 9 | 10 | import time 11 | 12 | 13 | class Timer(object): 14 | """A simple timer (adapted from Detectron).""" 15 | 16 | def __init__(self): 17 | self.total_time = None 18 | self.calls = None 19 | self.start_time = None 20 | self.diff = None 21 | self.average_time = None 22 | self.reset() 23 | 24 | def tic(self): 25 | # using time.time as time.clock does not normalize for multithreading 26 | self.start_time = time.time() 27 | 28 | def toc(self): 29 | self.diff = time.time() - self.start_time 30 | self.total_time += self.diff 31 | self.calls += 1 32 | self.average_time = self.total_time / self.calls 33 | 34 | def reset(self): 35 | self.total_time = 0.0 36 | self.calls = 0 37 | self.start_time = 0.0 38 | self.diff = 0.0 39 | self.average_time = 0.0 40 | -------------------------------------------------------------------------------- /imagenet/pycls/core/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Tools for training and testing a model.""" 9 | 10 | import os 11 | 12 | import numpy as np 13 | import pycls.core.benchmark as benchmark 14 | import pycls.core.builders as builders 15 | import pycls.core.checkpoint as checkpoint 16 | import pycls.core.config as config 17 | import pycls.core.distributed as dist 18 | import pycls.core.logging as logging 19 | import pycls.core.meters as meters 20 | import pycls.core.net as net 21 | import pycls.core.optimizer as optim 22 | import pycls.datasets.loader as loader 23 | from pycls.datasets.transforms import cutmix_batch 24 | import torch 25 | from pycls.core.config import cfg 26 | import torch.nn as nn 27 | import torch.distributed as tdist 28 | 29 | 30 | logger = logging.get_logger(__name__) 31 | 32 | 33 | def setup_env(): 34 | """Sets up environment for training or testing.""" 35 | if dist.is_master_proc(): 36 | # Ensure that the output dir exists 37 | os.makedirs(cfg.OUT_DIR, exist_ok=True) 38 | # Save the config 39 | config.dump_cfg() 40 | # Setup logging 41 | logging.setup_logging() 42 | # Log the config as both human readable and as a json 43 | logger.info("Config:\n{}".format(cfg)) 44 | logger.info(logging.dump_log_data(cfg, "cfg")) 45 | # Fix the RNG seeds (see RNG comment in core/config.py for discussion) 46 | np.random.seed(cfg.RNG_SEED) 47 | torch.manual_seed(cfg.RNG_SEED) 48 | # Configure the CUDNN backend 49 | torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK 50 | 51 | 52 | def setup_model(): 53 | """Sets up a model for training or testing and log the results.""" 54 | # Build the model 55 | model = builders.build_model() 56 | if cfg.TRAIN.SYNC_BN: 57 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 58 | logger.info("Model:\n{}".format(model)) 59 | # Log model complexity 60 | # logger.info(logging.dump_log_data(net.complexity(model), "complexity")) 61 | # Transfer the model to the current GPU device 62 | cur_device = torch.cuda.current_device() 63 | model = model.cuda(device=cur_device) 64 | # Use multi-process data parallel model in the multi-gpu setting 65 | # Make model replica operate on the current device 66 | model = torch.nn.parallel.DistributedDataParallel( 67 | module=model, device_ids=[cur_device], output_device=cur_device 68 | ) 69 | # Set complexity function to be module's complexity function 70 | # model.complexity = model.module.complexity 71 | return model 72 | 73 | class SoftCrossEntropyLoss(nn.Module): 74 | """SoftCrossEntropyLoss (useful for label smoothing and mixup). 75 | Identical to torch.nn.CrossEntropyLoss if used with one-hot labels.""" 76 | 77 | def __init__(self): 78 | super(SoftCrossEntropyLoss, self).__init__() 79 | 80 | def forward(self, x, y): 81 | loss = -y * torch.nn.functional.log_softmax(x, -1) 82 | return torch.sum(loss) / x.shape[0] 83 | 84 | class KDLoss(nn.Module): 85 | def __init__(self, temp_factor=4.): 86 | super(KDLoss, self).__init__() 87 | self.temp_factor = temp_factor 88 | self.kl_div = nn.KLDivLoss(reduction="sum") 89 | 90 | def forward(self, input, target): 91 | log_p = torch.log_softmax(input/self.temp_factor, dim=1) 92 | loss = self.kl_div(log_p, target)*(self.temp_factor**2)/input.size(0) 93 | return loss 94 | 95 | def knowledge_ensemble(feats, logits, temp_factor=4., omega=0.5, cross_gpus=True): 96 | batch_size = logits.size(0) 97 | feats = nn.functional.normalize(feats, p=2, dim=1) 98 | logits = nn.functional.softmax(logits/temp_factor, dim=1) 99 | 100 | if cross_gpus: 101 | feats_large = [torch.zeros_like(feats) \ 102 | for _ in range(tdist.get_world_size())] 103 | tdist.all_gather(feats_large, feats) 104 | feats_large = torch.cat(feats_large, dim=0) 105 | logits_large = [torch.zeros_like(logits) \ 106 | for _ in range(tdist.get_world_size())] 107 | tdist.all_gather(logits_large, logits) 108 | logits_large = torch.cat(logits_large, dim=0) 109 | enlarged_batch_size = logits_large.size(0) 110 | labels_idx = torch.arange(batch_size) + tdist.get_rank() * batch_size 111 | else: 112 | feats_large = feats 113 | logits_large = logits 114 | enlarged_batch_size = batch_size 115 | labels_idx = torch.arange(batch_size) 116 | 117 | masks = torch.eye(enlarged_batch_size).cuda() 118 | W = torch.matmul(feats_large, feats_large.permute(1, 0)) - masks * 1e9 119 | W = nn.functional.softmax(W, dim=1) 120 | W = (1 - omega) * torch.inverse(masks - omega * W) 121 | logits_new = torch.matmul(W, logits_large) 122 | return logits_new[labels_idx] 123 | 124 | def to_one_hot_labels(labels): 125 | """Convert each label to a one-hot vector.""" 126 | n_classes = cfg.MODEL.NUM_CLASSES 127 | err_str = "Invalid input to one_hot_vector()" 128 | assert labels.ndim == 1 and labels.max() < n_classes, err_str 129 | shape = (labels.shape[0], n_classes) 130 | neg_val, pos_val = 0.0, 1.0 131 | labels_one_hot = torch.full(shape, neg_val, dtype=torch.float, device=labels.device) 132 | labels_one_hot.scatter_(1, labels.long().view(-1, 1), pos_val) 133 | return labels_one_hot 134 | 135 | @torch.no_grad() 136 | def concat_all_gather(tensor): 137 | """ 138 | Performs all_gather operation on the provided tensors. 139 | *** Warning ***: torch.distributed.all_gather has no gradient. 140 | """ 141 | tensors_gather = [torch.ones_like(tensor) 142 | for _ in range(torch.distributed.get_world_size())] 143 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 144 | 145 | output = torch.cat(tensors_gather, dim=0) 146 | return output 147 | 148 | @torch.no_grad() 149 | def _batch_shuffle_ddp(x, y): 150 | """ 151 | Batch shuffle, for making use of BatchNorm. 152 | *** Only support DistributedDataParallel (DDP) model. *** 153 | """ 154 | # gather from all gpus 155 | batch_size_this = x.shape[0] 156 | x_gather = concat_all_gather(x) 157 | y_gather = concat_all_gather(y) 158 | batch_size_all = x_gather.shape[0] 159 | 160 | num_gpus = batch_size_all // batch_size_this 161 | 162 | # random shuffle index 163 | idx_shuffle = torch.randperm(batch_size_all).cuda() 164 | 165 | # broadcast to all gpus 166 | torch.distributed.broadcast(idx_shuffle, src=0) 167 | 168 | # index for restoring 169 | idx_unshuffle = torch.argsort(idx_shuffle) 170 | 171 | # shuffled index for this gpu 172 | gpu_idx = torch.distributed.get_rank() 173 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 174 | 175 | return x_gather[idx_this], y_gather[idx_this] 176 | 177 | 178 | def train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch): 179 | """Performs one epoch of training.""" 180 | # Shuffle the data 181 | loader.shuffle(train_loader, cur_epoch) 182 | # Update the learning rate 183 | lr = optim.get_epoch_lr(cur_epoch) 184 | optim.set_lr(optimizer, lr) 185 | kdloss = KDLoss(cfg.TRAIN.TEMP).cuda() 186 | softceloss = SoftCrossEntropyLoss().cuda() 187 | # Enable training mode 188 | model.train() 189 | train_meter.iter_tic() 190 | for cur_iter, (inputs, labels) in enumerate(train_loader): 191 | # if (cur_iter>=train_meter.epoch_iters): break 192 | # Transfer the data to the current GPU device 193 | # import pdb; pdb.set_trace() 194 | inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) 195 | if cfg.TRAIN.SHUFFLE_BN: 196 | inputs, labels = _batch_shuffle_ddp(inputs, labels) 197 | onehot_labels = to_one_hot_labels(labels) 198 | # apply cutmix augmentation 199 | if cfg.TRAIN.IS_CUTMIX: 200 | disable = (cur_epoch >= (cfg.OPTIM.MAX_EPOCH - cfg.TRAIN.CUTMIX.OFF_EPOCH)) 201 | inputs, onehot_labels = cutmix_batch(inputs, onehot_labels, cfg.MODEL.NUM_CLASSES, cfg.TRAIN.CUTMIX, disable) 202 | # Perform the forward pass 203 | feas, preds = model(inputs) 204 | # Compute the loss 205 | loss = softceloss(preds, onehot_labels) 206 | # random walk 207 | if cfg.TRAIN.USE_BAKE: 208 | with torch.no_grad(): 209 | kd_targets = knowledge_ensemble(feas.detach(), preds.detach(), 210 | temp_factor=cfg.TRAIN.TEMP, omega=cfg.TRAIN.OMEGA, cross_gpus=cfg.TRAIN.GLOBAL_KE) 211 | loss += kdloss(preds, kd_targets.detach())*cfg.TRAIN.LAMBDA 212 | # Perform the backward pass 213 | optimizer.zero_grad() 214 | loss.backward() 215 | # Update the parameters 216 | optimizer.step() 217 | # Compute the errors 218 | top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5]) 219 | # Combine the stats across the GPUs (no reduction if 1 GPU used) 220 | loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err]) 221 | # Copy the stats from GPU to CPU (sync point) 222 | loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item() 223 | train_meter.iter_toc() 224 | # Update and log stats 225 | mb_size = inputs.size(0) * cfg.NUM_GPUS 226 | train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size) 227 | train_meter.log_iter_stats(cur_epoch, cur_iter) 228 | train_meter.iter_tic() 229 | # Log epoch stats 230 | train_meter.log_epoch_stats(cur_epoch) 231 | train_meter.reset() 232 | 233 | 234 | @torch.no_grad() 235 | def test_epoch(test_loader, model, test_meter, cur_epoch): 236 | """Evaluates the model on the test set.""" 237 | # Enable eval mode 238 | model.eval() 239 | test_meter.iter_tic() 240 | for cur_iter, (inputs, labels) in enumerate(test_loader): 241 | # Transfer the data to the current GPU device 242 | inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) 243 | # Compute the predictions 244 | _, preds = model(inputs) 245 | # Compute the errors 246 | top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5]) 247 | # Combine the errors across the GPUs (no reduction if 1 GPU used) 248 | top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err]) 249 | # Copy the errors from GPU to CPU (sync point) 250 | top1_err, top5_err = top1_err.item(), top5_err.item() 251 | test_meter.iter_toc() 252 | # Update and log stats 253 | test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) 254 | test_meter.log_iter_stats(cur_epoch, cur_iter) 255 | test_meter.iter_tic() 256 | # Log epoch stats 257 | test_meter.log_epoch_stats(cur_epoch) 258 | test_meter.reset() 259 | 260 | 261 | def train_model(): 262 | """Trains the model.""" 263 | # Setup training/testing environment 264 | setup_env() 265 | # Construct the model, loss_fun, and optimizer 266 | model = setup_model() 267 | loss_fun = builders.build_loss_fun().cuda() 268 | optimizer = optim.construct_optimizer(model) 269 | # Load checkpoint or initial weights 270 | start_epoch = 0 271 | if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint(): 272 | last_checkpoint = checkpoint.get_last_checkpoint() 273 | checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer) 274 | logger.info("Loaded checkpoint from: {}".format(last_checkpoint)) 275 | start_epoch = checkpoint_epoch + 1 276 | elif cfg.TRAIN.WEIGHTS: 277 | checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model) 278 | logger.info("Loaded initial weights from: {}".format(cfg.TRAIN.WEIGHTS)) 279 | # Create data loaders and meters 280 | train_loader = loader.construct_train_loader() 281 | test_loader = loader.construct_test_loader() 282 | train_meter = meters.TrainMeter(len(train_loader)) 283 | test_meter = meters.TestMeter(len(test_loader)) 284 | # Compute model and loader timings 285 | # if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: 286 | # benchmark.compute_time_full(model, loss_fun, train_loader, test_loader) 287 | # Perform the training loop 288 | logger.info("Start epoch: {}".format(start_epoch + 1)) 289 | for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): 290 | # Train for one epoch 291 | train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch) 292 | # Compute precise BN stats 293 | if cfg.BN.USE_PRECISE_STATS: 294 | net.compute_precise_bn_stats(model, train_loader) 295 | # Save a checkpoint 296 | if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0: 297 | checkpoint_file = checkpoint.save_checkpoint(model, optimizer, cur_epoch) 298 | logger.info("Wrote checkpoint to: {}".format(checkpoint_file)) 299 | # Evaluate the model 300 | next_epoch = cur_epoch + 1 301 | if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH: 302 | test_epoch(test_loader, model, test_meter, cur_epoch) 303 | 304 | 305 | def test_model(): 306 | """Evaluates a trained model.""" 307 | # Setup training/testing environment 308 | setup_env() 309 | # Construct the model 310 | model = setup_model() 311 | # Load model weights 312 | checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model) 313 | logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS)) 314 | # Create data loaders and meters 315 | test_loader = loader.construct_test_loader() 316 | test_meter = meters.TestMeter(len(test_loader)) 317 | # Evaluate the model 318 | test_epoch(test_loader, model, test_meter, 0) 319 | 320 | 321 | def time_model(): 322 | """Times model and data loader.""" 323 | # Setup training/testing environment 324 | setup_env() 325 | # Construct the model and loss_fun 326 | model = setup_model() 327 | loss_fun = builders.build_loss_fun().cuda() 328 | # Create data loaders 329 | train_loader = loader.construct_train_loader() 330 | test_loader = loader.construct_test_loader() 331 | # Compute model and loader timings 332 | benchmark.compute_time_full(model, loss_fun, train_loader, test_loader) 333 | -------------------------------------------------------------------------------- /imagenet/pycls/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxgeee/BAKE/07c4f668ea19311d5b50121026e73d2f035d5765/imagenet/pycls/datasets/__init__.py -------------------------------------------------------------------------------- /imagenet/pycls/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ImageNet dataset.""" 9 | 10 | import os 11 | import re 12 | 13 | import cv2 14 | from PIL import Image 15 | import numpy as np 16 | import pycls.core.logging as logging 17 | import pycls.datasets.transforms as transforms 18 | import torch.utils.data 19 | from pycls.core.config import cfg 20 | import torchvision.transforms as T 21 | 22 | 23 | logger = logging.get_logger(__name__) 24 | 25 | # Per-channel mean and SD values in BGR order 26 | _MEAN = [0.406, 0.456, 0.485] 27 | _SD = [0.225, 0.224, 0.229] 28 | 29 | # Eig vals and vecs of the cov mat 30 | _EIG_VALS = np.array([[0.2175, 0.0188, 0.0045]]) 31 | _EIG_VECS = np.array( 32 | [[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]] 33 | ) 34 | 35 | 36 | class ImageNet(torch.utils.data.Dataset): 37 | """ImageNet dataset.""" 38 | 39 | def __init__(self, data_path, split): 40 | assert os.path.exists(data_path), "Data path '{}' not found".format(data_path) 41 | splits = ["train", "val"] 42 | assert split in splits, "Split '{}' not supported for ImageNet".format(split) 43 | logger.info("Constructing ImageNet {}...".format(split)) 44 | self._data_path, self._split = data_path, split 45 | self._construct_imdb() 46 | 47 | def _construct_imdb(self): 48 | """Constructs the imdb.""" 49 | # Compile the split data path 50 | split_path = os.path.join(self._data_path, self._split) 51 | logger.info("{} data path: {}".format(self._split, split_path)) 52 | # Images are stored per class in subdirs (format: n) 53 | split_files = os.listdir(split_path) 54 | self._class_ids = sorted(f for f in split_files if re.match(r"^n[0-9]+$", f)) 55 | # Map ImageNet class ids to contiguous ids 56 | self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)} 57 | # Construct the image db 58 | self._imdb = [] 59 | for class_id in self._class_ids: 60 | cont_id = self._class_id_cont_id[class_id] 61 | im_dir = os.path.join(split_path, class_id) 62 | for im_name in os.listdir(im_dir): 63 | im_path = os.path.join(im_dir, im_name) 64 | self._imdb.append({"im_path": im_path, "class": cont_id}) 65 | logger.info("Number of images: {}".format(len(self._imdb))) 66 | logger.info("Number of classes: {}".format(len(self._class_ids))) 67 | 68 | def _prepare_im(self, im): 69 | """Prepares the image for network input.""" 70 | # Train and test setups differ 71 | train_size = cfg.TRAIN.IM_SIZE 72 | if self._split == "train": 73 | # Scale and aspect ratio then horizontal flip 74 | im = transforms.random_sized_crop(im=im, size=train_size, area_frac=0.08) 75 | im = transforms.horizontal_flip(im=im, p=0.5, order="HWC") 76 | else: 77 | # Scale and center crop 78 | im = transforms.scale(cfg.TEST.IM_SIZE, im) 79 | im = transforms.center_crop(train_size, im) 80 | # HWC -> CHW 81 | im = im.transpose([2, 0, 1]) 82 | # [0, 255] -> [0, 1] 83 | im = im / 255.0 84 | # PCA jitter 85 | if self._split == "train": 86 | im = transforms.lighting(im, 0.1, _EIG_VALS, _EIG_VECS) 87 | # Color normalization 88 | im = transforms.color_norm(im, _MEAN, _SD) 89 | return im 90 | 91 | def __getitem__(self, index): 92 | # Load the image 93 | im = Image.open(self._imdb[index]["im_path"]).convert('RGB') 94 | if self._split == "train" and cfg.TRAIN.JITTER: 95 | im = T.ColorJitter(0.4, 0.4, 0.4)(im) 96 | im = cv2.cvtColor(np.asarray(im),cv2.COLOR_RGB2BGR) 97 | im = im.astype(np.float32, copy=False) 98 | # Prepare the image for training / testing 99 | im = self._prepare_im(im) 100 | # Retrieve the label 101 | label = self._imdb[index]["class"] 102 | return im, label 103 | 104 | def __len__(self): 105 | return len(self._imdb) 106 | -------------------------------------------------------------------------------- /imagenet/pycls/datasets/loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Data loader.""" 9 | 10 | import os 11 | 12 | import torch 13 | from pycls.core.config import cfg 14 | from pycls.datasets.imagenet import ImageNet 15 | from pycls.datasets.sampler import DistributedClassSampler 16 | from torch.utils.data.distributed import DistributedSampler 17 | from torch.utils.data.sampler import RandomSampler 18 | 19 | 20 | # Supported datasets 21 | _DATASETS = {"imagenet": ImageNet} 22 | 23 | # Default data directory (/path/pycls/pycls/datasets/data) 24 | _DATA_DIR = os.path.join(os.path.dirname(__file__), "data") 25 | 26 | # Relative data paths to default data directory 27 | _PATHS = {"imagenet": "imagenet"} 28 | 29 | 30 | def _construct_loader(dataset_name, split, batch_size, shuffle, drop_last): 31 | """Constructs the data loader for the given dataset.""" 32 | err_str = "Dataset '{}' not supported".format(dataset_name) 33 | assert dataset_name in _DATASETS and dataset_name in _PATHS, err_str 34 | # Retrieve the data path for the dataset 35 | data_path = os.path.join(_DATA_DIR, _PATHS[dataset_name]) 36 | # Construct the dataset 37 | dataset = _DATASETS[dataset_name](data_path, split) 38 | # Create a sampler for multi-process training 39 | if (cfg.TRAIN.INTRA_IMGS > 0 and split == cfg.TRAIN.SPLIT): 40 | sampler = DistributedClassSampler(dataset._imdb, cfg.TRAIN.INTRA_IMGS+1) 41 | else: 42 | sampler = DistributedSampler(dataset) 43 | # Create a loader 44 | loader = torch.utils.data.DataLoader( 45 | dataset, 46 | batch_size=batch_size, 47 | shuffle=(False if sampler else shuffle), 48 | sampler=sampler, 49 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 50 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 51 | drop_last=drop_last, 52 | ) 53 | return loader 54 | 55 | 56 | def construct_train_loader(): 57 | """Train loader wrapper.""" 58 | return _construct_loader( 59 | dataset_name=cfg.TRAIN.DATASET, 60 | split=cfg.TRAIN.SPLIT, 61 | batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS), 62 | shuffle=True, 63 | drop_last=True, 64 | ) 65 | 66 | 67 | def construct_test_loader(): 68 | """Test loader wrapper.""" 69 | return _construct_loader( 70 | dataset_name=cfg.TEST.DATASET, 71 | split=cfg.TEST.SPLIT, 72 | batch_size=int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS), 73 | shuffle=False, 74 | drop_last=False, 75 | ) 76 | 77 | 78 | def shuffle(loader, cur_epoch): 79 | """"Shuffles the data.""" 80 | err_str = "Sampler type '{}' not supported".format(type(loader.sampler)) 81 | assert isinstance(loader.sampler, (RandomSampler, DistributedSampler, DistributedClassSampler)), err_str 82 | # RandomSampler handles shuffling automatically 83 | if isinstance(loader.sampler, (DistributedSampler, DistributedClassSampler)): 84 | # DistributedSampler shuffles data based on epoch 85 | loader.sampler.set_epoch(cur_epoch) 86 | -------------------------------------------------------------------------------- /imagenet/pycls/datasets/sampler.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import defaultdict 3 | import numpy as np 4 | import math 5 | import copy 6 | 7 | import torch 8 | from torch.utils.data.sampler import Sampler 9 | from torch.utils.data.distributed import DistributedSampler 10 | import torch.distributed as dist 11 | 12 | __all__ = ["DistributedClassSampler"] 13 | 14 | 15 | class DistributedClassSampler(Sampler): 16 | def __init__(self, dataset, num_instances, seed=0): 17 | if not dist.is_available(): 18 | raise RuntimeError("Requires distributed package to be available") 19 | num_replicas = dist.get_world_size() 20 | rank = dist.get_rank() 21 | self.dataset = dataset 22 | self.num_replicas = num_replicas 23 | self.rank = rank 24 | self.epoch = 0 25 | self.seed = seed 26 | self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | self.num_instances = num_instances 30 | self.pid_index = defaultdict(list) 31 | for idx, item in enumerate(self.dataset): 32 | self.pid_index[item['class']].append(idx) 33 | 34 | 35 | def __len__(self): 36 | return self.num_samples * self.num_instances 37 | 38 | def __iter__(self): 39 | # deterministically shuffle based on epoch and seed 40 | g = torch.Generator() 41 | g.manual_seed(self.seed + self.epoch) 42 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore 43 | 44 | # add extra samples to make it evenly divisible 45 | indices += indices[:(self.total_size - len(indices))] 46 | assert len(indices) == self.total_size 47 | 48 | # subsample 49 | indices = indices[self.rank:self.total_size:self.num_replicas] 50 | assert len(indices) == self.num_samples 51 | 52 | ret = [] 53 | for i in indices: 54 | ret.append(i) 55 | select_indexes = [j for j in self.pid_index[self.dataset[i]['class']] if j!=i] 56 | if (not select_indexes): 57 | continue 58 | if (len(select_indexes)>=self.num_instances-1): 59 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=False) 60 | else: 61 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=True) 62 | ret.extend(ind_indexes) 63 | return iter(ret) 64 | 65 | def set_epoch(self, epoch): 66 | self.epoch = epoch 67 | -------------------------------------------------------------------------------- /imagenet/pycls/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Image transformations.""" 9 | 10 | import math 11 | 12 | import cv2 13 | import numpy as np 14 | import torch 15 | 16 | def color_norm(im, mean, std): 17 | """Performs per-channel normalization (CHW format).""" 18 | for i in range(im.shape[0]): 19 | im[i] = im[i] - mean[i] 20 | im[i] = im[i] / std[i] 21 | return im 22 | 23 | 24 | def zero_pad(im, pad_size): 25 | """Performs zero padding (CHW format).""" 26 | pad_width = ((0, 0), (pad_size, pad_size), (pad_size, pad_size)) 27 | return np.pad(im, pad_width, mode="constant") 28 | 29 | 30 | def horizontal_flip(im, p, order="CHW"): 31 | """Performs horizontal flip (CHW or HWC format).""" 32 | assert order in ["CHW", "HWC"] 33 | if np.random.uniform() < p: 34 | if order == "CHW": 35 | im = im[:, :, ::-1] 36 | else: 37 | im = im[:, ::-1, :] 38 | return im 39 | 40 | 41 | def random_crop(im, size, pad_size=0): 42 | """Performs random crop (CHW format).""" 43 | if pad_size > 0: 44 | im = zero_pad(im=im, pad_size=pad_size) 45 | h, w = im.shape[1:] 46 | y = np.random.randint(0, h - size) 47 | x = np.random.randint(0, w - size) 48 | im_crop = im[:, y : (y + size), x : (x + size)] 49 | assert im_crop.shape[1:] == (size, size) 50 | return im_crop 51 | 52 | 53 | def scale(size, im): 54 | """Performs scaling (HWC format).""" 55 | h, w = im.shape[:2] 56 | if (w <= h and w == size) or (h <= w and h == size): 57 | return im 58 | h_new, w_new = size, size 59 | if w < h: 60 | h_new = int(math.floor((float(h) / w) * size)) 61 | else: 62 | w_new = int(math.floor((float(w) / h) * size)) 63 | im = cv2.resize(im, (w_new, h_new), interpolation=cv2.INTER_LINEAR) 64 | return im.astype(np.float32) 65 | 66 | 67 | def center_crop(size, im): 68 | """Performs center cropping (HWC format).""" 69 | h, w = im.shape[:2] 70 | y = int(math.ceil((h - size) / 2)) 71 | x = int(math.ceil((w - size) / 2)) 72 | im_crop = im[y : (y + size), x : (x + size), :] 73 | assert im_crop.shape[:2] == (size, size) 74 | return im_crop 75 | 76 | 77 | def random_sized_crop(im, size, area_frac=0.08, max_iter=10): 78 | """Performs Inception-style cropping (HWC format).""" 79 | h, w = im.shape[:2] 80 | area = h * w 81 | for _ in range(max_iter): 82 | target_area = np.random.uniform(area_frac, 1.0) * area 83 | aspect_ratio = np.random.uniform(3.0 / 4.0, 4.0 / 3.0) 84 | w_crop = int(round(math.sqrt(float(target_area) * aspect_ratio))) 85 | h_crop = int(round(math.sqrt(float(target_area) / aspect_ratio))) 86 | if np.random.uniform() < 0.5: 87 | w_crop, h_crop = h_crop, w_crop 88 | if h_crop <= h and w_crop <= w: 89 | y = 0 if h_crop == h else np.random.randint(0, h - h_crop) 90 | x = 0 if w_crop == w else np.random.randint(0, w - w_crop) 91 | im_crop = im[y : (y + h_crop), x : (x + w_crop), :] 92 | assert im_crop.shape[:2] == (h_crop, w_crop) 93 | im_crop = cv2.resize(im_crop, (size, size), interpolation=cv2.INTER_LINEAR) 94 | return im_crop.astype(np.float32) 95 | return center_crop(size, scale(size, im)) 96 | 97 | 98 | def lighting(im, alpha_std, eig_val, eig_vec): 99 | """Performs AlexNet-style PCA jitter (CHW format).""" 100 | if alpha_std == 0: 101 | return im 102 | alpha = np.random.normal(0, alpha_std, size=(1, 3)) 103 | alpha = np.repeat(alpha, 3, axis=0) 104 | eig_val = np.repeat(eig_val, 3, axis=0) 105 | rgb = np.sum(eig_vec * alpha * eig_val, axis=1) 106 | for i in range(im.shape[0]): 107 | im[i] = im[i] + rgb[2 - i] 108 | return im 109 | 110 | # for cutmix 111 | 112 | def smooth_target(target, smoothing=0.1, num_classes=1000): 113 | target *= (1. - smoothing) 114 | target += (smoothing / num_classes) 115 | return target 116 | 117 | 118 | def mix_target(target_a, target_b, num_classes, lam=1., smoothing=0.0): 119 | y1 = smooth_target(target_a, smoothing, num_classes) 120 | y2 = smooth_target(target_b, smoothing, num_classes) 121 | return lam * y1 + (1. - lam) * y2 122 | 123 | 124 | def rand_bbox(size, lam): 125 | W = size[2] 126 | H = size[3] 127 | cut_rat = np.sqrt(1. - lam) 128 | cut_w = np.int(W * cut_rat) 129 | cut_h = np.int(H * cut_rat) 130 | 131 | # uniform 132 | cx = np.random.randint(W) 133 | cy = np.random.randint(H) 134 | 135 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 136 | bby1 = np.clip(cy - cut_h // 2, 0, H) 137 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 138 | bby2 = np.clip(cy + cut_h // 2, 0, H) 139 | 140 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (size[2] * size[3])) 141 | 142 | return bbx1, bby1, bbx2, bby2, lam 143 | 144 | 145 | def cutmix_batch(image, target, num_classes, mix_args, disable): 146 | if not disable and np.random.rand(1) < mix_args.PROB: 147 | lam = np.random.beta(mix_args.BETA, mix_args.BETA) 148 | else: 149 | target = smooth_target( 150 | target=target, 151 | smoothing=mix_args.SMOOTHING, 152 | num_classes=num_classes) 153 | return image, target 154 | 155 | rand_index = torch.randperm(image.size()[0], device=image.device) 156 | target_a = target 157 | target_b = target[rand_index] 158 | bbx1, bby1, bbx2, bby2, lam = rand_bbox(image.size(), lam) 159 | image[:, :, bbx1:bbx2, bby1:bby2] = image[rand_index, :, bbx1:bbx2, 160 | bby1:bby2] 161 | 162 | target = mix_target(target_a=target_a, 163 | target_b=target_b, 164 | num_classes=num_classes, 165 | lam=lam, 166 | smoothing=mix_args.SMOOTHING) 167 | return image, target 168 | -------------------------------------------------------------------------------- /imagenet/pycls/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yxgeee/BAKE/07c4f668ea19311d5b50121026e73d2f035d5765/imagenet/pycls/models/__init__.py -------------------------------------------------------------------------------- /imagenet/pycls/models/effnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """EfficientNet models.""" 9 | 10 | import pycls.core.net as net 11 | import torch 12 | import torch.nn as nn 13 | from pycls.core.config import cfg 14 | 15 | 16 | class EffHead(nn.Module): 17 | """EfficientNet head: 1x1, BN, Swish, AvgPool, Dropout, FC.""" 18 | 19 | def __init__(self, w_in, w_out, nc): 20 | super(EffHead, self).__init__() 21 | self.conv = nn.Conv2d(w_in, w_out, 1, stride=1, padding=0, bias=False) 22 | self.conv_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 23 | self.conv_swish = Swish() 24 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 25 | if cfg.EN.DROPOUT_RATIO > 0.0: 26 | self.dropout = nn.Dropout(p=cfg.EN.DROPOUT_RATIO) 27 | self.fc = nn.Linear(w_out, nc, bias=True) 28 | 29 | def forward(self, x): 30 | x = self.conv_swish(self.conv_bn(self.conv(x))) 31 | x = self.avg_pool(x) 32 | x = x.view(x.size(0), -1) 33 | x = self.dropout(x) if hasattr(self, "dropout") else x 34 | out = self.fc(x) 35 | return x, out 36 | 37 | @staticmethod 38 | def complexity(cx, w_in, w_out, nc): 39 | cx = net.complexity_conv2d(cx, w_in, w_out, 1, 1, 0) 40 | cx = net.complexity_batchnorm2d(cx, w_out) 41 | cx["h"], cx["w"] = 1, 1 42 | cx = net.complexity_conv2d(cx, w_out, nc, 1, 1, 0, bias=True) 43 | return cx 44 | 45 | 46 | class Swish(nn.Module): 47 | """Swish activation function: x * sigmoid(x).""" 48 | 49 | def __init__(self): 50 | super(Swish, self).__init__() 51 | 52 | def forward(self, x): 53 | return x * torch.sigmoid(x) 54 | 55 | 56 | class SE(nn.Module): 57 | """Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid.""" 58 | 59 | def __init__(self, w_in, w_se): 60 | super(SE, self).__init__() 61 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 62 | self.f_ex = nn.Sequential( 63 | nn.Conv2d(w_in, w_se, 1, bias=True), 64 | Swish(), 65 | nn.Conv2d(w_se, w_in, 1, bias=True), 66 | nn.Sigmoid(), 67 | ) 68 | 69 | def forward(self, x): 70 | return x * self.f_ex(self.avg_pool(x)) 71 | 72 | @staticmethod 73 | def complexity(cx, w_in, w_se): 74 | h, w = cx["h"], cx["w"] 75 | cx["h"], cx["w"] = 1, 1 76 | cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True) 77 | cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True) 78 | cx["h"], cx["w"] = h, w 79 | return cx 80 | 81 | 82 | class MBConv(nn.Module): 83 | """Mobile inverted bottleneck block w/ SE (MBConv).""" 84 | 85 | def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out): 86 | # expansion, 3x3 dwise, BN, Swish, SE, 1x1, BN, skip_connection 87 | super(MBConv, self).__init__() 88 | self.exp = None 89 | w_exp = int(w_in * exp_r) 90 | if w_exp != w_in: 91 | self.exp = nn.Conv2d(w_in, w_exp, 1, stride=1, padding=0, bias=False) 92 | self.exp_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 93 | self.exp_swish = Swish() 94 | dwise_args = {"groups": w_exp, "padding": (kernel - 1) // 2, "bias": False} 95 | self.dwise = nn.Conv2d(w_exp, w_exp, kernel, stride=stride, **dwise_args) 96 | self.dwise_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 97 | self.dwise_swish = Swish() 98 | self.se = SE(w_exp, int(w_in * se_r)) 99 | self.lin_proj = nn.Conv2d(w_exp, w_out, 1, stride=1, padding=0, bias=False) 100 | self.lin_proj_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 101 | # Skip connection if in and out shapes are the same (MN-V2 style) 102 | self.has_skip = stride == 1 and w_in == w_out 103 | 104 | def forward(self, x): 105 | f_x = x 106 | if self.exp: 107 | f_x = self.exp_swish(self.exp_bn(self.exp(f_x))) 108 | f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x))) 109 | f_x = self.se(f_x) 110 | f_x = self.lin_proj_bn(self.lin_proj(f_x)) 111 | if self.has_skip: 112 | if self.training and cfg.EN.DC_RATIO > 0.0: 113 | f_x = net.drop_connect(f_x, cfg.EN.DC_RATIO) 114 | f_x = x + f_x 115 | return f_x 116 | 117 | @staticmethod 118 | def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out): 119 | w_exp = int(w_in * exp_r) 120 | if w_exp != w_in: 121 | cx = net.complexity_conv2d(cx, w_in, w_exp, 1, 1, 0) 122 | cx = net.complexity_batchnorm2d(cx, w_exp) 123 | padding = (kernel - 1) // 2 124 | cx = net.complexity_conv2d(cx, w_exp, w_exp, kernel, stride, padding, w_exp) 125 | cx = net.complexity_batchnorm2d(cx, w_exp) 126 | cx = SE.complexity(cx, w_exp, int(w_in * se_r)) 127 | cx = net.complexity_conv2d(cx, w_exp, w_out, 1, 1, 0) 128 | cx = net.complexity_batchnorm2d(cx, w_out) 129 | return cx 130 | 131 | 132 | class EffStage(nn.Module): 133 | """EfficientNet stage.""" 134 | 135 | def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d): 136 | super(EffStage, self).__init__() 137 | for i in range(d): 138 | b_stride = stride if i == 0 else 1 139 | b_w_in = w_in if i == 0 else w_out 140 | name = "b{}".format(i + 1) 141 | self.add_module(name, MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out)) 142 | 143 | def forward(self, x): 144 | for block in self.children(): 145 | x = block(x) 146 | return x 147 | 148 | @staticmethod 149 | def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out, d): 150 | for i in range(d): 151 | b_stride = stride if i == 0 else 1 152 | b_w_in = w_in if i == 0 else w_out 153 | cx = MBConv.complexity(cx, b_w_in, exp_r, kernel, b_stride, se_r, w_out) 154 | return cx 155 | 156 | 157 | class StemIN(nn.Module): 158 | """EfficientNet stem for ImageNet: 3x3, BN, Swish.""" 159 | 160 | def __init__(self, w_in, w_out): 161 | super(StemIN, self).__init__() 162 | self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False) 163 | self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 164 | self.swish = Swish() 165 | 166 | def forward(self, x): 167 | for layer in self.children(): 168 | x = layer(x) 169 | return x 170 | 171 | @staticmethod 172 | def complexity(cx, w_in, w_out): 173 | cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1) 174 | cx = net.complexity_batchnorm2d(cx, w_out) 175 | return cx 176 | 177 | 178 | class EffNet(nn.Module): 179 | """EfficientNet model.""" 180 | 181 | @staticmethod 182 | def get_args(): 183 | return { 184 | "stem_w": cfg.EN.STEM_W, 185 | "ds": cfg.EN.DEPTHS, 186 | "ws": cfg.EN.WIDTHS, 187 | "exp_rs": cfg.EN.EXP_RATIOS, 188 | "se_r": cfg.EN.SE_R, 189 | "ss": cfg.EN.STRIDES, 190 | "ks": cfg.EN.KERNELS, 191 | "head_w": cfg.EN.HEAD_W, 192 | "nc": cfg.MODEL.NUM_CLASSES, 193 | } 194 | 195 | def __init__(self): 196 | err_str = "Dataset {} is not supported" 197 | assert cfg.TRAIN.DATASET in ["imagenet"], err_str.format(cfg.TRAIN.DATASET) 198 | assert cfg.TEST.DATASET in ["imagenet"], err_str.format(cfg.TEST.DATASET) 199 | super(EffNet, self).__init__() 200 | self._construct(**EffNet.get_args()) 201 | self.apply(net.init_weights) 202 | 203 | def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc): 204 | stage_params = list(zip(ds, ws, exp_rs, ss, ks)) 205 | self.stem = StemIN(3, stem_w) 206 | prev_w = stem_w 207 | for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params): 208 | name = "s{}".format(i + 1) 209 | self.add_module(name, EffStage(prev_w, exp_r, kernel, stride, se_r, w, d)) 210 | prev_w = w 211 | self.head = EffHead(prev_w, head_w, nc) 212 | 213 | def forward(self, x): 214 | # for module in self.children(): 215 | # x = module(x) 216 | # return x 217 | for name, module in self.named_children(): 218 | if (name=='head'): 219 | break 220 | x = module(x) 221 | x, out = self.head(x) 222 | return x, out 223 | 224 | @staticmethod 225 | def complexity(cx): 226 | """Computes model complexity. If you alter the model, make sure to update.""" 227 | return EffNet._complexity(cx, **EffNet.get_args()) 228 | 229 | @staticmethod 230 | def _complexity(cx, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc): 231 | stage_params = list(zip(ds, ws, exp_rs, ss, ks)) 232 | cx = StemIN.complexity(cx, 3, stem_w) 233 | prev_w = stem_w 234 | for d, w, exp_r, stride, kernel in stage_params: 235 | cx = EffStage.complexity(cx, prev_w, exp_r, kernel, stride, se_r, w, d) 236 | prev_w = w 237 | cx = EffHead.complexity(cx, prev_w, head_w, nc) 238 | return cx 239 | -------------------------------------------------------------------------------- /imagenet/pycls/models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def conv_bn(inp, oup, stride): 6 | return nn.Sequential( 7 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 8 | nn.BatchNorm2d(oup), 9 | nn.ReLU6(inplace=True) 10 | ) 11 | 12 | 13 | def conv_1x1_bn(inp, oup): 14 | return nn.Sequential( 15 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 16 | nn.BatchNorm2d(oup), 17 | nn.ReLU6(inplace=True) 18 | ) 19 | 20 | 21 | def make_divisible(x, divisible_by=8): 22 | import numpy as np 23 | return int(np.ceil(x * 1. / divisible_by) * divisible_by) 24 | 25 | 26 | class InvertedResidual(nn.Module): 27 | def __init__(self, inp, oup, stride, expand_ratio): 28 | super(InvertedResidual, self).__init__() 29 | self.stride = stride 30 | assert stride in [1, 2] 31 | 32 | hidden_dim = int(inp * expand_ratio) 33 | self.use_res_connect = self.stride == 1 and inp == oup 34 | 35 | if expand_ratio == 1: 36 | self.conv = nn.Sequential( 37 | # dw 38 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 39 | nn.BatchNorm2d(hidden_dim), 40 | nn.ReLU6(inplace=True), 41 | # pw-linear 42 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 43 | nn.BatchNorm2d(oup), 44 | ) 45 | else: 46 | self.conv = nn.Sequential( 47 | # pw 48 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 49 | nn.BatchNorm2d(hidden_dim), 50 | nn.ReLU6(inplace=True), 51 | # dw 52 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 53 | nn.BatchNorm2d(hidden_dim), 54 | nn.ReLU6(inplace=True), 55 | # pw-linear 56 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 57 | nn.BatchNorm2d(oup), 58 | ) 59 | 60 | def forward(self, x): 61 | if self.use_res_connect: 62 | return x + self.conv(x) 63 | else: 64 | return self.conv(x) 65 | 66 | 67 | class MobileNetV2(nn.Module): 68 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 69 | super(MobileNetV2, self).__init__() 70 | block = InvertedResidual 71 | input_channel = 32 72 | last_channel = 1280 73 | interverted_residual_setting = [ 74 | # t, c, n, s 75 | [1, 16, 1, 1], 76 | [6, 24, 2, 2], 77 | [6, 32, 3, 2], 78 | [6, 64, 4, 2], 79 | [6, 96, 3, 1], 80 | [6, 160, 3, 2], 81 | [6, 320, 1, 1], 82 | ] 83 | 84 | # building first layer 85 | assert input_size % 32 == 0 86 | # input_channel = make_divisible(input_channel * width_mult) # first channel is always 32! 87 | self.last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel 88 | self.features = [conv_bn(3, input_channel, 2)] 89 | # building inverted residual blocks 90 | for t, c, n, s in interverted_residual_setting: 91 | output_channel = make_divisible(c * width_mult) if t > 1 else c 92 | for i in range(n): 93 | if i == 0: 94 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t)) 95 | else: 96 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t)) 97 | input_channel = output_channel 98 | # building last several layers 99 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 100 | # make it nn.Sequential 101 | self.features = nn.Sequential(*self.features) 102 | 103 | # building classifier 104 | self.classifier = nn.Linear(self.last_channel, n_class) 105 | 106 | self._initialize_weights() 107 | 108 | def forward(self, x): 109 | x = self.features(x) 110 | x = x.mean(3).mean(2) 111 | out = self.classifier(x) 112 | return x, out 113 | 114 | def _initialize_weights(self): 115 | for m in self.modules(): 116 | if isinstance(m, nn.Conv2d): 117 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 118 | m.weight.data.normal_(0, math.sqrt(2. / n)) 119 | if m.bias is not None: 120 | m.bias.data.zero_() 121 | elif isinstance(m, nn.BatchNorm2d): 122 | m.weight.data.fill_(1) 123 | m.bias.data.zero_() 124 | elif isinstance(m, nn.Linear): 125 | n = m.weight.size(1) 126 | m.weight.data.normal_(0, 0.01) 127 | m.bias.data.zero_() 128 | 129 | 130 | # def mobilenet_v2(pretrained=True): 131 | # model = MobileNetV2(width_mult=1) 132 | # return model 133 | -------------------------------------------------------------------------------- /imagenet/pycls/models/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ResNe(X)t models.""" 9 | 10 | import pycls.core.net as net 11 | import torch.nn as nn 12 | from pycls.core.config import cfg 13 | 14 | 15 | # Stage depths for ImageNet models 16 | _IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)} 17 | 18 | 19 | def get_trans_fun(name): 20 | """Retrieves the transformation function by name.""" 21 | trans_funs = { 22 | "basic_transform": BasicTransform, 23 | "bottleneck_transform": BottleneckTransform, 24 | } 25 | err_str = "Transformation function '{}' not supported" 26 | assert name in trans_funs.keys(), err_str.format(name) 27 | return trans_funs[name] 28 | 29 | 30 | class ResHead(nn.Module): 31 | """ResNet head: AvgPool, 1x1.""" 32 | 33 | def __init__(self, w_in, nc): 34 | super(ResHead, self).__init__() 35 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 36 | self.fc = nn.Linear(w_in, nc, bias=True) 37 | 38 | def forward(self, x): 39 | x = self.avg_pool(x) 40 | x = x.view(x.size(0), -1) 41 | x = self.fc(x) 42 | return x 43 | 44 | @staticmethod 45 | def complexity(cx, w_in, nc): 46 | cx["h"], cx["w"] = 1, 1 47 | cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True) 48 | return cx 49 | 50 | 51 | class BasicTransform(nn.Module): 52 | """Basic transformation: 3x3, BN, ReLU, 3x3, BN.""" 53 | 54 | def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1): 55 | err_str = "Basic transform does not support w_b and num_gs options" 56 | assert w_b is None and num_gs == 1, err_str 57 | super(BasicTransform, self).__init__() 58 | self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False) 59 | self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 60 | self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) 61 | self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False) 62 | self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 63 | self.b_bn.final_bn = True 64 | 65 | def forward(self, x): 66 | for layer in self.children(): 67 | x = layer(x) 68 | return x 69 | 70 | @staticmethod 71 | def complexity(cx, w_in, w_out, stride, w_b=None, num_gs=1): 72 | err_str = "Basic transform does not support w_b and num_gs options" 73 | assert w_b is None and num_gs == 1, err_str 74 | cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1) 75 | cx = net.complexity_batchnorm2d(cx, w_out) 76 | cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1) 77 | cx = net.complexity_batchnorm2d(cx, w_out) 78 | return cx 79 | 80 | 81 | class BottleneckTransform(nn.Module): 82 | """Bottleneck transformation: 1x1, BN, ReLU, 3x3, BN, ReLU, 1x1, BN.""" 83 | 84 | def __init__(self, w_in, w_out, stride, w_b, num_gs): 85 | super(BottleneckTransform, self).__init__() 86 | # MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3 87 | (s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride) 88 | self.a = nn.Conv2d(w_in, w_b, 1, stride=s1, padding=0, bias=False) 89 | self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 90 | self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) 91 | self.b = nn.Conv2d(w_b, w_b, 3, stride=s3, padding=1, groups=num_gs, bias=False) 92 | self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 93 | self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE) 94 | self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False) 95 | self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 96 | self.c_bn.final_bn = True 97 | 98 | def forward(self, x): 99 | for layer in self.children(): 100 | x = layer(x) 101 | return x 102 | 103 | @staticmethod 104 | def complexity(cx, w_in, w_out, stride, w_b, num_gs): 105 | (s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride) 106 | cx = net.complexity_conv2d(cx, w_in, w_b, 1, s1, 0) 107 | cx = net.complexity_batchnorm2d(cx, w_b) 108 | cx = net.complexity_conv2d(cx, w_b, w_b, 3, s3, 1, num_gs) 109 | cx = net.complexity_batchnorm2d(cx, w_b) 110 | cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0) 111 | cx = net.complexity_batchnorm2d(cx, w_out) 112 | return cx 113 | 114 | 115 | class ResBlock(nn.Module): 116 | """Residual block: x + F(x).""" 117 | 118 | def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1): 119 | super(ResBlock, self).__init__() 120 | # Use skip connection with projection if shape changes 121 | self.proj_block = (w_in != w_out) or (stride != 1) 122 | if self.proj_block: 123 | self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False) 124 | self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 125 | self.f = trans_fun(w_in, w_out, stride, w_b, num_gs) 126 | self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) 127 | 128 | def forward(self, x): 129 | if self.proj_block: 130 | x = self.bn(self.proj(x)) + self.f(x) 131 | else: 132 | x = x + self.f(x) 133 | x = self.relu(x) 134 | return x 135 | 136 | @staticmethod 137 | def complexity(cx, w_in, w_out, stride, trans_fun, w_b, num_gs): 138 | proj_block = (w_in != w_out) or (stride != 1) 139 | if proj_block: 140 | h, w = cx["h"], cx["w"] 141 | cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0) 142 | cx = net.complexity_batchnorm2d(cx, w_out) 143 | cx["h"], cx["w"] = h, w # parallel branch 144 | cx = trans_fun.complexity(cx, w_in, w_out, stride, w_b, num_gs) 145 | return cx 146 | 147 | 148 | class ResStage(nn.Module): 149 | """Stage of ResNet.""" 150 | 151 | def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1): 152 | super(ResStage, self).__init__() 153 | for i in range(d): 154 | b_stride = stride if i == 0 else 1 155 | b_w_in = w_in if i == 0 else w_out 156 | trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN) 157 | res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs) 158 | self.add_module("b{}".format(i + 1), res_block) 159 | 160 | def forward(self, x): 161 | for block in self.children(): 162 | x = block(x) 163 | return x 164 | 165 | @staticmethod 166 | def complexity(cx, w_in, w_out, stride, d, w_b=None, num_gs=1): 167 | for i in range(d): 168 | b_stride = stride if i == 0 else 1 169 | b_w_in = w_in if i == 0 else w_out 170 | trans_f = get_trans_fun(cfg.RESNET.TRANS_FUN) 171 | cx = ResBlock.complexity(cx, b_w_in, w_out, b_stride, trans_f, w_b, num_gs) 172 | return cx 173 | 174 | 175 | class ResStemCifar(nn.Module): 176 | """ResNet stem for CIFAR: 3x3, BN, ReLU.""" 177 | 178 | def __init__(self, w_in, w_out): 179 | super(ResStemCifar, self).__init__() 180 | self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False) 181 | self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 182 | self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) 183 | 184 | def forward(self, x): 185 | for layer in self.children(): 186 | x = layer(x) 187 | return x 188 | 189 | @staticmethod 190 | def complexity(cx, w_in, w_out): 191 | cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1) 192 | cx = net.complexity_batchnorm2d(cx, w_out) 193 | return cx 194 | 195 | 196 | class ResStemIN(nn.Module): 197 | """ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool.""" 198 | 199 | def __init__(self, w_in, w_out): 200 | super(ResStemIN, self).__init__() 201 | self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False) 202 | self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM) 203 | self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE) 204 | self.pool = nn.MaxPool2d(3, stride=2, padding=1) 205 | 206 | def forward(self, x): 207 | for layer in self.children(): 208 | x = layer(x) 209 | return x 210 | 211 | @staticmethod 212 | def complexity(cx, w_in, w_out): 213 | cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3) 214 | cx = net.complexity_batchnorm2d(cx, w_out) 215 | cx = net.complexity_maxpool2d(cx, 3, 2, 1) 216 | return cx 217 | 218 | 219 | class ResNet(nn.Module): 220 | """ResNet model.""" 221 | 222 | def __init__(self): 223 | datasets = ["cifar10", "imagenet"] 224 | err_str = "Dataset {} is not supported" 225 | assert cfg.TRAIN.DATASET in datasets, err_str.format(cfg.TRAIN.DATASET) 226 | assert cfg.TEST.DATASET in datasets, err_str.format(cfg.TEST.DATASET) 227 | super(ResNet, self).__init__() 228 | if "cifar" in cfg.TRAIN.DATASET: 229 | self._construct_cifar() 230 | else: 231 | self._construct_imagenet() 232 | self.apply(net.init_weights) 233 | 234 | def _construct_cifar(self): 235 | err_str = "Model depth should be of the format 6n + 2 for cifar" 236 | assert (cfg.MODEL.DEPTH - 2) % 6 == 0, err_str 237 | d = int((cfg.MODEL.DEPTH - 2) / 6) 238 | self.stem = ResStemCifar(3, 16) 239 | self.s1 = ResStage(16, 16, stride=1, d=d) 240 | self.s2 = ResStage(16, 32, stride=2, d=d) 241 | self.s3 = ResStage(32, 64, stride=2, d=d) 242 | self.head = ResHead(64, nc=cfg.MODEL.NUM_CLASSES) 243 | 244 | def _construct_imagenet(self): 245 | g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP 246 | (d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH] 247 | w_b = gw * g 248 | self.stem = ResStemIN(3, 64) 249 | self.s1 = ResStage(64, 256, stride=1, d=d1, w_b=w_b, num_gs=g) 250 | self.s2 = ResStage(256, 512, stride=2, d=d2, w_b=w_b * 2, num_gs=g) 251 | self.s3 = ResStage(512, 1024, stride=2, d=d3, w_b=w_b * 4, num_gs=g) 252 | self.s4 = ResStage(1024, 2048, stride=2, d=d4, w_b=w_b * 8, num_gs=g) 253 | self.head = ResHead(2048, nc=cfg.MODEL.NUM_CLASSES) 254 | 255 | def forward(self, x): 256 | for module in self.children(): 257 | x = module(x) 258 | return x 259 | 260 | @staticmethod 261 | def complexity(cx): 262 | """Computes model complexity. If you alter the model, make sure to update.""" 263 | if "cifar" in cfg.TRAIN.DATASET: 264 | d = int((cfg.MODEL.DEPTH - 2) / 6) 265 | cx = ResStemCifar.complexity(cx, 3, 16) 266 | cx = ResStage.complexity(cx, 16, 16, stride=1, d=d) 267 | cx = ResStage.complexity(cx, 16, 32, stride=2, d=d) 268 | cx = ResStage.complexity(cx, 32, 64, stride=2, d=d) 269 | cx = ResHead.complexity(cx, 64, nc=cfg.MODEL.NUM_CLASSES) 270 | else: 271 | g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP 272 | (d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH] 273 | w_b = gw * g 274 | cx = ResStemIN.complexity(cx, 3, 64) 275 | cx = ResStage.complexity(cx, 64, 256, 1, d=d1, w_b=w_b, num_gs=g) 276 | cx = ResStage.complexity(cx, 256, 512, 2, d=d2, w_b=w_b * 2, num_gs=g) 277 | cx = ResStage.complexity(cx, 512, 1024, 2, d=d3, w_b=w_b * 4, num_gs=g) 278 | cx = ResStage.complexity(cx, 1024, 2048, 2, d=d4, w_b=w_b * 8, num_gs=g) 279 | cx = ResHead.complexity(cx, 2048, nc=cfg.MODEL.NUM_CLASSES) 280 | return cx 281 | -------------------------------------------------------------------------------- /imagenet/requirements.txt: -------------------------------------------------------------------------------- 1 | black==19.3b0 2 | flake8 3 | isort 4 | numpy 5 | opencv-python 6 | parameterized 7 | setuptools 8 | simplejson 9 | yacs 10 | -------------------------------------------------------------------------------- /imagenet/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Setup pycls.""" 9 | 10 | from setuptools import setup 11 | 12 | 13 | setup(name="pycls", packages=["pycls"]) 14 | -------------------------------------------------------------------------------- /imagenet/tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PYTHON=${PYTHON:-"python"} 6 | 7 | CONFIG=$1 8 | CKPT=$2 9 | PY_ARGS=${@:3} 10 | 11 | GPUS=${GPUS:-8} 12 | 13 | while true # find unused tcp port 14 | do 15 | PORT=$(( ((RANDOM<<15)|RANDOM) % 49152 + 10000 )) 16 | status="$(nc -z 127.0.0.1 $PORT < /dev/null &>/dev/null; echo $?)" 17 | if [ "${status}" != "0" ]; then 18 | break; 19 | fi 20 | done 21 | 22 | $PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT --use_env \ 23 | tools/test_net.py --cfg ${CONFIG} \ 24 | TEST.WEIGHTS ${CKPT} \ 25 | LAUNCHER pytorch \ 26 | PORT ${PORT} \ 27 | ${PY_ARGS} 28 | -------------------------------------------------------------------------------- /imagenet/tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PYTHON=${PYTHON:-"python"} 6 | 7 | WORK_DIR=$1 8 | CONFIG=$2 9 | PY_ARGS=${@:3} 10 | 11 | GPUS=${GPUS:-8} 12 | 13 | while true # find unused tcp port 14 | do 15 | PORT=$(( ((RANDOM<<15)|RANDOM) % 49152 + 10000 )) 16 | status="$(nc -z 127.0.0.1 $PORT < /dev/null &>/dev/null; echo $?)" 17 | if [ "${status}" != "0" ]; then 18 | break; 19 | fi 20 | done 21 | 22 | mkdir -p $WORK_DIR 23 | 24 | $PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT --use_env \ 25 | tools/train_net.py --cfg ${CONFIG} \ 26 | LAUNCHER pytorch OUT_DIR ${WORK_DIR} PORT ${PORT} ${PY_ARGS} | tee ${WORK_DIR}/log.txt 27 | -------------------------------------------------------------------------------- /imagenet/tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | CONFIG=$2 7 | CKPT=$3 8 | PY_ARGS=${@:4} 9 | 10 | GPUS=${GPUS:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | SRUN_ARGS=${SRUN_ARGS:-""} 13 | 14 | while true # find unused tcp port 15 | do 16 | PORT=$(( ((RANDOM<<15)|RANDOM) % 49152 + 10000 )) 17 | status="$(nc -z 127.0.0.1 $PORT < /dev/null &>/dev/null; echo $?)" 18 | if [ "${status}" != "0" ]; then 19 | break; 20 | fi 21 | done 22 | 23 | srun --mpi=pmi2 -p ${PARTITION} \ 24 | --job-name=test \ 25 | --gres=gpu:${GPUS} \ 26 | --ntasks=${GPUS} \ 27 | --ntasks-per-node=${GPUS} \ 28 | --cpus-per-task=${CPUS_PER_TASK} \ 29 | --kill-on-bad-exit=1 \ 30 | ${SRUN_ARGS} \ 31 | python -u tools/test_net.py --cfg ${CONFIG} \ 32 | TEST.WEIGHTS ${CKPT} \ 33 | LAUNCHER slurm \ 34 | PORT ${PORT} \ 35 | ${PY_ARGS} 36 | -------------------------------------------------------------------------------- /imagenet/tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | WORK_DIR=$3 8 | CONFIG=$4 9 | PY_ARGS=${@:5} 10 | 11 | GPUS=${GPUS:-8} 12 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 13 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 14 | SRUN_ARGS=${SRUN_ARGS:-""} 15 | 16 | while true # find unused tcp port 17 | do 18 | PORT=$(( ((RANDOM<<15)|RANDOM) % 49152 + 10000 )) 19 | status="$(nc -z 127.0.0.1 $PORT < /dev/null &>/dev/null; echo $?)" 20 | if [ "${status}" != "0" ]; then 21 | break; 22 | fi 23 | done 24 | 25 | mkdir -p $WORK_DIR 26 | 27 | srun --mpi=pmi2 -p ${PARTITION} \ 28 | --job-name=${JOB_NAME} \ 29 | --gres=gpu:${GPUS_PER_NODE} \ 30 | --ntasks=${GPUS} \ 31 | --ntasks-per-node=${GPUS_PER_NODE} \ 32 | --cpus-per-task=${CPUS_PER_TASK} \ 33 | --kill-on-bad-exit=1 \ 34 | ${SRUN_ARGS} \ 35 | python -u tools/train_net.py --cfg ${CONFIG} \ 36 | LAUNCHER slurm OUT_DIR ${WORK_DIR} PORT ${PORT} ${PY_ARGS} | tee ${WORK_DIR}/log.txt 37 | -------------------------------------------------------------------------------- /imagenet/tools/test_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Test a trained classification model.""" 9 | 10 | import pycls.core.config as config 11 | import pycls.core.distributed as dist 12 | import pycls.core.trainer as trainer 13 | from pycls.core.config import cfg 14 | 15 | 16 | def main(): 17 | config.load_cfg_fom_args("Test a trained classification model.") 18 | config.assert_and_infer_cfg() 19 | # cfg.freeze() 20 | dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.test_model) 21 | 22 | 23 | if __name__ == "__main__": 24 | main() 25 | -------------------------------------------------------------------------------- /imagenet/tools/train_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """Train a classification model.""" 9 | 10 | import pycls.core.config as config 11 | import pycls.core.distributed as dist 12 | import pycls.core.trainer as trainer 13 | from pycls.core.config import cfg 14 | 15 | 16 | def main(): 17 | config.load_cfg_fom_args("Train a classification model.") 18 | config.assert_and_infer_cfg() 19 | # cfg.freeze() 20 | dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.train_model) 21 | 22 | 23 | if __name__ == "__main__": 24 | main() 25 | -------------------------------------------------------------------------------- /small_scale/README.md: -------------------------------------------------------------------------------- 1 | # BAKE on Small-scale Datasets 2 | 3 | PyTorch implementation of [Self-distillation with Batch Knowledge Ensembling Improves ImageNet Classification](https://arxiv.org/abs/2104.13298) on CIFAR-100, TinyImageNet, CUB-200-2011, Stanford Dogs, and MIT67. 4 | 5 | ## Requirements 6 | 7 | - torch >= 1.2.0 8 | - torchvision >= 0.4.0 9 | 10 | ## Datasets 11 | 12 | Please download the raw datasets and re-organize them according to [DATA](data/). The folder tree should be like 13 | ``` 14 | ├── data 15 | │   ├── README.md 16 | │   ├── cifar-100-python/ 17 | │   ├── CUB200/ 18 | │   ├── MIT67/ 19 | │   ├── STANFORD120/ 20 | │   ├── tinyimagenet/ 21 | │   ├── tools/ 22 | └── └── ... 23 | ``` 24 | 25 | ## BAKE Training 26 | 27 | ``` 28 | sh scripts/train_bake.sh 29 | ``` 30 | 31 | Specifically, 32 | + CIFAR-100 33 | 34 | ``` 35 | sh scripts/train_bake.sh 0 cifar100 CIFAR_ResNet18 128 3 0.5 36 | sh scripts/train_bake.sh 0 cifar100 CIFAR_DenseNet121 128 3 0.5 37 | ``` 38 | 39 | + TinyImageNet 40 | 41 | ``` 42 | sh scripts/train_bake.sh 0 tinyimagenet CIFAR_ResNet18 128 1 0.9 43 | sh scripts/train_bake.sh 0 tinyimagenet CIFAR_DenseNet121 128 1 0.9 44 | ``` 45 | 46 | + CUB-200-2011 47 | 48 | ``` 49 | sh scripts/train_bake.sh 0 CUB200 resnet18 32 3 0.5 50 | sh scripts/train_bake.sh 0 CUB200 densenet121 32 3 0.5 51 | ``` 52 | 53 | + Stanford Dogs 54 | 55 | ``` 56 | sh scripts/train_bake.sh 0 STANFORD120 resnet18 32 1 0.9 57 | sh scripts/train_bake.sh 0 STANFORD120 densenet121 32 1 0.9 58 | ``` 59 | 60 | + MIT67 61 | 62 | ``` 63 | sh scripts/train_bake.sh 0 MIT67 resnet18 32 1 0.9 64 | sh scripts/train_bake.sh 0 MIT67 densenet121 32 1 0.9 65 | ``` 66 | 67 | ## Baseline Training 68 | 69 | ``` 70 | sh scripts/train_baseline.sh 71 | ``` 72 | 73 | ## Validation 74 | 75 | ``` 76 | sh scripts/val.sh 77 | ``` 78 | 79 | ## Results (top-1 error) 80 | 81 | ||CIFAR-100|TinyImageNet|CUB-200-2011|Stanford Dogs|MIT67| 82 | |---|:--:|:--:|:--:|:--:|:--:| 83 | |ResNet-18|21.28|41.71|29.74|30.20|39.95| 84 | |DenseNet-121|20.74|37.07|28.79|27.66|39.15| 85 | 86 | ## Thanks 87 | The code is modified from [CS-KD](https://github.com/alinlab/cs-kd). 88 | -------------------------------------------------------------------------------- /small_scale/data/README.md: -------------------------------------------------------------------------------- 1 | ## Datasets 2 | 3 | ### Supported Datasets 4 | 5 | + CIFAR-100 6 | + [TinyImageNet](http://cs231n.stanford.edu/tiny-imagenet-200.zip) 7 | + [CUB-200-2011](https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view) 8 | + [Stanford Dogs](http://vision.stanford.edu/aditya86/ImageNetDogs/) 9 | + [MIT67](http://web.mit.edu/torralba/www/indoor.html) 10 | 11 | ### Preparation Steps 12 | 13 | + Download the raw datasets and save them under this directory. Note that there's no need to download CIFAR-100, as it could be downloaded automatically by the code. 14 | + Process them by running the corresponding scripts in `tools/`, e.g. `cd tools && python cub.py`. 15 | -------------------------------------------------------------------------------- /small_scale/data/tools/cub.py: -------------------------------------------------------------------------------- 1 | # Please download the raw data of CUB_200_2011 dataset from 2 | # https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view 3 | 4 | import os 5 | import os.path as osp 6 | 7 | ori='../CUB_200_2011' 8 | tar='../CUB200' 9 | os.makedirs(tar) 10 | 11 | with open(osp.join(ori, 'images.txt'), 'r') as f: 12 | img_list = f.readlines() 13 | with open(osp.join(ori, 'train_test_split.txt'), 'r') as f: 14 | split_list = f.readlines() 15 | 16 | img_dict = {} 17 | for im in img_list: 18 | img_dict[im.split(' ')[0]]=im.strip().split(' ')[1] 19 | split_dict = {} 20 | for s in split_list: 21 | split_dict[s.split(' ')[0]]=int(s.strip().split(' ')[1]) 22 | 23 | for k in img_dict.keys(): 24 | if split_dict[k]==1: 25 | s='train' 26 | else: 27 | s='test' 28 | if not os.path.isdir(os.path.join(tar,s,osp.dirname(img_dict[k]))): 29 | os.mkdir(os.path.join(tar,s,osp.dirname(img_dict[k]))) 30 | os.symlink(osp.join(ori,'images',img_dict[k]),os.path.join(tar,s,img_dict[k])) 31 | -------------------------------------------------------------------------------- /small_scale/data/tools/mit67.py: -------------------------------------------------------------------------------- 1 | # Please download the raw data of MIT67 dataset from 2 | # http://web.mit.edu/torralba/www/indoor.html 3 | 4 | import os 5 | import os.path as osp 6 | import scipy.io 7 | 8 | root = '../' 9 | ori = root+'Images' 10 | tar = root+'MIT67' 11 | os.makedirs(tar) 12 | 13 | with open(root+'TrainImages.txt', 'r') as f: 14 | train_list = f.readlines() 15 | with open(root+'TestImages.txt', 'r') as f: 16 | test_list = f.readlines() 17 | 18 | def link(file_list, tar_folder): 19 | for name in file_list: 20 | name=name.strip() 21 | assert(osp.isfile(osp.join(ori,name))) 22 | if not osp.isdir(osp.dirname(osp.join(tar_folder, name))): 23 | os.makedirs(osp.dirname(osp.join(tar_folder, name))) 24 | os.symlink(osp.join(ori,name), osp.join(tar_folder, name)) 25 | 26 | link(train_list, osp.join(tar,'train')) 27 | link(test_list, osp.join(tar,'test')) 28 | -------------------------------------------------------------------------------- /small_scale/data/tools/stanford_dogs.py: -------------------------------------------------------------------------------- 1 | # Please download the raw data of Stanford Dogs dataset from 2 | # http://vision.stanford.edu/aditya86/ImageNetDogs/ 3 | 4 | import os 5 | import os.path as osp 6 | import scipy.io 7 | 8 | root = '../' 9 | ori = root+'Images' 10 | tar = root+'STANFORD120' 11 | os.makedirs(tar) 12 | 13 | train_list = scipy.io.loadmat(root+'train_list.mat')['annotation_list'] 14 | test_list = scipy.io.loadmat(root+'test_list.mat')['annotation_list'] 15 | 16 | def link(file_list, tar_folder): 17 | for i in range(file_list.shape[0]): 18 | name = str(file_list[i][0][0]) 19 | if not name.endswith('.jpg'): 20 | name = name+'.jpg' 21 | assert(osp.isfile(osp.join(ori,name))) 22 | if not osp.isdir(osp.dirname(osp.join(tar_folder, name))): 23 | os.makedirs(osp.dirname(osp.join(tar_folder, name))) 24 | os.symlink(osp.join(ori,name), osp.join(tar_folder, name)) 25 | 26 | link(train_list, osp.join(tar,'train')) 27 | link(test_list, osp.join(tar,'test')) 28 | -------------------------------------------------------------------------------- /small_scale/data/tools/tinyimagenet.py: -------------------------------------------------------------------------------- 1 | # Please download the raw data of TinyImageNet dataset from 2 | # http://cs231n.stanford.edu/tiny-imagenet-200.zip 3 | 4 | import os 5 | import os.path as osp 6 | 7 | root = '../' 8 | ori = root+'tiny-imagenet-200/train' 9 | tar = root+'tinyimagenet/train' 10 | os.makedirs(tar) 11 | 12 | for dir in os.listdir(ori): 13 | os.makedirs(osp.join(tar,dir)) 14 | for file in os.listdir(osp.join(ori,dir,'images')): 15 | os.symlink(osp.join(ori,dir,'images',file), osp.join(tar,dir,file)) 16 | 17 | 18 | ori = root+'tiny-imagenet-200/val' 19 | tar = root+'tinyimagenet/val' 20 | os.makedirs(tar) 21 | 22 | with open(root+'tiny-imagenet-200/val/val_annotations.txt','r') as f: 23 | list_file = f.readlines() 24 | 25 | for item in list_file: 26 | item = item.strip() 27 | file = item.split('\t')[0] 28 | dir = item.split('\t')[1] 29 | if not osp.isdir(osp.join(tar,dir)): 30 | os.makedirs(osp.join(tar,dir)) 31 | os.symlink(osp.join(ori,'images',file), osp.join(tar,dir,file)) 32 | -------------------------------------------------------------------------------- /small_scale/datasets.py: -------------------------------------------------------------------------------- 1 | import csv, torchvision, numpy as np, random, os 2 | from PIL import Image 3 | import numpy as np 4 | import copy 5 | 6 | from torch.utils.data import Sampler, Dataset, DataLoader, BatchSampler, SequentialSampler, RandomSampler, Subset 7 | from torchvision import transforms, datasets 8 | from collections import defaultdict 9 | 10 | 11 | class IdentityBatchSampler(Sampler): 12 | def __init__(self, dataset, batch_size, num_instances, num_iterations=None): 13 | self.dataset = dataset 14 | self.batch_size = batch_size 15 | self.num_instances = num_instances 16 | self.num_iterations = num_iterations 17 | 18 | def __iter__(self): 19 | indices = list(range(len(self.dataset))) 20 | random.shuffle(indices) 21 | for k in range(len(self)): 22 | offset = k*self.batch_size%len(indices) 23 | batch_indices = indices[offset:offset+self.batch_size] 24 | 25 | pair_indices = [] 26 | for idx in batch_indices: 27 | y = self.dataset.get_class(idx) 28 | t = copy.deepcopy(self.dataset.classwise_indices[y]) 29 | t.pop(t.index(idx)) 30 | if len(t)>=(self.num_instances-1): 31 | class_indices = np.random.choice(t, size=self.num_instances-1, replace=False) 32 | else: 33 | class_indices = np.random.choice(t, size=self.num_instances-1, replace=True) 34 | pair_indices.extend(class_indices) 35 | 36 | yield batch_indices+pair_indices 37 | 38 | def __len__(self): 39 | if self.num_iterations is None: 40 | return (len(self.dataset)+self.batch_size-1) // (self.batch_size) 41 | else: 42 | return self.num_iterations 43 | 44 | 45 | class PairBatchSampler(Sampler): 46 | def __init__(self, dataset, batch_size, num_iterations=None): 47 | self.dataset = dataset 48 | self.batch_size = batch_size 49 | self.num_iterations = num_iterations 50 | 51 | def __iter__(self): 52 | indices = list(range(len(self.dataset))) 53 | random.shuffle(indices) 54 | for k in range(len(self)): 55 | if self.num_iterations is None: 56 | offset = k*self.batch_size 57 | batch_indices = indices[offset:offset+self.batch_size] 58 | else: 59 | batch_indices = random.sample(range(len(self.dataset)), 60 | self.batch_size) 61 | 62 | pair_indices = [] 63 | for idx in batch_indices: 64 | y = self.dataset.get_class(idx) 65 | pair_indices.append(random.choice(self.dataset.classwise_indices[y])) 66 | 67 | yield batch_indices + pair_indices 68 | 69 | def __len__(self): 70 | if self.num_iterations is None: 71 | return (len(self.dataset)+self.batch_size-1) // self.batch_size 72 | else: 73 | return self.num_iterations 74 | 75 | 76 | class DatasetWrapper(Dataset): 77 | # Additinoal attributes 78 | # - indices 79 | # - classwise_indices 80 | # - num_classes 81 | # - get_class 82 | 83 | def __init__(self, dataset, indices=None): 84 | self.base_dataset = dataset 85 | if indices is None: 86 | self.indices = list(range(len(dataset))) 87 | else: 88 | self.indices = indices 89 | 90 | # torchvision 0.2.0 compatibility 91 | if torchvision.__version__.startswith('0.2'): 92 | if isinstance(self.base_dataset, datasets.ImageFolder): 93 | self.base_dataset.targets = [s[1] for s in self.base_dataset.imgs] 94 | else: 95 | if self.base_dataset.train: 96 | self.base_dataset.targets = self.base_dataset.train_labels 97 | else: 98 | self.base_dataset.targets = self.base_dataset.test_labels 99 | 100 | self.classwise_indices = defaultdict(list) 101 | for i in range(len(self)): 102 | y = self.base_dataset.targets[self.indices[i]] 103 | self.classwise_indices[y].append(i) 104 | self.num_classes = max(self.classwise_indices.keys())+1 105 | 106 | def __getitem__(self, i): 107 | return self.base_dataset[self.indices[i]] 108 | 109 | def __len__(self): 110 | return len(self.indices) 111 | 112 | def get_class(self, i): 113 | return self.base_dataset.targets[self.indices[i]] 114 | 115 | 116 | class ConcatWrapper(Dataset): # TODO: Naming 117 | @staticmethod 118 | def cumsum(sequence): 119 | r, s = [], 0 120 | for e in sequence: 121 | l = len(e) 122 | r.append(l + s) 123 | s += l 124 | return r 125 | 126 | @staticmethod 127 | def numcls(sequence): 128 | s = 0 129 | for e in sequence: 130 | l = e.num_classes 131 | s += l 132 | return s 133 | 134 | @staticmethod 135 | def clsidx(sequence): 136 | r, s, n = defaultdict(list), 0, 0 137 | for e in sequence: 138 | l = e.classwise_indices 139 | for c in range(s, s + e.num_classes): 140 | t = np.asarray(l[c-s]) + n 141 | r[c] = t.tolist() 142 | s += e.num_classes 143 | n += len(e) 144 | return r 145 | 146 | def __init__(self, datasets): 147 | super(ConcatWrapper, self).__init__() 148 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 149 | self.datasets = list(datasets) 150 | # for d in self.datasets: 151 | # assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset" 152 | self.cumulative_sizes = self.cumsum(self.datasets) 153 | 154 | self.num_classes = self.numcls(self.datasets) 155 | self.classwise_indices = self.clsidx(self.datasets) 156 | 157 | def __len__(self): 158 | return self.cumulative_sizes[-1] 159 | 160 | def __getitem__(self, idx): 161 | if idx < 0: 162 | if -idx > len(self): 163 | raise ValueError("absolute value of index should not exceed dataset length") 164 | idx = len(self) + idx 165 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 166 | if dataset_idx == 0: 167 | sample_idx = idx 168 | else: 169 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 170 | return self.datasets[dataset_idx][sample_idx] 171 | 172 | def get_class(self, idx): 173 | if idx < 0: 174 | if -idx > len(self): 175 | raise ValueError("absolute value of index should not exceed dataset length") 176 | idx = len(self) + idx 177 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 178 | if dataset_idx == 0: 179 | sample_idx = idx 180 | else: 181 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 182 | true_class = self.datasets[dataset_idx].base_dataset.targets[self.datasets[dataset_idx].indices[sample_idx]] 183 | return self.datasets[dataset_idx].base_dataset.target_transform(true_class) 184 | 185 | @property 186 | def cummulative_sizes(self): 187 | warnings.warn("cummulative_sizes attribute is renamed to " 188 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 189 | return self.cumulative_sizes 190 | 191 | 192 | 193 | def load_dataset(name, root, sample='default', **kwargs): 194 | # Dataset 195 | if name in ['imagenet','tinyimagenet', 'CUB200', 'STANFORD120', 'MIT67']: 196 | # TODO 197 | if name == 'tinyimagenet': 198 | transform_train = transforms.Compose([ 199 | transforms.RandomResizedCrop(32), 200 | transforms.RandomHorizontalFlip(), 201 | transforms.ToTensor(), 202 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 203 | ]) 204 | transform_test = transforms.Compose([ 205 | transforms.Resize(32), 206 | transforms.ToTensor(), 207 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 208 | ]) 209 | 210 | train_val_dataset_dir = os.path.join(root, "tinyimagenet/train") 211 | test_dataset_dir = os.path.join(root, "tinyimagenet/val") 212 | 213 | trainset = DatasetWrapper(datasets.ImageFolder(root=train_val_dataset_dir, transform=transform_train)) 214 | valset = DatasetWrapper(datasets.ImageFolder(root=test_dataset_dir, transform=transform_test)) 215 | 216 | elif name == 'imagenet': 217 | transform_train = transforms.Compose([ 218 | transforms.RandomResizedCrop(224), 219 | transforms.RandomHorizontalFlip(), 220 | transforms.ToTensor(), 221 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 222 | ]) 223 | transform_test = transforms.Compose([ 224 | transforms.Resize(256), 225 | transforms.CenterCrop(224), 226 | transforms.ToTensor(), 227 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 228 | ]) 229 | train_val_dataset_dir = os.path.join(root, "train") 230 | test_dataset_dir = os.path.join(root, "val") 231 | 232 | trainset = DatasetWrapper(datasets.ImageFolder(root=train_val_dataset_dir, transform=transform_train)) 233 | valset = DatasetWrapper(datasets.ImageFolder(root=test_dataset_dir, transform=transform_test)) 234 | 235 | else: 236 | transform_train = transforms.Compose([ 237 | transforms.RandomResizedCrop(224), 238 | transforms.RandomHorizontalFlip(), 239 | transforms.ToTensor(), 240 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 241 | ]) 242 | transform_test = transforms.Compose([ 243 | transforms.Resize(256), 244 | transforms.CenterCrop(224), 245 | transforms.ToTensor(), 246 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 247 | ]) 248 | 249 | train_val_dataset_dir = os.path.join(root, name, "train") 250 | test_dataset_dir = os.path.join(root, name, "test") 251 | 252 | trainset = DatasetWrapper(datasets.ImageFolder(root=train_val_dataset_dir, transform=transform_train)) 253 | valset = DatasetWrapper(datasets.ImageFolder(root=test_dataset_dir, transform=transform_test)) 254 | 255 | elif name.startswith('cifar'): 256 | transform_train = transforms.Compose([ 257 | transforms.RandomCrop(32, padding=4), 258 | transforms.RandomHorizontalFlip(), 259 | transforms.ToTensor(), 260 | transforms.Normalize((0.4914, 0.4822, 0.4465), 261 | (0.2023, 0.1994, 0.2010)), 262 | ]) 263 | transform_test = transforms.Compose([ 264 | transforms.ToTensor(), 265 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 266 | ]) 267 | 268 | if name == 'cifar10': 269 | CIFAR = datasets.CIFAR10 270 | else: 271 | CIFAR = datasets.CIFAR100 272 | 273 | trainset = DatasetWrapper(CIFAR(root, train=True, download=True, transform=transform_train)) 274 | valset = DatasetWrapper(CIFAR(root, train=False, download=True, transform=transform_test)) 275 | else: 276 | raise Exception('Unknown dataset: {}'.format(name)) 277 | 278 | # Sampler 279 | if sample == 'default': 280 | #get_train_sampler = lambda d: BatchSampler(RandomSampler(d), kwargs['batch_size'], False) 281 | get_train_sampler = lambda d: IdentityBatchSampler(d, kwargs['batch_size'], kwargs['num_instances']) 282 | get_test_sampler = lambda d: BatchSampler(SequentialSampler(d), kwargs['batch_size'], False) 283 | 284 | elif sample == 'pair': 285 | get_train_sampler = lambda d: PairBatchSampler(d, kwargs['batch_size']) 286 | get_test_sampler = lambda d: BatchSampler(SequentialSampler(d), kwargs['batch_size'], False) 287 | 288 | else: 289 | raise Exception('Unknown sampling: {}'.format(sampling)) 290 | 291 | trainloader = DataLoader(trainset, batch_sampler=get_train_sampler(trainset), num_workers=8) 292 | valloader = DataLoader(valset, batch_sampler=get_test_sampler(valset), num_workers=8) 293 | 294 | return trainloader, valloader 295 | -------------------------------------------------------------------------------- /small_scale/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .densenet import * 3 | from .densenet3 import * 4 | 5 | def load_model(name, num_classes=10, pretrained=False, **kwargs): 6 | model_dict = globals() 7 | model = model_dict[name](pretrained=pretrained, num_classes=num_classes, **kwargs) 8 | return model 9 | -------------------------------------------------------------------------------- /small_scale/models/densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint as cp 6 | from collections import OrderedDict 7 | # from .utils import load_state_dict_from_url 8 | try: 9 | from torch.hub import load_state_dict_from_url 10 | except ImportError: 11 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 12 | 13 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 14 | 15 | model_urls = { 16 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 17 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 18 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 19 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 20 | } 21 | 22 | 23 | def _bn_function_factory(norm, relu, conv): 24 | def bn_function(*inputs): 25 | concated_features = torch.cat(inputs, 1) 26 | bottleneck_output = conv(relu(norm(concated_features))) 27 | return bottleneck_output 28 | 29 | return bn_function 30 | 31 | 32 | class _DenseLayer(nn.Sequential): 33 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False): 34 | super(_DenseLayer, self).__init__() 35 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 36 | self.add_module('relu1', nn.ReLU(inplace=True)), 37 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 38 | growth_rate, kernel_size=1, stride=1, 39 | bias=False)), 40 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 41 | self.add_module('relu2', nn.ReLU(inplace=True)), 42 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 43 | kernel_size=3, stride=1, padding=1, 44 | bias=False)), 45 | self.drop_rate = drop_rate 46 | self.memory_efficient = memory_efficient 47 | 48 | def forward(self, *prev_features): 49 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 50 | if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features): 51 | bottleneck_output = cp.checkpoint(bn_function, *prev_features) 52 | else: 53 | bottleneck_output = bn_function(*prev_features) 54 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 55 | if self.drop_rate > 0: 56 | new_features = F.dropout(new_features, p=self.drop_rate, 57 | training=self.training) 58 | return new_features 59 | 60 | 61 | class _DenseBlock(nn.Module): 62 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False): 63 | super(_DenseBlock, self).__init__() 64 | for i in range(num_layers): 65 | layer = _DenseLayer( 66 | num_input_features + i * growth_rate, 67 | growth_rate=growth_rate, 68 | bn_size=bn_size, 69 | drop_rate=drop_rate, 70 | memory_efficient=memory_efficient, 71 | ) 72 | self.add_module('denselayer%d' % (i + 1), layer) 73 | 74 | def forward(self, init_features): 75 | features = [init_features] 76 | for name, layer in self.named_children(): 77 | new_features = layer(*features) 78 | features.append(new_features) 79 | return torch.cat(features, 1) 80 | 81 | 82 | class _Transition(nn.Sequential): 83 | def __init__(self, num_input_features, num_output_features): 84 | super(_Transition, self).__init__() 85 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 86 | self.add_module('relu', nn.ReLU(inplace=True)) 87 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 88 | kernel_size=1, stride=1, bias=False)) 89 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 90 | 91 | 92 | class DenseNet(nn.Module): 93 | r"""Densenet-BC model class, based on 94 | `"Densely Connected Convolutional Networks" `_ 95 | Args: 96 | growth_rate (int) - how many filters to add each layer (`k` in paper) 97 | block_config (list of 4 ints) - how many layers in each pooling block 98 | num_init_features (int) - the number of filters to learn in the first convolution layer 99 | bn_size (int) - multiplicative factor for number of bottle neck layers 100 | (i.e. bn_size * k features in the bottleneck layer) 101 | drop_rate (float) - dropout rate after each dense layer 102 | num_classes (int) - number of classification classes 103 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 104 | but slower. Default: *False*. See `"paper" `_ 105 | """ 106 | 107 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 108 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False, bias=True): 109 | 110 | super(DenseNet, self).__init__() 111 | 112 | # First convolution 113 | self.features = nn.Sequential(OrderedDict([ 114 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, 115 | padding=3, bias=False)), 116 | ('norm0', nn.BatchNorm2d(num_init_features)), 117 | ('relu0', nn.ReLU(inplace=True)), 118 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 119 | ])) 120 | 121 | # Each denseblock 122 | num_features = num_init_features 123 | for i, num_layers in enumerate(block_config): 124 | block = _DenseBlock( 125 | num_layers=num_layers, 126 | num_input_features=num_features, 127 | bn_size=bn_size, 128 | growth_rate=growth_rate, 129 | drop_rate=drop_rate, 130 | memory_efficient=memory_efficient 131 | ) 132 | self.features.add_module('denseblock%d' % (i + 1), block) 133 | num_features = num_features + num_layers * growth_rate 134 | if i != len(block_config) - 1: 135 | trans = _Transition(num_input_features=num_features, 136 | num_output_features=num_features // 2) 137 | self.features.add_module('transition%d' % (i + 1), trans) 138 | num_features = num_features // 2 139 | 140 | # Final batch norm 141 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 142 | 143 | # Linear layer 144 | self.classifier = nn.Linear(num_features, num_classes, bias=bias) 145 | 146 | # Official init from torch repo. 147 | for m in self.modules(): 148 | if isinstance(m, nn.Conv2d): 149 | nn.init.kaiming_normal_(m.weight) 150 | elif isinstance(m, nn.BatchNorm2d): 151 | nn.init.constant_(m.weight, 1) 152 | nn.init.constant_(m.bias, 0) 153 | elif isinstance(m, nn.Linear): 154 | if bias: 155 | nn.init.constant_(m.bias, 0) 156 | 157 | def forward(self, x): 158 | features = self.features(x) 159 | out = F.relu(features, inplace=True) 160 | out = F.adaptive_avg_pool2d(out, (1, 1)) 161 | out = torch.flatten(out, 1) 162 | prob = self.classifier(out) 163 | return out, prob 164 | 165 | 166 | 167 | def _load_state_dict(model, model_url, progress): 168 | # '.'s are no longer allowed in module names, but previous _DenseLayer 169 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 170 | # They are also in the checkpoints in model_urls. This pattern is used 171 | # to find such keys. 172 | pattern = re.compile( 173 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 174 | 175 | state_dict = load_state_dict_from_url(model_url, progress=progress) 176 | for key in list(state_dict.keys()): 177 | res = pattern.match(key) 178 | if res: 179 | new_key = res.group(1) + res.group(2) 180 | state_dict[new_key] = state_dict[key] 181 | del state_dict[key] 182 | model.load_state_dict(state_dict) 183 | 184 | 185 | def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress, num_classes, bias, 186 | **kwargs): 187 | model = DenseNet(growth_rate, block_config, num_init_features, num_classes=num_classes, bias=bias, **kwargs) 188 | if pretrained: 189 | _load_state_dict(model, model_urls[arch], progress) 190 | return model 191 | 192 | 193 | def densenet121(pretrained=False, progress=True, num_classes=10, bias=True, **kwargs): 194 | r"""Densenet-121 model from 195 | `"Densely Connected Convolutional Networks" `_ 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | progress (bool): If True, displays a progress bar of the download to stderr 199 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 200 | but slower. Default: *False*. See `"paper" `_ 201 | """ 202 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, num_classes=num_classes, bias=bias, 203 | **kwargs) 204 | 205 | 206 | def densenet161(pretrained=False, progress=True, **kwargs): 207 | r"""Densenet-161 model from 208 | `"Densely Connected Convolutional Networks" `_ 209 | Args: 210 | pretrained (bool): If True, returns a model pre-trained on ImageNet 211 | progress (bool): If True, displays a progress bar of the download to stderr 212 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 213 | but slower. Default: *False*. See `"paper" `_ 214 | """ 215 | return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, 216 | **kwargs) 217 | 218 | 219 | def densenet169(pretrained=False, progress=True, **kwargs): 220 | r"""Densenet-169 model from 221 | `"Densely Connected Convolutional Networks" `_ 222 | Args: 223 | pretrained (bool): If True, returns a model pre-trained on ImageNet 224 | progress (bool): If True, displays a progress bar of the download to stderr 225 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 226 | but slower. Default: *False*. See `"paper" `_ 227 | """ 228 | return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, 229 | **kwargs) 230 | 231 | 232 | def densenet201(pretrained=False, progress=True, **kwargs): 233 | r"""Densenet-201 model from 234 | `"Densely Connected Convolutional Networks" `_ 235 | Args: 236 | pretrained (bool): If True, returns a model pre-trained on ImageNet 237 | progress (bool): If True, displays a progress bar of the download to stderr 238 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 239 | but slower. Default: *False*. See `"paper" `_ 240 | """ 241 | return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, 242 | **kwargs) 243 | -------------------------------------------------------------------------------- /small_scale/models/densenet3.py: -------------------------------------------------------------------------------- 1 | '''DenseNet in PyTorch.''' 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, in_planes, growth_rate): 11 | super(Bottleneck, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 16 | 17 | def forward(self, x): 18 | out = self.conv1(F.relu(self.bn1(x))) 19 | out = self.conv2(F.relu(self.bn2(out))) 20 | out = torch.cat([out,x], 1) 21 | return out 22 | 23 | 24 | class Transition(nn.Module): 25 | def __init__(self, in_planes, out_planes): 26 | super(Transition, self).__init__() 27 | self.bn = nn.BatchNorm2d(in_planes) 28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 29 | 30 | def forward(self, x): 31 | out = self.conv(F.relu(self.bn(x))) 32 | out = F.avg_pool2d(out, 2) 33 | return out 34 | 35 | 36 | class CIFAR_DenseNet(nn.Module): 37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10, bias=True): 38 | super(CIFAR_DenseNet, self).__init__() 39 | self.growth_rate = growth_rate 40 | 41 | num_planes = 2*growth_rate 42 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 43 | 44 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 45 | num_planes += nblocks[0]*growth_rate 46 | out_planes = int(math.floor(num_planes*reduction)) 47 | self.trans1 = Transition(num_planes, out_planes) 48 | num_planes = out_planes 49 | 50 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 51 | num_planes += nblocks[1]*growth_rate 52 | out_planes = int(math.floor(num_planes*reduction)) 53 | self.trans2 = Transition(num_planes, out_planes) 54 | num_planes = out_planes 55 | 56 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 57 | num_planes += nblocks[2]*growth_rate 58 | out_planes = int(math.floor(num_planes*reduction)) 59 | self.trans3 = Transition(num_planes, out_planes) 60 | num_planes = out_planes 61 | 62 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 63 | num_planes += nblocks[3]*growth_rate 64 | 65 | self.bn = nn.BatchNorm2d(num_planes) 66 | self.linear = nn.Linear(num_planes, num_classes, bias=bias) 67 | 68 | def _make_dense_layers(self, block, in_planes, nblock): 69 | layers = [] 70 | for i in range(nblock): 71 | layers.append(block(in_planes, self.growth_rate)) 72 | in_planes += self.growth_rate 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = self.conv1(x) 77 | out = self.trans1(self.dense1(out)) 78 | out = self.trans2(self.dense2(out)) 79 | out = self.trans3(self.dense3(out)) 80 | out = self.dense4(out) 81 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 82 | out = out.view(out.size(0), -1) 83 | prob = self.linear(out) 84 | return out, prob 85 | 86 | def CIFAR_DenseNet121(pretrained=False, num_classes=10, bias=True, **kwargs): 87 | return CIFAR_DenseNet(Bottleneck, [6,12,24,16], growth_rate=32, num_classes=num_classes, bias=bias) 88 | 89 | # def DenseNet169(): 90 | # return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32) 91 | 92 | # def DenseNet201(): 93 | # return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32) 94 | 95 | # def DenseNet161(): 96 | # return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48) 97 | 98 | # def densenet_cifar(): 99 | # return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12) 100 | 101 | # def test(): 102 | # net = densenet_cifar() 103 | # x = torch.randn(1,3,32,32) 104 | # y = net(x) 105 | # print(y) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /small_scale/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import torch.nn.functional as F 4 | 5 | __all__ = ['ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'CIFAR_ResNet', 'CIFAR_ResNet18', 'CIFAR_ResNet34', 'CIFAR_ResNet10'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, groups=groups, bias=False) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 33 | base_width=64, norm_layer=None): 34 | super(BasicBlock, self).__init__() 35 | if norm_layer is None: 36 | norm_layer = nn.BatchNorm2d 37 | if groups != 1 or base_width != 64: 38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 39 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = norm_layer(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes) 44 | self.bn2 = norm_layer(planes) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | identity = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | 61 | out += identity 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 71 | base_width=64, norm_layer=None): 72 | super(Bottleneck, self).__init__() 73 | if norm_layer is None: 74 | norm_layer = nn.BatchNorm2d 75 | width = int(planes * (base_width / 64.)) * groups 76 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 77 | self.conv1 = conv1x1(inplanes, width) 78 | self.bn1 = norm_layer(width) 79 | self.conv2 = conv3x3(width, width, stride, groups) 80 | self.bn2 = norm_layer(width) 81 | self.conv3 = conv1x1(width, planes * self.expansion) 82 | self.bn3 = norm_layer(planes * self.expansion) 83 | self.relu = nn.ReLU(inplace=True) 84 | self.downsample = downsample 85 | self.stride = stride 86 | 87 | def forward(self, x): 88 | identity = x 89 | 90 | out = self.conv1(x) 91 | out = self.bn1(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv2(out) 95 | out = self.bn2(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv3(out) 99 | out = self.bn3(out) 100 | 101 | if self.downsample is not None: 102 | identity = self.downsample(x) 103 | 104 | out += identity 105 | out = self.relu(out) 106 | 107 | return out 108 | 109 | class PreActBlock(nn.Module): 110 | '''Pre-activation version of the BasicBlock.''' 111 | expansion = 1 112 | 113 | def __init__(self, in_planes, planes, stride=1): 114 | super(PreActBlock, self).__init__() 115 | self.bn1 = nn.BatchNorm2d(in_planes) 116 | self.conv1 = conv3x3(in_planes, planes, stride) 117 | self.bn2 = nn.BatchNorm2d(planes) 118 | self.conv2 = conv3x3(planes, planes) 119 | self.relu = nn.ReLU(inplace=True) 120 | 121 | self.shortcut = nn.Sequential() 122 | if stride != 1 or in_planes != self.expansion*planes: 123 | self.shortcut = nn.Sequential( 124 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 125 | ) 126 | 127 | def forward(self, x): 128 | out = self.relu(self.bn1(x)) 129 | shortcut = self.shortcut(out) 130 | out = self.conv1(out) 131 | out = self.conv2(self.relu(self.bn2(out))) 132 | out += shortcut 133 | return out 134 | 135 | class ResNet(nn.Module): 136 | 137 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 138 | groups=1, width_per_group=64, norm_layer=None): 139 | super(ResNet, self).__init__() 140 | if norm_layer is None: 141 | norm_layer = nn.BatchNorm2d 142 | 143 | self.inplanes = 64 144 | self.groups = groups 145 | self.base_width = width_per_group 146 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 147 | bias=False) 148 | self.bn1 = norm_layer(self.inplanes) 149 | self.relu = nn.ReLU(inplace=True) 150 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 151 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 152 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer) 154 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer) 155 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 156 | self.fc = nn.Linear(512 * block.expansion, num_classes) 157 | for m in self.modules(): 158 | if isinstance(m, nn.Conv2d): 159 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 160 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 161 | nn.init.constant_(m.weight, 1) 162 | nn.init.constant_(m.bias, 0) 163 | 164 | # Zero-initialize the last BN in each residual branch, 165 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 166 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 167 | if zero_init_residual: 168 | for m in self.modules(): 169 | if isinstance(m, Bottleneck): 170 | nn.init.constant_(m.bn3.weight, 0) 171 | elif isinstance(m, BasicBlock): 172 | nn.init.constant_(m.bn2.weight, 0) 173 | 174 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None): 175 | if norm_layer is None: 176 | norm_layer = nn.BatchNorm2d 177 | downsample = None 178 | if stride != 1 or self.inplanes != planes * block.expansion: 179 | downsample = nn.Sequential( 180 | conv1x1(self.inplanes, planes * block.expansion, stride), 181 | norm_layer(planes * block.expansion), 182 | ) 183 | 184 | layers = [] 185 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 186 | self.base_width, norm_layer)) 187 | self.inplanes = planes * block.expansion 188 | for _ in range(1, blocks): 189 | layers.append(block(self.inplanes, planes, groups=self.groups, 190 | base_width=self.base_width, norm_layer=norm_layer)) 191 | 192 | return nn.Sequential(*layers) 193 | 194 | def forward(self, x): 195 | x = self.conv1(x) 196 | x = self.bn1(x) 197 | x = self.relu(x) 198 | x = self.maxpool(x) 199 | 200 | x = self.layer1(x) 201 | x = self.layer2(x) 202 | x = self.layer3(x) 203 | x = self.layer4(x) 204 | 205 | x = self.avgpool(x) 206 | x = x.view(x.size(0), -1) 207 | prob = self.fc(x) 208 | 209 | return x, prob 210 | 211 | class CIFAR_ResNet(nn.Module): 212 | def __init__(self, block, num_blocks, num_classes=10, bias=True): 213 | super(CIFAR_ResNet, self).__init__() 214 | self.in_planes = 64 215 | self.conv1 = conv3x3(3,64) 216 | self.bn1 = nn.BatchNorm2d(64) 217 | self.relu = nn.ReLU(inplace=True) 218 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 219 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 220 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 221 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 222 | self.gap = nn.AvgPool2d(4) 223 | self.linear = nn.Linear(512*block.expansion, num_classes, bias=bias) 224 | 225 | 226 | def _make_layer(self, block, planes, num_blocks, stride): 227 | strides = [stride] + [1]*(num_blocks-1) 228 | layers = [] 229 | for stride in strides: 230 | layers.append(block(self.in_planes, planes, stride)) 231 | self.in_planes = planes * block.expansion 232 | return nn.Sequential(*layers) 233 | 234 | def forward(self, x, lin=0, lout=5): 235 | out = x 236 | out = self.conv1(out) 237 | out = self.bn1(out) 238 | out = self.relu(out) 239 | out1 = self.layer1(out) 240 | out2 = self.layer2(out1) 241 | out3 = self.layer3(out2) 242 | out = self.layer4(out3) 243 | #out = F.avg_pool2d(out, 4) 244 | out = self.gap(out) 245 | out4 = out.view(out.size(0), -1) 246 | out = self.linear(out4) 247 | 248 | return out4, out 249 | 250 | 251 | def CIFAR_ResNet10(pretrained=False, **kwargs): 252 | return CIFAR_ResNet(PreActBlock, [1,1,1,1], **kwargs) 253 | 254 | def CIFAR_ResNet18(pretrained=False, **kwargs): 255 | return CIFAR_ResNet(PreActBlock, [2,2,2,2], **kwargs) 256 | 257 | def CIFAR_ResNet34(pretrained=False, **kwargs): 258 | return CIFAR_ResNet(PreActBlock, [3,4,6,3], **kwargs) 259 | 260 | def resnet10(pretrained=False, **kwargs): 261 | """Constructs a ResNet-10 model. 262 | Args: 263 | pretrained (bool): If True, returns a model pre-trained on ImageNet 264 | """ 265 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 266 | return model 267 | 268 | def resnet18(pretrained=False, **kwargs): 269 | """Constructs a ResNet-18 model. 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | """ 273 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 274 | if pretrained: 275 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 276 | return model 277 | 278 | 279 | def my_resnet34(pretrained=False, **kwargs): 280 | model = my_ResNet(BasicBlock, [3, 4, 6, 3, 3], **kwargs) 281 | return model 282 | 283 | def resnet34(pretrained=False, **kwargs): 284 | """Constructs a ResNet-34 model. 285 | Args: 286 | pretrained (bool): If True, returns a model pre-trained on ImageNet 287 | """ 288 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 289 | if pretrained: 290 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 291 | return model 292 | 293 | 294 | def resnet50(pretrained=False, **kwargs): 295 | """Constructs a ResNet-50 model. 296 | Args: 297 | pretrained (bool): If True, returns a model pre-trained on ImageNet 298 | """ 299 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 300 | if pretrained: 301 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 302 | return model 303 | 304 | 305 | def resnet101(pretrained=False, **kwargs): 306 | """Constructs a ResNet-101 model. 307 | Args: 308 | pretrained (bool): If True, returns a model pre-trained on ImageNet 309 | """ 310 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 311 | if pretrained: 312 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 313 | return model 314 | 315 | 316 | def resnet152(pretrained=False, **kwargs): 317 | """Constructs a ResNet-152 model. 318 | Args: 319 | pretrained (bool): If True, returns a model pre-trained on ImageNet 320 | """ 321 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 322 | if pretrained: 323 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 324 | return model 325 | 326 | 327 | def resnext50_32x4d(pretrained=False, **kwargs): 328 | model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, **kwargs) 329 | # if pretrained: 330 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext50_32x4d'])) 331 | return model 332 | 333 | 334 | def resnext101_32x8d(pretrained=False, **kwargs): 335 | model = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs) 336 | # if pretrained: 337 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d'])) 338 | return model 339 | -------------------------------------------------------------------------------- /small_scale/scripts/train_bake.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | GPU=$1 3 | DATA=$2 4 | ARCH=$3 5 | N=$4 6 | M=$5 7 | OMEGA=$6 8 | 9 | if [ $# -ne 6 ] 10 | then 11 | echo "Arguments error: " 12 | exit 1 13 | fi 14 | 15 | python train.py \ 16 | --lr 0.1 \ 17 | --decay 1e-4 \ 18 | --epoch 200 \ 19 | --lamda 1.0 \ 20 | --temp 4.0 \ 21 | --sgpu $GPU \ 22 | -d $DATA \ 23 | -a $ARCH \ 24 | -n $N \ 25 | -m $M \ 26 | --omega $OMEGA \ 27 | --name BAKE_$DATA_$ARCH 28 | -------------------------------------------------------------------------------- /small_scale/scripts/train_baseline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | GPU=$1 3 | DATA=$2 4 | ARCH=$3 5 | N=$4 6 | 7 | if [ $# -ne 4 ] 8 | then 9 | echo "Arguments error: " 10 | exit 1 11 | fi 12 | 13 | python train.py \ 14 | --lr 0.1 \ 15 | --decay 1e-4 \ 16 | --epoch 200 \ 17 | --lamda 0.0 \ 18 | -m 0 \ 19 | --omega 0.0 \ 20 | --sgpu $GPU \ 21 | -d $DATA \ 22 | -a $ARCH \ 23 | -n $N \ 24 | --name baseline_$DATA_$ARCH 25 | -------------------------------------------------------------------------------- /small_scale/scripts/val.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | GPU=$1 3 | DATA=$2 4 | ARCH=$3 5 | CKPT=$4 6 | 7 | if [ $# -ne 4 ] 8 | then 9 | echo "Arguments error: " 10 | exit 1 11 | fi 12 | 13 | python train.py \ 14 | --eval \ 15 | --resume $CKPT \ 16 | --sgpu $GPU \ 17 | -d $DATA \ 18 | -a $ARCH \ 19 | -n 64 \ 20 | -m 0 \ 21 | --name test 22 | -------------------------------------------------------------------------------- /small_scale/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import csv 5 | import os, logging 6 | import random 7 | 8 | import numpy as np 9 | import torch 10 | from torch.autograd import Variable, grad 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | import torchvision.transforms as transforms 16 | 17 | import models 18 | from utils import progress_bar, set_logging_defaults 19 | from datasets import load_dataset 20 | 21 | parser = argparse.ArgumentParser(description='Small-scale Datasets Training') 22 | 23 | parser.add_argument('--name', default='cifar_res18_train', type=str, help='name of experiment') 24 | parser.add_argument('--seed', default=0, type=int, help='random seed') 25 | parser.add_argument('--arch', '-a', default="CIFAR_ResNet18", type=str, help='model type (32x32: CIFAR_ResNet18, CIFAR_DenseNet121, 224x224: resnet18, densenet121)') 26 | parser.add_argument('--resume', '-r', default="", help='resume from checkpoint') 27 | parser.add_argument('--eval', action='store_true', help='only evaluate') 28 | parser.add_argument('--sgpu', default=0, type=int, help='gpu index (start)') 29 | parser.add_argument('--ngpu', default=1, type=int, help='number of gpu') 30 | parser.add_argument('--dataroot', default='./data', type=str, help='data directory') 31 | parser.add_argument('--saveroot', default='./results', type=str, help='save directory') 32 | parser.add_argument('--dataset', '-d', default='cifar100', type=str, help='the name for dataset cifar100 | tinyimagenet | CUB200 | STANFORD120 | MIT67') 33 | 34 | parser.add_argument('--epoch', default=200, type=int, help='total epochs to run') 35 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 36 | parser.add_argument('--decay', default=1e-4, type=float, help='weight decay') 37 | parser.add_argument('--batch-size', '-n', default=128, type=int, help='batch size, N') 38 | parser.add_argument('--intra-imgs', '-m', default=3, type=int, help='intra-class images, M') 39 | 40 | parser.add_argument('--temp', default=4.0, type=float, help='temperature scaling') 41 | parser.add_argument('--lamda', default=1.0, type=float, help='kd loss weight ratio') 42 | parser.add_argument('--omega', default=0.5, type=float, help='ensembling weight') 43 | 44 | args = parser.parse_args() 45 | use_cuda = torch.cuda.is_available() 46 | args.num_instances = args.intra_imgs + 1 47 | args.batch_size = args.batch_size // args.num_instances 48 | 49 | best_val = 0 # best validation accuracy 50 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 51 | 52 | random.seed(args.seed) 53 | np.random.seed(args.seed) 54 | torch.manual_seed(args.seed) 55 | torch.cuda.manual_seed_all(args.seed) 56 | 57 | cudnn.benchmark = True 58 | 59 | # Data 60 | print('==> Preparing dataset: {}'.format(args.dataset)) 61 | trainloader, valloader = load_dataset(args.dataset, args.dataroot, 62 | batch_size=args.batch_size, 63 | num_instances=args.num_instances) 64 | 65 | num_class = trainloader.dataset.num_classes 66 | print('Number of train dataset: ' ,len(trainloader.dataset)) 67 | print('Number of validation dataset: ' ,len(valloader.dataset)) 68 | 69 | # Model 70 | print('==> Building model: {}'.format(args.arch)) 71 | net = models.load_model(args.arch, num_class) 72 | 73 | if use_cuda: 74 | torch.cuda.set_device(args.sgpu) 75 | net.cuda() 76 | print(torch.cuda.device_count()) 77 | print('Using CUDA..') 78 | 79 | if args.ngpu > 1: 80 | net = torch.nn.DataParallel(net, device_ids=list(range(args.sgpu, args.sgpu + args.ngpu))) 81 | 82 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.decay) 83 | 84 | logdir = os.path.join(args.saveroot, args.dataset, args.arch, args.name) 85 | set_logging_defaults(logdir, args) 86 | logger = logging.getLogger('main') 87 | logname = os.path.join(logdir, 'log.csv') 88 | 89 | # Resume 90 | if args.resume: 91 | # Load checkpoint. 92 | print('==> Resuming from checkpoint..') 93 | checkpoint = torch.load(args.resume) 94 | net.load_state_dict(checkpoint['net']) 95 | optimizer.load_state_dict(checkpoint['optimizer']) 96 | best_acc = checkpoint['acc'] 97 | start_epoch = checkpoint['epoch'] + 1 98 | rng_state = checkpoint['rng_state'] 99 | torch.set_rng_state(rng_state) 100 | 101 | criterion = nn.CrossEntropyLoss() 102 | 103 | class KDLoss(nn.Module): 104 | def __init__(self, temp_factor): 105 | super(KDLoss, self).__init__() 106 | self.temp_factor = temp_factor 107 | self.kl_div = nn.KLDivLoss(reduction="sum") 108 | 109 | def forward(self, input, target): 110 | log_p = torch.log_softmax(input/self.temp_factor, dim=1) 111 | loss = self.kl_div(log_p, target)*(self.temp_factor**2)/input.size(0) 112 | return loss 113 | 114 | kdloss = KDLoss(args.temp) 115 | softmax = nn.Softmax(dim=1) 116 | 117 | def knowledge_ensemble(feats, logits): 118 | batch_size = logits.size(0) 119 | masks = torch.eye(batch_size) 120 | if use_cuda: 121 | masks = masks.cuda() 122 | feats = nn.functional.normalize(feats, p=2, dim=1) 123 | logits = nn.functional.softmax(logits/args.temp, dim=1) 124 | W = torch.matmul(feats, feats.permute(1, 0)) - masks * 1e9 125 | W = softmax(W) 126 | W = (1 - args.omega) * torch.inverse(masks - args.omega * W) 127 | return torch.matmul(W, logits) 128 | 129 | def train(epoch): 130 | print('\nEpoch: %d' % epoch) 131 | net.train() 132 | train_ce_loss = 0 133 | correct = 0 134 | total = 0 135 | train_kd_loss = 0 136 | for batch_idx, (inputs, targets) in enumerate(trainloader): 137 | if use_cuda: 138 | inputs, targets = inputs.cuda(), targets.cuda() 139 | 140 | batch_size = inputs.size(0) 141 | 142 | features, outputs = net(inputs) 143 | loss = criterion(outputs, targets) 144 | train_ce_loss += loss.item() 145 | 146 | ############ 147 | with torch.no_grad(): 148 | kd_targets = knowledge_ensemble(features.detach(), outputs.detach()) 149 | kd_loss = kdloss(outputs, kd_targets.detach()) 150 | loss += args.lamda * kd_loss 151 | train_kd_loss += kd_loss.item() 152 | ############ 153 | 154 | _, predicted = torch.max(outputs, 1) 155 | total += targets.size(0) 156 | correct += predicted.eq(targets.data).sum().float().cpu() 157 | 158 | optimizer.zero_grad() 159 | loss.backward() 160 | optimizer.step() 161 | progress_bar(batch_idx, len(trainloader), 162 | 'CE loss: %.3f | KD loss: %.3f | Acc: %.3f%% (%d/%d)' 163 | % (train_ce_loss/(batch_idx+1), train_kd_loss/(batch_idx+1), 100.*correct/total, correct, total)) 164 | 165 | logger = logging.getLogger('train') 166 | logger.info('[Epoch {}] [CE loss {:.3f}] [KD loss {:.3f}] [Acc {:.3f}]'.format( 167 | epoch, 168 | train_ce_loss/(batch_idx+1), 169 | train_kd_loss/(batch_idx+1), 170 | 100.*correct/total)) 171 | 172 | return 100.*correct/total 173 | 174 | def val(epoch): 175 | global best_val 176 | net.eval() 177 | val_loss = 0.0 178 | correct = 0.0 179 | total = 0.0 180 | 181 | # Define a data loader for evaluating 182 | loader = valloader 183 | 184 | with torch.no_grad(): 185 | for batch_idx, (inputs, targets) in enumerate(loader): 186 | if use_cuda: 187 | inputs, targets = inputs.cuda(), targets.cuda() 188 | 189 | _, outputs = net(inputs) 190 | loss = torch.mean(criterion(outputs, targets)) 191 | 192 | val_loss += loss.item() 193 | _, predicted = torch.max(outputs, 1) 194 | total += targets.size(0) 195 | correct += predicted.eq(targets.data).cpu().sum().float() 196 | 197 | progress_bar(batch_idx, len(loader), 198 | 'Loss: %.3f | Acc: %.3f%% (%d/%d) ' 199 | % (val_loss/(batch_idx+1), 100.*correct/total, correct, total)) 200 | 201 | acc = 100.*correct/total 202 | if acc > best_val: 203 | best_val = acc 204 | checkpoint(acc, epoch) 205 | logger = logging.getLogger('val') 206 | logger.info('[Epoch {}] [Loss {:.3f}] [Acc {:.3f}] [Best Acc {:.3f}]'.format( 207 | epoch, 208 | val_loss/(batch_idx+1), 209 | acc, best_val)) 210 | 211 | return (val_loss/(batch_idx+1), acc) 212 | 213 | 214 | def checkpoint(acc, epoch): 215 | # Save checkpoint. 216 | print('Saving..') 217 | state = { 218 | 'net': net.state_dict(), 219 | 'optimizer': optimizer.state_dict(), 220 | 'acc': acc, 221 | 'epoch': epoch, 222 | 'rng_state': torch.get_rng_state() 223 | } 224 | torch.save(state, os.path.join(logdir, 'ckpt.t7')) 225 | 226 | 227 | def adjust_learning_rate(optimizer, epoch): 228 | """decrease the learning rate at 100 and 150 epoch""" 229 | lr = args.lr 230 | if epoch >= 0.5 * args.epoch: 231 | lr /= 10 232 | if epoch >= 0.75 * args.epoch: 233 | lr /= 10 234 | for param_group in optimizer.param_groups: 235 | param_group['lr'] = lr 236 | 237 | if (not args.eval): 238 | # Logs 239 | for epoch in range(start_epoch, args.epoch): 240 | train_acc = train(epoch) 241 | val_loss, val_acc = val(epoch) 242 | adjust_learning_rate(optimizer, epoch) 243 | else: 244 | val_loss, val_acc = val(0) 245 | 246 | print("Best Accuracy : {}".format(best_val)) 247 | logger = logging.getLogger('best') 248 | logger.info('[Acc {:.3f}]'.format(best_val)) 249 | -------------------------------------------------------------------------------- /small_scale/utils.py: -------------------------------------------------------------------------------- 1 | import os, logging 2 | import sys 3 | import time 4 | import math 5 | import shutil 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.init as init 10 | 11 | 12 | def set_logging_defaults(logdir, args): 13 | if os.path.isdir(logdir): 14 | res = input('"{}" exists. Overwrite [Y/n]? '.format(logdir)) 15 | if res != 'Y': 16 | raise Exception('"{}" exists.'.format(logdir)) 17 | else: 18 | os.makedirs(logdir) 19 | 20 | # set basic configuration for logging 21 | logging.basicConfig(format="[%(asctime)s] [%(name)s] %(message)s", 22 | level=logging.INFO, 23 | handlers=[logging.FileHandler(os.path.join(logdir, 'log.txt')), 24 | logging.StreamHandler(os.sys.stdout)]) 25 | 26 | # log cmdline argumetns 27 | logger = logging.getLogger('main') 28 | logger.info(' '.join(os.sys.argv)) 29 | logger.info(args) 30 | 31 | # _, term_width = os.popen('stty size', 'r').read().split() 32 | _, term_width = shutil.get_terminal_size() 33 | term_width = int(term_width) 34 | 35 | TOTAL_BAR_LENGTH = 86. 36 | last_time = time.time() 37 | begin_time = last_time 38 | 39 | def progress_bar(current, total, msg=None): 40 | global last_time, begin_time 41 | if current == 0: 42 | begin_time = time.time() # Reset for new bar. 43 | 44 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 45 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 46 | 47 | sys.stdout.write(' [') 48 | for i in range(cur_len): 49 | sys.stdout.write('=') 50 | sys.stdout.write('>') 51 | for i in range(rest_len): 52 | sys.stdout.write('.') 53 | sys.stdout.write(']') 54 | 55 | cur_time = time.time() 56 | step_time = cur_time - last_time 57 | last_time = cur_time 58 | tot_time = cur_time - begin_time 59 | 60 | L = [] 61 | L.append(' Step: %s' % format_time(step_time)) 62 | L.append(' | Tot: %s' % format_time(tot_time)) 63 | if msg: 64 | L.append(' | ' + msg) 65 | 66 | msg = ''.join(L) 67 | sys.stdout.write(msg) 68 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 69 | sys.stdout.write(' ') 70 | 71 | # Go back to the center of the bar. 72 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)): 73 | sys.stdout.write('\b') 74 | sys.stdout.write(' %d/%d ' % (current+1, total)) 75 | 76 | if current < total-1: 77 | sys.stdout.write('\r') 78 | else: 79 | sys.stdout.write('\n') 80 | sys.stdout.flush() 81 | 82 | def format_time(seconds): 83 | days = int(seconds / 3600/24) 84 | seconds = seconds - days*3600*24 85 | hours = int(seconds / 3600) 86 | seconds = seconds - hours*3600 87 | minutes = int(seconds / 60) 88 | seconds = seconds - minutes*60 89 | secondsf = int(seconds) 90 | seconds = seconds - secondsf 91 | millis = int(seconds*1000) 92 | 93 | f = '' 94 | i = 1 95 | if days > 0: 96 | f += str(days) + 'D' 97 | i += 1 98 | if hours > 0 and i <= 2: 99 | f += str(hours) + 'h' 100 | i += 1 101 | if minutes > 0 and i <= 2: 102 | f += str(minutes) + 'm' 103 | i += 1 104 | if secondsf > 0 and i <= 2: 105 | f += str(secondsf) + 's' 106 | i += 1 107 | if millis > 0 and i <= 2: 108 | f += str(millis) + 'ms' 109 | i += 1 110 | if f == '': 111 | f = '0ms' 112 | return f 113 | --------------------------------------------------------------------------------