├── LICENSE ├── README.md ├── eval_cifar.py ├── models ├── __init__.py ├── densenet.py ├── dpn.py ├── efficientnet.py ├── googlenet.py ├── lenet.py ├── mobilenet.py ├── mobilenetv2.py ├── pnasnet.py ├── regnet.py ├── resnet.py ├── resnext.py ├── senet.py ├── shufflenet.py ├── shufflenetv2.py └── vgg.py ├── preactresnet.py ├── train_cifar.py ├── utils.py ├── utils_plus.py └── wideresnet.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bag of Tricks for Adversarial Training 2 | Empirical tricks for training state-of-the-art robust models on CIFAR-10. A playground for fine-tuning the basic adversarial training settings. 3 | 4 | [Bag of Tricks for Adversarial Training](https://openreview.net/forum?id=Xb8xvrtB8Ce) (ICLR 2021) 5 | 6 | [Tianyu Pang](http://ml.cs.tsinghua.edu.cn/~tianyu/), [Xiao Yang](https://github.com/ShawnXYang), [Yinpeng Dong](http://ml.cs.tsinghua.edu.cn/~yinpeng/), [Hang Su](http://www.suhangss.me/), and [Jun Zhu](http://ml.cs.tsinghua.edu.cn/~jun/index.shtml). 7 | 8 | ## Environment settings and libraries we used in our experiments 9 | 10 | This project is tested under the following environment settings: 11 | - OS: Ubuntu 18.04.4 12 | - GPU: Geforce 2080 Ti or Tesla P100 13 | - Cuda: 10.1, Cudnn: v7.6 14 | - Python: 3.6 15 | - PyTorch: >= 1.4.0 16 | - Torchvision: >= 0.4.0 17 | 18 | ## Acknowledgement 19 | The codes are modifed based on [Rice et al. 2020](https://github.com/locuslab/robust_overfitting), and the model architectures are implemented by [pytorch-cifar](https://github.com/kuangliu/pytorch-cifar). 20 | 21 | ## Threat Model 22 | We consider the most widely studied setting: 23 | - **L-inf norm constraint with the maximal epsilon be 8/255 on CIFAR-10**. 24 | - **No accessibility to additional data, neither labeled nor unlabeled**. 25 | - **Utilize the PGD-AT framework in [Madry et al. 2018](https://arxiv.org/abs/1706.06083)**. 26 | 27 | (Implementations on the TRADES framework can be found [here](https://github.com/ShawnXYang/AT_HE)) 28 | 29 | ## Trick Candidates 30 | Importance rate: *Critical*; *Useful*; *Insignificance* 31 | 32 | - **Early stopping w.r.t. training epochs** (*Critical*). 33 | Early stopping w.r.t. training epochs was first introduced in the [code of TRADES](https://github.com/yaodongyu/TRADES), and was later thoroughly studied by [Rice et al., 2020](https://arxiv.org/abs/2002.11569). Due to its effectiveness, we regard this trick as a default choice. 34 | 35 | - **Early stopping w.r.t. attack intensity** (*Useful*). Early stopping w.r.t. attack iterations was studied by [Wang et al. 2019](proceedings.mlr.press/v97/wang19i/wang19i.pdf) and [Zhang et al. 2020](https://arxiv.org/abs/2002.11242). Here we exploit the strategy of the later one, where the authors show that this trick can promote clean accuracy. The relevant flags include `--earlystopPGD` indicates whether apply this trick, while '--earlystopPGDepoch1' and '--earlystopPGDepoch2' separately indicate the epoch to increase the tolerence t by one, as detailed in [Zhang et al. 2020](https://arxiv.org/abs/2002.11242). (*Note that early stopping attack intensity may degrade worst-case robustness under strong attacks*) 36 | 37 | - **Warmup w.r.t. learning rate** (*Insignificance*). Warmup w.r.t. learning rate was found useful for [FastAT](https://arxiv.org/abs/2001.03994), while [Rice et al., 2020](https://arxiv.org/abs/2002.11569) found that piecewise decay schedule is more compatible with early stop w.r.t. training epochs. The relevant flags include `--warmup_lr` indicates whether apply this trick, while `--warmup_lr_epoch` indicates the end epoch of the gradually increase of learning rate. 38 | 39 | - **Warmup w.r.t. epsilon** (*Insignificance*). [Qin et al. 2019](https://arxiv.org/abs/1907.02610) use warmup w.r.t. epsilon in their implementation, where the epsilon gradually increase from 0 to 8/255 in the first 15 epochs. Similarly, the relevant flags include `--warmup_eps` indicates whether apply this trick, while `--warmup_eps_epoch` indicates the end epoch of the gradually increase of epsilon. 40 | 41 | - **Batch size** (*Insignificance*). The typical batch size used for CIFAR-10 is 128 in the adversarial setting. In the meanwhile, [Xie et al. 2019](https://arxiv.org/pdf/1812.03411.pdf) apply a large batch size of 4096 to perform adversarial training on ImageNet, where the model is distributed on 128 GPUs and has quite robust performance. The relevant flag is `--batch-size`. According to [Goyal et al. 2017](https://arxiv.org/abs/1706.02677), we take bs=128 and lr=0.1 as a basis, and scale the lr when we use larger batch size, e.g., bs=256 and lr=0.2. 42 | 43 | - **Label smoothing** (*Useful*). Label smoothing is advocated by [Shafahi et al. 2019](https://arxiv.org/abs/1910.11585) to mimic the adversarial training procedure. The relevant flags include `--labelsmooth` indicates whether apply this trick, while `--labelsmoothvalue` indicates the degree of smoothing applied on the label vectors. When `--labelsmoothvalue=0`, there is no label smoothing applied. (*Note that only moderate label smoothing (~0.2) is helpful, while exccessive label smoothing (>0.3) could be harmful, as observed in [Jiang et al. 2020](https://arxiv.org/abs/2006.13726)*) 44 | 45 | - **Optimizer** (*Insignificance*). Most of the AT methods apply SGD with momentum as the optimizer. In other cases, [Carmon et al. 2019](https://arxiv.org/abs/1905.13736) apply SGD with Nesterov, and [Rice et al., 2020](https://arxiv.org/abs/2002.11569) apply Adam for cyclic learning rate schedule. The relevant flag is `--optimizer`, which include common optimizers implemented by official Pytorch API and recently proposed gradient centralization trick by [Yong et al. 2020](https://arxiv.org/abs/2004.01461). 46 | 47 | - **Weight decay** (*Critical*). The values of weight decay used in previous AT methods mainly fall into `1e-4` (e.g., [Wang et al. 2019](proceedings.mlr.press/v97/wang19i/wang19i.pdf)), `2e-4` (e.g., [Madry et al. 2018](https://arxiv.org/abs/1706.06083)), and `5e-4` (e.g., [Rice et al., 2020](https://arxiv.org/abs/2002.11569)). We find that slightly different values of weight decay could largely affect the robustness of the adversarially trained models. 48 | 49 | - **Activation function** (*Useful*). As shown in [Xie et al., 2020a](https://arxiv.org/pdf/2006.14536.pdf), the smooth alternatives of `ReLU`, including `Softplus` and `GELU` can promote the performance of adversarial training. The relevant flags are `--activation` to choose the activation, and `--softplus_beta` to set the beta for Softplus. Other hyperparameters are used by default in the code. 50 | 51 | - **BN mode** (*Useful*). TRADES applies eval mode of BN when crafting adversarial examples during training, while PGD-AT methods implemented by [Madry et al. 2018](https://arxiv.org/abs/1706.06083) or [Rice et al., 2020](https://arxiv.org/abs/2002.11569) use train mode of BN to craft training adversarial examples. As indicated by [Xie et al., 2020b](https://arxiv.org/pdf/1906.03787.pdf), properly dealing with BN layers is critical to obtain a well-performed adversarially trained model, while train mode of BN during multi-step PGD process may blur the distribution. 52 | 53 | 54 | ## Baseline setting (on CIFAR-10) 55 | - **Architecture**: WideResNet-34-10 56 | - **Optimizer**: Momentum SGD with default hyperparameters 57 | - **Total epoch**: `110` 58 | - **Batch size**: `128` 59 | - **Weight decay**: `5e-4` 60 | - **Learning rate**: `lr=0.1`; decay to `lr=0.01` at 100 epoch; decay to `0.001` at 105 epoch 61 | - **BN mode**: eval 62 | 63 | running command for training: 64 | ```python 65 | python train_cifar.py --model WideResNet --attack pgd \ 66 | --lr-schedule piecewise --norm l_inf --epsilon 8 \ 67 | --epochs 110 --attack-iters 10 --pgd-alpha 2 \ 68 | --fname auto \ 69 | --optimizer 'momentum' \ 70 | --weight_decay 5e-4 71 | --batch-size 128 \ 72 | --BNeval \ 73 | ``` 74 | 75 | ## Empirical Evaluations 76 | *The evaluation results on the baselines are quoted from [AutoAttack](https://arxiv.org/abs/2003.01690) ([evaluation code](https://github.com/P2333/Bag-of-Tricks-for-AT/blob/master/eval_cifar.py))*. 77 | 78 | Note that **OURS (TRADES)** below only change the weight decay value from `2e-4` (used in original TRADES) to `5e-4`, and train for 110 epochs (lr decays at 100 and 105 epochs). To run the evaluation script `eval_cifar.py`, the command should be 79 | ```python 80 | python eval_cifar.py --out-dir 'path_to_the_model' --ATmethods 'TRADES' 81 | ``` 82 | Here `ATmethods` refer to the AT framework (e.g., PGDAT or TRADES). 83 | 84 | ### CIFAR-10 (eps = 8/255) 85 | |paper | Architecture | clean | AA | 86 | |---|:---:|:---:|:---:| 87 | | **OURS (TRADES)**[[Checkpoint](http://ml.cs.tsinghua.edu.cn/~xiaoyang/downloads/bag_of_tricks/wide20_trades_eps8_tricks.pt)] | WRN-34-20| 86.43 | 54.39 | 88 | | **OURS (TRADES)**[[Checkpoint](http://ml.cs.tsinghua.edu.cn/~xiaoyang/downloads/bag_of_tricks/wide10_trades_eps8_tricks.pt)] | WRN-34-10| 85.48 | 53.80 | 89 | | [(Pang et al., 2020)](https://arxiv.org/abs/2002.08619) | WRN-34-20| 85.14 | 53.74 | 90 | | [(Zhang et al., 2020)](https://arxiv.org/abs/2002.11242)| WRN-34-10| 84.52 | 53.51 | 91 | | [(Rice et al., 2020)](https://arxiv.org/abs/2002.11569) | WRN-34-20| 85.34 | 53.35 | 92 | 93 | 94 | ### CIFAR-10 (eps = 0.031) 95 | |paper | Architecture | clean | AA | 96 | |---|:---:|:---:|:---:| 97 | | **OURS (TRADES)**[[Checkpoint](http://ml.cs.tsinghua.edu.cn/~xiaoyang/downloads/bag_of_tricks/wide10_trades_tricks.pt)] | WRN-34-10| 85.34 | 54.64 | 98 | | [(Huang et al., 2020)](https://arxiv.org/abs/2002.10319) | WRN-34-10| 83.48 | 53.34 | 99 | | [(Zhang et al., 2019)](https://arxiv.org/abs/1901.08573) | WRN-34-10| 84.92 | 53.04 | 100 | 101 | ## References 102 | If you find the code useful for your research, please consider citing 103 | ```bib 104 | @inproceedings{pang2021bag, 105 | title={Bag of Tricks for Adversarial Training}, 106 | author={Pang, Tianyu and Yang, Xiao and Dong, Yinpeng and Su, Hang and Zhu, Jun}, 107 | booktitle={International Conference on Learning Representations (ICLR)}, 108 | year={2021} 109 | } 110 | ``` 111 | 112 | and/or our related works 113 | 114 | ```bib 115 | @inproceedings{wang2023better, 116 | title={Better Diffusion Models Further Improve Adversarial Training}, 117 | author={Wang, Zekai and Pang, Tianyu and Du, Chao and Lin, Min and Liu, Weiwei and Yan, Shuicheng}, 118 | booktitle={International Conference on Machine Learning (ICML)}, 119 | year={2023} 120 | } 121 | ``` 122 | ```bib 123 | @inproceedings{pang2022robustness, 124 | title={Robustness and Accuracy Could be Reconcilable by (Proper) Definition}, 125 | author={Pang, Tianyu and Lin, Min and Yang, Xiao and Zhu, Jun and Yan, Shuicheng}, 126 | booktitle={International Conference on Machine Learning (ICML)}, 127 | year={2022} 128 | } 129 | ``` 130 | -------------------------------------------------------------------------------- /eval_cifar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import logging 4 | import os 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from preactresnet import PreActResNet18 13 | from wideresnet import WideResNet 14 | from utils_plus import (upper_limit, lower_limit, clamp, get_loaders, 15 | attack_pgd, evaluate_pgd, evaluate_standard) 16 | from autoattack import AutoAttack 17 | # installing AutoAttack by: pip install git+https://github.com/fra31/auto-attack 18 | 19 | cifar10_mean = (0.4914, 0.4822, 0.4465) 20 | cifar10_std = (0.2471, 0.2435, 0.2616) 21 | mu = torch.tensor(cifar10_mean).view(3,1,1).cuda() 22 | std = torch.tensor(cifar10_std).view(3,1,1).cuda() 23 | 24 | def normalize_PGDAT(X): 25 | return (X - mu)/std 26 | 27 | def normalize_TRADES(X): 28 | return X 29 | 30 | def get_args(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--batch-size', default=128, type=int) 33 | parser.add_argument('--data-dir', default='../cifar-data', type=str) 34 | parser.add_argument('--epsilon', default=8, type=int) 35 | parser.add_argument('--out-dir', default='train_fgsm_output', type=str, help='Output directory') 36 | parser.add_argument('--seed', default=0, type=int, help='Random seed') 37 | parser.add_argument('--ATmethods', default='TRADES', type=str) 38 | return parser.parse_args() 39 | 40 | 41 | def main(): 42 | args = get_args() 43 | 44 | np.random.seed(args.seed) 45 | torch.manual_seed(args.seed) 46 | torch.cuda.manual_seed(args.seed) 47 | 48 | logger = logging.getLogger(__name__) 49 | 50 | logging.basicConfig( 51 | format='[%(asctime)s] - %(message)s', 52 | datefmt='%Y/%m/%d %H:%M:%S', 53 | level=logging.DEBUG, 54 | handlers=[ 55 | logging.StreamHandler() 56 | ]) 57 | 58 | logger.info(args) 59 | 60 | _, test_loader = get_loaders(args.data_dir, args.batch_size) 61 | 62 | best_state_dict = torch.load(os.path.join(args.out_dir, 'model_best.pth')) 63 | 64 | if args.ATmethods == 'TRADES': 65 | normalize = normalize_TRADES 66 | elif args.ATmethods == 'PGDAT': 67 | normalize = normalize_PGDAT 68 | 69 | # Evaluation 70 | model_test = PreActResNet18().cuda() 71 | # model_test = WideResNet(34, 10, widen_factor=10, dropRate=0.0) 72 | model_test = nn.DataParallel(model_test).cuda() # put this line after loading state_dict if the weights are saved without module. 73 | if 'state_dict' in best_state_dict.keys(): 74 | model_test.load_state_dict(best_state_dict['state_dict']) 75 | else: 76 | model_test.load_state_dict(best_state_dict) 77 | model_test.float() 78 | model_test.eval() 79 | 80 | 81 | ### Evaluate clean acc ### 82 | _, test_acc = evaluate_standard(test_loader, model_test, normalize=normalize) 83 | print('Clean acc: ', test_acc) 84 | 85 | ### Evaluate PGD (CE loss) acc ### 86 | _, pgd_acc_CE = evaluate_pgd(test_loader, model_test, attack_iters=10, restarts=1, eps=8, step=2, use_CWloss=False, normalize=normalize) 87 | print('PGD-10 (10 restarts, step 2, CE loss) acc: ', pgd_acc_CE) 88 | 89 | ### Evaluate PGD (CW loss) acc ### 90 | _, pgd_acc_CW = evaluate_pgd(test_loader, model_test, attack_iters=10, restarts=1, eps=8, step=2, use_CWloss=True, normalize=normalize) 91 | print('PGD-10 (10 restarts, step 2, CW loss) acc: ', pgd_acc_CW) 92 | 93 | ### Evaluate AutoAttack ### 94 | l = [x for (x, y) in test_loader] 95 | x_test = torch.cat(l, 0) 96 | l = [y for (x, y) in test_loader] 97 | y_test = torch.cat(l, 0) 98 | class normalize_model(): 99 | def __init__(self, model): 100 | self.model_test = model 101 | def __call__(self, x): 102 | return self.model_test(normalize(x)) 103 | new_model = normalize_model(model_test) 104 | epsilon = 8 / 255. 105 | adversary = AutoAttack(new_model, norm='Linf', eps=epsilon, version='standard') 106 | X_adv = adversary.run_standard_evaluation(x_test, y_test, bs=128) 107 | 108 | if __name__ == "__main__": 109 | main() 110 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .dpn import * 3 | from .lenet import * 4 | from .senet import * 5 | from .pnasnet import * 6 | from .densenet import * 7 | from .googlenet import * 8 | from .shufflenet import * 9 | from .shufflenetv2 import * 10 | from .resnet import * 11 | from .resnext import * 12 | from .mobilenet import * 13 | from .mobilenetv2 import * 14 | from .efficientnet import * 15 | from .regnet import * 16 | -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | '''DenseNet in PyTorch.''' 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, in_planes, growth_rate): 11 | super(Bottleneck, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 16 | 17 | def forward(self, x): 18 | out = self.conv1(F.relu(self.bn1(x))) 19 | out = self.conv2(F.relu(self.bn2(out))) 20 | out = torch.cat([out,x], 1) 21 | return out 22 | 23 | 24 | class Transition(nn.Module): 25 | def __init__(self, in_planes, out_planes): 26 | super(Transition, self).__init__() 27 | self.bn = nn.BatchNorm2d(in_planes) 28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 29 | 30 | def forward(self, x): 31 | out = self.conv(F.relu(self.bn(x))) 32 | out = F.avg_pool2d(out, 2) 33 | return out 34 | 35 | 36 | class DenseNet(nn.Module): 37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 38 | super(DenseNet, self).__init__() 39 | self.growth_rate = growth_rate 40 | 41 | num_planes = 2*growth_rate 42 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 43 | 44 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 45 | num_planes += nblocks[0]*growth_rate 46 | out_planes = int(math.floor(num_planes*reduction)) 47 | self.trans1 = Transition(num_planes, out_planes) 48 | num_planes = out_planes 49 | 50 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 51 | num_planes += nblocks[1]*growth_rate 52 | out_planes = int(math.floor(num_planes*reduction)) 53 | self.trans2 = Transition(num_planes, out_planes) 54 | num_planes = out_planes 55 | 56 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 57 | num_planes += nblocks[2]*growth_rate 58 | out_planes = int(math.floor(num_planes*reduction)) 59 | self.trans3 = Transition(num_planes, out_planes) 60 | num_planes = out_planes 61 | 62 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 63 | num_planes += nblocks[3]*growth_rate 64 | 65 | self.bn = nn.BatchNorm2d(num_planes) 66 | self.linear = nn.Linear(num_planes, num_classes) 67 | 68 | def _make_dense_layers(self, block, in_planes, nblock): 69 | layers = [] 70 | for i in range(nblock): 71 | layers.append(block(in_planes, self.growth_rate)) 72 | in_planes += self.growth_rate 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = self.conv1(x) 77 | out = self.trans1(self.dense1(out)) 78 | out = self.trans2(self.dense2(out)) 79 | out = self.trans3(self.dense3(out)) 80 | out = self.dense4(out) 81 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 82 | out = out.view(out.size(0), -1) 83 | out = self.linear(out) 84 | return out 85 | 86 | def DenseNet121(): 87 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32) 88 | 89 | def DenseNet169(): 90 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32) 91 | 92 | def DenseNet201(): 93 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32) 94 | 95 | def DenseNet161(): 96 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48) 97 | 98 | def densenet_cifar(): 99 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12) 100 | 101 | def test(): 102 | net = densenet_cifar() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /models/dpn.py: -------------------------------------------------------------------------------- 1 | '''Dual Path Networks in PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Bottleneck(nn.Module): 8 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer): 9 | super(Bottleneck, self).__init__() 10 | self.out_planes = out_planes 11 | self.dense_depth = dense_depth 12 | 13 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False) 16 | self.bn2 = nn.BatchNorm2d(in_planes) 17 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False) 18 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth) 19 | 20 | self.shortcut = nn.Sequential() 21 | if first_layer: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(out_planes+dense_depth) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = F.relu(self.bn2(self.conv2(out))) 30 | out = self.bn3(self.conv3(out)) 31 | x = self.shortcut(x) 32 | d = self.out_planes 33 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1) 34 | out = F.relu(out) 35 | return out 36 | 37 | 38 | class DPN(nn.Module): 39 | def __init__(self, cfg): 40 | super(DPN, self).__init__() 41 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes'] 42 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth'] 43 | 44 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(64) 46 | self.last_planes = 64 47 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1) 48 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2) 49 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2) 50 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2) 51 | self.linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], 10) 52 | 53 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride): 54 | strides = [stride] + [1]*(num_blocks-1) 55 | layers = [] 56 | for i,stride in enumerate(strides): 57 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0)) 58 | self.last_planes = out_planes + (i+2) * dense_depth 59 | return nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(self.conv1(x))) 63 | out = self.layer1(out) 64 | out = self.layer2(out) 65 | out = self.layer3(out) 66 | out = self.layer4(out) 67 | out = F.avg_pool2d(out, 4) 68 | out = out.view(out.size(0), -1) 69 | out = self.linear(out) 70 | return out 71 | 72 | 73 | def DPN26(): 74 | cfg = { 75 | 'in_planes': (96,192,384,768), 76 | 'out_planes': (256,512,1024,2048), 77 | 'num_blocks': (2,2,2,2), 78 | 'dense_depth': (16,32,24,128) 79 | } 80 | return DPN(cfg) 81 | 82 | def DPN92(): 83 | cfg = { 84 | 'in_planes': (96,192,384,768), 85 | 'out_planes': (256,512,1024,2048), 86 | 'num_blocks': (3,4,20,3), 87 | 'dense_depth': (16,32,24,128) 88 | } 89 | return DPN(cfg) 90 | 91 | 92 | def test(): 93 | net = DPN92() 94 | x = torch.randn(1,3,32,32) 95 | y = net(x) 96 | print(y) 97 | 98 | # test() 99 | -------------------------------------------------------------------------------- /models/efficientnet.py: -------------------------------------------------------------------------------- 1 | '''EfficientNet in PyTorch. 2 | 3 | Paper: "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks". 4 | 5 | Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | def swish(x): 13 | return x * x.sigmoid() 14 | 15 | 16 | def drop_connect(x, drop_ratio): 17 | keep_ratio = 1.0 - drop_ratio 18 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device) 19 | mask.bernoulli_(keep_ratio) 20 | x.div_(keep_ratio) 21 | x.mul_(mask) 22 | return x 23 | 24 | 25 | class SE(nn.Module): 26 | '''Squeeze-and-Excitation block with Swish.''' 27 | 28 | def __init__(self, in_channels, se_channels): 29 | super(SE, self).__init__() 30 | self.se1 = nn.Conv2d(in_channels, se_channels, 31 | kernel_size=1, bias=True) 32 | self.se2 = nn.Conv2d(se_channels, in_channels, 33 | kernel_size=1, bias=True) 34 | 35 | def forward(self, x): 36 | out = F.adaptive_avg_pool2d(x, (1, 1)) 37 | out = swish(self.se1(out)) 38 | out = self.se2(out).sigmoid() 39 | out = x * out 40 | return out 41 | 42 | 43 | class Block(nn.Module): 44 | '''expansion + depthwise + pointwise + squeeze-excitation''' 45 | 46 | def __init__(self, 47 | in_channels, 48 | out_channels, 49 | kernel_size, 50 | stride, 51 | expand_ratio=1, 52 | se_ratio=0., 53 | drop_rate=0.): 54 | super(Block, self).__init__() 55 | self.stride = stride 56 | self.drop_rate = drop_rate 57 | self.expand_ratio = expand_ratio 58 | 59 | # Expansion 60 | channels = expand_ratio * in_channels 61 | self.conv1 = nn.Conv2d(in_channels, 62 | channels, 63 | kernel_size=1, 64 | stride=1, 65 | padding=0, 66 | bias=False) 67 | self.bn1 = nn.BatchNorm2d(channels) 68 | 69 | # Depthwise conv 70 | self.conv2 = nn.Conv2d(channels, 71 | channels, 72 | kernel_size=kernel_size, 73 | stride=stride, 74 | padding=(1 if kernel_size == 3 else 2), 75 | groups=channels, 76 | bias=False) 77 | self.bn2 = nn.BatchNorm2d(channels) 78 | 79 | # SE layers 80 | se_channels = int(in_channels * se_ratio) 81 | self.se = SE(channels, se_channels) 82 | 83 | # Output 84 | self.conv3 = nn.Conv2d(channels, 85 | out_channels, 86 | kernel_size=1, 87 | stride=1, 88 | padding=0, 89 | bias=False) 90 | self.bn3 = nn.BatchNorm2d(out_channels) 91 | 92 | # Skip connection if in and out shapes are the same (MV-V2 style) 93 | self.has_skip = (stride == 1) and (in_channels == out_channels) 94 | 95 | def forward(self, x): 96 | out = x if self.expand_ratio == 1 else swish(self.bn1(self.conv1(x))) 97 | out = swish(self.bn2(self.conv2(out))) 98 | out = self.se(out) 99 | out = self.bn3(self.conv3(out)) 100 | if self.has_skip: 101 | if self.training and self.drop_rate > 0: 102 | out = drop_connect(out, self.drop_rate) 103 | out = out + x 104 | return out 105 | 106 | 107 | class EfficientNet(nn.Module): 108 | def __init__(self, cfg, num_classes=10): 109 | super(EfficientNet, self).__init__() 110 | self.cfg = cfg 111 | self.conv1 = nn.Conv2d(3, 112 | 32, 113 | kernel_size=3, 114 | stride=1, 115 | padding=1, 116 | bias=False) 117 | self.bn1 = nn.BatchNorm2d(32) 118 | self.layers = self._make_layers(in_channels=32) 119 | self.linear = nn.Linear(cfg['out_channels'][-1], num_classes) 120 | 121 | def _make_layers(self, in_channels): 122 | layers = [] 123 | cfg = [self.cfg[k] for k in ['expansion', 'out_channels', 'num_blocks', 'kernel_size', 124 | 'stride']] 125 | b = 0 126 | blocks = sum(self.cfg['num_blocks']) 127 | for expansion, out_channels, num_blocks, kernel_size, stride in zip(*cfg): 128 | strides = [stride] + [1] * (num_blocks - 1) 129 | for stride in strides: 130 | drop_rate = self.cfg['drop_connect_rate'] * b / blocks 131 | layers.append( 132 | Block(in_channels, 133 | out_channels, 134 | kernel_size, 135 | stride, 136 | expansion, 137 | se_ratio=0.25, 138 | drop_rate=drop_rate)) 139 | in_channels = out_channels 140 | return nn.Sequential(*layers) 141 | 142 | def forward(self, x): 143 | out = swish(self.bn1(self.conv1(x))) 144 | out = self.layers(out) 145 | out = F.adaptive_avg_pool2d(out, 1) 146 | out = out.view(out.size(0), -1) 147 | dropout_rate = self.cfg['dropout_rate'] 148 | if self.training and dropout_rate > 0: 149 | out = F.dropout(out, p=dropout_rate) 150 | out = self.linear(out) 151 | return out 152 | 153 | 154 | def EfficientNetB0(): 155 | cfg = { 156 | 'num_blocks': [1, 2, 2, 3, 3, 4, 1], 157 | 'expansion': [1, 6, 6, 6, 6, 6, 6], 158 | 'out_channels': [16, 24, 40, 80, 112, 192, 320], 159 | 'kernel_size': [3, 3, 5, 3, 5, 5, 3], 160 | 'stride': [1, 2, 2, 2, 1, 2, 1], 161 | 'dropout_rate': 0.2, 162 | 'drop_connect_rate': 0.2, 163 | } 164 | return EfficientNet(cfg) 165 | 166 | 167 | def test(): 168 | net = EfficientNetB0() 169 | x = torch.randn(2, 3, 32, 32) 170 | y = net(x) 171 | print(y.shape) 172 | 173 | 174 | if __name__ == '__main__': 175 | test() 176 | -------------------------------------------------------------------------------- /models/googlenet.py: -------------------------------------------------------------------------------- 1 | '''GoogLeNet with PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Inception(nn.Module): 8 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): 9 | super(Inception, self).__init__() 10 | # 1x1 conv branch 11 | self.b1 = nn.Sequential( 12 | nn.Conv2d(in_planes, n1x1, kernel_size=1), 13 | nn.BatchNorm2d(n1x1), 14 | nn.ReLU(True), 15 | ) 16 | 17 | # 1x1 conv -> 3x3 conv branch 18 | self.b2 = nn.Sequential( 19 | nn.Conv2d(in_planes, n3x3red, kernel_size=1), 20 | nn.BatchNorm2d(n3x3red), 21 | nn.ReLU(True), 22 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1), 23 | nn.BatchNorm2d(n3x3), 24 | nn.ReLU(True), 25 | ) 26 | 27 | # 1x1 conv -> 5x5 conv branch 28 | self.b3 = nn.Sequential( 29 | nn.Conv2d(in_planes, n5x5red, kernel_size=1), 30 | nn.BatchNorm2d(n5x5red), 31 | nn.ReLU(True), 32 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1), 33 | nn.BatchNorm2d(n5x5), 34 | nn.ReLU(True), 35 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(n5x5), 37 | nn.ReLU(True), 38 | ) 39 | 40 | # 3x3 pool -> 1x1 conv branch 41 | self.b4 = nn.Sequential( 42 | nn.MaxPool2d(3, stride=1, padding=1), 43 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 44 | nn.BatchNorm2d(pool_planes), 45 | nn.ReLU(True), 46 | ) 47 | 48 | def forward(self, x): 49 | y1 = self.b1(x) 50 | y2 = self.b2(x) 51 | y3 = self.b3(x) 52 | y4 = self.b4(x) 53 | return torch.cat([y1,y2,y3,y4], 1) 54 | 55 | 56 | class GoogLeNet(nn.Module): 57 | def __init__(self): 58 | super(GoogLeNet, self).__init__() 59 | self.pre_layers = nn.Sequential( 60 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 61 | nn.BatchNorm2d(192), 62 | nn.ReLU(True), 63 | ) 64 | 65 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 66 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 67 | 68 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 69 | 70 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 71 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 72 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 73 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 74 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 75 | 76 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 77 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 78 | 79 | self.avgpool = nn.AvgPool2d(8, stride=1) 80 | self.linear = nn.Linear(1024, 10) 81 | 82 | def forward(self, x): 83 | out = self.pre_layers(x) 84 | out = self.a3(out) 85 | out = self.b3(out) 86 | out = self.maxpool(out) 87 | out = self.a4(out) 88 | out = self.b4(out) 89 | out = self.c4(out) 90 | out = self.d4(out) 91 | out = self.e4(out) 92 | out = self.maxpool(out) 93 | out = self.a5(out) 94 | out = self.b5(out) 95 | out = self.avgpool(out) 96 | out = out.view(out.size(0), -1) 97 | out = self.linear(out) 98 | return out 99 | 100 | 101 | def test(): 102 | net = GoogLeNet() 103 | x = torch.randn(1,3,32,32) 104 | y = net(x) 105 | print(y.size()) 106 | 107 | # test() 108 | -------------------------------------------------------------------------------- /models/lenet.py: -------------------------------------------------------------------------------- 1 | '''LeNet in PyTorch.''' 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LeNet(nn.Module): 6 | def __init__(self): 7 | super(LeNet, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 6, 5) 9 | self.conv2 = nn.Conv2d(6, 16, 5) 10 | self.fc1 = nn.Linear(16*5*5, 120) 11 | self.fc2 = nn.Linear(120, 84) 12 | self.fc3 = nn.Linear(84, 10) 13 | 14 | def forward(self, x): 15 | out = F.relu(self.conv1(x)) 16 | out = F.max_pool2d(out, 2) 17 | out = F.relu(self.conv2(out)) 18 | out = F.max_pool2d(out, 2) 19 | out = out.view(out.size(0), -1) 20 | out = F.relu(self.fc1(out)) 21 | out = F.relu(self.fc2(out)) 22 | out = self.fc3(out) 23 | return out 24 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | '''MobileNet in PyTorch. 2 | 3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 4 | for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''Depthwise conv + Pointwise conv''' 13 | def __init__(self, in_planes, out_planes, stride=1): 14 | super(Block, self).__init__() 15 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 18 | self.bn2 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | out = F.relu(self.bn1(self.conv1(x))) 22 | out = F.relu(self.bn2(self.conv2(out))) 23 | return out 24 | 25 | 26 | class MobileNet(nn.Module): 27 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 28 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 29 | 30 | def __init__(self, num_classes=10): 31 | super(MobileNet, self).__init__() 32 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(32) 34 | self.layers = self._make_layers(in_planes=32) 35 | self.linear = nn.Linear(1024, num_classes) 36 | 37 | def _make_layers(self, in_planes): 38 | layers = [] 39 | for x in self.cfg: 40 | out_planes = x if isinstance(x, int) else x[0] 41 | stride = 1 if isinstance(x, int) else x[1] 42 | layers.append(Block(in_planes, out_planes, stride)) 43 | in_planes = out_planes 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | out = F.relu(self.bn1(self.conv1(x))) 48 | out = self.layers(out) 49 | out = F.avg_pool2d(out, 2) 50 | out = out.view(out.size(0), -1) 51 | out = self.linear(out) 52 | return out 53 | 54 | 55 | def test(): 56 | net = MobileNet() 57 | x = torch.randn(1,3,32,32) 58 | y = net(x) 59 | print(y.size()) 60 | 61 | # test() 62 | -------------------------------------------------------------------------------- /models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | '''MobileNetV2 in PyTorch. 2 | 3 | See the paper "Inverted Residuals and Linear Bottlenecks: 4 | Mobile Networks for Classification, Detection and Segmentation" for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class Block(nn.Module): 12 | '''expand + depthwise + pointwise''' 13 | def __init__(self, in_planes, out_planes, expansion, stride): 14 | super(Block, self).__init__() 15 | self.stride = stride 16 | 17 | planes = expansion * in_planes 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 23 | self.bn3 = nn.BatchNorm2d(out_planes) 24 | 25 | self.shortcut = nn.Sequential() 26 | if stride == 1 and in_planes != out_planes: 27 | self.shortcut = nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 29 | nn.BatchNorm2d(out_planes), 30 | ) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.bn1(self.conv1(x))) 34 | out = F.relu(self.bn2(self.conv2(out))) 35 | out = self.bn3(self.conv3(out)) 36 | out = out + self.shortcut(x) if self.stride==1 else out 37 | return out 38 | 39 | 40 | class MobileNetV2(nn.Module): 41 | # (expansion, out_planes, num_blocks, stride) 42 | cfg = [(1, 16, 1, 1), 43 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 44 | (6, 32, 3, 2), 45 | (6, 64, 4, 2), 46 | (6, 96, 3, 1), 47 | (6, 160, 3, 2), 48 | (6, 320, 1, 1)] 49 | 50 | def __init__(self, num_classes=10): 51 | super(MobileNetV2, self).__init__() 52 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 53 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(32) 55 | self.layers = self._make_layers(in_planes=32) 56 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 57 | self.bn2 = nn.BatchNorm2d(1280) 58 | self.linear = nn.Linear(1280, num_classes) 59 | 60 | def _make_layers(self, in_planes): 61 | layers = [] 62 | for expansion, out_planes, num_blocks, stride in self.cfg: 63 | strides = [stride] + [1]*(num_blocks-1) 64 | for stride in strides: 65 | layers.append(Block(in_planes, out_planes, expansion, stride)) 66 | in_planes = out_planes 67 | return nn.Sequential(*layers) 68 | 69 | def forward(self, x): 70 | out = F.relu(self.bn1(self.conv1(x))) 71 | out = self.layers(out) 72 | out = F.relu(self.bn2(self.conv2(out))) 73 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 74 | out = F.avg_pool2d(out, 4) 75 | out = out.view(out.size(0), -1) 76 | out = self.linear(out) 77 | return out 78 | 79 | 80 | def test(): 81 | net = MobileNetV2() 82 | x = torch.randn(2,3,32,32) 83 | y = net(x) 84 | print(y.size()) 85 | 86 | # test() 87 | -------------------------------------------------------------------------------- /models/pnasnet.py: -------------------------------------------------------------------------------- 1 | '''PNASNet in PyTorch. 2 | 3 | Paper: Progressive Neural Architecture Search 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class SepConv(nn.Module): 11 | '''Separable Convolution.''' 12 | def __init__(self, in_planes, out_planes, kernel_size, stride): 13 | super(SepConv, self).__init__() 14 | self.conv1 = nn.Conv2d(in_planes, out_planes, 15 | kernel_size, stride, 16 | padding=(kernel_size-1)//2, 17 | bias=False, groups=in_planes) 18 | self.bn1 = nn.BatchNorm2d(out_planes) 19 | 20 | def forward(self, x): 21 | return self.bn1(self.conv1(x)) 22 | 23 | 24 | class CellA(nn.Module): 25 | def __init__(self, in_planes, out_planes, stride=1): 26 | super(CellA, self).__init__() 27 | self.stride = stride 28 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 29 | if stride==2: 30 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 31 | self.bn1 = nn.BatchNorm2d(out_planes) 32 | 33 | def forward(self, x): 34 | y1 = self.sep_conv1(x) 35 | y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 36 | if self.stride==2: 37 | y2 = self.bn1(self.conv1(y2)) 38 | return F.relu(y1+y2) 39 | 40 | class CellB(nn.Module): 41 | def __init__(self, in_planes, out_planes, stride=1): 42 | super(CellB, self).__init__() 43 | self.stride = stride 44 | # Left branch 45 | self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) 46 | self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride) 47 | # Right branch 48 | self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride) 49 | if stride==2: 50 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 51 | self.bn1 = nn.BatchNorm2d(out_planes) 52 | # Reduce channels 53 | self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 54 | self.bn2 = nn.BatchNorm2d(out_planes) 55 | 56 | def forward(self, x): 57 | # Left branch 58 | y1 = self.sep_conv1(x) 59 | y2 = self.sep_conv2(x) 60 | # Right branch 61 | y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) 62 | if self.stride==2: 63 | y3 = self.bn1(self.conv1(y3)) 64 | y4 = self.sep_conv3(x) 65 | # Concat & reduce channels 66 | b1 = F.relu(y1+y2) 67 | b2 = F.relu(y3+y4) 68 | y = torch.cat([b1,b2], 1) 69 | return F.relu(self.bn2(self.conv2(y))) 70 | 71 | class PNASNet(nn.Module): 72 | def __init__(self, cell_type, num_cells, num_planes): 73 | super(PNASNet, self).__init__() 74 | self.in_planes = num_planes 75 | self.cell_type = cell_type 76 | 77 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(num_planes) 79 | 80 | self.layer1 = self._make_layer(num_planes, num_cells=6) 81 | self.layer2 = self._downsample(num_planes*2) 82 | self.layer3 = self._make_layer(num_planes*2, num_cells=6) 83 | self.layer4 = self._downsample(num_planes*4) 84 | self.layer5 = self._make_layer(num_planes*4, num_cells=6) 85 | 86 | self.linear = nn.Linear(num_planes*4, 10) 87 | 88 | def _make_layer(self, planes, num_cells): 89 | layers = [] 90 | for _ in range(num_cells): 91 | layers.append(self.cell_type(self.in_planes, planes, stride=1)) 92 | self.in_planes = planes 93 | return nn.Sequential(*layers) 94 | 95 | def _downsample(self, planes): 96 | layer = self.cell_type(self.in_planes, planes, stride=2) 97 | self.in_planes = planes 98 | return layer 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = self.layer5(out) 107 | out = F.avg_pool2d(out, 8) 108 | out = self.linear(out.view(out.size(0), -1)) 109 | return out 110 | 111 | 112 | def PNASNetA(): 113 | return PNASNet(CellA, num_cells=6, num_planes=44) 114 | 115 | def PNASNetB(): 116 | return PNASNet(CellB, num_cells=6, num_planes=32) 117 | 118 | 119 | def test(): 120 | net = PNASNetB() 121 | x = torch.randn(1,3,32,32) 122 | y = net(x) 123 | print(y) 124 | 125 | # test() 126 | -------------------------------------------------------------------------------- /models/regnet.py: -------------------------------------------------------------------------------- 1 | '''RegNet in PyTorch. 2 | 3 | Paper: "Designing Network Design Spaces". 4 | 5 | Reference: https://github.com/keras-team/keras-applications/blob/master/keras_applications/efficientnet.py 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class SE(nn.Module): 13 | '''Squeeze-and-Excitation block.''' 14 | 15 | def __init__(self, in_planes, se_planes): 16 | super(SE, self).__init__() 17 | self.se1 = nn.Conv2d(in_planes, se_planes, kernel_size=1, bias=True) 18 | self.se2 = nn.Conv2d(se_planes, in_planes, kernel_size=1, bias=True) 19 | 20 | def forward(self, x): 21 | out = F.adaptive_avg_pool2d(x, (1, 1)) 22 | out = F.relu(self.se1(out)) 23 | out = self.se2(out).sigmoid() 24 | out = x * out 25 | return out 26 | 27 | 28 | class Block(nn.Module): 29 | def __init__(self, w_in, w_out, stride, group_width, bottleneck_ratio, se_ratio): 30 | super(Block, self).__init__() 31 | # 1x1 32 | w_b = int(round(w_out * bottleneck_ratio)) 33 | self.conv1 = nn.Conv2d(w_in, w_b, kernel_size=1, bias=False) 34 | self.bn1 = nn.BatchNorm2d(w_b) 35 | # 3x3 36 | num_groups = w_b // group_width 37 | self.conv2 = nn.Conv2d(w_b, w_b, kernel_size=3, 38 | stride=stride, padding=1, groups=num_groups, bias=False) 39 | self.bn2 = nn.BatchNorm2d(w_b) 40 | # se 41 | self.with_se = se_ratio > 0 42 | if self.with_se: 43 | w_se = int(round(w_in * se_ratio)) 44 | self.se = SE(w_b, w_se) 45 | # 1x1 46 | self.conv3 = nn.Conv2d(w_b, w_out, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(w_out) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride != 1 or w_in != w_out: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(w_in, w_out, 53 | kernel_size=1, stride=stride, bias=False), 54 | nn.BatchNorm2d(w_out) 55 | ) 56 | 57 | def forward(self, x): 58 | out = F.relu(self.bn1(self.conv1(x))) 59 | out = F.relu(self.bn2(self.conv2(out))) 60 | if self.with_se: 61 | out = self.se(out) 62 | out = self.bn3(self.conv3(out)) 63 | out += self.shortcut(x) 64 | out = F.relu(out) 65 | return out 66 | 67 | 68 | class RegNet(nn.Module): 69 | def __init__(self, cfg, num_classes=10): 70 | super(RegNet, self).__init__() 71 | self.cfg = cfg 72 | self.in_planes = 64 73 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 74 | stride=1, padding=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(64) 76 | self.layer1 = self._make_layer(0) 77 | self.layer2 = self._make_layer(1) 78 | self.layer3 = self._make_layer(2) 79 | self.layer4 = self._make_layer(3) 80 | self.linear = nn.Linear(self.cfg['widths'][-1], num_classes) 81 | 82 | def _make_layer(self, idx): 83 | depth = self.cfg['depths'][idx] 84 | width = self.cfg['widths'][idx] 85 | stride = self.cfg['strides'][idx] 86 | group_width = self.cfg['group_width'] 87 | bottleneck_ratio = self.cfg['bottleneck_ratio'] 88 | se_ratio = self.cfg['se_ratio'] 89 | 90 | layers = [] 91 | for i in range(depth): 92 | s = stride if i == 0 else 1 93 | layers.append(Block(self.in_planes, width, 94 | s, group_width, bottleneck_ratio, se_ratio)) 95 | self.in_planes = width 96 | return nn.Sequential(*layers) 97 | 98 | def forward(self, x): 99 | out = F.relu(self.bn1(self.conv1(x))) 100 | out = self.layer1(out) 101 | out = self.layer2(out) 102 | out = self.layer3(out) 103 | out = self.layer4(out) 104 | out = F.adaptive_avg_pool2d(out, (1, 1)) 105 | out = out.view(out.size(0), -1) 106 | out = self.linear(out) 107 | return out 108 | 109 | 110 | def RegNetX_200MF(): 111 | cfg = { 112 | 'depths': [1, 1, 4, 7], 113 | 'widths': [24, 56, 152, 368], 114 | 'strides': [1, 1, 2, 2], 115 | 'group_width': 8, 116 | 'bottleneck_ratio': 1, 117 | 'se_ratio': 0, 118 | } 119 | return RegNet(cfg) 120 | 121 | 122 | def RegNetX_400MF(): 123 | cfg = { 124 | 'depths': [1, 2, 7, 12], 125 | 'widths': [32, 64, 160, 384], 126 | 'strides': [1, 1, 2, 2], 127 | 'group_width': 16, 128 | 'bottleneck_ratio': 1, 129 | 'se_ratio': 0, 130 | } 131 | return RegNet(cfg) 132 | 133 | 134 | def RegNetY_400MF(): 135 | cfg = { 136 | 'depths': [1, 2, 7, 12], 137 | 'widths': [32, 64, 160, 384], 138 | 'strides': [1, 1, 2, 2], 139 | 'group_width': 16, 140 | 'bottleneck_ratio': 1, 141 | 'se_ratio': 0.25, 142 | } 143 | return RegNet(cfg) 144 | 145 | 146 | def test(): 147 | net = RegNetX_200MF() 148 | print(net) 149 | x = torch.randn(2, 3, 32, 32) 150 | y = net(x) 151 | print(y.shape) 152 | 153 | 154 | if __name__ == '__main__': 155 | test() 156 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d( 20 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 21 | self.bn1 = nn.BatchNorm2d(planes) 22 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 23 | stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*planes, 30 | kernel_size=1, stride=stride, bias=False), 31 | nn.BatchNorm2d(self.expansion*planes) 32 | ) 33 | 34 | def forward(self, x): 35 | out = F.relu(self.bn1(self.conv1(x))) 36 | out = self.bn2(self.conv2(out)) 37 | out += self.shortcut(x) 38 | out = F.relu(out) 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1): 46 | super(Bottleneck, self).__init__() 47 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(planes) 49 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 50 | stride=stride, padding=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv3 = nn.Conv2d(planes, self.expansion * 53 | planes, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 55 | 56 | self.shortcut = nn.Sequential() 57 | if stride != 1 or in_planes != self.expansion*planes: 58 | self.shortcut = nn.Sequential( 59 | nn.Conv2d(in_planes, self.expansion*planes, 60 | kernel_size=1, stride=stride, bias=False), 61 | nn.BatchNorm2d(self.expansion*planes) 62 | ) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(self.conv1(x))) 66 | out = F.relu(self.bn2(self.conv2(out))) 67 | out = self.bn3(self.conv3(out)) 68 | out += self.shortcut(x) 69 | out = F.relu(out) 70 | return out 71 | 72 | 73 | class ResNet(nn.Module): 74 | def __init__(self, block, num_blocks, num_classes=10): 75 | super(ResNet, self).__init__() 76 | self.in_planes = 64 77 | 78 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 79 | stride=1, padding=1, bias=False) 80 | self.bn1 = nn.BatchNorm2d(64) 81 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 82 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 83 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 84 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 85 | self.linear = nn.Linear(512*block.expansion, num_classes) 86 | 87 | def _make_layer(self, block, planes, num_blocks, stride): 88 | strides = [stride] + [1]*(num_blocks-1) 89 | layers = [] 90 | for stride in strides: 91 | layers.append(block(self.in_planes, planes, stride)) 92 | self.in_planes = planes * block.expansion 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | out = F.relu(self.bn1(self.conv1(x))) 97 | out = self.layer1(out) 98 | out = self.layer2(out) 99 | out = self.layer3(out) 100 | out = self.layer4(out) 101 | out = F.avg_pool2d(out, 4) 102 | out = out.view(out.size(0), -1) 103 | out = self.linear(out) 104 | return out 105 | 106 | 107 | def ResNet18(): 108 | return ResNet(BasicBlock, [2, 2, 2, 2]) 109 | 110 | 111 | def ResNet34(): 112 | return ResNet(BasicBlock, [3, 4, 6, 3]) 113 | 114 | 115 | def ResNet50(): 116 | return ResNet(Bottleneck, [3, 4, 6, 3]) 117 | 118 | 119 | def ResNet101(): 120 | return ResNet(Bottleneck, [3, 4, 23, 3]) 121 | 122 | 123 | def ResNet152(): 124 | return ResNet(Bottleneck, [3, 8, 36, 3]) 125 | 126 | 127 | def test(): 128 | net = ResNet18() 129 | y = net(torch.randn(1, 3, 32, 32)) 130 | print(y.size()) 131 | 132 | # test() 133 | -------------------------------------------------------------------------------- /models/resnext.py: -------------------------------------------------------------------------------- 1 | '''ResNeXt in PyTorch. 2 | 3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Block(nn.Module): 11 | '''Grouped convolution block.''' 12 | expansion = 2 13 | 14 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 15 | super(Block, self).__init__() 16 | group_width = cardinality * bottleneck_width 17 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(group_width) 19 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 20 | self.bn2 = nn.BatchNorm2d(group_width) 21 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 22 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*group_width: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*group_width) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = F.relu(self.bn2(self.conv2(out))) 34 | out = self.bn3(self.conv3(out)) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class ResNeXt(nn.Module): 41 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10): 42 | super(ResNeXt, self).__init__() 43 | self.cardinality = cardinality 44 | self.bottleneck_width = bottleneck_width 45 | self.in_planes = 64 46 | 47 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 48 | self.bn1 = nn.BatchNorm2d(64) 49 | self.layer1 = self._make_layer(num_blocks[0], 1) 50 | self.layer2 = self._make_layer(num_blocks[1], 2) 51 | self.layer3 = self._make_layer(num_blocks[2], 2) 52 | # self.layer4 = self._make_layer(num_blocks[3], 2) 53 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes) 54 | 55 | def _make_layer(self, num_blocks, stride): 56 | strides = [stride] + [1]*(num_blocks-1) 57 | layers = [] 58 | for stride in strides: 59 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 60 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width 61 | # Increase bottleneck_width by 2 after each stage. 62 | self.bottleneck_width *= 2 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, x): 66 | out = F.relu(self.bn1(self.conv1(x))) 67 | out = self.layer1(out) 68 | out = self.layer2(out) 69 | out = self.layer3(out) 70 | # out = self.layer4(out) 71 | out = F.avg_pool2d(out, 8) 72 | out = out.view(out.size(0), -1) 73 | out = self.linear(out) 74 | return out 75 | 76 | 77 | def ResNeXt29_2x64d(): 78 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64) 79 | 80 | def ResNeXt29_4x64d(): 81 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64) 82 | 83 | def ResNeXt29_8x64d(): 84 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64) 85 | 86 | def ResNeXt29_32x4d(): 87 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4) 88 | 89 | def test_resnext(): 90 | net = ResNeXt29_2x64d() 91 | x = torch.randn(1,3,32,32) 92 | y = net(x) 93 | print(y.size()) 94 | 95 | # test_resnext() 96 | -------------------------------------------------------------------------------- /models/senet.py: -------------------------------------------------------------------------------- 1 | '''SENet in PyTorch. 2 | 3 | SENet is the winner of ImageNet-2017. The paper is not released yet. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | def __init__(self, in_planes, planes, stride=1): 12 | super(BasicBlock, self).__init__() 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | 18 | self.shortcut = nn.Sequential() 19 | if stride != 1 or in_planes != planes: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), 22 | nn.BatchNorm2d(planes) 23 | ) 24 | 25 | # SE layers 26 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear 27 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | 33 | # Squeeze 34 | w = F.avg_pool2d(out, out.size(2)) 35 | w = F.relu(self.fc1(w)) 36 | w = F.sigmoid(self.fc2(w)) 37 | # Excitation 38 | out = out * w # New broadcasting feature from v0.2! 39 | 40 | out += self.shortcut(x) 41 | out = F.relu(out) 42 | return out 43 | 44 | 45 | class PreActBlock(nn.Module): 46 | def __init__(self, in_planes, planes, stride=1): 47 | super(PreActBlock, self).__init__() 48 | self.bn1 = nn.BatchNorm2d(in_planes) 49 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 50 | self.bn2 = nn.BatchNorm2d(planes) 51 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 52 | 53 | if stride != 1 or in_planes != planes: 54 | self.shortcut = nn.Sequential( 55 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) 56 | ) 57 | 58 | # SE layers 59 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) 60 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(x)) 64 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 65 | out = self.conv1(out) 66 | out = self.conv2(F.relu(self.bn2(out))) 67 | 68 | # Squeeze 69 | w = F.avg_pool2d(out, out.size(2)) 70 | w = F.relu(self.fc1(w)) 71 | w = F.sigmoid(self.fc2(w)) 72 | # Excitation 73 | out = out * w 74 | 75 | out += shortcut 76 | return out 77 | 78 | 79 | class SENet(nn.Module): 80 | def __init__(self, block, num_blocks, num_classes=10): 81 | super(SENet, self).__init__() 82 | self.in_planes = 64 83 | 84 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 85 | self.bn1 = nn.BatchNorm2d(64) 86 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 87 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 88 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 89 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 90 | self.linear = nn.Linear(512, num_classes) 91 | 92 | def _make_layer(self, block, planes, num_blocks, stride): 93 | strides = [stride] + [1]*(num_blocks-1) 94 | layers = [] 95 | for stride in strides: 96 | layers.append(block(self.in_planes, planes, stride)) 97 | self.in_planes = planes 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | out = F.relu(self.bn1(self.conv1(x))) 102 | out = self.layer1(out) 103 | out = self.layer2(out) 104 | out = self.layer3(out) 105 | out = self.layer4(out) 106 | out = F.avg_pool2d(out, 4) 107 | out = out.view(out.size(0), -1) 108 | out = self.linear(out) 109 | return out 110 | 111 | 112 | def SENet18(): 113 | return SENet(PreActBlock, [2,2,2,2]) 114 | 115 | 116 | def test(): 117 | net = SENet18() 118 | y = net(torch.randn(1,3,32,32)) 119 | print(y.size()) 120 | 121 | # test() 122 | -------------------------------------------------------------------------------- /models/shufflenet.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | 3 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ShuffleBlock(nn.Module): 11 | def __init__(self, groups): 12 | super(ShuffleBlock, self).__init__() 13 | self.groups = groups 14 | 15 | def forward(self, x): 16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 17 | N,C,H,W = x.size() 18 | g = self.groups 19 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W) 20 | 21 | 22 | class Bottleneck(nn.Module): 23 | def __init__(self, in_planes, out_planes, stride, groups): 24 | super(Bottleneck, self).__init__() 25 | self.stride = stride 26 | 27 | mid_planes = out_planes/4 28 | g = 1 if in_planes==24 else groups 29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 30 | self.bn1 = nn.BatchNorm2d(mid_planes) 31 | self.shuffle1 = ShuffleBlock(groups=g) 32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 33 | self.bn2 = nn.BatchNorm2d(mid_planes) 34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 35 | self.bn3 = nn.BatchNorm2d(out_planes) 36 | 37 | self.shortcut = nn.Sequential() 38 | if stride == 2: 39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 40 | 41 | def forward(self, x): 42 | out = F.relu(self.bn1(self.conv1(x))) 43 | out = self.shuffle1(out) 44 | out = F.relu(self.bn2(self.conv2(out))) 45 | out = self.bn3(self.conv3(out)) 46 | res = self.shortcut(x) 47 | out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res) 48 | return out 49 | 50 | 51 | class ShuffleNet(nn.Module): 52 | def __init__(self, cfg): 53 | super(ShuffleNet, self).__init__() 54 | out_planes = cfg['out_planes'] 55 | num_blocks = cfg['num_blocks'] 56 | groups = cfg['groups'] 57 | 58 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 59 | self.bn1 = nn.BatchNorm2d(24) 60 | self.in_planes = 24 61 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 62 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 63 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 64 | self.linear = nn.Linear(out_planes[2], 10) 65 | 66 | def _make_layer(self, out_planes, num_blocks, groups): 67 | layers = [] 68 | for i in range(num_blocks): 69 | stride = 2 if i == 0 else 1 70 | cat_planes = self.in_planes if i == 0 else 0 71 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups)) 72 | self.in_planes = out_planes 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = F.relu(self.bn1(self.conv1(x))) 77 | out = self.layer1(out) 78 | out = self.layer2(out) 79 | out = self.layer3(out) 80 | out = F.avg_pool2d(out, 4) 81 | out = out.view(out.size(0), -1) 82 | out = self.linear(out) 83 | return out 84 | 85 | 86 | def ShuffleNetG2(): 87 | cfg = { 88 | 'out_planes': [200,400,800], 89 | 'num_blocks': [4,8,4], 90 | 'groups': 2 91 | } 92 | return ShuffleNet(cfg) 93 | 94 | def ShuffleNetG3(): 95 | cfg = { 96 | 'out_planes': [240,480,960], 97 | 'num_blocks': [4,8,4], 98 | 'groups': 3 99 | } 100 | return ShuffleNet(cfg) 101 | 102 | 103 | def test(): 104 | net = ShuffleNetG2() 105 | x = torch.randn(1,3,32,32) 106 | y = net(x) 107 | print(y) 108 | 109 | # test() 110 | -------------------------------------------------------------------------------- /models/shufflenetv2.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNetV2 in PyTorch. 2 | 3 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class ShuffleBlock(nn.Module): 11 | def __init__(self, groups=2): 12 | super(ShuffleBlock, self).__init__() 13 | self.groups = groups 14 | 15 | def forward(self, x): 16 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 17 | N, C, H, W = x.size() 18 | g = self.groups 19 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 20 | 21 | 22 | class SplitBlock(nn.Module): 23 | def __init__(self, ratio): 24 | super(SplitBlock, self).__init__() 25 | self.ratio = ratio 26 | 27 | def forward(self, x): 28 | c = int(x.size(1) * self.ratio) 29 | return x[:, :c, :, :], x[:, c:, :, :] 30 | 31 | 32 | class BasicBlock(nn.Module): 33 | def __init__(self, in_channels, split_ratio=0.5): 34 | super(BasicBlock, self).__init__() 35 | self.split = SplitBlock(split_ratio) 36 | in_channels = int(in_channels * split_ratio) 37 | self.conv1 = nn.Conv2d(in_channels, in_channels, 38 | kernel_size=1, bias=False) 39 | self.bn1 = nn.BatchNorm2d(in_channels) 40 | self.conv2 = nn.Conv2d(in_channels, in_channels, 41 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False) 42 | self.bn2 = nn.BatchNorm2d(in_channels) 43 | self.conv3 = nn.Conv2d(in_channels, in_channels, 44 | kernel_size=1, bias=False) 45 | self.bn3 = nn.BatchNorm2d(in_channels) 46 | self.shuffle = ShuffleBlock() 47 | 48 | def forward(self, x): 49 | x1, x2 = self.split(x) 50 | out = F.relu(self.bn1(self.conv1(x2))) 51 | out = self.bn2(self.conv2(out)) 52 | out = F.relu(self.bn3(self.conv3(out))) 53 | out = torch.cat([x1, out], 1) 54 | out = self.shuffle(out) 55 | return out 56 | 57 | 58 | class DownBlock(nn.Module): 59 | def __init__(self, in_channels, out_channels): 60 | super(DownBlock, self).__init__() 61 | mid_channels = out_channels // 2 62 | # left 63 | self.conv1 = nn.Conv2d(in_channels, in_channels, 64 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False) 65 | self.bn1 = nn.BatchNorm2d(in_channels) 66 | self.conv2 = nn.Conv2d(in_channels, mid_channels, 67 | kernel_size=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(mid_channels) 69 | # right 70 | self.conv3 = nn.Conv2d(in_channels, mid_channels, 71 | kernel_size=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(mid_channels) 73 | self.conv4 = nn.Conv2d(mid_channels, mid_channels, 74 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False) 75 | self.bn4 = nn.BatchNorm2d(mid_channels) 76 | self.conv5 = nn.Conv2d(mid_channels, mid_channels, 77 | kernel_size=1, bias=False) 78 | self.bn5 = nn.BatchNorm2d(mid_channels) 79 | 80 | self.shuffle = ShuffleBlock() 81 | 82 | def forward(self, x): 83 | # left 84 | out1 = self.bn1(self.conv1(x)) 85 | out1 = F.relu(self.bn2(self.conv2(out1))) 86 | # right 87 | out2 = F.relu(self.bn3(self.conv3(x))) 88 | out2 = self.bn4(self.conv4(out2)) 89 | out2 = F.relu(self.bn5(self.conv5(out2))) 90 | # concat 91 | out = torch.cat([out1, out2], 1) 92 | out = self.shuffle(out) 93 | return out 94 | 95 | 96 | class ShuffleNetV2(nn.Module): 97 | def __init__(self, net_size): 98 | super(ShuffleNetV2, self).__init__() 99 | out_channels = configs[net_size]['out_channels'] 100 | num_blocks = configs[net_size]['num_blocks'] 101 | 102 | self.conv1 = nn.Conv2d(3, 24, kernel_size=3, 103 | stride=1, padding=1, bias=False) 104 | self.bn1 = nn.BatchNorm2d(24) 105 | self.in_channels = 24 106 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) 107 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) 108 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) 109 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3], 110 | kernel_size=1, stride=1, padding=0, bias=False) 111 | self.bn2 = nn.BatchNorm2d(out_channels[3]) 112 | self.linear = nn.Linear(out_channels[3], 10) 113 | 114 | def _make_layer(self, out_channels, num_blocks): 115 | layers = [DownBlock(self.in_channels, out_channels)] 116 | for i in range(num_blocks): 117 | layers.append(BasicBlock(out_channels)) 118 | self.in_channels = out_channels 119 | return nn.Sequential(*layers) 120 | 121 | def forward(self, x): 122 | out = F.relu(self.bn1(self.conv1(x))) 123 | # out = F.max_pool2d(out, 3, stride=2, padding=1) 124 | out = self.layer1(out) 125 | out = self.layer2(out) 126 | out = self.layer3(out) 127 | out = F.relu(self.bn2(self.conv2(out))) 128 | out = F.avg_pool2d(out, 4) 129 | out = out.view(out.size(0), -1) 130 | out = self.linear(out) 131 | return out 132 | 133 | 134 | configs = { 135 | 0.5: { 136 | 'out_channels': (48, 96, 192, 1024), 137 | 'num_blocks': (3, 7, 3) 138 | }, 139 | 140 | 1: { 141 | 'out_channels': (116, 232, 464, 1024), 142 | 'num_blocks': (3, 7, 3) 143 | }, 144 | 1.5: { 145 | 'out_channels': (176, 352, 704, 1024), 146 | 'num_blocks': (3, 7, 3) 147 | }, 148 | 2: { 149 | 'out_channels': (224, 488, 976, 2048), 150 | 'num_blocks': (3, 7, 3) 151 | } 152 | } 153 | 154 | 155 | def test(): 156 | net = ShuffleNetV2(net_size=0.5) 157 | x = torch.randn(3, 3, 32, 32) 158 | y = net(x) 159 | print(y.shape) 160 | 161 | 162 | # test() 163 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | cfg = { 7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 11 | } 12 | 13 | 14 | class VGG(nn.Module): 15 | def __init__(self, vgg_name): 16 | super(VGG, self).__init__() 17 | self.features = self._make_layers(cfg[vgg_name]) 18 | self.classifier = nn.Linear(512, 10) 19 | 20 | def forward(self, x): 21 | out = self.features(x) 22 | out = out.view(out.size(0), -1) 23 | out = self.classifier(out) 24 | return out 25 | 26 | def _make_layers(self, cfg): 27 | layers = [] 28 | in_channels = 3 29 | for x in cfg: 30 | if x == 'M': 31 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 32 | else: 33 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 34 | nn.BatchNorm2d(x), 35 | nn.ReLU(inplace=True)] 36 | in_channels = x 37 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 38 | return nn.Sequential(*layers) 39 | 40 | 41 | def test(): 42 | net = VGG('VGG11') 43 | x = torch.randn(2,3,32,32) 44 | y = net(x) 45 | print(y.size()) 46 | 47 | # test() 48 | -------------------------------------------------------------------------------- /preactresnet.py: -------------------------------------------------------------------------------- 1 | '''Pre-activation ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | track_running_stats=True 12 | affine=True 13 | normal_func = nn.BatchNorm2d 14 | 15 | # track_running_stats=False 16 | # affine=True 17 | # normal_func = nn.InstanceNorm2d 18 | 19 | 20 | if not track_running_stats: 21 | print('BN track False') 22 | 23 | class PreActBlock(nn.Module): 24 | '''Pre-activation version of the BasicBlock.''' 25 | expansion = 1 26 | 27 | def __init__(self, in_planes, planes, stride=1, activation='ReLU', softplus_beta=1): 28 | super(PreActBlock, self).__init__() 29 | self.bn1 = normal_func(in_planes, track_running_stats=track_running_stats, affine=affine) 30 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 31 | self.bn2 = normal_func(planes, track_running_stats=track_running_stats, affine=affine) 32 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 33 | 34 | if stride != 1 or in_planes != self.expansion*planes: 35 | self.shortcut = nn.Sequential( 36 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 37 | ) 38 | if activation == 'ReLU': 39 | self.relu = nn.ReLU(inplace=True) 40 | print('ReLU') 41 | elif activation == 'Softplus': 42 | self.relu = nn.Softplus(beta=softplus_beta, threshold=20) 43 | print('Softplus') 44 | elif activation == 'GELU': 45 | self.relu = nn.GELU() 46 | print('GELU') 47 | elif activation == 'ELU': 48 | self.relu = nn.ELU(alpha=1.0, inplace=True) 49 | print('ELU') 50 | elif activation == 'LeakyReLU': 51 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 52 | print('LeakyReLU') 53 | elif activation == 'SELU': 54 | self.relu = nn.SELU(inplace=True) 55 | print('SELU') 56 | elif activation == 'CELU': 57 | self.relu = nn.CELU(alpha=1.2, inplace=True) 58 | print('CELU') 59 | elif activation == 'Tanh': 60 | self.relu = nn.Tanh() 61 | print('Tanh') 62 | 63 | def forward(self, x): 64 | out = self.relu(self.bn1(x)) 65 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 66 | out = self.conv1(out) 67 | out = self.conv2(self.relu(self.bn2(out))) 68 | out += shortcut 69 | return out 70 | 71 | 72 | class PreActBottleneck(nn.Module): 73 | '''Pre-activation version of the original Bottleneck module.''' 74 | expansion = 4 75 | 76 | def __init__(self, in_planes, planes, stride=1, activation='ReLU', softplus_beta=1): 77 | super(PreActBottleneck, self).__init__() 78 | self.bn1 = normal_func(in_planes, track_running_stats=track_running_stats, affine=affine) 79 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 80 | self.bn2 = normal_func(planes, track_running_stats=track_running_stats, affine=affine) 81 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 82 | self.bn3 = normal_func(planes, track_running_stats=track_running_stats, affine=affine) 83 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 84 | 85 | if stride != 1 or in_planes != self.expansion*planes: 86 | self.shortcut = nn.Sequential( 87 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 88 | ) 89 | 90 | def forward(self, x): 91 | out = F.relu(self.bn1(x)) 92 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 93 | out = self.conv1(out) 94 | out = self.conv2(F.relu(self.bn2(out))) 95 | out = self.conv3(F.relu(self.bn3(out))) 96 | out += shortcut 97 | return out 98 | 99 | 100 | class PreActResNet(nn.Module): 101 | def __init__(self, block, num_blocks, num_classes=10, normalize = False, normalize_only_FN = False, scale = 15, activation='ReLU', softplus_beta=1): 102 | super(PreActResNet, self).__init__() 103 | self.in_planes = 64 104 | 105 | self.normalize = normalize 106 | self.normalize_only_FN = normalize_only_FN 107 | self.scale = scale 108 | 109 | self.activation = activation 110 | self.softplus_beta = softplus_beta 111 | 112 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 113 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 114 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 115 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 116 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 117 | self.bn = normal_func(512 * block.expansion, track_running_stats=track_running_stats, affine=affine) 118 | 119 | if self.normalize: 120 | self.linear = nn.Linear(512*block.expansion, num_classes, bias=False) 121 | else: 122 | self.linear = nn.Linear(512*block.expansion, num_classes) 123 | 124 | 125 | if activation == 'ReLU': 126 | self.relu = nn.ReLU(inplace=True) 127 | print('ReLU') 128 | elif activation == 'Softplus': 129 | self.relu = nn.Softplus(beta=softplus_beta, threshold=20) 130 | print('Softplus') 131 | elif activation == 'GELU': 132 | self.relu = nn.GELU() 133 | print('GELU') 134 | elif activation == 'ELU': 135 | self.relu = nn.ELU(alpha=1.0, inplace=True) 136 | print('ELU') 137 | elif activation == 'LeakyReLU': 138 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 139 | print('LeakyReLU') 140 | elif activation == 'SELU': 141 | self.relu = nn.SELU(inplace=True) 142 | print('SELU') 143 | elif activation == 'CELU': 144 | self.relu = nn.CELU(alpha=1.2, inplace=True) 145 | print('CELU') 146 | elif activation == 'Tanh': 147 | self.relu = nn.Tanh() 148 | print('Tanh') 149 | print('Use activation of ' + activation) 150 | 151 | 152 | def _make_layer(self, block, planes, num_blocks, stride): 153 | strides = [stride] + [1]*(num_blocks-1) 154 | layers = [] 155 | for stride in strides: 156 | layers.append(block(self.in_planes, planes, stride, 157 | activation=self.activation, softplus_beta=self.softplus_beta)) 158 | self.in_planes = planes * block.expansion 159 | return nn.Sequential(*layers) 160 | 161 | def forward(self, x): 162 | out = self.conv1(x) 163 | out = self.layer1(out) 164 | out = self.layer2(out) 165 | out = self.layer3(out) 166 | out = self.layer4(out) 167 | out = self.relu(self.bn(out)) 168 | out = F.avg_pool2d(out, 4) 169 | out = out.view(out.size(0), -1) 170 | if self.normalize_only_FN: 171 | out = F.normalize(out, p=2, dim=1) 172 | 173 | if self.normalize: 174 | out = F.normalize(out, p=2, dim=1) * self.scale 175 | for _, module in self.linear.named_modules(): 176 | if isinstance(module, nn.Linear): 177 | module.weight.data = F.normalize(module.weight, p=2, dim=1) 178 | return self.linear(out) 179 | 180 | 181 | def PreActResNet18(num_classes=10, normalize = False, normalize_only_FN = False, scale = 15, activation='ReLU', softplus_beta=1): 182 | return PreActResNet(PreActBlock, [2,2,2,2], num_classes=num_classes, normalize = normalize 183 | , normalize_only_FN = normalize_only_FN, scale = scale, activation=activation, softplus_beta=softplus_beta) 184 | 185 | def PreActResNet34(): 186 | return PreActResNet(PreActBlock, [3,4,6,3]) 187 | 188 | def PreActResNet50(): 189 | return PreActResNet(PreActBottleneck, [3,4,6,3]) 190 | 191 | def PreActResNet101(): 192 | return PreActResNet(PreActBottleneck, [3,4,23,3]) 193 | 194 | def PreActResNet152(): 195 | return PreActResNet(PreActBottleneck, [3,8,36,3]) 196 | 197 | 198 | def test(): 199 | net = PreActResNet18() 200 | y = net((torch.randn(1,3,32,32))) 201 | print(y.size()) 202 | 203 | # test() 204 | -------------------------------------------------------------------------------- /train_cifar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import sys 4 | import time 5 | import math 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | 13 | import os 14 | 15 | from wideresnet import WideResNet 16 | from preactresnet import PreActResNet18, PreActResNet50 17 | from models import * 18 | 19 | from utils import * 20 | 21 | mu = torch.tensor(cifar10_mean).view(3,1,1).cuda() 22 | std = torch.tensor(cifar10_std).view(3,1,1).cuda() 23 | 24 | def normalize(X): 25 | return (X - mu)/std 26 | 27 | upper_limit, lower_limit = 1,0 28 | 29 | 30 | def clamp(X, lower_limit, upper_limit): 31 | return torch.max(torch.min(X, upper_limit), lower_limit) 32 | 33 | 34 | class LabelSmoothingLoss(nn.Module): 35 | def __init__(self, classes=10, smoothing=0.0, dim=-1): 36 | super(LabelSmoothingLoss, self).__init__() 37 | self.confidence = 1.0 - smoothing 38 | self.smoothing = smoothing 39 | self.cls = classes 40 | self.dim = dim 41 | 42 | def forward(self, pred, target): 43 | pred = pred.log_softmax(dim=self.dim) 44 | with torch.no_grad(): 45 | # true_dist = pred.data.clone() 46 | true_dist = torch.zeros_like(pred) 47 | true_dist.fill_(self.smoothing / (self.cls - 1)) 48 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 49 | return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) 50 | 51 | 52 | class Batches(): 53 | def __init__(self, dataset, batch_size, shuffle, set_random_choices=False, num_workers=0, drop_last=False): 54 | self.dataset = dataset 55 | self.batch_size = batch_size 56 | self.set_random_choices = set_random_choices 57 | self.dataloader = torch.utils.data.DataLoader( 58 | dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=shuffle, drop_last=drop_last 59 | ) 60 | 61 | def __iter__(self): 62 | if self.set_random_choices: 63 | self.dataset.set_random_choices() 64 | return ({'input': x.to(device).float(), 'target': y.to(device).long()} for (x,y) in self.dataloader) 65 | 66 | def __len__(self): 67 | return len(self.dataloader) 68 | 69 | 70 | def mixup_data(x, y, alpha=1.0): 71 | '''Returns mixed inputs, pairs of targets, and lambda''' 72 | if alpha > 0: 73 | lam = np.random.beta(alpha, alpha) 74 | else: 75 | lam = 1 76 | 77 | batch_size = x.size()[0] 78 | index = torch.randperm(batch_size).cuda() 79 | 80 | mixed_x = lam * x + (1 - lam) * x[index, :] 81 | y_a, y_b = y, y[index] 82 | return mixed_x, y_a, y_b, lam 83 | 84 | 85 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 86 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 87 | 88 | 89 | def dlr_loss(x, y): 90 | x_sorted, ind_sorted = x.sort(dim=1) 91 | ind = (ind_sorted[:, -1] == y).float() 92 | 93 | loss_value = -(x[np.arange(x.shape[0]), y] - x_sorted[:, -2] * ind - x_sorted[:, -1] 94 | * (1. - ind)) / (x_sorted[:, -1] - x_sorted[:, -3] + 1e-12) 95 | return loss_value.mean() 96 | 97 | def CW_loss(x, y): 98 | x_sorted, ind_sorted = x.sort(dim=1) 99 | ind = (ind_sorted[:, -1] == y).float() 100 | 101 | loss_value = -(x[np.arange(x.shape[0]), y] - x_sorted[:, -2] * ind - x_sorted[:, -1] * (1. - ind)) 102 | return loss_value.mean() 103 | 104 | def attack_pgd(model, X, y, epsilon, alpha, attack_iters, restarts, 105 | norm, mixup=False, y_a=None, y_b=None, lam=None, 106 | early_stop=False, early_stop_pgd_max=1, 107 | multitarget=False, 108 | use_DLRloss=False, use_CWloss=False, 109 | epoch=0, totalepoch=110, gamma=0.8, 110 | use_adaptive=False, s_HE=15, 111 | fast_better=False, BNeval=False): 112 | max_loss = torch.zeros(y.shape[0]).cuda() 113 | max_delta = torch.zeros_like(X).cuda() 114 | 115 | if BNeval: 116 | model.eval() 117 | 118 | for _ in range(restarts): 119 | # early stop pgd counter for each x 120 | early_stop_pgd_count = early_stop_pgd_max * torch.ones(y.shape[0], dtype=torch.int32).cuda() 121 | 122 | # initialize perturbation 123 | delta = torch.zeros_like(X).cuda() 124 | if norm == "l_inf": 125 | delta.uniform_(-epsilon, epsilon) 126 | elif norm == "l_2": 127 | delta.normal_() 128 | d_flat = delta.view(delta.size(0),-1) 129 | n = d_flat.norm(p=2,dim=1).view(delta.size(0),1,1,1) 130 | r = torch.zeros_like(n).uniform_(0, 1) 131 | delta *= r/n*epsilon 132 | else: 133 | raise ValueError 134 | delta = clamp(delta, lower_limit-X, upper_limit-X) 135 | delta.requires_grad = True 136 | 137 | iter_count = torch.zeros(y.shape[0]) 138 | 139 | # craft adversarial examples 140 | for _ in range(attack_iters): 141 | output = model(normalize(X + delta)) 142 | 143 | # if use early stop pgd 144 | if early_stop: 145 | # calculate mask for early stop pgd 146 | if_success_fool = (output.max(1)[1] != y).to(dtype=torch.int32) 147 | early_stop_pgd_count = early_stop_pgd_count - if_success_fool 148 | index = torch.where(early_stop_pgd_count > 0)[0] 149 | iter_count[index] = iter_count[index] + 1 150 | else: 151 | index = slice(None,None,None) 152 | if not isinstance(index, slice) and len(index) == 0: 153 | break 154 | 155 | # Whether use mixup criterion 156 | if fast_better: 157 | loss_ori = F.cross_entropy(output, y) 158 | grad_ori = torch.autograd.grad(loss_ori, delta, create_graph=True)[0] 159 | loss_grad = (alpha / 4.) * (torch.norm(grad_ori.view(grad_ori.shape[0], -1), p=2, dim=1) ** 2) 160 | loss = loss_ori + loss_grad.mean() 161 | loss.backward() 162 | grad = delta.grad.detach() 163 | 164 | elif not mixup: 165 | if multitarget: 166 | random_label = torch.randint(low=0, high=10, size=y.shape).cuda() 167 | random_direction = 2*((random_label == y).to(dtype=torch.float32) - 0.5) 168 | loss = torch.mean(random_direction * F.cross_entropy(output, random_label, reduction='none')) 169 | loss.backward() 170 | grad = delta.grad.detach() 171 | elif use_DLRloss: 172 | beta_ = gamma * epoch / totalepoch 173 | loss = (1. - beta_) * F.cross_entropy(output, y) + beta_ * dlr_loss(output, y) 174 | loss.backward() 175 | grad = delta.grad.detach() 176 | elif use_CWloss: 177 | beta_ = gamma * epoch / totalepoch 178 | loss = (1. - beta_) * F.cross_entropy(output, y) + beta_ * CW_loss(output, y) 179 | loss.backward() 180 | grad = delta.grad.detach() 181 | else: 182 | if use_adaptive: 183 | loss = F.cross_entropy(s_HE * output, y) 184 | else: 185 | loss = F.cross_entropy(output, y) 186 | loss.backward() 187 | grad = delta.grad.detach() 188 | else: 189 | criterion = nn.CrossEntropyLoss() 190 | loss = mixup_criterion(criterion, model(normalize(X+delta)), y_a, y_b, lam) 191 | loss.backward() 192 | grad = delta.grad.detach() 193 | 194 | 195 | d = delta[index, :, :, :] 196 | g = grad[index, :, :, :] 197 | x = X[index, :, :, :] 198 | if norm == "l_inf": 199 | d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon) 200 | elif norm == "l_2": 201 | g_norm = torch.norm(g.view(g.shape[0],-1),dim=1).view(-1,1,1,1) 202 | scaled_g = g/(g_norm + 1e-10) 203 | d = (d + scaled_g*alpha).view(d.size(0),-1).renorm(p=2,dim=0,maxnorm=epsilon).view_as(d) 204 | d = clamp(d, lower_limit - x, upper_limit - x) 205 | delta.data[index, :, :, :] = d 206 | delta.grad.zero_() 207 | if mixup: 208 | criterion = nn.CrossEntropyLoss(reduction='none') 209 | all_loss = mixup_criterion(criterion, model(normalize(X+delta)), y_a, y_b, lam) 210 | else: 211 | all_loss = F.cross_entropy(model(normalize(X+delta)), y, reduction='none') 212 | max_delta[all_loss >= max_loss] = delta.detach()[all_loss >= max_loss] 213 | max_loss = torch.max(max_loss, all_loss) 214 | 215 | if BNeval: 216 | model.train() 217 | 218 | return max_delta, iter_count 219 | 220 | 221 | 222 | def get_args(): 223 | parser = argparse.ArgumentParser() 224 | parser.add_argument('--model', default='PreActResNet18') 225 | parser.add_argument('--l1', default=0, type=float) 226 | parser.add_argument('--data-dir', default='../cifar-data', type=str) 227 | parser.add_argument('--epochs', default=110, type=int) 228 | parser.add_argument('--lr-schedule', default='piecewise', choices=['superconverge', 'piecewise', 'linear', 'piecewisesmoothed', 'piecewisezoom', 'onedrop', 'multipledecay', 'cosine', 'cyclic']) 229 | parser.add_argument('--lr-max', default=0.1, type=float) 230 | parser.add_argument('--lr-one-drop', default=0.01, type=float) 231 | parser.add_argument('--lr-drop-epoch', default=100, type=int) 232 | parser.add_argument('--attack', default='pgd', type=str, choices=['pgd', 'fgsm', 'free', 'none']) 233 | parser.add_argument('--epsilon', default=8, type=int) 234 | parser.add_argument('--test_epsilon', default=8, type=int) 235 | parser.add_argument('--attack-iters', default=10, type=int) 236 | parser.add_argument('--restarts', default=1, type=int) 237 | parser.add_argument('--pgd-alpha', default=2, type=float) 238 | parser.add_argument('--test-pgd-alpha', default=2, type=float) 239 | parser.add_argument('--fgsm-alpha', default=1.25, type=float) 240 | parser.add_argument('--norm', default='l_inf', type=str, choices=['l_inf', 'l_2']) 241 | parser.add_argument('--fgsm-init', default='random', choices=['zero', 'random', 'previous']) 242 | parser.add_argument('--fname', default='cifar_model', type=str) 243 | parser.add_argument('--seed', default=0, type=int) 244 | parser.add_argument('--half', action='store_true') 245 | parser.add_argument('--width-factor', default=10, type=int) 246 | parser.add_argument('--resume', default=0, type=int) 247 | parser.add_argument('--eval', action='store_true') 248 | parser.add_argument('--val', action='store_true') 249 | parser.add_argument('--chkpt-iters', default=100, type=int) 250 | parser.add_argument('--mixture', action='store_true') # whether use mixture of clean and adv examples in a mini-batch 251 | parser.add_argument('--mixture_alpha', type=float) 252 | parser.add_argument('--l2', default=0, type=float) 253 | 254 | # Group 1 255 | parser.add_argument('--earlystopPGD', action='store_true') # whether use early stop in PGD 256 | parser.add_argument('--earlystopPGDepoch1', default=60, type=int) 257 | parser.add_argument('--earlystopPGDepoch2', default=100, type=int) 258 | 259 | parser.add_argument('--warmup_lr', action='store_true') # whether warm_up lr from 0 to max_lr in the first n epochs 260 | parser.add_argument('--warmup_lr_epoch', default=15, type=int) 261 | 262 | parser.add_argument('--weight_decay', default=5e-4, type=float)#weight decay 263 | 264 | parser.add_argument('--warmup_eps', action='store_true') # whether warm_up eps from 0 to 8/255 in the first n epochs 265 | parser.add_argument('--warmup_eps_epoch', default=15, type=int) 266 | 267 | parser.add_argument('--batch-size', default=128, type=int) #batch size 268 | 269 | parser.add_argument('--labelsmooth', action='store_true') # whether use label smoothing 270 | parser.add_argument('--labelsmoothvalue', default=0.0, type=float) 271 | 272 | parser.add_argument('--lrdecay', default='base', type=str, choices=['intenselr', 'base', 'looselr', 'lineardecay']) 273 | 274 | # Group 2 275 | parser.add_argument('--use_DLRloss', action='store_true') # whether use DLRloss 276 | parser.add_argument('--use_CWloss', action='store_true') # whether use CWloss 277 | 278 | 279 | parser.add_argument('--use_multitarget', action='store_true') # whether use multitarget 280 | 281 | parser.add_argument('--use_stronger_adv', action='store_true') # whether use mixture of clean and adv examples in a mini-batch 282 | parser.add_argument('--stronger_index', default=0, type=int) 283 | 284 | parser.add_argument('--use_FNandWN', action='store_true') # whether use FN and WN 285 | parser.add_argument('--use_adaptive', action='store_true') # whether use s in attack during training 286 | parser.add_argument('--s_FN', default=15, type=float) # s in FN 287 | parser.add_argument('--m_FN', default=0.2, type=float) # s in FN 288 | 289 | parser.add_argument('--use_FNonly', action='store_true') # whether use FN only 290 | 291 | parser.add_argument('--fast_better', action='store_true') 292 | 293 | parser.add_argument('--BNeval', action='store_true') # whether use eval mode for BN when crafting adversarial examples 294 | 295 | parser.add_argument('--focalloss', action='store_true') # whether use focalloss 296 | parser.add_argument('--focallosslambda', default=2., type=float) 297 | 298 | parser.add_argument('--activation', default='ReLU', type=str) 299 | parser.add_argument('--softplus_beta', default=1., type=float) 300 | 301 | parser.add_argument('--optimizer', default='momentum', choices=['momentum', 'Nesterov', 'SGD_GC', 'SGD_GCC', 'Adam', 'AdamW']) 302 | 303 | parser.add_argument('--mixup', action='store_true') 304 | parser.add_argument('--mixup-alpha', type=float) 305 | 306 | parser.add_argument('--cutout', action='store_true') 307 | parser.add_argument('--cutout-len', type=int) 308 | 309 | return parser.parse_args() 310 | 311 | def get_auto_fname(args): 312 | names = args.model + '_' + args.lr_schedule + '_eps' + str(args.epsilon) + '_bs' + str(args.batch_size) + '_maxlr' + str(args.lr_max) 313 | # Group 1 314 | if args.earlystopPGD: 315 | names = names + '_earlystopPGD' + str(args.earlystopPGDepoch1) + str(args.earlystopPGDepoch2) 316 | if args.warmup_lr: 317 | names = names + '_warmuplr' + str(args.warmup_lr_epoch) 318 | if args.warmup_eps: 319 | names = names + '_warmupeps' + str(args.warmup_eps_epoch) 320 | if args.weight_decay != 5e-4: 321 | names = names + '_wd' + str(args.weight_decay) 322 | if args.labelsmooth: 323 | names = names + '_ls' + str(args.labelsmoothvalue) 324 | 325 | # Group 2 326 | if args.use_stronger_adv: 327 | names = names + '_usestrongeradv#' + str(args.stronger_index) 328 | if args.use_multitarget: 329 | names = names + '_usemultitarget' 330 | if args.use_DLRloss: 331 | names = names + '_useDLRloss' 332 | if args.use_CWloss: 333 | names = names + '_useCWloss' 334 | if args.use_FNandWN: 335 | names = names + '_HE' + 's' + str(args.s_FN) + 'm' + str(args.m_FN) 336 | if args.use_adaptive: 337 | names = names + 'adaptive' 338 | if args.use_FNonly: 339 | names = names + '_FNonly' 340 | if args.fast_better: 341 | names = names + '_fastbetter' 342 | if args.activation != 'ReLU': 343 | names = names + '_' + args.activation 344 | if args.activation == 'Softplus': 345 | names = names + str(args.softplus_beta) 346 | if args.lrdecay != 'base': 347 | names = names + '_' + args.lrdecay 348 | if args.BNeval: 349 | names = names + '_BNeval' 350 | if args.focalloss: 351 | names = names + '_focalloss' + str(args.focallosslambda) 352 | if args.optimizer != 'momentum': 353 | names = names + '_' + args.optimizer 354 | if args.mixup: 355 | names = names + '_mixup' + str(args.mixup_alpha) 356 | if args.cutout: 357 | names = names + '_cutout' + str(args.cutout_len) 358 | if args.attack != 'pgd': 359 | names = names + '_' + args.attack 360 | 361 | print('File name: ', names) 362 | return names 363 | 364 | 365 | def main(): 366 | args = get_args() 367 | if args.fname == 'auto': 368 | names = get_auto_fname(args) 369 | args.fname = 'trained_models/' + names 370 | else: 371 | args.fname = 'trained_models/' + args.fname 372 | 373 | if not os.path.exists(args.fname): 374 | os.makedirs(args.fname) 375 | 376 | logger = logging.getLogger(__name__) 377 | logging.basicConfig( 378 | format='[%(asctime)s] - %(message)s', 379 | datefmt='%Y/%m/%d %H:%M:%S', 380 | level=logging.DEBUG, 381 | handlers=[ 382 | logging.FileHandler(os.path.join(args.fname, 'eval.log' if args.eval else 'output.log')), 383 | logging.StreamHandler() 384 | ]) 385 | 386 | logger.info(args) 387 | 388 | 389 | # Set seed 390 | np.random.seed(args.seed) 391 | torch.manual_seed(args.seed) 392 | torch.cuda.manual_seed(args.seed) 393 | 394 | 395 | # Prepare data 396 | transforms = [Crop(32, 32), FlipLR()] 397 | if args.cutout: 398 | transforms.append(Cutout(args.cutout_len, args.cutout_len)) 399 | if args.val: 400 | try: 401 | dataset = torch.load("cifar10_validation_split.pth") 402 | except: 403 | print("Couldn't find a dataset with a validation split, did you run " 404 | "generate_validation.py?") 405 | return 406 | val_set = list(zip(transpose(dataset['val']['data']/255.), dataset['val']['labels'])) 407 | val_batches = Batches(val_set, args.batch_size, shuffle=False, num_workers=4) 408 | else: 409 | dataset = cifar10(args.data_dir) 410 | train_set = list(zip(transpose(pad(dataset['train']['data'], 4)/255.), 411 | dataset['train']['labels'])) 412 | train_set_x = Transform(train_set, transforms) 413 | train_batches = Batches(train_set_x, args.batch_size, shuffle=True, set_random_choices=True, num_workers=4) 414 | 415 | test_set = list(zip(transpose(dataset['test']['data']/255.), dataset['test']['labels'])) 416 | test_batches = Batches(test_set, args.batch_size, shuffle=False, num_workers=4) 417 | 418 | 419 | # Set perturbations 420 | epsilon = (args.epsilon / 255.) 421 | test_epsilon = (args.test_epsilon / 255.) 422 | pgd_alpha = (args.pgd_alpha / 255.) 423 | test_pgd_alpha = (args.test_pgd_alpha / 255.) 424 | 425 | 426 | # Set models 427 | if args.model == 'VGG': 428 | model = VGG('VGG19') 429 | elif args.model == 'ResNet18': 430 | model = ResNet18() 431 | elif args.model == 'GoogLeNet': 432 | model = GoogLeNet() 433 | elif args.model == 'DenseNet121': 434 | model = DenseNet121() 435 | elif args.model == 'DenseNet201': 436 | model = DenseNet201() 437 | elif args.model == 'ResNeXt29': 438 | model = ResNeXt29_2x64d() 439 | elif args.model == 'ResNeXt29L': 440 | model = ResNeXt29_32x4d() 441 | elif args.model == 'MobileNet': 442 | model = MobileNet() 443 | elif args.model == 'MobileNetV2': 444 | model = MobileNetV2() 445 | elif args.model == 'DPN26': 446 | model = DPN26() 447 | elif args.model == 'DPN92': 448 | model = DPN92() 449 | elif args.model == 'ShuffleNetG2': 450 | model = ShuffleNetG2() 451 | elif args.model == 'SENet18': 452 | model = SENet18() 453 | elif args.model == 'ShuffleNetV2': 454 | model = ShuffleNetV2(1) 455 | elif args.model == 'EfficientNetB0': 456 | model = EfficientNetB0() 457 | elif args.model == 'PNASNetA': 458 | model = PNASNetA() 459 | elif args.model == 'RegNetX': 460 | model = RegNetX_200MF() 461 | elif args.model == 'RegNetLX': 462 | model = RegNetX_400MF() 463 | elif args.model == 'PreActResNet50': 464 | model = PreActResNet50() 465 | elif args.model == 'PreActResNet18': 466 | model = PreActResNet18(normalize_only_FN=args.use_FNonly, normalize=args.use_FNandWN, scale=args.s_FN, 467 | activation=args.activation, softplus_beta=args.softplus_beta) 468 | elif args.model == 'WideResNet': 469 | model = WideResNet(34, 10, widen_factor=10, dropRate=0.0, normalize=args.use_FNandWN, 470 | activation=args.activation, softplus_beta=args.softplus_beta) 471 | elif args.model == 'WideResNet_20': 472 | model = WideResNet(34, 10, widen_factor=20, dropRate=0.0, normalize=args.use_FNandWN, 473 | activation=args.activation, softplus_beta=args.softplus_beta) 474 | else: 475 | raise ValueError("Unknown model") 476 | 477 | model = nn.DataParallel(model).cuda() 478 | model.train() 479 | 480 | # Set training hyperparameters 481 | if args.l2: 482 | decay, no_decay = [], [] 483 | for name,param in model.named_parameters(): 484 | if 'bn' not in name and 'bias' not in name: 485 | decay.append(param) 486 | else: 487 | no_decay.append(param) 488 | params = [{'params':decay, 'weight_decay':args.l2}, 489 | {'params':no_decay, 'weight_decay': 0 }] 490 | else: 491 | params = model.parameters() 492 | if args.lr_schedule == 'cyclic': 493 | opt = torch.optim.Adam(params, lr=args.lr_max, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay) 494 | else: 495 | if args.optimizer == 'momentum': 496 | opt = torch.optim.SGD(params, lr=args.lr_max, momentum=0.9, weight_decay=args.weight_decay) 497 | elif args.optimizer == 'Nesterov': 498 | opt = torch.optim.SGD(params, lr=args.lr_max, momentum=0.9, weight_decay=args.weight_decay, nesterov=True) 499 | elif args.optimizer == 'SGD_GC': 500 | opt = SGD_GC(params, lr=args.lr_max, momentum=0.9, weight_decay=args.weight_decay) 501 | elif args.optimizer == 'SGD_GCC': 502 | opt = SGD_GCC(params, lr=args.lr_max, momentum=0.9, weight_decay=args.weight_decay) 503 | elif args.optimizer == 'Adam': 504 | opt = torch.optim.Adam(params, lr=args.lr_max, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay) 505 | elif args.optimizer == 'AdamW': 506 | opt = torch.optim.AdamW(params, lr=args.lr_max, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.weight_decay) 507 | 508 | # Cross-entropy (mean) 509 | if args.labelsmooth: 510 | criterion = LabelSmoothingLoss(smoothing=args.labelsmoothvalue) 511 | else: 512 | criterion = nn.CrossEntropyLoss() 513 | 514 | # If we use freeAT or fastAT with previous init 515 | if args.attack == 'free': 516 | delta = torch.zeros(args.batch_size, 3, 32, 32).cuda() 517 | delta.requires_grad = True 518 | elif args.attack == 'fgsm' and args.fgsm_init == 'previous': 519 | delta = torch.zeros(args.batch_size, 3, 32, 32).cuda() 520 | delta.requires_grad = True 521 | 522 | if args.attack == 'free': 523 | epochs = int(math.ceil(args.epochs / args.attack_iters)) 524 | else: 525 | epochs = args.epochs 526 | 527 | 528 | # Set lr schedule 529 | if args.lr_schedule == 'superconverge': 530 | lr_schedule = lambda t: np.interp([t], [0, args.epochs * 2 // 5, args.epochs], [0, args.lr_max, 0])[0] 531 | elif args.lr_schedule == 'piecewise': 532 | def lr_schedule(t, warm_up_lr = args.warmup_lr): 533 | if t < 100: 534 | if warm_up_lr and t < args.warmup_lr_epoch: 535 | return (t + 1.) / args.warmup_lr_epoch * args.lr_max 536 | else: 537 | return args.lr_max 538 | if args.lrdecay == 'lineardecay': 539 | if t < 105: 540 | return args.lr_max * 0.02 * (105 - t) 541 | else: 542 | return 0. 543 | elif args.lrdecay == 'intenselr': 544 | if t < 102: 545 | return args.lr_max / 10. 546 | else: 547 | return args.lr_max / 100. 548 | elif args.lrdecay == 'looselr': 549 | if t < 150: 550 | return args.lr_max / 10. 551 | else: 552 | return args.lr_max / 100. 553 | elif args.lrdecay == 'base': 554 | if t < 105: 555 | return args.lr_max / 10. 556 | else: 557 | return args.lr_max / 100. 558 | elif args.lr_schedule == 'linear': 559 | lr_schedule = lambda t: np.interp([t], [0, args.epochs // 3, args.epochs * 2 // 3, args.epochs], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0] 560 | elif args.lr_schedule == 'onedrop': 561 | def lr_schedule(t): 562 | if t < args.lr_drop_epoch: 563 | return args.lr_max 564 | else: 565 | return args.lr_one_drop 566 | elif args.lr_schedule == 'multipledecay': 567 | def lr_schedule(t): 568 | return args.lr_max - (t//(args.epochs//10))*(args.lr_max/10) 569 | elif args.lr_schedule == 'cosine': 570 | def lr_schedule(t): 571 | return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi)) 572 | elif args.lr_schedule == 'cyclic': 573 | def lr_schedule(t, stepsize=18, min_lr=1e-5, max_lr=args.lr_max): 574 | 575 | # Scaler: we can adapt this if we do not want the triangular CLR 576 | scaler = lambda x: 1. 577 | 578 | # Additional function to see where on the cycle we are 579 | cycle = math.floor(1 + t / (2 * stepsize)) 580 | x = abs(t / stepsize - 2 * cycle + 1) 581 | relative = max(0, (1 - x)) * scaler(cycle) 582 | 583 | return min_lr + (max_lr - min_lr) * relative 584 | 585 | 586 | 587 | 588 | #### Set stronger adv attacks when decay the lr #### 589 | def eps_alpha_schedule(t, warm_up_eps = args.warmup_eps, if_use_stronger_adv=args.use_stronger_adv, stronger_index=args.stronger_index): # Schedule number 0 590 | if stronger_index == 0: 591 | epsilon_s = [epsilon * 1.5, epsilon * 2] 592 | pgd_alpha_s = [pgd_alpha, pgd_alpha] 593 | elif stronger_index == 1: 594 | epsilon_s = [epsilon * 1.5, epsilon * 2] 595 | pgd_alpha_s = [pgd_alpha * 1.25, pgd_alpha * 1.5] 596 | elif stronger_index == 2: 597 | epsilon_s = [epsilon * 2, epsilon * 2.5] 598 | pgd_alpha_s = [pgd_alpha * 1.5, pgd_alpha * 2] 599 | else: 600 | print('Undefined stronger index') 601 | 602 | if if_use_stronger_adv: 603 | if t < 100: 604 | if t < args.warmup_eps_epoch and warm_up_eps: 605 | return (t + 1.) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts 606 | else: 607 | return epsilon, pgd_alpha, args.restarts 608 | elif t < 105: 609 | return epsilon_s[0], pgd_alpha_s[0], args.restarts 610 | else: 611 | return epsilon_s[1], pgd_alpha_s[1], args.restarts 612 | else: 613 | if t < args.warmup_eps_epoch and warm_up_eps: 614 | return (t + 1.) / args.warmup_eps_epoch * epsilon, pgd_alpha, args.restarts 615 | else: 616 | return epsilon, pgd_alpha, args.restarts 617 | 618 | #### Set the counter for the early stop of PGD #### 619 | def early_stop_counter_schedule(t): 620 | if t < args.earlystopPGDepoch1: 621 | return 1 622 | elif t < args.earlystopPGDepoch2: 623 | return 2 624 | else: 625 | return 3 626 | 627 | 628 | 629 | 630 | 631 | best_test_robust_acc = 0 632 | best_val_robust_acc = 0 633 | if args.resume: 634 | start_epoch = args.resume 635 | model.load_state_dict(torch.load(os.path.join(args.fname, f'model_{start_epoch-1}.pth'))) 636 | opt.load_state_dict(torch.load(os.path.join(args.fname, f'opt_{start_epoch-1}.pth'))) 637 | logger.info(f'Resuming at epoch {start_epoch}') 638 | 639 | best_test_robust_acc = torch.load(os.path.join(args.fname, f'model_best.pth'))['test_robust_acc'] 640 | if args.val: 641 | best_val_robust_acc = torch.load(os.path.join(args.fname, f'model_val.pth'))['val_robust_acc'] 642 | else: 643 | start_epoch = 0 644 | 645 | if args.eval: 646 | if not args.resume: 647 | logger.info("No model loaded to evaluate, specify with --resume FNAME") 648 | return 649 | logger.info("[Evaluation mode]") 650 | 651 | # logger.info('Epoch \t Train Time \t Test Time \t LR \t Train Loss \t Train Grad \t Train Acc \t Train Robust Loss \t Train Robust Acc || \t Test Loss \t Test Acc \t Test Robust Loss \t Test Robust Acc') 652 | logger.info('Epoch \t Train Acc \t Train Robust Acc \t Test Acc \t Test Robust Acc') 653 | 654 | # Records per epoch for savetxt 655 | train_loss_record = [] 656 | train_acc_record = [] 657 | train_robust_loss_record = [] 658 | train_robust_acc_record = [] 659 | train_grad_record = [] 660 | 661 | test_loss_record = [] 662 | test_acc_record = [] 663 | test_robust_loss_record = [] 664 | test_robust_acc_record = [] 665 | test_grad_record = [] 666 | 667 | for epoch in range(start_epoch, epochs): 668 | model.train() 669 | start_time = time.time() 670 | 671 | train_loss = 0 672 | train_acc = 0 673 | train_robust_loss = 0 674 | train_robust_acc = 0 675 | train_n = 0 676 | train_grad = 0 677 | 678 | record_iter = torch.tensor([]) 679 | 680 | for i, batch in enumerate(train_batches): 681 | if args.eval: 682 | break 683 | X, y = batch['input'], batch['target'] 684 | 685 | onehot_target_withmargin_HE = args.m_FN * args.s_FN * torch.nn.functional.one_hot(y, num_classes=10) 686 | 687 | if args.mixup: 688 | X, y_a, y_b, lam = mixup_data(X, y, args.mixup_alpha) 689 | X, y_a, y_b = map(Variable, (X, y_a, y_b)) 690 | epoch_now = epoch + (i + 1) / len(train_batches) 691 | lr = lr_schedule(epoch_now) 692 | opt.param_groups[0].update(lr=lr) 693 | 694 | if args.attack == 'pgd': 695 | # Random initialization 696 | epsilon_sche, pgd_alpha_sche, restarts_sche = eps_alpha_schedule(epoch_now) 697 | early_counter_max = early_stop_counter_schedule(epoch_now) 698 | if args.mixup: 699 | delta, iter_counts = attack_pgd(model, X, y, epsilon_sche, pgd_alpha_sche, args.attack_iters, restarts_sche, args.norm, 700 | early_stop=args.earlystopPGD, early_stop_pgd_max=early_counter_max, 701 | mixup=True, y_a=y_a, y_b=y_b, lam=lam) 702 | else: 703 | delta, iter_counts = attack_pgd(model, X, y, epsilon_sche, pgd_alpha_sche, args.attack_iters, restarts_sche, args.norm, 704 | early_stop=args.earlystopPGD, early_stop_pgd_max=early_counter_max, multitarget=args.use_multitarget, 705 | use_DLRloss=args.use_DLRloss, use_CWloss=args.use_CWloss, 706 | epoch=epoch_now, totalepoch=args.epochs, gamma=0.8, 707 | use_adaptive=args.use_adaptive, s_HE=args.s_FN, 708 | fast_better=args.fast_better, BNeval=args.BNeval) 709 | 710 | record_iter = torch.cat((record_iter, iter_counts)) 711 | 712 | delta = delta.detach() 713 | elif args.attack == 'fgsm': 714 | delta,_ = attack_pgd(model, X, y, epsilon, args.fgsm_alpha*epsilon, 1, 1, args.norm, fast_better=args.fast_better) 715 | delta = delta.detach() 716 | # Standard training 717 | elif args.attack == 'none': 718 | delta = torch.zeros_like(X) 719 | 720 | 721 | adv_input = normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)) 722 | adv_input.requires_grad = True 723 | robust_output = model(adv_input) 724 | 725 | 726 | 727 | 728 | # Training losses 729 | if args.mixup: 730 | clean_input = normalize(X) 731 | clean_input.requires_grad = True 732 | output = model(clean_input) 733 | robust_loss = mixup_criterion(criterion, robust_output, y_a, y_b, lam) 734 | 735 | elif args.mixture: 736 | clean_input = normalize(X) 737 | clean_input.requires_grad = True 738 | output = model(clean_input) 739 | robust_loss = args.mixture_alpha * criterion(robust_output, y) + (1-args.mixture_alpha) * criterion(output, y) 740 | 741 | else: 742 | clean_input = normalize(X) 743 | clean_input.requires_grad = True 744 | output = model(clean_input) 745 | if args.focalloss: 746 | criterion_nonreduct = nn.CrossEntropyLoss(reduction='none') 747 | robust_confidence = F.softmax(robust_output, dim=1)[:, y].detach() 748 | robust_loss = (criterion_nonreduct(robust_output, y) * ((1. - robust_confidence) ** args.focallosslambda)).mean() 749 | 750 | elif args.use_DLRloss: 751 | beta_ = 0.8 * epoch_now / args.epochs 752 | robust_loss = (1. - beta_) * F.cross_entropy(robust_output, y) + beta_ * dlr_loss(robust_output, y) 753 | 754 | elif args.use_CWloss: 755 | beta_ = 0.8 * epoch_now / args.epochs 756 | robust_loss = (1. - beta_) * F.cross_entropy(robust_output, y) + beta_ * CW_loss(robust_output, y) 757 | 758 | elif args.use_FNandWN: 759 | #print('use FN and WN with margin') 760 | robust_loss = criterion(args.s_FN * robust_output - onehot_target_withmargin_HE, y) 761 | 762 | else: 763 | robust_loss = criterion(robust_output, y) 764 | 765 | 766 | 767 | 768 | if args.l1: 769 | for name,param in model.named_parameters(): 770 | if 'bn' not in name and 'bias' not in name: 771 | robust_loss += args.l1*param.abs().sum() 772 | 773 | 774 | opt.zero_grad() 775 | robust_loss.backward() 776 | opt.step() 777 | 778 | 779 | clean_input = normalize(X) 780 | clean_input.requires_grad = True 781 | output = model(clean_input) 782 | if args.mixup: 783 | loss = mixup_criterion(criterion, output, y_a, y_b, lam) 784 | else: 785 | loss = criterion(output, y) 786 | 787 | # Get the gradient norm values 788 | input_grads = torch.autograd.grad(loss, clean_input, create_graph=False)[0] 789 | 790 | # Record the statstic values 791 | train_robust_loss += robust_loss.item() * y.size(0) 792 | train_robust_acc += (robust_output.max(1)[1] == y).sum().item() 793 | train_loss += loss.item() * y.size(0) 794 | train_acc += (output.max(1)[1] == y).sum().item() 795 | train_n += y.size(0) 796 | train_grad += input_grads.abs().sum() 797 | 798 | train_time = time.time() 799 | if args.earlystopPGD: 800 | print('Iter mean: ', record_iter.mean().item(), ' Iter std: ', record_iter.std().item()) 801 | print('Learning rate: ', lr) 802 | #print('Eps: ', epsilon_sche) 803 | # Evaluate on test data 804 | model.eval() 805 | test_loss = 0 806 | test_acc = 0 807 | test_robust_loss = 0 808 | test_robust_acc = 0 809 | test_n = 0 810 | test_grad = 0 811 | for i, batch in enumerate(test_batches): 812 | X, y = batch['input'], batch['target'] 813 | 814 | # Random initialization 815 | if args.attack == 'none': 816 | delta = torch.zeros_like(X) 817 | else: 818 | delta, _ = attack_pgd(model, X, y, test_epsilon, test_pgd_alpha, args.attack_iters, args.restarts, args.norm, early_stop=False) 819 | delta = delta.detach() 820 | 821 | adv_input = normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit)) 822 | adv_input.requires_grad = True 823 | robust_output = model(adv_input) 824 | robust_loss = criterion(robust_output, y) 825 | 826 | clean_input = normalize(X) 827 | clean_input.requires_grad = True 828 | output = model(clean_input) 829 | loss = criterion(output, y) 830 | 831 | # Get the gradient norm values 832 | input_grads = torch.autograd.grad(loss, clean_input, create_graph=False)[0] 833 | 834 | test_robust_loss += robust_loss.item() * y.size(0) 835 | test_robust_acc += (robust_output.max(1)[1] == y).sum().item() 836 | test_loss += loss.item() * y.size(0) 837 | test_acc += (output.max(1)[1] == y).sum().item() 838 | test_n += y.size(0) 839 | test_grad += input_grads.abs().sum() 840 | 841 | test_time = time.time() 842 | 843 | if args.val: 844 | val_loss = 0 845 | val_acc = 0 846 | val_robust_loss = 0 847 | val_robust_acc = 0 848 | val_n = 0 849 | for i, batch in enumerate(val_batches): 850 | X, y = batch['input'], batch['target'] 851 | 852 | # Random initialization 853 | if args.attack == 'none': 854 | delta = torch.zeros_like(X) 855 | else: 856 | delta, _ = attack_pgd(model, X, y, test_epsilon, pgd_alpha, args.attack_iters, args.restarts, args.norm, early_stop=False) 857 | delta = delta.detach() 858 | 859 | robust_output = model(normalize(torch.clamp(X + delta[:X.size(0)], min=lower_limit, max=upper_limit))) 860 | robust_loss = criterion(robust_output, y) 861 | 862 | output = model(normalize(X)) 863 | loss = criterion(output, y) 864 | 865 | val_robust_loss += robust_loss.item() * y.size(0) 866 | val_robust_acc += (robust_output.max(1)[1] == y).sum().item() 867 | val_loss += loss.item() * y.size(0) 868 | val_acc += (output.max(1)[1] == y).sum().item() 869 | val_n += y.size(0) 870 | 871 | if not args.eval: 872 | # logger.info('%d \t %.1f \t %.1f \t %.4f \t %.4f \t %.4f \t %.4f \t %.4f \t %.4f \t %.4f %.4f \t %.4f \t %.4f', 873 | # epoch, train_time - start_time, test_time - train_time, lr, 874 | # train_loss/train_n, train_grad/train_n, train_acc/train_n, train_robust_loss/train_n, train_robust_acc/train_n, 875 | # test_loss/test_n, test_acc/test_n, test_robust_loss/test_n, test_robust_acc/test_n) 876 | logger.info('%d \t %.4f \t %.4f \t %.4f \t %.4f', 877 | epoch, train_acc/train_n, train_robust_acc/train_n, test_acc/test_n, test_robust_acc/test_n) 878 | 879 | # Save results 880 | train_loss_record.append(train_loss/train_n) 881 | train_acc_record.append(train_acc/train_n) 882 | train_robust_loss_record.append(train_robust_loss/train_n) 883 | train_robust_acc_record.append(train_robust_acc/train_n) 884 | train_grad_record.append(train_grad/train_n) 885 | 886 | np.savetxt(args.fname+'/train_loss_record.txt', np.array(train_loss_record)) 887 | np.savetxt(args.fname+'/train_acc_record.txt', np.array(train_acc_record)) 888 | np.savetxt(args.fname+'/train_robust_loss_record.txt', np.array(train_robust_loss_record)) 889 | np.savetxt(args.fname+'/train_robust_acc_record.txt', np.array(train_robust_acc_record)) 890 | np.savetxt(args.fname+'/train_grad_record.txt', np.array(train_grad_record)) 891 | 892 | test_loss_record.append(test_loss/test_n) 893 | test_acc_record.append(test_acc/test_n) 894 | test_robust_loss_record.append(test_robust_loss/test_n) 895 | test_robust_acc_record.append(test_robust_acc/test_n) 896 | test_grad_record.append(test_grad/test_n) 897 | 898 | np.savetxt(args.fname+'/test_loss_record.txt', np.array(test_loss_record)) 899 | np.savetxt(args.fname+'/test_acc_record.txt', np.array(test_acc_record)) 900 | np.savetxt(args.fname+'/test_robust_loss_record.txt', np.array(test_robust_loss_record)) 901 | np.savetxt(args.fname+'/test_robust_acc_record.txt', np.array(test_robust_acc_record)) 902 | np.savetxt(args.fname+'/test_grad_record.txt', np.array(test_grad_record)) 903 | 904 | 905 | 906 | 907 | if args.val: 908 | logger.info('validation %.4f \t %.4f \t %.4f \t %.4f', 909 | val_loss/val_n, val_acc/val_n, val_robust_loss/val_n, val_robust_acc/val_n) 910 | 911 | if val_robust_acc/val_n > best_val_robust_acc: 912 | torch.save({ 913 | 'state_dict':model.state_dict(), 914 | 'test_robust_acc':test_robust_acc/test_n, 915 | 'test_robust_loss':test_robust_loss/test_n, 916 | 'test_loss':test_loss/test_n, 917 | 'test_acc':test_acc/test_n, 918 | 'val_robust_acc':val_robust_acc/val_n, 919 | 'val_robust_loss':val_robust_loss/val_n, 920 | 'val_loss':val_loss/val_n, 921 | 'val_acc':val_acc/val_n, 922 | }, os.path.join(args.fname, f'model_val.pth')) 923 | best_val_robust_acc = val_robust_acc/val_n 924 | 925 | # save checkpoint 926 | if epoch > 99 or (epoch+1) % args.chkpt_iters == 0 or epoch+1 == epochs: 927 | torch.save(model.state_dict(), os.path.join(args.fname, f'model_{epoch}.pth')) 928 | torch.save(opt.state_dict(), os.path.join(args.fname, f'opt_{epoch}.pth')) 929 | 930 | # save best 931 | if test_robust_acc/test_n > best_test_robust_acc: 932 | torch.save({ 933 | 'state_dict':model.state_dict(), 934 | 'test_robust_acc':test_robust_acc/test_n, 935 | 'test_robust_loss':test_robust_loss/test_n, 936 | 'test_loss':test_loss/test_n, 937 | 'test_acc':test_acc/test_n, 938 | }, os.path.join(args.fname, f'model_best.pth')) 939 | best_test_robust_acc = test_robust_acc/test_n 940 | else: 941 | logger.info('%d \t %.1f \t \t %.1f \t \t %.4f \t %.4f \t %.4f \t %.4f \t \t %.4f \t \t %.4f \t %.4f \t %.4f \t \t %.4f', 942 | epoch, train_time - start_time, test_time - train_time, -1, 943 | -1, -1, -1, -1, 944 | test_loss/test_n, test_acc/test_n, test_robust_loss/test_n, test_robust_acc/test_n) 945 | return 946 | 947 | 948 | if __name__ == "__main__": 949 | main() 950 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import namedtuple 3 | import torch 4 | from torch import nn 5 | import torchvision 6 | from torch.optim.optimizer import Optimizer, required 7 | 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | ################################################################ 11 | ## Components from https://github.com/davidcpage/cifar10-fast ## 12 | ################################################################ 13 | 14 | ##################### 15 | ## data preprocessing 16 | ##################### 17 | 18 | cifar10_mean = (0.4914, 0.4822, 0.4465) # equals np.mean(train_set.train_data, axis=(0,1,2))/255 19 | cifar10_std = (0.2471, 0.2435, 0.2616) # equals np.std(train_set.train_data, axis=(0,1,2))/255 20 | 21 | def pad(x, border=4): 22 | return np.pad(x, [(0, 0), (border, border), (border, border), (0, 0)], mode='reflect') 23 | 24 | def transpose(x, source='NHWC', target='NCHW'): 25 | return x.transpose([source.index(d) for d in target]) 26 | 27 | ##################### 28 | ## data augmentation 29 | ##################### 30 | 31 | class Crop(namedtuple('Crop', ('h', 'w'))): 32 | def __call__(self, x, x0, y0): 33 | return x[:,y0:y0+self.h,x0:x0+self.w] 34 | 35 | def options(self, x_shape): 36 | C, H, W = x_shape 37 | return {'x0': range(W+1-self.w), 'y0': range(H+1-self.h)} 38 | 39 | def output_shape(self, x_shape): 40 | C, H, W = x_shape 41 | return (C, self.h, self.w) 42 | 43 | class FlipLR(namedtuple('FlipLR', ())): 44 | def __call__(self, x, choice): 45 | return x[:, :, ::-1].copy() if choice else x 46 | 47 | def options(self, x_shape): 48 | return {'choice': [True, False]} 49 | 50 | class Cutout(namedtuple('Cutout', ('h', 'w'))): 51 | def __call__(self, x, x0, y0): 52 | x = x.copy() 53 | x[:,y0:y0+self.h,x0:x0+self.w].fill(0.0) 54 | return x 55 | 56 | def options(self, x_shape): 57 | C, H, W = x_shape 58 | return {'x0': range(W+1-self.w), 'y0': range(H+1-self.h)} 59 | 60 | 61 | class Transform(): 62 | def __init__(self, dataset, transforms): 63 | self.dataset, self.transforms = dataset, transforms 64 | self.choices = None 65 | 66 | def __len__(self): 67 | return len(self.dataset) 68 | 69 | def __getitem__(self, index): 70 | data, labels = self.dataset[index] 71 | for choices, f in zip(self.choices, self.transforms): 72 | args = {k: v[index] for (k,v) in choices.items()} 73 | data = f(data, **args) 74 | return data, labels 75 | 76 | def set_random_choices(self): 77 | self.choices = [] 78 | x_shape = self.dataset[0][0].shape 79 | N = len(self) 80 | for t in self.transforms: 81 | options = t.options(x_shape) 82 | x_shape = t.output_shape(x_shape) if hasattr(t, 'output_shape') else x_shape 83 | self.choices.append({k:np.random.choice(v, size=N) for (k,v) in options.items()}) 84 | 85 | ##################### 86 | ## dataset 87 | ##################### 88 | 89 | def cifar10(root): 90 | train_set = torchvision.datasets.CIFAR10(root=root, train=True, download=True) 91 | test_set = torchvision.datasets.CIFAR10(root=root, train=False, download=True) 92 | return { 93 | 'train': {'data': train_set.data, 'labels': train_set.targets}, 94 | 'test': {'data': test_set.data, 'labels': test_set.targets} 95 | } 96 | 97 | ##################### 98 | ## data loading 99 | ##################### 100 | 101 | class Batches(): 102 | def __init__(self, dataset, batch_size, shuffle, set_random_choices=False, num_workers=0, drop_last=False): 103 | self.dataset = dataset 104 | self.batch_size = batch_size 105 | self.set_random_choices = set_random_choices 106 | self.dataloader = torch.utils.data.DataLoader( 107 | dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=shuffle, drop_last=drop_last 108 | ) 109 | 110 | def __iter__(self): 111 | if self.set_random_choices: 112 | self.dataset.set_random_choices() 113 | return ({'input': x.to(device).half(), 'target': y.to(device).long()} for (x,y) in self.dataloader) 114 | 115 | def __len__(self): 116 | return len(self.dataloader) 117 | 118 | ##################### 119 | ## new optimizer 120 | ##################### 121 | 122 | 123 | class SGD_GCC(Optimizer): 124 | 125 | def __init__(self, params, lr=required, momentum=0, dampening=0, 126 | weight_decay=0, nesterov=False): 127 | if lr is not required and lr < 0.0: 128 | raise ValueError("Invalid learning rate: {}".format(lr)) 129 | if momentum < 0.0: 130 | raise ValueError("Invalid momentum value: {}".format(momentum)) 131 | if weight_decay < 0.0: 132 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 133 | 134 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 135 | weight_decay=weight_decay, nesterov=nesterov) 136 | if nesterov and (momentum <= 0 or dampening != 0): 137 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 138 | super(SGD_GCC, self).__init__(params, defaults) 139 | 140 | def __setstate__(self, state): 141 | super(SGD_GCC, self).__setstate__(state) 142 | for group in self.param_groups: 143 | group.setdefault('nesterov', False) 144 | 145 | def step(self, closure=None): 146 | """Performs a single optimization step. 147 | Arguments: 148 | closure (callable, optional): A closure that reevaluates the model 149 | and returns the loss. 150 | """ 151 | loss = None 152 | if closure is not None: 153 | loss = closure() 154 | 155 | for group in self.param_groups: 156 | weight_decay = group['weight_decay'] 157 | momentum = group['momentum'] 158 | dampening = group['dampening'] 159 | nesterov = group['nesterov'] 160 | 161 | for p in group['params']: 162 | if p.grad is None: 163 | continue 164 | d_p = p.grad.data 165 | 166 | if weight_decay != 0: 167 | d_p.add_(weight_decay, p.data) 168 | 169 | #GC operation for Conv layers 170 | if len(list(d_p.size()))>3: 171 | d_p.add_(-d_p.mean(dim = tuple(range(1,len(list(d_p.size())))), keepdim = True)) 172 | 173 | if momentum != 0: 174 | param_state = self.state[p] 175 | if 'momentum_buffer' not in param_state: 176 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 177 | else: 178 | buf = param_state['momentum_buffer'] 179 | buf.mul_(momentum).add_(1 - dampening, d_p) 180 | if nesterov: 181 | d_p = d_p.add(momentum, buf) 182 | else: 183 | d_p = buf 184 | 185 | p.data.add_(-group['lr'], d_p) 186 | 187 | return loss 188 | 189 | class SGD_GC(Optimizer): 190 | 191 | def __init__(self, params, lr=required, momentum=0, dampening=0, 192 | weight_decay=0, nesterov=False): 193 | if lr is not required and lr < 0.0: 194 | raise ValueError("Invalid learning rate: {}".format(lr)) 195 | if momentum < 0.0: 196 | raise ValueError("Invalid momentum value: {}".format(momentum)) 197 | if weight_decay < 0.0: 198 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 199 | 200 | defaults = dict(lr=lr, momentum=momentum, dampening=dampening, 201 | weight_decay=weight_decay, nesterov=nesterov) 202 | if nesterov and (momentum <= 0 or dampening != 0): 203 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 204 | super(SGD_GC, self).__init__(params, defaults) 205 | 206 | def __setstate__(self, state): 207 | super(SGD_GC, self).__setstate__(state) 208 | for group in self.param_groups: 209 | group.setdefault('nesterov', False) 210 | 211 | def step(self, closure=None): 212 | """Performs a single optimization step. 213 | Arguments: 214 | closure (callable, optional): A closure that reevaluates the model 215 | and returns the loss. 216 | """ 217 | loss = None 218 | if closure is not None: 219 | loss = closure() 220 | 221 | for group in self.param_groups: 222 | weight_decay = group['weight_decay'] 223 | momentum = group['momentum'] 224 | dampening = group['dampening'] 225 | nesterov = group['nesterov'] 226 | 227 | for p in group['params']: 228 | if p.grad is None: 229 | continue 230 | d_p = p.grad.data 231 | 232 | if weight_decay != 0: 233 | d_p.add_(weight_decay, p.data) 234 | 235 | #GC operation for Conv layers and FC layers 236 | if len(list(d_p.size()))>1: 237 | d_p.add_(-d_p.mean(dim = tuple(range(1,len(list(d_p.size())))), keepdim = True)) 238 | 239 | if momentum != 0: 240 | param_state = self.state[p] 241 | if 'momentum_buffer' not in param_state: 242 | buf = param_state['momentum_buffer'] = torch.clone(d_p).detach() 243 | else: 244 | buf = param_state['momentum_buffer'] 245 | buf.mul_(momentum).add_(1 - dampening, d_p) 246 | if nesterov: 247 | d_p = d_p.add(momentum, buf) 248 | else: 249 | d_p = buf 250 | 251 | p.data.add_(-group['lr'], d_p) 252 | 253 | return loss 254 | -------------------------------------------------------------------------------- /utils_plus.py: -------------------------------------------------------------------------------- 1 | #import apex.amp as amp 2 | import torch 3 | import torch.nn.functional as F 4 | from torchvision import datasets, transforms 5 | from torch.utils.data.sampler import SubsetRandomSampler 6 | import numpy as np 7 | 8 | upper_limit, lower_limit = 1, 0 9 | 10 | def clamp(X, lower_limit, upper_limit): 11 | return torch.max(torch.min(X, upper_limit), lower_limit) 12 | 13 | def get_loaders(dir_, batch_size, DATASET='CIFAR10'): 14 | train_transform = transforms.Compose([ 15 | transforms.RandomCrop(32, padding=4), 16 | transforms.RandomHorizontalFlip(), 17 | transforms.ToTensor() 18 | ]) 19 | test_transform = transforms.Compose([ 20 | transforms.ToTensor() 21 | ]) 22 | num_workers = 2 23 | 24 | if DATASET == 'CIFAR10': 25 | train_dataset = datasets.CIFAR10( 26 | dir_, train=True, transform=train_transform, download=True) 27 | test_dataset = datasets.CIFAR10( 28 | dir_, train=False, transform=test_transform, download=True) 29 | elif DATASET == 'CIFAR100': 30 | train_dataset = datasets.CIFAR100( 31 | dir_, train=True, transform=train_transform, download=True) 32 | test_dataset = datasets.CIFAR100( 33 | dir_, train=False, transform=test_transform, download=True) 34 | 35 | train_loader = torch.utils.data.DataLoader( 36 | dataset=train_dataset, 37 | batch_size=batch_size, 38 | shuffle=True, 39 | pin_memory=True, 40 | num_workers=num_workers, 41 | ) 42 | test_loader = torch.utils.data.DataLoader( 43 | dataset=test_dataset, 44 | batch_size=batch_size, 45 | shuffle=False, 46 | pin_memory=True, 47 | num_workers=2, 48 | ) 49 | return train_loader, test_loader 50 | 51 | def CW_loss(x, y): 52 | x_sorted, ind_sorted = x.sort(dim=1) 53 | ind = (ind_sorted[:, -1] == y).float() 54 | 55 | loss_value = -(x[np.arange(x.shape[0]), y] - x_sorted[:, -2] * ind - x_sorted[:, -1] * (1. - ind)) 56 | return loss_value.mean() 57 | 58 | def attack_pgd(model, X, y, epsilon, alpha, attack_iters, restarts, use_CWloss=False, normalize=None): 59 | max_loss = torch.zeros(y.shape[0]).cuda() 60 | max_delta = torch.zeros_like(X).cuda() 61 | for _ in range(restarts): 62 | delta = torch.zeros_like(X).cuda() 63 | delta.uniform_(-epsilon, epsilon) 64 | delta.data = clamp(delta, lower_limit - X, upper_limit - X) 65 | delta.requires_grad = True 66 | for _ in range(attack_iters): 67 | output = model(normalize(X + delta)) 68 | index = torch.where(output.max(1)[1] == y) 69 | if len(index[0]) == 0: 70 | break 71 | if use_CWloss: 72 | loss = CW_loss(output, y) 73 | else: 74 | loss = F.cross_entropy(output, y) 75 | loss.backward() 76 | grad = delta.grad.detach() 77 | d = delta[index[0], :, :, :] 78 | g = grad[index[0], :, :, :] 79 | d = torch.clamp(d + alpha * torch.sign(g), -epsilon, epsilon) 80 | d = clamp(d, lower_limit - X[index[0], :, :, :], upper_limit - X[index[0], :, :, :]) 81 | delta.data[index[0], :, :, :] = d 82 | delta.grad.zero_() 83 | all_loss = F.cross_entropy(model(normalize(X + delta)), y, reduction='none').detach() 84 | max_delta[all_loss >= max_loss] = delta.detach()[all_loss >= max_loss] 85 | max_loss = torch.max(max_loss, all_loss) 86 | return max_delta 87 | 88 | 89 | def evaluate_pgd(test_loader, model, attack_iters, restarts, eps=8, step=2, use_CWloss=False, normalize=None): 90 | epsilon = eps / 255. 91 | alpha = step / 255. 92 | pgd_loss = 0 93 | pgd_acc = 0 94 | n = 0 95 | model.eval() 96 | for i, (X, y) in enumerate(test_loader): 97 | X, y = X.cuda(), y.cuda() 98 | pgd_delta = attack_pgd(model, X, y, epsilon, alpha, attack_iters, restarts, use_CWloss=use_CWloss, normalize=normalize) 99 | with torch.no_grad(): 100 | output = model(normalize(X + pgd_delta)) 101 | loss = F.cross_entropy(output, y) 102 | pgd_loss += loss.item() * y.size(0) 103 | pgd_acc += (output.max(1)[1] == y).sum().item() 104 | n += y.size(0) 105 | return pgd_loss/n, pgd_acc/n 106 | 107 | 108 | def evaluate_standard(test_loader, model, normalize=None): 109 | test_loss = 0 110 | test_acc = 0 111 | n = 0 112 | model.eval() 113 | with torch.no_grad(): 114 | for i, (X, y) in enumerate(test_loader): 115 | X, y = X.cuda(), y.cuda() 116 | output = model(normalize(X)) 117 | loss = F.cross_entropy(output, y) 118 | test_loss += loss.item() * y.size(0) 119 | test_acc += (output.max(1)[1] == y).sum().item() 120 | n += y.size(0) 121 | return test_loss/n, test_acc/n 122 | -------------------------------------------------------------------------------- /wideresnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | class BasicBlock(nn.Module): 6 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0, activation='ReLU', softplus_beta=1): 7 | super(BasicBlock, self).__init__() 8 | self.bn1 = nn.BatchNorm2d(in_planes) 9 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | self.bn2 = nn.BatchNorm2d(out_planes) 12 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 13 | padding=1, bias=False) 14 | if activation == 'ReLU': 15 | self.relu1 = nn.ReLU(inplace=True) 16 | self.relu2 = nn.ReLU(inplace=True) 17 | print('R') 18 | elif activation == 'Softplus': 19 | self.relu1 = nn.Softplus(beta=softplus_beta, threshold=20) 20 | self.relu2 = nn.Softplus(beta=softplus_beta, threshold=20) 21 | print('S') 22 | elif activation == 'GELU': 23 | self.relu1 = nn.GELU() 24 | self.relu2 = nn.GELU() 25 | print('G') 26 | elif activation == 'ELU': 27 | self.relu1 = nn.ELU(alpha=1.0, inplace=True) 28 | self.relu2 = nn.ELU(alpha=1.0, inplace=True) 29 | print('E') 30 | 31 | self.droprate = dropRate 32 | self.equalInOut = (in_planes == out_planes) 33 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 34 | padding=0, bias=False) or None 35 | 36 | def forward(self, x): 37 | if not self.equalInOut: 38 | x = self.relu1(self.bn1(x)) 39 | else: 40 | out = self.relu1(self.bn1(x)) 41 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 42 | if self.droprate > 0: 43 | out = F.dropout(out, p=self.droprate, training=self.training) 44 | out = self.conv2(out) 45 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 46 | 47 | 48 | class NetworkBlock(nn.Module): 49 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0, activation='ReLU', softplus_beta=1): 50 | super(NetworkBlock, self).__init__() 51 | self.activation = activation 52 | self.softplus_beta = softplus_beta 53 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 54 | 55 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 56 | layers = [] 57 | for i in range(int(nb_layers)): 58 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate, 59 | self.activation, self.softplus_beta)) 60 | return nn.Sequential(*layers) 61 | 62 | def forward(self, x): 63 | return self.layer(x) 64 | 65 | 66 | class WideResNet(nn.Module): 67 | def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0.0, normalize=False, activation='ReLU', softplus_beta=1): 68 | super(WideResNet, self).__init__() 69 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 70 | assert ((depth - 4) % 6 == 0) 71 | n = (depth - 4) / 6 72 | block = BasicBlock 73 | self.normalize = normalize 74 | #self.scale = scale 75 | # 1st conv before any network block 76 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 77 | padding=1, bias=False) 78 | # 1st block 79 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activation=activation, softplus_beta=softplus_beta) 80 | # 1st sub-block 81 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate, activation=activation, softplus_beta=softplus_beta) 82 | # 2nd block 83 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate, activation=activation, softplus_beta=softplus_beta) 84 | # 3rd block 85 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate, activation=activation, softplus_beta=softplus_beta) 86 | # global average pooling and classifier 87 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 88 | 89 | if activation == 'ReLU': 90 | self.relu = nn.ReLU(inplace=True) 91 | elif activation == 'Softplus': 92 | self.relu = nn.Softplus(beta=softplus_beta, threshold=20) 93 | elif activation == 'GELU': 94 | self.relu = nn.GELU() 95 | elif activation == 'ELU': 96 | self.relu = nn.ELU(alpha=1.0, inplace=True) 97 | print('Use activation of ' + activation) 98 | 99 | if self.normalize: 100 | self.fc = nn.Linear(nChannels[3], num_classes, bias = False) 101 | else: 102 | self.fc = nn.Linear(nChannels[3], num_classes) 103 | self.nChannels = nChannels[3] 104 | 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 108 | m.weight.data.normal_(0, math.sqrt(2. / n)) 109 | elif isinstance(m, nn.BatchNorm2d): 110 | m.weight.data.fill_(1) 111 | m.bias.data.zero_() 112 | elif isinstance(m, nn.Linear) and not self.normalize: 113 | m.bias.data.zero_() 114 | 115 | def forward(self, x): 116 | out = self.conv1(x) 117 | out = self.block1(out) 118 | out = self.block2(out) 119 | out = self.block3(out) 120 | out = self.relu(self.bn1(out)) 121 | out = F.avg_pool2d(out, 8) 122 | out = out.view(-1, self.nChannels) 123 | if self.normalize: 124 | out = F.normalize(out, p=2, dim=1) 125 | for _, module in self.fc.named_modules(): 126 | if isinstance(module, nn.Linear): 127 | module.weight.data = F.normalize(module.weight, p=2, dim=1) 128 | return self.fc(out) --------------------------------------------------------------------------------