├── LICENSE ├── README.md ├── assets ├── MiSLAS.PNG └── MiSLAS.pdf ├── config ├── cifar10 │ ├── cifar10_imb001_stage1_mixup.yaml │ ├── cifar10_imb001_stage2_mislas.yaml │ ├── cifar10_imb002_stage1_mixup.yaml │ ├── cifar10_imb002_stage2_mislas.yaml │ ├── cifar10_imb01_stage1_mixup.yaml │ └── cifar10_imb01_stage2_mislas.yaml ├── cifar100 │ ├── cifar100_imb001_stage1_mixup.yaml │ ├── cifar100_imb001_stage2_mislas.yaml │ ├── cifar100_imb002_stage1_mixup.yaml │ ├── cifar100_imb002_stage2_mislas.yaml │ ├── cifar100_imb01_stage1_mixup.yaml │ └── cifar100_imb01_stage2_mislas.yaml ├── imagenet │ ├── imagenet_resnet50_stage1_mixup.yaml │ └── imagenet_resnet50_stage2_mislas.yaml ├── ina2018 │ ├── ina2018_resnet50_stage1_mixup.yaml │ └── ina2018_resnet50_stage2_mislas.yaml └── places │ ├── places_resnet152_stage1_mixup.yaml │ └── places_resnet152_stage2_mislas.yaml ├── datasets ├── cifar10.py ├── cifar100.py ├── data_txt │ ├── ImageNet_LT_test.txt │ ├── ImageNet_LT_train.txt │ ├── Places_LT_test.txt │ ├── Places_LT_train.txt │ ├── iNaturalist18_train.txt │ └── iNaturalist18_val.txt ├── imagenet.py ├── ina2018.py ├── places.py └── sampler.py ├── eval.py ├── methods.py ├── models ├── resnet.py ├── resnet_cifar.py └── resnet_places.py ├── reliability_diagrams.py ├── requirements.txt ├── train_stage1.py ├── train_stage2.py └── utils ├── __init__.py ├── logger.py ├── meter.py └── metric.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Zhisheng Zhong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MiSLAS 2 | **Improving Calibration for Long-Tailed Recognition** 3 | 4 | **Authors**: Zhisheng Zhong, Jiequan Cui, Shu Liu, Jiaya Jia 5 | 6 | [[`arXiv`](https://arxiv.org/pdf/2104.00466.pdf)] [[`slide`]](./assets/MiSLAS.pdf) [[`BibTeX`](#Citation)] 7 | 8 |
9 | 10 |

11 | 12 | **Introduction**: This repository provides an implementation for the CVPR 2021 paper: "[Improving Calibration for Long-Tailed Recognition](https://arxiv.org/pdf/2104.00466.pdf)" based on [LDAM-DRW](https://github.com/kaidic/LDAM-DRW) and [Decoupling models](https://github.com/facebookresearch/classifier-balancing). *Our study shows, because of the extreme imbalanced composition ratio of each class, networks trained on long-tailed datasets are more miscalibrated and over-confident*. MiSLAS is a simple, and efficient two-stage framework for long-tailed recognition, which greatly improves recognition accuracy and markedly relieves over-confidence simultaneously. 13 | 14 | ## Installation 15 | 16 | **Requirements** 17 | 18 | * Python 3.7 19 | * torchvision 0.4.0 20 | * Pytorch 1.2.0 21 | * yacs 0.1.8 22 | 23 | **Virtual Environment** 24 | ``` 25 | conda create -n MiSLAS python==3.7 26 | source activate MiSLAS 27 | ``` 28 | 29 | **Install MiSLAS** 30 | ``` 31 | git clone https://github.com/Jia-Research-Lab/MiSLAS.git 32 | cd MiSLAS 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | **Dataset Preparation** 37 | * [CIFAR-10 & CIFAR-100](https://www.cs.toronto.edu/~kriz/cifar.html) 38 | * [ImageNet](http://image-net.org/index) 39 | * [iNaturalist 2018](https://github.com/visipedia/inat_comp/tree/master/2018) 40 | * [Places](http://places2.csail.mit.edu/download.html) 41 | 42 | Change the `data_path` in `config/*/*.yaml` accordingly. 43 | 44 | ## Training 45 | 46 | **Stage-1**: 47 | 48 | To train a model for Stage-1 with *mixup*, run: 49 | 50 | (one GPU for CIFAR-10-LT & CIFAR-100-LT, four GPUs for ImageNet-LT, iNaturalist 2018, and Places-LT) 51 | 52 | ``` 53 | python train_stage1.py --cfg ./config/DATASETNAME/DATASETNAME_ARCH_stage1_mixup.yaml 54 | ``` 55 | 56 | `DATASETNAME` can be selected from `cifar10`, `cifar100`, `imagenet`, `ina2018`, and `places`. 57 | 58 | `ARCH` can be `resnet32` for `cifar10/100`, `resnet50/101/152` for `imagenet`, `resnet50` for `ina2018`, and `resnet152` for `places`, respectively. 59 | 60 | **Stage-2**: 61 | 62 | To train a model for Stage-2 with *one GPU* (all the above datasets), run: 63 | 64 | ``` 65 | python train_stage2.py --cfg ./config/DATASETNAME/DATASETNAME_ARCH_stage2_mislas.yaml resume /path/to/checkpoint/stage1 66 | ``` 67 | 68 | The saved folder (including logs and checkpoints) is organized as follows. 69 | ``` 70 | MiSLAS 71 | ├── saved 72 | │ ├── modelname_date 73 | │ │ ├── ckps 74 | │ │ │ ├── current.pth.tar 75 | │ │ │ └── model_best.pth.tar 76 | │ │ └── logs 77 | │ │ └── modelname.txt 78 | │ ... 79 | ``` 80 | ## Evaluation 81 | 82 | To evaluate a trained model, run: 83 | 84 | ``` 85 | python eval.py --cfg ./config/DATASETNAME/DATASETNAME_ARCH_stage1_mixup.yaml resume /path/to/checkpoint/stage1 86 | python eval.py --cfg ./config/DATASETNAME/DATASETNAME_ARCH_stage2_mislas.yaml resume /path/to/checkpoint/stage2 87 | ``` 88 | 89 | ## Results and Models 90 | 91 | **1) CIFAR-10-LT and CIFAR-100-LT** 92 | 93 | * Stage-1 (*mixup*): 94 | 95 | | Dataset | Top-1 Accuracy | ECE (15 bins) | Model | 96 | | -------------------- | -------------- | ------------- | ----- | 97 | | CIFAR-10-LT IF=10 | 87.6% | 11.9% | [link](https://drive.google.com/file/d/1dV1hchsIR5kTSqSOhdEs6nnXApcH5wEG/view?usp=sharing) | 98 | | CIFAR-10-LT IF=50 | 78.1% | 2.49% | [link](https://drive.google.com/file/d/1LoczjQRK20u_HpFMLmzeT0pVCp3V-gyf/view?usp=sharing) | 99 | | CIFAR-10-LT IF=100 | 72.8% | 2.14% | [link](https://drive.google.com/file/d/1TFetlV4MT4zjKEAPKcZuzmY2Dgtcqmsd/view?usp=sharing) | 100 | | CIFAR-100-LT IF=10 | 59.1% | 5.24% | [link](https://drive.google.com/file/d/1BmLjPReBoH6LJwl5x8_zSPnm1f6N_Cp0/view?usp=sharing) | 101 | | CIFAR-100-LT IF=50 | 45.4% | 4.33% | [link](https://drive.google.com/file/d/1l0LfZozJxWgzKp2IgM9mSpfwjTsIC-Mg/view?usp=sharing) | 102 | | CIFAR-100-LT IF=100 | 39.5% | 8.82% | [link](https://drive.google.com/file/d/15dHVdkI8J-oKkeQqyj6FtrHtIpO_TYfq/view?usp=sharing) | 103 | 104 | * Stage-2 (*MiSLAS*): 105 | 106 | | Dataset | Top-1 Accuracy | ECE (15 bins) | Model | 107 | | -------------------- | -------------- | ------------- | ----- | 108 | | CIFAR-10-LT IF=10 | 90.0% | 1.20% | [link](https://drive.google.com/file/d/1iST8Tr2LQ8nIjTNT1CKiQ-1T-RKxAvqr/view?usp=sharing) | 109 | | CIFAR-10-LT IF=50 | 85.7% | 2.01% | [link](https://drive.google.com/file/d/15bfA7uJsyM8eTwoptwp452kStk6FYT7v/view?usp=sharing) | 110 | | CIFAR-10-LT IF=100 | 82.5% | 3.66% | [link](https://drive.google.com/file/d/1KOTkjTOhIP5UOhqvHGJzEqq4_kQGKSJY/view?usp=sharing) | 111 | | CIFAR-100-LT IF=10 | 63.2% | 1.73% | [link](https://drive.google.com/file/d/1N2ai-l1hsbXTp_25Hoh5BSoAmR1_0UVD/view?usp=sharing) | 112 | | CIFAR-100-LT IF=50 | 52.3% | 2.47% | [link](https://drive.google.com/file/d/1Z2nukCMTG0cMmGXzZip3zIwv2WB5cOiZ/view?usp=sharing) | 113 | | CIFAR-100-LT IF=100 | 47.0% | 4.83% | [link](https://drive.google.com/file/d/1bX3eM-hlxGvEGuHBcfNhuz6VNp32Y0IQ/view?usp=sharing) | 114 | 115 | *Note: To obtain better performance, we highly recommend changing the weight decay 2e-4 to 5e-4 on CIFAR-LT.* 116 | 117 | **2) Large-scale Datasets** 118 | 119 | * Stage-1 (*mixup*): 120 | 121 | | Dataset | Arch | Top-1 Accuracy | ECE (15 bins) | Model | 122 | | ----------- | ---------- | -------------- | ------------- | ----- | 123 | | ImageNet-LT | ResNet-50 | 45.5% | 7.98% | [link](https://drive.google.com/file/d/1QKVnK7n75q465ppf7wkK4jzZvZJE_BPi/view?usp=sharing) | 124 | | iNa'2018 | ResNet-50 | 66.9% | 5.37% | [link](https://drive.google.com/file/d/1wvj-cITz8Ps1TksLHi_KoGsq9CecXcVt/view?usp=sharing) | 125 | | Places-LT | ResNet-152 | 29.4% | 16.7% | [link](https://drive.google.com/file/d/1Tx-tY5Y8_-XuGn9ZdSxtAm0onOsKWhUH/view?usp=sharing) | 126 | 127 | * Stage-2 (*MiSLAS*): 128 | 129 | | Dataset | Arch | Top-1 Accuracy | ECE (15 bins) | Model | 130 | | ----------- | ---------- | -------------- | ------------- | ----- | 131 | | ImageNet-LT | ResNet-50 | 52.7% | 1.78% | [link](https://drive.google.com/file/d/1ofJKlUJZQjjkoFU9MLI08UP2uBvywRgF/view?usp=sharing) | 132 | | iNa'2018 | ResNet-50 | 71.6% | 7.67% | [link](https://drive.google.com/file/d/1crOo3INxqkz8ZzKZt9pH4aYb3-ep4lo-/view?usp=sharing) | 133 | | Places-LT | ResNet-152 | 40.4% | 3.41% | [link](https://drive.google.com/file/d/1DgL0aN3UadI3UoHU6TO7M6UD69QgvnbT/view?usp=sharing) | 134 | 135 | ## Citation 136 | 137 | Please consider citing MiSLAS in your publications if it helps your research. :) 138 | 139 | ```bib 140 | @inproceedings{zhong2021mislas, 141 | title={Improving Calibration for Long-Tailed Recognition}, 142 | author={Zhisheng Zhong, Jiequan Cui, Shu Liu, and Jiaya Jia}, 143 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 144 | year={2021}, 145 | } 146 | ``` 147 | 148 | ## Contact 149 | 150 | If you have any questions about our work, feel free to contact us through email (Zhisheng Zhong: zszhong@pku.edu.cn) or Github issues. 151 | -------------------------------------------------------------------------------- /assets/MiSLAS.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/MiSLAS/0fc381e333d8a11170851b7af2058c59840969e0/assets/MiSLAS.PNG -------------------------------------------------------------------------------- /assets/MiSLAS.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/MiSLAS/0fc381e333d8a11170851b7af2058c59840969e0/assets/MiSLAS.pdf -------------------------------------------------------------------------------- /config/cifar10/cifar10_imb001_stage1_mixup.yaml: -------------------------------------------------------------------------------- 1 | name: cifar10_imb001_stage1_mixup 2 | print_freq: 40 3 | workers: 16 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | # dataset & model setting 8 | dataset: 'cifar10' 9 | data_path: './data/cifar10' 10 | num_classes: 10 11 | imb_factor: 0.01 12 | backbone: 'resnet32_fe' 13 | resume: '' 14 | head_class_idx: 15 | - 0 16 | - 3 17 | med_class_idx: 18 | - 3 19 | - 7 20 | tail_class_idx: 21 | - 7 22 | - 10 23 | 24 | 25 | # distributed training 26 | deterministic: False 27 | distributed: False 28 | gpu: null 29 | world_size: -1 30 | rank: -1 31 | dist_url: 'tcp://224.66.41.62:23456' 32 | dist_backend: 'nccl' 33 | multiprocessing_distributed: False 34 | 35 | 36 | 37 | # Train 38 | mode: 'stage1' 39 | lr: 0.1 40 | batch_size: 128 41 | weight_decay: 2e-4 42 | num_epochs: 200 43 | momentum: 0.9 44 | cos: False 45 | mixup: True 46 | alpha: 1.0 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /config/cifar10/cifar10_imb001_stage2_mislas.yaml: -------------------------------------------------------------------------------- 1 | name: cifar10_imb001_stage2_mislas 2 | print_freq: 40 3 | workers: 16 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | # dataset & model setting 8 | dataset: 'cifar10' 9 | data_path: './data/cifar10' 10 | num_classes: 10 11 | imb_factor: 0.01 12 | backbone: 'resnet32_fe' 13 | resume: 'Path/to/Stage1_checkpoint.pth.tar' 14 | head_class_idx: 15 | - 0 16 | - 3 17 | med_class_idx: 18 | - 3 19 | - 7 20 | tail_class_idx: 21 | - 7 22 | - 10 23 | 24 | 25 | # distributed training 26 | deterministic: False 27 | distributed: False 28 | gpu: null 29 | world_size: -1 30 | rank: -1 31 | dist_url: 'tcp://224.66.41.62:23456' 32 | dist_backend: 'nccl' 33 | multiprocessing_distributed: False 34 | 35 | 36 | 37 | # Train 38 | mode: 'stage2' 39 | smooth_head: 0.3 40 | smooth_tail: 0.0 41 | shift_bn: False 42 | lr_factor: 0.5 43 | lr: 0.1 44 | batch_size: 128 45 | weight_decay: 2e-4 46 | num_epochs: 10 47 | momentum: 0.9 48 | mixup: False 49 | alpha: null 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/cifar10/cifar10_imb002_stage1_mixup.yaml: -------------------------------------------------------------------------------- 1 | name: cifar10_imb002_stage1_mixup 2 | print_freq: 40 3 | workers: 16 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | # dataset & model setting 8 | dataset: 'cifar10' 9 | data_path: './data/cifar10' 10 | num_classes: 10 11 | imb_factor: 0.02 12 | backbone: 'resnet32_fe' 13 | resume: '' 14 | head_class_idx: 15 | - 0 16 | - 3 17 | med_class_idx: 18 | - 3 19 | - 7 20 | tail_class_idx: 21 | - 7 22 | - 10 23 | 24 | 25 | # distributed training 26 | deterministic: False 27 | distributed: False 28 | gpu: null 29 | world_size: -1 30 | rank: -1 31 | dist_url: 'tcp://224.66.41.62:23456' 32 | dist_backend: 'nccl' 33 | multiprocessing_distributed: False 34 | 35 | 36 | 37 | # Train 38 | mode: 'stage1' 39 | lr: 0.1 40 | batch_size: 128 41 | weight_decay: 2e-4 42 | num_epochs: 200 43 | momentum: 0.9 44 | cos: False 45 | mixup: True 46 | alpha: 1.0 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /config/cifar10/cifar10_imb002_stage2_mislas.yaml: -------------------------------------------------------------------------------- 1 | name: cifar10_imb002_stage2_mislas 2 | print_freq: 40 3 | workers: 16 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | # dataset & model setting 8 | dataset: 'cifar10' 9 | data_path: './data/cifar10' 10 | num_classes: 10 11 | imb_factor: 0.02 12 | backbone: 'resnet32_fe' 13 | resume: 'Path/to/Stage1_checkpoint.pth.tar' 14 | head_class_idx: 15 | - 0 16 | - 3 17 | med_class_idx: 18 | - 3 19 | - 7 20 | tail_class_idx: 21 | - 7 22 | - 10 23 | 24 | 25 | # distributed training 26 | deterministic: False 27 | distributed: False 28 | gpu: null 29 | world_size: -1 30 | rank: -1 31 | dist_url: 'tcp://224.66.41.62:23456' 32 | dist_backend: 'nccl' 33 | multiprocessing_distributed: False 34 | 35 | 36 | 37 | # Train 38 | mode: 'stage2' 39 | smooth_head: 0.2 40 | smooth_tail: 0.0 41 | shift_bn: False 42 | lr_factor: 0.2 43 | lr: 0.1 44 | batch_size: 128 45 | weight_decay: 2e-4 46 | num_epochs: 10 47 | momentum: 0.9 48 | mixup: False 49 | alpha: null 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/cifar10/cifar10_imb01_stage1_mixup.yaml: -------------------------------------------------------------------------------- 1 | name: cifar10_imb01_stage1_mixup 2 | print_freq: 40 3 | workers: 16 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | # dataset & model setting 8 | dataset: 'cifar10' 9 | data_path: './data/cifar10' 10 | num_classes: 10 11 | imb_factor: 0.1 12 | backbone: 'resnet32_fe' 13 | resume: '' 14 | head_class_idx: 15 | - 0 16 | - 3 17 | med_class_idx: 18 | - 3 19 | - 7 20 | tail_class_idx: 21 | - 7 22 | - 10 23 | 24 | 25 | # distributed training 26 | deterministic: False 27 | distributed: False 28 | gpu: null 29 | world_size: -1 30 | rank: -1 31 | dist_url: 'tcp://224.66.41.62:23456' 32 | dist_backend: 'nccl' 33 | multiprocessing_distributed: False 34 | 35 | 36 | 37 | # Train 38 | mode: 'stage1' 39 | lr: 0.1 40 | batch_size: 128 41 | weight_decay: 2e-4 42 | num_epochs: 200 43 | momentum: 0.9 44 | cos: False 45 | mixup: True 46 | alpha: 1.0 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /config/cifar10/cifar10_imb01_stage2_mislas.yaml: -------------------------------------------------------------------------------- 1 | name: cifar10_imb01_stage2_mislas 2 | print_freq: 40 3 | workers: 16 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | # dataset & model setting 8 | dataset: 'cifar10' 9 | data_path: './data/cifar10' 10 | num_classes: 10 11 | imb_factor: 0.1 12 | backbone: 'resnet32_fe' 13 | resume: 'Path/to/Stage1_checkpoint.pth.tar' 14 | head_class_idx: 15 | - 0 16 | - 3 17 | med_class_idx: 18 | - 3 19 | - 7 20 | tail_class_idx: 21 | - 7 22 | - 10 23 | 24 | 25 | # distributed training 26 | deterministic: False 27 | distributed: False 28 | gpu: null 29 | world_size: -1 30 | rank: -1 31 | dist_url: 'tcp://224.66.41.62:23456' 32 | dist_backend: 'nccl' 33 | multiprocessing_distributed: False 34 | 35 | 36 | 37 | # Train 38 | mode: 'stage2' 39 | smooth_head: 0.1 40 | smooth_tail: 0.0 41 | shift_bn: False 42 | lr_factor: 0.2 43 | lr: 0.1 44 | batch_size: 128 45 | weight_decay: 2e-4 46 | num_epochs: 10 47 | momentum: 0.9 48 | mixup: False 49 | alpha: null 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/cifar100/cifar100_imb001_stage1_mixup.yaml: -------------------------------------------------------------------------------- 1 | name: cifar100_imb001_stage1_mixup 2 | print_freq: 40 3 | workers: 4 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | 8 | # dataset & model setting 9 | dataset: 'cifar100' 10 | data_path: './data/cifar100' 11 | num_classes: 100 12 | imb_factor: 0.01 13 | backbone: 'resnet32_fe' 14 | resume: '' 15 | head_class_idx: 16 | - 0 17 | - 36 18 | med_class_idx: 19 | - 36 20 | - 71 21 | tail_class_idx: 22 | - 71 23 | - 100 24 | 25 | 26 | # distributed training 27 | deterministic: True 28 | distributed: False 29 | gpu: null 30 | world_size: -1 31 | rank: -1 32 | dist_url: 'tcp://224.66.41.62:23456' 33 | dist_backend: 'nccl' 34 | multiprocessing_distributed: False 35 | 36 | 37 | # Train 38 | mode: 'stage1' 39 | lr: 0.1 40 | batch_size: 128 41 | weight_decay: 2e-4 42 | num_epochs: 200 43 | momentum: 0.9 44 | cos: False 45 | mixup: True 46 | alpha: 1.0 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /config/cifar100/cifar100_imb001_stage2_mislas.yaml: -------------------------------------------------------------------------------- 1 | name: cifar100_imb001_stage2_mislas 2 | print_freq: 40 3 | workers: 4 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | 8 | # dataset & model setting 9 | dataset: 'cifar100' 10 | data_path: './data/cifar100' 11 | num_classes: 100 12 | imb_factor: 0.01 13 | backbone: 'resnet32_fe' 14 | resume: 'Path/to/Stage1_checkpoint.pth.tar' 15 | head_class_idx: 16 | - 0 17 | - 36 18 | med_class_idx: 19 | - 36 20 | - 71 21 | tail_class_idx: 22 | - 71 23 | - 100 24 | 25 | 26 | # distributed training 27 | deterministic: False 28 | distributed: False 29 | gpu: null 30 | world_size: -1 31 | rank: -1 32 | dist_url: 'tcp://224.66.41.62:23456' 33 | dist_backend: 'nccl' 34 | multiprocessing_distributed: False 35 | 36 | 37 | # Train 38 | mode: 'stage2' 39 | smooth_head: 0.4 40 | smooth_tail: 0.1 41 | shift_bn: True 42 | lr_factor: 0.2 43 | lr: 0.1 44 | batch_size: 128 45 | weight_decay: 2e-4 46 | num_epochs: 10 47 | momentum: 0.9 48 | mixup: False 49 | alpha: null -------------------------------------------------------------------------------- /config/cifar100/cifar100_imb002_stage1_mixup.yaml: -------------------------------------------------------------------------------- 1 | name: cifar100_imb002_stage1_mixup 2 | print_freq: 40 3 | workers: 4 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | 8 | # dataset & model setting 9 | dataset: 'cifar100' 10 | data_path: './data/cifar100' 11 | num_classes: 100 12 | imb_factor: 0.02 13 | backbone: 'resnet32_fe' 14 | resume: '' 15 | head_class_idx: 16 | - 0 17 | - 36 18 | med_class_idx: 19 | - 36 20 | - 71 21 | tail_class_idx: 22 | - 71 23 | - 100 24 | 25 | 26 | # distributed training 27 | deterministic: False 28 | distributed: False 29 | gpu: null 30 | world_size: -1 31 | rank: -1 32 | dist_url: 'tcp://224.66.41.62:23456' 33 | dist_backend: 'nccl' 34 | multiprocessing_distributed: False 35 | 36 | 37 | # Train 38 | mode: 'stage1' 39 | lr: 0.1 40 | batch_size: 128 41 | weight_decay: 2e-4 42 | num_epochs: 200 43 | momentum: 0.9 44 | cos: False 45 | mixup: True 46 | alpha: 1.0 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /config/cifar100/cifar100_imb002_stage2_mislas.yaml: -------------------------------------------------------------------------------- 1 | name: cifar100_imb002_stage2_mislas 2 | print_freq: 40 3 | workers: 4 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | 8 | # dataset & model setting 9 | dataset: 'cifar100' 10 | data_path: './data/cifar100' 11 | num_classes: 100 12 | imb_factor: 0.02 13 | backbone: 'resnet32_fe' 14 | resume: 'Path/to/Stage1_checkpoint.pth.tar' 15 | head_class_idx: 16 | - 0 17 | - 36 18 | med_class_idx: 19 | - 36 20 | - 71 21 | tail_class_idx: 22 | - 71 23 | - 100 24 | 25 | 26 | # distributed training 27 | deterministic: False 28 | distributed: False 29 | gpu: null 30 | world_size: -1 31 | rank: -1 32 | dist_url: 'tcp://224.66.41.62:23456' 33 | dist_backend: 'nccl' 34 | multiprocessing_distributed: False 35 | 36 | 37 | # Train 38 | mode: 'stage2' 39 | smooth_head: 0.3 40 | smooth_tail: 0.0 41 | shift_bn: True 42 | lr_factor: 0.1 43 | lr: 0.1 44 | batch_size: 128 45 | weight_decay: 2e-4 46 | num_epochs: 10 47 | momentum: 0.9 48 | mixup: False 49 | alpha: null -------------------------------------------------------------------------------- /config/cifar100/cifar100_imb01_stage1_mixup.yaml: -------------------------------------------------------------------------------- 1 | name: cifar100_imb01_stage1_mixup 2 | print_freq: 40 3 | workers: 4 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | 8 | # dataset & model setting 9 | dataset: 'cifar100' 10 | data_path: './data/cifar100' 11 | num_classes: 100 12 | imb_factor: 0.1 13 | backbone: 'resnet32_fe' 14 | resume: '' 15 | head_class_idx: 16 | - 0 17 | - 36 18 | med_class_idx: 19 | - 36 20 | - 71 21 | tail_class_idx: 22 | - 71 23 | - 100 24 | 25 | 26 | # distributed training 27 | deterministic: False 28 | distributed: False 29 | gpu: null 30 | world_size: -1 31 | rank: -1 32 | dist_url: 'tcp://224.66.41.62:23456' 33 | dist_backend: 'nccl' 34 | multiprocessing_distributed: False 35 | 36 | 37 | # Train 38 | mode: 'stage1' 39 | lr: 0.1 40 | batch_size: 128 41 | weight_decay: 2e-4 42 | num_epochs: 200 43 | momentum: 0.9 44 | cos: False 45 | mixup: True 46 | alpha: 1.0 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /config/cifar100/cifar100_imb01_stage2_mislas.yaml: -------------------------------------------------------------------------------- 1 | name: cifar100_imb01_stage2_mislas 2 | print_freq: 40 3 | workers: 4 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | 8 | # dataset & model setting 9 | dataset: 'cifar100' 10 | data_path: './data/cifar100' 11 | num_classes: 100 12 | imb_factor: 0.1 13 | backbone: 'resnet32_fe' 14 | resume: 'Path/to/Stage1_checkpoint.pth.tar' 15 | head_class_idx: 16 | - 0 17 | - 36 18 | med_class_idx: 19 | - 36 20 | - 71 21 | tail_class_idx: 22 | - 71 23 | - 100 24 | 25 | 26 | # distributed training 27 | deterministic: False 28 | distributed: False 29 | gpu: null 30 | world_size: -1 31 | rank: -1 32 | dist_url: 'tcp://224.66.41.62:23456' 33 | dist_backend: 'nccl' 34 | multiprocessing_distributed: False 35 | 36 | 37 | # Train 38 | mode: 'stage2' 39 | smooth_head: 0.2 40 | smooth_tail: 0.0 41 | shift_bn: True 42 | lr_factor: 0.1 43 | lr: 0.1 44 | batch_size: 128 45 | weight_decay: 2e-4 46 | num_epochs: 10 47 | momentum: 0.9 48 | mixup: False 49 | alpha: null -------------------------------------------------------------------------------- /config/imagenet/imagenet_resnet50_stage1_mixup.yaml: -------------------------------------------------------------------------------- 1 | name: imagenet_resnet50_stage1_mixup 2 | print_freq: 100 3 | workers: 48 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | 8 | # dataset & model setting 9 | dataset: 'imagenet' 10 | data_path: 'Path/to/Data/ImageNet/' 11 | num_classes: 1000 12 | imb_factor: null 13 | backbone: 'resnet50_fe' 14 | resume: '' 15 | head_class_idx: 16 | - 0 17 | - 390 18 | med_class_idx: 19 | - 390 20 | - 835 21 | tail_class_idx: 22 | - 835 23 | - 1000 24 | 25 | 26 | # distributed training 27 | deterministic: False 28 | distributed: False 29 | gpu: null 30 | world_size: -1 31 | rank: -1 32 | dist_url: 'tcp://224.66.41.62:23456' 33 | dist_backend: 'nccl' 34 | multiprocessing_distributed: False 35 | 36 | 37 | # Train 38 | mode: 'stage1' 39 | lr: 0.1 40 | batch_size: 256 41 | weight_decay: 5e-4 42 | num_epochs: 180 43 | momentum: 0.9 44 | cos: True 45 | mixup: True 46 | alpha: 0.2 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /config/imagenet/imagenet_resnet50_stage2_mislas.yaml: -------------------------------------------------------------------------------- 1 | name: imagenet_resnet50_stage2_mislas 2 | print_freq: 100 3 | workers: 48 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | 8 | # dataset & model setting 9 | dataset: 'imagenet' 10 | data_path: 'Path/to/Data/ImageNet/' 11 | num_classes: 1000 12 | imb_factor: null 13 | backbone: 'resnet50_fe' 14 | resume: 'Path/to/Stage1_checkpoint.pth.tar' 15 | head_class_idx: 16 | - 0 17 | - 390 18 | med_class_idx: 19 | - 390 20 | - 835 21 | tail_class_idx: 22 | - 835 23 | - 1000 24 | 25 | 26 | # distributed training 27 | deterministic: False 28 | distributed: False 29 | gpu: null 30 | world_size: -1 31 | rank: -1 32 | dist_url: 'tcp://224.66.41.62:23456' 33 | dist_backend: 'nccl' 34 | multiprocessing_distributed: False 35 | 36 | 37 | # Train 38 | mode: 'stage2' 39 | smooth_head: 0.3 40 | smooth_tail: 0.0 41 | shift_bn: True 42 | lr_factor: 0.05 43 | lr: 0.1 44 | batch_size: 256 45 | weight_decay: 5e-4 46 | num_epochs: 10 47 | momentum: 0.9 48 | mixup: False 49 | alpha: null 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /config/ina2018/ina2018_resnet50_stage1_mixup.yaml: -------------------------------------------------------------------------------- 1 | name: ina2018_resnet50_stage1_mixup 2 | print_freq: 200 3 | workers: 48 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | 8 | # dataset & model setting 9 | dataset: 'ina2018' 10 | data_path: 'Path/to/Data/iNaturalist2018' 11 | num_classes: 8142 12 | imb_factor: null 13 | backbone: 'resnet50_fe' 14 | resume: '' 15 | head_class_idx: 16 | - 0 17 | - 842 18 | med_class_idx: 19 | - 842 20 | - 4543 21 | tail_class_idx: 22 | - 4543 23 | - 8142 24 | 25 | 26 | # distributed training 27 | deterministic: False 28 | distributed: False 29 | gpu: null 30 | world_size: -1 31 | rank: -1 32 | dist_url: 'tcp://224.66.41.62:23456' 33 | dist_backend: 'nccl' 34 | multiprocessing_distributed: False 35 | 36 | 37 | # Train 38 | mode: 'stage1' 39 | lr: 0.1 40 | batch_size: 256 41 | weight_decay: 1e-4 42 | num_epochs: 200 43 | momentum: 0.9 44 | cos: True 45 | mixup: True 46 | alpha: 0.2 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /config/ina2018/ina2018_resnet50_stage2_mislas.yaml: -------------------------------------------------------------------------------- 1 | name: ina2018_resnet50_stage2_mislas 2 | print_freq: 200 3 | workers: 48 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | 8 | # dataset & model setting 9 | dataset: 'ina2018' 10 | data_path: 'Path/to/Data/iNaturalist2018' 11 | num_classes: 8142 12 | imb_factor: null 13 | backbone: 'resnet50_fe' 14 | resume: 'Path/to/Stage1_checkpoint.pth.tar' 15 | head_class_idx: 16 | - 0 17 | - 842 18 | med_class_idx: 19 | - 842 20 | - 4543 21 | tail_class_idx: 22 | - 4543 23 | - 8142 24 | 25 | 26 | # distributed training 27 | deterministic: False 28 | distributed: False 29 | gpu: null 30 | world_size: -1 31 | rank: -1 32 | dist_url: 'tcp://224.66.41.62:23456' 33 | dist_backend: 'nccl' 34 | multiprocessing_distributed: False 35 | 36 | 37 | # Train 38 | mode: 'stage2' 39 | smooth_head: 0.3 40 | smooth_tail: 0.0 41 | shift_bn: True 42 | lr_factor: 0.05 43 | lr: 0.1 44 | batch_size: 256 45 | weight_decay: 1e-4 46 | num_epochs: 10 47 | momentum: 0.9 48 | mixup: False 49 | alpha: null 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /config/places/places_resnet152_stage1_mixup.yaml: -------------------------------------------------------------------------------- 1 | name: places_resnet152_stage1_mixup 2 | print_freq: 100 3 | workers: 48 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | 8 | # dataset & model setting 9 | dataset: 'places' 10 | data_path: 'Path/to/Data/Places365/' 11 | num_classes: 365 12 | imb_factor: null 13 | backbone: 'resnet152_fe' 14 | resume: '' 15 | head_class_idx: 16 | - 0 17 | - 131 18 | med_class_idx: 19 | - 131 20 | - 288 21 | tail_class_idx: 22 | - 288 23 | - 365 24 | 25 | 26 | # distributed training 27 | deterministic: False 28 | distributed: False 29 | gpu: null 30 | world_size: -1 31 | rank: -1 32 | dist_url: 'tcp://224.66.41.62:23456' 33 | dist_backend: 'nccl' 34 | multiprocessing_distributed: False 35 | 36 | 37 | # Train 38 | mode: 'stage1' 39 | lr: 0.1 40 | batch_size: 256 41 | weight_decay: 5e-4 42 | num_epochs: 90 43 | momentum: 0.9 44 | cos: True 45 | mixup: True 46 | alpha: 0.2 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /config/places/places_resnet152_stage2_mislas.yaml: -------------------------------------------------------------------------------- 1 | name: places_resnet152_stage2_mislas 2 | print_freq: 100 3 | workers: 48 4 | log_dir: 'logs' 5 | model_dir: 'ckps' 6 | 7 | 8 | # dataset & model setting 9 | dataset: 'places' 10 | data_path: 'Path/to/Data/Places365/' 11 | num_classes: 365 12 | imb_factor: null 13 | backbone: 'resnet152_fe' 14 | resume: 'Path/to/Stage1_checkpoint.pth.tar' 15 | head_class_idx: 16 | - 0 17 | - 131 18 | med_class_idx: 19 | - 131 20 | - 288 21 | tail_class_idx: 22 | - 288 23 | - 365 24 | 25 | 26 | # distributed training 27 | deterministic: False 28 | distributed: False 29 | gpu: null 30 | world_size: -1 31 | rank: -1 32 | dist_url: 'tcp://224.66.41.62:23456' 33 | dist_backend: 'nccl' 34 | multiprocessing_distributed: False 35 | 36 | 37 | # Train 38 | mode: 'stage2' 39 | smooth_head: 0.4 40 | smooth_tail: 0.1 41 | shift_bn: True 42 | lr_factor: 0.05 43 | lr: 0.1 44 | batch_size: 256 45 | weight_decay: 5e-4 46 | num_epochs: 10 47 | momentum: 0.9 48 | mixup: False 49 | alpha: null 50 | -------------------------------------------------------------------------------- /datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .sampler import ClassAwareSampler 3 | 4 | import torch 5 | import torchvision 6 | from torchvision import transforms 7 | import torchvision.datasets 8 | 9 | class IMBALANCECIFAR10(torchvision.datasets.CIFAR10): 10 | cls_num = 10 11 | 12 | def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True, 13 | transform=None, target_transform=None, 14 | download=False): 15 | super(IMBALANCECIFAR10, self).__init__(root, train, transform, target_transform, download) 16 | np.random.seed(rand_number) 17 | img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor) 18 | self.gen_imbalanced_data(img_num_list) 19 | 20 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor): 21 | img_max = len(self.data) / cls_num 22 | img_num_per_cls = [] 23 | if imb_type == 'exp': 24 | for cls_idx in range(cls_num): 25 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) 26 | img_num_per_cls.append(int(num)) 27 | elif imb_type == 'step': 28 | for cls_idx in range(cls_num // 2): 29 | img_num_per_cls.append(int(img_max)) 30 | for cls_idx in range(cls_num // 2): 31 | img_num_per_cls.append(int(img_max * imb_factor)) 32 | else: 33 | img_num_per_cls.extend([int(img_max)] * cls_num) 34 | return img_num_per_cls 35 | 36 | def gen_imbalanced_data(self, img_num_per_cls): 37 | new_data = [] 38 | new_targets = [] 39 | targets_np = np.array(self.targets, dtype=np.int64) 40 | classes = np.unique(targets_np) 41 | # np.random.shuffle(classes) 42 | self.num_per_cls_dict = dict() 43 | for the_class, the_img_num in zip(classes, img_num_per_cls): 44 | self.num_per_cls_dict[the_class] = the_img_num 45 | idx = np.where(targets_np == the_class)[0] 46 | np.random.shuffle(idx) 47 | selec_idx = idx[:the_img_num] 48 | new_data.append(self.data[selec_idx, ...]) 49 | new_targets.extend([the_class, ] * the_img_num) 50 | new_data = np.vstack(new_data) 51 | self.data = new_data 52 | self.targets = new_targets 53 | 54 | def get_cls_num_list(self): 55 | cls_num_list = [] 56 | for i in range(self.cls_num): 57 | cls_num_list.append(self.num_per_cls_dict[i]) 58 | return cls_num_list 59 | 60 | 61 | 62 | class CIFAR10_LT(object): 63 | 64 | def __init__(self, distributed, root='./data/cifar10', imb_type='exp', 65 | imb_factor=0.01, batch_size=128, num_works=40): 66 | 67 | train_transform = transforms.Compose([ 68 | transforms.RandomCrop(32, padding=4), 69 | transforms.RandomHorizontalFlip(), 70 | transforms.ToTensor(), 71 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 72 | ]) 73 | 74 | 75 | eval_transform = transforms.Compose([ 76 | transforms.ToTensor(), 77 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 78 | ]) 79 | 80 | 81 | train_dataset = IMBALANCECIFAR10(root=root, imb_type=imb_type, imb_factor=imb_factor, rand_number=0, train=True, download=True, transform=train_transform) 82 | eval_dataset = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=eval_transform) 83 | 84 | self.cls_num_list = train_dataset.get_cls_num_list() 85 | 86 | self.dist_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed else None 87 | self.train_instance = torch.utils.data.DataLoader( 88 | train_dataset, 89 | batch_size=batch_size, shuffle=True, 90 | num_workers=num_works, pin_memory=True, sampler=self.dist_sampler) 91 | 92 | balance_sampler = ClassAwareSampler(train_dataset) 93 | self.train_balance = torch.utils.data.DataLoader( 94 | train_dataset, 95 | batch_size=batch_size, shuffle=False, 96 | num_workers=num_works, pin_memory=True, sampler=balance_sampler) 97 | 98 | self.eval = torch.utils.data.DataLoader( 99 | eval_dataset, 100 | batch_size=batch_size, shuffle=False, 101 | num_workers=num_works, pin_memory=True) 102 | -------------------------------------------------------------------------------- /datasets/cifar100.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .sampler import ClassAwareSampler 3 | 4 | import torch 5 | import torchvision 6 | from torchvision import transforms 7 | import torchvision.datasets 8 | 9 | class IMBALANCECIFAR100(torchvision.datasets.CIFAR100): 10 | cls_num = 100 11 | 12 | def __init__(self, root, imb_type='exp', imb_factor=0.01, rand_number=0, train=True, 13 | transform=None, target_transform=None, 14 | download=False): 15 | super(IMBALANCECIFAR100, self).__init__(root, train, transform, target_transform, download) 16 | np.random.seed(rand_number) 17 | img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imb_factor) 18 | self.gen_imbalanced_data(img_num_list) 19 | 20 | def get_img_num_per_cls(self, cls_num, imb_type, imb_factor): 21 | img_max = len(self.data) / cls_num 22 | img_num_per_cls = [] 23 | if imb_type == 'exp': 24 | for cls_idx in range(cls_num): 25 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) 26 | img_num_per_cls.append(int(num)) 27 | elif imb_type == 'step': 28 | for cls_idx in range(cls_num // 2): 29 | img_num_per_cls.append(int(img_max)) 30 | for cls_idx in range(cls_num // 2): 31 | img_num_per_cls.append(int(img_max * imb_factor)) 32 | else: 33 | img_num_per_cls.extend([int(img_max)] * cls_num) 34 | return img_num_per_cls 35 | 36 | def gen_imbalanced_data(self, img_num_per_cls): 37 | new_data = [] 38 | new_targets = [] 39 | targets_np = np.array(self.targets, dtype=np.int64) 40 | classes = np.unique(targets_np) 41 | # np.random.shuffle(classes) 42 | self.num_per_cls_dict = dict() 43 | for the_class, the_img_num in zip(classes, img_num_per_cls): 44 | self.num_per_cls_dict[the_class] = the_img_num 45 | idx = np.where(targets_np == the_class)[0] 46 | np.random.shuffle(idx) 47 | selec_idx = idx[:the_img_num] 48 | new_data.append(self.data[selec_idx, ...]) 49 | new_targets.extend([the_class, ] * the_img_num) 50 | new_data = np.vstack(new_data) 51 | self.data = new_data 52 | self.targets = new_targets 53 | 54 | def get_cls_num_list(self): 55 | cls_num_list = [] 56 | for i in range(self.cls_num): 57 | cls_num_list.append(self.num_per_cls_dict[i]) 58 | return cls_num_list 59 | 60 | 61 | 62 | 63 | 64 | 65 | class CIFAR100_LT(object): 66 | def __init__(self, distributed, root='./data/cifar100', imb_type='exp', 67 | imb_factor=0.01, batch_size=128, num_works=40): 68 | 69 | train_transform = transforms.Compose([ 70 | transforms.RandomCrop(32, padding=4), 71 | transforms.RandomHorizontalFlip(), 72 | transforms.ToTensor(), 73 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 74 | ]) 75 | 76 | 77 | 78 | eval_transform = transforms.Compose([ 79 | transforms.ToTensor(), 80 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 81 | ]) 82 | 83 | 84 | train_dataset = IMBALANCECIFAR100(root=root, imb_type=imb_type, imb_factor=imb_factor, rand_number=0, train=True, download=True, transform=train_transform) 85 | eval_dataset = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=eval_transform) 86 | 87 | self.cls_num_list = train_dataset.get_cls_num_list() 88 | 89 | self.dist_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed else None 90 | self.train_instance = torch.utils.data.DataLoader( 91 | train_dataset, 92 | batch_size=batch_size, shuffle=True, 93 | num_workers=num_works, pin_memory=True, sampler=self.dist_sampler) 94 | 95 | balance_sampler = ClassAwareSampler(train_dataset) 96 | self.train_balance = torch.utils.data.DataLoader( 97 | train_dataset, 98 | batch_size=batch_size, shuffle=False, 99 | num_workers=num_works, pin_memory=True, sampler=balance_sampler) 100 | 101 | self.eval = torch.utils.data.DataLoader( 102 | eval_dataset, 103 | batch_size=batch_size, shuffle=False, 104 | num_workers=num_works, pin_memory=True) -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | import torch 6 | import torchvision 7 | import torchvision.datasets 8 | from torchvision import transforms 9 | from torch.utils.data import Dataset 10 | 11 | from .sampler import ClassAwareSampler 12 | 13 | 14 | class LT_Dataset(Dataset): 15 | num_classes = 1000 16 | 17 | def __init__(self, root, txt, transform=None): 18 | self.img_path = [] 19 | self.targets = [] 20 | self.transform = transform 21 | with open(txt) as f: 22 | for line in f: 23 | self.img_path.append(os.path.join(root, line.split()[0])) 24 | self.targets.append(int(line.split()[1])) 25 | 26 | cls_num_list_old = [np.sum(np.array(self.targets) == i) for i in range(self.num_classes)] 27 | 28 | # generate class_map: class index sort by num (descending) 29 | sorted_classes = np.argsort(-np.array(cls_num_list_old)) 30 | self.class_map = [0 for i in range(self.num_classes)] 31 | for i in range(self.num_classes): 32 | self.class_map[sorted_classes[i]] = i 33 | 34 | self.targets = np.array(self.class_map)[self.targets].tolist() 35 | 36 | self.class_data = [[] for i in range(self.num_classes)] 37 | for i in range(len(self.targets)): 38 | j = self.targets[i] 39 | self.class_data[j].append(i) 40 | 41 | self.cls_num_list = [np.sum(np.array(self.targets)==i) for i in range(self.num_classes)] 42 | 43 | 44 | def __len__(self): 45 | return len(self.targets) 46 | 47 | def __getitem__(self, index): 48 | path = self.img_path[index] 49 | target = self.targets[index] 50 | 51 | with open(path, 'rb') as f: 52 | sample = Image.open(f).convert('RGB') 53 | if self.transform is not None: 54 | sample = self.transform(sample) 55 | return sample, target 56 | 57 | 58 | 59 | class LT_Dataset_Eval(Dataset): 60 | num_classes = 1000 61 | 62 | def __init__(self, root, txt, class_map, transform=None): 63 | self.img_path = [] 64 | self.targets = [] 65 | self.transform = transform 66 | self.class_map = class_map 67 | with open(txt) as f: 68 | for line in f: 69 | self.img_path.append(os.path.join(root, line.split()[0])) 70 | self.targets.append(int(line.split()[1])) 71 | 72 | self.targets = np.array(self.class_map)[self.targets].tolist() 73 | 74 | def __len__(self): 75 | return len(self.targets) 76 | 77 | def __getitem__(self, index): 78 | path = self.img_path[index] 79 | target = self.targets[index] 80 | 81 | with open(path, 'rb') as f: 82 | sample = Image.open(f).convert('RGB') 83 | if self.transform is not None: 84 | sample = self.transform(sample) 85 | return sample, target 86 | 87 | 88 | class ImageNet_LT(object): 89 | def __init__(self, distributed, root="", batch_size=60, num_works=40): 90 | 91 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 92 | 93 | transform_train = transforms.Compose([ 94 | transforms.RandomResizedCrop(224), 95 | transforms.RandomHorizontalFlip(), 96 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0), 97 | transforms.ToTensor(), 98 | normalize, 99 | ]) 100 | 101 | 102 | transform_test = transforms.Compose([ 103 | transforms.Resize(256), 104 | transforms.CenterCrop(224), 105 | transforms.ToTensor(), 106 | normalize, 107 | ]) 108 | 109 | 110 | train_txt = "./datasets/data_txt/ImageNet_LT_train.txt" 111 | eval_txt = "./datasets/data_txt/ImageNet_LT_test.txt" 112 | 113 | train_dataset = LT_Dataset(root, train_txt, transform=transform_train) 114 | eval_dataset = LT_Dataset_Eval(root, eval_txt, transform=transform_test, class_map=train_dataset.class_map) 115 | 116 | self.cls_num_list = train_dataset.cls_num_list 117 | 118 | self.dist_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed else None 119 | self.train_instance = torch.utils.data.DataLoader( 120 | train_dataset, 121 | batch_size=batch_size, shuffle=True, 122 | num_workers=num_works, pin_memory=True, sampler=self.dist_sampler) 123 | 124 | balance_sampler = ClassAwareSampler(train_dataset) 125 | self.train_balance = torch.utils.data.DataLoader( 126 | train_dataset, 127 | batch_size=batch_size, shuffle=False, 128 | num_workers=num_works, pin_memory=True, sampler=balance_sampler) 129 | 130 | self.eval = torch.utils.data.DataLoader( 131 | eval_dataset, 132 | batch_size=batch_size, shuffle=False, 133 | num_workers=num_works, pin_memory=True) -------------------------------------------------------------------------------- /datasets/ina2018.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | import torch 6 | import torchvision 7 | import torchvision.datasets 8 | from torchvision import transforms 9 | from torch.utils.data import Dataset 10 | 11 | from .sampler import ClassAwareSampler 12 | 13 | 14 | class LT_Dataset(Dataset): 15 | num_classes = 8142 16 | 17 | def __init__(self, root, txt, transform=None): 18 | self.img_path = [] 19 | self.targets = [] 20 | self.transform = transform 21 | with open(txt) as f: 22 | for line in f: 23 | self.img_path.append(os.path.join(root, line.split()[0])) 24 | self.targets.append(int(line.split()[1])) 25 | 26 | cls_num_list_old = [np.sum(np.array(self.targets) == i) for i in range(self.num_classes)] 27 | 28 | # generate class_map: class index sort by num (descending) 29 | sorted_classes = np.argsort(-np.array(cls_num_list_old)) 30 | self.class_map = [0 for i in range(self.num_classes)] 31 | for i in range(self.num_classes): 32 | self.class_map[sorted_classes[i]] = i 33 | 34 | self.targets = np.array(self.class_map)[self.targets].tolist() 35 | 36 | self.class_data = [[] for i in range(self.num_classes)] 37 | for i in range(len(self.targets)): 38 | j = self.targets[i] 39 | self.class_data[j].append(i) 40 | 41 | self.cls_num_list = [np.sum(np.array(self.targets)==i) for i in range(self.num_classes)] 42 | 43 | 44 | def __len__(self): 45 | return len(self.targets) 46 | 47 | def __getitem__(self, index): 48 | path = self.img_path[index] 49 | target = self.targets[index] 50 | 51 | with open(path, 'rb') as f: 52 | sample = Image.open(f).convert('RGB') 53 | if self.transform is not None: 54 | sample = self.transform(sample) 55 | return sample, target 56 | 57 | 58 | 59 | class LT_Dataset_Eval(Dataset): 60 | num_classes = 8142 61 | 62 | def __init__(self, root, txt, class_map, transform=None): 63 | self.img_path = [] 64 | self.targets = [] 65 | self.transform = transform 66 | self.class_map = class_map 67 | with open(txt) as f: 68 | for line in f: 69 | self.img_path.append(os.path.join(root, line.split()[0])) 70 | self.targets.append(int(line.split()[1])) 71 | 72 | self.targets = np.array(self.class_map)[self.targets].tolist() 73 | 74 | def __len__(self): 75 | return len(self.targets) 76 | 77 | def __getitem__(self, index): 78 | path = self.img_path[index] 79 | target = self.targets[index] 80 | 81 | with open(path, 'rb') as f: 82 | sample = Image.open(f).convert('RGB') 83 | if self.transform is not None: 84 | sample = self.transform(sample) 85 | return sample, target 86 | 87 | 88 | class iNa2018(object): 89 | def __init__(self, distributed, root="", batch_size=60, num_works=40): 90 | 91 | normalize = transforms.Normalize(mean=[0.466, 0.471, 0.380], std=[0.195, 0.194, 0.192]) 92 | 93 | transform_train = transforms.Compose([ 94 | transforms.RandomResizedCrop(224), 95 | transforms.RandomHorizontalFlip(), 96 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0), 97 | transforms.ToTensor(), 98 | normalize, 99 | ]) 100 | 101 | 102 | transform_test = transforms.Compose([ 103 | transforms.Resize(256), 104 | transforms.CenterCrop(224), 105 | transforms.ToTensor(), 106 | normalize, 107 | ]) 108 | 109 | train_txt = "./datasets/data_txt/iNaturalist18_train.txt" 110 | eval_txt = "./datasets/data_txt/iNaturalist18_val.txt" 111 | 112 | train_dataset = LT_Dataset(root, train_txt, transform=transform_train) 113 | eval_dataset = LT_Dataset_Eval(root, eval_txt, transform=transform_test, class_map=train_dataset.class_map) 114 | 115 | self.cls_num_list = train_dataset.cls_num_list 116 | 117 | self.dist_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed else None 118 | self.train_instance = torch.utils.data.DataLoader( 119 | train_dataset, 120 | batch_size=batch_size, shuffle=True, 121 | num_workers=num_works, pin_memory=True, sampler=self.dist_sampler) 122 | 123 | balance_sampler = ClassAwareSampler(train_dataset) 124 | self.train_balance = torch.utils.data.DataLoader( 125 | train_dataset, 126 | batch_size=batch_size, shuffle=False, 127 | num_workers=num_works, pin_memory=True, sampler=balance_sampler) 128 | 129 | self.eval = torch.utils.data.DataLoader( 130 | eval_dataset, 131 | batch_size=batch_size, shuffle=False, 132 | num_workers=num_works, pin_memory=True) -------------------------------------------------------------------------------- /datasets/places.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | import torch 6 | import torchvision 7 | import torchvision.datasets 8 | from torchvision import transforms 9 | from torch.utils.data import Dataset 10 | 11 | from .sampler import ClassAwareSampler 12 | 13 | 14 | class LT_Dataset(Dataset): 15 | num_classes = 365 16 | 17 | def __init__(self, root, txt, transform=None): 18 | self.img_path = [] 19 | self.targets = [] 20 | self.transform = transform 21 | with open(txt) as f: 22 | for line in f: 23 | self.img_path.append(os.path.join(root, line.split()[0])) 24 | self.targets.append(int(line.split()[1])) 25 | 26 | cls_num_list_old = [np.sum(np.array(self.targets) == i) for i in range(self.num_classes)] 27 | 28 | # generate class_map: class index sort by num (descending) 29 | sorted_classes = np.argsort(-np.array(cls_num_list_old)) 30 | self.class_map = [0 for i in range(self.num_classes)] 31 | for i in range(self.num_classes): 32 | self.class_map[sorted_classes[i]] = i 33 | 34 | self.targets = np.array(self.class_map)[self.targets].tolist() 35 | 36 | self.class_data = [[] for i in range(self.num_classes)] 37 | for i in range(len(self.targets)): 38 | j = self.targets[i] 39 | self.class_data[j].append(i) 40 | 41 | self.cls_num_list = [np.sum(np.array(self.targets)==i) for i in range(self.num_classes)] 42 | 43 | 44 | def __len__(self): 45 | return len(self.targets) 46 | 47 | def __getitem__(self, index): 48 | path = self.img_path[index] 49 | target = self.targets[index] 50 | 51 | with open(path, 'rb') as f: 52 | sample = Image.open(f).convert('RGB') 53 | if self.transform is not None: 54 | sample = self.transform(sample) 55 | return sample, target 56 | 57 | 58 | 59 | class LT_Dataset_Eval(Dataset): 60 | num_classes = 365 61 | 62 | def __init__(self, root, txt, class_map, transform=None): 63 | self.img_path = [] 64 | self.targets = [] 65 | self.transform = transform 66 | self.class_map = class_map 67 | with open(txt) as f: 68 | for line in f: 69 | self.img_path.append(os.path.join(root, line.split()[0])) 70 | self.targets.append(int(line.split()[1])) 71 | 72 | self.targets = np.array(self.class_map)[self.targets].tolist() 73 | 74 | def __len__(self): 75 | return len(self.targets) 76 | 77 | def __getitem__(self, index): 78 | path = self.img_path[index] 79 | target = self.targets[index] 80 | 81 | with open(path, 'rb') as f: 82 | sample = Image.open(f).convert('RGB') 83 | if self.transform is not None: 84 | sample = self.transform(sample) 85 | return sample, target 86 | 87 | 88 | class Places_LT(object): 89 | def __init__(self, distributed, root="", batch_size=60, num_works=40): 90 | 91 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 92 | 93 | transform_train = transforms.Compose([ 94 | transforms.RandomResizedCrop(224), 95 | transforms.RandomHorizontalFlip(), 96 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0), 97 | transforms.ToTensor(), 98 | normalize, 99 | ]) 100 | 101 | 102 | transform_test = transforms.Compose([ 103 | transforms.Resize(256), 104 | transforms.CenterCrop(224), 105 | transforms.ToTensor(), 106 | normalize, 107 | ]) 108 | 109 | train_txt = "./datasets/data_txt/Places_LT_train.txt" 110 | eval_txt = "./datasets/data_txt/Places_LT_test.txt" 111 | 112 | 113 | train_dataset = LT_Dataset(root, train_txt, transform=transform_train) 114 | eval_dataset = LT_Dataset_Eval(root, eval_txt, transform=transform_test, class_map=train_dataset.class_map) 115 | 116 | self.cls_num_list = train_dataset.cls_num_list 117 | 118 | self.dist_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed else None 119 | self.train_instance = torch.utils.data.DataLoader( 120 | train_dataset, 121 | batch_size=batch_size, shuffle=True, 122 | num_workers=num_works, pin_memory=True, sampler=self.dist_sampler) 123 | 124 | balance_sampler = ClassAwareSampler(train_dataset) 125 | self.train_balance = torch.utils.data.DataLoader( 126 | train_dataset, 127 | batch_size=batch_size, shuffle=False, 128 | num_workers=num_works, pin_memory=True, sampler=balance_sampler) 129 | 130 | self.eval = torch.utils.data.DataLoader( 131 | eval_dataset, 132 | batch_size=batch_size, shuffle=False, 133 | num_workers=num_works, pin_memory=True) 134 | -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | import torch 5 | 6 | class BalancedDatasetSampler(torch.utils.data.sampler.Sampler): 7 | 8 | def __init__(self, dataset, indices=None, num_samples=None): 9 | 10 | # if indices is not provided, 11 | # all elements in the dataset will be considered 12 | self.indices = list(range(len(dataset))) \ 13 | if indices is None else indices 14 | 15 | # if num_samples is not provided, 16 | # draw `len(indices)` samples in each iteration 17 | self.num_samples = len(self.indices) \ 18 | if num_samples is None else num_samples 19 | 20 | # distribution of classes in the dataset 21 | label_to_count = [0] * len(np.unique(dataset.targets)) 22 | for idx in self.indices: 23 | label = self._get_label(dataset, idx) 24 | label_to_count[label] += 1 25 | 26 | 27 | 28 | 29 | per_cls_weights = 1 / np.array(label_to_count) 30 | 31 | # weight for each sample 32 | weights = [per_cls_weights[self._get_label(dataset, idx)] 33 | for idx in self.indices] 34 | 35 | 36 | self.weights = torch.DoubleTensor(weights) 37 | 38 | def _get_label(self, dataset, idx): 39 | return dataset.targets[idx] 40 | 41 | def __iter__(self): 42 | return iter(torch.multinomial(self.weights, self.num_samples, replacement=True).tolist()) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | class EffectNumSampler(torch.utils.data.sampler.Sampler): 48 | 49 | def __init__(self, dataset, indices=None, num_samples=None): 50 | 51 | # if indices is not provided, 52 | # all elements in the dataset will be considered 53 | self.indices = list(range(len(dataset))) \ 54 | if indices is None else indices 55 | 56 | # if num_samples is not provided, 57 | # draw `len(indices)` samples in each iteration 58 | self.num_samples = len(self.indices) \ 59 | if num_samples is None else num_samples 60 | 61 | # distribution of classes in the dataset 62 | label_to_count = [0] * len(np.unique(dataset.targets)) 63 | for idx in self.indices: 64 | label = self._get_label(dataset, idx) 65 | label_to_count[label] += 1 66 | 67 | 68 | 69 | beta = 0.9999 70 | effective_num = 1.0 - np.power(beta, label_to_count) 71 | per_cls_weights = (1.0 - beta) / np.array(effective_num) 72 | 73 | # weight for each sample 74 | weights = [per_cls_weights[self._get_label(dataset, idx)] 75 | for idx in self.indices] 76 | 77 | 78 | self.weights = torch.DoubleTensor(weights) 79 | 80 | def _get_label(self, dataset, idx): 81 | return dataset.targets[idx] 82 | 83 | def __iter__(self): 84 | return iter(torch.multinomial(self.weights, self.num_samples, replacement=True).tolist()) 85 | 86 | def __len__(self): 87 | return self.num_samples 88 | 89 | class RandomCycleIter: 90 | 91 | def __init__ (self, data, test_mode=False): 92 | self.data_list = list(data) 93 | self.length = len(self.data_list) 94 | self.i = self.length - 1 95 | self.test_mode = test_mode 96 | 97 | def __iter__ (self): 98 | return self 99 | 100 | def __next__ (self): 101 | self.i += 1 102 | 103 | if self.i == self.length: 104 | self.i = 0 105 | if not self.test_mode: 106 | random.shuffle(self.data_list) 107 | 108 | return self.data_list[self.i] 109 | 110 | def class_aware_sample_generator(cls_iter, data_iter_list, n, num_samples_cls=1): 111 | 112 | i = 0 113 | j = 0 114 | while i < n: 115 | 116 | # yield next(data_iter_list[next(cls_iter)]) 117 | 118 | if j >= num_samples_cls: 119 | j = 0 120 | 121 | if j == 0: 122 | temp_tuple = next(zip(*[data_iter_list[next(cls_iter)]]*num_samples_cls)) 123 | yield temp_tuple[j] 124 | else: 125 | yield temp_tuple[j] 126 | 127 | i += 1 128 | j += 1 129 | 130 | class ClassAwareSampler(torch.utils.data.sampler.Sampler): 131 | def __init__(self, data_source, num_samples_cls=4,): 132 | # pdb.set_trace() 133 | num_classes = len(np.unique(data_source.targets)) 134 | self.class_iter = RandomCycleIter(range(num_classes)) 135 | cls_data_list = [list() for _ in range(num_classes)] 136 | for i, label in enumerate(data_source.targets): 137 | cls_data_list[label].append(i) 138 | self.data_iter_list = [RandomCycleIter(x) for x in cls_data_list] 139 | self.num_samples = max([len(x) for x in cls_data_list]) * len(cls_data_list) 140 | self.num_samples_cls = num_samples_cls 141 | 142 | def __iter__ (self): 143 | return class_aware_sample_generator(self.class_iter, self.data_iter_list, 144 | self.num_samples, self.num_samples_cls) 145 | 146 | def __len__ (self): 147 | return self.num_samples 148 | 149 | def get_sampler(): 150 | return ClassAwareSampler -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import time 5 | import warnings 6 | import numpy as np 7 | import pprint 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.distributed as dist 13 | import torch.optim 14 | import torch.multiprocessing as mp 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | import torch.nn.functional as F 18 | 19 | from datasets.cifar10 import CIFAR10_LT 20 | from datasets.cifar100 import CIFAR100_LT 21 | from datasets.places import Places_LT 22 | from datasets.imagenet import ImageNet_LT 23 | from datasets.ina2018 import iNa2018 24 | 25 | from models import resnet 26 | from models import resnet_places 27 | from models import resnet_cifar 28 | 29 | from utils import config, update_config, create_logger 30 | from utils import AverageMeter, ProgressMeter 31 | from utils import accuracy, calibration 32 | 33 | from methods import LearnableWeightScaling 34 | 35 | 36 | def parse_args(): 37 | parser = argparse.ArgumentParser(description='MiSLAS evaluation') 38 | parser.add_argument('--cfg', 39 | help='experiment configure file name', 40 | required=True, 41 | type=str) 42 | parser.add_argument('opts', 43 | help="Modify config options using the command-line", 44 | default=None, 45 | nargs=argparse.REMAINDER) 46 | 47 | args = parser.parse_args() 48 | update_config(config, args) 49 | 50 | return args 51 | 52 | 53 | best_acc1 = 0 54 | 55 | 56 | def main(): 57 | args = parse_args() 58 | logger, model_dir = create_logger(config, args.cfg) 59 | logger.info('\n' + pprint.pformat(args)) 60 | logger.info('\n' + str(config)) 61 | 62 | if config.deterministic: 63 | seed = 0 64 | torch.backends.cudnn.deterministic = True 65 | torch.backends.cudnn.benchmark = False 66 | random.seed(seed) 67 | np.random.seed(seed) 68 | os.environ['PYTHONHASHSEED'] = str(seed) 69 | torch.manual_seed(seed) 70 | torch.cuda.manual_seed(seed) 71 | torch.cuda.manual_seed_all(seed) 72 | 73 | if config.gpu is not None: 74 | warnings.warn('You have chosen a specific GPU. This will completely ' 75 | 'disable data parallelism.') 76 | 77 | if config.dist_url == "env://" and config.world_size == -1: 78 | config.world_size = int(os.environ["WORLD_SIZE"]) 79 | 80 | config.distributed = config.world_size > 1 or config.multiprocessing_distributed 81 | 82 | ngpus_per_node = torch.cuda.device_count() 83 | if config.multiprocessing_distributed: 84 | # Since we have ngpus_per_node processes per node, the total world_size 85 | # needs to be adjusted accordingly 86 | config.world_size = ngpus_per_node * config.world_size 87 | # Use torch.multiprocessing.spawn to launch distributed processes: the 88 | # main_worker process function 89 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, config, logger)) 90 | else: 91 | # Simply call main_worker function 92 | main_worker(config.gpu, ngpus_per_node, config, logger, model_dir) 93 | 94 | 95 | def main_worker(gpu, ngpus_per_node, config, logger, model_dir): 96 | global best_acc1 97 | config.gpu = gpu 98 | # start_time = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 99 | 100 | if config.gpu is not None: 101 | logger.info("Use GPU: {} for training".format(config.gpu)) 102 | 103 | if config.distributed: 104 | if config.dist_url == "env://" and config.rank == -1: 105 | config.rank = int(os.environ["RANK"]) 106 | if config.multiprocessing_distributed: 107 | # For multiprocessing distributed training, rank needs to be the 108 | # global rank among all the processes 109 | config.rank = config.rank * ngpus_per_node + gpu 110 | dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url, 111 | world_size=config.world_size, rank=config.rank) 112 | 113 | if config.dataset == 'cifar10' or config.dataset == 'cifar100': 114 | model = getattr(resnet_cifar, config.backbone)() 115 | classifier = getattr(resnet_cifar, 'Classifier')(feat_in=64, num_classes=config.num_classes) 116 | 117 | elif config.dataset == 'imagenet' or config.dataset == 'ina2018': 118 | model = getattr(resnet, config.backbone)() 119 | classifier = getattr(resnet, 'Classifier')(feat_in=2048, num_classes=config.num_classes) 120 | 121 | elif config.dataset == 'places': 122 | model = getattr(resnet_places, config.backbone)(pretrained=True) 123 | classifier = getattr(resnet_places, 'Classifier')(feat_in=2048, num_classes=config.num_classes) 124 | block = getattr(resnet_places, 'Bottleneck')(2048, 512, groups=1, 125 | base_width=64, dilation=1, 126 | norm_layer=nn.BatchNorm2d) 127 | 128 | lws_model = LearnableWeightScaling(num_classes=config.num_classes) 129 | 130 | if not torch.cuda.is_available(): 131 | logger.info('using CPU, this will be slow') 132 | elif config.distributed: 133 | # For multiprocessing distributed, DistributedDataParallel constructor 134 | # should always set the single device scope, otherwise, 135 | # DistributedDataParallel will use all available devices. 136 | if config.gpu is not None: 137 | torch.cuda.set_device(config.gpu) 138 | model.cuda(config.gpu) 139 | classifier.cuda(config.gpu) 140 | lws_model.cuda(config.gpu) 141 | # When using a single GPU per process and per 142 | # DistributedDataParallel, we need to divide the batch size 143 | # ourselves based on the total number of GPUs we have 144 | config.batch_size = int(config.batch_size / ngpus_per_node) 145 | config.workers = int((config.workers + ngpus_per_node - 1) / ngpus_per_node) 146 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu]) 147 | classifier = torch.nn.parallel.DistributedDataParallel(classifier, device_ids=[config.gpu]) 148 | lws_model = torch.nn.parallel.DistributedDataParallel(lws_model, device_ids=[config.gpu]) 149 | 150 | if config.dataset == 'places': 151 | block.cuda(config.gpu) 152 | block = torch.nn.parallel.DistributedDataParallel(block, device_ids=[config.gpu]) 153 | else: 154 | model.cuda() 155 | classifier.cuda() 156 | lws_model.cuda() 157 | # DistributedDataParallel will divide and allocate batch_size to all 158 | # available GPUs if device_ids are not set 159 | model = torch.nn.parallel.DistributedDataParallel(model) 160 | classifier = torch.nn.parallel.DistributedDataParallel(classifier) 161 | lws_model = torch.nn.parallel.DistributedDataParallel(lws_model) 162 | if config.dataset == 'places': 163 | block.cuda() 164 | block = torch.nn.parallel.DistributedDataParallel(block) 165 | 166 | elif config.gpu is not None: 167 | torch.cuda.set_device(config.gpu) 168 | model = model.cuda(config.gpu) 169 | classifier = classifier.cuda(config.gpu) 170 | lws_model = lws_model.cuda(config.gpu) 171 | if config.dataset == 'places': 172 | block.cuda(config.gpu) 173 | else: 174 | # DataParallel will divide and allocate batch_size to all available GPUs 175 | model = torch.nn.DataParallel(model).cuda() 176 | classifier = torch.nn.DataParallel(classifier).cuda() 177 | lws_model = torch.nn.DataParallel(lws_model).cuda() 178 | if config.dataset == 'places': 179 | block = torch.nn.DataParallel(block).cuda() 180 | 181 | # optionally resume from a checkpoint 182 | if config.resume: 183 | if os.path.isfile(config.resume): 184 | logger.info("=> loading checkpoint '{}'".format(config.resume)) 185 | if config.gpu is None: 186 | checkpoint = torch.load(config.resume) 187 | else: 188 | # Map model to be loaded to specified single gpu. 189 | loc = 'cuda:{}'.format(config.gpu) 190 | checkpoint = torch.load(config.resume, map_location=loc) 191 | if config.gpu is not None: 192 | # best_acc1 may be from a checkpoint from a different GPU 193 | best_acc1 = best_acc1.to(config.gpu) 194 | model.load_state_dict(checkpoint['state_dict_model']) 195 | classifier.load_state_dict(checkpoint['state_dict_classifier']) 196 | if config.dataset == 'places': 197 | block.load_state_dict(checkpoint['state_dict_block']) 198 | if config.mode == 'stage2': 199 | lws_model.load_state_dict(checkpoint['state_dict_lws_model']) 200 | logger.info("=> loaded checkpoint '{}' (epoch {})" 201 | .format(config.resume, checkpoint['epoch'])) 202 | else: 203 | logger.info("=> no checkpoint found at '{}'".format(config.resume)) 204 | 205 | # Data loading code 206 | if config.dataset == 'cifar10': 207 | dataset = CIFAR10_LT(config.distributed, root=config.data_path, imb_factor=config.imb_factor, 208 | batch_size=config.batch_size, num_works=config.workers) 209 | 210 | elif config.dataset == 'cifar100': 211 | dataset = CIFAR100_LT(config.distributed, root=config.data_path, imb_factor=config.imb_factor, 212 | batch_size=config.batch_size, num_works=config.workers) 213 | 214 | elif config.dataset == 'places': 215 | dataset = Places_LT(config.distributed, root=config.data_path, 216 | batch_size=config.batch_size, num_works=config.workers) 217 | 218 | elif config.dataset == 'imagenet': 219 | dataset = ImageNet_LT(config.distributed, root=config.data_path, 220 | batch_size=config.batch_size, num_works=config.workers) 221 | 222 | elif config.dataset == 'ina2018': 223 | dataset = iNa2018(config.distributed, root=config.data_path, 224 | batch_size=config.batch_size, num_works=config.workers) 225 | 226 | val_loader = dataset.eval 227 | criterion = nn.CrossEntropyLoss().cuda(config.gpu) 228 | 229 | if config.dataset != 'places': 230 | block = None 231 | 232 | validate(val_loader, model, classifier, lws_model, criterion, config, logger, block) 233 | 234 | 235 | def validate(val_loader, model, classifier, lws_model, criterion, config, logger, block=None): 236 | batch_time = AverageMeter('Time', ':6.3f') 237 | losses = AverageMeter('Loss', ':.3f') 238 | top1 = AverageMeter('Acc@1', ':6.3f') 239 | top5 = AverageMeter('Acc@5', ':6.3f') 240 | progress = ProgressMeter( 241 | len(val_loader), 242 | [batch_time, losses, top1, top5], 243 | prefix='Eval: ') 244 | 245 | # switch to evaluate mode 246 | model.eval() 247 | if config.dataset == 'places': 248 | block.eval() 249 | classifier.eval() 250 | class_num = torch.zeros(config.num_classes).cuda() 251 | correct = torch.zeros(config.num_classes).cuda() 252 | 253 | confidence = np.array([]) 254 | pred_class = np.array([]) 255 | true_class = np.array([]) 256 | 257 | with torch.no_grad(): 258 | end = time.time() 259 | for i, (images, target) in enumerate(val_loader): 260 | if config.gpu is not None: 261 | images = images.cuda(config.gpu, non_blocking=True) 262 | if torch.cuda.is_available(): 263 | target = target.cuda(config.gpu, non_blocking=True) 264 | 265 | # compute output 266 | if config.dataset == 'places': 267 | feat = block(model(images)) 268 | else: 269 | feat = model(images) 270 | output = classifier(feat) 271 | output = lws_model(output) 272 | loss = criterion(output, target) 273 | 274 | # measure accuracy and record loss 275 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 276 | losses.update(loss.item(), images.size(0)) 277 | top1.update(acc1[0], images.size(0)) 278 | top5.update(acc5[0], images.size(0)) 279 | 280 | _, predicted = output.max(1) 281 | target_one_hot = F.one_hot(target, config.num_classes) 282 | predict_one_hot = F.one_hot(predicted, config.num_classes) 283 | class_num = class_num + target_one_hot.sum(dim=0).to(torch.float) 284 | correct = correct + (target_one_hot + predict_one_hot == 2).sum(dim=0).to(torch.float) 285 | 286 | prob = torch.softmax(output, dim=1) 287 | confidence_part, pred_class_part = torch.max(prob, dim=1) 288 | confidence = np.append(confidence, confidence_part.cpu().numpy()) 289 | pred_class = np.append(pred_class, pred_class_part.cpu().numpy()) 290 | true_class = np.append(true_class, target.cpu().numpy()) 291 | 292 | # measure elapsed time 293 | batch_time.update(time.time() - end) 294 | end = time.time() 295 | 296 | if i % config.print_freq == 0: 297 | progress.display(i, logger) 298 | 299 | acc_classes = correct / class_num 300 | head_acc = acc_classes[config.head_class_idx[0]:config.head_class_idx[1]].mean() * 100 301 | med_acc = acc_classes[config.med_class_idx[0]:config.med_class_idx[1]].mean() * 100 302 | tail_acc = acc_classes[config.tail_class_idx[0]:config.tail_class_idx[1]].mean() * 100 303 | 304 | logger.info('* Acc@1 {top1.avg:.3f}% Acc@5 {top5.avg:.3f}% HAcc {head_acc:.3f}% MAcc {med_acc:.3f}% TAcc {tail_acc:.3f}%.'.format(top1=top1, top5=top5, head_acc=head_acc, med_acc=med_acc, tail_acc=tail_acc)) 305 | 306 | cal = calibration(true_class, pred_class, confidence, num_bins=15) 307 | logger.info('* ECE {ece:.3f}%.'.format(ece=cal['expected_calibration_error'] * 100)) 308 | 309 | return top1.avg, cal['expected_calibration_error'] * 100 310 | 311 | 312 | if __name__ == '__main__': 313 | main() 314 | -------------------------------------------------------------------------------- /methods.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def mixup_data(x, y, alpha=1.0, use_cuda=True): 9 | '''Returns mixed inputs, pairs of targets, and lambda''' 10 | if alpha > 0: 11 | lam = np.random.beta(alpha, alpha) 12 | else: 13 | lam = 1 14 | 15 | batch_size = x.size()[0] 16 | if use_cuda: 17 | index = torch.randperm(batch_size).cuda() 18 | else: 19 | index = torch.randperm(batch_size) 20 | 21 | mixed_x = lam * x + (1 - lam) * x[index, :] 22 | y_a, y_b = y, y[index] 23 | 24 | return mixed_x, y_a, y_b, lam 25 | 26 | 27 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 28 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 29 | 30 | 31 | class LabelAwareSmoothing(nn.Module): 32 | def __init__(self, cls_num_list, smooth_head, smooth_tail, shape='concave', power=None): 33 | super(LabelAwareSmoothing, self).__init__() 34 | 35 | n_1 = max(cls_num_list) 36 | n_K = min(cls_num_list) 37 | 38 | if shape == 'concave': 39 | self.smooth = smooth_tail + (smooth_head - smooth_tail) * np.sin((np.array(cls_num_list) - n_K) * np.pi / (2 * (n_1 - n_K))) 40 | 41 | elif shape == 'linear': 42 | self.smooth = smooth_tail + (smooth_head - smooth_tail) * (np.array(cls_num_list) - n_K) / (n_1 - n_K) 43 | 44 | elif shape == 'convex': 45 | self.smooth = smooth_head + (smooth_head - smooth_tail) * np.sin(1.5 * np.pi + (np.array(cls_num_list) - n_K) * np.pi / (2 * (n_1 - n_K))) 46 | 47 | elif shape == 'exp' and power is not None: 48 | self.smooth = smooth_tail + (smooth_head - smooth_tail) * np.power((np.array(cls_num_list) - n_K) / (n_1 - n_K), power) 49 | 50 | self.smooth = torch.from_numpy(self.smooth) 51 | self.smooth = self.smooth.float() 52 | if torch.cuda.is_available(): 53 | self.smooth = self.smooth.cuda() 54 | 55 | def forward(self, x, target): 56 | smoothing = self.smooth[target] 57 | confidence = 1. - smoothing 58 | logprobs = F.log_softmax(x, dim=-1) 59 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 60 | nll_loss = nll_loss.squeeze(1) 61 | smooth_loss = -logprobs.mean(dim=-1) 62 | loss = confidence * nll_loss + smoothing * smooth_loss 63 | 64 | return loss.mean() 65 | 66 | 67 | class LearnableWeightScaling(nn.Module): 68 | def __init__(self, num_classes): 69 | super(LearnableWeightScaling, self).__init__() 70 | self.learned_norm = nn.Parameter(torch.ones(1, num_classes)) 71 | 72 | def forward(self, x): 73 | return self.learned_norm * x 74 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | try: 5 | from torch.hub import load_state_dict_from_url 6 | except ImportError: 7 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 8 | 9 | 10 | 11 | __all__ = ['Classifier', 'ResNet', 'resnet10', 'resnet10_fe', 'resnet18', 'resnet34', 'resnet50', 'resnet50_fe', 'resnet101', 'resnet101_fe', 12 | 'resnet152', 'resnet152_fe', 'resnext50_32x4d', 'resnext101_32x8d', 'resnext152_32x4d', 13 | 'wide_resnet50_2', 'wide_resnet101_2'] 14 | 15 | 16 | 17 | model_urls = { 18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 23 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 25 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 26 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 27 | } 28 | 29 | 30 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 31 | """3x3 convolution with padding""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 33 | padding=dilation, groups=groups, bias=False, dilation=dilation) 34 | 35 | 36 | def conv1x1(in_planes, out_planes, stride=1): 37 | """1x1 convolution""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | expansion = 1 43 | 44 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 45 | base_width=64, dilation=1, norm_layer=None): 46 | super(BasicBlock, self).__init__() 47 | if norm_layer is None: 48 | norm_layer = nn.BatchNorm2d 49 | if groups != 1 or base_width != 64: 50 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 51 | if dilation > 1: 52 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 53 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 54 | self.conv1 = conv3x3(inplanes, planes, stride) 55 | self.bn1 = norm_layer(planes) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.conv2 = conv3x3(planes, planes) 58 | self.bn2 = norm_layer(planes) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | identity = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | 72 | if self.downsample is not None: 73 | identity = self.downsample(x) 74 | 75 | out += identity 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | expansion = 4 83 | 84 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 85 | base_width=64, dilation=1, norm_layer=None): 86 | super(Bottleneck, self).__init__() 87 | if norm_layer is None: 88 | norm_layer = nn.BatchNorm2d 89 | width = int(planes * (base_width / 64.)) * groups 90 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 91 | self.conv1 = conv1x1(inplanes, width) 92 | self.bn1 = norm_layer(width) 93 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 94 | self.bn2 = norm_layer(width) 95 | self.conv3 = conv1x1(width, planes * self.expansion) 96 | self.bn3 = norm_layer(planes * self.expansion) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.downsample = downsample 99 | self.stride = stride 100 | 101 | def forward(self, x): 102 | identity = x 103 | 104 | out = self.conv1(x) 105 | out = self.bn1(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv2(out) 109 | out = self.bn2(out) 110 | out = self.relu(out) 111 | 112 | out = self.conv3(out) 113 | out = self.bn3(out) 114 | 115 | if self.downsample is not None: 116 | identity = self.downsample(x) 117 | 118 | out += identity 119 | out = self.relu(out) 120 | 121 | return out 122 | 123 | 124 | class ResNet(nn.Module): 125 | 126 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 127 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 128 | norm_layer=None): 129 | super(ResNet, self).__init__() 130 | if norm_layer is None: 131 | norm_layer = nn.BatchNorm2d 132 | self._norm_layer = norm_layer 133 | 134 | self.inplanes = 64 135 | self.dilation = 1 136 | if replace_stride_with_dilation is None: 137 | # each element in the tuple indicates if we should replace 138 | # the 2x2 stride with a dilated convolution instead 139 | replace_stride_with_dilation = [False, False, False] 140 | if len(replace_stride_with_dilation) != 3: 141 | raise ValueError("replace_stride_with_dilation should be None " 142 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 143 | self.groups = groups 144 | self.base_width = width_per_group 145 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 146 | bias=False) 147 | self.bn1 = norm_layer(self.inplanes) 148 | self.relu = nn.ReLU(inplace=True) 149 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 150 | self.layer1 = self._make_layer(block, 64, layers[0]) 151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 152 | dilate=replace_stride_with_dilation[0]) 153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 154 | dilate=replace_stride_with_dilation[1]) 155 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 156 | dilate=replace_stride_with_dilation[2]) 157 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 158 | self.fc = nn.Linear(512 * block.expansion, num_classes) 159 | 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 163 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 164 | nn.init.constant_(m.weight, 1) 165 | nn.init.constant_(m.bias, 0) 166 | 167 | # Zero-initialize the last BN in each residual branch, 168 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 169 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 170 | if zero_init_residual: 171 | for m in self.modules(): 172 | if isinstance(m, Bottleneck): 173 | nn.init.constant_(m.bn3.weight, 0) 174 | elif isinstance(m, BasicBlock): 175 | nn.init.constant_(m.bn2.weight, 0) 176 | 177 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 178 | norm_layer = self._norm_layer 179 | downsample = None 180 | previous_dilation = self.dilation 181 | if dilate: 182 | self.dilation *= stride 183 | stride = 1 184 | if stride != 1 or self.inplanes != planes * block.expansion: 185 | downsample = nn.Sequential( 186 | conv1x1(self.inplanes, planes * block.expansion, stride), 187 | norm_layer(planes * block.expansion), 188 | ) 189 | 190 | layers = [] 191 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 192 | self.base_width, previous_dilation, norm_layer)) 193 | self.inplanes = planes * block.expansion 194 | for _ in range(1, blocks): 195 | layers.append(block(self.inplanes, planes, groups=self.groups, 196 | base_width=self.base_width, dilation=self.dilation, 197 | norm_layer=norm_layer)) 198 | 199 | return nn.Sequential(*layers) 200 | 201 | def forward(self, x): 202 | x = self.conv1(x) 203 | x = self.bn1(x) 204 | x = self.relu(x) 205 | x = self.maxpool(x) 206 | 207 | x = self.layer1(x) 208 | x = self.layer2(x) 209 | x = self.layer3(x) 210 | x = self.layer4(x) 211 | 212 | x = self.avgpool(x) 213 | x = torch.flatten(x, 1) 214 | x = self.fc(x) 215 | 216 | return x 217 | 218 | class ResNet_FE(nn.Module): 219 | 220 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 221 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 222 | norm_layer=None): 223 | super(ResNet_FE, self).__init__() 224 | if norm_layer is None: 225 | norm_layer = nn.BatchNorm2d 226 | self._norm_layer = norm_layer 227 | 228 | self.inplanes = 64 229 | self.dilation = 1 230 | if replace_stride_with_dilation is None: 231 | # each element in the tuple indicates if we should replace 232 | # the 2x2 stride with a dilated convolution instead 233 | replace_stride_with_dilation = [False, False, False] 234 | if len(replace_stride_with_dilation) != 3: 235 | raise ValueError("replace_stride_with_dilation should be None " 236 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 237 | self.groups = groups 238 | self.base_width = width_per_group 239 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 240 | bias=False) 241 | self.bn1 = norm_layer(self.inplanes) 242 | self.relu = nn.ReLU(inplace=True) 243 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 244 | self.layer1 = self._make_layer(block, 64, layers[0]) 245 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 246 | dilate=replace_stride_with_dilation[0]) 247 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 248 | dilate=replace_stride_with_dilation[1]) 249 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 250 | dilate=replace_stride_with_dilation[2]) 251 | 252 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 253 | 254 | for m in self.modules(): 255 | if isinstance(m, nn.Conv2d): 256 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 257 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 258 | nn.init.constant_(m.weight, 1) 259 | nn.init.constant_(m.bias, 0) 260 | 261 | # Zero-initialize the last BN in each residual branch, 262 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 263 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 264 | if zero_init_residual: 265 | for m in self.modules(): 266 | if isinstance(m, Bottleneck): 267 | nn.init.constant_(m.bn3.weight, 0) 268 | elif isinstance(m, BasicBlock): 269 | nn.init.constant_(m.bn2.weight, 0) 270 | 271 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 272 | norm_layer = self._norm_layer 273 | downsample = None 274 | previous_dilation = self.dilation 275 | if dilate: 276 | self.dilation *= stride 277 | stride = 1 278 | if stride != 1 or self.inplanes != planes * block.expansion: 279 | downsample = nn.Sequential( 280 | conv1x1(self.inplanes, planes * block.expansion, stride), 281 | norm_layer(planes * block.expansion), 282 | ) 283 | 284 | layers = [] 285 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 286 | self.base_width, previous_dilation, norm_layer)) 287 | self.inplanes = planes * block.expansion 288 | for _ in range(1, blocks): 289 | layers.append(block(self.inplanes, planes, groups=self.groups, 290 | base_width=self.base_width, dilation=self.dilation, 291 | norm_layer=norm_layer)) 292 | 293 | return nn.Sequential(*layers) 294 | 295 | def forward(self, x): 296 | x = self.conv1(x) 297 | x = self.bn1(x) 298 | x = self.relu(x) 299 | x = self.maxpool(x) 300 | 301 | x = self.layer1(x) 302 | x = self.layer2(x) 303 | x = self.layer3(x) 304 | x = self.layer4(x) 305 | x = self.avgpool(x) 306 | x = torch.flatten(x, 1) 307 | 308 | return x 309 | 310 | 311 | 312 | class Classifier(nn.Module): 313 | def __init__(self, feat_in, num_classes): 314 | super(Classifier, self).__init__() 315 | self.fc = nn.Linear(feat_in, num_classes) 316 | 317 | def forward(self, x): 318 | x = self.fc(x) 319 | return x 320 | 321 | 322 | 323 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 324 | model = ResNet(block, layers, **kwargs) 325 | if pretrained: 326 | state_dict = load_state_dict_from_url(model_urls[arch], 327 | progress=progress) 328 | model.load_state_dict(state_dict) 329 | return model 330 | 331 | def _resnet_fe(arch, block, layers, pretrained, progress, **kwargs): 332 | model = ResNet_FE(block, layers, **kwargs) 333 | if pretrained: 334 | # import pdb; pdb.set_trace() 335 | state_dict = load_state_dict_from_url(model_urls[arch], 336 | progress=progress) 337 | model.load_state_dict(state_dict) 338 | return model 339 | 340 | 341 | def resnet10(pretrained=False, progress=True, **kwargs): 342 | r"""ResNet-18 model from 343 | `"Deep Residual Learning for Image Recognition" `_ 344 | 345 | Args: 346 | pretrained (bool): If True, returns a model pre-trained on ImageNet 347 | progress (bool): If True, displays a progress bar of the download to stderr 348 | """ 349 | return _resnet('resnet10', BasicBlock, [1, 1, 1, 1], pretrained, progress, 350 | **kwargs) 351 | 352 | 353 | def resnet10_fe(pretrained=False, progress=True, **kwargs): 354 | r"""ResNet-18 model from 355 | `"Deep Residual Learning for Image Recognition" `_ 356 | 357 | Args: 358 | pretrained (bool): If True, returns a model pre-trained on ImageNet 359 | progress (bool): If True, displays a progress bar of the download to stderr 360 | """ 361 | return _resnet_fe('resnet10_fe', BasicBlock, [1, 1, 1, 1], pretrained, progress, 362 | **kwargs) 363 | 364 | 365 | 366 | def resnet18(pretrained=False, progress=True, **kwargs): 367 | r"""ResNet-18 model from 368 | `"Deep Residual Learning for Image Recognition" `_ 369 | 370 | Args: 371 | pretrained (bool): If True, returns a model pre-trained on ImageNet 372 | progress (bool): If True, displays a progress bar of the download to stderr 373 | """ 374 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 375 | **kwargs) 376 | 377 | 378 | def resnet34(pretrained=False, progress=True, **kwargs): 379 | r"""ResNet-34 model from 380 | `"Deep Residual Learning for Image Recognition" `_ 381 | 382 | Args: 383 | pretrained (bool): If True, returns a model pre-trained on ImageNet 384 | progress (bool): If True, displays a progress bar of the download to stderr 385 | """ 386 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 387 | **kwargs) 388 | 389 | 390 | def resnet50(pretrained=False, progress=True, **kwargs): 391 | r"""ResNet-50 model from 392 | `"Deep Residual Learning for Image Recognition" `_ 393 | 394 | Args: 395 | pretrained (bool): If True, returns a model pre-trained on ImageNet 396 | progress (bool): If True, displays a progress bar of the download to stderr 397 | """ 398 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 399 | **kwargs) 400 | 401 | def resnet50_fe(pretrained=False, progress=True, **kwargs): 402 | r"""ResNet-18 model from 403 | `"Deep Residual Learning for Image Recognition" `_ 404 | 405 | Args: 406 | pretrained (bool): If True, returns a model pre-trained on ImageNet 407 | progress (bool): If True, displays a progress bar of the download to stderr 408 | """ 409 | return _resnet_fe('resnet50_fe', Bottleneck, [3, 4, 6, 3], pretrained, progress, 410 | **kwargs) 411 | 412 | 413 | def resnet101(pretrained=False, progress=True, **kwargs): 414 | r"""ResNet-101 model from 415 | `"Deep Residual Learning for Image Recognition" `_ 416 | 417 | Args: 418 | pretrained (bool): If True, returns a model pre-trained on ImageNet 419 | progress (bool): If True, displays a progress bar of the download to stderr 420 | """ 421 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 422 | **kwargs) 423 | 424 | 425 | def resnet101_fe(pretrained=False, progress=True, **kwargs): 426 | r"""ResNet-101 model from 427 | `"Deep Residual Learning for Image Recognition" `_ 428 | 429 | Args: 430 | pretrained (bool): If True, returns a model pre-trained on ImageNet 431 | progress (bool): If True, displays a progress bar of the download to stderr 432 | """ 433 | return _resnet_fe('resnet101_fe', Bottleneck, [3, 4, 23, 3], pretrained, progress, 434 | **kwargs) 435 | 436 | def resnet152(pretrained=False, progress=True, **kwargs): 437 | r"""ResNet-152 model from 438 | `"Deep Residual Learning for Image Recognition" `_ 439 | 440 | Args: 441 | pretrained (bool): If True, returns a model pre-trained on ImageNet 442 | progress (bool): If True, displays a progress bar of the download to stderr 443 | """ 444 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 445 | **kwargs) 446 | 447 | 448 | def resnet152_fe(pretrained=False, progress=True, **kwargs): 449 | r"""ResNet-152 model from 450 | `"Deep Residual Learning for Image Recognition" `_ 451 | 452 | Args: 453 | pretrained (bool): If True, returns a model pre-trained on ImageNet 454 | progress (bool): If True, displays a progress bar of the download to stderr 455 | 456 | """ 457 | return _resnet_fe('resnet152_fe', Bottleneck, [3, 8, 36, 3], pretrained, progress, 458 | **kwargs) 459 | 460 | 461 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 462 | r"""ResNeXt-50 32x4d model from 463 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 464 | 465 | Args: 466 | pretrained (bool): If True, returns a model pre-trained on ImageNet 467 | progress (bool): If True, displays a progress bar of the download to stderr 468 | """ 469 | kwargs['groups'] = 32 470 | kwargs['width_per_group'] = 4 471 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 472 | pretrained, progress, **kwargs) 473 | 474 | 475 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 476 | r"""ResNeXt-101 32x8d model from 477 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 478 | 479 | Args: 480 | pretrained (bool): If True, returns a model pre-trained on ImageNet 481 | progress (bool): If True, displays a progress bar of the download to stderr 482 | """ 483 | kwargs['groups'] = 32 484 | kwargs['width_per_group'] = 8 485 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 486 | pretrained, progress, **kwargs) 487 | 488 | def resnext152_32x4d(pretrained=False, progress=True, **kwargs): 489 | kwargs['groups'] = 32 490 | kwargs['width_per_group'] = 4 491 | return _resnet('resnext152_32x4d', Bottleneck, [3, 8, 36, 3], 492 | pretrained, progress, **kwargs) 493 | 494 | 495 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 496 | r"""Wide ResNet-50-2 model from 497 | `"Wide Residual Networks" `_ 498 | 499 | The model is the same as ResNet except for the bottleneck number of channels 500 | which is twice larger in every block. The number of channels in outer 1x1 501 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 502 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 503 | 504 | Args: 505 | pretrained (bool): If True, returns a model pre-trained on ImageNet 506 | progress (bool): If True, displays a progress bar of the download to stderr 507 | """ 508 | kwargs['width_per_group'] = 64 * 2 509 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 510 | pretrained, progress, **kwargs) 511 | 512 | 513 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 514 | r"""Wide ResNet-101-2 model from 515 | `"Wide Residual Networks" `_ 516 | 517 | The model is the same as ResNet except for the bottleneck number of channels 518 | which is twice larger in every block. The number of channels in outer 1x1 519 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 520 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 521 | 522 | Args: 523 | pretrained (bool): If True, returns a model pre-trained on ImageNet 524 | progress (bool): If True, displays a progress bar of the download to stderr 525 | """ 526 | kwargs['width_per_group'] = 64 * 2 527 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 528 | pretrained, progress, **kwargs) 529 | -------------------------------------------------------------------------------- /models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Properly implemented ResNet for CIFAR10 as described in paper [1]. 3 | The implementation and structure of this file is hugely influenced by [2] 4 | which is implemented for ImageNet and doesn't have option A for identity. 5 | Moreover, most of the implementations on the web is copy-paste from 6 | torchvision's resnet and has wrong number of params. 7 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 8 | number of layers and parameters: 9 | name | layers | params 10 | ResNet20 | 20 | 0.0017M 11 | ResNet32 | 32 | 0.46M 12 | ResNet44 | 44 | 0.66M 13 | ResNet56 | 56 | 0.85M 14 | ResNet110 | 110 | 1.7M 15 | ResNet1202| 1202 | 19.4m 16 | which this implementation indeed has. 17 | Reference: 18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 20 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 21 | If you use this implementation in you work, please don't forget to mention the 22 | author, Yerlan Idelbayev. 23 | ''' 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | import torch.nn.init as init 28 | from torch.nn import Parameter 29 | 30 | __all__ = ['ResNet_s', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 31 | 32 | def _weights_init(m): 33 | classname = m.__class__.__name__ 34 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 35 | init.kaiming_normal_(m.weight) 36 | 37 | class NormedLinear(nn.Module): 38 | 39 | def __init__(self, in_features, out_features): 40 | super(NormedLinear, self).__init__() 41 | self.weight = Parameter(torch.Tensor(in_features, out_features)) 42 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 43 | 44 | def forward(self, x): 45 | out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0)) 46 | return out 47 | 48 | class LambdaLayer(nn.Module): 49 | 50 | def __init__(self, lambd): 51 | super(LambdaLayer, self).__init__() 52 | self.lambd = lambd 53 | 54 | def forward(self, x): 55 | return self.lambd(x) 56 | 57 | 58 | class BasicBlock(nn.Module): 59 | expansion = 1 60 | 61 | def __init__(self, in_planes, planes, stride=1, option='A'): 62 | super(BasicBlock, self).__init__() 63 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(planes) 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | 68 | self.shortcut = nn.Sequential() 69 | if stride != 1 or in_planes != planes: 70 | if option == 'A': 71 | """ 72 | For CIFAR10 ResNet paper uses option A. 73 | """ 74 | self.shortcut = LambdaLayer(lambda x: 75 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 76 | elif option == 'B': 77 | self.shortcut = nn.Sequential( 78 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 79 | nn.BatchNorm2d(self.expansion * planes) 80 | ) 81 | 82 | def forward(self, x): 83 | out = F.relu(self.bn1(self.conv1(x))) 84 | out = self.bn2(self.conv2(out)) 85 | out += self.shortcut(x) 86 | out = F.relu(out) 87 | return out 88 | 89 | 90 | class ResNet_s(nn.Module): 91 | 92 | def __init__(self, block, num_blocks, num_classes=10, use_norm=False): 93 | super(ResNet_s, self).__init__() 94 | self.in_planes = 16 95 | 96 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 97 | self.bn1 = nn.BatchNorm2d(16) 98 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 99 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 100 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 101 | if use_norm: 102 | self.linear = NormedLinear(64, num_classes) 103 | else: 104 | self.linear = nn.Linear(64, num_classes) 105 | self.apply(_weights_init) 106 | 107 | def _make_layer(self, block, planes, num_blocks, stride): 108 | strides = [stride] + [1]*(num_blocks-1) 109 | layers = [] 110 | for stride in strides: 111 | layers.append(block(self.in_planes, planes, stride)) 112 | self.in_planes = planes * block.expansion 113 | 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x): 117 | out = F.relu(self.bn1(self.conv1(x))) 118 | out = self.layer1(out) 119 | out = self.layer2(out) 120 | out = self.layer3(out) 121 | out = F.avg_pool2d(out, out.size()[3]) 122 | out = out.view(out.size(0), -1) 123 | out = self.linear(out) 124 | return out 125 | 126 | class ResNet_fe(nn.Module): 127 | 128 | def __init__(self, block, num_blocks): 129 | super(ResNet_fe, self).__init__() 130 | self.in_planes = 16 131 | 132 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 133 | self.bn1 = nn.BatchNorm2d(16) 134 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 135 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 136 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 137 | self.apply(_weights_init) 138 | 139 | def _make_layer(self, block, planes, num_blocks, stride): 140 | strides = [stride] + [1]*(num_blocks-1) 141 | layers = [] 142 | for stride in strides: 143 | layers.append(block(self.in_planes, planes, stride)) 144 | self.in_planes = planes * block.expansion 145 | 146 | return nn.Sequential(*layers) 147 | 148 | def forward(self, x): 149 | out = F.relu(self.bn1(self.conv1(x))) 150 | out = self.layer1(out) 151 | out = self.layer2(out) 152 | out = self.layer3(out) 153 | out = F.avg_pool2d(out, out.size()[3]) 154 | out = out.view(out.size(0), -1) 155 | return out 156 | 157 | class Classifier(nn.Module): 158 | def __init__(self, feat_in, num_classes): 159 | super(Classifier, self).__init__() 160 | self.fc = nn.Linear(feat_in, num_classes) 161 | self.apply(_weights_init) 162 | 163 | def forward(self, x): 164 | x = self.fc(x) 165 | return x 166 | 167 | 168 | def resnet20(): 169 | return ResNet_s(BasicBlock, [3, 3, 3]) 170 | 171 | def resnet32_fe(): 172 | return ResNet_fe(BasicBlock, [5, 5, 5]) 173 | 174 | def resnet32(num_classes=10, use_norm=False): 175 | return ResNet_s(BasicBlock, [5, 5, 5], num_classes=num_classes, use_norm=use_norm) 176 | 177 | 178 | def resnet44(): 179 | return ResNet_s(BasicBlock, [7, 7, 7]) 180 | 181 | 182 | def resnet56(): 183 | return ResNet_s(BasicBlock, [9, 9, 9]) 184 | 185 | 186 | def resnet110(): 187 | return ResNet_s(BasicBlock, [18, 18, 18]) 188 | 189 | 190 | def resnet1202(): 191 | return ResNet_s(BasicBlock, [200, 200, 200]) 192 | 193 | 194 | def test(net): 195 | import numpy as np 196 | total_params = 0 197 | 198 | for x in filter(lambda p: p.requires_grad, net.parameters()): 199 | total_params += np.prod(x.data.numpy().shape) 200 | print("Total number of params", total_params) 201 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 202 | 203 | 204 | if __name__ == "__main__": 205 | for net_name in __all__: 206 | if net_name.startswith('resnet'): 207 | print(net_name) 208 | test(globals()[net_name]()) 209 | print() -------------------------------------------------------------------------------- /models/resnet_places.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | try: 5 | from torch.hub import load_state_dict_from_url 6 | except ImportError: 7 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 8 | 9 | 10 | 11 | __all__ = ['Classifier', 'ResNet', 'resnet10', 'resnet10_fe', 'resnet18', 'resnet34', 'resnet50', 'resnet50_fe', 'resnet101', 'resnet101_fe', 12 | 'resnet152', 'resnet152_fe', 'resnext50_32x4d', 'resnext101_32x8d', 'resnext152_32x4d', 13 | 'wide_resnet50_2', 'wide_resnet101_2'] 14 | 15 | 16 | 17 | model_urls = { 18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 23 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 25 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 26 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 27 | } 28 | 29 | 30 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 31 | """3x3 convolution with padding""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 33 | padding=dilation, groups=groups, bias=False, dilation=dilation) 34 | 35 | 36 | def conv1x1(in_planes, out_planes, stride=1): 37 | """1x1 convolution""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | expansion = 1 43 | 44 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 45 | base_width=64, dilation=1, norm_layer=None): 46 | super(BasicBlock, self).__init__() 47 | if norm_layer is None: 48 | norm_layer = nn.BatchNorm2d 49 | if groups != 1 or base_width != 64: 50 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 51 | if dilation > 1: 52 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 53 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 54 | self.conv1 = conv3x3(inplanes, planes, stride) 55 | self.bn1 = norm_layer(planes) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.conv2 = conv3x3(planes, planes) 58 | self.bn2 = norm_layer(planes) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | identity = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | 72 | if self.downsample is not None: 73 | identity = self.downsample(x) 74 | 75 | out += identity 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | expansion = 4 83 | 84 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 85 | base_width=64, dilation=1, norm_layer=None): 86 | super(Bottleneck, self).__init__() 87 | if norm_layer is None: 88 | norm_layer = nn.BatchNorm2d 89 | width = int(planes * (base_width / 64.)) * groups 90 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 91 | self.conv1 = conv1x1(inplanes, width) 92 | self.bn1 = norm_layer(width) 93 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 94 | self.bn2 = norm_layer(width) 95 | self.conv3 = conv1x1(width, planes * self.expansion) 96 | self.bn3 = norm_layer(planes * self.expansion) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.downsample = downsample 99 | self.stride = stride 100 | 101 | def forward(self, x): 102 | identity = x 103 | 104 | out = self.conv1(x) 105 | out = self.bn1(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv2(out) 109 | out = self.bn2(out) 110 | out = self.relu(out) 111 | 112 | out = self.conv3(out) 113 | out = self.bn3(out) 114 | 115 | if self.downsample is not None: 116 | identity = self.downsample(x) 117 | 118 | out += identity 119 | out = self.relu(out) 120 | 121 | return out 122 | 123 | 124 | class ResNet(nn.Module): 125 | 126 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 127 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 128 | norm_layer=None): 129 | super(ResNet, self).__init__() 130 | if norm_layer is None: 131 | norm_layer = nn.BatchNorm2d 132 | self._norm_layer = norm_layer 133 | 134 | self.inplanes = 64 135 | self.dilation = 1 136 | if replace_stride_with_dilation is None: 137 | # each element in the tuple indicates if we should replace 138 | # the 2x2 stride with a dilated convolution instead 139 | replace_stride_with_dilation = [False, False, False] 140 | if len(replace_stride_with_dilation) != 3: 141 | raise ValueError("replace_stride_with_dilation should be None " 142 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 143 | self.groups = groups 144 | self.base_width = width_per_group 145 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 146 | bias=False) 147 | self.bn1 = norm_layer(self.inplanes) 148 | self.relu = nn.ReLU(inplace=True) 149 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 150 | self.layer1 = self._make_layer(block, 64, layers[0]) 151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 152 | dilate=replace_stride_with_dilation[0]) 153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 154 | dilate=replace_stride_with_dilation[1]) 155 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 156 | dilate=replace_stride_with_dilation[2]) 157 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 158 | self.fc = nn.Linear(512 * block.expansion, num_classes) 159 | 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 163 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 164 | nn.init.constant_(m.weight, 1) 165 | nn.init.constant_(m.bias, 0) 166 | 167 | # Zero-initialize the last BN in each residual branch, 168 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 169 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 170 | if zero_init_residual: 171 | for m in self.modules(): 172 | if isinstance(m, Bottleneck): 173 | nn.init.constant_(m.bn3.weight, 0) 174 | elif isinstance(m, BasicBlock): 175 | nn.init.constant_(m.bn2.weight, 0) 176 | 177 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 178 | norm_layer = self._norm_layer 179 | downsample = None 180 | previous_dilation = self.dilation 181 | if dilate: 182 | self.dilation *= stride 183 | stride = 1 184 | if stride != 1 or self.inplanes != planes * block.expansion: 185 | downsample = nn.Sequential( 186 | conv1x1(self.inplanes, planes * block.expansion, stride), 187 | norm_layer(planes * block.expansion), 188 | ) 189 | 190 | layers = [] 191 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 192 | self.base_width, previous_dilation, norm_layer)) 193 | self.inplanes = planes * block.expansion 194 | for _ in range(1, blocks): 195 | layers.append(block(self.inplanes, planes, groups=self.groups, 196 | base_width=self.base_width, dilation=self.dilation, 197 | norm_layer=norm_layer)) 198 | 199 | return nn.Sequential(*layers) 200 | 201 | def forward(self, x): 202 | x = self.conv1(x) 203 | x = self.bn1(x) 204 | x = self.relu(x) 205 | x = self.maxpool(x) 206 | 207 | x = self.layer1(x) 208 | x = self.layer2(x) 209 | x = self.layer3(x) 210 | x = self.layer4(x) 211 | 212 | x = self.avgpool(x) 213 | x = torch.flatten(x, 1) 214 | x = self.fc(x) 215 | 216 | return x 217 | 218 | class ResNet_FE(nn.Module): 219 | 220 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 221 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 222 | norm_layer=None): 223 | super(ResNet_FE, self).__init__() 224 | if norm_layer is None: 225 | norm_layer = nn.BatchNorm2d 226 | self._norm_layer = norm_layer 227 | 228 | self.inplanes = 64 229 | self.dilation = 1 230 | if replace_stride_with_dilation is None: 231 | # each element in the tuple indicates if we should replace 232 | # the 2x2 stride with a dilated convolution instead 233 | replace_stride_with_dilation = [False, False, False] 234 | if len(replace_stride_with_dilation) != 3: 235 | raise ValueError("replace_stride_with_dilation should be None " 236 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 237 | self.groups = groups 238 | self.base_width = width_per_group 239 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 240 | bias=False) 241 | self.bn1 = norm_layer(self.inplanes) 242 | self.relu = nn.ReLU(inplace=True) 243 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 244 | self.layer1 = self._make_layer(block, 64, layers[0]) 245 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 246 | dilate=replace_stride_with_dilation[0]) 247 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 248 | dilate=replace_stride_with_dilation[1]) 249 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 250 | dilate=replace_stride_with_dilation[2]) 251 | 252 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 253 | 254 | for m in self.modules(): 255 | if isinstance(m, nn.Conv2d): 256 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 257 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 258 | nn.init.constant_(m.weight, 1) 259 | nn.init.constant_(m.bias, 0) 260 | 261 | # Zero-initialize the last BN in each residual branch, 262 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 263 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 264 | if zero_init_residual: 265 | for m in self.modules(): 266 | if isinstance(m, Bottleneck): 267 | nn.init.constant_(m.bn3.weight, 0) 268 | elif isinstance(m, BasicBlock): 269 | nn.init.constant_(m.bn2.weight, 0) 270 | 271 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 272 | norm_layer = self._norm_layer 273 | downsample = None 274 | previous_dilation = self.dilation 275 | if dilate: 276 | self.dilation *= stride 277 | stride = 1 278 | if stride != 1 or self.inplanes != planes * block.expansion: 279 | downsample = nn.Sequential( 280 | conv1x1(self.inplanes, planes * block.expansion, stride), 281 | norm_layer(planes * block.expansion), 282 | ) 283 | 284 | layers = [] 285 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 286 | self.base_width, previous_dilation, norm_layer)) 287 | self.inplanes = planes * block.expansion 288 | for _ in range(1, blocks): 289 | layers.append(block(self.inplanes, planes, groups=self.groups, 290 | base_width=self.base_width, dilation=self.dilation, 291 | norm_layer=norm_layer)) 292 | 293 | return nn.Sequential(*layers) 294 | 295 | def forward(self, x): 296 | x = self.conv1(x) 297 | x = self.bn1(x) 298 | x = self.relu(x) 299 | x = self.maxpool(x) 300 | 301 | x = self.layer1(x) 302 | x = self.layer2(x) 303 | x = self.layer3(x) 304 | x = self.layer4(x) 305 | 306 | 307 | return x 308 | 309 | 310 | 311 | class Classifier(nn.Module): 312 | def __init__(self, feat_in, num_classes): 313 | super(Classifier, self).__init__() 314 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 315 | self.fc = nn.Linear(feat_in, num_classes) 316 | 317 | def forward(self, x): 318 | x = self.avgpool(x) 319 | x = torch.flatten(x, 1) 320 | x = self.fc(x) 321 | return x 322 | 323 | 324 | 325 | 326 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 327 | model = ResNet(block, layers, **kwargs) 328 | if pretrained: 329 | state_dict = load_state_dict_from_url(model_urls[arch], 330 | progress=progress) 331 | model.load_state_dict(state_dict) 332 | return model 333 | 334 | def _resnet_fe(arch, block, layers, pretrained, progress, **kwargs): 335 | model = ResNet_FE(block, layers, **kwargs) 336 | if pretrained: 337 | state_dict = load_state_dict_from_url(model_urls[arch[:-3]], 338 | progress=progress) 339 | 340 | del state_dict['fc.weight'] 341 | del state_dict['fc.bias'] 342 | del state_dict["layer4.2.conv1.weight"] 343 | del state_dict["layer4.2.bn1.running_mean"] 344 | del state_dict["layer4.2.bn1.running_var"] 345 | del state_dict["layer4.2.bn1.weight"] 346 | del state_dict["layer4.2.bn1.bias"] 347 | del state_dict["layer4.2.conv2.weight"] 348 | del state_dict["layer4.2.bn2.running_mean"] 349 | del state_dict["layer4.2.bn2.running_var"] 350 | del state_dict["layer4.2.bn2.weight"] 351 | del state_dict["layer4.2.bn2.bias"] 352 | del state_dict["layer4.2.conv3.weight"] 353 | del state_dict["layer4.2.bn3.running_mean"] 354 | del state_dict["layer4.2.bn3.running_var"] 355 | del state_dict["layer4.2.bn3.weight"] 356 | del state_dict["layer4.2.bn3.bias"] 357 | 358 | model.load_state_dict(state_dict) 359 | 360 | return model 361 | 362 | 363 | def resnet10(pretrained=False, progress=True, **kwargs): 364 | r"""ResNet-18 model from 365 | `"Deep Residual Learning for Image Recognition" `_ 366 | 367 | Args: 368 | pretrained (bool): If True, returns a model pre-trained on ImageNet 369 | progress (bool): If True, displays a progress bar of the download to stderr 370 | """ 371 | return _resnet('resnet10', BasicBlock, [1, 1, 1, 1], pretrained, progress, 372 | **kwargs) 373 | 374 | 375 | def resnet10_fe(pretrained=False, progress=True, **kwargs): 376 | r"""ResNet-18 model from 377 | `"Deep Residual Learning for Image Recognition" `_ 378 | 379 | Args: 380 | pretrained (bool): If True, returns a model pre-trained on ImageNet 381 | progress (bool): If True, displays a progress bar of the download to stderr 382 | """ 383 | return _resnet_fe('resnet10_fe', BasicBlock, [1, 1, 1, 1], pretrained, progress, 384 | **kwargs) 385 | 386 | 387 | 388 | def resnet18(pretrained=False, progress=True, **kwargs): 389 | r"""ResNet-18 model from 390 | `"Deep Residual Learning for Image Recognition" `_ 391 | 392 | Args: 393 | pretrained (bool): If True, returns a model pre-trained on ImageNet 394 | progress (bool): If True, displays a progress bar of the download to stderr 395 | """ 396 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 397 | **kwargs) 398 | 399 | 400 | def resnet34(pretrained=False, progress=True, **kwargs): 401 | r"""ResNet-34 model from 402 | `"Deep Residual Learning for Image Recognition" `_ 403 | 404 | Args: 405 | pretrained (bool): If True, returns a model pre-trained on ImageNet 406 | progress (bool): If True, displays a progress bar of the download to stderr 407 | """ 408 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 409 | **kwargs) 410 | 411 | 412 | def resnet50(pretrained=False, progress=True, **kwargs): 413 | r"""ResNet-50 model from 414 | `"Deep Residual Learning for Image Recognition" `_ 415 | 416 | Args: 417 | pretrained (bool): If True, returns a model pre-trained on ImageNet 418 | progress (bool): If True, displays a progress bar of the download to stderr 419 | """ 420 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 421 | **kwargs) 422 | 423 | def resnet50_fe(pretrained=False, progress=True, **kwargs): 424 | r"""ResNet-18 model from 425 | `"Deep Residual Learning for Image Recognition" `_ 426 | 427 | Args: 428 | pretrained (bool): If True, returns a model pre-trained on ImageNet 429 | progress (bool): If True, displays a progress bar of the download to stderr 430 | """ 431 | return _resnet_fe('resnet50_fe', Bottleneck, [3, 4, 6, 3], pretrained, progress, 432 | **kwargs) 433 | 434 | 435 | def resnet101(pretrained=False, progress=True, **kwargs): 436 | r"""ResNet-101 model from 437 | `"Deep Residual Learning for Image Recognition" `_ 438 | 439 | Args: 440 | pretrained (bool): If True, returns a model pre-trained on ImageNet 441 | progress (bool): If True, displays a progress bar of the download to stderr 442 | """ 443 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 444 | **kwargs) 445 | 446 | 447 | def resnet101_fe(pretrained=False, progress=True, **kwargs): 448 | r"""ResNet-101 model from 449 | `"Deep Residual Learning for Image Recognition" `_ 450 | 451 | Args: 452 | pretrained (bool): If True, returns a model pre-trained on ImageNet 453 | progress (bool): If True, displays a progress bar of the download to stderr 454 | """ 455 | return _resnet_fe('resnet101_fe', Bottleneck, [3, 4, 23, 3], pretrained, progress, 456 | **kwargs) 457 | 458 | def resnet152(pretrained=False, progress=True, **kwargs): 459 | r"""ResNet-152 model from 460 | `"Deep Residual Learning for Image Recognition" `_ 461 | 462 | Args: 463 | pretrained (bool): If True, returns a model pre-trained on ImageNet 464 | progress (bool): If True, displays a progress bar of the download to stderr 465 | """ 466 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 467 | **kwargs) 468 | 469 | 470 | def resnet152_fe(pretrained=False, progress=True, **kwargs): 471 | r"""ResNet-152 model from 472 | `"Deep Residual Learning for Image Recognition" `_ 473 | 474 | Args: 475 | pretrained (bool): If True, returns a model pre-trained on ImageNet 476 | progress (bool): If True, displays a progress bar of the download to stderr 477 | 478 | Follow Liu et. al "Large-Scale Long-Tailed Recognition in an Open World", CVPR 2019 479 | just train the last block in ResNet-152 480 | 481 | """ 482 | return _resnet_fe('resnet152_fe', Bottleneck, [3, 8, 36, 2], pretrained, progress, 483 | **kwargs) 484 | 485 | 486 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 487 | r"""ResNeXt-50 32x4d model from 488 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 489 | 490 | Args: 491 | pretrained (bool): If True, returns a model pre-trained on ImageNet 492 | progress (bool): If True, displays a progress bar of the download to stderr 493 | """ 494 | kwargs['groups'] = 32 495 | kwargs['width_per_group'] = 4 496 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 497 | pretrained, progress, **kwargs) 498 | 499 | 500 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 501 | r"""ResNeXt-101 32x8d model from 502 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 503 | 504 | Args: 505 | pretrained (bool): If True, returns a model pre-trained on ImageNet 506 | progress (bool): If True, displays a progress bar of the download to stderr 507 | """ 508 | kwargs['groups'] = 32 509 | kwargs['width_per_group'] = 8 510 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 511 | pretrained, progress, **kwargs) 512 | 513 | def resnext152_32x4d(pretrained=False, progress=True, **kwargs): 514 | kwargs['groups'] = 32 515 | kwargs['width_per_group'] = 4 516 | return _resnet('resnext152_32x4d', Bottleneck, [3, 8, 36, 3], 517 | pretrained, progress, **kwargs) 518 | 519 | 520 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 521 | r"""Wide ResNet-50-2 model from 522 | `"Wide Residual Networks" `_ 523 | 524 | The model is the same as ResNet except for the bottleneck number of channels 525 | which is twice larger in every block. The number of channels in outer 1x1 526 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 527 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 528 | 529 | Args: 530 | pretrained (bool): If True, returns a model pre-trained on ImageNet 531 | progress (bool): If True, displays a progress bar of the download to stderr 532 | """ 533 | kwargs['width_per_group'] = 64 * 2 534 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 535 | pretrained, progress, **kwargs) 536 | 537 | 538 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 539 | r"""Wide ResNet-101-2 model from 540 | `"Wide Residual Networks" `_ 541 | 542 | The model is the same as ResNet except for the bottleneck number of channels 543 | which is twice larger in every block. The number of channels in outer 1x1 544 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 545 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 546 | 547 | Args: 548 | pretrained (bool): If True, returns a model pre-trained on ImageNet 549 | progress (bool): If True, displays a progress bar of the download to stderr 550 | """ 551 | kwargs['width_per_group'] = 64 * 2 552 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 553 | pretrained, progress, **kwargs) 554 | -------------------------------------------------------------------------------- /reliability_diagrams.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | import pdb 6 | sns.set_style("darkgrid") 7 | from matplotlib import rcParams 8 | # rcParams['font.family'] = 'Arial' 9 | rcParams['font.size'] = 75 10 | 11 | def compute_calibration(true_labels, pred_labels, confidences, num_bins=10): 12 | """Collects predictions into bins used to draw a reliability diagram. 13 | 14 | Arguments: 15 | true_labels: the true labels for the test examples 16 | pred_labels: the predicted labels for the test examples 17 | confidences: the predicted confidences for the test examples 18 | num_bins: number of bins 19 | 20 | The true_labels, pred_labels, confidences arguments must be NumPy arrays; 21 | pred_labels and true_labels may contain numeric or string labels. 22 | 23 | For a multi-class model, the predicted label and confidence should be those 24 | of the highest scoring class. 25 | 26 | Returns a dictionary containing the following NumPy arrays: 27 | accuracies: the average accuracy for each bin 28 | confidences: the average confidence for each bin 29 | counts: the number of examples in each bin 30 | bins: the confidence thresholds for each bin 31 | avg_accuracy: the accuracy over the entire test set 32 | avg_confidence: the average confidence over the entire test set 33 | expected_calibration_error: a weighted average of all calibration gaps 34 | max_calibration_error: the largest calibration gap across all bins 35 | """ 36 | assert(len(confidences) == len(pred_labels)) 37 | assert(len(confidences) == len(true_labels)) 38 | assert(num_bins > 0) 39 | 40 | bin_size = 1.0 / num_bins 41 | bins = np.linspace(0.0, 1.0, num_bins + 1) 42 | indices = np.digitize(confidences, bins, right=True) 43 | 44 | bin_accuracies = np.zeros(num_bins, dtype=np.float) 45 | bin_confidences = np.zeros(num_bins, dtype=np.float) 46 | bin_counts = np.zeros(num_bins, dtype=np.int) 47 | 48 | for b in range(num_bins): 49 | selected = np.where(indices == b + 1)[0] 50 | if len(selected) > 0: 51 | bin_accuracies[b] = np.mean(true_labels[selected] == pred_labels[selected]) 52 | bin_confidences[b] = np.mean(confidences[selected]) 53 | bin_counts[b] = len(selected) 54 | 55 | avg_acc = np.sum(bin_accuracies * bin_counts) / np.sum(bin_counts) 56 | avg_conf = np.sum(bin_confidences * bin_counts) / np.sum(bin_counts) 57 | 58 | gaps = np.abs(bin_accuracies - bin_confidences) 59 | ece = np.sum(gaps * bin_counts) / np.sum(bin_counts) 60 | mce = np.max(gaps) 61 | 62 | return { "accuracies": bin_accuracies, 63 | "confidences": bin_confidences, 64 | "counts": bin_counts, 65 | "bins": bins, 66 | "avg_accuracy": avg_acc, 67 | "avg_confidence": avg_conf, 68 | "expected_calibration_error": ece, 69 | "max_calibration_error": mce } 70 | 71 | 72 | def _reliability_diagram_subplot(ax, bin_data, draw_order=True, 73 | draw_ece=True, 74 | draw_bin_importance=False, 75 | title="Reliability Diagram", 76 | xlabel="Confidence", 77 | ylabel="Expected Accuracy"): 78 | """Draws a reliability diagram into a subplot.""" 79 | accuracies = bin_data["accuracies"] 80 | confidences = bin_data["confidences"] 81 | counts = bin_data["counts"] 82 | bins = bin_data["bins"] 83 | 84 | bin_size = 1.0 / len(counts) 85 | positions = bins[:-1] + bin_size/2.0 86 | 87 | widths = bin_size 88 | alphas = 0.3 89 | min_count = np.min(counts) 90 | max_count = np.max(counts) 91 | normalized_counts = (counts - min_count) / (max_count - min_count) 92 | 93 | if draw_bin_importance == "alpha": 94 | alphas = 0.2 + 0.8*normalized_counts 95 | elif draw_bin_importance == "width": 96 | widths = 0.1*bin_size + 0.9*bin_size*normalized_counts 97 | 98 | colors = np.zeros((len(counts), 4)) 99 | colors[:, 0] = 240 / 255. 100 | colors[:, 1] = 60 / 255. 101 | colors[:, 2] = 60 / 255. 102 | colors[:, 3] = alphas 103 | 104 | if draw_order == True: 105 | # pdb.set_trace() 106 | acc_plt = ax.bar(positions, accuracies, bottom=0, width=widths, 107 | color=sns.color_palette("Blues", 10)[4], edgecolor="white", alpha=1.0, linewidth=5, 108 | label="Accuracy") 109 | 110 | gap_plt = ax.bar(positions, np.abs(accuracies - confidences), 111 | bottom=np.minimum(accuracies, confidences), width=widths, 112 | color=sns.color_palette("Blues", 10)[7], edgecolor="white", linewidth=5, label="Gap") 113 | else: 114 | acc_plt = ax.bar(positions, accuracies, bottom=0, width=widths, 115 | color=sns.color_palette("Blues", 10)[4], edgecolor="white", alpha=1.0, linewidth=5, 116 | label="Accuracy") 117 | 118 | gap_plt = ax.bar(positions, np.abs(accuracies - confidences), 119 | bottom=np.minimum(accuracies, confidences), width=widths, 120 | color=sns.color_palette("Blues", 10)[7], edgecolor="white", linewidth=5, label="Gap") 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | ax.set_aspect("equal") 129 | ax.plot([0,1], [0,1], linestyle = "--", color="gray", linewidth=10) 130 | ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0]) 131 | ax.set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) 132 | 133 | if draw_ece: 134 | import matplotlib.patches as patches 135 | ax.add_patch( 136 | patches.Rectangle( 137 | (0.51, 0.015), # (x,y) 138 | 0.48, 0.2, # width and height 139 | # You can add rotation as well with 'angle' 140 | alpha=0.7, facecolor="white", edgecolor="white", linewidth=3, linestyle='solid') 141 | ) 142 | acc_avg = (bin_data["avg_accuracy"]) 143 | # ax.text(0.26, 0.79, "ECE=%.3f" % ece, color="black", 144 | # ha="right", va="bottom", transform=ax.transAxes) 145 | ax.text(0.98, 0.11, "ACC=%.3f" % acc_avg, color="black", 146 | ha="right", va="bottom", transform=ax.transAxes) 147 | ece = (bin_data["expected_calibration_error"]) 148 | # ax.text(0.26, 0.79, "ECE=%.3f" % ece, color="black", 149 | # ha="right", va="bottom", transform=ax.transAxes) 150 | ax.text(0.98, 0.03, "ECE=%.3f" % ece, color="black", 151 | ha="right", va="bottom", transform=ax.transAxes) 152 | 153 | ax.set_xlim(0, 1) 154 | ax.set_ylim(0, 1) 155 | #ax.set_xticks(bins) 156 | 157 | # ax.set_title(title, pad=40, fontsize=50) 158 | # ttl = ax.title 159 | # ttl.set_position([.45, 1.00]) 160 | ax.set_xlabel(xlabel) 161 | ax.set_ylabel(ylabel) 162 | 163 | ax.legend(handles=[gap_plt, acc_plt]) 164 | 165 | 166 | def _confidence_histogram_subplot(ax, bin_data, 167 | draw_averages=True, 168 | title="Examples per bin", 169 | xlabel="Confidence", 170 | ylabel="Count"): 171 | """Draws a confidence histogram into a subplot.""" 172 | counts = bin_data["counts"] 173 | bins = bin_data["bins"] 174 | 175 | bin_size = 1.0 / len(counts) 176 | positions = bins[:-1] + bin_size/2.0 177 | 178 | ax.bar(positions, counts, width=bin_size * 0.9) 179 | 180 | ax.set_xlim(0, 1) 181 | ax.set_title(title) 182 | ax.set_xlabel(xlabel) 183 | ax.set_ylabel(ylabel) 184 | 185 | if draw_averages: 186 | acc_plt = ax.axvline(x=bin_data["avg_accuracy"], ls="solid", lw=3, 187 | c="black", label="Accuracy") 188 | conf_plt = ax.axvline(x=bin_data["avg_confidence"], ls="dotted", lw=3, 189 | c="#444", label="Avg. confidence") 190 | ax.legend(handles=[acc_plt, conf_plt]) 191 | 192 | 193 | def _reliability_diagram_combined(bin_data, 194 | draw_ece, draw_bin_importance, draw_averages, 195 | title, figsize, dpi, return_fig): 196 | """Draws a reliability diagram and confidence histogram using the output 197 | from compute_calibration().""" 198 | figsize = (figsize[0], figsize[0]) 199 | 200 | fig, ax = plt.subplots(nrows=1, ncols=1, sharex=True, figsize=figsize, dpi=dpi) 201 | 202 | plt.tight_layout() 203 | plt.subplots_adjust(hspace=-0.1) 204 | 205 | _reliability_diagram_subplot(ax, bin_data, draw_ece, draw_bin_importance, 206 | title=title, xlabel="Confidence") 207 | 208 | # Draw the confidence histogram upside down. 209 | # orig_counts = bin_data["counts"] 210 | # bin_data["counts"] = -bin_data["counts"] 211 | # _confidence_histogram_subplot(ax[1], bin_data, draw_averages, title="") 212 | # bin_data["counts"] = orig_counts 213 | 214 | # # Also negate the ticks for the upside-down histogram. 215 | # new_ticks = np.abs(ax[1].get_yticks()).astype(np.int) 216 | # ax[1].set_yticklabels(new_ticks) 217 | 218 | plt.show() 219 | 220 | if return_fig: return fig 221 | 222 | 223 | def reliability_diagram(true_labels, pred_labels, confidences, num_bins=10, 224 | draw_ece=True, draw_bin_importance=False, 225 | draw_averages=True, title="Reliability Diagram", 226 | figsize=(20, 20), dpi=72, return_fig=False): 227 | """Draws a reliability diagram and confidence histogram in a single plot. 228 | 229 | First, the model's predictions are divided up into bins based on their 230 | confidence scores. 231 | 232 | The reliability diagram shows the gap between average accuracy and average 233 | confidence in each bin. These are the red bars. 234 | 235 | The black line is the accuracy, the other end of the bar is the confidence. 236 | 237 | Ideally, there is no gap and the black line is on the dotted diagonal. 238 | In that case, the model is properly calibrated and we can interpret the 239 | confidence scores as probabilities. 240 | 241 | The confidence histogram visualizes how many examples are in each bin. 242 | This is useful for judging how much each bin contributes to the calibration 243 | error. 244 | 245 | The confidence histogram also shows the overall accuracy and confidence. 246 | The closer these two lines are together, the better the calibration. 247 | 248 | The ECE or Expected Calibration Error is a summary statistic that gives the 249 | difference in expectation between confidence and accuracy. In other words, 250 | it's a weighted average of the gaps across all bins. A lower ECE is better. 251 | 252 | Arguments: 253 | true_labels: the true labels for the test examples 254 | pred_labels: the predicted labels for the test examples 255 | confidences: the predicted confidences for the test examples 256 | num_bins: number of bins 257 | draw_ece: whether to include the Expected Calibration Error 258 | draw_bin_importance: whether to represent how much each bin contributes 259 | to the total accuracy: False, "alpha", "widths" 260 | draw_averages: whether to draw the overall accuracy and confidence in 261 | the confidence histogram 262 | title: optional title for the plot 263 | figsize: setting for matplotlib; height is ignored 264 | dpi: setting for matplotlib 265 | return_fig: if True, returns the matplotlib Figure object 266 | """ 267 | bin_data = compute_calibration(true_labels, pred_labels, confidences, num_bins) 268 | return _reliability_diagram_combined(bin_data, draw_ece, draw_bin_importance, 269 | draw_averages, title, figsize=figsize, 270 | dpi=dpi, return_fig=return_fig) 271 | 272 | 273 | def reliability_diagrams(results, num_bins=10, 274 | draw_ece=True, draw_bin_importance=False, 275 | num_cols=4, dpi=72, return_fig=False): 276 | """Draws reliability diagrams for one or more models. 277 | 278 | Arguments: 279 | results: dictionary where the key is the model name and the value is 280 | a dictionary containing the true labels, predicated labels, and 281 | confidences for this model 282 | num_bins: number of bins 283 | draw_ece: whether to include the Expected Calibration Error 284 | draw_bin_importance: whether to represent how much each bin contributes 285 | to the total accuracy: False, "alpha", "widths" 286 | num_cols: how wide to make the plot 287 | dpi: setting for matplotlib 288 | return_fig: if True, returns the matplotlib Figure object 289 | """ 290 | ncols = num_cols 291 | 292 | nrows = (len(results) + ncols - 1) // ncols 293 | figsize = (ncols * 16, nrows * 16) 294 | 295 | fig, ax = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True, 296 | figsize=figsize, dpi=dpi, constrained_layout=True) 297 | fig.text(0.5, 0.00, 'Confidence', ha='center') 298 | 299 | for i, (plot_name, data) in enumerate(results.items()): 300 | y_true = data["true_labels"] 301 | y_pred = data["pred_labels"] 302 | y_conf = data["confidences"] 303 | 304 | bin_data = compute_calibration(y_true, y_pred, y_conf, num_bins) 305 | 306 | row = i // ncols 307 | col = i % ncols 308 | # pdb.set_trace() 309 | draw_order = True 310 | if i == 3: 311 | draw_order = False 312 | # pdb.set_trace() 313 | _reliability_diagram_subplot(ax[col], bin_data, draw_order, draw_ece, 314 | draw_bin_importance, 315 | title="\n".join(plot_name.split()), 316 | xlabel=" " if row == nrows - 1 else "", 317 | ylabel="Accuracy" if col == 0 else "") 318 | 319 | for i in range(i + 1, nrows * ncols): 320 | row = i // ncols 321 | col = i % ncols 322 | ax[row, col].axis("off") 323 | 324 | # plt.show() 325 | 326 | if return_fig: return fig -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision==0.4.0 2 | torch==1.2.0 3 | numpy==1.19.2 4 | yacs==0.1.8 5 | -------------------------------------------------------------------------------- /train_stage1.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | import numpy as np 8 | import pprint 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.distributed as dist 15 | import torch.optim 16 | import torch.multiprocessing as mp 17 | import torch.utils.data 18 | import torch.utils.data.distributed 19 | import torch.nn.functional as F 20 | 21 | from datasets.cifar10 import CIFAR10_LT 22 | from datasets.cifar100 import CIFAR100_LT 23 | from datasets.places import Places_LT 24 | from datasets.imagenet import ImageNet_LT 25 | from datasets.ina2018 import iNa2018 26 | 27 | from models import resnet 28 | from models import resnet_places 29 | from models import resnet_cifar 30 | 31 | from utils import config, update_config, create_logger 32 | from utils import AverageMeter, ProgressMeter 33 | from utils import accuracy, calibration 34 | 35 | from methods import mixup_data, mixup_criterion 36 | 37 | 38 | def parse_args(): 39 | parser = argparse.ArgumentParser(description='MiSLAS training (Stage-1)') 40 | parser.add_argument('--cfg', 41 | help='experiment configure file name', 42 | required=True, 43 | type=str) 44 | parser.add_argument('opts', 45 | help="Modify config options using the command-line", 46 | default=None, 47 | nargs=argparse.REMAINDER) 48 | args = parser.parse_args() 49 | update_config(config, args) 50 | 51 | return args 52 | 53 | 54 | best_acc1 = 0 55 | its_ece = 100 56 | 57 | 58 | def main(): 59 | args = parse_args() 60 | logger, model_dir = create_logger(config, args.cfg) 61 | logger.info('\n' + pprint.pformat(args)) 62 | logger.info('\n' + str(config)) 63 | 64 | if config.deterministic: 65 | seed = 0 66 | torch.backends.cudnn.deterministic = True 67 | torch.backends.cudnn.benchmark = False 68 | random.seed(seed) 69 | np.random.seed(seed) 70 | os.environ['PYTHONHASHSEED'] = str(seed) 71 | torch.manual_seed(seed) 72 | torch.cuda.manual_seed(seed) 73 | torch.cuda.manual_seed_all(seed) 74 | 75 | if config.gpu is not None: 76 | warnings.warn('You have chosen a specific GPU. This will completely ' 77 | 'disable data parallelism.') 78 | 79 | if config.dist_url == "env://" and config.world_size == -1: 80 | config.world_size = int(os.environ["WORLD_SIZE"]) 81 | 82 | config.distributed = config.world_size > 1 or config.multiprocessing_distributed 83 | 84 | ngpus_per_node = torch.cuda.device_count() 85 | if config.multiprocessing_distributed: 86 | # Since we have ngpus_per_node processes per node, the total world_size 87 | # needs to be adjusted accordingly 88 | config.world_size = ngpus_per_node * config.world_size 89 | # Use torch.multiprocessing.spawn to launch distributed processes: the 90 | # main_worker process function 91 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, config, logger)) 92 | else: 93 | # Simply call main_worker function 94 | main_worker(config.gpu, ngpus_per_node, config, logger, model_dir) 95 | 96 | 97 | def main_worker(gpu, ngpus_per_node, config, logger, model_dir): 98 | global best_acc1, its_ece 99 | config.gpu = gpu 100 | # start_time = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 101 | 102 | if config.gpu is not None: 103 | logger.info("Use GPU: {} for training".format(config.gpu)) 104 | 105 | if config.distributed: 106 | if config.dist_url == "env://" and config.rank == -1: 107 | config.rank = int(os.environ["RANK"]) 108 | if config.multiprocessing_distributed: 109 | # For multiprocessing distributed training, rank needs to be the 110 | # global rank among all the processes 111 | config.rank = config.rank * ngpus_per_node + gpu 112 | dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url, 113 | world_size=config.world_size, rank=config.rank) 114 | 115 | if config.dataset == 'cifar10' or config.dataset == 'cifar100': 116 | model = getattr(resnet_cifar, config.backbone)() 117 | classifier = getattr(resnet_cifar, 'Classifier')(feat_in=64, num_classes=config.num_classes) 118 | 119 | elif config.dataset == 'imagenet' or config.dataset == 'ina2018': 120 | model = getattr(resnet, config.backbone)() 121 | classifier = getattr(resnet, 'Classifier')(feat_in=2048, num_classes=config.num_classes) 122 | 123 | elif config.dataset == 'places': 124 | model = getattr(resnet_places, config.backbone)(pretrained=True) 125 | classifier = getattr(resnet_places, 'Classifier')(feat_in=2048, num_classes=config.num_classes) 126 | block = getattr(resnet_places, 'Bottleneck')(2048, 512, groups=1, base_width=64, dilation=1, norm_layer=nn.BatchNorm2d) 127 | 128 | if not torch.cuda.is_available(): 129 | logger.info('using CPU, this will be slow') 130 | elif config.distributed: 131 | # For multiprocessing distributed, DistributedDataParallel constructor 132 | # should always set the single device scope, otherwise, 133 | # DistributedDataParallel will use all available devices. 134 | if config.gpu is not None: 135 | torch.cuda.set_device(config.gpu) 136 | model.cuda(config.gpu) 137 | classifier.cuda(config.gpu) 138 | # When using a single GPU per process and per 139 | # DistributedDataParallel, we need to divide the batch size 140 | # ourselves based on the total number of GPUs we have 141 | config.batch_size = int(config.batch_size / ngpus_per_node) 142 | config.workers = int((config.workers + ngpus_per_node - 1) / ngpus_per_node) 143 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu]) 144 | classifier = torch.nn.parallel.DistributedDataParallel(classifier, device_ids=[config.gpu]) 145 | if config.dataset == 'places': 146 | block.cuda(config.gpu) 147 | block = torch.nn.parallel.DistributedDataParallel(block, device_ids=[config.gpu]) 148 | else: 149 | model.cuda() 150 | classifier.cuda() 151 | # DistributedDataParallel will divide and allocate batch_size to all 152 | # available GPUs if device_ids are not set 153 | model = torch.nn.parallel.DistributedDataParallel(model) 154 | classifier = torch.nn.parallel.DistributedDataParallel(classifier) 155 | if config.dataset == 'places': 156 | block.cuda() 157 | block = torch.nn.parallel.DistributedDataParallel(block) 158 | elif config.gpu is not None: 159 | torch.cuda.set_device(config.gpu) 160 | model = model.cuda(config.gpu) 161 | classifier = classifier.cuda(config.gpu) 162 | if config.dataset == 'places': 163 | block.cuda(config.gpu) 164 | else: 165 | # DataParallel will divide and allocate batch_size to all available GPUs 166 | model = torch.nn.DataParallel(model).cuda() 167 | classifier = torch.nn.DataParallel(classifier).cuda() 168 | if config.dataset == 'places': 169 | block = torch.nn.DataParallel(block).cuda() 170 | 171 | # optionally resume from a checkpoint 172 | if config.resume: 173 | if os.path.isfile(config.resume): 174 | logger.info("=> loading checkpoint '{}'".format(config.resume)) 175 | if config.gpu is None: 176 | checkpoint = torch.load(config.resume) 177 | else: 178 | # Map model to be loaded to specified single gpu. 179 | loc = 'cuda:{}'.format(config.gpu) 180 | checkpoint = torch.load(config.resume, map_location=loc) 181 | # config.start_epoch = checkpoint['epoch'] 182 | best_acc1 = checkpoint['best_acc1'] 183 | if config.gpu is not None: 184 | # best_acc1 may be from a checkpoint from a different GPU 185 | best_acc1 = best_acc1.to(config.gpu) 186 | model.load_state_dict(checkpoint['state_dict_model']) 187 | classifier.load_state_dict(checkpoint['state_dict_classifier']) 188 | logger.info("=> loaded checkpoint '{}' (epoch {})" 189 | .format(config.resume, checkpoint['epoch'])) 190 | else: 191 | logger.info("=> no checkpoint found at '{}'".format(config.resume)) 192 | 193 | # Data loading code 194 | if config.dataset == 'cifar10': 195 | dataset = CIFAR10_LT(config.distributed, root=config.data_path, imb_factor=config.imb_factor, 196 | batch_size=config.batch_size, num_works=config.workers) 197 | 198 | elif config.dataset == 'cifar100': 199 | dataset = CIFAR100_LT(config.distributed, root=config.data_path, imb_factor=config.imb_factor, 200 | batch_size=config.batch_size, num_works=config.workers) 201 | 202 | elif config.dataset == 'places': 203 | dataset = Places_LT(config.distributed, root=config.data_path, 204 | batch_size=config.batch_size, num_works=config.workers) 205 | 206 | elif config.dataset == 'imagenet': 207 | dataset = ImageNet_LT(config.distributed, root=config.data_path, 208 | batch_size=config.batch_size, num_works=config.workers) 209 | 210 | elif config.dataset == 'ina2018': 211 | dataset = iNa2018(config.distributed, root=config.data_path, 212 | batch_size=config.batch_size, num_works=config.workers) 213 | 214 | train_loader = dataset.train_instance 215 | val_loader = dataset.eval 216 | if config.distributed: 217 | train_sampler = dataset.dist_sampler 218 | 219 | # define loss function (criterion) and optimizer 220 | criterion = nn.CrossEntropyLoss().cuda(config.gpu) 221 | 222 | if config.dataset == 'places': 223 | optimizer = torch.optim.SGD([{"params": block.parameters()}, 224 | {"params": classifier.parameters()}], config.lr, 225 | momentum=config.momentum, 226 | weight_decay=config.weight_decay) 227 | else: 228 | optimizer = torch.optim.SGD([{"params": model.parameters()}, 229 | {"params": classifier.parameters()}], config.lr, 230 | momentum=config.momentum, 231 | weight_decay=config.weight_decay) 232 | 233 | for epoch in range(config.num_epochs): 234 | if config.distributed: 235 | train_sampler.set_epoch(epoch) 236 | 237 | adjust_learning_rate(optimizer, epoch, config) 238 | 239 | if config.dataset != 'places': 240 | block = None 241 | # train for one epoch 242 | train(train_loader, model, classifier, criterion, optimizer, epoch, config, logger, block) 243 | 244 | # evaluate on validation set 245 | acc1, ece = validate(val_loader, model, classifier, criterion, config, logger, block) 246 | 247 | # remember best acc@1 and save checkpoint 248 | is_best = acc1 > best_acc1 249 | best_acc1 = max(acc1, best_acc1) 250 | if is_best: 251 | its_ece = ece 252 | logger.info('Best Prec@1: %.3f%% ECE: %.3f%%\n' % (best_acc1, its_ece)) 253 | 254 | if not config.multiprocessing_distributed or (config.multiprocessing_distributed 255 | and config.rank % ngpus_per_node == 0): 256 | if config.dataset == 'places': 257 | save_checkpoint({ 258 | 'epoch': epoch + 1, 259 | 'state_dict_model': model.state_dict(), 260 | 'state_dict_classifier': classifier.state_dict(), 261 | 'state_dict_block': block.state_dict(), 262 | 'best_acc1': best_acc1, 263 | 'its_ece': its_ece, 264 | }, is_best, model_dir) 265 | 266 | else: 267 | save_checkpoint({ 268 | 'epoch': epoch + 1, 269 | 'state_dict_model': model.state_dict(), 270 | 'state_dict_classifier': classifier.state_dict(), 271 | 'best_acc1': best_acc1, 272 | 'its_ece': its_ece, 273 | }, is_best, model_dir) 274 | 275 | 276 | def train(train_loader, model, classifier, criterion, optimizer, epoch, config, logger, block=None): 277 | batch_time = AverageMeter('Time', ':6.3f') 278 | data_time = AverageMeter('Data', ':6.3f') 279 | losses = AverageMeter('Loss', ':.3f') 280 | top1 = AverageMeter('Acc@1', ':6.3f') 281 | top5 = AverageMeter('Acc@5', ':6.3f') 282 | progress = ProgressMeter( 283 | len(train_loader), 284 | [batch_time, losses, top1, top5], 285 | prefix="Epoch: [{}]".format(epoch)) 286 | 287 | # switch to train mode 288 | if config.dataset == 'places': 289 | model.eval() 290 | block.train() 291 | else: 292 | model.train() 293 | classifier.train() 294 | 295 | training_data_num = len(train_loader.dataset) 296 | end_steps = int(training_data_num / train_loader.batch_size) 297 | 298 | end = time.time() 299 | for i, (images, target) in enumerate(train_loader): 300 | if i > end_steps: 301 | break 302 | 303 | # measure data loading time 304 | data_time.update(time.time() - end) 305 | 306 | if torch.cuda.is_available(): 307 | images = images.cuda(config.gpu, non_blocking=True) 308 | target = target.cuda(config.gpu, non_blocking=True) 309 | 310 | if config.mixup is True: 311 | images, targets_a, targets_b, lam = mixup_data(images, target, alpha=config.alpha) 312 | if config.dataset == 'places': 313 | with torch.no_grad(): 314 | feat_a = model(images) 315 | feat = block(feat_a.detach()) 316 | output = classifier(feat) 317 | else: 318 | feat = model(images) 319 | output = classifier(feat) 320 | loss = mixup_criterion(criterion, output, targets_a, targets_b, lam) 321 | else: 322 | if config.dataset == 'places': 323 | with torch.no_grad(): 324 | feat_a = model(images) 325 | feat = block(feat_a.detach()) 326 | output = classifier(feat) 327 | else: 328 | feat = model(images) 329 | output = classifier(feat) 330 | 331 | loss = criterion(output, target) 332 | 333 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 334 | losses.update(loss.item(), images.size(0)) 335 | top1.update(acc1[0], images.size(0)) 336 | top5.update(acc5[0], images.size(0)) 337 | 338 | # compute gradient and do SGD step 339 | optimizer.zero_grad() 340 | loss.backward() 341 | optimizer.step() 342 | 343 | # measure elapsed time 344 | batch_time.update(time.time() - end) 345 | end = time.time() 346 | 347 | if i % config.print_freq == 0: 348 | progress.display(i, logger) 349 | 350 | 351 | def validate(val_loader, model, classifier, criterion, config, logger, block=None): 352 | batch_time = AverageMeter('Time', ':6.3f') 353 | losses = AverageMeter('Loss', ':.3f') 354 | top1 = AverageMeter('Acc@1', ':6.3f') 355 | top5 = AverageMeter('Acc@5', ':6.3f') 356 | progress = ProgressMeter( 357 | len(val_loader), 358 | [batch_time, losses, top1, top5], 359 | prefix='Eval: ') 360 | 361 | # switch to evaluate mode 362 | model.eval() 363 | if config.dataset == 'places': 364 | block.eval() 365 | classifier.eval() 366 | class_num = torch.zeros(config.num_classes).cuda() 367 | correct = torch.zeros(config.num_classes).cuda() 368 | 369 | confidence = np.array([]) 370 | pred_class = np.array([]) 371 | true_class = np.array([]) 372 | 373 | with torch.no_grad(): 374 | end = time.time() 375 | for i, (images, target) in enumerate(val_loader): 376 | if config.gpu is not None: 377 | images = images.cuda(config.gpu, non_blocking=True) 378 | if torch.cuda.is_available(): 379 | target = target.cuda(config.gpu, non_blocking=True) 380 | 381 | # compute output 382 | feat = model(images) 383 | if config.dataset == 'places': 384 | feat = block(feat) 385 | output = classifier(feat) 386 | loss = criterion(output, target) 387 | 388 | # measure accuracy and record loss 389 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 390 | losses.update(loss.item(), images.size(0)) 391 | top1.update(acc1[0], images.size(0)) 392 | top5.update(acc5[0], images.size(0)) 393 | 394 | _, predicted = output.max(1) 395 | target_one_hot = F.one_hot(target, config.num_classes) 396 | predict_one_hot = F.one_hot(predicted, config.num_classes) 397 | class_num = class_num + target_one_hot.sum(dim=0).to(torch.float) 398 | correct = correct + (target_one_hot + predict_one_hot == 2).sum(dim=0).to(torch.float) 399 | 400 | prob = torch.softmax(output, dim=1) 401 | confidence_part, pred_class_part = torch.max(prob, dim=1) 402 | confidence = np.append(confidence, confidence_part.cpu().numpy()) 403 | pred_class = np.append(pred_class, pred_class_part.cpu().numpy()) 404 | true_class = np.append(true_class, target.cpu().numpy()) 405 | 406 | # measure elapsed time 407 | batch_time.update(time.time() - end) 408 | end = time.time() 409 | 410 | if i % config.print_freq == 0: 411 | progress.display(i, logger) 412 | 413 | acc_classes = correct / class_num 414 | head_acc = acc_classes[config.head_class_idx[0]:config.head_class_idx[1]].mean() * 100 415 | 416 | med_acc = acc_classes[config.med_class_idx[0]:config.med_class_idx[1]].mean() * 100 417 | tail_acc = acc_classes[config.tail_class_idx[0]:config.tail_class_idx[1]].mean() * 100 418 | logger.info('* Acc@1 {top1.avg:.3f}% Acc@5 {top5.avg:.3f}% HAcc {head_acc:.3f}% MAcc {med_acc:.3f}% TAcc {tail_acc:.3f}%.'.format(top1=top1, top5=top5, head_acc=head_acc, med_acc=med_acc, tail_acc=tail_acc)) 419 | 420 | cal = calibration(true_class, pred_class, confidence, num_bins=15) 421 | logger.info('* ECE {ece:.3f}%.'.format(ece=cal['expected_calibration_error'] * 100)) 422 | 423 | return top1.avg, cal['expected_calibration_error'] * 100 424 | 425 | 426 | def save_checkpoint(state, is_best, model_dir): 427 | filename = model_dir + '/current.pth.tar' 428 | torch.save(state, filename) 429 | if is_best: 430 | shutil.copyfile(filename, model_dir + '/model_best.pth.tar') 431 | 432 | 433 | def adjust_learning_rate(optimizer, epoch, config): 434 | """Sets the learning rate""" 435 | if config.cos: 436 | lr_min = 0 437 | lr_max = config.lr 438 | lr = lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(epoch / config.num_epochs * 3.1415926535)) 439 | else: 440 | epoch = epoch + 1 441 | if epoch <= 5: 442 | lr = config.lr * epoch / 5 443 | elif epoch > 180: 444 | lr = config.lr * 0.01 445 | elif epoch > 160: 446 | lr = config.lr * 0.1 447 | else: 448 | lr = config.lr 449 | 450 | for param_group in optimizer.param_groups: 451 | param_group['lr'] = lr 452 | 453 | 454 | if __name__ == '__main__': 455 | main() 456 | -------------------------------------------------------------------------------- /train_stage2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | import numpy as np 8 | import pprint 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.parallel 14 | import torch.distributed as dist 15 | import torch.optim 16 | import torch.multiprocessing as mp 17 | import torch.utils.data 18 | import torch.utils.data.distributed 19 | import torch.nn.functional as F 20 | 21 | from datasets.cifar10 import CIFAR10_LT 22 | from datasets.cifar100 import CIFAR100_LT 23 | from datasets.places import Places_LT 24 | from datasets.imagenet import ImageNet_LT 25 | from datasets.ina2018 import iNa2018 26 | 27 | from models import resnet 28 | from models import resnet_places 29 | from models import resnet_cifar 30 | 31 | from utils import config, update_config, create_logger 32 | from utils import AverageMeter, ProgressMeter 33 | from utils import accuracy, calibration 34 | 35 | from methods import mixup_data, mixup_criterion 36 | from methods import LabelAwareSmoothing, LearnableWeightScaling 37 | 38 | 39 | def parse_args(): 40 | parser = argparse.ArgumentParser(description='MiSLAS training (Stage-2)') 41 | parser.add_argument('--cfg', 42 | help='experiment configure file name', 43 | required=True, 44 | type=str) 45 | parser.add_argument('opts', 46 | help="Modify config options using the command-line", 47 | default=None, 48 | nargs=argparse.REMAINDER) 49 | args = parser.parse_args() 50 | update_config(config, args) 51 | 52 | return args 53 | 54 | 55 | best_acc1 = 0 56 | its_ece = 100 57 | 58 | 59 | def main(): 60 | 61 | args = parse_args() 62 | logger, model_dir = create_logger(config, args.cfg) 63 | logger.info('\n' + pprint.pformat(args)) 64 | logger.info('\n' + str(config)) 65 | 66 | if config.deterministic: 67 | seed = 0 68 | torch.backends.cudnn.deterministic = True 69 | torch.backends.cudnn.benchmark = False 70 | random.seed(seed) 71 | np.random.seed(seed) 72 | os.environ['PYTHONHASHSEED'] = str(seed) 73 | torch.manual_seed(seed) 74 | torch.cuda.manual_seed(seed) 75 | torch.cuda.manual_seed_all(seed) 76 | 77 | if config.gpu is not None: 78 | warnings.warn('You have chosen a specific GPU. This will completely ' 79 | 'disable data parallelism.') 80 | 81 | if config.dist_url == "env://" and config.world_size == -1: 82 | config.world_size = int(os.environ["WORLD_SIZE"]) 83 | 84 | config.distributed = config.world_size > 1 or config.multiprocessing_distributed 85 | 86 | ngpus_per_node = torch.cuda.device_count() 87 | if config.multiprocessing_distributed: 88 | # Since we have ngpus_per_node processes per node, the total world_size 89 | # needs to be adjusted accordingly 90 | config.world_size = ngpus_per_node * config.world_size 91 | # Use torch.multiprocessing.spawn to launch distributed processes: the 92 | # main_worker process function 93 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, config, logger)) 94 | else: 95 | # Simply call main_worker function 96 | main_worker(config.gpu, ngpus_per_node, config, logger, model_dir) 97 | 98 | 99 | def main_worker(gpu, ngpus_per_node, config, logger, model_dir): 100 | global best_acc1, its_ece 101 | config.gpu = gpu 102 | # start_time = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 103 | 104 | if config.gpu is not None: 105 | logger.info("Use GPU: {} for training".format(config.gpu)) 106 | 107 | if config.distributed: 108 | if config.dist_url == "env://" and config.rank == -1: 109 | config.rank = int(os.environ["RANK"]) 110 | if config.multiprocessing_distributed: 111 | # For multiprocessing distributed training, rank needs to be the 112 | # global rank among all the processes 113 | config.rank = config.rank * ngpus_per_node + gpu 114 | dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url, 115 | world_size=config.world_size, rank=config.rank) 116 | 117 | if config.dataset == 'cifar10' or config.dataset == 'cifar100': 118 | model = getattr(resnet_cifar, config.backbone)() 119 | classifier = getattr(resnet_cifar, 'Classifier')(feat_in=64, num_classes=config.num_classes) 120 | 121 | elif config.dataset == 'imagenet' or config.dataset == 'ina2018': 122 | model = getattr(resnet, config.backbone)() 123 | classifier = getattr(resnet, 'Classifier')(feat_in=2048, num_classes=config.num_classes) 124 | 125 | elif config.dataset == 'places': 126 | model = getattr(resnet_places, config.backbone)(pretrained=True) 127 | classifier = getattr(resnet_places, 'Classifier')(feat_in=2048, num_classes=config.num_classes) 128 | block = getattr(resnet_places, 'Bottleneck')(2048, 512, groups=1, base_width=64, 129 | dilation=1, norm_layer=nn.BatchNorm2d) 130 | 131 | lws_model = LearnableWeightScaling(num_classes=config.num_classes) 132 | 133 | if not torch.cuda.is_available(): 134 | logger.info('using CPU, this will be slow') 135 | elif config.distributed: 136 | # For multiprocessing distributed, DistributedDataParallel constructor 137 | # should always set the single device scope, otherwise, 138 | # DistributedDataParallel will use all available devices. 139 | if config.gpu is not None: 140 | torch.cuda.set_device(config.gpu) 141 | model.cuda(config.gpu) 142 | classifier.cuda(config.gpu) 143 | lws_model.cuda(config.gpu) 144 | # When using a single GPU per process and per 145 | # DistributedDataParallel, we need to divide the batch size 146 | # ourselves based on the total number of GPUs we have 147 | config.batch_size = int(config.batch_size / ngpus_per_node) 148 | config.workers = int((config.workers + ngpus_per_node - 1) / ngpus_per_node) 149 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu]) 150 | classifier = torch.nn.parallel.DistributedDataParallel(classifier, device_ids=[config.gpu]) 151 | lws_model = torch.nn.parallel.DistributedDataParallel(lws_model, device_ids=[config.gpu]) 152 | 153 | if config.dataset == 'places': 154 | block.cuda(config.gpu) 155 | block = torch.nn.parallel.DistributedDataParallel(block, device_ids=[config.gpu]) 156 | else: 157 | model.cuda() 158 | classifier.cuda() 159 | lws_model.cuda() 160 | # DistributedDataParallel will divide and allocate batch_size to all 161 | # available GPUs if device_ids are not set 162 | model = torch.nn.parallel.DistributedDataParallel(model) 163 | classifier = torch.nn.parallel.DistributedDataParallel(classifier) 164 | lws_model = torch.nn.parallel.DistributedDataParallel(lws_model) 165 | if config.dataset == 'places': 166 | block.cuda() 167 | block = torch.nn.parallel.DistributedDataParallel(block) 168 | 169 | elif config.gpu is not None: 170 | torch.cuda.set_device(config.gpu) 171 | model = model.cuda(config.gpu) 172 | classifier = classifier.cuda(config.gpu) 173 | lws_model = lws_model.cuda(config.gpu) 174 | if config.dataset == 'places': 175 | block.cuda(config.gpu) 176 | else: 177 | # DataParallel will divide and allocate batch_size to all available GPUs 178 | model = torch.nn.DataParallel(model).cuda() 179 | classifier = torch.nn.DataParallel(classifier).cuda() 180 | lws_model = torch.nn.DataParallel(lws_model).cuda() 181 | if config.dataset == 'places': 182 | block = torch.nn.DataParallel(block).cuda() 183 | 184 | # optionally resume from a checkpoint 185 | if config.resume: 186 | if os.path.isfile(config.resume): 187 | logger.info("=> loading checkpoint '{}'".format(config.resume)) 188 | if config.gpu is None: 189 | checkpoint = torch.load(config.resume) 190 | else: 191 | # Map model to be loaded to specified single gpu. 192 | loc = 'cuda:{}'.format(config.gpu) 193 | checkpoint = torch.load(config.resume, map_location=loc) 194 | # config.start_epoch = checkpoint['epoch'] 195 | best_acc1 = checkpoint['best_acc1'] 196 | its_ece = checkpoint['its_ece'] 197 | if config.gpu is not None: 198 | # best_acc1 may be from a checkpoint from a different GPU 199 | best_acc1 = best_acc1.to(config.gpu) 200 | model.load_state_dict(checkpoint['state_dict_model']) 201 | classifier.load_state_dict(checkpoint['state_dict_classifier']) 202 | if config.dataset == 'places': 203 | block.load_state_dict(checkpoint['state_dict_block']) 204 | logger.info("=> loaded checkpoint '{}' (epoch {})" 205 | .format(config.resume, checkpoint['epoch'])) 206 | else: 207 | logger.info("=> no checkpoint found at '{}'".format(config.resume)) 208 | 209 | # Data loading code 210 | if config.dataset == 'cifar10': 211 | dataset = CIFAR10_LT(config.distributed, root=config.data_path, imb_factor=config.imb_factor, 212 | batch_size=config.batch_size, num_works=config.workers) 213 | 214 | elif config.dataset == 'cifar100': 215 | dataset = CIFAR100_LT(config.distributed, root=config.data_path, imb_factor=config.imb_factor, 216 | batch_size=config.batch_size, num_works=config.workers) 217 | 218 | elif config.dataset == 'places': 219 | dataset = Places_LT(config.distributed, root=config.data_path, 220 | batch_size=config.batch_size, num_works=config.workers) 221 | 222 | elif config.dataset == 'imagenet': 223 | dataset = ImageNet_LT(config.distributed, root=config.data_path, 224 | batch_size=config.batch_size, num_works=config.workers) 225 | 226 | elif config.dataset == 'ina2018': 227 | dataset = iNa2018(config.distributed, root=config.data_path, 228 | batch_size=config.batch_size, num_works=config.workers) 229 | 230 | train_loader = dataset.train_balance 231 | val_loader = dataset.eval 232 | cls_num_list = dataset.cls_num_list 233 | if config.distributed: 234 | train_sampler = dataset.dist_sampler 235 | 236 | # define loss function (criterion) and optimizer 237 | 238 | criterion = LabelAwareSmoothing(cls_num_list=cls_num_list, smooth_head=config.smooth_head, 239 | smooth_tail=config.smooth_tail).cuda(config.gpu) 240 | 241 | optimizer = torch.optim.SGD([{"params": classifier.parameters()}, 242 | {'params': lws_model.parameters()}], config.lr, 243 | momentum=config.momentum, 244 | weight_decay=config.weight_decay) 245 | 246 | for epoch in range(config.num_epochs): 247 | if config.distributed: 248 | train_sampler.set_epoch(epoch) 249 | 250 | adjust_learning_rate(optimizer, epoch, config) 251 | 252 | if config.dataset != 'places': 253 | block = None 254 | # train for one epoch 255 | train(train_loader, model, classifier, lws_model, criterion, optimizer, epoch, config, logger, block) 256 | 257 | # evaluate on validation set 258 | acc1, ece = validate(val_loader, model, classifier, lws_model, criterion, config, logger, block) 259 | # remember best acc@1 and save checkpoint 260 | is_best = acc1 > best_acc1 261 | best_acc1 = max(acc1, best_acc1) 262 | if is_best: 263 | its_ece = ece 264 | logger.info('Best Prec@1: %.3f%% ECE: %.3f%%\n' % (best_acc1, its_ece)) 265 | if not config.multiprocessing_distributed or (config.multiprocessing_distributed 266 | and config.rank % ngpus_per_node == 0): 267 | if config.dataset == 'places': 268 | save_checkpoint({ 269 | 'epoch': epoch + 1, 270 | 'state_dict_model': model.state_dict(), 271 | 'state_dict_classifier': classifier.state_dict(), 272 | 'state_dict_block': block.state_dict(), 273 | 'state_dict_lws_model': lws_model.state_dict(), 274 | 'best_acc1': best_acc1, 275 | 'its_ece': its_ece, 276 | }, is_best, model_dir) 277 | else: 278 | save_checkpoint({ 279 | 'epoch': epoch + 1, 280 | 'state_dict_model': model.state_dict(), 281 | 'state_dict_classifier': classifier.state_dict(), 282 | 'state_dict_lws_model': lws_model.state_dict(), 283 | 'best_acc1': best_acc1, 284 | 'its_ece': its_ece, 285 | }, is_best, model_dir) 286 | 287 | 288 | def train(train_loader, model, classifier, lws_model, criterion, optimizer, epoch, config, logger, block=None): 289 | batch_time = AverageMeter('Time', ':6.3f') 290 | data_time = AverageMeter('Data', ':6.3f') 291 | losses = AverageMeter('Loss', ':.3f') 292 | top1 = AverageMeter('Acc@1', ':6.3f') 293 | top5 = AverageMeter('Acc@5', ':6.3f') 294 | training_data_num = len(train_loader.dataset) 295 | end_steps = int(np.ceil(float(training_data_num) / float(train_loader.batch_size))) 296 | progress = ProgressMeter( 297 | end_steps, 298 | [batch_time, losses, top1, top5], 299 | prefix="Epoch: [{}]".format(epoch)) 300 | 301 | # switch to train mode 302 | 303 | if config.dataset == 'places': 304 | model.eval() 305 | if config.shift_bn: 306 | block.train() 307 | else: 308 | block.eval() 309 | else: 310 | if config.shift_bn: 311 | model.train() 312 | else: 313 | model.eval() 314 | classifier.train() 315 | 316 | end = time.time() 317 | 318 | for i, (images, target) in enumerate(train_loader): 319 | if i > end_steps: 320 | break 321 | 322 | # measure data loading time 323 | data_time.update(time.time() - end) 324 | 325 | if torch.cuda.is_available(): 326 | images = images.cuda(config.gpu, non_blocking=True) 327 | target = target.cuda(config.gpu, non_blocking=True) 328 | 329 | if config.mixup is True: 330 | images, targets_a, targets_b, lam = mixup_data(images, target, alpha=config.alpha) 331 | with torch.no_grad(): 332 | if config.dataset == 'places': 333 | feat = block(model(images)) 334 | else: 335 | feat = model(images) 336 | output = classifier(feat.detach()) 337 | output = lws_model(output) 338 | loss = mixup_criterion(criterion, output, targets_a, targets_b, lam) 339 | else: 340 | # compute output 341 | with torch.no_grad(): 342 | if config.dataset == 'places': 343 | feat = block(model(images)) 344 | else: 345 | feat = model(images) 346 | output = classifier(feat.detach()) 347 | output = lws_model(output) 348 | loss = criterion(output, target) 349 | 350 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 351 | losses.update(loss.item(), images.size(0)) 352 | top1.update(acc1[0], images.size(0)) 353 | top5.update(acc5[0], images.size(0)) 354 | 355 | # compute gradient and do SGD step 356 | optimizer.zero_grad() 357 | loss.backward() 358 | optimizer.step() 359 | 360 | # measure elapsed time 361 | batch_time.update(time.time() - end) 362 | end = time.time() 363 | 364 | if i % config.print_freq == 0: 365 | progress.display(i, logger) 366 | 367 | 368 | def validate(val_loader, model, classifier, lws_model, criterion, config, logger, block=None): 369 | batch_time = AverageMeter('Time', ':6.3f') 370 | losses = AverageMeter('Loss', ':.3f') 371 | top1 = AverageMeter('Acc@1', ':6.3f') 372 | top5 = AverageMeter('Acc@5', ':6.3f') 373 | progress = ProgressMeter( 374 | len(val_loader), 375 | [batch_time, losses, top1, top5], 376 | prefix='Eval: ') 377 | 378 | # switch to evaluate mode 379 | model.eval() 380 | if config.dataset == 'places': 381 | block.eval() 382 | classifier.eval() 383 | class_num = torch.zeros(config.num_classes).cuda() 384 | correct = torch.zeros(config.num_classes).cuda() 385 | 386 | confidence = np.array([]) 387 | pred_class = np.array([]) 388 | true_class = np.array([]) 389 | 390 | with torch.no_grad(): 391 | end = time.time() 392 | for i, (images, target) in enumerate(val_loader): 393 | if config.gpu is not None: 394 | images = images.cuda(config.gpu, non_blocking=True) 395 | if torch.cuda.is_available(): 396 | target = target.cuda(config.gpu, non_blocking=True) 397 | 398 | # compute output 399 | if config.dataset == 'places': 400 | feat = block(model(images)) 401 | else: 402 | feat = model(images) 403 | output = classifier(feat) 404 | output = lws_model(output) 405 | loss = criterion(output, target) 406 | 407 | # measure accuracy and record loss 408 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 409 | losses.update(loss.item(), images.size(0)) 410 | top1.update(acc1[0], images.size(0)) 411 | top5.update(acc5[0], images.size(0)) 412 | 413 | _, predicted = output.max(1) 414 | target_one_hot = F.one_hot(target, config.num_classes) 415 | predict_one_hot = F.one_hot(predicted, config.num_classes) 416 | class_num = class_num + target_one_hot.sum(dim=0).to(torch.float) 417 | correct = correct + (target_one_hot + predict_one_hot == 2).sum(dim=0).to(torch.float) 418 | 419 | prob = torch.softmax(output, dim=1) 420 | confidence_part, pred_class_part = torch.max(prob, dim=1) 421 | confidence = np.append(confidence, confidence_part.cpu().numpy()) 422 | pred_class = np.append(pred_class, pred_class_part.cpu().numpy()) 423 | true_class = np.append(true_class, target.cpu().numpy()) 424 | 425 | # measure elapsed time 426 | batch_time.update(time.time() - end) 427 | end = time.time() 428 | 429 | if i % config.print_freq == 0: 430 | progress.display(i, logger) 431 | 432 | acc_classes = correct / class_num 433 | head_acc = acc_classes[config.head_class_idx[0]:config.head_class_idx[1]].mean() * 100 434 | med_acc = acc_classes[config.med_class_idx[0]:config.med_class_idx[1]].mean() * 100 435 | tail_acc = acc_classes[config.tail_class_idx[0]:config.tail_class_idx[1]].mean() * 100 436 | 437 | logger.info('* Acc@1 {top1.avg:.3f}% Acc@5 {top5.avg:.3f}% HAcc {head_acc:.3f}% MAcc {med_acc:.3f}% TAcc {tail_acc:.3f}%.'.format(top1=top1, top5=top5, head_acc=head_acc, med_acc=med_acc, tail_acc=tail_acc)) 438 | 439 | cal = calibration(true_class, pred_class, confidence, num_bins=15) 440 | logger.info('* ECE {ece:.3f}%.'.format(ece=cal['expected_calibration_error'] * 100)) 441 | 442 | return top1.avg, cal['expected_calibration_error'] * 100 443 | 444 | 445 | def save_checkpoint(state, is_best, model_dir): 446 | filename = model_dir + '/current.pth.tar' 447 | torch.save(state, filename) 448 | if is_best: 449 | shutil.copyfile(filename, model_dir + '/model_best.pth.tar') 450 | 451 | 452 | def adjust_learning_rate(optimizer, epoch, config): 453 | """Sets the learning rate""" 454 | lr_min = 0 455 | lr_max = config.lr 456 | lr = lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(epoch / config.num_epochs * 3.1415926535)) 457 | 458 | for idx, param_group in enumerate(optimizer.param_groups): 459 | if idx == 0: 460 | param_group['lr'] = config.lr_factor * lr 461 | else: 462 | param_group['lr'] = 1.00 * lr 463 | 464 | 465 | if __name__ == '__main__': 466 | main() 467 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .logger import _C as config 6 | from .logger import update_config, create_logger 7 | 8 | from .metric import accuracy, calibration 9 | from .meter import AverageMeter, ProgressMeter -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from yacs.config import CfgNode as CN 3 | import os 4 | import time 5 | import logging 6 | 7 | _C = CN() 8 | _C.name = '' 9 | _C.print_freq = 40 10 | _C.workers = 16 11 | _C.log_dir = 'logs' 12 | _C.model_dir = 'ckps' 13 | 14 | 15 | _C.dataset = 'cifar10' 16 | _C.data_path = './data/cifar10' 17 | _C.num_classes = 100 18 | _C.imb_factor = 0.01 19 | _C.backbone = 'resnet32_fe' 20 | _C.resume = '' 21 | _C.head_class_idx = [0, 1] 22 | _C.med_class_idx = [0, 1] 23 | _C.tail_class_idx = [0, 1] 24 | 25 | _C.deterministic = True 26 | _C.gpu = 0 27 | _C.world_size = -1 28 | _C.rank = -1 29 | _C.dist_url = 'tcp://224.66.41.62:23456' 30 | _C.dist_backend = 'nccl' 31 | _C.multiprocessing_distributed = False 32 | _C.distributed = False 33 | 34 | _C.mode = None 35 | _C.smooth_tail = None 36 | _C.smooth_head = None 37 | _C.shift_bn = False 38 | _C.lr_factor = None 39 | _C.lr = 0.1 40 | _C.batch_size = 128 41 | _C.weight_decay = 0.002 42 | _C.num_epochs = 200 43 | _C.momentum = 0.9 44 | _C.cos = False 45 | _C.mixup = True 46 | _C.alpha = 1.0 47 | 48 | def update_config(cfg, args): 49 | cfg.defrost() 50 | 51 | cfg.merge_from_file(args.cfg) 52 | cfg.merge_from_list(args.opts) 53 | 54 | # cfg.freeze() 55 | 56 | def create_logger(cfg, cfg_name): 57 | time_str = time.strftime('%Y%m%d%H%M') 58 | 59 | cfg_name = os.path.basename(cfg_name).split('.')[0] 60 | 61 | log_dir = Path("saved") / (cfg_name + '_' + time_str) / Path(cfg.log_dir) 62 | print('=> creating {}'.format(log_dir)) 63 | log_dir.mkdir(parents=True, exist_ok=True) 64 | 65 | 66 | log_file = '{}.txt'.format(cfg_name) 67 | final_log_file = log_dir / log_file 68 | head = '%(asctime)-15s %(message)s' 69 | logging.basicConfig(filename=str(final_log_file), 70 | format=head) 71 | logger = logging.getLogger() 72 | logger.setLevel(logging.INFO) 73 | console = logging.StreamHandler() 74 | logging.getLogger('').addHandler(console) 75 | 76 | model_dir = Path("saved") / (cfg_name + '_' + time_str) / Path(cfg.model_dir) 77 | print('=> creating {}'.format(model_dir)) 78 | model_dir.mkdir(parents=True, exist_ok=True) 79 | 80 | return logger, str(model_dir) -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | def __init__(self, name, fmt=':f'): 4 | self.name = name 5 | self.fmt = fmt 6 | self.reset() 7 | 8 | def reset(self): 9 | self.val = 0 10 | self.avg = 0 11 | self.sum = 0 12 | self.count = 0 13 | 14 | def update(self, val, n=1): 15 | self.val = val 16 | self.sum += val * n 17 | self.count += n 18 | self.avg = self.sum / self.count 19 | 20 | def __str__(self): 21 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 22 | return fmtstr.format(**self.__dict__) 23 | 24 | 25 | class ProgressMeter(object): 26 | def __init__(self, num_batches, meters, prefix=""): 27 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 28 | self.meters = meters 29 | self.prefix = prefix 30 | 31 | def display(self, batch, logger): 32 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 33 | entries += [str(meter) for meter in self.meters] 34 | logger.info('\t'.join(entries)) 35 | 36 | def _get_batch_fmtstr(self, num_batches): 37 | num_digits = len(str(num_batches // 1)) 38 | fmt = '{:' + str(num_digits) + 'd}' 39 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def accuracy(output, target, topk=(1,)): 5 | """Computes the accuracy over the k top predictions for the specified values of k""" 6 | with torch.no_grad(): 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res 19 | 20 | 21 | def calibration(true_labels, pred_labels, confidences, num_bins=15): 22 | """Collects predictions into bins used to draw a reliability diagram. 23 | 24 | Arguments: 25 | true_labels: the true labels for the test examples 26 | pred_labels: the predicted labels for the test examples 27 | confidences: the predicted confidences for the test examples 28 | num_bins: number of bins 29 | 30 | The true_labels, pred_labels, confidences arguments must be NumPy arrays; 31 | pred_labels and true_labels may contain numeric or string labels. 32 | 33 | For a multi-class model, the predicted label and confidence should be those 34 | of the highest scoring class. 35 | 36 | Returns a dictionary containing the following NumPy arrays: 37 | accuracies: the average accuracy for each bin 38 | confidences: the average confidence for each bin 39 | counts: the number of examples in each bin 40 | bins: the confidence thresholds for each bin 41 | avg_accuracy: the accuracy over the entire test set 42 | avg_confidence: the average confidence over the entire test set 43 | expected_calibration_error: a weighted average of all calibration gaps 44 | max_calibration_error: the largest calibration gap across all bins 45 | """ 46 | assert(len(confidences) == len(pred_labels)) 47 | assert(len(confidences) == len(true_labels)) 48 | assert(num_bins > 0) 49 | 50 | bin_size = 1.0 / num_bins 51 | bins = np.linspace(0.0, 1.0, num_bins + 1) 52 | indices = np.digitize(confidences, bins, right=True) 53 | 54 | bin_accuracies = np.zeros(num_bins, dtype=np.float) 55 | bin_confidences = np.zeros(num_bins, dtype=np.float) 56 | bin_counts = np.zeros(num_bins, dtype=np.int) 57 | 58 | for b in range(num_bins): 59 | selected = np.where(indices == b + 1)[0] 60 | if len(selected) > 0: 61 | bin_accuracies[b] = np.mean(true_labels[selected] == pred_labels[selected]) 62 | bin_confidences[b] = np.mean(confidences[selected]) 63 | bin_counts[b] = len(selected) 64 | 65 | avg_acc = np.sum(bin_accuracies * bin_counts) / np.sum(bin_counts) 66 | avg_conf = np.sum(bin_confidences * bin_counts) / np.sum(bin_counts) 67 | 68 | gaps = np.abs(bin_accuracies - bin_confidences) 69 | ece = np.sum(gaps * bin_counts) / np.sum(bin_counts) 70 | mce = np.max(gaps) 71 | 72 | return { "accuracies": bin_accuracies, 73 | "confidences": bin_confidences, 74 | "counts": bin_counts, 75 | "bins": bins, 76 | "avg_accuracy": avg_acc, 77 | "avg_confidence": avg_conf, 78 | "expected_calibration_error": ece, 79 | "max_calibration_error": mce } --------------------------------------------------------------------------------