├── AutoAugment_Exploration.ipynb ├── LICENSE ├── README.md ├── autoaugment.py ├── figures ├── CIFAR100_results.png ├── CIFAR10_results.png ├── FGVC_results.png ├── Figure2_Paper.png ├── ImageNet_results.png └── SVHN_results.png └── ops.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Philip Popien 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoAugment - Learning Augmentation Policies from Data 2 | Unofficial implementation of the ImageNet, CIFAR10 and SVHN Augmentation Policies learned by [AutoAugment](https://arxiv.org/abs/1805.09501v1), described in this [Google AI Blogpost](https://ai.googleblog.com/2018/06/improving-deep-learning-performance.html). 3 | 4 | __Update July 13th, 2018:__ Wrote a [Blogpost](https://towardsdatascience.com/how-to-improve-your-image-classifier-with-googles-autoaugment-77643f0be0c9) about AutoAugment and Double Transfer Learning. 5 | 6 | ##### Tested with Python 3.6. Needs pillow>=5.0.0 7 | 8 | ![Examples of the best ImageNet Policy](figures/Figure2_Paper.png) 9 | 10 | 11 | ------------------ 12 | 13 | 14 | 15 | ## Example 16 | 17 | ```python 18 | from autoaugment import ImageNetPolicy 19 | image = PIL.Image.open(path) 20 | policy = ImageNetPolicy() 21 | transformed = policy(image) 22 | ``` 23 | 24 | To see examples of all operations and magnitudes applied to images, take a look at [AutoAugment_Exploration.ipynb](AutoAugment_Exploration.ipynb). 25 | 26 | ## Example as a PyTorch Transform - ImageNet 27 | 28 | ```python 29 | from autoaugment import ImageNetPolicy 30 | data = ImageFolder(rootdir, transform=transforms.Compose( 31 | [transforms.RandomResizedCrop(224), 32 | transforms.RandomHorizontalFlip(), ImageNetPolicy(), 33 | transforms.ToTensor(), transforms.Normalize(...)])) 34 | loader = DataLoader(data, ...) 35 | ``` 36 | 37 | ## Example as a PyTorch Transform - CIFAR10 38 | 39 | ```python 40 | from autoaugment import CIFAR10Policy 41 | data = ImageFolder(rootdir, transform=transforms.Compose( 42 | [transforms.RandomCrop(32, padding=4, fill=128), # fill parameter needs torchvision installed from source 43 | transforms.RandomHorizontalFlip(), CIFAR10Policy(), 44 | transforms.ToTensor(), 45 | Cutout(n_holes=1, length=16), # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py) 46 | transforms.Normalize(...)])) 47 | loader = DataLoader(data, ...) 48 | ``` 49 | 50 | ## Example as a PyTorch Transform - SVHN 51 | 52 | ```python 53 | from autoaugment import SVHNPolicy 54 | data = ImageFolder(rootdir, transform=transforms.Compose( 55 | [SVHNPolicy(), 56 | transforms.ToTensor(), 57 | Cutout(n_holes=1, length=20), # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py) 58 | transforms.Normalize(...)])) 59 | loader = DataLoader(data, ...) 60 | ``` 61 | 62 | ------------------ 63 | 64 | 65 | ## Results with AutoAugment 66 | 67 | ### Generalizable Data Augmentations 68 | 69 | > Finally, we show that policies found on one task can generalize well across different models and datasets. 70 | > For example, the policy found on ImageNet leads to significant improvements on a variety of FGVC datasets. Even on datasets for 71 | > which fine-tuning weights pre-trained on ImageNet does not help significantly [26], e.g. Stanford 72 | > Cars [27] and FGVC Aircraft [28], training with the ImageNet policy reduces test set error by 1.16% 73 | > and 1.76%, respectively. __This result suggests that transferring data augmentation policies offers an 74 | > alternative method for transfer learning__. 75 | 76 | ### CIFAR 10 77 | 78 | ![CIFAR10 Results](figures/CIFAR10_results.png) 79 | 80 | ### CIFAR 100 81 | 82 | ![CIFAR10 Results](figures/CIFAR100_results.png) 83 | 84 | ### ImageNet 85 | 86 | ![ImageNet Results](figures/ImageNet_results.png) 87 | 88 | ### SVHN 89 | 90 | ![SVHN Results](figures/SVHN_results.png) 91 | 92 | ### Fine Grained Visual Classification Datasets 93 | 94 | ![SVHN Results](figures/FGVC_results.png) 95 | -------------------------------------------------------------------------------- /autoaugment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ops import * 3 | 4 | 5 | class ImageNetPolicy(object): 6 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 7 | 8 | Example: 9 | >>> policy = ImageNetPolicy() 10 | >>> transformed = policy(image) 11 | 12 | Example as a PyTorch Transform: 13 | >>> transform = transforms.Compose([ 14 | >>> transforms.Resize(256), 15 | >>> ImageNetPolicy(), 16 | >>> transforms.ToTensor()]) 17 | """ 18 | def __init__(self, fillcolor=(128, 128, 128)): 19 | self.policies = [ 20 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 21 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 22 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 23 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 24 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 25 | 26 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 27 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 28 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 29 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 30 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 31 | 32 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 33 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 34 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 35 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 36 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 37 | 38 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 39 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 40 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 41 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 42 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 43 | 44 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 45 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 46 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 47 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 48 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 49 | ] 50 | 51 | def __call__(self, img): 52 | policy_idx = random.randint(0, len(self.policies) - 1) 53 | return self.policies[policy_idx](img) 54 | 55 | def __repr__(self): 56 | return "AutoAugment ImageNet Policy" 57 | 58 | 59 | class CIFAR10Policy(object): 60 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 61 | 62 | Example: 63 | >>> policy = CIFAR10Policy() 64 | >>> transformed = policy(image) 65 | 66 | Example as a PyTorch Transform: 67 | >>> transform=transforms.Compose([ 68 | >>> transforms.Resize(256), 69 | >>> CIFAR10Policy(), 70 | >>> transforms.ToTensor()]) 71 | """ 72 | def __init__(self, fillcolor=(128, 128, 128)): 73 | self.policies = [ 74 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 75 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 76 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 77 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 78 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 79 | 80 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 81 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 82 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 83 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 84 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 85 | 86 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 87 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 88 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 89 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 90 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 91 | 92 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 93 | SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor), 94 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 95 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 96 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 97 | 98 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 99 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 100 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 101 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 102 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 103 | ] 104 | 105 | def __call__(self, img): 106 | policy_idx = random.randint(0, len(self.policies) - 1) 107 | return self.policies[policy_idx](img) 108 | 109 | def __repr__(self): 110 | return "AutoAugment CIFAR10 Policy" 111 | 112 | 113 | class SVHNPolicy(object): 114 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 115 | 116 | Example: 117 | >>> policy = SVHNPolicy() 118 | >>> transformed = policy(image) 119 | 120 | Example as a PyTorch Transform: 121 | >>> transform=transforms.Compose([ 122 | >>> transforms.Resize(256), 123 | >>> SVHNPolicy(), 124 | >>> transforms.ToTensor()]) 125 | """ 126 | def __init__(self, fillcolor=(128, 128, 128)): 127 | self.policies = [ 128 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 129 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 130 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 131 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 132 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 133 | 134 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 135 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 136 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 137 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 138 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 139 | 140 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 141 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 142 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 143 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 144 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 145 | 146 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 147 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 148 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 149 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 150 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 151 | 152 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 153 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 154 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 155 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 156 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 157 | ] 158 | 159 | def __call__(self, img): 160 | policy_idx = random.randint(0, len(self.policies) - 1) 161 | return self.policies[policy_idx](img) 162 | 163 | def __repr__(self): 164 | return "AutoAugment SVHN Policy" 165 | 166 | 167 | class SubPolicy(object): 168 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 169 | ranges = { 170 | "shearX": np.linspace(0, 0.3, 10), 171 | "shearY": np.linspace(0, 0.3, 10), 172 | "translateX": np.linspace(0, 150 / 331, 10), 173 | "translateY": np.linspace(0, 150 / 331, 10), 174 | "rotate": np.linspace(0, 30, 10), 175 | "color": np.linspace(0.0, 0.9, 10), 176 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 177 | "solarize": np.linspace(256, 0, 10), 178 | "contrast": np.linspace(0.0, 0.9, 10), 179 | "sharpness": np.linspace(0.0, 0.9, 10), 180 | "brightness": np.linspace(0.0, 0.9, 10), 181 | "autocontrast": [0] * 10, 182 | "equalize": [0] * 10, 183 | "invert": [0] * 10 184 | } 185 | 186 | func = { 187 | "shearX": ShearX(fillcolor=fillcolor), 188 | "shearY": ShearY(fillcolor=fillcolor), 189 | "translateX": TranslateX(fillcolor=fillcolor), 190 | "translateY": TranslateY(fillcolor=fillcolor), 191 | "rotate": Rotate(), 192 | "color": Color(), 193 | "posterize": Posterize(), 194 | "solarize": Solarize(), 195 | "contrast": Contrast(), 196 | "sharpness": Sharpness(), 197 | "brightness": Brightness(), 198 | "autocontrast": AutoContrast(), 199 | "equalize": Equalize(), 200 | "invert": Invert() 201 | } 202 | 203 | self.p1 = p1 204 | self.operation1 = func[operation1] 205 | self.magnitude1 = ranges[operation1][magnitude_idx1] 206 | self.p2 = p2 207 | self.operation2 = func[operation2] 208 | self.magnitude2 = ranges[operation2][magnitude_idx2] 209 | 210 | def __call__(self, img): 211 | if random.random() < self.p1: 212 | img = self.operation1(img, self.magnitude1) 213 | if random.random() < self.p2: 214 | img = self.operation2(img, self.magnitude2) 215 | return img 216 | -------------------------------------------------------------------------------- /figures/CIFAR100_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepVoltaire/AutoAugment/19c8c484807b3462e59561501794d744e88b56bf/figures/CIFAR100_results.png -------------------------------------------------------------------------------- /figures/CIFAR10_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepVoltaire/AutoAugment/19c8c484807b3462e59561501794d744e88b56bf/figures/CIFAR10_results.png -------------------------------------------------------------------------------- /figures/FGVC_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepVoltaire/AutoAugment/19c8c484807b3462e59561501794d744e88b56bf/figures/FGVC_results.png -------------------------------------------------------------------------------- /figures/Figure2_Paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepVoltaire/AutoAugment/19c8c484807b3462e59561501794d744e88b56bf/figures/Figure2_Paper.png -------------------------------------------------------------------------------- /figures/ImageNet_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepVoltaire/AutoAugment/19c8c484807b3462e59561501794d744e88b56bf/figures/ImageNet_results.png -------------------------------------------------------------------------------- /figures/SVHN_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepVoltaire/AutoAugment/19c8c484807b3462e59561501794d744e88b56bf/figures/SVHN_results.png -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageEnhance, ImageOps 2 | import random 3 | 4 | 5 | class ShearX(object): 6 | def __init__(self, fillcolor=(128, 128, 128)): 7 | self.fillcolor = fillcolor 8 | 9 | def __call__(self, x, magnitude): 10 | return x.transform( 11 | x.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 12 | Image.BICUBIC, fillcolor=self.fillcolor) 13 | 14 | 15 | class ShearY(object): 16 | def __init__(self, fillcolor=(128, 128, 128)): 17 | self.fillcolor = fillcolor 18 | 19 | def __call__(self, x, magnitude): 20 | return x.transform( 21 | x.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 22 | Image.BICUBIC, fillcolor=self.fillcolor) 23 | 24 | 25 | class TranslateX(object): 26 | def __init__(self, fillcolor=(128, 128, 128)): 27 | self.fillcolor = fillcolor 28 | 29 | def __call__(self, x, magnitude): 30 | return x.transform( 31 | x.size, Image.AFFINE, (1, 0, magnitude * x.size[0] * random.choice([-1, 1]), 0, 1, 0), 32 | fillcolor=self.fillcolor) 33 | 34 | 35 | class TranslateY(object): 36 | def __init__(self, fillcolor=(128, 128, 128)): 37 | self.fillcolor = fillcolor 38 | 39 | def __call__(self, x, magnitude): 40 | return x.transform( 41 | x.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * x.size[1] * random.choice([-1, 1])), 42 | fillcolor=self.fillcolor) 43 | 44 | 45 | class Rotate(object): 46 | # from https://stackoverflow.com/questions/ 47 | # 5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 48 | def __call__(self, x, magnitude): 49 | rot = x.convert("RGBA").rotate(magnitude * random.choice([-1, 1])) 50 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(x.mode) 51 | 52 | 53 | class Color(object): 54 | def __call__(self, x, magnitude): 55 | return ImageEnhance.Color(x).enhance(1 + magnitude * random.choice([-1, 1])) 56 | 57 | 58 | class Posterize(object): 59 | def __call__(self, x, magnitude): 60 | return ImageOps.posterize(x, magnitude) 61 | 62 | 63 | class Solarize(object): 64 | def __call__(self, x, magnitude): 65 | return ImageOps.solarize(x, magnitude) 66 | 67 | 68 | class Contrast(object): 69 | def __call__(self, x, magnitude): 70 | return ImageEnhance.Contrast(x).enhance(1 + magnitude * random.choice([-1, 1])) 71 | 72 | 73 | class Sharpness(object): 74 | def __call__(self, x, magnitude): 75 | return ImageEnhance.Sharpness(x).enhance(1 + magnitude * random.choice([-1, 1])) 76 | 77 | 78 | class Brightness(object): 79 | def __call__(self, x, magnitude): 80 | return ImageEnhance.Brightness(x).enhance(1 + magnitude * random.choice([-1, 1])) 81 | 82 | 83 | class AutoContrast(object): 84 | def __call__(self, x, magnitude): 85 | return ImageOps.autocontrast(x) 86 | 87 | 88 | class Equalize(object): 89 | def __call__(self, x, magnitude): 90 | return ImageOps.equalize(x) 91 | 92 | 93 | class Invert(object): 94 | def __call__(self, x, magnitude): 95 | return ImageOps.invert(x) 96 | --------------------------------------------------------------------------------