├── .gitignore ├── DeepAA_evaluate ├── README.md ├── __init__.py ├── augmentations.py ├── autoaugment.py ├── common.py ├── data.py ├── deep_autoaugment.py ├── fast_autoaugment.py ├── imagenet.py ├── lr_scheduler.py ├── metrics.py ├── networks │ ├── __init__.py │ ├── convnet.py │ ├── mlp.py │ ├── resnet.py │ ├── shakeshake │ │ ├── __init__.py │ │ ├── shake_resnet.py │ │ ├── shake_resnext.py │ │ └── shakeshake.py │ └── wideresnet.py ├── train.py └── utils.py ├── DeepAA_search.py ├── DeepAA_utils.py ├── README.md ├── __init__.py ├── aug_lib.py ├── augmentation.py ├── confs ├── resnet50_imagenet_DeepAA_8x256_1.yaml ├── resnet50_imagenet_DeepAA_8x256_2.yaml ├── wresnet28x10_cifar100_DeepAA_1.yaml ├── wresnet28x10_cifar100_DeepAA_1_wd1e-3.yaml ├── wresnet28x10_cifar100_DeepAA_2.yaml ├── wresnet28x10_cifar100_DeepAA_2_wd1e-3.yaml ├── wresnet28x10_cifar100_DeepAA_BatchAug8x_1.yaml ├── wresnet28x10_cifar100_DeepAA_BatchAug8x_2.yaml ├── wresnet28x10_cifar10_DeepAA_1.yaml ├── wresnet28x10_cifar10_DeepAA_1_wd1e-3.yaml ├── wresnet28x10_cifar10_DeepAA_2.yaml ├── wresnet28x10_cifar10_DeepAA_2_wd1e-3.yaml ├── wresnet28x10_cifar10_DeepAA_BatchAug8x_1.yaml └── wresnet28x10_cifar10_DeepAA_BatchAug8x_2.yaml ├── data_generator.py ├── imagenet_data_utils.py ├── images ├── DeepAA.png ├── DeepAA_slideslive.png ├── magnitude_distribution_cifar.png ├── magnitude_distribution_imagenet.png └── operation_distribution.png ├── lr_scheduler.py ├── policy.py ├── policy_port ├── policy_DeepAA_cifar_1.npz ├── policy_DeepAA_cifar_2.npz ├── policy_DeepAA_imagenet_1.npz └── policy_DeepAA_imagenet_2.npz ├── requirements.txt ├── resnet.py ├── resnet_imagenet.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | model_checkpoints/** 3 | *.pyc 4 | __pycache__ 5 | .idea/** 6 | results/** 7 | plots/** 8 | temp/results/ -------------------------------------------------------------------------------- /DeepAA_evaluate/README.md: -------------------------------------------------------------------------------- 1 | Code for evaluating the generated DeepAA policy. 2 | 3 | The code in this folder is adapted from [TrivialAugment](https://github.com/automl/trivialaugment). -------------------------------------------------------------------------------- /DeepAA_evaluate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/DeepAA_evaluate/__init__.py -------------------------------------------------------------------------------- /DeepAA_evaluate/augmentations.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from DeepAA_evaluate import autoaugment, fast_autoaugment 8 | import aug_lib 9 | 10 | 11 | class Lighting(object): 12 | """Lighting noise(AlexNet - style PCA - based noise)""" 13 | 14 | def __init__(self, alphastd, eigval, eigvec): 15 | self.alphastd = alphastd 16 | self.eigval = torch.Tensor(eigval) 17 | self.eigvec = torch.Tensor(eigvec) 18 | 19 | def __call__(self, img): 20 | if self.alphastd == 0: 21 | return img 22 | 23 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 24 | rgb = self.eigvec.type_as(img).clone() \ 25 | .mul(alpha.view(1, 3).expand(3, 3)) \ 26 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 27 | .sum(1).squeeze() 28 | 29 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 30 | 31 | 32 | class CutoutDefault(object): 33 | """ 34 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 35 | """ 36 | def __init__(self, length): 37 | self.length = length 38 | 39 | def __call__(self, img): 40 | h, w = img.size(1), img.size(2) 41 | mask = np.ones((h, w), np.float32) 42 | y = np.random.randint(h) 43 | x = np.random.randint(w) 44 | 45 | y1 = np.clip(y - self.length // 2, 0, h) 46 | y2 = np.clip(y + self.length // 2, 0, h) 47 | x1 = np.clip(x - self.length // 2, 0, w) 48 | x2 = np.clip(x + self.length // 2, 0, w) 49 | 50 | mask[y1: y2, x1: x2] = 0. 51 | mask = torch.from_numpy(mask) 52 | mask = mask.expand_as(img) 53 | img *= mask 54 | return img 55 | 56 | 57 | def get_randaugment(n,m,weights,bs): 58 | if n == 101 and m == 101: 59 | return autoaugment.CifarAutoAugment(fixed_posterize=False) 60 | if n == 102 and m == 102: 61 | return autoaugment.CifarAutoAugment(fixed_posterize=True) 62 | if n == 201 and m == 201: 63 | return autoaugment.SVHNAutoAugment(fixed_posterize=False) 64 | if n == 202 and m == 202: 65 | return autoaugment.SVHNAutoAugment(fixed_posterize=False) 66 | if n == 301 and m == 301: 67 | return fast_autoaugment.cifar10_faa 68 | if n == 401 and m == 401: 69 | return fast_autoaugment.svhn_faa 70 | assert m < 100 and n < 100 71 | if m == 0: 72 | if weights is not None: 73 | return aug_lib.UniAugmentWeighted(n, probs=weights) 74 | elif n == 0: 75 | return aug_lib.UniAugment() 76 | else: 77 | raise ValueError('Wrong RandAug Params.') 78 | else: 79 | assert n > 0 and m > 0 80 | return aug_lib.RandAugment(n, m) -------------------------------------------------------------------------------- /DeepAA_evaluate/autoaugment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Transforms used in the Augmentation Policies.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import random 23 | import numpy as np 24 | # pylint:disable=g-multiple-import 25 | from PIL import ImageOps, ImageEnhance, ImageFilter, Image 26 | # pylint:enable=g-multiple-import 27 | 28 | 29 | IMAGE_SIZE = 32 30 | # What is the dataset mean and std of the images on the training set 31 | PARAMETER_MAX = 30 # What is the max 'level' a transform could be predicted 32 | 33 | def pil_wrap(img): 34 | """Convert the `img` numpy tensor to a PIL Image.""" 35 | return img.convert('RGBA') 36 | 37 | 38 | def pil_unwrap(img): 39 | """Converts the PIL img to a numpy array.""" 40 | return img.convert('RGB') 41 | 42 | def apply_policy(policy, img, use_fixed_posterize=False): 43 | """Apply the `policy` to the numpy `img`. 44 | 45 | Args: 46 | policy: A list of tuples with the form (name, probability, level) where 47 | `name` is the name of the augmentation operation to apply, `probability` 48 | is the probability of applying the operation and `level` is what strength 49 | the operation to apply. 50 | img: Numpy image that will have `policy` applied to it. 51 | 52 | Returns: 53 | The result of applying `policy` to `img`. 54 | """ 55 | nametotransform = fixed_AA_NAME_TO_TRANSFORM if use_fixed_posterize else AA_NAME_TO_TRANSFORM 56 | pil_img = pil_wrap(img) 57 | for xform in policy: 58 | assert len(xform) == 3 59 | name, probability, level = xform 60 | xform_fn = nametotransform[name].pil_transformer(probability, level) 61 | pil_img = xform_fn(pil_img) 62 | return pil_unwrap(pil_img) 63 | 64 | 65 | def random_flip(x): 66 | """Flip the input x horizontally with 50% probability.""" 67 | if np.random.rand(1)[0] > 0.5: 68 | return np.fliplr(x) 69 | return x 70 | 71 | 72 | def zero_pad_and_crop(img, amount=4): 73 | """Zero pad by `amount` zero pixels on each side then take a random crop. 74 | 75 | Args: 76 | img: numpy image that will be zero padded and cropped. 77 | amount: amount of zeros to pad `img` with horizontally and verically. 78 | 79 | Returns: 80 | The cropped zero padded img. The returned numpy array will be of the same 81 | shape as `img`. 82 | """ 83 | padded_img = np.zeros((img.shape[0] + amount * 2, img.shape[1] + amount * 2, 84 | img.shape[2])) 85 | padded_img[amount:img.shape[0] + amount, amount: 86 | img.shape[1] + amount, :] = img 87 | top = np.random.randint(low=0, high=2 * amount) 88 | left = np.random.randint(low=0, high=2 * amount) 89 | new_img = padded_img[top:top + img.shape[0], left:left + img.shape[1], :] 90 | return new_img 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | def float_parameter(level, maxval): 99 | """Helper function to scale `val` between 0 and maxval . 100 | 101 | Args: 102 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 103 | maxval: Maximum value that the operation can have. This will be scaled 104 | to level/PARAMETER_MAX. 105 | 106 | Returns: 107 | A float that results from scaling `maxval` according to `level`. 108 | """ 109 | return float(level) * maxval / PARAMETER_MAX 110 | 111 | 112 | def int_parameter(level, maxval): 113 | """Helper function to scale `val` between 0 and maxval . 114 | 115 | Args: 116 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 117 | maxval: Maximum value that the operation can have. This will be scaled 118 | to level/PARAMETER_MAX. 119 | 120 | Returns: 121 | An int that results from scaling `maxval` according to `level`. 122 | """ 123 | return int(level * maxval / PARAMETER_MAX) 124 | 125 | 126 | 127 | 128 | class TransformFunction(object): 129 | """Wraps the Transform function for pretty printing options.""" 130 | 131 | def __init__(self, func, name): 132 | self.f = func 133 | self.name = name 134 | 135 | def __repr__(self): 136 | return '<' + self.name + '>' 137 | 138 | def __call__(self, pil_img): 139 | return self.f(pil_img) 140 | 141 | 142 | class TransformT(object): 143 | """Each instance of this class represents a specific transform.""" 144 | 145 | def __init__(self, name, xform_fn): 146 | self.name = name 147 | self.xform = xform_fn 148 | 149 | def pil_transformer(self, probability, level): 150 | 151 | def return_function(im): 152 | if random.random() < probability: 153 | im = self.xform(im, level) 154 | return im 155 | 156 | name = self.name + '({:.1f},{})'.format(probability, level) 157 | return TransformFunction(return_function, name) 158 | 159 | def do_transform(self, image, level): 160 | f = self.pil_transformer(PARAMETER_MAX, level) 161 | return f(image) 162 | 163 | 164 | ################## Transform Functions ################## 165 | identity = TransformT('identity', lambda pil_img, level: pil_img) 166 | flip_lr = TransformT( 167 | 'FlipLR', 168 | lambda pil_img, level: pil_img.transpose(Image.FLIP_LEFT_RIGHT)) 169 | flip_ud = TransformT( 170 | 'FlipUD', 171 | lambda pil_img, level: pil_img.transpose(Image.FLIP_TOP_BOTTOM)) 172 | # pylint:disable=g-long-lambda 173 | auto_contrast = TransformT( 174 | 'AutoContrast', 175 | lambda pil_img, level: ImageOps.autocontrast( 176 | pil_img.convert('RGB')).convert('RGBA')) 177 | equalize = TransformT( 178 | 'Equalize', 179 | lambda pil_img, level: ImageOps.equalize( 180 | pil_img.convert('RGB')).convert('RGBA')) 181 | invert = TransformT( 182 | 'Invert', 183 | lambda pil_img, level: ImageOps.invert( 184 | pil_img.convert('RGB')).convert('RGBA')) 185 | # pylint:enable=g-long-lambda 186 | blur = TransformT( 187 | 'Blur', lambda pil_img, level: pil_img.filter(ImageFilter.BLUR)) 188 | smooth = TransformT( 189 | 'Smooth', 190 | lambda pil_img, level: pil_img.filter(ImageFilter.SMOOTH)) 191 | 192 | 193 | def _rotate_impl(pil_img, level): 194 | """Rotates `pil_img` from -30 to 30 degrees depending on `level`.""" 195 | degrees = int_parameter(level, 30) 196 | if random.random() > 0.5: 197 | degrees = -degrees 198 | return pil_img.rotate(degrees) 199 | 200 | 201 | rotate = TransformT('Rotate', _rotate_impl) 202 | 203 | 204 | def _posterize_impl(pil_img, level): 205 | """Applies PIL Posterize to `pil_img`.""" 206 | level = int_parameter(level, 4) 207 | return ImageOps.posterize(pil_img.convert('RGB'), 4 - level).convert('RGBA') 208 | 209 | 210 | posterize = TransformT('Posterize', _posterize_impl) 211 | 212 | def _fixed_posterize_impl(pil_img, level): 213 | """Applies PIL Posterize to `pil_img`.""" 214 | level = int_parameter(level, 4) 215 | return ImageOps.posterize(pil_img.convert('RGB'), 8 - level).convert('RGBA') 216 | 217 | fixed_posterize = TransformT('Posterize', _fixed_posterize_impl) 218 | 219 | 220 | def _shear_x_impl(pil_img, level): 221 | """Applies PIL ShearX to `pil_img`. 222 | 223 | The ShearX operation shears the image along the horizontal axis with `level` 224 | magnitude. 225 | 226 | Args: 227 | pil_img: Image in PIL object. 228 | level: Strength of the operation specified as an Integer from 229 | [0, `PARAMETER_MAX`]. 230 | 231 | Returns: 232 | A PIL Image that has had ShearX applied to it. 233 | """ 234 | level = float_parameter(level, 0.3) 235 | if random.random() > 0.5: 236 | level = -level 237 | return pil_img.transform((32, 32), Image.AFFINE, (1, level, 0, 0, 1, 0)) 238 | 239 | 240 | shear_x = TransformT('ShearX', _shear_x_impl) 241 | 242 | 243 | def _shear_y_impl(pil_img, level): 244 | """Applies PIL ShearY to `pil_img`. 245 | 246 | The ShearY operation shears the image along the vertical axis with `level` 247 | magnitude. 248 | 249 | Args: 250 | pil_img: Image in PIL object. 251 | level: Strength of the operation specified as an Integer from 252 | [0, `PARAMETER_MAX`]. 253 | 254 | Returns: 255 | A PIL Image that has had ShearX applied to it. 256 | """ 257 | level = float_parameter(level, 0.3) 258 | if random.random() > 0.5: 259 | level = -level 260 | return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, level, 1, 0)) 261 | 262 | 263 | shear_y = TransformT('ShearY', _shear_y_impl) 264 | 265 | 266 | def _translate_x_impl(pil_img, level): 267 | """Applies PIL TranslateX to `pil_img`. 268 | 269 | Translate the image in the horizontal direction by `level` 270 | number of pixels. 271 | 272 | Args: 273 | pil_img: Image in PIL object. 274 | level: Strength of the operation specified as an Integer from 275 | [0, `PARAMETER_MAX`]. 276 | 277 | Returns: 278 | A PIL Image that has had TranslateX applied to it. 279 | """ 280 | level = int_parameter(level, 10) 281 | if random.random() > 0.5: 282 | level = -level 283 | return pil_img.transform((32, 32), Image.AFFINE, (1, 0, level, 0, 1, 0)) 284 | 285 | 286 | translate_x = TransformT('TranslateX', _translate_x_impl) 287 | 288 | 289 | def _translate_y_impl(pil_img, level): 290 | """Applies PIL TranslateY to `pil_img`. 291 | 292 | Translate the image in the vertical direction by `level` 293 | number of pixels. 294 | 295 | Args: 296 | pil_img: Image in PIL object. 297 | level: Strength of the operation specified as an Integer from 298 | [0, `PARAMETER_MAX`]. 299 | 300 | Returns: 301 | A PIL Image that has had TranslateY applied to it. 302 | """ 303 | level = int_parameter(level, 10) 304 | if random.random() > 0.5: 305 | level = -level 306 | return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, 0, 1, level)) 307 | 308 | 309 | translate_y = TransformT('TranslateY', _translate_y_impl) 310 | 311 | 312 | def _crop_impl(pil_img, level, interpolation=Image.BILINEAR): 313 | """Applies a crop to `pil_img` with the size depending on the `level`.""" 314 | cropped = pil_img.crop((level, level, IMAGE_SIZE - level, IMAGE_SIZE - level)) 315 | resized = cropped.resize((IMAGE_SIZE, IMAGE_SIZE), interpolation) 316 | return resized 317 | 318 | 319 | crop_bilinear = TransformT('CropBilinear', _crop_impl) 320 | 321 | 322 | def _solarize_impl(pil_img, level): 323 | """Applies PIL Solarize to `pil_img`. 324 | 325 | Translate the image in the vertical direction by `level` 326 | number of pixels. 327 | 328 | Args: 329 | pil_img: Image in PIL object. 330 | level: Strength of the operation specified as an Integer from 331 | [0, `PARAMETER_MAX`]. 332 | 333 | Returns: 334 | A PIL Image that has had Solarize applied to it. 335 | """ 336 | level = int_parameter(level, 256) 337 | return ImageOps.solarize(pil_img.convert('RGB'), 256 - level).convert('RGBA') 338 | 339 | 340 | solarize = TransformT('Solarize', _solarize_impl) 341 | 342 | 343 | def _enhancer_impl(enhancer): 344 | """Sets level to be between 0.1 and 1.8 for ImageEnhance transforms of PIL.""" 345 | def impl(pil_img, level): 346 | v = float_parameter(level, 1.8) + .1 # going to 0 just destroys it 347 | return enhancer(pil_img).enhance(v) 348 | return impl 349 | 350 | 351 | color = TransformT('Color', _enhancer_impl(ImageEnhance.Color)) 352 | contrast = TransformT('Contrast', _enhancer_impl(ImageEnhance.Contrast)) 353 | brightness = TransformT('Brightness', _enhancer_impl( 354 | ImageEnhance.Brightness)) 355 | sharpness = TransformT('Sharpness', _enhancer_impl(ImageEnhance.Sharpness)) 356 | 357 | def create_cutout_mask(img_height, img_width, num_channels, size): 358 | """Creates a zero mask used for cutout of shape `img_height` x `img_width`. 359 | 360 | Args: 361 | img_height: Height of image cutout mask will be applied to. 362 | img_width: Width of image cutout mask will be applied to. 363 | num_channels: Number of channels in the image. 364 | size: Size of the zeros mask. 365 | 366 | Returns: 367 | A mask of shape `img_height` x `img_width` with all ones except for a 368 | square of zeros of shape `size` x `size`. This mask is meant to be 369 | elementwise multiplied with the original image. Additionally returns 370 | the `upper_coord` and `lower_coord` which specify where the cutout mask 371 | will be applied. 372 | """ 373 | assert img_height == img_width 374 | 375 | # Sample center where cutout mask will be applied 376 | height_loc = np.random.randint(low=0, high=img_height) 377 | width_loc = np.random.randint(low=0, high=img_width) 378 | 379 | # Determine upper right and lower left corners of patch 380 | upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2)) 381 | lower_coord = (min(img_height, height_loc + size // 2), 382 | min(img_width, width_loc + size // 2)) 383 | mask_height = lower_coord[0] - upper_coord[0] 384 | mask_width = lower_coord[1] - upper_coord[1] 385 | assert mask_height > 0 386 | assert mask_width > 0 387 | 388 | mask = np.ones((img_height, img_width, num_channels)) 389 | zeros = np.zeros((mask_height, mask_width, num_channels)) 390 | mask[upper_coord[0]:lower_coord[0], upper_coord[1]:lower_coord[1], :] = ( 391 | zeros) 392 | return mask, upper_coord, lower_coord 393 | 394 | def _cutout_pil_impl(pil_img, level): 395 | """Apply cutout to pil_img at the specified level.""" 396 | size = int_parameter(level, 20) 397 | if size <= 0: 398 | return pil_img 399 | img_height, img_width, num_channels = (32, 32, 3) 400 | _, upper_coord, lower_coord = ( 401 | create_cutout_mask(img_height, img_width, num_channels, size)) 402 | pixels = pil_img.load() # create the pixel map 403 | for i in range(upper_coord[0], lower_coord[0]): # for every col: 404 | for j in range(upper_coord[1], lower_coord[1]): # For every row 405 | pixels[i, j] = (125, 122, 113, 0) # set the colour accordingly 406 | return pil_img 407 | 408 | cutout = TransformT('Cutout', _cutout_pil_impl) 409 | 410 | 411 | 412 | ALL_TRANSFORMS = [ 413 | identity, 414 | auto_contrast, 415 | equalize, 416 | rotate, 417 | posterize, 418 | solarize, 419 | color, 420 | contrast, 421 | brightness, 422 | sharpness, 423 | shear_x, 424 | shear_y, 425 | translate_x, 426 | translate_y, 427 | ] 428 | 429 | AA_ALL_TRANSFORMS = [ 430 | flip_lr, 431 | flip_ud, 432 | auto_contrast, 433 | equalize, 434 | invert, 435 | rotate, 436 | posterize, 437 | crop_bilinear, 438 | solarize, 439 | color, 440 | contrast, 441 | brightness, 442 | sharpness, 443 | shear_x, 444 | shear_y, 445 | translate_x, 446 | translate_y, 447 | cutout, 448 | blur, 449 | smooth 450 | ] 451 | 452 | 453 | fixed_AA_ALL_TRANSFORMS = [ 454 | flip_lr, 455 | flip_ud, 456 | auto_contrast, 457 | equalize, 458 | invert, 459 | rotate, 460 | fixed_posterize, 461 | crop_bilinear, 462 | solarize, 463 | color, 464 | contrast, 465 | brightness, 466 | sharpness, 467 | shear_x, 468 | shear_y, 469 | translate_x, 470 | translate_y, 471 | cutout, 472 | blur, 473 | smooth 474 | ] 475 | 476 | 477 | class RandAugment: 478 | def __init__(self, n, m): 479 | self.n = n 480 | self.m = m # [0, 30] 481 | 482 | def __call__(self, img): 483 | img = pil_wrap(img) 484 | ops = random.choices(ALL_TRANSFORMS, k=self.n) 485 | for op in ops: 486 | img = op.pil_transformer(1.,self.m)(img) 487 | img = pil_unwrap(img) 488 | 489 | return img 490 | 491 | AA_NAME_TO_TRANSFORM = {t.name: t for t in AA_ALL_TRANSFORMS} 492 | fixed_AA_NAME_TO_TRANSFORM = {t.name: t for t in fixed_AA_ALL_TRANSFORMS} 493 | 494 | NAME_TO_TRANSFORM = {t.name: t for t in ALL_TRANSFORMS} 495 | 496 | def good_policies(): 497 | """AutoAugment policies found on Cifar.""" 498 | exp0_0 = [ 499 | [('Invert', 0.1, 7), ('Contrast', 0.2, 6)], 500 | [('Rotate', 0.7, 2), ('TranslateX', 0.3, 9)], 501 | [('Sharpness', 0.8, 1), ('Sharpness', 0.9, 3)], 502 | [('ShearY', 0.5, 8), ('TranslateY', 0.7, 9)], 503 | [('AutoContrast', 0.5, 8), ('Equalize', 0.9, 2)]] 504 | exp0_1 = [ 505 | [('Solarize', 0.4, 5), ('AutoContrast', 0.9, 3)], 506 | [('TranslateY', 0.9, 9), ('TranslateY', 0.7, 9)], 507 | [('AutoContrast', 0.9, 2), ('Solarize', 0.8, 3)], 508 | [('Equalize', 0.8, 8), ('Invert', 0.1, 3)], 509 | [('TranslateY', 0.7, 9), ('AutoContrast', 0.9, 1)]] 510 | exp0_2 = [ 511 | [('Solarize', 0.4, 5), ('AutoContrast', 0.0, 2)], 512 | [('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)], 513 | [('AutoContrast', 0.9, 0), ('Solarize', 0.4, 3)], 514 | [('Equalize', 0.7, 5), ('Invert', 0.1, 3)], 515 | [('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)]] 516 | exp0_3 = [ 517 | [('Solarize', 0.4, 5), ('AutoContrast', 0.9, 1)], 518 | [('TranslateY', 0.8, 9), ('TranslateY', 0.9, 9)], 519 | [('AutoContrast', 0.8, 0), ('TranslateY', 0.7, 9)], 520 | [('TranslateY', 0.2, 7), ('Color', 0.9, 6)], 521 | [('Equalize', 0.7, 6), ('Color', 0.4, 9)]] 522 | exp1_0 = [ 523 | [('ShearY', 0.2, 7), ('Posterize', 0.3, 7)], 524 | [('Color', 0.4, 3), ('Brightness', 0.6, 7)], 525 | [('Sharpness', 0.3, 9), ('Brightness', 0.7, 9)], 526 | [('Equalize', 0.6, 5), ('Equalize', 0.5, 1)], 527 | [('Contrast', 0.6, 7), ('Sharpness', 0.6, 5)]] 528 | exp1_1 = [ 529 | [('Brightness', 0.3, 7), ('AutoContrast', 0.5, 8)], 530 | [('AutoContrast', 0.9, 4), ('AutoContrast', 0.5, 6)], 531 | [('Solarize', 0.3, 5), ('Equalize', 0.6, 5)], 532 | [('TranslateY', 0.2, 4), ('Sharpness', 0.3, 3)], 533 | [('Brightness', 0.0, 8), ('Color', 0.8, 8)]] 534 | exp1_2 = [ 535 | [('Solarize', 0.2, 6), ('Color', 0.8, 6)], 536 | [('Solarize', 0.2, 6), ('AutoContrast', 0.8, 1)], 537 | [('Solarize', 0.4, 1), ('Equalize', 0.6, 5)], 538 | [('Brightness', 0.0, 0), ('Solarize', 0.5, 2)], 539 | [('AutoContrast', 0.9, 5), ('Brightness', 0.5, 3)]] 540 | exp1_3 = [ 541 | [('Contrast', 0.7, 5), ('Brightness', 0.0, 2)], 542 | [('Solarize', 0.2, 8), ('Solarize', 0.1, 5)], 543 | [('Contrast', 0.5, 1), ('TranslateY', 0.2, 9)], 544 | [('AutoContrast', 0.6, 5), ('TranslateY', 0.0, 9)], 545 | [('AutoContrast', 0.9, 4), ('Equalize', 0.8, 4)]] 546 | exp1_4 = [ 547 | [('Brightness', 0.0, 7), ('Equalize', 0.4, 7)], 548 | [('Solarize', 0.2, 5), ('Equalize', 0.7, 5)], 549 | [('Equalize', 0.6, 8), ('Color', 0.6, 2)], 550 | [('Color', 0.3, 7), ('Color', 0.2, 4)], 551 | [('AutoContrast', 0.5, 2), ('Solarize', 0.7, 2)]] 552 | exp1_5 = [ 553 | [('AutoContrast', 0.2, 0), ('Equalize', 0.1, 0)], 554 | [('ShearY', 0.6, 5), ('Equalize', 0.6, 5)], 555 | [('Brightness', 0.9, 3), ('AutoContrast', 0.4, 1)], 556 | [('Equalize', 0.8, 8), ('Equalize', 0.7, 7)], 557 | [('Equalize', 0.7, 7), ('Solarize', 0.5, 0)]] 558 | exp1_6 = [ 559 | [('Equalize', 0.8, 4), ('TranslateY', 0.8, 9)], 560 | [('TranslateY', 0.8, 9), ('TranslateY', 0.6, 9)], 561 | [('TranslateY', 0.9, 0), ('TranslateY', 0.5, 9)], 562 | [('AutoContrast', 0.5, 3), ('Solarize', 0.3, 4)], 563 | [('Solarize', 0.5, 3), ('Equalize', 0.4, 4)]] 564 | exp2_0 = [ 565 | [('Color', 0.7, 7), ('TranslateX', 0.5, 8)], 566 | [('Equalize', 0.3, 7), ('AutoContrast', 0.4, 8)], 567 | [('TranslateY', 0.4, 3), ('Sharpness', 0.2, 6)], 568 | [('Brightness', 0.9, 6), ('Color', 0.2, 8)], 569 | [('Solarize', 0.5, 2), ('Invert', 0.0, 3)]] 570 | exp2_1 = [ 571 | [('AutoContrast', 0.1, 5), ('Brightness', 0.0, 0)], 572 | [('Cutout', 0.2, 4), ('Equalize', 0.1, 1)], 573 | [('Equalize', 0.7, 7), ('AutoContrast', 0.6, 4)], 574 | [('Color', 0.1, 8), ('ShearY', 0.2, 3)], 575 | [('ShearY', 0.4, 2), ('Rotate', 0.7, 0)]] 576 | exp2_2 = [ 577 | [('ShearY', 0.1, 3), ('AutoContrast', 0.9, 5)], 578 | [('TranslateY', 0.3, 6), ('Cutout', 0.3, 3)], 579 | [('Equalize', 0.5, 0), ('Solarize', 0.6, 6)], 580 | [('AutoContrast', 0.3, 5), ('Rotate', 0.2, 7)], 581 | [('Equalize', 0.8, 2), ('Invert', 0.4, 0)]] 582 | exp2_3 = [ 583 | [('Equalize', 0.9, 5), ('Color', 0.7, 0)], 584 | [('Equalize', 0.1, 1), ('ShearY', 0.1, 3)], 585 | [('AutoContrast', 0.7, 3), ('Equalize', 0.7, 0)], 586 | [('Brightness', 0.5, 1), ('Contrast', 0.1, 7)], 587 | [('Contrast', 0.1, 4), ('Solarize', 0.6, 5)]] 588 | exp2_4 = [ 589 | [('Solarize', 0.2, 3), ('ShearX', 0.0, 0)], 590 | [('TranslateX', 0.3, 0), ('TranslateX', 0.6, 0)], 591 | [('Equalize', 0.5, 9), ('TranslateY', 0.6, 7)], 592 | [('ShearX', 0.1, 0), ('Sharpness', 0.5, 1)], 593 | [('Equalize', 0.8, 6), ('Invert', 0.3, 6)]] 594 | exp2_5 = [ 595 | [('AutoContrast', 0.3, 9), ('Cutout', 0.5, 3)], 596 | [('ShearX', 0.4, 4), ('AutoContrast', 0.9, 2)], 597 | [('ShearX', 0.0, 3), ('Posterize', 0.0, 3)], 598 | [('Solarize', 0.4, 3), ('Color', 0.2, 4)], 599 | [('Equalize', 0.1, 4), ('Equalize', 0.7, 6)]] 600 | exp2_6 = [ 601 | [('Equalize', 0.3, 8), ('AutoContrast', 0.4, 3)], 602 | [('Solarize', 0.6, 4), ('AutoContrast', 0.7, 6)], 603 | [('AutoContrast', 0.2, 9), ('Brightness', 0.4, 8)], 604 | [('Equalize', 0.1, 0), ('Equalize', 0.0, 6)], 605 | [('Equalize', 0.8, 4), ('Equalize', 0.0, 4)]] 606 | exp2_7 = [ 607 | [('Equalize', 0.5, 5), ('AutoContrast', 0.1, 2)], 608 | [('Solarize', 0.5, 5), ('AutoContrast', 0.9, 5)], 609 | [('AutoContrast', 0.6, 1), ('AutoContrast', 0.7, 8)], 610 | [('Equalize', 0.2, 0), ('AutoContrast', 0.1, 2)], 611 | [('Equalize', 0.6, 9), ('Equalize', 0.4, 4)]] 612 | exp0s = exp0_0 + exp0_1 + exp0_2 + exp0_3 613 | exp1s = exp1_0 + exp1_1 + exp1_2 + exp1_3 + exp1_4 + exp1_5 + exp1_6 614 | exp2s = exp2_0 + exp2_1 + exp2_2 + exp2_3 + exp2_4 + exp2_5 + exp2_6 + exp2_7 615 | return exp0s + exp1s + exp2s 616 | 617 | cifar_gp = good_policies() 618 | 619 | first_aug_ops = [("ShearX",0.9,4), ("ShearY",0.9,8), ("Equalize",0.6,5), ("Invert",0.9,3), ("Equalize",0.6,1), ("ShearX",0.9,4), ("ShearY",0.9,8), ("ShearY",0.9,5), ("Invert",0.9,6), ("Equalize",0.6,3), ("ShearX",0.9,4), ("ShearY",0.8,8), ("Equalize",0.9,5), ("Invert",0.9,4), ("Contrast",0.3,3), ("Invert",0.8,5), ("ShearY",0.7,6), ("Invert",0.6,4), ("ShearY",0.3,7), ("ShearX",0.1,6), ("Solarize",0.7,2), ("ShearY",0.8,4), ("ShearX",0.7,9), ("ShearY",0.8,5), ("ShearX",0.7,2)] 620 | second_aug_ops = [("Invert",0.2,3), ("Invert",0.7,5), ("Solarize",0.6,6), ("Equalize",0.6,3), ("Rotate",0.9,3), ("AutoContrast",0.8,3), ("Invert",0.4,5), ("Solarize",0.2,6), ("AutoContrast",0.8,1), ("Rotate",0.9,3), ("Solarize",0.3,3), ("Invert",0.7,4), ("TranslateY",0.6,6), ("Equalize",0.6,7), ("Rotate",0.8,4), ("TranslateY",0.0,2), ("Solarize",0.4,8), ("Rotate",0.8,4), ("TranslateX",0.9,3), ("Invert",0.6,5), ("TranslateY",0.6,7), ("Invert",0.8,8), ("TranslateY",0.8,3), ("AutoContrast",0.7,3), ("Invert",0.1,5)] 621 | 622 | svhn_gp = [[a1, a2] for a1, a2 in zip(first_aug_ops,second_aug_ops)] 623 | 624 | class CifarAutoAugment: 625 | def __init__(self, fixed_posterize): 626 | self.fixed_posterize = fixed_posterize 627 | 628 | def __call__(self, img): 629 | epoch_policy = cifar_gp[np.random.choice(len(cifar_gp))] 630 | final_img = apply_policy(epoch_policy, img, use_fixed_posterize=self.fixed_posterize) 631 | 632 | return final_img 633 | 634 | class SVHNAutoAugment: 635 | def __init__(self, fixed_posterize): 636 | self.fixed_posterize = fixed_posterize 637 | 638 | def __call__(self, img): 639 | epoch_policy = svhn_gp[np.random.choice(len(svhn_gp))] 640 | final_img = apply_policy(epoch_policy, img, use_fixed_posterize=self.fixed_posterize) 641 | 642 | return final_img -------------------------------------------------------------------------------- /DeepAA_evaluate/common.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | import random 4 | from copy import copy 5 | from typing import Union 6 | from collections import Counter 7 | 8 | import numpy as np 9 | import torch 10 | from torch.utils.checkpoint import check_backward_validity, detach_variable, get_device_states, set_device_states 11 | from torchvision.datasets import VisionDataset, CIFAR10, CIFAR100, ImageFolder 12 | from torch.utils.data import Subset, ConcatDataset 13 | 14 | from PIL import Image 15 | 16 | formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s') 17 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 18 | 19 | 20 | def get_logger(name, level=logging.DEBUG): 21 | logger = logging.getLogger(name) 22 | logger.handlers.clear() 23 | logger.setLevel(level) 24 | ch = logging.StreamHandler() 25 | ch.setLevel(level) 26 | ch.setFormatter(formatter) 27 | logger.addHandler(ch) 28 | return logger 29 | 30 | 31 | def add_filehandler(logger, filepath): 32 | fh = logging.FileHandler(filepath) 33 | fh.setLevel(logging.DEBUG) 34 | fh.setFormatter(formatter) 35 | logger.addHandler(fh) 36 | 37 | 38 | def copy_and_replace_transform(ds: Union[CIFAR10, ImageFolder, Subset], transform): 39 | assert ds.dataset.transform is not None if isinstance(ds,Subset) else (all(d.transform is not None for d in ds.datasets) if isinstance(ds,ConcatDataset) else ds.transform is not None) # make sure still uses old style transform 40 | if isinstance(ds, Subset): 41 | new_super_ds = copy(ds.dataset) 42 | new_super_ds.transform = transform 43 | new_ds = copy(ds) 44 | new_ds.dataset = new_super_ds 45 | elif isinstance(ds, ConcatDataset): 46 | def copy_and_replace_transform(ds): 47 | new_ds = copy(ds) 48 | new_ds.transform = transform 49 | return new_ds 50 | 51 | new_ds = ConcatDataset([copy_and_replace_transform(d) for d in ds.datasets]) 52 | 53 | else: 54 | new_ds = copy(ds) 55 | new_ds.transform = transform 56 | return new_ds 57 | 58 | def apply_weightnorm(nn): 59 | def apply_weightnorm_(module): 60 | if 'Linear' in type(module).__name__ or 'Conv' in type(module).__name__: 61 | torch.nn.utils.weight_norm(module, name='weight', dim=0) 62 | nn.apply(apply_weightnorm_) 63 | 64 | 65 | def shufflelist_with_seed(lis, seed='2020'): 66 | s = random.getstate() 67 | random.seed(seed) 68 | random.shuffle(lis) 69 | random.setstate(s) 70 | 71 | 72 | def stratified_split(labels, val_share): 73 | assert isinstance(labels, list) 74 | counter = Counter(labels) 75 | indices_per_label = {label: [i for i,l in enumerate(labels) if l == label] for label in counter} 76 | per_label_split = {} 77 | for label, count in counter.items(): 78 | indices = indices_per_label[label] 79 | assert count == len(indices) 80 | shufflelist_with_seed(indices, f'2020_{label}_{count}') 81 | train_val_border = round(count*(1.-val_share)) 82 | per_label_split[label] = (indices[:train_val_border], indices[train_val_border:]) 83 | final_split = ([],[]) 84 | for label, split in per_label_split.items(): 85 | for f_s, s in zip(final_split, split): 86 | f_s.extend(s) 87 | shufflelist_with_seed(final_split[0], '2020_yoyo') 88 | shufflelist_with_seed(final_split[1], '2020_yo') 89 | return final_split 90 | 91 | 92 | def denormalize(img, mean, std): 93 | mean, std = torch.tensor(mean).to(img.device), torch.tensor(std).to(img.device) 94 | return img.mul_(std[:,None,None]).add_(mean[:,None,None]) 95 | 96 | def normalize(img, mean, std): 97 | mean, std = torch.tensor(mean).to(img.device), torch.tensor(std).to(img.device) 98 | return img.sub_(mean[:,None,None]).div_(std[:,None,None]) -------------------------------------------------------------------------------- /DeepAA_evaluate/data.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | from collections import Counter 5 | 6 | import torchvision 7 | from PIL import Image 8 | 9 | from torch.utils.data import SubsetRandomSampler, Sampler 10 | from torch.utils.data.distributed import DistributedSampler 11 | from torch.utils.data.dataset import ConcatDataset, Subset 12 | from torchvision.transforms import transforms 13 | from sklearn.model_selection import StratifiedShuffleSplit 14 | from theconf import Config as C 15 | 16 | from DeepAA_evaluate.augmentations import * 17 | from DeepAA_evaluate.common import get_logger, copy_and_replace_transform, stratified_split, denormalize 18 | from DeepAA_evaluate.imagenet import ImageNet 19 | 20 | from DeepAA_evaluate.augmentations import Lighting 21 | 22 | from DeepAA_evaluate.deep_autoaugment import Augmentation_DeepAA 23 | 24 | logger = get_logger('DeepAA_evaluate') 25 | logger.setLevel(logging.INFO) 26 | _IMAGENET_PCA = { 27 | 'eigval': [0.2175, 0.0188, 0.0045], 28 | 'eigvec': [ 29 | [-0.5675, 0.7192, 0.4009], 30 | [-0.5808, -0.0045, -0.8140], 31 | [-0.5836, -0.6948, 0.4203], 32 | ] 33 | } 34 | _CIFAR_MEAN, _CIFAR_STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) # these are for CIFAR 10, not for cifar100 actaully. They are pretty similar, though. 35 | # mean für cifar 100: tensor([0.5071, 0.4866, 0.4409]) 36 | 37 | def expand(num_classes, dtype, tensor): 38 | e = torch.zeros( 39 | tensor.size(0), num_classes, dtype=dtype, device=torch.device("cuda") 40 | ) 41 | e = e.scatter(1, tensor.unsqueeze(1), 1.0) 42 | return e 43 | 44 | def mixup_data(data, label, alpha): 45 | with torch.no_grad(): 46 | if alpha > 0: 47 | lam = np.random.beta(alpha, alpha) 48 | else: 49 | lam = 1.0 50 | batch_size = data.size()[0] 51 | index = torch.randperm(batch_size).to(data.device) 52 | mixed_data = lam * data + (1.0-lam) * data[index,:] 53 | return mixed_data, label, label[index], lam 54 | 55 | 56 | class PrefetchedWrapper(object): 57 | # Ref: https://github.com/NVIDIA/DeepLearningExamples/blob/d788e8d4968e72c722c5148a50a7d4692f6e7bd3/PyTorch/Classification/ConvNets/image_classification/dataloaders.py#L405 58 | def prefetched_loader(loader, num_classes, one_hot): 59 | mean = ( 60 | torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]) 61 | .cuda() 62 | .view(1, 3, 1, 1) 63 | ) 64 | std = ( 65 | torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]) 66 | .cuda() 67 | .view(1, 3, 1, 1) 68 | ) 69 | 70 | stream = torch.cuda.Stream() 71 | first = True 72 | 73 | for next_input, next_target in loader: 74 | with torch.cuda.stream(stream): 75 | next_input = next_input.cuda(non_blocking=True) 76 | next_target = next_target.cuda(non_blocking=True) 77 | next_input = next_input.float() 78 | if one_hot: 79 | raise Exception('Currently do not use onehot encoding, becasue num_calsses==None') 80 | next_target = expand(num_classes, torch.float, next_target) 81 | 82 | next_input = next_input.sub_(mean).div_(std) 83 | 84 | if not first: 85 | yield input, target 86 | else: 87 | first = False 88 | 89 | torch.cuda.current_stream().wait_stream(stream) 90 | input = next_input 91 | target = next_target 92 | 93 | yield input, target 94 | 95 | def __init__(self, dataloader, start_epoch, num_classes, one_hot): 96 | self.dataloader = dataloader 97 | self.epoch = start_epoch 98 | self.one_hot = one_hot 99 | self.num_classes = num_classes 100 | 101 | def __iter__(self): 102 | if self.dataloader.sampler is not None and isinstance( 103 | self.dataloader.sampler, torch.utils.data.distributed.DistributedSampler 104 | ): 105 | 106 | self.dataloader.sampler.set_epoch(self.epoch) 107 | self.epoch += 1 108 | return PrefetchedWrapper.prefetched_loader( 109 | self.dataloader, self.num_classes, self.one_hot 110 | ) 111 | 112 | def __len__(self): 113 | return len(self.dataloader) 114 | 115 | def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, distributed=False, started_with_spawn=False, summary_writer=None): 116 | print(f'started with spawn {started_with_spawn}') 117 | dataset_info = {} 118 | pre_transform_train = transforms.Compose([]) 119 | if 'cifar' in dataset and (C.get()['aug'] in ['DeepAA']): 120 | transform_train = transforms.Compose([ 121 | # transforms.RandomCrop(32, padding=4), 122 | # transforms.RandomHorizontalFlip(), 123 | transforms.ToTensor(), 124 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), 125 | ]) 126 | transform_test = transforms.Compose([ 127 | transforms.ToTensor(), 128 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), 129 | ]) 130 | dataset_info['mean'] = _CIFAR_MEAN 131 | dataset_info['std'] = _CIFAR_STD 132 | dataset_info['img_dims'] = (3,32,32) 133 | dataset_info['num_labels'] = 100 if '100' in dataset and 'ten' not in dataset else 10 134 | elif 'cifar' in dataset: 135 | transform_train = transforms.Compose([ 136 | transforms.RandomCrop(32, padding=4), 137 | transforms.RandomHorizontalFlip(), 138 | transforms.ToTensor(), 139 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), 140 | ]) 141 | transform_test = transforms.Compose([ 142 | transforms.ToTensor(), 143 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), 144 | ]) 145 | dataset_info['mean'] = _CIFAR_MEAN 146 | dataset_info['std'] = _CIFAR_STD 147 | dataset_info['img_dims'] = (3,32,32) 148 | dataset_info['num_labels'] = 100 if '100' in dataset and 'ten' not in dataset else 10 149 | elif 'pre_transform_cifar' in dataset: 150 | pre_transform_train = transforms.Compose([ 151 | transforms.RandomCrop(32, padding=4), 152 | transforms.RandomHorizontalFlip(),]) 153 | transform_train = transforms.Compose([ 154 | transforms.ToTensor(), 155 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), 156 | ]) 157 | transform_test = transforms.Compose([ 158 | transforms.ToTensor(), 159 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), 160 | ]) 161 | dataset_info['mean'] = _CIFAR_MEAN 162 | dataset_info['std'] = _CIFAR_STD 163 | dataset_info['img_dims'] = (3, 32, 32) 164 | dataset_info['num_labels'] = 100 if '100' in dataset and 'ten' not in dataset else 10 165 | elif 'svhn' in dataset: 166 | svhn_mean = [0.4379, 0.4440, 0.4729] 167 | svhn_std = [0.1980, 0.2010, 0.1970] 168 | transform_train = transforms.Compose([ 169 | transforms.ToTensor(), 170 | transforms.Normalize(svhn_mean, svhn_std), 171 | ]) 172 | transform_test = transforms.Compose([ 173 | transforms.ToTensor(), 174 | transforms.Normalize(svhn_mean, svhn_std), 175 | ]) 176 | dataset_info['mean'] = svhn_mean 177 | dataset_info['std'] = svhn_std 178 | dataset_info['img_dims'] = (3, 32, 32) 179 | dataset_info['num_labels'] = 10 180 | elif 'imagenet' in dataset and C.get()['aug'] in ['DeepAA']: 181 | transform_train = transforms.Compose([ 182 | transforms.ToTensor(), 183 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Image size (224, 224) instead of (224, 244) in TA 184 | ]) 185 | 186 | transform_test = transforms.Compose([ 187 | transforms.Resize(256, interpolation=Image.BICUBIC), 188 | transforms.CenterCrop((224,224)), 189 | transforms.ToTensor(), 190 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 191 | ]) 192 | dataset_info['mean'] = [0.485, 0.456, 0.406] 193 | dataset_info['std'] = [0.229, 0.224, 0.225] 194 | dataset_info['img_dims'] = (3,224,224) 195 | dataset_info['num_labels'] = 1000 196 | elif 'imagenet' in dataset and C.get()['aug']=='inception': 197 | transform_train = transforms.Compose([ 198 | transforms.RandomResizedCrop((224,224), scale=(0.08, 1.0), interpolation=Image.BICUBIC), # Image size (224, 224) instead of (224, 244) in TA 199 | transforms.RandomHorizontalFlip(), 200 | transforms.ColorJitter( 201 | brightness=0.4, 202 | contrast=0.4, 203 | saturation=0.4, 204 | ), 205 | transforms.ToTensor(), 206 | Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']), 207 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 208 | ]) 209 | 210 | transform_test = transforms.Compose([ 211 | transforms.Resize(256, interpolation=Image.BICUBIC), 212 | transforms.CenterCrop((224,224)), 213 | transforms.ToTensor(), 214 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 215 | ]) 216 | dataset_info['mean'] = [0.485, 0.456, 0.406] 217 | dataset_info['std'] = [0.229, 0.224, 0.225] 218 | dataset_info['img_dims'] = (3,224,224) 219 | dataset_info['num_labels'] = 1000 220 | elif 'smallwidth_imagenet' in dataset: 221 | transform_train = transforms.Compose([ 222 | transforms.RandomResizedCrop((224,224), scale=(0.08, 1.0), interpolation=Image.BICUBIC), 223 | transforms.RandomHorizontalFlip(), 224 | transforms.ColorJitter( 225 | brightness=0.4, 226 | contrast=0.4, 227 | saturation=0.4, 228 | ), 229 | transforms.ToTensor(), 230 | Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']), 231 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 232 | ]) 233 | 234 | transform_test = transforms.Compose([ 235 | transforms.Resize(256, interpolation=Image.BICUBIC), 236 | transforms.CenterCrop((224,224)), 237 | transforms.ToTensor(), 238 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 239 | ]) 240 | dataset_info['mean'] = [0.485, 0.456, 0.406] 241 | dataset_info['std'] = [0.229, 0.224, 0.225] 242 | dataset_info['img_dims'] = (3,224,224) 243 | dataset_info['num_labels'] = 1000 244 | elif 'ohl_pipeline_imagenet' in dataset: 245 | pre_transform_train = transforms.Compose([ 246 | transforms.RandomResizedCrop((224, 224), scale=(0.08, 1.0), interpolation=Image.BICUBIC), 247 | transforms.RandomHorizontalFlip(), 248 | ]) 249 | transform_train = transforms.Compose([ 250 | transforms.ToTensor(), 251 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[1.,1.,1.]) 252 | ]) 253 | 254 | transform_test = transforms.Compose([ 255 | transforms.Resize(256, interpolation=Image.BICUBIC), 256 | transforms.CenterCrop((224,224)), 257 | transforms.ToTensor(), 258 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[1.,1.,1.]) 259 | ]) 260 | dataset_info['mean'] = [0.485, 0.456, 0.406] 261 | dataset_info['std'] = [1.,1.,1.] 262 | dataset_info['img_dims'] = (3,224,224) 263 | dataset_info['num_labels'] = 1000 264 | elif 'largewidth_imagenet' in dataset: 265 | transform_train = transforms.Compose([ 266 | transforms.RandomResizedCrop((224, 244), scale=(0.08, 1.0), interpolation=Image.BICUBIC), 267 | transforms.RandomHorizontalFlip(), 268 | transforms.ColorJitter( 269 | brightness=0.4, 270 | contrast=0.4, 271 | saturation=0.4, 272 | ), 273 | transforms.ToTensor(), 274 | Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']), 275 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 276 | ]) 277 | 278 | transform_test = transforms.Compose([ 279 | transforms.Resize(256, interpolation=Image.BICUBIC), 280 | transforms.CenterCrop((224, 244)), 281 | transforms.ToTensor(), 282 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 283 | ]) 284 | dataset_info['mean'] = [0.485, 0.456, 0.406] 285 | dataset_info['std'] = [0.229, 0.224, 0.225] 286 | dataset_info['img_dims'] = (3, 224, 244) 287 | dataset_info['num_labels'] = 1000 288 | else: 289 | raise ValueError('dataset=%s' % dataset) 290 | 291 | logger.debug('augmentation: %s' % C.get()['aug']) 292 | if C.get()['aug'] == 'randaugment': 293 | assert not C.get()['randaug'].get('corrected_sample_space') and not C.get()['randaug'].get('google_augmentations') 294 | transform_train.transforms.insert(0, get_randaugment(n=C.get()['randaug']['N'], m=C.get()['randaug']['M'], 295 | weights=C.get()['randaug'].get('weights',None), bs=C.get()['batch'])) 296 | elif C.get()['aug'] in ['default', 'inception', 'inception320']: 297 | pass 298 | elif C.get()['aug'] in ['DeepAA']: 299 | transform_train.transforms.insert(0, Augmentation_DeepAA(EXP = C.get()['deepaa']['EXP'], 300 | use_crop = ('imagenet' in dataset) and C.get()['aug'] == 'DeepAA' 301 | )) 302 | else: 303 | raise ValueError('not found augmentations. %s' % C.get()['aug']) 304 | 305 | transform_train.transforms.insert(0, pre_transform_train) 306 | 307 | if C.get()['cutout'] > 0: 308 | transform_train.transforms.append(CutoutDefault(C.get()['cutout'])) 309 | 310 | if 'preprocessor' in C.get(): 311 | if 'imagenet' in dataset: 312 | print("Only using cropping/centering transforms on dataset, since preprocessor active.") 313 | transform_train = transforms.Compose([ 314 | transforms.RandomResizedCrop(224, scale=(0.08, 1.0), interpolation=Image.BICUBIC), 315 | PILImageToHWCByteTensor(), 316 | ]) 317 | 318 | transform_test = transforms.Compose([ 319 | transforms.Resize(256, interpolation=Image.BICUBIC), 320 | transforms.CenterCrop(224), 321 | PILImageToHWCByteTensor(), 322 | ]) 323 | else: 324 | print("Not using any transforms in dataset, since preprocessor is active.") 325 | transform_train = PILImageToHWCByteTensor() 326 | transform_test = PILImageToHWCByteTensor() 327 | 328 | if dataset in ('cifar10', 'pre_transform_cifar10'): 329 | total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train) 330 | testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) 331 | elif dataset in ('cifar100', 'pre_transform_cifar100'): 332 | total_trainset = torchvision.datasets.CIFAR100(root=dataroot, train=True, download=True, transform=transform_train) 333 | testset = torchvision.datasets.CIFAR100(root=dataroot, train=False, download=True, transform=transform_test) 334 | elif dataset == 'svhncore': 335 | total_trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, 336 | transform=transform_train) 337 | testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test) 338 | elif dataset == 'svhn': 339 | trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train) 340 | extraset = torchvision.datasets.SVHN(root=dataroot, split='extra', download=True, transform=transform_train) 341 | total_trainset = ConcatDataset([trainset, extraset]) 342 | testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test) 343 | elif dataset in ('imagenet', 'ohl_pipeline_imagenet', 'smallwidth_imagenet'): 344 | # Ignore archive only means to not to try to extract the files again, because they already are and the zip files 345 | # are not there no more 346 | total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train, ignore_archive=True) 347 | testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test, ignore_archive=True) 348 | 349 | # compatibility 350 | total_trainset.targets = [lb for _, lb in total_trainset.samples] 351 | else: 352 | raise ValueError('invalid dataset name=%s' % dataset) 353 | 354 | if 'throwaway_share_of_ds' in C.get(): 355 | assert 'val_step_trainloader_val_share' not in C.get() 356 | share = C.get()['throwaway_share_of_ds']['throwaway_share'] 357 | train_subset_inds, rest_inds = stratified_split(total_trainset.targets if hasattr(total_trainset, 'targets') else list(total_trainset.labels),share) 358 | if C.get()['throwaway_share_of_ds']['use_throwaway_as_val']: 359 | testset = copy_and_replace_transform(Subset(total_trainset, rest_inds), transform_test) 360 | total_trainset = Subset(total_trainset, train_subset_inds) 361 | 362 | train_sampler = None 363 | if split > 0.0: 364 | sss = StratifiedShuffleSplit(n_splits=5, test_size=split, random_state=0) 365 | sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) 366 | for _ in range(split_idx + 1): 367 | train_idx, valid_idx = next(sss) 368 | 369 | train_sampler = SubsetRandomSampler(train_idx) 370 | valid_sampler = SubsetSampler(valid_idx) 371 | else: 372 | valid_sampler = SubsetSampler([]) 373 | 374 | if distributed: 375 | assert split == 0.0, "Split not supported for distributed training." 376 | if C.get().get('all_workers_use_the_same_batches', False): 377 | train_sampler = DistributedSampler(total_trainset, num_replicas=1, rank=0) 378 | else: 379 | train_sampler = DistributedSampler(total_trainset) 380 | test_sampler = None 381 | test_train_sampler = None # if these are specified, acc/loss computation is wrong for results. 382 | # while one has to say, that this setting leads to the test sets being computed seperately on each gpu which 383 | # might be considered not-very-climate-friendly 384 | else: 385 | test_sampler = None 386 | test_train_sampler = None 387 | 388 | trainloader = torch.utils.data.DataLoader( 389 | total_trainset, batch_size=batch, shuffle=train_sampler is None, num_workers= os.cpu_count()//8 if distributed else 32, # fix the data laoder 390 | pin_memory=True, 391 | sampler=train_sampler, drop_last=True, persistent_workers=True) 392 | validloader = torch.utils.data.DataLoader( 393 | total_trainset, batch_size=batch, shuffle=False, num_workers=0 if started_with_spawn else 8, pin_memory=True, 394 | sampler=valid_sampler, drop_last=False) 395 | 396 | testloader = torch.utils.data.DataLoader( 397 | testset, batch_size=batch, shuffle=False, num_workers=16 if started_with_spawn else 8, pin_memory=True, 398 | drop_last=False, sampler=test_sampler, persistent_workers=True 399 | ) 400 | # We use this 'hacky' solution s.t. we do not need to keep the dataset twice in memory. 401 | test_total_trainset = copy_and_replace_transform(total_trainset, transform_test) 402 | test_trainloader = torch.utils.data.DataLoader( 403 | test_total_trainset, batch_size=batch, shuffle=False, num_workers=0 if started_with_spawn else 8, pin_memory=True, 404 | drop_last=False, sampler=test_train_sampler 405 | ) 406 | test_trainloader.denorm = lambda x: denormalize(x, dataset_info['mean'], dataset_info['std']) 407 | 408 | return train_sampler, trainloader, validloader, testloader, test_trainloader, dataset_info 409 | # trainloader_prefetch = PrefetchedWrapper(trainloader, start_epoch=0, num_classes=None, one_hot=False) 410 | # testloader_prefetch = PrefetchedWrapper(testloader, start_epoch=0, num_classes=None, one_hot=False) 411 | # return train_sampler, trainloader_prefetch, validloader, testloader_prefetch, test_trainloader, dataset_info 412 | 413 | 414 | class SubsetSampler(Sampler): 415 | r"""Samples elements from a given list of indices, without replacement. 416 | 417 | Arguments: 418 | indices (sequence): a sequence of indices 419 | """ 420 | 421 | def __init__(self, indices): 422 | self.indices = indices 423 | 424 | def __iter__(self): 425 | return (i for i in self.indices) 426 | 427 | def __len__(self): 428 | return len(self.indices) -------------------------------------------------------------------------------- /DeepAA_evaluate/deep_autoaugment.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | import math 5 | 6 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 7 | import numpy as np 8 | import torch 9 | import os 10 | import json 11 | import hashlib 12 | import requests 13 | import scipy 14 | from torchvision.transforms.transforms import Compose 15 | 16 | random_mirror = True 17 | 18 | ########################################################################## 19 | CIFAR_MEANS = np.array([0.49139968, 0.48215841, 0.44653091], dtype=np.float32) 20 | # CIFAR10_STDS = np.array([0.24703223, 0.24348513, 0.26158784], dtype=np.float32) 21 | CIFAR_STDS = np.array([0.2023, 0.1994, 0.2010], dtype=np.float32) 22 | 23 | SVHN_MEANS = np.array([0.4379, 0.4440, 0.4729], dtype=np.float32) 24 | SVHN_STDS = np.array([0.1980, 0.2010, 0.1970], dtype=np.float32) 25 | 26 | IMAGENET_MEANS = np.array([0.485, 0.456, 0.406], dtype=np.float32) 27 | IMAGENET_STDS = np.array([0.229, 0.224, 0.225], dtype=np.float32) 28 | 29 | def ShearX(img, v): # [-0.3, 0.3] 30 | assert -0.3 <= v <= 0.3 31 | if random_mirror and random.random() > 0.5: 32 | v = -v 33 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 34 | 35 | 36 | def ShearY(img, v): # [-0.3, 0.3] 37 | assert -0.3 <= v <= 0.3 38 | if random_mirror and random.random() > 0.5: 39 | v = -v 40 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 41 | 42 | 43 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 44 | assert -0.45 <= v <= 0.45 45 | if random_mirror and random.random() > 0.5: 46 | v = -v 47 | v = v * img.size[0] 48 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 49 | 50 | 51 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 52 | assert -0.45 <= v <= 0.45 53 | if random_mirror and random.random() > 0.5: 54 | v = -v 55 | v = v * img.size[1] 56 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 57 | 58 | 59 | def TranslateXAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 60 | assert 0 <= v <= 10 61 | if random_mirror and random.random() > 0.5: 62 | v = -v 63 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 64 | 65 | 66 | def TranslateYAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 67 | assert 0 <= v <= 10 68 | if random_mirror and random.random() > 0.5: 69 | v = -v 70 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 71 | 72 | 73 | def Rotate(img, v): # [-30, 30] 74 | assert -30 <= v <= 30 75 | if random_mirror and random.random() > 0.5: 76 | v = -v 77 | return img.rotate(v) 78 | 79 | 80 | def AutoContrast(img, _): 81 | return PIL.ImageOps.autocontrast(img) 82 | 83 | 84 | def Invert(img, _): 85 | return PIL.ImageOps.invert(img) 86 | 87 | 88 | def Equalize(img, _): 89 | return PIL.ImageOps.equalize(img) 90 | 91 | 92 | def Flip(img, _): # not from the paper 93 | return PIL.ImageOps.mirror(img) 94 | 95 | 96 | def Solarize(img, v): # [0, 256] 97 | assert 0 <= v <= 256 98 | return PIL.ImageOps.solarize(img, v) 99 | 100 | 101 | def Posterize(img, v): # [4, 8] 102 | assert 4 <= v <= 8 103 | v = int(v) 104 | v = max(1, v) 105 | return PIL.ImageOps.posterize(img, v) 106 | 107 | 108 | def Posterize2(img, v): # [0, 4] 109 | assert 0 <= v <= 4 110 | v = int(v) 111 | return PIL.ImageOps.posterize(img, v) 112 | 113 | 114 | def Contrast(img, v): # [0.1,1.9] 115 | assert 0.1 <= v <= 1.9 116 | return PIL.ImageEnhance.Contrast(img).enhance(v) 117 | 118 | 119 | def Color(img, v): # [0.1,1.9] 120 | assert 0.1 <= v <= 1.9 121 | return PIL.ImageEnhance.Color(img).enhance(v) 122 | 123 | 124 | def Brightness(img, v): # [0.1,1.9] 125 | assert 0.1 <= v <= 1.9 126 | return PIL.ImageEnhance.Brightness(img).enhance(v) 127 | 128 | 129 | def Sharpness(img, v): # [0.1,1.9] 130 | assert 0.1 <= v <= 1.9 131 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 132 | 133 | 134 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 135 | assert 0.0 <= v <= 0.2 136 | if v <= 0.: 137 | return img 138 | 139 | v = v * img.size[0] 140 | return Cutout_default(img, v) 141 | 142 | 143 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 144 | # assert 0 <= v <= 20 145 | if v < 0: 146 | return img 147 | w, h = img.size 148 | # x0 = np.random.uniform(w) 149 | # y0 = np.random.uniform(h) 150 | x0 = random.uniform(0, w) 151 | y0 = random.uniform(0, h) 152 | 153 | x0 = int(max(0, x0 - v / 2.)) 154 | y0 = int(max(0, y0 - v / 2.)) 155 | x1 = min(w, x0 + v) 156 | y1 = min(h, y0 + v) 157 | 158 | xy = (x0, y0, x1, y1) 159 | # color = (125, 123, 114) 160 | color = (0, 0, 0) 161 | img = img.copy() 162 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 163 | return img 164 | 165 | 166 | def SamplePairing(imgs): # [0, 0.4] 167 | def f(img1, v): 168 | i = np.random.choice(len(imgs)) 169 | img2 = PIL.Image.fromarray(imgs[i]) 170 | return PIL.Image.blend(img1, img2, v) 171 | 172 | return f 173 | 174 | # =============== OPS for DeepAA ==============: 175 | def mean_pad_randcrop(img, v): 176 | # v: Pad with mean value=[125, 123, 114] by v pixels on each side and then take random crop 177 | assert v <= 10, 'The maximum shift should be less then 10' 178 | padded_size = (img.size[0] + 2*v, img.size[1] + 2*v) 179 | new_img = PIL.Image.new('RGB', padded_size, color=(125, 123, 114)) 180 | # new_img = PIL.Image.new('RGB', padded_size, color=(0, 0, 0)) 181 | new_img.paste(img, (v, v)) 182 | top = random.randint(0, v*2) 183 | left = random.randint(0, v*2) 184 | new_img = new_img.crop((left, top, left + img.size[0], top + img.size[1])) 185 | return new_img 186 | 187 | 188 | 189 | def Cutout_default(img, v): # Used in FastAA, different from CutoutABS, the actual cutout size can be smaller than v on the boundary 190 | # Passed random number generation test 191 | # assert 0 <= v <= 20 192 | if v < 0: 193 | return img 194 | w, h = img.size 195 | # x = np.random.uniform(w) 196 | # y = np.random.uniform(h) 197 | if v <= 16: # for cutout of cifar and SVHN 198 | assert w == h == 32 199 | x = random.uniform(0, w) 200 | y = random.uniform(0, h) 201 | 202 | x0 = int(min(w, max(0, x - v // 2))) # clip to the range (0, w) 203 | x1 = int(min(w, max(0, x + v // 2))) 204 | y0 = int(min(h, max(0, y - v // 2))) 205 | y1 = int(min(h, max(0, y + v // 2))) 206 | 207 | xy = (x0, y0, x1, y1) 208 | color = (125, 123, 114) 209 | # color = (0, 0, 0) 210 | img = img.copy() 211 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 212 | # img = CutoutAbs(img, v) 213 | return img 214 | else: 215 | raise NotImplementedError 216 | 217 | def RandCrop(img, _): 218 | v = 4 219 | return mean_pad_randcrop(img, v) 220 | 221 | def RandCutout(img, _): 222 | v = 16 # Cutout 0.5 means 0.5*32=16 pixels as in the FastAA paper 223 | return Cutout_default(img, v) 224 | 225 | def RandCutout60(img, _): 226 | v = 60 # Cutout 0.5 means 0.5*32=16 pixels as in the FastAA paper 227 | return Cutout_default(img, v) 228 | 229 | def RandFlip(img, _): 230 | if random.random() > 0.5: 231 | img = Flip(img, None) 232 | return img 233 | 234 | def Identity(img, _): 235 | return img 236 | 237 | # ===================== ops for imagenet ============= 238 | def RandResizeCrop_imagenet(img, _): 239 | # ported from torchvision 240 | # for ImageNet use only 241 | scale = (0.08, 1.0) 242 | ratio = (3. / 4., 4. / 3.) 243 | size = IMAGENET_SIZE # (224, 224) 244 | 245 | def get_params(img, scale, ratio): 246 | width, height = img.size 247 | area = float(width * height) 248 | log_ratio = [math.log(r) for r in ratio] 249 | 250 | for _ in range(10): 251 | target_area = area * random.uniform(scale[0], scale[1]) 252 | aspect_ratio = math.exp(random.uniform(log_ratio[0], log_ratio[1])) 253 | 254 | w = round(math.sqrt(target_area * aspect_ratio)) 255 | h = round(math.sqrt(target_area / aspect_ratio)) 256 | if 0 < w <= width and 0 < h <= height: 257 | top = random.randint(0, height - h) 258 | left = random.randint(0, width - w) 259 | return left, top, w, h 260 | 261 | # fallback to central crop 262 | in_ratio = float(width) / float(height) 263 | if in_ratio < min(ratio): 264 | w = width 265 | h = round(w / min(ratio)) 266 | elif in_ratio > max(ratio): 267 | h = height 268 | w = round(h * max(ratio)) 269 | else: 270 | w = width 271 | h = height 272 | top = (height - h) // 2 273 | left = (width - w) // 2 274 | return left, top, w, h 275 | 276 | left, top, w_box, h_box = get_params(img, scale, ratio) 277 | box = (left, top, left + w_box, top + h_box) 278 | img = img.resize(size=size, resample=PIL.Image.CUBIC, box=box) 279 | return img 280 | 281 | 282 | def Resize_imagenet(img, size): 283 | w, h = img.size 284 | if isinstance(size, int): 285 | short, long = (w, h) if w <= h else (h, w) 286 | if short == size: 287 | return img 288 | new_short, new_long = size, int(size * long / short) 289 | new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) 290 | return img.resize((new_w, new_h), PIL.Image.BICUBIC) 291 | elif isinstance(size, tuple) or isinstance(size, list): 292 | assert len(size) == 2, 'Check the size {}'.format(size) 293 | return img.resize(size, PIL.Image.BICUBIC) 294 | else: 295 | raise Exception 296 | 297 | 298 | def centerCrop_imagenet(img, _): 299 | # for ImageNet only 300 | # https://github.com/pytorch/vision/blob/master/torchvision/transforms/functional.py 301 | crop_width, crop_height = IMAGENET_SIZE # (224,224) 302 | image_width, image_height = img.size 303 | 304 | if crop_width > image_width or crop_height > image_height: 305 | padding_ltrb = [ 306 | (crop_width - image_width) // 2 if crop_width > image_width else 0, 307 | (crop_height - image_height) // 2 if crop_height > image_height else 0, 308 | (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, 309 | (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, 310 | ] 311 | img = pad(img, padding_ltrb, fill=0) 312 | image_width, image_height = img.size 313 | if crop_width == image_width and crop_height == image_height: 314 | return img 315 | 316 | crop_top = int(round((image_height - crop_height) / 2.)) 317 | crop_left = int(round((image_width - crop_width) / 2.)) 318 | return img.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height)) 319 | 320 | 321 | def _parse_fill(fill, img, name="fillcolor"): 322 | # Process fill color for affine transforms 323 | num_bands = len(img.getbands()) 324 | if fill is None: 325 | fill = 0 326 | if isinstance(fill, (int, float)) and num_bands > 1: 327 | fill = tuple([fill] * num_bands) 328 | if isinstance(fill, (list, tuple)): 329 | if len(fill) != num_bands: 330 | msg = ("The number of elements in 'fill' does not match the number of " 331 | "bands of the image ({} != {})") 332 | raise ValueError(msg.format(len(fill), num_bands)) 333 | 334 | fill = tuple(fill) 335 | 336 | return {name: fill} 337 | 338 | 339 | def pad(img, padding_ltrb, fill=0, padding_mode='constant'): 340 | if isinstance(padding_ltrb, list): 341 | padding_ltrb = tuple(padding_ltrb) 342 | if padding_mode == 'constant': 343 | opts = _parse_fill(fill, img, name='fill') 344 | if img.mode == 'P': 345 | palette = img.getpalette() 346 | image = PIL.ImageOps.expand(img, border=padding_ltrb, **opts) 347 | image.putpalette(palette) 348 | return image 349 | return PIL.ImageOps.expand(img, border=padding_ltrb, **opts) 350 | elif len(padding_ltrb) == 4: 351 | image_width, image_height = img.size 352 | cropping = -np.minimum(padding_ltrb, 0) 353 | if cropping.any(): 354 | crop_left, crop_top, crop_right, crop_bottom = cropping 355 | img = img.crop((crop_left, crop_top, image_width - crop_right, image_height - crop_bottom)) 356 | pad_left, pad_top, pad_right, pad_bottom = np.maximum(padding_ltrb, 0) 357 | 358 | if img.mode == 'P': 359 | palette = img.getpalette() 360 | img = np.asarray(img) 361 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) 362 | img = PIL.Image.fromarray(img) 363 | img.putpalette(palette) 364 | return img 365 | 366 | img = np.asarray(img) 367 | # RGB image 368 | if len(img.shape) == 3: 369 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode) 370 | # Grayscale image 371 | if len(img.shape) == 2: 372 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) 373 | 374 | return PIL.Image.fromarray(img) 375 | else: 376 | raise Exception 377 | 378 | def augment_list(for_autoaug=True, for_DeepAA_cifar=True, for_DeepAA_imagenet=True): # 16 oeprations and their ranges 379 | l = [ 380 | (ShearX, -0.3, 0.3), # 0 381 | (ShearY, -0.3, 0.3), # 1 382 | (TranslateX, -0.45, 0.45), # 2 383 | (TranslateY, -0.45, 0.45), # 3 384 | (Rotate, -30, 30), # 4 385 | (AutoContrast, 0, 1), # 5 386 | (Invert, 0, 1), # 6 387 | (Equalize, 0, 1), # 7 388 | (Solarize, 0, 256), # 8 389 | (Posterize, 4, 8), # 9 390 | (Contrast, 0.1, 1.9), # 10 391 | (Color, 0.1, 1.9), # 11 392 | (Brightness, 0.1, 1.9), # 12 393 | (Sharpness, 0.1, 1.9), # 13 394 | (Cutout, 0, 0.2), # 14 395 | # (SamplePairing(imgs), 0, 0.4), # 15 396 | ] 397 | if for_autoaug: 398 | l += [ 399 | (CutoutAbs, 0, 20), # compatible with auto-augment 400 | (Posterize2, 0, 4), # 9 401 | (TranslateXAbs, 0, 10), # 9 402 | (TranslateYAbs, 0, 10), # 9 403 | ] 404 | if for_DeepAA_cifar: 405 | l += [ 406 | (Identity, 0., 1.0), 407 | (RandFlip, 0., 1.0), # Additional 15 408 | (RandCutout, 0., 1.0), # 16 409 | (RandCrop, 0., 1.0), # 17 410 | ] 411 | if for_DeepAA_imagenet: 412 | l += [ 413 | (RandResizeCrop_imagenet, 0., 1.0), 414 | (RandCutout60, 0., 1.0) 415 | ] 416 | 417 | return l 418 | 419 | 420 | augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()} 421 | 422 | def Cutout16(img, _): 423 | # return CutoutAbs(img, 16) 424 | return Cutout_default(img, 16) 425 | 426 | augmentation_TA_list = [ 427 | (Identity, 0., 1.0), 428 | (ShearX, -0.3, 0.3), # 0 429 | (ShearY, -0.3, 0.3), # 1 430 | (TranslateX, -0.45, 0.45), # 2 431 | (TranslateY, -0.45, 0.45), # 3 432 | (Rotate, -30, 30), # 4 433 | (AutoContrast, 0, 1), # 5 434 | # (Invert, 0, 1), # 6 435 | (Equalize, 0, 1), # 7 436 | (Solarize, 0, 256), # 8 437 | (Posterize, 4, 8), # 9 438 | (Contrast, 0.1, 1.9), # 10 439 | (Color, 0.1, 1.9), # 11 440 | (Brightness, 0.1, 1.9), # 12 441 | (Sharpness, 0.1, 1.9), # 13 442 | (Flip, 0., 1.0), # Additional 15 443 | (Cutout16, 0, 20), # (RandCutout, 0, 20), # compatible with auto-augment 444 | (RandCrop, 0., 1.0), # 17 445 | ] 446 | 447 | 448 | def get_augment(name): 449 | return augment_dict[name] 450 | 451 | 452 | def apply_augment(img, name, level): 453 | augment_fn, low, high = get_augment(name) 454 | return augment_fn(img.copy(), level * (high - low) + low) 455 | 456 | 457 | class Lighting(object): 458 | """Lighting noise(AlexNet - style PCA - based noise)""" 459 | 460 | def __init__(self, alphastd, eigval, eigvec): 461 | self.alphastd = alphastd 462 | self.eigval = torch.Tensor(eigval) 463 | self.eigvec = torch.Tensor(eigvec) 464 | 465 | def __call__(self, img): 466 | if self.alphastd == 0: 467 | return img 468 | 469 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 470 | rgb = self.eigvec.type_as(img).clone() \ 471 | .mul(alpha.view(1, 3).expand(3, 3)) \ 472 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 473 | .sum(1).squeeze() 474 | 475 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 476 | 477 | 478 | class Augmentation_DeepAA(object): 479 | def __init__(self, EXP='cifar', use_crop=False): 480 | self.use_crop = use_crop 481 | policy_data = np.load('./policy_port/policy_DeepAA_{}.npz'.format(EXP)) 482 | self.policy_probs = policy_data['policy_probs'] 483 | 484 | self.l_ops = policy_data['l_ops'] 485 | self.l_mags = policy_data['l_mags'] 486 | self.ops = policy_data['ops'] 487 | self.mags = policy_data['mags'] 488 | self.op_names = policy_data['op_names'] 489 | 490 | def __call__(self, img): 491 | for k_policy in self.policy_probs: 492 | k_samp = random.choices(range(len(k_policy)), weights=k_policy, k=1)[0] 493 | op, mag = np.squeeze(self.ops[k_samp]), np.squeeze(self.mags[k_samp]).astype(np.float32)/float(self.l_mags-1) 494 | op_name = self.op_names[op].split(':')[0] 495 | img = apply_augment(img, op_name, mag) 496 | if self.use_crop: 497 | w, h = img.size 498 | if w==IMAGENET_SIZE[0] and h==IMAGENET_SIZE[1]: 499 | return img 500 | # return centerCrop_imagenet(Resize_imagenet(img, 256), None) 501 | return centerCrop_imagenet(img, None) 502 | return img 503 | 504 | 505 | IMAGENET_SIZE = (224, 224) -------------------------------------------------------------------------------- /DeepAA_evaluate/imagenet.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets.imagenet import * 2 | 3 | class ImageNet(ImageFolder): 4 | """`ImageNet `_ 2012 Classification Dataset. 5 | Copied from torchvision, besides warning below. 6 | 7 | Args: 8 | root (string): Root directory of the ImageNet Dataset. 9 | split (string, optional): The dataset split, supports ``train``, or ``val``. 10 | transform (callable, optional): A function/transform that takes in an PIL image 11 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 12 | target_transform (callable, optional): A function/transform that takes in the 13 | target and transforms it. 14 | loader (callable, optional): A function to load an image given its path. 15 | 16 | Attributes: 17 | classes (list): List of the class name tuples. 18 | class_to_idx (dict): Dict with items (class_name, class_index). 19 | wnids (list): List of the WordNet IDs. 20 | wnid_to_idx (dict): Dict with items (wordnet_id, class_index). 21 | imgs (list): List of (image path, class_index) tuples 22 | targets (list): The class_index value for each image in the dataset 23 | 24 | WARN:: 25 | This is the same ImageNet class as in torchvision.datasets.imagenet, but it has the `ignore_archive` argument. 26 | This allows us to only copy the unzipped files before training. 27 | """ 28 | 29 | def __init__(self, root, split='train', download=None, ignore_archive=False, **kwargs): 30 | if download is True: 31 | msg = ("The dataset is no longer publicly accessible. You need to " 32 | "download the archives externally and place them in the root " 33 | "directory.") 34 | raise RuntimeError(msg) 35 | elif download is False: 36 | msg = ("The use of the download flag is deprecated, since the dataset " 37 | "is no longer publicly accessible.") 38 | warnings.warn(msg, RuntimeWarning) 39 | 40 | root = self.root = os.path.expanduser(root) 41 | self.split = verify_str_arg(split, "split", ("train", "val")) 42 | 43 | if not ignore_archive: 44 | self.parse_archives() 45 | wnid_to_classes = load_meta_file(self.root)[0] 46 | 47 | super(ImageNet, self).__init__(self.split_folder, **kwargs) 48 | self.root = root 49 | 50 | self.wnids = self.classes 51 | self.wnid_to_idx = self.class_to_idx 52 | self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] 53 | self.class_to_idx = {cls: idx 54 | for idx, clss in enumerate(self.classes) 55 | for cls in clss} 56 | 57 | def parse_archives(self): 58 | if not check_integrity(os.path.join(self.root, META_FILE)): 59 | parse_devkit_archive(self.root) 60 | 61 | if not os.path.isdir(self.split_folder): 62 | if self.split == 'train': 63 | parse_train_archive(self.root) 64 | elif self.split == 'val': 65 | parse_val_archive(self.root) 66 | 67 | @property 68 | def split_folder(self): 69 | return os.path.join(self.root, self.split) 70 | 71 | def extra_repr(self): 72 | return "Split: {split}".format(**self.__dict__) -------------------------------------------------------------------------------- /DeepAA_evaluate/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from theconf import Config as C 4 | 5 | 6 | def adjust_learning_rate_resnet(optimizer): 7 | """ 8 | Sets the learning rate to the initial LR decayed by 10 on every predefined epochs 9 | Ref: AutoAugment 10 | """ 11 | 12 | if C.get()['epoch'] == 90: 13 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 80]) 14 | elif C.get()['epoch'] == 180: 15 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 120, 160]) 16 | elif C.get()['epoch'] == 270: 17 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, [90, 180, 240]) 18 | else: 19 | raise ValueError('invalid epoch=%d for resnet scheduler' % C.get()['epoch']) 20 | -------------------------------------------------------------------------------- /DeepAA_evaluate/metrics.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | from collections import defaultdict 5 | 6 | from torch import nn 7 | 8 | 9 | def accuracy(output, target, topk=(1,)): 10 | """Computes the precision@k for the specified values of k""" 11 | maxk = max(topk) 12 | batch_size = target.size(0) 13 | 14 | _, pred = output.topk(maxk, 1, True, True) 15 | pred = pred.t() 16 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 17 | 18 | res = [] 19 | for k in topk: 20 | correct_k = correct[:k].flatten().float().sum(0) 21 | res.append(correct_k.mul_(1. / batch_size)) 22 | return res 23 | 24 | 25 | def cross_entropy_smooth(input, target, size_average=True, label_smoothing=0.1): 26 | y = torch.eye(10).cuda() 27 | lb_oh = y[target] 28 | 29 | target = lb_oh * (1 - label_smoothing) + 0.5 * label_smoothing 30 | 31 | logsoftmax = nn.LogSoftmax() 32 | if size_average: 33 | return torch.mean(torch.sum(-target * logsoftmax(input), dim=1)) 34 | else: 35 | return torch.sum(torch.sum(-target * logsoftmax(input), dim=1)) 36 | 37 | 38 | class Accumulator: 39 | def __init__(self): 40 | self.metrics = defaultdict(lambda: 0.) 41 | 42 | def add(self, key, value): 43 | self.metrics[key] += value 44 | 45 | def add_dict(self, dict): 46 | for key, value in dict.items(): 47 | self.add(key, value) 48 | 49 | def __getitem__(self, item): 50 | return self.metrics[item] 51 | 52 | def __setitem__(self, key, value): 53 | self.metrics[key] = value 54 | 55 | def __contains__(self, item): 56 | return self.metrics.__contains__(item) 57 | 58 | def get_dict(self): 59 | return copy.deepcopy(dict(self.metrics)) 60 | 61 | def items(self): 62 | return self.metrics.items() 63 | 64 | def __str__(self): 65 | return str(dict(self.metrics)) 66 | 67 | def __truediv__(self, other): 68 | newone = Accumulator() 69 | for key, value in self.items(): 70 | newone[key] = value / other 71 | return newone 72 | 73 | def divide(self, divisor, **special_divisors): 74 | newone = Accumulator() 75 | for key, value in self.items(): 76 | if key in special_divisors: 77 | newone[key] = value/special_divisors[key] 78 | else: 79 | newone[key] = value/divisor 80 | return newone 81 | 82 | 83 | class SummaryWriterDummy: 84 | def __init__(self, log_dir): 85 | pass 86 | 87 | def add_scalar(self, *args, **kwargs): 88 | pass 89 | 90 | def add_image(self, *args, **kwargs): 91 | pass -------------------------------------------------------------------------------- /DeepAA_evaluate/networks/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from torch.nn import DataParallel 5 | import torch.backends.cudnn as cudnn 6 | # from torchvision import models 7 | 8 | from DeepAA_evaluate.networks.resnet import ResNet 9 | from DeepAA_evaluate.networks.shakeshake.shake_resnet import ShakeResNet 10 | from DeepAA_evaluate.networks.wideresnet import WideResNet 11 | from DeepAA_evaluate.networks.shakeshake.shake_resnext import ShakeResNeXt 12 | from DeepAA_evaluate.networks.convnet import SeqConvNet 13 | from DeepAA_evaluate.networks.mlp import MLP 14 | from DeepAA_evaluate.common import apply_weightnorm 15 | 16 | 17 | 18 | # example usage get_model( 19 | def get_model(conf, bs, num_class=10, writer=None): 20 | name = conf['type'] 21 | ad_creators = (None,None) 22 | 23 | 24 | if name == 'resnet50': 25 | model = ResNet(dataset='imagenet', depth=50, num_classes=num_class, bottleneck=True) 26 | elif name == 'resnet200': 27 | model = ResNet(dataset='imagenet', depth=200, num_classes=num_class, bottleneck=True) 28 | elif name == 'resnet18': 29 | model = ResNet(dataset='imagenet', depth=18, num_classes=num_class, bottleneck=False) 30 | elif name == 'wresnet40_2': 31 | model = WideResNet(40, 2, dropout_rate=conf.get('dropout',0.0), num_classes=num_class, adaptive_dropouter_creator=ad_creators[0],adaptive_conv_dropouter_creator=ad_creators[1], groupnorm=conf.get('groupnorm', False), examplewise_bn=conf.get('examplewise_bn', False), virtual_bn=conf.get('virtual_bn', False)) 32 | elif name == 'wresnet28_10': 33 | model = WideResNet(28, 10, dropout_rate=conf.get('dropout',0.0), num_classes=num_class, adaptive_dropouter_creator=ad_creators[0],adaptive_conv_dropouter_creator=ad_creators[1], groupnorm=conf.get('groupnorm',False), examplewise_bn=conf.get('examplewise_bn', False), virtual_bn=conf.get('virtual_bn', False)) 34 | elif name == 'wresnet28_2': 35 | model = WideResNet(28, 2, dropout_rate=conf.get('dropout', 0.0), num_classes=num_class, 36 | adaptive_dropouter_creator=ad_creators[0], adaptive_conv_dropouter_creator=ad_creators[1], 37 | groupnorm=conf.get('groupnorm', False), examplewise_bn=conf.get('examplewise_bn', False), 38 | virtual_bn=conf.get('virtual_bn', False)) 39 | elif name == 'miniconvnet': 40 | model = SeqConvNet(num_class,adaptive_dropout_creator=ad_creators[0],batch_norm=False) 41 | elif name == 'mlp': 42 | model = MLP(num_class, (3,32,32), adaptive_dropouter_creator=ad_creators[0]) 43 | elif name == 'shakeshake26_2x96d': 44 | model = ShakeResNet(26, 96, num_class) 45 | elif name == 'shakeshake26_2x112d': 46 | model = ShakeResNet(26, 112, num_class) 47 | elif name == 'shakeshake26_2x96d_next': 48 | model = ShakeResNeXt(26, 96, 4, num_class) 49 | else: 50 | raise NameError('no model named, %s' % name) 51 | 52 | if conf.get('weight_norm', False): 53 | print('Using weight norm.') 54 | apply_weightnorm(model) 55 | 56 | #model = model.cuda() 57 | #model = DataParallel(model) 58 | cudnn.benchmark = True 59 | return model 60 | 61 | 62 | def num_class(dataset): 63 | return { 64 | 'cifar10': 10, 65 | 'noised_cifar10': 10, 66 | 'targetnoised_cifar10': 10, 67 | 'reduced_cifar10': 10, 68 | 'cifar10.1': 10, 69 | 'pre_transform_cifar10': 10, 70 | 'cifar100': 100, 71 | 'pre_transform_cifar100': 100, 72 | 'fiftyexample_cifar100': 100, 73 | 'tenclass_cifar100': 10, 74 | 'svhn': 10, 75 | 'svhncore': 10, 76 | 'reduced_svhn': 10, 77 | 'imagenet': 1000, 78 | 'smallwidth_imagenet': 1000, 79 | 'ohl_pipeline_imagenet': 1000, 80 | 'reduced_imagenet': 120, 81 | }[dataset] 82 | -------------------------------------------------------------------------------- /DeepAA_evaluate/networks/convnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class SeqConvNet(nn.Module): 5 | def __init__(self,D_out,fixed_dropout=None,in_channels=3,channels=(64,64),h_dims=(200,100),adaptive_dropout_creator=None,batch_norm=False): 6 | super().__init__() 7 | print("Using SeqConvNet") 8 | assert len(channels) == 2 == len(h_dims) 9 | pool = lambda: nn.MaxPool2d(2,2) 10 | dropout = lambda: torch.nn.Dropout(p=fixed_dropout) 11 | dropout_li = lambda: ([] if fixed_dropout is None else [dropout()]) 12 | relu = lambda: torch.nn.ReLU(inplace=False) 13 | flatten = lambda l: [item for sublist in l for item in sublist] 14 | convs = [nn.Conv2d(in_channels, channels[0], 5),nn.Conv2d(channels[0], channels[1], 5)] 15 | fcs = [nn.Linear(channels[1] * 5 * 5, h_dims[0]),nn.Linear(h_dims[0], h_dims[1])] 16 | self.final_fc = nn.Linear(h_dims[1], D_out) 17 | self.conv_blocks = nn.Sequential(*flatten([[conv,relu(),pool()] + dropout_li() for conv in convs])) 18 | self.bn = nn.BatchNorm1d(h_dims[1], momentum=.9) if batch_norm else nn.Identity() 19 | self.fc_blocks = nn.Sequential(*flatten([[fc,relu()] + dropout_li() for fc in fcs])) 20 | self.adaptive_dropouters = [adaptive_dropout_creator(h_dims[1])] if adaptive_dropout_creator is not None else [] 21 | 22 | def forward(self, x): 23 | x = self.conv_blocks(x) 24 | x = torch.nn.Flatten()(x) 25 | x = self.fc_blocks(x) 26 | if self.adaptive_dropouters: 27 | x = self.adaptive_dropouters[0](x) 28 | x = self.bn(x) 29 | x = self.final_fc(x) 30 | return x 31 | 32 | -------------------------------------------------------------------------------- /DeepAA_evaluate/networks/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def MLP(D_out,in_dims,adaptive_dropouter_creator): 6 | print('adaptive dropouter', adaptive_dropouter_creator) 7 | in_dim = 1 8 | for d in in_dims: in_dim *= d 9 | ada_dropper = adaptive_dropouter_creator(100) if adaptive_dropouter_creator is not None else None 10 | model = nn.Sequential( 11 | nn.Flatten(), 12 | nn.Linear(in_dim, 300), 13 | nn.Tanh(), 14 | nn.Linear(300,100), 15 | ada_dropper or nn.Identity(), 16 | nn.Tanh(), 17 | nn.Linear(100,D_out) 18 | ) 19 | model.adaptive_dropouters = [ada_dropper] if ada_dropper is not None else [] 20 | return model 21 | -------------------------------------------------------------------------------- /DeepAA_evaluate/networks/resnet.py: -------------------------------------------------------------------------------- 1 | # Original code: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 2 | # gamma is initialized ot 0 in the last BN of each residual block 3 | 4 | import torch.nn as nn 5 | import math 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | "3x3 convolution with padding" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, inplanes, planes, stride=1, downsample=None): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = conv3x3(inplanes, planes, stride) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | nn.init.zeros_(self.bn2.weight) 24 | self.relu = nn.ReLU(inplace=True) 25 | 26 | self.downsample = downsample 27 | self.stride = stride 28 | 29 | def forward(self, x): 30 | residual = x 31 | 32 | out = self.conv1(x) 33 | out = self.bn1(out) 34 | out = self.relu(out) 35 | 36 | out = self.conv2(out) 37 | out = self.bn2(out) 38 | 39 | if self.downsample is not None: 40 | residual = self.downsample(x) 41 | 42 | out += residual 43 | out = self.relu(out) 44 | 45 | return out 46 | 47 | 48 | class Bottleneck(nn.Module): 49 | expansion = 4 50 | 51 | def __init__(self, inplanes, planes, stride=1, downsample=None): 52 | super(Bottleneck, self).__init__() 53 | 54 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 55 | self.bn1 = nn.BatchNorm2d(planes) 56 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False) 59 | self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion) 60 | nn.init.zeros_(self.bn3.weight) 61 | self.relu = nn.ReLU(inplace=True) 62 | 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | residual = x 68 | 69 | out = self.conv1(x) 70 | out = self.bn1(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv2(out) 74 | out = self.bn2(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv3(out) 78 | out = self.bn3(out) 79 | if self.downsample is not None: 80 | residual = self.downsample(x) 81 | 82 | out += residual 83 | out = self.relu(out) 84 | 85 | return out 86 | 87 | class ResNet(nn.Module): 88 | def __init__(self, dataset, depth, num_classes, bottleneck=False): 89 | super(ResNet, self).__init__() 90 | self.dataset = dataset 91 | if self.dataset.startswith('cifar'): 92 | self.inplanes = 16 93 | print(bottleneck) 94 | if bottleneck == True: 95 | n = int((depth - 2) / 9) 96 | block = Bottleneck 97 | else: 98 | n = int((depth - 2) / 6) 99 | block = BasicBlock 100 | 101 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 102 | self.bn1 = nn.BatchNorm2d(self.inplanes) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.layer1 = self._make_layer(block, 16, n) 105 | self.layer2 = self._make_layer(block, 32, n, stride=2) 106 | self.layer3 = self._make_layer(block, 64, n, stride=2) 107 | # self.avgpool = nn.AvgPool2d(8) 108 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 109 | self.fc = nn.Linear(64 * block.expansion, num_classes) 110 | 111 | elif dataset == 'imagenet': 112 | blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} 113 | layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]} 114 | assert layers[depth], 'invalid detph for ResNet (depth should be one of 18, 34, 50, 101, 152, and 200)' 115 | 116 | self.inplanes = 64 117 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 118 | self.bn1 = nn.BatchNorm2d(64) 119 | self.relu = nn.ReLU(inplace=True) 120 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 121 | self.layer1 = self._make_layer(blocks[depth], 64, layers[depth][0]) 122 | self.layer2 = self._make_layer(blocks[depth], 128, layers[depth][1], stride=2) 123 | self.layer3 = self._make_layer(blocks[depth], 256, layers[depth][2], stride=2) 124 | self.layer4 = self._make_layer(blocks[depth], 512, layers[depth][3], stride=2) 125 | # self.avgpool = nn.AvgPool2d(7) 126 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 127 | self.fc = nn.Linear(512 * blocks[depth].expansion, num_classes) 128 | 129 | for m in self.modules(): 130 | if isinstance(m, nn.Conv2d): 131 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 132 | m.weight.data.normal_(0, math.sqrt(2. / n)) 133 | elif isinstance(m, nn.BatchNorm2d): 134 | m.weight.data.fill_(1) 135 | m.bias.data.zero_() 136 | 137 | def _make_layer(self, block, planes, blocks, stride=1): 138 | downsample = None 139 | if stride != 1 or self.inplanes != planes * block.expansion: 140 | downsample = nn.Sequential( 141 | nn.Conv2d(self.inplanes, planes * block.expansion, 142 | kernel_size=1, stride=stride, bias=False), 143 | nn.BatchNorm2d(planes * block.expansion), 144 | ) 145 | 146 | layers = [] 147 | layers.append(block(self.inplanes, planes, stride, downsample)) 148 | self.inplanes = planes * block.expansion 149 | for i in range(1, blocks): 150 | layers.append(block(self.inplanes, planes)) 151 | 152 | return nn.Sequential(*layers) 153 | 154 | def forward(self, x): 155 | if self.dataset == 'cifar10' or self.dataset == 'cifar100': 156 | x = self.conv1(x) 157 | x = self.bn1(x) 158 | x = self.relu(x) 159 | 160 | x = self.layer1(x) 161 | x = self.layer2(x) 162 | x = self.layer3(x) 163 | 164 | x = self.avgpool(x) 165 | x = x.view(x.size(0), -1) 166 | x = self.fc(x) 167 | 168 | elif self.dataset == 'imagenet': 169 | x = self.conv1(x) 170 | x = self.bn1(x) 171 | x = self.relu(x) 172 | x = self.maxpool(x) 173 | 174 | x = self.layer1(x) 175 | x = self.layer2(x) 176 | x = self.layer3(x) 177 | x = self.layer4(x) 178 | 179 | x = self.avgpool(x) 180 | x = x.view(x.size(0), -1) 181 | x = self.fc(x) 182 | 183 | return x 184 | -------------------------------------------------------------------------------- /DeepAA_evaluate/networks/shakeshake/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/DeepAA_evaluate/networks/shakeshake/__init__.py -------------------------------------------------------------------------------- /DeepAA_evaluate/networks/shakeshake/shake_resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from DeepAA_evaluate.networks.shakeshake.shakeshake import ShakeShake 9 | from DeepAA_evaluate.networks.shakeshake.shakeshake import Shortcut 10 | 11 | 12 | class ShakeBlock(nn.Module): 13 | 14 | def __init__(self, in_ch, out_ch, stride=1): 15 | super(ShakeBlock, self).__init__() 16 | self.equal_io = in_ch == out_ch 17 | if self.equal_io: 18 | self.shortcut = lambda x: x 19 | else: 20 | self.shortcut = Shortcut(in_ch, out_ch, stride=stride) 21 | #self.shortcut = self.equal_io and None or Shortcut(in_ch, out_ch, stride=stride) 22 | 23 | self.branch1 = self._make_branch(in_ch, out_ch, stride) 24 | self.branch2 = self._make_branch(in_ch, out_ch, stride) 25 | 26 | def forward(self, x): 27 | h1 = self.branch1(x) 28 | h2 = self.branch2(x) 29 | h = ShakeShake.apply(h1, h2, self.training) 30 | #h0 = x if self.equal_io else self.shortcut(x) 31 | h0 = self.shortcut(x) 32 | return h + h0 33 | 34 | def _make_branch(self, in_ch, out_ch, stride=1): 35 | return nn.Sequential( 36 | nn.ReLU(inplace=False), 37 | nn.Conv2d(in_ch, out_ch, 3, padding=1, stride=stride, bias=False), 38 | nn.BatchNorm2d(out_ch), 39 | nn.ReLU(inplace=False), 40 | nn.Conv2d(out_ch, out_ch, 3, padding=1, stride=1, bias=False), 41 | nn.BatchNorm2d(out_ch)) 42 | 43 | 44 | class ShakeResNet(nn.Module): 45 | 46 | def __init__(self, depth, w_base, label): 47 | super(ShakeResNet, self).__init__() 48 | n_units = (depth - 2) / 6 49 | 50 | in_chs = [16, w_base, w_base * 2, w_base * 4] 51 | self.in_chs = in_chs 52 | 53 | self.c_in = nn.Conv2d(3, in_chs[0], 3, padding=1) 54 | self.layer1 = self._make_layer(n_units, in_chs[0], in_chs[1]) 55 | self.layer2 = self._make_layer(n_units, in_chs[1], in_chs[2], 2) 56 | self.layer3 = self._make_layer(n_units, in_chs[2], in_chs[3], 2) 57 | self.fc_out = nn.Linear(in_chs[3], label) 58 | 59 | # Initialize paramters 60 | for m in self.modules(): 61 | if isinstance(m, nn.Conv2d): 62 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 63 | m.weight.data.normal_(0, math.sqrt(2. / n)) 64 | elif isinstance(m, nn.BatchNorm2d): 65 | m.weight.data.fill_(1) 66 | m.bias.data.zero_() 67 | elif isinstance(m, nn.Linear): 68 | m.bias.data.zero_() 69 | 70 | def forward(self, x): 71 | h = self.c_in(x) 72 | h = self.layer1(h) 73 | h = self.layer2(h) 74 | h = self.layer3(h) 75 | h = F.relu(h) 76 | h = F.avg_pool2d(h, 8) 77 | h = h.view(-1, self.in_chs[3]) 78 | h = self.fc_out(h) 79 | return h 80 | 81 | def _make_layer(self, n_units, in_ch, out_ch, stride=1): 82 | layers = [] 83 | for i in range(int(n_units)): 84 | layers.append(ShakeBlock(in_ch, out_ch, stride=stride)) 85 | in_ch, stride = out_ch, 1 86 | return nn.Sequential(*layers) 87 | -------------------------------------------------------------------------------- /DeepAA_evaluate/networks/shakeshake/shake_resnext.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from DeepAA_evaluate.networks.shakeshake.shakeshake import ShakeShake 9 | from DeepAA_evaluate.networks.shakeshake.shakeshake import Shortcut 10 | 11 | 12 | class ShakeBottleNeck(nn.Module): 13 | 14 | def __init__(self, in_ch, mid_ch, out_ch, cardinary, stride=1): 15 | super(ShakeBottleNeck, self).__init__() 16 | self.equal_io = in_ch == out_ch 17 | self.shortcut = None if self.equal_io else Shortcut(in_ch, out_ch, stride=stride) 18 | 19 | self.branch1 = self._make_branch(in_ch, mid_ch, out_ch, cardinary, stride) 20 | self.branch2 = self._make_branch(in_ch, mid_ch, out_ch, cardinary, stride) 21 | 22 | def forward(self, x): 23 | h1 = self.branch1(x) 24 | h2 = self.branch2(x) 25 | h = ShakeShake.apply(h1, h2, self.training) 26 | h0 = x if self.equal_io else self.shortcut(x) 27 | return h + h0 28 | 29 | def _make_branch(self, in_ch, mid_ch, out_ch, cardinary, stride=1): 30 | return nn.Sequential( 31 | nn.Conv2d(in_ch, mid_ch, 1, padding=0, bias=False), 32 | nn.BatchNorm2d(mid_ch), 33 | nn.ReLU(inplace=False), 34 | nn.Conv2d(mid_ch, mid_ch, 3, padding=1, stride=stride, groups=cardinary, bias=False), 35 | nn.BatchNorm2d(mid_ch), 36 | nn.ReLU(inplace=False), 37 | nn.Conv2d(mid_ch, out_ch, 1, padding=0, bias=False), 38 | nn.BatchNorm2d(out_ch)) 39 | 40 | 41 | class ShakeResNeXt(nn.Module): 42 | 43 | def __init__(self, depth, w_base, cardinary, label): 44 | super(ShakeResNeXt, self).__init__() 45 | n_units = (depth - 2) // 9 46 | n_chs = [64, 128, 256, 1024] 47 | self.n_chs = n_chs 48 | self.in_ch = n_chs[0] 49 | 50 | self.c_in = nn.Conv2d(3, n_chs[0], 3, padding=1) 51 | self.layer1 = self._make_layer(n_units, n_chs[0], w_base, cardinary) 52 | self.layer2 = self._make_layer(n_units, n_chs[1], w_base, cardinary, 2) 53 | self.layer3 = self._make_layer(n_units, n_chs[2], w_base, cardinary, 2) 54 | self.fc_out = nn.Linear(n_chs[3], label) 55 | 56 | # Initialize paramters 57 | for m in self.modules(): 58 | if isinstance(m, nn.Conv2d): 59 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 60 | m.weight.data.normal_(0, math.sqrt(2. / n)) 61 | elif isinstance(m, nn.BatchNorm2d): 62 | m.weight.data.fill_(1) 63 | m.bias.data.zero_() 64 | elif isinstance(m, nn.Linear): 65 | m.bias.data.zero_() 66 | 67 | def forward(self, x): 68 | h = self.c_in(x) 69 | h = self.layer1(h) 70 | h = self.layer2(h) 71 | h = self.layer3(h) 72 | h = F.relu(h) 73 | h = F.avg_pool2d(h, 8) 74 | h = h.view(-1, self.n_chs[3]) 75 | h = self.fc_out(h) 76 | return h 77 | 78 | def _make_layer(self, n_units, n_ch, w_base, cardinary, stride=1): 79 | layers = [] 80 | mid_ch, out_ch = n_ch * (w_base // 64) * cardinary, n_ch * 4 81 | for i in range(n_units): 82 | layers.append(ShakeBottleNeck(self.in_ch, mid_ch, out_ch, cardinary, stride=stride)) 83 | self.in_ch, stride = out_ch, 1 84 | return nn.Sequential(*layers) 85 | -------------------------------------------------------------------------------- /DeepAA_evaluate/networks/shakeshake/shakeshake.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | class ShakeShake(torch.autograd.Function): 10 | 11 | @staticmethod 12 | def forward(ctx, x1, x2, training=True): 13 | if training: 14 | alpha = torch.cuda.FloatTensor(x1.size(0)).uniform_() 15 | alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x1) 16 | else: 17 | alpha = 0.5 18 | return alpha * x1 + (1 - alpha) * x2 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output): 22 | beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_() 23 | beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output) 24 | beta = Variable(beta) 25 | 26 | return beta * grad_output, (1 - beta) * grad_output, None 27 | 28 | 29 | class Shortcut(nn.Module): 30 | 31 | def __init__(self, in_ch, out_ch, stride): 32 | super(Shortcut, self).__init__() 33 | self.stride = stride 34 | self.conv1 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False) 35 | self.conv2 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False) 36 | self.bn = nn.BatchNorm2d(out_ch) 37 | 38 | def forward(self, x): 39 | h = F.relu(x) 40 | 41 | h1 = F.avg_pool2d(h, 1, self.stride) 42 | h1 = self.conv1(h1) 43 | 44 | h2 = F.avg_pool2d(F.pad(h, (-1, 1, -1, 1)), 1, self.stride) 45 | h2 = self.conv2(h2) 46 | 47 | h = torch.cat((h1, h2), 1) 48 | return self.bn(h) 49 | -------------------------------------------------------------------------------- /DeepAA_evaluate/networks/wideresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | 8 | _bn_momentum = 0.1 9 | CpG = 8 10 | 11 | 12 | class ExampleWiseBatchNorm2d(nn.BatchNorm2d): 13 | def __init__(self, num_features, eps=1e-5, momentum=0.1, 14 | affine=True, track_running_stats=True): 15 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 16 | 17 | def forward(self, input): 18 | self._check_input_dim(input) 19 | 20 | exponential_average_factor = 0.0 21 | 22 | if self.training and self.track_running_stats: 23 | if self.num_batches_tracked is not None: 24 | self.num_batches_tracked += 1 25 | if self.momentum is None: # use cumulative moving average 26 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 27 | else: # use exponential moving average 28 | exponential_average_factor = self.momentum 29 | 30 | # calculate running estimates 31 | if self.training: 32 | mean = input.mean([0, 2, 3]) 33 | # use biased var in train 34 | var = input.var([0, 2, 3], unbiased=False) 35 | n = input.numel() / input.size(1) 36 | with torch.no_grad(): 37 | self.running_mean = exponential_average_factor * mean\ 38 | + (1 - exponential_average_factor) * self.running_mean 39 | # update running_var with unbiased var 40 | self.running_var = exponential_average_factor * var * n / (n - 1)\ 41 | + (1 - exponential_average_factor) * self.running_var 42 | local_means = input.mean([2, 3]) 43 | local_global_means = local_means + (mean.unsqueeze(0) - local_means).detach() 44 | local_vars = input.var([2, 3], unbiased=False) 45 | local_global_vars = local_vars + (var.unsqueeze(0) - local_vars).detach() 46 | input = (input - local_global_means[:,:,None,None]) / (torch.sqrt(local_global_vars[:,:,None,None] + self.eps)) 47 | else: 48 | mean = self.running_mean 49 | var = self.running_var 50 | input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps)) 51 | 52 | if self.affine: 53 | input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None] 54 | 55 | return input 56 | 57 | 58 | class VirtualBatchNorm2d(nn.BatchNorm2d): 59 | def __init__(self, num_features, eps=1e-5, momentum=0.1, 60 | affine=True, track_running_stats=True): 61 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 62 | 63 | def forward(self, input): 64 | self._check_input_dim(input) 65 | 66 | exponential_average_factor = 0.0 67 | 68 | if self.training and self.track_running_stats: 69 | if self.num_batches_tracked is not None: 70 | self.num_batches_tracked += 1 71 | if self.momentum is None: # use cumulative moving average 72 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 73 | else: # use exponential moving average 74 | exponential_average_factor = self.momentum 75 | 76 | # calculate running estimates 77 | if self.training: 78 | mean = input.mean([0, 2, 3]) 79 | # use biased var in train 80 | var = input.var([0, 2, 3], unbiased=False) 81 | n = input.numel() / input.size(1) 82 | with torch.no_grad(): 83 | self.running_mean = exponential_average_factor * mean \ 84 | + (1 - exponential_average_factor) * self.running_mean 85 | # update running_var with unbiased var 86 | self.running_var = exponential_average_factor * var * n / (n - 1) \ 87 | + (1 - exponential_average_factor) * self.running_var 88 | input = (input - mean.detach()[None, :, None, None]) / (torch.sqrt(var.detach()[None, :, None, None] + self.eps)) 89 | else: 90 | mean = self.running_mean 91 | var = self.running_var 92 | input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps)) 93 | 94 | if self.affine: 95 | input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None] 96 | 97 | return input 98 | 99 | 100 | def conv3x3(in_planes, out_planes, stride=1): 101 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 102 | 103 | 104 | def conv_init(m): 105 | classname = m.__class__.__name__ 106 | if classname.find('Conv') != -1: 107 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 108 | init.constant_(m.bias, 0) 109 | elif classname.find('BatchNorm') != -1: 110 | init.constant_(m.weight, 1) 111 | init.constant_(m.bias, 0) 112 | 113 | 114 | class WideBasic(nn.Module): 115 | def __init__(self, in_planes, planes, dropout_rate, norm_creator, stride=1, adaptive_dropouter_creator=None): 116 | super(WideBasic, self).__init__() 117 | self.bn1 = norm_creator(in_planes) 118 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 119 | if adaptive_dropouter_creator is None: 120 | self.dropout = nn.Dropout(p=dropout_rate) 121 | else: 122 | self.dropout = adaptive_dropouter_creator(planes, 3, stride, 1) 123 | self.bn2 = norm_creator(planes) 124 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 125 | 126 | self.shortcut = nn.Sequential() 127 | if stride != 1 or in_planes != planes: 128 | self.shortcut = nn.Sequential( 129 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 130 | ) 131 | 132 | def forward(self, x): 133 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 134 | out = self.conv2(F.relu(self.bn2(out))) 135 | out += self.shortcut(x) 136 | 137 | return out 138 | 139 | 140 | class WideResNet(nn.Module): 141 | def __init__(self, depth, widen_factor, dropout_rate, num_classes, adaptive_dropouter_creator, adaptive_conv_dropouter_creator, groupnorm, examplewise_bn, virtual_bn): 142 | super(WideResNet, self).__init__() 143 | self.in_planes = 16 144 | self.adaptive_conv_dropouter_creator = adaptive_conv_dropouter_creator 145 | 146 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 147 | assert sum([groupnorm,examplewise_bn,virtual_bn]) <= 1 148 | n = int((depth - 4) / 6) 149 | k = widen_factor 150 | 151 | nStages = [16, 16*k, 32*k, 64*k] 152 | 153 | self.adaptive_dropouters = [] #nn.ModuleList() 154 | 155 | if groupnorm: 156 | print('Uses group norm.') 157 | self.norm_creator = lambda c: nn.GroupNorm(max(c//CpG, 1), c) 158 | elif examplewise_bn: 159 | print("Uses Example Wise BN") 160 | self.norm_creator = lambda c: ExampleWiseBatchNorm2d(c, momentum=_bn_momentum) 161 | elif virtual_bn: 162 | print("Uses Virtual BN") 163 | self.norm_creator = lambda c: VirtualBatchNorm2d(c, momentum=_bn_momentum) 164 | else: 165 | self.norm_creator = lambda c: nn.BatchNorm2d(c, momentum=_bn_momentum) 166 | 167 | self.conv1 = conv3x3(3, nStages[0]) 168 | self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1) 169 | self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2) 170 | self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2) 171 | self.bn1 = self.norm_creator(nStages[3]) 172 | self.linear = nn.Linear(nStages[3], num_classes) 173 | if adaptive_dropouter_creator is not None: 174 | last_dropout = adaptive_dropouter_creator(nStages[3]) 175 | else: 176 | last_dropout = lambda x: x 177 | self.adaptive_dropouters.append(last_dropout) 178 | 179 | # self.apply(conv_init) 180 | 181 | def to(self, *args, **kwargs): 182 | super().to(*args,**kwargs) 183 | print(*args) 184 | for ad in self.adaptive_dropouters: 185 | if hasattr(ad,'to'): 186 | ad.to(*args,**kwargs) 187 | return self 188 | 189 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 190 | strides = [stride] + [1]*(num_blocks-1) 191 | layers = [] 192 | 193 | for i,stride in enumerate(strides): 194 | ada_conv_drop_c = self.adaptive_conv_dropouter_creator if i == 0 else None 195 | new_block = block(self.in_planes, planes, dropout_rate, self.norm_creator, stride, adaptive_dropouter_creator=ada_conv_drop_c) 196 | layers.append(new_block) 197 | if ada_conv_drop_c is not None: 198 | self.adaptive_dropouters.append(new_block.dropout) 199 | 200 | self.in_planes = planes 201 | 202 | return nn.Sequential(*layers) 203 | 204 | def forward(self, x): 205 | out = self.conv1(x) 206 | out = self.layer1(out) 207 | out = self.layer2(out) 208 | out = self.layer3(out) 209 | out = F.relu(self.bn1(out)) 210 | # out = F.avg_pool2d(out, 8) 211 | out = F.adaptive_avg_pool2d(out, (1, 1)) 212 | out = out.view(out.size(0), -1) 213 | out = self.adaptive_dropouters[-1](out) 214 | out = self.linear(out) 215 | 216 | return out 217 | -------------------------------------------------------------------------------- /DeepAA_evaluate/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib 4 | matplotlib.use('TkAgg') 5 | import matplotlib.pyplot as plt 6 | 7 | import torchvision.transforms.functional as F 8 | 9 | 10 | plt.rcParams["savefig.bbox"] = 'tight' 11 | 12 | 13 | def save_images(imgs, dir): 14 | if not isinstance(imgs, list): 15 | imgs = [imgs] 16 | fix, axs = plt.subplots(ncols=len(imgs), squeeze=False) 17 | for i, img in enumerate(imgs): 18 | img = img.detach() 19 | img = F.to_pil_image(img) 20 | axs[0, i].imshow(np.asarray(img)) 21 | axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) 22 | fix.savefig(dir) 23 | return fix -------------------------------------------------------------------------------- /DeepAA_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import copy 5 | import random 6 | import datetime 7 | 8 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 9 | import tensorflow as tf 10 | tf.get_logger().setLevel(logging.ERROR) 11 | 12 | 13 | from data_generator import DataGenerator, DataAugmentation 14 | from utils import CTLHistory 15 | from lr_scheduler import GradualWarmup_Cosine_Scheduler 16 | import resnet 17 | from resnet_imagenet import imagenet_resnet50 18 | 19 | from data_generator import get_cifar10_data, get_cifar100_data 20 | 21 | from augmentation import AutoContrast, Invert, Equalize, Solarize, Posterize, Contrast, Brightness, Sharpness, \ 22 | Identity, Color, ShearX, ShearY, TranslateX, TranslateY, Rotate 23 | from augmentation import RandCrop, RandCutout, RandFlip, RandCutout60 24 | from augmentation import RandResizeCrop_imagenet, centerCrop_imagenet 25 | 26 | 27 | from policy import DA_Policy_logits 28 | from augmentation import IMAGENET_SIZE 29 | 30 | import torch 31 | import threading 32 | import queue 33 | from imagenet_data_utils import get_imagenet_split 34 | 35 | def aug_op_cifar_list(): # oeprators and their ranges 36 | l = [ 37 | (Identity, 0., 1.0), # 0 38 | (ShearX, -0.3, 0.3), # 1 39 | (ShearY, -0.3, 0.3), # 2 40 | (TranslateX, -0.45, 0.45), # 3 41 | (TranslateY, -0.45, 0.45), # 4 42 | (Rotate, -30., 30.), # 5 43 | (AutoContrast, 0., 1.), # 6 44 | (Invert, 0., 1.), # 7 45 | (Equalize, 0., 1.), # 8 46 | (Solarize, 0., 256.), # 9 47 | (Posterize, 4., 8.), # 10, 48 | (Contrast, 0.1, 1.9), # 11 49 | (Color, 0.1, 1.9), # 12 50 | (Brightness, 0.1, 1.9), # 13 51 | (Sharpness, 0.1, 1.9), # 14 52 | (RandFlip, 0., 1.0), # 15 53 | (RandCutout, 0., 1.0), # 16 54 | (RandCrop, 0., 1.0), # 17 55 | ] 56 | names = [] 57 | for op in l: 58 | info = op.__str__().split(' ') 59 | name = '{}:({},{}'.format(info[1], info[-2], info[-1]) 60 | names.append(name) 61 | 62 | return l, names 63 | 64 | def aug_op_imagenet_list(): # 16 oeprations and their ranges 65 | l = [ 66 | (Identity, 0., 1.0), # 0 67 | (ShearX, -0.3, 0.3), # 1 68 | (ShearY, -0.3, 0.3), # 2 69 | (TranslateX, -0.45, 0.45), # 3 70 | (TranslateY, -0.45, 0.45), # 4 71 | (Rotate, -30., 30.), # 5 72 | (AutoContrast, 0., 1.), # 6 73 | (Invert, 0., 1.), # 7 74 | (Equalize, 0., 1.), # 8 75 | (Solarize, 0., 256.), # 9 76 | (Posterize, 4., 8.), # 10 77 | (Contrast, 0.1, 1.9), # 11 78 | (Color, 0.1, 1.9), # 12 79 | (Brightness, 0.1, 1.9), # 13 80 | (Sharpness, 0.1, 1.9), # 14 81 | (RandFlip, 0., 1.0), # 15 82 | (RandCutout60, 0., 1.0), # 16 83 | (RandResizeCrop_imagenet, 0., 1.), 84 | ] 85 | names = [] 86 | for op in l: 87 | info = op.__str__().split(' ') 88 | name = '{}:({},{}'.format(info[1], info[-2], info[-1]) 89 | names.append(name) 90 | 91 | return l, names 92 | 93 | 94 | # Get the model 95 | def get_model(args, model, n_classes): 96 | if model == 'WRN_28_10': 97 | model = resnet.cifar_WRN_28_10(dropout=0, l2_reg=0.00025, 98 | preact_shortcuts=False, n_classes=n_classes, input_shape=args.img_size) 99 | elif model == 'WRN_40_2': 100 | model = resnet.cifar_WRN_40_2(dropout=0, l2_reg=0.00025, 101 | preact_shortcuts=False, n_classes=n_classes, input_shape=args.img_size) 102 | elif model == 'resnet50': 103 | model = imagenet_resnet50() 104 | else: 105 | raise Exception('Unrecognized model') 106 | return model 107 | 108 | # metric to keep track of 109 | train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() 110 | test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy() 111 | train_loss = tf.keras.metrics.Mean() 112 | test_loss = tf.keras.metrics.Mean() 113 | 114 | def get_img_size(args): 115 | if 'cifar' in args.dataset: 116 | return (32, 32, 3) 117 | elif 'imagenet' in args.dataset: 118 | return (*IMAGENET_SIZE, 3) 119 | else: 120 | raise Exception 121 | 122 | # get the data 123 | def get_dataset(args): 124 | print('Loading train and retrain dataset.') 125 | if args.dataset in ['cifar10', 'cifar100']: 126 | if args.dataset == 'cifar10': 127 | assert args.n_classes == 10 128 | x_train_, y_train_, x_val, y_val, x_test, y_test = get_cifar10_data(val_size=10000) 129 | x_train, y_train = x_train_[:args.pretrain_size], y_train_[:args.pretrain_size] 130 | x_search, y_search = x_train_[args.pretrain_size:], y_train_[args.pretrain_size:] 131 | elif args.dataset == 'cifar100': 132 | assert args.n_classes == 100 133 | x_train_, y_train_, x_val, y_val, x_test, y_test = get_cifar100_data(val_size=10000) 134 | x_train, y_train = x_train_[:args.pretrain_size], y_train_[:args.pretrain_size] 135 | x_search, y_search = x_train_[args.pretrain_size:], y_train_[args.pretrain_size:] 136 | train_ds = DataGenerator(x_train, y_train, batch_size=args.batch_size, drop_last=True) 137 | search_ds = DataGenerator(x_search, y_search, batch_size=args.batch_size, drop_last=True) 138 | val_ds = DataGenerator(x_val, y_val, batch_size=args.val_batch_size, drop_last=True) 139 | test_ds = DataGenerator(x_test, y_test, batch_size=args.test_batch_size, drop_last=False, shuffle=False) # setting shuffle=False for parallel evaluation 140 | elif args.dataset == 'imagenet': 141 | assert args.n_classes == 1000 142 | def collate_fn_imagenet_list(l): # return a list 143 | images, labels = zip(*l) 144 | assert images[0].dtype == np.uint8 145 | return list(images), np.array(labels, dtype=np.int32) 146 | if args.dataset == 'imagenet': 147 | train_ds_total, val_ds, search_ds, train_ds, test_ds = get_imagenet_split(n_GPU=1, seed=300) 148 | assert len(train_ds) == 1 and isinstance(train_ds, list), 'Train_ds should be a length=1 list' 149 | train_ds = train_ds[0] 150 | test_ds = torch.utils.data.DataLoader( 151 | test_ds, batch_size=256, shuffle=False, num_workers=64, 152 | pin_memory=False, 153 | drop_last=False, sampler=None, 154 | collate_fn=collate_fn_imagenet_list, 155 | ) 156 | else: 157 | raise Exception('Unrecognized dataset') 158 | 159 | return train_ds, val_ds, test_ds, search_ds 160 | 161 | def get_augmentation(args): 162 | if 'cifar' in args.dataset: 163 | augmentation_default = DataAugmentation(num_classes=args.n_classes, dataset=args.dataset, image_shape=args.img_size, 164 | ops_list=(None, None), 165 | default_pre_aug=None, 166 | default_post_aug=[RandCrop, 167 | RandFlip, 168 | RandCutout]) 169 | 170 | augmentation_search = DataAugmentation(num_classes=args.n_classes, dataset=args.dataset, image_shape=args.img_size, 171 | ops_list=aug_op_cifar_list(), 172 | default_pre_aug=None, 173 | default_post_aug=None) 174 | 175 | augmentation_test = DataAugmentation(num_classes=args.n_classes, dataset=args.dataset, image_shape=args.img_size, 176 | ops_list=(None, None), 177 | default_pre_aug=None, 178 | default_post_aug=None) 179 | elif 'imagenet' in args.dataset: 180 | augmentation_default = DataAugmentation(num_classes=args.n_classes, dataset=args.dataset, 181 | image_shape=args.img_size, 182 | ops_list=(None, None), 183 | default_pre_aug=None, 184 | default_post_aug=[RandResizeCrop_imagenet, # 185 | RandFlip]) 186 | 187 | augmentation_search = DataAugmentation(num_classes=args.n_classes, dataset=args.dataset, image_shape=args.img_size, 188 | ops_list=aug_op_imagenet_list(), 189 | default_pre_aug=None, 190 | default_post_aug=None) 191 | 192 | 193 | augmentation_test = DataAugmentation(num_classes=args.n_classes, dataset=args.dataset, 194 | image_shape=args.img_size, 195 | ops_list=(None, None), 196 | default_pre_aug=None, 197 | default_post_aug=[ 198 | centerCrop_imagenet, 199 | ]) 200 | return augmentation_default, augmentation_search, augmentation_test 201 | 202 | def get_optim_net(args, nb_train_steps): 203 | scheduler_lr = GradualWarmup_Cosine_Scheduler(starting_lr=0., initial_lr=args.pretrain_lr, 204 | ending_lr=1e-7, 205 | warmup_steps= 0, 206 | total_steps=nb_train_steps * args.nb_epochs) 207 | 208 | optim_net = tf.optimizers.SGD(learning_rate=scheduler_lr, momentum=0.9, nesterov=True) 209 | return optim_net 210 | 211 | 212 | 213 | 214 | def get_policy(args, op_names, ops_mid_magnitude, available_policies): 215 | policy = DA_Policy_logits(args.l_ops, args.l_mags, args.l_uniq, 216 | op_names=op_names, 217 | ops_mid_magnitude=ops_mid_magnitude, N_repeat_random=args.N_repeat_random, 218 | available_policies=available_policies) 219 | return policy 220 | 221 | def get_optim_policy(policy_lr): 222 | optim_policy = tf.optimizers.Adam(learning_rate=policy_lr, beta_1=0.9, beta_2=0.999) 223 | return optim_policy 224 | 225 | 226 | # get the loss 227 | def get_loss_fun(): 228 | train_loss_fun = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, 229 | reduction=tf.keras.losses.Reduction.NONE) 230 | test_loss_fun = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, 231 | reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE) 232 | val_loss_fun = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, 233 | reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE) 234 | return train_loss_fun, test_loss_fun, val_loss_fun 235 | 236 | 237 | def get_lops_luniq(args, ops_mid_magnitude): 238 | if 'cifar' in args.dataset: 239 | _, op_names = aug_op_cifar_list() 240 | elif 'imagenet' in args.dataset: 241 | _, op_names = aug_op_imagenet_list() 242 | else: 243 | raise Exception('Unknown dataset ={}'.format(args.dataset)) 244 | 245 | names_modified = [op_name.split(':')[0] for op_name in op_names] 246 | l_ops = len(op_names) 247 | l_uniq = 0 248 | for k_name, name in enumerate(names_modified): 249 | mid_mag = ops_mid_magnitude[name] 250 | if mid_mag == 'random': 251 | l_uniq += 1 # The op is a random op, however we only sample one op 252 | elif mid_mag is not None and mid_mag >=0 and mid_mag <= args.l_mags-1: 253 | l_uniq += args.l_mags-1 254 | elif mid_mag is not None and mid_mag == -1: # magnitude==-1 means all l_mags are independnt policies; or mid_mag > args.l_mags-1) 255 | l_uniq += args.l_mags 256 | elif mid_mag is None: 257 | l_uniq += 1 258 | else: 259 | raise Exception('mid_mag = {} is invalid'.format(mid_mag)) 260 | return l_ops, l_uniq 261 | 262 | def get_all_policy(policy_train): 263 | l_ops, l_mags = policy_train.l_ops, policy_train.l_mags 264 | ops, mags = np.meshgrid(np.arange(l_ops), np.arange(l_mags), indexing='ij') 265 | ops = np.reshape(ops, [l_ops*l_mags,1]) 266 | mags = np.reshape(mags, [l_ops*l_mags,1]) 267 | return ops.astype(np.int32), mags.astype(np.int32) 268 | 269 | class PrefetchGenerator(threading.Thread): 270 | def __init__(self, search_ds, val_ds, n_classes, search_bs=8, val_bs=64): 271 | threading.Thread.__init__(self) 272 | self.queue = queue.Queue(1) 273 | self.search_ds = search_ds 274 | self.val_ds = val_ds 275 | self.n_classes = n_classes 276 | self.search_bs = search_bs 277 | self.val_bs = val_bs 278 | self.daemon = True 279 | self.start() 280 | 281 | @staticmethod 282 | def sample_label_and_batch(dataset, bs, n_classes, MAX_iterations=100): 283 | for k in range(MAX_iterations): 284 | try: 285 | lab = random.randint(0, n_classes-1) 286 | imgs, labs = dataset.sample_labeled_data_batch(lab, bs) 287 | except: 288 | print('Insufficient data in a single class, try {}/{}'.format(k, MAX_iterations)) 289 | continue 290 | return lab, imgs, labs 291 | raise Exception('Maximum number of iteration {} reached'.format(MAX_iterations)) 292 | 293 | def run(self): 294 | while True: 295 | images_val, labels_val, images_train, labels_train = [], [], [], [] 296 | for _ in range(self.search_bs): 297 | lab, imgs_val, labs_val = PrefetchGenerator.sample_label_and_batch(self.val_ds, self.val_bs, self.n_classes) 298 | imgs_train, labs_train = self.search_ds.sample_labeled_data_batch(lab, 1) 299 | images_val.append(imgs_val) 300 | labels_val.append(labs_val) 301 | images_train.append(imgs_train) 302 | labels_train.append(labs_train) 303 | self.queue.put( (images_val, labels_val, images_train, labels_train) ) 304 | 305 | def next(self): 306 | next_item = self.queue.get() 307 | return next_item 308 | 309 | 310 | def save_policy(args, all_using_policies, augmentation_search): 311 | ops, mags = all_using_policies[0].unique_policy 312 | op_names = augmentation_search.op_names 313 | policy_probs = [] 314 | for k_policy, policy in enumerate(all_using_policies): 315 | policy_probs.append(tf.nn.softmax(policy.logits).numpy()) 316 | policy_probs = np.stack(policy_probs, axis=0) 317 | 318 | np.savez('./policy_port/policy_DeepAA_{}.npz'.format(datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")), 319 | policy_probs=policy_probs, l_ops=args.l_ops, l_mags=args.l_mags, 320 | ops=ops, mags=mags, op_names=op_names) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep AutoAugment 2 | 3 | This is the official implementation of Deep AutoAugment ([DeepAA](https://openreview.net/forum?id=St-53J9ZARf)), a fully automated data augmentation policy search method. Leaderboard is here: https://paperswithcode.com/paper/deep-autoaugment-1 4 | 5 |

6 | DeepAA
7 |

8 | 9 | ## 5-Minute Explanation Video 10 | Click the figure to watch this short video explaining our work. 11 | 12 | [![slideslive_link](./images/DeepAA_slideslive.png)](https://recorder-v3.slideslive.com/#/share?share=64177&s=6d93977f-2a40-436d-a404-8808aee650fa) 13 | 14 | ## Requirements 15 | DeepAA is implemented using TensorFlow. 16 | To be consistent with previous work, we run the policy evaluation based on [TrivialAugment](https://github.com/automl/trivialaugment), which is implemented using PyTorch. 17 | 18 | ### Install required packages 19 | a. Create a conda virtual environment. 20 | ```shell 21 | conda create -n deepaa python=3.7 22 | conda activate deepaa 23 | ``` 24 | 25 | b. Install Tensorflow and PyTorch. 26 | ```shell 27 | conda install tensorflow-gpu=2.5 cudnn=8.1 cudatoolkit=11.2 -c conda-forge 28 | pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html 29 | ``` 30 | 31 | c. Install other dependencies. 32 | ```shell 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | 37 | ## Experiments 38 | 39 | ### Run augmentation policy search on CIFAR-10/100. 40 | ```shell 41 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 42 | python DeepAA_search.py --dataset cifar10 --n_classes 10 --use_model WRN_40_2 --n_policies 6 --search_bno 1024 --pretrain_lr 0.1 --seed 1 --batch_size 128 --test_batch_size 512 --policy_lr 0.025 --l_mags 13 --use_pool --pretrain_size 5000 --nb_epochs 45 --EXP_G 16 --EXP_gT_factor=4 --train_same_labels 16 43 | ``` 44 | 45 | ### Run augmentation policy search on ImageNet. 46 | ```shell 47 | mkdir pretrained_imagenet 48 | ``` 49 | Download the [files](https://drive.google.com/drive/folders/1QmqWfF_dzyZPDIuvkiLHp0X6JiUNbIZI?usp=sharing) and copy them to the `./pretrained_imagenet` folder. 50 | ```shell 51 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 52 | python DeepAA_search.py --dataset imagenet --n_classes 1000 --use_model resnet50 --n_policies 6 --search_bno 1024 --seed 1 --batch_size 128 --test_batch_size 512 --policy_lr 0.025 --l_mags 13 --use_pool --EXP_G 16 --EXP_gT_factor=4 --train_same_labels 16 53 | ``` 54 | 55 | ### Evaluate the policy found on CIFAR-10/100 and ImageNet. 56 | ```shell 57 | mkdir ckpt 58 | python -m DeepAA_evaluate.train -c confs/wresnet28x10_cifar10_DeepAA_1.yaml --dataroot ./data --save ckpt/DeepAA_cifar10.pth --tag Exp_DeepAA_cifar10 59 | python -m DeepAA_evaluate.train -c confs/wresnet28x10_cifar100_DeepAA_1.yaml --dataroot ./data --save ckpt/DeepAA_cifar100.pth --tag Exp_DeepAA_cifar100 60 | python -m DeepAA_evaluate.train -c confs/resnet50_imagenet_DeepAA_8x256_1.yaml --dataroot ./data --save ckpt/DeepAA_imagenet.pth --tag Exp_DeepAA_imagenet 61 | ``` 62 | 63 | ### Evaluate the policy found on CIFAR-10/100 with Batch Augmentation. 64 | ```shell 65 | mkdir ckpt 66 | python -m DeepAA_evaluate.train -c confs/wresnet28x10_cifar10_DeepAA_BatchAug8x_1.yaml --dataroot ./data --save ckpt/DeepAA_cifar10.pth --tag Exp_DeepAA_cifar10 67 | python -m DeepAA_evaluate.train -c confs/wresnet28x10_cifar100_DeepAA_BatchAug8x_1.yaml --dataroot ./data --save ckpt/DeepAA_cifar100.pth --tag Exp_DeepAA_cifar100 68 | ``` 69 | 70 | ## Visualization 71 | 72 | The policies found on CIFAR-10/100 and ImageNet are visualized as follows. 73 | 74 |

75 | operator
76 |

77 | 78 | The distribution of operations at each layer of the policy for (a) CIFAR-10/100 and (b) ImageNet. The probability of each operation is summed up over all 12 discrete intensity levels of the corresponding transformation. 79 | 80 |

81 | magnitude CIFAR
82 |

83 | 84 | The distribution of discrete magnitudes of each augmentation transformation in each layer of the policy for CIFAR-10/100. The x-axis represents the discrete magnitudes and the y-axis represents the probability. The magnitude is discretized to 12 levels with each transformation having its own range. A large absolute value of the magnitude corresponds to high transformation intensity. Note that we do not show identity, autoContrast, invert, equalize, flips, Cutout and crop because they do not have intensity parameters. 85 | 86 |

87 | magnitude ImageNet
88 |

89 | 90 | The distribution of discrete magnitudes of each augmentation transformation in each layer of the policy for ImageNet. The x-axis represents the discrete magnitudes and the y-axis represents the probability. The magnitude is discretized to 12 levels with each transformation having its own range. A large absolute value of the magnitude corresponds to high transformation intensity. Note that we do not show identity, autoContrast, invert, equalize, flips, Cutout and crop because they do not have intensity parameters. 91 | 92 | ## Citation 93 | If you find this useful for your work, please consider citing: 94 | ``` 95 | @inproceedings{ 96 | zheng2022deep, 97 | title={Deep AutoAugment}, 98 | author={Yu Zheng and Zhi Zhang and Shen Yan and Mi Zhang}, 99 | booktitle={International Conference on Learning Representations}, 100 | year={2022}, 101 | url={https://openreview.net/forum?id=St-53J9ZARf} 102 | } 103 | ``` 104 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/__init__.py -------------------------------------------------------------------------------- /augmentation.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py 4 | 5 | import random 6 | 7 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 8 | import numpy as np 9 | from PIL import Image 10 | import math 11 | 12 | IMAGENET_SIZE = (224, 224) # (width, height) may set to (244, 224) 13 | 14 | _IMAGENET_PCA = { 15 | 'eigval': [0.2175, 0.0188, 0.0045], 16 | 'eigvec': [ 17 | [-0.5675, 0.7192, 0.4009], 18 | [-0.5808, -0.0045, -0.8140], 19 | [-0.5836, -0.6948, 0.4203], 20 | ] 21 | } 22 | _CIFAR_MEAN, _CIFAR_STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 23 | 24 | def ShearX(img, v): # [-0.3, 0.3] 25 | assert -0.3 <= v <= 0.3 26 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 27 | 28 | def ShearY(img, v): # [-0.3, 0.3] 29 | assert -0.3 <= v <= 0.3 30 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 31 | 32 | 33 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 34 | assert -0.45 <= v <= 0.45 35 | v = v * img.size[0] 36 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 37 | 38 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 39 | assert -0.45 <= v <= 0.45 40 | v = v * img.size[1] 41 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 42 | 43 | 44 | def Rotate(img, v): # [-30, 30] 45 | assert -30 <= v <= 30 46 | return img.rotate(v) 47 | 48 | 49 | def AutoContrast(img, _): 50 | return PIL.ImageOps.autocontrast(img) 51 | 52 | 53 | def Invert(img, _): 54 | return PIL.ImageOps.invert(img) 55 | 56 | 57 | def Equalize(img, _): 58 | return PIL.ImageOps.equalize(img) 59 | 60 | 61 | def Flip(img, _): # not from the paper 62 | return PIL.ImageOps.mirror(img) 63 | 64 | 65 | def Solarize(img, v): # [0, 256] 66 | assert 0 <= v <= 256 67 | return PIL.ImageOps.solarize(img, v) 68 | 69 | 70 | def SolarizeAdd(img, addition=0, threshold=128): 71 | img_np = np.array(img).astype(np.int) 72 | img_np = img_np + addition 73 | img_np = np.clip(img_np, 0, 255) 74 | img_np = img_np.astype(np.uint8) 75 | img = Image.fromarray(img_np) 76 | return PIL.ImageOps.solarize(img, threshold) 77 | 78 | 79 | def Posterize(img, v): # [4, 8] 80 | assert 4 <= v <= 8 # FastAA 81 | v = int(v) 82 | return PIL.ImageOps.posterize(img, v) 83 | 84 | 85 | def Contrast(img, v): # [0.1,1.9] 86 | assert 0.1 <= v <= 1.9 87 | return PIL.ImageEnhance.Contrast(img).enhance(v) 88 | 89 | 90 | def Color(img, v): # [0.1,1.9] 91 | assert 0.1 <= v <= 1.9 92 | return PIL.ImageEnhance.Color(img).enhance(v) 93 | 94 | 95 | def Brightness(img, v): # [0.1,1.9] 96 | assert 0.1 <= v <= 1.9 97 | return PIL.ImageEnhance.Brightness(img).enhance(v) 98 | 99 | 100 | def Sharpness(img, v): # [0.1,1.9] 101 | assert 0.1 <= v <= 1.9 102 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 103 | 104 | 105 | def RandCrop(img, _): 106 | v = 4 107 | return mean_pad_randcrop(img, v) 108 | 109 | def RandCutout(img, _): 110 | v = 16 111 | w, h = img.size 112 | x = random.uniform(0, w) 113 | y = random.uniform(0, h) 114 | 115 | x0 = int(min(w, max(0, x - v // 2))) # clip to the range (0, w) 116 | x1 = int(min(w, max(0, x + v // 2))) 117 | y0 = int(min(h, max(0, y - v // 2))) 118 | y1 = int(min(h, max(0, y + v // 2))) 119 | 120 | xy = (x0, y0, x1, y1) 121 | color = (125, 123, 114) 122 | # color = (0, 0, 0) 123 | img = img.copy() 124 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 125 | return img 126 | 127 | 128 | def RandCutout60(img, _): 129 | v = 60 130 | w, h = img.size 131 | x_left = max(0, w // 2 - 256 // 2) 132 | x_right = min(w, w // 2 + 256 // 2) 133 | y_bottom = max(0, h // 2 - 256 // 2) 134 | y_top = min(h, h // 2 + 256 // 2) 135 | 136 | x = random.uniform(x_left, x_right) 137 | y = random.uniform(y_bottom, y_top) 138 | 139 | x0 = int(min(w, max(0, x - v // 2))) 140 | x1 = int(min(w, max(0, x + v // 2))) 141 | y0 = int(min(h, max(0, y - v // 2))) 142 | y1 = int(min(h, max(0, y + v // 2))) 143 | 144 | xy = (x0, y0, x1, y1) 145 | color = (125, 123, 114) 146 | # color = (0, 0, 0) 147 | img = img.copy() 148 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 149 | return img 150 | 151 | 152 | def RandFlip(img, _): 153 | if random.random() > 0.5: 154 | img = Flip(img, None) 155 | return img 156 | 157 | 158 | 159 | def mean_pad_randcrop(img, v): 160 | # v: Pad with mean value=[125, 123, 114] by v pixels on each side and then take random crop 161 | assert v <= 10, 'The maximum shift should be less then 10' 162 | padded_size = (img.size[0] + 2*v, img.size[1] + 2*v) 163 | new_img = PIL.Image.new('RGB', padded_size, color=(125, 123, 114)) 164 | new_img.paste(img, (v, v)) 165 | top = random.randint(0, v*2) 166 | left = random.randint(0, v*2) 167 | new_img = new_img.crop((left, top, left + img.size[0], top + img.size[1])) 168 | return new_img 169 | 170 | def Identity(img, v): 171 | return img 172 | 173 | 174 | def RandResizeCrop_imagenet(img, _): 175 | # ported from torchvision 176 | # for ImageNet use only 177 | scale = (0.08, 1.0) 178 | ratio = (3. / 4., 4. / 3.) 179 | size = IMAGENET_SIZE # (224, 224) 180 | 181 | def get_params(img, scale, ratio): 182 | width, height = img.size 183 | area = float(width * height) 184 | log_ratio = [math.log(r) for r in ratio] 185 | 186 | for _ in range(10): 187 | target_area = area * random.uniform(scale[0], scale[1]) 188 | aspect_ratio = math.exp(random.uniform(log_ratio[0], log_ratio[1])) 189 | 190 | w = round(math.sqrt(target_area * aspect_ratio)) 191 | h = round(math.sqrt(target_area / aspect_ratio)) 192 | if 0 < w <= width and 0 < h <= height: 193 | top = random.randint(0, height - h) 194 | left = random.randint(0, width - w) 195 | return left, top, w, h 196 | 197 | # fallback to central crop 198 | in_ratio = float(width) / float(height) 199 | if in_ratio < min(ratio): 200 | w = width 201 | h = round(w / min(ratio)) 202 | elif in_ratio > max(ratio): 203 | h = height 204 | w = round(h * max(ratio)) 205 | else: 206 | w = width 207 | h = height 208 | top = (height - h) // 2 209 | left = (width - w) // 2 210 | return left, top, w, h 211 | 212 | left, top, w_box, h_box = get_params(img, scale, ratio) 213 | box = (left, top, left + w_box, top + h_box) 214 | img = img.resize(size=size, resample=PIL.Image.CUBIC, box=box) 215 | return img 216 | 217 | 218 | def Resize_imagenet(img, size): 219 | w, h = img.size 220 | if isinstance(size, int): 221 | short, long = (w, h) if w <= h else (h, w) 222 | if short == size: 223 | return img 224 | new_short, new_long = size, int(size * long / short) 225 | new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) 226 | return img.resize((new_w, new_h), PIL.Image.BICUBIC) 227 | elif isinstance(size, tuple) or isinstance(size, list): 228 | assert len(size) == 2, 'Check the size {}'.format(size) 229 | return img.resize(size, PIL.Image.BICUBIC) 230 | else: 231 | raise Exception 232 | 233 | 234 | def centerCrop_imagenet(img, _): 235 | # for ImageNet only 236 | # https://github.com/pytorch/vision/blob/master/torchvision/transforms/functional.py 237 | crop_width, crop_height = IMAGENET_SIZE # (224,224) 238 | image_width, image_height = img.size 239 | 240 | if crop_width > image_width or crop_height > image_height: 241 | padding_ltrb = [ 242 | (crop_width - image_width) // 2 if crop_width > image_width else 0, 243 | (crop_height - image_height) // 2 if crop_height > image_height else 0, 244 | (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, 245 | (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, 246 | ] 247 | img = pad(img, padding_ltrb, fill=0) 248 | image_width, image_height = img.size 249 | if crop_width == image_width and crop_height == image_height: 250 | return img 251 | 252 | crop_top = int(round((image_height - crop_height) / 2.)) 253 | crop_left = int(round((image_width - crop_width) / 2.)) 254 | return img.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height)) 255 | 256 | # def centerCrop_imagenet_default(img): 257 | # return centerCrop_imagenet(img, None) 258 | 259 | def _parse_fill(fill, img, name="fillcolor"): 260 | # Process fill color for affine transforms 261 | num_bands = len(img.getbands()) 262 | if fill is None: 263 | fill = 0 264 | if isinstance(fill, (int, float)) and num_bands > 1: 265 | fill = tuple([fill] * num_bands) 266 | if isinstance(fill, (list, tuple)): 267 | if len(fill) != num_bands: 268 | msg = ("The number of elements in 'fill' does not match the number of " 269 | "bands of the image ({} != {})") 270 | raise ValueError(msg.format(len(fill), num_bands)) 271 | 272 | fill = tuple(fill) 273 | 274 | return {name: fill} 275 | 276 | 277 | def pad(img, padding_ltrb, fill=0, padding_mode='constant'): 278 | if isinstance(padding_ltrb, list): 279 | padding_ltrb = tuple(padding_ltrb) 280 | if padding_mode == 'constant': 281 | opts = _parse_fill(fill, img, name='fill') 282 | if img.mode == 'P': 283 | palette = img.getpalette() 284 | image = PIL.ImageOps.expand(img, border=padding_ltrb, **opts) 285 | image.putpalette(palette) 286 | return image 287 | return PIL.ImageOps.expand(img, border=padding_ltrb, **opts) 288 | elif len(padding_ltrb) == 4: 289 | image_width, image_height = img.size 290 | cropping = -np.minimum(padding_ltrb, 0) 291 | if cropping.any(): 292 | crop_left, crop_top, crop_right, crop_bottom = cropping 293 | img = img.crop((crop_left, crop_top, image_width - crop_right, image_height - crop_bottom)) 294 | pad_left, pad_top, pad_right, pad_bottom = np.maximum(padding_ltrb, 0) 295 | 296 | if img.mode == 'P': 297 | palette = img.getpalette() 298 | img = np.asarray(img) 299 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) 300 | img = Image.fromarray(img) 301 | img.putpalette(palette) 302 | return img 303 | 304 | img = np.asarray(img) 305 | # RGB image 306 | if len(img.shape) == 3: 307 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode) 308 | # Grayscale image 309 | if len(img.shape) == 2: 310 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode) 311 | 312 | return Image.fromarray(img) 313 | else: 314 | raise Exception 315 | 316 | def get_mid_magnitude(l_mags): 317 | ops_mid_magnitude = {'Identity': None, 318 | 'ShearX': (l_mags - 1) // 2, 319 | 'ShearY': (l_mags - 1) // 2, 320 | 'TranslateX': (l_mags - 1) // 2, 321 | 'TranslateY': (l_mags - 1) // 2, 322 | 'Rotate': (l_mags - 1) // 2, 323 | 'AutoContrast': None, 324 | 'Invert': None, 325 | 'Equalize': None, 326 | 'Solarize': l_mags - 1, 327 | 'Posterize': l_mags - 1, 328 | 'Contrast': (l_mags - 1) // 2, 329 | 'Color': (l_mags - 1) // 2, 330 | 'Brightness': (l_mags - 1) // 2, 331 | 'Sharpness': (l_mags - 1) // 2, 332 | 'RandFlip': 'random', 333 | 'RandCutout': 'random', 334 | 'RandCutout60': 'random', 335 | 'RandCrop': 'random', 336 | 'RandResizeCrop_imagenet': 'random', 337 | } 338 | return ops_mid_magnitude -------------------------------------------------------------------------------- /confs/resnet50_imagenet_DeepAA_8x256_1.yaml: -------------------------------------------------------------------------------- 1 | #load_main_model: true 2 | save_model: true 3 | model: 4 | type: resnet50 5 | dataset: imagenet 6 | aug: DeepAA 7 | deepaa: 8 | EXP: imagenet_1 9 | augmentation_search_space: Not_used 10 | cutout: -1 11 | batch: 256 12 | gpus: 8 13 | epoch: 270 14 | lr: .1 15 | lr_schedule: 16 | type: 'resnet' 17 | warmup: 18 | multiplier: 8.0 19 | epoch: 5 20 | optimizer: 21 | type: sgd 22 | nesterov: True 23 | decay: 0.0001 24 | clip: 0 25 | test_interval: 20 26 | 27 | -------------------------------------------------------------------------------- /confs/resnet50_imagenet_DeepAA_8x256_2.yaml: -------------------------------------------------------------------------------- 1 | #load_main_model: true 2 | save_model: true 3 | model: 4 | type: resnet50 5 | dataset: imagenet 6 | aug: DeepAA 7 | deepaa: 8 | EXP: imagenet_2 9 | augmentation_search_space: Not_used 10 | cutout: -1 11 | batch: 256 12 | gpus: 8 13 | epoch: 270 14 | lr: .1 15 | lr_schedule: 16 | type: 'resnet' 17 | warmup: 18 | multiplier: 8.0 19 | epoch: 5 20 | optimizer: 21 | type: sgd 22 | nesterov: True 23 | decay: 0.0001 24 | clip: 0 25 | test_interval: 20 26 | 27 | -------------------------------------------------------------------------------- /confs/wresnet28x10_cifar100_DeepAA_1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: wresnet28_10 3 | dataset: cifar100 4 | aug: DeepAA 5 | deepaa: 6 | EXP: cifar_1 7 | cutout: -1 8 | batch: 128 9 | gpus: 1 10 | augmentation_search_space: Not_used # fixed_standard 11 | epoch: 200 12 | lr: 0.1 13 | lr_schedule: 14 | type: 'cosine' 15 | warmup: 16 | multiplier: 1 17 | epoch: 5 18 | optimizer: 19 | type: sgd 20 | nesterov: True 21 | decay: 0.0005 22 | 23 | -------------------------------------------------------------------------------- /confs/wresnet28x10_cifar100_DeepAA_1_wd1e-3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: wresnet28_10 3 | dataset: cifar100 4 | aug: DeepAA 5 | deepaa: 6 | EXP: cifar_1 7 | cutout: -1 8 | batch: 128 9 | gpus: 1 10 | augmentation_search_space: Not_used # fixed_standard 11 | epoch: 200 12 | lr: 0.1 13 | lr_schedule: 14 | type: 'cosine' 15 | warmup: 16 | multiplier: 1 17 | epoch: 5 18 | optimizer: 19 | type: sgd 20 | nesterov: True 21 | decay: 0.001 22 | 23 | -------------------------------------------------------------------------------- /confs/wresnet28x10_cifar100_DeepAA_2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: wresnet28_10 3 | dataset: cifar100 4 | aug: DeepAA 5 | deepaa: 6 | EXP: cifar_2 7 | cutout: -1 8 | batch: 128 9 | gpus: 1 10 | augmentation_search_space: Not_used # fixed_standard 11 | epoch: 200 12 | lr: 0.1 13 | lr_schedule: 14 | type: 'cosine' 15 | warmup: 16 | multiplier: 1 17 | epoch: 5 18 | optimizer: 19 | type: sgd 20 | nesterov: True 21 | decay: 0.0005 22 | 23 | -------------------------------------------------------------------------------- /confs/wresnet28x10_cifar100_DeepAA_2_wd1e-3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: wresnet28_10 3 | dataset: cifar100 4 | aug: DeepAA 5 | deepaa: 6 | EXP: cifar_2 7 | cutout: -1 8 | batch: 128 9 | gpus: 1 10 | augmentation_search_space: Not_used # fixed_standard 11 | epoch: 200 12 | lr: 0.1 13 | lr_schedule: 14 | type: 'cosine' 15 | warmup: 16 | multiplier: 1 17 | epoch: 5 18 | optimizer: 19 | type: sgd 20 | nesterov: True 21 | decay: 0.001 22 | 23 | -------------------------------------------------------------------------------- /confs/wresnet28x10_cifar100_DeepAA_BatchAug8x_1.yaml: -------------------------------------------------------------------------------- 1 | all_workers_use_the_same_batches: true 2 | model: 3 | type: wresnet28_10 4 | dataset: cifar100 5 | aug: DeepAA 6 | deepaa: 7 | EXP: cifar_1 8 | cutout: -1 9 | batch: 128 10 | gpus: 8 11 | augmentation_search_space: Not_used 12 | epoch: 35 13 | lr: 0.4 14 | lr_schedule: 15 | type: 'cosine' 16 | warmup: 17 | multiplier: 1 18 | epoch: 5 19 | optimizer: 20 | type: sgd 21 | nesterov: True 22 | decay: 0.0005 23 | 24 | -------------------------------------------------------------------------------- /confs/wresnet28x10_cifar100_DeepAA_BatchAug8x_2.yaml: -------------------------------------------------------------------------------- 1 | all_workers_use_the_same_batches: true 2 | model: 3 | type: wresnet28_10 4 | dataset: cifar100 5 | aug: DeepAA 6 | deepaa: 7 | EXP: cifar_2 8 | cutout: -1 9 | batch: 128 10 | gpus: 8 11 | augmentation_search_space: Not_used 12 | epoch: 35 13 | lr: 0.4 14 | lr_schedule: 15 | type: 'cosine' 16 | warmup: 17 | multiplier: 1 18 | epoch: 5 19 | optimizer: 20 | type: sgd 21 | nesterov: True 22 | decay: 0.0005 23 | 24 | -------------------------------------------------------------------------------- /confs/wresnet28x10_cifar10_DeepAA_1.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: wresnet28_10 3 | dataset: cifar10 4 | aug: DeepAA 5 | deepaa: 6 | EXP: cifar_1 7 | cutout: -1 8 | batch: 128 9 | gpus: 1 10 | augmentation_search_space: Not_used # fixed_standard 11 | epoch: 200 12 | lr: 0.1 13 | lr_schedule: 14 | type: 'cosine' 15 | warmup: 16 | multiplier: 1 17 | epoch: 5 18 | optimizer: 19 | type: sgd 20 | nesterov: True 21 | decay: 0.0005 22 | 23 | -------------------------------------------------------------------------------- /confs/wresnet28x10_cifar10_DeepAA_1_wd1e-3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: wresnet28_10 3 | dataset: cifar10 4 | aug: DeepAA 5 | deepaa: 6 | EXP: cifar_1 7 | cutout: -1 8 | batch: 128 9 | gpus: 1 10 | augmentation_search_space: Not_used # fixed_standard 11 | epoch: 200 12 | lr: 0.1 13 | lr_schedule: 14 | type: 'cosine' 15 | warmup: 16 | multiplier: 1 17 | epoch: 5 18 | optimizer: 19 | type: sgd 20 | nesterov: True 21 | decay: 0.001 22 | 23 | -------------------------------------------------------------------------------- /confs/wresnet28x10_cifar10_DeepAA_2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: wresnet28_10 3 | dataset: cifar10 4 | aug: DeepAA 5 | deepaa: 6 | EXP: cifar_2 7 | cutout: -1 8 | batch: 128 9 | gpus: 1 10 | augmentation_search_space: Not_used # fixed_standard 11 | epoch: 200 12 | lr: 0.1 13 | lr_schedule: 14 | type: 'cosine' 15 | warmup: 16 | multiplier: 1 17 | epoch: 5 18 | optimizer: 19 | type: sgd 20 | nesterov: True 21 | decay: 0.0005 22 | 23 | -------------------------------------------------------------------------------- /confs/wresnet28x10_cifar10_DeepAA_2_wd1e-3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | type: wresnet28_10 3 | dataset: cifar10 4 | aug: DeepAA 5 | deepaa: 6 | EXP: cifar_2 7 | cutout: -1 8 | batch: 128 9 | gpus: 1 10 | augmentation_search_space: Not_used # fixed_standard 11 | epoch: 200 12 | lr: 0.1 13 | lr_schedule: 14 | type: 'cosine' 15 | warmup: 16 | multiplier: 1 17 | epoch: 5 18 | optimizer: 19 | type: sgd 20 | nesterov: True 21 | decay: 0.001 22 | 23 | -------------------------------------------------------------------------------- /confs/wresnet28x10_cifar10_DeepAA_BatchAug8x_1.yaml: -------------------------------------------------------------------------------- 1 | all_workers_use_the_same_batches: true 2 | model: 3 | type: wresnet28_10 4 | dataset: cifar10 5 | aug: DeepAA 6 | deepaa: 7 | EXP: cifar_1 8 | cutout: -1 9 | batch: 128 10 | gpus: 8 11 | augmentation_search_space: Not_used 12 | epoch: 100 13 | lr: 0.2 14 | lr_schedule: 15 | type: 'cosine' 16 | warmup: 17 | multiplier: 1 18 | epoch: 5 19 | optimizer: 20 | type: sgd 21 | nesterov: True 22 | decay: 0.0005 23 | 24 | -------------------------------------------------------------------------------- /confs/wresnet28x10_cifar10_DeepAA_BatchAug8x_2.yaml: -------------------------------------------------------------------------------- 1 | all_workers_use_the_same_batches: true 2 | model: 3 | type: wresnet28_10 4 | dataset: cifar10 5 | aug: DeepAA 6 | deepaa: 7 | EXP: cifar_2 8 | cutout: -1 9 | batch: 128 10 | gpus: 8 11 | augmentation_search_space: Not_used 12 | epoch: 100 13 | lr: 0.2 14 | lr_schedule: 15 | type: 'cosine' 16 | warmup: 17 | multiplier: 1 18 | epoch: 5 19 | optimizer: 20 | type: sgd 21 | nesterov: True 22 | decay: 0.0005 23 | 24 | -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import logging 4 | import numpy as np 5 | import math 6 | from PIL import Image 7 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 8 | import tensorflow as tf 9 | tf.get_logger().setLevel(logging.ERROR) 10 | from tensorflow.keras.utils import Sequence 11 | from augmentation import IMAGENET_SIZE, centerCrop_imagenet 12 | 13 | 14 | CIFAR_MEANS = np.array([0.49139968, 0.48215841, 0.44653091], dtype=np.float32) 15 | CIFAR_STDS = np.array([0.2023, 0.1994, 0.2010], dtype=np.float32) 16 | 17 | IMAGENET_MEANS = np.array([0.485, 0.456, 0.406], dtype=np.float32) 18 | IMAGENET_STDS = np.array([0.229, 0.224, 0.225], dtype=np.float32) 19 | 20 | def split_train_validation(x, y, val_size): 21 | indices = np.arange(len(x)) 22 | np.random.shuffle(indices) 23 | x_train, x_val, y_train, y_val = x[:-val_size], x[-val_size:], y[:-val_size], y[-val_size:] 24 | return x_train, y_train, x_val, y_val 25 | 26 | def get_cifar100_data(num_classes=100, val_size=10000): 27 | (x_train_val, y_train_val), (x_test, y_test) = tf.keras.datasets.cifar100.load_data() 28 | y_train_val = y_train_val.squeeze() 29 | y_test = y_test.squeeze() 30 | if val_size > 0: 31 | x_train, y_train, x_val, y_val = split_train_validation(x_train_val, y_train_val, val_size=val_size) 32 | else: 33 | x_train, y_train = x_train_val, y_train_val 34 | x_val, y_val = None, None 35 | return x_train, y_train, x_val, y_val, x_test, y_test 36 | 37 | def get_cifar10_data(num_classes=10, val_size=10000): 38 | (x_train_val, y_train_val), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() 39 | y_train_val = y_train_val.squeeze() 40 | y_test = y_test.squeeze() 41 | if val_size > 0: 42 | x_train, y_train, x_val, y_val = split_train_validation(x_train_val, y_train_val, val_size=val_size) 43 | else: 44 | x_train, y_train = x_train_val, y_train_val 45 | x_val, y_val = None, None 46 | return x_train, y_train, x_val, y_val, x_test, y_test 47 | 48 | 49 | class DataGenerator(Sequence): 50 | def __init__(self, 51 | data, 52 | labels, 53 | img_dim=None, 54 | batch_size=32, 55 | num_classes=10, 56 | shuffle=True, 57 | drop_last=True, 58 | ): 59 | 60 | self._data = data 61 | self.data = self._data # initially without calling augment, the output data is not augmented 62 | self.labels = labels 63 | self.img_dim = img_dim 64 | self.batch_size = batch_size 65 | self.num_classes = num_classes 66 | self.shuffle = shuffle 67 | self.drop_last = drop_last 68 | self.on_epoch_end() 69 | 70 | def reset_augment(self): 71 | self.data = self._data 72 | 73 | def on_epoch_end(self): 74 | self.indices = np.arange(len(self._data)) 75 | if self.shuffle: 76 | np.random.shuffle(self.indices) 77 | 78 | def sample_labeled_data_batch(self, label, bs): 79 | # suffle indices every time 80 | indices = np.arange(len(self._data)) 81 | np.random.shuffle(indices) 82 | if isinstance(self.labels, list): 83 | labels = [self.labels[k] for k in indices] 84 | else: 85 | labels = self.labels[indices] 86 | matched_labels = np.array(labels) == int(label) 87 | matched_indices = [id for id, isMatched in enumerate(matched_labels) if isMatched] 88 | if len(matched_indices) - bs >=0: 89 | start_idx = np.random.randint(0, len(matched_indices)-bs) 90 | batch_indices = matched_indices[start_idx:start_idx + bs] 91 | else: 92 | print('Not enough matched data, required {}, but got {} instead'.format(bs, len(matched_indices))) 93 | batch_indices = matched_indices 94 | data_indices = indices[batch_indices] 95 | return [self.data[k] for k in data_indices], np.array([self.labels[k] for k in data_indices], dtype=self.labels[0].dtype) 96 | 97 | def __len__(self): 98 | if self.drop_last: 99 | return int(np.floor(len(self.data) / self.batch_size)) # drop the last batch 100 | else: 101 | return int(np.ceil(len(self.data) / self.batch_size)) # drop the last batch 102 | 103 | def __getitem__(self, idx): 104 | curr_batch = self.indices[idx*self.batch_size:(idx+1)*self.batch_size] 105 | batch_len = len(curr_batch) 106 | if isinstance(self.data, list) and isinstance(self.labels, list): 107 | return [self.data[k] for k in curr_batch], np.array([self.labels[k] for k in curr_batch], np.int32) 108 | else: 109 | return self.data[curr_batch], self.labels[curr_batch] 110 | 111 | class DataAugmentation(object): 112 | def __init__(self, num_classes, dataset, image_shape, ops_list=None, default_pre_aug=None, default_post_aug=None): 113 | self.ops, self.op_names = ops_list 114 | self.default_pre_aug = default_pre_aug 115 | self.default_post_aug = default_post_aug 116 | self.num_classes = num_classes 117 | self.dataset = dataset 118 | self.image_shape = image_shape 119 | if 'imagenet' in self.dataset: 120 | assert self.image_shape == (*IMAGENET_SIZE, 3) 121 | elif 'cifar' in self.dataset: 122 | assert self.image_shape == (32, 32, 3) 123 | else: 124 | raise Exception('Unrecognized dataset') 125 | 126 | def sequantially_augment(self, args): 127 | idx, img_, op_idxs, mags, aug_finish = args 128 | assert img_.dtype == np.uint8, 'Input images should be unporocessed, should stay in np.uint8' 129 | img = copy.deepcopy(img_) 130 | pil_img = Image.fromarray(img) # Convert to PIL.Image 131 | if self.default_pre_aug is not None: 132 | for op in self.default_pre_aug: 133 | pil_img = op(pil_img) 134 | if self.ops is not None: 135 | for op_idx, mag in zip(op_idxs, mags): 136 | op, minval, maxval = self.ops[op_idx] 137 | assert mag > -1e-5 and mag < 1. + 1e-5, 'magnitudes should be in the range of (0., 1.)' 138 | mag = mag * (maxval - minval) + minval 139 | pil_img = op(pil_img, mag) 140 | if self.default_post_aug is not None and self.use_post_aug: 141 | for op in self.default_post_aug: 142 | pil_img = op(pil_img, None) 143 | if 'cifar' in self.dataset: 144 | img = np.asarray(pil_img, dtype=np.uint8) 145 | return idx, img 146 | elif 'imagenet' in self.dataset: 147 | if aug_finish: 148 | pil_img = self.crop_IMAGENET(pil_img) 149 | img = np.asarray(pil_img, dtype=np.uint8) 150 | return idx, img 151 | else: 152 | raise Exception 153 | 154 | def postprocessing_standardization(self, pil_img): 155 | x = np.asarray(pil_img, dtype=np.float32) / 255. 156 | if 'cifar' in self.dataset: 157 | x = (x - CIFAR_MEANS) / CIFAR_STDS 158 | elif 'imagenet' in self.dataset: 159 | x = (x - IMAGENET_MEANS) / IMAGENET_STDS 160 | else: 161 | raise Exception('Unrecoginized dataset') 162 | return x 163 | 164 | def crop_IMAGENET(self, img): 165 | # cropping imagenet dataset to the same size 166 | if isinstance(img, np.ndarray): 167 | assert img.shape == (IMAGENET_SIZE[1], IMAGENET_SIZE[0], 3) and img.dtype==np.uint8, 'numpy array should be {}, but got {}. crop_IMAGENET does not apply to numpy array, but got {}'.format(IMAGENET_SIZE, img.size, img.dtype) 168 | return img 169 | w, h = img.size 170 | if w == IMAGENET_SIZE[0] and h == IMAGENET_SIZE[1]: 171 | return img 172 | return centerCrop_imagenet(img, None) 173 | 174 | def check_data_type(self, images, labels): 175 | assert images[0].dtype == np.uint8 176 | if 'imagenet' in self.dataset: 177 | assert type(labels[0]) == np.int32 178 | elif 'cifar' in self.dataset: 179 | assert type(labels[0]) == np.uint8 180 | else: 181 | raise Exception('Unrecognized dataset') 182 | 183 | def __call__(self, images, labels, samples_op, samples_mag, use_post_aug, pool=None, chunksize=None, aug_finish=True): 184 | self.check_data_type(images, labels) 185 | 186 | self.use_post_aug = use_post_aug 187 | self.batch_len = len(labels) 188 | if aug_finish: 189 | aug_imgs = np.empty([self.batch_len, *self.image_shape], dtype=np.float32) 190 | else: 191 | aug_imgs = [None]*self.batch_len 192 | aug_results = pool.imap_unordered(self.sequantially_augment, 193 | zip(range(self.batch_len), images, samples_op, samples_mag, [aug_finish]*self.batch_len), 194 | chunksize=math.ceil(float(self.batch_len) / float(pool._processes)) if chunksize is None else chunksize) 195 | for idx, img in aug_results: 196 | aug_imgs[idx] = img 197 | 198 | if aug_finish: 199 | aug_imgs = self.postprocessing_standardization(aug_imgs) 200 | 201 | return aug_imgs, labels -------------------------------------------------------------------------------- /imagenet_data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from torchvision.datasets.imagenet import * 4 | from torch import randperm, default_generator 5 | from torch._utils import _accumulate 6 | from torch.utils.data.dataset import Subset 7 | 8 | 9 | _DATA_TYPE = tf.float32 10 | 11 | CMYK_IMAGES = [ 12 | 'n01739381_1309.JPEG', 13 | 'n02077923_14822.JPEG', 14 | 'n02447366_23489.JPEG', 15 | 'n02492035_15739.JPEG', 16 | 'n02747177_10752.JPEG', 17 | 'n03018349_4028.JPEG', 18 | 'n03062245_4620.JPEG', 19 | 'n03347037_9675.JPEG', 20 | 'n03467068_12171.JPEG', 21 | 'n03529860_11437.JPEG', 22 | 'n03544143_17228.JPEG', 23 | 'n03633091_5218.JPEG', 24 | 'n03710637_5125.JPEG', 25 | 'n03961711_5286.JPEG', 26 | 'n04033995_2932.JPEG', 27 | 'n04258138_17003.JPEG', 28 | 'n04264628_27969.JPEG', 29 | 'n04336792_7448.JPEG', 30 | 'n04371774_5854.JPEG', 31 | 'n04596742_4225.JPEG', 32 | 'n07583066_647.JPEG', 33 | 'n13037406_4650.JPEG', 34 | ] 35 | 36 | PNG_IMAGES = ['n02105855_2933.JPEG'] 37 | 38 | class ImageNet(ImageFolder): 39 | """`ImageNet `_ 2012 Classification Dataset. 40 | Copied from torchvision, besides warning below. 41 | 42 | Args: 43 | root (string): Root directory of the ImageNet Dataset. 44 | split (string, optional): The dataset split, supports ``train``, or ``val``. 45 | transform (callable, optional): A function/transform that takes in an PIL image 46 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 47 | target_transform (callable, optional): A function/transform that takes in the 48 | target and transforms it. 49 | loader (callable, optional): A function to load an image given its path. 50 | 51 | Attributes: 52 | classes (list): List of the class name tuples. 53 | class_to_idx (dict): Dict with items (class_name, class_index). 54 | wnids (list): List of the WordNet IDs. 55 | wnid_to_idx (dict): Dict with items (wordnet_id, class_index). 56 | imgs (list): List of (image path, class_index) tuples 57 | targets (list): The class_index value for each image in the dataset 58 | 59 | WARN:: 60 | This is the same ImageNet class as in torchvision.datasets.imagenet, but it has the `ignore_archive` argument. 61 | This allows us to only copy the unzipped files before training. 62 | """ 63 | 64 | def __init__(self, root, split='train', download=None, ignore_archive=False, **kwargs): 65 | if download is True: 66 | msg = ("The dataset is no longer publicly accessible. You need to " 67 | "download the archives externally and place them in the root " 68 | "directory.") 69 | raise RuntimeError(msg) 70 | elif download is False: 71 | msg = ("The use of the download flag is deprecated, since the dataset " 72 | "is no longer publicly accessible.") 73 | warnings.warn(msg, RuntimeWarning) 74 | 75 | root = self.root = os.path.expanduser(root) 76 | self.split = verify_str_arg(split, "split", ("train", "val")) 77 | 78 | if not ignore_archive: 79 | self.parse_archives() 80 | wnid_to_classes = load_meta_file(self.root)[0] 81 | 82 | super(ImageNet, self).__init__(self.split_folder, **kwargs) 83 | self.root = root 84 | 85 | self.wnids = self.classes 86 | self.wnid_to_idx = self.class_to_idx 87 | self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] 88 | self.class_to_idx = {cls: idx 89 | for idx, clss in enumerate(self.classes) 90 | for cls in clss} 91 | 92 | def parse_archives(self): 93 | if not check_integrity(os.path.join(self.root, META_FILE)): 94 | parse_devkit_archive(self.root) 95 | 96 | if not os.path.isdir(self.split_folder): 97 | if self.split == 'train': 98 | parse_train_archive(self.root) 99 | elif self.split == 'val': 100 | parse_val_archive(self.root) 101 | 102 | @property 103 | def split_folder(self): 104 | return os.path.join(self.root, self.split) 105 | 106 | def extra_repr(self): 107 | return "Split: {split}".format(**self.__dict__) 108 | 109 | class ImageNet_DeepAA(ImageNet): 110 | def __init__(self, root, split='train', download=None, **kwargs): 111 | super(ImageNet_DeepAA, self).__init__(root, split=split, download=download, ignore_archive=True, **kwargs) 112 | _, self.labels_ = zip(*self.samples) 113 | 114 | def on_epoch_end(self): 115 | print('Dummy one_epoch_end for ImageNet dataset using torchvision') 116 | pass 117 | 118 | def sample_labeled_data_batch(self, label, val_bs): # generate val and train batch at the same time 119 | matched_indices = [id for id, lab in enumerate(self.labels_) if lab==label] 120 | matched_indices = np.array(matched_indices) 121 | assert len(matched_indices) > val_bs, 'Make sure the have enough data' 122 | np.random.shuffle(matched_indices) 123 | val_indices = matched_indices[:val_bs] 124 | 125 | val_samples, val_labels = zip(*[self[id] for id in val_indices]) 126 | val_samples = list(val_samples) 127 | val_labels = np.array(val_labels, dtype=np.int32) 128 | 129 | return val_samples, val_labels 130 | 131 | class Subset_ImageNet(Subset): 132 | def __init__(self, dataset, indices): 133 | super(Subset_ImageNet, self).__init__(dataset, indices) 134 | self.subset_labels_ = [self.dataset.labels_[k] for k in indices] 135 | 136 | 137 | def on_epoch_end(self): 138 | pass 139 | 140 | def sample_labeled_data_batch(self, label, val_bs): 141 | matched_indices = [self.indices[id] for id, lab in enumerate(self.subset_labels_) if lab == label] 142 | matched_indices = np.array(matched_indices) 143 | assert len(matched_indices) > val_bs, 'Make sure the have enough data' 144 | np.random.shuffle(matched_indices) 145 | val_indices = matched_indices[:val_bs] 146 | 147 | val_samples, val_labels = zip(*[self.dataset[id] for id in val_indices]) # applies transforms 148 | val_samples = list(val_samples) 149 | val_labels = np.array(val_labels, dtype=np.int32) 150 | 151 | return val_samples, val_labels 152 | 153 | def random_split_ImageNet(dataset, lengths, generator=default_generator): 154 | if sum(lengths) != len(dataset): 155 | raise ValueError('Sum of input lengths does not equal the length of the input dataset') 156 | indices = randperm(sum(lengths), generator=generator).tolist() 157 | return [Subset_ImageNet(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)] 158 | 159 | def get_imagenet_split(val_size=400000, train_sep_size=100000, dataroot='./data', n_GPU=None, seed=300): 160 | transform = lambda img: np.array(img, dtype=np.uint8) 161 | total_trainset = ImageNet_DeepAA(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform) 162 | testset = ImageNet_DeepAA(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform) 163 | 164 | N_per_shard = (len(total_trainset) - val_size - train_sep_size)//n_GPU 165 | remaining_data = len(total_trainset) - val_size - train_sep_size - n_GPU * N_per_shard 166 | if remaining_data > 0: 167 | splits = [val_size, train_sep_size, *[N_per_shard]*n_GPU, remaining_data] 168 | else: 169 | splits = [val_size, train_sep_size, *[N_per_shard]*n_GPU] 170 | all_ds = random_split_ImageNet(total_trainset, 171 | lengths=splits, 172 | generator=torch.Generator().manual_seed(seed)) 173 | val_ds = all_ds[0] 174 | train_ds_sep = all_ds[1] 175 | pretrain_ds_splits = all_ds[2:2+n_GPU] 176 | return total_trainset, val_ds, train_ds_sep, pretrain_ds_splits, testset -------------------------------------------------------------------------------- /images/DeepAA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/images/DeepAA.png -------------------------------------------------------------------------------- /images/DeepAA_slideslive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/images/DeepAA_slideslive.png -------------------------------------------------------------------------------- /images/magnitude_distribution_cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/images/magnitude_distribution_cifar.png -------------------------------------------------------------------------------- /images/magnitude_distribution_imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/images/magnitude_distribution_imagenet.png -------------------------------------------------------------------------------- /images/operation_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/images/operation_distribution.png -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.optimizers.schedules import LearningRateSchedule 3 | from tensorflow.python.framework import ops 4 | from tensorflow.python.ops import math_ops, control_flow_ops 5 | 6 | class GradualWarmup_Cosine_Scheduler(LearningRateSchedule): 7 | def __init__(self, starting_lr, initial_lr, ending_lr, warmup_steps, total_steps, name=None): 8 | super(GradualWarmup_Cosine_Scheduler, self).__init__() 9 | 10 | self.starting_lr = starting_lr 11 | self.initial_lr = initial_lr 12 | self.ending_lr = ending_lr 13 | self.warmup_steps = warmup_steps 14 | self.total_steps = total_steps 15 | self.name = name 16 | 17 | def __call__(self, step): 18 | with ops.name_scope_v2(self.name or 'GradualWarmup_Cosine') as name: 19 | initial_lr = ops.convert_to_tensor_v2(self.initial_lr, name='initial_learning_rate') 20 | dtype = initial_lr.dtype 21 | starting_lr = math_ops.cast(self.starting_lr, dtype) 22 | ending_lr = math_ops.cast(self.ending_lr, dtype) 23 | warmup_steps = math_ops.cast(self.warmup_steps, dtype) 24 | total_steps = math_ops.cast(self.total_steps, dtype) 25 | one = math_ops.cast(1.0, dtype) 26 | point5 = math_ops.cast(0.5, dtype) 27 | pi = math_ops.cast(3.1415926536, dtype) 28 | step = math_ops.cast(step, dtype) 29 | 30 | lr = tf.cond(step < warmup_steps, 31 | true_fn=lambda: self._warmup_schedule(starting_lr, initial_lr, step, warmup_steps), 32 | false_fn=lambda: self._cosine_annealing_schedule(initial_lr, ending_lr, step, warmup_steps, total_steps, pi, 33 | point5, one)) 34 | return lr 35 | 36 | def _warmup_schedule(self, starting_lr, initial_lr, step, warmup_steps): 37 | ratio = math_ops.divide(step, warmup_steps) 38 | lr = math_ops.add(starting_lr, 39 | math_ops.multiply(initial_lr - starting_lr, ratio)) 40 | return lr 41 | 42 | def _cosine_annealing_schedule(self, initial_lr, ending_lr, step, warmup_steps, total_steps, pi, point5, one): 43 | ratio = math_ops.divide(step - warmup_steps, total_steps - warmup_steps) 44 | cosine_ratio_pi = math_ops.cos(math_ops.multiply(ratio, pi)) 45 | second_part = math_ops.multiply(point5, 46 | math_ops.multiply(initial_lr - ending_lr, 47 | one + cosine_ratio_pi)) 48 | lr = math_ops.add(ending_lr, second_part) 49 | return lr 50 | 51 | 52 | def get_config(self): 53 | return { 54 | 'starting_lr': self.starting_lr, 55 | 'initial_lr': self.initial_lr, 56 | 'ending_lr': self.ending_lr, 57 | 'warmup_steps': self.warmup_steps, 58 | 'total_steps': self.total_steps, 59 | 'name': self.name 60 | } -------------------------------------------------------------------------------- /policy.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import math 4 | import json 5 | 6 | from tensorflow_probability import distributions as tfd 7 | 8 | from resnet import Resnet 9 | 10 | CIFAR_MEANS = np.array([0.49139968, 0.48215841, 0.44653091], dtype=np.float32) 11 | CIFAR_STDS = np.array([0.2023, 0.1994, 0.2010], dtype=np.float32) 12 | 13 | SVHN_MEANS = np.array([0.4379, 0.4440, 0.4729], dtype=np.float32) 14 | SVHN_STDS = np.array([0.1980, 0.2010, 0.1970], dtype=np.float32) 15 | 16 | IMAGENET_MEANS = np.array([0.485, 0.456, 0.406], dtype=np.float32) 17 | IMAGENET_STDS = np.array([0.229, 0.224, 0.225], dtype=np.float32) 18 | 19 | class DA_Policy_logits(tf.keras.Model): 20 | def __init__(self, l_ops, l_mags, l_uniq, op_names, ops_mid_magnitude, 21 | N_repeat_random, available_policies, policy_init='identity'): 22 | super().__init__() 23 | self.l_uniq = l_uniq 24 | self.l_ops = l_ops 25 | self.l_mags = l_mags 26 | self.N_repeat_random = N_repeat_random 27 | self.available_policies = available_policies 28 | 29 | if policy_init == 'uniform': 30 | init_value = tf.constant([0.0]*len(available_policies), dtype=tf.float32) 31 | elif policy_init == 'identity': 32 | init_value = tf.constant([8.0] + [0.0]*(len(available_policies)-1), dtype=tf.float32) 33 | init_value = init_value - tf.reduce_mean(init_value) 34 | else: 35 | raise Exception 36 | self.logits = tf.Variable(initial_value=init_value, trainable=True) 37 | 38 | self.ops_mid_magnitude = ops_mid_magnitude 39 | self.unique_policy = self._get_unique_policy(op_names, l_ops, l_mags) 40 | self.N_random, self.repeat_cfg, self.reduce_random_mat = self._get_repeat_random(op_names, l_ops, l_mags, 41 | l_uniq, N_repeat_random) 42 | self.act = tf.nn.softmax 43 | 44 | def sample(self, images_orig, images, onehot_ops_mags, augNum): 45 | bs = len(images_orig) 46 | probs = self.act(self.logits, axis=-1) 47 | dist = tfd.Categorical(probs=probs) 48 | samples_om = dist.sample(augNum*bs).numpy() # (augNum, bs) 49 | 50 | ops_dense, mags_dense, reduce_random_mat, ops_mags_idx, probs, probs_exp = self.get_dense_aug(images, repeat_random_ops=False) 51 | ops = ops_dense[samples_om] 52 | mags = mags_dense[samples_om] 53 | ops_mags_idx_sample = ops_mags_idx[samples_om] 54 | probs_sample = probs.numpy()[samples_om] 55 | 56 | return ops, mags, ops_mags_idx_sample, probs_sample 57 | 58 | def probs(self, images_orig, images, onehot_ops_mags, training): 59 | bs = len(images_orig) 60 | probs = self.act(self.logits, axis=-1) 61 | probs = tf.repeat(probs[tf.newaxis], bs, axis=0) 62 | return probs 63 | 64 | def get_dense_aug(self, images, repeat_random_ops): 65 | ops_uniq, mags_uniq = self.unique_policy 66 | ops_dense = np.squeeze(ops_uniq)[self.available_policies] 67 | mags_dense = np.squeeze(mags_uniq)[self.available_policies] 68 | ops_mags_idx = self.available_policies 69 | if repeat_random_ops: 70 | isRepeat = [np.any(np.array(ops_dense == repeat_op_idx), axis=1) for repeat_op_idx in self.repeat_ops_idx] 71 | isRepeat = np.stack(isRepeat, axis=1) 72 | isRepeat = np.any(isRepeat, axis=1) 73 | nRepeat = [self.N_repeat_random if isrepeat else 1 for isrepeat in isRepeat] 74 | 75 | ops_dense = np.repeat(ops_dense, nRepeat, axis=0) 76 | mags_dense = np.repeat(mags_dense, nRepeat, axis=0) 77 | reduce_random_mat = np.eye(len(self.available_policies)) / np.array(nRepeat, dtype=np.float32) 78 | reduce_random_mat = np.repeat(reduce_random_mat, nRepeat, axis=1) 79 | else: 80 | nRepeat = [1] * len(self.available_policies) 81 | reduce_random_mat = np.eye(len(self.available_policies)) 82 | 83 | probs = self.act(self.logits) 84 | probs_exp = np.repeat(probs/np.array(nRepeat, dtype=np.float32), nRepeat, axis=0) 85 | return ops_dense, mags_dense, reduce_random_mat, ops_mags_idx, probs, probs_exp 86 | 87 | def _get_unique_policy(self, op_names, l_ops, l_mags): 88 | names_modified = [op_name.split(':')[0] for op_name in op_names] 89 | ops_list, mags_list = [], [] 90 | repeat_ops_idx = [] 91 | for k_name, name in enumerate(names_modified): 92 | if self.ops_mid_magnitude[name] == 'random': 93 | repeat_ops_idx.append(k_name) 94 | ops_sub, mags_sub = np.array([[k_name]], dtype=np.int32), np.array([[(l_mags - 1) // 2]], dtype=np.int32) 95 | elif self.ops_mid_magnitude[name] is not None and self.ops_mid_magnitude[name]>=0 and self.ops_mid_magnitude[name]<=l_mags-1: 96 | ops_sub = k_name * np.ones([l_mags - 1, 1], dtype=np.int32) 97 | mags_sub = np.array([l for l in range(l_mags) if l != self.ops_mid_magnitude[name]], dtype=np.int32)[:, np.newaxis] 98 | elif self.ops_mid_magnitude[name] is not None and self.ops_mid_magnitude[name]<0: #or self.ops_mid_magnitude[name]>l_mags-1): 99 | ops_sub = k_name * np.ones([l_mags, 1], dtype=np.int32) 100 | mags_sub = np.arange(l_mags, dtype=np.int32)[:, np.newaxis] 101 | elif self.ops_mid_magnitude[name] is None: 102 | ops_sub, mags_sub = np.array([[k_name]], dtype=np.int32), np.array([[(l_mags - 1) // 2]], dtype=np.int32) 103 | else: 104 | raise Exception('Unrecognized middle magnitude') 105 | ops_list.append(ops_sub) 106 | mags_list.append(mags_sub) 107 | ops = np.concatenate(ops_list, axis=0) 108 | mags = np.concatenate(mags_list, axis=0) 109 | self.repeat_ops_idx = repeat_ops_idx 110 | return ops.astype(np.int32), mags.astype(np.int32) 111 | 112 | def _get_repeat_random(self, op_names, l_ops, l_mags, l_uniq, N_repeat_random): 113 | names_modified = [op_name.split(':')[0] for op_name in op_names] 114 | N_random = sum([1 for name in names_modified if self.ops_mid_magnitude[name]=='random']) 115 | repeat_cfg = [] 116 | for k_name, name in enumerate(names_modified): 117 | if self.ops_mid_magnitude[name] == 'random': 118 | repeat_cfg.append(N_repeat_random) # we may repeat random operations for N_repeat_random times 119 | elif self.ops_mid_magnitude[name] is not None and self.ops_mid_magnitude[name] == -1: 120 | repeat_cfg.append([1]*l_mags) 121 | elif self.ops_mid_magnitude[name] is not None and self.ops_mid_magnitude[name] >= 0 and self.ops_mid_magnitude[name]<=l_mags-1: 122 | repeat_cfg.extend([1]*(l_mags-1)) 123 | elif self.ops_mid_magnitude[name] is None: 124 | repeat_cfg.append(1) 125 | else: 126 | raise Exception 127 | repeat_cfg = np.array(repeat_cfg, dtype=np.int32) 128 | 129 | reduce_mat = np.eye(l_uniq)/repeat_cfg[np.newaxis].astype(np.float) 130 | reduce_mat = np.repeat(reduce_mat, repeat_cfg, axis=1) 131 | return N_random, repeat_cfg, reduce_mat 132 | 133 | @property 134 | def idx_removed_redundant(self): 135 | idx_removed_redundant = np.concatenate([[1] if rep == 1 else [1]+[0]*(rep-1) for rep in self.repeat_cfg ]).nonzero()[0] 136 | assert len(idx_removed_redundant) == self.l_uniq, 'removing the repeated random operations' 137 | return idx_removed_redundant -------------------------------------------------------------------------------- /policy_port/policy_DeepAA_cifar_1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/policy_port/policy_DeepAA_cifar_1.npz -------------------------------------------------------------------------------- /policy_port/policy_DeepAA_cifar_2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/policy_port/policy_DeepAA_cifar_2.npz -------------------------------------------------------------------------------- /policy_port/policy_DeepAA_imagenet_1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/policy_port/policy_DeepAA_imagenet_1.npz -------------------------------------------------------------------------------- /policy_port/policy_DeepAA_imagenet_2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/policy_port/policy_DeepAA_imagenet_2.npz -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/wbaek/theconf 2 | git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git 3 | git+https://github.com/ildoonet/pystopwatch2.git 4 | 5 | keras==2.4.0 6 | tensorflow-datasets==4.3.0 7 | tensorflow-probability==0.13.0 8 | matplotlib 9 | seaborn 10 | pandas 11 | packaging 12 | 13 | colored 14 | pretrainedmodels 15 | tqdm 16 | tensorboardx 17 | sklearn 18 | matplotlib 19 | psutil 20 | requests 21 | Pillow -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | # ref: https://github.com/gahaalt/resnets-in-tensorflow2/blob/master/Models/Resnets.py 4 | _bn_momentum = 0.9 5 | 6 | def regularized_padded_conv(*args, **kwargs): 7 | return tf.keras.layers.Conv2D(*args, **kwargs, padding='same', kernel_regularizer=_regularizer, bias_regularizer=_regularizer, 8 | kernel_initializer='he_normal', use_bias=True) 9 | 10 | 11 | def bn_relu(x): 12 | x = tf.keras.layers.experimental.SyncBatchNormalization(momentum=_bn_momentum)(x) 13 | return tf.keras.layers.ReLU()(x) 14 | 15 | 16 | def shortcut(x, filters, stride, mode): 17 | if x.shape[-1] == filters: # maybe and stride==1 18 | return x 19 | elif mode == 'B': 20 | return regularized_padded_conv(filters, 1, strides=stride)(x) 21 | elif mode == 'B_original': 22 | x = regularized_padded_conv(filters, 1, strides=stride)(x) 23 | return tf.keras.layers.experimental.SyncBatchNormalization(momentum=_bn_momentum)(x) 24 | elif mode == 'A': 25 | return tf.pad(tf.keras.layers.MaxPool2D(1, stride)(x) if stride > 1 else x, 26 | paddings=[(0, 0), (0, 0), (0, 0), (0, filters - x.shape[-1])]) 27 | else: 28 | raise KeyError("Parameter shortcut_type not recognized!") 29 | 30 | 31 | def original_block(x, filters, stride=1, **kwargs): 32 | c1 = regularized_padded_conv(filters, 3, strides=stride)(x) 33 | c2 = regularized_padded_conv(filters, 3)(bn_relu(c1)) 34 | c2 = tf.keras.layers.experimental.SyncBatchNormalization(momentum=_bn_momentum)(c2) 35 | 36 | mode = 'B_original' if _shortcut_type == 'B' else _shortcut_type 37 | x = shortcut(x, filters, stride, mode=mode) 38 | return tf.keras.layers.ReLU()(x + c2) 39 | 40 | 41 | def preactivation_block(x, filters, stride=1, preact_block=False): 42 | flow = bn_relu(x) 43 | 44 | c1 = regularized_padded_conv(filters, 3)(flow) 45 | if _dropout: 46 | c1 = tf.keras.layers.Dropout(_dropout)(c1) 47 | 48 | c2 = regularized_padded_conv(filters, 3, strides=stride)(bn_relu(c1)) 49 | x = shortcut(x, filters, stride, mode=_shortcut_type) 50 | return x + c2 51 | 52 | 53 | def bootleneck_block(x, filters, stride=1, preact_block=False): 54 | flow = bn_relu(x) 55 | if preact_block: 56 | x = flow 57 | 58 | c1 = regularized_padded_conv(filters // _bootleneck_width, 1)(flow) 59 | c2 = regularized_padded_conv(filters // _bootleneck_width, 3, strides=stride)(bn_relu(c1)) 60 | c3 = regularized_padded_conv(filters, 1)(bn_relu(c2)) 61 | x = shortcut(x, filters, stride, mode=_shortcut_type) 62 | return x + c3 63 | 64 | 65 | def group_of_blocks(x, block_type, num_blocks, filters, stride, block_idx=0): 66 | global _preact_shortcuts 67 | preact_block = True if _preact_shortcuts or block_idx == 0 else False 68 | 69 | x = block_type(x, filters, stride, preact_block=preact_block) 70 | for i in range(num_blocks - 1): 71 | x = block_type(x, filters) 72 | return x 73 | 74 | 75 | def Resnet(input_shape, n_classes, l2_reg=1e-4, group_sizes=(2, 2, 2), features=(16, 32, 64), strides=(1, 2, 2), 76 | shortcut_type='B', block_type='preactivated', first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, 77 | dropout=0, cardinality=1, bootleneck_width=4, preact_shortcuts=True, 78 | final_dense_kernel_initializer=None, final_dense_bias_initializer=None): 79 | global _regularizer, _shortcut_type, _preact_projection, _dropout, _cardinality, _bootleneck_width, _preact_shortcuts 80 | _bootleneck_width = bootleneck_width # used in ResNeXts and bootleneck blocks 81 | _regularizer = tf.keras.regularizers.l2(l2_reg) 82 | _shortcut_type = shortcut_type # used in blocks 83 | _cardinality = cardinality # used in ResNeXts 84 | _dropout = dropout # used in Wide ResNets 85 | _preact_shortcuts = preact_shortcuts 86 | 87 | block_types = {'preactivated': preactivation_block, 88 | 'bootleneck': bootleneck_block, 89 | 'original': original_block} 90 | 91 | selected_block = block_types[block_type] 92 | inputs = tf.keras.layers.Input(shape=input_shape) 93 | flow = regularized_padded_conv(**first_conv)(inputs) 94 | 95 | if block_type == 'original': 96 | flow = bn_relu(flow) 97 | 98 | for block_idx, (group_size, feature, stride) in enumerate(zip(group_sizes, features, strides)): 99 | flow = group_of_blocks(flow, 100 | block_type=selected_block, 101 | num_blocks=group_size, 102 | block_idx=block_idx, 103 | filters=feature, 104 | stride=stride) 105 | 106 | if block_type != 'original': 107 | flow = bn_relu(flow) 108 | 109 | flow = tf.keras.layers.GlobalAveragePooling2D()(flow) 110 | 111 | if final_dense_kernel_initializer is not None: 112 | assert final_dense_bias_initializer is not None, 'make sure kernel and bias initializer is not None at the same time' 113 | outputs = tf.keras.layers.Dense(n_classes, kernel_regularizer=_regularizer, 114 | kernel_initializer=final_dense_kernel_initializer, 115 | bias_initializer=final_dense_bias_initializer)(flow) 116 | else: 117 | outputs = tf.keras.layers.Dense(n_classes, kernel_regularizer=_regularizer)(flow) 118 | model = tf.keras.Model(inputs=inputs, outputs=outputs) 119 | return model 120 | 121 | 122 | def load_weights_func(model, model_name): 123 | try: 124 | model.load_weights(os.path.join('saved_models', model_name + '.tf')) 125 | except tf.errors.NotFoundError: 126 | print("No weights found for this model!") 127 | return model 128 | 129 | 130 | def cifar_resnet20(block_type='original', shortcut_type='A', l2_reg=1e-4, load_weights=False, input_shape=None, n_classes=None): 131 | model = Resnet(input_shape=input_shape, n_classes=n_classes, l2_reg=l2_reg, group_sizes=(3, 3, 3), features=(16, 32, 64), 132 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, 133 | shortcut_type=shortcut_type, 134 | block_type=block_type, preact_shortcuts=False) 135 | if load_weights: model = load_weights_func(model, 'cifar_resnet20') 136 | return model 137 | 138 | 139 | def cifar_resnet32(block_type='original', shortcut_type='A', l2_reg=1e-4, load_weights=False, input_shape=None): 140 | model = Resnet(input_shape=input_shape, n_classes=10, l2_reg=l2_reg, group_sizes=(5, 5, 5), features=(16, 32, 64), 141 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, 142 | shortcut_type=shortcut_type, 143 | block_type=block_type, preact_shortcuts=False) 144 | if load_weights: model = load_weights_func(model, 'cifar_resnet32') 145 | return model 146 | 147 | 148 | def cifar_resnet44(block_type='original', shortcut_type='A', l2_reg=1e-4, load_weights=False, input_shape=None): 149 | model = Resnet(input_shape=input_shape, n_classes=10, l2_reg=l2_reg, group_sizes=(7, 7, 7), features=(16, 32, 64), 150 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, 151 | shortcut_type=shortcut_type, 152 | block_type=block_type, preact_shortcuts=False) 153 | if load_weights: model = load_weights_func(model, 'cifar_resnet44') 154 | return model 155 | 156 | 157 | def cifar_resnet56(block_type='original', shortcut_type='A', l2_reg=1e-4, load_weights=False, input_shape=None): 158 | model = Resnet(input_shape=input_shape, n_classes=10, l2_reg=l2_reg, group_sizes=(9, 9, 9), features=(16, 32, 64), 159 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, 160 | shortcut_type=shortcut_type, 161 | block_type=block_type, preact_shortcuts=False) 162 | if load_weights: model = load_weights_func(model, 'cifar_resnet56') 163 | return model 164 | 165 | 166 | def cifar_resnet110(block_type='preactivated', shortcut_type='B', l2_reg=1e-4, load_weights=False, input_shape=None): 167 | model = Resnet(input_shape=input_shape, n_classes=10, l2_reg=l2_reg, group_sizes=(18, 18, 18), 168 | features=(16, 32, 64), 169 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, 170 | shortcut_type=shortcut_type, 171 | block_type=block_type, preact_shortcuts=False) 172 | if load_weights: model = load_weights_func(model, 'cifar_resnet110') 173 | return model 174 | 175 | 176 | def cifar_resnet164(shortcut_type='B', load_weights=False, l2_reg=1e-4, input_shape=None): 177 | model = Resnet(input_shape=input_shape, n_classes=10, l2_reg=l2_reg, group_sizes=(18, 18, 18), 178 | features=(64, 128, 256), 179 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, 180 | shortcut_type=shortcut_type, 181 | block_type='bootleneck', preact_shortcuts=True) 182 | if load_weights: model = load_weights_func(model, 'cifar_resnet164') 183 | return model 184 | 185 | 186 | def cifar_resnet1001(shortcut_type='B', load_weights=False, l2_reg=1e-4, input_shape=None): 187 | model = Resnet(input_shape=input_shape, n_classes=10, l2_reg=l2_reg, group_sizes=(111, 111, 111), 188 | features=(64, 128, 256), 189 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, 190 | shortcut_type=shortcut_type, 191 | block_type='bootleneck', preact_shortcuts=True) 192 | if load_weights: model = load_weights_func(model, 'cifar_resnet1001') 193 | return model 194 | 195 | 196 | def cifar_wide_resnet(N, K, block_type='preactivated', shortcut_type='B', dropout=0, l2_reg=2.5e-4, n_classes=None, preact_shortcuts=False, input_shape=None): 197 | assert (N - 4) % 6 == 0, "N-4 has to be divisible by 6" 198 | lpb = (N - 4) // 6 # layers per block - since N is total number of convolutional layers in Wide ResNet 199 | model = Resnet(input_shape=input_shape, n_classes=n_classes, l2_reg=l2_reg, group_sizes=(lpb, lpb, lpb), 200 | features=(16 * K, 32 * K, 64 * K), 201 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, 202 | shortcut_type=shortcut_type, 203 | block_type=block_type, dropout=dropout, preact_shortcuts=preact_shortcuts) 204 | return model 205 | 206 | 207 | def cifar_WRN_16_4(shortcut_type='B', load_weights=False, dropout=0, l2_reg=2.5e-4, input_shape=None): 208 | model = cifar_wide_resnet(16, 4, 'preactivated', shortcut_type, dropout=dropout, l2_reg=l2_reg, input_shape=input_shape) 209 | if load_weights: model = load_weights_func(model, 'cifar_WRN_16_4') 210 | return model 211 | 212 | 213 | def cifar_WRN_40_4(shortcut_type='B', load_weights=False, dropout=0, l2_reg=2.5e-4, input_shape=None): 214 | model = cifar_wide_resnet(40, 4, 'preactivated', shortcut_type, dropout=dropout, l2_reg=l2_reg, input_shape=input_shape) 215 | if load_weights: model = load_weights_func(model, 'cifar_WRN_40_4') 216 | return model 217 | 218 | 219 | def cifar_WRN_16_8(shortcut_type='B', load_weights=False, dropout=0, l2_reg=2.5e-4, input_shape=None): 220 | model = cifar_wide_resnet(16, 8, 'preactivated', shortcut_type, dropout=dropout, l2_reg=l2_reg, input_shape=input_shape) 221 | if load_weights: model = load_weights_func(model, 'cifar_WRN_16_8') 222 | return model 223 | 224 | 225 | def cifar_WRN_28_10(shortcut_type='B', load_weights=False, dropout=0, l2_reg=2.5e-4, n_classes=None, preact_shortcuts=False, input_shape=None): 226 | model = cifar_wide_resnet(28, 10, 'preactivated', shortcut_type, dropout=dropout, l2_reg=l2_reg, n_classes = n_classes, preact_shortcuts=preact_shortcuts, input_shape=input_shape) 227 | return model 228 | 229 | def cifar_WRN_28_2(shortcut_type='B', load_weights=False, dropout=0, l2_reg=2.5e-4, n_classes=None, preact_shortcuts=False, input_shape=None): 230 | model = cifar_wide_resnet(28, 2, 'preactivated', shortcut_type, dropout=dropout, l2_reg=l2_reg, n_classes = n_classes, preact_shortcuts=preact_shortcuts, input_shape=input_shape) 231 | return model 232 | 233 | 234 | def cifar_WRN_40_2(shortcut_type='B', load_weights=False, dropout=0, l2_reg=2.5e-4, n_classes=None, preact_shortcuts=False, input_shape=None): 235 | model = cifar_wide_resnet(40, 2, 'preactivated', shortcut_type, dropout=dropout, l2_reg=l2_reg, n_classes = n_classes, preact_shortcuts=preact_shortcuts, input_shape=input_shape) 236 | return model 237 | 238 | def cifar_resnext(N, cardinality, width, shortcut_type='B', ): 239 | assert (N - 3) % 9 == 0, "N-4 has to be divisible by 6" 240 | lpb = (N - 3) // 9 # layers per block - since N is total number of convolutional layers in Wide ResNet 241 | model = Resnet(input_shape=(32, 32, 3), n_classes=10, l2_reg=1e-4, group_sizes=(lpb, lpb, lpb), 242 | features=(16 * width, 32 * width, 64 * width), 243 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, 244 | shortcut_type=shortcut_type, 245 | block_type='resnext', cardinality=cardinality, width=width) 246 | return model 247 | 248 | 249 | if __name__ == '__main__': 250 | model = cifar_WRN_28_10(dropout=0, l2_reg=5e-4/2., preact_shortcuts=False, n_classes=10) -------------------------------------------------------------------------------- /resnet_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | # ref: https://github.com/gahaalt/resnets-in-tensorflow2/blob/master/Models/Resnets.py 4 | _bn_momentum = 0.9 5 | 6 | def regularized_padded_conv(*args, **kwargs): 7 | return tf.keras.layers.Conv2D(*args, **kwargs, padding='same', kernel_regularizer=_regularizer, bias_regularizer=_regularizer, 8 | kernel_initializer='he_normal', use_bias=False) 9 | 10 | 11 | def bn_relu(x, gamma_initializer='ones'): 12 | x = tf.keras.layers.experimental.SyncBatchNormalization(momentum=_bn_momentum, gamma_initializer=gamma_initializer)(x) 13 | return tf.keras.layers.ReLU()(x) 14 | 15 | 16 | def shortcut(x, filters, stride, mode): 17 | if x.shape[-1] == filters: # maybe and stride==1 18 | return x 19 | elif mode == 'B': 20 | return regularized_padded_conv(filters, 1, strides=stride)(x) 21 | elif mode == 'B_original': 22 | x = regularized_padded_conv(filters, 1, strides=stride)(x) 23 | return tf.keras.layers.experimental.SyncBatchNormalization(momentum=_bn_momentum)(x) 24 | elif mode == 'A': 25 | return tf.pad(tf.keras.layers.MaxPool2D(1, stride)(x) if stride > 1 else x, 26 | paddings=[(0, 0), (0, 0), (0, 0), (0, filters - x.shape[-1])]) 27 | else: 28 | raise KeyError("Parameter shortcut_type not recognized!") 29 | 30 | 31 | def original_block(x, filters, stride=1, **kwargs): 32 | c1 = regularized_padded_conv(filters, 3, strides=stride)(x) 33 | c2 = regularized_padded_conv(filters, 3)(bn_relu(c1)) 34 | c2 = tf.keras.layers.experimental.SyncBatchNormalization(momentum=_bn_momentum)(c2) 35 | 36 | mode = 'B_original' if _shortcut_type == 'B' else _shortcut_type 37 | x = shortcut(x, filters, stride, mode=mode) 38 | return tf.keras.layers.ReLU()(x + c2) 39 | 40 | 41 | def bootleneck_block(x, filters, stride=1, preact_block=False): # preact_block==False 42 | # flow = bn_relu(x) 43 | # if preact_block: 44 | # x = flow 45 | residual = x 46 | c1 = regularized_padded_conv(filters // _bootleneck_width, 1)(bn_relu(x)) 47 | c2 = regularized_padded_conv(filters // _bootleneck_width, 3, strides=stride)(bn_relu(c1)) 48 | c3 = regularized_padded_conv(filters, 1)(bn_relu(c2)) 49 | if x.shape[-1] != filters or stride != 1: 50 | residual = shortcut(x, filters, stride, mode=_shortcut_type) 51 | return tf.keras.layers.ReLU()(residual + tf.keras.layers.experimental.SyncBatchNormalization(momentum=_bn_momentum, gamma_initializer='zeros')(c3)) 52 | 53 | 54 | def group_of_blocks(x, block_type, num_blocks, filters, stride, block_idx=0): 55 | global _preact_shortcuts 56 | preact_block = False 57 | 58 | x = block_type(x, filters, stride, preact_block=preact_block) 59 | for i in range(num_blocks - 1): 60 | x = block_type(x, filters) 61 | return x 62 | 63 | 64 | def Resnet(input_shape, n_classes, l2_reg=1e-4, group_sizes=(2, 2, 2), features=(16, 32, 64), strides=(1, 2, 2), 65 | shortcut_type='B', block_type='preactivated', first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, 66 | dropout=0, cardinality=1, bootleneck_width=4, preact_shortcuts=True): 67 | global _regularizer, _shortcut_type, _preact_projection, _dropout, _cardinality, _bootleneck_width, _preact_shortcuts 68 | _bootleneck_width = bootleneck_width # used in ResNeXts and bootleneck blocks 69 | _regularizer = tf.keras.regularizers.l2(l2_reg) 70 | _shortcut_type = shortcut_type # used in blocks 71 | _cardinality = cardinality # used in ResNeXts 72 | _dropout = dropout # used in Wide ResNets 73 | _preact_shortcuts = preact_shortcuts 74 | 75 | block_types = { 76 | # 'preactivated': preactivation_block, 77 | 'bootleneck': bootleneck_block, 78 | 'original': original_block 79 | } 80 | 81 | selected_block = block_types[block_type] 82 | inputs = tf.keras.layers.Input(shape=input_shape) 83 | flow = regularized_padded_conv(**first_conv)(inputs) 84 | 85 | # if block_type == 'original': 86 | flow = bn_relu(flow) 87 | flow = tf.keras.layers.MaxPool2D(pool_size=(3,3), strides=2, padding='same')(flow) 88 | 89 | for block_idx, (group_size, feature, stride) in enumerate(zip(group_sizes, features, strides)): 90 | flow = group_of_blocks(flow, 91 | block_type=selected_block, 92 | num_blocks=group_size, 93 | block_idx=block_idx, 94 | filters=feature, 95 | stride=stride) 96 | 97 | # if block_type != 'original': 98 | # flow = bn_relu(flow) 99 | 100 | flow = tf.keras.layers.GlobalAveragePooling2D()(flow) 101 | 102 | outputs = tf.keras.layers.Dense(n_classes, kernel_regularizer=_regularizer, bias_regularizer=_regularizer, use_bias=True)(flow) 103 | model = tf.keras.Model(inputs=inputs, outputs=outputs) 104 | return model 105 | 106 | def imagenet_resnet50(block_type='bootleneck', shortcut_type='B_original', l2_reg=0.5e-4, load_weights=False, input_shape=(224,224,3), n_classes=1000): 107 | bootleneck_width = 4 108 | model = Resnet(input_shape=input_shape, n_classes=n_classes, l2_reg=l2_reg, group_sizes=(3,4,6,3), 109 | features=(64*bootleneck_width, 128*bootleneck_width, 256*bootleneck_width, 512*bootleneck_width), 110 | strides=(1, 2, 2, 2), first_conv={"filters": 64, "kernel_size": 7, "strides": 2}, 111 | shortcut_type=shortcut_type, 112 | block_type=block_type, preact_shortcuts=False, 113 | bootleneck_width=bootleneck_width) 114 | return model 115 | 116 | def imagenet_resnet50_pretrained(input_shape, n_classes, l2_reg): 117 | _regularizer = tf.keras.regularizers.l2(l2_reg) 118 | inputs = tf.keras.layers.Input(shape=input_shape) 119 | base_model = tf.keras.applications.resnet50.ResNet50(include_top=False, input_shape=input_shape, 120 | pooling='avg', weights='imagenet') 121 | base_model.trainable = False 122 | x = base_model(inputs, training=False) # do not update batch augmentation 123 | outputs = tf.keras.layers.Dense(n_classes, kernel_regularizer=_regularizer, bias_regularizer=_regularizer, use_bias=True)(x) 124 | model = tf.keras.Model(inputs=inputs, outputs=outputs) 125 | return model 126 | 127 | def imagenet_resnet18(block_type='original', shortcut_type='B_original', l2_reg=0.5e-4, load_weights=False, input_shape=(224,224,3), n_classes=1000): 128 | model = Resnet(input_shape=input_shape, n_classes=n_classes, l2_reg=l2_reg, group_sizes=(2,2,2,2), 129 | features=(64, 128, 256, 512), 130 | strides=(1, 2, 2, 2), first_conv={"filters": 64, "kernel_size": 7, "strides": 2}, 131 | shortcut_type=shortcut_type, 132 | block_type=block_type, preact_shortcuts=False, 133 | bootleneck_width=None) 134 | return model 135 | 136 | def load_weights_func(model, model_name): 137 | try: 138 | model.load_weights(os.path.join('saved_models', model_name + '.tf')) 139 | except tf.errors.NotFoundError: 140 | print("No weights found for this model!") 141 | return model 142 | 143 | 144 | if __name__ == '__main__': 145 | model = imagenet_resnet50() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import matplotlib 5 | # configure backend here 6 | matplotlib.use('Agg') 7 | # matplotlib.use('tkagg') 8 | import matplotlib.pyplot as plt 9 | import matplotlib.patheffects as PathEffects 10 | from mpl_toolkits.axes_grid1 import ImageGrid 11 | import tensorflow as tf 12 | import math 13 | import sys 14 | from data_generator import CIFAR_MEANS, CIFAR_STDS 15 | 16 | gfile = tf.io.gfile 17 | 18 | class Logger(object): 19 | """Prints to both STDOUT and a file.""" 20 | 21 | def __init__(self, filepath): 22 | self.terminal = sys.stdout 23 | self.log = gfile.GFile(filepath, 'a+') 24 | 25 | def write(self, message): 26 | self.terminal.write(message) 27 | self.terminal.flush() 28 | self.log.write(message) 29 | self.log.flush() 30 | 31 | def flush(self): 32 | self.terminal.flush() 33 | self.log.flush() 34 | 35 | class CTLEarlyStopping: 36 | def __init__(self, 37 | monitor='val_loss', 38 | min_delta=0, 39 | patience=0, 40 | mode='auto', 41 | ): 42 | self.monitor = monitor 43 | self.patience = patience 44 | self.min_delta = abs(min_delta) 45 | self.wait = 0 46 | self.stop_training = False 47 | self.improvement = False 48 | 49 | if mode not in ['auto', 'min', 'max']: 50 | logging.warning('EarlyStopping mode %s is unknown, ' 51 | 'fallback to auto mode.', mode) 52 | mode = 'auto' 53 | 54 | if mode == 'min': 55 | self.monitor_op = np.less 56 | elif mode == 'max': 57 | self.monitor_op = np.greater 58 | else: 59 | if 'acc' in self.monitor: 60 | self.monitor_op = np.greater 61 | else: 62 | self.monitor_op = np.less 63 | 64 | if self.monitor_op == np.greater: 65 | self.min_delta *= 1 66 | else: 67 | self.min_delta *= -1 68 | 69 | self.best = np.Inf if self.monitor_op == np.less else -np.Inf 70 | 71 | 72 | def check_progress(self, current): 73 | if self.monitor_op(current - self.min_delta, self.best): 74 | print(f"{self.monitor} improved from {self.best:.4f} to {current:.4f}.", end=" ") 75 | self.best = current 76 | self.wait = 0 77 | self.improvement = True 78 | else: 79 | self.wait += 1 80 | self.improvement = False 81 | print(f"{self.monitor} didn't improve") 82 | if self.wait >= self.patience: 83 | print("Early stopping") 84 | self.stop_training = True 85 | 86 | return self.improvement, self.stop_training 87 | 88 | 89 | ########################################################################################## 90 | 91 | 92 | class CTLHistory: 93 | def __init__(self, 94 | filename=None, 95 | save_dir='plots'): 96 | 97 | self.history = {'train_loss':[], 98 | "train_acc":[], 99 | "val_loss":[], 100 | "val_acc":[], 101 | "lr":[], 102 | "wd":[]} 103 | 104 | self.save_dir = save_dir 105 | if not os.path.exists(self.save_dir): 106 | os.mkdir(self.save_dir) 107 | 108 | try: 109 | filename = 'history_cuda.png' 110 | except: 111 | filename = 'history.png' if filename is None else filename 112 | 113 | self.plot_name = os.path.join(self.save_dir, filename) 114 | 115 | 116 | 117 | def update(self, train_stats, val_stats, record_lr_wd): 118 | train_loss, train_acc = train_stats 119 | val_loss, val_acc = val_stats 120 | lr_history, wd_history = record_lr_wd 121 | 122 | self.history['train_loss'].append(train_loss) 123 | self.history['train_acc'].append(np.round(train_acc*100)) 124 | self.history['val_loss'].append(val_loss) 125 | self.history['val_acc'].append(np.round(val_acc*100)) 126 | self.history['lr'].extend(lr_history) 127 | self.history['wd'].extend(wd_history) 128 | 129 | 130 | def plot_and_save(self, initial_epoch=0): 131 | train_loss = self.history['train_loss'] 132 | train_acc = self.history['train_acc'] 133 | val_loss = self.history['val_loss'] 134 | val_acc = self.history['val_acc'] 135 | 136 | epochs = [(i+initial_epoch) for i in range(len(train_loss))] 137 | 138 | f, ax = plt.subplots(3, 1, figsize=(15,8)) 139 | ax[0].plot(epochs, train_loss) 140 | ax[0].plot(epochs, val_loss) 141 | ax[0].set_title('loss progression') 142 | ax[0].set_xlabel('Epochs') 143 | ax[0].set_ylabel('loss values') 144 | ax[0].legend(['train', 'test']) 145 | 146 | ax[1].plot(epochs, train_acc) 147 | ax[1].plot(epochs, val_acc) 148 | ax[1].set_title('accuracy progression') 149 | ax[1].set_xlabel('Epochs') 150 | ax[1].set_ylabel('Accuracy') 151 | ax[1].legend(['train', 'test']) 152 | 153 | steps = len(self.history['lr']) 154 | bs = steps/len(train_loss) 155 | ax[2].plot([s/bs for s in range(steps)], self.history['lr']) 156 | ax[2].plot([s/bs for s in range(steps)], self.history['wd']) 157 | ax[2].set_title('learning rate and weight decay') 158 | ax[2].set_xlabel('Epochs') 159 | ax[2].set_ylabel('lr and wd') 160 | ax[2].legend(['lr', 'wd']) 161 | 162 | plt.savefig(self.plot_name) 163 | plt.close() 164 | 165 | def repeat(x, n, axis): 166 | if isinstance(x, np.ndarray): 167 | return np.repeat(x, n, axis=axis) 168 | elif isinstance(x, list): 169 | return repeat_list(x, n, axis) 170 | else: 171 | raise Exception('Unsupport data type {}'.format(type(x))) 172 | 173 | def repeat_list(x, n, axis): 174 | assert isinstance(x, list), 'Can only consume list type' 175 | if axis == 0: 176 | x_new = sum([[x_] * n for x_ in x], []) 177 | elif axis > 1: 178 | x_new = [repeat(x_, n, axis=axis - 1) for x_ in x] 179 | else: 180 | raise Exception 181 | return x_new 182 | 183 | def tile(x): 184 | return None --------------------------------------------------------------------------------