├── .gitignore ├── LICENSE ├── README.md ├── demo.py ├── randaugment ├── __init__.py └── randaugment.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea/ 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ildoo Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-randaugment 2 | Unofficial PyTorch Reimplementation of AutoAugment and RandAugment. 3 | 4 | Code taken from https://github.com/DeepVoltaire/AutoAugment and https://github.com/jizongFox/uda 5 | ## How to install: 6 | ```bash 7 | pip install randaugment 8 | ``` 9 | --- 10 | 11 | ## How to use: 12 | ```python 13 | from randaugment import RandAugment, ImageNetPolicy 14 | data = ImageFolder(rootdir, transform=transforms.Compose( 15 | [ 16 | transforms.RandomCrop(32, padding=4, fill=128), # fill parameter needs torchvision installed from source 17 | transforms.RandomHorizontalFlip(), 18 | RandAugment(), 19 | #ImageNetPolicy(), 20 | Cutout(size=16), # (https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py) 21 | transforms.ToTensor(), 22 | transforms.Normalize(...) 23 | ]) 24 | ) 25 | loader = DataLoader(data, ...) 26 | ``` 27 | 28 | 29 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from PIL import Image 3 | 4 | from randaugment.randaugment import RandAugment,Cutout 5 | 6 | url = "https://images.pexels.com/photos/356378/pexels-photo-356378.jpeg?auto=compress&cs=tinysrgb&dpr=2&h=650&w=940" 7 | img = Image.open(requests.get(url, stream=True).raw) 8 | img=img.resize((256,256)) 9 | random_transform = RandAugment() 10 | cutout = Cutout(size=30) 11 | for i in range(5): 12 | img_ = random_transform(img) 13 | img_ = cutout(img_) 14 | img_.show() 15 | -------------------------------------------------------------------------------- /randaugment/__init__.py: -------------------------------------------------------------------------------- 1 | from .randaugment import CIFAR10Policy, CIFAR10PolicyAll, SVHNPolicy, RandAugment, ImageNetPolicy, Cutout -------------------------------------------------------------------------------- /randaugment/randaugment.py: -------------------------------------------------------------------------------- 1 | ## the code is mostly taken from autoaugment pytorch repo: 2 | # https://github.com/DeepVoltaire/AutoAugment 3 | 4 | 5 | import random 6 | 7 | import numpy as np 8 | from PIL import Image, ImageEnhance, ImageOps 9 | 10 | 11 | class Cutout: 12 | 13 | def __init__(self, size=16) -> None: 14 | self.size = size 15 | 16 | def _create_cutout_mask(self, img_height, img_width, num_channels, size): 17 | """Creates a zero mask used for cutout of shape `img_height` x `img_width`. 18 | Args: 19 | img_height: Height of image cutout mask will be applied to. 20 | img_width: Width of image cutout mask will be applied to. 21 | num_channels: Number of channels in the image. 22 | size: Size of the zeros mask. 23 | Returns: 24 | A mask of shape `img_height` x `img_width` with all ones except for a 25 | square of zeros of shape `size` x `size`. This mask is meant to be 26 | elementwise multiplied with the original image. Additionally returns 27 | the `upper_coord` and `lower_coord` which specify where the cutout mask 28 | will be applied. 29 | """ 30 | # assert img_height == img_width 31 | 32 | # Sample center where cutout mask will be applied 33 | height_loc = np.random.randint(low=0, high=img_height) 34 | width_loc = np.random.randint(low=0, high=img_width) 35 | 36 | size = int(size) 37 | # Determine upper right and lower left corners of patch 38 | upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2)) 39 | lower_coord = ( 40 | min(img_height, height_loc + size // 2), 41 | min(img_width, width_loc + size // 2), 42 | ) 43 | mask_height = lower_coord[0] - upper_coord[0] 44 | mask_width = lower_coord[1] - upper_coord[1] 45 | assert mask_height > 0 46 | assert mask_width > 0 47 | 48 | mask = np.ones((img_height, img_width, num_channels)) 49 | zeros = np.zeros((mask_height, mask_width, num_channels)) 50 | mask[upper_coord[0]: lower_coord[0], upper_coord[1]: lower_coord[1], :] = zeros 51 | return mask, upper_coord, lower_coord 52 | 53 | def __call__(self, pil_img): 54 | pil_img = pil_img.copy() 55 | img_height, img_width, num_channels = (*pil_img.size, 3) 56 | _, upper_coord, lower_coord = self._create_cutout_mask( 57 | img_height, img_width, num_channels, self.size 58 | ) 59 | pixels = pil_img.load() # create the pixel map 60 | for i in range(upper_coord[0], lower_coord[0]): # for every col: 61 | for j in range(upper_coord[1], lower_coord[1]): # For every row 62 | pixels[i, j] = (125, 122, 113, 0) # set the colour accordingly 63 | return pil_img 64 | 65 | 66 | class ImageNetPolicy(object): 67 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 68 | 69 | Example: 70 | >>> policy = ImageNetPolicy() 71 | >>> transformed = policy(image) 72 | 73 | Example as a PyTorch Transform: 74 | >>> transform=transforms.Compose([ 75 | >>> transforms.Resize(256), 76 | >>> ImageNetPolicy(), 77 | >>> transforms.ToTensor()]) 78 | """ 79 | 80 | def __init__(self, fillcolor=(128, 128, 128)): 81 | """ 82 | Auto augment from https://arxiv.org/pdf/1805.09501.pdf 83 | :param fillcolor: 84 | """ 85 | 86 | self.policies = [ 87 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 88 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 89 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 90 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 91 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 92 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 93 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 94 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 95 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 96 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 97 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 98 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 99 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 100 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 101 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 102 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 103 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 104 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 105 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 106 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 107 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 108 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 109 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 110 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 111 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 112 | ] 113 | 114 | def __call__(self, img): 115 | policy_idx = random.randint(0, len(self.policies) - 1) 116 | return self.policies[policy_idx](img) 117 | 118 | def __repr__(self): 119 | return "AutoAugment ImageNet Policy" 120 | 121 | 122 | class CIFAR10PolicyAll(object): 123 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 124 | 125 | Example: 126 | >>> policy = CIFAR10Policy() 127 | >>> transformed = policy(image) 128 | 129 | Example as a PyTorch Transform: 130 | >>> transform=transforms.Compose([ 131 | >>> transforms.Resize(256), 132 | >>> CIFAR10Policy(), 133 | >>> transforms.ToTensor()]) 134 | """ 135 | 136 | def __init__(self, fillcolor=(128, 128, 128)): 137 | self.policies = [ 138 | SubPolicy(0.1, "Invert", 7, 0.2, "Contrast", 6, fillcolor), 139 | SubPolicy(0.7, "Rotate", 2, 0.3, "TranslateX", 9, fillcolor), 140 | SubPolicy(0.8, "Sharpness", 1, 0.9, "Sharpness", 3, fillcolor), 141 | SubPolicy(0.5, "ShearY", 8, 0.7, "TranslateY", 9, fillcolor), 142 | SubPolicy(0.5, "AutoContrast", 8, 0.9, "Equalize", 2, fillcolor), 143 | SubPolicy(0.4, "Solarize", 5, 0.9, "AutoContrast", 3, fillcolor), 144 | SubPolicy(0.9, "TranslateY", 9, 0.7, "TranslateY", 9, fillcolor), 145 | SubPolicy(0.9, "AutoContrast", 2, 0.8, "Solarize", 3, fillcolor), 146 | SubPolicy(0.8, "Equalize", 8, 0.1, "Invert", 3, fillcolor), 147 | SubPolicy(0.7, "TranslateY", 9, 0.9, "AutoContrast", 1, fillcolor), 148 | SubPolicy(0.4, "Solarize", 5, 0.0, "AutoContrast", 2, fillcolor), 149 | SubPolicy(0.7, "TranslateY", 9, 0.7, "TranslateY", 9, fillcolor), 150 | SubPolicy(0.9, "AutoContrast", 0, 0.4, "Solarize", 3, fillcolor), 151 | SubPolicy(0.7, "Equalize", 5, 0.1, "Invert", 3, fillcolor), 152 | SubPolicy(0.7, "TranslateY", 9, 0.7, "TranslateY", 9, fillcolor), 153 | SubPolicy(0.4, "Solarize", 5, 0.9, "AutoContrast", 1, fillcolor), 154 | SubPolicy(0.8, "TranslateY", 9, 0.9, "TranslateY", 9, fillcolor), 155 | SubPolicy(0.8, "AutoContrast", 0, 0.7, "TranslateY", 9, fillcolor), 156 | SubPolicy(0.2, "TranslateY", 7, 0.9, "Color", 6, fillcolor), 157 | SubPolicy(0.7, "Equalize", 6, 0.4, "Color", 9, fillcolor), 158 | SubPolicy(0.2, "ShearY", 7, 0.3, "Posterize", 7, fillcolor), 159 | SubPolicy(0.4, "Color", 3, 0.6, "Brightness", 7, fillcolor), 160 | SubPolicy(0.3, "Sharpness", 9, 0.7, "Brightness", 9, fillcolor), 161 | SubPolicy(0.6, "Equalize", 5, 0.5, "Equalize", 1, fillcolor), 162 | SubPolicy(0.6, "Contrast", 7, 0.6, "Sharpness", 5, fillcolor), 163 | SubPolicy(0.3, "Brightness", 7, 0.5, "AutoContrast", 8, fillcolor), 164 | SubPolicy(0.9, "AutoContrast", 4, 0.5, "AutoContrast", 6, fillcolor), 165 | SubPolicy(0.3, "Solarize", 5, 0.6, "Equalize", 5, fillcolor), 166 | SubPolicy(0.2, "TranslateY", 4, 0.3, "Sharpness", 3, fillcolor), 167 | SubPolicy(0.0, "Brightness", 8, 0.8, "Color", 8, fillcolor), 168 | SubPolicy(0.2, "Solarize", 6, 0.8, "Color", 6, fillcolor), 169 | SubPolicy(0.2, "Solarize", 6, 0.8, "AutoContrast", 1, fillcolor), 170 | SubPolicy(0.4, "Solarize", 1, 0.6, "Equalize", 5, fillcolor), 171 | SubPolicy(0.0, "Brightness", 0, 0.5, "Solarize", 2, fillcolor), 172 | SubPolicy(0.9, "AutoContrast", 5, 0.5, "Brightness", 3, fillcolor), 173 | SubPolicy(0.7, "Contrast", 5, 0.0, "Brightness", 2, fillcolor), 174 | SubPolicy(0.2, "Solarize", 8, 0.1, "Solarize", 5, fillcolor), 175 | SubPolicy(0.5, "Contrast", 1, 0.2, "TranslateY", 9, fillcolor), 176 | SubPolicy(0.6, "AutoContrast", 5, 0.0, "TranslateY", 9, fillcolor), 177 | SubPolicy(0.9, "AutoContrast", 4, 0.8, "Equalize", 4, fillcolor), 178 | SubPolicy(0.0, "Brightness", 7, 0.4, "Equalize", 7, fillcolor), 179 | SubPolicy(0.2, "Solarize", 5, 0.7, "Equalize", 5, fillcolor), 180 | SubPolicy(0.6, "Equalize", 8, 0.6, "Color", 2, fillcolor), 181 | SubPolicy(0.3, "Color", 7, 0.2, "Color", 4, fillcolor), 182 | SubPolicy(0.5, "AutoContrast", 2, 0.7, "Solarize", 2, fillcolor), 183 | SubPolicy(0.2, "AutoContrast", 0, 0.1, "Equalize", 0, fillcolor), 184 | SubPolicy(0.6, "ShearY", 5, 0.6, "Equalize", 5, fillcolor), 185 | SubPolicy(0.9, "Brightness", 3, 0.4, "AutoContrast", 1, fillcolor), 186 | SubPolicy(0.8, "Equalize", 8, 0.7, "Equalize", 7, fillcolor), 187 | SubPolicy(0.7, "Equalize", 7, 0.5, "Solarize", 0, fillcolor), 188 | SubPolicy(0.8, "Equalize", 4, 0.8, "TranslateY", 9, fillcolor), 189 | SubPolicy(0.8, "TranslateY", 9, 0.6, "TranslateY", 9, fillcolor), 190 | SubPolicy(0.9, "TranslateY", 0, 0.5, "TranslateY", 9, fillcolor), 191 | SubPolicy(0.5, "AutoContrast", 3, 0.3, "Solarize", 4, fillcolor), 192 | SubPolicy(0.5, "Solarize", 3, 0.4, "Equalize", 4, fillcolor), 193 | SubPolicy(0.7, "Color", 7, 0.5, "TranslateX", 8, fillcolor), 194 | SubPolicy(0.3, "Equalize", 7, 0.4, "AutoContrast", 8, fillcolor), 195 | SubPolicy(0.4, "TranslateY", 3, 0.2, "Sharpness", 6, fillcolor), 196 | SubPolicy(0.9, "Brightness", 6, 0.2, "Color", 8, fillcolor), 197 | SubPolicy(0.5, "Solarize", 2, 0.0, "Invert", 3, fillcolor), 198 | SubPolicy(0.1, "AutoContrast", 5, 0.0, "Brightness", 0, fillcolor), 199 | SubPolicy(0.2, "Cutout", 4, 0.1, "Equalize", 1, fillcolor), 200 | SubPolicy(0.7, "Equalize", 7, 0.6, "AutoContrast", 4, fillcolor), 201 | SubPolicy(0.1, "Color", 8, 0.2, "ShearY", 3, fillcolor), 202 | SubPolicy(0.4, "ShearY", 2, 0.7, "Rotate", 0, fillcolor), 203 | SubPolicy(0.1, "ShearY", 3, 0.9, "AutoContrast", 5, fillcolor), 204 | SubPolicy(0.3, "TranslateY", 6, 0.3, "Cutout", 3, fillcolor), 205 | SubPolicy(0.5, "Equalize", 0, 0.6, "Solarize", 6, fillcolor), 206 | SubPolicy(0.3, "AutoContrast", 5, 0.2, "Rotate", 7, fillcolor), 207 | SubPolicy(0.8, "Equalize", 2, 0.4, "Invert", 0, fillcolor), 208 | SubPolicy(0.9, "Equalize", 5, 0.7, "Color", 0, fillcolor), 209 | SubPolicy(0.1, "Equalize", 1, 0.1, "ShearY", 3, fillcolor), 210 | SubPolicy(0.7, "AutoContrast", 3, 0.7, "Equalize", 0, fillcolor), 211 | SubPolicy(0.5, "Brightness", 1, 0.1, "Contrast", 7, fillcolor), 212 | SubPolicy(0.1, "Contrast", 4, 0.6, "Solarize", 5, fillcolor), 213 | SubPolicy(0.2, "Solarize", 3, 0.0, "ShearX", 0, fillcolor), 214 | SubPolicy(0.3, "TranslateX", 0, 0.6, "TranslateX", 0, fillcolor), 215 | SubPolicy(0.5, "Equalize", 9, 0.6, "TranslateY", 7, fillcolor), 216 | SubPolicy(0.1, "ShearX", 0, 0.5, "Sharpness", 1, fillcolor), 217 | SubPolicy(0.8, "Equalize", 6, 0.3, "Invert", 6, fillcolor), 218 | SubPolicy(0.3, "AutoContrast", 9, 0.5, "Cutout", 3, fillcolor), 219 | SubPolicy(0.4, "ShearX", 4, 0.9, "AutoContrast", 2, fillcolor), 220 | SubPolicy(0.0, "ShearX", 3, 0.0, "Posterize", 3, fillcolor), 221 | SubPolicy(0.4, "Solarize", 3, 0.2, "Color", 4, fillcolor), 222 | SubPolicy(0.1, "Equalize", 4, 0.7, "Equalize", 6, fillcolor), 223 | SubPolicy(0.3, "Equalize", 8, 0.4, "AutoContrast", 3, fillcolor), 224 | SubPolicy(0.6, "Solarize", 4, 0.7, "AutoContrast", 6, fillcolor), 225 | SubPolicy(0.2, "AutoContrast", 9, 0.4, "Brightness", 8, fillcolor), 226 | SubPolicy(0.1, "Equalize", 0, 0.0, "Equalize", 6, fillcolor), 227 | SubPolicy(0.8, "Equalize", 4, 0.0, "Equalize", 4, fillcolor), 228 | SubPolicy(0.5, "Equalize", 5, 0.1, "AutoContrast", 2, fillcolor), 229 | SubPolicy(0.5, "Solarize", 5, 0.9, "AutoContrast", 5, fillcolor), 230 | SubPolicy(0.6, "AutoContrast", 1, 0.7, "AutoContrast", 8, fillcolor), 231 | SubPolicy(0.2, "Equalize", 0, 0.1, "AutoContrast", 2, fillcolor), 232 | SubPolicy(0.6, "Equalize", 9, 0.4, "Equalize", 4, fillcolor), 233 | ] 234 | 235 | def __call__(self, img): 236 | policy_idx = random.randint(0, len(self.policies) - 1) 237 | return self.policies[policy_idx](img) 238 | 239 | def __repr__(self): 240 | return "AutoAugment CIFAR10 Policy" 241 | 242 | 243 | class CIFAR10Policy(object): 244 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 245 | 246 | Example: 247 | >>> policy = CIFAR10Policy() 248 | >>> transformed = policy(image) 249 | 250 | Example as a PyTorch Transform: 251 | >>> transform=transforms.Compose([ 252 | >>> transforms.Resize(256), 253 | >>> CIFAR10Policy(), 254 | >>> transforms.ToTensor()]) 255 | """ 256 | 257 | def __init__(self, fillcolor=(128, 128, 128)): 258 | """ 259 | Auto augment from https://arxiv.org/pdf/1805.09501.pdf 260 | :param fillcolor: 261 | """ 262 | 263 | self.policies = [ 264 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 265 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 266 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 267 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 268 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 269 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 270 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 271 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 272 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 273 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 274 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 275 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 276 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 277 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 278 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 279 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 280 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), 281 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 282 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 283 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 284 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 285 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 286 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 287 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 288 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor), 289 | ] 290 | 291 | def __call__(self, img): 292 | policy_idx = random.randint(0, len(self.policies) - 1) 293 | return self.policies[policy_idx](img) 294 | 295 | def __repr__(self): 296 | return "AutoAugment CIFAR10 Policy" 297 | 298 | 299 | class SVHNPolicy(object): 300 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 301 | 302 | Example: 303 | >>> policy = SVHNPolicy() 304 | >>> transformed = policy(image) 305 | 306 | Example as a PyTorch Transform: 307 | >>> transform=transforms.Compose([ 308 | >>> transforms.Resize(256), 309 | >>> SVHNPolicy(), 310 | >>> transforms.ToTensor()]) 311 | """ 312 | 313 | def __init__(self, fillcolor=(128, 128, 128)): 314 | """ 315 | Auto augment from https://arxiv.org/pdf/1805.09501.pdf 316 | :param fillcolor: 317 | """ 318 | self.policies = [ 319 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 320 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 321 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 322 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 323 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 324 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 325 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 326 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 327 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 328 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 329 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 330 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 331 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 332 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 333 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 334 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 335 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 336 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 337 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 338 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 339 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 340 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 341 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 342 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 343 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor), 344 | ] 345 | 346 | def __call__(self, img): 347 | policy_idx = random.randint(0, len(self.policies) - 1) 348 | return self.policies[policy_idx](img) 349 | 350 | def __repr__(self): 351 | return "AutoAugment SVHN Policy" 352 | 353 | 354 | class SubPolicy(object): 355 | def __init__( 356 | self, 357 | p1, 358 | operation1, 359 | magnitude_idx1, 360 | p2, 361 | operation2, 362 | magnitude_idx2, 363 | fillcolor=(128, 128, 128), 364 | ): 365 | ranges = { 366 | "shearx": np.linspace(0, 0.3, 10), 367 | "sheary": np.linspace(0, 0.3, 10), 368 | "translatex": np.linspace(0, 150 / 331, 10), 369 | "translatey": np.linspace(0, 150 / 331, 10), 370 | "rotate": np.linspace(0, 30, 10), 371 | "color": np.linspace(0.0, 0.9, 10), 372 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 373 | "solarize": np.linspace(256, 0, 10), 374 | "contrast": np.linspace(0.0, 0.9, 10), 375 | "sharpness": np.linspace(0.0, 0.9, 10), 376 | "brightness": np.linspace(0.0, 0.9, 10), 377 | "autocontrast": [0] * 10, 378 | "equalize": [0] * 10, 379 | "invert": [0] * 10, 380 | "cutout": np.round(np.linspace(0, 20, 10), 0).astype(np.int), 381 | } 382 | 383 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 384 | def rotate_with_fill(img, magnitude): 385 | rot = img.convert("RGBA").rotate(magnitude) 386 | return Image.composite( 387 | rot, Image.new("RGBA", rot.size, (128,) * 4), rot 388 | ).convert(img.mode) 389 | 390 | func = { 391 | "shearx": lambda img, magnitude: img.transform( 392 | img.size, 393 | Image.AFFINE, 394 | (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 395 | Image.BICUBIC, 396 | fillcolor=fillcolor, 397 | ), 398 | "sheary": lambda img, magnitude: img.transform( 399 | img.size, 400 | Image.AFFINE, 401 | (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 402 | Image.BICUBIC, 403 | fillcolor=fillcolor, 404 | ), 405 | "translatex": lambda img, magnitude: img.transform( 406 | img.size, 407 | Image.AFFINE, 408 | (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 409 | fillcolor=fillcolor, 410 | ), 411 | "translatey": lambda img, magnitude: img.transform( 412 | img.size, 413 | Image.AFFINE, 414 | (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 415 | fillcolor=fillcolor, 416 | ), 417 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 418 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 419 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance( 420 | 1 + magnitude * random.choice([-1, 1]) 421 | ), 422 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 423 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 424 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 425 | 1 + magnitude * random.choice([-1, 1]) 426 | ), 427 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 428 | 1 + magnitude * random.choice([-1, 1]) 429 | ), 430 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 431 | 1 + magnitude * random.choice([-1, 1]) 432 | ), 433 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 434 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 435 | "invert": lambda img, magnitude: ImageOps.invert(img), 436 | "cutout": lambda img, magnitude: Cutout(magnitude)(img), 437 | } 438 | 439 | self.p1 = p1 440 | self._operation1_name = operation1 441 | self.operation1 = func[operation1.lower()] 442 | self.magnitude1 = ranges[operation1.lower()][magnitude_idx1] 443 | self.p2 = p2 444 | self._operation2_name = operation2 445 | self.operation2 = func[operation2.lower()] 446 | self.magnitude2 = ranges[operation2.lower()][magnitude_idx2] 447 | 448 | def __call__(self, img): 449 | if random.random() < self.p1: 450 | img = self.operation1(img, self.magnitude1) 451 | if random.random() < self.p2: 452 | img = self.operation2(img, self.magnitude2) 453 | return img 454 | 455 | def __repr__(self): 456 | return f"{self._operation1_name} with p:{self.p1} and magnitude:{self.magnitude1} \t" \ 457 | f"{self._operation2_name} with p:{self.p2} and magnitude:{self.magnitude2} \n" 458 | 459 | 460 | class RandAugment: 461 | """ 462 | # randaugment is adaptived from UDA tensorflow implementation: 463 | # https://github.com/jizongFox/uda 464 | """ 465 | 466 | @classmethod 467 | def get_trans_list(cls): 468 | trans_list = [ 469 | 'Invert', 'Cutout', 'Sharpness', 'AutoContrast', 'Posterize', 470 | 'ShearX', 'TranslateX', 'TranslateY', 'ShearY', 'Rotate', 471 | 'Equalize', 'Contrast', 'Color', 'Solarize', 'Brightness'] 472 | return trans_list 473 | 474 | @classmethod 475 | def get_rand_policies(cls): 476 | op_list = [] 477 | for trans in cls.get_trans_list(): 478 | for magnitude in range(1, 10): 479 | op_list += [(0.5, trans, magnitude)] 480 | policies = [] 481 | for op_1 in op_list: 482 | for op_2 in op_list: 483 | policies += [[op_1, op_2]] 484 | return policies 485 | 486 | def __init__(self) -> None: 487 | super().__init__() 488 | self._policies = self.get_rand_policies() 489 | 490 | def __call__(self, img): 491 | randomly_chosen_policy = self._policies[random.randint(0, len(self._policies) - 1)] 492 | policy = SubPolicy(*randomly_chosen_policy[0], *randomly_chosen_policy[1]) 493 | return policy(img) 494 | 495 | def __repr__(self): 496 | return "Random Augment Policy" 497 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from setuptools import setup,find_packages 3 | HERE = pathlib.Path(__file__).parent 4 | README_path = (HERE / "README.md") 5 | 6 | 7 | with open(README_path, encoding='utf-8') as f: 8 | long_description = f.read() 9 | setup( 10 | name='randaugment', 11 | version='1.0.2', 12 | packages=find_packages(), 13 | url='https://github.com/jizongFox/pytorch-randaugment', 14 | license='MIT', 15 | author='Jizong Peng', 16 | author_email='jizong.peng.1@etsmtl.net', 17 | long_description=long_description, 18 | long_description_content_type='text/markdown', 19 | install_requires=[], 20 | classifiers=[ 21 | "License :: OSI Approved :: MIT License", 22 | "Programming Language :: Python :: 3", 23 | "Programming Language :: Python :: 3.7", 24 | ], 25 | ) 26 | --------------------------------------------------------------------------------