├── LICENSE
├── README.md
├── assets
├── augmix.gif
└── pseudocode.png
├── augment_and_mix.py
├── augmentations.py
├── cifar.py
├── imagenet.py
├── models
└── cifar
│ └── allconv.py
├── requirements.txt
└── third_party
├── ResNeXt_DenseNet
├── LICENSE
├── METADATA
├── __init__.py
└── models
│ ├── __init__.py
│ ├── densenet.py
│ └── resnext.py
├── WideResNet_pytorch
├── LICENSE
├── METADATA
├── __init__.py
└── wideresnet.py
└── __init__.py
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AugMix
2 |
3 |
4 |
5 | ## Introduction
6 |
7 | We propose AugMix, a data processing technique that mixes augmented images and
8 | enforces consistent embeddings of the augmented images, which results in
9 | increased robustness and improved uncertainty calibration. AugMix does not
10 | require tuning to work correctly, as with random cropping or CutOut, and thus
11 | enables plug-and-play data augmentation. AugMix significantly improves
12 | robustness and uncertainty measures on challenging image classification
13 | benchmarks, closing the gap between previous methods and the best possible
14 | performance by more than half in some cases. With AugMix, we obtain
15 | state-of-the-art on ImageNet-C, ImageNet-P and in uncertainty estimation when
16 | the train and test distribution do not match.
17 |
18 | For more details please see our [ICLR 2020 paper](https://arxiv.org/pdf/1912.02781.pdf).
19 |
20 | ## Pseudocode
21 |
22 |
23 |
24 | ## Contents
25 |
26 | This directory includes a reference implementation in NumPy of the augmentation
27 | method used in AugMix in `augment_and_mix.py`. The full AugMix method also adds
28 | a Jensen-Shanon Divergence consistency loss to enforce consistent predictions
29 | between two different augmentations of the input image and the clean image
30 | itself.
31 |
32 | We also include PyTorch re-implementations of AugMix on both CIFAR-10/100 and
33 | ImageNet in `cifar.py` and `imagenet.py` respectively, which both support
34 | training and evaluation on CIFAR-10/100-C and ImageNet-C.
35 |
36 | ## Requirements
37 |
38 | * numpy>=1.15.0
39 | * Pillow>=6.1.0
40 | * torch==1.2.0
41 | * torchvision==0.2.2
42 |
43 | ## Setup
44 |
45 | 1. Install PyTorch and other required python libraries with:
46 |
47 | ```
48 | pip install -r requirements.txt
49 | ```
50 |
51 | 2. Download CIFAR-10-C and CIFAR-100-C datasets with:
52 |
53 | ```
54 | mkdir -p ./data/cifar
55 | curl -O https://zenodo.org/record/2535967/files/CIFAR-10-C.tar
56 | curl -O https://zenodo.org/record/3555552/files/CIFAR-100-C.tar
57 | tar -xvf CIFAR-100-C.tar -C data/cifar/
58 | tar -xvf CIFAR-10-C.tar -C data/cifar/
59 | ```
60 |
61 | 3. Download ImageNet-C with:
62 |
63 | ```
64 | mkdir -p ./data/imagenet/imagenet-c
65 | curl -O https://zenodo.org/record/2235448/files/blur.tar
66 | curl -O https://zenodo.org/record/2235448/files/digital.tar
67 | curl -O https://zenodo.org/record/2235448/files/noise.tar
68 | curl -O https://zenodo.org/record/2235448/files/weather.tar
69 | tar -xvf blur.tar -C data/imagenet/imagenet-c
70 | tar -xvf digital.tar -C data/imagenet/imagenet-c
71 | tar -xvf noise.tar -C data/imagenet/imagenet-c
72 | tar -xvf weather.tar -C data/imagenet/imagenet-c
73 | ```
74 |
75 | ## Usage
76 |
77 | The Jensen-Shannon Divergence loss term may be disabled for faster training at the cost of slightly lower performance by adding the flag `--no-jsd`.
78 |
79 | Training recipes used in our paper:
80 |
81 | WRN: `python cifar.py`
82 |
83 | AllConv: `python cifar.py -m allconv`
84 |
85 | ResNeXt: `python cifar.py -m resnext -e 200`
86 |
87 | DenseNet: `python cifar.py -m densenet -e 200 -wd 0.0001`
88 |
89 | ResNet-50: `python imagenet.py `
90 |
91 | ## Pretrained weights
92 |
93 | Weights for a ResNet-50 ImageNet classifier trained with AugMix for 180 epochs are available
94 | [here](https://drive.google.com/file/d/1z-1V3rdFiwqSECz7Wkmn4VJVefJGJGiF/view?usp=sharing).
95 |
96 | This model has a 65.3 mean Corruption Error (mCE) and a 77.53% top-1 accuracy on clean ImageNet data.
97 |
98 | ## Citation
99 |
100 | If you find this useful for your work, please consider citing
101 |
102 | ```
103 | @article{hendrycks2020augmix,
104 | title={{AugMix}: A Simple Data Processing Method to Improve Robustness and Uncertainty},
105 | author={Hendrycks, Dan and Mu, Norman and Cubuk, Ekin D. and Zoph, Barret and Gilmer, Justin and Lakshminarayanan, Balaji},
106 | journal={Proceedings of the International Conference on Learning Representations (ICLR)},
107 | year={2020}
108 | }
109 | ```
110 |
--------------------------------------------------------------------------------
/assets/augmix.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/augmix/9b9824c7c19bf7e72df2d085d97b99b3bfb00ba4/assets/augmix.gif
--------------------------------------------------------------------------------
/assets/pseudocode.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/augmix/9b9824c7c19bf7e72df2d085d97b99b3bfb00ba4/assets/pseudocode.png
--------------------------------------------------------------------------------
/augment_and_mix.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Reference implementation of AugMix's data augmentation method in numpy."""
16 | import augmentations
17 | import numpy as np
18 | from PIL import Image
19 |
20 | # CIFAR-10 constants
21 | MEAN = [0.4914, 0.4822, 0.4465]
22 | STD = [0.2023, 0.1994, 0.2010]
23 |
24 |
25 | def normalize(image):
26 | """Normalize input image channel-wise to zero mean and unit variance."""
27 | image = image.transpose(2, 0, 1) # Switch to channel-first
28 | mean, std = np.array(MEAN), np.array(STD)
29 | image = (image - mean[:, None, None]) / std[:, None, None]
30 | return image.transpose(1, 2, 0)
31 |
32 |
33 | def apply_op(image, op, severity):
34 | image = np.clip(image * 255., 0, 255).astype(np.uint8)
35 | pil_img = Image.fromarray(image) # Convert to PIL.Image
36 | pil_img = op(pil_img, severity)
37 | return np.asarray(pil_img) / 255.
38 |
39 |
40 | def augment_and_mix(image, severity=3, width=3, depth=-1, alpha=1.):
41 | """Perform AugMix augmentations and compute mixture.
42 |
43 | Args:
44 | image: Raw input image as float32 np.ndarray of shape (h, w, c)
45 | severity: Severity of underlying augmentation operators (between 1 to 10).
46 | width: Width of augmentation chain
47 | depth: Depth of augmentation chain. -1 enables stochastic depth uniformly
48 | from [1, 3]
49 | alpha: Probability coefficient for Beta and Dirichlet distributions.
50 |
51 | Returns:
52 | mixed: Augmented and mixed image.
53 | """
54 | ws = np.float32(
55 | np.random.dirichlet([alpha] * width))
56 | m = np.float32(np.random.beta(alpha, alpha))
57 |
58 | mix = np.zeros_like(image)
59 | for i in range(width):
60 | image_aug = image.copy()
61 | d = depth if depth > 0 else np.random.randint(1, 4)
62 | for _ in range(d):
63 | op = np.random.choice(augmentations.augmentations)
64 | image_aug = apply_op(image_aug, op, severity)
65 | # Preprocessing commutes since all coefficients are convex
66 | mix += ws[i] * normalize(image_aug)
67 |
68 | mixed = (1 - m) * normalize(image) + m * mix
69 | return mixed
70 |
71 |
--------------------------------------------------------------------------------
/augmentations.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Base augmentations operators."""
16 |
17 | import numpy as np
18 | from PIL import Image, ImageOps, ImageEnhance
19 |
20 | # ImageNet code should change this value
21 | IMAGE_SIZE = 32
22 |
23 |
24 | def int_parameter(level, maxval):
25 | """Helper function to scale `val` between 0 and maxval .
26 |
27 | Args:
28 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
29 | maxval: Maximum value that the operation can have. This will be scaled to
30 | level/PARAMETER_MAX.
31 |
32 | Returns:
33 | An int that results from scaling `maxval` according to `level`.
34 | """
35 | return int(level * maxval / 10)
36 |
37 |
38 | def float_parameter(level, maxval):
39 | """Helper function to scale `val` between 0 and maxval.
40 |
41 | Args:
42 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
43 | maxval: Maximum value that the operation can have. This will be scaled to
44 | level/PARAMETER_MAX.
45 |
46 | Returns:
47 | A float that results from scaling `maxval` according to `level`.
48 | """
49 | return float(level) * maxval / 10.
50 |
51 |
52 | def sample_level(n):
53 | return np.random.uniform(low=0.1, high=n)
54 |
55 |
56 | def autocontrast(pil_img, _):
57 | return ImageOps.autocontrast(pil_img)
58 |
59 |
60 | def equalize(pil_img, _):
61 | return ImageOps.equalize(pil_img)
62 |
63 |
64 | def posterize(pil_img, level):
65 | level = int_parameter(sample_level(level), 4)
66 | return ImageOps.posterize(pil_img, 4 - level)
67 |
68 |
69 | def rotate(pil_img, level):
70 | degrees = int_parameter(sample_level(level), 30)
71 | if np.random.uniform() > 0.5:
72 | degrees = -degrees
73 | return pil_img.rotate(degrees, resample=Image.BILINEAR)
74 |
75 |
76 | def solarize(pil_img, level):
77 | level = int_parameter(sample_level(level), 256)
78 | return ImageOps.solarize(pil_img, 256 - level)
79 |
80 |
81 | def shear_x(pil_img, level):
82 | level = float_parameter(sample_level(level), 0.3)
83 | if np.random.uniform() > 0.5:
84 | level = -level
85 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
86 | Image.AFFINE, (1, level, 0, 0, 1, 0),
87 | resample=Image.BILINEAR)
88 |
89 |
90 | def shear_y(pil_img, level):
91 | level = float_parameter(sample_level(level), 0.3)
92 | if np.random.uniform() > 0.5:
93 | level = -level
94 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
95 | Image.AFFINE, (1, 0, 0, level, 1, 0),
96 | resample=Image.BILINEAR)
97 |
98 |
99 | def translate_x(pil_img, level):
100 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3)
101 | if np.random.random() > 0.5:
102 | level = -level
103 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
104 | Image.AFFINE, (1, 0, level, 0, 1, 0),
105 | resample=Image.BILINEAR)
106 |
107 |
108 | def translate_y(pil_img, level):
109 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3)
110 | if np.random.random() > 0.5:
111 | level = -level
112 | return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE),
113 | Image.AFFINE, (1, 0, 0, 0, 1, level),
114 | resample=Image.BILINEAR)
115 |
116 |
117 | # operation that overlaps with ImageNet-C's test set
118 | def color(pil_img, level):
119 | level = float_parameter(sample_level(level), 1.8) + 0.1
120 | return ImageEnhance.Color(pil_img).enhance(level)
121 |
122 |
123 | # operation that overlaps with ImageNet-C's test set
124 | def contrast(pil_img, level):
125 | level = float_parameter(sample_level(level), 1.8) + 0.1
126 | return ImageEnhance.Contrast(pil_img).enhance(level)
127 |
128 |
129 | # operation that overlaps with ImageNet-C's test set
130 | def brightness(pil_img, level):
131 | level = float_parameter(sample_level(level), 1.8) + 0.1
132 | return ImageEnhance.Brightness(pil_img).enhance(level)
133 |
134 |
135 | # operation that overlaps with ImageNet-C's test set
136 | def sharpness(pil_img, level):
137 | level = float_parameter(sample_level(level), 1.8) + 0.1
138 | return ImageEnhance.Sharpness(pil_img).enhance(level)
139 |
140 |
141 | augmentations = [
142 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
143 | translate_x, translate_y
144 | ]
145 |
146 | augmentations_all = [
147 | autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y,
148 | translate_x, translate_y, color, contrast, brightness, sharpness
149 | ]
150 |
--------------------------------------------------------------------------------
/cifar.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Main script to launch AugMix training on CIFAR-10/100.
16 |
17 | Supports WideResNet, AllConv, ResNeXt models on CIFAR-10 and CIFAR-100 as well
18 | as evaluation on CIFAR-10-C and CIFAR-100-C.
19 |
20 | Example usage:
21 | `python cifar.py`
22 | """
23 | from __future__ import print_function
24 |
25 | import argparse
26 | import os
27 | import shutil
28 | import time
29 |
30 | import augmentations
31 | from models.cifar.allconv import AllConvNet
32 | import numpy as np
33 | from third_party.ResNeXt_DenseNet.models.densenet import densenet
34 | from third_party.ResNeXt_DenseNet.models.resnext import resnext29
35 | from third_party.WideResNet_pytorch.wideresnet import WideResNet
36 |
37 | import torch
38 | import torch.backends.cudnn as cudnn
39 | import torch.nn.functional as F
40 | from torchvision import datasets
41 | from torchvision import transforms
42 |
43 | parser = argparse.ArgumentParser(
44 | description='Trains a CIFAR Classifier',
45 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
46 | parser.add_argument(
47 | '--dataset',
48 | type=str,
49 | default='cifar10',
50 | choices=['cifar10', 'cifar100'],
51 | help='Choose between CIFAR-10, CIFAR-100.')
52 | parser.add_argument(
53 | '--model',
54 | '-m',
55 | type=str,
56 | default='wrn',
57 | choices=['wrn', 'allconv', 'densenet', 'resnext'],
58 | help='Choose architecture.')
59 | # Optimization options
60 | parser.add_argument(
61 | '--epochs', '-e', type=int, default=100, help='Number of epochs to train.')
62 | parser.add_argument(
63 | '--learning-rate',
64 | '-lr',
65 | type=float,
66 | default=0.1,
67 | help='Initial learning rate.')
68 | parser.add_argument(
69 | '--batch-size', '-b', type=int, default=128, help='Batch size.')
70 | parser.add_argument('--eval-batch-size', type=int, default=1000)
71 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.')
72 | parser.add_argument(
73 | '--decay',
74 | '-wd',
75 | type=float,
76 | default=0.0005,
77 | help='Weight decay (L2 penalty).')
78 | # WRN Architecture options
79 | parser.add_argument(
80 | '--layers', default=40, type=int, help='total number of layers')
81 | parser.add_argument('--widen-factor', default=2, type=int, help='Widen factor')
82 | parser.add_argument(
83 | '--droprate', default=0.0, type=float, help='Dropout probability')
84 | # AugMix options
85 | parser.add_argument(
86 | '--mixture-width',
87 | default=3,
88 | type=int,
89 | help='Number of augmentation chains to mix per augmented example')
90 | parser.add_argument(
91 | '--mixture-depth',
92 | default=-1,
93 | type=int,
94 | help='Depth of augmentation chains. -1 denotes stochastic depth in [1, 3]')
95 | parser.add_argument(
96 | '--aug-severity',
97 | default=3,
98 | type=int,
99 | help='Severity of base augmentation operators')
100 | parser.add_argument(
101 | '--no-jsd',
102 | '-nj',
103 | action='store_true',
104 | help='Turn off JSD consistency loss.')
105 | parser.add_argument(
106 | '--all-ops',
107 | '-all',
108 | action='store_true',
109 | help='Turn on all operations (+brightness,contrast,color,sharpness).')
110 | # Checkpointing options
111 | parser.add_argument(
112 | '--save',
113 | '-s',
114 | type=str,
115 | default='./snapshots',
116 | help='Folder to save checkpoints.')
117 | parser.add_argument(
118 | '--resume',
119 | '-r',
120 | type=str,
121 | default='',
122 | help='Checkpoint path for resume / test.')
123 | parser.add_argument('--evaluate', action='store_true', help='Eval only.')
124 | parser.add_argument(
125 | '--print-freq',
126 | type=int,
127 | default=50,
128 | help='Training loss print frequency (batches).')
129 | # Acceleration
130 | parser.add_argument(
131 | '--num-workers',
132 | type=int,
133 | default=4,
134 | help='Number of pre-fetching threads.')
135 |
136 | args = parser.parse_args()
137 |
138 | CORRUPTIONS = [
139 | 'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
140 | 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
141 | 'brightness', 'contrast', 'elastic_transform', 'pixelate',
142 | 'jpeg_compression'
143 | ]
144 |
145 |
146 | def get_lr(step, total_steps, lr_max, lr_min):
147 | """Compute learning rate according to cosine annealing schedule."""
148 | return lr_min + (lr_max - lr_min) * 0.5 * (1 +
149 | np.cos(step / total_steps * np.pi))
150 |
151 |
152 | def aug(image, preprocess):
153 | """Perform AugMix augmentations and compute mixture.
154 |
155 | Args:
156 | image: PIL.Image input image
157 | preprocess: Preprocessing function which should return a torch tensor.
158 |
159 | Returns:
160 | mixed: Augmented and mixed image.
161 | """
162 | aug_list = augmentations.augmentations
163 | if args.all_ops:
164 | aug_list = augmentations.augmentations_all
165 |
166 | ws = np.float32(np.random.dirichlet([1] * args.mixture_width))
167 | m = np.float32(np.random.beta(1, 1))
168 |
169 | mix = torch.zeros_like(preprocess(image))
170 | for i in range(args.mixture_width):
171 | image_aug = image.copy()
172 | depth = args.mixture_depth if args.mixture_depth > 0 else np.random.randint(
173 | 1, 4)
174 | for _ in range(depth):
175 | op = np.random.choice(aug_list)
176 | image_aug = op(image_aug, args.aug_severity)
177 | # Preprocessing commutes since all coefficients are convex
178 | mix += ws[i] * preprocess(image_aug)
179 |
180 | mixed = (1 - m) * preprocess(image) + m * mix
181 | return mixed
182 |
183 |
184 | class AugMixDataset(torch.utils.data.Dataset):
185 | """Dataset wrapper to perform AugMix augmentation."""
186 |
187 | def __init__(self, dataset, preprocess, no_jsd=False):
188 | self.dataset = dataset
189 | self.preprocess = preprocess
190 | self.no_jsd = no_jsd
191 |
192 | def __getitem__(self, i):
193 | x, y = self.dataset[i]
194 | if self.no_jsd:
195 | return aug(x, self.preprocess), y
196 | else:
197 | im_tuple = (self.preprocess(x), aug(x, self.preprocess),
198 | aug(x, self.preprocess))
199 | return im_tuple, y
200 |
201 | def __len__(self):
202 | return len(self.dataset)
203 |
204 |
205 | def train(net, train_loader, optimizer, scheduler):
206 | """Train for one epoch."""
207 | net.train()
208 | loss_ema = 0.
209 | for i, (images, targets) in enumerate(train_loader):
210 | optimizer.zero_grad()
211 |
212 | if args.no_jsd:
213 | images = images.cuda()
214 | targets = targets.cuda()
215 | logits = net(images)
216 | loss = F.cross_entropy(logits, targets)
217 | else:
218 | images_all = torch.cat(images, 0).cuda()
219 | targets = targets.cuda()
220 | logits_all = net(images_all)
221 | logits_clean, logits_aug1, logits_aug2 = torch.split(
222 | logits_all, images[0].size(0))
223 |
224 | # Cross-entropy is only computed on clean images
225 | loss = F.cross_entropy(logits_clean, targets)
226 |
227 | p_clean, p_aug1, p_aug2 = F.softmax(
228 | logits_clean, dim=1), F.softmax(
229 | logits_aug1, dim=1), F.softmax(
230 | logits_aug2, dim=1)
231 |
232 | # Clamp mixture distribution to avoid exploding KL divergence
233 | p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log()
234 | loss += 12 * (F.kl_div(p_mixture, p_clean, reduction='batchmean') +
235 | F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
236 | F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.
237 |
238 | loss.backward()
239 | optimizer.step()
240 | scheduler.step()
241 | loss_ema = loss_ema * 0.9 + float(loss) * 0.1
242 | if i % args.print_freq == 0:
243 | print('Train Loss {:.3f}'.format(loss_ema))
244 |
245 | return loss_ema
246 |
247 |
248 | def test(net, test_loader):
249 | """Evaluate network on given dataset."""
250 | net.eval()
251 | total_loss = 0.
252 | total_correct = 0
253 | with torch.no_grad():
254 | for images, targets in test_loader:
255 | images, targets = images.cuda(), targets.cuda()
256 | logits = net(images)
257 | loss = F.cross_entropy(logits, targets)
258 | pred = logits.data.max(1)[1]
259 | total_loss += float(loss.data)
260 | total_correct += pred.eq(targets.data).sum().item()
261 |
262 | return total_loss / len(test_loader.dataset), total_correct / len(
263 | test_loader.dataset)
264 |
265 |
266 | def test_c(net, test_data, base_path):
267 | """Evaluate network on given corrupted dataset."""
268 | corruption_accs = []
269 | for corruption in CORRUPTIONS:
270 | # Reference to original data is mutated
271 | test_data.data = np.load(base_path + corruption + '.npy')
272 | test_data.targets = torch.LongTensor(np.load(base_path + 'labels.npy'))
273 |
274 | test_loader = torch.utils.data.DataLoader(
275 | test_data,
276 | batch_size=args.eval_batch_size,
277 | shuffle=False,
278 | num_workers=args.num_workers,
279 | pin_memory=True)
280 |
281 | test_loss, test_acc = test(net, test_loader)
282 | corruption_accs.append(test_acc)
283 | print('{}\n\tTest Loss {:.3f} | Test Error {:.3f}'.format(
284 | corruption, test_loss, 100 - 100. * test_acc))
285 |
286 | return np.mean(corruption_accs)
287 |
288 |
289 | def main():
290 | torch.manual_seed(1)
291 | np.random.seed(1)
292 |
293 | # Load datasets
294 | train_transform = transforms.Compose(
295 | [transforms.RandomHorizontalFlip(),
296 | transforms.RandomCrop(32, padding=4)])
297 | preprocess = transforms.Compose(
298 | [transforms.ToTensor(),
299 | transforms.Normalize([0.5] * 3, [0.5] * 3)])
300 | test_transform = preprocess
301 |
302 | if args.dataset == 'cifar10':
303 | train_data = datasets.CIFAR10(
304 | './data/cifar', train=True, transform=train_transform, download=True)
305 | test_data = datasets.CIFAR10(
306 | './data/cifar', train=False, transform=test_transform, download=True)
307 | base_c_path = './data/cifar/CIFAR-10-C/'
308 | num_classes = 10
309 | else:
310 | train_data = datasets.CIFAR100(
311 | './data/cifar', train=True, transform=train_transform, download=True)
312 | test_data = datasets.CIFAR100(
313 | './data/cifar', train=False, transform=test_transform, download=True)
314 | base_c_path = './data/cifar/CIFAR-100-C/'
315 | num_classes = 100
316 |
317 | train_data = AugMixDataset(train_data, preprocess, args.no_jsd)
318 | train_loader = torch.utils.data.DataLoader(
319 | train_data,
320 | batch_size=args.batch_size,
321 | shuffle=True,
322 | num_workers=args.num_workers,
323 | pin_memory=True)
324 |
325 | test_loader = torch.utils.data.DataLoader(
326 | test_data,
327 | batch_size=args.eval_batch_size,
328 | shuffle=False,
329 | num_workers=args.num_workers,
330 | pin_memory=True)
331 |
332 | # Create model
333 | if args.model == 'densenet':
334 | net = densenet(num_classes=num_classes)
335 | elif args.model == 'wrn':
336 | net = WideResNet(args.layers, num_classes, args.widen_factor, args.droprate)
337 | elif args.model == 'allconv':
338 | net = AllConvNet(num_classes)
339 | elif args.model == 'resnext':
340 | net = resnext29(num_classes=num_classes)
341 |
342 | optimizer = torch.optim.SGD(
343 | net.parameters(),
344 | args.learning_rate,
345 | momentum=args.momentum,
346 | weight_decay=args.decay,
347 | nesterov=True)
348 |
349 | # Distribute model across all visible GPUs
350 | net = torch.nn.DataParallel(net).cuda()
351 | cudnn.benchmark = True
352 |
353 | start_epoch = 0
354 |
355 | if args.resume:
356 | if os.path.isfile(args.resume):
357 | checkpoint = torch.load(args.resume)
358 | start_epoch = checkpoint['epoch'] + 1
359 | best_acc = checkpoint['best_acc']
360 | net.load_state_dict(checkpoint['state_dict'])
361 | optimizer.load_state_dict(checkpoint['optimizer'])
362 | print('Model restored from epoch:', start_epoch)
363 |
364 | if args.evaluate:
365 | # Evaluate clean accuracy first because test_c mutates underlying data
366 | test_loss, test_acc = test(net, test_loader)
367 | print('Clean\n\tTest Loss {:.3f} | Test Error {:.2f}'.format(
368 | test_loss, 100 - 100. * test_acc))
369 |
370 | test_c_acc = test_c(net, test_data, base_c_path)
371 | print('Mean Corruption Error: {:.3f}'.format(100 - 100. * test_c_acc))
372 | return
373 |
374 | scheduler = torch.optim.lr_scheduler.LambdaLR(
375 | optimizer,
376 | lr_lambda=lambda step: get_lr( # pylint: disable=g-long-lambda
377 | step,
378 | args.epochs * len(train_loader),
379 | 1, # lr_lambda computes multiplicative factor
380 | 1e-6 / args.learning_rate))
381 |
382 | if not os.path.exists(args.save):
383 | os.makedirs(args.save)
384 | if not os.path.isdir(args.save):
385 | raise Exception('%s is not a dir' % args.save)
386 |
387 | log_path = os.path.join(args.save,
388 | args.dataset + '_' + args.model + '_training_log.csv')
389 | with open(log_path, 'w') as f:
390 | f.write('epoch,time(s),train_loss,test_loss,test_error(%)\n')
391 |
392 | best_acc = 0
393 | print('Beginning training from epoch:', start_epoch + 1)
394 | for epoch in range(start_epoch, args.epochs):
395 | begin_time = time.time()
396 |
397 | train_loss_ema = train(net, train_loader, optimizer, scheduler)
398 | test_loss, test_acc = test(net, test_loader)
399 |
400 | is_best = test_acc > best_acc
401 | best_acc = max(test_acc, best_acc)
402 | checkpoint = {
403 | 'epoch': epoch,
404 | 'dataset': args.dataset,
405 | 'model': args.model,
406 | 'state_dict': net.state_dict(),
407 | 'best_acc': best_acc,
408 | 'optimizer': optimizer.state_dict(),
409 | }
410 |
411 | save_path = os.path.join(args.save, 'checkpoint.pth.tar')
412 | torch.save(checkpoint, save_path)
413 | if is_best:
414 | shutil.copyfile(save_path, os.path.join(args.save, 'model_best.pth.tar'))
415 |
416 | with open(log_path, 'a') as f:
417 | f.write('%03d,%05d,%0.6f,%0.5f,%0.2f\n' % (
418 | (epoch + 1),
419 | time.time() - begin_time,
420 | train_loss_ema,
421 | test_loss,
422 | 100 - 100. * test_acc,
423 | ))
424 |
425 | print(
426 | 'Epoch {0:3d} | Time {1:5d} | Train Loss {2:.4f} | Test Loss {3:.3f} |'
427 | ' Test Error {4:.2f}'
428 | .format((epoch + 1), int(time.time() - begin_time), train_loss_ema,
429 | test_loss, 100 - 100. * test_acc))
430 |
431 | test_c_acc = test_c(net, test_data, base_c_path)
432 | print('Mean Corruption Error: {:.3f}'.format(100 - 100. * test_c_acc))
433 |
434 | with open(log_path, 'a') as f:
435 | f.write('%03d,%05d,%0.6f,%0.5f,%0.2f\n' %
436 | (args.epochs + 1, 0, 0, 0, 100 - 100 * test_c_acc))
437 |
438 |
439 | if __name__ == '__main__':
440 | main()
441 |
--------------------------------------------------------------------------------
/imagenet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Main script to launch AugMix training on ImageNet.
16 |
17 | Currently only supports ResNet-50 training.
18 |
19 | Example usage:
20 | `python imagenet.py `
21 | """
22 | from __future__ import print_function
23 |
24 | import argparse
25 | import os
26 | import shutil
27 | import time
28 |
29 | import augmentations
30 |
31 | import numpy as np
32 | import torch
33 | import torch.backends.cudnn as cudnn
34 | import torch.nn.functional as F
35 | from torchvision import datasets
36 | from torchvision import models
37 | from torchvision import transforms
38 |
39 | augmentations.IMAGE_SIZE = 224
40 |
41 | model_names = sorted(name for name in models.__dict__
42 | if name.islower() and not name.startswith('__') and
43 | callable(models.__dict__[name]))
44 |
45 | parser = argparse.ArgumentParser(description='Trains an ImageNet Classifier')
46 | parser.add_argument(
47 | 'clean_data', metavar='DIR', help='path to clean ImageNet dataset')
48 | parser.add_argument(
49 | 'corrupted_data', metavar='DIR_C', help='path to ImageNet-C dataset')
50 | parser.add_argument(
51 | '--model',
52 | '-m',
53 | default='resnet50',
54 | choices=model_names,
55 | help='model architecture: ' + ' | '.join(model_names) +
56 | ' (default: resnet50)')
57 | # Optimization options
58 | parser.add_argument(
59 | '--epochs', '-e', type=int, default=90, help='Number of epochs to train.')
60 | parser.add_argument(
61 | '--learning-rate',
62 | '-lr',
63 | type=float,
64 | default=0.1,
65 | help='Initial learning rate.')
66 | parser.add_argument(
67 | '--batch-size', '-b', type=int, default=256, help='Batch size.')
68 | parser.add_argument('--eval-batch-size', type=int, default=1000)
69 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.')
70 | parser.add_argument(
71 | '--decay',
72 | '-wd',
73 | type=float,
74 | default=0.0001,
75 | help='Weight decay (L2 penalty).')
76 | # AugMix options
77 | parser.add_argument(
78 | '--mixture-width',
79 | default=3,
80 | type=int,
81 | help='Number of augmentation chains to mix per augmented example')
82 | parser.add_argument(
83 | '--mixture-depth',
84 | default=-1,
85 | type=int,
86 | help='Depth of augmentation chains. -1 denotes stochastic depth in [1, 3]')
87 | parser.add_argument(
88 | '--aug-severity',
89 | default=1,
90 | type=int,
91 | help='Severity of base augmentation operators')
92 | parser.add_argument(
93 | '--aug-prob-coeff',
94 | default=1.,
95 | type=float,
96 | help='Probability distribution coefficients')
97 | parser.add_argument(
98 | '--no-jsd',
99 | '-nj',
100 | action='store_true',
101 | help='Turn off JSD consistency loss.')
102 | parser.add_argument(
103 | '--all-ops',
104 | '-all',
105 | action='store_true',
106 | help='Turn on all operations (+brightness,contrast,color,sharpness).')
107 | # Checkpointing options
108 | parser.add_argument(
109 | '--save',
110 | '-s',
111 | type=str,
112 | default='./snapshots',
113 | help='Folder to save checkpoints.')
114 | parser.add_argument(
115 | '--resume',
116 | '-r',
117 | type=str,
118 | default='',
119 | help='Checkpoint path for resume / test.')
120 | parser.add_argument('--evaluate', action='store_true', help='Eval only.')
121 | parser.add_argument(
122 | '--print-freq',
123 | type=int,
124 | default=10,
125 | help='Training loss print frequency (batches).')
126 | parser.add_argument(
127 | '--pretrained',
128 | dest='pretrained',
129 | action='store_true',
130 | help='use pre-trained model')
131 | # Acceleration
132 | parser.add_argument(
133 | '--num-workers',
134 | type=int,
135 | default=4,
136 | help='Number of pre-fetching threads.')
137 |
138 | args = parser.parse_args()
139 |
140 | CORRUPTIONS = [
141 | 'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur',
142 | 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
143 | 'brightness', 'contrast', 'elastic_transform', 'pixelate',
144 | 'jpeg_compression'
145 | ]
146 |
147 | # Raw AlexNet errors taken from https://github.com/hendrycks/robustness
148 | ALEXNET_ERR = [
149 | 0.886428, 0.894468, 0.922640, 0.819880, 0.826268, 0.785948, 0.798360,
150 | 0.866816, 0.826572, 0.819324, 0.564592, 0.853204, 0.646056, 0.717840,
151 | 0.606500
152 | ]
153 |
154 |
155 | def adjust_learning_rate(optimizer, epoch):
156 | """Sets the learning rate to the initial LR (linearly scaled to batch size) decayed by 10 every n / 3 epochs."""
157 | b = args.batch_size / 256.
158 | k = args.epochs // 3
159 | if epoch < k:
160 | m = 1
161 | elif epoch < 2 * k:
162 | m = 0.1
163 | else:
164 | m = 0.01
165 | lr = args.learning_rate * m * b
166 | for param_group in optimizer.param_groups:
167 | param_group['lr'] = lr
168 |
169 |
170 | def accuracy(output, target, topk=(1,)):
171 | """Computes the accuracy over the k top predictions for the specified values of k."""
172 | with torch.no_grad():
173 | maxk = max(topk)
174 | batch_size = target.size(0)
175 |
176 | _, pred = output.topk(maxk, 1, True, True)
177 | pred = pred.t()
178 | correct = pred.eq(target.view(1, -1).expand_as(pred))
179 |
180 | res = []
181 | for k in topk:
182 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
183 | res.append(correct_k.mul_(100.0 / batch_size))
184 | return res
185 |
186 |
187 | def compute_mce(corruption_accs):
188 | """Compute mCE (mean Corruption Error) normalized by AlexNet performance."""
189 | mce = 0.
190 | for i in range(len(CORRUPTIONS)):
191 | avg_err = 1 - np.mean(corruption_accs[CORRUPTIONS[i]])
192 | ce = 100 * avg_err / ALEXNET_ERR[i]
193 | mce += ce / 15
194 | return mce
195 |
196 |
197 | def aug(image, preprocess):
198 | """Perform AugMix augmentations and compute mixture.
199 |
200 | Args:
201 | image: PIL.Image input image
202 | preprocess: Preprocessing function which should return a torch tensor.
203 |
204 | Returns:
205 | mixed: Augmented and mixed image.
206 | """
207 | aug_list = augmentations.augmentations
208 | if args.all_ops:
209 | aug_list = augmentations.augmentations_all
210 |
211 | ws = np.float32(
212 | np.random.dirichlet([args.aug_prob_coeff] * args.mixture_width))
213 | m = np.float32(np.random.beta(args.aug_prob_coeff, args.aug_prob_coeff))
214 |
215 | mix = torch.zeros_like(preprocess(image))
216 | for i in range(args.mixture_width):
217 | image_aug = image.copy()
218 | depth = args.mixture_depth if args.mixture_depth > 0 else np.random.randint(
219 | 1, 4)
220 | for _ in range(depth):
221 | op = np.random.choice(aug_list)
222 | image_aug = op(image_aug, args.aug_severity)
223 | # Preprocessing commutes since all coefficients are convex
224 | mix += ws[i] * preprocess(image_aug)
225 |
226 | mixed = (1 - m) * preprocess(image) + m * mix
227 | return mixed
228 |
229 |
230 | class AugMixDataset(torch.utils.data.Dataset):
231 | """Dataset wrapper to perform AugMix augmentation."""
232 |
233 | def __init__(self, dataset, preprocess, no_jsd=False):
234 | self.dataset = dataset
235 | self.preprocess = preprocess
236 | self.no_jsd = no_jsd
237 |
238 | def __getitem__(self, i):
239 | x, y = self.dataset[i]
240 | if self.no_jsd:
241 | return aug(x, self.preprocess), y
242 | else:
243 | im_tuple = (self.preprocess(x), aug(x, self.preprocess),
244 | aug(x, self.preprocess))
245 | return im_tuple, y
246 |
247 | def __len__(self):
248 | return len(self.dataset)
249 |
250 |
251 | def train(net, train_loader, optimizer):
252 | """Train for one epoch."""
253 | net.train()
254 | data_ema = 0.
255 | batch_ema = 0.
256 | loss_ema = 0.
257 | acc1_ema = 0.
258 | acc5_ema = 0.
259 |
260 | end = time.time()
261 | for i, (images, targets) in enumerate(train_loader):
262 | # Compute data loading time
263 | data_time = time.time() - end
264 | optimizer.zero_grad()
265 |
266 | if args.no_jsd:
267 | images = images.cuda()
268 | targets = targets.cuda()
269 | logits = net(images)
270 | loss = F.cross_entropy(logits, targets)
271 | acc1, acc5 = accuracy(logits, targets, topk=(1, 5)) # pylint: disable=unbalanced-tuple-unpacking
272 | else:
273 | images_all = torch.cat(images, 0).cuda()
274 | targets = targets.cuda()
275 | logits_all = net(images_all)
276 | logits_clean, logits_aug1, logits_aug2 = torch.split(
277 | logits_all, images[0].size(0))
278 |
279 | # Cross-entropy is only computed on clean images
280 | loss = F.cross_entropy(logits_clean, targets)
281 |
282 | p_clean, p_aug1, p_aug2 = F.softmax(
283 | logits_clean, dim=1), F.softmax(
284 | logits_aug1, dim=1), F.softmax(
285 | logits_aug2, dim=1)
286 |
287 | # Clamp mixture distribution to avoid exploding KL divergence
288 | p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log()
289 | loss += 12 * (F.kl_div(p_mixture, p_clean, reduction='batchmean') +
290 | F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
291 | F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.
292 | acc1, acc5 = accuracy(logits_clean, targets, topk=(1, 5)) # pylint: disable=unbalanced-tuple-unpacking
293 |
294 | loss.backward()
295 | optimizer.step()
296 |
297 | # Compute batch computation time and update moving averages.
298 | batch_time = time.time() - end
299 | end = time.time()
300 |
301 | data_ema = data_ema * 0.1 + float(data_time) * 0.9
302 | batch_ema = batch_ema * 0.1 + float(batch_time) * 0.9
303 | loss_ema = loss_ema * 0.1 + float(loss) * 0.9
304 | acc1_ema = acc1_ema * 0.1 + float(acc1) * 0.9
305 | acc5_ema = acc5_ema * 0.1 + float(acc5) * 0.9
306 |
307 | if i % args.print_freq == 0:
308 | print(
309 | 'Batch {}/{}: Data Time {:.3f} | Batch Time {:.3f} | Train Loss {:.3f} | Train Acc1 '
310 | '{:.3f} | Train Acc5 {:.3f}'.format(i, len(train_loader), data_ema,
311 | batch_ema, loss_ema, acc1_ema,
312 | acc5_ema))
313 |
314 | return loss_ema, acc1_ema, batch_ema
315 |
316 |
317 | def test(net, test_loader):
318 | """Evaluate network on given dataset."""
319 | net.eval()
320 | total_loss = 0.
321 | total_correct = 0
322 | with torch.no_grad():
323 | for images, targets in test_loader:
324 | images, targets = images.cuda(), targets.cuda()
325 | logits = net(images)
326 | loss = F.cross_entropy(logits, targets)
327 | pred = logits.data.max(1)[1]
328 | total_loss += float(loss.data)
329 | total_correct += pred.eq(targets.data).sum().item()
330 |
331 | return total_loss / len(test_loader.dataset), total_correct / len(
332 | test_loader.dataset)
333 |
334 |
335 | def test_c(net, test_transform):
336 | """Evaluate network on given corrupted dataset."""
337 | corruption_accs = {}
338 | for c in CORRUPTIONS:
339 | print(c)
340 | for s in range(1, 6):
341 | valdir = os.path.join(args.corrupted_data, c, str(s))
342 | val_loader = torch.utils.data.DataLoader(
343 | datasets.ImageFolder(valdir, test_transform),
344 | batch_size=args.eval_batch_size,
345 | shuffle=False,
346 | num_workers=args.num_workers,
347 | pin_memory=True)
348 |
349 | loss, acc1 = test(net, val_loader)
350 | if c in corruption_accs:
351 | corruption_accs[c].append(acc1)
352 | else:
353 | corruption_accs[c] = [acc1]
354 |
355 | print('\ts={}: Test Loss {:.3f} | Test Acc1 {:.3f}'.format(
356 | s, loss, 100. * acc1))
357 |
358 | return corruption_accs
359 |
360 |
361 | def main():
362 | torch.manual_seed(1)
363 | np.random.seed(1)
364 |
365 | # Load datasets
366 | mean = [0.485, 0.456, 0.406]
367 | std = [0.229, 0.224, 0.225]
368 | train_transform = transforms.Compose(
369 | [transforms.RandomResizedCrop(224),
370 | transforms.RandomHorizontalFlip()])
371 | preprocess = transforms.Compose(
372 | [transforms.ToTensor(),
373 | transforms.Normalize(mean, std)])
374 | test_transform = transforms.Compose([
375 | transforms.Resize(256),
376 | transforms.CenterCrop(224),
377 | preprocess,
378 | ])
379 |
380 | traindir = os.path.join(args.clean_data, 'train')
381 | valdir = os.path.join(args.clean_data, 'val')
382 | train_dataset = datasets.ImageFolder(traindir, train_transform)
383 | train_dataset = AugMixDataset(train_dataset, preprocess)
384 | train_loader = torch.utils.data.DataLoader(
385 | train_dataset,
386 | batch_size=args.batch_size,
387 | shuffle=True,
388 | num_workers=args.num_workers)
389 | val_loader = torch.utils.data.DataLoader(
390 | datasets.ImageFolder(valdir, test_transform),
391 | batch_size=args.batch_size,
392 | shuffle=False,
393 | num_workers=args.num_workers)
394 |
395 | if args.pretrained:
396 | print("=> using pre-trained model '{}'".format(args.model))
397 | net = models.__dict__[args.model](pretrained=True)
398 | else:
399 | print("=> creating model '{}'".format(args.model))
400 | net = models.__dict__[args.model]()
401 |
402 | optimizer = torch.optim.SGD(
403 | net.parameters(),
404 | args.learning_rate,
405 | momentum=args.momentum,
406 | weight_decay=args.decay)
407 |
408 | # Distribute model across all visible GPUs
409 | net = torch.nn.DataParallel(net).cuda()
410 | cudnn.benchmark = True
411 |
412 | start_epoch = 0
413 |
414 | if args.resume:
415 | if os.path.isfile(args.resume):
416 | checkpoint = torch.load(args.resume)
417 | start_epoch = checkpoint['epoch'] + 1
418 | best_acc1 = checkpoint['best_acc1']
419 | net.load_state_dict(checkpoint['state_dict'])
420 | optimizer.load_state_dict(checkpoint['optimizer'])
421 | print('Model restored from epoch:', start_epoch)
422 |
423 | if args.evaluate:
424 | test_loss, test_acc1 = test(net, val_loader)
425 | print('Clean\n\tTest Loss {:.3f} | Test Acc1 {:.3f}'.format(
426 | test_loss, 100 * test_acc1))
427 |
428 | corruption_accs = test_c(net, test_transform)
429 | for c in CORRUPTIONS:
430 | print('\t'.join([c] + map(str, corruption_accs[c])))
431 |
432 | print('mCE (normalized by AlexNet): ', compute_mce(corruption_accs))
433 | return
434 |
435 | if not os.path.exists(args.save):
436 | os.makedirs(args.save)
437 | if not os.path.isdir(args.save):
438 | raise Exception('%s is not a dir' % args.save)
439 |
440 | log_path = os.path.join(args.save,
441 | 'imagenet_{}_training_log.csv'.format(args.model))
442 | with open(log_path, 'w') as f:
443 | f.write(
444 | 'epoch,batch_time,train_loss,train_acc1(%),test_loss,test_acc1(%)\n')
445 |
446 | best_acc1 = 0
447 | print('Beginning training from epoch:', start_epoch + 1)
448 | for epoch in range(start_epoch, args.epochs):
449 | adjust_learning_rate(optimizer, epoch)
450 |
451 | train_loss_ema, train_acc1_ema, batch_ema = train(net, train_loader,
452 | optimizer)
453 | test_loss, test_acc1 = test(net, val_loader)
454 |
455 | is_best = test_acc1 > best_acc1
456 | best_acc1 = max(test_acc1, best_acc1)
457 | checkpoint = {
458 | 'epoch': epoch,
459 | 'model': args.model,
460 | 'state_dict': net.state_dict(),
461 | 'best_acc1': best_acc1,
462 | 'optimizer': optimizer.state_dict(),
463 | }
464 |
465 | save_path = os.path.join(args.save, 'checkpoint.pth.tar')
466 | torch.save(checkpoint, save_path)
467 | if is_best:
468 | shutil.copyfile(save_path, os.path.join(args.save, 'model_best.pth.tar'))
469 |
470 | with open(log_path, 'a') as f:
471 | f.write('%03d,%0.3f,%0.6f,%0.2f,%0.5f,%0.2f\n' % (
472 | (epoch + 1),
473 | batch_ema,
474 | train_loss_ema,
475 | 100. * train_acc1_ema,
476 | test_loss,
477 | 100. * test_acc1,
478 | ))
479 |
480 | print(
481 | 'Epoch {:3d} | Train Loss {:.4f} | Test Loss {:.3f} | Test Acc1 '
482 | '{:.2f}'
483 | .format((epoch + 1), train_loss_ema, test_loss, 100. * test_acc1))
484 |
485 | corruption_accs = test_c(net, test_transform)
486 | for c in CORRUPTIONS:
487 | print('\t'.join(map(str, [c] + corruption_accs[c])))
488 |
489 | print('mCE (normalized by AlexNet):', compute_mce(corruption_accs))
490 |
491 |
492 | if __name__ == '__main__':
493 | main()
494 |
--------------------------------------------------------------------------------
/models/cifar/allconv.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """AllConv implementation (https://arxiv.org/abs/1412.6806)."""
16 | import math
17 | import torch
18 | import torch.nn as nn
19 |
20 |
21 | class GELU(nn.Module):
22 |
23 | def forward(self, x):
24 | return torch.sigmoid(1.702 * x) * x
25 |
26 |
27 | def make_layers(cfg):
28 | """Create a single layer."""
29 | layers = []
30 | in_channels = 3
31 | for v in cfg:
32 | if v == 'Md':
33 | layers += [nn.MaxPool2d(kernel_size=2, stride=2), nn.Dropout(p=0.5)]
34 | elif v == 'A':
35 | layers += [nn.AvgPool2d(kernel_size=8)]
36 | elif v == 'NIN':
37 | conv2d = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=1)
38 | layers += [conv2d, nn.BatchNorm2d(in_channels), GELU()]
39 | elif v == 'nopad':
40 | conv2d = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=0)
41 | layers += [conv2d, nn.BatchNorm2d(in_channels), GELU()]
42 | else:
43 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
44 | layers += [conv2d, nn.BatchNorm2d(v), GELU()]
45 | in_channels = v
46 | return nn.Sequential(*layers)
47 |
48 |
49 | class AllConvNet(nn.Module):
50 | """AllConvNet main class."""
51 |
52 | def __init__(self, num_classes):
53 | super(AllConvNet, self).__init__()
54 |
55 | self.num_classes = num_classes
56 | self.width1, w1 = 96, 96
57 | self.width2, w2 = 192, 192
58 |
59 | self.features = make_layers(
60 | [w1, w1, w1, 'Md', w2, w2, w2, 'Md', 'nopad', 'NIN', 'NIN', 'A'])
61 | self.classifier = nn.Linear(self.width2, num_classes)
62 |
63 | for m in self.modules():
64 | if isinstance(m, nn.Conv2d):
65 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
66 | m.weight.data.normal_(0, math.sqrt(2. / n)) # He initialization
67 | elif isinstance(m, nn.BatchNorm2d):
68 | m.weight.data.fill_(1)
69 | m.bias.data.zero_()
70 | elif isinstance(m, nn.Linear):
71 | m.bias.data.zero_()
72 |
73 | def forward(self, x):
74 | x = self.features(x)
75 | x = x.view(x.size(0), -1)
76 | x = self.classifier(x)
77 | return x
78 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.15.0
2 | Pillow>=6.1.0
3 | torch==1.2.0
4 | torchvision==0.2.2
5 |
--------------------------------------------------------------------------------
/third_party/ResNeXt_DenseNet/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Xuanyi Dong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/third_party/ResNeXt_DenseNet/METADATA:
--------------------------------------------------------------------------------
1 | name: "ResNeXt-DenseNet"
2 | description: "PyTorch implementations of ResNeXt and DenseNet."
3 |
4 | third_party {
5 | url {
6 | type: GIT
7 | value: "https://github.com/D-X-Y/ResNeXt-DenseNet"
8 | }
9 | version: "0de9a8c8fd095b37eb60945f8dafefdbfe1cef6b"
10 | last_upgrade_date { year: 2019 month: 12 day: 4 }
11 | license_type: PERMISSIVE
12 | }
13 |
--------------------------------------------------------------------------------
/third_party/ResNeXt_DenseNet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/augmix/9b9824c7c19bf7e72df2d085d97b99b3bfb00ba4/third_party/ResNeXt_DenseNet/__init__.py
--------------------------------------------------------------------------------
/third_party/ResNeXt_DenseNet/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/augmix/9b9824c7c19bf7e72df2d085d97b99b3bfb00ba4/third_party/ResNeXt_DenseNet/models/__init__.py
--------------------------------------------------------------------------------
/third_party/ResNeXt_DenseNet/models/densenet.py:
--------------------------------------------------------------------------------
1 | """DenseNet implementation (https://arxiv.org/abs/1608.06993)."""
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class Bottleneck(nn.Module):
9 | """Bottleneck block for DenseNet."""
10 |
11 | def __init__(self, n_channels, growth_rate):
12 | super(Bottleneck, self).__init__()
13 | inter_channels = 4 * growth_rate
14 | self.bn1 = nn.BatchNorm2d(n_channels)
15 | self.conv1 = nn.Conv2d(
16 | n_channels, inter_channels, kernel_size=1, bias=False)
17 | self.bn2 = nn.BatchNorm2d(inter_channels)
18 | self.conv2 = nn.Conv2d(
19 | inter_channels, growth_rate, kernel_size=3, padding=1, bias=False)
20 |
21 | def forward(self, x):
22 | out = self.conv1(F.relu(self.bn1(x)))
23 | out = self.conv2(F.relu(self.bn2(out)))
24 | out = torch.cat((x, out), 1)
25 | return out
26 |
27 |
28 | class SingleLayer(nn.Module):
29 | """Layer container for blocks."""
30 |
31 | def __init__(self, n_channels, growth_rate):
32 | super(SingleLayer, self).__init__()
33 | self.bn1 = nn.BatchNorm2d(n_channels)
34 | self.conv1 = nn.Conv2d(
35 | n_channels, growth_rate, kernel_size=3, padding=1, bias=False)
36 |
37 | def forward(self, x):
38 | out = self.conv1(F.relu(self.bn1(x)))
39 | out = torch.cat((x, out), 1)
40 | return out
41 |
42 |
43 | class Transition(nn.Module):
44 | """Transition block."""
45 |
46 | def __init__(self, n_channels, n_out_channels):
47 | super(Transition, self).__init__()
48 | self.bn1 = nn.BatchNorm2d(n_channels)
49 | self.conv1 = nn.Conv2d(
50 | n_channels, n_out_channels, kernel_size=1, bias=False)
51 |
52 | def forward(self, x):
53 | out = self.conv1(F.relu(self.bn1(x)))
54 | out = F.avg_pool2d(out, 2)
55 | return out
56 |
57 |
58 | class DenseNet(nn.Module):
59 | """DenseNet main class."""
60 |
61 | def __init__(self, growth_rate, depth, reduction, n_classes, bottleneck):
62 | super(DenseNet, self).__init__()
63 |
64 | if bottleneck:
65 | n_dense_blocks = int((depth - 4) / 6)
66 | else:
67 | n_dense_blocks = int((depth - 4) / 3)
68 |
69 | n_channels = 2 * growth_rate
70 | self.conv1 = nn.Conv2d(3, n_channels, kernel_size=3, padding=1, bias=False)
71 |
72 | self.dense1 = self._make_dense(n_channels, growth_rate, n_dense_blocks,
73 | bottleneck)
74 | n_channels += n_dense_blocks * growth_rate
75 | n_out_channels = int(math.floor(n_channels * reduction))
76 | self.trans1 = Transition(n_channels, n_out_channels)
77 |
78 | n_channels = n_out_channels
79 | self.dense2 = self._make_dense(n_channels, growth_rate, n_dense_blocks,
80 | bottleneck)
81 | n_channels += n_dense_blocks * growth_rate
82 | n_out_channels = int(math.floor(n_channels * reduction))
83 | self.trans2 = Transition(n_channels, n_out_channels)
84 |
85 | n_channels = n_out_channels
86 | self.dense3 = self._make_dense(n_channels, growth_rate, n_dense_blocks,
87 | bottleneck)
88 | n_channels += n_dense_blocks * growth_rate
89 |
90 | self.bn1 = nn.BatchNorm2d(n_channels)
91 | self.fc = nn.Linear(n_channels, n_classes)
92 |
93 | for m in self.modules():
94 | if isinstance(m, nn.Conv2d):
95 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
96 | m.weight.data.normal_(0, math.sqrt(2. / n))
97 | elif isinstance(m, nn.BatchNorm2d):
98 | m.weight.data.fill_(1)
99 | m.bias.data.zero_()
100 | elif isinstance(m, nn.Linear):
101 | m.bias.data.zero_()
102 |
103 | def _make_dense(self, n_channels, growth_rate, n_dense_blocks, bottleneck):
104 | layers = []
105 | for _ in range(int(n_dense_blocks)):
106 | if bottleneck:
107 | layers.append(Bottleneck(n_channels, growth_rate))
108 | else:
109 | layers.append(SingleLayer(n_channels, growth_rate))
110 | n_channels += growth_rate
111 | return nn.Sequential(*layers)
112 |
113 | def forward(self, x):
114 | out = self.conv1(x)
115 | out = self.trans1(self.dense1(out))
116 | out = self.trans2(self.dense2(out))
117 | out = self.dense3(out)
118 | out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8))
119 | out = self.fc(out)
120 | return out
121 |
122 |
123 | def densenet(growth_rate=12, depth=40, num_classes=10):
124 | model = DenseNet(growth_rate, depth, 1., num_classes, False)
125 | return model
126 |
--------------------------------------------------------------------------------
/third_party/ResNeXt_DenseNet/models/resnext.py:
--------------------------------------------------------------------------------
1 | """ResNeXt implementation (https://arxiv.org/abs/1611.05431)."""
2 | import math
3 | import torch.nn as nn
4 | from torch.nn import init
5 | import torch.nn.functional as F
6 |
7 |
8 | class ResNeXtBottleneck(nn.Module):
9 | """ResNeXt Bottleneck Block type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua)."""
10 | expansion = 4
11 |
12 | def __init__(self,
13 | inplanes,
14 | planes,
15 | cardinality,
16 | base_width,
17 | stride=1,
18 | downsample=None):
19 | super(ResNeXtBottleneck, self).__init__()
20 |
21 | dim = int(math.floor(planes * (base_width / 64.0)))
22 |
23 | self.conv_reduce = nn.Conv2d(
24 | inplanes,
25 | dim * cardinality,
26 | kernel_size=1,
27 | stride=1,
28 | padding=0,
29 | bias=False)
30 | self.bn_reduce = nn.BatchNorm2d(dim * cardinality)
31 |
32 | self.conv_conv = nn.Conv2d(
33 | dim * cardinality,
34 | dim * cardinality,
35 | kernel_size=3,
36 | stride=stride,
37 | padding=1,
38 | groups=cardinality,
39 | bias=False)
40 | self.bn = nn.BatchNorm2d(dim * cardinality)
41 |
42 | self.conv_expand = nn.Conv2d(
43 | dim * cardinality,
44 | planes * 4,
45 | kernel_size=1,
46 | stride=1,
47 | padding=0,
48 | bias=False)
49 | self.bn_expand = nn.BatchNorm2d(planes * 4)
50 |
51 | self.downsample = downsample
52 |
53 | def forward(self, x):
54 | residual = x
55 |
56 | bottleneck = self.conv_reduce(x)
57 | bottleneck = F.relu(self.bn_reduce(bottleneck), inplace=True)
58 |
59 | bottleneck = self.conv_conv(bottleneck)
60 | bottleneck = F.relu(self.bn(bottleneck), inplace=True)
61 |
62 | bottleneck = self.conv_expand(bottleneck)
63 | bottleneck = self.bn_expand(bottleneck)
64 |
65 | if self.downsample is not None:
66 | residual = self.downsample(x)
67 |
68 | return F.relu(residual + bottleneck, inplace=True)
69 |
70 |
71 | class CifarResNeXt(nn.Module):
72 | """ResNext optimized for the Cifar dataset, as specified in https://arxiv.org/pdf/1611.05431.pdf."""
73 |
74 | def __init__(self, block, depth, cardinality, base_width, num_classes):
75 | super(CifarResNeXt, self).__init__()
76 |
77 | # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model
78 | assert (depth - 2) % 9 == 0, 'depth should be one of 29, 38, 47, 56, 101'
79 | layer_blocks = (depth - 2) // 9
80 |
81 | self.cardinality = cardinality
82 | self.base_width = base_width
83 | self.num_classes = num_classes
84 |
85 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
86 | self.bn_1 = nn.BatchNorm2d(64)
87 |
88 | self.inplanes = 64
89 | self.stage_1 = self._make_layer(block, 64, layer_blocks, 1)
90 | self.stage_2 = self._make_layer(block, 128, layer_blocks, 2)
91 | self.stage_3 = self._make_layer(block, 256, layer_blocks, 2)
92 | self.avgpool = nn.AvgPool2d(8)
93 | self.classifier = nn.Linear(256 * block.expansion, num_classes)
94 |
95 | for m in self.modules():
96 | if isinstance(m, nn.Conv2d):
97 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
98 | m.weight.data.normal_(0, math.sqrt(2. / n))
99 | elif isinstance(m, nn.BatchNorm2d):
100 | m.weight.data.fill_(1)
101 | m.bias.data.zero_()
102 | elif isinstance(m, nn.Linear):
103 | init.kaiming_normal(m.weight)
104 | m.bias.data.zero_()
105 |
106 | def _make_layer(self, block, planes, blocks, stride=1):
107 | downsample = None
108 | if stride != 1 or self.inplanes != planes * block.expansion:
109 | downsample = nn.Sequential(
110 | nn.Conv2d(
111 | self.inplanes,
112 | planes * block.expansion,
113 | kernel_size=1,
114 | stride=stride,
115 | bias=False),
116 | nn.BatchNorm2d(planes * block.expansion),
117 | )
118 |
119 | layers = []
120 | layers.append(
121 | block(self.inplanes, planes, self.cardinality, self.base_width, stride,
122 | downsample))
123 | self.inplanes = planes * block.expansion
124 | for _ in range(1, blocks):
125 | layers.append(
126 | block(self.inplanes, planes, self.cardinality, self.base_width))
127 |
128 | return nn.Sequential(*layers)
129 |
130 | def forward(self, x):
131 | x = self.conv_1_3x3(x)
132 | x = F.relu(self.bn_1(x), inplace=True)
133 | x = self.stage_1(x)
134 | x = self.stage_2(x)
135 | x = self.stage_3(x)
136 | x = self.avgpool(x)
137 | x = x.view(x.size(0), -1)
138 | return self.classifier(x)
139 |
140 |
141 | def resnext29(num_classes=10, cardinality=4, base_width=32):
142 | model = CifarResNeXt(ResNeXtBottleneck, 29, cardinality, base_width,
143 | num_classes)
144 | return model
145 |
--------------------------------------------------------------------------------
/third_party/WideResNet_pytorch/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 xternalz
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/third_party/WideResNet_pytorch/METADATA:
--------------------------------------------------------------------------------
1 | name: "WideResNet-pytorch"
2 | description: "PyTorch implementation of WideResNet."
3 |
4 | third_party {
5 | url {
6 | type: GIT
7 | value: "https://github.com/xternalz/WideResNet-pytorch"
8 | }
9 | version: "1171f93d5a9ae28eb5e603e5e7545f488d0df6ab"
10 | last_upgrade_date { year: 2019 month: 12 day: 4 }
11 | license_type: PERMISSIVE
12 | }
13 |
--------------------------------------------------------------------------------
/third_party/WideResNet_pytorch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/augmix/9b9824c7c19bf7e72df2d085d97b99b3bfb00ba4/third_party/WideResNet_pytorch/__init__.py
--------------------------------------------------------------------------------
/third_party/WideResNet_pytorch/wideresnet.py:
--------------------------------------------------------------------------------
1 | """WideResNet implementation (https://arxiv.org/abs/1605.07146)."""
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class BasicBlock(nn.Module):
9 | """Basic ResNet block."""
10 |
11 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0):
12 | super(BasicBlock, self).__init__()
13 | self.bn1 = nn.BatchNorm2d(in_planes)
14 | self.relu1 = nn.ReLU(inplace=True)
15 | self.conv1 = nn.Conv2d(
16 | in_planes,
17 | out_planes,
18 | kernel_size=3,
19 | stride=stride,
20 | padding=1,
21 | bias=False)
22 | self.bn2 = nn.BatchNorm2d(out_planes)
23 | self.relu2 = nn.ReLU(inplace=True)
24 | self.conv2 = nn.Conv2d(
25 | out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
26 | self.drop_rate = drop_rate
27 | self.is_in_equal_out = (in_planes == out_planes)
28 | self.conv_shortcut = (not self.is_in_equal_out) and nn.Conv2d(
29 | in_planes,
30 | out_planes,
31 | kernel_size=1,
32 | stride=stride,
33 | padding=0,
34 | bias=False) or None
35 |
36 | def forward(self, x):
37 | if not self.is_in_equal_out:
38 | x = self.relu1(self.bn1(x))
39 | else:
40 | out = self.relu1(self.bn1(x))
41 | if self.is_in_equal_out:
42 | out = self.relu2(self.bn2(self.conv1(out)))
43 | else:
44 | out = self.relu2(self.bn2(self.conv1(x)))
45 | if self.drop_rate > 0:
46 | out = F.dropout(out, p=self.drop_rate, training=self.training)
47 | out = self.conv2(out)
48 | if not self.is_in_equal_out:
49 | return torch.add(self.conv_shortcut(x), out)
50 | else:
51 | return torch.add(x, out)
52 |
53 |
54 | class NetworkBlock(nn.Module):
55 | """Layer container for blocks."""
56 |
57 | def __init__(self,
58 | nb_layers,
59 | in_planes,
60 | out_planes,
61 | block,
62 | stride,
63 | drop_rate=0.0):
64 | super(NetworkBlock, self).__init__()
65 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers,
66 | stride, drop_rate)
67 |
68 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride,
69 | drop_rate):
70 | layers = []
71 | for i in range(nb_layers):
72 | layers.append(
73 | block(i == 0 and in_planes or out_planes, out_planes,
74 | i == 0 and stride or 1, drop_rate))
75 | return nn.Sequential(*layers)
76 |
77 | def forward(self, x):
78 | return self.layer(x)
79 |
80 |
81 | class WideResNet(nn.Module):
82 | """WideResNet class."""
83 |
84 | def __init__(self, depth, num_classes, widen_factor=1, drop_rate=0.0):
85 | super(WideResNet, self).__init__()
86 | n_channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
87 | assert (depth - 4) % 6 == 0
88 | n = (depth - 4) // 6
89 | block = BasicBlock
90 | # 1st conv before any network block
91 | self.conv1 = nn.Conv2d(
92 | 3, n_channels[0], kernel_size=3, stride=1, padding=1, bias=False)
93 | # 1st block
94 | self.block1 = NetworkBlock(n, n_channels[0], n_channels[1], block, 1,
95 | drop_rate)
96 | # 2nd block
97 | self.block2 = NetworkBlock(n, n_channels[1], n_channels[2], block, 2,
98 | drop_rate)
99 | # 3rd block
100 | self.block3 = NetworkBlock(n, n_channels[2], n_channels[3], block, 2,
101 | drop_rate)
102 | # global average pooling and classifier
103 | self.bn1 = nn.BatchNorm2d(n_channels[3])
104 | self.relu = nn.ReLU(inplace=True)
105 | self.fc = nn.Linear(n_channels[3], num_classes)
106 | self.n_channels = n_channels[3]
107 |
108 | for m in self.modules():
109 | if isinstance(m, nn.Conv2d):
110 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
111 | m.weight.data.normal_(0, math.sqrt(2. / n))
112 | elif isinstance(m, nn.BatchNorm2d):
113 | m.weight.data.fill_(1)
114 | m.bias.data.zero_()
115 | elif isinstance(m, nn.Linear):
116 | m.bias.data.zero_()
117 |
118 | def forward(self, x):
119 | out = self.conv1(x)
120 | out = self.block1(out)
121 | out = self.block2(out)
122 | out = self.block3(out)
123 | out = self.relu(self.bn1(out))
124 | out = F.avg_pool2d(out, 8)
125 | out = out.view(-1, self.n_channels)
126 | return self.fc(out)
127 |
--------------------------------------------------------------------------------
/third_party/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/augmix/9b9824c7c19bf7e72df2d085d97b99b3bfb00ba4/third_party/__init__.py
--------------------------------------------------------------------------------