├── .gitignore
├── LICENSE
├── README.md
├── bake.png
├── imagenet
├── LICENSE
├── README.md
├── configs
│ ├── effnet
│ │ ├── BAKE-EN-B0_dds_8gpu.yaml
│ │ └── EN-B0_dds_8gpu.yaml
│ ├── mobilenet
│ │ ├── BAKE-M-V2-W1_dds_4gpu.yaml
│ │ └── M-V2-W1_dds_4gpu.yaml
│ ├── resnest
│ │ ├── BAKE-S-101_dds_8gpu.yaml
│ │ ├── BAKE-S-50_dds_8gpu.yaml
│ │ ├── S-101_dds_8gpu.yaml
│ │ └── S-50_dds_8gpu.yaml
│ ├── resnet
│ │ ├── BAKE-R-101-1x64d_dds_8gpu.yaml
│ │ ├── BAKE-R-152-1x64d_dds_8gpu.yaml
│ │ ├── BAKE-R-50-1x64d_dds_8gpu.yaml
│ │ ├── R-101-1x64d_dds_8gpu.yaml
│ │ ├── R-152-1x64d_dds_8gpu.yaml
│ │ └── R-50-1x64d_dds_8gpu.yaml
│ └── resnext
│ │ ├── BAKE-X-101-32x4d_dds_8gpu.yaml
│ │ ├── BAKE-X-152-32x4d_dds_8gpu.yaml
│ │ ├── X-101-32x4d_dds_8gpu.yaml
│ │ └── X-152-32x4d_dds_8gpu.yaml
├── pycls
│ ├── __init__.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── benchmark.py
│ │ ├── builders.py
│ │ ├── checkpoint.py
│ │ ├── config.py
│ │ ├── distributed.py
│ │ ├── io.py
│ │ ├── logging.py
│ │ ├── meters.py
│ │ ├── net.py
│ │ ├── optimizer.py
│ │ ├── plotting.py
│ │ ├── timer.py
│ │ └── trainer.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── imagenet.py
│ │ ├── loader.py
│ │ ├── sampler.py
│ │ └── transforms.py
│ └── models
│ │ ├── __init__.py
│ │ ├── anynet.py
│ │ ├── effnet.py
│ │ ├── mobilenetv2.py
│ │ ├── resnest.py
│ │ └── resnet.py
├── requirements.txt
├── setup.py
└── tools
│ ├── dist_test.sh
│ ├── dist_train.sh
│ ├── slurm_test.sh
│ ├── slurm_train.sh
│ ├── test_net.py
│ └── train_net.py
└── small_scale
├── README.md
├── data
├── README.md
└── tools
│ ├── cub.py
│ ├── mit67.py
│ ├── stanford_dogs.py
│ └── tinyimagenet.py
├── datasets.py
├── models
├── __init__.py
├── densenet.py
├── densenet3.py
└── resnet.py
├── scripts
├── train_bake.sh
├── train_baseline.sh
└── val.sh
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Shared objects
7 | *.so
8 |
9 | # Distribution / packaging
10 | build/
11 | *.egg-info/
12 | *.egg
13 |
14 | # Temporary files
15 | *.swn
16 | *.swo
17 | *.swp
18 |
19 | # PyCharm
20 | .idea/
21 |
22 | # Mac
23 | .DS_STORE
24 |
25 | # Data symlinks
26 | pycls/datasets/data/
27 |
28 | # Other
29 | logs/
30 | scratch*
31 |
32 | .un~
33 |
34 | arun_log/
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Yixiao Ge
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Self-distillation with Batch Knowledge Ensembling Improves ImageNet Classification
2 |
3 |
4 |
5 |

6 |
7 |
8 |
9 | ## Updates
10 |
11 | [2021-06-08] The implementation of BAKE on small-scale datasets has been added, please refer to [small_scale](small_scale/).
12 | [2021-06-09] The implementation of BAKE on ImageNet has been added, please refer to [imagenet](imagenet/).
13 |
14 |
15 | ## Citation
16 |
17 | If you find **BAKE** helpful in your research, please consider citing:
18 |
19 | ```
20 | @misc{ge2020bake,
21 | title={Self-distillation with Batch Knowledge Ensembling Improves ImageNet Classification},
22 | author={Yixiao Ge and Ching Lam Choi and Xiao Zhang and Peipei Zhao and Feng Zhu and Rui Zhao and Hongsheng Li},
23 | year={2021},
24 | archivePrefix={arXiv},
25 | primaryClass={cs.CV}
26 | }
27 | ```
28 |
--------------------------------------------------------------------------------
/bake.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxgeee/BAKE/07c4f668ea19311d5b50121026e73d2f035d5765/bake.png
--------------------------------------------------------------------------------
/imagenet/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Facebook, Inc. and its affiliates. Modified by Yixiao Ge.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/imagenet/README.md:
--------------------------------------------------------------------------------
1 | # BAKE on ImageNet
2 |
3 | PyTorch implementation of [Self-distillation with Batch Knowledge Ensembling Improves ImageNet Classification](https://arxiv.org/abs/2104.13298) on ImageNet.
4 |
5 |
6 | ## Installation
7 |
8 | - Install Python dependencies:
9 |
10 | ```
11 | pip install -r requirements.txt
12 | ```
13 |
14 | - Set up Python modules:
15 |
16 | ```
17 | python setup.py develop --user
18 | ```
19 |
20 | ## Dataset
21 |
22 | - Download ImageNet with an expected structure:
23 |
24 | ```
25 | imagenet
26 | |_ train
27 | | |_ n01440764
28 | | |_ ...
29 | | |_ n15075141
30 | |_ val
31 | | |_ n01440764
32 | | |_ ...
33 | | |_ n15075141
34 | |_ ...
35 | ```
36 |
37 | - Create a directory containing symlinks:
38 |
39 | ```
40 | mkdir -p pycls/datasets/data
41 | ```
42 |
43 | - Symlink ImageNet:
44 |
45 | ```
46 | ln -s /path/imagenet pycls/datasets/data/imagenet
47 | ```
48 |
49 |
50 | ## BAKE Training
51 |
52 | ```
53 | sh tools/dist_train.sh
54 | ```
55 |
56 | For example,
57 | ```
58 | sh tools/dist_train.sh logs/resnet50_bake configs/resnet/BAKE-R-50-1x64d_dds_8gpu.yaml
59 | ```
60 |
61 | **Note**: you could use `tools/slurm_train.sh` for distributed training with multiple machines.
62 |
63 |
64 | ## Validation
65 |
66 | ```
67 | sh tools/dist_test.sh
68 | ```
69 |
70 | For example,
71 | ```
72 | sh tools/dist_test.sh configs/resnet/BAKE-R-50-1x64d_dds_8gpu.yaml logs/resnet50_bake/checkpoints/model_epoch_0100.pyth
73 | ```
74 |
75 | ## Results
76 |
77 | |architecture|ImageNet top-1 acc.|config|download|
78 | |---|:--:|:--:|:--:|
79 | |ResNet-50|78.0|[config](configs/resnet/BAKE-R-50-1x64d_dds_8gpu.yaml)|[model](https://drive.google.com/file/d/1RJeUxXLHnSc6m3iiaEslVN2hx3Vsbldc/view?usp=sharing)|
80 | |ResNet-101|79.3|[config](configs/resnet/BAKE-R-101-1x64d_dds_8gpu.yaml)||
81 | |ResNet-152|79.6|[config](configs/resnet/BAKE-R-152-1x64d_dds_8gpu.yaml)||
82 | |ResNeSt-50|79.4|[config](configs/resnest/BAKE-S-50_dds_8gpu.yaml)||
83 | |ResNeSt-101|80.4|[config](configs/resnest/BAKE-S-101_dds_8gpu.yaml)||
84 | |ResNeXt-101(32x4d)|79.3|[config](configs/resnext/BAKE-X-101-32x4d_dds_8gpu.yaml)||
85 | |ResNeXt-152(32x4d)|79.7|[config](configs/resnext/BAKE-X-152-32x4d_dds_8gpu.yaml)||
86 | |MobileNet-V2|72.0|[config](configs/mobilenet/BAKE-M-V2-W1_dds_4gpu.yaml)||
87 | |EfficientNet-B0|76.2|[config](configs/effnet/BAKE-EN-B0_dds_8gpu.yaml)||
88 |
89 | ## Thanks
90 | The code is modified from [pycls](https://github.com/facebookresearch/pycls).
91 |
--------------------------------------------------------------------------------
/imagenet/configs/effnet/BAKE-EN-B0_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: effnet
3 | NUM_CLASSES: 1000
4 | EN:
5 | STEM_W: 32
6 | STRIDES: [1, 2, 2, 2, 1, 2, 1]
7 | DEPTHS: [1, 2, 2, 3, 3, 4, 1]
8 | WIDTHS: [16, 24, 40, 80, 112, 192, 320]
9 | EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6]
10 | KERNELS: [3, 3, 5, 3, 5, 5, 3]
11 | HEAD_W: 1280
12 | OPTIM:
13 | LR_POLICY: cos
14 | BASE_LR: 0.8
15 | MAX_EPOCH: 100
16 | MOMENTUM: 0.9
17 | WEIGHT_DECAY: 1e-5
18 | TRAIN:
19 | DATASET: imagenet
20 | IM_SIZE: 224
21 | BATCH_SIZE: 512
22 | INTRA_IMGS: 1
23 | SYNC_BN: True
24 | USE_BAKE: True
25 | OMEGA: 0.5
26 | TEMP: 4.0
27 | LAMBDA: 1.0
28 | GLOBAL_KE: True
29 | TEST:
30 | DATASET: imagenet
31 | IM_SIZE: 256
32 | BATCH_SIZE: 400
33 |
--------------------------------------------------------------------------------
/imagenet/configs/effnet/EN-B0_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: effnet
3 | NUM_CLASSES: 1000
4 | EN:
5 | STEM_W: 32
6 | STRIDES: [1, 2, 2, 2, 1, 2, 1]
7 | DEPTHS: [1, 2, 2, 3, 3, 4, 1]
8 | WIDTHS: [16, 24, 40, 80, 112, 192, 320]
9 | EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6]
10 | KERNELS: [3, 3, 5, 3, 5, 5, 3]
11 | HEAD_W: 1280
12 | OPTIM:
13 | LR_POLICY: cos
14 | BASE_LR: 0.4
15 | MAX_EPOCH: 100
16 | MOMENTUM: 0.9
17 | WEIGHT_DECAY: 1e-5
18 | TRAIN:
19 | DATASET: imagenet
20 | IM_SIZE: 224
21 | BATCH_SIZE: 256
22 | TEST:
23 | DATASET: imagenet
24 | IM_SIZE: 256
25 | BATCH_SIZE: 200
26 |
--------------------------------------------------------------------------------
/imagenet/configs/mobilenet/BAKE-M-V2-W1_dds_4gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: mobilenetv2
3 | NUM_CLASSES: 1000
4 | OPTIM:
5 | LR_POLICY: cos
6 | BASE_LR: 0.1
7 | MAX_EPOCH: 100
8 | MOMENTUM: 0.9
9 | WEIGHT_DECAY: 4e-5
10 | TRAIN:
11 | DATASET: imagenet
12 | IM_SIZE: 224
13 | BATCH_SIZE: 512
14 | INTRA_IMGS: 1
15 | SYNC_BN: True
16 | USE_BAKE: True
17 | OMEGA: 0.5
18 | TEMP: 4.0
19 | LAMBDA: 1.0
20 | GLOBAL_KE: True
21 | TEST:
22 | DATASET: imagenet
23 | IM_SIZE: 256
24 | BATCH_SIZE: 400
25 |
--------------------------------------------------------------------------------
/imagenet/configs/mobilenet/M-V2-W1_dds_4gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: mobilenetv2
3 | NUM_CLASSES: 1000
4 | OPTIM:
5 | LR_POLICY: cos
6 | BASE_LR: 0.05
7 | MAX_EPOCH: 100
8 | MOMENTUM: 0.9
9 | WEIGHT_DECAY: 4e-5
10 | TRAIN:
11 | DATASET: imagenet
12 | IM_SIZE: 224
13 | BATCH_SIZE: 256
14 | TEST:
15 | DATASET: imagenet
16 | IM_SIZE: 256
17 | BATCH_SIZE: 200
18 |
--------------------------------------------------------------------------------
/imagenet/configs/resnest/BAKE-S-101_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: resnest
3 | NUM_CLASSES: 1000
4 | RESNEST:
5 | STEM_W: 64
6 | DEPTHS: [3, 4, 23, 3]
7 | OPTIM:
8 | LR_POLICY: cos
9 | BASE_LR: 0.4
10 | MAX_EPOCH: 100
11 | MOMENTUM: 0.9
12 | WEIGHT_DECAY: 5e-5
13 | TRAIN:
14 | DATASET: imagenet
15 | IM_SIZE: 224
16 | BATCH_SIZE: 512
17 | INTRA_IMGS: 1
18 | SYNC_BN: True
19 | USE_BAKE: True
20 | OMEGA: 0.5
21 | TEMP: 4.0
22 | LAMBDA: 1.0
23 | GLOBAL_KE: True
24 | TEST:
25 | DATASET: imagenet
26 | IM_SIZE: 256
27 | BATCH_SIZE: 400
28 |
--------------------------------------------------------------------------------
/imagenet/configs/resnest/BAKE-S-50_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: resnest
3 | NUM_CLASSES: 1000
4 | RESNEST:
5 | STEM_W: 32
6 | DEPTHS: [3, 4, 6, 3]
7 | OPTIM:
8 | LR_POLICY: cos
9 | BASE_LR: 0.4
10 | MAX_EPOCH: 100
11 | MOMENTUM: 0.9
12 | WEIGHT_DECAY: 5e-5
13 | TRAIN:
14 | DATASET: imagenet
15 | IM_SIZE: 224
16 | BATCH_SIZE: 512
17 | INTRA_IMGS: 1
18 | SYNC_BN: True
19 | USE_BAKE: True
20 | OMEGA: 0.5
21 | TEMP: 4.0
22 | LAMBDA: 1.0
23 | GLOBAL_KE: True
24 | TEST:
25 | DATASET: imagenet
26 | IM_SIZE: 256
27 | BATCH_SIZE: 400
28 |
--------------------------------------------------------------------------------
/imagenet/configs/resnest/S-101_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: resnest
3 | NUM_CLASSES: 1000
4 | RESNEST:
5 | STEM_W: 64
6 | DEPTHS: [3, 4, 23, 3]
7 | OPTIM:
8 | LR_POLICY: cos
9 | BASE_LR: 0.2
10 | MAX_EPOCH: 100
11 | MOMENTUM: 0.9
12 | WEIGHT_DECAY: 5e-5
13 | TRAIN:
14 | DATASET: imagenet
15 | IM_SIZE: 224
16 | BATCH_SIZE: 256
17 | TEST:
18 | DATASET: imagenet
19 | IM_SIZE: 256
20 | BATCH_SIZE: 200
21 |
--------------------------------------------------------------------------------
/imagenet/configs/resnest/S-50_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: resnest
3 | NUM_CLASSES: 1000
4 | RESNEST:
5 | STEM_W: 32
6 | DEPTHS: [3, 4, 6, 3]
7 | OPTIM:
8 | LR_POLICY: cos
9 | BASE_LR: 0.2
10 | MAX_EPOCH: 100
11 | MOMENTUM: 0.9
12 | WEIGHT_DECAY: 5e-5
13 | TRAIN:
14 | DATASET: imagenet
15 | IM_SIZE: 224
16 | BATCH_SIZE: 256
17 | TEST:
18 | DATASET: imagenet
19 | IM_SIZE: 256
20 | BATCH_SIZE: 200
21 |
--------------------------------------------------------------------------------
/imagenet/configs/resnet/BAKE-R-101-1x64d_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: anynet
3 | NUM_CLASSES: 1000
4 | ANYNET:
5 | STEM_TYPE: res_stem_in
6 | STEM_W: 64
7 | BLOCK_TYPE: res_bottleneck_block
8 | STRIDES: [1, 2, 2, 2]
9 | DEPTHS: [3, 4, 23, 3]
10 | WIDTHS: [256, 512, 1024, 2048]
11 | BOT_MULS: [0.25, 0.25, 0.25, 0.25]
12 | GROUP_WS: [64, 128, 256, 512]
13 | OPTIM:
14 | LR_POLICY: cos
15 | BASE_LR: 0.4
16 | MAX_EPOCH: 100
17 | MOMENTUM: 0.9
18 | WEIGHT_DECAY: 5e-5
19 | TRAIN:
20 | DATASET: imagenet
21 | IM_SIZE: 224
22 | BATCH_SIZE: 512
23 | INTRA_IMGS: 1
24 | SYNC_BN: True
25 | USE_BAKE: True
26 | OMEGA: 0.5
27 | TEMP: 4.0
28 | LAMBDA: 1.0
29 | GLOBAL_KE: True
30 | TEST:
31 | DATASET: imagenet
32 | IM_SIZE: 256
33 | BATCH_SIZE: 400
34 |
--------------------------------------------------------------------------------
/imagenet/configs/resnet/BAKE-R-152-1x64d_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: anynet
3 | NUM_CLASSES: 1000
4 | ANYNET:
5 | STEM_TYPE: res_stem_in
6 | STEM_W: 64
7 | BLOCK_TYPE: res_bottleneck_block
8 | STRIDES: [1, 2, 2, 2]
9 | DEPTHS: [3, 8, 36, 3]
10 | WIDTHS: [256, 512, 1024, 2048]
11 | BOT_MULS: [0.25, 0.25, 0.25, 0.25]
12 | GROUP_WS: [64, 128, 256, 512]
13 | OPTIM:
14 | LR_POLICY: cos
15 | BASE_LR: 0.4
16 | MAX_EPOCH: 100
17 | MOMENTUM: 0.9
18 | WEIGHT_DECAY: 5e-5
19 | TRAIN:
20 | DATASET: imagenet
21 | IM_SIZE: 224
22 | BATCH_SIZE: 512
23 | INTRA_IMGS: 1
24 | SYNC_BN: True
25 | USE_BAKE: True
26 | OMEGA: 0.5
27 | TEMP: 4.0
28 | LAMBDA: 1.0
29 | GLOBAL_KE: True
30 | TEST:
31 | DATASET: imagenet
32 | IM_SIZE: 256
33 | BATCH_SIZE: 400
34 |
--------------------------------------------------------------------------------
/imagenet/configs/resnet/BAKE-R-50-1x64d_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: anynet
3 | NUM_CLASSES: 1000
4 | ANYNET:
5 | STEM_TYPE: res_stem_in
6 | STEM_W: 64
7 | BLOCK_TYPE: res_bottleneck_block
8 | STRIDES: [1, 2, 2, 2]
9 | DEPTHS: [3, 4, 6, 3]
10 | WIDTHS: [256, 512, 1024, 2048]
11 | BOT_MULS: [0.25, 0.25, 0.25, 0.25]
12 | GROUP_WS: [64, 128, 256, 512]
13 | OPTIM:
14 | LR_POLICY: cos
15 | BASE_LR: 0.4
16 | MAX_EPOCH: 100
17 | MOMENTUM: 0.9
18 | WEIGHT_DECAY: 5e-5
19 | TRAIN:
20 | DATASET: imagenet
21 | IM_SIZE: 224
22 | BATCH_SIZE: 512
23 | INTRA_IMGS: 1
24 | SYNC_BN: True
25 | USE_BAKE: True
26 | OMEGA: 0.5
27 | TEMP: 4.0
28 | LAMBDA: 1.0
29 | GLOBAL_KE: True
30 | TEST:
31 | DATASET: imagenet
32 | IM_SIZE: 256
33 | BATCH_SIZE: 400
34 |
--------------------------------------------------------------------------------
/imagenet/configs/resnet/R-101-1x64d_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: anynet
3 | NUM_CLASSES: 1000
4 | ANYNET:
5 | STEM_TYPE: res_stem_in
6 | STEM_W: 64
7 | BLOCK_TYPE: res_bottleneck_block
8 | STRIDES: [1, 2, 2, 2]
9 | DEPTHS: [3, 4, 23, 3]
10 | WIDTHS: [256, 512, 1024, 2048]
11 | BOT_MULS: [0.25, 0.25, 0.25, 0.25]
12 | GROUP_WS: [64, 128, 256, 512]
13 | OPTIM:
14 | LR_POLICY: cos
15 | BASE_LR: 0.2
16 | MAX_EPOCH: 100
17 | MOMENTUM: 0.9
18 | WEIGHT_DECAY: 5e-5
19 | TRAIN:
20 | DATASET: imagenet
21 | IM_SIZE: 224
22 | BATCH_SIZE: 256
23 | TEST:
24 | DATASET: imagenet
25 | IM_SIZE: 256
26 | BATCH_SIZE: 200
27 |
--------------------------------------------------------------------------------
/imagenet/configs/resnet/R-152-1x64d_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: anynet
3 | NUM_CLASSES: 1000
4 | ANYNET:
5 | STEM_TYPE: res_stem_in
6 | STEM_W: 64
7 | BLOCK_TYPE: res_bottleneck_block
8 | STRIDES: [1, 2, 2, 2]
9 | DEPTHS: [3, 8, 36, 3]
10 | WIDTHS: [256, 512, 1024, 2048]
11 | BOT_MULS: [0.25, 0.25, 0.25, 0.25]
12 | GROUP_WS: [64, 128, 256, 512]
13 | OPTIM:
14 | LR_POLICY: cos
15 | BASE_LR: 0.2
16 | MAX_EPOCH: 100
17 | MOMENTUM: 0.9
18 | WEIGHT_DECAY: 5e-5
19 | TRAIN:
20 | DATASET: imagenet
21 | IM_SIZE: 224
22 | BATCH_SIZE: 256
23 | TEST:
24 | DATASET: imagenet
25 | IM_SIZE: 256
26 | BATCH_SIZE: 200
27 |
--------------------------------------------------------------------------------
/imagenet/configs/resnet/R-50-1x64d_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: anynet
3 | NUM_CLASSES: 1000
4 | ANYNET:
5 | STEM_TYPE: res_stem_in
6 | STEM_W: 64
7 | BLOCK_TYPE: res_bottleneck_block
8 | STRIDES: [1, 2, 2, 2]
9 | DEPTHS: [3, 4, 6, 3]
10 | WIDTHS: [256, 512, 1024, 2048]
11 | BOT_MULS: [0.25, 0.25, 0.25, 0.25]
12 | GROUP_WS: [64, 128, 256, 512]
13 | OPTIM:
14 | LR_POLICY: cos
15 | BASE_LR: 0.2
16 | MAX_EPOCH: 100
17 | MOMENTUM: 0.9
18 | WEIGHT_DECAY: 5e-5
19 | TRAIN:
20 | DATASET: imagenet
21 | IM_SIZE: 224
22 | BATCH_SIZE: 256
23 | TEST:
24 | DATASET: imagenet
25 | IM_SIZE: 256
26 | BATCH_SIZE: 200
27 |
--------------------------------------------------------------------------------
/imagenet/configs/resnext/BAKE-X-101-32x4d_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: anynet
3 | NUM_CLASSES: 1000
4 | ANYNET:
5 | STEM_TYPE: res_stem_in
6 | STEM_W: 64
7 | BLOCK_TYPE: res_bottleneck_block
8 | STRIDES: [1, 2, 2, 2]
9 | DEPTHS: [3, 4, 23, 3]
10 | WIDTHS: [256, 512, 1024, 2048]
11 | BOT_MULS: [0.5, 0.5, 0.5, 0.5]
12 | GROUP_WS: [4, 8, 16, 32]
13 | OPTIM:
14 | LR_POLICY: cos
15 | BASE_LR: 0.4
16 | MAX_EPOCH: 100
17 | MOMENTUM: 0.9
18 | WEIGHT_DECAY: 5e-5
19 | TRAIN:
20 | DATASET: imagenet
21 | IM_SIZE: 224
22 | BATCH_SIZE: 512
23 | INTRA_IMGS: 1
24 | SYNC_BN: True
25 | USE_BAKE: True
26 | OMEGA: 0.5
27 | TEMP: 4.0
28 | LAMBDA: 1.0
29 | GLOBAL_KE: True
30 | TEST:
31 | DATASET: imagenet
32 | IM_SIZE: 256
33 | BATCH_SIZE: 400
34 |
--------------------------------------------------------------------------------
/imagenet/configs/resnext/BAKE-X-152-32x4d_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: anynet
3 | NUM_CLASSES: 1000
4 | ANYNET:
5 | STEM_TYPE: res_stem_in
6 | STEM_W: 64
7 | BLOCK_TYPE: res_bottleneck_block
8 | STRIDES: [1, 2, 2, 2]
9 | DEPTHS: [3, 8, 36, 3]
10 | WIDTHS: [256, 512, 1024, 2048]
11 | BOT_MULS: [0.5, 0.5, 0.5, 0.5]
12 | GROUP_WS: [4, 8, 16, 32]
13 | OPTIM:
14 | LR_POLICY: cos
15 | BASE_LR: 0.4
16 | MAX_EPOCH: 100
17 | MOMENTUM: 0.9
18 | WEIGHT_DECAY: 5e-5
19 | TRAIN:
20 | DATASET: imagenet
21 | IM_SIZE: 224
22 | BATCH_SIZE: 512
23 | INTRA_IMGS: 1
24 | SYNC_BN: True
25 | USE_BAKE: True
26 | OMEGA: 0.5
27 | TEMP: 4.0
28 | LAMBDA: 1.0
29 | GLOBAL_KE: True
30 | TEST:
31 | DATASET: imagenet
32 | IM_SIZE: 256
33 | BATCH_SIZE: 400
34 |
--------------------------------------------------------------------------------
/imagenet/configs/resnext/X-101-32x4d_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: anynet
3 | NUM_CLASSES: 1000
4 | ANYNET:
5 | STEM_TYPE: res_stem_in
6 | STEM_W: 64
7 | BLOCK_TYPE: res_bottleneck_block
8 | STRIDES: [1, 2, 2, 2]
9 | DEPTHS: [3, 4, 23, 3]
10 | WIDTHS: [256, 512, 1024, 2048]
11 | BOT_MULS: [0.5, 0.5, 0.5, 0.5]
12 | GROUP_WS: [4, 8, 16, 32]
13 | OPTIM:
14 | LR_POLICY: cos
15 | BASE_LR: 0.2
16 | MAX_EPOCH: 100
17 | MOMENTUM: 0.9
18 | WEIGHT_DECAY: 5e-5
19 | TRAIN:
20 | DATASET: imagenet
21 | IM_SIZE: 224
22 | BATCH_SIZE: 256
23 | TEST:
24 | DATASET: imagenet
25 | IM_SIZE: 256
26 | BATCH_SIZE: 200
27 |
--------------------------------------------------------------------------------
/imagenet/configs/resnext/X-152-32x4d_dds_8gpu.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: anynet
3 | NUM_CLASSES: 1000
4 | ANYNET:
5 | STEM_TYPE: res_stem_in
6 | STEM_W: 64
7 | BLOCK_TYPE: res_bottleneck_block
8 | STRIDES: [1, 2, 2, 2]
9 | DEPTHS: [3, 8, 36, 3]
10 | WIDTHS: [256, 512, 1024, 2048]
11 | BOT_MULS: [0.5, 0.5, 0.5, 0.5]
12 | GROUP_WS: [4, 8, 16, 32]
13 | OPTIM:
14 | LR_POLICY: cos
15 | BASE_LR: 0.2
16 | MAX_EPOCH: 100
17 | MOMENTUM: 0.9
18 | WEIGHT_DECAY: 5e-5
19 | TRAIN:
20 | DATASET: imagenet
21 | IM_SIZE: 224
22 | BATCH_SIZE: 256
23 | TEST:
24 | DATASET: imagenet
25 | IM_SIZE: 256
26 | BATCH_SIZE: 200
27 |
--------------------------------------------------------------------------------
/imagenet/pycls/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxgeee/BAKE/07c4f668ea19311d5b50121026e73d2f035d5765/imagenet/pycls/__init__.py
--------------------------------------------------------------------------------
/imagenet/pycls/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxgeee/BAKE/07c4f668ea19311d5b50121026e73d2f035d5765/imagenet/pycls/core/__init__.py
--------------------------------------------------------------------------------
/imagenet/pycls/core/benchmark.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Benchmarking functions."""
9 |
10 | import pycls.core.logging as logging
11 | import pycls.datasets.loader as loader
12 | import torch
13 | from pycls.core.config import cfg
14 | from pycls.core.timer import Timer
15 |
16 |
17 | logger = logging.get_logger(__name__)
18 |
19 |
20 | @torch.no_grad()
21 | def compute_time_eval(model):
22 | """Computes precise model forward test time using dummy data."""
23 | # Use eval mode
24 | model.eval()
25 | # Generate a dummy mini-batch and copy data to GPU
26 | im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS)
27 | inputs = torch.zeros(batch_size, 3, im_size, im_size).cuda(non_blocking=False)
28 | # Compute precise forward pass time
29 | timer = Timer()
30 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
31 | for cur_iter in range(total_iter):
32 | # Reset the timers after the warmup phase
33 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
34 | timer.reset()
35 | # Forward
36 | timer.tic()
37 | model(inputs)
38 | torch.cuda.synchronize()
39 | timer.toc()
40 | return timer.average_time
41 |
42 |
43 | def compute_time_train(model, loss_fun):
44 | """Computes precise model forward + backward time using dummy data."""
45 | # Use train mode
46 | model.train()
47 | # Generate a dummy mini-batch and copy data to GPU
48 | im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS)
49 | inputs = torch.rand(batch_size, 3, im_size, im_size).cuda(non_blocking=False)
50 | labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False)
51 | # Cache BatchNorm2D running stats
52 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
53 | bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns]
54 | # Compute precise forward backward pass time
55 | fw_timer, bw_timer = Timer(), Timer()
56 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
57 | for cur_iter in range(total_iter):
58 | # Reset the timers after the warmup phase
59 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
60 | fw_timer.reset()
61 | bw_timer.reset()
62 | # Forward
63 | fw_timer.tic()
64 | _, preds = model(inputs)
65 | loss = loss_fun(preds, labels)
66 | torch.cuda.synchronize()
67 | fw_timer.toc()
68 | # Backward
69 | bw_timer.tic()
70 | loss.backward()
71 | torch.cuda.synchronize()
72 | bw_timer.toc()
73 | # Restore BatchNorm2D running stats
74 | for bn, (mean, var) in zip(bns, bn_stats):
75 | bn.running_mean, bn.running_var = mean, var
76 | return fw_timer.average_time, bw_timer.average_time
77 |
78 |
79 | def compute_time_loader(data_loader):
80 | """Computes loader time."""
81 | timer = Timer()
82 | loader.shuffle(data_loader, 0)
83 | data_loader_iterator = iter(data_loader)
84 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
85 | total_iter = min(total_iter, len(data_loader))
86 | for cur_iter in range(total_iter):
87 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER:
88 | timer.reset()
89 | timer.tic()
90 | next(data_loader_iterator)
91 | timer.toc()
92 | return timer.average_time
93 |
94 |
95 | def compute_time_full(model, loss_fun, train_loader, test_loader):
96 | """Times model and data loader."""
97 | logger.info("Computing model and loader timings...")
98 | # Compute timings
99 | test_fw_time = compute_time_eval(model)
100 | train_fw_time, train_bw_time = compute_time_train(model, loss_fun)
101 | train_fw_bw_time = train_fw_time + train_bw_time
102 | train_loader_time = compute_time_loader(train_loader)
103 | # Output iter timing
104 | iter_times = {
105 | "test_fw_time": test_fw_time,
106 | "train_fw_time": train_fw_time,
107 | "train_bw_time": train_bw_time,
108 | "train_fw_bw_time": train_fw_bw_time,
109 | "train_loader_time": train_loader_time,
110 | }
111 | logger.info(logging.dump_log_data(iter_times, "iter_times"))
112 | # Output epoch timing
113 | epoch_times = {
114 | "test_fw_time": test_fw_time * len(test_loader),
115 | "train_fw_time": train_fw_time * len(train_loader),
116 | "train_bw_time": train_bw_time * len(train_loader),
117 | "train_fw_bw_time": train_fw_bw_time * len(train_loader),
118 | "train_loader_time": train_loader_time * len(train_loader),
119 | }
120 | logger.info(logging.dump_log_data(epoch_times, "epoch_times"))
121 | # Compute data loader overhead (assuming DATA_LOADER.NUM_WORKERS>1)
122 | overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time
123 | logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100))
124 |
--------------------------------------------------------------------------------
/imagenet/pycls/core/builders.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Model and loss construction functions."""
9 |
10 | import torch
11 | from pycls.core.config import cfg
12 | from pycls.models.anynet import AnyNet
13 | from pycls.models.effnet import EffNet
14 | from pycls.models.resnet import ResNet
15 | from pycls.models.mobilenetv2 import MobileNetV2
16 | from pycls.models.resnest import ResNeSt
17 |
18 |
19 | # Supported models
20 | _models = {
21 | "anynet": AnyNet,
22 | "effnet": EffNet,
23 | "resnet": ResNet,
24 | "mobilenetv2": MobileNetV2,
25 | "resnest": ResNeSt,
26 | }
27 |
28 | # Supported loss functions
29 | _loss_funs = {"cross_entropy": torch.nn.CrossEntropyLoss}
30 |
31 |
32 | def get_model():
33 | """Gets the model class specified in the config."""
34 | err_str = "Model type '{}' not supported"
35 | assert cfg.MODEL.TYPE in _models.keys(), err_str.format(cfg.MODEL.TYPE)
36 | return _models[cfg.MODEL.TYPE]
37 |
38 |
39 | def get_loss_fun():
40 | """Gets the loss function class specified in the config."""
41 | err_str = "Loss function type '{}' not supported"
42 | assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS)
43 | return _loss_funs[cfg.MODEL.LOSS_FUN]
44 |
45 |
46 | def build_model():
47 | """Builds the model."""
48 | return get_model()()
49 |
50 |
51 | def build_loss_fun():
52 | """Build the loss function."""
53 | return get_loss_fun()()
54 |
55 |
56 | def register_model(name, ctor):
57 | """Registers a model dynamically."""
58 | _models[name] = ctor
59 |
60 |
61 | def register_loss_fun(name, ctor):
62 | """Registers a loss function dynamically."""
63 | _loss_funs[name] = ctor
64 |
--------------------------------------------------------------------------------
/imagenet/pycls/core/checkpoint.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Functions that handle saving and loading of checkpoints."""
9 |
10 | import os
11 |
12 | import pycls.core.distributed as dist
13 | import torch
14 | from pycls.core.config import cfg
15 |
16 |
17 | # Common prefix for checkpoint file names
18 | _NAME_PREFIX = "model_epoch_"
19 | # Checkpoints directory name
20 | _DIR_NAME = "checkpoints"
21 |
22 |
23 | def get_checkpoint_dir():
24 | """Retrieves the location for storing checkpoints."""
25 | return os.path.join(cfg.OUT_DIR, _DIR_NAME)
26 |
27 |
28 | def get_checkpoint(epoch):
29 | """Retrieves the path to a checkpoint file."""
30 | name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch)
31 | return os.path.join(get_checkpoint_dir(), name)
32 |
33 |
34 | def get_last_checkpoint():
35 | """Retrieves the most recent checkpoint (highest epoch number)."""
36 | checkpoint_dir = get_checkpoint_dir()
37 | # Checkpoint file names are in lexicographic order
38 | checkpoints = [f for f in os.listdir(checkpoint_dir) if _NAME_PREFIX in f]
39 | last_checkpoint_name = sorted(checkpoints)[-1]
40 | return os.path.join(checkpoint_dir, last_checkpoint_name)
41 |
42 |
43 | def has_checkpoint():
44 | """Determines if there are checkpoints available."""
45 | checkpoint_dir = get_checkpoint_dir()
46 | if not os.path.exists(checkpoint_dir):
47 | return False
48 | return any(_NAME_PREFIX in f for f in os.listdir(checkpoint_dir))
49 |
50 |
51 | def save_checkpoint(model, optimizer, epoch):
52 | """Saves a checkpoint."""
53 | # Save checkpoints only from the master process
54 | if not dist.is_master_proc():
55 | return
56 | # Ensure that the checkpoint dir exists
57 | os.makedirs(get_checkpoint_dir(), exist_ok=True)
58 | # Omit the DDP wrapper in the multi-gpu setting
59 | sd = model.module.state_dict()
60 | # if cfg.NUM_GPUS > 1 else model.state_dict()
61 | # Record the state
62 | checkpoint = {
63 | "epoch": epoch,
64 | "model_state": sd,
65 | "optimizer_state": optimizer.state_dict(),
66 | "cfg": cfg.dump(),
67 | }
68 | # Write the checkpoint
69 | checkpoint_file = get_checkpoint(epoch + 1)
70 | torch.save(checkpoint, checkpoint_file)
71 | return checkpoint_file
72 |
73 |
74 | def load_checkpoint(checkpoint_file, model, optimizer=None):
75 | """Loads the checkpoint from the given file."""
76 | err_str = "Checkpoint '{}' not found"
77 | assert os.path.exists(checkpoint_file), err_str.format(checkpoint_file)
78 | # Load the checkpoint on CPU to avoid GPU mem spike
79 | checkpoint = torch.load(checkpoint_file, map_location="cpu")
80 | # Account for the DDP wrapper in the multi-gpu setting
81 | ms = model.module
82 | # if cfg.NUM_GPUS > 1 else model
83 | ms.load_state_dict(checkpoint["model_state"])
84 | # Load the optimizer state (commonly not done when fine-tuning)
85 | if optimizer:
86 | optimizer.load_state_dict(checkpoint["optimizer_state"])
87 | return checkpoint["epoch"]
88 |
--------------------------------------------------------------------------------
/imagenet/pycls/core/config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Configuration file (powered by YACS)."""
9 |
10 | import argparse
11 | import os
12 | import sys
13 |
14 | from pycls.core.io import cache_url
15 | from yacs.config import CfgNode as CfgNode
16 |
17 |
18 | # Global config object
19 | _C = CfgNode()
20 |
21 | # Example usage:
22 | # from core.config import cfg
23 | cfg = _C
24 |
25 |
26 | # ------------------------------------------------------------------------------------ #
27 | # Model options
28 | # ------------------------------------------------------------------------------------ #
29 | _C.MODEL = CfgNode()
30 |
31 | # Model type
32 | _C.MODEL.TYPE = ""
33 |
34 | # Number of weight layers
35 | _C.MODEL.DEPTH = 0
36 |
37 | # Number of classes
38 | _C.MODEL.NUM_CLASSES = 10
39 |
40 | # Loss function (see pycls/models/loss.py for options)
41 | _C.MODEL.LOSS_FUN = "cross_entropy"
42 |
43 |
44 | # ------------------------------------------------------------------------------------ #
45 | # ResNet options
46 | # ------------------------------------------------------------------------------------ #
47 | _C.RESNET = CfgNode()
48 |
49 | # Transformation function (see pycls/models/resnet.py for options)
50 | _C.RESNET.TRANS_FUN = "basic_transform"
51 |
52 | # Number of groups to use (1 -> ResNet; > 1 -> ResNeXt)
53 | _C.RESNET.NUM_GROUPS = 1
54 |
55 | # Width of each group (64 -> ResNet; 4 -> ResNeXt)
56 | _C.RESNET.WIDTH_PER_GROUP = 64
57 |
58 | # Apply stride to 1x1 conv (True -> MSRA; False -> fb.torch)
59 | _C.RESNET.STRIDE_1X1 = True
60 |
61 |
62 | # ------------------------------------------------------------------------------------ #
63 | # AnyNet options
64 | # ------------------------------------------------------------------------------------ #
65 | _C.ANYNET = CfgNode()
66 |
67 | # Stem type
68 | _C.ANYNET.STEM_TYPE = "simple_stem_in"
69 |
70 | # Stem width
71 | _C.ANYNET.STEM_W = 32
72 |
73 | # Block type
74 | _C.ANYNET.BLOCK_TYPE = "res_bottleneck_block"
75 |
76 | # Depth for each stage (number of blocks in the stage)
77 | _C.ANYNET.DEPTHS = []
78 |
79 | # Width for each stage (width of each block in the stage)
80 | _C.ANYNET.WIDTHS = []
81 |
82 | # Strides for each stage (applies to the first block of each stage)
83 | _C.ANYNET.STRIDES = []
84 |
85 | # Bottleneck multipliers for each stage (applies to bottleneck block)
86 | _C.ANYNET.BOT_MULS = []
87 |
88 | # Group widths for each stage (applies to bottleneck block)
89 | _C.ANYNET.GROUP_WS = []
90 |
91 | # Whether SE is enabled for res_bottleneck_block
92 | _C.ANYNET.SE_ON = False
93 |
94 | # SE ratio
95 | _C.ANYNET.SE_R = 0.25
96 |
97 |
98 | # ------------------------------------------------------------------------------------ #
99 | # RegNet options
100 | # ------------------------------------------------------------------------------------ #
101 | _C.REGNET = CfgNode()
102 |
103 | # Stem type
104 | _C.REGNET.STEM_TYPE = "simple_stem_in"
105 |
106 | # Stem width
107 | _C.REGNET.STEM_W = 32
108 |
109 | # Block type
110 | _C.REGNET.BLOCK_TYPE = "res_bottleneck_block"
111 |
112 | # Stride of each stage
113 | _C.REGNET.STRIDE = 2
114 |
115 | # Squeeze-and-Excitation (RegNetY)
116 | _C.REGNET.SE_ON = False
117 | _C.REGNET.SE_R = 0.25
118 |
119 | # Depth
120 | _C.REGNET.DEPTH = 10
121 |
122 | # Initial width
123 | _C.REGNET.W0 = 32
124 |
125 | # Slope
126 | _C.REGNET.WA = 5.0
127 |
128 | # Quantization
129 | _C.REGNET.WM = 2.5
130 |
131 | # Group width
132 | _C.REGNET.GROUP_W = 16
133 |
134 | # Bottleneck multiplier (bm = 1 / b from the paper)
135 | _C.REGNET.BOT_MUL = 1.0
136 |
137 |
138 | # ------------------------------------------------------------------------------------ #
139 | # ResNeSt options
140 | # ------------------------------------------------------------------------------------ #
141 | _C.RESNEST = CfgNode()
142 |
143 | # Depth for each stage (number of blocks in the stage)
144 | _C.RESNEST.DEPTHS = []
145 |
146 | # Width for each stage (width of each block in the stage)
147 | _C.RESNEST.STEM_W = 32
148 |
149 |
150 | # ------------------------------------------------------------------------------------ #
151 | # EfficientNet options
152 | # ------------------------------------------------------------------------------------ #
153 | _C.EN = CfgNode()
154 |
155 | # Stem width
156 | _C.EN.STEM_W = 32
157 |
158 | # Depth for each stage (number of blocks in the stage)
159 | _C.EN.DEPTHS = []
160 |
161 | # Width for each stage (width of each block in the stage)
162 | _C.EN.WIDTHS = []
163 |
164 | # Expansion ratios for MBConv blocks in each stage
165 | _C.EN.EXP_RATIOS = []
166 |
167 | # Squeeze-and-Excitation (SE) ratio
168 | _C.EN.SE_R = 0.25
169 |
170 | # Strides for each stage (applies to the first block of each stage)
171 | _C.EN.STRIDES = []
172 |
173 | # Kernel sizes for each stage
174 | _C.EN.KERNELS = []
175 |
176 | # Head width
177 | _C.EN.HEAD_W = 1280
178 |
179 | # Drop connect ratio
180 | _C.EN.DC_RATIO = 0.0
181 |
182 | # Dropout ratio
183 | _C.EN.DROPOUT_RATIO = 0.0
184 |
185 |
186 | # ------------------------------------------------------------------------------------ #
187 | # Batch norm options
188 | # ------------------------------------------------------------------------------------ #
189 | _C.BN = CfgNode()
190 |
191 | # BN epsilon
192 | _C.BN.EPS = 1e-5
193 |
194 | # BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2)
195 | _C.BN.MOM = 0.1
196 |
197 | # Precise BN stats
198 | _C.BN.USE_PRECISE_STATS = True
199 | _C.BN.NUM_SAMPLES_PRECISE = 8192
200 |
201 | # Initialize the gamma of the final BN of each block to zero
202 | _C.BN.ZERO_INIT_FINAL_GAMMA = False
203 |
204 | # Use a different weight decay for BN layers
205 | _C.BN.USE_CUSTOM_WEIGHT_DECAY = False
206 | _C.BN.CUSTOM_WEIGHT_DECAY = 0.0
207 |
208 |
209 | # ------------------------------------------------------------------------------------ #
210 | # Optimizer options
211 | # ------------------------------------------------------------------------------------ #
212 | _C.OPTIM = CfgNode()
213 |
214 | # Base learning rate
215 | _C.OPTIM.BASE_LR = 0.1
216 |
217 | # Learning rate policy select from {'cos', 'exp', 'steps'}
218 | _C.OPTIM.LR_POLICY = "cos"
219 |
220 | # Exponential decay factor
221 | _C.OPTIM.GAMMA = 0.1
222 |
223 | # Steps for 'steps' policy (in epochs)
224 | _C.OPTIM.STEPS = []
225 |
226 | # Learning rate multiplier for 'steps' policy
227 | _C.OPTIM.LR_MULT = 0.1
228 |
229 | # Maximal number of epochs
230 | _C.OPTIM.MAX_EPOCH = 200
231 |
232 | # Momentum
233 | _C.OPTIM.MOMENTUM = 0.9
234 |
235 | # Momentum dampening
236 | _C.OPTIM.DAMPENING = 0.0
237 |
238 | # Nesterov momentum
239 | _C.OPTIM.NESTEROV = True
240 |
241 | # L2 regularization
242 | _C.OPTIM.WEIGHT_DECAY = 5e-4
243 |
244 | # Start the warm up from OPTIM.BASE_LR * OPTIM.WARMUP_FACTOR
245 | _C.OPTIM.WARMUP_FACTOR = 0.1
246 |
247 | # Gradually warm up the OPTIM.BASE_LR over this number of epochs
248 | _C.OPTIM.WARMUP_EPOCHS = 5
249 |
250 |
251 | # ------------------------------------------------------------------------------------ #
252 | # Training options
253 | # ------------------------------------------------------------------------------------ #
254 | _C.TRAIN = CfgNode()
255 |
256 | # Dataset and split
257 | _C.TRAIN.DATASET = ""
258 | _C.TRAIN.SPLIT = "train"
259 |
260 | # Total mini-batch size
261 | _C.TRAIN.BATCH_SIZE = 128
262 | _C.TRAIN.INTRA_IMGS = 0
263 |
264 | # Image size
265 | _C.TRAIN.IM_SIZE = 224
266 |
267 | # Evaluate model on test data every eval period epochs
268 | _C.TRAIN.EVAL_PERIOD = 1
269 |
270 | # Save model checkpoint every checkpoint period epochs
271 | _C.TRAIN.CHECKPOINT_PERIOD = 1
272 |
273 | # Resume training from the latest checkpoint in the output directory
274 | _C.TRAIN.AUTO_RESUME = True
275 |
276 | # Weights to start training from
277 | _C.TRAIN.WEIGHTS = ""
278 |
279 | _C.TRAIN.SYNC_BN = False
280 | _C.TRAIN.SHUFFLE_BN = False
281 |
282 | _C.TRAIN.USE_BAKE = False
283 | _C.TRAIN.OMEGA = 0.5
284 | _C.TRAIN.TEMP = 4.
285 | _C.TRAIN.LAMBDA = 0.
286 | _C.TRAIN.GLOBAL_KE = True
287 |
288 | _C.TRAIN.IS_CUTMIX = False
289 | _C.TRAIN.CUTMIX = CfgNode()
290 | _C.TRAIN.CUTMIX.PROB = 1.0
291 | _C.TRAIN.CUTMIX.BETA = 1.0
292 | _C.TRAIN.CUTMIX.SMOOTHING = 0.0
293 | _C.TRAIN.CUTMIX.OFF_EPOCH = 5
294 |
295 | _C.TRAIN.JITTER = False
296 |
297 | # ------------------------------------------------------------------------------------ #
298 | # Testing options
299 | # ------------------------------------------------------------------------------------ #
300 | _C.TEST = CfgNode()
301 |
302 | # Dataset and split
303 | _C.TEST.DATASET = ""
304 | _C.TEST.SPLIT = "val"
305 |
306 | # Total mini-batch size
307 | _C.TEST.BATCH_SIZE = 200
308 |
309 | # Image size
310 | _C.TEST.IM_SIZE = 256
311 |
312 | # Weights to use for testing
313 | _C.TEST.WEIGHTS = ""
314 |
315 |
316 | # ------------------------------------------------------------------------------------ #
317 | # Common train/test data loader options
318 | # ------------------------------------------------------------------------------------ #
319 | _C.DATA_LOADER = CfgNode()
320 |
321 | # Number of data loader workers per process
322 | _C.DATA_LOADER.NUM_WORKERS = 6
323 |
324 | # Load data to pinned host memory
325 | _C.DATA_LOADER.PIN_MEMORY = True
326 |
327 |
328 | # ------------------------------------------------------------------------------------ #
329 | # Memory options
330 | # ------------------------------------------------------------------------------------ #
331 | _C.MEM = CfgNode()
332 |
333 | # Perform ReLU inplace
334 | _C.MEM.RELU_INPLACE = True
335 |
336 |
337 | # ------------------------------------------------------------------------------------ #
338 | # CUDNN options
339 | # ------------------------------------------------------------------------------------ #
340 | _C.CUDNN = CfgNode()
341 |
342 | # Perform benchmarking to select the fastest CUDNN algorithms to use
343 | # Note that this may increase the memory usage and will likely not result
344 | # in overall speedups when variable size inputs are used (e.g. COCO training)
345 | _C.CUDNN.BENCHMARK = True
346 |
347 |
348 | # ------------------------------------------------------------------------------------ #
349 | # Precise timing options
350 | # ------------------------------------------------------------------------------------ #
351 | _C.PREC_TIME = CfgNode()
352 |
353 | # Number of iterations to warm up the caches
354 | _C.PREC_TIME.WARMUP_ITER = 3
355 |
356 | # Number of iterations to compute avg time
357 | _C.PREC_TIME.NUM_ITER = 30
358 |
359 |
360 | # ------------------------------------------------------------------------------------ #
361 | # Misc options
362 | # ------------------------------------------------------------------------------------ #
363 |
364 | _C.LAUNCHER = "slurm"
365 | _C.PORT = 8080
366 | _C.RANK = 0
367 | _C.WORLD_SIZE = 1
368 | _C.NGPUS_PER_NODE = 1
369 | _C.GPU = 0
370 |
371 | # Number of GPUs to use (applies to both training and testing)
372 | _C.NUM_GPUS = 1
373 |
374 | # Output directory
375 | _C.OUT_DIR = "."
376 |
377 | # Config destination (in OUT_DIR)
378 | _C.CFG_DEST = "config.yaml"
379 |
380 | # Note that non-determinism may still be present due to non-deterministic
381 | # operator implementations in GPU operator libraries
382 | _C.RNG_SEED = 1
383 |
384 | # Log destination ('stdout' or 'file')
385 | _C.LOG_DEST = "stdout"
386 |
387 | # Log period in iters
388 | _C.LOG_PERIOD = 10
389 |
390 | # Distributed backend
391 | _C.DIST_BACKEND = "nccl"
392 |
393 | # Hostname and port range for multi-process groups (actual port selected randomly)
394 | _C.HOST = "localhost"
395 | _C.PORT_RANGE = [10000, 65000]
396 |
397 | # Models weights referred to by URL are downloaded to this local cache
398 | _C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache"
399 |
400 |
401 | # ------------------------------------------------------------------------------------ #
402 | # Deprecated keys
403 | # ------------------------------------------------------------------------------------ #
404 |
405 | _C.register_deprecated_key("PREC_TIME.BATCH_SIZE")
406 | _C.register_deprecated_key("PREC_TIME.ENABLED")
407 | # _C.register_deprecated_key("PORT")
408 |
409 |
410 | def assert_and_infer_cfg(cache_urls=True):
411 | """Checks config values invariants."""
412 | err_str = "The first lr step must start at 0"
413 | assert not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0, err_str
414 | data_splits = ["train", "val", "test"]
415 | err_str = "Data split '{}' not supported"
416 | assert _C.TRAIN.SPLIT in data_splits, err_str.format(_C.TRAIN.SPLIT)
417 | assert _C.TEST.SPLIT in data_splits, err_str.format(_C.TEST.SPLIT)
418 | err_str = "Mini-batch size should be a multiple of NUM_GPUS."
419 | assert _C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
420 | assert _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0, err_str
421 | err_str = "Log destination '{}' not supported"
422 | assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST)
423 | if cache_urls:
424 | cache_cfg_urls()
425 |
426 |
427 | def cache_cfg_urls():
428 | """Download URLs in config, cache them, and rewrite cfg to use cached file."""
429 | _C.TRAIN.WEIGHTS = cache_url(_C.TRAIN.WEIGHTS, _C.DOWNLOAD_CACHE)
430 | _C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE)
431 |
432 |
433 | def dump_cfg():
434 | """Dumps the config to the output directory."""
435 | cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST)
436 | with open(cfg_file, "w") as f:
437 | _C.dump(stream=f)
438 |
439 |
440 | def load_cfg(out_dir, cfg_dest="config.yaml"):
441 | """Loads config from specified output directory."""
442 | cfg_file = os.path.join(out_dir, cfg_dest)
443 | _C.merge_from_file(cfg_file)
444 |
445 |
446 | def load_cfg_fom_args(description="Config file options."):
447 | """Load config from command line arguments and set any specified options."""
448 | parser = argparse.ArgumentParser(description=description)
449 | help_s = "Config file location"
450 | parser.add_argument("--cfg", dest="cfg_file", help=help_s, required=True, type=str)
451 | help_s = "See pycls/core/config.py for all options"
452 | parser.add_argument("opts", help=help_s, default=None, nargs=argparse.REMAINDER)
453 | if len(sys.argv) == 1:
454 | parser.print_help()
455 | sys.exit(1)
456 | args = parser.parse_args()
457 | _C.merge_from_file(args.cfg_file)
458 | _C.merge_from_list(args.opts)
459 |
--------------------------------------------------------------------------------
/imagenet/pycls/core/distributed.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Distributed helpers."""
9 |
10 | import multiprocessing
11 | import os
12 | import subprocess
13 | import random
14 | import signal
15 | import threading
16 | import traceback
17 |
18 | import torch
19 | import torch.distributed as dist
20 | import torch.multiprocessing as mp
21 |
22 | from pycls.core.config import cfg
23 |
24 |
25 | def is_master_proc():
26 | """Determines if the current process is the master process.
27 |
28 | Master process is responsible for logging, writing and loading checkpoints. In
29 | the multi GPU setting, we assign the master role to the rank 0 process. When
30 | training using a single GPU, there is a single process which is considered master.
31 | """
32 | return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() == 0
33 |
34 |
35 | def init_process_group(proc_rank, world_size, port):
36 | """Initializes the default process group."""
37 | # Set the GPU to use
38 | torch.cuda.set_device(proc_rank)
39 | # Initialize the process group
40 | torch.distributed.init_process_group(
41 | backend=cfg.DIST_BACKEND,
42 | init_method="tcp://{}:{}".format(cfg.HOST, port),
43 | world_size=world_size,
44 | rank=proc_rank,
45 | )
46 |
47 |
48 | def destroy_process_group():
49 | """Destroys the default process group."""
50 | torch.distributed.destroy_process_group()
51 |
52 |
53 | def scaled_all_reduce(tensors):
54 | """Performs the scaled all_reduce operation on the provided tensors.
55 |
56 | The input tensors are modified in-place. Currently supports only the sum
57 | reduction operator. The reduced values are scaled by the inverse size of the
58 | process group (equivalent to cfg.NUM_GPUS).
59 | """
60 | # There is no need for reduction in the single-proc case
61 | if cfg.NUM_GPUS == 1:
62 | return tensors
63 | # Queue the reductions
64 | reductions = []
65 | for tensor in tensors:
66 | reduction = torch.distributed.all_reduce(tensor, async_op=True)
67 | reductions.append(reduction)
68 | # Wait for reductions to finish
69 | for reduction in reductions:
70 | reduction.wait()
71 | # Scale the results
72 | for tensor in tensors:
73 | tensor.mul_(1.0 / cfg.NUM_GPUS)
74 | return tensors
75 |
76 |
77 | class ChildException(Exception):
78 | """Wraps an exception from a child process."""
79 |
80 | def __init__(self, child_trace):
81 | super(ChildException, self).__init__(child_trace)
82 |
83 |
84 | class ErrorHandler(object):
85 | """Multiprocessing error handler (based on fairseq's).
86 |
87 | Listens for errors in child processes and propagates the tracebacks to the parent.
88 | """
89 |
90 | def __init__(self, error_queue):
91 | # Shared error queue
92 | self.error_queue = error_queue
93 | # Children processes sharing the error queue
94 | self.children_pids = []
95 | # Start a thread listening to errors
96 | self.error_listener = threading.Thread(target=self.listen, daemon=True)
97 | self.error_listener.start()
98 | # Register the signal handler
99 | signal.signal(signal.SIGUSR1, self.signal_handler)
100 |
101 | def add_child(self, pid):
102 | """Registers a child process."""
103 | self.children_pids.append(pid)
104 |
105 | def listen(self):
106 | """Listens for errors in the error queue."""
107 | # Wait until there is an error in the queue
108 | child_trace = self.error_queue.get()
109 | # Put the error back for the signal handler
110 | self.error_queue.put(child_trace)
111 | # Invoke the signal handler
112 | os.kill(os.getpid(), signal.SIGUSR1)
113 |
114 | def signal_handler(self, _sig_num, _stack_frame):
115 | """Signal handler."""
116 | # Kill children processes
117 | for pid in self.children_pids:
118 | os.kill(pid, signal.SIGINT)
119 | # Propagate the error from the child process
120 | raise ChildException(self.error_queue.get())
121 |
122 |
123 | def run(proc_rank, world_size, port, error_queue, fun, fun_args, fun_kwargs):
124 | """Runs a function from a child process."""
125 | try:
126 | # Initialize the process group
127 | init_process_group(proc_rank, world_size, port)
128 | # Run the function
129 | fun(*fun_args, **fun_kwargs)
130 | except KeyboardInterrupt:
131 | # Killed by the parent process
132 | pass
133 | except Exception:
134 | # Propagate exception to the parent process
135 | error_queue.put(traceback.format_exc())
136 | finally:
137 | # Destroy the process group
138 | destroy_process_group()
139 |
140 | def multi_proc_run(num_proc, fun, fun_args=(), fun_kwargs=None):
141 | init_dist()
142 | cfg.freeze()
143 | fun_kwargs = fun_kwargs if fun_kwargs else {}
144 | fun(*fun_args, **fun_kwargs)
145 | # return
146 |
147 | def init_dist(backend="nccl"):
148 | if mp.get_start_method(allow_none=True) is None:
149 | mp.set_start_method("spawn")
150 |
151 | if cfg.LAUNCHER == "pytorch":
152 | init_dist_pytorch(backend)
153 | elif cfg.LAUNCHER == "slurm":
154 | init_dist_slurm(backend)
155 | else:
156 | raise ValueError("Invalid launcher type: {}".format(cfg.LAUNCHER))
157 |
158 |
159 | def init_dist_pytorch(backend="nccl"):
160 | cfg.RANK = int(os.environ["LOCAL_RANK"])
161 | if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
162 | cfg.NGPUS_PER_NODE = len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
163 | else:
164 | cfg.NGPUS_PER_NODE = torch.cuda.device_count()
165 | assert cfg.NGPUS_PER_NODE>0, "CUDA is not supported"
166 | cfg.GPU = cfg.RANK
167 | torch.cuda.set_device(cfg.GPU)
168 | dist.init_process_group(backend=backend)
169 | cfg.NUM_GPUS = dist.get_world_size()
170 | cfg.WORLD_SIZE = cfg.NUM_GPUS
171 |
172 | def init_dist_slurm(backend="nccl"):
173 | cfg.RANK = int(os.environ["SLURM_PROCID"])
174 | cfg.WORLD_SIZE = int(os.environ["SLURM_NTASKS"])
175 | node_list = os.environ["SLURM_NODELIST"]
176 | if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
177 | cfg.NGPUS_PER_NODE = len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
178 | else:
179 | cfg.NGPUS_PER_NODE = torch.cuda.device_count()
180 | assert cfg.NGPUS_PER_NODE>0, "CUDA is not supported"
181 | cfg.GPU = cfg.RANK % cfg.NGPUS_PER_NODE
182 | torch.cuda.set_device(cfg.GPU)
183 | addr = subprocess.getoutput(
184 | "scontrol show hostname {} | head -n1".format(node_list)
185 | )
186 | os.environ["MASTER_PORT"] = str(cfg.PORT)
187 | os.environ["MASTER_ADDR"] = addr
188 | os.environ["WORLD_SIZE"] = str(cfg.WORLD_SIZE)
189 | os.environ["RANK"] = str(cfg.RANK)
190 | dist.init_process_group(backend=backend)
191 | cfg.NUM_GPUS = dist.get_world_size()
192 |
--------------------------------------------------------------------------------
/imagenet/pycls/core/io.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """IO utilities (adapted from Detectron)"""
9 |
10 | import logging
11 | import os
12 | import re
13 | import sys
14 | from urllib import request as urlrequest
15 |
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | _PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls"
20 |
21 |
22 | def cache_url(url_or_file, cache_dir):
23 | """Download the file specified by the URL to the cache_dir and return the path to
24 | the cached file. If the argument is not a URL, simply return it as is.
25 | """
26 | is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None
27 | if not is_url:
28 | return url_or_file
29 | url = url_or_file
30 | err_str = "pycls only automatically caches URLs in the pycls S3 bucket: {}"
31 | assert url.startswith(_PYCLS_BASE_URL), err_str.format(_PYCLS_BASE_URL)
32 | cache_file_path = url.replace(_PYCLS_BASE_URL, cache_dir)
33 | if os.path.exists(cache_file_path):
34 | return cache_file_path
35 | cache_file_dir = os.path.dirname(cache_file_path)
36 | if not os.path.exists(cache_file_dir):
37 | os.makedirs(cache_file_dir)
38 | logger.info("Downloading remote file {} to {}".format(url, cache_file_path))
39 | download_url(url, cache_file_path)
40 | return cache_file_path
41 |
42 |
43 | def _progress_bar(count, total):
44 | """Report download progress. Credit:
45 | https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113
46 | """
47 | bar_len = 60
48 | filled_len = int(round(bar_len * count / float(total)))
49 | percents = round(100.0 * count / float(total), 1)
50 | bar = "=" * filled_len + "-" * (bar_len - filled_len)
51 | sys.stdout.write(
52 | " [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024)
53 | )
54 | sys.stdout.flush()
55 | if count >= total:
56 | sys.stdout.write("\n")
57 |
58 |
59 | def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar):
60 | """Download url and write it to dst_file_path. Credit:
61 | https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook
62 | """
63 | req = urlrequest.Request(url)
64 | response = urlrequest.urlopen(req)
65 | total_size = response.info().get("Content-Length").strip()
66 | total_size = int(total_size)
67 | bytes_so_far = 0
68 | with open(dst_file_path, "wb") as f:
69 | while 1:
70 | chunk = response.read(chunk_size)
71 | bytes_so_far += len(chunk)
72 | if not chunk:
73 | break
74 | if progress_hook:
75 | progress_hook(bytes_so_far, total_size)
76 | f.write(chunk)
77 | return bytes_so_far
78 |
--------------------------------------------------------------------------------
/imagenet/pycls/core/logging.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Logging."""
9 |
10 | import builtins
11 | import decimal
12 | import logging
13 | import os
14 | import sys
15 |
16 | import pycls.core.distributed as dist
17 | import simplejson
18 | from pycls.core.config import cfg
19 |
20 |
21 | # Show filename and line number in logs
22 | _FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s"
23 |
24 | # Log file name (for cfg.LOG_DEST = 'file')
25 | _LOG_FILE = "stdout.log"
26 |
27 | # Data output with dump_log_data(data, data_type) will be tagged w/ this
28 | _TAG = "json_stats: "
29 |
30 | # Data output with dump_log_data(data, data_type) will have data[_TYPE]=data_type
31 | _TYPE = "_type"
32 |
33 |
34 | def _suppress_print():
35 | """Suppresses printing from the current process."""
36 |
37 | def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False):
38 | pass
39 |
40 | builtins.print = ignore
41 |
42 |
43 | def setup_logging():
44 | """Sets up the logging."""
45 | # Enable logging only for the master process
46 | if dist.is_master_proc():
47 | # Clear the root logger to prevent any existing logging config
48 | # (e.g. set by another module) from messing with our setup
49 | logging.root.handlers = []
50 | # Construct logging configuration
51 | logging_config = {"level": logging.INFO, "format": _FORMAT}
52 | # Log either to stdout or to a file
53 | if cfg.LOG_DEST == "stdout":
54 | logging_config["stream"] = sys.stdout
55 | else:
56 | logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE)
57 | # Configure logging
58 | logging.basicConfig(**logging_config)
59 | else:
60 | _suppress_print()
61 |
62 |
63 | def get_logger(name):
64 | """Retrieves the logger."""
65 | return logging.getLogger(name)
66 |
67 |
68 | def dump_log_data(data, data_type, prec=4):
69 | """Covert data (a dictionary) into tagged json string for logging."""
70 | data[_TYPE] = data_type
71 | data = float_to_decimal(data, prec)
72 | data_json = simplejson.dumps(data, sort_keys=True, use_decimal=True)
73 | return "{:s}{:s}".format(_TAG, data_json)
74 |
75 |
76 | def float_to_decimal(data, prec=4):
77 | """Convert floats to decimals which allows for fixed width json."""
78 | if isinstance(data, dict):
79 | return {k: float_to_decimal(v, prec) for k, v in data.items()}
80 | if isinstance(data, float):
81 | return decimal.Decimal(("{:." + str(prec) + "f}").format(data))
82 | else:
83 | return data
84 |
85 |
86 | def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE):
87 | """Get all log files in directory containing subdirs of trained models."""
88 | names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n]
89 | files = [os.path.join(log_dir, n, log_file) for n in names]
90 | f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)]
91 | files, names = zip(*f_n_ps) if f_n_ps else ([], [])
92 | return files, names
93 |
94 |
95 | def load_log_data(log_file, data_types_to_skip=()):
96 | """Loads log data into a dictionary of the form data[data_type][metric][index]."""
97 | # Load log_file
98 | assert os.path.exists(log_file), "Log file not found: {}".format(log_file)
99 | with open(log_file, "r") as f:
100 | lines = f.readlines()
101 | # Extract and parse lines that start with _TAG and have a type specified
102 | lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l]
103 | lines = [simplejson.loads(l) for l in lines]
104 | lines = [l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip]
105 | # Generate data structure accessed by data[data_type][index][metric]
106 | data_types = [l[_TYPE] for l in lines]
107 | data = {t: [] for t in data_types}
108 | for t, line in zip(data_types, lines):
109 | del line[_TYPE]
110 | data[t].append(line)
111 | # Generate data structure accessed by data[data_type][metric][index]
112 | for t in data:
113 | metrics = sorted(data[t][0].keys())
114 | err_str = "Inconsistent metrics in log for _type={}: {}".format(t, metrics)
115 | assert all(sorted(d.keys()) == metrics for d in data[t]), err_str
116 | data[t] = {m: [d[m] for d in data[t]] for m in metrics}
117 | return data
118 |
119 |
120 | def sort_log_data(data):
121 | """Sort each data[data_type][metric] by epoch or keep only first instance."""
122 | for t in data:
123 | if "epoch" in data[t]:
124 | assert "epoch_ind" not in data[t] and "epoch_max" not in data[t]
125 | data[t]["epoch_ind"] = [int(e.split("/")[0]) for e in data[t]["epoch"]]
126 | data[t]["epoch_max"] = [int(e.split("/")[1]) for e in data[t]["epoch"]]
127 | epoch = data[t]["epoch_ind"]
128 | if "iter" in data[t]:
129 | assert "iter_ind" not in data[t] and "iter_max" not in data[t]
130 | data[t]["iter_ind"] = [int(i.split("/")[0]) for i in data[t]["iter"]]
131 | data[t]["iter_max"] = [int(i.split("/")[1]) for i in data[t]["iter"]]
132 | itr = zip(epoch, data[t]["iter_ind"], data[t]["iter_max"])
133 | epoch = [e + (i_ind - 1) / i_max for e, i_ind, i_max in itr]
134 | for m in data[t]:
135 | data[t][m] = [v for _, v in sorted(zip(epoch, data[t][m]))]
136 | else:
137 | data[t] = {m: d[0] for m, d in data[t].items()}
138 | return data
139 |
--------------------------------------------------------------------------------
/imagenet/pycls/core/meters.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Meters."""
9 |
10 | from collections import deque
11 |
12 | import numpy as np
13 | import pycls.core.logging as logging
14 | import torch
15 | from pycls.core.config import cfg
16 | from pycls.core.timer import Timer
17 |
18 |
19 | logger = logging.get_logger(__name__)
20 |
21 |
22 | def time_string(seconds):
23 | """Converts time in seconds to a fixed-width string format."""
24 | days, rem = divmod(int(seconds), 24 * 3600)
25 | hrs, rem = divmod(rem, 3600)
26 | mins, secs = divmod(rem, 60)
27 | return "{0:02},{1:02}:{2:02}:{3:02}".format(days, hrs, mins, secs)
28 |
29 |
30 | def topk_errors(preds, labels, ks):
31 | """Computes the top-k error for each k."""
32 | err_str = "Batch dim of predictions and labels must match"
33 | assert preds.size(0) == labels.size(0), err_str
34 | # Find the top max_k predictions for each sample
35 | _top_max_k_vals, top_max_k_inds = torch.topk(
36 | preds, max(ks), dim=1, largest=True, sorted=True
37 | )
38 | # (batch_size, max_k) -> (max_k, batch_size)
39 | top_max_k_inds = top_max_k_inds.t()
40 | # (batch_size, ) -> (max_k, batch_size)
41 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds)
42 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct
43 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels)
44 | # Compute the number of topk correct predictions for each k
45 | topks_correct = [top_max_k_correct[:k, :].contiguous().view(-1).float().sum() for k in ks]
46 | return [(1.0 - x / preds.size(0)) * 100.0 for x in topks_correct]
47 |
48 |
49 | def gpu_mem_usage():
50 | """Computes the GPU memory usage for the current device (MB)."""
51 | mem_usage_bytes = torch.cuda.max_memory_allocated()
52 | return mem_usage_bytes / 1024 / 1024
53 |
54 |
55 | class ScalarMeter(object):
56 | """Measures a scalar value (adapted from Detectron)."""
57 |
58 | def __init__(self, window_size):
59 | self.deque = deque(maxlen=window_size)
60 | self.total = 0.0
61 | self.count = 0
62 |
63 | def reset(self):
64 | self.deque.clear()
65 | self.total = 0.0
66 | self.count = 0
67 |
68 | def add_value(self, value):
69 | self.deque.append(value)
70 | self.count += 1
71 | self.total += value
72 |
73 | def get_win_median(self):
74 | return np.median(self.deque)
75 |
76 | def get_win_avg(self):
77 | return np.mean(self.deque)
78 |
79 | def get_global_avg(self):
80 | return self.total / self.count
81 |
82 |
83 | class TrainMeter(object):
84 | """Measures training stats."""
85 |
86 | def __init__(self, epoch_iters):
87 | self.epoch_iters = epoch_iters
88 | self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
89 | self.iter_timer = Timer()
90 | self.loss = ScalarMeter(cfg.LOG_PERIOD)
91 | self.loss_total = 0.0
92 | self.lr = None
93 | # Current minibatch errors (smoothed over a window)
94 | self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
95 | self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
96 | # Number of misclassified examples
97 | self.num_top1_mis = 0
98 | self.num_top5_mis = 0
99 | self.num_samples = 0
100 |
101 | def reset(self, timer=False):
102 | if timer:
103 | self.iter_timer.reset()
104 | self.loss.reset()
105 | self.loss_total = 0.0
106 | self.lr = None
107 | self.mb_top1_err.reset()
108 | self.mb_top5_err.reset()
109 | self.num_top1_mis = 0
110 | self.num_top5_mis = 0
111 | self.num_samples = 0
112 |
113 | def iter_tic(self):
114 | self.iter_timer.tic()
115 |
116 | def iter_toc(self):
117 | self.iter_timer.toc()
118 |
119 | def update_stats(self, top1_err, top5_err, loss, lr, mb_size):
120 | # Current minibatch stats
121 | self.mb_top1_err.add_value(top1_err)
122 | self.mb_top5_err.add_value(top5_err)
123 | self.loss.add_value(loss)
124 | self.lr = lr
125 | # Aggregate stats
126 | self.num_top1_mis += top1_err * mb_size
127 | self.num_top5_mis += top5_err * mb_size
128 | self.loss_total += loss * mb_size
129 | self.num_samples += mb_size
130 |
131 | def get_iter_stats(self, cur_epoch, cur_iter):
132 | cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1
133 | eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
134 | mem_usage = gpu_mem_usage()
135 | stats = {
136 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
137 | "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters),
138 | "time_avg": self.iter_timer.average_time,
139 | "time_diff": self.iter_timer.diff,
140 | "eta": time_string(eta_sec),
141 | "top1_err": self.mb_top1_err.get_win_median(),
142 | "top5_err": self.mb_top5_err.get_win_median(),
143 | "loss": self.loss.get_win_median(),
144 | "lr": self.lr,
145 | "mem": int(np.ceil(mem_usage)),
146 | }
147 | return stats
148 |
149 | def log_iter_stats(self, cur_epoch, cur_iter):
150 | if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
151 | return
152 | stats = self.get_iter_stats(cur_epoch, cur_iter)
153 | logger.info(logging.dump_log_data(stats, "train_iter"))
154 |
155 | def get_epoch_stats(self, cur_epoch):
156 | cur_iter_total = (cur_epoch + 1) * self.epoch_iters
157 | eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total)
158 | mem_usage = gpu_mem_usage()
159 | top1_err = self.num_top1_mis / self.num_samples
160 | top5_err = self.num_top5_mis / self.num_samples
161 | avg_loss = self.loss_total / self.num_samples
162 | stats = {
163 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
164 | "time_avg": self.iter_timer.average_time,
165 | "eta": time_string(eta_sec),
166 | "top1_err": top1_err,
167 | "top5_err": top5_err,
168 | "loss": avg_loss,
169 | "lr": self.lr,
170 | "mem": int(np.ceil(mem_usage)),
171 | }
172 | return stats
173 |
174 | def log_epoch_stats(self, cur_epoch):
175 | stats = self.get_epoch_stats(cur_epoch)
176 | logger.info(logging.dump_log_data(stats, "train_epoch"))
177 |
178 |
179 | class TestMeter(object):
180 | """Measures testing stats."""
181 |
182 | def __init__(self, max_iter):
183 | self.max_iter = max_iter
184 | self.iter_timer = Timer()
185 | # Current minibatch errors (smoothed over a window)
186 | self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
187 | self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD)
188 | # Min errors (over the full test set)
189 | self.min_top1_err = 100.0
190 | self.min_top5_err = 100.0
191 | # Number of misclassified examples
192 | self.num_top1_mis = 0
193 | self.num_top5_mis = 0
194 | self.num_samples = 0
195 |
196 | def reset(self, min_errs=False):
197 | if min_errs:
198 | self.min_top1_err = 100.0
199 | self.min_top5_err = 100.0
200 | self.iter_timer.reset()
201 | self.mb_top1_err.reset()
202 | self.mb_top5_err.reset()
203 | self.num_top1_mis = 0
204 | self.num_top5_mis = 0
205 | self.num_samples = 0
206 |
207 | def iter_tic(self):
208 | self.iter_timer.tic()
209 |
210 | def iter_toc(self):
211 | self.iter_timer.toc()
212 |
213 | def update_stats(self, top1_err, top5_err, mb_size):
214 | self.mb_top1_err.add_value(top1_err)
215 | self.mb_top5_err.add_value(top5_err)
216 | self.num_top1_mis += top1_err * mb_size
217 | self.num_top5_mis += top5_err * mb_size
218 | self.num_samples += mb_size
219 |
220 | def get_iter_stats(self, cur_epoch, cur_iter):
221 | mem_usage = gpu_mem_usage()
222 | iter_stats = {
223 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
224 | "iter": "{}/{}".format(cur_iter + 1, self.max_iter),
225 | "time_avg": self.iter_timer.average_time,
226 | "time_diff": self.iter_timer.diff,
227 | "top1_err": self.mb_top1_err.get_win_median(),
228 | "top5_err": self.mb_top5_err.get_win_median(),
229 | "mem": int(np.ceil(mem_usage)),
230 | }
231 | return iter_stats
232 |
233 | def log_iter_stats(self, cur_epoch, cur_iter):
234 | if (cur_iter + 1) % cfg.LOG_PERIOD != 0:
235 | return
236 | stats = self.get_iter_stats(cur_epoch, cur_iter)
237 | logger.info(logging.dump_log_data(stats, "test_iter"))
238 |
239 | def get_epoch_stats(self, cur_epoch):
240 | top1_err = self.num_top1_mis / max(1, self.num_samples)
241 | top5_err = self.num_top5_mis / max(1, self.num_samples)
242 | self.min_top1_err = min(self.min_top1_err, top1_err)
243 | self.min_top5_err = min(self.min_top5_err, top5_err)
244 | mem_usage = gpu_mem_usage()
245 | stats = {
246 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH),
247 | "time_avg": self.iter_timer.average_time,
248 | "top1_err": top1_err,
249 | "top5_err": top5_err,
250 | "min_top1_err": self.min_top1_err,
251 | "min_top5_err": self.min_top5_err,
252 | "mem": int(np.ceil(mem_usage)),
253 | }
254 | return stats
255 |
256 | def log_epoch_stats(self, cur_epoch):
257 | stats = self.get_epoch_stats(cur_epoch)
258 | logger.info(logging.dump_log_data(stats, "test_epoch"))
259 |
--------------------------------------------------------------------------------
/imagenet/pycls/core/net.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Functions for manipulating networks."""
9 |
10 | import itertools
11 | import math
12 |
13 | import pycls.core.distributed as dist
14 | import torch
15 | import torch.nn as nn
16 | from pycls.core.config import cfg
17 |
18 |
19 | def init_weights(m):
20 | """Performs ResNet-style weight initialization."""
21 | if isinstance(m, nn.Conv2d):
22 | # Note that there is no bias due to BN
23 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
24 | m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
25 | elif isinstance(m, nn.BatchNorm2d):
26 | zero_init_gamma = cfg.BN.ZERO_INIT_FINAL_GAMMA
27 | zero_init_gamma = hasattr(m, "final_bn") and m.final_bn and zero_init_gamma
28 | m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
29 | m.bias.data.zero_()
30 | elif isinstance(m, nn.Linear):
31 | m.weight.data.normal_(mean=0.0, std=0.01)
32 | m.bias.data.zero_()
33 |
34 |
35 | @torch.no_grad()
36 | def compute_precise_bn_stats(model, loader):
37 | """Computes precise BN stats on training data."""
38 | # Compute the number of minibatches to use
39 | num_iter = int(cfg.BN.NUM_SAMPLES_PRECISE / loader.batch_size / cfg.NUM_GPUS)
40 | num_iter = min(num_iter, len(loader))
41 | # Retrieve the BN layers
42 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
43 | # Initialize BN stats storage for computing mean(mean(batch)) and mean(var(batch))
44 | running_means = [torch.zeros_like(bn.running_mean) for bn in bns]
45 | running_vars = [torch.zeros_like(bn.running_var) for bn in bns]
46 | # Remember momentum values
47 | momentums = [bn.momentum for bn in bns]
48 | # Set momentum to 1.0 to compute BN stats that only reflect the current batch
49 | for bn in bns:
50 | bn.momentum = 1.0
51 | # Average the BN stats for each BN layer over the batches
52 | for inputs, _labels in itertools.islice(loader, num_iter):
53 | model(inputs.cuda())
54 | for i, bn in enumerate(bns):
55 | running_means[i] += bn.running_mean / num_iter
56 | running_vars[i] += bn.running_var / num_iter
57 | # Sync BN stats across GPUs (no reduction if 1 GPU used)
58 | running_means = dist.scaled_all_reduce(running_means)
59 | running_vars = dist.scaled_all_reduce(running_vars)
60 | # Set BN stats and restore original momentum values
61 | for i, bn in enumerate(bns):
62 | bn.running_mean = running_means[i]
63 | bn.running_var = running_vars[i]
64 | bn.momentum = momentums[i]
65 |
66 |
67 | def reset_bn_stats(model):
68 | """Resets running BN stats."""
69 | for m in model.modules():
70 | if isinstance(m, torch.nn.BatchNorm2d):
71 | m.reset_running_stats()
72 |
73 |
74 | def complexity_conv2d(cx, w_in, w_out, k, stride, padding, groups=1, bias=False):
75 | """Accumulates complexity of Conv2D into cx = (h, w, flops, params, acts)."""
76 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
77 | h = (h + 2 * padding - k) // stride + 1
78 | w = (w + 2 * padding - k) // stride + 1
79 | flops += k * k * w_in * w_out * h * w // groups
80 | params += k * k * w_in * w_out // groups
81 | flops += w_out if bias else 0
82 | params += w_out if bias else 0
83 | acts += w_out * h * w
84 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
85 |
86 |
87 | def complexity_batchnorm2d(cx, w_in):
88 | """Accumulates complexity of BatchNorm2D into cx = (h, w, flops, params, acts)."""
89 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
90 | params += 2 * w_in
91 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
92 |
93 |
94 | def complexity_maxpool2d(cx, k, stride, padding):
95 | """Accumulates complexity of MaxPool2d into cx = (h, w, flops, params, acts)."""
96 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
97 | h = (h + 2 * padding - k) // stride + 1
98 | w = (w + 2 * padding - k) // stride + 1
99 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
100 |
101 |
102 | def complexity(model):
103 | """Compute model complexity (model can be model instance or model class)."""
104 | size = cfg.TRAIN.IM_SIZE
105 | cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0}
106 | cx = model.complexity(cx)
107 | return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]}
108 |
109 |
110 | def drop_connect(x, drop_ratio):
111 | """Drop connect (adapted from DARTS)."""
112 | keep_ratio = 1.0 - drop_ratio
113 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
114 | mask.bernoulli_(keep_ratio)
115 | x.div_(keep_ratio)
116 | x.mul_(mask)
117 | return x
118 |
119 |
120 | def get_flat_weights(model):
121 | """Gets all model weights as a single flat vector."""
122 | return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0)
123 |
124 |
125 | def set_flat_weights(model, flat_weights):
126 | """Sets all model weights from a single flat vector."""
127 | k = 0
128 | for p in model.parameters():
129 | n = p.data.numel()
130 | p.data.copy_(flat_weights[k : (k + n)].view_as(p.data))
131 | k += n
132 | assert k == flat_weights.numel()
133 |
--------------------------------------------------------------------------------
/imagenet/pycls/core/optimizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Optimizer."""
9 |
10 | import numpy as np
11 | import torch
12 | from pycls.core.config import cfg
13 |
14 |
15 | def construct_optimizer(model):
16 | """Constructs the optimizer.
17 |
18 | Note that the momentum update in PyTorch differs from the one in Caffe2.
19 | In particular,
20 |
21 | Caffe2:
22 | V := mu * V + lr * g
23 | p := p - V
24 |
25 | PyTorch:
26 | V := mu * V + g
27 | p := p - lr * V
28 |
29 | where V is the velocity, mu is the momentum factor, lr is the learning rate,
30 | g is the gradient and p are the parameters.
31 |
32 | Since V is defined independently of the learning rate in PyTorch,
33 | when the learning rate is changed there is no need to perform the
34 | momentum correction by scaling V (unlike in the Caffe2 case).
35 | """
36 | if cfg.BN.USE_CUSTOM_WEIGHT_DECAY:
37 | # Apply different weight decay to Batchnorm and non-batchnorm parameters.
38 | p_bn = [p for n, p in model.named_parameters() if "bn" in n]
39 | p_non_bn = [p for n, p in model.named_parameters() if "bn" not in n]
40 | optim_params = [
41 | {"params": p_bn, "weight_decay": cfg.BN.CUSTOM_WEIGHT_DECAY},
42 | {"params": p_non_bn, "weight_decay": cfg.OPTIM.WEIGHT_DECAY},
43 | ]
44 | else:
45 | optim_params = model.parameters()
46 |
47 | return torch.optim.SGD(
48 | optim_params,
49 | lr=cfg.OPTIM.BASE_LR,
50 | momentum=cfg.OPTIM.MOMENTUM,
51 | weight_decay=cfg.OPTIM.WEIGHT_DECAY,
52 | dampening=cfg.OPTIM.DAMPENING,
53 | nesterov=cfg.OPTIM.NESTEROV,
54 | )
55 |
56 |
57 | def lr_fun_steps(cur_epoch):
58 | """Steps schedule (cfg.OPTIM.LR_POLICY = 'steps')."""
59 | ind = [i for i, s in enumerate(cfg.OPTIM.STEPS) if cur_epoch >= s][-1]
60 | return cfg.OPTIM.BASE_LR * (cfg.OPTIM.LR_MULT ** ind)
61 |
62 |
63 | def lr_fun_exp(cur_epoch):
64 | """Exponential schedule (cfg.OPTIM.LR_POLICY = 'exp')."""
65 | return cfg.OPTIM.BASE_LR * (cfg.OPTIM.GAMMA ** cur_epoch)
66 |
67 |
68 | def lr_fun_cos(cur_epoch):
69 | """Cosine schedule (cfg.OPTIM.LR_POLICY = 'cos')."""
70 | base_lr, max_epoch = cfg.OPTIM.BASE_LR, cfg.OPTIM.MAX_EPOCH
71 | return 0.5 * base_lr * (1.0 + np.cos(np.pi * cur_epoch / max_epoch))
72 |
73 |
74 | def get_lr_fun():
75 | """Retrieves the specified lr policy function"""
76 | lr_fun = "lr_fun_" + cfg.OPTIM.LR_POLICY
77 | if lr_fun not in globals():
78 | raise NotImplementedError("Unknown LR policy:" + cfg.OPTIM.LR_POLICY)
79 | return globals()[lr_fun]
80 |
81 |
82 | def get_epoch_lr(cur_epoch):
83 | """Retrieves the lr for the given epoch according to the policy."""
84 | lr = get_lr_fun()(cur_epoch)
85 | # Linear warmup
86 | if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS:
87 | alpha = cur_epoch / cfg.OPTIM.WARMUP_EPOCHS
88 | warmup_factor = cfg.OPTIM.WARMUP_FACTOR * (1.0 - alpha) + alpha
89 | lr *= warmup_factor
90 | return lr
91 |
92 |
93 | def set_lr(optimizer, new_lr):
94 | """Sets the optimizer lr to the specified value."""
95 | for param_group in optimizer.param_groups:
96 | param_group["lr"] = new_lr
97 |
--------------------------------------------------------------------------------
/imagenet/pycls/core/plotting.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Plotting functions."""
9 |
10 | import colorlover as cl
11 | import matplotlib.pyplot as plt
12 | import plotly.graph_objs as go
13 | import plotly.offline as offline
14 | import pycls.core.logging as logging
15 |
16 |
17 | def get_plot_colors(max_colors, color_format="pyplot"):
18 | """Generate colors for plotting."""
19 | colors = cl.scales["11"]["qual"]["Paired"]
20 | if max_colors > len(colors):
21 | colors = cl.to_rgb(cl.interp(colors, max_colors))
22 | if color_format == "pyplot":
23 | return [[j / 255.0 for j in c] for c in cl.to_numeric(colors)]
24 | return colors
25 |
26 |
27 | def prepare_plot_data(log_files, names, metric="top1_err"):
28 | """Load logs and extract data for plotting error curves."""
29 | plot_data = []
30 | for file, name in zip(log_files, names):
31 | d, data = {}, logging.sort_log_data(logging.load_log_data(file))
32 | for phase in ["train", "test"]:
33 | x = data[phase + "_epoch"]["epoch_ind"]
34 | y = data[phase + "_epoch"][metric]
35 | d["x_" + phase], d["y_" + phase] = x, y
36 | d[phase + "_label"] = "[{:5.2f}] ".format(min(y) if y else 0) + name
37 | plot_data.append(d)
38 | assert len(plot_data) > 0, "No data to plot"
39 | return plot_data
40 |
41 |
42 | def plot_error_curves_plotly(log_files, names, filename, metric="top1_err"):
43 | """Plot error curves using plotly and save to file."""
44 | plot_data = prepare_plot_data(log_files, names, metric)
45 | colors = get_plot_colors(len(plot_data), "plotly")
46 | # Prepare data for plots (3 sets, train duplicated w and w/o legend)
47 | data = []
48 | for i, d in enumerate(plot_data):
49 | s = str(i)
50 | line_train = {"color": colors[i], "dash": "dashdot", "width": 1.5}
51 | line_test = {"color": colors[i], "dash": "solid", "width": 1.5}
52 | data.append(
53 | go.Scatter(
54 | x=d["x_train"],
55 | y=d["y_train"],
56 | mode="lines",
57 | name=d["train_label"],
58 | line=line_train,
59 | legendgroup=s,
60 | visible=True,
61 | showlegend=False,
62 | )
63 | )
64 | data.append(
65 | go.Scatter(
66 | x=d["x_test"],
67 | y=d["y_test"],
68 | mode="lines",
69 | name=d["test_label"],
70 | line=line_test,
71 | legendgroup=s,
72 | visible=True,
73 | showlegend=True,
74 | )
75 | )
76 | data.append(
77 | go.Scatter(
78 | x=d["x_train"],
79 | y=d["y_train"],
80 | mode="lines",
81 | name=d["train_label"],
82 | line=line_train,
83 | legendgroup=s,
84 | visible=False,
85 | showlegend=True,
86 | )
87 | )
88 | # Prepare layout w ability to toggle 'all', 'train', 'test'
89 | titlefont = {"size": 18, "color": "#7f7f7f"}
90 | vis = [[True, True, False], [False, False, True], [False, True, False]]
91 | buttons = zip(["all", "train", "test"], [[{"visible": v}] for v in vis])
92 | buttons = [{"label": b, "args": v, "method": "update"} for b, v in buttons]
93 | layout = go.Layout(
94 | title=metric + " vs. epoch
[dash=train, solid=test]",
95 | xaxis={"title": "epoch", "titlefont": titlefont},
96 | yaxis={"title": metric, "titlefont": titlefont},
97 | showlegend=True,
98 | hoverlabel={"namelength": -1},
99 | updatemenus=[
100 | {
101 | "buttons": buttons,
102 | "direction": "down",
103 | "showactive": True,
104 | "x": 1.02,
105 | "xanchor": "left",
106 | "y": 1.08,
107 | "yanchor": "top",
108 | }
109 | ],
110 | )
111 | # Create plotly plot
112 | offline.plot({"data": data, "layout": layout}, filename=filename)
113 |
114 |
115 | def plot_error_curves_pyplot(log_files, names, filename=None, metric="top1_err"):
116 | """Plot error curves using matplotlib.pyplot and save to file."""
117 | plot_data = prepare_plot_data(log_files, names, metric)
118 | colors = get_plot_colors(len(names))
119 | for ind, d in enumerate(plot_data):
120 | c, lbl = colors[ind], d["test_label"]
121 | plt.plot(d["x_train"], d["y_train"], "--", c=c, alpha=0.8)
122 | plt.plot(d["x_test"], d["y_test"], "-", c=c, alpha=0.8, label=lbl)
123 | plt.title(metric + " vs. epoch\n[dash=train, solid=test]", fontsize=14)
124 | plt.xlabel("epoch", fontsize=14)
125 | plt.ylabel(metric, fontsize=14)
126 | plt.grid(alpha=0.4)
127 | plt.legend()
128 | if filename:
129 | plt.savefig(filename)
130 | plt.clf()
131 | else:
132 | plt.show()
133 |
--------------------------------------------------------------------------------
/imagenet/pycls/core/timer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Timer."""
9 |
10 | import time
11 |
12 |
13 | class Timer(object):
14 | """A simple timer (adapted from Detectron)."""
15 |
16 | def __init__(self):
17 | self.total_time = None
18 | self.calls = None
19 | self.start_time = None
20 | self.diff = None
21 | self.average_time = None
22 | self.reset()
23 |
24 | def tic(self):
25 | # using time.time as time.clock does not normalize for multithreading
26 | self.start_time = time.time()
27 |
28 | def toc(self):
29 | self.diff = time.time() - self.start_time
30 | self.total_time += self.diff
31 | self.calls += 1
32 | self.average_time = self.total_time / self.calls
33 |
34 | def reset(self):
35 | self.total_time = 0.0
36 | self.calls = 0
37 | self.start_time = 0.0
38 | self.diff = 0.0
39 | self.average_time = 0.0
40 |
--------------------------------------------------------------------------------
/imagenet/pycls/core/trainer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Tools for training and testing a model."""
9 |
10 | import os
11 |
12 | import numpy as np
13 | import pycls.core.benchmark as benchmark
14 | import pycls.core.builders as builders
15 | import pycls.core.checkpoint as checkpoint
16 | import pycls.core.config as config
17 | import pycls.core.distributed as dist
18 | import pycls.core.logging as logging
19 | import pycls.core.meters as meters
20 | import pycls.core.net as net
21 | import pycls.core.optimizer as optim
22 | import pycls.datasets.loader as loader
23 | from pycls.datasets.transforms import cutmix_batch
24 | import torch
25 | from pycls.core.config import cfg
26 | import torch.nn as nn
27 | import torch.distributed as tdist
28 |
29 |
30 | logger = logging.get_logger(__name__)
31 |
32 |
33 | def setup_env():
34 | """Sets up environment for training or testing."""
35 | if dist.is_master_proc():
36 | # Ensure that the output dir exists
37 | os.makedirs(cfg.OUT_DIR, exist_ok=True)
38 | # Save the config
39 | config.dump_cfg()
40 | # Setup logging
41 | logging.setup_logging()
42 | # Log the config as both human readable and as a json
43 | logger.info("Config:\n{}".format(cfg))
44 | logger.info(logging.dump_log_data(cfg, "cfg"))
45 | # Fix the RNG seeds (see RNG comment in core/config.py for discussion)
46 | np.random.seed(cfg.RNG_SEED)
47 | torch.manual_seed(cfg.RNG_SEED)
48 | # Configure the CUDNN backend
49 | torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK
50 |
51 |
52 | def setup_model():
53 | """Sets up a model for training or testing and log the results."""
54 | # Build the model
55 | model = builders.build_model()
56 | if cfg.TRAIN.SYNC_BN:
57 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
58 | logger.info("Model:\n{}".format(model))
59 | # Log model complexity
60 | # logger.info(logging.dump_log_data(net.complexity(model), "complexity"))
61 | # Transfer the model to the current GPU device
62 | cur_device = torch.cuda.current_device()
63 | model = model.cuda(device=cur_device)
64 | # Use multi-process data parallel model in the multi-gpu setting
65 | # Make model replica operate on the current device
66 | model = torch.nn.parallel.DistributedDataParallel(
67 | module=model, device_ids=[cur_device], output_device=cur_device
68 | )
69 | # Set complexity function to be module's complexity function
70 | # model.complexity = model.module.complexity
71 | return model
72 |
73 | class SoftCrossEntropyLoss(nn.Module):
74 | """SoftCrossEntropyLoss (useful for label smoothing and mixup).
75 | Identical to torch.nn.CrossEntropyLoss if used with one-hot labels."""
76 |
77 | def __init__(self):
78 | super(SoftCrossEntropyLoss, self).__init__()
79 |
80 | def forward(self, x, y):
81 | loss = -y * torch.nn.functional.log_softmax(x, -1)
82 | return torch.sum(loss) / x.shape[0]
83 |
84 | class KDLoss(nn.Module):
85 | def __init__(self, temp_factor=4.):
86 | super(KDLoss, self).__init__()
87 | self.temp_factor = temp_factor
88 | self.kl_div = nn.KLDivLoss(reduction="sum")
89 |
90 | def forward(self, input, target):
91 | log_p = torch.log_softmax(input/self.temp_factor, dim=1)
92 | loss = self.kl_div(log_p, target)*(self.temp_factor**2)/input.size(0)
93 | return loss
94 |
95 | def knowledge_ensemble(feats, logits, temp_factor=4., omega=0.5, cross_gpus=True):
96 | batch_size = logits.size(0)
97 | feats = nn.functional.normalize(feats, p=2, dim=1)
98 | logits = nn.functional.softmax(logits/temp_factor, dim=1)
99 |
100 | if cross_gpus:
101 | feats_large = [torch.zeros_like(feats) \
102 | for _ in range(tdist.get_world_size())]
103 | tdist.all_gather(feats_large, feats)
104 | feats_large = torch.cat(feats_large, dim=0)
105 | logits_large = [torch.zeros_like(logits) \
106 | for _ in range(tdist.get_world_size())]
107 | tdist.all_gather(logits_large, logits)
108 | logits_large = torch.cat(logits_large, dim=0)
109 | enlarged_batch_size = logits_large.size(0)
110 | labels_idx = torch.arange(batch_size) + tdist.get_rank() * batch_size
111 | else:
112 | feats_large = feats
113 | logits_large = logits
114 | enlarged_batch_size = batch_size
115 | labels_idx = torch.arange(batch_size)
116 |
117 | masks = torch.eye(enlarged_batch_size).cuda()
118 | W = torch.matmul(feats_large, feats_large.permute(1, 0)) - masks * 1e9
119 | W = nn.functional.softmax(W, dim=1)
120 | W = (1 - omega) * torch.inverse(masks - omega * W)
121 | logits_new = torch.matmul(W, logits_large)
122 | return logits_new[labels_idx]
123 |
124 | def to_one_hot_labels(labels):
125 | """Convert each label to a one-hot vector."""
126 | n_classes = cfg.MODEL.NUM_CLASSES
127 | err_str = "Invalid input to one_hot_vector()"
128 | assert labels.ndim == 1 and labels.max() < n_classes, err_str
129 | shape = (labels.shape[0], n_classes)
130 | neg_val, pos_val = 0.0, 1.0
131 | labels_one_hot = torch.full(shape, neg_val, dtype=torch.float, device=labels.device)
132 | labels_one_hot.scatter_(1, labels.long().view(-1, 1), pos_val)
133 | return labels_one_hot
134 |
135 | @torch.no_grad()
136 | def concat_all_gather(tensor):
137 | """
138 | Performs all_gather operation on the provided tensors.
139 | *** Warning ***: torch.distributed.all_gather has no gradient.
140 | """
141 | tensors_gather = [torch.ones_like(tensor)
142 | for _ in range(torch.distributed.get_world_size())]
143 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
144 |
145 | output = torch.cat(tensors_gather, dim=0)
146 | return output
147 |
148 | @torch.no_grad()
149 | def _batch_shuffle_ddp(x, y):
150 | """
151 | Batch shuffle, for making use of BatchNorm.
152 | *** Only support DistributedDataParallel (DDP) model. ***
153 | """
154 | # gather from all gpus
155 | batch_size_this = x.shape[0]
156 | x_gather = concat_all_gather(x)
157 | y_gather = concat_all_gather(y)
158 | batch_size_all = x_gather.shape[0]
159 |
160 | num_gpus = batch_size_all // batch_size_this
161 |
162 | # random shuffle index
163 | idx_shuffle = torch.randperm(batch_size_all).cuda()
164 |
165 | # broadcast to all gpus
166 | torch.distributed.broadcast(idx_shuffle, src=0)
167 |
168 | # index for restoring
169 | idx_unshuffle = torch.argsort(idx_shuffle)
170 |
171 | # shuffled index for this gpu
172 | gpu_idx = torch.distributed.get_rank()
173 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
174 |
175 | return x_gather[idx_this], y_gather[idx_this]
176 |
177 |
178 | def train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch):
179 | """Performs one epoch of training."""
180 | # Shuffle the data
181 | loader.shuffle(train_loader, cur_epoch)
182 | # Update the learning rate
183 | lr = optim.get_epoch_lr(cur_epoch)
184 | optim.set_lr(optimizer, lr)
185 | kdloss = KDLoss(cfg.TRAIN.TEMP).cuda()
186 | softceloss = SoftCrossEntropyLoss().cuda()
187 | # Enable training mode
188 | model.train()
189 | train_meter.iter_tic()
190 | for cur_iter, (inputs, labels) in enumerate(train_loader):
191 | # if (cur_iter>=train_meter.epoch_iters): break
192 | # Transfer the data to the current GPU device
193 | # import pdb; pdb.set_trace()
194 | inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
195 | if cfg.TRAIN.SHUFFLE_BN:
196 | inputs, labels = _batch_shuffle_ddp(inputs, labels)
197 | onehot_labels = to_one_hot_labels(labels)
198 | # apply cutmix augmentation
199 | if cfg.TRAIN.IS_CUTMIX:
200 | disable = (cur_epoch >= (cfg.OPTIM.MAX_EPOCH - cfg.TRAIN.CUTMIX.OFF_EPOCH))
201 | inputs, onehot_labels = cutmix_batch(inputs, onehot_labels, cfg.MODEL.NUM_CLASSES, cfg.TRAIN.CUTMIX, disable)
202 | # Perform the forward pass
203 | feas, preds = model(inputs)
204 | # Compute the loss
205 | loss = softceloss(preds, onehot_labels)
206 | # random walk
207 | if cfg.TRAIN.USE_BAKE:
208 | with torch.no_grad():
209 | kd_targets = knowledge_ensemble(feas.detach(), preds.detach(),
210 | temp_factor=cfg.TRAIN.TEMP, omega=cfg.TRAIN.OMEGA, cross_gpus=cfg.TRAIN.GLOBAL_KE)
211 | loss += kdloss(preds, kd_targets.detach())*cfg.TRAIN.LAMBDA
212 | # Perform the backward pass
213 | optimizer.zero_grad()
214 | loss.backward()
215 | # Update the parameters
216 | optimizer.step()
217 | # Compute the errors
218 | top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5])
219 | # Combine the stats across the GPUs (no reduction if 1 GPU used)
220 | loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
221 | # Copy the stats from GPU to CPU (sync point)
222 | loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item()
223 | train_meter.iter_toc()
224 | # Update and log stats
225 | mb_size = inputs.size(0) * cfg.NUM_GPUS
226 | train_meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
227 | train_meter.log_iter_stats(cur_epoch, cur_iter)
228 | train_meter.iter_tic()
229 | # Log epoch stats
230 | train_meter.log_epoch_stats(cur_epoch)
231 | train_meter.reset()
232 |
233 |
234 | @torch.no_grad()
235 | def test_epoch(test_loader, model, test_meter, cur_epoch):
236 | """Evaluates the model on the test set."""
237 | # Enable eval mode
238 | model.eval()
239 | test_meter.iter_tic()
240 | for cur_iter, (inputs, labels) in enumerate(test_loader):
241 | # Transfer the data to the current GPU device
242 | inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
243 | # Compute the predictions
244 | _, preds = model(inputs)
245 | # Compute the errors
246 | top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5])
247 | # Combine the errors across the GPUs (no reduction if 1 GPU used)
248 | top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err])
249 | # Copy the errors from GPU to CPU (sync point)
250 | top1_err, top5_err = top1_err.item(), top5_err.item()
251 | test_meter.iter_toc()
252 | # Update and log stats
253 | test_meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS)
254 | test_meter.log_iter_stats(cur_epoch, cur_iter)
255 | test_meter.iter_tic()
256 | # Log epoch stats
257 | test_meter.log_epoch_stats(cur_epoch)
258 | test_meter.reset()
259 |
260 |
261 | def train_model():
262 | """Trains the model."""
263 | # Setup training/testing environment
264 | setup_env()
265 | # Construct the model, loss_fun, and optimizer
266 | model = setup_model()
267 | loss_fun = builders.build_loss_fun().cuda()
268 | optimizer = optim.construct_optimizer(model)
269 | # Load checkpoint or initial weights
270 | start_epoch = 0
271 | if cfg.TRAIN.AUTO_RESUME and checkpoint.has_checkpoint():
272 | last_checkpoint = checkpoint.get_last_checkpoint()
273 | checkpoint_epoch = checkpoint.load_checkpoint(last_checkpoint, model, optimizer)
274 | logger.info("Loaded checkpoint from: {}".format(last_checkpoint))
275 | start_epoch = checkpoint_epoch + 1
276 | elif cfg.TRAIN.WEIGHTS:
277 | checkpoint.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
278 | logger.info("Loaded initial weights from: {}".format(cfg.TRAIN.WEIGHTS))
279 | # Create data loaders and meters
280 | train_loader = loader.construct_train_loader()
281 | test_loader = loader.construct_test_loader()
282 | train_meter = meters.TrainMeter(len(train_loader))
283 | test_meter = meters.TestMeter(len(test_loader))
284 | # Compute model and loader timings
285 | # if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
286 | # benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
287 | # Perform the training loop
288 | logger.info("Start epoch: {}".format(start_epoch + 1))
289 | for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
290 | # Train for one epoch
291 | train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch)
292 | # Compute precise BN stats
293 | if cfg.BN.USE_PRECISE_STATS:
294 | net.compute_precise_bn_stats(model, train_loader)
295 | # Save a checkpoint
296 | if (cur_epoch + 1) % cfg.TRAIN.CHECKPOINT_PERIOD == 0:
297 | checkpoint_file = checkpoint.save_checkpoint(model, optimizer, cur_epoch)
298 | logger.info("Wrote checkpoint to: {}".format(checkpoint_file))
299 | # Evaluate the model
300 | next_epoch = cur_epoch + 1
301 | if next_epoch % cfg.TRAIN.EVAL_PERIOD == 0 or next_epoch == cfg.OPTIM.MAX_EPOCH:
302 | test_epoch(test_loader, model, test_meter, cur_epoch)
303 |
304 |
305 | def test_model():
306 | """Evaluates a trained model."""
307 | # Setup training/testing environment
308 | setup_env()
309 | # Construct the model
310 | model = setup_model()
311 | # Load model weights
312 | checkpoint.load_checkpoint(cfg.TEST.WEIGHTS, model)
313 | logger.info("Loaded model weights from: {}".format(cfg.TEST.WEIGHTS))
314 | # Create data loaders and meters
315 | test_loader = loader.construct_test_loader()
316 | test_meter = meters.TestMeter(len(test_loader))
317 | # Evaluate the model
318 | test_epoch(test_loader, model, test_meter, 0)
319 |
320 |
321 | def time_model():
322 | """Times model and data loader."""
323 | # Setup training/testing environment
324 | setup_env()
325 | # Construct the model and loss_fun
326 | model = setup_model()
327 | loss_fun = builders.build_loss_fun().cuda()
328 | # Create data loaders
329 | train_loader = loader.construct_train_loader()
330 | test_loader = loader.construct_test_loader()
331 | # Compute model and loader timings
332 | benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
333 |
--------------------------------------------------------------------------------
/imagenet/pycls/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxgeee/BAKE/07c4f668ea19311d5b50121026e73d2f035d5765/imagenet/pycls/datasets/__init__.py
--------------------------------------------------------------------------------
/imagenet/pycls/datasets/imagenet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """ImageNet dataset."""
9 |
10 | import os
11 | import re
12 |
13 | import cv2
14 | from PIL import Image
15 | import numpy as np
16 | import pycls.core.logging as logging
17 | import pycls.datasets.transforms as transforms
18 | import torch.utils.data
19 | from pycls.core.config import cfg
20 | import torchvision.transforms as T
21 |
22 |
23 | logger = logging.get_logger(__name__)
24 |
25 | # Per-channel mean and SD values in BGR order
26 | _MEAN = [0.406, 0.456, 0.485]
27 | _SD = [0.225, 0.224, 0.229]
28 |
29 | # Eig vals and vecs of the cov mat
30 | _EIG_VALS = np.array([[0.2175, 0.0188, 0.0045]])
31 | _EIG_VECS = np.array(
32 | [[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]]
33 | )
34 |
35 |
36 | class ImageNet(torch.utils.data.Dataset):
37 | """ImageNet dataset."""
38 |
39 | def __init__(self, data_path, split):
40 | assert os.path.exists(data_path), "Data path '{}' not found".format(data_path)
41 | splits = ["train", "val"]
42 | assert split in splits, "Split '{}' not supported for ImageNet".format(split)
43 | logger.info("Constructing ImageNet {}...".format(split))
44 | self._data_path, self._split = data_path, split
45 | self._construct_imdb()
46 |
47 | def _construct_imdb(self):
48 | """Constructs the imdb."""
49 | # Compile the split data path
50 | split_path = os.path.join(self._data_path, self._split)
51 | logger.info("{} data path: {}".format(self._split, split_path))
52 | # Images are stored per class in subdirs (format: n)
53 | split_files = os.listdir(split_path)
54 | self._class_ids = sorted(f for f in split_files if re.match(r"^n[0-9]+$", f))
55 | # Map ImageNet class ids to contiguous ids
56 | self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)}
57 | # Construct the image db
58 | self._imdb = []
59 | for class_id in self._class_ids:
60 | cont_id = self._class_id_cont_id[class_id]
61 | im_dir = os.path.join(split_path, class_id)
62 | for im_name in os.listdir(im_dir):
63 | im_path = os.path.join(im_dir, im_name)
64 | self._imdb.append({"im_path": im_path, "class": cont_id})
65 | logger.info("Number of images: {}".format(len(self._imdb)))
66 | logger.info("Number of classes: {}".format(len(self._class_ids)))
67 |
68 | def _prepare_im(self, im):
69 | """Prepares the image for network input."""
70 | # Train and test setups differ
71 | train_size = cfg.TRAIN.IM_SIZE
72 | if self._split == "train":
73 | # Scale and aspect ratio then horizontal flip
74 | im = transforms.random_sized_crop(im=im, size=train_size, area_frac=0.08)
75 | im = transforms.horizontal_flip(im=im, p=0.5, order="HWC")
76 | else:
77 | # Scale and center crop
78 | im = transforms.scale(cfg.TEST.IM_SIZE, im)
79 | im = transforms.center_crop(train_size, im)
80 | # HWC -> CHW
81 | im = im.transpose([2, 0, 1])
82 | # [0, 255] -> [0, 1]
83 | im = im / 255.0
84 | # PCA jitter
85 | if self._split == "train":
86 | im = transforms.lighting(im, 0.1, _EIG_VALS, _EIG_VECS)
87 | # Color normalization
88 | im = transforms.color_norm(im, _MEAN, _SD)
89 | return im
90 |
91 | def __getitem__(self, index):
92 | # Load the image
93 | im = Image.open(self._imdb[index]["im_path"]).convert('RGB')
94 | if self._split == "train" and cfg.TRAIN.JITTER:
95 | im = T.ColorJitter(0.4, 0.4, 0.4)(im)
96 | im = cv2.cvtColor(np.asarray(im),cv2.COLOR_RGB2BGR)
97 | im = im.astype(np.float32, copy=False)
98 | # Prepare the image for training / testing
99 | im = self._prepare_im(im)
100 | # Retrieve the label
101 | label = self._imdb[index]["class"]
102 | return im, label
103 |
104 | def __len__(self):
105 | return len(self._imdb)
106 |
--------------------------------------------------------------------------------
/imagenet/pycls/datasets/loader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Data loader."""
9 |
10 | import os
11 |
12 | import torch
13 | from pycls.core.config import cfg
14 | from pycls.datasets.imagenet import ImageNet
15 | from pycls.datasets.sampler import DistributedClassSampler
16 | from torch.utils.data.distributed import DistributedSampler
17 | from torch.utils.data.sampler import RandomSampler
18 |
19 |
20 | # Supported datasets
21 | _DATASETS = {"imagenet": ImageNet}
22 |
23 | # Default data directory (/path/pycls/pycls/datasets/data)
24 | _DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
25 |
26 | # Relative data paths to default data directory
27 | _PATHS = {"imagenet": "imagenet"}
28 |
29 |
30 | def _construct_loader(dataset_name, split, batch_size, shuffle, drop_last):
31 | """Constructs the data loader for the given dataset."""
32 | err_str = "Dataset '{}' not supported".format(dataset_name)
33 | assert dataset_name in _DATASETS and dataset_name in _PATHS, err_str
34 | # Retrieve the data path for the dataset
35 | data_path = os.path.join(_DATA_DIR, _PATHS[dataset_name])
36 | # Construct the dataset
37 | dataset = _DATASETS[dataset_name](data_path, split)
38 | # Create a sampler for multi-process training
39 | if (cfg.TRAIN.INTRA_IMGS > 0 and split == cfg.TRAIN.SPLIT):
40 | sampler = DistributedClassSampler(dataset._imdb, cfg.TRAIN.INTRA_IMGS+1)
41 | else:
42 | sampler = DistributedSampler(dataset)
43 | # Create a loader
44 | loader = torch.utils.data.DataLoader(
45 | dataset,
46 | batch_size=batch_size,
47 | shuffle=(False if sampler else shuffle),
48 | sampler=sampler,
49 | num_workers=cfg.DATA_LOADER.NUM_WORKERS,
50 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
51 | drop_last=drop_last,
52 | )
53 | return loader
54 |
55 |
56 | def construct_train_loader():
57 | """Train loader wrapper."""
58 | return _construct_loader(
59 | dataset_name=cfg.TRAIN.DATASET,
60 | split=cfg.TRAIN.SPLIT,
61 | batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS),
62 | shuffle=True,
63 | drop_last=True,
64 | )
65 |
66 |
67 | def construct_test_loader():
68 | """Test loader wrapper."""
69 | return _construct_loader(
70 | dataset_name=cfg.TEST.DATASET,
71 | split=cfg.TEST.SPLIT,
72 | batch_size=int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS),
73 | shuffle=False,
74 | drop_last=False,
75 | )
76 |
77 |
78 | def shuffle(loader, cur_epoch):
79 | """"Shuffles the data."""
80 | err_str = "Sampler type '{}' not supported".format(type(loader.sampler))
81 | assert isinstance(loader.sampler, (RandomSampler, DistributedSampler, DistributedClassSampler)), err_str
82 | # RandomSampler handles shuffling automatically
83 | if isinstance(loader.sampler, (DistributedSampler, DistributedClassSampler)):
84 | # DistributedSampler shuffles data based on epoch
85 | loader.sampler.set_epoch(cur_epoch)
86 |
--------------------------------------------------------------------------------
/imagenet/pycls/datasets/sampler.py:
--------------------------------------------------------------------------------
1 |
2 | from collections import defaultdict
3 | import numpy as np
4 | import math
5 | import copy
6 |
7 | import torch
8 | from torch.utils.data.sampler import Sampler
9 | from torch.utils.data.distributed import DistributedSampler
10 | import torch.distributed as dist
11 |
12 | __all__ = ["DistributedClassSampler"]
13 |
14 |
15 | class DistributedClassSampler(Sampler):
16 | def __init__(self, dataset, num_instances, seed=0):
17 | if not dist.is_available():
18 | raise RuntimeError("Requires distributed package to be available")
19 | num_replicas = dist.get_world_size()
20 | rank = dist.get_rank()
21 | self.dataset = dataset
22 | self.num_replicas = num_replicas
23 | self.rank = rank
24 | self.epoch = 0
25 | self.seed = seed
26 | self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
27 | self.total_size = self.num_samples * self.num_replicas
28 |
29 | self.num_instances = num_instances
30 | self.pid_index = defaultdict(list)
31 | for idx, item in enumerate(self.dataset):
32 | self.pid_index[item['class']].append(idx)
33 |
34 |
35 | def __len__(self):
36 | return self.num_samples * self.num_instances
37 |
38 | def __iter__(self):
39 | # deterministically shuffle based on epoch and seed
40 | g = torch.Generator()
41 | g.manual_seed(self.seed + self.epoch)
42 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore
43 |
44 | # add extra samples to make it evenly divisible
45 | indices += indices[:(self.total_size - len(indices))]
46 | assert len(indices) == self.total_size
47 |
48 | # subsample
49 | indices = indices[self.rank:self.total_size:self.num_replicas]
50 | assert len(indices) == self.num_samples
51 |
52 | ret = []
53 | for i in indices:
54 | ret.append(i)
55 | select_indexes = [j for j in self.pid_index[self.dataset[i]['class']] if j!=i]
56 | if (not select_indexes):
57 | continue
58 | if (len(select_indexes)>=self.num_instances-1):
59 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=False)
60 | else:
61 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=True)
62 | ret.extend(ind_indexes)
63 | return iter(ret)
64 |
65 | def set_epoch(self, epoch):
66 | self.epoch = epoch
67 |
--------------------------------------------------------------------------------
/imagenet/pycls/datasets/transforms.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Image transformations."""
9 |
10 | import math
11 |
12 | import cv2
13 | import numpy as np
14 | import torch
15 |
16 | def color_norm(im, mean, std):
17 | """Performs per-channel normalization (CHW format)."""
18 | for i in range(im.shape[0]):
19 | im[i] = im[i] - mean[i]
20 | im[i] = im[i] / std[i]
21 | return im
22 |
23 |
24 | def zero_pad(im, pad_size):
25 | """Performs zero padding (CHW format)."""
26 | pad_width = ((0, 0), (pad_size, pad_size), (pad_size, pad_size))
27 | return np.pad(im, pad_width, mode="constant")
28 |
29 |
30 | def horizontal_flip(im, p, order="CHW"):
31 | """Performs horizontal flip (CHW or HWC format)."""
32 | assert order in ["CHW", "HWC"]
33 | if np.random.uniform() < p:
34 | if order == "CHW":
35 | im = im[:, :, ::-1]
36 | else:
37 | im = im[:, ::-1, :]
38 | return im
39 |
40 |
41 | def random_crop(im, size, pad_size=0):
42 | """Performs random crop (CHW format)."""
43 | if pad_size > 0:
44 | im = zero_pad(im=im, pad_size=pad_size)
45 | h, w = im.shape[1:]
46 | y = np.random.randint(0, h - size)
47 | x = np.random.randint(0, w - size)
48 | im_crop = im[:, y : (y + size), x : (x + size)]
49 | assert im_crop.shape[1:] == (size, size)
50 | return im_crop
51 |
52 |
53 | def scale(size, im):
54 | """Performs scaling (HWC format)."""
55 | h, w = im.shape[:2]
56 | if (w <= h and w == size) or (h <= w and h == size):
57 | return im
58 | h_new, w_new = size, size
59 | if w < h:
60 | h_new = int(math.floor((float(h) / w) * size))
61 | else:
62 | w_new = int(math.floor((float(w) / h) * size))
63 | im = cv2.resize(im, (w_new, h_new), interpolation=cv2.INTER_LINEAR)
64 | return im.astype(np.float32)
65 |
66 |
67 | def center_crop(size, im):
68 | """Performs center cropping (HWC format)."""
69 | h, w = im.shape[:2]
70 | y = int(math.ceil((h - size) / 2))
71 | x = int(math.ceil((w - size) / 2))
72 | im_crop = im[y : (y + size), x : (x + size), :]
73 | assert im_crop.shape[:2] == (size, size)
74 | return im_crop
75 |
76 |
77 | def random_sized_crop(im, size, area_frac=0.08, max_iter=10):
78 | """Performs Inception-style cropping (HWC format)."""
79 | h, w = im.shape[:2]
80 | area = h * w
81 | for _ in range(max_iter):
82 | target_area = np.random.uniform(area_frac, 1.0) * area
83 | aspect_ratio = np.random.uniform(3.0 / 4.0, 4.0 / 3.0)
84 | w_crop = int(round(math.sqrt(float(target_area) * aspect_ratio)))
85 | h_crop = int(round(math.sqrt(float(target_area) / aspect_ratio)))
86 | if np.random.uniform() < 0.5:
87 | w_crop, h_crop = h_crop, w_crop
88 | if h_crop <= h and w_crop <= w:
89 | y = 0 if h_crop == h else np.random.randint(0, h - h_crop)
90 | x = 0 if w_crop == w else np.random.randint(0, w - w_crop)
91 | im_crop = im[y : (y + h_crop), x : (x + w_crop), :]
92 | assert im_crop.shape[:2] == (h_crop, w_crop)
93 | im_crop = cv2.resize(im_crop, (size, size), interpolation=cv2.INTER_LINEAR)
94 | return im_crop.astype(np.float32)
95 | return center_crop(size, scale(size, im))
96 |
97 |
98 | def lighting(im, alpha_std, eig_val, eig_vec):
99 | """Performs AlexNet-style PCA jitter (CHW format)."""
100 | if alpha_std == 0:
101 | return im
102 | alpha = np.random.normal(0, alpha_std, size=(1, 3))
103 | alpha = np.repeat(alpha, 3, axis=0)
104 | eig_val = np.repeat(eig_val, 3, axis=0)
105 | rgb = np.sum(eig_vec * alpha * eig_val, axis=1)
106 | for i in range(im.shape[0]):
107 | im[i] = im[i] + rgb[2 - i]
108 | return im
109 |
110 | # for cutmix
111 |
112 | def smooth_target(target, smoothing=0.1, num_classes=1000):
113 | target *= (1. - smoothing)
114 | target += (smoothing / num_classes)
115 | return target
116 |
117 |
118 | def mix_target(target_a, target_b, num_classes, lam=1., smoothing=0.0):
119 | y1 = smooth_target(target_a, smoothing, num_classes)
120 | y2 = smooth_target(target_b, smoothing, num_classes)
121 | return lam * y1 + (1. - lam) * y2
122 |
123 |
124 | def rand_bbox(size, lam):
125 | W = size[2]
126 | H = size[3]
127 | cut_rat = np.sqrt(1. - lam)
128 | cut_w = np.int(W * cut_rat)
129 | cut_h = np.int(H * cut_rat)
130 |
131 | # uniform
132 | cx = np.random.randint(W)
133 | cy = np.random.randint(H)
134 |
135 | bbx1 = np.clip(cx - cut_w // 2, 0, W)
136 | bby1 = np.clip(cy - cut_h // 2, 0, H)
137 | bbx2 = np.clip(cx + cut_w // 2, 0, W)
138 | bby2 = np.clip(cy + cut_h // 2, 0, H)
139 |
140 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (size[2] * size[3]))
141 |
142 | return bbx1, bby1, bbx2, bby2, lam
143 |
144 |
145 | def cutmix_batch(image, target, num_classes, mix_args, disable):
146 | if not disable and np.random.rand(1) < mix_args.PROB:
147 | lam = np.random.beta(mix_args.BETA, mix_args.BETA)
148 | else:
149 | target = smooth_target(
150 | target=target,
151 | smoothing=mix_args.SMOOTHING,
152 | num_classes=num_classes)
153 | return image, target
154 |
155 | rand_index = torch.randperm(image.size()[0], device=image.device)
156 | target_a = target
157 | target_b = target[rand_index]
158 | bbx1, bby1, bbx2, bby2, lam = rand_bbox(image.size(), lam)
159 | image[:, :, bbx1:bbx2, bby1:bby2] = image[rand_index, :, bbx1:bbx2,
160 | bby1:bby2]
161 |
162 | target = mix_target(target_a=target_a,
163 | target_b=target_b,
164 | num_classes=num_classes,
165 | lam=lam,
166 | smoothing=mix_args.SMOOTHING)
167 | return image, target
168 |
--------------------------------------------------------------------------------
/imagenet/pycls/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yxgeee/BAKE/07c4f668ea19311d5b50121026e73d2f035d5765/imagenet/pycls/models/__init__.py
--------------------------------------------------------------------------------
/imagenet/pycls/models/effnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """EfficientNet models."""
9 |
10 | import pycls.core.net as net
11 | import torch
12 | import torch.nn as nn
13 | from pycls.core.config import cfg
14 |
15 |
16 | class EffHead(nn.Module):
17 | """EfficientNet head: 1x1, BN, Swish, AvgPool, Dropout, FC."""
18 |
19 | def __init__(self, w_in, w_out, nc):
20 | super(EffHead, self).__init__()
21 | self.conv = nn.Conv2d(w_in, w_out, 1, stride=1, padding=0, bias=False)
22 | self.conv_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
23 | self.conv_swish = Swish()
24 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
25 | if cfg.EN.DROPOUT_RATIO > 0.0:
26 | self.dropout = nn.Dropout(p=cfg.EN.DROPOUT_RATIO)
27 | self.fc = nn.Linear(w_out, nc, bias=True)
28 |
29 | def forward(self, x):
30 | x = self.conv_swish(self.conv_bn(self.conv(x)))
31 | x = self.avg_pool(x)
32 | x = x.view(x.size(0), -1)
33 | x = self.dropout(x) if hasattr(self, "dropout") else x
34 | out = self.fc(x)
35 | return x, out
36 |
37 | @staticmethod
38 | def complexity(cx, w_in, w_out, nc):
39 | cx = net.complexity_conv2d(cx, w_in, w_out, 1, 1, 0)
40 | cx = net.complexity_batchnorm2d(cx, w_out)
41 | cx["h"], cx["w"] = 1, 1
42 | cx = net.complexity_conv2d(cx, w_out, nc, 1, 1, 0, bias=True)
43 | return cx
44 |
45 |
46 | class Swish(nn.Module):
47 | """Swish activation function: x * sigmoid(x)."""
48 |
49 | def __init__(self):
50 | super(Swish, self).__init__()
51 |
52 | def forward(self, x):
53 | return x * torch.sigmoid(x)
54 |
55 |
56 | class SE(nn.Module):
57 | """Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid."""
58 |
59 | def __init__(self, w_in, w_se):
60 | super(SE, self).__init__()
61 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
62 | self.f_ex = nn.Sequential(
63 | nn.Conv2d(w_in, w_se, 1, bias=True),
64 | Swish(),
65 | nn.Conv2d(w_se, w_in, 1, bias=True),
66 | nn.Sigmoid(),
67 | )
68 |
69 | def forward(self, x):
70 | return x * self.f_ex(self.avg_pool(x))
71 |
72 | @staticmethod
73 | def complexity(cx, w_in, w_se):
74 | h, w = cx["h"], cx["w"]
75 | cx["h"], cx["w"] = 1, 1
76 | cx = net.complexity_conv2d(cx, w_in, w_se, 1, 1, 0, bias=True)
77 | cx = net.complexity_conv2d(cx, w_se, w_in, 1, 1, 0, bias=True)
78 | cx["h"], cx["w"] = h, w
79 | return cx
80 |
81 |
82 | class MBConv(nn.Module):
83 | """Mobile inverted bottleneck block w/ SE (MBConv)."""
84 |
85 | def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out):
86 | # expansion, 3x3 dwise, BN, Swish, SE, 1x1, BN, skip_connection
87 | super(MBConv, self).__init__()
88 | self.exp = None
89 | w_exp = int(w_in * exp_r)
90 | if w_exp != w_in:
91 | self.exp = nn.Conv2d(w_in, w_exp, 1, stride=1, padding=0, bias=False)
92 | self.exp_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
93 | self.exp_swish = Swish()
94 | dwise_args = {"groups": w_exp, "padding": (kernel - 1) // 2, "bias": False}
95 | self.dwise = nn.Conv2d(w_exp, w_exp, kernel, stride=stride, **dwise_args)
96 | self.dwise_bn = nn.BatchNorm2d(w_exp, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
97 | self.dwise_swish = Swish()
98 | self.se = SE(w_exp, int(w_in * se_r))
99 | self.lin_proj = nn.Conv2d(w_exp, w_out, 1, stride=1, padding=0, bias=False)
100 | self.lin_proj_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
101 | # Skip connection if in and out shapes are the same (MN-V2 style)
102 | self.has_skip = stride == 1 and w_in == w_out
103 |
104 | def forward(self, x):
105 | f_x = x
106 | if self.exp:
107 | f_x = self.exp_swish(self.exp_bn(self.exp(f_x)))
108 | f_x = self.dwise_swish(self.dwise_bn(self.dwise(f_x)))
109 | f_x = self.se(f_x)
110 | f_x = self.lin_proj_bn(self.lin_proj(f_x))
111 | if self.has_skip:
112 | if self.training and cfg.EN.DC_RATIO > 0.0:
113 | f_x = net.drop_connect(f_x, cfg.EN.DC_RATIO)
114 | f_x = x + f_x
115 | return f_x
116 |
117 | @staticmethod
118 | def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out):
119 | w_exp = int(w_in * exp_r)
120 | if w_exp != w_in:
121 | cx = net.complexity_conv2d(cx, w_in, w_exp, 1, 1, 0)
122 | cx = net.complexity_batchnorm2d(cx, w_exp)
123 | padding = (kernel - 1) // 2
124 | cx = net.complexity_conv2d(cx, w_exp, w_exp, kernel, stride, padding, w_exp)
125 | cx = net.complexity_batchnorm2d(cx, w_exp)
126 | cx = SE.complexity(cx, w_exp, int(w_in * se_r))
127 | cx = net.complexity_conv2d(cx, w_exp, w_out, 1, 1, 0)
128 | cx = net.complexity_batchnorm2d(cx, w_out)
129 | return cx
130 |
131 |
132 | class EffStage(nn.Module):
133 | """EfficientNet stage."""
134 |
135 | def __init__(self, w_in, exp_r, kernel, stride, se_r, w_out, d):
136 | super(EffStage, self).__init__()
137 | for i in range(d):
138 | b_stride = stride if i == 0 else 1
139 | b_w_in = w_in if i == 0 else w_out
140 | name = "b{}".format(i + 1)
141 | self.add_module(name, MBConv(b_w_in, exp_r, kernel, b_stride, se_r, w_out))
142 |
143 | def forward(self, x):
144 | for block in self.children():
145 | x = block(x)
146 | return x
147 |
148 | @staticmethod
149 | def complexity(cx, w_in, exp_r, kernel, stride, se_r, w_out, d):
150 | for i in range(d):
151 | b_stride = stride if i == 0 else 1
152 | b_w_in = w_in if i == 0 else w_out
153 | cx = MBConv.complexity(cx, b_w_in, exp_r, kernel, b_stride, se_r, w_out)
154 | return cx
155 |
156 |
157 | class StemIN(nn.Module):
158 | """EfficientNet stem for ImageNet: 3x3, BN, Swish."""
159 |
160 | def __init__(self, w_in, w_out):
161 | super(StemIN, self).__init__()
162 | self.conv = nn.Conv2d(w_in, w_out, 3, stride=2, padding=1, bias=False)
163 | self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
164 | self.swish = Swish()
165 |
166 | def forward(self, x):
167 | for layer in self.children():
168 | x = layer(x)
169 | return x
170 |
171 | @staticmethod
172 | def complexity(cx, w_in, w_out):
173 | cx = net.complexity_conv2d(cx, w_in, w_out, 3, 2, 1)
174 | cx = net.complexity_batchnorm2d(cx, w_out)
175 | return cx
176 |
177 |
178 | class EffNet(nn.Module):
179 | """EfficientNet model."""
180 |
181 | @staticmethod
182 | def get_args():
183 | return {
184 | "stem_w": cfg.EN.STEM_W,
185 | "ds": cfg.EN.DEPTHS,
186 | "ws": cfg.EN.WIDTHS,
187 | "exp_rs": cfg.EN.EXP_RATIOS,
188 | "se_r": cfg.EN.SE_R,
189 | "ss": cfg.EN.STRIDES,
190 | "ks": cfg.EN.KERNELS,
191 | "head_w": cfg.EN.HEAD_W,
192 | "nc": cfg.MODEL.NUM_CLASSES,
193 | }
194 |
195 | def __init__(self):
196 | err_str = "Dataset {} is not supported"
197 | assert cfg.TRAIN.DATASET in ["imagenet"], err_str.format(cfg.TRAIN.DATASET)
198 | assert cfg.TEST.DATASET in ["imagenet"], err_str.format(cfg.TEST.DATASET)
199 | super(EffNet, self).__init__()
200 | self._construct(**EffNet.get_args())
201 | self.apply(net.init_weights)
202 |
203 | def _construct(self, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
204 | stage_params = list(zip(ds, ws, exp_rs, ss, ks))
205 | self.stem = StemIN(3, stem_w)
206 | prev_w = stem_w
207 | for i, (d, w, exp_r, stride, kernel) in enumerate(stage_params):
208 | name = "s{}".format(i + 1)
209 | self.add_module(name, EffStage(prev_w, exp_r, kernel, stride, se_r, w, d))
210 | prev_w = w
211 | self.head = EffHead(prev_w, head_w, nc)
212 |
213 | def forward(self, x):
214 | # for module in self.children():
215 | # x = module(x)
216 | # return x
217 | for name, module in self.named_children():
218 | if (name=='head'):
219 | break
220 | x = module(x)
221 | x, out = self.head(x)
222 | return x, out
223 |
224 | @staticmethod
225 | def complexity(cx):
226 | """Computes model complexity. If you alter the model, make sure to update."""
227 | return EffNet._complexity(cx, **EffNet.get_args())
228 |
229 | @staticmethod
230 | def _complexity(cx, stem_w, ds, ws, exp_rs, se_r, ss, ks, head_w, nc):
231 | stage_params = list(zip(ds, ws, exp_rs, ss, ks))
232 | cx = StemIN.complexity(cx, 3, stem_w)
233 | prev_w = stem_w
234 | for d, w, exp_r, stride, kernel in stage_params:
235 | cx = EffStage.complexity(cx, prev_w, exp_r, kernel, stride, se_r, w, d)
236 | prev_w = w
237 | cx = EffHead.complexity(cx, prev_w, head_w, nc)
238 | return cx
239 |
--------------------------------------------------------------------------------
/imagenet/pycls/models/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 |
4 |
5 | def conv_bn(inp, oup, stride):
6 | return nn.Sequential(
7 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
8 | nn.BatchNorm2d(oup),
9 | nn.ReLU6(inplace=True)
10 | )
11 |
12 |
13 | def conv_1x1_bn(inp, oup):
14 | return nn.Sequential(
15 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
16 | nn.BatchNorm2d(oup),
17 | nn.ReLU6(inplace=True)
18 | )
19 |
20 |
21 | def make_divisible(x, divisible_by=8):
22 | import numpy as np
23 | return int(np.ceil(x * 1. / divisible_by) * divisible_by)
24 |
25 |
26 | class InvertedResidual(nn.Module):
27 | def __init__(self, inp, oup, stride, expand_ratio):
28 | super(InvertedResidual, self).__init__()
29 | self.stride = stride
30 | assert stride in [1, 2]
31 |
32 | hidden_dim = int(inp * expand_ratio)
33 | self.use_res_connect = self.stride == 1 and inp == oup
34 |
35 | if expand_ratio == 1:
36 | self.conv = nn.Sequential(
37 | # dw
38 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
39 | nn.BatchNorm2d(hidden_dim),
40 | nn.ReLU6(inplace=True),
41 | # pw-linear
42 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
43 | nn.BatchNorm2d(oup),
44 | )
45 | else:
46 | self.conv = nn.Sequential(
47 | # pw
48 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
49 | nn.BatchNorm2d(hidden_dim),
50 | nn.ReLU6(inplace=True),
51 | # dw
52 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
53 | nn.BatchNorm2d(hidden_dim),
54 | nn.ReLU6(inplace=True),
55 | # pw-linear
56 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
57 | nn.BatchNorm2d(oup),
58 | )
59 |
60 | def forward(self, x):
61 | if self.use_res_connect:
62 | return x + self.conv(x)
63 | else:
64 | return self.conv(x)
65 |
66 |
67 | class MobileNetV2(nn.Module):
68 | def __init__(self, n_class=1000, input_size=224, width_mult=1.):
69 | super(MobileNetV2, self).__init__()
70 | block = InvertedResidual
71 | input_channel = 32
72 | last_channel = 1280
73 | interverted_residual_setting = [
74 | # t, c, n, s
75 | [1, 16, 1, 1],
76 | [6, 24, 2, 2],
77 | [6, 32, 3, 2],
78 | [6, 64, 4, 2],
79 | [6, 96, 3, 1],
80 | [6, 160, 3, 2],
81 | [6, 320, 1, 1],
82 | ]
83 |
84 | # building first layer
85 | assert input_size % 32 == 0
86 | # input_channel = make_divisible(input_channel * width_mult) # first channel is always 32!
87 | self.last_channel = make_divisible(last_channel * width_mult) if width_mult > 1.0 else last_channel
88 | self.features = [conv_bn(3, input_channel, 2)]
89 | # building inverted residual blocks
90 | for t, c, n, s in interverted_residual_setting:
91 | output_channel = make_divisible(c * width_mult) if t > 1 else c
92 | for i in range(n):
93 | if i == 0:
94 | self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
95 | else:
96 | self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
97 | input_channel = output_channel
98 | # building last several layers
99 | self.features.append(conv_1x1_bn(input_channel, self.last_channel))
100 | # make it nn.Sequential
101 | self.features = nn.Sequential(*self.features)
102 |
103 | # building classifier
104 | self.classifier = nn.Linear(self.last_channel, n_class)
105 |
106 | self._initialize_weights()
107 |
108 | def forward(self, x):
109 | x = self.features(x)
110 | x = x.mean(3).mean(2)
111 | out = self.classifier(x)
112 | return x, out
113 |
114 | def _initialize_weights(self):
115 | for m in self.modules():
116 | if isinstance(m, nn.Conv2d):
117 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
118 | m.weight.data.normal_(0, math.sqrt(2. / n))
119 | if m.bias is not None:
120 | m.bias.data.zero_()
121 | elif isinstance(m, nn.BatchNorm2d):
122 | m.weight.data.fill_(1)
123 | m.bias.data.zero_()
124 | elif isinstance(m, nn.Linear):
125 | n = m.weight.size(1)
126 | m.weight.data.normal_(0, 0.01)
127 | m.bias.data.zero_()
128 |
129 |
130 | # def mobilenet_v2(pretrained=True):
131 | # model = MobileNetV2(width_mult=1)
132 | # return model
133 |
--------------------------------------------------------------------------------
/imagenet/pycls/models/resnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """ResNe(X)t models."""
9 |
10 | import pycls.core.net as net
11 | import torch.nn as nn
12 | from pycls.core.config import cfg
13 |
14 |
15 | # Stage depths for ImageNet models
16 | _IN_STAGE_DS = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}
17 |
18 |
19 | def get_trans_fun(name):
20 | """Retrieves the transformation function by name."""
21 | trans_funs = {
22 | "basic_transform": BasicTransform,
23 | "bottleneck_transform": BottleneckTransform,
24 | }
25 | err_str = "Transformation function '{}' not supported"
26 | assert name in trans_funs.keys(), err_str.format(name)
27 | return trans_funs[name]
28 |
29 |
30 | class ResHead(nn.Module):
31 | """ResNet head: AvgPool, 1x1."""
32 |
33 | def __init__(self, w_in, nc):
34 | super(ResHead, self).__init__()
35 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
36 | self.fc = nn.Linear(w_in, nc, bias=True)
37 |
38 | def forward(self, x):
39 | x = self.avg_pool(x)
40 | x = x.view(x.size(0), -1)
41 | x = self.fc(x)
42 | return x
43 |
44 | @staticmethod
45 | def complexity(cx, w_in, nc):
46 | cx["h"], cx["w"] = 1, 1
47 | cx = net.complexity_conv2d(cx, w_in, nc, 1, 1, 0, bias=True)
48 | return cx
49 |
50 |
51 | class BasicTransform(nn.Module):
52 | """Basic transformation: 3x3, BN, ReLU, 3x3, BN."""
53 |
54 | def __init__(self, w_in, w_out, stride, w_b=None, num_gs=1):
55 | err_str = "Basic transform does not support w_b and num_gs options"
56 | assert w_b is None and num_gs == 1, err_str
57 | super(BasicTransform, self).__init__()
58 | self.a = nn.Conv2d(w_in, w_out, 3, stride=stride, padding=1, bias=False)
59 | self.a_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
60 | self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
61 | self.b = nn.Conv2d(w_out, w_out, 3, stride=1, padding=1, bias=False)
62 | self.b_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
63 | self.b_bn.final_bn = True
64 |
65 | def forward(self, x):
66 | for layer in self.children():
67 | x = layer(x)
68 | return x
69 |
70 | @staticmethod
71 | def complexity(cx, w_in, w_out, stride, w_b=None, num_gs=1):
72 | err_str = "Basic transform does not support w_b and num_gs options"
73 | assert w_b is None and num_gs == 1, err_str
74 | cx = net.complexity_conv2d(cx, w_in, w_out, 3, stride, 1)
75 | cx = net.complexity_batchnorm2d(cx, w_out)
76 | cx = net.complexity_conv2d(cx, w_out, w_out, 3, 1, 1)
77 | cx = net.complexity_batchnorm2d(cx, w_out)
78 | return cx
79 |
80 |
81 | class BottleneckTransform(nn.Module):
82 | """Bottleneck transformation: 1x1, BN, ReLU, 3x3, BN, ReLU, 1x1, BN."""
83 |
84 | def __init__(self, w_in, w_out, stride, w_b, num_gs):
85 | super(BottleneckTransform, self).__init__()
86 | # MSRA -> stride=2 is on 1x1; TH/C2 -> stride=2 is on 3x3
87 | (s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
88 | self.a = nn.Conv2d(w_in, w_b, 1, stride=s1, padding=0, bias=False)
89 | self.a_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
90 | self.a_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
91 | self.b = nn.Conv2d(w_b, w_b, 3, stride=s3, padding=1, groups=num_gs, bias=False)
92 | self.b_bn = nn.BatchNorm2d(w_b, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
93 | self.b_relu = nn.ReLU(inplace=cfg.MEM.RELU_INPLACE)
94 | self.c = nn.Conv2d(w_b, w_out, 1, stride=1, padding=0, bias=False)
95 | self.c_bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
96 | self.c_bn.final_bn = True
97 |
98 | def forward(self, x):
99 | for layer in self.children():
100 | x = layer(x)
101 | return x
102 |
103 | @staticmethod
104 | def complexity(cx, w_in, w_out, stride, w_b, num_gs):
105 | (s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride)
106 | cx = net.complexity_conv2d(cx, w_in, w_b, 1, s1, 0)
107 | cx = net.complexity_batchnorm2d(cx, w_b)
108 | cx = net.complexity_conv2d(cx, w_b, w_b, 3, s3, 1, num_gs)
109 | cx = net.complexity_batchnorm2d(cx, w_b)
110 | cx = net.complexity_conv2d(cx, w_b, w_out, 1, 1, 0)
111 | cx = net.complexity_batchnorm2d(cx, w_out)
112 | return cx
113 |
114 |
115 | class ResBlock(nn.Module):
116 | """Residual block: x + F(x)."""
117 |
118 | def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, num_gs=1):
119 | super(ResBlock, self).__init__()
120 | # Use skip connection with projection if shape changes
121 | self.proj_block = (w_in != w_out) or (stride != 1)
122 | if self.proj_block:
123 | self.proj = nn.Conv2d(w_in, w_out, 1, stride=stride, padding=0, bias=False)
124 | self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
125 | self.f = trans_fun(w_in, w_out, stride, w_b, num_gs)
126 | self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
127 |
128 | def forward(self, x):
129 | if self.proj_block:
130 | x = self.bn(self.proj(x)) + self.f(x)
131 | else:
132 | x = x + self.f(x)
133 | x = self.relu(x)
134 | return x
135 |
136 | @staticmethod
137 | def complexity(cx, w_in, w_out, stride, trans_fun, w_b, num_gs):
138 | proj_block = (w_in != w_out) or (stride != 1)
139 | if proj_block:
140 | h, w = cx["h"], cx["w"]
141 | cx = net.complexity_conv2d(cx, w_in, w_out, 1, stride, 0)
142 | cx = net.complexity_batchnorm2d(cx, w_out)
143 | cx["h"], cx["w"] = h, w # parallel branch
144 | cx = trans_fun.complexity(cx, w_in, w_out, stride, w_b, num_gs)
145 | return cx
146 |
147 |
148 | class ResStage(nn.Module):
149 | """Stage of ResNet."""
150 |
151 | def __init__(self, w_in, w_out, stride, d, w_b=None, num_gs=1):
152 | super(ResStage, self).__init__()
153 | for i in range(d):
154 | b_stride = stride if i == 0 else 1
155 | b_w_in = w_in if i == 0 else w_out
156 | trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN)
157 | res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, num_gs)
158 | self.add_module("b{}".format(i + 1), res_block)
159 |
160 | def forward(self, x):
161 | for block in self.children():
162 | x = block(x)
163 | return x
164 |
165 | @staticmethod
166 | def complexity(cx, w_in, w_out, stride, d, w_b=None, num_gs=1):
167 | for i in range(d):
168 | b_stride = stride if i == 0 else 1
169 | b_w_in = w_in if i == 0 else w_out
170 | trans_f = get_trans_fun(cfg.RESNET.TRANS_FUN)
171 | cx = ResBlock.complexity(cx, b_w_in, w_out, b_stride, trans_f, w_b, num_gs)
172 | return cx
173 |
174 |
175 | class ResStemCifar(nn.Module):
176 | """ResNet stem for CIFAR: 3x3, BN, ReLU."""
177 |
178 | def __init__(self, w_in, w_out):
179 | super(ResStemCifar, self).__init__()
180 | self.conv = nn.Conv2d(w_in, w_out, 3, stride=1, padding=1, bias=False)
181 | self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
182 | self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
183 |
184 | def forward(self, x):
185 | for layer in self.children():
186 | x = layer(x)
187 | return x
188 |
189 | @staticmethod
190 | def complexity(cx, w_in, w_out):
191 | cx = net.complexity_conv2d(cx, w_in, w_out, 3, 1, 1)
192 | cx = net.complexity_batchnorm2d(cx, w_out)
193 | return cx
194 |
195 |
196 | class ResStemIN(nn.Module):
197 | """ResNet stem for ImageNet: 7x7, BN, ReLU, MaxPool."""
198 |
199 | def __init__(self, w_in, w_out):
200 | super(ResStemIN, self).__init__()
201 | self.conv = nn.Conv2d(w_in, w_out, 7, stride=2, padding=3, bias=False)
202 | self.bn = nn.BatchNorm2d(w_out, eps=cfg.BN.EPS, momentum=cfg.BN.MOM)
203 | self.relu = nn.ReLU(cfg.MEM.RELU_INPLACE)
204 | self.pool = nn.MaxPool2d(3, stride=2, padding=1)
205 |
206 | def forward(self, x):
207 | for layer in self.children():
208 | x = layer(x)
209 | return x
210 |
211 | @staticmethod
212 | def complexity(cx, w_in, w_out):
213 | cx = net.complexity_conv2d(cx, w_in, w_out, 7, 2, 3)
214 | cx = net.complexity_batchnorm2d(cx, w_out)
215 | cx = net.complexity_maxpool2d(cx, 3, 2, 1)
216 | return cx
217 |
218 |
219 | class ResNet(nn.Module):
220 | """ResNet model."""
221 |
222 | def __init__(self):
223 | datasets = ["cifar10", "imagenet"]
224 | err_str = "Dataset {} is not supported"
225 | assert cfg.TRAIN.DATASET in datasets, err_str.format(cfg.TRAIN.DATASET)
226 | assert cfg.TEST.DATASET in datasets, err_str.format(cfg.TEST.DATASET)
227 | super(ResNet, self).__init__()
228 | if "cifar" in cfg.TRAIN.DATASET:
229 | self._construct_cifar()
230 | else:
231 | self._construct_imagenet()
232 | self.apply(net.init_weights)
233 |
234 | def _construct_cifar(self):
235 | err_str = "Model depth should be of the format 6n + 2 for cifar"
236 | assert (cfg.MODEL.DEPTH - 2) % 6 == 0, err_str
237 | d = int((cfg.MODEL.DEPTH - 2) / 6)
238 | self.stem = ResStemCifar(3, 16)
239 | self.s1 = ResStage(16, 16, stride=1, d=d)
240 | self.s2 = ResStage(16, 32, stride=2, d=d)
241 | self.s3 = ResStage(32, 64, stride=2, d=d)
242 | self.head = ResHead(64, nc=cfg.MODEL.NUM_CLASSES)
243 |
244 | def _construct_imagenet(self):
245 | g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP
246 | (d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
247 | w_b = gw * g
248 | self.stem = ResStemIN(3, 64)
249 | self.s1 = ResStage(64, 256, stride=1, d=d1, w_b=w_b, num_gs=g)
250 | self.s2 = ResStage(256, 512, stride=2, d=d2, w_b=w_b * 2, num_gs=g)
251 | self.s3 = ResStage(512, 1024, stride=2, d=d3, w_b=w_b * 4, num_gs=g)
252 | self.s4 = ResStage(1024, 2048, stride=2, d=d4, w_b=w_b * 8, num_gs=g)
253 | self.head = ResHead(2048, nc=cfg.MODEL.NUM_CLASSES)
254 |
255 | def forward(self, x):
256 | for module in self.children():
257 | x = module(x)
258 | return x
259 |
260 | @staticmethod
261 | def complexity(cx):
262 | """Computes model complexity. If you alter the model, make sure to update."""
263 | if "cifar" in cfg.TRAIN.DATASET:
264 | d = int((cfg.MODEL.DEPTH - 2) / 6)
265 | cx = ResStemCifar.complexity(cx, 3, 16)
266 | cx = ResStage.complexity(cx, 16, 16, stride=1, d=d)
267 | cx = ResStage.complexity(cx, 16, 32, stride=2, d=d)
268 | cx = ResStage.complexity(cx, 32, 64, stride=2, d=d)
269 | cx = ResHead.complexity(cx, 64, nc=cfg.MODEL.NUM_CLASSES)
270 | else:
271 | g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP
272 | (d1, d2, d3, d4) = _IN_STAGE_DS[cfg.MODEL.DEPTH]
273 | w_b = gw * g
274 | cx = ResStemIN.complexity(cx, 3, 64)
275 | cx = ResStage.complexity(cx, 64, 256, 1, d=d1, w_b=w_b, num_gs=g)
276 | cx = ResStage.complexity(cx, 256, 512, 2, d=d2, w_b=w_b * 2, num_gs=g)
277 | cx = ResStage.complexity(cx, 512, 1024, 2, d=d3, w_b=w_b * 4, num_gs=g)
278 | cx = ResStage.complexity(cx, 1024, 2048, 2, d=d4, w_b=w_b * 8, num_gs=g)
279 | cx = ResHead.complexity(cx, 2048, nc=cfg.MODEL.NUM_CLASSES)
280 | return cx
281 |
--------------------------------------------------------------------------------
/imagenet/requirements.txt:
--------------------------------------------------------------------------------
1 | black==19.3b0
2 | flake8
3 | isort
4 | numpy
5 | opencv-python
6 | parameterized
7 | setuptools
8 | simplejson
9 | yacs
10 |
--------------------------------------------------------------------------------
/imagenet/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Setup pycls."""
9 |
10 | from setuptools import setup
11 |
12 |
13 | setup(name="pycls", packages=["pycls"])
14 |
--------------------------------------------------------------------------------
/imagenet/tools/dist_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PYTHON=${PYTHON:-"python"}
6 |
7 | CONFIG=$1
8 | CKPT=$2
9 | PY_ARGS=${@:3}
10 |
11 | GPUS=${GPUS:-8}
12 |
13 | while true # find unused tcp port
14 | do
15 | PORT=$(( ((RANDOM<<15)|RANDOM) % 49152 + 10000 ))
16 | status="$(nc -z 127.0.0.1 $PORT < /dev/null &>/dev/null; echo $?)"
17 | if [ "${status}" != "0" ]; then
18 | break;
19 | fi
20 | done
21 |
22 | $PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT --use_env \
23 | tools/test_net.py --cfg ${CONFIG} \
24 | TEST.WEIGHTS ${CKPT} \
25 | LAUNCHER pytorch \
26 | PORT ${PORT} \
27 | ${PY_ARGS}
28 |
--------------------------------------------------------------------------------
/imagenet/tools/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PYTHON=${PYTHON:-"python"}
6 |
7 | WORK_DIR=$1
8 | CONFIG=$2
9 | PY_ARGS=${@:3}
10 |
11 | GPUS=${GPUS:-8}
12 |
13 | while true # find unused tcp port
14 | do
15 | PORT=$(( ((RANDOM<<15)|RANDOM) % 49152 + 10000 ))
16 | status="$(nc -z 127.0.0.1 $PORT < /dev/null &>/dev/null; echo $?)"
17 | if [ "${status}" != "0" ]; then
18 | break;
19 | fi
20 | done
21 |
22 | mkdir -p $WORK_DIR
23 |
24 | $PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT --use_env \
25 | tools/train_net.py --cfg ${CONFIG} \
26 | LAUNCHER pytorch OUT_DIR ${WORK_DIR} PORT ${PORT} ${PY_ARGS} | tee ${WORK_DIR}/log.txt
27 |
--------------------------------------------------------------------------------
/imagenet/tools/slurm_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PARTITION=$1
6 | CONFIG=$2
7 | CKPT=$3
8 | PY_ARGS=${@:4}
9 |
10 | GPUS=${GPUS:-8}
11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
12 | SRUN_ARGS=${SRUN_ARGS:-""}
13 |
14 | while true # find unused tcp port
15 | do
16 | PORT=$(( ((RANDOM<<15)|RANDOM) % 49152 + 10000 ))
17 | status="$(nc -z 127.0.0.1 $PORT < /dev/null &>/dev/null; echo $?)"
18 | if [ "${status}" != "0" ]; then
19 | break;
20 | fi
21 | done
22 |
23 | srun --mpi=pmi2 -p ${PARTITION} \
24 | --job-name=test \
25 | --gres=gpu:${GPUS} \
26 | --ntasks=${GPUS} \
27 | --ntasks-per-node=${GPUS} \
28 | --cpus-per-task=${CPUS_PER_TASK} \
29 | --kill-on-bad-exit=1 \
30 | ${SRUN_ARGS} \
31 | python -u tools/test_net.py --cfg ${CONFIG} \
32 | TEST.WEIGHTS ${CKPT} \
33 | LAUNCHER slurm \
34 | PORT ${PORT} \
35 | ${PY_ARGS}
36 |
--------------------------------------------------------------------------------
/imagenet/tools/slurm_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PARTITION=$1
6 | JOB_NAME=$2
7 | WORK_DIR=$3
8 | CONFIG=$4
9 | PY_ARGS=${@:5}
10 |
11 | GPUS=${GPUS:-8}
12 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
13 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
14 | SRUN_ARGS=${SRUN_ARGS:-""}
15 |
16 | while true # find unused tcp port
17 | do
18 | PORT=$(( ((RANDOM<<15)|RANDOM) % 49152 + 10000 ))
19 | status="$(nc -z 127.0.0.1 $PORT < /dev/null &>/dev/null; echo $?)"
20 | if [ "${status}" != "0" ]; then
21 | break;
22 | fi
23 | done
24 |
25 | mkdir -p $WORK_DIR
26 |
27 | srun --mpi=pmi2 -p ${PARTITION} \
28 | --job-name=${JOB_NAME} \
29 | --gres=gpu:${GPUS_PER_NODE} \
30 | --ntasks=${GPUS} \
31 | --ntasks-per-node=${GPUS_PER_NODE} \
32 | --cpus-per-task=${CPUS_PER_TASK} \
33 | --kill-on-bad-exit=1 \
34 | ${SRUN_ARGS} \
35 | python -u tools/train_net.py --cfg ${CONFIG} \
36 | LAUNCHER slurm OUT_DIR ${WORK_DIR} PORT ${PORT} ${PY_ARGS} | tee ${WORK_DIR}/log.txt
37 |
--------------------------------------------------------------------------------
/imagenet/tools/test_net.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Test a trained classification model."""
9 |
10 | import pycls.core.config as config
11 | import pycls.core.distributed as dist
12 | import pycls.core.trainer as trainer
13 | from pycls.core.config import cfg
14 |
15 |
16 | def main():
17 | config.load_cfg_fom_args("Test a trained classification model.")
18 | config.assert_and_infer_cfg()
19 | # cfg.freeze()
20 | dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.test_model)
21 |
22 |
23 | if __name__ == "__main__":
24 | main()
25 |
--------------------------------------------------------------------------------
/imagenet/tools/train_net.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | #
5 | # This source code is licensed under the MIT license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | """Train a classification model."""
9 |
10 | import pycls.core.config as config
11 | import pycls.core.distributed as dist
12 | import pycls.core.trainer as trainer
13 | from pycls.core.config import cfg
14 |
15 |
16 | def main():
17 | config.load_cfg_fom_args("Train a classification model.")
18 | config.assert_and_infer_cfg()
19 | # cfg.freeze()
20 | dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.train_model)
21 |
22 |
23 | if __name__ == "__main__":
24 | main()
25 |
--------------------------------------------------------------------------------
/small_scale/README.md:
--------------------------------------------------------------------------------
1 | # BAKE on Small-scale Datasets
2 |
3 | PyTorch implementation of [Self-distillation with Batch Knowledge Ensembling Improves ImageNet Classification](https://arxiv.org/abs/2104.13298) on CIFAR-100, TinyImageNet, CUB-200-2011, Stanford Dogs, and MIT67.
4 |
5 | ## Requirements
6 |
7 | - torch >= 1.2.0
8 | - torchvision >= 0.4.0
9 |
10 | ## Datasets
11 |
12 | Please download the raw datasets and re-organize them according to [DATA](data/). The folder tree should be like
13 | ```
14 | ├── data
15 | │ ├── README.md
16 | │ ├── cifar-100-python/
17 | │ ├── CUB200/
18 | │ ├── MIT67/
19 | │ ├── STANFORD120/
20 | │ ├── tinyimagenet/
21 | │ ├── tools/
22 | └── └── ...
23 | ```
24 |
25 | ## BAKE Training
26 |
27 | ```
28 | sh scripts/train_bake.sh
29 | ```
30 |
31 | Specifically,
32 | + CIFAR-100
33 |
34 | ```
35 | sh scripts/train_bake.sh 0 cifar100 CIFAR_ResNet18 128 3 0.5
36 | sh scripts/train_bake.sh 0 cifar100 CIFAR_DenseNet121 128 3 0.5
37 | ```
38 |
39 | + TinyImageNet
40 |
41 | ```
42 | sh scripts/train_bake.sh 0 tinyimagenet CIFAR_ResNet18 128 1 0.9
43 | sh scripts/train_bake.sh 0 tinyimagenet CIFAR_DenseNet121 128 1 0.9
44 | ```
45 |
46 | + CUB-200-2011
47 |
48 | ```
49 | sh scripts/train_bake.sh 0 CUB200 resnet18 32 3 0.5
50 | sh scripts/train_bake.sh 0 CUB200 densenet121 32 3 0.5
51 | ```
52 |
53 | + Stanford Dogs
54 |
55 | ```
56 | sh scripts/train_bake.sh 0 STANFORD120 resnet18 32 1 0.9
57 | sh scripts/train_bake.sh 0 STANFORD120 densenet121 32 1 0.9
58 | ```
59 |
60 | + MIT67
61 |
62 | ```
63 | sh scripts/train_bake.sh 0 MIT67 resnet18 32 1 0.9
64 | sh scripts/train_bake.sh 0 MIT67 densenet121 32 1 0.9
65 | ```
66 |
67 | ## Baseline Training
68 |
69 | ```
70 | sh scripts/train_baseline.sh
71 | ```
72 |
73 | ## Validation
74 |
75 | ```
76 | sh scripts/val.sh
77 | ```
78 |
79 | ## Results (top-1 error)
80 |
81 | ||CIFAR-100|TinyImageNet|CUB-200-2011|Stanford Dogs|MIT67|
82 | |---|:--:|:--:|:--:|:--:|:--:|
83 | |ResNet-18|21.28|41.71|29.74|30.20|39.95|
84 | |DenseNet-121|20.74|37.07|28.79|27.66|39.15|
85 |
86 | ## Thanks
87 | The code is modified from [CS-KD](https://github.com/alinlab/cs-kd).
88 |
--------------------------------------------------------------------------------
/small_scale/data/README.md:
--------------------------------------------------------------------------------
1 | ## Datasets
2 |
3 | ### Supported Datasets
4 |
5 | + CIFAR-100
6 | + [TinyImageNet](http://cs231n.stanford.edu/tiny-imagenet-200.zip)
7 | + [CUB-200-2011](https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view)
8 | + [Stanford Dogs](http://vision.stanford.edu/aditya86/ImageNetDogs/)
9 | + [MIT67](http://web.mit.edu/torralba/www/indoor.html)
10 |
11 | ### Preparation Steps
12 |
13 | + Download the raw datasets and save them under this directory. Note that there's no need to download CIFAR-100, as it could be downloaded automatically by the code.
14 | + Process them by running the corresponding scripts in `tools/`, e.g. `cd tools && python cub.py`.
15 |
--------------------------------------------------------------------------------
/small_scale/data/tools/cub.py:
--------------------------------------------------------------------------------
1 | # Please download the raw data of CUB_200_2011 dataset from
2 | # https://drive.google.com/file/d/1hbzc_P1FuxMkcabkgn9ZKinBwW683j45/view
3 |
4 | import os
5 | import os.path as osp
6 |
7 | ori='../CUB_200_2011'
8 | tar='../CUB200'
9 | os.makedirs(tar)
10 |
11 | with open(osp.join(ori, 'images.txt'), 'r') as f:
12 | img_list = f.readlines()
13 | with open(osp.join(ori, 'train_test_split.txt'), 'r') as f:
14 | split_list = f.readlines()
15 |
16 | img_dict = {}
17 | for im in img_list:
18 | img_dict[im.split(' ')[0]]=im.strip().split(' ')[1]
19 | split_dict = {}
20 | for s in split_list:
21 | split_dict[s.split(' ')[0]]=int(s.strip().split(' ')[1])
22 |
23 | for k in img_dict.keys():
24 | if split_dict[k]==1:
25 | s='train'
26 | else:
27 | s='test'
28 | if not os.path.isdir(os.path.join(tar,s,osp.dirname(img_dict[k]))):
29 | os.mkdir(os.path.join(tar,s,osp.dirname(img_dict[k])))
30 | os.symlink(osp.join(ori,'images',img_dict[k]),os.path.join(tar,s,img_dict[k]))
31 |
--------------------------------------------------------------------------------
/small_scale/data/tools/mit67.py:
--------------------------------------------------------------------------------
1 | # Please download the raw data of MIT67 dataset from
2 | # http://web.mit.edu/torralba/www/indoor.html
3 |
4 | import os
5 | import os.path as osp
6 | import scipy.io
7 |
8 | root = '../'
9 | ori = root+'Images'
10 | tar = root+'MIT67'
11 | os.makedirs(tar)
12 |
13 | with open(root+'TrainImages.txt', 'r') as f:
14 | train_list = f.readlines()
15 | with open(root+'TestImages.txt', 'r') as f:
16 | test_list = f.readlines()
17 |
18 | def link(file_list, tar_folder):
19 | for name in file_list:
20 | name=name.strip()
21 | assert(osp.isfile(osp.join(ori,name)))
22 | if not osp.isdir(osp.dirname(osp.join(tar_folder, name))):
23 | os.makedirs(osp.dirname(osp.join(tar_folder, name)))
24 | os.symlink(osp.join(ori,name), osp.join(tar_folder, name))
25 |
26 | link(train_list, osp.join(tar,'train'))
27 | link(test_list, osp.join(tar,'test'))
28 |
--------------------------------------------------------------------------------
/small_scale/data/tools/stanford_dogs.py:
--------------------------------------------------------------------------------
1 | # Please download the raw data of Stanford Dogs dataset from
2 | # http://vision.stanford.edu/aditya86/ImageNetDogs/
3 |
4 | import os
5 | import os.path as osp
6 | import scipy.io
7 |
8 | root = '../'
9 | ori = root+'Images'
10 | tar = root+'STANFORD120'
11 | os.makedirs(tar)
12 |
13 | train_list = scipy.io.loadmat(root+'train_list.mat')['annotation_list']
14 | test_list = scipy.io.loadmat(root+'test_list.mat')['annotation_list']
15 |
16 | def link(file_list, tar_folder):
17 | for i in range(file_list.shape[0]):
18 | name = str(file_list[i][0][0])
19 | if not name.endswith('.jpg'):
20 | name = name+'.jpg'
21 | assert(osp.isfile(osp.join(ori,name)))
22 | if not osp.isdir(osp.dirname(osp.join(tar_folder, name))):
23 | os.makedirs(osp.dirname(osp.join(tar_folder, name)))
24 | os.symlink(osp.join(ori,name), osp.join(tar_folder, name))
25 |
26 | link(train_list, osp.join(tar,'train'))
27 | link(test_list, osp.join(tar,'test'))
28 |
--------------------------------------------------------------------------------
/small_scale/data/tools/tinyimagenet.py:
--------------------------------------------------------------------------------
1 | # Please download the raw data of TinyImageNet dataset from
2 | # http://cs231n.stanford.edu/tiny-imagenet-200.zip
3 |
4 | import os
5 | import os.path as osp
6 |
7 | root = '../'
8 | ori = root+'tiny-imagenet-200/train'
9 | tar = root+'tinyimagenet/train'
10 | os.makedirs(tar)
11 |
12 | for dir in os.listdir(ori):
13 | os.makedirs(osp.join(tar,dir))
14 | for file in os.listdir(osp.join(ori,dir,'images')):
15 | os.symlink(osp.join(ori,dir,'images',file), osp.join(tar,dir,file))
16 |
17 |
18 | ori = root+'tiny-imagenet-200/val'
19 | tar = root+'tinyimagenet/val'
20 | os.makedirs(tar)
21 |
22 | with open(root+'tiny-imagenet-200/val/val_annotations.txt','r') as f:
23 | list_file = f.readlines()
24 |
25 | for item in list_file:
26 | item = item.strip()
27 | file = item.split('\t')[0]
28 | dir = item.split('\t')[1]
29 | if not osp.isdir(osp.join(tar,dir)):
30 | os.makedirs(osp.join(tar,dir))
31 | os.symlink(osp.join(ori,'images',file), osp.join(tar,dir,file))
32 |
--------------------------------------------------------------------------------
/small_scale/datasets.py:
--------------------------------------------------------------------------------
1 | import csv, torchvision, numpy as np, random, os
2 | from PIL import Image
3 | import numpy as np
4 | import copy
5 |
6 | from torch.utils.data import Sampler, Dataset, DataLoader, BatchSampler, SequentialSampler, RandomSampler, Subset
7 | from torchvision import transforms, datasets
8 | from collections import defaultdict
9 |
10 |
11 | class IdentityBatchSampler(Sampler):
12 | def __init__(self, dataset, batch_size, num_instances, num_iterations=None):
13 | self.dataset = dataset
14 | self.batch_size = batch_size
15 | self.num_instances = num_instances
16 | self.num_iterations = num_iterations
17 |
18 | def __iter__(self):
19 | indices = list(range(len(self.dataset)))
20 | random.shuffle(indices)
21 | for k in range(len(self)):
22 | offset = k*self.batch_size%len(indices)
23 | batch_indices = indices[offset:offset+self.batch_size]
24 |
25 | pair_indices = []
26 | for idx in batch_indices:
27 | y = self.dataset.get_class(idx)
28 | t = copy.deepcopy(self.dataset.classwise_indices[y])
29 | t.pop(t.index(idx))
30 | if len(t)>=(self.num_instances-1):
31 | class_indices = np.random.choice(t, size=self.num_instances-1, replace=False)
32 | else:
33 | class_indices = np.random.choice(t, size=self.num_instances-1, replace=True)
34 | pair_indices.extend(class_indices)
35 |
36 | yield batch_indices+pair_indices
37 |
38 | def __len__(self):
39 | if self.num_iterations is None:
40 | return (len(self.dataset)+self.batch_size-1) // (self.batch_size)
41 | else:
42 | return self.num_iterations
43 |
44 |
45 | class PairBatchSampler(Sampler):
46 | def __init__(self, dataset, batch_size, num_iterations=None):
47 | self.dataset = dataset
48 | self.batch_size = batch_size
49 | self.num_iterations = num_iterations
50 |
51 | def __iter__(self):
52 | indices = list(range(len(self.dataset)))
53 | random.shuffle(indices)
54 | for k in range(len(self)):
55 | if self.num_iterations is None:
56 | offset = k*self.batch_size
57 | batch_indices = indices[offset:offset+self.batch_size]
58 | else:
59 | batch_indices = random.sample(range(len(self.dataset)),
60 | self.batch_size)
61 |
62 | pair_indices = []
63 | for idx in batch_indices:
64 | y = self.dataset.get_class(idx)
65 | pair_indices.append(random.choice(self.dataset.classwise_indices[y]))
66 |
67 | yield batch_indices + pair_indices
68 |
69 | def __len__(self):
70 | if self.num_iterations is None:
71 | return (len(self.dataset)+self.batch_size-1) // self.batch_size
72 | else:
73 | return self.num_iterations
74 |
75 |
76 | class DatasetWrapper(Dataset):
77 | # Additinoal attributes
78 | # - indices
79 | # - classwise_indices
80 | # - num_classes
81 | # - get_class
82 |
83 | def __init__(self, dataset, indices=None):
84 | self.base_dataset = dataset
85 | if indices is None:
86 | self.indices = list(range(len(dataset)))
87 | else:
88 | self.indices = indices
89 |
90 | # torchvision 0.2.0 compatibility
91 | if torchvision.__version__.startswith('0.2'):
92 | if isinstance(self.base_dataset, datasets.ImageFolder):
93 | self.base_dataset.targets = [s[1] for s in self.base_dataset.imgs]
94 | else:
95 | if self.base_dataset.train:
96 | self.base_dataset.targets = self.base_dataset.train_labels
97 | else:
98 | self.base_dataset.targets = self.base_dataset.test_labels
99 |
100 | self.classwise_indices = defaultdict(list)
101 | for i in range(len(self)):
102 | y = self.base_dataset.targets[self.indices[i]]
103 | self.classwise_indices[y].append(i)
104 | self.num_classes = max(self.classwise_indices.keys())+1
105 |
106 | def __getitem__(self, i):
107 | return self.base_dataset[self.indices[i]]
108 |
109 | def __len__(self):
110 | return len(self.indices)
111 |
112 | def get_class(self, i):
113 | return self.base_dataset.targets[self.indices[i]]
114 |
115 |
116 | class ConcatWrapper(Dataset): # TODO: Naming
117 | @staticmethod
118 | def cumsum(sequence):
119 | r, s = [], 0
120 | for e in sequence:
121 | l = len(e)
122 | r.append(l + s)
123 | s += l
124 | return r
125 |
126 | @staticmethod
127 | def numcls(sequence):
128 | s = 0
129 | for e in sequence:
130 | l = e.num_classes
131 | s += l
132 | return s
133 |
134 | @staticmethod
135 | def clsidx(sequence):
136 | r, s, n = defaultdict(list), 0, 0
137 | for e in sequence:
138 | l = e.classwise_indices
139 | for c in range(s, s + e.num_classes):
140 | t = np.asarray(l[c-s]) + n
141 | r[c] = t.tolist()
142 | s += e.num_classes
143 | n += len(e)
144 | return r
145 |
146 | def __init__(self, datasets):
147 | super(ConcatWrapper, self).__init__()
148 | assert len(datasets) > 0, 'datasets should not be an empty iterable'
149 | self.datasets = list(datasets)
150 | # for d in self.datasets:
151 | # assert not isinstance(d, IterableDataset), "ConcatDataset does not support IterableDataset"
152 | self.cumulative_sizes = self.cumsum(self.datasets)
153 |
154 | self.num_classes = self.numcls(self.datasets)
155 | self.classwise_indices = self.clsidx(self.datasets)
156 |
157 | def __len__(self):
158 | return self.cumulative_sizes[-1]
159 |
160 | def __getitem__(self, idx):
161 | if idx < 0:
162 | if -idx > len(self):
163 | raise ValueError("absolute value of index should not exceed dataset length")
164 | idx = len(self) + idx
165 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
166 | if dataset_idx == 0:
167 | sample_idx = idx
168 | else:
169 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
170 | return self.datasets[dataset_idx][sample_idx]
171 |
172 | def get_class(self, idx):
173 | if idx < 0:
174 | if -idx > len(self):
175 | raise ValueError("absolute value of index should not exceed dataset length")
176 | idx = len(self) + idx
177 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
178 | if dataset_idx == 0:
179 | sample_idx = idx
180 | else:
181 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
182 | true_class = self.datasets[dataset_idx].base_dataset.targets[self.datasets[dataset_idx].indices[sample_idx]]
183 | return self.datasets[dataset_idx].base_dataset.target_transform(true_class)
184 |
185 | @property
186 | def cummulative_sizes(self):
187 | warnings.warn("cummulative_sizes attribute is renamed to "
188 | "cumulative_sizes", DeprecationWarning, stacklevel=2)
189 | return self.cumulative_sizes
190 |
191 |
192 |
193 | def load_dataset(name, root, sample='default', **kwargs):
194 | # Dataset
195 | if name in ['imagenet','tinyimagenet', 'CUB200', 'STANFORD120', 'MIT67']:
196 | # TODO
197 | if name == 'tinyimagenet':
198 | transform_train = transforms.Compose([
199 | transforms.RandomResizedCrop(32),
200 | transforms.RandomHorizontalFlip(),
201 | transforms.ToTensor(),
202 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
203 | ])
204 | transform_test = transforms.Compose([
205 | transforms.Resize(32),
206 | transforms.ToTensor(),
207 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
208 | ])
209 |
210 | train_val_dataset_dir = os.path.join(root, "tinyimagenet/train")
211 | test_dataset_dir = os.path.join(root, "tinyimagenet/val")
212 |
213 | trainset = DatasetWrapper(datasets.ImageFolder(root=train_val_dataset_dir, transform=transform_train))
214 | valset = DatasetWrapper(datasets.ImageFolder(root=test_dataset_dir, transform=transform_test))
215 |
216 | elif name == 'imagenet':
217 | transform_train = transforms.Compose([
218 | transforms.RandomResizedCrop(224),
219 | transforms.RandomHorizontalFlip(),
220 | transforms.ToTensor(),
221 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
222 | ])
223 | transform_test = transforms.Compose([
224 | transforms.Resize(256),
225 | transforms.CenterCrop(224),
226 | transforms.ToTensor(),
227 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
228 | ])
229 | train_val_dataset_dir = os.path.join(root, "train")
230 | test_dataset_dir = os.path.join(root, "val")
231 |
232 | trainset = DatasetWrapper(datasets.ImageFolder(root=train_val_dataset_dir, transform=transform_train))
233 | valset = DatasetWrapper(datasets.ImageFolder(root=test_dataset_dir, transform=transform_test))
234 |
235 | else:
236 | transform_train = transforms.Compose([
237 | transforms.RandomResizedCrop(224),
238 | transforms.RandomHorizontalFlip(),
239 | transforms.ToTensor(),
240 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
241 | ])
242 | transform_test = transforms.Compose([
243 | transforms.Resize(256),
244 | transforms.CenterCrop(224),
245 | transforms.ToTensor(),
246 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
247 | ])
248 |
249 | train_val_dataset_dir = os.path.join(root, name, "train")
250 | test_dataset_dir = os.path.join(root, name, "test")
251 |
252 | trainset = DatasetWrapper(datasets.ImageFolder(root=train_val_dataset_dir, transform=transform_train))
253 | valset = DatasetWrapper(datasets.ImageFolder(root=test_dataset_dir, transform=transform_test))
254 |
255 | elif name.startswith('cifar'):
256 | transform_train = transforms.Compose([
257 | transforms.RandomCrop(32, padding=4),
258 | transforms.RandomHorizontalFlip(),
259 | transforms.ToTensor(),
260 | transforms.Normalize((0.4914, 0.4822, 0.4465),
261 | (0.2023, 0.1994, 0.2010)),
262 | ])
263 | transform_test = transforms.Compose([
264 | transforms.ToTensor(),
265 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
266 | ])
267 |
268 | if name == 'cifar10':
269 | CIFAR = datasets.CIFAR10
270 | else:
271 | CIFAR = datasets.CIFAR100
272 |
273 | trainset = DatasetWrapper(CIFAR(root, train=True, download=True, transform=transform_train))
274 | valset = DatasetWrapper(CIFAR(root, train=False, download=True, transform=transform_test))
275 | else:
276 | raise Exception('Unknown dataset: {}'.format(name))
277 |
278 | # Sampler
279 | if sample == 'default':
280 | #get_train_sampler = lambda d: BatchSampler(RandomSampler(d), kwargs['batch_size'], False)
281 | get_train_sampler = lambda d: IdentityBatchSampler(d, kwargs['batch_size'], kwargs['num_instances'])
282 | get_test_sampler = lambda d: BatchSampler(SequentialSampler(d), kwargs['batch_size'], False)
283 |
284 | elif sample == 'pair':
285 | get_train_sampler = lambda d: PairBatchSampler(d, kwargs['batch_size'])
286 | get_test_sampler = lambda d: BatchSampler(SequentialSampler(d), kwargs['batch_size'], False)
287 |
288 | else:
289 | raise Exception('Unknown sampling: {}'.format(sampling))
290 |
291 | trainloader = DataLoader(trainset, batch_sampler=get_train_sampler(trainset), num_workers=8)
292 | valloader = DataLoader(valset, batch_sampler=get_test_sampler(valset), num_workers=8)
293 |
294 | return trainloader, valloader
295 |
--------------------------------------------------------------------------------
/small_scale/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 | from .densenet import *
3 | from .densenet3 import *
4 |
5 | def load_model(name, num_classes=10, pretrained=False, **kwargs):
6 | model_dict = globals()
7 | model = model_dict[name](pretrained=pretrained, num_classes=num_classes, **kwargs)
8 | return model
9 |
--------------------------------------------------------------------------------
/small_scale/models/densenet.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.utils.checkpoint as cp
6 | from collections import OrderedDict
7 | # from .utils import load_state_dict_from_url
8 | try:
9 | from torch.hub import load_state_dict_from_url
10 | except ImportError:
11 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
12 |
13 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
14 |
15 | model_urls = {
16 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
17 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
18 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
19 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
20 | }
21 |
22 |
23 | def _bn_function_factory(norm, relu, conv):
24 | def bn_function(*inputs):
25 | concated_features = torch.cat(inputs, 1)
26 | bottleneck_output = conv(relu(norm(concated_features)))
27 | return bottleneck_output
28 |
29 | return bn_function
30 |
31 |
32 | class _DenseLayer(nn.Sequential):
33 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False):
34 | super(_DenseLayer, self).__init__()
35 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
36 | self.add_module('relu1', nn.ReLU(inplace=True)),
37 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
38 | growth_rate, kernel_size=1, stride=1,
39 | bias=False)),
40 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
41 | self.add_module('relu2', nn.ReLU(inplace=True)),
42 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
43 | kernel_size=3, stride=1, padding=1,
44 | bias=False)),
45 | self.drop_rate = drop_rate
46 | self.memory_efficient = memory_efficient
47 |
48 | def forward(self, *prev_features):
49 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
50 | if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
51 | bottleneck_output = cp.checkpoint(bn_function, *prev_features)
52 | else:
53 | bottleneck_output = bn_function(*prev_features)
54 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
55 | if self.drop_rate > 0:
56 | new_features = F.dropout(new_features, p=self.drop_rate,
57 | training=self.training)
58 | return new_features
59 |
60 |
61 | class _DenseBlock(nn.Module):
62 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False):
63 | super(_DenseBlock, self).__init__()
64 | for i in range(num_layers):
65 | layer = _DenseLayer(
66 | num_input_features + i * growth_rate,
67 | growth_rate=growth_rate,
68 | bn_size=bn_size,
69 | drop_rate=drop_rate,
70 | memory_efficient=memory_efficient,
71 | )
72 | self.add_module('denselayer%d' % (i + 1), layer)
73 |
74 | def forward(self, init_features):
75 | features = [init_features]
76 | for name, layer in self.named_children():
77 | new_features = layer(*features)
78 | features.append(new_features)
79 | return torch.cat(features, 1)
80 |
81 |
82 | class _Transition(nn.Sequential):
83 | def __init__(self, num_input_features, num_output_features):
84 | super(_Transition, self).__init__()
85 | self.add_module('norm', nn.BatchNorm2d(num_input_features))
86 | self.add_module('relu', nn.ReLU(inplace=True))
87 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
88 | kernel_size=1, stride=1, bias=False))
89 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
90 |
91 |
92 | class DenseNet(nn.Module):
93 | r"""Densenet-BC model class, based on
94 | `"Densely Connected Convolutional Networks" `_
95 | Args:
96 | growth_rate (int) - how many filters to add each layer (`k` in paper)
97 | block_config (list of 4 ints) - how many layers in each pooling block
98 | num_init_features (int) - the number of filters to learn in the first convolution layer
99 | bn_size (int) - multiplicative factor for number of bottle neck layers
100 | (i.e. bn_size * k features in the bottleneck layer)
101 | drop_rate (float) - dropout rate after each dense layer
102 | num_classes (int) - number of classification classes
103 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
104 | but slower. Default: *False*. See `"paper" `_
105 | """
106 |
107 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
108 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False, bias=True):
109 |
110 | super(DenseNet, self).__init__()
111 |
112 | # First convolution
113 | self.features = nn.Sequential(OrderedDict([
114 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
115 | padding=3, bias=False)),
116 | ('norm0', nn.BatchNorm2d(num_init_features)),
117 | ('relu0', nn.ReLU(inplace=True)),
118 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
119 | ]))
120 |
121 | # Each denseblock
122 | num_features = num_init_features
123 | for i, num_layers in enumerate(block_config):
124 | block = _DenseBlock(
125 | num_layers=num_layers,
126 | num_input_features=num_features,
127 | bn_size=bn_size,
128 | growth_rate=growth_rate,
129 | drop_rate=drop_rate,
130 | memory_efficient=memory_efficient
131 | )
132 | self.features.add_module('denseblock%d' % (i + 1), block)
133 | num_features = num_features + num_layers * growth_rate
134 | if i != len(block_config) - 1:
135 | trans = _Transition(num_input_features=num_features,
136 | num_output_features=num_features // 2)
137 | self.features.add_module('transition%d' % (i + 1), trans)
138 | num_features = num_features // 2
139 |
140 | # Final batch norm
141 | self.features.add_module('norm5', nn.BatchNorm2d(num_features))
142 |
143 | # Linear layer
144 | self.classifier = nn.Linear(num_features, num_classes, bias=bias)
145 |
146 | # Official init from torch repo.
147 | for m in self.modules():
148 | if isinstance(m, nn.Conv2d):
149 | nn.init.kaiming_normal_(m.weight)
150 | elif isinstance(m, nn.BatchNorm2d):
151 | nn.init.constant_(m.weight, 1)
152 | nn.init.constant_(m.bias, 0)
153 | elif isinstance(m, nn.Linear):
154 | if bias:
155 | nn.init.constant_(m.bias, 0)
156 |
157 | def forward(self, x):
158 | features = self.features(x)
159 | out = F.relu(features, inplace=True)
160 | out = F.adaptive_avg_pool2d(out, (1, 1))
161 | out = torch.flatten(out, 1)
162 | prob = self.classifier(out)
163 | return out, prob
164 |
165 |
166 |
167 | def _load_state_dict(model, model_url, progress):
168 | # '.'s are no longer allowed in module names, but previous _DenseLayer
169 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
170 | # They are also in the checkpoints in model_urls. This pattern is used
171 | # to find such keys.
172 | pattern = re.compile(
173 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
174 |
175 | state_dict = load_state_dict_from_url(model_url, progress=progress)
176 | for key in list(state_dict.keys()):
177 | res = pattern.match(key)
178 | if res:
179 | new_key = res.group(1) + res.group(2)
180 | state_dict[new_key] = state_dict[key]
181 | del state_dict[key]
182 | model.load_state_dict(state_dict)
183 |
184 |
185 | def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress, num_classes, bias,
186 | **kwargs):
187 | model = DenseNet(growth_rate, block_config, num_init_features, num_classes=num_classes, bias=bias, **kwargs)
188 | if pretrained:
189 | _load_state_dict(model, model_urls[arch], progress)
190 | return model
191 |
192 |
193 | def densenet121(pretrained=False, progress=True, num_classes=10, bias=True, **kwargs):
194 | r"""Densenet-121 model from
195 | `"Densely Connected Convolutional Networks" `_
196 | Args:
197 | pretrained (bool): If True, returns a model pre-trained on ImageNet
198 | progress (bool): If True, displays a progress bar of the download to stderr
199 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
200 | but slower. Default: *False*. See `"paper" `_
201 | """
202 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, num_classes=num_classes, bias=bias,
203 | **kwargs)
204 |
205 |
206 | def densenet161(pretrained=False, progress=True, **kwargs):
207 | r"""Densenet-161 model from
208 | `"Densely Connected Convolutional Networks" `_
209 | Args:
210 | pretrained (bool): If True, returns a model pre-trained on ImageNet
211 | progress (bool): If True, displays a progress bar of the download to stderr
212 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
213 | but slower. Default: *False*. See `"paper" `_
214 | """
215 | return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress,
216 | **kwargs)
217 |
218 |
219 | def densenet169(pretrained=False, progress=True, **kwargs):
220 | r"""Densenet-169 model from
221 | `"Densely Connected Convolutional Networks" `_
222 | Args:
223 | pretrained (bool): If True, returns a model pre-trained on ImageNet
224 | progress (bool): If True, displays a progress bar of the download to stderr
225 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
226 | but slower. Default: *False*. See `"paper" `_
227 | """
228 | return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress,
229 | **kwargs)
230 |
231 |
232 | def densenet201(pretrained=False, progress=True, **kwargs):
233 | r"""Densenet-201 model from
234 | `"Densely Connected Convolutional Networks" `_
235 | Args:
236 | pretrained (bool): If True, returns a model pre-trained on ImageNet
237 | progress (bool): If True, displays a progress bar of the download to stderr
238 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
239 | but slower. Default: *False*. See `"paper" `_
240 | """
241 | return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress,
242 | **kwargs)
243 |
--------------------------------------------------------------------------------
/small_scale/models/densenet3.py:
--------------------------------------------------------------------------------
1 | '''DenseNet in PyTorch.'''
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class Bottleneck(nn.Module):
10 | def __init__(self, in_planes, growth_rate):
11 | super(Bottleneck, self).__init__()
12 | self.bn1 = nn.BatchNorm2d(in_planes)
13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(4*growth_rate)
15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
16 |
17 | def forward(self, x):
18 | out = self.conv1(F.relu(self.bn1(x)))
19 | out = self.conv2(F.relu(self.bn2(out)))
20 | out = torch.cat([out,x], 1)
21 | return out
22 |
23 |
24 | class Transition(nn.Module):
25 | def __init__(self, in_planes, out_planes):
26 | super(Transition, self).__init__()
27 | self.bn = nn.BatchNorm2d(in_planes)
28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)
29 |
30 | def forward(self, x):
31 | out = self.conv(F.relu(self.bn(x)))
32 | out = F.avg_pool2d(out, 2)
33 | return out
34 |
35 |
36 | class CIFAR_DenseNet(nn.Module):
37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10, bias=True):
38 | super(CIFAR_DenseNet, self).__init__()
39 | self.growth_rate = growth_rate
40 |
41 | num_planes = 2*growth_rate
42 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)
43 |
44 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0])
45 | num_planes += nblocks[0]*growth_rate
46 | out_planes = int(math.floor(num_planes*reduction))
47 | self.trans1 = Transition(num_planes, out_planes)
48 | num_planes = out_planes
49 |
50 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1])
51 | num_planes += nblocks[1]*growth_rate
52 | out_planes = int(math.floor(num_planes*reduction))
53 | self.trans2 = Transition(num_planes, out_planes)
54 | num_planes = out_planes
55 |
56 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2])
57 | num_planes += nblocks[2]*growth_rate
58 | out_planes = int(math.floor(num_planes*reduction))
59 | self.trans3 = Transition(num_planes, out_planes)
60 | num_planes = out_planes
61 |
62 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3])
63 | num_planes += nblocks[3]*growth_rate
64 |
65 | self.bn = nn.BatchNorm2d(num_planes)
66 | self.linear = nn.Linear(num_planes, num_classes, bias=bias)
67 |
68 | def _make_dense_layers(self, block, in_planes, nblock):
69 | layers = []
70 | for i in range(nblock):
71 | layers.append(block(in_planes, self.growth_rate))
72 | in_planes += self.growth_rate
73 | return nn.Sequential(*layers)
74 |
75 | def forward(self, x):
76 | out = self.conv1(x)
77 | out = self.trans1(self.dense1(out))
78 | out = self.trans2(self.dense2(out))
79 | out = self.trans3(self.dense3(out))
80 | out = self.dense4(out)
81 | out = F.avg_pool2d(F.relu(self.bn(out)), 4)
82 | out = out.view(out.size(0), -1)
83 | prob = self.linear(out)
84 | return out, prob
85 |
86 | def CIFAR_DenseNet121(pretrained=False, num_classes=10, bias=True, **kwargs):
87 | return CIFAR_DenseNet(Bottleneck, [6,12,24,16], growth_rate=32, num_classes=num_classes, bias=bias)
88 |
89 | # def DenseNet169():
90 | # return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32)
91 |
92 | # def DenseNet201():
93 | # return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32)
94 |
95 | # def DenseNet161():
96 | # return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48)
97 |
98 | # def densenet_cifar():
99 | # return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12)
100 |
101 | # def test():
102 | # net = densenet_cifar()
103 | # x = torch.randn(1,3,32,32)
104 | # y = net(x)
105 | # print(y)
106 |
107 | # test()
108 |
--------------------------------------------------------------------------------
/small_scale/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.model_zoo as model_zoo
3 | import torch.nn.functional as F
4 |
5 | __all__ = ['ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'CIFAR_ResNet', 'CIFAR_ResNet18', 'CIFAR_ResNet34', 'CIFAR_ResNet10']
7 |
8 |
9 | model_urls = {
10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15 | }
16 |
17 |
18 | def conv3x3(in_planes, out_planes, stride=1, groups=1):
19 | """3x3 convolution with padding"""
20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
21 | padding=1, groups=groups, bias=False)
22 |
23 |
24 | def conv1x1(in_planes, out_planes, stride=1):
25 | """1x1 convolution"""
26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
27 |
28 |
29 | class BasicBlock(nn.Module):
30 | expansion = 1
31 |
32 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
33 | base_width=64, norm_layer=None):
34 | super(BasicBlock, self).__init__()
35 | if norm_layer is None:
36 | norm_layer = nn.BatchNorm2d
37 | if groups != 1 or base_width != 64:
38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
39 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
40 | self.conv1 = conv3x3(inplanes, planes, stride)
41 | self.bn1 = norm_layer(planes)
42 | self.relu = nn.ReLU(inplace=True)
43 | self.conv2 = conv3x3(planes, planes)
44 | self.bn2 = norm_layer(planes)
45 | self.downsample = downsample
46 | self.stride = stride
47 |
48 | def forward(self, x):
49 | identity = x
50 |
51 | out = self.conv1(x)
52 | out = self.bn1(out)
53 | out = self.relu(out)
54 |
55 | out = self.conv2(out)
56 | out = self.bn2(out)
57 |
58 | if self.downsample is not None:
59 | identity = self.downsample(x)
60 |
61 | out += identity
62 | out = self.relu(out)
63 |
64 | return out
65 |
66 |
67 | class Bottleneck(nn.Module):
68 | expansion = 4
69 |
70 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
71 | base_width=64, norm_layer=None):
72 | super(Bottleneck, self).__init__()
73 | if norm_layer is None:
74 | norm_layer = nn.BatchNorm2d
75 | width = int(planes * (base_width / 64.)) * groups
76 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
77 | self.conv1 = conv1x1(inplanes, width)
78 | self.bn1 = norm_layer(width)
79 | self.conv2 = conv3x3(width, width, stride, groups)
80 | self.bn2 = norm_layer(width)
81 | self.conv3 = conv1x1(width, planes * self.expansion)
82 | self.bn3 = norm_layer(planes * self.expansion)
83 | self.relu = nn.ReLU(inplace=True)
84 | self.downsample = downsample
85 | self.stride = stride
86 |
87 | def forward(self, x):
88 | identity = x
89 |
90 | out = self.conv1(x)
91 | out = self.bn1(out)
92 | out = self.relu(out)
93 |
94 | out = self.conv2(out)
95 | out = self.bn2(out)
96 | out = self.relu(out)
97 |
98 | out = self.conv3(out)
99 | out = self.bn3(out)
100 |
101 | if self.downsample is not None:
102 | identity = self.downsample(x)
103 |
104 | out += identity
105 | out = self.relu(out)
106 |
107 | return out
108 |
109 | class PreActBlock(nn.Module):
110 | '''Pre-activation version of the BasicBlock.'''
111 | expansion = 1
112 |
113 | def __init__(self, in_planes, planes, stride=1):
114 | super(PreActBlock, self).__init__()
115 | self.bn1 = nn.BatchNorm2d(in_planes)
116 | self.conv1 = conv3x3(in_planes, planes, stride)
117 | self.bn2 = nn.BatchNorm2d(planes)
118 | self.conv2 = conv3x3(planes, planes)
119 | self.relu = nn.ReLU(inplace=True)
120 |
121 | self.shortcut = nn.Sequential()
122 | if stride != 1 or in_planes != self.expansion*planes:
123 | self.shortcut = nn.Sequential(
124 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
125 | )
126 |
127 | def forward(self, x):
128 | out = self.relu(self.bn1(x))
129 | shortcut = self.shortcut(out)
130 | out = self.conv1(out)
131 | out = self.conv2(self.relu(self.bn2(out)))
132 | out += shortcut
133 | return out
134 |
135 | class ResNet(nn.Module):
136 |
137 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
138 | groups=1, width_per_group=64, norm_layer=None):
139 | super(ResNet, self).__init__()
140 | if norm_layer is None:
141 | norm_layer = nn.BatchNorm2d
142 |
143 | self.inplanes = 64
144 | self.groups = groups
145 | self.base_width = width_per_group
146 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
147 | bias=False)
148 | self.bn1 = norm_layer(self.inplanes)
149 | self.relu = nn.ReLU(inplace=True)
150 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
151 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
152 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, norm_layer=norm_layer)
154 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, norm_layer=norm_layer)
155 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
156 | self.fc = nn.Linear(512 * block.expansion, num_classes)
157 | for m in self.modules():
158 | if isinstance(m, nn.Conv2d):
159 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
160 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
161 | nn.init.constant_(m.weight, 1)
162 | nn.init.constant_(m.bias, 0)
163 |
164 | # Zero-initialize the last BN in each residual branch,
165 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
166 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
167 | if zero_init_residual:
168 | for m in self.modules():
169 | if isinstance(m, Bottleneck):
170 | nn.init.constant_(m.bn3.weight, 0)
171 | elif isinstance(m, BasicBlock):
172 | nn.init.constant_(m.bn2.weight, 0)
173 |
174 | def _make_layer(self, block, planes, blocks, stride=1, norm_layer=None):
175 | if norm_layer is None:
176 | norm_layer = nn.BatchNorm2d
177 | downsample = None
178 | if stride != 1 or self.inplanes != planes * block.expansion:
179 | downsample = nn.Sequential(
180 | conv1x1(self.inplanes, planes * block.expansion, stride),
181 | norm_layer(planes * block.expansion),
182 | )
183 |
184 | layers = []
185 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
186 | self.base_width, norm_layer))
187 | self.inplanes = planes * block.expansion
188 | for _ in range(1, blocks):
189 | layers.append(block(self.inplanes, planes, groups=self.groups,
190 | base_width=self.base_width, norm_layer=norm_layer))
191 |
192 | return nn.Sequential(*layers)
193 |
194 | def forward(self, x):
195 | x = self.conv1(x)
196 | x = self.bn1(x)
197 | x = self.relu(x)
198 | x = self.maxpool(x)
199 |
200 | x = self.layer1(x)
201 | x = self.layer2(x)
202 | x = self.layer3(x)
203 | x = self.layer4(x)
204 |
205 | x = self.avgpool(x)
206 | x = x.view(x.size(0), -1)
207 | prob = self.fc(x)
208 |
209 | return x, prob
210 |
211 | class CIFAR_ResNet(nn.Module):
212 | def __init__(self, block, num_blocks, num_classes=10, bias=True):
213 | super(CIFAR_ResNet, self).__init__()
214 | self.in_planes = 64
215 | self.conv1 = conv3x3(3,64)
216 | self.bn1 = nn.BatchNorm2d(64)
217 | self.relu = nn.ReLU(inplace=True)
218 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
219 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
220 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
221 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
222 | self.gap = nn.AvgPool2d(4)
223 | self.linear = nn.Linear(512*block.expansion, num_classes, bias=bias)
224 |
225 |
226 | def _make_layer(self, block, planes, num_blocks, stride):
227 | strides = [stride] + [1]*(num_blocks-1)
228 | layers = []
229 | for stride in strides:
230 | layers.append(block(self.in_planes, planes, stride))
231 | self.in_planes = planes * block.expansion
232 | return nn.Sequential(*layers)
233 |
234 | def forward(self, x, lin=0, lout=5):
235 | out = x
236 | out = self.conv1(out)
237 | out = self.bn1(out)
238 | out = self.relu(out)
239 | out1 = self.layer1(out)
240 | out2 = self.layer2(out1)
241 | out3 = self.layer3(out2)
242 | out = self.layer4(out3)
243 | #out = F.avg_pool2d(out, 4)
244 | out = self.gap(out)
245 | out4 = out.view(out.size(0), -1)
246 | out = self.linear(out4)
247 |
248 | return out4, out
249 |
250 |
251 | def CIFAR_ResNet10(pretrained=False, **kwargs):
252 | return CIFAR_ResNet(PreActBlock, [1,1,1,1], **kwargs)
253 |
254 | def CIFAR_ResNet18(pretrained=False, **kwargs):
255 | return CIFAR_ResNet(PreActBlock, [2,2,2,2], **kwargs)
256 |
257 | def CIFAR_ResNet34(pretrained=False, **kwargs):
258 | return CIFAR_ResNet(PreActBlock, [3,4,6,3], **kwargs)
259 |
260 | def resnet10(pretrained=False, **kwargs):
261 | """Constructs a ResNet-10 model.
262 | Args:
263 | pretrained (bool): If True, returns a model pre-trained on ImageNet
264 | """
265 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs)
266 | return model
267 |
268 | def resnet18(pretrained=False, **kwargs):
269 | """Constructs a ResNet-18 model.
270 | Args:
271 | pretrained (bool): If True, returns a model pre-trained on ImageNet
272 | """
273 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
274 | if pretrained:
275 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
276 | return model
277 |
278 |
279 | def my_resnet34(pretrained=False, **kwargs):
280 | model = my_ResNet(BasicBlock, [3, 4, 6, 3, 3], **kwargs)
281 | return model
282 |
283 | def resnet34(pretrained=False, **kwargs):
284 | """Constructs a ResNet-34 model.
285 | Args:
286 | pretrained (bool): If True, returns a model pre-trained on ImageNet
287 | """
288 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
289 | if pretrained:
290 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
291 | return model
292 |
293 |
294 | def resnet50(pretrained=False, **kwargs):
295 | """Constructs a ResNet-50 model.
296 | Args:
297 | pretrained (bool): If True, returns a model pre-trained on ImageNet
298 | """
299 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
300 | if pretrained:
301 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
302 | return model
303 |
304 |
305 | def resnet101(pretrained=False, **kwargs):
306 | """Constructs a ResNet-101 model.
307 | Args:
308 | pretrained (bool): If True, returns a model pre-trained on ImageNet
309 | """
310 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
311 | if pretrained:
312 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
313 | return model
314 |
315 |
316 | def resnet152(pretrained=False, **kwargs):
317 | """Constructs a ResNet-152 model.
318 | Args:
319 | pretrained (bool): If True, returns a model pre-trained on ImageNet
320 | """
321 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
322 | if pretrained:
323 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
324 | return model
325 |
326 |
327 | def resnext50_32x4d(pretrained=False, **kwargs):
328 | model = ResNet(Bottleneck, [3, 4, 6, 3], groups=32, width_per_group=4, **kwargs)
329 | # if pretrained:
330 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext50_32x4d']))
331 | return model
332 |
333 |
334 | def resnext101_32x8d(pretrained=False, **kwargs):
335 | model = ResNet(Bottleneck, [3, 4, 23, 3], groups=32, width_per_group=8, **kwargs)
336 | # if pretrained:
337 | # model.load_state_dict(model_zoo.load_url(model_urls['resnext101_32x8d']))
338 | return model
339 |
--------------------------------------------------------------------------------
/small_scale/scripts/train_bake.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | GPU=$1
3 | DATA=$2
4 | ARCH=$3
5 | N=$4
6 | M=$5
7 | OMEGA=$6
8 |
9 | if [ $# -ne 6 ]
10 | then
11 | echo "Arguments error: "
12 | exit 1
13 | fi
14 |
15 | python train.py \
16 | --lr 0.1 \
17 | --decay 1e-4 \
18 | --epoch 200 \
19 | --lamda 1.0 \
20 | --temp 4.0 \
21 | --sgpu $GPU \
22 | -d $DATA \
23 | -a $ARCH \
24 | -n $N \
25 | -m $M \
26 | --omega $OMEGA \
27 | --name BAKE_$DATA_$ARCH
28 |
--------------------------------------------------------------------------------
/small_scale/scripts/train_baseline.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | GPU=$1
3 | DATA=$2
4 | ARCH=$3
5 | N=$4
6 |
7 | if [ $# -ne 4 ]
8 | then
9 | echo "Arguments error: "
10 | exit 1
11 | fi
12 |
13 | python train.py \
14 | --lr 0.1 \
15 | --decay 1e-4 \
16 | --epoch 200 \
17 | --lamda 0.0 \
18 | -m 0 \
19 | --omega 0.0 \
20 | --sgpu $GPU \
21 | -d $DATA \
22 | -a $ARCH \
23 | -n $N \
24 | --name baseline_$DATA_$ARCH
25 |
--------------------------------------------------------------------------------
/small_scale/scripts/val.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | GPU=$1
3 | DATA=$2
4 | ARCH=$3
5 | CKPT=$4
6 |
7 | if [ $# -ne 4 ]
8 | then
9 | echo "Arguments error: "
10 | exit 1
11 | fi
12 |
13 | python train.py \
14 | --eval \
15 | --resume $CKPT \
16 | --sgpu $GPU \
17 | -d $DATA \
18 | -a $ARCH \
19 | -n 64 \
20 | -m 0 \
21 | --name test
22 |
--------------------------------------------------------------------------------
/small_scale/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import argparse
4 | import csv
5 | import os, logging
6 | import random
7 |
8 | import numpy as np
9 | import torch
10 | from torch.autograd import Variable, grad
11 | import torch.backends.cudnn as cudnn
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | import torch.optim as optim
15 | import torchvision.transforms as transforms
16 |
17 | import models
18 | from utils import progress_bar, set_logging_defaults
19 | from datasets import load_dataset
20 |
21 | parser = argparse.ArgumentParser(description='Small-scale Datasets Training')
22 |
23 | parser.add_argument('--name', default='cifar_res18_train', type=str, help='name of experiment')
24 | parser.add_argument('--seed', default=0, type=int, help='random seed')
25 | parser.add_argument('--arch', '-a', default="CIFAR_ResNet18", type=str, help='model type (32x32: CIFAR_ResNet18, CIFAR_DenseNet121, 224x224: resnet18, densenet121)')
26 | parser.add_argument('--resume', '-r', default="", help='resume from checkpoint')
27 | parser.add_argument('--eval', action='store_true', help='only evaluate')
28 | parser.add_argument('--sgpu', default=0, type=int, help='gpu index (start)')
29 | parser.add_argument('--ngpu', default=1, type=int, help='number of gpu')
30 | parser.add_argument('--dataroot', default='./data', type=str, help='data directory')
31 | parser.add_argument('--saveroot', default='./results', type=str, help='save directory')
32 | parser.add_argument('--dataset', '-d', default='cifar100', type=str, help='the name for dataset cifar100 | tinyimagenet | CUB200 | STANFORD120 | MIT67')
33 |
34 | parser.add_argument('--epoch', default=200, type=int, help='total epochs to run')
35 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
36 | parser.add_argument('--decay', default=1e-4, type=float, help='weight decay')
37 | parser.add_argument('--batch-size', '-n', default=128, type=int, help='batch size, N')
38 | parser.add_argument('--intra-imgs', '-m', default=3, type=int, help='intra-class images, M')
39 |
40 | parser.add_argument('--temp', default=4.0, type=float, help='temperature scaling')
41 | parser.add_argument('--lamda', default=1.0, type=float, help='kd loss weight ratio')
42 | parser.add_argument('--omega', default=0.5, type=float, help='ensembling weight')
43 |
44 | args = parser.parse_args()
45 | use_cuda = torch.cuda.is_available()
46 | args.num_instances = args.intra_imgs + 1
47 | args.batch_size = args.batch_size // args.num_instances
48 |
49 | best_val = 0 # best validation accuracy
50 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
51 |
52 | random.seed(args.seed)
53 | np.random.seed(args.seed)
54 | torch.manual_seed(args.seed)
55 | torch.cuda.manual_seed_all(args.seed)
56 |
57 | cudnn.benchmark = True
58 |
59 | # Data
60 | print('==> Preparing dataset: {}'.format(args.dataset))
61 | trainloader, valloader = load_dataset(args.dataset, args.dataroot,
62 | batch_size=args.batch_size,
63 | num_instances=args.num_instances)
64 |
65 | num_class = trainloader.dataset.num_classes
66 | print('Number of train dataset: ' ,len(trainloader.dataset))
67 | print('Number of validation dataset: ' ,len(valloader.dataset))
68 |
69 | # Model
70 | print('==> Building model: {}'.format(args.arch))
71 | net = models.load_model(args.arch, num_class)
72 |
73 | if use_cuda:
74 | torch.cuda.set_device(args.sgpu)
75 | net.cuda()
76 | print(torch.cuda.device_count())
77 | print('Using CUDA..')
78 |
79 | if args.ngpu > 1:
80 | net = torch.nn.DataParallel(net, device_ids=list(range(args.sgpu, args.sgpu + args.ngpu)))
81 |
82 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.decay)
83 |
84 | logdir = os.path.join(args.saveroot, args.dataset, args.arch, args.name)
85 | set_logging_defaults(logdir, args)
86 | logger = logging.getLogger('main')
87 | logname = os.path.join(logdir, 'log.csv')
88 |
89 | # Resume
90 | if args.resume:
91 | # Load checkpoint.
92 | print('==> Resuming from checkpoint..')
93 | checkpoint = torch.load(args.resume)
94 | net.load_state_dict(checkpoint['net'])
95 | optimizer.load_state_dict(checkpoint['optimizer'])
96 | best_acc = checkpoint['acc']
97 | start_epoch = checkpoint['epoch'] + 1
98 | rng_state = checkpoint['rng_state']
99 | torch.set_rng_state(rng_state)
100 |
101 | criterion = nn.CrossEntropyLoss()
102 |
103 | class KDLoss(nn.Module):
104 | def __init__(self, temp_factor):
105 | super(KDLoss, self).__init__()
106 | self.temp_factor = temp_factor
107 | self.kl_div = nn.KLDivLoss(reduction="sum")
108 |
109 | def forward(self, input, target):
110 | log_p = torch.log_softmax(input/self.temp_factor, dim=1)
111 | loss = self.kl_div(log_p, target)*(self.temp_factor**2)/input.size(0)
112 | return loss
113 |
114 | kdloss = KDLoss(args.temp)
115 | softmax = nn.Softmax(dim=1)
116 |
117 | def knowledge_ensemble(feats, logits):
118 | batch_size = logits.size(0)
119 | masks = torch.eye(batch_size)
120 | if use_cuda:
121 | masks = masks.cuda()
122 | feats = nn.functional.normalize(feats, p=2, dim=1)
123 | logits = nn.functional.softmax(logits/args.temp, dim=1)
124 | W = torch.matmul(feats, feats.permute(1, 0)) - masks * 1e9
125 | W = softmax(W)
126 | W = (1 - args.omega) * torch.inverse(masks - args.omega * W)
127 | return torch.matmul(W, logits)
128 |
129 | def train(epoch):
130 | print('\nEpoch: %d' % epoch)
131 | net.train()
132 | train_ce_loss = 0
133 | correct = 0
134 | total = 0
135 | train_kd_loss = 0
136 | for batch_idx, (inputs, targets) in enumerate(trainloader):
137 | if use_cuda:
138 | inputs, targets = inputs.cuda(), targets.cuda()
139 |
140 | batch_size = inputs.size(0)
141 |
142 | features, outputs = net(inputs)
143 | loss = criterion(outputs, targets)
144 | train_ce_loss += loss.item()
145 |
146 | ############
147 | with torch.no_grad():
148 | kd_targets = knowledge_ensemble(features.detach(), outputs.detach())
149 | kd_loss = kdloss(outputs, kd_targets.detach())
150 | loss += args.lamda * kd_loss
151 | train_kd_loss += kd_loss.item()
152 | ############
153 |
154 | _, predicted = torch.max(outputs, 1)
155 | total += targets.size(0)
156 | correct += predicted.eq(targets.data).sum().float().cpu()
157 |
158 | optimizer.zero_grad()
159 | loss.backward()
160 | optimizer.step()
161 | progress_bar(batch_idx, len(trainloader),
162 | 'CE loss: %.3f | KD loss: %.3f | Acc: %.3f%% (%d/%d)'
163 | % (train_ce_loss/(batch_idx+1), train_kd_loss/(batch_idx+1), 100.*correct/total, correct, total))
164 |
165 | logger = logging.getLogger('train')
166 | logger.info('[Epoch {}] [CE loss {:.3f}] [KD loss {:.3f}] [Acc {:.3f}]'.format(
167 | epoch,
168 | train_ce_loss/(batch_idx+1),
169 | train_kd_loss/(batch_idx+1),
170 | 100.*correct/total))
171 |
172 | return 100.*correct/total
173 |
174 | def val(epoch):
175 | global best_val
176 | net.eval()
177 | val_loss = 0.0
178 | correct = 0.0
179 | total = 0.0
180 |
181 | # Define a data loader for evaluating
182 | loader = valloader
183 |
184 | with torch.no_grad():
185 | for batch_idx, (inputs, targets) in enumerate(loader):
186 | if use_cuda:
187 | inputs, targets = inputs.cuda(), targets.cuda()
188 |
189 | _, outputs = net(inputs)
190 | loss = torch.mean(criterion(outputs, targets))
191 |
192 | val_loss += loss.item()
193 | _, predicted = torch.max(outputs, 1)
194 | total += targets.size(0)
195 | correct += predicted.eq(targets.data).cpu().sum().float()
196 |
197 | progress_bar(batch_idx, len(loader),
198 | 'Loss: %.3f | Acc: %.3f%% (%d/%d) '
199 | % (val_loss/(batch_idx+1), 100.*correct/total, correct, total))
200 |
201 | acc = 100.*correct/total
202 | if acc > best_val:
203 | best_val = acc
204 | checkpoint(acc, epoch)
205 | logger = logging.getLogger('val')
206 | logger.info('[Epoch {}] [Loss {:.3f}] [Acc {:.3f}] [Best Acc {:.3f}]'.format(
207 | epoch,
208 | val_loss/(batch_idx+1),
209 | acc, best_val))
210 |
211 | return (val_loss/(batch_idx+1), acc)
212 |
213 |
214 | def checkpoint(acc, epoch):
215 | # Save checkpoint.
216 | print('Saving..')
217 | state = {
218 | 'net': net.state_dict(),
219 | 'optimizer': optimizer.state_dict(),
220 | 'acc': acc,
221 | 'epoch': epoch,
222 | 'rng_state': torch.get_rng_state()
223 | }
224 | torch.save(state, os.path.join(logdir, 'ckpt.t7'))
225 |
226 |
227 | def adjust_learning_rate(optimizer, epoch):
228 | """decrease the learning rate at 100 and 150 epoch"""
229 | lr = args.lr
230 | if epoch >= 0.5 * args.epoch:
231 | lr /= 10
232 | if epoch >= 0.75 * args.epoch:
233 | lr /= 10
234 | for param_group in optimizer.param_groups:
235 | param_group['lr'] = lr
236 |
237 | if (not args.eval):
238 | # Logs
239 | for epoch in range(start_epoch, args.epoch):
240 | train_acc = train(epoch)
241 | val_loss, val_acc = val(epoch)
242 | adjust_learning_rate(optimizer, epoch)
243 | else:
244 | val_loss, val_acc = val(0)
245 |
246 | print("Best Accuracy : {}".format(best_val))
247 | logger = logging.getLogger('best')
248 | logger.info('[Acc {:.3f}]'.format(best_val))
249 |
--------------------------------------------------------------------------------
/small_scale/utils.py:
--------------------------------------------------------------------------------
1 | import os, logging
2 | import sys
3 | import time
4 | import math
5 | import shutil
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.init as init
10 |
11 |
12 | def set_logging_defaults(logdir, args):
13 | if os.path.isdir(logdir):
14 | res = input('"{}" exists. Overwrite [Y/n]? '.format(logdir))
15 | if res != 'Y':
16 | raise Exception('"{}" exists.'.format(logdir))
17 | else:
18 | os.makedirs(logdir)
19 |
20 | # set basic configuration for logging
21 | logging.basicConfig(format="[%(asctime)s] [%(name)s] %(message)s",
22 | level=logging.INFO,
23 | handlers=[logging.FileHandler(os.path.join(logdir, 'log.txt')),
24 | logging.StreamHandler(os.sys.stdout)])
25 |
26 | # log cmdline argumetns
27 | logger = logging.getLogger('main')
28 | logger.info(' '.join(os.sys.argv))
29 | logger.info(args)
30 |
31 | # _, term_width = os.popen('stty size', 'r').read().split()
32 | _, term_width = shutil.get_terminal_size()
33 | term_width = int(term_width)
34 |
35 | TOTAL_BAR_LENGTH = 86.
36 | last_time = time.time()
37 | begin_time = last_time
38 |
39 | def progress_bar(current, total, msg=None):
40 | global last_time, begin_time
41 | if current == 0:
42 | begin_time = time.time() # Reset for new bar.
43 |
44 | cur_len = int(TOTAL_BAR_LENGTH*current/total)
45 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
46 |
47 | sys.stdout.write(' [')
48 | for i in range(cur_len):
49 | sys.stdout.write('=')
50 | sys.stdout.write('>')
51 | for i in range(rest_len):
52 | sys.stdout.write('.')
53 | sys.stdout.write(']')
54 |
55 | cur_time = time.time()
56 | step_time = cur_time - last_time
57 | last_time = cur_time
58 | tot_time = cur_time - begin_time
59 |
60 | L = []
61 | L.append(' Step: %s' % format_time(step_time))
62 | L.append(' | Tot: %s' % format_time(tot_time))
63 | if msg:
64 | L.append(' | ' + msg)
65 |
66 | msg = ''.join(L)
67 | sys.stdout.write(msg)
68 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
69 | sys.stdout.write(' ')
70 |
71 | # Go back to the center of the bar.
72 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)):
73 | sys.stdout.write('\b')
74 | sys.stdout.write(' %d/%d ' % (current+1, total))
75 |
76 | if current < total-1:
77 | sys.stdout.write('\r')
78 | else:
79 | sys.stdout.write('\n')
80 | sys.stdout.flush()
81 |
82 | def format_time(seconds):
83 | days = int(seconds / 3600/24)
84 | seconds = seconds - days*3600*24
85 | hours = int(seconds / 3600)
86 | seconds = seconds - hours*3600
87 | minutes = int(seconds / 60)
88 | seconds = seconds - minutes*60
89 | secondsf = int(seconds)
90 | seconds = seconds - secondsf
91 | millis = int(seconds*1000)
92 |
93 | f = ''
94 | i = 1
95 | if days > 0:
96 | f += str(days) + 'D'
97 | i += 1
98 | if hours > 0 and i <= 2:
99 | f += str(hours) + 'h'
100 | i += 1
101 | if minutes > 0 and i <= 2:
102 | f += str(minutes) + 'm'
103 | i += 1
104 | if secondsf > 0 and i <= 2:
105 | f += str(secondsf) + 's'
106 | i += 1
107 | if millis > 0 and i <= 2:
108 | f += str(millis) + 'ms'
109 | i += 1
110 | if f == '':
111 | f = '0ms'
112 | return f
113 |
--------------------------------------------------------------------------------