├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.md ├── README.md ├── README_FixEfficientNet.md ├── hubconf.py ├── image ├── Fix-Efficient-Net.png └── image2.png ├── imnet_evaluate ├── Res.py ├── __init__.py ├── config.py ├── pnasnet.py ├── resnext_wsl.py ├── samplers.py ├── train.py └── transforms.py ├── imnet_extract ├── Res.py ├── __init__.py ├── config.py ├── pnasnet.py ├── resnext_wsl.py ├── samplers.py ├── train.py └── transforms.py ├── imnet_finetune ├── Res.py ├── __init__.py ├── config.py ├── pnasnet.py ├── resnext_wsl.py ├── samplers.py ├── train.py └── transforms.py ├── imnet_resnet50_scratch ├── __init__.py ├── config.py ├── samplers.py ├── train.py └── transforms.py ├── main_evaluate_imnet.py ├── main_evaluate_softmax.py ├── main_extract.py ├── main_finetune.py ├── main_resnet50_scratch.py ├── requirements.txt ├── setup.py └── transforms_v2.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to FixRes 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to spreadingvectors, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FixRes 2 | 3 | 4 | 5 | FixRes is a simple method for fixing the train-test resolution discrepancy. 6 | It can improve the performance of any convolutional neural network architecture. 7 | 8 | The method is described in "Fixing the train-test resolution discrepancy" (Links: [arXiv](https://arxiv.org/abs/1906.06423),[NeurIPS](https://papers.nips.cc/paper/9035-fixing-the-train-test-resolution-discrepancy)). 9 | 10 | BibTeX reference to cite, if you use it: 11 | ```bibtex 12 | @inproceedings{touvron2019FixRes, 13 | author = {Touvron, Hugo and Vedaldi, Andrea and Douze, Matthijs and J{\'e}gou, Herv{\'e}}, 14 | title = {Fixing the train-test resolution discrepancy}, 15 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, 16 | year = {2019}, 17 | } 18 | ``` 19 | 20 | Please notice that our models depend on previous trained models, see [References to other models](#references-to-other-models) 21 | 22 | # Installation 23 | 24 | The FixRes code requires 25 | * Python 3.6 or higher 26 | * PyTorch 1.0 or higher 27 | 28 | and the requirements highlighted in [requirements.txt](requirements.txt) (for Anaconda) 29 | 30 | # Cluster settings 31 | 32 | Ours codes were executed on a cluster with several GPUs. As configurations are different from one cluster to another, we provide a generic implementation. You must run the code on each GPU by specifying job-id, local-rank, global-rank, and num-tasks which is not very convenient. Therefore, we strongly recommend to adapt our code according to the configuration of your cluster. 33 | 34 | # Using the code 35 | 36 | The configurations given in the examples provide the results of the Pretrained Networks table (Table 2 in the article). 37 | The training and fine-tuning codes record the learned model in a checkpoint.pth file. 38 | 39 | ## Extracting features with pre-trained networks 40 | 41 | ### Pre-trained networks 42 | 43 | We provide pre-trained networks with different trunks, we report in the table validation resolution, Top-1 and Top-5 accuracy on ImageNet validation set: 44 | 45 | | Models | Resolution | #Parameters | Top-1 / Top-5 | Weights | 46 | |:---:|:-:|:------------:|:------:|:---------------------------------------------------------------------------------------:| 47 | | ResNet-50 Baseline| 224 | 25.6M | 77.0 / 93.4 | [FixResNet50_no_adaptation.pth](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/ResNet_no_adaptation.pth) | 48 | | FixResNet-50 | 384 | 25.6M | 79.0 / 94.6 | [FixResNet50.pth](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/ResNetFinetune.pth) | 49 | | FixResNet-50 (*)| 384 | 25.6M | 79.1 / 94.6 | [FixResNet50_v2.pth](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/ResNet50_v2.pth) | 50 | | FixResNet-50 CutMix | 320 | 25.6M | 79.7 / 94.9 | [FixResNet50CutMix.pth](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/ResNetCutMix.pth) | 51 | | FixResNet-50 CutMix (*)| 320 | 25.6M | 79.8 / 94.9 | [FixResNet50CutMix_v2.pth](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/ResNet50_CutMix_v2.pth) | 52 | | FixPNASNet-5 | 480 | 86.1M | 83.7 / 96.8 | [FixPNASNet.pth](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/PNASNet.pth) | 53 | | FixResNeXt-101 32x48d | 320 | 829M | 86.3 / 97.9 |[FixResNeXt101_32x48d.pth](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/ResNeXt_101_32x48d.pth) | 54 | | FixResNeXt-101 32x48d (*)| 320 | 829M | 86.4 / 98.0 |[FixResNeXt101_32x48d_v2.pth](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/ResNext101_32x48d_v2.pth) | 55 | | FixEfficientNet-B0 (+)| 320 | 5.3M | 80.2 / 95.4 |[FixEfficientNet](README_FixEfficientNet.md) | 56 | | FixEfficientNet-L2 (+)| 600 | 480M | 88.5 / 98.7 |[FixEfficientNet](README_FixEfficientNet.md) | 57 | 58 | (*) We use Horizontal flip, shifted Center Crop and color jittering for fine-tuning (described in [transforms_v2.py](transforms_v2.py)) 59 | 60 | (+) We report different results with our FixEfficientNet (see [FixEfficientNet](README_FixEfficientNet.md) for more details) 61 | 62 | To load a network, use the following PyTorch code: 63 | 64 | ```python 65 | import torch 66 | from .resnext_wsl import resnext101_32x48d_wsl 67 | 68 | model=resnext101_32x48d_wsl(progress=True) # example with the ResNeXt-101 32x48d 69 | 70 | pretrained_dict=torch.load('ResNeXt101_32x48d.pth',map_location='cpu')['model'] 71 | 72 | model_dict = model.state_dict() 73 | for k in model_dict.keys(): 74 | if(('module.'+k) in pretrained_dict.keys()): 75 | model_dict[k]=pretrained_dict.get(('module.'+k)) 76 | model.load_state_dict(model_dict) 77 | ``` 78 | The network takes images in any resolution. 79 | A normalization pre-processing step is used, with mean `[0.485, 0.456, 0.406]`. 80 | and standard deviation `[0.229, 0.224, 0.225]` for ResNet-50 and ResNeXt-101 32x48d, 81 | use mean `[0.5, 0.5, 0.5]` and standard deviation `[0.5, 0.5, 0.5]` with PNASNet. 82 | You can find the code in transforms.py. 83 | 84 | ### Features extracted from the ImageNet validation set 85 | 86 | We provide the probabilities, embedding and [labels](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Extracted_Features/labels.npy) of each image in the ImageNet validation so that the results can be reproduced easily. 87 | 88 | Embedding files are matrixes of size 50000 by 2048 for all models except for PNASNet where the size is 50000 by 4320, embeddings are extracted after the last spatial pooling. The softmax are matrixes of sizes 50000 by 1000 it representing the probability of each class for each image. 89 | 90 | | Model | Softmax | Embedding | 91 | |:---:|:------------------------------------------------------------:|:------------------------------------------------------------:| 92 | | FixResNet-50|[FixResNet50_Softmax.npy](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Extracted_Features/ResNet50_softmax.npy) |[FixResNet50Embedding.npy](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Extracted_Features/ResNet50_embedding.npy) | 93 | | FixResNet-50 (*)|[FixResNet50_Softmax_v2.npy](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Extracted_Features/ResNet50_softmax_v2.npy) |[FixResNet50Embedding_v2.npy](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Extracted_Features/ResNet50_embedding_v2.npy) | 94 | | FixResNet-50 CutMix|[FixResNet50_CutMix_Softmax.npy](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Extracted_Features/ResNet50CutMix_softmax.npy) |[FixResNet50_CutMix_Embedding.npy](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Extracted_Features/ResNet50CutMix_embedding.npy) | 95 | | FixResNet-50 CutMix (*)|[FixResNet50_CutMix_Softmax_v2.npy](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Extracted_Features/ResNet50CutMix_softmax_v2.npy) |[FixResNet50_CutMix_Embedding_v2.npy](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Extracted_Features/ResNet50CutMix_embedding_v2.npy) | 96 | | FixPNASNet-5|[FixPNASNet_Softmax.npy](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Extracted_Features/PNASNet_softmax.npy) |[FixPNASNet_Embedding.npy](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Extracted_Features/PNASNet_embedding.npy) | 97 | | FixResNeXt-101 32x48d|[FixResNeXt101_32x48d_Softmax.npy](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Extracted_Features/IGAM_Resnext101_32x48d_softmax.npy) |[FixResNeXt101_32x48d_Embedding.npy](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Extracted_Features/IGAM_Resnext101_32x48d_embedding.npy) | 98 | | FixResNeXt-101 32x48d (*)|[FixResNeXt101_32x48d_Softmax_v2.npy](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Extracted_Features/IGAM_Resnext101_32x48d_softmax_v2.npy) |[FixResNeXt101_32x48d_Embedding_v2.npy](https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Extracted_Features/IGAM_Resnext101_32x48d_embedding_v2.npy) | 99 | 100 | (*) We use Horizontal flip, shifted Center Crop and color jittering for fine-tuning (described in [transforms_v2.py](transforms_v2.py)) 101 | 102 | ## Evaluation of the networks 103 | 104 | See help (`-h` flag) for detailed parameter list of each script before executing the code. 105 | 106 | 107 | ### Classification results 108 | 109 | `main_evaluate_imnet.py` evaluates the network on standard benchmarks. 110 | 111 | `main_evaluate_softmax.py` evaluates the network on ImageNet-val with already extracted softmax output. (Much faster to execute) 112 | 113 | ### Example evaluation procedure 114 | 115 | ```bash 116 | # FixResNeXt-101 32x48d 117 | python main_evaluate_imnet.py --input-size 320 --architecture 'IGAM_Resnext101_32x48d' --weight-path 'ResNext101_32x48d.pth' 118 | # FixResNet-50 119 | python main_evaluate_imnet.py --input-size 384 --architecture 'ResNet50' --weight-path 'ResNet50.pth' 120 | 121 | #FixPNASNet-5 122 | python main_evaluate_imnet.py --input-size 480 --architecture 'PNASNet' --weight-path 'PNASNet.pth' 123 | ``` 124 | 125 | The following code give results that corresponds to table 2 in the paper : 126 | 127 | ```bash 128 | # FixResNeXt-101 32x48d 129 | python main_evaluate_softmax.py --architecture 'IGAM_Resnext101_32x48d' --save-path 'where_softmax_and_labels_are_saved' 130 | 131 | # FixPNASNet-5 132 | python main_evaluate_softmax.py --architecture 'PNASNet' --save-path 'where_softmax_and_labels_are_saved' 133 | 134 | # FixResNet50 135 | python main_evaluate_softmax.py --architecture 'ResNet50' --save-path 'where_softmax_and_labels_are_saved' 136 | ``` 137 | ### Features extraction 138 | 139 | `main_extract.py` extract embedding, labels and probability with the networks. 140 | 141 | ### Example extraction procedure 142 | 143 | ```bash 144 | # FixResNeXt-101 32x48d 145 | python main_extract.py --input-size 320 --architecture 'IGAM_Resnext101_32x48d' --weight-path 'ResNeXt101_32x48d.pth' --save-path 'where_output_will_be_save' 146 | # FixResNet-50 147 | python main_extract.py --input-size 384 --architecture 'ResNet50' --weight-path 'ResNet50.pth' --save-path 'where_output_will_be_save' 148 | 149 | # FixPNASNet-5 150 | python main_extract.py --input-size 480 --architecture 'PNASNet' --weight-path 'PNASNet.pth' --save-path 'where_output_will_be_save' 151 | ``` 152 | 153 | ## Fine-tuning existing network with our Method 154 | See help (`-h` flag) for detailed parameter list of each script before executing the code. 155 | 156 | ### Classifier and Batch-norm fine-tuning 157 | 158 | `main_finetune.py` fine-tune the network on standard benchmarks. 159 | 160 | ### Example fine-tuning procedure 161 | 162 | ```bash 163 | # FixResNeXt-101 32x48d 164 | python main_finetune.py --input-size 320 --architecture 'IGAM_Resnext101_32x48d' --epochs 1 --batch 8 --num-tasks 32 --learning-rate 1e-3 165 | 166 | # FixResNet-50 167 | python main_finetune.py --input-size 384 --architecture 'ResNet50' --epochs 56 --batch 64 --num-tasks 8 --learning-rate 1e-3 168 | 169 | # FixPNASNet-5 170 | python main_finetune.py --input-size 480 --architecture 'PNASNet' --epochs 1 --batch 64 --num-tasks 8 --learning-rate 1e-4 171 | ``` 172 | ### Using transforms_v2 for fine-tuning 173 | To reproduce our best results we must use the data-augmentation of transforms_v2 and use almost the same parameters as for the classic data augmentation, the only changes are the learning rate which must be 1e-4 and the number of epochs which must be 11. For FixResNet-50 fine-tune you have to use 31 epochs and a learning rate of 1e-3 and for FixResNet-50 CutMix you have to use 11 epochs and a learning rate of 1e-3. 174 | Here is how to use transforms_v2 : 175 | 176 | ```python 177 | from torchvision import datasets 178 | from .transforms_v2 import get_transforms 179 | 180 | transform = get_transforms(input_size=Train_size,test_size=Test_size, kind='full', crop=True, need=('train', 'val'), backbone=None) 181 | train_set = datasets.ImageFolder(train_path,transform=transform['val_train']) 182 | test_set = datasets.ImageFolder(val_path,transform=transform['val_test']) 183 | ``` 184 | 185 | ## Training 186 | 187 | See help (`-h` flag) for detailed parameter list of each script before executing the code. 188 | 189 | ### Train ResNet-50 from scratch 190 | 191 | `main_resnet50_scratch.py` Train ResNet-50 on standard benchmarks. 192 | 193 | ### Example training procedure 194 | 195 | ```bash 196 | # ResNet50 197 | python main_resnet50_scratch.py --batch 64 --num-tasks 8 --learning-rate 2e-2 198 | 199 | ``` 200 | 201 | ## Contributing 202 | See the [CONTRIBUTING](CONTRIBUTING.md) file for how to help out. 203 | 204 | ## References to other models 205 | 206 | Model definition scripts are based on https://github.com/pytorch/vision/ and https://github.com/Cadene/pretrained-models.pytorch. 207 | 208 | The Training from scratch implementation is based on https://github.com/facebookresearch/multigrain. 209 | 210 | Our FixResNet-50 CutMix is fine-tune from the weights of the GitHub page : https://github.com/clovaai/CutMix-PyTorch. 211 | The corresponding paper is 212 | ``` 213 | @inproceedings{2019arXivCutMix, 214 | author = {Sangdoo Yun and Dongyoon Han and Seong Joon Oh and Sanghyuk Chun and Junsuk Choe and Youngjoon Yoo, 215 | title = "{CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features}", 216 | journal = {arXiv e-prints}, 217 | year = "2019"} 218 | ``` 219 | 220 | Our FixResNeXt-101 32x48d is fine-tuned from the weights of the Pytorch Hub page : https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ 221 | 222 | The corresponding paper is 223 | ``` 224 | @inproceedings{mahajan2018exploring, 225 | author = {Mahajan, Dhruv and Girshick, Ross and Ramanathan, Vignesh and He, Kaiming and Paluri, Manohar and Li, Yixuan and Bharambe, Ashwin and van der Maaten, Laurens, 226 | title = "{Exploring the limits of weakly supervised pretraining}", 227 | journal = {European Conference on Computer Vision}, 228 | year = "2018"} 229 | ``` 230 | 231 | For FixEfficientNet we used model definition scripts and pretrained weights from https://github.com/rwightman/pytorch-image-models. 232 | 233 | The corresponding papers are: 234 | 235 | For models with extra-training data : 236 | 237 | ``` 238 | @misc{xie2019selftraining, 239 | author={Qizhe Xie and Minh-Thang Luong and Eduard Hovy and Quoc V. Le, 240 | title="{Self-training with Noisy Student improves ImageNet classification}", 241 | journal = {arXiv e-prints}, 242 | year=2019} 243 | } 244 | ``` 245 | 246 | For models without extra-training data : 247 | 248 | ``` 249 | @misc{xie2019adversarial, 250 | author={Cihang Xie and Mingxing Tan and Boqing Gong and Jiang Wang and Alan Yuille and Quoc V. Le, 251 | title="{Adversarial Examples Improve Image Recognition}", 252 | journal = {arXiv e-prints}, 253 | year="2019"} 254 | } 255 | ``` 256 | ## License 257 | FixRes is [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) licensed, as found in the LICENSE file. 258 | -------------------------------------------------------------------------------- /README_FixEfficientNet.md: -------------------------------------------------------------------------------- 1 | # FixEfficientNet 2 | 3 |

4 | 5 |

6 | 7 | [FixRes](https://github.com/facebookresearch/FixRes) is a simple method for fixing the train-test resolution discrepancy. 8 | It improves the performance of any convolutional neural network architecture. 9 | The method is described in the Neurips paper "[Fixing the train-test resolution discrepancy](https://papers.nips.cc/paper/9035-fixing-the-train-test-resolution-discrepancy)" ([More results on arXiv](https://arxiv.org/abs/1906.06423)). 10 | 11 | Hereafter we provide some results reported in [this note](https://arxiv.org/abs/2003.08237) for EfficientNet models. 12 | These models depend on and improve previous trained models, see the [references to other models](#references-to-other-models). 13 | 14 | ## ImageNet Results 15 | 16 | | Models | Resolution | #Parameters | Top-1 / Top-5 | Extra training data | 17 | |:---:|:-:|:------------:|:------:|:-----:| 18 | | FixEfficientNet-B0| 320 | 5.3M | 79.3 / 94.6 | | 19 | | FixEfficientNet-B0| 320 | 5.3M | 80.2 / 95.4 | x | 20 | | FixEfficientNet-B1| 384 | 7.8M | 81.3 / 95.7 | | 21 | | FixEfficientNet-B1| 384 | 7.8M | 82.6 / 96.4 | x | 22 | | FixEfficientNet-B2| 420 | 9.2M | 82.0 / 96.0 | | 23 | | FixEfficientNet-B2| 420 | 9.2M | 83.6 / 96.9 | x | 24 | | FixEfficientNet-B3| 472 | 12M | 83.0 / 96.4 | | 25 | | FixEfficientNet-B3| 472 | 12M | 85.0 / 97.4 | x | 26 | | FixEfficientNet-B4| 512 | 19M | 84.0 / 97.0 | | 27 | | FixEfficientNet-B4| 472 | 19M | 85.9 / 97.7 | x | 28 | | FixEfficientNet-B5| 576 | 30M | 84.7 / 97.2 | | 29 | | FixEfficientNet-B5| 576 | 30M | 86.4/ 97.9 | x | 30 | | FixEfficientNet-B6| 576 | 43M | 84.9 / 97.3 | | 31 | | FixEfficientNet-B6| 680 | 43M | 86.7 / 98.0 | x | 32 | | FixEfficientNet-B7| 632 | 66M | 85.3 / 97.4 | | 33 | | FixEfficientNet-B7| 632 | 66M | 87.1 / 98.2 | x | 34 | | FixEfficientNet-B8| 800 | 87.4M | 85.7 / 97.6 | | 35 | | FixEfficientNet-L2| 600 | 480M | 88.5 / 98.7 |x | 36 | 37 | ```bibtex 38 | @inproceedings{touvron2019FixRes, 39 | author = {Touvron, Hugo and Vedaldi, Andrea and Douze, Matthijs and J{\'e}gou, Herv{\'e}}, 40 | title = {Fixing the train-test resolution discrepancy}, 41 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, 42 | year = {2019}, 43 | } 44 | ``` 45 | 46 | ``` 47 | @misc{touvron2020FixEfficientNet, 48 | author = {Touvron, Hugo and Vedaldi, Andrea and Douze, Matthijs and J{\'e}gou, Herv{\'e}}, 49 | title = {Fixing the train-test resolution discrepancy: FixEfficientNet}, 50 | journal={arXiv preprint arXiv:2003.08237}, 51 | year = {2020}, 52 | } 53 | ``` 54 | 55 | ## References to other models 56 | 57 | Model definition scripts and pretrained weights are from https://github.com/rwightman/pytorch-image-models. 58 | 59 | The corresponding papers are as follows. 60 | 61 | For models with extra-training data: 62 | 63 | ``` 64 | @misc{xie2019selftraining, 65 | author={Qizhe Xie and Minh-Thang Luong and Eduard Hovy and Quoc V. Le, 66 | title="{Self-training with Noisy Student improves ImageNet classification}", 67 | journal = {arXiv preprint arXiv:1911.04252}, 68 | year=2019, 69 | } 70 | ``` 71 | 72 | For models without extra-training data: 73 | 74 | ``` 75 | @misc{xie2019adversarial, 76 | author={Cihang Xie and Mingxing Tan and Boqing Gong and Jiang Wang and Alan Yuille and Quoc V. Le, 77 | title="{Adversarial Examples Improve Image Recognition}", 78 | journal = {arXiv preprint arXiv:1911.09665}, 79 | year="2019", 80 | } 81 | ``` 82 | 83 | ``` 84 | @misc{tan2019efficientnet, 85 | author = {Mingxing Tan and Quoc V. Le}, 86 | title = "{EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks}", 87 | journal = {arXiv preprint arXiv:1905.11946}, 88 | year= "2019", 89 | } 90 | ``` 91 | 92 | ## License 93 | FixRes is [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) licensed, as found in the LICENSE file. 94 | 95 | 96 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from tqdm import tqdm 8 | import torch 9 | import hashlib 10 | import os 11 | import re 12 | import shutil 13 | import sys 14 | import tempfile 15 | 16 | try: 17 | from requests.utils import urlparse 18 | from requests import get as urlopen 19 | requests_available = True 20 | except ImportError: 21 | requests_available = False 22 | if sys.version_info[0] == 2: 23 | from urlparse import urlparse # noqa f811 24 | from urllib2 import urlopen # noqa f811 25 | else: 26 | from urllib.request import urlopen 27 | from urllib.parse import urlparse 28 | 29 | dependencies = ['torch', 'torchvision'] 30 | 31 | from torchvision.models.resnet import ResNet, Bottleneck 32 | 33 | def _download_url_to_file(url, dst, hash_prefix, progress): 34 | r""" 35 | function from https://pytorch.org/docs/stable/model_zoo.html 36 | """ 37 | if requests_available: 38 | u = urlopen(url, stream=True) 39 | file_size = int(u.headers["Content-Length"]) 40 | u = u.raw 41 | else: 42 | u = urlopen(url) 43 | meta = u.info() 44 | if hasattr(meta, 'getheaders'): 45 | file_size = int(meta.getheaders("Content-Length")[0]) 46 | else: 47 | file_size = int(meta.get_all("Content-Length")[0]) 48 | 49 | f = tempfile.NamedTemporaryFile(delete=False) 50 | try: 51 | if hash_prefix is not None: 52 | sha256 = hashlib.sha256() 53 | with tqdm(total=file_size, disable=not progress) as pbar: 54 | while True: 55 | buffer = u.read(8192) 56 | if len(buffer) == 0: 57 | break 58 | f.write(buffer) 59 | if hash_prefix is not None: 60 | sha256.update(buffer) 61 | pbar.update(len(buffer)) 62 | 63 | f.close() 64 | if hash_prefix is not None: 65 | digest = sha256.hexdigest() 66 | if digest[:len(hash_prefix)] != hash_prefix: 67 | raise RuntimeError('invalid hash value (expected "{}", got "{}")' 68 | .format(hash_prefix, digest)) 69 | shutil.move(f.name, dst) 70 | finally: 71 | f.close() 72 | if os.path.exists(f.name): 73 | os.remove(f.name) 74 | 75 | def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True): 76 | r""" 77 | function from https://pytorch.org/docs/stable/model_zoo.html 78 | """ 79 | if model_dir is None: 80 | torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch')) 81 | model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models')) 82 | if not os.path.exists(model_dir): 83 | os.makedirs(model_dir) 84 | parts = urlparse(url) 85 | filename = os.path.basename(parts.path) 86 | cached_file = os.path.join(model_dir, filename) 87 | if not os.path.exists(cached_file): 88 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 89 | hash_prefix = None 90 | _download_url_to_file(url, cached_file, hash_prefix, progress=progress) 91 | return torch.load(cached_file, map_location=map_location) 92 | 93 | 94 | model_urls = { 95 | 'FixResNet50': 'https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/ResNet50_v2.pth', 96 | 'FixResNet50CutMix': 'https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/ResNet50_CutMix_v2.pth', 97 | 'FixResNeXt101_32x48d': 'https://dl.fbaipublicfiles.com/FixRes_data/FixRes_Pretrained_Models/ResNext101_32x48d_v2.pth', 98 | } 99 | 100 | 101 | def _fixmodel(arch, block, layers, pretrained, progress, **kwargs): 102 | model = ResNet(block, layers, **kwargs) 103 | pretrained_dict = load_state_dict_from_url(model_urls[arch], progress=progress, map_location='cpu')['model'] 104 | model_dict = model.state_dict() 105 | count=0 106 | count2=0 107 | for k in model_dict.keys(): 108 | count=count+1.0 109 | if(('module.'+k) in pretrained_dict.keys()): 110 | count2=count2+1.0 111 | model_dict[k]=pretrained_dict.get(('module.'+k)) 112 | 113 | assert int(count2*100/count)== 100,"model loading error" 114 | 115 | model.load_state_dict(model_dict) 116 | return model 117 | 118 | def fixresnet_50(progress=True, **kwargs): 119 | """Constructs a FixResNet-50 120 | `"Fixing the train-test resolution discrepancy" `_ 121 | Args: 122 | progress (bool): If True, displays a progress bar of the download to stderr. 123 | """ 124 | 125 | return _fixmodel('FixResNet50', Bottleneck, [3, 4, 6, 3], True, progress, **kwargs) 126 | 127 | def fixresnet_50_CutMix(progress=True, **kwargs): 128 | """Constructs a FixRes-50 CutMix 129 | `"Fixing the train-test resolution discrepancy" `_ 130 | Args: 131 | progress (bool): If True, displays a progress bar of the download to stderr. 132 | """ 133 | return _fixmodel('FixResNet50CutMix', Bottleneck, [3, 4, 6, 3], True, progress, **kwargs) 134 | 135 | def fixresnext101_32x48d(progress=True, **kwargs): 136 | """Constructs a FixResNeXt-101 32x48 137 | `"Fixing the train-test resolution discrepancy" `_ 138 | Args: 139 | progress (bool): If True, displays a progress bar of the download to stderr. 140 | """ 141 | kwargs['groups'] = 32 142 | kwargs['width_per_group'] = 48 143 | return _fixmodel('FixResNeXt101_32x48d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 144 | -------------------------------------------------------------------------------- /image/Fix-Efficient-Net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/FixRes/c9be6acc7a6b32f896e62c28a97c20c2348327d3/image/Fix-Efficient-Net.png -------------------------------------------------------------------------------- /image/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/FixRes/c9be6acc7a6b32f896e62c28a97c20c2348327d3/image/image2.png -------------------------------------------------------------------------------- /imnet_evaluate/Res.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | ''' 8 | Code from : https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 9 | ''' 10 | 11 | import torch.nn as nn 12 | try: 13 | from torch.hub import load_state_dict_from_url 14 | except ImportError: 15 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 16 | 17 | 18 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 19 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 20 | 21 | 22 | model_urls = { 23 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 24 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 25 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 26 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 27 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 28 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 29 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 30 | } 31 | 32 | 33 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 34 | """3x3 convolution with padding""" 35 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 36 | padding=dilation, groups=groups, bias=False, dilation=dilation) 37 | 38 | 39 | def conv1x1(in_planes, out_planes, stride=1): 40 | """1x1 convolution""" 41 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 42 | 43 | 44 | class BasicBlock(nn.Module): 45 | expansion = 1 46 | 47 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 48 | base_width=64, dilation=1, norm_layer=None): 49 | super(BasicBlock, self).__init__() 50 | if norm_layer is None: 51 | norm_layer = nn.BatchNorm2d 52 | if groups != 1 or base_width != 64: 53 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 54 | if dilation > 1: 55 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 56 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 57 | self.conv1 = conv3x3(inplanes, planes, stride) 58 | self.bn1 = norm_layer(planes) 59 | self.relu = nn.ReLU(inplace=True) 60 | self.conv2 = conv3x3(planes, planes) 61 | self.bn2 = norm_layer(planes) 62 | self.downsample = downsample 63 | self.stride = stride 64 | 65 | def forward(self, x): 66 | identity = x 67 | 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | 75 | if self.downsample is not None: 76 | identity = self.downsample(x) 77 | 78 | out += identity 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | 84 | class Bottleneck(nn.Module): 85 | expansion = 4 86 | 87 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 88 | base_width=64, dilation=1, norm_layer=None): 89 | super(Bottleneck, self).__init__() 90 | if norm_layer is None: 91 | norm_layer = nn.BatchNorm2d 92 | width = int(planes * (base_width / 64.)) * groups 93 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 94 | self.conv1 = conv1x1(inplanes, width) 95 | self.bn1 = norm_layer(width) 96 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 97 | self.bn2 = norm_layer(width) 98 | self.conv3 = conv1x1(width, planes * self.expansion) 99 | self.bn3 = norm_layer(planes * self.expansion) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.downsample = downsample 102 | self.stride = stride 103 | 104 | def forward(self, x): 105 | identity = x 106 | 107 | out = self.conv1(x) 108 | out = self.bn1(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv2(out) 112 | out = self.bn2(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv3(out) 116 | out = self.bn3(out) 117 | 118 | if self.downsample is not None: 119 | identity = self.downsample(x) 120 | 121 | out += identity 122 | out = self.relu(out) 123 | 124 | return out 125 | 126 | 127 | class ResNet(nn.Module): 128 | 129 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 130 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 131 | norm_layer=None): 132 | super(ResNet, self).__init__() 133 | if norm_layer is None: 134 | norm_layer = nn.BatchNorm2d 135 | self._norm_layer = norm_layer 136 | 137 | self.inplanes = 64 138 | self.dilation = 1 139 | if replace_stride_with_dilation is None: 140 | # each element in the tuple indicates if we should replace 141 | # the 2x2 stride with a dilated convolution instead 142 | replace_stride_with_dilation = [False, False, False] 143 | if len(replace_stride_with_dilation) != 3: 144 | raise ValueError("replace_stride_with_dilation should be None " 145 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 146 | self.groups = groups 147 | self.base_width = width_per_group 148 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 149 | bias=False) 150 | self.bn1 = norm_layer(self.inplanes) 151 | self.relu = nn.ReLU(inplace=True) 152 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 153 | self.layer1 = self._make_layer(block, 64, layers[0]) 154 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 155 | dilate=replace_stride_with_dilation[0]) 156 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 157 | dilate=replace_stride_with_dilation[1]) 158 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 159 | dilate=replace_stride_with_dilation[2]) 160 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 161 | self.fc = nn.Linear(512 * block.expansion, num_classes) 162 | 163 | for m in self.modules(): 164 | if isinstance(m, nn.Conv2d): 165 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 166 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 167 | nn.init.constant_(m.weight, 1) 168 | nn.init.constant_(m.bias, 0) 169 | 170 | # Zero-initialize the last BN in each residual branch, 171 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 172 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 173 | if zero_init_residual: 174 | for m in self.modules(): 175 | if isinstance(m, Bottleneck): 176 | nn.init.constant_(m.bn3.weight, 0) 177 | elif isinstance(m, BasicBlock): 178 | nn.init.constant_(m.bn2.weight, 0) 179 | 180 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 181 | norm_layer = self._norm_layer 182 | downsample = None 183 | previous_dilation = self.dilation 184 | if dilate: 185 | self.dilation *= stride 186 | stride = 1 187 | if stride != 1 or self.inplanes != planes * block.expansion: 188 | downsample = nn.Sequential( 189 | conv1x1(self.inplanes, planes * block.expansion, stride), 190 | norm_layer(planes * block.expansion), 191 | ) 192 | 193 | layers = [] 194 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 195 | self.base_width, previous_dilation, norm_layer)) 196 | self.inplanes = planes * block.expansion 197 | for _ in range(1, blocks): 198 | layers.append(block(self.inplanes, planes, groups=self.groups, 199 | base_width=self.base_width, dilation=self.dilation, 200 | norm_layer=norm_layer)) 201 | 202 | return nn.Sequential(*layers) 203 | 204 | def forward(self, x): 205 | x = self.conv1(x) 206 | x = self.bn1(x) 207 | x = self.relu(x) 208 | x = self.maxpool(x) 209 | 210 | x = self.layer1(x) 211 | x = self.layer2(x) 212 | x = self.layer3(x) 213 | x = self.layer4(x) 214 | 215 | x = self.avgpool(x) 216 | x = x.reshape(x.size(0), -1) 217 | x = self.fc(x) 218 | 219 | return x 220 | 221 | 222 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 223 | model = ResNet(block, layers, **kwargs) 224 | if pretrained: 225 | state_dict = load_state_dict_from_url(model_urls[arch], 226 | progress=progress) 227 | model.load_state_dict(state_dict) 228 | return model 229 | 230 | 231 | def resnet18(pretrained=False, progress=True, **kwargs): 232 | """Constructs a ResNet-18 model. 233 | Args: 234 | pretrained (bool): If True, returns a model pre-trained on ImageNet 235 | progress (bool): If True, displays a progress bar of the download to stderr 236 | """ 237 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 238 | **kwargs) 239 | 240 | 241 | def resnet34(pretrained=False, progress=True, **kwargs): 242 | """Constructs a ResNet-34 model. 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | progress (bool): If True, displays a progress bar of the download to stderr 246 | """ 247 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 248 | **kwargs) 249 | 250 | 251 | def resnet50(pretrained=False, progress=True, **kwargs): 252 | """Constructs a ResNet-50 model. 253 | Args: 254 | pretrained (bool): If True, returns a model pre-trained on ImageNet 255 | progress (bool): If True, displays a progress bar of the download to stderr 256 | """ 257 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 258 | **kwargs) 259 | 260 | 261 | def resnet101(pretrained=False, progress=True, **kwargs): 262 | """Constructs a ResNet-101 model. 263 | Args: 264 | pretrained (bool): If True, returns a model pre-trained on ImageNet 265 | progress (bool): If True, displays a progress bar of the download to stderr 266 | """ 267 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 268 | **kwargs) 269 | 270 | 271 | def resnet152(pretrained=False, progress=True, **kwargs): 272 | """Constructs a ResNet-152 model. 273 | Args: 274 | pretrained (bool): If True, returns a model pre-trained on ImageNet 275 | progress (bool): If True, displays a progress bar of the download to stderr 276 | """ 277 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 278 | **kwargs) 279 | 280 | 281 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 282 | """Constructs a ResNeXt-50 32x4d model. 283 | Args: 284 | pretrained (bool): If True, returns a model pre-trained on ImageNet 285 | progress (bool): If True, displays a progress bar of the download to stderr 286 | """ 287 | kwargs['groups'] = 32 288 | kwargs['width_per_group'] = 4 289 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 290 | pretrained, progress, **kwargs) 291 | 292 | 293 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 294 | """Constructs a ResNeXt-101 32x8d model. 295 | Args: 296 | pretrained (bool): If True, returns a model pre-trained on ImageNet 297 | progress (bool): If True, displays a progress bar of the download to stderr 298 | """ 299 | kwargs['groups'] = 32 300 | kwargs['width_per_group'] = 8 301 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 302 | pretrained, progress, **kwargs) 303 | -------------------------------------------------------------------------------- /imnet_evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .train import Trainer 8 | from .config import TrainerConfig, ClusterConfig 9 | -------------------------------------------------------------------------------- /imnet_evaluate/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from typing import NamedTuple 8 | 9 | 10 | class ClusterConfig(NamedTuple): 11 | dist_backend: str 12 | dist_url: str 13 | 14 | 15 | class TrainerConfig(NamedTuple): 16 | data_folder: str 17 | architecture: str 18 | weight_path: str 19 | imnet_path: str 20 | workers: int 21 | input_size: int 22 | batch_per_gpu: int 23 | local_rank: int 24 | global_rank: int 25 | num_tasks: int 26 | job_id: str 27 | save_folder: str 28 | -------------------------------------------------------------------------------- /imnet_evaluate/pnasnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | ''' 8 | Code from https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/pnasnet.py 9 | with some adaptations 10 | ''' 11 | 12 | 13 | 14 | from __future__ import print_function, division, absolute_import 15 | from collections import OrderedDict 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.utils.model_zoo as model_zoo 20 | 21 | 22 | pretrained_settings = { 23 | 'pnasnet5large': { 24 | 'imagenet': { 25 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth', 26 | 'input_space': 'RGB', 27 | 'input_size': [3, 331, 331], 28 | 'input_range': [0, 1], 29 | 'mean': [0.5, 0.5, 0.5], 30 | 'std': [0.5, 0.5, 0.5], 31 | 'num_classes': 1000 32 | }, 33 | 'imagenet+background': { 34 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth', 35 | 'input_space': 'RGB', 36 | 'input_size': [3, 331, 331], 37 | 'input_range': [0, 1], 38 | 'mean': [0.5, 0.5, 0.5], 39 | 'std': [0.5, 0.5, 0.5], 40 | 'num_classes': 1001 41 | } 42 | } 43 | } 44 | 45 | 46 | class MaxPool(nn.Module): 47 | 48 | def __init__(self, kernel_size, stride=1, padding=1, zero_pad=False): 49 | super(MaxPool, self).__init__() 50 | self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None 51 | self.pool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding) 52 | 53 | def forward(self, x): 54 | if self.zero_pad: 55 | x = self.zero_pad(x) 56 | x = self.pool(x) 57 | if self.zero_pad: 58 | x = x[:, :, 1:, 1:] 59 | return x 60 | 61 | 62 | class SeparableConv2d(nn.Module): 63 | 64 | def __init__(self, in_channels, out_channels, dw_kernel_size, dw_stride, 65 | dw_padding): 66 | super(SeparableConv2d, self).__init__() 67 | self.depthwise_conv2d = nn.Conv2d(in_channels, in_channels, 68 | kernel_size=dw_kernel_size, 69 | stride=dw_stride, padding=dw_padding, 70 | groups=in_channels, bias=False) 71 | self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, 72 | kernel_size=1, bias=False) 73 | 74 | def forward(self, x): 75 | x = self.depthwise_conv2d(x) 76 | x = self.pointwise_conv2d(x) 77 | return x 78 | 79 | 80 | class BranchSeparables(nn.Module): 81 | 82 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 83 | stem_cell=False, zero_pad=False): 84 | super(BranchSeparables, self).__init__() 85 | padding = kernel_size // 2 86 | middle_channels = out_channels if stem_cell else in_channels 87 | self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None 88 | self.relu_1 = nn.ReLU() 89 | self.separable_1 = SeparableConv2d(in_channels, middle_channels, 90 | kernel_size, dw_stride=stride, 91 | dw_padding=padding) 92 | self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001) 93 | self.relu_2 = nn.ReLU() 94 | self.separable_2 = SeparableConv2d(middle_channels, out_channels, 95 | kernel_size, dw_stride=1, 96 | dw_padding=padding) 97 | self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001) 98 | 99 | def forward(self, x): 100 | x = self.relu_1(x) 101 | if self.zero_pad: 102 | x = self.zero_pad(x) 103 | x = self.separable_1(x) 104 | if self.zero_pad: 105 | x = x[:, :, 1:, 1:].contiguous() 106 | x = self.bn_sep_1(x) 107 | x = self.relu_2(x) 108 | x = self.separable_2(x) 109 | x = self.bn_sep_2(x) 110 | return x 111 | 112 | 113 | class ReluConvBn(nn.Module): 114 | 115 | def __init__(self, in_channels, out_channels, kernel_size, stride=1): 116 | super(ReluConvBn, self).__init__() 117 | self.relu = nn.ReLU() 118 | self.conv = nn.Conv2d(in_channels, out_channels, 119 | kernel_size=kernel_size, stride=stride, 120 | bias=False) 121 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 122 | 123 | def forward(self, x): 124 | x = self.relu(x) 125 | x = self.conv(x) 126 | x = self.bn(x) 127 | return x 128 | 129 | 130 | class FactorizedReduction(nn.Module): 131 | 132 | def __init__(self, in_channels, out_channels): 133 | super(FactorizedReduction, self).__init__() 134 | self.relu = nn.ReLU() 135 | self.path_1 = nn.Sequential(OrderedDict([ 136 | ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), 137 | ('conv', nn.Conv2d(in_channels, out_channels // 2, 138 | kernel_size=1, bias=False)), 139 | ])) 140 | self.path_2 = nn.Sequential(OrderedDict([ 141 | ('pad', nn.ZeroPad2d((0, 1, 0, 1))), 142 | ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), 143 | ('conv', nn.Conv2d(in_channels, out_channels // 2, 144 | kernel_size=1, bias=False)), 145 | ])) 146 | self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001) 147 | 148 | def forward(self, x): 149 | x = self.relu(x) 150 | 151 | x_path1 = self.path_1(x) 152 | 153 | x_path2 = self.path_2.pad(x) 154 | x_path2 = x_path2[:, :, 1:, 1:] 155 | x_path2 = self.path_2.avgpool(x_path2) 156 | x_path2 = self.path_2.conv(x_path2) 157 | 158 | out = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) 159 | return out 160 | 161 | 162 | class CellBase(nn.Module): 163 | 164 | def cell_forward(self, x_left, x_right): 165 | x_comb_iter_0_left = self.comb_iter_0_left(x_left) 166 | x_comb_iter_0_right = self.comb_iter_0_right(x_left) 167 | x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right 168 | 169 | x_comb_iter_1_left = self.comb_iter_1_left(x_right) 170 | x_comb_iter_1_right = self.comb_iter_1_right(x_right) 171 | x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right 172 | 173 | x_comb_iter_2_left = self.comb_iter_2_left(x_right) 174 | x_comb_iter_2_right = self.comb_iter_2_right(x_right) 175 | x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right 176 | 177 | x_comb_iter_3_left = self.comb_iter_3_left(x_comb_iter_2) 178 | x_comb_iter_3_right = self.comb_iter_3_right(x_right) 179 | x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right 180 | 181 | x_comb_iter_4_left = self.comb_iter_4_left(x_left) 182 | if self.comb_iter_4_right: 183 | x_comb_iter_4_right = self.comb_iter_4_right(x_right) 184 | else: 185 | x_comb_iter_4_right = x_right 186 | x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right 187 | 188 | x_out = torch.cat( 189 | [x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, 190 | x_comb_iter_4], 1) 191 | return x_out 192 | 193 | 194 | class CellStem0(CellBase): 195 | 196 | def __init__(self, in_channels_left, out_channels_left, in_channels_right, 197 | out_channels_right): 198 | super(CellStem0, self).__init__() 199 | self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right, 200 | kernel_size=1) 201 | self.comb_iter_0_left = BranchSeparables(in_channels_left, 202 | out_channels_left, 203 | kernel_size=5, stride=2, 204 | stem_cell=True) 205 | self.comb_iter_0_right = nn.Sequential(OrderedDict([ 206 | ('max_pool', MaxPool(3, stride=2)), 207 | ('conv', nn.Conv2d(in_channels_left, out_channels_left, 208 | kernel_size=1, bias=False)), 209 | ('bn', nn.BatchNorm2d(out_channels_left, eps=0.001)), 210 | ])) 211 | self.comb_iter_1_left = BranchSeparables(out_channels_right, 212 | out_channels_right, 213 | kernel_size=7, stride=2) 214 | self.comb_iter_1_right = MaxPool(3, stride=2) 215 | self.comb_iter_2_left = BranchSeparables(out_channels_right, 216 | out_channels_right, 217 | kernel_size=5, stride=2) 218 | self.comb_iter_2_right = BranchSeparables(out_channels_right, 219 | out_channels_right, 220 | kernel_size=3, stride=2) 221 | self.comb_iter_3_left = BranchSeparables(out_channels_right, 222 | out_channels_right, 223 | kernel_size=3) 224 | self.comb_iter_3_right = MaxPool(3, stride=2) 225 | self.comb_iter_4_left = BranchSeparables(in_channels_right, 226 | out_channels_right, 227 | kernel_size=3, stride=2, 228 | stem_cell=True) 229 | self.comb_iter_4_right = ReluConvBn(out_channels_right, 230 | out_channels_right, 231 | kernel_size=1, stride=2) 232 | 233 | def forward(self, x_left): 234 | x_right = self.conv_1x1(x_left) 235 | x_out = self.cell_forward(x_left, x_right) 236 | return x_out 237 | 238 | 239 | class Cell(CellBase): 240 | 241 | def __init__(self, in_channels_left, out_channels_left, in_channels_right, 242 | out_channels_right, is_reduction=False, zero_pad=False, 243 | match_prev_layer_dimensions=False): 244 | super(Cell, self).__init__() 245 | 246 | # If `is_reduction` is set to `True` stride 2 is used for 247 | # convolutional and pooling layers to reduce the spatial size of 248 | # the output of a cell approximately by a factor of 2. 249 | stride = 2 if is_reduction else 1 250 | 251 | # If `match_prev_layer_dimensions` is set to `True` 252 | # `FactorizedReduction` is used to reduce the spatial size 253 | # of the left input of a cell approximately by a factor of 2. 254 | self.match_prev_layer_dimensions = match_prev_layer_dimensions 255 | if match_prev_layer_dimensions: 256 | self.conv_prev_1x1 = FactorizedReduction(in_channels_left, 257 | out_channels_left) 258 | else: 259 | self.conv_prev_1x1 = ReluConvBn(in_channels_left, 260 | out_channels_left, kernel_size=1) 261 | 262 | self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right, 263 | kernel_size=1) 264 | self.comb_iter_0_left = BranchSeparables(out_channels_left, 265 | out_channels_left, 266 | kernel_size=5, stride=stride, 267 | zero_pad=zero_pad) 268 | self.comb_iter_0_right = MaxPool(3, stride=stride, zero_pad=zero_pad) 269 | self.comb_iter_1_left = BranchSeparables(out_channels_right, 270 | out_channels_right, 271 | kernel_size=7, stride=stride, 272 | zero_pad=zero_pad) 273 | self.comb_iter_1_right = MaxPool(3, stride=stride, zero_pad=zero_pad) 274 | self.comb_iter_2_left = BranchSeparables(out_channels_right, 275 | out_channels_right, 276 | kernel_size=5, stride=stride, 277 | zero_pad=zero_pad) 278 | self.comb_iter_2_right = BranchSeparables(out_channels_right, 279 | out_channels_right, 280 | kernel_size=3, stride=stride, 281 | zero_pad=zero_pad) 282 | self.comb_iter_3_left = BranchSeparables(out_channels_right, 283 | out_channels_right, 284 | kernel_size=3) 285 | self.comb_iter_3_right = MaxPool(3, stride=stride, zero_pad=zero_pad) 286 | self.comb_iter_4_left = BranchSeparables(out_channels_left, 287 | out_channels_left, 288 | kernel_size=3, stride=stride, 289 | zero_pad=zero_pad) 290 | if is_reduction: 291 | self.comb_iter_4_right = ReluConvBn(out_channels_right, 292 | out_channels_right, 293 | kernel_size=1, stride=stride) 294 | else: 295 | self.comb_iter_4_right = None 296 | 297 | def forward(self, x_left, x_right): 298 | x_left = self.conv_prev_1x1(x_left) 299 | x_right = self.conv_1x1(x_right) 300 | x_out = self.cell_forward(x_left, x_right) 301 | return x_out 302 | 303 | 304 | class PNASNet5Large(nn.Module): 305 | def __init__(self, num_classes=1001): 306 | super(PNASNet5Large, self).__init__() 307 | self.num_classes = num_classes 308 | self.conv_0 = nn.Sequential(OrderedDict([ 309 | ('conv', nn.Conv2d(3, 96, kernel_size=3, stride=2, bias=False)), 310 | ('bn', nn.BatchNorm2d(96, eps=0.001)) 311 | ])) 312 | self.cell_stem_0 = CellStem0(in_channels_left=96, out_channels_left=54, 313 | in_channels_right=96, 314 | out_channels_right=54) 315 | self.cell_stem_1 = Cell(in_channels_left=96, out_channels_left=108, 316 | in_channels_right=270, out_channels_right=108, 317 | match_prev_layer_dimensions=True, 318 | is_reduction=True) 319 | self.cell_0 = Cell(in_channels_left=270, out_channels_left=216, 320 | in_channels_right=540, out_channels_right=216, 321 | match_prev_layer_dimensions=True) 322 | self.cell_1 = Cell(in_channels_left=540, out_channels_left=216, 323 | in_channels_right=1080, out_channels_right=216) 324 | self.cell_2 = Cell(in_channels_left=1080, out_channels_left=216, 325 | in_channels_right=1080, out_channels_right=216) 326 | self.cell_3 = Cell(in_channels_left=1080, out_channels_left=216, 327 | in_channels_right=1080, out_channels_right=216) 328 | self.cell_4 = Cell(in_channels_left=1080, out_channels_left=432, 329 | in_channels_right=1080, out_channels_right=432, 330 | is_reduction=True, zero_pad=True) 331 | self.cell_5 = Cell(in_channels_left=1080, out_channels_left=432, 332 | in_channels_right=2160, out_channels_right=432, 333 | match_prev_layer_dimensions=True) 334 | self.cell_6 = Cell(in_channels_left=2160, out_channels_left=432, 335 | in_channels_right=2160, out_channels_right=432) 336 | self.cell_7 = Cell(in_channels_left=2160, out_channels_left=432, 337 | in_channels_right=2160, out_channels_right=432) 338 | self.cell_8 = Cell(in_channels_left=2160, out_channels_left=864, 339 | in_channels_right=2160, out_channels_right=864, 340 | is_reduction=True) 341 | self.cell_9 = Cell(in_channels_left=2160, out_channels_left=864, 342 | in_channels_right=4320, out_channels_right=864, 343 | match_prev_layer_dimensions=True) 344 | self.cell_10 = Cell(in_channels_left=4320, out_channels_left=864, 345 | in_channels_right=4320, out_channels_right=864) 346 | self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864, 347 | in_channels_right=4320, out_channels_right=864) 348 | self.relu = nn.ReLU() 349 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 350 | self.dropout = nn.Dropout(0.5) 351 | self.last_linear = nn.Linear(4320, num_classes) 352 | 353 | def features(self, x): 354 | x_conv_0 = self.conv_0(x) 355 | x_stem_0 = self.cell_stem_0(x_conv_0) 356 | x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0) 357 | x_cell_0 = self.cell_0(x_stem_0, x_stem_1) 358 | x_cell_1 = self.cell_1(x_stem_1, x_cell_0) 359 | x_cell_2 = self.cell_2(x_cell_0, x_cell_1) 360 | x_cell_3 = self.cell_3(x_cell_1, x_cell_2) 361 | x_cell_4 = self.cell_4(x_cell_2, x_cell_3) 362 | x_cell_5 = self.cell_5(x_cell_3, x_cell_4) 363 | x_cell_6 = self.cell_6(x_cell_4, x_cell_5) 364 | x_cell_7 = self.cell_7(x_cell_5, x_cell_6) 365 | x_cell_8 = self.cell_8(x_cell_6, x_cell_7) 366 | x_cell_9 = self.cell_9(x_cell_7, x_cell_8) 367 | x_cell_10 = self.cell_10(x_cell_8, x_cell_9) 368 | x_cell_11 = self.cell_11(x_cell_9, x_cell_10) 369 | return x_cell_11 370 | 371 | def logits(self, features): 372 | x = self.relu(features) 373 | x = self.avg_pool(x) 374 | x = x.view(x.size(0), -1) 375 | x = self.dropout(x) 376 | x = self.last_linear(x) 377 | return x 378 | 379 | def forward(self, input): 380 | x = self.features(input) 381 | x = self.logits(x) 382 | return x 383 | 384 | 385 | def pnasnet5large(num_classes=1000, pretrained='imagenet'): 386 | r"""PNASNet-5 model architecture from the 387 | `"Progressive Neural Architecture Search" 388 | `_ paper. 389 | """ 390 | if pretrained: 391 | settings = pretrained_settings['pnasnet5large'][pretrained] 392 | assert num_classes == settings[ 393 | 'num_classes'], 'num_classes should be {}, but is {}'.format( 394 | settings['num_classes'], num_classes) 395 | 396 | # both 'imagenet'&'imagenet+background' are loaded from same parameters 397 | model = PNASNet5Large(num_classes=1001) 398 | model.load_state_dict(model_zoo.load_url(settings['url'])) 399 | 400 | if pretrained == 'imagenet': 401 | new_last_linear = nn.Linear(model.last_linear.in_features, 1000) 402 | new_last_linear.weight.data = model.last_linear.weight.data[1:] 403 | new_last_linear.bias.data = model.last_linear.bias.data[1:] 404 | model.last_linear = new_last_linear 405 | 406 | model.input_space = settings['input_space'] 407 | model.input_size = settings['input_size'] 408 | model.input_range = settings['input_range'] 409 | 410 | model.mean = settings['mean'] 411 | model.std = settings['std'] 412 | else: 413 | model = PNASNet5Large(num_classes=num_classes) 414 | return model 415 | -------------------------------------------------------------------------------- /imnet_evaluate/resnext_wsl.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # Optional list of dependencies required by the package 9 | 10 | ''' 11 | Code From : https://github.com/facebookresearch/WSL-Images/blob/master/hubconf.py 12 | ''' 13 | dependencies = ['torch', 'torchvision'] 14 | 15 | try: 16 | from torch.hub import load_state_dict_from_url 17 | except ImportError: 18 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 19 | 20 | from .Res import ResNet, Bottleneck 21 | 22 | 23 | model_urls = { 24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth', 25 | 'resnext101_32x16d': 'https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth', 26 | 'resnext101_32x32d': 'https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth', 27 | 'resnext101_32x48d': 'https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth', 28 | } 29 | 30 | 31 | def _resnext(arch, block, layers, pretrained, progress, **kwargs): 32 | model = ResNet(block, layers, **kwargs) 33 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 34 | model.load_state_dict(state_dict) 35 | return model 36 | 37 | 38 | def resnext101_32x8d_wsl(progress=True, **kwargs): 39 | """Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data 40 | and finetuned on ImageNet from Figure 5 in 41 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 42 | Args: 43 | progress (bool): If True, displays a progress bar of the download to stderr. 44 | """ 45 | kwargs['groups'] = 32 46 | kwargs['width_per_group'] = 8 47 | return _resnext('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 48 | 49 | 50 | def resnext101_32x16d_wsl(progress=True, **kwargs): 51 | """Constructs a ResNeXt-101 32x16 model pre-trained on weakly-supervised data 52 | and finetuned on ImageNet from Figure 5 in 53 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 54 | Args: 55 | progress (bool): If True, displays a progress bar of the download to stderr. 56 | """ 57 | kwargs['groups'] = 32 58 | kwargs['width_per_group'] = 16 59 | return _resnext('resnext101_32x16d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 60 | 61 | 62 | def resnext101_32x32d_wsl(progress=True, **kwargs): 63 | """Constructs a ResNeXt-101 32x32 model pre-trained on weakly-supervised data 64 | and finetuned on ImageNet from Figure 5 in 65 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 66 | Args: 67 | progress (bool): If True, displays a progress bar of the download to stderr. 68 | """ 69 | kwargs['groups'] = 32 70 | kwargs['width_per_group'] = 32 71 | return _resnext('resnext101_32x32d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 72 | 73 | 74 | def resnext101_32x48d_wsl(progress=True, **kwargs): 75 | """Constructs a ResNeXt-101 32x48 model pre-trained on weakly-supervised data 76 | and finetuned on ImageNet from Figure 5 in 77 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 78 | Args: 79 | progress (bool): If True, displays a progress bar of the download to stderr. 80 | """ 81 | kwargs['groups'] = 32 82 | kwargs['width_per_group'] = 48 83 | return _resnext('resnext101_32x48d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 84 | -------------------------------------------------------------------------------- /imnet_evaluate/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from torch.utils.data.sampler import BatchSampler 8 | import torch 9 | import numpy as np 10 | from torch.utils.data.dataloader import default_collate 11 | from collections.abc import Mapping, Sequence 12 | import math 13 | import torch.distributed as dist 14 | 15 | class RASampler(torch.utils.data.Sampler): 16 | """ 17 | Batch Sampler with Repeated Augmentations (RA) 18 | - dataset_len: original length of the dataset 19 | - batch_size 20 | - repetitions: instances per image 21 | - len_factor: multiplicative factor for epoch size 22 | """ 23 | 24 | def __init__(self,dataset,num_replicas, rank, dataset_len, batch_size, repetitions=1, len_factor=1.0, shuffle=False, drop_last=False): 25 | self.dataset=dataset 26 | self.dataset_len = dataset_len 27 | self.batch_size = batch_size 28 | self.repetitions = repetitions 29 | self.len_images = int(dataset_len * len_factor) 30 | self.shuffle = shuffle 31 | self.drop_last = drop_last 32 | if num_replicas is None: 33 | if not dist.is_available(): 34 | raise RuntimeError("Requires distributed package to be available") 35 | num_replicas = dist.get_world_size() 36 | if rank is None: 37 | if not dist.is_available(): 38 | raise RuntimeError("Requires distributed package to be available") 39 | rank = dist.get_rank() 40 | self.dataset = dataset 41 | self.num_replicas = num_replicas 42 | self.rank = rank 43 | self.epoch = 0 44 | self.num_samples = int(math.ceil(len(self.dataset) * self.repetitions * 1.0 / self.num_replicas)) 45 | self.total_size = self.num_samples * self.num_replicas 46 | 47 | 48 | def shuffler(self): 49 | if self.shuffle: 50 | new_perm = lambda: iter(np.random.permutation(self.dataset_len)) 51 | else: 52 | new_perm = lambda: iter(np.arange(self.dataset_len)) 53 | shuffle = new_perm() 54 | while True: 55 | try: 56 | index = next(shuffle) 57 | except StopIteration: 58 | shuffle = new_perm() 59 | index = next(shuffle) 60 | for repetition in range(self.repetitions): 61 | yield index 62 | 63 | def __iter__(self): 64 | shuffle = iter(self.shuffler()) 65 | seen = 0 66 | indices=[] 67 | for _ in range(self.len_images): 68 | index = next(shuffle) 69 | indices.append(index) 70 | indices += indices[:(self.total_size - len(indices))] 71 | assert len(indices) == self.total_size 72 | # subsample 73 | indices = indices[self.rank:self.total_size:self.num_replicas] 74 | assert len(indices) == self.num_samples 75 | 76 | return iter(indices) 77 | 78 | 79 | def __len__(self): 80 | return self.num_samples 81 | 82 | def set_epoch(self, epoch): 83 | self.epoch = epoch 84 | 85 | def list_collate(batch): 86 | """ 87 | Collate into a list instead of a tensor to deal with variable-sized inputs 88 | """ 89 | elem_type = type(batch[0]) 90 | if isinstance(batch[0], torch.Tensor): 91 | return batch 92 | elif elem_type.__module__ == 'numpy': 93 | if elem_type.__name__ == 'ndarray': 94 | return list_collate([torch.from_numpy(b) for b in batch]) 95 | elif isinstance(batch[0], Mapping): 96 | return {key: list_collate([d[key] for d in batch]) for key in batch[0]} 97 | elif isinstance(batch[0], Sequence): 98 | transposed = zip(*batch) 99 | return [list_collate(samples) for samples in transposed] 100 | return default_collate(batch) 101 | -------------------------------------------------------------------------------- /imnet_evaluate/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import os.path as osp 9 | from typing import Optional 10 | import torch 11 | import torch.distributed 12 | import torch.nn as nn 13 | import attr 14 | from torchvision import datasets 15 | import tqdm 16 | import torchvision.models as models 17 | import numpy as np 18 | from .config import TrainerConfig, ClusterConfig 19 | from .transforms import get_transforms 20 | from .resnext_wsl import resnext101_32x48d_wsl 21 | from collections import defaultdict 22 | from .pnasnet import pnasnet5large 23 | 24 | 25 | def accuracy_sp(output, target, topk=(1,)): 26 | """Computes the precision@k for the specified values of k""" 27 | maxk = max(topk) 28 | batch_size = target.size(0) 29 | 30 | _, pred = output.topk(maxk, 1, True, True) 31 | pred = pred.t() 32 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 33 | 34 | res = [] 35 | for k in topk: 36 | correct_k = correct[:k].view(-1).float().sum(0).item() 37 | res.append(correct_k * (100.0 / batch_size)) 38 | return res 39 | 40 | class HistoryMeter(object): 41 | """Remember all values""" 42 | def __init__(self): 43 | self.reset() 44 | 45 | def reset(self): 46 | self.hist = [] 47 | self.partials = [] 48 | self.count = 0 49 | self.val = 0 50 | 51 | def update(self, x, n=1): 52 | self.val = x 53 | self.hist.append(x) 54 | x = n * x 55 | self.count += n 56 | # full precision summation based on http://code.activestate.com/recipes/393090/ 57 | i = 0 58 | for y in self.partials: 59 | if abs(x) < abs(y): 60 | x, y = y, x 61 | hi = x + y 62 | lo = y - (hi - x) 63 | if lo: 64 | self.partials[i] = lo 65 | i += 1 66 | x = hi 67 | self.partials[i:] = [x] 68 | 69 | @property 70 | def avg(self): 71 | """ 72 | Alternative to AverageMeter without floating point errors 73 | """ 74 | return sum(self.partials, 0.0) / self.count if self.partials else 0 75 | 76 | 77 | 78 | 79 | @attr.s(auto_attribs=True) 80 | class TrainerState: 81 | """ 82 | Contains the state of the Trainer. 83 | It can be saved to checkpoint the training and loaded to resume it. 84 | """ 85 | 86 | model: nn.Module 87 | 88 | def save(self, filename: str) -> None: 89 | data = attr.asdict(self) 90 | # store only the state dict 91 | data["model"] = self.model.state_dict() 92 | 93 | torch.save(data, filename) 94 | 95 | @classmethod 96 | def load(cls, filename: str, default: "TrainerState") -> "TrainerState": 97 | data = torch.load(filename) 98 | # We need this default to load the state dict 99 | model = default.model 100 | model.load_state_dict(data["model"]) 101 | data["model"] = model 102 | 103 | return cls(**data) 104 | 105 | 106 | class Trainer: 107 | def __init__(self, train_cfg: TrainerConfig, cluster_cfg: ClusterConfig) -> None: 108 | self._train_cfg = train_cfg 109 | self._cluster_cfg = cluster_cfg 110 | 111 | def __call__(self) -> Optional[float]: 112 | """ 113 | Called for each task. 114 | 115 | :return: The master task return the final accuracy of the model. 116 | """ 117 | self._setup_process_group() 118 | self._init_state() 119 | final_acc = self._train() 120 | return final_acc 121 | 122 | def checkpoint(self, rm_init=True): 123 | 124 | save_dir = osp.join(self._train_cfg.save_folder, str(self._train_cfg.job_id)) 125 | os.makedirs(save_dir, exist_ok=True) 126 | self._state.save(osp.join(save_dir, "checkpoint.pth")) 127 | self._state.save(osp.join(save_dir, "checkpoint_"+str(self._state.epoch)+".pth")) 128 | if rm_init: 129 | os.remove(self._cluster_cfg.dist_url[7:]) 130 | empty_trainer = Trainer(self._train_cfg, self._cluster_cfg) 131 | return empty_trainer 132 | 133 | def _setup_process_group(self) -> None: 134 | 135 | torch.cuda.set_device(self._train_cfg.local_rank) 136 | torch.distributed.init_process_group( 137 | backend=self._cluster_cfg.dist_backend, 138 | init_method=self._cluster_cfg.dist_url, 139 | world_size=self._train_cfg.num_tasks, 140 | rank=self._train_cfg.global_rank, 141 | ) 142 | print(f"Process group: {self._train_cfg.num_tasks} tasks, rank: {self._train_cfg.global_rank}") 143 | 144 | def _init_state(self) -> None: 145 | """ 146 | Initialize the state and load it from an existing checkpoint if any 147 | """ 148 | torch.manual_seed(0) 149 | np.random.seed(0) 150 | 151 | print("Create data loaders", flush=True) 152 | print("Input size : "+str(self._train_cfg.input_size)) 153 | print("Model : " + str(self._train_cfg.architecture) ) 154 | backbone_architecture=None 155 | if self._train_cfg.architecture=='PNASNet' : 156 | backbone_architecture='pnasnet5large' 157 | 158 | 159 | transformation=get_transforms(input_size=self._train_cfg.input_size,test_size=self._train_cfg.input_size, kind='full', crop=True, need=('train', 'val'), backbone=backbone_architecture) 160 | transform_test = transformation['val'] 161 | 162 | test_set = datasets.ImageFolder(self._train_cfg.imnet_path + '/val',transform=transform_test) 163 | 164 | self._test_loader = torch.utils.data.DataLoader( 165 | test_set, batch_size=self._train_cfg.batch_per_gpu, shuffle=False, num_workers=(self._train_cfg.workers-1), 166 | ) 167 | 168 | 169 | print("Create distributed model", flush=True) 170 | 171 | if self._train_cfg.architecture=='PNASNet' : 172 | model= pnasnet5large(pretrained='imagenet') 173 | 174 | if self._train_cfg.architecture=='ResNet50' : 175 | model=models.resnet50(pretrained=False) 176 | 177 | if self._train_cfg.architecture=='IGAM_Resnext101_32x48d' : 178 | model=resnext101_32x48d_wsl(progress=True) 179 | 180 | pretrained_dict=torch.load(self._train_cfg.weight_path,map_location='cpu')['model'] 181 | model_dict = model.state_dict() 182 | count=0 183 | count2=0 184 | for k in model_dict.keys(): 185 | count=count+1.0 186 | if(('module.'+k) in pretrained_dict.keys()): 187 | count2=count2+1.0 188 | model_dict[k]=pretrained_dict.get(('module.'+k)) 189 | model.load_state_dict(model_dict) 190 | print("load "+str(count2*100/count)+" %") 191 | 192 | assert int(count2*100/count)== 100,"model loading error" 193 | 194 | for name, child in model.named_children(): 195 | for name2, params in child.named_parameters(): 196 | params.requires_grad = False 197 | 198 | print('model_load') 199 | if torch.cuda.is_available(): 200 | model.cuda(self._train_cfg.local_rank) 201 | model = torch.nn.parallel.DistributedDataParallel( 202 | model, device_ids=[self._train_cfg.local_rank], output_device=self._train_cfg.local_rank 203 | ) 204 | 205 | self._state = TrainerState( 206 | model=model 207 | ) 208 | checkpoint_fn = osp.join(self._train_cfg.save_folder, str(self._train_cfg.job_id), "checkpoint.pth") 209 | if os.path.isfile(checkpoint_fn): 210 | print(f"Load existing checkpoint from {checkpoint_fn}", flush=True) 211 | self._state = TrainerState.load(checkpoint_fn, default=self._state) 212 | 213 | def _train(self) -> Optional[float]: 214 | 215 | self._state.model.eval() 216 | metrics = defaultdict(HistoryMeter) 217 | with torch.no_grad(): 218 | for data in tqdm.tqdm(self._test_loader): 219 | images, labels = data 220 | images = images.cuda(self._train_cfg.local_rank, non_blocking=True) 221 | labels = labels.cuda(self._train_cfg.local_rank, non_blocking=True) 222 | outputs = self._state.model(images) 223 | top1, top5 = accuracy_sp(outputs, labels, topk=(1, 5)) 224 | metrics["val_top1"].update(top1, n=images.size(0)) 225 | metrics["val_top5"].update(top5, n=images.size(0)) 226 | for k in metrics: metrics[k] = metrics[k].avg 227 | print(metrics) 228 | return metrics['val_top1'],metrics['val_top5'] 229 | 230 | 231 | 232 | -------------------------------------------------------------------------------- /imnet_evaluate/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torchvision.transforms.functional as F 9 | from torchvision import transforms 10 | 11 | 12 | class Resize(transforms.Resize): 13 | """ 14 | Resize with a ``largest=False'' argument 15 | allowing to resize to a common largest side without cropping 16 | """ 17 | 18 | 19 | def __init__(self, size, largest=False, **kwargs): 20 | super().__init__(size, **kwargs) 21 | self.largest = largest 22 | 23 | @staticmethod 24 | def target_size(w, h, size, largest=False): 25 | if h < w and largest: 26 | w, h = size, int(size * h / w) 27 | else: 28 | w, h = int(size * w / h), size 29 | size = (h, w) 30 | return size 31 | 32 | def __call__(self, img): 33 | size = self.size 34 | w, h = img.size 35 | target_size = self.target_size(w, h, size, self.largest) 36 | return F.resize(img, target_size, self.interpolation) 37 | 38 | def __repr__(self): 39 | r = super().__repr__() 40 | return r[:-1] + ', largest={})'.format(self.largest) 41 | 42 | 43 | 44 | 45 | 46 | def get_transforms(input_size=224,test_size=224, kind='full', crop=True, need=('train', 'val'), backbone=None): 47 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 48 | if backbone is not None and backbone in ['pnasnet5large', 'nasnetamobile']: 49 | mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 50 | 51 | transformations = {} 52 | if 'train' in need: 53 | if kind == 'torch': 54 | transformations['train'] = transforms.Compose([ 55 | transforms.RandomResizedCrop(input_size), 56 | transforms.RandomHorizontalFlip(), 57 | transforms.ToTensor(), 58 | transforms.Normalize(mean, std), 59 | ]) 60 | elif kind == 'full': 61 | transformations['train'] = transforms.Compose([ 62 | transforms.RandomResizedCrop(input_size), 63 | transforms.RandomHorizontalFlip(), 64 | transforms.ColorJitter(0.3, 0.3, 0.3), 65 | transforms.ToTensor(), 66 | transforms.Normalize(mean, std), 67 | ]) 68 | 69 | else: 70 | raise ValueError('Transforms kind {} unknown'.format(kind)) 71 | if 'val' in need: 72 | if crop: 73 | transformations['val'] = transforms.Compose( 74 | [Resize(int((256 / 224) * test_size)), # to maintain same ratio w.r.t. 224 images 75 | transforms.CenterCrop(test_size), 76 | transforms.ToTensor(), 77 | transforms.Normalize(mean, std)]) 78 | else: 79 | transformations['val'] = transforms.Compose( 80 | [Resize(test_size, largest=True), # to maintain same ratio w.r.t. 224 images 81 | transforms.ToTensor(), 82 | transforms.Normalize(mean, std)]) 83 | return transformations 84 | 85 | transforms_list = ['torch', 'full'] 86 | -------------------------------------------------------------------------------- /imnet_extract/Res.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | ''' 8 | Code from : https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 9 | ''' 10 | 11 | import torch.nn as nn 12 | try: 13 | from torch.hub import load_state_dict_from_url 14 | except ImportError: 15 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 16 | 17 | 18 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 19 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 20 | 21 | 22 | model_urls = { 23 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 24 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 25 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 26 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 27 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 28 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 29 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 30 | } 31 | 32 | 33 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 34 | """3x3 convolution with padding""" 35 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 36 | padding=dilation, groups=groups, bias=False, dilation=dilation) 37 | 38 | 39 | def conv1x1(in_planes, out_planes, stride=1): 40 | """1x1 convolution""" 41 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 42 | 43 | 44 | class BasicBlock(nn.Module): 45 | expansion = 1 46 | 47 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 48 | base_width=64, dilation=1, norm_layer=None): 49 | super(BasicBlock, self).__init__() 50 | if norm_layer is None: 51 | norm_layer = nn.BatchNorm2d 52 | if groups != 1 or base_width != 64: 53 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 54 | if dilation > 1: 55 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 56 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 57 | self.conv1 = conv3x3(inplanes, planes, stride) 58 | self.bn1 = norm_layer(planes) 59 | self.relu = nn.ReLU(inplace=True) 60 | self.conv2 = conv3x3(planes, planes) 61 | self.bn2 = norm_layer(planes) 62 | self.downsample = downsample 63 | self.stride = stride 64 | 65 | def forward(self, x): 66 | identity = x 67 | 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | 75 | if self.downsample is not None: 76 | identity = self.downsample(x) 77 | 78 | out += identity 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | 84 | class Bottleneck(nn.Module): 85 | expansion = 4 86 | 87 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 88 | base_width=64, dilation=1, norm_layer=None): 89 | super(Bottleneck, self).__init__() 90 | if norm_layer is None: 91 | norm_layer = nn.BatchNorm2d 92 | width = int(planes * (base_width / 64.)) * groups 93 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 94 | self.conv1 = conv1x1(inplanes, width) 95 | self.bn1 = norm_layer(width) 96 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 97 | self.bn2 = norm_layer(width) 98 | self.conv3 = conv1x1(width, planes * self.expansion) 99 | self.bn3 = norm_layer(planes * self.expansion) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.downsample = downsample 102 | self.stride = stride 103 | 104 | def forward(self, x): 105 | identity = x 106 | 107 | out = self.conv1(x) 108 | out = self.bn1(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv2(out) 112 | out = self.bn2(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv3(out) 116 | out = self.bn3(out) 117 | 118 | if self.downsample is not None: 119 | identity = self.downsample(x) 120 | 121 | out += identity 122 | out = self.relu(out) 123 | 124 | return out 125 | 126 | 127 | class ResNet(nn.Module): 128 | 129 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 130 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 131 | norm_layer=None): 132 | super(ResNet, self).__init__() 133 | if norm_layer is None: 134 | norm_layer = nn.BatchNorm2d 135 | self._norm_layer = norm_layer 136 | 137 | self.inplanes = 64 138 | self.dilation = 1 139 | if replace_stride_with_dilation is None: 140 | # each element in the tuple indicates if we should replace 141 | # the 2x2 stride with a dilated convolution instead 142 | replace_stride_with_dilation = [False, False, False] 143 | if len(replace_stride_with_dilation) != 3: 144 | raise ValueError("replace_stride_with_dilation should be None " 145 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 146 | self.groups = groups 147 | self.base_width = width_per_group 148 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 149 | bias=False) 150 | self.bn1 = norm_layer(self.inplanes) 151 | self.relu = nn.ReLU(inplace=True) 152 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 153 | self.layer1 = self._make_layer(block, 64, layers[0]) 154 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 155 | dilate=replace_stride_with_dilation[0]) 156 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 157 | dilate=replace_stride_with_dilation[1]) 158 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 159 | dilate=replace_stride_with_dilation[2]) 160 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 161 | self.fc = nn.Linear(512 * block.expansion, num_classes) 162 | 163 | for m in self.modules(): 164 | if isinstance(m, nn.Conv2d): 165 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 166 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 167 | nn.init.constant_(m.weight, 1) 168 | nn.init.constant_(m.bias, 0) 169 | 170 | # Zero-initialize the last BN in each residual branch, 171 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 172 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 173 | if zero_init_residual: 174 | for m in self.modules(): 175 | if isinstance(m, Bottleneck): 176 | nn.init.constant_(m.bn3.weight, 0) 177 | elif isinstance(m, BasicBlock): 178 | nn.init.constant_(m.bn2.weight, 0) 179 | 180 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 181 | norm_layer = self._norm_layer 182 | downsample = None 183 | previous_dilation = self.dilation 184 | if dilate: 185 | self.dilation *= stride 186 | stride = 1 187 | if stride != 1 or self.inplanes != planes * block.expansion: 188 | downsample = nn.Sequential( 189 | conv1x1(self.inplanes, planes * block.expansion, stride), 190 | norm_layer(planes * block.expansion), 191 | ) 192 | 193 | layers = [] 194 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 195 | self.base_width, previous_dilation, norm_layer)) 196 | self.inplanes = planes * block.expansion 197 | for _ in range(1, blocks): 198 | layers.append(block(self.inplanes, planes, groups=self.groups, 199 | base_width=self.base_width, dilation=self.dilation, 200 | norm_layer=norm_layer)) 201 | 202 | return nn.Sequential(*layers) 203 | 204 | def forward(self, x): 205 | x = self.conv1(x) 206 | x = self.bn1(x) 207 | x = self.relu(x) 208 | x = self.maxpool(x) 209 | 210 | x = self.layer1(x) 211 | x = self.layer2(x) 212 | x = self.layer3(x) 213 | x = self.layer4(x) 214 | 215 | x = self.avgpool(x) 216 | x1 = x.reshape(x.size(0), -1) 217 | x = self.fc(x1) 218 | 219 | return x,x1 220 | 221 | 222 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 223 | model = ResNet(block, layers, **kwargs) 224 | if pretrained: 225 | state_dict = load_state_dict_from_url(model_urls[arch], 226 | progress=progress) 227 | model.load_state_dict(state_dict) 228 | return model 229 | 230 | 231 | def resnet18(pretrained=False, progress=True, **kwargs): 232 | """Constructs a ResNet-18 model. 233 | Args: 234 | pretrained (bool): If True, returns a model pre-trained on ImageNet 235 | progress (bool): If True, displays a progress bar of the download to stderr 236 | """ 237 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 238 | **kwargs) 239 | 240 | 241 | def resnet34(pretrained=False, progress=True, **kwargs): 242 | """Constructs a ResNet-34 model. 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | progress (bool): If True, displays a progress bar of the download to stderr 246 | """ 247 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 248 | **kwargs) 249 | 250 | 251 | def resnet50(pretrained=False, progress=True, **kwargs): 252 | """Constructs a ResNet-50 model. 253 | Args: 254 | pretrained (bool): If True, returns a model pre-trained on ImageNet 255 | progress (bool): If True, displays a progress bar of the download to stderr 256 | """ 257 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 258 | **kwargs) 259 | 260 | 261 | def resnet101(pretrained=False, progress=True, **kwargs): 262 | """Constructs a ResNet-101 model. 263 | Args: 264 | pretrained (bool): If True, returns a model pre-trained on ImageNet 265 | progress (bool): If True, displays a progress bar of the download to stderr 266 | """ 267 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 268 | **kwargs) 269 | 270 | 271 | def resnet152(pretrained=False, progress=True, **kwargs): 272 | """Constructs a ResNet-152 model. 273 | Args: 274 | pretrained (bool): If True, returns a model pre-trained on ImageNet 275 | progress (bool): If True, displays a progress bar of the download to stderr 276 | """ 277 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 278 | **kwargs) 279 | 280 | 281 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 282 | """Constructs a ResNeXt-50 32x4d model. 283 | Args: 284 | pretrained (bool): If True, returns a model pre-trained on ImageNet 285 | progress (bool): If True, displays a progress bar of the download to stderr 286 | """ 287 | kwargs['groups'] = 32 288 | kwargs['width_per_group'] = 4 289 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 290 | pretrained, progress, **kwargs) 291 | 292 | 293 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 294 | """Constructs a ResNeXt-101 32x8d model. 295 | Args: 296 | pretrained (bool): If True, returns a model pre-trained on ImageNet 297 | progress (bool): If True, displays a progress bar of the download to stderr 298 | """ 299 | kwargs['groups'] = 32 300 | kwargs['width_per_group'] = 8 301 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 302 | pretrained, progress, **kwargs) 303 | -------------------------------------------------------------------------------- /imnet_extract/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .train import Trainer 8 | from .config import TrainerConfig, ClusterConfig 9 | -------------------------------------------------------------------------------- /imnet_extract/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from typing import NamedTuple 8 | 9 | 10 | class ClusterConfig(NamedTuple): 11 | dist_backend: str 12 | dist_url: str 13 | 14 | 15 | class TrainerConfig(NamedTuple): 16 | data_folder: str 17 | architecture: str 18 | weight_path: str 19 | dataset_path: str 20 | save_path: str 21 | workers: int 22 | input_size: int 23 | batch_per_gpu: int 24 | local_rank: int 25 | global_rank: int 26 | num_tasks: int 27 | job_id: str 28 | save_folder: str 29 | -------------------------------------------------------------------------------- /imnet_extract/resnext_wsl.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # Optional list of dependencies required by the package 9 | 10 | ''' 11 | Code From : https://github.com/facebookresearch/WSL-Images/blob/master/hubconf.py 12 | ''' 13 | dependencies = ['torch', 'torchvision'] 14 | 15 | try: 16 | from torch.hub import load_state_dict_from_url 17 | except ImportError: 18 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 19 | 20 | from .Res import ResNet, Bottleneck 21 | 22 | 23 | model_urls = { 24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth', 25 | 'resnext101_32x16d': 'https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth', 26 | 'resnext101_32x32d': 'https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth', 27 | 'resnext101_32x48d': 'https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth', 28 | } 29 | 30 | 31 | def _resnext(arch, block, layers, pretrained, progress, **kwargs): 32 | model = ResNet(block, layers, **kwargs) 33 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 34 | model.load_state_dict(state_dict) 35 | return model 36 | 37 | 38 | def resnext101_32x8d_wsl(progress=True, **kwargs): 39 | """Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data 40 | and finetuned on ImageNet from Figure 5 in 41 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 42 | Args: 43 | progress (bool): If True, displays a progress bar of the download to stderr. 44 | """ 45 | kwargs['groups'] = 32 46 | kwargs['width_per_group'] = 8 47 | return _resnext('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 48 | 49 | 50 | def resnext101_32x16d_wsl(progress=True, **kwargs): 51 | """Constructs a ResNeXt-101 32x16 model pre-trained on weakly-supervised data 52 | and finetuned on ImageNet from Figure 5 in 53 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 54 | Args: 55 | progress (bool): If True, displays a progress bar of the download to stderr. 56 | """ 57 | kwargs['groups'] = 32 58 | kwargs['width_per_group'] = 16 59 | return _resnext('resnext101_32x16d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 60 | 61 | 62 | def resnext101_32x32d_wsl(progress=True, **kwargs): 63 | """Constructs a ResNeXt-101 32x32 model pre-trained on weakly-supervised data 64 | and finetuned on ImageNet from Figure 5 in 65 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 66 | Args: 67 | progress (bool): If True, displays a progress bar of the download to stderr. 68 | """ 69 | kwargs['groups'] = 32 70 | kwargs['width_per_group'] = 32 71 | return _resnext('resnext101_32x32d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 72 | 73 | 74 | def resnext101_32x48d_wsl(progress=True, **kwargs): 75 | """Constructs a ResNeXt-101 32x48 model pre-trained on weakly-supervised data 76 | and finetuned on ImageNet from Figure 5 in 77 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 78 | Args: 79 | progress (bool): If True, displays a progress bar of the download to stderr. 80 | """ 81 | kwargs['groups'] = 32 82 | kwargs['width_per_group'] = 48 83 | return _resnext('resnext101_32x48d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 84 | -------------------------------------------------------------------------------- /imnet_extract/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from torch.utils.data.sampler import BatchSampler 8 | import torch 9 | import numpy as np 10 | from torch.utils.data.dataloader import default_collate 11 | from collections.abc import Mapping, Sequence 12 | import math 13 | import torch.distributed as dist 14 | 15 | class RASampler(torch.utils.data.Sampler): 16 | """ 17 | Batch Sampler with Repeated Augmentations (RA) 18 | - dataset_len: original length of the dataset 19 | - batch_size 20 | - repetitions: instances per image 21 | - len_factor: multiplicative factor for epoch size 22 | """ 23 | 24 | def __init__(self,dataset,num_replicas, rank, dataset_len, batch_size, repetitions=1, len_factor=1.0, shuffle=False, drop_last=False): 25 | self.dataset=dataset 26 | self.dataset_len = dataset_len 27 | self.batch_size = batch_size 28 | self.repetitions = repetitions 29 | self.len_images = int(dataset_len * len_factor) 30 | self.shuffle = shuffle 31 | self.drop_last = drop_last 32 | if num_replicas is None: 33 | if not dist.is_available(): 34 | raise RuntimeError("Requires distributed package to be available") 35 | num_replicas = dist.get_world_size() 36 | if rank is None: 37 | if not dist.is_available(): 38 | raise RuntimeError("Requires distributed package to be available") 39 | rank = dist.get_rank() 40 | self.dataset = dataset 41 | self.num_replicas = num_replicas 42 | self.rank = rank 43 | self.epoch = 0 44 | self.num_samples = int(math.ceil(len(self.dataset) * self.repetitions * 1.0 / self.num_replicas)) 45 | self.total_size = self.num_samples * self.num_replicas 46 | 47 | 48 | def shuffler(self): 49 | if self.shuffle: 50 | new_perm = lambda: iter(np.random.permutation(self.dataset_len)) 51 | else: 52 | new_perm = lambda: iter(np.arange(self.dataset_len)) 53 | shuffle = new_perm() 54 | while True: 55 | try: 56 | index = next(shuffle) 57 | except StopIteration: 58 | shuffle = new_perm() 59 | index = next(shuffle) 60 | for repetition in range(self.repetitions): 61 | yield index 62 | 63 | def __iter__(self): 64 | shuffle = iter(self.shuffler()) 65 | seen = 0 66 | indices=[] 67 | for _ in range(self.len_images): 68 | index = next(shuffle) 69 | indices.append(index) 70 | indices += indices[:(self.total_size - len(indices))] 71 | assert len(indices) == self.total_size 72 | # subsample 73 | indices = indices[self.rank:self.total_size:self.num_replicas] 74 | assert len(indices) == self.num_samples 75 | 76 | return iter(indices) 77 | 78 | 79 | def __len__(self): 80 | return self.num_samples 81 | 82 | def set_epoch(self, epoch): 83 | self.epoch = epoch 84 | 85 | def list_collate(batch): 86 | """ 87 | Collate into a list instead of a tensor to deal with variable-sized inputs 88 | """ 89 | elem_type = type(batch[0]) 90 | if isinstance(batch[0], torch.Tensor): 91 | return batch 92 | elif elem_type.__module__ == 'numpy': 93 | if elem_type.__name__ == 'ndarray': 94 | return list_collate([torch.from_numpy(b) for b in batch]) 95 | elif isinstance(batch[0], Mapping): 96 | return {key: list_collate([d[key] for d in batch]) for key in batch[0]} 97 | elif isinstance(batch[0], Sequence): 98 | transposed = zip(*batch) 99 | return [list_collate(samples) for samples in transposed] 100 | return default_collate(batch) 101 | -------------------------------------------------------------------------------- /imnet_extract/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import os.path as osp 9 | from typing import Optional 10 | import torch 11 | import torch.distributed 12 | import torch.nn as nn 13 | import attr 14 | from torchvision import datasets 15 | import tqdm 16 | import numpy as np 17 | from .config import TrainerConfig, ClusterConfig 18 | from .transforms import get_transforms 19 | from .resnext_wsl import resnext101_32x48d_wsl 20 | from .pnasnet import pnasnet5large 21 | from .Res import resnet50 22 | 23 | def conv_numpy_tensor(output): 24 | """Convert CUDA Tensor to numpy element""" 25 | return output.data.cpu().numpy() 26 | 27 | @attr.s(auto_attribs=True) 28 | class TrainerState: 29 | """ 30 | Contains the state of the Trainer. 31 | It can be saved to checkpoint the training and loaded to resume it. 32 | """ 33 | 34 | model: nn.Module 35 | 36 | def save(self, filename: str) -> None: 37 | data = attr.asdict(self) 38 | # store only the state dict 39 | data["model"] = self.model.state_dict() 40 | 41 | torch.save(data, filename) 42 | 43 | @classmethod 44 | def load(cls, filename: str, default: "TrainerState") -> "TrainerState": 45 | data = torch.load(filename) 46 | # We need this default to load the state dict 47 | model = default.model 48 | model.load_state_dict(data["model"]) 49 | data["model"] = model 50 | 51 | return cls(**data) 52 | 53 | 54 | class Trainer: 55 | def __init__(self, train_cfg: TrainerConfig, cluster_cfg: ClusterConfig) -> None: 56 | self._train_cfg = train_cfg 57 | self._cluster_cfg = cluster_cfg 58 | 59 | def __call__(self) -> Optional[float]: 60 | """ 61 | Called for each task. 62 | 63 | :return: The master task return the final accuracy of the model. 64 | """ 65 | self._setup_process_group() 66 | self._init_state() 67 | final_acc = self._train() 68 | return final_acc 69 | 70 | def checkpoint(self, rm_init=True): 71 | # will be called by submitit in case of preemption 72 | 73 | save_dir = osp.join(self._train_cfg.save_folder, str(self._train_cfg.job_id)) 74 | os.makedirs(save_dir, exist_ok=True) 75 | self._state.save(osp.join(save_dir, "checkpoint.pth")) 76 | self._state.save(osp.join(save_dir, "checkpoint_"+str(self._state.epoch)+".pth")) 77 | # Trick here: when the job will be requeue, we will use the same init file 78 | # but it must not exist when we initialize the process group 79 | # so we delete it, but only when this method is called by submitit for requeue 80 | if rm_init: 81 | os.remove(self._cluster_cfg.dist_url[7:]) # remove file:// at the beginning 82 | # This allow to remove any non-pickable part of the Trainer instance. 83 | empty_trainer = Trainer(self._train_cfg, self._cluster_cfg) 84 | return empty_trainer 85 | 86 | def _setup_process_group(self) -> None: 87 | torch.cuda.set_device(self._train_cfg.local_rank) 88 | torch.distributed.init_process_group( 89 | backend=self._cluster_cfg.dist_backend, 90 | init_method=self._cluster_cfg.dist_url, 91 | world_size=self._train_cfg.num_tasks, 92 | rank=self._train_cfg.global_rank, 93 | ) 94 | print(f"Process group: {self._train_cfg.num_tasks} tasks, rank: {self._train_cfg.global_rank}") 95 | 96 | def _init_state(self) -> None: 97 | """ 98 | Initialize the state and load it from an existing checkpoint if any 99 | """ 100 | torch.manual_seed(0) 101 | np.random.seed(0) 102 | 103 | print("Create data loaders", flush=True) 104 | print("Input size : "+str(self._train_cfg.input_size)) 105 | print("Model : " + str(self._train_cfg.architecture) ) 106 | backbone_architecture=None 107 | if self._train_cfg.architecture=='PNASNet' : 108 | backbone_architecture='pnasnet5large' 109 | 110 | 111 | transformation=get_transforms(input_size=self._train_cfg.input_size,test_size=self._train_cfg.input_size, kind='full', crop=True, need=('train', 'val'), backbone=backbone_architecture) 112 | transform_test = transformation['val'] 113 | 114 | test_set = datasets.ImageFolder(self._train_cfg.dataset_path,transform=transform_test) 115 | 116 | self._test_loader = torch.utils.data.DataLoader( 117 | test_set, batch_size=self._train_cfg.batch_per_gpu, shuffle=False, num_workers=(self._train_cfg.workers-1), 118 | ) 119 | 120 | 121 | print("Create distributed model", flush=True) 122 | 123 | if self._train_cfg.architecture=='PNASNet' : 124 | model= pnasnet5large(pretrained='imagenet') 125 | 126 | if self._train_cfg.architecture=='ResNet50' : 127 | model=resnet50(pretrained=False) 128 | 129 | if self._train_cfg.architecture=='IGAM_Resnext101_32x48d' : 130 | model=resnext101_32x48d_wsl(progress=True) 131 | 132 | pretrained_dict=torch.load(self._train_cfg.weight_path,map_location='cpu')['model'] 133 | model_dict = model.state_dict() 134 | count=0 135 | count2=0 136 | for k in model_dict.keys(): 137 | count=count+1.0 138 | if(('module.'+k) in pretrained_dict.keys()): 139 | count2=count2+1.0 140 | model_dict[k]=pretrained_dict.get(('module.'+k)) 141 | model.load_state_dict(model_dict) 142 | print("load "+str(count2*100/count)+" %") 143 | 144 | assert int(count2*100/count)== 100,"model loading error" 145 | 146 | for name, child in model.named_children(): 147 | for name2, params in child.named_parameters(): 148 | params.requires_grad = False 149 | 150 | print('model_load') 151 | if torch.cuda.is_available(): 152 | 153 | model.cuda(self._train_cfg.local_rank) 154 | model = torch.nn.parallel.DistributedDataParallel( 155 | model, device_ids=[self._train_cfg.local_rank], output_device=self._train_cfg.local_rank 156 | ) 157 | 158 | self._state = TrainerState( 159 | model=model 160 | ) 161 | checkpoint_fn = osp.join(self._train_cfg.save_folder, str(self._train_cfg.job_id), "checkpoint.pth") 162 | if os.path.isfile(checkpoint_fn): 163 | print(f"Load existing checkpoint from {checkpoint_fn}", flush=True) 164 | self._state = TrainerState.load(checkpoint_fn, default=self._state) 165 | 166 | def _train(self) -> Optional[float]: 167 | self._state.model.eval() 168 | 169 | embedding=None 170 | softmax_probability=None 171 | exctract_label=None 172 | with torch.no_grad(): 173 | for data in tqdm.tqdm(self._test_loader): 174 | images, labels = data 175 | images = images.cuda(self._train_cfg.local_rank, non_blocking=True) 176 | labels = labels.cuda(self._train_cfg.local_rank, non_blocking=True) 177 | outputs , embed = self._state.model(images) 178 | if embedding is None: 179 | softmax_probability=conv_numpy_tensor(nn.Softmax()(outputs)) 180 | embedding=conv_numpy_tensor((embed)) 181 | exctract_label=conv_numpy_tensor(labels) 182 | else: 183 | softmax_probability=np.concatenate((softmax_probability,conv_numpy_tensor(nn.Softmax()(outputs)))) 184 | embedding=np.concatenate((embedding,conv_numpy_tensor((embed)))) 185 | exctract_label=np.concatenate((exctract_label,conv_numpy_tensor(labels))) 186 | 187 | 188 | np.save(str(self._train_cfg.save_path)+'labels.npy', exctract_label) 189 | np.save(str(self._train_cfg.save_path)+str(self._train_cfg.architecture)+'_embedding.npy', embedding) 190 | np.save(str(self._train_cfg.save_path)+str(self._train_cfg.architecture)+'_softmax.npy', softmax_probability) 191 | return 0.0 192 | 193 | 194 | 195 | -------------------------------------------------------------------------------- /imnet_extract/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torchvision.transforms.functional as F 9 | from torchvision import transforms 10 | 11 | class Resize(transforms.Resize): 12 | """ 13 | Resize with a ``largest=False'' argument 14 | allowing to resize to a common largest side without cropping 15 | """ 16 | 17 | 18 | def __init__(self, size, largest=False, **kwargs): 19 | super().__init__(size, **kwargs) 20 | self.largest = largest 21 | 22 | @staticmethod 23 | def target_size(w, h, size, largest=False): 24 | if h < w and largest: 25 | w, h = size, int(size * h / w) 26 | else: 27 | w, h = int(size * w / h), size 28 | size = (h, w) 29 | return size 30 | 31 | def __call__(self, img): 32 | size = self.size 33 | w, h = img.size 34 | target_size = self.target_size(w, h, size, self.largest) 35 | return F.resize(img, target_size, self.interpolation) 36 | 37 | def __repr__(self): 38 | r = super().__repr__() 39 | return r[:-1] + ', largest={})'.format(self.largest) 40 | 41 | 42 | 43 | 44 | 45 | def get_transforms(input_size=224,test_size=224, kind='full', crop=True, need=('train', 'val'), backbone=None): 46 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 47 | if backbone is not None and backbone in ['pnasnet5large', 'nasnetamobile']: 48 | mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 49 | 50 | transformations = {} 51 | if 'train' in need: 52 | if kind == 'torch': 53 | transformations['train'] = transforms.Compose([ 54 | transforms.RandomResizedCrop(input_size), 55 | transforms.RandomHorizontalFlip(), 56 | transforms.ToTensor(), 57 | transforms.Normalize(mean, std), 58 | ]) 59 | elif kind == 'full': 60 | transformations['train'] = transforms.Compose([ 61 | transforms.RandomResizedCrop(input_size), 62 | transforms.RandomHorizontalFlip(), 63 | transforms.ColorJitter(0.3, 0.3, 0.3), 64 | transforms.ToTensor(), 65 | transforms.Normalize(mean, std), 66 | ]) 67 | 68 | else: 69 | raise ValueError('Transforms kind {} unknown'.format(kind)) 70 | if 'val' in need: 71 | if crop: 72 | transformations['val'] = transforms.Compose( 73 | [Resize(int((256 / 224) * test_size)), # to maintain same ratio w.r.t. 224 images 74 | transforms.CenterCrop(test_size), 75 | transforms.ToTensor(), 76 | transforms.Normalize(mean, std)]) 77 | else: 78 | transformations['val'] = transforms.Compose( 79 | [Resize(test_size, largest=True), # to maintain same ratio w.r.t. 224 images 80 | transforms.ToTensor(), 81 | transforms.Normalize(mean, std)]) 82 | return transformations 83 | 84 | transforms_list = ['torch', 'full'] 85 | -------------------------------------------------------------------------------- /imnet_finetune/Res.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import torch.nn as nn 9 | try: 10 | from torch.hub import load_state_dict_from_url 11 | except ImportError: 12 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 13 | 14 | 15 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 16 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 17 | 18 | 19 | model_urls = { 20 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 21 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 22 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 23 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 24 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 25 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 26 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 27 | } 28 | 29 | 30 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 31 | """3x3 convolution with padding""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 33 | padding=dilation, groups=groups, bias=False, dilation=dilation) 34 | 35 | 36 | def conv1x1(in_planes, out_planes, stride=1): 37 | """1x1 convolution""" 38 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 39 | 40 | 41 | class BasicBlock(nn.Module): 42 | expansion = 1 43 | 44 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 45 | base_width=64, dilation=1, norm_layer=None): 46 | super(BasicBlock, self).__init__() 47 | if norm_layer is None: 48 | norm_layer = nn.BatchNorm2d 49 | if groups != 1 or base_width != 64: 50 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 51 | if dilation > 1: 52 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 53 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 54 | self.conv1 = conv3x3(inplanes, planes, stride) 55 | self.bn1 = norm_layer(planes) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.conv2 = conv3x3(planes, planes) 58 | self.bn2 = norm_layer(planes) 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | identity = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | 72 | if self.downsample is not None: 73 | identity = self.downsample(x) 74 | 75 | out += identity 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class Bottleneck(nn.Module): 82 | expansion = 4 83 | 84 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 85 | base_width=64, dilation=1, norm_layer=None): 86 | super(Bottleneck, self).__init__() 87 | if norm_layer is None: 88 | norm_layer = nn.BatchNorm2d 89 | width = int(planes * (base_width / 64.)) * groups 90 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 91 | self.conv1 = conv1x1(inplanes, width) 92 | self.bn1 = norm_layer(width) 93 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 94 | self.bn2 = norm_layer(width) 95 | self.conv3 = conv1x1(width, planes * self.expansion) 96 | self.bn3 = norm_layer(planes * self.expansion) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.downsample = downsample 99 | self.stride = stride 100 | 101 | def forward(self, x): 102 | identity = x 103 | 104 | out = self.conv1(x) 105 | out = self.bn1(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv2(out) 109 | out = self.bn2(out) 110 | out = self.relu(out) 111 | 112 | out = self.conv3(out) 113 | out = self.bn3(out) 114 | 115 | if self.downsample is not None: 116 | identity = self.downsample(x) 117 | 118 | out += identity 119 | out = self.relu(out) 120 | 121 | return out 122 | 123 | 124 | class ResNet(nn.Module): 125 | 126 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 127 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 128 | norm_layer=None): 129 | super(ResNet, self).__init__() 130 | if norm_layer is None: 131 | norm_layer = nn.BatchNorm2d 132 | self._norm_layer = norm_layer 133 | 134 | self.inplanes = 64 135 | self.dilation = 1 136 | if replace_stride_with_dilation is None: 137 | # each element in the tuple indicates if we should replace 138 | # the 2x2 stride with a dilated convolution instead 139 | replace_stride_with_dilation = [False, False, False] 140 | if len(replace_stride_with_dilation) != 3: 141 | raise ValueError("replace_stride_with_dilation should be None " 142 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 143 | self.groups = groups 144 | self.base_width = width_per_group 145 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 146 | bias=False) 147 | self.bn1 = norm_layer(self.inplanes) 148 | self.relu = nn.ReLU(inplace=True) 149 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 150 | self.layer1 = self._make_layer(block, 64, layers[0]) 151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 152 | dilate=replace_stride_with_dilation[0]) 153 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 154 | dilate=replace_stride_with_dilation[1]) 155 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 156 | dilate=replace_stride_with_dilation[2]) 157 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 158 | self.fc = nn.Linear(512 * block.expansion, num_classes) 159 | 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 163 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 164 | nn.init.constant_(m.weight, 1) 165 | nn.init.constant_(m.bias, 0) 166 | 167 | # Zero-initialize the last BN in each residual branch, 168 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 169 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 170 | if zero_init_residual: 171 | for m in self.modules(): 172 | if isinstance(m, Bottleneck): 173 | nn.init.constant_(m.bn3.weight, 0) 174 | elif isinstance(m, BasicBlock): 175 | nn.init.constant_(m.bn2.weight, 0) 176 | 177 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 178 | norm_layer = self._norm_layer 179 | downsample = None 180 | previous_dilation = self.dilation 181 | if dilate: 182 | self.dilation *= stride 183 | stride = 1 184 | if stride != 1 or self.inplanes != planes * block.expansion: 185 | downsample = nn.Sequential( 186 | conv1x1(self.inplanes, planes * block.expansion, stride), 187 | norm_layer(planes * block.expansion), 188 | ) 189 | 190 | layers = [] 191 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 192 | self.base_width, previous_dilation, norm_layer)) 193 | self.inplanes = planes * block.expansion 194 | for _ in range(1, blocks): 195 | layers.append(block(self.inplanes, planes, groups=self.groups, 196 | base_width=self.base_width, dilation=self.dilation, 197 | norm_layer=norm_layer)) 198 | 199 | return nn.Sequential(*layers) 200 | 201 | def forward(self, x): 202 | x = self.conv1(x) 203 | x = self.bn1(x) 204 | x = self.relu(x) 205 | x = self.maxpool(x) 206 | 207 | x = self.layer1(x) 208 | x = self.layer2(x) 209 | x = self.layer3(x) 210 | x = self.layer4(x) 211 | 212 | x = self.avgpool(x) 213 | x = x.reshape(x.size(0), -1) 214 | x = self.fc(x) 215 | 216 | return x 217 | 218 | 219 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 220 | model = ResNet(block, layers, **kwargs) 221 | if pretrained: 222 | state_dict = load_state_dict_from_url(model_urls[arch], 223 | progress=progress) 224 | model.load_state_dict(state_dict) 225 | return model 226 | 227 | 228 | def resnet18(pretrained=False, progress=True, **kwargs): 229 | """Constructs a ResNet-18 model. 230 | Args: 231 | pretrained (bool): If True, returns a model pre-trained on ImageNet 232 | progress (bool): If True, displays a progress bar of the download to stderr 233 | """ 234 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 235 | **kwargs) 236 | 237 | 238 | def resnet34(pretrained=False, progress=True, **kwargs): 239 | """Constructs a ResNet-34 model. 240 | Args: 241 | pretrained (bool): If True, returns a model pre-trained on ImageNet 242 | progress (bool): If True, displays a progress bar of the download to stderr 243 | """ 244 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 245 | **kwargs) 246 | 247 | 248 | def resnet50(pretrained=False, progress=True, **kwargs): 249 | """Constructs a ResNet-50 model. 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | progress (bool): If True, displays a progress bar of the download to stderr 253 | """ 254 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 255 | **kwargs) 256 | 257 | 258 | def resnet101(pretrained=False, progress=True, **kwargs): 259 | """Constructs a ResNet-101 model. 260 | Args: 261 | pretrained (bool): If True, returns a model pre-trained on ImageNet 262 | progress (bool): If True, displays a progress bar of the download to stderr 263 | """ 264 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 265 | **kwargs) 266 | 267 | 268 | def resnet152(pretrained=False, progress=True, **kwargs): 269 | """Constructs a ResNet-152 model. 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | progress (bool): If True, displays a progress bar of the download to stderr 273 | """ 274 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 275 | **kwargs) 276 | 277 | 278 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 279 | """Constructs a ResNeXt-50 32x4d model. 280 | Args: 281 | pretrained (bool): If True, returns a model pre-trained on ImageNet 282 | progress (bool): If True, displays a progress bar of the download to stderr 283 | """ 284 | kwargs['groups'] = 32 285 | kwargs['width_per_group'] = 4 286 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 287 | pretrained, progress, **kwargs) 288 | 289 | 290 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 291 | """Constructs a ResNeXt-101 32x8d model. 292 | Args: 293 | pretrained (bool): If True, returns a model pre-trained on ImageNet 294 | progress (bool): If True, displays a progress bar of the download to stderr 295 | """ 296 | kwargs['groups'] = 32 297 | kwargs['width_per_group'] = 8 298 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 299 | pretrained, progress, **kwargs) 300 | -------------------------------------------------------------------------------- /imnet_finetune/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .train import Trainer 8 | from .config import TrainerConfig, ClusterConfig 9 | -------------------------------------------------------------------------------- /imnet_finetune/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from typing import NamedTuple 8 | 9 | 10 | class ClusterConfig(NamedTuple): 11 | dist_backend: str 12 | dist_url: str 13 | 14 | 15 | class TrainerConfig(NamedTuple): 16 | data_folder: str 17 | epochs: int 18 | lr: float 19 | input_size: int 20 | batch_per_gpu: int 21 | save_folder: str 22 | imnet_path: str 23 | architecture: str 24 | resnet_weight_path: str 25 | workers: int 26 | local_rank: int 27 | global_rank: int 28 | num_tasks: int 29 | job_id: str 30 | EfficientNet_models: str 31 | -------------------------------------------------------------------------------- /imnet_finetune/pnasnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | ''' 8 | Code from https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/pnasnet.py 9 | with some adaptations 10 | ''' 11 | 12 | 13 | 14 | from __future__ import print_function, division, absolute_import 15 | from collections import OrderedDict 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.utils.model_zoo as model_zoo 20 | 21 | 22 | pretrained_settings = { 23 | 'pnasnet5large': { 24 | 'imagenet': { 25 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth', 26 | 'input_space': 'RGB', 27 | 'input_size': [3, 331, 331], 28 | 'input_range': [0, 1], 29 | 'mean': [0.5, 0.5, 0.5], 30 | 'std': [0.5, 0.5, 0.5], 31 | 'num_classes': 1000 32 | }, 33 | 'imagenet+background': { 34 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth', 35 | 'input_space': 'RGB', 36 | 'input_size': [3, 331, 331], 37 | 'input_range': [0, 1], 38 | 'mean': [0.5, 0.5, 0.5], 39 | 'std': [0.5, 0.5, 0.5], 40 | 'num_classes': 1001 41 | } 42 | } 43 | } 44 | 45 | 46 | class MaxPool(nn.Module): 47 | 48 | def __init__(self, kernel_size, stride=1, padding=1, zero_pad=False): 49 | super(MaxPool, self).__init__() 50 | self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None 51 | self.pool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding) 52 | 53 | def forward(self, x): 54 | if self.zero_pad: 55 | x = self.zero_pad(x) 56 | x = self.pool(x) 57 | if self.zero_pad: 58 | x = x[:, :, 1:, 1:] 59 | return x 60 | 61 | 62 | class SeparableConv2d(nn.Module): 63 | 64 | def __init__(self, in_channels, out_channels, dw_kernel_size, dw_stride, 65 | dw_padding): 66 | super(SeparableConv2d, self).__init__() 67 | self.depthwise_conv2d = nn.Conv2d(in_channels, in_channels, 68 | kernel_size=dw_kernel_size, 69 | stride=dw_stride, padding=dw_padding, 70 | groups=in_channels, bias=False) 71 | self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, 72 | kernel_size=1, bias=False) 73 | 74 | def forward(self, x): 75 | x = self.depthwise_conv2d(x) 76 | x = self.pointwise_conv2d(x) 77 | return x 78 | 79 | 80 | class BranchSeparables(nn.Module): 81 | 82 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 83 | stem_cell=False, zero_pad=False): 84 | super(BranchSeparables, self).__init__() 85 | padding = kernel_size // 2 86 | middle_channels = out_channels if stem_cell else in_channels 87 | self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None 88 | self.relu_1 = nn.ReLU() 89 | self.separable_1 = SeparableConv2d(in_channels, middle_channels, 90 | kernel_size, dw_stride=stride, 91 | dw_padding=padding) 92 | self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001) 93 | self.relu_2 = nn.ReLU() 94 | self.separable_2 = SeparableConv2d(middle_channels, out_channels, 95 | kernel_size, dw_stride=1, 96 | dw_padding=padding) 97 | self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001) 98 | 99 | def forward(self, x): 100 | x = self.relu_1(x) 101 | if self.zero_pad: 102 | x = self.zero_pad(x) 103 | x = self.separable_1(x) 104 | if self.zero_pad: 105 | x = x[:, :, 1:, 1:].contiguous() 106 | x = self.bn_sep_1(x) 107 | x = self.relu_2(x) 108 | x = self.separable_2(x) 109 | x = self.bn_sep_2(x) 110 | return x 111 | 112 | 113 | class ReluConvBn(nn.Module): 114 | 115 | def __init__(self, in_channels, out_channels, kernel_size, stride=1): 116 | super(ReluConvBn, self).__init__() 117 | self.relu = nn.ReLU() 118 | self.conv = nn.Conv2d(in_channels, out_channels, 119 | kernel_size=kernel_size, stride=stride, 120 | bias=False) 121 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 122 | 123 | def forward(self, x): 124 | x = self.relu(x) 125 | x = self.conv(x) 126 | x = self.bn(x) 127 | return x 128 | 129 | 130 | class FactorizedReduction(nn.Module): 131 | 132 | def __init__(self, in_channels, out_channels): 133 | super(FactorizedReduction, self).__init__() 134 | self.relu = nn.ReLU() 135 | self.path_1 = nn.Sequential(OrderedDict([ 136 | ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), 137 | ('conv', nn.Conv2d(in_channels, out_channels // 2, 138 | kernel_size=1, bias=False)), 139 | ])) 140 | self.path_2 = nn.Sequential(OrderedDict([ 141 | ('pad', nn.ZeroPad2d((0, 1, 0, 1))), 142 | ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), 143 | ('conv', nn.Conv2d(in_channels, out_channels // 2, 144 | kernel_size=1, bias=False)), 145 | ])) 146 | self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001) 147 | 148 | def forward(self, x): 149 | x = self.relu(x) 150 | 151 | x_path1 = self.path_1(x) 152 | 153 | x_path2 = self.path_2.pad(x) 154 | x_path2 = x_path2[:, :, 1:, 1:] 155 | x_path2 = self.path_2.avgpool(x_path2) 156 | x_path2 = self.path_2.conv(x_path2) 157 | 158 | out = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) 159 | return out 160 | 161 | 162 | class CellBase(nn.Module): 163 | 164 | def cell_forward(self, x_left, x_right): 165 | x_comb_iter_0_left = self.comb_iter_0_left(x_left) 166 | x_comb_iter_0_right = self.comb_iter_0_right(x_left) 167 | x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right 168 | 169 | x_comb_iter_1_left = self.comb_iter_1_left(x_right) 170 | x_comb_iter_1_right = self.comb_iter_1_right(x_right) 171 | x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right 172 | 173 | x_comb_iter_2_left = self.comb_iter_2_left(x_right) 174 | x_comb_iter_2_right = self.comb_iter_2_right(x_right) 175 | x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right 176 | 177 | x_comb_iter_3_left = self.comb_iter_3_left(x_comb_iter_2) 178 | x_comb_iter_3_right = self.comb_iter_3_right(x_right) 179 | x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right 180 | 181 | x_comb_iter_4_left = self.comb_iter_4_left(x_left) 182 | if self.comb_iter_4_right: 183 | x_comb_iter_4_right = self.comb_iter_4_right(x_right) 184 | else: 185 | x_comb_iter_4_right = x_right 186 | x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right 187 | 188 | x_out = torch.cat( 189 | [x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, 190 | x_comb_iter_4], 1) 191 | return x_out 192 | 193 | 194 | class CellStem0(CellBase): 195 | 196 | def __init__(self, in_channels_left, out_channels_left, in_channels_right, 197 | out_channels_right): 198 | super(CellStem0, self).__init__() 199 | self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right, 200 | kernel_size=1) 201 | self.comb_iter_0_left = BranchSeparables(in_channels_left, 202 | out_channels_left, 203 | kernel_size=5, stride=2, 204 | stem_cell=True) 205 | self.comb_iter_0_right = nn.Sequential(OrderedDict([ 206 | ('max_pool', MaxPool(3, stride=2)), 207 | ('conv', nn.Conv2d(in_channels_left, out_channels_left, 208 | kernel_size=1, bias=False)), 209 | ('bn', nn.BatchNorm2d(out_channels_left, eps=0.001)), 210 | ])) 211 | self.comb_iter_1_left = BranchSeparables(out_channels_right, 212 | out_channels_right, 213 | kernel_size=7, stride=2) 214 | self.comb_iter_1_right = MaxPool(3, stride=2) 215 | self.comb_iter_2_left = BranchSeparables(out_channels_right, 216 | out_channels_right, 217 | kernel_size=5, stride=2) 218 | self.comb_iter_2_right = BranchSeparables(out_channels_right, 219 | out_channels_right, 220 | kernel_size=3, stride=2) 221 | self.comb_iter_3_left = BranchSeparables(out_channels_right, 222 | out_channels_right, 223 | kernel_size=3) 224 | self.comb_iter_3_right = MaxPool(3, stride=2) 225 | self.comb_iter_4_left = BranchSeparables(in_channels_right, 226 | out_channels_right, 227 | kernel_size=3, stride=2, 228 | stem_cell=True) 229 | self.comb_iter_4_right = ReluConvBn(out_channels_right, 230 | out_channels_right, 231 | kernel_size=1, stride=2) 232 | 233 | def forward(self, x_left): 234 | x_right = self.conv_1x1(x_left) 235 | x_out = self.cell_forward(x_left, x_right) 236 | return x_out 237 | 238 | 239 | class Cell(CellBase): 240 | 241 | def __init__(self, in_channels_left, out_channels_left, in_channels_right, 242 | out_channels_right, is_reduction=False, zero_pad=False, 243 | match_prev_layer_dimensions=False): 244 | super(Cell, self).__init__() 245 | 246 | # If `is_reduction` is set to `True` stride 2 is used for 247 | # convolutional and pooling layers to reduce the spatial size of 248 | # the output of a cell approximately by a factor of 2. 249 | stride = 2 if is_reduction else 1 250 | 251 | # If `match_prev_layer_dimensions` is set to `True` 252 | # `FactorizedReduction` is used to reduce the spatial size 253 | # of the left input of a cell approximately by a factor of 2. 254 | self.match_prev_layer_dimensions = match_prev_layer_dimensions 255 | if match_prev_layer_dimensions: 256 | self.conv_prev_1x1 = FactorizedReduction(in_channels_left, 257 | out_channels_left) 258 | else: 259 | self.conv_prev_1x1 = ReluConvBn(in_channels_left, 260 | out_channels_left, kernel_size=1) 261 | 262 | self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right, 263 | kernel_size=1) 264 | self.comb_iter_0_left = BranchSeparables(out_channels_left, 265 | out_channels_left, 266 | kernel_size=5, stride=stride, 267 | zero_pad=zero_pad) 268 | self.comb_iter_0_right = MaxPool(3, stride=stride, zero_pad=zero_pad) 269 | self.comb_iter_1_left = BranchSeparables(out_channels_right, 270 | out_channels_right, 271 | kernel_size=7, stride=stride, 272 | zero_pad=zero_pad) 273 | self.comb_iter_1_right = MaxPool(3, stride=stride, zero_pad=zero_pad) 274 | self.comb_iter_2_left = BranchSeparables(out_channels_right, 275 | out_channels_right, 276 | kernel_size=5, stride=stride, 277 | zero_pad=zero_pad) 278 | self.comb_iter_2_right = BranchSeparables(out_channels_right, 279 | out_channels_right, 280 | kernel_size=3, stride=stride, 281 | zero_pad=zero_pad) 282 | self.comb_iter_3_left = BranchSeparables(out_channels_right, 283 | out_channels_right, 284 | kernel_size=3) 285 | self.comb_iter_3_right = MaxPool(3, stride=stride, zero_pad=zero_pad) 286 | self.comb_iter_4_left = BranchSeparables(out_channels_left, 287 | out_channels_left, 288 | kernel_size=3, stride=stride, 289 | zero_pad=zero_pad) 290 | if is_reduction: 291 | self.comb_iter_4_right = ReluConvBn(out_channels_right, 292 | out_channels_right, 293 | kernel_size=1, stride=stride) 294 | else: 295 | self.comb_iter_4_right = None 296 | 297 | def forward(self, x_left, x_right): 298 | x_left = self.conv_prev_1x1(x_left) 299 | x_right = self.conv_1x1(x_right) 300 | x_out = self.cell_forward(x_left, x_right) 301 | return x_out 302 | 303 | 304 | class PNASNet5Large(nn.Module): 305 | def __init__(self, num_classes=1001): 306 | super(PNASNet5Large, self).__init__() 307 | self.num_classes = num_classes 308 | self.conv_0 = nn.Sequential(OrderedDict([ 309 | ('conv', nn.Conv2d(3, 96, kernel_size=3, stride=2, bias=False)), 310 | ('bn', nn.BatchNorm2d(96, eps=0.001)) 311 | ])) 312 | self.cell_stem_0 = CellStem0(in_channels_left=96, out_channels_left=54, 313 | in_channels_right=96, 314 | out_channels_right=54) 315 | self.cell_stem_1 = Cell(in_channels_left=96, out_channels_left=108, 316 | in_channels_right=270, out_channels_right=108, 317 | match_prev_layer_dimensions=True, 318 | is_reduction=True) 319 | self.cell_0 = Cell(in_channels_left=270, out_channels_left=216, 320 | in_channels_right=540, out_channels_right=216, 321 | match_prev_layer_dimensions=True) 322 | self.cell_1 = Cell(in_channels_left=540, out_channels_left=216, 323 | in_channels_right=1080, out_channels_right=216) 324 | self.cell_2 = Cell(in_channels_left=1080, out_channels_left=216, 325 | in_channels_right=1080, out_channels_right=216) 326 | self.cell_3 = Cell(in_channels_left=1080, out_channels_left=216, 327 | in_channels_right=1080, out_channels_right=216) 328 | self.cell_4 = Cell(in_channels_left=1080, out_channels_left=432, 329 | in_channels_right=1080, out_channels_right=432, 330 | is_reduction=True, zero_pad=True) 331 | self.cell_5 = Cell(in_channels_left=1080, out_channels_left=432, 332 | in_channels_right=2160, out_channels_right=432, 333 | match_prev_layer_dimensions=True) 334 | self.cell_6 = Cell(in_channels_left=2160, out_channels_left=432, 335 | in_channels_right=2160, out_channels_right=432) 336 | self.cell_7 = Cell(in_channels_left=2160, out_channels_left=432, 337 | in_channels_right=2160, out_channels_right=432) 338 | self.cell_8 = Cell(in_channels_left=2160, out_channels_left=864, 339 | in_channels_right=2160, out_channels_right=864, 340 | is_reduction=True) 341 | self.cell_9 = Cell(in_channels_left=2160, out_channels_left=864, 342 | in_channels_right=4320, out_channels_right=864, 343 | match_prev_layer_dimensions=True) 344 | self.cell_10 = Cell(in_channels_left=4320, out_channels_left=864, 345 | in_channels_right=4320, out_channels_right=864) 346 | self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864, 347 | in_channels_right=4320, out_channels_right=864) 348 | self.relu = nn.ReLU() 349 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 350 | self.dropout = nn.Dropout(0.5) 351 | self.last_linear = nn.Linear(4320, num_classes) 352 | 353 | def features(self, x): 354 | x_conv_0 = self.conv_0(x) 355 | x_stem_0 = self.cell_stem_0(x_conv_0) 356 | x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0) 357 | x_cell_0 = self.cell_0(x_stem_0, x_stem_1) 358 | x_cell_1 = self.cell_1(x_stem_1, x_cell_0) 359 | x_cell_2 = self.cell_2(x_cell_0, x_cell_1) 360 | x_cell_3 = self.cell_3(x_cell_1, x_cell_2) 361 | x_cell_4 = self.cell_4(x_cell_2, x_cell_3) 362 | x_cell_5 = self.cell_5(x_cell_3, x_cell_4) 363 | x_cell_6 = self.cell_6(x_cell_4, x_cell_5) 364 | x_cell_7 = self.cell_7(x_cell_5, x_cell_6) 365 | x_cell_8 = self.cell_8(x_cell_6, x_cell_7) 366 | x_cell_9 = self.cell_9(x_cell_7, x_cell_8) 367 | x_cell_10 = self.cell_10(x_cell_8, x_cell_9) 368 | x_cell_11 = self.cell_11(x_cell_9, x_cell_10) 369 | return x_cell_11 370 | 371 | def logits(self, features): 372 | x = self.relu(features) 373 | x = self.avg_pool(x) 374 | x = x.view(x.size(0), -1) 375 | x = self.dropout(x) 376 | x = self.last_linear(x) 377 | return x 378 | 379 | def forward(self, input): 380 | x = self.features(input) 381 | x = self.logits(x) 382 | return x 383 | 384 | 385 | def pnasnet5large(num_classes=1000, pretrained='imagenet'): 386 | r"""PNASNet-5 model architecture from the 387 | `"Progressive Neural Architecture Search" 388 | `_ paper. 389 | """ 390 | if pretrained: 391 | settings = pretrained_settings['pnasnet5large'][pretrained] 392 | assert num_classes == settings[ 393 | 'num_classes'], 'num_classes should be {}, but is {}'.format( 394 | settings['num_classes'], num_classes) 395 | 396 | # both 'imagenet'&'imagenet+background' are loaded from same parameters 397 | model = PNASNet5Large(num_classes=1001) 398 | model.load_state_dict(model_zoo.load_url(settings['url'])) 399 | 400 | if pretrained == 'imagenet': 401 | new_last_linear = nn.Linear(model.last_linear.in_features, 1000) 402 | new_last_linear.weight.data = model.last_linear.weight.data[1:] 403 | new_last_linear.bias.data = model.last_linear.bias.data[1:] 404 | model.last_linear = new_last_linear 405 | 406 | model.input_space = settings['input_space'] 407 | model.input_size = settings['input_size'] 408 | model.input_range = settings['input_range'] 409 | 410 | model.mean = settings['mean'] 411 | model.std = settings['std'] 412 | else: 413 | model = PNASNet5Large(num_classes=num_classes) 414 | return model 415 | -------------------------------------------------------------------------------- /imnet_finetune/resnext_wsl.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # Optional list of dependencies required by the package 9 | 10 | ''' 11 | Code From : https://github.com/facebookresearch/WSL-Images/blob/master/hubconf.py 12 | ''' 13 | dependencies = ['torch', 'torchvision'] 14 | 15 | try: 16 | from torch.hub import load_state_dict_from_url 17 | except ImportError: 18 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 19 | 20 | from .Res import ResNet, Bottleneck 21 | 22 | 23 | model_urls = { 24 | 'resnext101_32x8d': 'https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth', 25 | 'resnext101_32x16d': 'https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth', 26 | 'resnext101_32x32d': 'https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth', 27 | 'resnext101_32x48d': 'https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth', 28 | } 29 | 30 | 31 | def _resnext(arch, block, layers, pretrained, progress, **kwargs): 32 | model = ResNet(block, layers, **kwargs) 33 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 34 | model.load_state_dict(state_dict) 35 | return model 36 | 37 | 38 | def resnext101_32x8d_wsl(progress=True, **kwargs): 39 | """Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data 40 | and finetuned on ImageNet from Figure 5 in 41 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 42 | Args: 43 | progress (bool): If True, displays a progress bar of the download to stderr. 44 | """ 45 | kwargs['groups'] = 32 46 | kwargs['width_per_group'] = 8 47 | return _resnext('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 48 | 49 | 50 | def resnext101_32x16d_wsl(progress=True, **kwargs): 51 | """Constructs a ResNeXt-101 32x16 model pre-trained on weakly-supervised data 52 | and finetuned on ImageNet from Figure 5 in 53 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 54 | Args: 55 | progress (bool): If True, displays a progress bar of the download to stderr. 56 | """ 57 | kwargs['groups'] = 32 58 | kwargs['width_per_group'] = 16 59 | return _resnext('resnext101_32x16d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 60 | 61 | 62 | def resnext101_32x32d_wsl(progress=True, **kwargs): 63 | """Constructs a ResNeXt-101 32x32 model pre-trained on weakly-supervised data 64 | and finetuned on ImageNet from Figure 5 in 65 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 66 | Args: 67 | progress (bool): If True, displays a progress bar of the download to stderr. 68 | """ 69 | kwargs['groups'] = 32 70 | kwargs['width_per_group'] = 32 71 | return _resnext('resnext101_32x32d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 72 | 73 | 74 | def resnext101_32x48d_wsl(progress=True, **kwargs): 75 | """Constructs a ResNeXt-101 32x48 model pre-trained on weakly-supervised data 76 | and finetuned on ImageNet from Figure 5 in 77 | `"Exploring the Limits of Weakly Supervised Pretraining" `_ 78 | Args: 79 | progress (bool): If True, displays a progress bar of the download to stderr. 80 | """ 81 | kwargs['groups'] = 32 82 | kwargs['width_per_group'] = 48 83 | return _resnext('resnext101_32x48d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 84 | -------------------------------------------------------------------------------- /imnet_finetune/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from torch.utils.data.sampler import BatchSampler 8 | import torch 9 | import numpy as np 10 | from torch.utils.data.dataloader import default_collate 11 | from collections.abc import Mapping, Sequence 12 | import math 13 | import torch.distributed as dist 14 | 15 | class RASampler(torch.utils.data.Sampler): 16 | """ 17 | Batch Sampler with Repeated Augmentations (RA) 18 | - dataset_len: original length of the dataset 19 | - batch_size 20 | - repetitions: instances per image 21 | - len_factor: multiplicative factor for epoch size 22 | """ 23 | 24 | def __init__(self,dataset,num_replicas, rank, dataset_len, batch_size, repetitions=1, len_factor=1.0, shuffle=False, drop_last=False): 25 | self.dataset=dataset 26 | self.dataset_len = dataset_len 27 | self.batch_size = batch_size 28 | self.repetitions = repetitions 29 | self.len_images = int(dataset_len * len_factor) 30 | self.shuffle = shuffle 31 | self.drop_last = drop_last 32 | if num_replicas is None: 33 | if not dist.is_available(): 34 | raise RuntimeError("Requires distributed package to be available") 35 | num_replicas = dist.get_world_size() 36 | if rank is None: 37 | if not dist.is_available(): 38 | raise RuntimeError("Requires distributed package to be available") 39 | rank = dist.get_rank() 40 | self.dataset = dataset 41 | self.num_replicas = num_replicas 42 | self.rank = rank 43 | self.epoch = 0 44 | self.num_samples = int(math.ceil(len(self.dataset) * self.repetitions * 1.0 / self.num_replicas)) 45 | self.total_size = self.num_samples * self.num_replicas 46 | 47 | 48 | def shuffler(self): 49 | if self.shuffle: 50 | new_perm = lambda: iter(np.random.permutation(self.dataset_len)) 51 | else: 52 | new_perm = lambda: iter(np.arange(self.dataset_len)) 53 | shuffle = new_perm() 54 | while True: 55 | try: 56 | index = next(shuffle) 57 | except StopIteration: 58 | shuffle = new_perm() 59 | index = next(shuffle) 60 | for repetition in range(self.repetitions): 61 | yield index 62 | 63 | def __iter__(self): 64 | shuffle = iter(self.shuffler()) 65 | seen = 0 66 | indices=[] 67 | for _ in range(self.len_images): 68 | index = next(shuffle) 69 | indices.append(index) 70 | indices += indices[:(self.total_size - len(indices))] 71 | assert len(indices) == self.total_size 72 | # subsample 73 | indices = indices[self.rank:self.total_size:self.num_replicas] 74 | assert len(indices) == self.num_samples 75 | 76 | return iter(indices) 77 | 78 | 79 | def __len__(self): 80 | return self.num_samples 81 | 82 | def set_epoch(self, epoch): 83 | self.epoch = epoch 84 | 85 | def list_collate(batch): 86 | """ 87 | Collate into a list instead of a tensor to deal with variable-sized inputs 88 | """ 89 | elem_type = type(batch[0]) 90 | if isinstance(batch[0], torch.Tensor): 91 | return batch 92 | elif elem_type.__module__ == 'numpy': 93 | if elem_type.__name__ == 'ndarray': 94 | return list_collate([torch.from_numpy(b) for b in batch]) 95 | elif isinstance(batch[0], Mapping): 96 | return {key: list_collate([d[key] for d in batch]) for key in batch[0]} 97 | elif isinstance(batch[0], Sequence): 98 | transposed = zip(*batch) 99 | return [list_collate(samples) for samples in transposed] 100 | return default_collate(batch) 101 | -------------------------------------------------------------------------------- /imnet_finetune/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import os.path as osp 9 | from typing import Optional 10 | import torch 11 | import torch.distributed 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import attr 15 | from torchvision import datasets 16 | import torchvision.models as models 17 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 18 | import numpy as np 19 | from .config import TrainerConfig, ClusterConfig 20 | from .transforms import get_transforms 21 | from .resnext_wsl import resnext101_32x48d_wsl 22 | from .pnasnet import pnasnet5large 23 | try: 24 | from timm.models import create_model #From: https://github.com/rwightman/pytorch-image-models 25 | from timm.models.efficientnet import default_cfgs 26 | has_timm = True 27 | except ImportError: 28 | has_timm = False 29 | 30 | 31 | 32 | @attr.s(auto_attribs=True) 33 | class TrainerState: 34 | """ 35 | Contains the state of the Trainer. 36 | It can be saved to checkpoint the training and loaded to resume it. 37 | """ 38 | 39 | epoch: int 40 | accuracy:float 41 | model: nn.Module 42 | optimizer: optim.Optimizer 43 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler 44 | 45 | def save(self, filename: str) -> None: 46 | data = attr.asdict(self) 47 | # store only the state dict 48 | data["model"] = self.model.state_dict() 49 | data["optimizer"] = self.optimizer.state_dict() 50 | data["lr_scheduler"] = self.lr_scheduler.state_dict() 51 | data["accuracy"] = self.accuracy 52 | torch.save(data, filename) 53 | 54 | @classmethod 55 | def load(cls, filename: str, default: "TrainerState") -> "TrainerState": 56 | data = torch.load(filename) 57 | # We need this default to load the state dict 58 | model = default.model 59 | model.load_state_dict(data["model"]) 60 | data["model"] = model 61 | 62 | optimizer = default.optimizer 63 | optimizer.load_state_dict(data["optimizer"]) 64 | data["optimizer"] = optimizer 65 | 66 | lr_scheduler = default.lr_scheduler 67 | lr_scheduler.load_state_dict(data["lr_scheduler"]) 68 | data["lr_scheduler"] = lr_scheduler 69 | return cls(**data) 70 | 71 | 72 | class Trainer: 73 | def __init__(self, train_cfg: TrainerConfig, cluster_cfg: ClusterConfig) -> None: 74 | self._train_cfg = train_cfg 75 | self._cluster_cfg = cluster_cfg 76 | 77 | def __call__(self) -> Optional[float]: 78 | """ 79 | Called for each task. 80 | 81 | :return: The master task return the final accuracy of the model. 82 | """ 83 | self._setup_process_group() 84 | self._init_state() 85 | final_acc = self._train() 86 | return final_acc 87 | 88 | def checkpoint(self, rm_init=True): 89 | save_dir = osp.join(self._train_cfg.save_folder, str(self._train_cfg.job_id)) 90 | os.makedirs(save_dir, exist_ok=True) 91 | self._state.save(osp.join(save_dir, "checkpoint.pth")) 92 | self._state.save(osp.join(save_dir, "checkpoint_"+str(self._state.epoch)+".pth")) 93 | 94 | if rm_init: 95 | os.remove(self._cluster_cfg.dist_url[7:]) 96 | empty_trainer = Trainer(self._train_cfg, self._cluster_cfg) 97 | return empty_trainer 98 | 99 | def _setup_process_group(self) -> None: 100 | torch.cuda.set_device(self._train_cfg.local_rank) 101 | torch.distributed.init_process_group( 102 | backend=self._cluster_cfg.dist_backend, 103 | init_method=self._cluster_cfg.dist_url, 104 | world_size=self._train_cfg.num_tasks, 105 | rank=self._train_cfg.global_rank, 106 | ) 107 | print(f"Process group: {self._train_cfg.num_tasks} tasks, rank: {self._train_cfg.global_rank}") 108 | 109 | def _init_state(self) -> None: 110 | """ 111 | Initialize the state and load it from an existing checkpoint if any 112 | """ 113 | torch.manual_seed(0) 114 | np.random.seed(0) 115 | print("Create data loaders", flush=True) 116 | 117 | Input_size_Image=self._train_cfg.input_size 118 | 119 | print("Input size : "+str(Input_size_Image)) 120 | print("Model : " + str(self._train_cfg.architecture) ) 121 | backbone_architecture=None 122 | if self._train_cfg.architecture=='PNASNet' : 123 | backbone_architecture='pnasnet5large' 124 | 125 | transformation=get_transforms(input_size=self._train_cfg.input_size,test_size=self._train_cfg.input_size, kind='full', crop=True, need=('train', 'val'), backbone=backbone_architecture) 126 | transform_test = transformation['val'] 127 | 128 | 129 | train_set = datasets.ImageFolder(self._train_cfg.imnet_path+ '/train',transform=transform_test) 130 | 131 | train_sampler = torch.utils.data.distributed.DistributedSampler( 132 | train_set,num_replicas=self._train_cfg.num_tasks, rank=self._train_cfg.global_rank 133 | ) 134 | 135 | self._train_loader = torch.utils.data.DataLoader( 136 | train_set, 137 | batch_size=self._train_cfg.batch_per_gpu, 138 | num_workers=(self._train_cfg.workers-1), 139 | sampler=train_sampler, 140 | ) 141 | test_set = datasets.ImageFolder(self._train_cfg.imnet_path + '/val',transform=transform_test) 142 | 143 | 144 | self._test_loader = torch.utils.data.DataLoader( 145 | test_set, batch_size=self._train_cfg.batch_per_gpu, shuffle=False, num_workers=(self._train_cfg.workers-1), 146 | ) 147 | 148 | print(f"Total batch_size: {self._train_cfg.batch_per_gpu * self._train_cfg.num_tasks}", flush=True) 149 | 150 | print("Create distributed model", flush=True) 151 | 152 | if self._train_cfg.architecture=='PNASNet' : 153 | model= pnasnet5large(pretrained='imagenet') 154 | 155 | if self._train_cfg.architecture=='ResNet50' : 156 | model=models.resnet50(pretrained=False) 157 | pretrained_dict=torch.load(self._train_cfg.resnet_weight_path,map_location='cpu')['model'] 158 | model_dict = model.state_dict() 159 | count=0 160 | count2=0 161 | for k in model_dict.keys(): 162 | count=count+1.0 163 | if(('module.'+k) in pretrained_dict.keys()): 164 | count2=count2+1.0 165 | model_dict[k]=pretrained_dict.get(('module.'+k)) 166 | model.load_state_dict(model_dict) 167 | print("load "+str(count2*100/count)+" %") 168 | 169 | assert int(count2*100/count)== 100,"model loading error" 170 | 171 | if self._train_cfg.architecture=='IGAM_Resnext101_32x48d' : 172 | model=resnext101_32x48d_wsl(progress=True) 173 | 174 | if self._train_cfg.architecture=='PNASNet' : 175 | for name, child in model.named_children(): 176 | if 'last_linear' not in name and 'cell_11' not in name and 'cell_10' not in name and 'cell_9' not in name: 177 | for name2, params in child.named_parameters(): 178 | params.requires_grad = False 179 | elif not self._train_cfg.architecture=='EfficientNet' : 180 | 181 | for name, child in model.named_children(): 182 | if 'fc' not in name: 183 | for name2, params in child.named_parameters(): 184 | params.requires_grad = False 185 | 186 | if self._train_cfg.architecture=='EfficientNet' : 187 | assert has_timm 188 | model = create_model(self._train_cfg.EfficientNet_models,pretrained=False,num_classes=1000) #see https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/efficientnet.py for name 189 | for name, child in model.named_children(): 190 | if 'classifier' not in name: 191 | for name2, params in child.named_parameters(): 192 | params.requires_grad = False 193 | 194 | pretrained_dict=load_state_dict_from_url(default_cfgs[self._train_cfg.EfficientNet_models]['url'],map_location='cpu') 195 | model_dict = model.state_dict() 196 | for k in model_dict.keys(): 197 | if(k in pretrained_dict.keys()): 198 | model_dict[k]=pretrained_dict.get(k) 199 | model.load_state_dict(model_dict) 200 | torch.cuda.empty_cache() 201 | model.classifier.requires_grad=True 202 | model.conv_head.requires_grad=True 203 | 204 | model.cuda(self._train_cfg.local_rank) 205 | model = torch.nn.parallel.DistributedDataParallel( 206 | model, device_ids=[self._train_cfg.local_rank], output_device=self._train_cfg.local_rank 207 | ) 208 | linear_scaled_lr = 8.0 * self._train_cfg.lr * self._train_cfg.batch_per_gpu * self._train_cfg.num_tasks /512.0 209 | optimizer = optim.SGD(model.parameters(), lr=linear_scaled_lr, momentum=0.9,weight_decay=1e-4) 210 | lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30) 211 | self._state = TrainerState( 212 | epoch=0,accuracy=0.0, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler 213 | ) 214 | checkpoint_fn = osp.join(self._train_cfg.save_folder, str(self._train_cfg.job_id), "checkpoint.pth") 215 | if os.path.isfile(checkpoint_fn): 216 | print(f"Load existing checkpoint from {checkpoint_fn}", flush=True) 217 | self._state = TrainerState.load(checkpoint_fn, default=self._state) 218 | 219 | def _train(self) -> Optional[float]: 220 | criterion = nn.CrossEntropyLoss() 221 | print_freq = 10 222 | acc = None 223 | max_accuracy=0.0 224 | 225 | print("Evaluation before fine-tuning") 226 | correct = 0 227 | total = 0 228 | count=0.0 229 | running_val_loss = 0.0 230 | self._state.model.eval() 231 | 232 | if self._train_cfg.architecture=='PNASNet' : 233 | self._state.model.module.cell_11.eval() 234 | self._state.model.module.cell_10.eval() 235 | self._state.model.module.cell_9.eval() 236 | self._state.model.module.dropout.eval() 237 | elif self._train_cfg.architecture=='EfficientNet' : 238 | self._state.model.module.classifier.eval() 239 | self._state.model.module.conv_head.eval() 240 | self._state.model.module.bn2.eval() 241 | 242 | else: 243 | self._state.model.module.layer4[2].bn3.eval() 244 | 245 | 246 | with torch.no_grad(): 247 | for data in self._test_loader: 248 | images, labels = data 249 | images = images.cuda(self._train_cfg.local_rank, non_blocking=True) 250 | labels = labels.cuda(self._train_cfg.local_rank, non_blocking=True) 251 | outputs = self._state.model(images) 252 | loss_val = criterion(outputs, labels) 253 | _, predicted = torch.max(outputs.data, 1) 254 | total += labels.size(0) 255 | correct += (predicted == labels).sum().item() 256 | running_val_loss += loss_val.item() 257 | count=count+1.0 258 | 259 | acc = correct / total 260 | ls_nm=running_val_loss/count 261 | print(f"Accuracy of the network on the 50000 test images: {acc:.1%}", flush=True) 262 | print(f"Loss of the network on the 50000 test images: {ls_nm:.3f}", flush=True) 263 | print("Accuracy before fine-tuning : "+str(acc)) 264 | max_accuracy=np.max((max_accuracy,acc)) 265 | start_epoch = self._state.epoch 266 | # Start from the loaded epoch 267 | for epoch in range(start_epoch, self._train_cfg.epochs): 268 | print(f"Start epoch {epoch}", flush=True) 269 | self._state.model.eval() 270 | if self._train_cfg.architecture=='PNASNet' : 271 | self._state.model.module.cell_11.train() 272 | self._state.model.module.cell_10.train() 273 | self._state.model.module.cell_9.train() 274 | self._state.model.module.dropout.train() 275 | elif self._train_cfg.architecture=='EfficientNet' : 276 | self._state.model.module.classifier.train() 277 | self._state.model.module.conv_head.train() 278 | self._state.model.module.bn2.train() 279 | else: 280 | self._state.model.module.layer4[2].bn3.train() 281 | 282 | 283 | self._state.lr_scheduler.step(epoch) 284 | self._state.epoch = epoch 285 | running_loss = 0.0 286 | for i, data in enumerate(self._train_loader): 287 | inputs, labels = data 288 | inputs = inputs.cuda(self._train_cfg.local_rank, non_blocking=True) 289 | labels = labels.cuda(self._train_cfg.local_rank, non_blocking=True) 290 | 291 | outputs = self._state.model(inputs) 292 | loss = criterion(outputs, labels) 293 | 294 | self._state.optimizer.zero_grad() 295 | loss.backward() 296 | self._state.optimizer.step() 297 | 298 | running_loss += loss.item() 299 | if i % print_freq == print_freq - 1: 300 | print(f"[{epoch:02d}, {i:05d}] loss: {running_loss/print_freq:.3f}", flush=True) 301 | running_loss = 0.0 302 | 303 | 304 | if epoch==self._train_cfg.epochs-1: 305 | print("Start evaluation of the model", flush=True) 306 | 307 | correct = 0 308 | total = 0 309 | count=0.0 310 | running_val_loss = 0.0 311 | self._state.model.eval() 312 | 313 | if self._train_cfg.architecture=='PNASNet' : 314 | self._state.model.module.cell_11.eval() 315 | self._state.model.module.cell_10.eval() 316 | self._state.model.module.cell_9.eval() 317 | self._state.model.module.dropout.eval() 318 | elif self._train_cfg.architecture=='EfficientNet' : 319 | self._state.model.module.classifier.eval() 320 | self._state.model.module.conv_head.eval() 321 | self._state.model.module.bn2.eval() 322 | else: 323 | self._state.model.module.layer4[2].bn3.eval() 324 | 325 | with torch.no_grad(): 326 | for data in self._test_loader: 327 | images, labels = data 328 | images = images.cuda(self._train_cfg.local_rank, non_blocking=True) 329 | labels = labels.cuda(self._train_cfg.local_rank, non_blocking=True) 330 | outputs = self._state.model(images) 331 | loss_val = criterion(outputs, labels) 332 | _, predicted = torch.max(outputs.data, 1) 333 | total += labels.size(0) 334 | correct += (predicted == labels).sum().item() 335 | running_val_loss += loss_val.item() 336 | count=count+1.0 337 | 338 | acc = correct / total 339 | ls_nm=running_val_loss/count 340 | print(f"Accuracy of the network on the 50000 test images: {acc:.1%}", flush=True) 341 | print(f"Loss of the network on the 50000 test images: {ls_nm:.3f}", flush=True) 342 | self._state.accuracy = acc 343 | if self._train_cfg.global_rank == 0: 344 | self.checkpoint(rm_init=False) 345 | if epoch==self._train_cfg.epochs-1: 346 | return acc 347 | 348 | 349 | 350 | -------------------------------------------------------------------------------- /imnet_finetune/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torchvision.transforms.functional as F 9 | from torchvision import transforms 10 | 11 | import numpy as np 12 | 13 | class Resize(transforms.Resize): 14 | """ 15 | Resize with a ``largest=False'' argument 16 | allowing to resize to a common largest side without cropping 17 | """ 18 | 19 | 20 | def __init__(self, size, largest=False, **kwargs): 21 | super().__init__(size, **kwargs) 22 | self.largest = largest 23 | 24 | @staticmethod 25 | def target_size(w, h, size, largest=False): 26 | if h < w and largest: 27 | w, h = size, int(size * h / w) 28 | else: 29 | w, h = int(size * w / h), size 30 | size = (h, w) 31 | return size 32 | 33 | def __call__(self, img): 34 | size = self.size 35 | w, h = img.size 36 | target_size = self.target_size(w, h, size, self.largest) 37 | return F.resize(img, target_size, self.interpolation) 38 | 39 | def __repr__(self): 40 | r = super().__repr__() 41 | return r[:-1] + ', largest={})'.format(self.largest) 42 | 43 | 44 | 45 | 46 | def get_transforms(input_size=224,test_size=224, kind='full', crop=True, need=('train', 'val'), backbone=None): 47 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 48 | if backbone is not None and backbone in ['pnasnet5large', 'nasnetamobile']: 49 | mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 50 | 51 | transformations = {} 52 | if 'train' in need: 53 | if kind == 'torch': 54 | transformations['train'] = transforms.Compose([ 55 | transforms.RandomResizedCrop(input_size), 56 | transforms.RandomHorizontalFlip(), 57 | transforms.ToTensor(), 58 | transforms.Normalize(mean, std), 59 | ]) 60 | elif kind == 'full': 61 | transformations['train'] = transforms.Compose([ 62 | transforms.RandomResizedCrop(input_size), 63 | transforms.RandomHorizontalFlip(), 64 | transforms.ColorJitter(0.3, 0.3, 0.3), 65 | transforms.ToTensor(), 66 | transforms.Normalize(mean, std), 67 | ]) 68 | 69 | else: 70 | raise ValueError('Transforms kind {} unknown'.format(kind)) 71 | if 'val' in need: 72 | if crop: 73 | transformations['val'] = transforms.Compose( 74 | [Resize(int((256 / 224) * test_size)), # to maintain same ratio w.r.t. 224 images 75 | transforms.CenterCrop(test_size), 76 | transforms.ToTensor(), 77 | transforms.Normalize(mean, std)]) 78 | else: 79 | transformations['val'] = transforms.Compose( 80 | [Resize(test_size, largest=True), # to maintain same ratio w.r.t. 224 images 81 | transforms.ToTensor(), 82 | transforms.Normalize(mean, std)]) 83 | return transformations 84 | 85 | transforms_list = ['torch', 'full'] 86 | -------------------------------------------------------------------------------- /imnet_resnet50_scratch/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from .train import Trainer 8 | from .config import TrainerConfig, ClusterConfig 9 | -------------------------------------------------------------------------------- /imnet_resnet50_scratch/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from typing import NamedTuple 8 | 9 | 10 | class ClusterConfig(NamedTuple): 11 | dist_backend: str 12 | dist_url: str 13 | 14 | 15 | class TrainerConfig(NamedTuple): 16 | data_folder: str 17 | epochs: int 18 | lr: float 19 | input_size: int 20 | batch_per_gpu: int 21 | save_folder: str 22 | imnet_path: str 23 | workers: int 24 | local_rank: int 25 | global_rank: int 26 | num_tasks: int 27 | job_id: str 28 | -------------------------------------------------------------------------------- /imnet_resnet50_scratch/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from torch.utils.data.sampler import BatchSampler 8 | import torch 9 | import numpy as np 10 | from torch.utils.data.dataloader import default_collate 11 | from collections.abc import Mapping, Sequence 12 | import math 13 | import torch.distributed as dist 14 | 15 | class RASampler(torch.utils.data.Sampler): 16 | """ 17 | Batch Sampler with Repeated Augmentations (RA) 18 | - dataset_len: original length of the dataset 19 | - batch_size 20 | - repetitions: instances per image 21 | - len_factor: multiplicative factor for epoch size 22 | """ 23 | 24 | def __init__(self,dataset,num_replicas, rank, dataset_len, batch_size, repetitions=1, len_factor=1.0, shuffle=False, drop_last=False): 25 | self.dataset=dataset 26 | self.dataset_len = dataset_len 27 | self.batch_size = batch_size 28 | self.repetitions = repetitions 29 | self.len_images = int(dataset_len * len_factor) 30 | self.shuffle = shuffle 31 | self.drop_last = drop_last 32 | if num_replicas is None: 33 | if not dist.is_available(): 34 | raise RuntimeError("Requires distributed package to be available") 35 | num_replicas = dist.get_world_size() 36 | if rank is None: 37 | if not dist.is_available(): 38 | raise RuntimeError("Requires distributed package to be available") 39 | rank = dist.get_rank() 40 | self.dataset = dataset 41 | self.num_replicas = num_replicas 42 | self.rank = rank 43 | self.epoch = 0 44 | self.num_samples = int(math.ceil(len(self.dataset) * self.repetitions * 1.0 / self.num_replicas)) 45 | self.total_size = self.num_samples * self.num_replicas 46 | 47 | 48 | def shuffler(self): 49 | if self.shuffle: 50 | new_perm = lambda: iter(np.random.permutation(self.dataset_len)) 51 | else: 52 | new_perm = lambda: iter(np.arange(self.dataset_len)) 53 | shuffle = new_perm() 54 | while True: 55 | try: 56 | index = next(shuffle) 57 | except StopIteration: 58 | shuffle = new_perm() 59 | index = next(shuffle) 60 | for repetition in range(self.repetitions): 61 | yield index 62 | 63 | def __iter__(self): 64 | shuffle = iter(self.shuffler()) 65 | seen = 0 66 | indices=[] 67 | for _ in range(self.len_images): 68 | index = next(shuffle) 69 | indices.append(index) 70 | indices += indices[:(self.total_size - len(indices))] 71 | assert len(indices) == self.total_size 72 | # subsample 73 | indices = indices[self.rank:self.total_size:self.num_replicas] 74 | assert len(indices) == self.num_samples 75 | 76 | return iter(indices) 77 | 78 | 79 | def __len__(self): 80 | return self.num_samples 81 | 82 | def set_epoch(self, epoch): 83 | self.epoch = epoch 84 | 85 | def list_collate(batch): 86 | """ 87 | Collate into a list instead of a tensor to deal with variable-sized inputs 88 | """ 89 | elem_type = type(batch[0]) 90 | if isinstance(batch[0], torch.Tensor): 91 | return batch 92 | elif elem_type.__module__ == 'numpy': 93 | if elem_type.__name__ == 'ndarray': 94 | return list_collate([torch.from_numpy(b) for b in batch]) 95 | elif isinstance(batch[0], Mapping): 96 | return {key: list_collate([d[key] for d in batch]) for key in batch[0]} 97 | elif isinstance(batch[0], Sequence): 98 | transposed = zip(*batch) 99 | return [list_collate(samples) for samples in transposed] 100 | return default_collate(batch) 101 | -------------------------------------------------------------------------------- /imnet_resnet50_scratch/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import os.path as osp 9 | from typing import Optional 10 | import torch 11 | import torch.distributed 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import attr 15 | from torchvision import datasets 16 | import torchvision.models as models 17 | import numpy as np 18 | from .config import TrainerConfig, ClusterConfig 19 | from .transforms import get_transforms 20 | from .samplers import RASampler 21 | @attr.s(auto_attribs=True) 22 | class TrainerState: 23 | """ 24 | Contains the state of the Trainer. 25 | It can be saved to checkpoint the training and loaded to resume it. 26 | """ 27 | 28 | epoch: int 29 | accuracy:float 30 | model: nn.Module 31 | optimizer: optim.Optimizer 32 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler 33 | 34 | def save(self, filename: str) -> None: 35 | data = attr.asdict(self) 36 | # store only the state dict 37 | data["model"] = self.model.state_dict() 38 | data["optimizer"] = self.optimizer.state_dict() 39 | data["lr_scheduler"] = self.lr_scheduler.state_dict() 40 | data["accuracy"] = self.accuracy 41 | torch.save(data, filename) 42 | 43 | @classmethod 44 | def load(cls, filename: str, default: "TrainerState") -> "TrainerState": 45 | data = torch.load(filename) 46 | # We need this default to load the state dict 47 | model = default.model 48 | model.load_state_dict(data["model"]) 49 | data["model"] = model 50 | 51 | optimizer = default.optimizer 52 | optimizer.load_state_dict(data["optimizer"]) 53 | data["optimizer"] = optimizer 54 | 55 | lr_scheduler = default.lr_scheduler 56 | lr_scheduler.load_state_dict(data["lr_scheduler"]) 57 | data["lr_scheduler"] = lr_scheduler 58 | return cls(**data) 59 | 60 | 61 | class Trainer: 62 | def __init__(self, train_cfg: TrainerConfig, cluster_cfg: ClusterConfig) -> None: 63 | self._train_cfg = train_cfg 64 | self._cluster_cfg = cluster_cfg 65 | 66 | def __call__(self) -> Optional[float]: 67 | """ 68 | Called for each task. 69 | 70 | :return: The master task return the final accuracy of the model. 71 | """ 72 | self._setup_process_group() 73 | self._init_state() 74 | final_acc = self._train() 75 | return final_acc 76 | 77 | def checkpoint(self, rm_init=True): 78 | save_dir = osp.join(self._train_cfg.save_folder, str(self._train_cfg.job_id)) 79 | os.makedirs(save_dir, exist_ok=True) 80 | self._state.save(osp.join(save_dir, "checkpoint.pth")) 81 | self._state.save(osp.join(save_dir, "checkpoint_"+str(self._state.epoch)+".pth")) 82 | if rm_init: 83 | os.remove(self._cluster_cfg.dist_url[7:]) 84 | empty_trainer = Trainer(self._train_cfg, self._cluster_cfg) 85 | return empty_trainer 86 | 87 | def _setup_process_group(self) -> None: 88 | torch.cuda.set_device(self._train_cfg.local_rank) 89 | torch.distributed.init_process_group( 90 | backend=self._cluster_cfg.dist_backend, 91 | init_method=self._cluster_cfg.dist_url, 92 | world_size=self._train_cfg.num_tasks, 93 | rank=self._train_cfg.global_rank, 94 | ) 95 | print(f"Process group: {self._train_cfg.num_tasks} tasks, rank: {self._train_cfg.global_rank}") 96 | 97 | def _init_state(self) -> None: 98 | """ 99 | Initialize the state and load it from an existing checkpoint if any 100 | """ 101 | torch.manual_seed(0) 102 | np.random.seed(0) 103 | print("Create data loaders", flush=True) 104 | 105 | Input_size_Image=self._train_cfg.input_size 106 | 107 | Test_size=Input_size_Image 108 | print("Input size : "+str(Input_size_Image)) 109 | print("Test size : "+str(Input_size_Image)) 110 | print("Initial LR :"+str(self._train_cfg.lr)) 111 | 112 | transf=get_transforms(input_size=Input_size_Image,test_size=Test_size, kind='full', crop=True, need=('train', 'val'), backbone=None) 113 | transform_train = transf['train'] 114 | transform_test = transf['val'] 115 | 116 | train_set = datasets.ImageFolder(self._train_cfg.imnet_path + '/train',transform=transform_train) 117 | train_sampler = RASampler( 118 | train_set,self._train_cfg.num_tasks,self._train_cfg.global_rank,len(train_set),self._train_cfg.batch_per_gpu,repetitions=3,len_factor=2.0,shuffle=True, drop_last=False 119 | ) 120 | self._train_loader = torch.utils.data.DataLoader( 121 | train_set, 122 | batch_size=self._train_cfg.batch_per_gpu, 123 | num_workers=(self._train_cfg.workers-1), 124 | sampler=train_sampler, 125 | ) 126 | test_set = datasets.ImageFolder(self._train_cfg.imnet_path + '/val',transform=transform_test) 127 | 128 | self._test_loader = torch.utils.data.DataLoader( 129 | test_set, batch_size=self._train_cfg.batch_per_gpu, shuffle=False, num_workers=(self._train_cfg.workers-1),#sampler=test_sampler, Attention je le met pas pour l instant 130 | ) 131 | 132 | print(f"Total batch_size: {self._train_cfg.batch_per_gpu * self._train_cfg.num_tasks}", flush=True) 133 | 134 | print("Create distributed model", flush=True) 135 | model = models.resnet50(pretrained=False) 136 | 137 | model.cuda(self._train_cfg.local_rank) 138 | model = torch.nn.parallel.DistributedDataParallel( 139 | model, device_ids=[self._train_cfg.local_rank], output_device=self._train_cfg.local_rank 140 | ) 141 | linear_scaled_lr = 8.0 * self._train_cfg.lr * self._train_cfg.batch_per_gpu * self._train_cfg.num_tasks /512.0 142 | optimizer = optim.SGD(model.parameters(), lr=linear_scaled_lr, momentum=0.9,weight_decay=1e-4) 143 | lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30) 144 | self._state = TrainerState( 145 | epoch=0,accuracy=0.0, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler 146 | ) 147 | checkpoint_fn = osp.join(self._train_cfg.save_folder, str(self._train_cfg.job_id), "checkpoint.pth") 148 | if os.path.isfile(checkpoint_fn): 149 | print(f"Load existing checkpoint from {checkpoint_fn}", flush=True) 150 | self._state = TrainerState.load(checkpoint_fn, default=self._state) 151 | 152 | def _train(self) -> Optional[float]: 153 | criterion = nn.CrossEntropyLoss() 154 | print_freq = 10 155 | acc = None 156 | max_accuracy=0.0 157 | # Start from the loaded epoch 158 | start_epoch = self._state.epoch 159 | for epoch in range(start_epoch, self._train_cfg.epochs): 160 | print(f"Start epoch {epoch}", flush=True) 161 | self._state.model.train() 162 | self._state.lr_scheduler.step(epoch) 163 | self._state.epoch = epoch 164 | running_loss = 0.0 165 | count=0 166 | for i, data in enumerate(self._train_loader): 167 | inputs, labels = data 168 | inputs = inputs.cuda(self._train_cfg.local_rank, non_blocking=True) 169 | labels = labels.cuda(self._train_cfg.local_rank, non_blocking=True) 170 | 171 | outputs = self._state.model(inputs) 172 | loss = criterion(outputs, labels) 173 | 174 | self._state.optimizer.zero_grad() 175 | loss.backward() 176 | self._state.optimizer.step() 177 | 178 | running_loss += loss.item() 179 | count=count+1 180 | if i % print_freq == print_freq - 1: 181 | print(f"[{epoch:02d}, {i:05d}] loss: {running_loss/print_freq:.3f}", flush=True) 182 | running_loss = 0.0 183 | if count>=5005 * 512 /(self._train_cfg.batch_per_gpu * self._train_cfg.num_tasks): 184 | break 185 | 186 | if epoch==self._train_cfg.epochs-1: 187 | print("Start evaluation of the model", flush=True) 188 | 189 | correct = 0 190 | total = 0 191 | count=0.0 192 | running_val_loss = 0.0 193 | self._state.model.eval() 194 | with torch.no_grad(): 195 | for data in self._test_loader: 196 | images, labels = data 197 | images = images.cuda(self._train_cfg.local_rank, non_blocking=True) 198 | labels = labels.cuda(self._train_cfg.local_rank, non_blocking=True) 199 | outputs = self._state.model(images) 200 | loss_val = criterion(outputs, labels) 201 | _, predicted = torch.max(outputs.data, 1) 202 | total += labels.size(0) 203 | correct += (predicted == labels).sum().item() 204 | running_val_loss += loss_val.item() 205 | count=count+1.0 206 | 207 | acc = correct / total 208 | ls_nm=running_val_loss/count 209 | print(f"Accuracy of the network on the 50000 test images: {acc:.1%}", flush=True) 210 | print(f"Loss of the network on the 50000 test images: {ls_nm:.3f}", flush=True) 211 | self._state.accuracy = acc 212 | if self._train_cfg.global_rank == 0: 213 | self.checkpoint(rm_init=False) 214 | print("accuracy val epoch "+str(epoch)+" acc= "+str(acc)) 215 | max_accuracy=np.max((max_accuracy,acc)) 216 | if epoch==self._train_cfg.epochs-1: 217 | return acc 218 | 219 | 220 | 221 | -------------------------------------------------------------------------------- /imnet_resnet50_scratch/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import torch 8 | import torchvision.transforms.functional as F 9 | from torchvision import transforms 10 | 11 | 12 | import numpy as np 13 | 14 | class Resize(transforms.Resize): 15 | """ 16 | Resize with a ``largest=False'' argument 17 | allowing to resize to a common largest side without cropping 18 | """ 19 | 20 | 21 | def __init__(self, size, largest=False, **kwargs): 22 | super().__init__(size, **kwargs) 23 | self.largest = largest 24 | 25 | @staticmethod 26 | def target_size(w, h, size, largest=False): 27 | if h < w and largest: 28 | w, h = size, int(size * h / w) 29 | else: 30 | w, h = int(size * w / h), size 31 | size = (h, w) 32 | return size 33 | 34 | def __call__(self, img): 35 | size = self.size 36 | w, h = img.size 37 | target_size = self.target_size(w, h, size, self.largest) 38 | return F.resize(img, target_size, self.interpolation) 39 | 40 | def __repr__(self): 41 | r = super().__repr__() 42 | return r[:-1] + ', largest={})'.format(self.largest) 43 | 44 | 45 | 46 | 47 | 48 | def get_transforms(input_size=224,test_size=224, kind='full', crop=True, need=('train', 'val'), backbone=None): 49 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 50 | if backbone is not None and backbone in ['pnasnet5large', 'nasnetamobile']: 51 | mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 52 | 53 | transformations = {} 54 | if 'train' in need: 55 | if kind == 'torch': 56 | transformations['train'] = transforms.Compose([ 57 | transforms.RandomResizedCrop(input_size), 58 | transforms.RandomHorizontalFlip(), 59 | transforms.ToTensor(), 60 | transforms.Normalize(mean, std), 61 | ]) 62 | elif kind == 'full': 63 | transformations['train'] = transforms.Compose([ 64 | transforms.RandomResizedCrop(input_size), 65 | transforms.RandomHorizontalFlip(), 66 | transforms.ColorJitter(0.3, 0.3, 0.3), 67 | transforms.ToTensor(), 68 | transforms.Normalize(mean, std), 69 | ]) 70 | 71 | else: 72 | raise ValueError('Transforms kind {} unknown'.format(kind)) 73 | if 'val' in need: 74 | if crop: 75 | transformations['val'] = transforms.Compose( 76 | [Resize(int((256 / 224) * test_size)), # to maintain same ratio w.r.t. 224 images 77 | transforms.CenterCrop(test_size), 78 | transforms.ToTensor(), 79 | transforms.Normalize(mean, std)]) 80 | else: 81 | transformations['val'] = transforms.Compose( 82 | [Resize(test_size, largest=True), 83 | transforms.ToTensor(), 84 | transforms.Normalize(mean, std)]) 85 | return transformations 86 | 87 | transforms_list = ['torch', 'full'] 88 | -------------------------------------------------------------------------------- /main_evaluate_imnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import uuid 9 | from pathlib import Path 10 | from imnet_evaluate import TrainerConfig, ClusterConfig, Trainer 11 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 12 | 13 | 14 | 15 | def run(input_size,architecture,weight_path,imnet_path,batch,workers,shared_folder_path,job_id,local_rank,global_rank,num_tasks): 16 | cluster_cfg = ClusterConfig(dist_backend="nccl", dist_url="") 17 | shared_folder=None 18 | data_folder_Path=None 19 | if Path(str(shared_folder_path)).is_dir(): 20 | shared_folder=Path(shared_folder_path+"/evaluate/") 21 | else: 22 | raise RuntimeError("No shared folder available") 23 | if Path(str(imnet_path)).is_dir(): 24 | data_folder_Path=Path(str(imnet_path)) 25 | else: 26 | raise RuntimeError("No shared folder available") 27 | train_cfg = TrainerConfig( 28 | data_folder=str(data_folder_Path), 29 | architecture=architecture, 30 | weight_path=weight_path, 31 | input_size=input_size, 32 | imnet_path=imnet_path, 33 | batch_per_gpu=batch, 34 | workers=workers, 35 | local_rank=local_rank, 36 | global_rank=global_rank, 37 | num_tasks=num_tasks, 38 | job_id=job_id, 39 | save_folder=str(shared_folder), 40 | 41 | ) 42 | 43 | # Create the executor 44 | os.makedirs(str(shared_folder), exist_ok=True) 45 | init_file = shared_folder / f"{uuid.uuid4().hex}_init" 46 | if init_file.exists(): 47 | os.remove(str(init_file)) 48 | 49 | cluster_cfg = cluster_cfg._replace(dist_url=init_file.as_uri()) 50 | trainer = Trainer(train_cfg, cluster_cfg) 51 | 52 | #The code should be launch on each GPUs 53 | try: 54 | if global_rank==0: 55 | val_accuracy = trainer.__call__() 56 | print(f"Validation accuracy: {val_accuracy}") 57 | else: 58 | trainer.__call__() 59 | except: 60 | print("Job failed") 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = ArgumentParser(description="Evaluation script for FixRes models",formatter_class=ArgumentDefaultsHelpFormatter) 65 | parser.add_argument('--input-size', default=320, type=int, help='Images input size') 66 | parser.add_argument('--architecture', default='IGAM_Resnext101_32x48d', type=str,choices=['ResNet50', 'PNASNet' , 'IGAM_Resnext101_32x48d'], help='Neural network architecture') 67 | parser.add_argument('--weight-path', default='/where/are/the/weigths.pth', type=str, help='Neural network weights') 68 | parser.add_argument('--imnet-path', default='/the/imagenet/path', type=str, help='ImageNet dataset path') 69 | parser.add_argument('--shared-folder-path', default='your/shared/folder', type=str, help='Shared Folder') 70 | parser.add_argument('--batch', default=32, type=int, help='Batch per GPU') 71 | parser.add_argument('--workers', default=40, type=int, help='Numbers of CPUs') 72 | parser.add_argument('--job-id', default='0', type=str, help='id of the execution') 73 | parser.add_argument('--local-rank', default=0, type=int, help='GPU: Local rank') 74 | parser.add_argument('--global-rank', default=0, type=int, help='GPU: glocal rank') 75 | parser.add_argument('--num-tasks', default=32, type=int, help='How many GPUs are used') 76 | 77 | args = parser.parse_args() 78 | run(args.input_size,args.architecture,args.weight_path,args.imnet_path,args.batch,args.workers,args.shared_folder_path,args.job_id,args.local_rank,args.global_rank,args.num_tasks) 79 | -------------------------------------------------------------------------------- /main_evaluate_softmax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 10 | 11 | def run(path,name,version): 12 | if version=='v1': 13 | softmax=np.load(path+name+'_softmax.npy') 14 | else: 15 | softmax=np.load(path+name+'_softmax_v2.npy') 16 | 17 | label=np.load(path+'labels.npy') 18 | assert softmax.shape[0] == 50000,"Error ImageNet validation set doesn't match" 19 | correct=0.0 20 | for i in range(softmax.shape[0]): 21 | prediction=np.argmax(softmax[i]) 22 | if prediction==label[i]: 23 | correct=correct+1.0 24 | acc=(correct/50000.0)*100 25 | print("Top-1 Accuracy: %.1f"%(acc)) 26 | return acc 27 | 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = ArgumentParser(description="Evaluation script for softmax extracted from FixRes models",formatter_class=ArgumentDefaultsHelpFormatter) 32 | parser.add_argument('--architecture', default='IGAM_Resnext101_32x48d', type=str,choices=['ResNet50' , 'ResNet50CutMix', 'PNASNet' , 'IGAM_Resnext101_32x48d'], help='Neural network architecture') 33 | parser.add_argument('--save-path', default='/where/are/save/softmax/', type=str, help='Path where softmax were saved') 34 | parser.add_argument('--version', default='v1', type=str,choices=['v1' , 'v2'], help='version') 35 | args = parser.parse_args() 36 | run(args.save_path,args.architecture,args.version) 37 | -------------------------------------------------------------------------------- /main_extract.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import uuid 9 | from pathlib import Path 10 | from imnet_extract import TrainerConfig, ClusterConfig, Trainer 11 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 12 | 13 | 14 | def run(input_size,architecture,weight_path,dataset_path,batch,workers,save_path,shared_folder_path,job_id,local_rank,global_rank,num_tasks): 15 | shared_folder=None 16 | data_folder_Path=None 17 | if Path(str(shared_folder_path)).is_dir(): 18 | shared_folder=Path(shared_folder_path+"/extract/") 19 | else: 20 | raise RuntimeError("No shared folder available") 21 | if Path(str(dataset_path)).is_dir(): 22 | data_folder_Path=Path(str(dataset_path)) 23 | else: 24 | raise RuntimeError("No shared folder available") 25 | cluster_cfg = ClusterConfig(dist_backend="nccl", dist_url="") 26 | train_cfg = TrainerConfig( 27 | data_folder=str(data_folder_Path), 28 | architecture=architecture, 29 | weight_path=weight_path, 30 | input_size=input_size, 31 | dataset_path=dataset_path, 32 | batch_per_gpu=batch, 33 | workers=workers, 34 | save_path=save_path, 35 | local_rank=local_rank, 36 | global_rank=global_rank, 37 | num_tasks=num_tasks, 38 | job_id=job_id, 39 | save_folder=str(shared_folder), 40 | 41 | ) 42 | 43 | os.makedirs(str(shared_folder), exist_ok=True) 44 | init_file = shared_folder / f"{uuid.uuid4().hex}_init" 45 | if init_file.exists(): 46 | os.remove(str(init_file)) 47 | 48 | cluster_cfg = cluster_cfg._replace(dist_url=init_file.as_uri()) 49 | trainer = Trainer(train_cfg, cluster_cfg) 50 | 51 | #The code should be launch on each GPUs 52 | try: 53 | if global_rank==0: 54 | val_accuracy = trainer.__call__() 55 | print(f"Validation accuracy: {val_accuracy}") 56 | else: 57 | trainer.__call__() 58 | except: 59 | print("Job failed") 60 | 61 | 62 | if __name__ == "__main__": 63 | parser = ArgumentParser(description="Evaluation script for AERES models",formatter_class=ArgumentDefaultsHelpFormatter) 64 | parser.add_argument('--input-size', default=320, type=int, help='Images input size') 65 | parser.add_argument('--architecture', default='IGAM_Resnext101_32x48d', type=str,choices=['ResNet50', 'PNASNet' , 'IGAM_Resnext101_32x48d'], help='Neural network architecture') 66 | parser.add_argument('--weight-path', default='/checkpoint/htouvron/baseline_our_method_BN_FC_URU/13972162/checkpoint_0.pth', type=str, help='Neural network weights') 67 | parser.add_argument('--dataset-path', default='/datasets01_101/imagenet_full_size/061417/val', type=str, help='Dataset path') 68 | parser.add_argument('--batch', default=32, type=int, help='Batch per GPU') 69 | parser.add_argument('--workers', default=40, type=int, help='Numbers of CPUs') 70 | parser.add_argument('--save-path', default='/checkpoint/htouvron/github_reproduce_result_extract/output_extract/', type=str, help='Path where output will be save') 71 | parser.add_argument('--shared-folder-path', default='your/shared/folder', type=str, help='Shared Folder') 72 | parser.add_argument('--job-id', default='0', type=str, help='id of the execution') 73 | parser.add_argument('--local-rank', default=0, type=int, help='GPU: Local rank') 74 | parser.add_argument('--global-rank', default=0, type=int, help='GPU: glocal rank') 75 | parser.add_argument('--num-tasks', default=32, type=int, help='How many GPUs are used') 76 | args = parser.parse_args() 77 | run(args.input_size,args.architecture,args.weight_path,args.dataset_path,args.batch,args.workers,args.save_path,args.shared_folder_path,args.job_id,args.local_rank,args.global_rank,args.num_tasks) 78 | -------------------------------------------------------------------------------- /main_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import uuid 9 | from pathlib import Path 10 | from imnet_finetune import TrainerConfig, ClusterConfig, Trainer 11 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 12 | 13 | 14 | 15 | def run(input_sizes,epochs,learning_rate,batch,imnet_path,architecture,resnet_weight_path,workers,shared_folder_path,job_id,local_rank,global_rank,num_tasks,EfficientNet_models): 16 | cluster_cfg = ClusterConfig(dist_backend="nccl", dist_url="") 17 | shared_folder=None 18 | data_folder_Path=None 19 | if Path(str(shared_folder_path)).is_dir(): 20 | shared_folder=Path(shared_folder_path+"/finetune/") 21 | else: 22 | raise RuntimeError("No shared folder available") 23 | if Path(str(imnet_path)).is_dir(): 24 | data_folder_Path=Path(str(imnet_path)) 25 | else: 26 | raise RuntimeError("No shared folder available") 27 | train_cfg = TrainerConfig( 28 | data_folder=str(data_folder_Path), 29 | epochs=epochs, 30 | lr=learning_rate, 31 | input_size=input_sizes, 32 | batch_per_gpu=batch, 33 | save_folder=str(shared_folder), 34 | imnet_path=imnet_path, 35 | architecture=architecture, 36 | resnet_weight_path=resnet_weight_path, 37 | workers=workers, 38 | local_rank=local_rank, 39 | global_rank=global_rank, 40 | num_tasks=num_tasks, 41 | job_id=job_id, 42 | EfficientNet_models=EfficientNet_models, 43 | 44 | ) 45 | 46 | # Create the executor 47 | os.makedirs(str(shared_folder), exist_ok=True) 48 | init_file = shared_folder / f"{uuid.uuid4().hex}_init" 49 | if init_file.exists(): 50 | os.remove(str(init_file)) 51 | 52 | cluster_cfg = cluster_cfg._replace(dist_url=init_file.as_uri()) 53 | trainer = Trainer(train_cfg, cluster_cfg) 54 | 55 | #The code should be launch on each GPUs 56 | try: 57 | if global_rank==0: 58 | val_accuracy = trainer.__call__() 59 | print(f"Validation accuracy: {val_accuracy}") 60 | else: 61 | trainer.__call__() 62 | except: 63 | print("Job failed") 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = ArgumentParser(description="Fine-tune script for FixRes models",formatter_class=ArgumentDefaultsHelpFormatter) 68 | parser.add_argument('--learning-rate', default=1e-3, type=float, help='base learning rate') 69 | parser.add_argument('--epochs', default=1, type=int, help='epochs') 70 | parser.add_argument('--input-size', default=320, type=int, help='images input size') 71 | parser.add_argument('--batch', default=8, type=int, help='Batch by GPU') 72 | parser.add_argument('--imnet-path', default='/the/imagenet/path', type=str, help='Image Net dataset path') 73 | parser.add_argument('--architecture', default='IGAM_Resnext101_32x48d', type=str,choices=['ResNet50', 'PNASNet' , 'IGAM_Resnext101_32x48d','EfficientNet'], help='Neural network architecture') 74 | parser.add_argument('--resnet-weight-path', default='/where/are/the/weigths.pth', type=str, help='Neural network weights (only for ResNet50)') 75 | parser.add_argument('--workers', default=10, type=int, help='Numbers of CPUs') 76 | parser.add_argument('--job-id', default='0', type=str, help='id of the execution') 77 | parser.add_argument('--local-rank', default=0, type=int, help='GPU: Local rank') 78 | parser.add_argument('--global-rank', default=0, type=int, help='GPU: glocal rank') 79 | parser.add_argument('--num-tasks', default=32, type=int, help='How many GPUs are used') 80 | parser.add_argument('--shared-folder-path', default='your/shared/folder', type=str, help='Shared Folder') 81 | parser.add_argument('--EfficientNet-models', default='tf_efficientnet_b0_ap', type=str, help='EfficientNet Models') 82 | 83 | 84 | args = parser.parse_args() 85 | run(args.input_size,args.epochs,args.learning_rate,args.batch,args.imnet_path,args.architecture,args.resnet_weight_path,args.workers,args.shared_folder_path,args.job_id,args.local_rank,args.global_rank,args.num_tasks,args.EfficientNet_models) 86 | -------------------------------------------------------------------------------- /main_resnet50_scratch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import uuid 9 | from pathlib import Path 10 | from imnet_resnet50_scratch import TrainerConfig, ClusterConfig, Trainer 11 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 12 | 13 | 14 | def run(input_sizes,learning_rate,epochs,batch,node,workers,imnet_path,shared_folder_path,job_id,local_rank,global_rank,num_tasks): 15 | cluster_cfg = ClusterConfig(dist_backend="nccl", dist_url="") 16 | shared_folder=None 17 | data_folder_Path=None 18 | if Path(str(shared_folder_path)).is_dir(): 19 | shared_folder=Path(shared_folder_path+"/training/") 20 | else: 21 | raise RuntimeError("No shared folder available") 22 | if Path(str(imnet_path)).is_dir(): 23 | data_folder_Path=Path(str(imnet_path)) 24 | else: 25 | raise RuntimeError("No shared folder available") 26 | train_cfg = TrainerConfig( 27 | data_folder=str(data_folder_Path), 28 | epochs=epochs, 29 | lr=learning_rate, 30 | input_size=input_sizes, 31 | batch_per_gpu=batch, 32 | save_folder=str(shared_folder_path), 33 | workers=workers, 34 | imnet_path=imnet_path, 35 | local_rank=local_rank, 36 | global_rank=global_rank, 37 | num_tasks=num_tasks, 38 | job_id=job_id, 39 | ) 40 | 41 | # Create the executor 42 | os.makedirs(str(shared_folder), exist_ok=True) 43 | init_file = shared_folder / f"{uuid.uuid4().hex}_init" 44 | if init_file.exists(): 45 | os.remove(str(init_file)) 46 | 47 | cluster_cfg = cluster_cfg._replace(dist_url=init_file.as_uri()) 48 | trainer = Trainer(train_cfg, cluster_cfg) 49 | 50 | #The code should be launch on each GPUs 51 | try: 52 | if global_rank==0: 53 | val_accuracy = trainer.__call__() 54 | print(f"Validation accuracy: {val_accuracy}") 55 | else: 56 | trainer.__call__() 57 | except: 58 | print("Job failed") 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = ArgumentParser(description="Training script for ResNet50 FixRes",formatter_class=ArgumentDefaultsHelpFormatter) 63 | parser.add_argument('--learning-rate', default=0.02, type=float, help='base learning rate') 64 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 65 | parser.add_argument('--epochs', default=120, type=int, help='epochs') 66 | parser.add_argument('--batch', default=64, type=int, help='Batch by GPU') 67 | parser.add_argument('--node', default=1, type=int, help='GPU nodes') 68 | parser.add_argument('--workers', default=10, type=int, help='Numbers of CPUs') 69 | parser.add_argument('--imnet-path', default='/the/imagenet/path', type=str, help='ImageNet dataset path') 70 | parser.add_argument('--shared-folder-path', default='your/shared/folder', type=str, help='Shared Folder') 71 | parser.add_argument('--job-id', default='0', type=str, help='id of the execution') 72 | parser.add_argument('--local-rank', default=0, type=int, help='GPU: Local rank') 73 | parser.add_argument('--global-rank', default=0, type=int, help='GPU: glocal rank') 74 | parser.add_argument('--num-tasks', default=8, type=int, help='How many GPUs are used') 75 | args = parser.parse_args() 76 | run(args.input_size,args.learning_rate,args.epochs,args.batch,args.node,args.workers,args.imnet_path,args.shared_folder_path,args.job_id,args.local_rank,args.global_rank,args.num_tasks) 77 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrs 2 | numpy 3 | torch 4 | torchvision 5 | Pillow 6 | tqdm -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from setuptools import setup 8 | from setuptools import find_packages 9 | 10 | 11 | with open("requirements.txt") as f: 12 | requirements = f.read().splitlines() 13 | 14 | 15 | setup( 16 | name="Fixing the train-test resolution discrepancy scripts", 17 | version="1.0", 18 | description="Script of models from https://arxiv.org/abs/1906.06423", 19 | author="Facebook AI Research", 20 | packages=find_packages(), 21 | install_requires=requirements, 22 | ) 23 | -------------------------------------------------------------------------------- /transforms_v2.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | import torch 9 | import torchvision.transforms.functional as F 10 | from torchvision import transforms 11 | import torch 12 | import math 13 | import sys 14 | import random 15 | from PIL import Image 16 | try: 17 | import accimage 18 | except ImportError: 19 | accimage = None 20 | import numbers 21 | import types 22 | import collections 23 | import warnings 24 | import numpy as np 25 | try: 26 | import accimage 27 | except ImportError: 28 | accimage = None 29 | 30 | def _is_pil_image(img): 31 | if accimage is not None: 32 | return isinstance(img, (Image.Image, accimage.Image)) 33 | else: 34 | return isinstance(img, Image.Image) 35 | def crop(img, i, j, h, w): 36 | """Crop the given PIL Image. 37 | Args: 38 | img (PIL Image): Image to be cropped. 39 | i (int): i in (i,j) i.e coordinates of the upper left corner. 40 | j (int): j in (i,j) i.e coordinates of the upper left corner. 41 | h (int): Height of the cropped image. 42 | w (int): Width of the cropped image. 43 | Returns: 44 | PIL Image: Cropped image. 45 | """ 46 | if not _is_pil_image(img): 47 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 48 | 49 | return img.crop((j, i, j + w, i + h)) 50 | 51 | 52 | def center_crop_new(img, output_size): 53 | if isinstance(output_size, numbers.Number): 54 | output_size = (int(output_size), int(output_size)) 55 | w, h = img.size 56 | th, tw = output_size 57 | i = int(round((h - th) / 2.)) 58 | j = int(round((w - tw) / 2.)) 59 | jit=0 60 | if j > 0: 61 | jit=np.random.randint(int(j+1)) 62 | val=np.random.randint(2) 63 | scale=(1.0)*(val==0)+(-1.0)*(val==1) 64 | return crop(img, i, int(j+scale*jit), th, tw) 65 | 66 | 67 | 68 | 69 | class CenterCrop(object): 70 | """Crops the given PIL Image at the center. 71 | Args: 72 | size (sequence or int): Desired output size of the crop. If size is an 73 | int instead of sequence like (h, w), a square crop (size, size) is 74 | made. 75 | """ 76 | 77 | def __init__(self, size): 78 | if isinstance(size, numbers.Number): 79 | self.size = (int(size), int(size)) 80 | else: 81 | self.size = size 82 | 83 | def __call__(self, img): 84 | """ 85 | Args: 86 | img (PIL Image): Image to be cropped. 87 | Returns: 88 | PIL Image: Cropped image. 89 | """ 90 | return center_crop_new(img, self.size) 91 | 92 | def __repr__(self): 93 | return self.__class__.__name__ + '(size={0})'.format(self.size) 94 | 95 | class Resize(transforms.Resize): 96 | """ 97 | Resize with a ``largest=False'' argument 98 | allowing to resize to a common largest side without cropping 99 | """ 100 | 101 | 102 | def __init__(self, size, largest=False, **kwargs): 103 | super().__init__(size, **kwargs) 104 | self.largest = largest 105 | 106 | @staticmethod 107 | def target_size(w, h, size, largest=False): 108 | if h < w and largest: 109 | w, h = size, int(size * h / w) 110 | else: 111 | w, h = int(size * w / h), size 112 | size = (h, w) 113 | return size 114 | 115 | def __call__(self, img): 116 | size = self.size 117 | w, h = img.size 118 | target_size = self.target_size(w, h, size, self.largest) 119 | return F.resize(img, target_size, self.interpolation) 120 | 121 | def __repr__(self): 122 | r = super().__repr__() 123 | return r[:-1] + ', largest={})'.format(self.largest) 124 | 125 | 126 | class Lighting(object): 127 | """ 128 | PCA jitter transform on tensors 129 | """ 130 | def __init__(self, alpha_std, eig_val, eig_vec): 131 | self.alpha_std = alpha_std 132 | self.eig_val = torch.as_tensor(eig_val, dtype=torch.float).view(1, 3) 133 | self.eig_vec = torch.as_tensor(eig_vec, dtype=torch.float) 134 | 135 | def __call__(self, data): 136 | if self.alpha_std == 0: 137 | return data 138 | alpha = torch.empty(1, 3).normal_(0, self.alpha_std) 139 | rgb = ((self.eig_vec * alpha) * self.eig_val).sum(1) 140 | data += rgb.view(3, 1, 1) 141 | data /= 1. + self.alpha_std 142 | return data 143 | 144 | 145 | class Bound(object): 146 | def __init__(self, lower=0., upper=1.): 147 | self.lower = lower 148 | self.upper = upper 149 | 150 | def __call__(self, data): 151 | return data.clamp_(self.lower, self.upper) 152 | 153 | 154 | def get_transforms(input_size=224,test_size=224, kind='full', crop=True, need=('train', 'val'), backbone=None): 155 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 156 | if backbone is not None and backbone in ['pnasnet5large', 'nasnetamobile']: 157 | mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] 158 | transformations = {} 159 | if 'train' in need: 160 | if kind == 'torch': 161 | transformations['train'] = transforms.Compose([ 162 | transforms.RandomResizedCrop(input_size), 163 | transforms.RandomHorizontalFlip(), 164 | transforms.ToTensor(), 165 | transforms.Normalize(mean, std), 166 | ]) 167 | elif kind == 'full': 168 | transformations['train'] = transforms.Compose([ 169 | transforms.RandomResizedCrop(input_size), 170 | transforms.RandomHorizontalFlip(), 171 | transforms.ColorJitter(0.3, 0.3, 0.3), 172 | transforms.ToTensor(), 173 | transforms.Normalize(mean, std), 174 | ]) 175 | 176 | else: 177 | raise ValueError('Transforms kind {} unknown'.format(kind)) 178 | if 'val' in need: 179 | if crop: 180 | transformations['val_test'] = transforms.Compose( 181 | [Resize(int((256 / 224) * test_size)), # to maintain same ratio w.r.t. 224 images 182 | transforms.CenterCrop(test_size), 183 | transforms.ToTensor(), 184 | transforms.Normalize(mean, std)]) 185 | transformations['val_train'] = transforms.Compose( 186 | [Resize(int((256 / 224) * test_size)), # to maintain same ratio w.r.t. 224 images 187 | transforms.RandomHorizontalFlip(), 188 | transforms.ColorJitter(0.05, 0.05, 0.05), 189 | CenterCrop(test_size), 190 | transforms.ToTensor(), 191 | transforms.Normalize(mean, std)]) 192 | else: 193 | transformations['val'] = transforms.Compose( 194 | [Resize(test_size, largest=True), 195 | transforms.ToTensor(), 196 | transforms.Normalize(mean, std)]) 197 | return transformations 198 | 199 | transforms_list = ['torch', 'full'] 200 | --------------------------------------------------------------------------------