├── LICENSE
├── README.md
├── assets
├── MiSLAS.PNG
└── MiSLAS.pdf
├── config
├── cifar10
│ ├── cifar10_imb001_stage1_mixup.yaml
│ ├── cifar10_imb001_stage2_mislas.yaml
│ ├── cifar10_imb002_stage1_mixup.yaml
│ ├── cifar10_imb002_stage2_mislas.yaml
│ ├── cifar10_imb01_stage1_mixup.yaml
│ └── cifar10_imb01_stage2_mislas.yaml
├── cifar100
│ ├── cifar100_imb001_stage1_mixup.yaml
│ ├── cifar100_imb001_stage2_mislas.yaml
│ ├── cifar100_imb002_stage1_mixup.yaml
│ ├── cifar100_imb002_stage2_mislas.yaml
│ ├── cifar100_imb01_stage1_mixup.yaml
│ └── cifar100_imb01_stage2_mislas.yaml
├── imagenet
│ ├── imagenet_resnet50_stage1_mixup.yaml
│ └── imagenet_resnet50_stage2_mislas.yaml
├── ina2018
│ ├── ina2018_resnet50_stage1_mixup.yaml
│ └── ina2018_resnet50_stage2_mislas.yaml
└── places
│ ├── places_resnet152_stage1_mixup.yaml
│ └── places_resnet152_stage2_mislas.yaml
├── datasets
├── cifar10.py
├── cifar100.py
├── data_txt
│ ├── ImageNet_LT_test.txt
│ ├── ImageNet_LT_train.txt
│ ├── Places_LT_test.txt
│ ├── Places_LT_train.txt
│ ├── iNaturalist18_train.txt
│ └── iNaturalist18_val.txt
├── imagenet.py
├── ina2018.py
├── places.py
└── sampler.py
├── eval.py
├── methods.py
├── models
├── resnet.py
├── resnet_cifar.py
└── resnet_places.py
├── reliability_diagrams.py
├── requirements.txt
├── train_stage1.py
├── train_stage2.py
└── utils
├── __init__.py
├── logger.py
├── meter.py
└── metric.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Zhisheng Zhong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MiSLAS
2 | **Improving Calibration for Long-Tailed Recognition**
3 |
4 | **Authors**: Zhisheng Zhong, Jiequan Cui, Shu Liu, Jiaya Jia
5 |
6 | [[`arXiv`](https://arxiv.org/pdf/2104.00466.pdf)] [[`slide`]](./assets/MiSLAS.pdf) [[`BibTeX`](#Citation)]
7 |
8 |
9 |

10 |
11 |
12 | **Introduction**: This repository provides an implementation for the CVPR 2021 paper: "[Improving Calibration for Long-Tailed Recognition](https://arxiv.org/pdf/2104.00466.pdf)" based on [LDAM-DRW](https://github.com/kaidic/LDAM-DRW) and [Decoupling models](https://github.com/facebookresearch/classifier-balancing). *Our study shows, because of the extreme imbalanced composition ratio of each class, networks trained on long-tailed datasets are more miscalibrated and over-confident*. MiSLAS is a simple, and efficient two-stage framework for long-tailed recognition, which greatly improves recognition accuracy and markedly relieves over-confidence simultaneously.
13 |
14 | ## Installation
15 |
16 | **Requirements**
17 |
18 | * Python 3.7
19 | * torchvision 0.4.0
20 | * Pytorch 1.2.0
21 | * yacs 0.1.8
22 |
23 | **Virtual Environment**
24 | ```
25 | conda create -n MiSLAS python==3.7
26 | source activate MiSLAS
27 | ```
28 |
29 | **Install MiSLAS**
30 | ```
31 | git clone https://github.com/Jia-Research-Lab/MiSLAS.git
32 | cd MiSLAS
33 | pip install -r requirements.txt
34 | ```
35 |
36 | **Dataset Preparation**
37 | * [CIFAR-10 & CIFAR-100](https://www.cs.toronto.edu/~kriz/cifar.html)
38 | * [ImageNet](http://image-net.org/index)
39 | * [iNaturalist 2018](https://github.com/visipedia/inat_comp/tree/master/2018)
40 | * [Places](http://places2.csail.mit.edu/download.html)
41 |
42 | Change the `data_path` in `config/*/*.yaml` accordingly.
43 |
44 | ## Training
45 |
46 | **Stage-1**:
47 |
48 | To train a model for Stage-1 with *mixup*, run:
49 |
50 | (one GPU for CIFAR-10-LT & CIFAR-100-LT, four GPUs for ImageNet-LT, iNaturalist 2018, and Places-LT)
51 |
52 | ```
53 | python train_stage1.py --cfg ./config/DATASETNAME/DATASETNAME_ARCH_stage1_mixup.yaml
54 | ```
55 |
56 | `DATASETNAME` can be selected from `cifar10`, `cifar100`, `imagenet`, `ina2018`, and `places`.
57 |
58 | `ARCH` can be `resnet32` for `cifar10/100`, `resnet50/101/152` for `imagenet`, `resnet50` for `ina2018`, and `resnet152` for `places`, respectively.
59 |
60 | **Stage-2**:
61 |
62 | To train a model for Stage-2 with *one GPU* (all the above datasets), run:
63 |
64 | ```
65 | python train_stage2.py --cfg ./config/DATASETNAME/DATASETNAME_ARCH_stage2_mislas.yaml resume /path/to/checkpoint/stage1
66 | ```
67 |
68 | The saved folder (including logs and checkpoints) is organized as follows.
69 | ```
70 | MiSLAS
71 | ├── saved
72 | │ ├── modelname_date
73 | │ │ ├── ckps
74 | │ │ │ ├── current.pth.tar
75 | │ │ │ └── model_best.pth.tar
76 | │ │ └── logs
77 | │ │ └── modelname.txt
78 | │ ...
79 | ```
80 | ## Evaluation
81 |
82 | To evaluate a trained model, run:
83 |
84 | ```
85 | python eval.py --cfg ./config/DATASETNAME/DATASETNAME_ARCH_stage1_mixup.yaml resume /path/to/checkpoint/stage1
86 | python eval.py --cfg ./config/DATASETNAME/DATASETNAME_ARCH_stage2_mislas.yaml resume /path/to/checkpoint/stage2
87 | ```
88 |
89 | ## Results and Models
90 |
91 | **1) CIFAR-10-LT and CIFAR-100-LT**
92 |
93 | * Stage-1 (*mixup*):
94 |
95 | | Dataset | Top-1 Accuracy | ECE (15 bins) | Model |
96 | | -------------------- | -------------- | ------------- | ----- |
97 | | CIFAR-10-LT IF=10 | 87.6% | 11.9% | [link](https://drive.google.com/file/d/1dV1hchsIR5kTSqSOhdEs6nnXApcH5wEG/view?usp=sharing) |
98 | | CIFAR-10-LT IF=50 | 78.1% | 2.49% | [link](https://drive.google.com/file/d/1LoczjQRK20u_HpFMLmzeT0pVCp3V-gyf/view?usp=sharing) |
99 | | CIFAR-10-LT IF=100 | 72.8% | 2.14% | [link](https://drive.google.com/file/d/1TFetlV4MT4zjKEAPKcZuzmY2Dgtcqmsd/view?usp=sharing) |
100 | | CIFAR-100-LT IF=10 | 59.1% | 5.24% | [link](https://drive.google.com/file/d/1BmLjPReBoH6LJwl5x8_zSPnm1f6N_Cp0/view?usp=sharing) |
101 | | CIFAR-100-LT IF=50 | 45.4% | 4.33% | [link](https://drive.google.com/file/d/1l0LfZozJxWgzKp2IgM9mSpfwjTsIC-Mg/view?usp=sharing) |
102 | | CIFAR-100-LT IF=100 | 39.5% | 8.82% | [link](https://drive.google.com/file/d/15dHVdkI8J-oKkeQqyj6FtrHtIpO_TYfq/view?usp=sharing) |
103 |
104 | * Stage-2 (*MiSLAS*):
105 |
106 | | Dataset | Top-1 Accuracy | ECE (15 bins) | Model |
107 | | -------------------- | -------------- | ------------- | ----- |
108 | | CIFAR-10-LT IF=10 | 90.0% | 1.20% | [link](https://drive.google.com/file/d/1iST8Tr2LQ8nIjTNT1CKiQ-1T-RKxAvqr/view?usp=sharing) |
109 | | CIFAR-10-LT IF=50 | 85.7% | 2.01% | [link](https://drive.google.com/file/d/15bfA7uJsyM8eTwoptwp452kStk6FYT7v/view?usp=sharing) |
110 | | CIFAR-10-LT IF=100 | 82.5% | 3.66% | [link](https://drive.google.com/file/d/1KOTkjTOhIP5UOhqvHGJzEqq4_kQGKSJY/view?usp=sharing) |
111 | | CIFAR-100-LT IF=10 | 63.2% | 1.73% | [link](https://drive.google.com/file/d/1N2ai-l1hsbXTp_25Hoh5BSoAmR1_0UVD/view?usp=sharing) |
112 | | CIFAR-100-LT IF=50 | 52.3% | 2.47% | [link](https://drive.google.com/file/d/1Z2nukCMTG0cMmGXzZip3zIwv2WB5cOiZ/view?usp=sharing) |
113 | | CIFAR-100-LT IF=100 | 47.0% | 4.83% | [link](https://drive.google.com/file/d/1bX3eM-hlxGvEGuHBcfNhuz6VNp32Y0IQ/view?usp=sharing) |
114 |
115 | *Note: To obtain better performance, we highly recommend changing the weight decay 2e-4 to 5e-4 on CIFAR-LT.*
116 |
117 | **2) Large-scale Datasets**
118 |
119 | * Stage-1 (*mixup*):
120 |
121 | | Dataset | Arch | Top-1 Accuracy | ECE (15 bins) | Model |
122 | | ----------- | ---------- | -------------- | ------------- | ----- |
123 | | ImageNet-LT | ResNet-50 | 45.5% | 7.98% | [link](https://drive.google.com/file/d/1QKVnK7n75q465ppf7wkK4jzZvZJE_BPi/view?usp=sharing) |
124 | | iNa'2018 | ResNet-50 | 66.9% | 5.37% | [link](https://drive.google.com/file/d/1wvj-cITz8Ps1TksLHi_KoGsq9CecXcVt/view?usp=sharing) |
125 | | Places-LT | ResNet-152 | 29.4% | 16.7% | [link](https://drive.google.com/file/d/1Tx-tY5Y8_-XuGn9ZdSxtAm0onOsKWhUH/view?usp=sharing) |
126 |
127 | * Stage-2 (*MiSLAS*):
128 |
129 | | Dataset | Arch | Top-1 Accuracy | ECE (15 bins) | Model |
130 | | ----------- | ---------- | -------------- | ------------- | ----- |
131 | | ImageNet-LT | ResNet-50 | 52.7% | 1.78% | [link](https://drive.google.com/file/d/1ofJKlUJZQjjkoFU9MLI08UP2uBvywRgF/view?usp=sharing) |
132 | | iNa'2018 | ResNet-50 | 71.6% | 7.67% | [link](https://drive.google.com/file/d/1crOo3INxqkz8ZzKZt9pH4aYb3-ep4lo-/view?usp=sharing) |
133 | | Places-LT | ResNet-152 | 40.4% | 3.41% | [link](https://drive.google.com/file/d/1DgL0aN3UadI3UoHU6TO7M6UD69QgvnbT/view?usp=sharing) |
134 |
135 | ## Citation
136 |
137 | Please consider citing MiSLAS in your publications if it helps your research. :)
138 |
139 | ```bib
140 | @inproceedings{zhong2021mislas,
141 | title={Improving Calibration for Long-Tailed Recognition},
142 | author={Zhisheng Zhong, Jiequan Cui, Shu Liu, and Jiaya Jia},
143 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
144 | year={2021},
145 | }
146 | ```
147 |
148 | ## Contact
149 |
150 | If you have any questions about our work, feel free to contact us through email (Zhisheng Zhong: zszhong@pku.edu.cn) or Github issues.
151 |
--------------------------------------------------------------------------------
/assets/MiSLAS.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/MiSLAS/0fc381e333d8a11170851b7af2058c59840969e0/assets/MiSLAS.PNG
--------------------------------------------------------------------------------
/assets/MiSLAS.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/MiSLAS/0fc381e333d8a11170851b7af2058c59840969e0/assets/MiSLAS.pdf
--------------------------------------------------------------------------------
/config/cifar10/cifar10_imb001_stage1_mixup.yaml:
--------------------------------------------------------------------------------
1 | name: cifar10_imb001_stage1_mixup
2 | print_freq: 40
3 | workers: 16
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 | # dataset & model setting
8 | dataset: 'cifar10'
9 | data_path: './data/cifar10'
10 | num_classes: 10
11 | imb_factor: 0.01
12 | backbone: 'resnet32_fe'
13 | resume: ''
14 | head_class_idx:
15 | - 0
16 | - 3
17 | med_class_idx:
18 | - 3
19 | - 7
20 | tail_class_idx:
21 | - 7
22 | - 10
23 |
24 |
25 | # distributed training
26 | deterministic: False
27 | distributed: False
28 | gpu: null
29 | world_size: -1
30 | rank: -1
31 | dist_url: 'tcp://224.66.41.62:23456'
32 | dist_backend: 'nccl'
33 | multiprocessing_distributed: False
34 |
35 |
36 |
37 | # Train
38 | mode: 'stage1'
39 | lr: 0.1
40 | batch_size: 128
41 | weight_decay: 2e-4
42 | num_epochs: 200
43 | momentum: 0.9
44 | cos: False
45 | mixup: True
46 | alpha: 1.0
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/config/cifar10/cifar10_imb001_stage2_mislas.yaml:
--------------------------------------------------------------------------------
1 | name: cifar10_imb001_stage2_mislas
2 | print_freq: 40
3 | workers: 16
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 | # dataset & model setting
8 | dataset: 'cifar10'
9 | data_path: './data/cifar10'
10 | num_classes: 10
11 | imb_factor: 0.01
12 | backbone: 'resnet32_fe'
13 | resume: 'Path/to/Stage1_checkpoint.pth.tar'
14 | head_class_idx:
15 | - 0
16 | - 3
17 | med_class_idx:
18 | - 3
19 | - 7
20 | tail_class_idx:
21 | - 7
22 | - 10
23 |
24 |
25 | # distributed training
26 | deterministic: False
27 | distributed: False
28 | gpu: null
29 | world_size: -1
30 | rank: -1
31 | dist_url: 'tcp://224.66.41.62:23456'
32 | dist_backend: 'nccl'
33 | multiprocessing_distributed: False
34 |
35 |
36 |
37 | # Train
38 | mode: 'stage2'
39 | smooth_head: 0.3
40 | smooth_tail: 0.0
41 | shift_bn: False
42 | lr_factor: 0.5
43 | lr: 0.1
44 | batch_size: 128
45 | weight_decay: 2e-4
46 | num_epochs: 10
47 | momentum: 0.9
48 | mixup: False
49 | alpha: null
50 |
51 |
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/config/cifar10/cifar10_imb002_stage1_mixup.yaml:
--------------------------------------------------------------------------------
1 | name: cifar10_imb002_stage1_mixup
2 | print_freq: 40
3 | workers: 16
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 | # dataset & model setting
8 | dataset: 'cifar10'
9 | data_path: './data/cifar10'
10 | num_classes: 10
11 | imb_factor: 0.02
12 | backbone: 'resnet32_fe'
13 | resume: ''
14 | head_class_idx:
15 | - 0
16 | - 3
17 | med_class_idx:
18 | - 3
19 | - 7
20 | tail_class_idx:
21 | - 7
22 | - 10
23 |
24 |
25 | # distributed training
26 | deterministic: False
27 | distributed: False
28 | gpu: null
29 | world_size: -1
30 | rank: -1
31 | dist_url: 'tcp://224.66.41.62:23456'
32 | dist_backend: 'nccl'
33 | multiprocessing_distributed: False
34 |
35 |
36 |
37 | # Train
38 | mode: 'stage1'
39 | lr: 0.1
40 | batch_size: 128
41 | weight_decay: 2e-4
42 | num_epochs: 200
43 | momentum: 0.9
44 | cos: False
45 | mixup: True
46 | alpha: 1.0
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/config/cifar10/cifar10_imb002_stage2_mislas.yaml:
--------------------------------------------------------------------------------
1 | name: cifar10_imb002_stage2_mislas
2 | print_freq: 40
3 | workers: 16
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 | # dataset & model setting
8 | dataset: 'cifar10'
9 | data_path: './data/cifar10'
10 | num_classes: 10
11 | imb_factor: 0.02
12 | backbone: 'resnet32_fe'
13 | resume: 'Path/to/Stage1_checkpoint.pth.tar'
14 | head_class_idx:
15 | - 0
16 | - 3
17 | med_class_idx:
18 | - 3
19 | - 7
20 | tail_class_idx:
21 | - 7
22 | - 10
23 |
24 |
25 | # distributed training
26 | deterministic: False
27 | distributed: False
28 | gpu: null
29 | world_size: -1
30 | rank: -1
31 | dist_url: 'tcp://224.66.41.62:23456'
32 | dist_backend: 'nccl'
33 | multiprocessing_distributed: False
34 |
35 |
36 |
37 | # Train
38 | mode: 'stage2'
39 | smooth_head: 0.2
40 | smooth_tail: 0.0
41 | shift_bn: False
42 | lr_factor: 0.2
43 | lr: 0.1
44 | batch_size: 128
45 | weight_decay: 2e-4
46 | num_epochs: 10
47 | momentum: 0.9
48 | mixup: False
49 | alpha: null
50 |
51 |
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/config/cifar10/cifar10_imb01_stage1_mixup.yaml:
--------------------------------------------------------------------------------
1 | name: cifar10_imb01_stage1_mixup
2 | print_freq: 40
3 | workers: 16
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 | # dataset & model setting
8 | dataset: 'cifar10'
9 | data_path: './data/cifar10'
10 | num_classes: 10
11 | imb_factor: 0.1
12 | backbone: 'resnet32_fe'
13 | resume: ''
14 | head_class_idx:
15 | - 0
16 | - 3
17 | med_class_idx:
18 | - 3
19 | - 7
20 | tail_class_idx:
21 | - 7
22 | - 10
23 |
24 |
25 | # distributed training
26 | deterministic: False
27 | distributed: False
28 | gpu: null
29 | world_size: -1
30 | rank: -1
31 | dist_url: 'tcp://224.66.41.62:23456'
32 | dist_backend: 'nccl'
33 | multiprocessing_distributed: False
34 |
35 |
36 |
37 | # Train
38 | mode: 'stage1'
39 | lr: 0.1
40 | batch_size: 128
41 | weight_decay: 2e-4
42 | num_epochs: 200
43 | momentum: 0.9
44 | cos: False
45 | mixup: True
46 | alpha: 1.0
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/config/cifar10/cifar10_imb01_stage2_mislas.yaml:
--------------------------------------------------------------------------------
1 | name: cifar10_imb01_stage2_mislas
2 | print_freq: 40
3 | workers: 16
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 | # dataset & model setting
8 | dataset: 'cifar10'
9 | data_path: './data/cifar10'
10 | num_classes: 10
11 | imb_factor: 0.1
12 | backbone: 'resnet32_fe'
13 | resume: 'Path/to/Stage1_checkpoint.pth.tar'
14 | head_class_idx:
15 | - 0
16 | - 3
17 | med_class_idx:
18 | - 3
19 | - 7
20 | tail_class_idx:
21 | - 7
22 | - 10
23 |
24 |
25 | # distributed training
26 | deterministic: False
27 | distributed: False
28 | gpu: null
29 | world_size: -1
30 | rank: -1
31 | dist_url: 'tcp://224.66.41.62:23456'
32 | dist_backend: 'nccl'
33 | multiprocessing_distributed: False
34 |
35 |
36 |
37 | # Train
38 | mode: 'stage2'
39 | smooth_head: 0.1
40 | smooth_tail: 0.0
41 | shift_bn: False
42 | lr_factor: 0.2
43 | lr: 0.1
44 | batch_size: 128
45 | weight_decay: 2e-4
46 | num_epochs: 10
47 | momentum: 0.9
48 | mixup: False
49 | alpha: null
50 |
51 |
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/config/cifar100/cifar100_imb001_stage1_mixup.yaml:
--------------------------------------------------------------------------------
1 | name: cifar100_imb001_stage1_mixup
2 | print_freq: 40
3 | workers: 4
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 |
8 | # dataset & model setting
9 | dataset: 'cifar100'
10 | data_path: './data/cifar100'
11 | num_classes: 100
12 | imb_factor: 0.01
13 | backbone: 'resnet32_fe'
14 | resume: ''
15 | head_class_idx:
16 | - 0
17 | - 36
18 | med_class_idx:
19 | - 36
20 | - 71
21 | tail_class_idx:
22 | - 71
23 | - 100
24 |
25 |
26 | # distributed training
27 | deterministic: True
28 | distributed: False
29 | gpu: null
30 | world_size: -1
31 | rank: -1
32 | dist_url: 'tcp://224.66.41.62:23456'
33 | dist_backend: 'nccl'
34 | multiprocessing_distributed: False
35 |
36 |
37 | # Train
38 | mode: 'stage1'
39 | lr: 0.1
40 | batch_size: 128
41 | weight_decay: 2e-4
42 | num_epochs: 200
43 | momentum: 0.9
44 | cos: False
45 | mixup: True
46 | alpha: 1.0
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/config/cifar100/cifar100_imb001_stage2_mislas.yaml:
--------------------------------------------------------------------------------
1 | name: cifar100_imb001_stage2_mislas
2 | print_freq: 40
3 | workers: 4
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 |
8 | # dataset & model setting
9 | dataset: 'cifar100'
10 | data_path: './data/cifar100'
11 | num_classes: 100
12 | imb_factor: 0.01
13 | backbone: 'resnet32_fe'
14 | resume: 'Path/to/Stage1_checkpoint.pth.tar'
15 | head_class_idx:
16 | - 0
17 | - 36
18 | med_class_idx:
19 | - 36
20 | - 71
21 | tail_class_idx:
22 | - 71
23 | - 100
24 |
25 |
26 | # distributed training
27 | deterministic: False
28 | distributed: False
29 | gpu: null
30 | world_size: -1
31 | rank: -1
32 | dist_url: 'tcp://224.66.41.62:23456'
33 | dist_backend: 'nccl'
34 | multiprocessing_distributed: False
35 |
36 |
37 | # Train
38 | mode: 'stage2'
39 | smooth_head: 0.4
40 | smooth_tail: 0.1
41 | shift_bn: True
42 | lr_factor: 0.2
43 | lr: 0.1
44 | batch_size: 128
45 | weight_decay: 2e-4
46 | num_epochs: 10
47 | momentum: 0.9
48 | mixup: False
49 | alpha: null
--------------------------------------------------------------------------------
/config/cifar100/cifar100_imb002_stage1_mixup.yaml:
--------------------------------------------------------------------------------
1 | name: cifar100_imb002_stage1_mixup
2 | print_freq: 40
3 | workers: 4
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 |
8 | # dataset & model setting
9 | dataset: 'cifar100'
10 | data_path: './data/cifar100'
11 | num_classes: 100
12 | imb_factor: 0.02
13 | backbone: 'resnet32_fe'
14 | resume: ''
15 | head_class_idx:
16 | - 0
17 | - 36
18 | med_class_idx:
19 | - 36
20 | - 71
21 | tail_class_idx:
22 | - 71
23 | - 100
24 |
25 |
26 | # distributed training
27 | deterministic: False
28 | distributed: False
29 | gpu: null
30 | world_size: -1
31 | rank: -1
32 | dist_url: 'tcp://224.66.41.62:23456'
33 | dist_backend: 'nccl'
34 | multiprocessing_distributed: False
35 |
36 |
37 | # Train
38 | mode: 'stage1'
39 | lr: 0.1
40 | batch_size: 128
41 | weight_decay: 2e-4
42 | num_epochs: 200
43 | momentum: 0.9
44 | cos: False
45 | mixup: True
46 | alpha: 1.0
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/config/cifar100/cifar100_imb002_stage2_mislas.yaml:
--------------------------------------------------------------------------------
1 | name: cifar100_imb002_stage2_mislas
2 | print_freq: 40
3 | workers: 4
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 |
8 | # dataset & model setting
9 | dataset: 'cifar100'
10 | data_path: './data/cifar100'
11 | num_classes: 100
12 | imb_factor: 0.02
13 | backbone: 'resnet32_fe'
14 | resume: 'Path/to/Stage1_checkpoint.pth.tar'
15 | head_class_idx:
16 | - 0
17 | - 36
18 | med_class_idx:
19 | - 36
20 | - 71
21 | tail_class_idx:
22 | - 71
23 | - 100
24 |
25 |
26 | # distributed training
27 | deterministic: False
28 | distributed: False
29 | gpu: null
30 | world_size: -1
31 | rank: -1
32 | dist_url: 'tcp://224.66.41.62:23456'
33 | dist_backend: 'nccl'
34 | multiprocessing_distributed: False
35 |
36 |
37 | # Train
38 | mode: 'stage2'
39 | smooth_head: 0.3
40 | smooth_tail: 0.0
41 | shift_bn: True
42 | lr_factor: 0.1
43 | lr: 0.1
44 | batch_size: 128
45 | weight_decay: 2e-4
46 | num_epochs: 10
47 | momentum: 0.9
48 | mixup: False
49 | alpha: null
--------------------------------------------------------------------------------
/config/cifar100/cifar100_imb01_stage1_mixup.yaml:
--------------------------------------------------------------------------------
1 | name: cifar100_imb01_stage1_mixup
2 | print_freq: 40
3 | workers: 4
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 |
8 | # dataset & model setting
9 | dataset: 'cifar100'
10 | data_path: './data/cifar100'
11 | num_classes: 100
12 | imb_factor: 0.1
13 | backbone: 'resnet32_fe'
14 | resume: ''
15 | head_class_idx:
16 | - 0
17 | - 36
18 | med_class_idx:
19 | - 36
20 | - 71
21 | tail_class_idx:
22 | - 71
23 | - 100
24 |
25 |
26 | # distributed training
27 | deterministic: False
28 | distributed: False
29 | gpu: null
30 | world_size: -1
31 | rank: -1
32 | dist_url: 'tcp://224.66.41.62:23456'
33 | dist_backend: 'nccl'
34 | multiprocessing_distributed: False
35 |
36 |
37 | # Train
38 | mode: 'stage1'
39 | lr: 0.1
40 | batch_size: 128
41 | weight_decay: 2e-4
42 | num_epochs: 200
43 | momentum: 0.9
44 | cos: False
45 | mixup: True
46 | alpha: 1.0
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/config/cifar100/cifar100_imb01_stage2_mislas.yaml:
--------------------------------------------------------------------------------
1 | name: cifar100_imb01_stage2_mislas
2 | print_freq: 40
3 | workers: 4
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 |
8 | # dataset & model setting
9 | dataset: 'cifar100'
10 | data_path: './data/cifar100'
11 | num_classes: 100
12 | imb_factor: 0.1
13 | backbone: 'resnet32_fe'
14 | resume: 'Path/to/Stage1_checkpoint.pth.tar'
15 | head_class_idx:
16 | - 0
17 | - 36
18 | med_class_idx:
19 | - 36
20 | - 71
21 | tail_class_idx:
22 | - 71
23 | - 100
24 |
25 |
26 | # distributed training
27 | deterministic: False
28 | distributed: False
29 | gpu: null
30 | world_size: -1
31 | rank: -1
32 | dist_url: 'tcp://224.66.41.62:23456'
33 | dist_backend: 'nccl'
34 | multiprocessing_distributed: False
35 |
36 |
37 | # Train
38 | mode: 'stage2'
39 | smooth_head: 0.2
40 | smooth_tail: 0.0
41 | shift_bn: True
42 | lr_factor: 0.1
43 | lr: 0.1
44 | batch_size: 128
45 | weight_decay: 2e-4
46 | num_epochs: 10
47 | momentum: 0.9
48 | mixup: False
49 | alpha: null
--------------------------------------------------------------------------------
/config/imagenet/imagenet_resnet50_stage1_mixup.yaml:
--------------------------------------------------------------------------------
1 | name: imagenet_resnet50_stage1_mixup
2 | print_freq: 100
3 | workers: 48
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 |
8 | # dataset & model setting
9 | dataset: 'imagenet'
10 | data_path: 'Path/to/Data/ImageNet/'
11 | num_classes: 1000
12 | imb_factor: null
13 | backbone: 'resnet50_fe'
14 | resume: ''
15 | head_class_idx:
16 | - 0
17 | - 390
18 | med_class_idx:
19 | - 390
20 | - 835
21 | tail_class_idx:
22 | - 835
23 | - 1000
24 |
25 |
26 | # distributed training
27 | deterministic: False
28 | distributed: False
29 | gpu: null
30 | world_size: -1
31 | rank: -1
32 | dist_url: 'tcp://224.66.41.62:23456'
33 | dist_backend: 'nccl'
34 | multiprocessing_distributed: False
35 |
36 |
37 | # Train
38 | mode: 'stage1'
39 | lr: 0.1
40 | batch_size: 256
41 | weight_decay: 5e-4
42 | num_epochs: 180
43 | momentum: 0.9
44 | cos: True
45 | mixup: True
46 | alpha: 0.2
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/config/imagenet/imagenet_resnet50_stage2_mislas.yaml:
--------------------------------------------------------------------------------
1 | name: imagenet_resnet50_stage2_mislas
2 | print_freq: 100
3 | workers: 48
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 |
8 | # dataset & model setting
9 | dataset: 'imagenet'
10 | data_path: 'Path/to/Data/ImageNet/'
11 | num_classes: 1000
12 | imb_factor: null
13 | backbone: 'resnet50_fe'
14 | resume: 'Path/to/Stage1_checkpoint.pth.tar'
15 | head_class_idx:
16 | - 0
17 | - 390
18 | med_class_idx:
19 | - 390
20 | - 835
21 | tail_class_idx:
22 | - 835
23 | - 1000
24 |
25 |
26 | # distributed training
27 | deterministic: False
28 | distributed: False
29 | gpu: null
30 | world_size: -1
31 | rank: -1
32 | dist_url: 'tcp://224.66.41.62:23456'
33 | dist_backend: 'nccl'
34 | multiprocessing_distributed: False
35 |
36 |
37 | # Train
38 | mode: 'stage2'
39 | smooth_head: 0.3
40 | smooth_tail: 0.0
41 | shift_bn: True
42 | lr_factor: 0.05
43 | lr: 0.1
44 | batch_size: 256
45 | weight_decay: 5e-4
46 | num_epochs: 10
47 | momentum: 0.9
48 | mixup: False
49 | alpha: null
50 |
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/config/ina2018/ina2018_resnet50_stage1_mixup.yaml:
--------------------------------------------------------------------------------
1 | name: ina2018_resnet50_stage1_mixup
2 | print_freq: 200
3 | workers: 48
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 |
8 | # dataset & model setting
9 | dataset: 'ina2018'
10 | data_path: 'Path/to/Data/iNaturalist2018'
11 | num_classes: 8142
12 | imb_factor: null
13 | backbone: 'resnet50_fe'
14 | resume: ''
15 | head_class_idx:
16 | - 0
17 | - 842
18 | med_class_idx:
19 | - 842
20 | - 4543
21 | tail_class_idx:
22 | - 4543
23 | - 8142
24 |
25 |
26 | # distributed training
27 | deterministic: False
28 | distributed: False
29 | gpu: null
30 | world_size: -1
31 | rank: -1
32 | dist_url: 'tcp://224.66.41.62:23456'
33 | dist_backend: 'nccl'
34 | multiprocessing_distributed: False
35 |
36 |
37 | # Train
38 | mode: 'stage1'
39 | lr: 0.1
40 | batch_size: 256
41 | weight_decay: 1e-4
42 | num_epochs: 200
43 | momentum: 0.9
44 | cos: True
45 | mixup: True
46 | alpha: 0.2
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/config/ina2018/ina2018_resnet50_stage2_mislas.yaml:
--------------------------------------------------------------------------------
1 | name: ina2018_resnet50_stage2_mislas
2 | print_freq: 200
3 | workers: 48
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 |
8 | # dataset & model setting
9 | dataset: 'ina2018'
10 | data_path: 'Path/to/Data/iNaturalist2018'
11 | num_classes: 8142
12 | imb_factor: null
13 | backbone: 'resnet50_fe'
14 | resume: 'Path/to/Stage1_checkpoint.pth.tar'
15 | head_class_idx:
16 | - 0
17 | - 842
18 | med_class_idx:
19 | - 842
20 | - 4543
21 | tail_class_idx:
22 | - 4543
23 | - 8142
24 |
25 |
26 | # distributed training
27 | deterministic: False
28 | distributed: False
29 | gpu: null
30 | world_size: -1
31 | rank: -1
32 | dist_url: 'tcp://224.66.41.62:23456'
33 | dist_backend: 'nccl'
34 | multiprocessing_distributed: False
35 |
36 |
37 | # Train
38 | mode: 'stage2'
39 | smooth_head: 0.3
40 | smooth_tail: 0.0
41 | shift_bn: True
42 | lr_factor: 0.05
43 | lr: 0.1
44 | batch_size: 256
45 | weight_decay: 1e-4
46 | num_epochs: 10
47 | momentum: 0.9
48 | mixup: False
49 | alpha: null
50 |
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/config/places/places_resnet152_stage1_mixup.yaml:
--------------------------------------------------------------------------------
1 | name: places_resnet152_stage1_mixup
2 | print_freq: 100
3 | workers: 48
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 |
8 | # dataset & model setting
9 | dataset: 'places'
10 | data_path: 'Path/to/Data/Places365/'
11 | num_classes: 365
12 | imb_factor: null
13 | backbone: 'resnet152_fe'
14 | resume: ''
15 | head_class_idx:
16 | - 0
17 | - 131
18 | med_class_idx:
19 | - 131
20 | - 288
21 | tail_class_idx:
22 | - 288
23 | - 365
24 |
25 |
26 | # distributed training
27 | deterministic: False
28 | distributed: False
29 | gpu: null
30 | world_size: -1
31 | rank: -1
32 | dist_url: 'tcp://224.66.41.62:23456'
33 | dist_backend: 'nccl'
34 | multiprocessing_distributed: False
35 |
36 |
37 | # Train
38 | mode: 'stage1'
39 | lr: 0.1
40 | batch_size: 256
41 | weight_decay: 5e-4
42 | num_epochs: 90
43 | momentum: 0.9
44 | cos: True
45 | mixup: True
46 | alpha: 0.2
47 |
48 |
49 |
50 |
51 |
--------------------------------------------------------------------------------
/config/places/places_resnet152_stage2_mislas.yaml:
--------------------------------------------------------------------------------
1 | name: places_resnet152_stage2_mislas
2 | print_freq: 100
3 | workers: 48
4 | log_dir: 'logs'
5 | model_dir: 'ckps'
6 |
7 |
8 | # dataset & model setting
9 | dataset: 'places'
10 | data_path: 'Path/to/Data/Places365/'
11 | num_classes: 365
12 | imb_factor: null
13 | backbone: 'resnet152_fe'
14 | resume: 'Path/to/Stage1_checkpoint.pth.tar'
15 | head_class_idx:
16 | - 0
17 | - 131
18 | med_class_idx:
19 | - 131
20 | - 288
21 | tail_class_idx:
22 | - 288
23 | - 365
24 |
25 |
26 | # distributed training
27 | deterministic: False
28 | distributed: False
29 | gpu: null
30 | world_size: -1
31 | rank: -1
32 | dist_url: 'tcp://224.66.41.62:23456'
33 | dist_backend: 'nccl'
34 | multiprocessing_distributed: False
35 |
36 |
37 | # Train
38 | mode: 'stage2'
39 | smooth_head: 0.4
40 | smooth_tail: 0.1
41 | shift_bn: True
42 | lr_factor: 0.05
43 | lr: 0.1
44 | batch_size: 256
45 | weight_decay: 5e-4
46 | num_epochs: 10
47 | momentum: 0.9
48 | mixup: False
49 | alpha: null
50 |
--------------------------------------------------------------------------------
/datasets/cifar10.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .sampler import ClassAwareSampler
3 |
4 | import torch
5 | import torchvision
6 | from torchvision import transforms
7 | import torchvision.datasets
8 |
9 | class IMBALANCECIFAR10(torchvision.datasets.CIFAR10):
10 | cls_num = 10
11 |
12 | def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True,
13 | transform=None, target_transform=None,
14 | download=False):
15 | super(IMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download)
16 | np.random.seed(rand_number)
17 | img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor)
18 | self.gen_imbalanced_data(img_num_list)
19 |
20 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
21 | img_max = len(self.data) / cls_num
22 | img_num_per_cls = []
23 | if imb_type == 'exp':
24 | for cls_idx in range(cls_num):
25 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
26 | img_num_per_cls.append(int(num))
27 | elif imb_type == 'step':
28 | for cls_idx in range(cls_num // 2):
29 | img_num_per_cls.append(int(img_max))
30 | for cls_idx in range(cls_num // 2):
31 | img_num_per_cls.append(int(img_max * imb_factor))
32 | else:
33 | img_num_per_cls.extend([int(img_max)] * cls_num)
34 | return img_num_per_cls
35 |
36 | def gen_imbalanced_data(self, img_num_per_cls):
37 | new_data = []
38 | new_targets = []
39 | targets_np = np.array(self.targets, dtype=np.int64)
40 | classes = np.unique(targets_np)
41 | # np.random.shuffle(classes)
42 | self.num_per_cls_dict = dict()
43 | for the_class, the_img_num in zip(classes, img_num_per_cls):
44 | self.num_per_cls_dict[the_class] = the_img_num
45 | idx = np.where(targets_np == the_class)[0]
46 | np.random.shuffle(idx)
47 | selec_idx = idx[:the_img_num]
48 | new_data.append(self.data[selec_idx, ...])
49 | new_targets.extend([the_class, ] * the_img_num)
50 | new_data = np.vstack(new_data)
51 | self.data = new_data
52 | self.targets = new_targets
53 |
54 | def get_cls_num_list(self):
55 | cls_num_list = []
56 | for i in range(self.cls_num):
57 | cls_num_list.append(self.num_per_cls_dict[i])
58 | return cls_num_list
59 |
60 |
61 |
62 | class CIFAR10_LT(object):
63 |
64 | def __init__(self, distributed, root='./data/cifar10', imb_type='exp',
65 | imb_factor=0.01, batch_size=128, num_works=40):
66 |
67 | train_transform = transforms.Compose([
68 | transforms.RandomCrop(32, padding=4),
69 | transforms.RandomHorizontalFlip(),
70 | transforms.ToTensor(),
71 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
72 | ])
73 |
74 |
75 | eval_transform = transforms.Compose([
76 | transforms.ToTensor(),
77 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
78 | ])
79 |
80 |
81 | train_dataset = IMBALANCECIFAR10(root=root, imb_type=imb_type, imb_factor=imb_factor, rand_number=0, train=True, download=True, transform=train_transform)
82 | eval_dataset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=eval_transform)
83 |
84 | self.cls_num_list = train_dataset.get_cls_num_list()
85 |
86 | self.dist_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed else None
87 | self.train_instance = torch.utils.data.DataLoader(
88 | train_dataset,
89 | batch_size=batch_size, shuffle=True,
90 | num_workers=num_works, pin_memory=True, sampler=self.dist_sampler)
91 |
92 | balance_sampler = ClassAwareSampler(train_dataset)
93 | self.train_balance = torch.utils.data.DataLoader(
94 | train_dataset,
95 | batch_size=batch_size, shuffle=False,
96 | num_workers=num_works, pin_memory=True, sampler=balance_sampler)
97 |
98 | self.eval = torch.utils.data.DataLoader(
99 | eval_dataset,
100 | batch_size=batch_size, shuffle=False,
101 | num_workers=num_works, pin_memory=True)
102 |
--------------------------------------------------------------------------------
/datasets/cifar100.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .sampler import ClassAwareSampler
3 |
4 | import torch
5 | import torchvision
6 | from torchvision import transforms
7 | import torchvision.datasets
8 |
9 | class IMBALANCECIFAR100(torchvision.datasets.CIFAR100):
10 | cls_num = 100
11 |
12 | def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True,
13 | transform=None, target_transform=None,
14 | download=False):
15 | super(IMBALANCECIFAR100, self).__init__(root, train, transform, target_transform, download)
16 | np.random.seed(rand_number)
17 | img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor)
18 | self.gen_imbalanced_data(img_num_list)
19 |
20 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
21 | img_max = len(self.data) / cls_num
22 | img_num_per_cls = []
23 | if imb_type == 'exp':
24 | for cls_idx in range(cls_num):
25 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
26 | img_num_per_cls.append(int(num))
27 | elif imb_type == 'step':
28 | for cls_idx in range(cls_num // 2):
29 | img_num_per_cls.append(int(img_max))
30 | for cls_idx in range(cls_num // 2):
31 | img_num_per_cls.append(int(img_max * imb_factor))
32 | else:
33 | img_num_per_cls.extend([int(img_max)] * cls_num)
34 | return img_num_per_cls
35 |
36 | def gen_imbalanced_data(self, img_num_per_cls):
37 | new_data = []
38 | new_targets = []
39 | targets_np = np.array(self.targets, dtype=np.int64)
40 | classes = np.unique(targets_np)
41 | # np.random.shuffle(classes)
42 | self.num_per_cls_dict = dict()
43 | for the_class, the_img_num in zip(classes, img_num_per_cls):
44 | self.num_per_cls_dict[the_class] = the_img_num
45 | idx = np.where(targets_np == the_class)[0]
46 | np.random.shuffle(idx)
47 | selec_idx = idx[:the_img_num]
48 | new_data.append(self.data[selec_idx, ...])
49 | new_targets.extend([the_class, ] * the_img_num)
50 | new_data = np.vstack(new_data)
51 | self.data = new_data
52 | self.targets = new_targets
53 |
54 | def get_cls_num_list(self):
55 | cls_num_list = []
56 | for i in range(self.cls_num):
57 | cls_num_list.append(self.num_per_cls_dict[i])
58 | return cls_num_list
59 |
60 |
61 |
62 |
63 |
64 |
65 | class CIFAR100_LT(object):
66 | def __init__(self, distributed, root='./data/cifar100', imb_type='exp',
67 | imb_factor=0.01, batch_size=128, num_works=40):
68 |
69 | train_transform = transforms.Compose([
70 | transforms.RandomCrop(32, padding=4),
71 | transforms.RandomHorizontalFlip(),
72 | transforms.ToTensor(),
73 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
74 | ])
75 |
76 |
77 |
78 | eval_transform = transforms.Compose([
79 | transforms.ToTensor(),
80 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
81 | ])
82 |
83 |
84 | train_dataset = IMBALANCECIFAR100(root=root, imb_type=imb_type, imb_factor=imb_factor, rand_number=0, train=True, download=True, transform=train_transform)
85 | eval_dataset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=eval_transform)
86 |
87 | self.cls_num_list = train_dataset.get_cls_num_list()
88 |
89 | self.dist_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed else None
90 | self.train_instance = torch.utils.data.DataLoader(
91 | train_dataset,
92 | batch_size=batch_size, shuffle=True,
93 | num_workers=num_works, pin_memory=True, sampler=self.dist_sampler)
94 |
95 | balance_sampler = ClassAwareSampler(train_dataset)
96 | self.train_balance = torch.utils.data.DataLoader(
97 | train_dataset,
98 | batch_size=batch_size, shuffle=False,
99 | num_workers=num_works, pin_memory=True, sampler=balance_sampler)
100 |
101 | self.eval = torch.utils.data.DataLoader(
102 | eval_dataset,
103 | batch_size=batch_size, shuffle=False,
104 | num_workers=num_works, pin_memory=True)
--------------------------------------------------------------------------------
/datasets/imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 |
5 | import torch
6 | import torchvision
7 | import torchvision.datasets
8 | from torchvision import transforms
9 | from torch.utils.data import Dataset
10 |
11 | from .sampler import ClassAwareSampler
12 |
13 |
14 | class LT_Dataset(Dataset):
15 | num_classes = 1000
16 |
17 | def __init__(self, root, txt, transform=None):
18 | self.img_path = []
19 | self.targets = []
20 | self.transform = transform
21 | with open(txt) as f:
22 | for line in f:
23 | self.img_path.append(os.path.join(root, line.split()[0]))
24 | self.targets.append(int(line.split()[1]))
25 |
26 | cls_num_list_old = [np.sum(np.array(self.targets) == i) for i in range(self.num_classes)]
27 |
28 | # generate class_map: class index sort by num (descending)
29 | sorted_classes = np.argsort(-np.array(cls_num_list_old))
30 | self.class_map = [0 for i in range(self.num_classes)]
31 | for i in range(self.num_classes):
32 | self.class_map[sorted_classes[i]] = i
33 |
34 | self.targets = np.array(self.class_map)[self.targets].tolist()
35 |
36 | self.class_data = [[] for i in range(self.num_classes)]
37 | for i in range(len(self.targets)):
38 | j = self.targets[i]
39 | self.class_data[j].append(i)
40 |
41 | self.cls_num_list = [np.sum(np.array(self.targets)==i) for i in range(self.num_classes)]
42 |
43 |
44 | def __len__(self):
45 | return len(self.targets)
46 |
47 | def __getitem__(self, index):
48 | path = self.img_path[index]
49 | target = self.targets[index]
50 |
51 | with open(path, 'rb') as f:
52 | sample = Image.open(f).convert('RGB')
53 | if self.transform is not None:
54 | sample = self.transform(sample)
55 | return sample, target
56 |
57 |
58 |
59 | class LT_Dataset_Eval(Dataset):
60 | num_classes = 1000
61 |
62 | def __init__(self, root, txt, class_map, transform=None):
63 | self.img_path = []
64 | self.targets = []
65 | self.transform = transform
66 | self.class_map = class_map
67 | with open(txt) as f:
68 | for line in f:
69 | self.img_path.append(os.path.join(root, line.split()[0]))
70 | self.targets.append(int(line.split()[1]))
71 |
72 | self.targets = np.array(self.class_map)[self.targets].tolist()
73 |
74 | def __len__(self):
75 | return len(self.targets)
76 |
77 | def __getitem__(self, index):
78 | path = self.img_path[index]
79 | target = self.targets[index]
80 |
81 | with open(path, 'rb') as f:
82 | sample = Image.open(f).convert('RGB')
83 | if self.transform is not None:
84 | sample = self.transform(sample)
85 | return sample, target
86 |
87 |
88 | class ImageNet_LT(object):
89 | def __init__(self, distributed, root="", batch_size=60, num_works=40):
90 |
91 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
92 |
93 | transform_train = transforms.Compose([
94 | transforms.RandomResizedCrop(224),
95 | transforms.RandomHorizontalFlip(),
96 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0),
97 | transforms.ToTensor(),
98 | normalize,
99 | ])
100 |
101 |
102 | transform_test = transforms.Compose([
103 | transforms.Resize(256),
104 | transforms.CenterCrop(224),
105 | transforms.ToTensor(),
106 | normalize,
107 | ])
108 |
109 |
110 | train_txt = "./datasets/data_txt/ImageNet_LT_train.txt"
111 | eval_txt = "./datasets/data_txt/ImageNet_LT_test.txt"
112 |
113 | train_dataset = LT_Dataset(root, train_txt, transform=transform_train)
114 | eval_dataset = LT_Dataset_Eval(root, eval_txt, transform=transform_test, class_map=train_dataset.class_map)
115 |
116 | self.cls_num_list = train_dataset.cls_num_list
117 |
118 | self.dist_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed else None
119 | self.train_instance = torch.utils.data.DataLoader(
120 | train_dataset,
121 | batch_size=batch_size, shuffle=True,
122 | num_workers=num_works, pin_memory=True, sampler=self.dist_sampler)
123 |
124 | balance_sampler = ClassAwareSampler(train_dataset)
125 | self.train_balance = torch.utils.data.DataLoader(
126 | train_dataset,
127 | batch_size=batch_size, shuffle=False,
128 | num_workers=num_works, pin_memory=True, sampler=balance_sampler)
129 |
130 | self.eval = torch.utils.data.DataLoader(
131 | eval_dataset,
132 | batch_size=batch_size, shuffle=False,
133 | num_workers=num_works, pin_memory=True)
--------------------------------------------------------------------------------
/datasets/ina2018.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 |
5 | import torch
6 | import torchvision
7 | import torchvision.datasets
8 | from torchvision import transforms
9 | from torch.utils.data import Dataset
10 |
11 | from .sampler import ClassAwareSampler
12 |
13 |
14 | class LT_Dataset(Dataset):
15 | num_classes = 8142
16 |
17 | def __init__(self, root, txt, transform=None):
18 | self.img_path = []
19 | self.targets = []
20 | self.transform = transform
21 | with open(txt) as f:
22 | for line in f:
23 | self.img_path.append(os.path.join(root, line.split()[0]))
24 | self.targets.append(int(line.split()[1]))
25 |
26 | cls_num_list_old = [np.sum(np.array(self.targets) == i) for i in range(self.num_classes)]
27 |
28 | # generate class_map: class index sort by num (descending)
29 | sorted_classes = np.argsort(-np.array(cls_num_list_old))
30 | self.class_map = [0 for i in range(self.num_classes)]
31 | for i in range(self.num_classes):
32 | self.class_map[sorted_classes[i]] = i
33 |
34 | self.targets = np.array(self.class_map)[self.targets].tolist()
35 |
36 | self.class_data = [[] for i in range(self.num_classes)]
37 | for i in range(len(self.targets)):
38 | j = self.targets[i]
39 | self.class_data[j].append(i)
40 |
41 | self.cls_num_list = [np.sum(np.array(self.targets)==i) for i in range(self.num_classes)]
42 |
43 |
44 | def __len__(self):
45 | return len(self.targets)
46 |
47 | def __getitem__(self, index):
48 | path = self.img_path[index]
49 | target = self.targets[index]
50 |
51 | with open(path, 'rb') as f:
52 | sample = Image.open(f).convert('RGB')
53 | if self.transform is not None:
54 | sample = self.transform(sample)
55 | return sample, target
56 |
57 |
58 |
59 | class LT_Dataset_Eval(Dataset):
60 | num_classes = 8142
61 |
62 | def __init__(self, root, txt, class_map, transform=None):
63 | self.img_path = []
64 | self.targets = []
65 | self.transform = transform
66 | self.class_map = class_map
67 | with open(txt) as f:
68 | for line in f:
69 | self.img_path.append(os.path.join(root, line.split()[0]))
70 | self.targets.append(int(line.split()[1]))
71 |
72 | self.targets = np.array(self.class_map)[self.targets].tolist()
73 |
74 | def __len__(self):
75 | return len(self.targets)
76 |
77 | def __getitem__(self, index):
78 | path = self.img_path[index]
79 | target = self.targets[index]
80 |
81 | with open(path, 'rb') as f:
82 | sample = Image.open(f).convert('RGB')
83 | if self.transform is not None:
84 | sample = self.transform(sample)
85 | return sample, target
86 |
87 |
88 | class iNa2018(object):
89 | def __init__(self, distributed, root="", batch_size=60, num_works=40):
90 |
91 | normalize = transforms.Normalize(mean=[0.466, 0.471, 0.380], std=[0.195, 0.194, 0.192])
92 |
93 | transform_train = transforms.Compose([
94 | transforms.RandomResizedCrop(224),
95 | transforms.RandomHorizontalFlip(),
96 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0),
97 | transforms.ToTensor(),
98 | normalize,
99 | ])
100 |
101 |
102 | transform_test = transforms.Compose([
103 | transforms.Resize(256),
104 | transforms.CenterCrop(224),
105 | transforms.ToTensor(),
106 | normalize,
107 | ])
108 |
109 | train_txt = "./datasets/data_txt/iNaturalist18_train.txt"
110 | eval_txt = "./datasets/data_txt/iNaturalist18_val.txt"
111 |
112 | train_dataset = LT_Dataset(root, train_txt, transform=transform_train)
113 | eval_dataset = LT_Dataset_Eval(root, eval_txt, transform=transform_test, class_map=train_dataset.class_map)
114 |
115 | self.cls_num_list = train_dataset.cls_num_list
116 |
117 | self.dist_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed else None
118 | self.train_instance = torch.utils.data.DataLoader(
119 | train_dataset,
120 | batch_size=batch_size, shuffle=True,
121 | num_workers=num_works, pin_memory=True, sampler=self.dist_sampler)
122 |
123 | balance_sampler = ClassAwareSampler(train_dataset)
124 | self.train_balance = torch.utils.data.DataLoader(
125 | train_dataset,
126 | batch_size=batch_size, shuffle=False,
127 | num_workers=num_works, pin_memory=True, sampler=balance_sampler)
128 |
129 | self.eval = torch.utils.data.DataLoader(
130 | eval_dataset,
131 | batch_size=batch_size, shuffle=False,
132 | num_workers=num_works, pin_memory=True)
--------------------------------------------------------------------------------
/datasets/places.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 |
5 | import torch
6 | import torchvision
7 | import torchvision.datasets
8 | from torchvision import transforms
9 | from torch.utils.data import Dataset
10 |
11 | from .sampler import ClassAwareSampler
12 |
13 |
14 | class LT_Dataset(Dataset):
15 | num_classes = 365
16 |
17 | def __init__(self, root, txt, transform=None):
18 | self.img_path = []
19 | self.targets = []
20 | self.transform = transform
21 | with open(txt) as f:
22 | for line in f:
23 | self.img_path.append(os.path.join(root, line.split()[0]))
24 | self.targets.append(int(line.split()[1]))
25 |
26 | cls_num_list_old = [np.sum(np.array(self.targets) == i) for i in range(self.num_classes)]
27 |
28 | # generate class_map: class index sort by num (descending)
29 | sorted_classes = np.argsort(-np.array(cls_num_list_old))
30 | self.class_map = [0 for i in range(self.num_classes)]
31 | for i in range(self.num_classes):
32 | self.class_map[sorted_classes[i]] = i
33 |
34 | self.targets = np.array(self.class_map)[self.targets].tolist()
35 |
36 | self.class_data = [[] for i in range(self.num_classes)]
37 | for i in range(len(self.targets)):
38 | j = self.targets[i]
39 | self.class_data[j].append(i)
40 |
41 | self.cls_num_list = [np.sum(np.array(self.targets)==i) for i in range(self.num_classes)]
42 |
43 |
44 | def __len__(self):
45 | return len(self.targets)
46 |
47 | def __getitem__(self, index):
48 | path = self.img_path[index]
49 | target = self.targets[index]
50 |
51 | with open(path, 'rb') as f:
52 | sample = Image.open(f).convert('RGB')
53 | if self.transform is not None:
54 | sample = self.transform(sample)
55 | return sample, target
56 |
57 |
58 |
59 | class LT_Dataset_Eval(Dataset):
60 | num_classes = 365
61 |
62 | def __init__(self, root, txt, class_map, transform=None):
63 | self.img_path = []
64 | self.targets = []
65 | self.transform = transform
66 | self.class_map = class_map
67 | with open(txt) as f:
68 | for line in f:
69 | self.img_path.append(os.path.join(root, line.split()[0]))
70 | self.targets.append(int(line.split()[1]))
71 |
72 | self.targets = np.array(self.class_map)[self.targets].tolist()
73 |
74 | def __len__(self):
75 | return len(self.targets)
76 |
77 | def __getitem__(self, index):
78 | path = self.img_path[index]
79 | target = self.targets[index]
80 |
81 | with open(path, 'rb') as f:
82 | sample = Image.open(f).convert('RGB')
83 | if self.transform is not None:
84 | sample = self.transform(sample)
85 | return sample, target
86 |
87 |
88 | class Places_LT(object):
89 | def __init__(self, distributed, root="", batch_size=60, num_works=40):
90 |
91 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
92 |
93 | transform_train = transforms.Compose([
94 | transforms.RandomResizedCrop(224),
95 | transforms.RandomHorizontalFlip(),
96 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0),
97 | transforms.ToTensor(),
98 | normalize,
99 | ])
100 |
101 |
102 | transform_test = transforms.Compose([
103 | transforms.Resize(256),
104 | transforms.CenterCrop(224),
105 | transforms.ToTensor(),
106 | normalize,
107 | ])
108 |
109 | train_txt = "./datasets/data_txt/Places_LT_train.txt"
110 | eval_txt = "./datasets/data_txt/Places_LT_test.txt"
111 |
112 |
113 | train_dataset = LT_Dataset(root, train_txt, transform=transform_train)
114 | eval_dataset = LT_Dataset_Eval(root, eval_txt, transform=transform_test, class_map=train_dataset.class_map)
115 |
116 | self.cls_num_list = train_dataset.cls_num_list
117 |
118 | self.dist_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed else None
119 | self.train_instance = torch.utils.data.DataLoader(
120 | train_dataset,
121 | batch_size=batch_size, shuffle=True,
122 | num_workers=num_works, pin_memory=True, sampler=self.dist_sampler)
123 |
124 | balance_sampler = ClassAwareSampler(train_dataset)
125 | self.train_balance = torch.utils.data.DataLoader(
126 | train_dataset,
127 | batch_size=batch_size, shuffle=False,
128 | num_workers=num_works, pin_memory=True, sampler=balance_sampler)
129 |
130 | self.eval = torch.utils.data.DataLoader(
131 | eval_dataset,
132 | batch_size=batch_size, shuffle=False,
133 | num_workers=num_works, pin_memory=True)
134 |
--------------------------------------------------------------------------------
/datasets/sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 |
4 | import torch
5 |
6 | class BalancedDatasetSampler(torch.utils.data.sampler.Sampler):
7 |
8 | def __init__(self, dataset, indices=None, num_samples=None):
9 |
10 | # if indices is not provided,
11 | # all elements in the dataset will be considered
12 | self.indices = list(range(len(dataset))) \
13 | if indices is None else indices
14 |
15 | # if num_samples is not provided,
16 | # draw `len(indices)` samples in each iteration
17 | self.num_samples = len(self.indices) \
18 | if num_samples is None else num_samples
19 |
20 | # distribution of classes in the dataset
21 | label_to_count = [0] * len(np.unique(dataset.targets))
22 | for idx in self.indices:
23 | label = self._get_label(dataset, idx)
24 | label_to_count[label] += 1
25 |
26 |
27 |
28 |
29 | per_cls_weights = 1 / np.array(label_to_count)
30 |
31 | # weight for each sample
32 | weights = [per_cls_weights[self._get_label(dataset, idx)]
33 | for idx in self.indices]
34 |
35 |
36 | self.weights = torch.DoubleTensor(weights)
37 |
38 | def _get_label(self, dataset, idx):
39 | return dataset.targets[idx]
40 |
41 | def __iter__(self):
42 | return iter(torch.multinomial(self.weights, self.num_samples, replacement=True).tolist())
43 |
44 | def __len__(self):
45 | return self.num_samples
46 |
47 | class EffectNumSampler(torch.utils.data.sampler.Sampler):
48 |
49 | def __init__(self, dataset, indices=None, num_samples=None):
50 |
51 | # if indices is not provided,
52 | # all elements in the dataset will be considered
53 | self.indices = list(range(len(dataset))) \
54 | if indices is None else indices
55 |
56 | # if num_samples is not provided,
57 | # draw `len(indices)` samples in each iteration
58 | self.num_samples = len(self.indices) \
59 | if num_samples is None else num_samples
60 |
61 | # distribution of classes in the dataset
62 | label_to_count = [0] * len(np.unique(dataset.targets))
63 | for idx in self.indices:
64 | label = self._get_label(dataset, idx)
65 | label_to_count[label] += 1
66 |
67 |
68 |
69 | beta = 0.9999
70 | effective_num = 1.0 - np.power(beta, label_to_count)
71 | per_cls_weights = (1.0 - beta) / np.array(effective_num)
72 |
73 | # weight for each sample
74 | weights = [per_cls_weights[self._get_label(dataset, idx)]
75 | for idx in self.indices]
76 |
77 |
78 | self.weights = torch.DoubleTensor(weights)
79 |
80 | def _get_label(self, dataset, idx):
81 | return dataset.targets[idx]
82 |
83 | def __iter__(self):
84 | return iter(torch.multinomial(self.weights, self.num_samples, replacement=True).tolist())
85 |
86 | def __len__(self):
87 | return self.num_samples
88 |
89 | class RandomCycleIter:
90 |
91 | def __init__ (self, data, test_mode=False):
92 | self.data_list = list(data)
93 | self.length = len(self.data_list)
94 | self.i = self.length - 1
95 | self.test_mode = test_mode
96 |
97 | def __iter__ (self):
98 | return self
99 |
100 | def __next__ (self):
101 | self.i += 1
102 |
103 | if self.i == self.length:
104 | self.i = 0
105 | if not self.test_mode:
106 | random.shuffle(self.data_list)
107 |
108 | return self.data_list[self.i]
109 |
110 | def class_aware_sample_generator(cls_iter, data_iter_list, n, num_samples_cls=1):
111 |
112 | i = 0
113 | j = 0
114 | while i < n:
115 |
116 | # yield next(data_iter_list[next(cls_iter)])
117 |
118 | if j >= num_samples_cls:
119 | j = 0
120 |
121 | if j == 0:
122 | temp_tuple = next(zip(*[data_iter_list[next(cls_iter)]]*num_samples_cls))
123 | yield temp_tuple[j]
124 | else:
125 | yield temp_tuple[j]
126 |
127 | i += 1
128 | j += 1
129 |
130 | class ClassAwareSampler(torch.utils.data.sampler.Sampler):
131 | def __init__(self, data_source, num_samples_cls=4,):
132 | # pdb.set_trace()
133 | num_classes = len(np.unique(data_source.targets))
134 | self.class_iter = RandomCycleIter(range(num_classes))
135 | cls_data_list = [list() for _ in range(num_classes)]
136 | for i, label in enumerate(data_source.targets):
137 | cls_data_list[label].append(i)
138 | self.data_iter_list = [RandomCycleIter(x) for x in cls_data_list]
139 | self.num_samples = max([len(x) for x in cls_data_list]) * len(cls_data_list)
140 | self.num_samples_cls = num_samples_cls
141 |
142 | def __iter__ (self):
143 | return class_aware_sample_generator(self.class_iter, self.data_iter_list,
144 | self.num_samples, self.num_samples_cls)
145 |
146 | def __len__ (self):
147 | return self.num_samples
148 |
149 | def get_sampler():
150 | return ClassAwareSampler
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import time
5 | import warnings
6 | import numpy as np
7 | import pprint
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.parallel
12 | import torch.distributed as dist
13 | import torch.optim
14 | import torch.multiprocessing as mp
15 | import torch.utils.data
16 | import torch.utils.data.distributed
17 | import torch.nn.functional as F
18 |
19 | from datasets.cifar10 import CIFAR10_LT
20 | from datasets.cifar100 import CIFAR100_LT
21 | from datasets.places import Places_LT
22 | from datasets.imagenet import ImageNet_LT
23 | from datasets.ina2018 import iNa2018
24 |
25 | from models import resnet
26 | from models import resnet_places
27 | from models import resnet_cifar
28 |
29 | from utils import config, update_config, create_logger
30 | from utils import AverageMeter, ProgressMeter
31 | from utils import accuracy, calibration
32 |
33 | from methods import LearnableWeightScaling
34 |
35 |
36 | def parse_args():
37 | parser = argparse.ArgumentParser(description='MiSLAS evaluation')
38 | parser.add_argument('--cfg',
39 | help='experiment configure file name',
40 | required=True,
41 | type=str)
42 | parser.add_argument('opts',
43 | help="Modify config options using the command-line",
44 | default=None,
45 | nargs=argparse.REMAINDER)
46 |
47 | args = parser.parse_args()
48 | update_config(config, args)
49 |
50 | return args
51 |
52 |
53 | best_acc1 = 0
54 |
55 |
56 | def main():
57 | args = parse_args()
58 | logger, model_dir = create_logger(config, args.cfg)
59 | logger.info('\n' + pprint.pformat(args))
60 | logger.info('\n' + str(config))
61 |
62 | if config.deterministic:
63 | seed = 0
64 | torch.backends.cudnn.deterministic = True
65 | torch.backends.cudnn.benchmark = False
66 | random.seed(seed)
67 | np.random.seed(seed)
68 | os.environ['PYTHONHASHSEED'] = str(seed)
69 | torch.manual_seed(seed)
70 | torch.cuda.manual_seed(seed)
71 | torch.cuda.manual_seed_all(seed)
72 |
73 | if config.gpu is not None:
74 | warnings.warn('You have chosen a specific GPU. This will completely '
75 | 'disable data parallelism.')
76 |
77 | if config.dist_url == "env://" and config.world_size == -1:
78 | config.world_size = int(os.environ["WORLD_SIZE"])
79 |
80 | config.distributed = config.world_size > 1 or config.multiprocessing_distributed
81 |
82 | ngpus_per_node = torch.cuda.device_count()
83 | if config.multiprocessing_distributed:
84 | # Since we have ngpus_per_node processes per node, the total world_size
85 | # needs to be adjusted accordingly
86 | config.world_size = ngpus_per_node * config.world_size
87 | # Use torch.multiprocessing.spawn to launch distributed processes: the
88 | # main_worker process function
89 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, config, logger))
90 | else:
91 | # Simply call main_worker function
92 | main_worker(config.gpu, ngpus_per_node, config, logger, model_dir)
93 |
94 |
95 | def main_worker(gpu, ngpus_per_node, config, logger, model_dir):
96 | global best_acc1
97 | config.gpu = gpu
98 | # start_time = time.strftime("%Y%m%d_%H%M%S", time.localtime())
99 |
100 | if config.gpu is not None:
101 | logger.info("Use GPU: {} for training".format(config.gpu))
102 |
103 | if config.distributed:
104 | if config.dist_url == "env://" and config.rank == -1:
105 | config.rank = int(os.environ["RANK"])
106 | if config.multiprocessing_distributed:
107 | # For multiprocessing distributed training, rank needs to be the
108 | # global rank among all the processes
109 | config.rank = config.rank * ngpus_per_node + gpu
110 | dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url,
111 | world_size=config.world_size, rank=config.rank)
112 |
113 | if config.dataset == 'cifar10' or config.dataset == 'cifar100':
114 | model = getattr(resnet_cifar, config.backbone)()
115 | classifier = getattr(resnet_cifar, 'Classifier')(feat_in=64, num_classes=config.num_classes)
116 |
117 | elif config.dataset == 'imagenet' or config.dataset == 'ina2018':
118 | model = getattr(resnet, config.backbone)()
119 | classifier = getattr(resnet, 'Classifier')(feat_in=2048, num_classes=config.num_classes)
120 |
121 | elif config.dataset == 'places':
122 | model = getattr(resnet_places, config.backbone)(pretrained=True)
123 | classifier = getattr(resnet_places, 'Classifier')(feat_in=2048, num_classes=config.num_classes)
124 | block = getattr(resnet_places, 'Bottleneck')(2048, 512, groups=1,
125 | base_width=64, dilation=1,
126 | norm_layer=nn.BatchNorm2d)
127 |
128 | lws_model = LearnableWeightScaling(num_classes=config.num_classes)
129 |
130 | if not torch.cuda.is_available():
131 | logger.info('using CPU, this will be slow')
132 | elif config.distributed:
133 | # For multiprocessing distributed, DistributedDataParallel constructor
134 | # should always set the single device scope, otherwise,
135 | # DistributedDataParallel will use all available devices.
136 | if config.gpu is not None:
137 | torch.cuda.set_device(config.gpu)
138 | model.cuda(config.gpu)
139 | classifier.cuda(config.gpu)
140 | lws_model.cuda(config.gpu)
141 | # When using a single GPU per process and per
142 | # DistributedDataParallel, we need to divide the batch size
143 | # ourselves based on the total number of GPUs we have
144 | config.batch_size = int(config.batch_size / ngpus_per_node)
145 | config.workers = int((config.workers + ngpus_per_node - 1) / ngpus_per_node)
146 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu])
147 | classifier = torch.nn.parallel.DistributedDataParallel(classifier, device_ids=[config.gpu])
148 | lws_model = torch.nn.parallel.DistributedDataParallel(lws_model, device_ids=[config.gpu])
149 |
150 | if config.dataset == 'places':
151 | block.cuda(config.gpu)
152 | block = torch.nn.parallel.DistributedDataParallel(block, device_ids=[config.gpu])
153 | else:
154 | model.cuda()
155 | classifier.cuda()
156 | lws_model.cuda()
157 | # DistributedDataParallel will divide and allocate batch_size to all
158 | # available GPUs if device_ids are not set
159 | model = torch.nn.parallel.DistributedDataParallel(model)
160 | classifier = torch.nn.parallel.DistributedDataParallel(classifier)
161 | lws_model = torch.nn.parallel.DistributedDataParallel(lws_model)
162 | if config.dataset == 'places':
163 | block.cuda()
164 | block = torch.nn.parallel.DistributedDataParallel(block)
165 |
166 | elif config.gpu is not None:
167 | torch.cuda.set_device(config.gpu)
168 | model = model.cuda(config.gpu)
169 | classifier = classifier.cuda(config.gpu)
170 | lws_model = lws_model.cuda(config.gpu)
171 | if config.dataset == 'places':
172 | block.cuda(config.gpu)
173 | else:
174 | # DataParallel will divide and allocate batch_size to all available GPUs
175 | model = torch.nn.DataParallel(model).cuda()
176 | classifier = torch.nn.DataParallel(classifier).cuda()
177 | lws_model = torch.nn.DataParallel(lws_model).cuda()
178 | if config.dataset == 'places':
179 | block = torch.nn.DataParallel(block).cuda()
180 |
181 | # optionally resume from a checkpoint
182 | if config.resume:
183 | if os.path.isfile(config.resume):
184 | logger.info("=> loading checkpoint '{}'".format(config.resume))
185 | if config.gpu is None:
186 | checkpoint = torch.load(config.resume)
187 | else:
188 | # Map model to be loaded to specified single gpu.
189 | loc = 'cuda:{}'.format(config.gpu)
190 | checkpoint = torch.load(config.resume, map_location=loc)
191 | if config.gpu is not None:
192 | # best_acc1 may be from a checkpoint from a different GPU
193 | best_acc1 = best_acc1.to(config.gpu)
194 | model.load_state_dict(checkpoint['state_dict_model'])
195 | classifier.load_state_dict(checkpoint['state_dict_classifier'])
196 | if config.dataset == 'places':
197 | block.load_state_dict(checkpoint['state_dict_block'])
198 | if config.mode == 'stage2':
199 | lws_model.load_state_dict(checkpoint['state_dict_lws_model'])
200 | logger.info("=> loaded checkpoint '{}' (epoch {})"
201 | .format(config.resume, checkpoint['epoch']))
202 | else:
203 | logger.info("=> no checkpoint found at '{}'".format(config.resume))
204 |
205 | # Data loading code
206 | if config.dataset == 'cifar10':
207 | dataset = CIFAR10_LT(config.distributed, root=config.data_path, imb_factor=config.imb_factor,
208 | batch_size=config.batch_size, num_works=config.workers)
209 |
210 | elif config.dataset == 'cifar100':
211 | dataset = CIFAR100_LT(config.distributed, root=config.data_path, imb_factor=config.imb_factor,
212 | batch_size=config.batch_size, num_works=config.workers)
213 |
214 | elif config.dataset == 'places':
215 | dataset = Places_LT(config.distributed, root=config.data_path,
216 | batch_size=config.batch_size, num_works=config.workers)
217 |
218 | elif config.dataset == 'imagenet':
219 | dataset = ImageNet_LT(config.distributed, root=config.data_path,
220 | batch_size=config.batch_size, num_works=config.workers)
221 |
222 | elif config.dataset == 'ina2018':
223 | dataset = iNa2018(config.distributed, root=config.data_path,
224 | batch_size=config.batch_size, num_works=config.workers)
225 |
226 | val_loader = dataset.eval
227 | criterion = nn.CrossEntropyLoss().cuda(config.gpu)
228 |
229 | if config.dataset != 'places':
230 | block = None
231 |
232 | validate(val_loader, model, classifier, lws_model, criterion, config, logger, block)
233 |
234 |
235 | def validate(val_loader, model, classifier, lws_model, criterion, config, logger, block=None):
236 | batch_time = AverageMeter('Time', ':6.3f')
237 | losses = AverageMeter('Loss', ':.3f')
238 | top1 = AverageMeter('Acc@1', ':6.3f')
239 | top5 = AverageMeter('Acc@5', ':6.3f')
240 | progress = ProgressMeter(
241 | len(val_loader),
242 | [batch_time, losses, top1, top5],
243 | prefix='Eval: ')
244 |
245 | # switch to evaluate mode
246 | model.eval()
247 | if config.dataset == 'places':
248 | block.eval()
249 | classifier.eval()
250 | class_num = torch.zeros(config.num_classes).cuda()
251 | correct = torch.zeros(config.num_classes).cuda()
252 |
253 | confidence = np.array([])
254 | pred_class = np.array([])
255 | true_class = np.array([])
256 |
257 | with torch.no_grad():
258 | end = time.time()
259 | for i, (images, target) in enumerate(val_loader):
260 | if config.gpu is not None:
261 | images = images.cuda(config.gpu, non_blocking=True)
262 | if torch.cuda.is_available():
263 | target = target.cuda(config.gpu, non_blocking=True)
264 |
265 | # compute output
266 | if config.dataset == 'places':
267 | feat = block(model(images))
268 | else:
269 | feat = model(images)
270 | output = classifier(feat)
271 | output = lws_model(output)
272 | loss = criterion(output, target)
273 |
274 | # measure accuracy and record loss
275 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
276 | losses.update(loss.item(), images.size(0))
277 | top1.update(acc1[0], images.size(0))
278 | top5.update(acc5[0], images.size(0))
279 |
280 | _, predicted = output.max(1)
281 | target_one_hot = F.one_hot(target, config.num_classes)
282 | predict_one_hot = F.one_hot(predicted, config.num_classes)
283 | class_num = class_num + target_one_hot.sum(dim=0).to(torch.float)
284 | correct = correct + (target_one_hot + predict_one_hot == 2).sum(dim=0).to(torch.float)
285 |
286 | prob = torch.softmax(output, dim=1)
287 | confidence_part, pred_class_part = torch.max(prob, dim=1)
288 | confidence = np.append(confidence, confidence_part.cpu().numpy())
289 | pred_class = np.append(pred_class, pred_class_part.cpu().numpy())
290 | true_class = np.append(true_class, target.cpu().numpy())
291 |
292 | # measure elapsed time
293 | batch_time.update(time.time() - end)
294 | end = time.time()
295 |
296 | if i % config.print_freq == 0:
297 | progress.display(i, logger)
298 |
299 | acc_classes = correct / class_num
300 | head_acc = acc_classes[config.head_class_idx[0]:config.head_class_idx[1]].mean() * 100
301 | med_acc = acc_classes[config.med_class_idx[0]:config.med_class_idx[1]].mean() * 100
302 | tail_acc = acc_classes[config.tail_class_idx[0]:config.tail_class_idx[1]].mean() * 100
303 |
304 | logger.info('* Acc@1 {top1.avg:.3f}% Acc@5 {top5.avg:.3f}% HAcc {head_acc:.3f}% MAcc {med_acc:.3f}% TAcc {tail_acc:.3f}%.'.format(top1=top1, top5=top5, head_acc=head_acc, med_acc=med_acc, tail_acc=tail_acc))
305 |
306 | cal = calibration(true_class, pred_class, confidence, num_bins=15)
307 | logger.info('* ECE {ece:.3f}%.'.format(ece=cal['expected_calibration_error'] * 100))
308 |
309 | return top1.avg, cal['expected_calibration_error'] * 100
310 |
311 |
312 | if __name__ == '__main__':
313 | main()
314 |
--------------------------------------------------------------------------------
/methods.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | def mixup_data(x, y, alpha=1.0, use_cuda=True):
9 | '''Returns mixed inputs, pairs of targets, and lambda'''
10 | if alpha > 0:
11 | lam = np.random.beta(alpha, alpha)
12 | else:
13 | lam = 1
14 |
15 | batch_size = x.size()[0]
16 | if use_cuda:
17 | index = torch.randperm(batch_size).cuda()
18 | else:
19 | index = torch.randperm(batch_size)
20 |
21 | mixed_x = lam * x + (1 - lam) * x[index, :]
22 | y_a, y_b = y, y[index]
23 |
24 | return mixed_x, y_a, y_b, lam
25 |
26 |
27 | def mixup_criterion(criterion, pred, y_a, y_b, lam):
28 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
29 |
30 |
31 | class LabelAwareSmoothing(nn.Module):
32 | def __init__(self, cls_num_list, smooth_head, smooth_tail, shape='concave', power=None):
33 | super(LabelAwareSmoothing, self).__init__()
34 |
35 | n_1 = max(cls_num_list)
36 | n_K = min(cls_num_list)
37 |
38 | if shape == 'concave':
39 | self.smooth = smooth_tail + (smooth_head - smooth_tail) * np.sin((np.array(cls_num_list) - n_K) * np.pi / (2 * (n_1 - n_K)))
40 |
41 | elif shape == 'linear':
42 | self.smooth = smooth_tail + (smooth_head - smooth_tail) * (np.array(cls_num_list) - n_K) / (n_1 - n_K)
43 |
44 | elif shape == 'convex':
45 | self.smooth = smooth_head + (smooth_head - smooth_tail) * np.sin(1.5 * np.pi + (np.array(cls_num_list) - n_K) * np.pi / (2 * (n_1 - n_K)))
46 |
47 | elif shape == 'exp' and power is not None:
48 | self.smooth = smooth_tail + (smooth_head - smooth_tail) * np.power((np.array(cls_num_list) - n_K) / (n_1 - n_K), power)
49 |
50 | self.smooth = torch.from_numpy(self.smooth)
51 | self.smooth = self.smooth.float()
52 | if torch.cuda.is_available():
53 | self.smooth = self.smooth.cuda()
54 |
55 | def forward(self, x, target):
56 | smoothing = self.smooth[target]
57 | confidence = 1. - smoothing
58 | logprobs = F.log_softmax(x, dim=-1)
59 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
60 | nll_loss = nll_loss.squeeze(1)
61 | smooth_loss = -logprobs.mean(dim=-1)
62 | loss = confidence * nll_loss + smoothing * smooth_loss
63 |
64 | return loss.mean()
65 |
66 |
67 | class LearnableWeightScaling(nn.Module):
68 | def __init__(self, num_classes):
69 | super(LearnableWeightScaling, self).__init__()
70 | self.learned_norm = nn.Parameter(torch.ones(1, num_classes))
71 |
72 | def forward(self, x):
73 | return self.learned_norm * x
74 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | try:
5 | from torch.hub import load_state_dict_from_url
6 | except ImportError:
7 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
8 |
9 |
10 |
11 | __all__ = ['Classifier', 'ResNet', 'resnet10', 'resnet10_fe', 'resnet18', 'resnet34', 'resnet50', 'resnet50_fe', 'resnet101', 'resnet101_fe',
12 | 'resnet152', 'resnet152_fe', 'resnext50_32x4d', 'resnext101_32x8d', 'resnext152_32x4d',
13 | 'wide_resnet50_2', 'wide_resnet101_2']
14 |
15 |
16 |
17 | model_urls = {
18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
23 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
25 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
26 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
27 | }
28 |
29 |
30 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
31 | """3x3 convolution with padding"""
32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
33 | padding=dilation, groups=groups, bias=False, dilation=dilation)
34 |
35 |
36 | def conv1x1(in_planes, out_planes, stride=1):
37 | """1x1 convolution"""
38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
39 |
40 |
41 | class BasicBlock(nn.Module):
42 | expansion = 1
43 |
44 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
45 | base_width=64, dilation=1, norm_layer=None):
46 | super(BasicBlock, self).__init__()
47 | if norm_layer is None:
48 | norm_layer = nn.BatchNorm2d
49 | if groups != 1 or base_width != 64:
50 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
51 | if dilation > 1:
52 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
53 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
54 | self.conv1 = conv3x3(inplanes, planes, stride)
55 | self.bn1 = norm_layer(planes)
56 | self.relu = nn.ReLU(inplace=True)
57 | self.conv2 = conv3x3(planes, planes)
58 | self.bn2 = norm_layer(planes)
59 | self.downsample = downsample
60 | self.stride = stride
61 |
62 | def forward(self, x):
63 | identity = x
64 |
65 | out = self.conv1(x)
66 | out = self.bn1(out)
67 | out = self.relu(out)
68 |
69 | out = self.conv2(out)
70 | out = self.bn2(out)
71 |
72 | if self.downsample is not None:
73 | identity = self.downsample(x)
74 |
75 | out += identity
76 | out = self.relu(out)
77 |
78 | return out
79 |
80 |
81 | class Bottleneck(nn.Module):
82 | expansion = 4
83 |
84 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
85 | base_width=64, dilation=1, norm_layer=None):
86 | super(Bottleneck, self).__init__()
87 | if norm_layer is None:
88 | norm_layer = nn.BatchNorm2d
89 | width = int(planes * (base_width / 64.)) * groups
90 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
91 | self.conv1 = conv1x1(inplanes, width)
92 | self.bn1 = norm_layer(width)
93 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
94 | self.bn2 = norm_layer(width)
95 | self.conv3 = conv1x1(width, planes * self.expansion)
96 | self.bn3 = norm_layer(planes * self.expansion)
97 | self.relu = nn.ReLU(inplace=True)
98 | self.downsample = downsample
99 | self.stride = stride
100 |
101 | def forward(self, x):
102 | identity = x
103 |
104 | out = self.conv1(x)
105 | out = self.bn1(out)
106 | out = self.relu(out)
107 |
108 | out = self.conv2(out)
109 | out = self.bn2(out)
110 | out = self.relu(out)
111 |
112 | out = self.conv3(out)
113 | out = self.bn3(out)
114 |
115 | if self.downsample is not None:
116 | identity = self.downsample(x)
117 |
118 | out += identity
119 | out = self.relu(out)
120 |
121 | return out
122 |
123 |
124 | class ResNet(nn.Module):
125 |
126 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
127 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
128 | norm_layer=None):
129 | super(ResNet, self).__init__()
130 | if norm_layer is None:
131 | norm_layer = nn.BatchNorm2d
132 | self._norm_layer = norm_layer
133 |
134 | self.inplanes = 64
135 | self.dilation = 1
136 | if replace_stride_with_dilation is None:
137 | # each element in the tuple indicates if we should replace
138 | # the 2x2 stride with a dilated convolution instead
139 | replace_stride_with_dilation = [False, False, False]
140 | if len(replace_stride_with_dilation) != 3:
141 | raise ValueError("replace_stride_with_dilation should be None "
142 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
143 | self.groups = groups
144 | self.base_width = width_per_group
145 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
146 | bias=False)
147 | self.bn1 = norm_layer(self.inplanes)
148 | self.relu = nn.ReLU(inplace=True)
149 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
150 | self.layer1 = self._make_layer(block, 64, layers[0])
151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
152 | dilate=replace_stride_with_dilation[0])
153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
154 | dilate=replace_stride_with_dilation[1])
155 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
156 | dilate=replace_stride_with_dilation[2])
157 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
158 | self.fc = nn.Linear(512 * block.expansion, num_classes)
159 |
160 | for m in self.modules():
161 | if isinstance(m, nn.Conv2d):
162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
163 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
164 | nn.init.constant_(m.weight, 1)
165 | nn.init.constant_(m.bias, 0)
166 |
167 | # Zero-initialize the last BN in each residual branch,
168 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
169 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
170 | if zero_init_residual:
171 | for m in self.modules():
172 | if isinstance(m, Bottleneck):
173 | nn.init.constant_(m.bn3.weight, 0)
174 | elif isinstance(m, BasicBlock):
175 | nn.init.constant_(m.bn2.weight, 0)
176 |
177 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
178 | norm_layer = self._norm_layer
179 | downsample = None
180 | previous_dilation = self.dilation
181 | if dilate:
182 | self.dilation *= stride
183 | stride = 1
184 | if stride != 1 or self.inplanes != planes * block.expansion:
185 | downsample = nn.Sequential(
186 | conv1x1(self.inplanes, planes * block.expansion, stride),
187 | norm_layer(planes * block.expansion),
188 | )
189 |
190 | layers = []
191 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
192 | self.base_width, previous_dilation, norm_layer))
193 | self.inplanes = planes * block.expansion
194 | for _ in range(1, blocks):
195 | layers.append(block(self.inplanes, planes, groups=self.groups,
196 | base_width=self.base_width, dilation=self.dilation,
197 | norm_layer=norm_layer))
198 |
199 | return nn.Sequential(*layers)
200 |
201 | def forward(self, x):
202 | x = self.conv1(x)
203 | x = self.bn1(x)
204 | x = self.relu(x)
205 | x = self.maxpool(x)
206 |
207 | x = self.layer1(x)
208 | x = self.layer2(x)
209 | x = self.layer3(x)
210 | x = self.layer4(x)
211 |
212 | x = self.avgpool(x)
213 | x = torch.flatten(x, 1)
214 | x = self.fc(x)
215 |
216 | return x
217 |
218 | class ResNet_FE(nn.Module):
219 |
220 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
221 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
222 | norm_layer=None):
223 | super(ResNet_FE, self).__init__()
224 | if norm_layer is None:
225 | norm_layer = nn.BatchNorm2d
226 | self._norm_layer = norm_layer
227 |
228 | self.inplanes = 64
229 | self.dilation = 1
230 | if replace_stride_with_dilation is None:
231 | # each element in the tuple indicates if we should replace
232 | # the 2x2 stride with a dilated convolution instead
233 | replace_stride_with_dilation = [False, False, False]
234 | if len(replace_stride_with_dilation) != 3:
235 | raise ValueError("replace_stride_with_dilation should be None "
236 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
237 | self.groups = groups
238 | self.base_width = width_per_group
239 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
240 | bias=False)
241 | self.bn1 = norm_layer(self.inplanes)
242 | self.relu = nn.ReLU(inplace=True)
243 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
244 | self.layer1 = self._make_layer(block, 64, layers[0])
245 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
246 | dilate=replace_stride_with_dilation[0])
247 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
248 | dilate=replace_stride_with_dilation[1])
249 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
250 | dilate=replace_stride_with_dilation[2])
251 |
252 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
253 |
254 | for m in self.modules():
255 | if isinstance(m, nn.Conv2d):
256 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
257 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
258 | nn.init.constant_(m.weight, 1)
259 | nn.init.constant_(m.bias, 0)
260 |
261 | # Zero-initialize the last BN in each residual branch,
262 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
263 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
264 | if zero_init_residual:
265 | for m in self.modules():
266 | if isinstance(m, Bottleneck):
267 | nn.init.constant_(m.bn3.weight, 0)
268 | elif isinstance(m, BasicBlock):
269 | nn.init.constant_(m.bn2.weight, 0)
270 |
271 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
272 | norm_layer = self._norm_layer
273 | downsample = None
274 | previous_dilation = self.dilation
275 | if dilate:
276 | self.dilation *= stride
277 | stride = 1
278 | if stride != 1 or self.inplanes != planes * block.expansion:
279 | downsample = nn.Sequential(
280 | conv1x1(self.inplanes, planes * block.expansion, stride),
281 | norm_layer(planes * block.expansion),
282 | )
283 |
284 | layers = []
285 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
286 | self.base_width, previous_dilation, norm_layer))
287 | self.inplanes = planes * block.expansion
288 | for _ in range(1, blocks):
289 | layers.append(block(self.inplanes, planes, groups=self.groups,
290 | base_width=self.base_width, dilation=self.dilation,
291 | norm_layer=norm_layer))
292 |
293 | return nn.Sequential(*layers)
294 |
295 | def forward(self, x):
296 | x = self.conv1(x)
297 | x = self.bn1(x)
298 | x = self.relu(x)
299 | x = self.maxpool(x)
300 |
301 | x = self.layer1(x)
302 | x = self.layer2(x)
303 | x = self.layer3(x)
304 | x = self.layer4(x)
305 | x = self.avgpool(x)
306 | x = torch.flatten(x, 1)
307 |
308 | return x
309 |
310 |
311 |
312 | class Classifier(nn.Module):
313 | def __init__(self, feat_in, num_classes):
314 | super(Classifier, self).__init__()
315 | self.fc = nn.Linear(feat_in, num_classes)
316 |
317 | def forward(self, x):
318 | x = self.fc(x)
319 | return x
320 |
321 |
322 |
323 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
324 | model = ResNet(block, layers, **kwargs)
325 | if pretrained:
326 | state_dict = load_state_dict_from_url(model_urls[arch],
327 | progress=progress)
328 | model.load_state_dict(state_dict)
329 | return model
330 |
331 | def _resnet_fe(arch, block, layers, pretrained, progress, **kwargs):
332 | model = ResNet_FE(block, layers, **kwargs)
333 | if pretrained:
334 | # import pdb; pdb.set_trace()
335 | state_dict = load_state_dict_from_url(model_urls[arch],
336 | progress=progress)
337 | model.load_state_dict(state_dict)
338 | return model
339 |
340 |
341 | def resnet10(pretrained=False, progress=True, **kwargs):
342 | r"""ResNet-18 model from
343 | `"Deep Residual Learning for Image Recognition" `_
344 |
345 | Args:
346 | pretrained (bool): If True, returns a model pre-trained on ImageNet
347 | progress (bool): If True, displays a progress bar of the download to stderr
348 | """
349 | return _resnet('resnet10', BasicBlock, [1, 1, 1, 1], pretrained, progress,
350 | **kwargs)
351 |
352 |
353 | def resnet10_fe(pretrained=False, progress=True, **kwargs):
354 | r"""ResNet-18 model from
355 | `"Deep Residual Learning for Image Recognition" `_
356 |
357 | Args:
358 | pretrained (bool): If True, returns a model pre-trained on ImageNet
359 | progress (bool): If True, displays a progress bar of the download to stderr
360 | """
361 | return _resnet_fe('resnet10_fe', BasicBlock, [1, 1, 1, 1], pretrained, progress,
362 | **kwargs)
363 |
364 |
365 |
366 | def resnet18(pretrained=False, progress=True, **kwargs):
367 | r"""ResNet-18 model from
368 | `"Deep Residual Learning for Image Recognition" `_
369 |
370 | Args:
371 | pretrained (bool): If True, returns a model pre-trained on ImageNet
372 | progress (bool): If True, displays a progress bar of the download to stderr
373 | """
374 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
375 | **kwargs)
376 |
377 |
378 | def resnet34(pretrained=False, progress=True, **kwargs):
379 | r"""ResNet-34 model from
380 | `"Deep Residual Learning for Image Recognition" `_
381 |
382 | Args:
383 | pretrained (bool): If True, returns a model pre-trained on ImageNet
384 | progress (bool): If True, displays a progress bar of the download to stderr
385 | """
386 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
387 | **kwargs)
388 |
389 |
390 | def resnet50(pretrained=False, progress=True, **kwargs):
391 | r"""ResNet-50 model from
392 | `"Deep Residual Learning for Image Recognition" `_
393 |
394 | Args:
395 | pretrained (bool): If True, returns a model pre-trained on ImageNet
396 | progress (bool): If True, displays a progress bar of the download to stderr
397 | """
398 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
399 | **kwargs)
400 |
401 | def resnet50_fe(pretrained=False, progress=True, **kwargs):
402 | r"""ResNet-18 model from
403 | `"Deep Residual Learning for Image Recognition" `_
404 |
405 | Args:
406 | pretrained (bool): If True, returns a model pre-trained on ImageNet
407 | progress (bool): If True, displays a progress bar of the download to stderr
408 | """
409 | return _resnet_fe('resnet50_fe', Bottleneck, [3, 4, 6, 3], pretrained, progress,
410 | **kwargs)
411 |
412 |
413 | def resnet101(pretrained=False, progress=True, **kwargs):
414 | r"""ResNet-101 model from
415 | `"Deep Residual Learning for Image Recognition" `_
416 |
417 | Args:
418 | pretrained (bool): If True, returns a model pre-trained on ImageNet
419 | progress (bool): If True, displays a progress bar of the download to stderr
420 | """
421 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
422 | **kwargs)
423 |
424 |
425 | def resnet101_fe(pretrained=False, progress=True, **kwargs):
426 | r"""ResNet-101 model from
427 | `"Deep Residual Learning for Image Recognition" `_
428 |
429 | Args:
430 | pretrained (bool): If True, returns a model pre-trained on ImageNet
431 | progress (bool): If True, displays a progress bar of the download to stderr
432 | """
433 | return _resnet_fe('resnet101_fe', Bottleneck, [3, 4, 23, 3], pretrained, progress,
434 | **kwargs)
435 |
436 | def resnet152(pretrained=False, progress=True, **kwargs):
437 | r"""ResNet-152 model from
438 | `"Deep Residual Learning for Image Recognition" `_
439 |
440 | Args:
441 | pretrained (bool): If True, returns a model pre-trained on ImageNet
442 | progress (bool): If True, displays a progress bar of the download to stderr
443 | """
444 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
445 | **kwargs)
446 |
447 |
448 | def resnet152_fe(pretrained=False, progress=True, **kwargs):
449 | r"""ResNet-152 model from
450 | `"Deep Residual Learning for Image Recognition" `_
451 |
452 | Args:
453 | pretrained (bool): If True, returns a model pre-trained on ImageNet
454 | progress (bool): If True, displays a progress bar of the download to stderr
455 |
456 | """
457 | return _resnet_fe('resnet152_fe', Bottleneck, [3, 8, 36, 3], pretrained, progress,
458 | **kwargs)
459 |
460 |
461 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
462 | r"""ResNeXt-50 32x4d model from
463 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
464 |
465 | Args:
466 | pretrained (bool): If True, returns a model pre-trained on ImageNet
467 | progress (bool): If True, displays a progress bar of the download to stderr
468 | """
469 | kwargs['groups'] = 32
470 | kwargs['width_per_group'] = 4
471 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
472 | pretrained, progress, **kwargs)
473 |
474 |
475 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
476 | r"""ResNeXt-101 32x8d model from
477 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
478 |
479 | Args:
480 | pretrained (bool): If True, returns a model pre-trained on ImageNet
481 | progress (bool): If True, displays a progress bar of the download to stderr
482 | """
483 | kwargs['groups'] = 32
484 | kwargs['width_per_group'] = 8
485 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
486 | pretrained, progress, **kwargs)
487 |
488 | def resnext152_32x4d(pretrained=False, progress=True, **kwargs):
489 | kwargs['groups'] = 32
490 | kwargs['width_per_group'] = 4
491 | return _resnet('resnext152_32x4d', Bottleneck, [3, 8, 36, 3],
492 | pretrained, progress, **kwargs)
493 |
494 |
495 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
496 | r"""Wide ResNet-50-2 model from
497 | `"Wide Residual Networks" `_
498 |
499 | The model is the same as ResNet except for the bottleneck number of channels
500 | which is twice larger in every block. The number of channels in outer 1x1
501 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
502 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
503 |
504 | Args:
505 | pretrained (bool): If True, returns a model pre-trained on ImageNet
506 | progress (bool): If True, displays a progress bar of the download to stderr
507 | """
508 | kwargs['width_per_group'] = 64 * 2
509 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
510 | pretrained, progress, **kwargs)
511 |
512 |
513 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
514 | r"""Wide ResNet-101-2 model from
515 | `"Wide Residual Networks" `_
516 |
517 | The model is the same as ResNet except for the bottleneck number of channels
518 | which is twice larger in every block. The number of channels in outer 1x1
519 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
520 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
521 |
522 | Args:
523 | pretrained (bool): If True, returns a model pre-trained on ImageNet
524 | progress (bool): If True, displays a progress bar of the download to stderr
525 | """
526 | kwargs['width_per_group'] = 64 * 2
527 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
528 | pretrained, progress, **kwargs)
529 |
--------------------------------------------------------------------------------
/models/resnet_cifar.py:
--------------------------------------------------------------------------------
1 | '''
2 | Properly implemented ResNet for CIFAR10 as described in paper [1].
3 | The implementation and structure of this file is hugely influenced by [2]
4 | which is implemented for ImageNet and doesn't have option A for identity.
5 | Moreover, most of the implementations on the web is copy-paste from
6 | torchvision's resnet and has wrong number of params.
7 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following
8 | number of layers and parameters:
9 | name | layers | params
10 | ResNet20 | 20 | 0.0017M
11 | ResNet32 | 32 | 0.46M
12 | ResNet44 | 44 | 0.66M
13 | ResNet56 | 56 | 0.85M
14 | ResNet110 | 110 | 1.7M
15 | ResNet1202| 1202 | 19.4m
16 | which this implementation indeed has.
17 | Reference:
18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
20 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
21 | If you use this implementation in you work, please don't forget to mention the
22 | author, Yerlan Idelbayev.
23 | '''
24 | import torch
25 | import torch.nn as nn
26 | import torch.nn.functional as F
27 | import torch.nn.init as init
28 | from torch.nn import Parameter
29 |
30 | __all__ = ['ResNet_s', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']
31 |
32 | def _weights_init(m):
33 | classname = m.__class__.__name__
34 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
35 | init.kaiming_normal_(m.weight)
36 |
37 | class NormedLinear(nn.Module):
38 |
39 | def __init__(self, in_features, out_features):
40 | super(NormedLinear, self).__init__()
41 | self.weight = Parameter(torch.Tensor(in_features, out_features))
42 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)
43 |
44 | def forward(self, x):
45 | out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0))
46 | return out
47 |
48 | class LambdaLayer(nn.Module):
49 |
50 | def __init__(self, lambd):
51 | super(LambdaLayer, self).__init__()
52 | self.lambd = lambd
53 |
54 | def forward(self, x):
55 | return self.lambd(x)
56 |
57 |
58 | class BasicBlock(nn.Module):
59 | expansion = 1
60 |
61 | def __init__(self, in_planes, planes, stride=1, option='A'):
62 | super(BasicBlock, self).__init__()
63 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
64 | self.bn1 = nn.BatchNorm2d(planes)
65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
66 | self.bn2 = nn.BatchNorm2d(planes)
67 |
68 | self.shortcut = nn.Sequential()
69 | if stride != 1 or in_planes != planes:
70 | if option == 'A':
71 | """
72 | For CIFAR10 ResNet paper uses option A.
73 | """
74 | self.shortcut = LambdaLayer(lambda x:
75 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
76 | elif option == 'B':
77 | self.shortcut = nn.Sequential(
78 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
79 | nn.BatchNorm2d(self.expansion * planes)
80 | )
81 |
82 | def forward(self, x):
83 | out = F.relu(self.bn1(self.conv1(x)))
84 | out = self.bn2(self.conv2(out))
85 | out += self.shortcut(x)
86 | out = F.relu(out)
87 | return out
88 |
89 |
90 | class ResNet_s(nn.Module):
91 |
92 | def __init__(self, block, num_blocks, num_classes=10, use_norm=False):
93 | super(ResNet_s, self).__init__()
94 | self.in_planes = 16
95 |
96 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
97 | self.bn1 = nn.BatchNorm2d(16)
98 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
99 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
100 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
101 | if use_norm:
102 | self.linear = NormedLinear(64, num_classes)
103 | else:
104 | self.linear = nn.Linear(64, num_classes)
105 | self.apply(_weights_init)
106 |
107 | def _make_layer(self, block, planes, num_blocks, stride):
108 | strides = [stride] + [1]*(num_blocks-1)
109 | layers = []
110 | for stride in strides:
111 | layers.append(block(self.in_planes, planes, stride))
112 | self.in_planes = planes * block.expansion
113 |
114 | return nn.Sequential(*layers)
115 |
116 | def forward(self, x):
117 | out = F.relu(self.bn1(self.conv1(x)))
118 | out = self.layer1(out)
119 | out = self.layer2(out)
120 | out = self.layer3(out)
121 | out = F.avg_pool2d(out, out.size()[3])
122 | out = out.view(out.size(0), -1)
123 | out = self.linear(out)
124 | return out
125 |
126 | class ResNet_fe(nn.Module):
127 |
128 | def __init__(self, block, num_blocks):
129 | super(ResNet_fe, self).__init__()
130 | self.in_planes = 16
131 |
132 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
133 | self.bn1 = nn.BatchNorm2d(16)
134 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
135 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
136 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
137 | self.apply(_weights_init)
138 |
139 | def _make_layer(self, block, planes, num_blocks, stride):
140 | strides = [stride] + [1]*(num_blocks-1)
141 | layers = []
142 | for stride in strides:
143 | layers.append(block(self.in_planes, planes, stride))
144 | self.in_planes = planes * block.expansion
145 |
146 | return nn.Sequential(*layers)
147 |
148 | def forward(self, x):
149 | out = F.relu(self.bn1(self.conv1(x)))
150 | out = self.layer1(out)
151 | out = self.layer2(out)
152 | out = self.layer3(out)
153 | out = F.avg_pool2d(out, out.size()[3])
154 | out = out.view(out.size(0), -1)
155 | return out
156 |
157 | class Classifier(nn.Module):
158 | def __init__(self, feat_in, num_classes):
159 | super(Classifier, self).__init__()
160 | self.fc = nn.Linear(feat_in, num_classes)
161 | self.apply(_weights_init)
162 |
163 | def forward(self, x):
164 | x = self.fc(x)
165 | return x
166 |
167 |
168 | def resnet20():
169 | return ResNet_s(BasicBlock, [3, 3, 3])
170 |
171 | def resnet32_fe():
172 | return ResNet_fe(BasicBlock, [5, 5, 5])
173 |
174 | def resnet32(num_classes=10, use_norm=False):
175 | return ResNet_s(BasicBlock, [5, 5, 5], num_classes=num_classes, use_norm=use_norm)
176 |
177 |
178 | def resnet44():
179 | return ResNet_s(BasicBlock, [7, 7, 7])
180 |
181 |
182 | def resnet56():
183 | return ResNet_s(BasicBlock, [9, 9, 9])
184 |
185 |
186 | def resnet110():
187 | return ResNet_s(BasicBlock, [18, 18, 18])
188 |
189 |
190 | def resnet1202():
191 | return ResNet_s(BasicBlock, [200, 200, 200])
192 |
193 |
194 | def test(net):
195 | import numpy as np
196 | total_params = 0
197 |
198 | for x in filter(lambda p: p.requires_grad, net.parameters()):
199 | total_params += np.prod(x.data.numpy().shape)
200 | print("Total number of params", total_params)
201 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters()))))
202 |
203 |
204 | if __name__ == "__main__":
205 | for net_name in __all__:
206 | if net_name.startswith('resnet'):
207 | print(net_name)
208 | test(globals()[net_name]())
209 | print()
--------------------------------------------------------------------------------
/models/resnet_places.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | try:
5 | from torch.hub import load_state_dict_from_url
6 | except ImportError:
7 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
8 |
9 |
10 |
11 | __all__ = ['Classifier', 'ResNet', 'resnet10', 'resnet10_fe', 'resnet18', 'resnet34', 'resnet50', 'resnet50_fe', 'resnet101', 'resnet101_fe',
12 | 'resnet152', 'resnet152_fe', 'resnext50_32x4d', 'resnext101_32x8d', 'resnext152_32x4d',
13 | 'wide_resnet50_2', 'wide_resnet101_2']
14 |
15 |
16 |
17 | model_urls = {
18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
23 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
25 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
26 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
27 | }
28 |
29 |
30 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
31 | """3x3 convolution with padding"""
32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
33 | padding=dilation, groups=groups, bias=False, dilation=dilation)
34 |
35 |
36 | def conv1x1(in_planes, out_planes, stride=1):
37 | """1x1 convolution"""
38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
39 |
40 |
41 | class BasicBlock(nn.Module):
42 | expansion = 1
43 |
44 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
45 | base_width=64, dilation=1, norm_layer=None):
46 | super(BasicBlock, self).__init__()
47 | if norm_layer is None:
48 | norm_layer = nn.BatchNorm2d
49 | if groups != 1 or base_width != 64:
50 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
51 | if dilation > 1:
52 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
53 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
54 | self.conv1 = conv3x3(inplanes, planes, stride)
55 | self.bn1 = norm_layer(planes)
56 | self.relu = nn.ReLU(inplace=True)
57 | self.conv2 = conv3x3(planes, planes)
58 | self.bn2 = norm_layer(planes)
59 | self.downsample = downsample
60 | self.stride = stride
61 |
62 | def forward(self, x):
63 | identity = x
64 |
65 | out = self.conv1(x)
66 | out = self.bn1(out)
67 | out = self.relu(out)
68 |
69 | out = self.conv2(out)
70 | out = self.bn2(out)
71 |
72 | if self.downsample is not None:
73 | identity = self.downsample(x)
74 |
75 | out += identity
76 | out = self.relu(out)
77 |
78 | return out
79 |
80 |
81 | class Bottleneck(nn.Module):
82 | expansion = 4
83 |
84 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
85 | base_width=64, dilation=1, norm_layer=None):
86 | super(Bottleneck, self).__init__()
87 | if norm_layer is None:
88 | norm_layer = nn.BatchNorm2d
89 | width = int(planes * (base_width / 64.)) * groups
90 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
91 | self.conv1 = conv1x1(inplanes, width)
92 | self.bn1 = norm_layer(width)
93 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
94 | self.bn2 = norm_layer(width)
95 | self.conv3 = conv1x1(width, planes * self.expansion)
96 | self.bn3 = norm_layer(planes * self.expansion)
97 | self.relu = nn.ReLU(inplace=True)
98 | self.downsample = downsample
99 | self.stride = stride
100 |
101 | def forward(self, x):
102 | identity = x
103 |
104 | out = self.conv1(x)
105 | out = self.bn1(out)
106 | out = self.relu(out)
107 |
108 | out = self.conv2(out)
109 | out = self.bn2(out)
110 | out = self.relu(out)
111 |
112 | out = self.conv3(out)
113 | out = self.bn3(out)
114 |
115 | if self.downsample is not None:
116 | identity = self.downsample(x)
117 |
118 | out += identity
119 | out = self.relu(out)
120 |
121 | return out
122 |
123 |
124 | class ResNet(nn.Module):
125 |
126 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
127 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
128 | norm_layer=None):
129 | super(ResNet, self).__init__()
130 | if norm_layer is None:
131 | norm_layer = nn.BatchNorm2d
132 | self._norm_layer = norm_layer
133 |
134 | self.inplanes = 64
135 | self.dilation = 1
136 | if replace_stride_with_dilation is None:
137 | # each element in the tuple indicates if we should replace
138 | # the 2x2 stride with a dilated convolution instead
139 | replace_stride_with_dilation = [False, False, False]
140 | if len(replace_stride_with_dilation) != 3:
141 | raise ValueError("replace_stride_with_dilation should be None "
142 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
143 | self.groups = groups
144 | self.base_width = width_per_group
145 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
146 | bias=False)
147 | self.bn1 = norm_layer(self.inplanes)
148 | self.relu = nn.ReLU(inplace=True)
149 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
150 | self.layer1 = self._make_layer(block, 64, layers[0])
151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
152 | dilate=replace_stride_with_dilation[0])
153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
154 | dilate=replace_stride_with_dilation[1])
155 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
156 | dilate=replace_stride_with_dilation[2])
157 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
158 | self.fc = nn.Linear(512 * block.expansion, num_classes)
159 |
160 | for m in self.modules():
161 | if isinstance(m, nn.Conv2d):
162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
163 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
164 | nn.init.constant_(m.weight, 1)
165 | nn.init.constant_(m.bias, 0)
166 |
167 | # Zero-initialize the last BN in each residual branch,
168 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
169 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
170 | if zero_init_residual:
171 | for m in self.modules():
172 | if isinstance(m, Bottleneck):
173 | nn.init.constant_(m.bn3.weight, 0)
174 | elif isinstance(m, BasicBlock):
175 | nn.init.constant_(m.bn2.weight, 0)
176 |
177 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
178 | norm_layer = self._norm_layer
179 | downsample = None
180 | previous_dilation = self.dilation
181 | if dilate:
182 | self.dilation *= stride
183 | stride = 1
184 | if stride != 1 or self.inplanes != planes * block.expansion:
185 | downsample = nn.Sequential(
186 | conv1x1(self.inplanes, planes * block.expansion, stride),
187 | norm_layer(planes * block.expansion),
188 | )
189 |
190 | layers = []
191 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
192 | self.base_width, previous_dilation, norm_layer))
193 | self.inplanes = planes * block.expansion
194 | for _ in range(1, blocks):
195 | layers.append(block(self.inplanes, planes, groups=self.groups,
196 | base_width=self.base_width, dilation=self.dilation,
197 | norm_layer=norm_layer))
198 |
199 | return nn.Sequential(*layers)
200 |
201 | def forward(self, x):
202 | x = self.conv1(x)
203 | x = self.bn1(x)
204 | x = self.relu(x)
205 | x = self.maxpool(x)
206 |
207 | x = self.layer1(x)
208 | x = self.layer2(x)
209 | x = self.layer3(x)
210 | x = self.layer4(x)
211 |
212 | x = self.avgpool(x)
213 | x = torch.flatten(x, 1)
214 | x = self.fc(x)
215 |
216 | return x
217 |
218 | class ResNet_FE(nn.Module):
219 |
220 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
221 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
222 | norm_layer=None):
223 | super(ResNet_FE, self).__init__()
224 | if norm_layer is None:
225 | norm_layer = nn.BatchNorm2d
226 | self._norm_layer = norm_layer
227 |
228 | self.inplanes = 64
229 | self.dilation = 1
230 | if replace_stride_with_dilation is None:
231 | # each element in the tuple indicates if we should replace
232 | # the 2x2 stride with a dilated convolution instead
233 | replace_stride_with_dilation = [False, False, False]
234 | if len(replace_stride_with_dilation) != 3:
235 | raise ValueError("replace_stride_with_dilation should be None "
236 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
237 | self.groups = groups
238 | self.base_width = width_per_group
239 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
240 | bias=False)
241 | self.bn1 = norm_layer(self.inplanes)
242 | self.relu = nn.ReLU(inplace=True)
243 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
244 | self.layer1 = self._make_layer(block, 64, layers[0])
245 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
246 | dilate=replace_stride_with_dilation[0])
247 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
248 | dilate=replace_stride_with_dilation[1])
249 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
250 | dilate=replace_stride_with_dilation[2])
251 |
252 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
253 |
254 | for m in self.modules():
255 | if isinstance(m, nn.Conv2d):
256 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
257 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
258 | nn.init.constant_(m.weight, 1)
259 | nn.init.constant_(m.bias, 0)
260 |
261 | # Zero-initialize the last BN in each residual branch,
262 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
263 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
264 | if zero_init_residual:
265 | for m in self.modules():
266 | if isinstance(m, Bottleneck):
267 | nn.init.constant_(m.bn3.weight, 0)
268 | elif isinstance(m, BasicBlock):
269 | nn.init.constant_(m.bn2.weight, 0)
270 |
271 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
272 | norm_layer = self._norm_layer
273 | downsample = None
274 | previous_dilation = self.dilation
275 | if dilate:
276 | self.dilation *= stride
277 | stride = 1
278 | if stride != 1 or self.inplanes != planes * block.expansion:
279 | downsample = nn.Sequential(
280 | conv1x1(self.inplanes, planes * block.expansion, stride),
281 | norm_layer(planes * block.expansion),
282 | )
283 |
284 | layers = []
285 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
286 | self.base_width, previous_dilation, norm_layer))
287 | self.inplanes = planes * block.expansion
288 | for _ in range(1, blocks):
289 | layers.append(block(self.inplanes, planes, groups=self.groups,
290 | base_width=self.base_width, dilation=self.dilation,
291 | norm_layer=norm_layer))
292 |
293 | return nn.Sequential(*layers)
294 |
295 | def forward(self, x):
296 | x = self.conv1(x)
297 | x = self.bn1(x)
298 | x = self.relu(x)
299 | x = self.maxpool(x)
300 |
301 | x = self.layer1(x)
302 | x = self.layer2(x)
303 | x = self.layer3(x)
304 | x = self.layer4(x)
305 |
306 |
307 | return x
308 |
309 |
310 |
311 | class Classifier(nn.Module):
312 | def __init__(self, feat_in, num_classes):
313 | super(Classifier, self).__init__()
314 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
315 | self.fc = nn.Linear(feat_in, num_classes)
316 |
317 | def forward(self, x):
318 | x = self.avgpool(x)
319 | x = torch.flatten(x, 1)
320 | x = self.fc(x)
321 | return x
322 |
323 |
324 |
325 |
326 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
327 | model = ResNet(block, layers, **kwargs)
328 | if pretrained:
329 | state_dict = load_state_dict_from_url(model_urls[arch],
330 | progress=progress)
331 | model.load_state_dict(state_dict)
332 | return model
333 |
334 | def _resnet_fe(arch, block, layers, pretrained, progress, **kwargs):
335 | model = ResNet_FE(block, layers, **kwargs)
336 | if pretrained:
337 | state_dict = load_state_dict_from_url(model_urls[arch[:-3]],
338 | progress=progress)
339 |
340 | del state_dict['fc.weight']
341 | del state_dict['fc.bias']
342 | del state_dict["layer4.2.conv1.weight"]
343 | del state_dict["layer4.2.bn1.running_mean"]
344 | del state_dict["layer4.2.bn1.running_var"]
345 | del state_dict["layer4.2.bn1.weight"]
346 | del state_dict["layer4.2.bn1.bias"]
347 | del state_dict["layer4.2.conv2.weight"]
348 | del state_dict["layer4.2.bn2.running_mean"]
349 | del state_dict["layer4.2.bn2.running_var"]
350 | del state_dict["layer4.2.bn2.weight"]
351 | del state_dict["layer4.2.bn2.bias"]
352 | del state_dict["layer4.2.conv3.weight"]
353 | del state_dict["layer4.2.bn3.running_mean"]
354 | del state_dict["layer4.2.bn3.running_var"]
355 | del state_dict["layer4.2.bn3.weight"]
356 | del state_dict["layer4.2.bn3.bias"]
357 |
358 | model.load_state_dict(state_dict)
359 |
360 | return model
361 |
362 |
363 | def resnet10(pretrained=False, progress=True, **kwargs):
364 | r"""ResNet-18 model from
365 | `"Deep Residual Learning for Image Recognition" `_
366 |
367 | Args:
368 | pretrained (bool): If True, returns a model pre-trained on ImageNet
369 | progress (bool): If True, displays a progress bar of the download to stderr
370 | """
371 | return _resnet('resnet10', BasicBlock, [1, 1, 1, 1], pretrained, progress,
372 | **kwargs)
373 |
374 |
375 | def resnet10_fe(pretrained=False, progress=True, **kwargs):
376 | r"""ResNet-18 model from
377 | `"Deep Residual Learning for Image Recognition" `_
378 |
379 | Args:
380 | pretrained (bool): If True, returns a model pre-trained on ImageNet
381 | progress (bool): If True, displays a progress bar of the download to stderr
382 | """
383 | return _resnet_fe('resnet10_fe', BasicBlock, [1, 1, 1, 1], pretrained, progress,
384 | **kwargs)
385 |
386 |
387 |
388 | def resnet18(pretrained=False, progress=True, **kwargs):
389 | r"""ResNet-18 model from
390 | `"Deep Residual Learning for Image Recognition" `_
391 |
392 | Args:
393 | pretrained (bool): If True, returns a model pre-trained on ImageNet
394 | progress (bool): If True, displays a progress bar of the download to stderr
395 | """
396 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
397 | **kwargs)
398 |
399 |
400 | def resnet34(pretrained=False, progress=True, **kwargs):
401 | r"""ResNet-34 model from
402 | `"Deep Residual Learning for Image Recognition" `_
403 |
404 | Args:
405 | pretrained (bool): If True, returns a model pre-trained on ImageNet
406 | progress (bool): If True, displays a progress bar of the download to stderr
407 | """
408 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
409 | **kwargs)
410 |
411 |
412 | def resnet50(pretrained=False, progress=True, **kwargs):
413 | r"""ResNet-50 model from
414 | `"Deep Residual Learning for Image Recognition" `_
415 |
416 | Args:
417 | pretrained (bool): If True, returns a model pre-trained on ImageNet
418 | progress (bool): If True, displays a progress bar of the download to stderr
419 | """
420 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
421 | **kwargs)
422 |
423 | def resnet50_fe(pretrained=False, progress=True, **kwargs):
424 | r"""ResNet-18 model from
425 | `"Deep Residual Learning for Image Recognition" `_
426 |
427 | Args:
428 | pretrained (bool): If True, returns a model pre-trained on ImageNet
429 | progress (bool): If True, displays a progress bar of the download to stderr
430 | """
431 | return _resnet_fe('resnet50_fe', Bottleneck, [3, 4, 6, 3], pretrained, progress,
432 | **kwargs)
433 |
434 |
435 | def resnet101(pretrained=False, progress=True, **kwargs):
436 | r"""ResNet-101 model from
437 | `"Deep Residual Learning for Image Recognition" `_
438 |
439 | Args:
440 | pretrained (bool): If True, returns a model pre-trained on ImageNet
441 | progress (bool): If True, displays a progress bar of the download to stderr
442 | """
443 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
444 | **kwargs)
445 |
446 |
447 | def resnet101_fe(pretrained=False, progress=True, **kwargs):
448 | r"""ResNet-101 model from
449 | `"Deep Residual Learning for Image Recognition" `_
450 |
451 | Args:
452 | pretrained (bool): If True, returns a model pre-trained on ImageNet
453 | progress (bool): If True, displays a progress bar of the download to stderr
454 | """
455 | return _resnet_fe('resnet101_fe', Bottleneck, [3, 4, 23, 3], pretrained, progress,
456 | **kwargs)
457 |
458 | def resnet152(pretrained=False, progress=True, **kwargs):
459 | r"""ResNet-152 model from
460 | `"Deep Residual Learning for Image Recognition" `_
461 |
462 | Args:
463 | pretrained (bool): If True, returns a model pre-trained on ImageNet
464 | progress (bool): If True, displays a progress bar of the download to stderr
465 | """
466 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
467 | **kwargs)
468 |
469 |
470 | def resnet152_fe(pretrained=False, progress=True, **kwargs):
471 | r"""ResNet-152 model from
472 | `"Deep Residual Learning for Image Recognition" `_
473 |
474 | Args:
475 | pretrained (bool): If True, returns a model pre-trained on ImageNet
476 | progress (bool): If True, displays a progress bar of the download to stderr
477 |
478 | Follow Liu et. al "Large-Scale Long-Tailed Recognition in an Open World", CVPR 2019
479 | just train the last block in ResNet-152
480 |
481 | """
482 | return _resnet_fe('resnet152_fe', Bottleneck, [3, 8, 36, 2], pretrained, progress,
483 | **kwargs)
484 |
485 |
486 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
487 | r"""ResNeXt-50 32x4d model from
488 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
489 |
490 | Args:
491 | pretrained (bool): If True, returns a model pre-trained on ImageNet
492 | progress (bool): If True, displays a progress bar of the download to stderr
493 | """
494 | kwargs['groups'] = 32
495 | kwargs['width_per_group'] = 4
496 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
497 | pretrained, progress, **kwargs)
498 |
499 |
500 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
501 | r"""ResNeXt-101 32x8d model from
502 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
503 |
504 | Args:
505 | pretrained (bool): If True, returns a model pre-trained on ImageNet
506 | progress (bool): If True, displays a progress bar of the download to stderr
507 | """
508 | kwargs['groups'] = 32
509 | kwargs['width_per_group'] = 8
510 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
511 | pretrained, progress, **kwargs)
512 |
513 | def resnext152_32x4d(pretrained=False, progress=True, **kwargs):
514 | kwargs['groups'] = 32
515 | kwargs['width_per_group'] = 4
516 | return _resnet('resnext152_32x4d', Bottleneck, [3, 8, 36, 3],
517 | pretrained, progress, **kwargs)
518 |
519 |
520 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
521 | r"""Wide ResNet-50-2 model from
522 | `"Wide Residual Networks" `_
523 |
524 | The model is the same as ResNet except for the bottleneck number of channels
525 | which is twice larger in every block. The number of channels in outer 1x1
526 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
527 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
528 |
529 | Args:
530 | pretrained (bool): If True, returns a model pre-trained on ImageNet
531 | progress (bool): If True, displays a progress bar of the download to stderr
532 | """
533 | kwargs['width_per_group'] = 64 * 2
534 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
535 | pretrained, progress, **kwargs)
536 |
537 |
538 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
539 | r"""Wide ResNet-101-2 model from
540 | `"Wide Residual Networks" `_
541 |
542 | The model is the same as ResNet except for the bottleneck number of channels
543 | which is twice larger in every block. The number of channels in outer 1x1
544 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
545 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
546 |
547 | Args:
548 | pretrained (bool): If True, returns a model pre-trained on ImageNet
549 | progress (bool): If True, displays a progress bar of the download to stderr
550 | """
551 | kwargs['width_per_group'] = 64 * 2
552 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
553 | pretrained, progress, **kwargs)
554 |
--------------------------------------------------------------------------------
/reliability_diagrams.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import seaborn as sns
5 | import pdb
6 | sns.set_style("darkgrid")
7 | from matplotlib import rcParams
8 | # rcParams['font.family'] = 'Arial'
9 | rcParams['font.size'] = 75
10 |
11 | def compute_calibration(true_labels, pred_labels, confidences, num_bins=10):
12 | """Collects predictions into bins used to draw a reliability diagram.
13 |
14 | Arguments:
15 | true_labels: the true labels for the test examples
16 | pred_labels: the predicted labels for the test examples
17 | confidences: the predicted confidences for the test examples
18 | num_bins: number of bins
19 |
20 | The true_labels, pred_labels, confidences arguments must be NumPy arrays;
21 | pred_labels and true_labels may contain numeric or string labels.
22 |
23 | For a multi-class model, the predicted label and confidence should be those
24 | of the highest scoring class.
25 |
26 | Returns a dictionary containing the following NumPy arrays:
27 | accuracies: the average accuracy for each bin
28 | confidences: the average confidence for each bin
29 | counts: the number of examples in each bin
30 | bins: the confidence thresholds for each bin
31 | avg_accuracy: the accuracy over the entire test set
32 | avg_confidence: the average confidence over the entire test set
33 | expected_calibration_error: a weighted average of all calibration gaps
34 | max_calibration_error: the largest calibration gap across all bins
35 | """
36 | assert(len(confidences) == len(pred_labels))
37 | assert(len(confidences) == len(true_labels))
38 | assert(num_bins > 0)
39 |
40 | bin_size = 1.0 / num_bins
41 | bins = np.linspace(0.0, 1.0, num_bins + 1)
42 | indices = np.digitize(confidences, bins, right=True)
43 |
44 | bin_accuracies = np.zeros(num_bins, dtype=np.float)
45 | bin_confidences = np.zeros(num_bins, dtype=np.float)
46 | bin_counts = np.zeros(num_bins, dtype=np.int)
47 |
48 | for b in range(num_bins):
49 | selected = np.where(indices == b + 1)[0]
50 | if len(selected) > 0:
51 | bin_accuracies[b] = np.mean(true_labels[selected] == pred_labels[selected])
52 | bin_confidences[b] = np.mean(confidences[selected])
53 | bin_counts[b] = len(selected)
54 |
55 | avg_acc = np.sum(bin_accuracies * bin_counts) / np.sum(bin_counts)
56 | avg_conf = np.sum(bin_confidences * bin_counts) / np.sum(bin_counts)
57 |
58 | gaps = np.abs(bin_accuracies - bin_confidences)
59 | ece = np.sum(gaps * bin_counts) / np.sum(bin_counts)
60 | mce = np.max(gaps)
61 |
62 | return { "accuracies": bin_accuracies,
63 | "confidences": bin_confidences,
64 | "counts": bin_counts,
65 | "bins": bins,
66 | "avg_accuracy": avg_acc,
67 | "avg_confidence": avg_conf,
68 | "expected_calibration_error": ece,
69 | "max_calibration_error": mce }
70 |
71 |
72 | def _reliability_diagram_subplot(ax, bin_data, draw_order=True,
73 | draw_ece=True,
74 | draw_bin_importance=False,
75 | title="Reliability Diagram",
76 | xlabel="Confidence",
77 | ylabel="Expected Accuracy"):
78 | """Draws a reliability diagram into a subplot."""
79 | accuracies = bin_data["accuracies"]
80 | confidences = bin_data["confidences"]
81 | counts = bin_data["counts"]
82 | bins = bin_data["bins"]
83 |
84 | bin_size = 1.0 / len(counts)
85 | positions = bins[:-1] + bin_size/2.0
86 |
87 | widths = bin_size
88 | alphas = 0.3
89 | min_count = np.min(counts)
90 | max_count = np.max(counts)
91 | normalized_counts = (counts - min_count) / (max_count - min_count)
92 |
93 | if draw_bin_importance == "alpha":
94 | alphas = 0.2 + 0.8*normalized_counts
95 | elif draw_bin_importance == "width":
96 | widths = 0.1*bin_size + 0.9*bin_size*normalized_counts
97 |
98 | colors = np.zeros((len(counts), 4))
99 | colors[:, 0] = 240 / 255.
100 | colors[:, 1] = 60 / 255.
101 | colors[:, 2] = 60 / 255.
102 | colors[:, 3] = alphas
103 |
104 | if draw_order == True:
105 | # pdb.set_trace()
106 | acc_plt = ax.bar(positions, accuracies, bottom=0, width=widths,
107 | color=sns.color_palette("Blues", 10)[4], edgecolor="white", alpha=1.0, linewidth=5,
108 | label="Accuracy")
109 |
110 | gap_plt = ax.bar(positions, np.abs(accuracies - confidences),
111 | bottom=np.minimum(accuracies, confidences), width=widths,
112 | color=sns.color_palette("Blues", 10)[7], edgecolor="white", linewidth=5, label="Gap")
113 | else:
114 | acc_plt = ax.bar(positions, accuracies, bottom=0, width=widths,
115 | color=sns.color_palette("Blues", 10)[4], edgecolor="white", alpha=1.0, linewidth=5,
116 | label="Accuracy")
117 |
118 | gap_plt = ax.bar(positions, np.abs(accuracies - confidences),
119 | bottom=np.minimum(accuracies, confidences), width=widths,
120 | color=sns.color_palette("Blues", 10)[7], edgecolor="white", linewidth=5, label="Gap")
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 | ax.set_aspect("equal")
129 | ax.plot([0,1], [0,1], linestyle = "--", color="gray", linewidth=10)
130 | ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
131 | ax.set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
132 |
133 | if draw_ece:
134 | import matplotlib.patches as patches
135 | ax.add_patch(
136 | patches.Rectangle(
137 | (0.51, 0.015), # (x,y)
138 | 0.48, 0.2, # width and height
139 | # You can add rotation as well with 'angle'
140 | alpha=0.7, facecolor="white", edgecolor="white", linewidth=3, linestyle='solid')
141 | )
142 | acc_avg = (bin_data["avg_accuracy"])
143 | # ax.text(0.26, 0.79, "ECE=%.3f" % ece, color="black",
144 | # ha="right", va="bottom", transform=ax.transAxes)
145 | ax.text(0.98, 0.11, "ACC=%.3f" % acc_avg, color="black",
146 | ha="right", va="bottom", transform=ax.transAxes)
147 | ece = (bin_data["expected_calibration_error"])
148 | # ax.text(0.26, 0.79, "ECE=%.3f" % ece, color="black",
149 | # ha="right", va="bottom", transform=ax.transAxes)
150 | ax.text(0.98, 0.03, "ECE=%.3f" % ece, color="black",
151 | ha="right", va="bottom", transform=ax.transAxes)
152 |
153 | ax.set_xlim(0, 1)
154 | ax.set_ylim(0, 1)
155 | #ax.set_xticks(bins)
156 |
157 | # ax.set_title(title, pad=40, fontsize=50)
158 | # ttl = ax.title
159 | # ttl.set_position([.45, 1.00])
160 | ax.set_xlabel(xlabel)
161 | ax.set_ylabel(ylabel)
162 |
163 | ax.legend(handles=[gap_plt, acc_plt])
164 |
165 |
166 | def _confidence_histogram_subplot(ax, bin_data,
167 | draw_averages=True,
168 | title="Examples per bin",
169 | xlabel="Confidence",
170 | ylabel="Count"):
171 | """Draws a confidence histogram into a subplot."""
172 | counts = bin_data["counts"]
173 | bins = bin_data["bins"]
174 |
175 | bin_size = 1.0 / len(counts)
176 | positions = bins[:-1] + bin_size/2.0
177 |
178 | ax.bar(positions, counts, width=bin_size * 0.9)
179 |
180 | ax.set_xlim(0, 1)
181 | ax.set_title(title)
182 | ax.set_xlabel(xlabel)
183 | ax.set_ylabel(ylabel)
184 |
185 | if draw_averages:
186 | acc_plt = ax.axvline(x=bin_data["avg_accuracy"], ls="solid", lw=3,
187 | c="black", label="Accuracy")
188 | conf_plt = ax.axvline(x=bin_data["avg_confidence"], ls="dotted", lw=3,
189 | c="#444", label="Avg. confidence")
190 | ax.legend(handles=[acc_plt, conf_plt])
191 |
192 |
193 | def _reliability_diagram_combined(bin_data,
194 | draw_ece, draw_bin_importance, draw_averages,
195 | title, figsize, dpi, return_fig):
196 | """Draws a reliability diagram and confidence histogram using the output
197 | from compute_calibration()."""
198 | figsize = (figsize[0], figsize[0])
199 |
200 | fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True, figsize=figsize, dpi=dpi)
201 |
202 | plt.tight_layout()
203 | plt.subplots_adjust(hspace=-0.1)
204 |
205 | _reliability_diagram_subplot(ax, bin_data, draw_ece, draw_bin_importance,
206 | title=title, xlabel="Confidence")
207 |
208 | # Draw the confidence histogram upside down.
209 | # orig_counts = bin_data["counts"]
210 | # bin_data["counts"] = -bin_data["counts"]
211 | # _confidence_histogram_subplot(ax[1], bin_data, draw_averages, title="")
212 | # bin_data["counts"] = orig_counts
213 |
214 | # # Also negate the ticks for the upside-down histogram.
215 | # new_ticks = np.abs(ax[1].get_yticks()).astype(np.int)
216 | # ax[1].set_yticklabels(new_ticks)
217 |
218 | plt.show()
219 |
220 | if return_fig: return fig
221 |
222 |
223 | def reliability_diagram(true_labels, pred_labels, confidences, num_bins=10,
224 | draw_ece=True, draw_bin_importance=False,
225 | draw_averages=True, title="Reliability Diagram",
226 | figsize=(20, 20), dpi=72, return_fig=False):
227 | """Draws a reliability diagram and confidence histogram in a single plot.
228 |
229 | First, the model's predictions are divided up into bins based on their
230 | confidence scores.
231 |
232 | The reliability diagram shows the gap between average accuracy and average
233 | confidence in each bin. These are the red bars.
234 |
235 | The black line is the accuracy, the other end of the bar is the confidence.
236 |
237 | Ideally, there is no gap and the black line is on the dotted diagonal.
238 | In that case, the model is properly calibrated and we can interpret the
239 | confidence scores as probabilities.
240 |
241 | The confidence histogram visualizes how many examples are in each bin.
242 | This is useful for judging how much each bin contributes to the calibration
243 | error.
244 |
245 | The confidence histogram also shows the overall accuracy and confidence.
246 | The closer these two lines are together, the better the calibration.
247 |
248 | The ECE or Expected Calibration Error is a summary statistic that gives the
249 | difference in expectation between confidence and accuracy. In other words,
250 | it's a weighted average of the gaps across all bins. A lower ECE is better.
251 |
252 | Arguments:
253 | true_labels: the true labels for the test examples
254 | pred_labels: the predicted labels for the test examples
255 | confidences: the predicted confidences for the test examples
256 | num_bins: number of bins
257 | draw_ece: whether to include the Expected Calibration Error
258 | draw_bin_importance: whether to represent how much each bin contributes
259 | to the total accuracy: False, "alpha", "widths"
260 | draw_averages: whether to draw the overall accuracy and confidence in
261 | the confidence histogram
262 | title: optional title for the plot
263 | figsize: setting for matplotlib; height is ignored
264 | dpi: setting for matplotlib
265 | return_fig: if True, returns the matplotlib Figure object
266 | """
267 | bin_data = compute_calibration(true_labels, pred_labels, confidences, num_bins)
268 | return _reliability_diagram_combined(bin_data, draw_ece, draw_bin_importance,
269 | draw_averages, title, figsize=figsize,
270 | dpi=dpi, return_fig=return_fig)
271 |
272 |
273 | def reliability_diagrams(results, num_bins=10,
274 | draw_ece=True, draw_bin_importance=False,
275 | num_cols=4, dpi=72, return_fig=False):
276 | """Draws reliability diagrams for one or more models.
277 |
278 | Arguments:
279 | results: dictionary where the key is the model name and the value is
280 | a dictionary containing the true labels, predicated labels, and
281 | confidences for this model
282 | num_bins: number of bins
283 | draw_ece: whether to include the Expected Calibration Error
284 | draw_bin_importance: whether to represent how much each bin contributes
285 | to the total accuracy: False, "alpha", "widths"
286 | num_cols: how wide to make the plot
287 | dpi: setting for matplotlib
288 | return_fig: if True, returns the matplotlib Figure object
289 | """
290 | ncols = num_cols
291 |
292 | nrows = (len(results) + ncols - 1) // ncols
293 | figsize = (ncols * 16, nrows * 16)
294 |
295 | fig, ax = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True,
296 | figsize=figsize, dpi=dpi, constrained_layout=True)
297 | fig.text(0.5, 0.00, 'Confidence', ha='center')
298 |
299 | for i, (plot_name, data) in enumerate(results.items()):
300 | y_true = data["true_labels"]
301 | y_pred = data["pred_labels"]
302 | y_conf = data["confidences"]
303 |
304 | bin_data = compute_calibration(y_true, y_pred, y_conf, num_bins)
305 |
306 | row = i // ncols
307 | col = i % ncols
308 | # pdb.set_trace()
309 | draw_order = True
310 | if i == 3:
311 | draw_order = False
312 | # pdb.set_trace()
313 | _reliability_diagram_subplot(ax[col], bin_data, draw_order, draw_ece,
314 | draw_bin_importance,
315 | title="\n".join(plot_name.split()),
316 | xlabel=" " if row == nrows - 1 else "",
317 | ylabel="Accuracy" if col == 0 else "")
318 |
319 | for i in range(i + 1, nrows * ncols):
320 | row = i // ncols
321 | col = i % ncols
322 | ax[row, col].axis("off")
323 |
324 | # plt.show()
325 |
326 | if return_fig: return fig
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torchvision==0.4.0
2 | torch==1.2.0
3 | numpy==1.19.2
4 | yacs==0.1.8
5 |
--------------------------------------------------------------------------------
/train_stage1.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | import time
6 | import warnings
7 | import numpy as np
8 | import pprint
9 | import math
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.parallel
14 | import torch.distributed as dist
15 | import torch.optim
16 | import torch.multiprocessing as mp
17 | import torch.utils.data
18 | import torch.utils.data.distributed
19 | import torch.nn.functional as F
20 |
21 | from datasets.cifar10 import CIFAR10_LT
22 | from datasets.cifar100 import CIFAR100_LT
23 | from datasets.places import Places_LT
24 | from datasets.imagenet import ImageNet_LT
25 | from datasets.ina2018 import iNa2018
26 |
27 | from models import resnet
28 | from models import resnet_places
29 | from models import resnet_cifar
30 |
31 | from utils import config, update_config, create_logger
32 | from utils import AverageMeter, ProgressMeter
33 | from utils import accuracy, calibration
34 |
35 | from methods import mixup_data, mixup_criterion
36 |
37 |
38 | def parse_args():
39 | parser = argparse.ArgumentParser(description='MiSLAS training (Stage-1)')
40 | parser.add_argument('--cfg',
41 | help='experiment configure file name',
42 | required=True,
43 | type=str)
44 | parser.add_argument('opts',
45 | help="Modify config options using the command-line",
46 | default=None,
47 | nargs=argparse.REMAINDER)
48 | args = parser.parse_args()
49 | update_config(config, args)
50 |
51 | return args
52 |
53 |
54 | best_acc1 = 0
55 | its_ece = 100
56 |
57 |
58 | def main():
59 | args = parse_args()
60 | logger, model_dir = create_logger(config, args.cfg)
61 | logger.info('\n' + pprint.pformat(args))
62 | logger.info('\n' + str(config))
63 |
64 | if config.deterministic:
65 | seed = 0
66 | torch.backends.cudnn.deterministic = True
67 | torch.backends.cudnn.benchmark = False
68 | random.seed(seed)
69 | np.random.seed(seed)
70 | os.environ['PYTHONHASHSEED'] = str(seed)
71 | torch.manual_seed(seed)
72 | torch.cuda.manual_seed(seed)
73 | torch.cuda.manual_seed_all(seed)
74 |
75 | if config.gpu is not None:
76 | warnings.warn('You have chosen a specific GPU. This will completely '
77 | 'disable data parallelism.')
78 |
79 | if config.dist_url == "env://" and config.world_size == -1:
80 | config.world_size = int(os.environ["WORLD_SIZE"])
81 |
82 | config.distributed = config.world_size > 1 or config.multiprocessing_distributed
83 |
84 | ngpus_per_node = torch.cuda.device_count()
85 | if config.multiprocessing_distributed:
86 | # Since we have ngpus_per_node processes per node, the total world_size
87 | # needs to be adjusted accordingly
88 | config.world_size = ngpus_per_node * config.world_size
89 | # Use torch.multiprocessing.spawn to launch distributed processes: the
90 | # main_worker process function
91 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, config, logger))
92 | else:
93 | # Simply call main_worker function
94 | main_worker(config.gpu, ngpus_per_node, config, logger, model_dir)
95 |
96 |
97 | def main_worker(gpu, ngpus_per_node, config, logger, model_dir):
98 | global best_acc1, its_ece
99 | config.gpu = gpu
100 | # start_time = time.strftime("%Y%m%d_%H%M%S", time.localtime())
101 |
102 | if config.gpu is not None:
103 | logger.info("Use GPU: {} for training".format(config.gpu))
104 |
105 | if config.distributed:
106 | if config.dist_url == "env://" and config.rank == -1:
107 | config.rank = int(os.environ["RANK"])
108 | if config.multiprocessing_distributed:
109 | # For multiprocessing distributed training, rank needs to be the
110 | # global rank among all the processes
111 | config.rank = config.rank * ngpus_per_node + gpu
112 | dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url,
113 | world_size=config.world_size, rank=config.rank)
114 |
115 | if config.dataset == 'cifar10' or config.dataset == 'cifar100':
116 | model = getattr(resnet_cifar, config.backbone)()
117 | classifier = getattr(resnet_cifar, 'Classifier')(feat_in=64, num_classes=config.num_classes)
118 |
119 | elif config.dataset == 'imagenet' or config.dataset == 'ina2018':
120 | model = getattr(resnet, config.backbone)()
121 | classifier = getattr(resnet, 'Classifier')(feat_in=2048, num_classes=config.num_classes)
122 |
123 | elif config.dataset == 'places':
124 | model = getattr(resnet_places, config.backbone)(pretrained=True)
125 | classifier = getattr(resnet_places, 'Classifier')(feat_in=2048, num_classes=config.num_classes)
126 | block = getattr(resnet_places, 'Bottleneck')(2048, 512, groups=1, base_width=64, dilation=1, norm_layer=nn.BatchNorm2d)
127 |
128 | if not torch.cuda.is_available():
129 | logger.info('using CPU, this will be slow')
130 | elif config.distributed:
131 | # For multiprocessing distributed, DistributedDataParallel constructor
132 | # should always set the single device scope, otherwise,
133 | # DistributedDataParallel will use all available devices.
134 | if config.gpu is not None:
135 | torch.cuda.set_device(config.gpu)
136 | model.cuda(config.gpu)
137 | classifier.cuda(config.gpu)
138 | # When using a single GPU per process and per
139 | # DistributedDataParallel, we need to divide the batch size
140 | # ourselves based on the total number of GPUs we have
141 | config.batch_size = int(config.batch_size / ngpus_per_node)
142 | config.workers = int((config.workers + ngpus_per_node - 1) / ngpus_per_node)
143 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu])
144 | classifier = torch.nn.parallel.DistributedDataParallel(classifier, device_ids=[config.gpu])
145 | if config.dataset == 'places':
146 | block.cuda(config.gpu)
147 | block = torch.nn.parallel.DistributedDataParallel(block, device_ids=[config.gpu])
148 | else:
149 | model.cuda()
150 | classifier.cuda()
151 | # DistributedDataParallel will divide and allocate batch_size to all
152 | # available GPUs if device_ids are not set
153 | model = torch.nn.parallel.DistributedDataParallel(model)
154 | classifier = torch.nn.parallel.DistributedDataParallel(classifier)
155 | if config.dataset == 'places':
156 | block.cuda()
157 | block = torch.nn.parallel.DistributedDataParallel(block)
158 | elif config.gpu is not None:
159 | torch.cuda.set_device(config.gpu)
160 | model = model.cuda(config.gpu)
161 | classifier = classifier.cuda(config.gpu)
162 | if config.dataset == 'places':
163 | block.cuda(config.gpu)
164 | else:
165 | # DataParallel will divide and allocate batch_size to all available GPUs
166 | model = torch.nn.DataParallel(model).cuda()
167 | classifier = torch.nn.DataParallel(classifier).cuda()
168 | if config.dataset == 'places':
169 | block = torch.nn.DataParallel(block).cuda()
170 |
171 | # optionally resume from a checkpoint
172 | if config.resume:
173 | if os.path.isfile(config.resume):
174 | logger.info("=> loading checkpoint '{}'".format(config.resume))
175 | if config.gpu is None:
176 | checkpoint = torch.load(config.resume)
177 | else:
178 | # Map model to be loaded to specified single gpu.
179 | loc = 'cuda:{}'.format(config.gpu)
180 | checkpoint = torch.load(config.resume, map_location=loc)
181 | # config.start_epoch = checkpoint['epoch']
182 | best_acc1 = checkpoint['best_acc1']
183 | if config.gpu is not None:
184 | # best_acc1 may be from a checkpoint from a different GPU
185 | best_acc1 = best_acc1.to(config.gpu)
186 | model.load_state_dict(checkpoint['state_dict_model'])
187 | classifier.load_state_dict(checkpoint['state_dict_classifier'])
188 | logger.info("=> loaded checkpoint '{}' (epoch {})"
189 | .format(config.resume, checkpoint['epoch']))
190 | else:
191 | logger.info("=> no checkpoint found at '{}'".format(config.resume))
192 |
193 | # Data loading code
194 | if config.dataset == 'cifar10':
195 | dataset = CIFAR10_LT(config.distributed, root=config.data_path, imb_factor=config.imb_factor,
196 | batch_size=config.batch_size, num_works=config.workers)
197 |
198 | elif config.dataset == 'cifar100':
199 | dataset = CIFAR100_LT(config.distributed, root=config.data_path, imb_factor=config.imb_factor,
200 | batch_size=config.batch_size, num_works=config.workers)
201 |
202 | elif config.dataset == 'places':
203 | dataset = Places_LT(config.distributed, root=config.data_path,
204 | batch_size=config.batch_size, num_works=config.workers)
205 |
206 | elif config.dataset == 'imagenet':
207 | dataset = ImageNet_LT(config.distributed, root=config.data_path,
208 | batch_size=config.batch_size, num_works=config.workers)
209 |
210 | elif config.dataset == 'ina2018':
211 | dataset = iNa2018(config.distributed, root=config.data_path,
212 | batch_size=config.batch_size, num_works=config.workers)
213 |
214 | train_loader = dataset.train_instance
215 | val_loader = dataset.eval
216 | if config.distributed:
217 | train_sampler = dataset.dist_sampler
218 |
219 | # define loss function (criterion) and optimizer
220 | criterion = nn.CrossEntropyLoss().cuda(config.gpu)
221 |
222 | if config.dataset == 'places':
223 | optimizer = torch.optim.SGD([{"params": block.parameters()},
224 | {"params": classifier.parameters()}], config.lr,
225 | momentum=config.momentum,
226 | weight_decay=config.weight_decay)
227 | else:
228 | optimizer = torch.optim.SGD([{"params": model.parameters()},
229 | {"params": classifier.parameters()}], config.lr,
230 | momentum=config.momentum,
231 | weight_decay=config.weight_decay)
232 |
233 | for epoch in range(config.num_epochs):
234 | if config.distributed:
235 | train_sampler.set_epoch(epoch)
236 |
237 | adjust_learning_rate(optimizer, epoch, config)
238 |
239 | if config.dataset != 'places':
240 | block = None
241 | # train for one epoch
242 | train(train_loader, model, classifier, criterion, optimizer, epoch, config, logger, block)
243 |
244 | # evaluate on validation set
245 | acc1, ece = validate(val_loader, model, classifier, criterion, config, logger, block)
246 |
247 | # remember best acc@1 and save checkpoint
248 | is_best = acc1 > best_acc1
249 | best_acc1 = max(acc1, best_acc1)
250 | if is_best:
251 | its_ece = ece
252 | logger.info('Best Prec@1: %.3f%% ECE: %.3f%%\n' % (best_acc1, its_ece))
253 |
254 | if not config.multiprocessing_distributed or (config.multiprocessing_distributed
255 | and config.rank % ngpus_per_node == 0):
256 | if config.dataset == 'places':
257 | save_checkpoint({
258 | 'epoch': epoch + 1,
259 | 'state_dict_model': model.state_dict(),
260 | 'state_dict_classifier': classifier.state_dict(),
261 | 'state_dict_block': block.state_dict(),
262 | 'best_acc1': best_acc1,
263 | 'its_ece': its_ece,
264 | }, is_best, model_dir)
265 |
266 | else:
267 | save_checkpoint({
268 | 'epoch': epoch + 1,
269 | 'state_dict_model': model.state_dict(),
270 | 'state_dict_classifier': classifier.state_dict(),
271 | 'best_acc1': best_acc1,
272 | 'its_ece': its_ece,
273 | }, is_best, model_dir)
274 |
275 |
276 | def train(train_loader, model, classifier, criterion, optimizer, epoch, config, logger, block=None):
277 | batch_time = AverageMeter('Time', ':6.3f')
278 | data_time = AverageMeter('Data', ':6.3f')
279 | losses = AverageMeter('Loss', ':.3f')
280 | top1 = AverageMeter('Acc@1', ':6.3f')
281 | top5 = AverageMeter('Acc@5', ':6.3f')
282 | progress = ProgressMeter(
283 | len(train_loader),
284 | [batch_time, losses, top1, top5],
285 | prefix="Epoch: [{}]".format(epoch))
286 |
287 | # switch to train mode
288 | if config.dataset == 'places':
289 | model.eval()
290 | block.train()
291 | else:
292 | model.train()
293 | classifier.train()
294 |
295 | training_data_num = len(train_loader.dataset)
296 | end_steps = int(training_data_num / train_loader.batch_size)
297 |
298 | end = time.time()
299 | for i, (images, target) in enumerate(train_loader):
300 | if i > end_steps:
301 | break
302 |
303 | # measure data loading time
304 | data_time.update(time.time() - end)
305 |
306 | if torch.cuda.is_available():
307 | images = images.cuda(config.gpu, non_blocking=True)
308 | target = target.cuda(config.gpu, non_blocking=True)
309 |
310 | if config.mixup is True:
311 | images, targets_a, targets_b, lam = mixup_data(images, target, alpha=config.alpha)
312 | if config.dataset == 'places':
313 | with torch.no_grad():
314 | feat_a = model(images)
315 | feat = block(feat_a.detach())
316 | output = classifier(feat)
317 | else:
318 | feat = model(images)
319 | output = classifier(feat)
320 | loss = mixup_criterion(criterion, output, targets_a, targets_b, lam)
321 | else:
322 | if config.dataset == 'places':
323 | with torch.no_grad():
324 | feat_a = model(images)
325 | feat = block(feat_a.detach())
326 | output = classifier(feat)
327 | else:
328 | feat = model(images)
329 | output = classifier(feat)
330 |
331 | loss = criterion(output, target)
332 |
333 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
334 | losses.update(loss.item(), images.size(0))
335 | top1.update(acc1[0], images.size(0))
336 | top5.update(acc5[0], images.size(0))
337 |
338 | # compute gradient and do SGD step
339 | optimizer.zero_grad()
340 | loss.backward()
341 | optimizer.step()
342 |
343 | # measure elapsed time
344 | batch_time.update(time.time() - end)
345 | end = time.time()
346 |
347 | if i % config.print_freq == 0:
348 | progress.display(i, logger)
349 |
350 |
351 | def validate(val_loader, model, classifier, criterion, config, logger, block=None):
352 | batch_time = AverageMeter('Time', ':6.3f')
353 | losses = AverageMeter('Loss', ':.3f')
354 | top1 = AverageMeter('Acc@1', ':6.3f')
355 | top5 = AverageMeter('Acc@5', ':6.3f')
356 | progress = ProgressMeter(
357 | len(val_loader),
358 | [batch_time, losses, top1, top5],
359 | prefix='Eval: ')
360 |
361 | # switch to evaluate mode
362 | model.eval()
363 | if config.dataset == 'places':
364 | block.eval()
365 | classifier.eval()
366 | class_num = torch.zeros(config.num_classes).cuda()
367 | correct = torch.zeros(config.num_classes).cuda()
368 |
369 | confidence = np.array([])
370 | pred_class = np.array([])
371 | true_class = np.array([])
372 |
373 | with torch.no_grad():
374 | end = time.time()
375 | for i, (images, target) in enumerate(val_loader):
376 | if config.gpu is not None:
377 | images = images.cuda(config.gpu, non_blocking=True)
378 | if torch.cuda.is_available():
379 | target = target.cuda(config.gpu, non_blocking=True)
380 |
381 | # compute output
382 | feat = model(images)
383 | if config.dataset == 'places':
384 | feat = block(feat)
385 | output = classifier(feat)
386 | loss = criterion(output, target)
387 |
388 | # measure accuracy and record loss
389 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
390 | losses.update(loss.item(), images.size(0))
391 | top1.update(acc1[0], images.size(0))
392 | top5.update(acc5[0], images.size(0))
393 |
394 | _, predicted = output.max(1)
395 | target_one_hot = F.one_hot(target, config.num_classes)
396 | predict_one_hot = F.one_hot(predicted, config.num_classes)
397 | class_num = class_num + target_one_hot.sum(dim=0).to(torch.float)
398 | correct = correct + (target_one_hot + predict_one_hot == 2).sum(dim=0).to(torch.float)
399 |
400 | prob = torch.softmax(output, dim=1)
401 | confidence_part, pred_class_part = torch.max(prob, dim=1)
402 | confidence = np.append(confidence, confidence_part.cpu().numpy())
403 | pred_class = np.append(pred_class, pred_class_part.cpu().numpy())
404 | true_class = np.append(true_class, target.cpu().numpy())
405 |
406 | # measure elapsed time
407 | batch_time.update(time.time() - end)
408 | end = time.time()
409 |
410 | if i % config.print_freq == 0:
411 | progress.display(i, logger)
412 |
413 | acc_classes = correct / class_num
414 | head_acc = acc_classes[config.head_class_idx[0]:config.head_class_idx[1]].mean() * 100
415 |
416 | med_acc = acc_classes[config.med_class_idx[0]:config.med_class_idx[1]].mean() * 100
417 | tail_acc = acc_classes[config.tail_class_idx[0]:config.tail_class_idx[1]].mean() * 100
418 | logger.info('* Acc@1 {top1.avg:.3f}% Acc@5 {top5.avg:.3f}% HAcc {head_acc:.3f}% MAcc {med_acc:.3f}% TAcc {tail_acc:.3f}%.'.format(top1=top1, top5=top5, head_acc=head_acc, med_acc=med_acc, tail_acc=tail_acc))
419 |
420 | cal = calibration(true_class, pred_class, confidence, num_bins=15)
421 | logger.info('* ECE {ece:.3f}%.'.format(ece=cal['expected_calibration_error'] * 100))
422 |
423 | return top1.avg, cal['expected_calibration_error'] * 100
424 |
425 |
426 | def save_checkpoint(state, is_best, model_dir):
427 | filename = model_dir + '/current.pth.tar'
428 | torch.save(state, filename)
429 | if is_best:
430 | shutil.copyfile(filename, model_dir + '/model_best.pth.tar')
431 |
432 |
433 | def adjust_learning_rate(optimizer, epoch, config):
434 | """Sets the learning rate"""
435 | if config.cos:
436 | lr_min = 0
437 | lr_max = config.lr
438 | lr = lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(epoch / config.num_epochs * 3.1415926535))
439 | else:
440 | epoch = epoch + 1
441 | if epoch <= 5:
442 | lr = config.lr * epoch / 5
443 | elif epoch > 180:
444 | lr = config.lr * 0.01
445 | elif epoch > 160:
446 | lr = config.lr * 0.1
447 | else:
448 | lr = config.lr
449 |
450 | for param_group in optimizer.param_groups:
451 | param_group['lr'] = lr
452 |
453 |
454 | if __name__ == '__main__':
455 | main()
456 |
--------------------------------------------------------------------------------
/train_stage2.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | import time
6 | import warnings
7 | import numpy as np
8 | import pprint
9 | import math
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.parallel
14 | import torch.distributed as dist
15 | import torch.optim
16 | import torch.multiprocessing as mp
17 | import torch.utils.data
18 | import torch.utils.data.distributed
19 | import torch.nn.functional as F
20 |
21 | from datasets.cifar10 import CIFAR10_LT
22 | from datasets.cifar100 import CIFAR100_LT
23 | from datasets.places import Places_LT
24 | from datasets.imagenet import ImageNet_LT
25 | from datasets.ina2018 import iNa2018
26 |
27 | from models import resnet
28 | from models import resnet_places
29 | from models import resnet_cifar
30 |
31 | from utils import config, update_config, create_logger
32 | from utils import AverageMeter, ProgressMeter
33 | from utils import accuracy, calibration
34 |
35 | from methods import mixup_data, mixup_criterion
36 | from methods import LabelAwareSmoothing, LearnableWeightScaling
37 |
38 |
39 | def parse_args():
40 | parser = argparse.ArgumentParser(description='MiSLAS training (Stage-2)')
41 | parser.add_argument('--cfg',
42 | help='experiment configure file name',
43 | required=True,
44 | type=str)
45 | parser.add_argument('opts',
46 | help="Modify config options using the command-line",
47 | default=None,
48 | nargs=argparse.REMAINDER)
49 | args = parser.parse_args()
50 | update_config(config, args)
51 |
52 | return args
53 |
54 |
55 | best_acc1 = 0
56 | its_ece = 100
57 |
58 |
59 | def main():
60 |
61 | args = parse_args()
62 | logger, model_dir = create_logger(config, args.cfg)
63 | logger.info('\n' + pprint.pformat(args))
64 | logger.info('\n' + str(config))
65 |
66 | if config.deterministic:
67 | seed = 0
68 | torch.backends.cudnn.deterministic = True
69 | torch.backends.cudnn.benchmark = False
70 | random.seed(seed)
71 | np.random.seed(seed)
72 | os.environ['PYTHONHASHSEED'] = str(seed)
73 | torch.manual_seed(seed)
74 | torch.cuda.manual_seed(seed)
75 | torch.cuda.manual_seed_all(seed)
76 |
77 | if config.gpu is not None:
78 | warnings.warn('You have chosen a specific GPU. This will completely '
79 | 'disable data parallelism.')
80 |
81 | if config.dist_url == "env://" and config.world_size == -1:
82 | config.world_size = int(os.environ["WORLD_SIZE"])
83 |
84 | config.distributed = config.world_size > 1 or config.multiprocessing_distributed
85 |
86 | ngpus_per_node = torch.cuda.device_count()
87 | if config.multiprocessing_distributed:
88 | # Since we have ngpus_per_node processes per node, the total world_size
89 | # needs to be adjusted accordingly
90 | config.world_size = ngpus_per_node * config.world_size
91 | # Use torch.multiprocessing.spawn to launch distributed processes: the
92 | # main_worker process function
93 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, config, logger))
94 | else:
95 | # Simply call main_worker function
96 | main_worker(config.gpu, ngpus_per_node, config, logger, model_dir)
97 |
98 |
99 | def main_worker(gpu, ngpus_per_node, config, logger, model_dir):
100 | global best_acc1, its_ece
101 | config.gpu = gpu
102 | # start_time = time.strftime("%Y%m%d_%H%M%S", time.localtime())
103 |
104 | if config.gpu is not None:
105 | logger.info("Use GPU: {} for training".format(config.gpu))
106 |
107 | if config.distributed:
108 | if config.dist_url == "env://" and config.rank == -1:
109 | config.rank = int(os.environ["RANK"])
110 | if config.multiprocessing_distributed:
111 | # For multiprocessing distributed training, rank needs to be the
112 | # global rank among all the processes
113 | config.rank = config.rank * ngpus_per_node + gpu
114 | dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url,
115 | world_size=config.world_size, rank=config.rank)
116 |
117 | if config.dataset == 'cifar10' or config.dataset == 'cifar100':
118 | model = getattr(resnet_cifar, config.backbone)()
119 | classifier = getattr(resnet_cifar, 'Classifier')(feat_in=64, num_classes=config.num_classes)
120 |
121 | elif config.dataset == 'imagenet' or config.dataset == 'ina2018':
122 | model = getattr(resnet, config.backbone)()
123 | classifier = getattr(resnet, 'Classifier')(feat_in=2048, num_classes=config.num_classes)
124 |
125 | elif config.dataset == 'places':
126 | model = getattr(resnet_places, config.backbone)(pretrained=True)
127 | classifier = getattr(resnet_places, 'Classifier')(feat_in=2048, num_classes=config.num_classes)
128 | block = getattr(resnet_places, 'Bottleneck')(2048, 512, groups=1, base_width=64,
129 | dilation=1, norm_layer=nn.BatchNorm2d)
130 |
131 | lws_model = LearnableWeightScaling(num_classes=config.num_classes)
132 |
133 | if not torch.cuda.is_available():
134 | logger.info('using CPU, this will be slow')
135 | elif config.distributed:
136 | # For multiprocessing distributed, DistributedDataParallel constructor
137 | # should always set the single device scope, otherwise,
138 | # DistributedDataParallel will use all available devices.
139 | if config.gpu is not None:
140 | torch.cuda.set_device(config.gpu)
141 | model.cuda(config.gpu)
142 | classifier.cuda(config.gpu)
143 | lws_model.cuda(config.gpu)
144 | # When using a single GPU per process and per
145 | # DistributedDataParallel, we need to divide the batch size
146 | # ourselves based on the total number of GPUs we have
147 | config.batch_size = int(config.batch_size / ngpus_per_node)
148 | config.workers = int((config.workers + ngpus_per_node - 1) / ngpus_per_node)
149 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu])
150 | classifier = torch.nn.parallel.DistributedDataParallel(classifier, device_ids=[config.gpu])
151 | lws_model = torch.nn.parallel.DistributedDataParallel(lws_model, device_ids=[config.gpu])
152 |
153 | if config.dataset == 'places':
154 | block.cuda(config.gpu)
155 | block = torch.nn.parallel.DistributedDataParallel(block, device_ids=[config.gpu])
156 | else:
157 | model.cuda()
158 | classifier.cuda()
159 | lws_model.cuda()
160 | # DistributedDataParallel will divide and allocate batch_size to all
161 | # available GPUs if device_ids are not set
162 | model = torch.nn.parallel.DistributedDataParallel(model)
163 | classifier = torch.nn.parallel.DistributedDataParallel(classifier)
164 | lws_model = torch.nn.parallel.DistributedDataParallel(lws_model)
165 | if config.dataset == 'places':
166 | block.cuda()
167 | block = torch.nn.parallel.DistributedDataParallel(block)
168 |
169 | elif config.gpu is not None:
170 | torch.cuda.set_device(config.gpu)
171 | model = model.cuda(config.gpu)
172 | classifier = classifier.cuda(config.gpu)
173 | lws_model = lws_model.cuda(config.gpu)
174 | if config.dataset == 'places':
175 | block.cuda(config.gpu)
176 | else:
177 | # DataParallel will divide and allocate batch_size to all available GPUs
178 | model = torch.nn.DataParallel(model).cuda()
179 | classifier = torch.nn.DataParallel(classifier).cuda()
180 | lws_model = torch.nn.DataParallel(lws_model).cuda()
181 | if config.dataset == 'places':
182 | block = torch.nn.DataParallel(block).cuda()
183 |
184 | # optionally resume from a checkpoint
185 | if config.resume:
186 | if os.path.isfile(config.resume):
187 | logger.info("=> loading checkpoint '{}'".format(config.resume))
188 | if config.gpu is None:
189 | checkpoint = torch.load(config.resume)
190 | else:
191 | # Map model to be loaded to specified single gpu.
192 | loc = 'cuda:{}'.format(config.gpu)
193 | checkpoint = torch.load(config.resume, map_location=loc)
194 | # config.start_epoch = checkpoint['epoch']
195 | best_acc1 = checkpoint['best_acc1']
196 | its_ece = checkpoint['its_ece']
197 | if config.gpu is not None:
198 | # best_acc1 may be from a checkpoint from a different GPU
199 | best_acc1 = best_acc1.to(config.gpu)
200 | model.load_state_dict(checkpoint['state_dict_model'])
201 | classifier.load_state_dict(checkpoint['state_dict_classifier'])
202 | if config.dataset == 'places':
203 | block.load_state_dict(checkpoint['state_dict_block'])
204 | logger.info("=> loaded checkpoint '{}' (epoch {})"
205 | .format(config.resume, checkpoint['epoch']))
206 | else:
207 | logger.info("=> no checkpoint found at '{}'".format(config.resume))
208 |
209 | # Data loading code
210 | if config.dataset == 'cifar10':
211 | dataset = CIFAR10_LT(config.distributed, root=config.data_path, imb_factor=config.imb_factor,
212 | batch_size=config.batch_size, num_works=config.workers)
213 |
214 | elif config.dataset == 'cifar100':
215 | dataset = CIFAR100_LT(config.distributed, root=config.data_path, imb_factor=config.imb_factor,
216 | batch_size=config.batch_size, num_works=config.workers)
217 |
218 | elif config.dataset == 'places':
219 | dataset = Places_LT(config.distributed, root=config.data_path,
220 | batch_size=config.batch_size, num_works=config.workers)
221 |
222 | elif config.dataset == 'imagenet':
223 | dataset = ImageNet_LT(config.distributed, root=config.data_path,
224 | batch_size=config.batch_size, num_works=config.workers)
225 |
226 | elif config.dataset == 'ina2018':
227 | dataset = iNa2018(config.distributed, root=config.data_path,
228 | batch_size=config.batch_size, num_works=config.workers)
229 |
230 | train_loader = dataset.train_balance
231 | val_loader = dataset.eval
232 | cls_num_list = dataset.cls_num_list
233 | if config.distributed:
234 | train_sampler = dataset.dist_sampler
235 |
236 | # define loss function (criterion) and optimizer
237 |
238 | criterion = LabelAwareSmoothing(cls_num_list=cls_num_list, smooth_head=config.smooth_head,
239 | smooth_tail=config.smooth_tail).cuda(config.gpu)
240 |
241 | optimizer = torch.optim.SGD([{"params": classifier.parameters()},
242 | {'params': lws_model.parameters()}], config.lr,
243 | momentum=config.momentum,
244 | weight_decay=config.weight_decay)
245 |
246 | for epoch in range(config.num_epochs):
247 | if config.distributed:
248 | train_sampler.set_epoch(epoch)
249 |
250 | adjust_learning_rate(optimizer, epoch, config)
251 |
252 | if config.dataset != 'places':
253 | block = None
254 | # train for one epoch
255 | train(train_loader, model, classifier, lws_model, criterion, optimizer, epoch, config, logger, block)
256 |
257 | # evaluate on validation set
258 | acc1, ece = validate(val_loader, model, classifier, lws_model, criterion, config, logger, block)
259 | # remember best acc@1 and save checkpoint
260 | is_best = acc1 > best_acc1
261 | best_acc1 = max(acc1, best_acc1)
262 | if is_best:
263 | its_ece = ece
264 | logger.info('Best Prec@1: %.3f%% ECE: %.3f%%\n' % (best_acc1, its_ece))
265 | if not config.multiprocessing_distributed or (config.multiprocessing_distributed
266 | and config.rank % ngpus_per_node == 0):
267 | if config.dataset == 'places':
268 | save_checkpoint({
269 | 'epoch': epoch + 1,
270 | 'state_dict_model': model.state_dict(),
271 | 'state_dict_classifier': classifier.state_dict(),
272 | 'state_dict_block': block.state_dict(),
273 | 'state_dict_lws_model': lws_model.state_dict(),
274 | 'best_acc1': best_acc1,
275 | 'its_ece': its_ece,
276 | }, is_best, model_dir)
277 | else:
278 | save_checkpoint({
279 | 'epoch': epoch + 1,
280 | 'state_dict_model': model.state_dict(),
281 | 'state_dict_classifier': classifier.state_dict(),
282 | 'state_dict_lws_model': lws_model.state_dict(),
283 | 'best_acc1': best_acc1,
284 | 'its_ece': its_ece,
285 | }, is_best, model_dir)
286 |
287 |
288 | def train(train_loader, model, classifier, lws_model, criterion, optimizer, epoch, config, logger, block=None):
289 | batch_time = AverageMeter('Time', ':6.3f')
290 | data_time = AverageMeter('Data', ':6.3f')
291 | losses = AverageMeter('Loss', ':.3f')
292 | top1 = AverageMeter('Acc@1', ':6.3f')
293 | top5 = AverageMeter('Acc@5', ':6.3f')
294 | training_data_num = len(train_loader.dataset)
295 | end_steps = int(np.ceil(float(training_data_num) / float(train_loader.batch_size)))
296 | progress = ProgressMeter(
297 | end_steps,
298 | [batch_time, losses, top1, top5],
299 | prefix="Epoch: [{}]".format(epoch))
300 |
301 | # switch to train mode
302 |
303 | if config.dataset == 'places':
304 | model.eval()
305 | if config.shift_bn:
306 | block.train()
307 | else:
308 | block.eval()
309 | else:
310 | if config.shift_bn:
311 | model.train()
312 | else:
313 | model.eval()
314 | classifier.train()
315 |
316 | end = time.time()
317 |
318 | for i, (images, target) in enumerate(train_loader):
319 | if i > end_steps:
320 | break
321 |
322 | # measure data loading time
323 | data_time.update(time.time() - end)
324 |
325 | if torch.cuda.is_available():
326 | images = images.cuda(config.gpu, non_blocking=True)
327 | target = target.cuda(config.gpu, non_blocking=True)
328 |
329 | if config.mixup is True:
330 | images, targets_a, targets_b, lam = mixup_data(images, target, alpha=config.alpha)
331 | with torch.no_grad():
332 | if config.dataset == 'places':
333 | feat = block(model(images))
334 | else:
335 | feat = model(images)
336 | output = classifier(feat.detach())
337 | output = lws_model(output)
338 | loss = mixup_criterion(criterion, output, targets_a, targets_b, lam)
339 | else:
340 | # compute output
341 | with torch.no_grad():
342 | if config.dataset == 'places':
343 | feat = block(model(images))
344 | else:
345 | feat = model(images)
346 | output = classifier(feat.detach())
347 | output = lws_model(output)
348 | loss = criterion(output, target)
349 |
350 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
351 | losses.update(loss.item(), images.size(0))
352 | top1.update(acc1[0], images.size(0))
353 | top5.update(acc5[0], images.size(0))
354 |
355 | # compute gradient and do SGD step
356 | optimizer.zero_grad()
357 | loss.backward()
358 | optimizer.step()
359 |
360 | # measure elapsed time
361 | batch_time.update(time.time() - end)
362 | end = time.time()
363 |
364 | if i % config.print_freq == 0:
365 | progress.display(i, logger)
366 |
367 |
368 | def validate(val_loader, model, classifier, lws_model, criterion, config, logger, block=None):
369 | batch_time = AverageMeter('Time', ':6.3f')
370 | losses = AverageMeter('Loss', ':.3f')
371 | top1 = AverageMeter('Acc@1', ':6.3f')
372 | top5 = AverageMeter('Acc@5', ':6.3f')
373 | progress = ProgressMeter(
374 | len(val_loader),
375 | [batch_time, losses, top1, top5],
376 | prefix='Eval: ')
377 |
378 | # switch to evaluate mode
379 | model.eval()
380 | if config.dataset == 'places':
381 | block.eval()
382 | classifier.eval()
383 | class_num = torch.zeros(config.num_classes).cuda()
384 | correct = torch.zeros(config.num_classes).cuda()
385 |
386 | confidence = np.array([])
387 | pred_class = np.array([])
388 | true_class = np.array([])
389 |
390 | with torch.no_grad():
391 | end = time.time()
392 | for i, (images, target) in enumerate(val_loader):
393 | if config.gpu is not None:
394 | images = images.cuda(config.gpu, non_blocking=True)
395 | if torch.cuda.is_available():
396 | target = target.cuda(config.gpu, non_blocking=True)
397 |
398 | # compute output
399 | if config.dataset == 'places':
400 | feat = block(model(images))
401 | else:
402 | feat = model(images)
403 | output = classifier(feat)
404 | output = lws_model(output)
405 | loss = criterion(output, target)
406 |
407 | # measure accuracy and record loss
408 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
409 | losses.update(loss.item(), images.size(0))
410 | top1.update(acc1[0], images.size(0))
411 | top5.update(acc5[0], images.size(0))
412 |
413 | _, predicted = output.max(1)
414 | target_one_hot = F.one_hot(target, config.num_classes)
415 | predict_one_hot = F.one_hot(predicted, config.num_classes)
416 | class_num = class_num + target_one_hot.sum(dim=0).to(torch.float)
417 | correct = correct + (target_one_hot + predict_one_hot == 2).sum(dim=0).to(torch.float)
418 |
419 | prob = torch.softmax(output, dim=1)
420 | confidence_part, pred_class_part = torch.max(prob, dim=1)
421 | confidence = np.append(confidence, confidence_part.cpu().numpy())
422 | pred_class = np.append(pred_class, pred_class_part.cpu().numpy())
423 | true_class = np.append(true_class, target.cpu().numpy())
424 |
425 | # measure elapsed time
426 | batch_time.update(time.time() - end)
427 | end = time.time()
428 |
429 | if i % config.print_freq == 0:
430 | progress.display(i, logger)
431 |
432 | acc_classes = correct / class_num
433 | head_acc = acc_classes[config.head_class_idx[0]:config.head_class_idx[1]].mean() * 100
434 | med_acc = acc_classes[config.med_class_idx[0]:config.med_class_idx[1]].mean() * 100
435 | tail_acc = acc_classes[config.tail_class_idx[0]:config.tail_class_idx[1]].mean() * 100
436 |
437 | logger.info('* Acc@1 {top1.avg:.3f}% Acc@5 {top5.avg:.3f}% HAcc {head_acc:.3f}% MAcc {med_acc:.3f}% TAcc {tail_acc:.3f}%.'.format(top1=top1, top5=top5, head_acc=head_acc, med_acc=med_acc, tail_acc=tail_acc))
438 |
439 | cal = calibration(true_class, pred_class, confidence, num_bins=15)
440 | logger.info('* ECE {ece:.3f}%.'.format(ece=cal['expected_calibration_error'] * 100))
441 |
442 | return top1.avg, cal['expected_calibration_error'] * 100
443 |
444 |
445 | def save_checkpoint(state, is_best, model_dir):
446 | filename = model_dir + '/current.pth.tar'
447 | torch.save(state, filename)
448 | if is_best:
449 | shutil.copyfile(filename, model_dir + '/model_best.pth.tar')
450 |
451 |
452 | def adjust_learning_rate(optimizer, epoch, config):
453 | """Sets the learning rate"""
454 | lr_min = 0
455 | lr_max = config.lr
456 | lr = lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(epoch / config.num_epochs * 3.1415926535))
457 |
458 | for idx, param_group in enumerate(optimizer.param_groups):
459 | if idx == 0:
460 | param_group['lr'] = config.lr_factor * lr
461 | else:
462 | param_group['lr'] = 1.00 * lr
463 |
464 |
465 | if __name__ == '__main__':
466 | main()
467 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from .logger import _C as config
6 | from .logger import update_config, create_logger
7 |
8 | from .metric import accuracy, calibration
9 | from .meter import AverageMeter, ProgressMeter
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from yacs.config import CfgNode as CN
3 | import os
4 | import time
5 | import logging
6 |
7 | _C = CN()
8 | _C.name = ''
9 | _C.print_freq = 40
10 | _C.workers = 16
11 | _C.log_dir = 'logs'
12 | _C.model_dir = 'ckps'
13 |
14 |
15 | _C.dataset = 'cifar10'
16 | _C.data_path = './data/cifar10'
17 | _C.num_classes = 100
18 | _C.imb_factor = 0.01
19 | _C.backbone = 'resnet32_fe'
20 | _C.resume = ''
21 | _C.head_class_idx = [0, 1]
22 | _C.med_class_idx = [0, 1]
23 | _C.tail_class_idx = [0, 1]
24 |
25 | _C.deterministic = True
26 | _C.gpu = 0
27 | _C.world_size = -1
28 | _C.rank = -1
29 | _C.dist_url = 'tcp://224.66.41.62:23456'
30 | _C.dist_backend = 'nccl'
31 | _C.multiprocessing_distributed = False
32 | _C.distributed = False
33 |
34 | _C.mode = None
35 | _C.smooth_tail = None
36 | _C.smooth_head = None
37 | _C.shift_bn = False
38 | _C.lr_factor = None
39 | _C.lr = 0.1
40 | _C.batch_size = 128
41 | _C.weight_decay = 0.002
42 | _C.num_epochs = 200
43 | _C.momentum = 0.9
44 | _C.cos = False
45 | _C.mixup = True
46 | _C.alpha = 1.0
47 |
48 | def update_config(cfg, args):
49 | cfg.defrost()
50 |
51 | cfg.merge_from_file(args.cfg)
52 | cfg.merge_from_list(args.opts)
53 |
54 | # cfg.freeze()
55 |
56 | def create_logger(cfg, cfg_name):
57 | time_str = time.strftime('%Y%m%d%H%M')
58 |
59 | cfg_name = os.path.basename(cfg_name).split('.')[0]
60 |
61 | log_dir = Path("saved") / (cfg_name + '_' + time_str) / Path(cfg.log_dir)
62 | print('=> creating {}'.format(log_dir))
63 | log_dir.mkdir(parents=True, exist_ok=True)
64 |
65 |
66 | log_file = '{}.txt'.format(cfg_name)
67 | final_log_file = log_dir / log_file
68 | head = '%(asctime)-15s %(message)s'
69 | logging.basicConfig(filename=str(final_log_file),
70 | format=head)
71 | logger = logging.getLogger()
72 | logger.setLevel(logging.INFO)
73 | console = logging.StreamHandler()
74 | logging.getLogger('').addHandler(console)
75 |
76 | model_dir = Path("saved") / (cfg_name + '_' + time_str) / Path(cfg.model_dir)
77 | print('=> creating {}'.format(model_dir))
78 | model_dir.mkdir(parents=True, exist_ok=True)
79 |
80 | return logger, str(model_dir)
--------------------------------------------------------------------------------
/utils/meter.py:
--------------------------------------------------------------------------------
1 | class AverageMeter(object):
2 | """Computes and stores the average and current value"""
3 | def __init__(self, name, fmt=':f'):
4 | self.name = name
5 | self.fmt = fmt
6 | self.reset()
7 |
8 | def reset(self):
9 | self.val = 0
10 | self.avg = 0
11 | self.sum = 0
12 | self.count = 0
13 |
14 | def update(self, val, n=1):
15 | self.val = val
16 | self.sum += val * n
17 | self.count += n
18 | self.avg = self.sum / self.count
19 |
20 | def __str__(self):
21 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
22 | return fmtstr.format(**self.__dict__)
23 |
24 |
25 | class ProgressMeter(object):
26 | def __init__(self, num_batches, meters, prefix=""):
27 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
28 | self.meters = meters
29 | self.prefix = prefix
30 |
31 | def display(self, batch, logger):
32 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
33 | entries += [str(meter) for meter in self.meters]
34 | logger.info('\t'.join(entries))
35 |
36 | def _get_batch_fmtstr(self, num_batches):
37 | num_digits = len(str(num_batches // 1))
38 | fmt = '{:' + str(num_digits) + 'd}'
39 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
--------------------------------------------------------------------------------
/utils/metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | def accuracy(output, target, topk=(1,)):
5 | """Computes the accuracy over the k top predictions for the specified values of k"""
6 | with torch.no_grad():
7 | maxk = max(topk)
8 | batch_size = target.size(0)
9 |
10 | _, pred = output.topk(maxk, 1, True, True)
11 | pred = pred.t()
12 | correct = pred.eq(target.view(1, -1).expand_as(pred))
13 |
14 | res = []
15 | for k in topk:
16 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
17 | res.append(correct_k.mul_(100.0 / batch_size))
18 | return res
19 |
20 |
21 | def calibration(true_labels, pred_labels, confidences, num_bins=15):
22 | """Collects predictions into bins used to draw a reliability diagram.
23 |
24 | Arguments:
25 | true_labels: the true labels for the test examples
26 | pred_labels: the predicted labels for the test examples
27 | confidences: the predicted confidences for the test examples
28 | num_bins: number of bins
29 |
30 | The true_labels, pred_labels, confidences arguments must be NumPy arrays;
31 | pred_labels and true_labels may contain numeric or string labels.
32 |
33 | For a multi-class model, the predicted label and confidence should be those
34 | of the highest scoring class.
35 |
36 | Returns a dictionary containing the following NumPy arrays:
37 | accuracies: the average accuracy for each bin
38 | confidences: the average confidence for each bin
39 | counts: the number of examples in each bin
40 | bins: the confidence thresholds for each bin
41 | avg_accuracy: the accuracy over the entire test set
42 | avg_confidence: the average confidence over the entire test set
43 | expected_calibration_error: a weighted average of all calibration gaps
44 | max_calibration_error: the largest calibration gap across all bins
45 | """
46 | assert(len(confidences) == len(pred_labels))
47 | assert(len(confidences) == len(true_labels))
48 | assert(num_bins > 0)
49 |
50 | bin_size = 1.0 / num_bins
51 | bins = np.linspace(0.0, 1.0, num_bins + 1)
52 | indices = np.digitize(confidences, bins, right=True)
53 |
54 | bin_accuracies = np.zeros(num_bins, dtype=np.float)
55 | bin_confidences = np.zeros(num_bins, dtype=np.float)
56 | bin_counts = np.zeros(num_bins, dtype=np.int)
57 |
58 | for b in range(num_bins):
59 | selected = np.where(indices == b + 1)[0]
60 | if len(selected) > 0:
61 | bin_accuracies[b] = np.mean(true_labels[selected] == pred_labels[selected])
62 | bin_confidences[b] = np.mean(confidences[selected])
63 | bin_counts[b] = len(selected)
64 |
65 | avg_acc = np.sum(bin_accuracies * bin_counts) / np.sum(bin_counts)
66 | avg_conf = np.sum(bin_confidences * bin_counts) / np.sum(bin_counts)
67 |
68 | gaps = np.abs(bin_accuracies - bin_confidences)
69 | ece = np.sum(gaps * bin_counts) / np.sum(bin_counts)
70 | mce = np.max(gaps)
71 |
72 | return { "accuracies": bin_accuracies,
73 | "confidences": bin_confidences,
74 | "counts": bin_counts,
75 | "bins": bins,
76 | "avg_accuracy": avg_acc,
77 | "avg_confidence": avg_conf,
78 | "expected_calibration_error": ece,
79 | "max_calibration_error": mce }
--------------------------------------------------------------------------------