├── LICENSE ├── README.md ├── assets ├── augmix.gif └── pseudocode.png ├── augment_and_mix.py ├── augmentations.py ├── cifar.py ├── imagenet.py ├── models └── cifar │ └── allconv.py ├── requirements.txt └── third_party ├── ResNeXt_DenseNet ├── LICENSE ├── METADATA ├── __init__.py └── models │ ├── __init__.py │ ├── densenet.py │ └── resnext.py ├── WideResNet_pytorch ├── LICENSE ├── METADATA ├── __init__.py └── wideresnet.py └── __init__.py /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AugMix 2 | 3 | 4 | 5 | ## Introduction 6 | 7 | We propose AugMix, a data processing technique that mixes augmented images and 8 | enforces consistent embeddings of the augmented images, which results in 9 | increased robustness and improved uncertainty calibration. AugMix does not 10 | require tuning to work correctly, as with random cropping or CutOut, and thus 11 | enables plug-and-play data augmentation. AugMix significantly improves 12 | robustness and uncertainty measures on challenging image classification 13 | benchmarks, closing the gap between previous methods and the best possible 14 | performance by more than half in some cases. With AugMix, we obtain 15 | state-of-the-art on ImageNet-C, ImageNet-P and in uncertainty estimation when 16 | the train and test distribution do not match. 17 | 18 | For more details please see our [ICLR 2020 paper](https://arxiv.org/pdf/1912.02781.pdf). 19 | 20 | ## Pseudocode 21 | 22 | 23 | 24 | ## Contents 25 | 26 | This directory includes a reference implementation in NumPy of the augmentation 27 | method used in AugMix in `augment_and_mix.py`. The full AugMix method also adds 28 | a Jensen-Shanon Divergence consistency loss to enforce consistent predictions 29 | between two different augmentations of the input image and the clean image 30 | itself. 31 | 32 | We also include PyTorch re-implementations of AugMix on both CIFAR-10/100 and 33 | ImageNet in `cifar.py` and `imagenet.py` respectively, which both support 34 | training and evaluation on CIFAR-10/100-C and ImageNet-C. 35 | 36 | ## Requirements 37 | 38 | * numpy>=1.15.0 39 | * Pillow>=6.1.0 40 | * torch==1.2.0 41 | * torchvision==0.2.2 42 | 43 | ## Setup 44 | 45 | 1. Install PyTorch and other required python libraries with: 46 | 47 | ``` 48 | pip install -r requirements.txt 49 | ``` 50 | 51 | 2. Download CIFAR-10-C and CIFAR-100-C datasets with: 52 | 53 | ``` 54 | mkdir -p ./data/cifar 55 | curl -O https://zenodo.org/record/2535967/files/CIFAR-10-C.tar 56 | curl -O https://zenodo.org/record/3555552/files/CIFAR-100-C.tar 57 | tar -xvf CIFAR-100-C.tar -C data/cifar/ 58 | tar -xvf CIFAR-10-C.tar -C data/cifar/ 59 | ``` 60 | 61 | 3. Download ImageNet-C with: 62 | 63 | ``` 64 | mkdir -p ./data/imagenet/imagenet-c 65 | curl -O https://zenodo.org/record/2235448/files/blur.tar 66 | curl -O https://zenodo.org/record/2235448/files/digital.tar 67 | curl -O https://zenodo.org/record/2235448/files/noise.tar 68 | curl -O https://zenodo.org/record/2235448/files/weather.tar 69 | tar -xvf blur.tar -C data/imagenet/imagenet-c 70 | tar -xvf digital.tar -C data/imagenet/imagenet-c 71 | tar -xvf noise.tar -C data/imagenet/imagenet-c 72 | tar -xvf weather.tar -C data/imagenet/imagenet-c 73 | ``` 74 | 75 | ## Usage 76 | 77 | The Jensen-Shannon Divergence loss term may be disabled for faster training at the cost of slightly lower performance by adding the flag `--no-jsd`. 78 | 79 | Training recipes used in our paper: 80 | 81 | WRN: `python cifar.py` 82 | 83 | AllConv: `python cifar.py -m allconv` 84 | 85 | ResNeXt: `python cifar.py -m resnext -e 200` 86 | 87 | DenseNet: `python cifar.py -m densenet -e 200 -wd 0.0001` 88 | 89 | ResNet-50: `python imagenet.py ` 90 | 91 | ## Pretrained weights 92 | 93 | Weights for a ResNet-50 ImageNet classifier trained with AugMix for 180 epochs are available 94 | [here](https://drive.google.com/file/d/1z-1V3rdFiwqSECz7Wkmn4VJVefJGJGiF/view?usp=sharing). 95 | 96 | This model has a 65.3 mean Corruption Error (mCE) and a 77.53% top-1 accuracy on clean ImageNet data. 97 | 98 | ## Citation 99 | 100 | If you find this useful for your work, please consider citing 101 | 102 | ``` 103 | @article{hendrycks2020augmix, 104 | title={{AugMix}: A Simple Data Processing Method to Improve Robustness and Uncertainty}, 105 | author={Hendrycks, Dan and Mu, Norman and Cubuk, Ekin D. and Zoph, Barret and Gilmer, Justin and Lakshminarayanan, Balaji}, 106 | journal={Proceedings of the International Conference on Learning Representations (ICLR)}, 107 | year={2020} 108 | } 109 | ``` 110 | -------------------------------------------------------------------------------- /assets/augmix.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/augmix/9b9824c7c19bf7e72df2d085d97b99b3bfb00ba4/assets/augmix.gif -------------------------------------------------------------------------------- /assets/pseudocode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/augmix/9b9824c7c19bf7e72df2d085d97b99b3bfb00ba4/assets/pseudocode.png -------------------------------------------------------------------------------- /augment_and_mix.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Reference implementation of AugMix's data augmentation method in numpy.""" 16 | import augmentations 17 | import numpy as np 18 | from PIL import Image 19 | 20 | # CIFAR-10 constants 21 | MEAN = [0.4914, 0.4822, 0.4465] 22 | STD = [0.2023, 0.1994, 0.2010] 23 | 24 | 25 | def normalize(image): 26 | """Normalize input image channel-wise to zero mean and unit variance.""" 27 | image = image.transpose(2, 0, 1) # Switch to channel-first 28 | mean, std = np.array(MEAN), np.array(STD) 29 | image = (image - mean[:, None, None]) / std[:, None, None] 30 | return image.transpose(1, 2, 0) 31 | 32 | 33 | def apply_op(image, op, severity): 34 | image = np.clip(image * 255., 0, 255).astype(np.uint8) 35 | pil_img = Image.fromarray(image) # Convert to PIL.Image 36 | pil_img = op(pil_img, severity) 37 | return np.asarray(pil_img) / 255. 38 | 39 | 40 | def augment_and_mix(image, severity=3, width=3, depth=-1, alpha=1.): 41 | """Perform AugMix augmentations and compute mixture. 42 | 43 | Args: 44 | image: Raw input image as float32 np.ndarray of shape (h, w, c) 45 | severity: Severity of underlying augmentation operators (between 1 to 10). 46 | width: Width of augmentation chain 47 | depth: Depth of augmentation chain. -1 enables stochastic depth uniformly 48 | from [1, 3] 49 | alpha: Probability coefficient for Beta and Dirichlet distributions. 50 | 51 | Returns: 52 | mixed: Augmented and mixed image. 53 | """ 54 | ws = np.float32( 55 | np.random.dirichlet([alpha] * width)) 56 | m = np.float32(np.random.beta(alpha, alpha)) 57 | 58 | mix = np.zeros_like(image) 59 | for i in range(width): 60 | image_aug = image.copy() 61 | d = depth if depth > 0 else np.random.randint(1, 4) 62 | for _ in range(d): 63 | op = np.random.choice(augmentations.augmentations) 64 | image_aug = apply_op(image_aug, op, severity) 65 | # Preprocessing commutes since all coefficients are convex 66 | mix += ws[i] * normalize(image_aug) 67 | 68 | mixed = (1 - m) * normalize(image) + m * mix 69 | return mixed 70 | 71 | -------------------------------------------------------------------------------- /augmentations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Base augmentations operators.""" 16 | 17 | import numpy as np 18 | from PIL import Image, ImageOps, ImageEnhance 19 | 20 | # ImageNet code should change this value 21 | IMAGE_SIZE = 32 22 | 23 | 24 | def int_parameter(level, maxval): 25 | """Helper function to scale `val` between 0 and maxval . 26 | 27 | Args: 28 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 29 | maxval: Maximum value that the operation can have. This will be scaled to 30 | level/PARAMETER_MAX. 31 | 32 | Returns: 33 | An int that results from scaling `maxval` according to `level`. 34 | """ 35 | return int(level * maxval / 10) 36 | 37 | 38 | def float_parameter(level, maxval): 39 | """Helper function to scale `val` between 0 and maxval. 40 | 41 | Args: 42 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 43 | maxval: Maximum value that the operation can have. This will be scaled to 44 | level/PARAMETER_MAX. 45 | 46 | Returns: 47 | A float that results from scaling `maxval` according to `level`. 48 | """ 49 | return float(level) * maxval / 10. 50 | 51 | 52 | def sample_level(n): 53 | return np.random.uniform(low=0.1, high=n) 54 | 55 | 56 | def autocontrast(pil_img, _): 57 | return ImageOps.autocontrast(pil_img) 58 | 59 | 60 | def equalize(pil_img, _): 61 | return ImageOps.equalize(pil_img) 62 | 63 | 64 | def posterize(pil_img, level): 65 | level = int_parameter(sample_level(level), 4) 66 | return ImageOps.posterize(pil_img, 4 - level) 67 | 68 | 69 | def rotate(pil_img, level): 70 | degrees = int_parameter(sample_level(level), 30) 71 | if np.random.uniform() > 0.5: 72 | degrees = -degrees 73 | return pil_img.rotate(degrees, resample=Image.BILINEAR) 74 | 75 | 76 | def solarize(pil_img, level): 77 | level = int_parameter(sample_level(level), 256) 78 | return ImageOps.solarize(pil_img, 256 - level) 79 | 80 | 81 | def shear_x(pil_img, level): 82 | level = float_parameter(sample_level(level), 0.3) 83 | if np.random.uniform() > 0.5: 84 | level = -level 85 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 86 | Image.AFFINE, (1, level, 0, 0, 1, 0), 87 | resample=Image.BILINEAR) 88 | 89 | 90 | def shear_y(pil_img, level): 91 | level = float_parameter(sample_level(level), 0.3) 92 | if np.random.uniform() > 0.5: 93 | level = -level 94 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 95 | Image.AFFINE, (1, 0, 0, level, 1, 0), 96 | resample=Image.BILINEAR) 97 | 98 | 99 | def translate_x(pil_img, level): 100 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3) 101 | if np.random.random() > 0.5: 102 | level = -level 103 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 104 | Image.AFFINE, (1, 0, level, 0, 1, 0), 105 | resample=Image.BILINEAR) 106 | 107 | 108 | def translate_y(pil_img, level): 109 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3) 110 | if np.random.random() > 0.5: 111 | level = -level 112 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), 113 | Image.AFFINE, (1, 0, 0, 0, 1, level), 114 | resample=Image.BILINEAR) 115 | 116 | 117 | # operation that overlaps with ImageNet-C's test set 118 | def color(pil_img, level): 119 | level = float_parameter(sample_level(level), 1.8) + 0.1 120 | return ImageEnhance.Color(pil_img).enhance(level) 121 | 122 | 123 | # operation that overlaps with ImageNet-C's test set 124 | def contrast(pil_img, level): 125 | level = float_parameter(sample_level(level), 1.8) + 0.1 126 | return ImageEnhance.Contrast(pil_img).enhance(level) 127 | 128 | 129 | # operation that overlaps with ImageNet-C's test set 130 | def brightness(pil_img, level): 131 | level = float_parameter(sample_level(level), 1.8) + 0.1 132 | return ImageEnhance.Brightness(pil_img).enhance(level) 133 | 134 | 135 | # operation that overlaps with ImageNet-C's test set 136 | def sharpness(pil_img, level): 137 | level = float_parameter(sample_level(level), 1.8) + 0.1 138 | return ImageEnhance.Sharpness(pil_img).enhance(level) 139 | 140 | 141 | augmentations = [ 142 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, 143 | translate_x, translate_y 144 | ] 145 | 146 | augmentations_all = [ 147 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, 148 | translate_x, translate_y, color, contrast, brightness, sharpness 149 | ] 150 | -------------------------------------------------------------------------------- /cifar.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Main script to launch AugMix training on CIFAR-10/100. 16 | 17 | Supports WideResNet, AllConv, ResNeXt models on CIFAR-10 and CIFAR-100 as well 18 | as evaluation on CIFAR-10-C and CIFAR-100-C. 19 | 20 | Example usage: 21 | `python cifar.py` 22 | """ 23 | from __future__ import print_function 24 | 25 | import argparse 26 | import os 27 | import shutil 28 | import time 29 | 30 | import augmentations 31 | from models.cifar.allconv import AllConvNet 32 | import numpy as np 33 | from third_party.ResNeXt_DenseNet.models.densenet import densenet 34 | from third_party.ResNeXt_DenseNet.models.resnext import resnext29 35 | from third_party.WideResNet_pytorch.wideresnet import WideResNet 36 | 37 | import torch 38 | import torch.backends.cudnn as cudnn 39 | import torch.nn.functional as F 40 | from torchvision import datasets 41 | from torchvision import transforms 42 | 43 | parser = argparse.ArgumentParser( 44 | description='Trains a CIFAR Classifier', 45 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 46 | parser.add_argument( 47 | '--dataset', 48 | type=str, 49 | default='cifar10', 50 | choices=['cifar10', 'cifar100'], 51 | help='Choose between CIFAR-10, CIFAR-100.') 52 | parser.add_argument( 53 | '--model', 54 | '-m', 55 | type=str, 56 | default='wrn', 57 | choices=['wrn', 'allconv', 'densenet', 'resnext'], 58 | help='Choose architecture.') 59 | # Optimization options 60 | parser.add_argument( 61 | '--epochs', '-e', type=int, default=100, help='Number of epochs to train.') 62 | parser.add_argument( 63 | '--learning-rate', 64 | '-lr', 65 | type=float, 66 | default=0.1, 67 | help='Initial learning rate.') 68 | parser.add_argument( 69 | '--batch-size', '-b', type=int, default=128, help='Batch size.') 70 | parser.add_argument('--eval-batch-size', type=int, default=1000) 71 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.') 72 | parser.add_argument( 73 | '--decay', 74 | '-wd', 75 | type=float, 76 | default=0.0005, 77 | help='Weight decay (L2 penalty).') 78 | # WRN Architecture options 79 | parser.add_argument( 80 | '--layers', default=40, type=int, help='total number of layers') 81 | parser.add_argument('--widen-factor', default=2, type=int, help='Widen factor') 82 | parser.add_argument( 83 | '--droprate', default=0.0, type=float, help='Dropout probability') 84 | # AugMix options 85 | parser.add_argument( 86 | '--mixture-width', 87 | default=3, 88 | type=int, 89 | help='Number of augmentation chains to mix per augmented example') 90 | parser.add_argument( 91 | '--mixture-depth', 92 | default=-1, 93 | type=int, 94 | help='Depth of augmentation chains. -1 denotes stochastic depth in [1, 3]') 95 | parser.add_argument( 96 | '--aug-severity', 97 | default=3, 98 | type=int, 99 | help='Severity of base augmentation operators') 100 | parser.add_argument( 101 | '--no-jsd', 102 | '-nj', 103 | action='store_true', 104 | help='Turn off JSD consistency loss.') 105 | parser.add_argument( 106 | '--all-ops', 107 | '-all', 108 | action='store_true', 109 | help='Turn on all operations (+brightness,contrast,color,sharpness).') 110 | # Checkpointing options 111 | parser.add_argument( 112 | '--save', 113 | '-s', 114 | type=str, 115 | default='./snapshots', 116 | help='Folder to save checkpoints.') 117 | parser.add_argument( 118 | '--resume', 119 | '-r', 120 | type=str, 121 | default='', 122 | help='Checkpoint path for resume / test.') 123 | parser.add_argument('--evaluate', action='store_true', help='Eval only.') 124 | parser.add_argument( 125 | '--print-freq', 126 | type=int, 127 | default=50, 128 | help='Training loss print frequency (batches).') 129 | # Acceleration 130 | parser.add_argument( 131 | '--num-workers', 132 | type=int, 133 | default=4, 134 | help='Number of pre-fetching threads.') 135 | 136 | args = parser.parse_args() 137 | 138 | CORRUPTIONS = [ 139 | 'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 140 | 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog', 141 | 'brightness', 'contrast', 'elastic_transform', 'pixelate', 142 | 'jpeg_compression' 143 | ] 144 | 145 | 146 | def get_lr(step, total_steps, lr_max, lr_min): 147 | """Compute learning rate according to cosine annealing schedule.""" 148 | return lr_min + (lr_max - lr_min) * 0.5 * (1 + 149 | np.cos(step / total_steps * np.pi)) 150 | 151 | 152 | def aug(image, preprocess): 153 | """Perform AugMix augmentations and compute mixture. 154 | 155 | Args: 156 | image: PIL.Image input image 157 | preprocess: Preprocessing function which should return a torch tensor. 158 | 159 | Returns: 160 | mixed: Augmented and mixed image. 161 | """ 162 | aug_list = augmentations.augmentations 163 | if args.all_ops: 164 | aug_list = augmentations.augmentations_all 165 | 166 | ws = np.float32(np.random.dirichlet([1] * args.mixture_width)) 167 | m = np.float32(np.random.beta(1, 1)) 168 | 169 | mix = torch.zeros_like(preprocess(image)) 170 | for i in range(args.mixture_width): 171 | image_aug = image.copy() 172 | depth = args.mixture_depth if args.mixture_depth > 0 else np.random.randint( 173 | 1, 4) 174 | for _ in range(depth): 175 | op = np.random.choice(aug_list) 176 | image_aug = op(image_aug, args.aug_severity) 177 | # Preprocessing commutes since all coefficients are convex 178 | mix += ws[i] * preprocess(image_aug) 179 | 180 | mixed = (1 - m) * preprocess(image) + m * mix 181 | return mixed 182 | 183 | 184 | class AugMixDataset(torch.utils.data.Dataset): 185 | """Dataset wrapper to perform AugMix augmentation.""" 186 | 187 | def __init__(self, dataset, preprocess, no_jsd=False): 188 | self.dataset = dataset 189 | self.preprocess = preprocess 190 | self.no_jsd = no_jsd 191 | 192 | def __getitem__(self, i): 193 | x, y = self.dataset[i] 194 | if self.no_jsd: 195 | return aug(x, self.preprocess), y 196 | else: 197 | im_tuple = (self.preprocess(x), aug(x, self.preprocess), 198 | aug(x, self.preprocess)) 199 | return im_tuple, y 200 | 201 | def __len__(self): 202 | return len(self.dataset) 203 | 204 | 205 | def train(net, train_loader, optimizer, scheduler): 206 | """Train for one epoch.""" 207 | net.train() 208 | loss_ema = 0. 209 | for i, (images, targets) in enumerate(train_loader): 210 | optimizer.zero_grad() 211 | 212 | if args.no_jsd: 213 | images = images.cuda() 214 | targets = targets.cuda() 215 | logits = net(images) 216 | loss = F.cross_entropy(logits, targets) 217 | else: 218 | images_all = torch.cat(images, 0).cuda() 219 | targets = targets.cuda() 220 | logits_all = net(images_all) 221 | logits_clean, logits_aug1, logits_aug2 = torch.split( 222 | logits_all, images[0].size(0)) 223 | 224 | # Cross-entropy is only computed on clean images 225 | loss = F.cross_entropy(logits_clean, targets) 226 | 227 | p_clean, p_aug1, p_aug2 = F.softmax( 228 | logits_clean, dim=1), F.softmax( 229 | logits_aug1, dim=1), F.softmax( 230 | logits_aug2, dim=1) 231 | 232 | # Clamp mixture distribution to avoid exploding KL divergence 233 | p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log() 234 | loss += 12 * (F.kl_div(p_mixture, p_clean, reduction='batchmean') + 235 | F.kl_div(p_mixture, p_aug1, reduction='batchmean') + 236 | F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3. 237 | 238 | loss.backward() 239 | optimizer.step() 240 | scheduler.step() 241 | loss_ema = loss_ema * 0.9 + float(loss) * 0.1 242 | if i % args.print_freq == 0: 243 | print('Train Loss {:.3f}'.format(loss_ema)) 244 | 245 | return loss_ema 246 | 247 | 248 | def test(net, test_loader): 249 | """Evaluate network on given dataset.""" 250 | net.eval() 251 | total_loss = 0. 252 | total_correct = 0 253 | with torch.no_grad(): 254 | for images, targets in test_loader: 255 | images, targets = images.cuda(), targets.cuda() 256 | logits = net(images) 257 | loss = F.cross_entropy(logits, targets) 258 | pred = logits.data.max(1)[1] 259 | total_loss += float(loss.data) 260 | total_correct += pred.eq(targets.data).sum().item() 261 | 262 | return total_loss / len(test_loader.dataset), total_correct / len( 263 | test_loader.dataset) 264 | 265 | 266 | def test_c(net, test_data, base_path): 267 | """Evaluate network on given corrupted dataset.""" 268 | corruption_accs = [] 269 | for corruption in CORRUPTIONS: 270 | # Reference to original data is mutated 271 | test_data.data = np.load(base_path + corruption + '.npy') 272 | test_data.targets = torch.LongTensor(np.load(base_path + 'labels.npy')) 273 | 274 | test_loader = torch.utils.data.DataLoader( 275 | test_data, 276 | batch_size=args.eval_batch_size, 277 | shuffle=False, 278 | num_workers=args.num_workers, 279 | pin_memory=True) 280 | 281 | test_loss, test_acc = test(net, test_loader) 282 | corruption_accs.append(test_acc) 283 | print('{}\n\tTest Loss {:.3f} | Test Error {:.3f}'.format( 284 | corruption, test_loss, 100 - 100. * test_acc)) 285 | 286 | return np.mean(corruption_accs) 287 | 288 | 289 | def main(): 290 | torch.manual_seed(1) 291 | np.random.seed(1) 292 | 293 | # Load datasets 294 | train_transform = transforms.Compose( 295 | [transforms.RandomHorizontalFlip(), 296 | transforms.RandomCrop(32, padding=4)]) 297 | preprocess = transforms.Compose( 298 | [transforms.ToTensor(), 299 | transforms.Normalize([0.5] * 3, [0.5] * 3)]) 300 | test_transform = preprocess 301 | 302 | if args.dataset == 'cifar10': 303 | train_data = datasets.CIFAR10( 304 | './data/cifar', train=True, transform=train_transform, download=True) 305 | test_data = datasets.CIFAR10( 306 | './data/cifar', train=False, transform=test_transform, download=True) 307 | base_c_path = './data/cifar/CIFAR-10-C/' 308 | num_classes = 10 309 | else: 310 | train_data = datasets.CIFAR100( 311 | './data/cifar', train=True, transform=train_transform, download=True) 312 | test_data = datasets.CIFAR100( 313 | './data/cifar', train=False, transform=test_transform, download=True) 314 | base_c_path = './data/cifar/CIFAR-100-C/' 315 | num_classes = 100 316 | 317 | train_data = AugMixDataset(train_data, preprocess, args.no_jsd) 318 | train_loader = torch.utils.data.DataLoader( 319 | train_data, 320 | batch_size=args.batch_size, 321 | shuffle=True, 322 | num_workers=args.num_workers, 323 | pin_memory=True) 324 | 325 | test_loader = torch.utils.data.DataLoader( 326 | test_data, 327 | batch_size=args.eval_batch_size, 328 | shuffle=False, 329 | num_workers=args.num_workers, 330 | pin_memory=True) 331 | 332 | # Create model 333 | if args.model == 'densenet': 334 | net = densenet(num_classes=num_classes) 335 | elif args.model == 'wrn': 336 | net = WideResNet(args.layers, num_classes, args.widen_factor, args.droprate) 337 | elif args.model == 'allconv': 338 | net = AllConvNet(num_classes) 339 | elif args.model == 'resnext': 340 | net = resnext29(num_classes=num_classes) 341 | 342 | optimizer = torch.optim.SGD( 343 | net.parameters(), 344 | args.learning_rate, 345 | momentum=args.momentum, 346 | weight_decay=args.decay, 347 | nesterov=True) 348 | 349 | # Distribute model across all visible GPUs 350 | net = torch.nn.DataParallel(net).cuda() 351 | cudnn.benchmark = True 352 | 353 | start_epoch = 0 354 | 355 | if args.resume: 356 | if os.path.isfile(args.resume): 357 | checkpoint = torch.load(args.resume) 358 | start_epoch = checkpoint['epoch'] + 1 359 | best_acc = checkpoint['best_acc'] 360 | net.load_state_dict(checkpoint['state_dict']) 361 | optimizer.load_state_dict(checkpoint['optimizer']) 362 | print('Model restored from epoch:', start_epoch) 363 | 364 | if args.evaluate: 365 | # Evaluate clean accuracy first because test_c mutates underlying data 366 | test_loss, test_acc = test(net, test_loader) 367 | print('Clean\n\tTest Loss {:.3f} | Test Error {:.2f}'.format( 368 | test_loss, 100 - 100. * test_acc)) 369 | 370 | test_c_acc = test_c(net, test_data, base_c_path) 371 | print('Mean Corruption Error: {:.3f}'.format(100 - 100. * test_c_acc)) 372 | return 373 | 374 | scheduler = torch.optim.lr_scheduler.LambdaLR( 375 | optimizer, 376 | lr_lambda=lambda step: get_lr( # pylint: disable=g-long-lambda 377 | step, 378 | args.epochs * len(train_loader), 379 | 1, # lr_lambda computes multiplicative factor 380 | 1e-6 / args.learning_rate)) 381 | 382 | if not os.path.exists(args.save): 383 | os.makedirs(args.save) 384 | if not os.path.isdir(args.save): 385 | raise Exception('%s is not a dir' % args.save) 386 | 387 | log_path = os.path.join(args.save, 388 | args.dataset + '_' + args.model + '_training_log.csv') 389 | with open(log_path, 'w') as f: 390 | f.write('epoch,time(s),train_loss,test_loss,test_error(%)\n') 391 | 392 | best_acc = 0 393 | print('Beginning training from epoch:', start_epoch + 1) 394 | for epoch in range(start_epoch, args.epochs): 395 | begin_time = time.time() 396 | 397 | train_loss_ema = train(net, train_loader, optimizer, scheduler) 398 | test_loss, test_acc = test(net, test_loader) 399 | 400 | is_best = test_acc > best_acc 401 | best_acc = max(test_acc, best_acc) 402 | checkpoint = { 403 | 'epoch': epoch, 404 | 'dataset': args.dataset, 405 | 'model': args.model, 406 | 'state_dict': net.state_dict(), 407 | 'best_acc': best_acc, 408 | 'optimizer': optimizer.state_dict(), 409 | } 410 | 411 | save_path = os.path.join(args.save, 'checkpoint.pth.tar') 412 | torch.save(checkpoint, save_path) 413 | if is_best: 414 | shutil.copyfile(save_path, os.path.join(args.save, 'model_best.pth.tar')) 415 | 416 | with open(log_path, 'a') as f: 417 | f.write('%03d,%05d,%0.6f,%0.5f,%0.2f\n' % ( 418 | (epoch + 1), 419 | time.time() - begin_time, 420 | train_loss_ema, 421 | test_loss, 422 | 100 - 100. * test_acc, 423 | )) 424 | 425 | print( 426 | 'Epoch {0:3d} | Time {1:5d} | Train Loss {2:.4f} | Test Loss {3:.3f} |' 427 | ' Test Error {4:.2f}' 428 | .format((epoch + 1), int(time.time() - begin_time), train_loss_ema, 429 | test_loss, 100 - 100. * test_acc)) 430 | 431 | test_c_acc = test_c(net, test_data, base_c_path) 432 | print('Mean Corruption Error: {:.3f}'.format(100 - 100. * test_c_acc)) 433 | 434 | with open(log_path, 'a') as f: 435 | f.write('%03d,%05d,%0.6f,%0.5f,%0.2f\n' % 436 | (args.epochs + 1, 0, 0, 0, 100 - 100 * test_c_acc)) 437 | 438 | 439 | if __name__ == '__main__': 440 | main() 441 | -------------------------------------------------------------------------------- /imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Main script to launch AugMix training on ImageNet. 16 | 17 | Currently only supports ResNet-50 training. 18 | 19 | Example usage: 20 | `python imagenet.py ` 21 | """ 22 | from __future__ import print_function 23 | 24 | import argparse 25 | import os 26 | import shutil 27 | import time 28 | 29 | import augmentations 30 | 31 | import numpy as np 32 | import torch 33 | import torch.backends.cudnn as cudnn 34 | import torch.nn.functional as F 35 | from torchvision import datasets 36 | from torchvision import models 37 | from torchvision import transforms 38 | 39 | augmentations.IMAGE_SIZE = 224 40 | 41 | model_names = sorted(name for name in models.__dict__ 42 | if name.islower() and not name.startswith('__') and 43 | callable(models.__dict__[name])) 44 | 45 | parser = argparse.ArgumentParser(description='Trains an ImageNet Classifier') 46 | parser.add_argument( 47 | 'clean_data', metavar='DIR', help='path to clean ImageNet dataset') 48 | parser.add_argument( 49 | 'corrupted_data', metavar='DIR_C', help='path to ImageNet-C dataset') 50 | parser.add_argument( 51 | '--model', 52 | '-m', 53 | default='resnet50', 54 | choices=model_names, 55 | help='model architecture: ' + ' | '.join(model_names) + 56 | ' (default: resnet50)') 57 | # Optimization options 58 | parser.add_argument( 59 | '--epochs', '-e', type=int, default=90, help='Number of epochs to train.') 60 | parser.add_argument( 61 | '--learning-rate', 62 | '-lr', 63 | type=float, 64 | default=0.1, 65 | help='Initial learning rate.') 66 | parser.add_argument( 67 | '--batch-size', '-b', type=int, default=256, help='Batch size.') 68 | parser.add_argument('--eval-batch-size', type=int, default=1000) 69 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.') 70 | parser.add_argument( 71 | '--decay', 72 | '-wd', 73 | type=float, 74 | default=0.0001, 75 | help='Weight decay (L2 penalty).') 76 | # AugMix options 77 | parser.add_argument( 78 | '--mixture-width', 79 | default=3, 80 | type=int, 81 | help='Number of augmentation chains to mix per augmented example') 82 | parser.add_argument( 83 | '--mixture-depth', 84 | default=-1, 85 | type=int, 86 | help='Depth of augmentation chains. -1 denotes stochastic depth in [1, 3]') 87 | parser.add_argument( 88 | '--aug-severity', 89 | default=1, 90 | type=int, 91 | help='Severity of base augmentation operators') 92 | parser.add_argument( 93 | '--aug-prob-coeff', 94 | default=1., 95 | type=float, 96 | help='Probability distribution coefficients') 97 | parser.add_argument( 98 | '--no-jsd', 99 | '-nj', 100 | action='store_true', 101 | help='Turn off JSD consistency loss.') 102 | parser.add_argument( 103 | '--all-ops', 104 | '-all', 105 | action='store_true', 106 | help='Turn on all operations (+brightness,contrast,color,sharpness).') 107 | # Checkpointing options 108 | parser.add_argument( 109 | '--save', 110 | '-s', 111 | type=str, 112 | default='./snapshots', 113 | help='Folder to save checkpoints.') 114 | parser.add_argument( 115 | '--resume', 116 | '-r', 117 | type=str, 118 | default='', 119 | help='Checkpoint path for resume / test.') 120 | parser.add_argument('--evaluate', action='store_true', help='Eval only.') 121 | parser.add_argument( 122 | '--print-freq', 123 | type=int, 124 | default=10, 125 | help='Training loss print frequency (batches).') 126 | parser.add_argument( 127 | '--pretrained', 128 | dest='pretrained', 129 | action='store_true', 130 | help='use pre-trained model') 131 | # Acceleration 132 | parser.add_argument( 133 | '--num-workers', 134 | type=int, 135 | default=4, 136 | help='Number of pre-fetching threads.') 137 | 138 | args = parser.parse_args() 139 | 140 | CORRUPTIONS = [ 141 | 'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 142 | 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog', 143 | 'brightness', 'contrast', 'elastic_transform', 'pixelate', 144 | 'jpeg_compression' 145 | ] 146 | 147 | # Raw AlexNet errors taken from https://github.com/hendrycks/robustness 148 | ALEXNET_ERR = [ 149 | 0.886428, 0.894468, 0.922640, 0.819880, 0.826268, 0.785948, 0.798360, 150 | 0.866816, 0.826572, 0.819324, 0.564592, 0.853204, 0.646056, 0.717840, 151 | 0.606500 152 | ] 153 | 154 | 155 | def adjust_learning_rate(optimizer, epoch): 156 | """Sets the learning rate to the initial LR (linearly scaled to batch size) decayed by 10 every n / 3 epochs.""" 157 | b = args.batch_size / 256. 158 | k = args.epochs // 3 159 | if epoch < k: 160 | m = 1 161 | elif epoch < 2 * k: 162 | m = 0.1 163 | else: 164 | m = 0.01 165 | lr = args.learning_rate * m * b 166 | for param_group in optimizer.param_groups: 167 | param_group['lr'] = lr 168 | 169 | 170 | def accuracy(output, target, topk=(1,)): 171 | """Computes the accuracy over the k top predictions for the specified values of k.""" 172 | with torch.no_grad(): 173 | maxk = max(topk) 174 | batch_size = target.size(0) 175 | 176 | _, pred = output.topk(maxk, 1, True, True) 177 | pred = pred.t() 178 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 179 | 180 | res = [] 181 | for k in topk: 182 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 183 | res.append(correct_k.mul_(100.0 / batch_size)) 184 | return res 185 | 186 | 187 | def compute_mce(corruption_accs): 188 | """Compute mCE (mean Corruption Error) normalized by AlexNet performance.""" 189 | mce = 0. 190 | for i in range(len(CORRUPTIONS)): 191 | avg_err = 1 - np.mean(corruption_accs[CORRUPTIONS[i]]) 192 | ce = 100 * avg_err / ALEXNET_ERR[i] 193 | mce += ce / 15 194 | return mce 195 | 196 | 197 | def aug(image, preprocess): 198 | """Perform AugMix augmentations and compute mixture. 199 | 200 | Args: 201 | image: PIL.Image input image 202 | preprocess: Preprocessing function which should return a torch tensor. 203 | 204 | Returns: 205 | mixed: Augmented and mixed image. 206 | """ 207 | aug_list = augmentations.augmentations 208 | if args.all_ops: 209 | aug_list = augmentations.augmentations_all 210 | 211 | ws = np.float32( 212 | np.random.dirichlet([args.aug_prob_coeff] * args.mixture_width)) 213 | m = np.float32(np.random.beta(args.aug_prob_coeff, args.aug_prob_coeff)) 214 | 215 | mix = torch.zeros_like(preprocess(image)) 216 | for i in range(args.mixture_width): 217 | image_aug = image.copy() 218 | depth = args.mixture_depth if args.mixture_depth > 0 else np.random.randint( 219 | 1, 4) 220 | for _ in range(depth): 221 | op = np.random.choice(aug_list) 222 | image_aug = op(image_aug, args.aug_severity) 223 | # Preprocessing commutes since all coefficients are convex 224 | mix += ws[i] * preprocess(image_aug) 225 | 226 | mixed = (1 - m) * preprocess(image) + m * mix 227 | return mixed 228 | 229 | 230 | class AugMixDataset(torch.utils.data.Dataset): 231 | """Dataset wrapper to perform AugMix augmentation.""" 232 | 233 | def __init__(self, dataset, preprocess, no_jsd=False): 234 | self.dataset = dataset 235 | self.preprocess = preprocess 236 | self.no_jsd = no_jsd 237 | 238 | def __getitem__(self, i): 239 | x, y = self.dataset[i] 240 | if self.no_jsd: 241 | return aug(x, self.preprocess), y 242 | else: 243 | im_tuple = (self.preprocess(x), aug(x, self.preprocess), 244 | aug(x, self.preprocess)) 245 | return im_tuple, y 246 | 247 | def __len__(self): 248 | return len(self.dataset) 249 | 250 | 251 | def train(net, train_loader, optimizer): 252 | """Train for one epoch.""" 253 | net.train() 254 | data_ema = 0. 255 | batch_ema = 0. 256 | loss_ema = 0. 257 | acc1_ema = 0. 258 | acc5_ema = 0. 259 | 260 | end = time.time() 261 | for i, (images, targets) in enumerate(train_loader): 262 | # Compute data loading time 263 | data_time = time.time() - end 264 | optimizer.zero_grad() 265 | 266 | if args.no_jsd: 267 | images = images.cuda() 268 | targets = targets.cuda() 269 | logits = net(images) 270 | loss = F.cross_entropy(logits, targets) 271 | acc1, acc5 = accuracy(logits, targets, topk=(1, 5)) # pylint: disable=unbalanced-tuple-unpacking 272 | else: 273 | images_all = torch.cat(images, 0).cuda() 274 | targets = targets.cuda() 275 | logits_all = net(images_all) 276 | logits_clean, logits_aug1, logits_aug2 = torch.split( 277 | logits_all, images[0].size(0)) 278 | 279 | # Cross-entropy is only computed on clean images 280 | loss = F.cross_entropy(logits_clean, targets) 281 | 282 | p_clean, p_aug1, p_aug2 = F.softmax( 283 | logits_clean, dim=1), F.softmax( 284 | logits_aug1, dim=1), F.softmax( 285 | logits_aug2, dim=1) 286 | 287 | # Clamp mixture distribution to avoid exploding KL divergence 288 | p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log() 289 | loss += 12 * (F.kl_div(p_mixture, p_clean, reduction='batchmean') + 290 | F.kl_div(p_mixture, p_aug1, reduction='batchmean') + 291 | F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3. 292 | acc1, acc5 = accuracy(logits_clean, targets, topk=(1, 5)) # pylint: disable=unbalanced-tuple-unpacking 293 | 294 | loss.backward() 295 | optimizer.step() 296 | 297 | # Compute batch computation time and update moving averages. 298 | batch_time = time.time() - end 299 | end = time.time() 300 | 301 | data_ema = data_ema * 0.1 + float(data_time) * 0.9 302 | batch_ema = batch_ema * 0.1 + float(batch_time) * 0.9 303 | loss_ema = loss_ema * 0.1 + float(loss) * 0.9 304 | acc1_ema = acc1_ema * 0.1 + float(acc1) * 0.9 305 | acc5_ema = acc5_ema * 0.1 + float(acc5) * 0.9 306 | 307 | if i % args.print_freq == 0: 308 | print( 309 | 'Batch {}/{}: Data Time {:.3f} | Batch Time {:.3f} | Train Loss {:.3f} | Train Acc1 ' 310 | '{:.3f} | Train Acc5 {:.3f}'.format(i, len(train_loader), data_ema, 311 | batch_ema, loss_ema, acc1_ema, 312 | acc5_ema)) 313 | 314 | return loss_ema, acc1_ema, batch_ema 315 | 316 | 317 | def test(net, test_loader): 318 | """Evaluate network on given dataset.""" 319 | net.eval() 320 | total_loss = 0. 321 | total_correct = 0 322 | with torch.no_grad(): 323 | for images, targets in test_loader: 324 | images, targets = images.cuda(), targets.cuda() 325 | logits = net(images) 326 | loss = F.cross_entropy(logits, targets) 327 | pred = logits.data.max(1)[1] 328 | total_loss += float(loss.data) 329 | total_correct += pred.eq(targets.data).sum().item() 330 | 331 | return total_loss / len(test_loader.dataset), total_correct / len( 332 | test_loader.dataset) 333 | 334 | 335 | def test_c(net, test_transform): 336 | """Evaluate network on given corrupted dataset.""" 337 | corruption_accs = {} 338 | for c in CORRUPTIONS: 339 | print(c) 340 | for s in range(1, 6): 341 | valdir = os.path.join(args.corrupted_data, c, str(s)) 342 | val_loader = torch.utils.data.DataLoader( 343 | datasets.ImageFolder(valdir, test_transform), 344 | batch_size=args.eval_batch_size, 345 | shuffle=False, 346 | num_workers=args.num_workers, 347 | pin_memory=True) 348 | 349 | loss, acc1 = test(net, val_loader) 350 | if c in corruption_accs: 351 | corruption_accs[c].append(acc1) 352 | else: 353 | corruption_accs[c] = [acc1] 354 | 355 | print('\ts={}: Test Loss {:.3f} | Test Acc1 {:.3f}'.format( 356 | s, loss, 100. * acc1)) 357 | 358 | return corruption_accs 359 | 360 | 361 | def main(): 362 | torch.manual_seed(1) 363 | np.random.seed(1) 364 | 365 | # Load datasets 366 | mean = [0.485, 0.456, 0.406] 367 | std = [0.229, 0.224, 0.225] 368 | train_transform = transforms.Compose( 369 | [transforms.RandomResizedCrop(224), 370 | transforms.RandomHorizontalFlip()]) 371 | preprocess = transforms.Compose( 372 | [transforms.ToTensor(), 373 | transforms.Normalize(mean, std)]) 374 | test_transform = transforms.Compose([ 375 | transforms.Resize(256), 376 | transforms.CenterCrop(224), 377 | preprocess, 378 | ]) 379 | 380 | traindir = os.path.join(args.clean_data, 'train') 381 | valdir = os.path.join(args.clean_data, 'val') 382 | train_dataset = datasets.ImageFolder(traindir, train_transform) 383 | train_dataset = AugMixDataset(train_dataset, preprocess) 384 | train_loader = torch.utils.data.DataLoader( 385 | train_dataset, 386 | batch_size=args.batch_size, 387 | shuffle=True, 388 | num_workers=args.num_workers) 389 | val_loader = torch.utils.data.DataLoader( 390 | datasets.ImageFolder(valdir, test_transform), 391 | batch_size=args.batch_size, 392 | shuffle=False, 393 | num_workers=args.num_workers) 394 | 395 | if args.pretrained: 396 | print("=> using pre-trained model '{}'".format(args.model)) 397 | net = models.__dict__[args.model](pretrained=True) 398 | else: 399 | print("=> creating model '{}'".format(args.model)) 400 | net = models.__dict__[args.model]() 401 | 402 | optimizer = torch.optim.SGD( 403 | net.parameters(), 404 | args.learning_rate, 405 | momentum=args.momentum, 406 | weight_decay=args.decay) 407 | 408 | # Distribute model across all visible GPUs 409 | net = torch.nn.DataParallel(net).cuda() 410 | cudnn.benchmark = True 411 | 412 | start_epoch = 0 413 | 414 | if args.resume: 415 | if os.path.isfile(args.resume): 416 | checkpoint = torch.load(args.resume) 417 | start_epoch = checkpoint['epoch'] + 1 418 | best_acc1 = checkpoint['best_acc1'] 419 | net.load_state_dict(checkpoint['state_dict']) 420 | optimizer.load_state_dict(checkpoint['optimizer']) 421 | print('Model restored from epoch:', start_epoch) 422 | 423 | if args.evaluate: 424 | test_loss, test_acc1 = test(net, val_loader) 425 | print('Clean\n\tTest Loss {:.3f} | Test Acc1 {:.3f}'.format( 426 | test_loss, 100 * test_acc1)) 427 | 428 | corruption_accs = test_c(net, test_transform) 429 | for c in CORRUPTIONS: 430 | print('\t'.join([c] + map(str, corruption_accs[c]))) 431 | 432 | print('mCE (normalized by AlexNet): ', compute_mce(corruption_accs)) 433 | return 434 | 435 | if not os.path.exists(args.save): 436 | os.makedirs(args.save) 437 | if not os.path.isdir(args.save): 438 | raise Exception('%s is not a dir' % args.save) 439 | 440 | log_path = os.path.join(args.save, 441 | 'imagenet_{}_training_log.csv'.format(args.model)) 442 | with open(log_path, 'w') as f: 443 | f.write( 444 | 'epoch,batch_time,train_loss,train_acc1(%),test_loss,test_acc1(%)\n') 445 | 446 | best_acc1 = 0 447 | print('Beginning training from epoch:', start_epoch + 1) 448 | for epoch in range(start_epoch, args.epochs): 449 | adjust_learning_rate(optimizer, epoch) 450 | 451 | train_loss_ema, train_acc1_ema, batch_ema = train(net, train_loader, 452 | optimizer) 453 | test_loss, test_acc1 = test(net, val_loader) 454 | 455 | is_best = test_acc1 > best_acc1 456 | best_acc1 = max(test_acc1, best_acc1) 457 | checkpoint = { 458 | 'epoch': epoch, 459 | 'model': args.model, 460 | 'state_dict': net.state_dict(), 461 | 'best_acc1': best_acc1, 462 | 'optimizer': optimizer.state_dict(), 463 | } 464 | 465 | save_path = os.path.join(args.save, 'checkpoint.pth.tar') 466 | torch.save(checkpoint, save_path) 467 | if is_best: 468 | shutil.copyfile(save_path, os.path.join(args.save, 'model_best.pth.tar')) 469 | 470 | with open(log_path, 'a') as f: 471 | f.write('%03d,%0.3f,%0.6f,%0.2f,%0.5f,%0.2f\n' % ( 472 | (epoch + 1), 473 | batch_ema, 474 | train_loss_ema, 475 | 100. * train_acc1_ema, 476 | test_loss, 477 | 100. * test_acc1, 478 | )) 479 | 480 | print( 481 | 'Epoch {:3d} | Train Loss {:.4f} | Test Loss {:.3f} | Test Acc1 ' 482 | '{:.2f}' 483 | .format((epoch + 1), train_loss_ema, test_loss, 100. * test_acc1)) 484 | 485 | corruption_accs = test_c(net, test_transform) 486 | for c in CORRUPTIONS: 487 | print('\t'.join(map(str, [c] + corruption_accs[c]))) 488 | 489 | print('mCE (normalized by AlexNet):', compute_mce(corruption_accs)) 490 | 491 | 492 | if __name__ == '__main__': 493 | main() 494 | -------------------------------------------------------------------------------- /models/cifar/allconv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """AllConv implementation (https://arxiv.org/abs/1412.6806).""" 16 | import math 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | class GELU(nn.Module): 22 | 23 | def forward(self, x): 24 | return torch.sigmoid(1.702 * x) * x 25 | 26 | 27 | def make_layers(cfg): 28 | """Create a single layer.""" 29 | layers = [] 30 | in_channels = 3 31 | for v in cfg: 32 | if v == 'Md': 33 | layers += [nn.MaxPool2d(kernel_size=2, stride=2), nn.Dropout(p=0.5)] 34 | elif v == 'A': 35 | layers += [nn.AvgPool2d(kernel_size=8)] 36 | elif v == 'NIN': 37 | conv2d = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=1) 38 | layers += [conv2d, nn.BatchNorm2d(in_channels), GELU()] 39 | elif v == 'nopad': 40 | conv2d = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=0) 41 | layers += [conv2d, nn.BatchNorm2d(in_channels), GELU()] 42 | else: 43 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 44 | layers += [conv2d, nn.BatchNorm2d(v), GELU()] 45 | in_channels = v 46 | return nn.Sequential(*layers) 47 | 48 | 49 | class AllConvNet(nn.Module): 50 | """AllConvNet main class.""" 51 | 52 | def __init__(self, num_classes): 53 | super(AllConvNet, self).__init__() 54 | 55 | self.num_classes = num_classes 56 | self.width1, w1 = 96, 96 57 | self.width2, w2 = 192, 192 58 | 59 | self.features = make_layers( 60 | [w1, w1, w1, 'Md', w2, w2, w2, 'Md', 'nopad', 'NIN', 'NIN', 'A']) 61 | self.classifier = nn.Linear(self.width2, num_classes) 62 | 63 | for m in self.modules(): 64 | if isinstance(m, nn.Conv2d): 65 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 66 | m.weight.data.normal_(0, math.sqrt(2. / n)) # He initialization 67 | elif isinstance(m, nn.BatchNorm2d): 68 | m.weight.data.fill_(1) 69 | m.bias.data.zero_() 70 | elif isinstance(m, nn.Linear): 71 | m.bias.data.zero_() 72 | 73 | def forward(self, x): 74 | x = self.features(x) 75 | x = x.view(x.size(0), -1) 76 | x = self.classifier(x) 77 | return x 78 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.15.0 2 | Pillow>=6.1.0 3 | torch==1.2.0 4 | torchvision==0.2.2 5 | -------------------------------------------------------------------------------- /third_party/ResNeXt_DenseNet/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Xuanyi Dong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /third_party/ResNeXt_DenseNet/METADATA: -------------------------------------------------------------------------------- 1 | name: "ResNeXt-DenseNet" 2 | description: "PyTorch implementations of ResNeXt and DenseNet." 3 | 4 | third_party { 5 | url { 6 | type: GIT 7 | value: "https://github.com/D-X-Y/ResNeXt-DenseNet" 8 | } 9 | version: "0de9a8c8fd095b37eb60945f8dafefdbfe1cef6b" 10 | last_upgrade_date { year: 2019 month: 12 day: 4 } 11 | license_type: PERMISSIVE 12 | } 13 | -------------------------------------------------------------------------------- /third_party/ResNeXt_DenseNet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/augmix/9b9824c7c19bf7e72df2d085d97b99b3bfb00ba4/third_party/ResNeXt_DenseNet/__init__.py -------------------------------------------------------------------------------- /third_party/ResNeXt_DenseNet/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/augmix/9b9824c7c19bf7e72df2d085d97b99b3bfb00ba4/third_party/ResNeXt_DenseNet/models/__init__.py -------------------------------------------------------------------------------- /third_party/ResNeXt_DenseNet/models/densenet.py: -------------------------------------------------------------------------------- 1 | """DenseNet implementation (https://arxiv.org/abs/1608.06993).""" 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class Bottleneck(nn.Module): 9 | """Bottleneck block for DenseNet.""" 10 | 11 | def __init__(self, n_channels, growth_rate): 12 | super(Bottleneck, self).__init__() 13 | inter_channels = 4 * growth_rate 14 | self.bn1 = nn.BatchNorm2d(n_channels) 15 | self.conv1 = nn.Conv2d( 16 | n_channels, inter_channels, kernel_size=1, bias=False) 17 | self.bn2 = nn.BatchNorm2d(inter_channels) 18 | self.conv2 = nn.Conv2d( 19 | inter_channels, growth_rate, kernel_size=3, padding=1, bias=False) 20 | 21 | def forward(self, x): 22 | out = self.conv1(F.relu(self.bn1(x))) 23 | out = self.conv2(F.relu(self.bn2(out))) 24 | out = torch.cat((x, out), 1) 25 | return out 26 | 27 | 28 | class SingleLayer(nn.Module): 29 | """Layer container for blocks.""" 30 | 31 | def __init__(self, n_channels, growth_rate): 32 | super(SingleLayer, self).__init__() 33 | self.bn1 = nn.BatchNorm2d(n_channels) 34 | self.conv1 = nn.Conv2d( 35 | n_channels, growth_rate, kernel_size=3, padding=1, bias=False) 36 | 37 | def forward(self, x): 38 | out = self.conv1(F.relu(self.bn1(x))) 39 | out = torch.cat((x, out), 1) 40 | return out 41 | 42 | 43 | class Transition(nn.Module): 44 | """Transition block.""" 45 | 46 | def __init__(self, n_channels, n_out_channels): 47 | super(Transition, self).__init__() 48 | self.bn1 = nn.BatchNorm2d(n_channels) 49 | self.conv1 = nn.Conv2d( 50 | n_channels, n_out_channels, kernel_size=1, bias=False) 51 | 52 | def forward(self, x): 53 | out = self.conv1(F.relu(self.bn1(x))) 54 | out = F.avg_pool2d(out, 2) 55 | return out 56 | 57 | 58 | class DenseNet(nn.Module): 59 | """DenseNet main class.""" 60 | 61 | def __init__(self, growth_rate, depth, reduction, n_classes, bottleneck): 62 | super(DenseNet, self).__init__() 63 | 64 | if bottleneck: 65 | n_dense_blocks = int((depth - 4) / 6) 66 | else: 67 | n_dense_blocks = int((depth - 4) / 3) 68 | 69 | n_channels = 2 * growth_rate 70 | self.conv1 = nn.Conv2d(3, n_channels, kernel_size=3, padding=1, bias=False) 71 | 72 | self.dense1 = self._make_dense(n_channels, growth_rate, n_dense_blocks, 73 | bottleneck) 74 | n_channels += n_dense_blocks * growth_rate 75 | n_out_channels = int(math.floor(n_channels * reduction)) 76 | self.trans1 = Transition(n_channels, n_out_channels) 77 | 78 | n_channels = n_out_channels 79 | self.dense2 = self._make_dense(n_channels, growth_rate, n_dense_blocks, 80 | bottleneck) 81 | n_channels += n_dense_blocks * growth_rate 82 | n_out_channels = int(math.floor(n_channels * reduction)) 83 | self.trans2 = Transition(n_channels, n_out_channels) 84 | 85 | n_channels = n_out_channels 86 | self.dense3 = self._make_dense(n_channels, growth_rate, n_dense_blocks, 87 | bottleneck) 88 | n_channels += n_dense_blocks * growth_rate 89 | 90 | self.bn1 = nn.BatchNorm2d(n_channels) 91 | self.fc = nn.Linear(n_channels, n_classes) 92 | 93 | for m in self.modules(): 94 | if isinstance(m, nn.Conv2d): 95 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 96 | m.weight.data.normal_(0, math.sqrt(2. / n)) 97 | elif isinstance(m, nn.BatchNorm2d): 98 | m.weight.data.fill_(1) 99 | m.bias.data.zero_() 100 | elif isinstance(m, nn.Linear): 101 | m.bias.data.zero_() 102 | 103 | def _make_dense(self, n_channels, growth_rate, n_dense_blocks, bottleneck): 104 | layers = [] 105 | for _ in range(int(n_dense_blocks)): 106 | if bottleneck: 107 | layers.append(Bottleneck(n_channels, growth_rate)) 108 | else: 109 | layers.append(SingleLayer(n_channels, growth_rate)) 110 | n_channels += growth_rate 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, x): 114 | out = self.conv1(x) 115 | out = self.trans1(self.dense1(out)) 116 | out = self.trans2(self.dense2(out)) 117 | out = self.dense3(out) 118 | out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8)) 119 | out = self.fc(out) 120 | return out 121 | 122 | 123 | def densenet(growth_rate=12, depth=40, num_classes=10): 124 | model = DenseNet(growth_rate, depth, 1., num_classes, False) 125 | return model 126 | -------------------------------------------------------------------------------- /third_party/ResNeXt_DenseNet/models/resnext.py: -------------------------------------------------------------------------------- 1 | """ResNeXt implementation (https://arxiv.org/abs/1611.05431).""" 2 | import math 3 | import torch.nn as nn 4 | from torch.nn import init 5 | import torch.nn.functional as F 6 | 7 | 8 | class ResNeXtBottleneck(nn.Module): 9 | """ResNeXt Bottleneck Block type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua).""" 10 | expansion = 4 11 | 12 | def __init__(self, 13 | inplanes, 14 | planes, 15 | cardinality, 16 | base_width, 17 | stride=1, 18 | downsample=None): 19 | super(ResNeXtBottleneck, self).__init__() 20 | 21 | dim = int(math.floor(planes * (base_width / 64.0))) 22 | 23 | self.conv_reduce = nn.Conv2d( 24 | inplanes, 25 | dim * cardinality, 26 | kernel_size=1, 27 | stride=1, 28 | padding=0, 29 | bias=False) 30 | self.bn_reduce = nn.BatchNorm2d(dim * cardinality) 31 | 32 | self.conv_conv = nn.Conv2d( 33 | dim * cardinality, 34 | dim * cardinality, 35 | kernel_size=3, 36 | stride=stride, 37 | padding=1, 38 | groups=cardinality, 39 | bias=False) 40 | self.bn = nn.BatchNorm2d(dim * cardinality) 41 | 42 | self.conv_expand = nn.Conv2d( 43 | dim * cardinality, 44 | planes * 4, 45 | kernel_size=1, 46 | stride=1, 47 | padding=0, 48 | bias=False) 49 | self.bn_expand = nn.BatchNorm2d(planes * 4) 50 | 51 | self.downsample = downsample 52 | 53 | def forward(self, x): 54 | residual = x 55 | 56 | bottleneck = self.conv_reduce(x) 57 | bottleneck = F.relu(self.bn_reduce(bottleneck), inplace=True) 58 | 59 | bottleneck = self.conv_conv(bottleneck) 60 | bottleneck = F.relu(self.bn(bottleneck), inplace=True) 61 | 62 | bottleneck = self.conv_expand(bottleneck) 63 | bottleneck = self.bn_expand(bottleneck) 64 | 65 | if self.downsample is not None: 66 | residual = self.downsample(x) 67 | 68 | return F.relu(residual + bottleneck, inplace=True) 69 | 70 | 71 | class CifarResNeXt(nn.Module): 72 | """ResNext optimized for the Cifar dataset, as specified in https://arxiv.org/pdf/1611.05431.pdf.""" 73 | 74 | def __init__(self, block, depth, cardinality, base_width, num_classes): 75 | super(CifarResNeXt, self).__init__() 76 | 77 | # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 78 | assert (depth - 2) % 9 == 0, 'depth should be one of 29, 38, 47, 56, 101' 79 | layer_blocks = (depth - 2) // 9 80 | 81 | self.cardinality = cardinality 82 | self.base_width = base_width 83 | self.num_classes = num_classes 84 | 85 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 86 | self.bn_1 = nn.BatchNorm2d(64) 87 | 88 | self.inplanes = 64 89 | self.stage_1 = self._make_layer(block, 64, layer_blocks, 1) 90 | self.stage_2 = self._make_layer(block, 128, layer_blocks, 2) 91 | self.stage_3 = self._make_layer(block, 256, layer_blocks, 2) 92 | self.avgpool = nn.AvgPool2d(8) 93 | self.classifier = nn.Linear(256 * block.expansion, num_classes) 94 | 95 | for m in self.modules(): 96 | if isinstance(m, nn.Conv2d): 97 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 98 | m.weight.data.normal_(0, math.sqrt(2. / n)) 99 | elif isinstance(m, nn.BatchNorm2d): 100 | m.weight.data.fill_(1) 101 | m.bias.data.zero_() 102 | elif isinstance(m, nn.Linear): 103 | init.kaiming_normal(m.weight) 104 | m.bias.data.zero_() 105 | 106 | def _make_layer(self, block, planes, blocks, stride=1): 107 | downsample = None 108 | if stride != 1 or self.inplanes != planes * block.expansion: 109 | downsample = nn.Sequential( 110 | nn.Conv2d( 111 | self.inplanes, 112 | planes * block.expansion, 113 | kernel_size=1, 114 | stride=stride, 115 | bias=False), 116 | nn.BatchNorm2d(planes * block.expansion), 117 | ) 118 | 119 | layers = [] 120 | layers.append( 121 | block(self.inplanes, planes, self.cardinality, self.base_width, stride, 122 | downsample)) 123 | self.inplanes = planes * block.expansion 124 | for _ in range(1, blocks): 125 | layers.append( 126 | block(self.inplanes, planes, self.cardinality, self.base_width)) 127 | 128 | return nn.Sequential(*layers) 129 | 130 | def forward(self, x): 131 | x = self.conv_1_3x3(x) 132 | x = F.relu(self.bn_1(x), inplace=True) 133 | x = self.stage_1(x) 134 | x = self.stage_2(x) 135 | x = self.stage_3(x) 136 | x = self.avgpool(x) 137 | x = x.view(x.size(0), -1) 138 | return self.classifier(x) 139 | 140 | 141 | def resnext29(num_classes=10, cardinality=4, base_width=32): 142 | model = CifarResNeXt(ResNeXtBottleneck, 29, cardinality, base_width, 143 | num_classes) 144 | return model 145 | -------------------------------------------------------------------------------- /third_party/WideResNet_pytorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 xternalz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /third_party/WideResNet_pytorch/METADATA: -------------------------------------------------------------------------------- 1 | name: "WideResNet-pytorch" 2 | description: "PyTorch implementation of WideResNet." 3 | 4 | third_party { 5 | url { 6 | type: GIT 7 | value: "https://github.com/xternalz/WideResNet-pytorch" 8 | } 9 | version: "1171f93d5a9ae28eb5e603e5e7545f488d0df6ab" 10 | last_upgrade_date { year: 2019 month: 12 day: 4 } 11 | license_type: PERMISSIVE 12 | } 13 | -------------------------------------------------------------------------------- /third_party/WideResNet_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/augmix/9b9824c7c19bf7e72df2d085d97b99b3bfb00ba4/third_party/WideResNet_pytorch/__init__.py -------------------------------------------------------------------------------- /third_party/WideResNet_pytorch/wideresnet.py: -------------------------------------------------------------------------------- 1 | """WideResNet implementation (https://arxiv.org/abs/1605.07146).""" 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class BasicBlock(nn.Module): 9 | """Basic ResNet block.""" 10 | 11 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0): 12 | super(BasicBlock, self).__init__() 13 | self.bn1 = nn.BatchNorm2d(in_planes) 14 | self.relu1 = nn.ReLU(inplace=True) 15 | self.conv1 = nn.Conv2d( 16 | in_planes, 17 | out_planes, 18 | kernel_size=3, 19 | stride=stride, 20 | padding=1, 21 | bias=False) 22 | self.bn2 = nn.BatchNorm2d(out_planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | self.conv2 = nn.Conv2d( 25 | out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) 26 | self.drop_rate = drop_rate 27 | self.is_in_equal_out = (in_planes == out_planes) 28 | self.conv_shortcut = (not self.is_in_equal_out) and nn.Conv2d( 29 | in_planes, 30 | out_planes, 31 | kernel_size=1, 32 | stride=stride, 33 | padding=0, 34 | bias=False) or None 35 | 36 | def forward(self, x): 37 | if not self.is_in_equal_out: 38 | x = self.relu1(self.bn1(x)) 39 | else: 40 | out = self.relu1(self.bn1(x)) 41 | if self.is_in_equal_out: 42 | out = self.relu2(self.bn2(self.conv1(out))) 43 | else: 44 | out = self.relu2(self.bn2(self.conv1(x))) 45 | if self.drop_rate > 0: 46 | out = F.dropout(out, p=self.drop_rate, training=self.training) 47 | out = self.conv2(out) 48 | if not self.is_in_equal_out: 49 | return torch.add(self.conv_shortcut(x), out) 50 | else: 51 | return torch.add(x, out) 52 | 53 | 54 | class NetworkBlock(nn.Module): 55 | """Layer container for blocks.""" 56 | 57 | def __init__(self, 58 | nb_layers, 59 | in_planes, 60 | out_planes, 61 | block, 62 | stride, 63 | drop_rate=0.0): 64 | super(NetworkBlock, self).__init__() 65 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, 66 | stride, drop_rate) 67 | 68 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, 69 | drop_rate): 70 | layers = [] 71 | for i in range(nb_layers): 72 | layers.append( 73 | block(i == 0 and in_planes or out_planes, out_planes, 74 | i == 0 and stride or 1, drop_rate)) 75 | return nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | return self.layer(x) 79 | 80 | 81 | class WideResNet(nn.Module): 82 | """WideResNet class.""" 83 | 84 | def __init__(self, depth, num_classes, widen_factor=1, drop_rate=0.0): 85 | super(WideResNet, self).__init__() 86 | n_channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 87 | assert (depth - 4) % 6 == 0 88 | n = (depth - 4) // 6 89 | block = BasicBlock 90 | # 1st conv before any network block 91 | self.conv1 = nn.Conv2d( 92 | 3, n_channels[0], kernel_size=3, stride=1, padding=1, bias=False) 93 | # 1st block 94 | self.block1 = NetworkBlock(n, n_channels[0], n_channels[1], block, 1, 95 | drop_rate) 96 | # 2nd block 97 | self.block2 = NetworkBlock(n, n_channels[1], n_channels[2], block, 2, 98 | drop_rate) 99 | # 3rd block 100 | self.block3 = NetworkBlock(n, n_channels[2], n_channels[3], block, 2, 101 | drop_rate) 102 | # global average pooling and classifier 103 | self.bn1 = nn.BatchNorm2d(n_channels[3]) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.fc = nn.Linear(n_channels[3], num_classes) 106 | self.n_channels = n_channels[3] 107 | 108 | for m in self.modules(): 109 | if isinstance(m, nn.Conv2d): 110 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 111 | m.weight.data.normal_(0, math.sqrt(2. / n)) 112 | elif isinstance(m, nn.BatchNorm2d): 113 | m.weight.data.fill_(1) 114 | m.bias.data.zero_() 115 | elif isinstance(m, nn.Linear): 116 | m.bias.data.zero_() 117 | 118 | def forward(self, x): 119 | out = self.conv1(x) 120 | out = self.block1(out) 121 | out = self.block2(out) 122 | out = self.block3(out) 123 | out = self.relu(self.bn1(out)) 124 | out = F.avg_pool2d(out, 8) 125 | out = out.view(-1, self.n_channels) 126 | return self.fc(out) 127 | -------------------------------------------------------------------------------- /third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/augmix/9b9824c7c19bf7e72df2d085d97b99b3bfb00ba4/third_party/__init__.py --------------------------------------------------------------------------------