├── .gitignore
├── DeepAA_evaluate
├── README.md
├── __init__.py
├── augmentations.py
├── autoaugment.py
├── common.py
├── data.py
├── deep_autoaugment.py
├── fast_autoaugment.py
├── imagenet.py
├── lr_scheduler.py
├── metrics.py
├── networks
│ ├── __init__.py
│ ├── convnet.py
│ ├── mlp.py
│ ├── resnet.py
│ ├── shakeshake
│ │ ├── __init__.py
│ │ ├── shake_resnet.py
│ │ ├── shake_resnext.py
│ │ └── shakeshake.py
│ └── wideresnet.py
├── train.py
└── utils.py
├── DeepAA_search.py
├── DeepAA_utils.py
├── README.md
├── __init__.py
├── aug_lib.py
├── augmentation.py
├── confs
├── resnet50_imagenet_DeepAA_8x256_1.yaml
├── resnet50_imagenet_DeepAA_8x256_2.yaml
├── wresnet28x10_cifar100_DeepAA_1.yaml
├── wresnet28x10_cifar100_DeepAA_1_wd1e-3.yaml
├── wresnet28x10_cifar100_DeepAA_2.yaml
├── wresnet28x10_cifar100_DeepAA_2_wd1e-3.yaml
├── wresnet28x10_cifar100_DeepAA_BatchAug8x_1.yaml
├── wresnet28x10_cifar100_DeepAA_BatchAug8x_2.yaml
├── wresnet28x10_cifar10_DeepAA_1.yaml
├── wresnet28x10_cifar10_DeepAA_1_wd1e-3.yaml
├── wresnet28x10_cifar10_DeepAA_2.yaml
├── wresnet28x10_cifar10_DeepAA_2_wd1e-3.yaml
├── wresnet28x10_cifar10_DeepAA_BatchAug8x_1.yaml
└── wresnet28x10_cifar10_DeepAA_BatchAug8x_2.yaml
├── data_generator.py
├── imagenet_data_utils.py
├── images
├── DeepAA.png
├── DeepAA_slideslive.png
├── magnitude_distribution_cifar.png
├── magnitude_distribution_imagenet.png
└── operation_distribution.png
├── lr_scheduler.py
├── policy.py
├── policy_port
├── policy_DeepAA_cifar_1.npz
├── policy_DeepAA_cifar_2.npz
├── policy_DeepAA_imagenet_1.npz
└── policy_DeepAA_imagenet_2.npz
├── requirements.txt
├── resnet.py
├── resnet_imagenet.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | model_checkpoints/**
3 | *.pyc
4 | __pycache__
5 | .idea/**
6 | results/**
7 | plots/**
8 | temp/results/
--------------------------------------------------------------------------------
/DeepAA_evaluate/README.md:
--------------------------------------------------------------------------------
1 | Code for evaluating the generated DeepAA policy.
2 |
3 | The code in this folder is adapted from [TrivialAugment](https://github.com/automl/trivialaugment).
--------------------------------------------------------------------------------
/DeepAA_evaluate/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/DeepAA_evaluate/__init__.py
--------------------------------------------------------------------------------
/DeepAA_evaluate/augmentations.py:
--------------------------------------------------------------------------------
1 | # code in this file is adpated from rpmcruz/autoaugment
2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from DeepAA_evaluate import autoaugment, fast_autoaugment
8 | import aug_lib
9 |
10 |
11 | class Lighting(object):
12 | """Lighting noise(AlexNet - style PCA - based noise)"""
13 |
14 | def __init__(self, alphastd, eigval, eigvec):
15 | self.alphastd = alphastd
16 | self.eigval = torch.Tensor(eigval)
17 | self.eigvec = torch.Tensor(eigvec)
18 |
19 | def __call__(self, img):
20 | if self.alphastd == 0:
21 | return img
22 |
23 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
24 | rgb = self.eigvec.type_as(img).clone() \
25 | .mul(alpha.view(1, 3).expand(3, 3)) \
26 | .mul(self.eigval.view(1, 3).expand(3, 3)) \
27 | .sum(1).squeeze()
28 |
29 | return img.add(rgb.view(3, 1, 1).expand_as(img))
30 |
31 |
32 | class CutoutDefault(object):
33 | """
34 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
35 | """
36 | def __init__(self, length):
37 | self.length = length
38 |
39 | def __call__(self, img):
40 | h, w = img.size(1), img.size(2)
41 | mask = np.ones((h, w), np.float32)
42 | y = np.random.randint(h)
43 | x = np.random.randint(w)
44 |
45 | y1 = np.clip(y - self.length // 2, 0, h)
46 | y2 = np.clip(y + self.length // 2, 0, h)
47 | x1 = np.clip(x - self.length // 2, 0, w)
48 | x2 = np.clip(x + self.length // 2, 0, w)
49 |
50 | mask[y1: y2, x1: x2] = 0.
51 | mask = torch.from_numpy(mask)
52 | mask = mask.expand_as(img)
53 | img *= mask
54 | return img
55 |
56 |
57 | def get_randaugment(n,m,weights,bs):
58 | if n == 101 and m == 101:
59 | return autoaugment.CifarAutoAugment(fixed_posterize=False)
60 | if n == 102 and m == 102:
61 | return autoaugment.CifarAutoAugment(fixed_posterize=True)
62 | if n == 201 and m == 201:
63 | return autoaugment.SVHNAutoAugment(fixed_posterize=False)
64 | if n == 202 and m == 202:
65 | return autoaugment.SVHNAutoAugment(fixed_posterize=False)
66 | if n == 301 and m == 301:
67 | return fast_autoaugment.cifar10_faa
68 | if n == 401 and m == 401:
69 | return fast_autoaugment.svhn_faa
70 | assert m < 100 and n < 100
71 | if m == 0:
72 | if weights is not None:
73 | return aug_lib.UniAugmentWeighted(n, probs=weights)
74 | elif n == 0:
75 | return aug_lib.UniAugment()
76 | else:
77 | raise ValueError('Wrong RandAug Params.')
78 | else:
79 | assert n > 0 and m > 0
80 | return aug_lib.RandAugment(n, m)
--------------------------------------------------------------------------------
/DeepAA_evaluate/autoaugment.py:
--------------------------------------------------------------------------------
1 | # Copyright 2018 The TensorFlow Authors All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Transforms used in the Augmentation Policies."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import random
23 | import numpy as np
24 | # pylint:disable=g-multiple-import
25 | from PIL import ImageOps, ImageEnhance, ImageFilter, Image
26 | # pylint:enable=g-multiple-import
27 |
28 |
29 | IMAGE_SIZE = 32
30 | # What is the dataset mean and std of the images on the training set
31 | PARAMETER_MAX = 30 # What is the max 'level' a transform could be predicted
32 |
33 | def pil_wrap(img):
34 | """Convert the `img` numpy tensor to a PIL Image."""
35 | return img.convert('RGBA')
36 |
37 |
38 | def pil_unwrap(img):
39 | """Converts the PIL img to a numpy array."""
40 | return img.convert('RGB')
41 |
42 | def apply_policy(policy, img, use_fixed_posterize=False):
43 | """Apply the `policy` to the numpy `img`.
44 |
45 | Args:
46 | policy: A list of tuples with the form (name, probability, level) where
47 | `name` is the name of the augmentation operation to apply, `probability`
48 | is the probability of applying the operation and `level` is what strength
49 | the operation to apply.
50 | img: Numpy image that will have `policy` applied to it.
51 |
52 | Returns:
53 | The result of applying `policy` to `img`.
54 | """
55 | nametotransform = fixed_AA_NAME_TO_TRANSFORM if use_fixed_posterize else AA_NAME_TO_TRANSFORM
56 | pil_img = pil_wrap(img)
57 | for xform in policy:
58 | assert len(xform) == 3
59 | name, probability, level = xform
60 | xform_fn = nametotransform[name].pil_transformer(probability, level)
61 | pil_img = xform_fn(pil_img)
62 | return pil_unwrap(pil_img)
63 |
64 |
65 | def random_flip(x):
66 | """Flip the input x horizontally with 50% probability."""
67 | if np.random.rand(1)[0] > 0.5:
68 | return np.fliplr(x)
69 | return x
70 |
71 |
72 | def zero_pad_and_crop(img, amount=4):
73 | """Zero pad by `amount` zero pixels on each side then take a random crop.
74 |
75 | Args:
76 | img: numpy image that will be zero padded and cropped.
77 | amount: amount of zeros to pad `img` with horizontally and verically.
78 |
79 | Returns:
80 | The cropped zero padded img. The returned numpy array will be of the same
81 | shape as `img`.
82 | """
83 | padded_img = np.zeros((img.shape[0] + amount * 2, img.shape[1] + amount * 2,
84 | img.shape[2]))
85 | padded_img[amount:img.shape[0] + amount, amount:
86 | img.shape[1] + amount, :] = img
87 | top = np.random.randint(low=0, high=2 * amount)
88 | left = np.random.randint(low=0, high=2 * amount)
89 | new_img = padded_img[top:top + img.shape[0], left:left + img.shape[1], :]
90 | return new_img
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 | def float_parameter(level, maxval):
99 | """Helper function to scale `val` between 0 and maxval .
100 |
101 | Args:
102 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
103 | maxval: Maximum value that the operation can have. This will be scaled
104 | to level/PARAMETER_MAX.
105 |
106 | Returns:
107 | A float that results from scaling `maxval` according to `level`.
108 | """
109 | return float(level) * maxval / PARAMETER_MAX
110 |
111 |
112 | def int_parameter(level, maxval):
113 | """Helper function to scale `val` between 0 and maxval .
114 |
115 | Args:
116 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
117 | maxval: Maximum value that the operation can have. This will be scaled
118 | to level/PARAMETER_MAX.
119 |
120 | Returns:
121 | An int that results from scaling `maxval` according to `level`.
122 | """
123 | return int(level * maxval / PARAMETER_MAX)
124 |
125 |
126 |
127 |
128 | class TransformFunction(object):
129 | """Wraps the Transform function for pretty printing options."""
130 |
131 | def __init__(self, func, name):
132 | self.f = func
133 | self.name = name
134 |
135 | def __repr__(self):
136 | return '<' + self.name + '>'
137 |
138 | def __call__(self, pil_img):
139 | return self.f(pil_img)
140 |
141 |
142 | class TransformT(object):
143 | """Each instance of this class represents a specific transform."""
144 |
145 | def __init__(self, name, xform_fn):
146 | self.name = name
147 | self.xform = xform_fn
148 |
149 | def pil_transformer(self, probability, level):
150 |
151 | def return_function(im):
152 | if random.random() < probability:
153 | im = self.xform(im, level)
154 | return im
155 |
156 | name = self.name + '({:.1f},{})'.format(probability, level)
157 | return TransformFunction(return_function, name)
158 |
159 | def do_transform(self, image, level):
160 | f = self.pil_transformer(PARAMETER_MAX, level)
161 | return f(image)
162 |
163 |
164 | ################## Transform Functions ##################
165 | identity = TransformT('identity', lambda pil_img, level: pil_img)
166 | flip_lr = TransformT(
167 | 'FlipLR',
168 | lambda pil_img, level: pil_img.transpose(Image.FLIP_LEFT_RIGHT))
169 | flip_ud = TransformT(
170 | 'FlipUD',
171 | lambda pil_img, level: pil_img.transpose(Image.FLIP_TOP_BOTTOM))
172 | # pylint:disable=g-long-lambda
173 | auto_contrast = TransformT(
174 | 'AutoContrast',
175 | lambda pil_img, level: ImageOps.autocontrast(
176 | pil_img.convert('RGB')).convert('RGBA'))
177 | equalize = TransformT(
178 | 'Equalize',
179 | lambda pil_img, level: ImageOps.equalize(
180 | pil_img.convert('RGB')).convert('RGBA'))
181 | invert = TransformT(
182 | 'Invert',
183 | lambda pil_img, level: ImageOps.invert(
184 | pil_img.convert('RGB')).convert('RGBA'))
185 | # pylint:enable=g-long-lambda
186 | blur = TransformT(
187 | 'Blur', lambda pil_img, level: pil_img.filter(ImageFilter.BLUR))
188 | smooth = TransformT(
189 | 'Smooth',
190 | lambda pil_img, level: pil_img.filter(ImageFilter.SMOOTH))
191 |
192 |
193 | def _rotate_impl(pil_img, level):
194 | """Rotates `pil_img` from -30 to 30 degrees depending on `level`."""
195 | degrees = int_parameter(level, 30)
196 | if random.random() > 0.5:
197 | degrees = -degrees
198 | return pil_img.rotate(degrees)
199 |
200 |
201 | rotate = TransformT('Rotate', _rotate_impl)
202 |
203 |
204 | def _posterize_impl(pil_img, level):
205 | """Applies PIL Posterize to `pil_img`."""
206 | level = int_parameter(level, 4)
207 | return ImageOps.posterize(pil_img.convert('RGB'), 4 - level).convert('RGBA')
208 |
209 |
210 | posterize = TransformT('Posterize', _posterize_impl)
211 |
212 | def _fixed_posterize_impl(pil_img, level):
213 | """Applies PIL Posterize to `pil_img`."""
214 | level = int_parameter(level, 4)
215 | return ImageOps.posterize(pil_img.convert('RGB'), 8 - level).convert('RGBA')
216 |
217 | fixed_posterize = TransformT('Posterize', _fixed_posterize_impl)
218 |
219 |
220 | def _shear_x_impl(pil_img, level):
221 | """Applies PIL ShearX to `pil_img`.
222 |
223 | The ShearX operation shears the image along the horizontal axis with `level`
224 | magnitude.
225 |
226 | Args:
227 | pil_img: Image in PIL object.
228 | level: Strength of the operation specified as an Integer from
229 | [0, `PARAMETER_MAX`].
230 |
231 | Returns:
232 | A PIL Image that has had ShearX applied to it.
233 | """
234 | level = float_parameter(level, 0.3)
235 | if random.random() > 0.5:
236 | level = -level
237 | return pil_img.transform((32, 32), Image.AFFINE, (1, level, 0, 0, 1, 0))
238 |
239 |
240 | shear_x = TransformT('ShearX', _shear_x_impl)
241 |
242 |
243 | def _shear_y_impl(pil_img, level):
244 | """Applies PIL ShearY to `pil_img`.
245 |
246 | The ShearY operation shears the image along the vertical axis with `level`
247 | magnitude.
248 |
249 | Args:
250 | pil_img: Image in PIL object.
251 | level: Strength of the operation specified as an Integer from
252 | [0, `PARAMETER_MAX`].
253 |
254 | Returns:
255 | A PIL Image that has had ShearX applied to it.
256 | """
257 | level = float_parameter(level, 0.3)
258 | if random.random() > 0.5:
259 | level = -level
260 | return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, level, 1, 0))
261 |
262 |
263 | shear_y = TransformT('ShearY', _shear_y_impl)
264 |
265 |
266 | def _translate_x_impl(pil_img, level):
267 | """Applies PIL TranslateX to `pil_img`.
268 |
269 | Translate the image in the horizontal direction by `level`
270 | number of pixels.
271 |
272 | Args:
273 | pil_img: Image in PIL object.
274 | level: Strength of the operation specified as an Integer from
275 | [0, `PARAMETER_MAX`].
276 |
277 | Returns:
278 | A PIL Image that has had TranslateX applied to it.
279 | """
280 | level = int_parameter(level, 10)
281 | if random.random() > 0.5:
282 | level = -level
283 | return pil_img.transform((32, 32), Image.AFFINE, (1, 0, level, 0, 1, 0))
284 |
285 |
286 | translate_x = TransformT('TranslateX', _translate_x_impl)
287 |
288 |
289 | def _translate_y_impl(pil_img, level):
290 | """Applies PIL TranslateY to `pil_img`.
291 |
292 | Translate the image in the vertical direction by `level`
293 | number of pixels.
294 |
295 | Args:
296 | pil_img: Image in PIL object.
297 | level: Strength of the operation specified as an Integer from
298 | [0, `PARAMETER_MAX`].
299 |
300 | Returns:
301 | A PIL Image that has had TranslateY applied to it.
302 | """
303 | level = int_parameter(level, 10)
304 | if random.random() > 0.5:
305 | level = -level
306 | return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, 0, 1, level))
307 |
308 |
309 | translate_y = TransformT('TranslateY', _translate_y_impl)
310 |
311 |
312 | def _crop_impl(pil_img, level, interpolation=Image.BILINEAR):
313 | """Applies a crop to `pil_img` with the size depending on the `level`."""
314 | cropped = pil_img.crop((level, level, IMAGE_SIZE - level, IMAGE_SIZE - level))
315 | resized = cropped.resize((IMAGE_SIZE, IMAGE_SIZE), interpolation)
316 | return resized
317 |
318 |
319 | crop_bilinear = TransformT('CropBilinear', _crop_impl)
320 |
321 |
322 | def _solarize_impl(pil_img, level):
323 | """Applies PIL Solarize to `pil_img`.
324 |
325 | Translate the image in the vertical direction by `level`
326 | number of pixels.
327 |
328 | Args:
329 | pil_img: Image in PIL object.
330 | level: Strength of the operation specified as an Integer from
331 | [0, `PARAMETER_MAX`].
332 |
333 | Returns:
334 | A PIL Image that has had Solarize applied to it.
335 | """
336 | level = int_parameter(level, 256)
337 | return ImageOps.solarize(pil_img.convert('RGB'), 256 - level).convert('RGBA')
338 |
339 |
340 | solarize = TransformT('Solarize', _solarize_impl)
341 |
342 |
343 | def _enhancer_impl(enhancer):
344 | """Sets level to be between 0.1 and 1.8 for ImageEnhance transforms of PIL."""
345 | def impl(pil_img, level):
346 | v = float_parameter(level, 1.8) + .1 # going to 0 just destroys it
347 | return enhancer(pil_img).enhance(v)
348 | return impl
349 |
350 |
351 | color = TransformT('Color', _enhancer_impl(ImageEnhance.Color))
352 | contrast = TransformT('Contrast', _enhancer_impl(ImageEnhance.Contrast))
353 | brightness = TransformT('Brightness', _enhancer_impl(
354 | ImageEnhance.Brightness))
355 | sharpness = TransformT('Sharpness', _enhancer_impl(ImageEnhance.Sharpness))
356 |
357 | def create_cutout_mask(img_height, img_width, num_channels, size):
358 | """Creates a zero mask used for cutout of shape `img_height` x `img_width`.
359 |
360 | Args:
361 | img_height: Height of image cutout mask will be applied to.
362 | img_width: Width of image cutout mask will be applied to.
363 | num_channels: Number of channels in the image.
364 | size: Size of the zeros mask.
365 |
366 | Returns:
367 | A mask of shape `img_height` x `img_width` with all ones except for a
368 | square of zeros of shape `size` x `size`. This mask is meant to be
369 | elementwise multiplied with the original image. Additionally returns
370 | the `upper_coord` and `lower_coord` which specify where the cutout mask
371 | will be applied.
372 | """
373 | assert img_height == img_width
374 |
375 | # Sample center where cutout mask will be applied
376 | height_loc = np.random.randint(low=0, high=img_height)
377 | width_loc = np.random.randint(low=0, high=img_width)
378 |
379 | # Determine upper right and lower left corners of patch
380 | upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2))
381 | lower_coord = (min(img_height, height_loc + size // 2),
382 | min(img_width, width_loc + size // 2))
383 | mask_height = lower_coord[0] - upper_coord[0]
384 | mask_width = lower_coord[1] - upper_coord[1]
385 | assert mask_height > 0
386 | assert mask_width > 0
387 |
388 | mask = np.ones((img_height, img_width, num_channels))
389 | zeros = np.zeros((mask_height, mask_width, num_channels))
390 | mask[upper_coord[0]:lower_coord[0], upper_coord[1]:lower_coord[1], :] = (
391 | zeros)
392 | return mask, upper_coord, lower_coord
393 |
394 | def _cutout_pil_impl(pil_img, level):
395 | """Apply cutout to pil_img at the specified level."""
396 | size = int_parameter(level, 20)
397 | if size <= 0:
398 | return pil_img
399 | img_height, img_width, num_channels = (32, 32, 3)
400 | _, upper_coord, lower_coord = (
401 | create_cutout_mask(img_height, img_width, num_channels, size))
402 | pixels = pil_img.load() # create the pixel map
403 | for i in range(upper_coord[0], lower_coord[0]): # for every col:
404 | for j in range(upper_coord[1], lower_coord[1]): # For every row
405 | pixels[i, j] = (125, 122, 113, 0) # set the colour accordingly
406 | return pil_img
407 |
408 | cutout = TransformT('Cutout', _cutout_pil_impl)
409 |
410 |
411 |
412 | ALL_TRANSFORMS = [
413 | identity,
414 | auto_contrast,
415 | equalize,
416 | rotate,
417 | posterize,
418 | solarize,
419 | color,
420 | contrast,
421 | brightness,
422 | sharpness,
423 | shear_x,
424 | shear_y,
425 | translate_x,
426 | translate_y,
427 | ]
428 |
429 | AA_ALL_TRANSFORMS = [
430 | flip_lr,
431 | flip_ud,
432 | auto_contrast,
433 | equalize,
434 | invert,
435 | rotate,
436 | posterize,
437 | crop_bilinear,
438 | solarize,
439 | color,
440 | contrast,
441 | brightness,
442 | sharpness,
443 | shear_x,
444 | shear_y,
445 | translate_x,
446 | translate_y,
447 | cutout,
448 | blur,
449 | smooth
450 | ]
451 |
452 |
453 | fixed_AA_ALL_TRANSFORMS = [
454 | flip_lr,
455 | flip_ud,
456 | auto_contrast,
457 | equalize,
458 | invert,
459 | rotate,
460 | fixed_posterize,
461 | crop_bilinear,
462 | solarize,
463 | color,
464 | contrast,
465 | brightness,
466 | sharpness,
467 | shear_x,
468 | shear_y,
469 | translate_x,
470 | translate_y,
471 | cutout,
472 | blur,
473 | smooth
474 | ]
475 |
476 |
477 | class RandAugment:
478 | def __init__(self, n, m):
479 | self.n = n
480 | self.m = m # [0, 30]
481 |
482 | def __call__(self, img):
483 | img = pil_wrap(img)
484 | ops = random.choices(ALL_TRANSFORMS, k=self.n)
485 | for op in ops:
486 | img = op.pil_transformer(1.,self.m)(img)
487 | img = pil_unwrap(img)
488 |
489 | return img
490 |
491 | AA_NAME_TO_TRANSFORM = {t.name: t for t in AA_ALL_TRANSFORMS}
492 | fixed_AA_NAME_TO_TRANSFORM = {t.name: t for t in fixed_AA_ALL_TRANSFORMS}
493 |
494 | NAME_TO_TRANSFORM = {t.name: t for t in ALL_TRANSFORMS}
495 |
496 | def good_policies():
497 | """AutoAugment policies found on Cifar."""
498 | exp0_0 = [
499 | [('Invert', 0.1, 7), ('Contrast', 0.2, 6)],
500 | [('Rotate', 0.7, 2), ('TranslateX', 0.3, 9)],
501 | [('Sharpness', 0.8, 1), ('Sharpness', 0.9, 3)],
502 | [('ShearY', 0.5, 8), ('TranslateY', 0.7, 9)],
503 | [('AutoContrast', 0.5, 8), ('Equalize', 0.9, 2)]]
504 | exp0_1 = [
505 | [('Solarize', 0.4, 5), ('AutoContrast', 0.9, 3)],
506 | [('TranslateY', 0.9, 9), ('TranslateY', 0.7, 9)],
507 | [('AutoContrast', 0.9, 2), ('Solarize', 0.8, 3)],
508 | [('Equalize', 0.8, 8), ('Invert', 0.1, 3)],
509 | [('TranslateY', 0.7, 9), ('AutoContrast', 0.9, 1)]]
510 | exp0_2 = [
511 | [('Solarize', 0.4, 5), ('AutoContrast', 0.0, 2)],
512 | [('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)],
513 | [('AutoContrast', 0.9, 0), ('Solarize', 0.4, 3)],
514 | [('Equalize', 0.7, 5), ('Invert', 0.1, 3)],
515 | [('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)]]
516 | exp0_3 = [
517 | [('Solarize', 0.4, 5), ('AutoContrast', 0.9, 1)],
518 | [('TranslateY', 0.8, 9), ('TranslateY', 0.9, 9)],
519 | [('AutoContrast', 0.8, 0), ('TranslateY', 0.7, 9)],
520 | [('TranslateY', 0.2, 7), ('Color', 0.9, 6)],
521 | [('Equalize', 0.7, 6), ('Color', 0.4, 9)]]
522 | exp1_0 = [
523 | [('ShearY', 0.2, 7), ('Posterize', 0.3, 7)],
524 | [('Color', 0.4, 3), ('Brightness', 0.6, 7)],
525 | [('Sharpness', 0.3, 9), ('Brightness', 0.7, 9)],
526 | [('Equalize', 0.6, 5), ('Equalize', 0.5, 1)],
527 | [('Contrast', 0.6, 7), ('Sharpness', 0.6, 5)]]
528 | exp1_1 = [
529 | [('Brightness', 0.3, 7), ('AutoContrast', 0.5, 8)],
530 | [('AutoContrast', 0.9, 4), ('AutoContrast', 0.5, 6)],
531 | [('Solarize', 0.3, 5), ('Equalize', 0.6, 5)],
532 | [('TranslateY', 0.2, 4), ('Sharpness', 0.3, 3)],
533 | [('Brightness', 0.0, 8), ('Color', 0.8, 8)]]
534 | exp1_2 = [
535 | [('Solarize', 0.2, 6), ('Color', 0.8, 6)],
536 | [('Solarize', 0.2, 6), ('AutoContrast', 0.8, 1)],
537 | [('Solarize', 0.4, 1), ('Equalize', 0.6, 5)],
538 | [('Brightness', 0.0, 0), ('Solarize', 0.5, 2)],
539 | [('AutoContrast', 0.9, 5), ('Brightness', 0.5, 3)]]
540 | exp1_3 = [
541 | [('Contrast', 0.7, 5), ('Brightness', 0.0, 2)],
542 | [('Solarize', 0.2, 8), ('Solarize', 0.1, 5)],
543 | [('Contrast', 0.5, 1), ('TranslateY', 0.2, 9)],
544 | [('AutoContrast', 0.6, 5), ('TranslateY', 0.0, 9)],
545 | [('AutoContrast', 0.9, 4), ('Equalize', 0.8, 4)]]
546 | exp1_4 = [
547 | [('Brightness', 0.0, 7), ('Equalize', 0.4, 7)],
548 | [('Solarize', 0.2, 5), ('Equalize', 0.7, 5)],
549 | [('Equalize', 0.6, 8), ('Color', 0.6, 2)],
550 | [('Color', 0.3, 7), ('Color', 0.2, 4)],
551 | [('AutoContrast', 0.5, 2), ('Solarize', 0.7, 2)]]
552 | exp1_5 = [
553 | [('AutoContrast', 0.2, 0), ('Equalize', 0.1, 0)],
554 | [('ShearY', 0.6, 5), ('Equalize', 0.6, 5)],
555 | [('Brightness', 0.9, 3), ('AutoContrast', 0.4, 1)],
556 | [('Equalize', 0.8, 8), ('Equalize', 0.7, 7)],
557 | [('Equalize', 0.7, 7), ('Solarize', 0.5, 0)]]
558 | exp1_6 = [
559 | [('Equalize', 0.8, 4), ('TranslateY', 0.8, 9)],
560 | [('TranslateY', 0.8, 9), ('TranslateY', 0.6, 9)],
561 | [('TranslateY', 0.9, 0), ('TranslateY', 0.5, 9)],
562 | [('AutoContrast', 0.5, 3), ('Solarize', 0.3, 4)],
563 | [('Solarize', 0.5, 3), ('Equalize', 0.4, 4)]]
564 | exp2_0 = [
565 | [('Color', 0.7, 7), ('TranslateX', 0.5, 8)],
566 | [('Equalize', 0.3, 7), ('AutoContrast', 0.4, 8)],
567 | [('TranslateY', 0.4, 3), ('Sharpness', 0.2, 6)],
568 | [('Brightness', 0.9, 6), ('Color', 0.2, 8)],
569 | [('Solarize', 0.5, 2), ('Invert', 0.0, 3)]]
570 | exp2_1 = [
571 | [('AutoContrast', 0.1, 5), ('Brightness', 0.0, 0)],
572 | [('Cutout', 0.2, 4), ('Equalize', 0.1, 1)],
573 | [('Equalize', 0.7, 7), ('AutoContrast', 0.6, 4)],
574 | [('Color', 0.1, 8), ('ShearY', 0.2, 3)],
575 | [('ShearY', 0.4, 2), ('Rotate', 0.7, 0)]]
576 | exp2_2 = [
577 | [('ShearY', 0.1, 3), ('AutoContrast', 0.9, 5)],
578 | [('TranslateY', 0.3, 6), ('Cutout', 0.3, 3)],
579 | [('Equalize', 0.5, 0), ('Solarize', 0.6, 6)],
580 | [('AutoContrast', 0.3, 5), ('Rotate', 0.2, 7)],
581 | [('Equalize', 0.8, 2), ('Invert', 0.4, 0)]]
582 | exp2_3 = [
583 | [('Equalize', 0.9, 5), ('Color', 0.7, 0)],
584 | [('Equalize', 0.1, 1), ('ShearY', 0.1, 3)],
585 | [('AutoContrast', 0.7, 3), ('Equalize', 0.7, 0)],
586 | [('Brightness', 0.5, 1), ('Contrast', 0.1, 7)],
587 | [('Contrast', 0.1, 4), ('Solarize', 0.6, 5)]]
588 | exp2_4 = [
589 | [('Solarize', 0.2, 3), ('ShearX', 0.0, 0)],
590 | [('TranslateX', 0.3, 0), ('TranslateX', 0.6, 0)],
591 | [('Equalize', 0.5, 9), ('TranslateY', 0.6, 7)],
592 | [('ShearX', 0.1, 0), ('Sharpness', 0.5, 1)],
593 | [('Equalize', 0.8, 6), ('Invert', 0.3, 6)]]
594 | exp2_5 = [
595 | [('AutoContrast', 0.3, 9), ('Cutout', 0.5, 3)],
596 | [('ShearX', 0.4, 4), ('AutoContrast', 0.9, 2)],
597 | [('ShearX', 0.0, 3), ('Posterize', 0.0, 3)],
598 | [('Solarize', 0.4, 3), ('Color', 0.2, 4)],
599 | [('Equalize', 0.1, 4), ('Equalize', 0.7, 6)]]
600 | exp2_6 = [
601 | [('Equalize', 0.3, 8), ('AutoContrast', 0.4, 3)],
602 | [('Solarize', 0.6, 4), ('AutoContrast', 0.7, 6)],
603 | [('AutoContrast', 0.2, 9), ('Brightness', 0.4, 8)],
604 | [('Equalize', 0.1, 0), ('Equalize', 0.0, 6)],
605 | [('Equalize', 0.8, 4), ('Equalize', 0.0, 4)]]
606 | exp2_7 = [
607 | [('Equalize', 0.5, 5), ('AutoContrast', 0.1, 2)],
608 | [('Solarize', 0.5, 5), ('AutoContrast', 0.9, 5)],
609 | [('AutoContrast', 0.6, 1), ('AutoContrast', 0.7, 8)],
610 | [('Equalize', 0.2, 0), ('AutoContrast', 0.1, 2)],
611 | [('Equalize', 0.6, 9), ('Equalize', 0.4, 4)]]
612 | exp0s = exp0_0 + exp0_1 + exp0_2 + exp0_3
613 | exp1s = exp1_0 + exp1_1 + exp1_2 + exp1_3 + exp1_4 + exp1_5 + exp1_6
614 | exp2s = exp2_0 + exp2_1 + exp2_2 + exp2_3 + exp2_4 + exp2_5 + exp2_6 + exp2_7
615 | return exp0s + exp1s + exp2s
616 |
617 | cifar_gp = good_policies()
618 |
619 | first_aug_ops = [("ShearX",0.9,4), ("ShearY",0.9,8), ("Equalize",0.6,5), ("Invert",0.9,3), ("Equalize",0.6,1), ("ShearX",0.9,4), ("ShearY",0.9,8), ("ShearY",0.9,5), ("Invert",0.9,6), ("Equalize",0.6,3), ("ShearX",0.9,4), ("ShearY",0.8,8), ("Equalize",0.9,5), ("Invert",0.9,4), ("Contrast",0.3,3), ("Invert",0.8,5), ("ShearY",0.7,6), ("Invert",0.6,4), ("ShearY",0.3,7), ("ShearX",0.1,6), ("Solarize",0.7,2), ("ShearY",0.8,4), ("ShearX",0.7,9), ("ShearY",0.8,5), ("ShearX",0.7,2)]
620 | second_aug_ops = [("Invert",0.2,3), ("Invert",0.7,5), ("Solarize",0.6,6), ("Equalize",0.6,3), ("Rotate",0.9,3), ("AutoContrast",0.8,3), ("Invert",0.4,5), ("Solarize",0.2,6), ("AutoContrast",0.8,1), ("Rotate",0.9,3), ("Solarize",0.3,3), ("Invert",0.7,4), ("TranslateY",0.6,6), ("Equalize",0.6,7), ("Rotate",0.8,4), ("TranslateY",0.0,2), ("Solarize",0.4,8), ("Rotate",0.8,4), ("TranslateX",0.9,3), ("Invert",0.6,5), ("TranslateY",0.6,7), ("Invert",0.8,8), ("TranslateY",0.8,3), ("AutoContrast",0.7,3), ("Invert",0.1,5)]
621 |
622 | svhn_gp = [[a1, a2] for a1, a2 in zip(first_aug_ops,second_aug_ops)]
623 |
624 | class CifarAutoAugment:
625 | def __init__(self, fixed_posterize):
626 | self.fixed_posterize = fixed_posterize
627 |
628 | def __call__(self, img):
629 | epoch_policy = cifar_gp[np.random.choice(len(cifar_gp))]
630 | final_img = apply_policy(epoch_policy, img, use_fixed_posterize=self.fixed_posterize)
631 |
632 | return final_img
633 |
634 | class SVHNAutoAugment:
635 | def __init__(self, fixed_posterize):
636 | self.fixed_posterize = fixed_posterize
637 |
638 | def __call__(self, img):
639 | epoch_policy = svhn_gp[np.random.choice(len(svhn_gp))]
640 | final_img = apply_policy(epoch_policy, img, use_fixed_posterize=self.fixed_posterize)
641 |
642 | return final_img
--------------------------------------------------------------------------------
/DeepAA_evaluate/common.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import warnings
3 | import random
4 | from copy import copy
5 | from typing import Union
6 | from collections import Counter
7 |
8 | import numpy as np
9 | import torch
10 | from torch.utils.checkpoint import check_backward_validity, detach_variable, get_device_states, set_device_states
11 | from torchvision.datasets import VisionDataset, CIFAR10, CIFAR100, ImageFolder
12 | from torch.utils.data import Subset, ConcatDataset
13 |
14 | from PIL import Image
15 |
16 | formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s')
17 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
18 |
19 |
20 | def get_logger(name, level=logging.DEBUG):
21 | logger = logging.getLogger(name)
22 | logger.handlers.clear()
23 | logger.setLevel(level)
24 | ch = logging.StreamHandler()
25 | ch.setLevel(level)
26 | ch.setFormatter(formatter)
27 | logger.addHandler(ch)
28 | return logger
29 |
30 |
31 | def add_filehandler(logger, filepath):
32 | fh = logging.FileHandler(filepath)
33 | fh.setLevel(logging.DEBUG)
34 | fh.setFormatter(formatter)
35 | logger.addHandler(fh)
36 |
37 |
38 | def copy_and_replace_transform(ds: Union[CIFAR10, ImageFolder, Subset], transform):
39 | assert ds.dataset.transform is not None if isinstance(ds,Subset) else (all(d.transform is not None for d in ds.datasets) if isinstance(ds,ConcatDataset) else ds.transform is not None) # make sure still uses old style transform
40 | if isinstance(ds, Subset):
41 | new_super_ds = copy(ds.dataset)
42 | new_super_ds.transform = transform
43 | new_ds = copy(ds)
44 | new_ds.dataset = new_super_ds
45 | elif isinstance(ds, ConcatDataset):
46 | def copy_and_replace_transform(ds):
47 | new_ds = copy(ds)
48 | new_ds.transform = transform
49 | return new_ds
50 |
51 | new_ds = ConcatDataset([copy_and_replace_transform(d) for d in ds.datasets])
52 |
53 | else:
54 | new_ds = copy(ds)
55 | new_ds.transform = transform
56 | return new_ds
57 |
58 | def apply_weightnorm(nn):
59 | def apply_weightnorm_(module):
60 | if 'Linear' in type(module).__name__ or 'Conv' in type(module).__name__:
61 | torch.nn.utils.weight_norm(module, name='weight', dim=0)
62 | nn.apply(apply_weightnorm_)
63 |
64 |
65 | def shufflelist_with_seed(lis, seed='2020'):
66 | s = random.getstate()
67 | random.seed(seed)
68 | random.shuffle(lis)
69 | random.setstate(s)
70 |
71 |
72 | def stratified_split(labels, val_share):
73 | assert isinstance(labels, list)
74 | counter = Counter(labels)
75 | indices_per_label = {label: [i for i,l in enumerate(labels) if l == label] for label in counter}
76 | per_label_split = {}
77 | for label, count in counter.items():
78 | indices = indices_per_label[label]
79 | assert count == len(indices)
80 | shufflelist_with_seed(indices, f'2020_{label}_{count}')
81 | train_val_border = round(count*(1.-val_share))
82 | per_label_split[label] = (indices[:train_val_border], indices[train_val_border:])
83 | final_split = ([],[])
84 | for label, split in per_label_split.items():
85 | for f_s, s in zip(final_split, split):
86 | f_s.extend(s)
87 | shufflelist_with_seed(final_split[0], '2020_yoyo')
88 | shufflelist_with_seed(final_split[1], '2020_yo')
89 | return final_split
90 |
91 |
92 | def denormalize(img, mean, std):
93 | mean, std = torch.tensor(mean).to(img.device), torch.tensor(std).to(img.device)
94 | return img.mul_(std[:,None,None]).add_(mean[:,None,None])
95 |
96 | def normalize(img, mean, std):
97 | mean, std = torch.tensor(mean).to(img.device), torch.tensor(std).to(img.device)
98 | return img.sub_(mean[:,None,None]).div_(std[:,None,None])
--------------------------------------------------------------------------------
/DeepAA_evaluate/data.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import random
4 | from collections import Counter
5 |
6 | import torchvision
7 | from PIL import Image
8 |
9 | from torch.utils.data import SubsetRandomSampler, Sampler
10 | from torch.utils.data.distributed import DistributedSampler
11 | from torch.utils.data.dataset import ConcatDataset, Subset
12 | from torchvision.transforms import transforms
13 | from sklearn.model_selection import StratifiedShuffleSplit
14 | from theconf import Config as C
15 |
16 | from DeepAA_evaluate.augmentations import *
17 | from DeepAA_evaluate.common import get_logger, copy_and_replace_transform, stratified_split, denormalize
18 | from DeepAA_evaluate.imagenet import ImageNet
19 |
20 | from DeepAA_evaluate.augmentations import Lighting
21 |
22 | from DeepAA_evaluate.deep_autoaugment import Augmentation_DeepAA
23 |
24 | logger = get_logger('DeepAA_evaluate')
25 | logger.setLevel(logging.INFO)
26 | _IMAGENET_PCA = {
27 | 'eigval': [0.2175, 0.0188, 0.0045],
28 | 'eigvec': [
29 | [-0.5675, 0.7192, 0.4009],
30 | [-0.5808, -0.0045, -0.8140],
31 | [-0.5836, -0.6948, 0.4203],
32 | ]
33 | }
34 | _CIFAR_MEAN, _CIFAR_STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) # these are for CIFAR 10, not for cifar100 actaully. They are pretty similar, though.
35 | # mean für cifar 100: tensor([0.5071, 0.4866, 0.4409])
36 |
37 | def expand(num_classes, dtype, tensor):
38 | e = torch.zeros(
39 | tensor.size(0), num_classes, dtype=dtype, device=torch.device("cuda")
40 | )
41 | e = e.scatter(1, tensor.unsqueeze(1), 1.0)
42 | return e
43 |
44 | def mixup_data(data, label, alpha):
45 | with torch.no_grad():
46 | if alpha > 0:
47 | lam = np.random.beta(alpha, alpha)
48 | else:
49 | lam = 1.0
50 | batch_size = data.size()[0]
51 | index = torch.randperm(batch_size).to(data.device)
52 | mixed_data = lam * data + (1.0-lam) * data[index,:]
53 | return mixed_data, label, label[index], lam
54 |
55 |
56 | class PrefetchedWrapper(object):
57 | # Ref: https://github.com/NVIDIA/DeepLearningExamples/blob/d788e8d4968e72c722c5148a50a7d4692f6e7bd3/PyTorch/Classification/ConvNets/image_classification/dataloaders.py#L405
58 | def prefetched_loader(loader, num_classes, one_hot):
59 | mean = (
60 | torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255])
61 | .cuda()
62 | .view(1, 3, 1, 1)
63 | )
64 | std = (
65 | torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255])
66 | .cuda()
67 | .view(1, 3, 1, 1)
68 | )
69 |
70 | stream = torch.cuda.Stream()
71 | first = True
72 |
73 | for next_input, next_target in loader:
74 | with torch.cuda.stream(stream):
75 | next_input = next_input.cuda(non_blocking=True)
76 | next_target = next_target.cuda(non_blocking=True)
77 | next_input = next_input.float()
78 | if one_hot:
79 | raise Exception('Currently do not use onehot encoding, becasue num_calsses==None')
80 | next_target = expand(num_classes, torch.float, next_target)
81 |
82 | next_input = next_input.sub_(mean).div_(std)
83 |
84 | if not first:
85 | yield input, target
86 | else:
87 | first = False
88 |
89 | torch.cuda.current_stream().wait_stream(stream)
90 | input = next_input
91 | target = next_target
92 |
93 | yield input, target
94 |
95 | def __init__(self, dataloader, start_epoch, num_classes, one_hot):
96 | self.dataloader = dataloader
97 | self.epoch = start_epoch
98 | self.one_hot = one_hot
99 | self.num_classes = num_classes
100 |
101 | def __iter__(self):
102 | if self.dataloader.sampler is not None and isinstance(
103 | self.dataloader.sampler, torch.utils.data.distributed.DistributedSampler
104 | ):
105 |
106 | self.dataloader.sampler.set_epoch(self.epoch)
107 | self.epoch += 1
108 | return PrefetchedWrapper.prefetched_loader(
109 | self.dataloader, self.num_classes, self.one_hot
110 | )
111 |
112 | def __len__(self):
113 | return len(self.dataloader)
114 |
115 | def get_dataloaders(dataset, batch, dataroot, split=0.15, split_idx=0, distributed=False, started_with_spawn=False, summary_writer=None):
116 | print(f'started with spawn {started_with_spawn}')
117 | dataset_info = {}
118 | pre_transform_train = transforms.Compose([])
119 | if 'cifar' in dataset and (C.get()['aug'] in ['DeepAA']):
120 | transform_train = transforms.Compose([
121 | # transforms.RandomCrop(32, padding=4),
122 | # transforms.RandomHorizontalFlip(),
123 | transforms.ToTensor(),
124 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
125 | ])
126 | transform_test = transforms.Compose([
127 | transforms.ToTensor(),
128 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
129 | ])
130 | dataset_info['mean'] = _CIFAR_MEAN
131 | dataset_info['std'] = _CIFAR_STD
132 | dataset_info['img_dims'] = (3,32,32)
133 | dataset_info['num_labels'] = 100 if '100' in dataset and 'ten' not in dataset else 10
134 | elif 'cifar' in dataset:
135 | transform_train = transforms.Compose([
136 | transforms.RandomCrop(32, padding=4),
137 | transforms.RandomHorizontalFlip(),
138 | transforms.ToTensor(),
139 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
140 | ])
141 | transform_test = transforms.Compose([
142 | transforms.ToTensor(),
143 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
144 | ])
145 | dataset_info['mean'] = _CIFAR_MEAN
146 | dataset_info['std'] = _CIFAR_STD
147 | dataset_info['img_dims'] = (3,32,32)
148 | dataset_info['num_labels'] = 100 if '100' in dataset and 'ten' not in dataset else 10
149 | elif 'pre_transform_cifar' in dataset:
150 | pre_transform_train = transforms.Compose([
151 | transforms.RandomCrop(32, padding=4),
152 | transforms.RandomHorizontalFlip(),])
153 | transform_train = transforms.Compose([
154 | transforms.ToTensor(),
155 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
156 | ])
157 | transform_test = transforms.Compose([
158 | transforms.ToTensor(),
159 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD),
160 | ])
161 | dataset_info['mean'] = _CIFAR_MEAN
162 | dataset_info['std'] = _CIFAR_STD
163 | dataset_info['img_dims'] = (3, 32, 32)
164 | dataset_info['num_labels'] = 100 if '100' in dataset and 'ten' not in dataset else 10
165 | elif 'svhn' in dataset:
166 | svhn_mean = [0.4379, 0.4440, 0.4729]
167 | svhn_std = [0.1980, 0.2010, 0.1970]
168 | transform_train = transforms.Compose([
169 | transforms.ToTensor(),
170 | transforms.Normalize(svhn_mean, svhn_std),
171 | ])
172 | transform_test = transforms.Compose([
173 | transforms.ToTensor(),
174 | transforms.Normalize(svhn_mean, svhn_std),
175 | ])
176 | dataset_info['mean'] = svhn_mean
177 | dataset_info['std'] = svhn_std
178 | dataset_info['img_dims'] = (3, 32, 32)
179 | dataset_info['num_labels'] = 10
180 | elif 'imagenet' in dataset and C.get()['aug'] in ['DeepAA']:
181 | transform_train = transforms.Compose([
182 | transforms.ToTensor(),
183 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Image size (224, 224) instead of (224, 244) in TA
184 | ])
185 |
186 | transform_test = transforms.Compose([
187 | transforms.Resize(256, interpolation=Image.BICUBIC),
188 | transforms.CenterCrop((224,224)),
189 | transforms.ToTensor(),
190 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
191 | ])
192 | dataset_info['mean'] = [0.485, 0.456, 0.406]
193 | dataset_info['std'] = [0.229, 0.224, 0.225]
194 | dataset_info['img_dims'] = (3,224,224)
195 | dataset_info['num_labels'] = 1000
196 | elif 'imagenet' in dataset and C.get()['aug']=='inception':
197 | transform_train = transforms.Compose([
198 | transforms.RandomResizedCrop((224,224), scale=(0.08, 1.0), interpolation=Image.BICUBIC), # Image size (224, 224) instead of (224, 244) in TA
199 | transforms.RandomHorizontalFlip(),
200 | transforms.ColorJitter(
201 | brightness=0.4,
202 | contrast=0.4,
203 | saturation=0.4,
204 | ),
205 | transforms.ToTensor(),
206 | Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']),
207 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
208 | ])
209 |
210 | transform_test = transforms.Compose([
211 | transforms.Resize(256, interpolation=Image.BICUBIC),
212 | transforms.CenterCrop((224,224)),
213 | transforms.ToTensor(),
214 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
215 | ])
216 | dataset_info['mean'] = [0.485, 0.456, 0.406]
217 | dataset_info['std'] = [0.229, 0.224, 0.225]
218 | dataset_info['img_dims'] = (3,224,224)
219 | dataset_info['num_labels'] = 1000
220 | elif 'smallwidth_imagenet' in dataset:
221 | transform_train = transforms.Compose([
222 | transforms.RandomResizedCrop((224,224), scale=(0.08, 1.0), interpolation=Image.BICUBIC),
223 | transforms.RandomHorizontalFlip(),
224 | transforms.ColorJitter(
225 | brightness=0.4,
226 | contrast=0.4,
227 | saturation=0.4,
228 | ),
229 | transforms.ToTensor(),
230 | Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']),
231 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
232 | ])
233 |
234 | transform_test = transforms.Compose([
235 | transforms.Resize(256, interpolation=Image.BICUBIC),
236 | transforms.CenterCrop((224,224)),
237 | transforms.ToTensor(),
238 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
239 | ])
240 | dataset_info['mean'] = [0.485, 0.456, 0.406]
241 | dataset_info['std'] = [0.229, 0.224, 0.225]
242 | dataset_info['img_dims'] = (3,224,224)
243 | dataset_info['num_labels'] = 1000
244 | elif 'ohl_pipeline_imagenet' in dataset:
245 | pre_transform_train = transforms.Compose([
246 | transforms.RandomResizedCrop((224, 224), scale=(0.08, 1.0), interpolation=Image.BICUBIC),
247 | transforms.RandomHorizontalFlip(),
248 | ])
249 | transform_train = transforms.Compose([
250 | transforms.ToTensor(),
251 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[1.,1.,1.])
252 | ])
253 |
254 | transform_test = transforms.Compose([
255 | transforms.Resize(256, interpolation=Image.BICUBIC),
256 | transforms.CenterCrop((224,224)),
257 | transforms.ToTensor(),
258 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[1.,1.,1.])
259 | ])
260 | dataset_info['mean'] = [0.485, 0.456, 0.406]
261 | dataset_info['std'] = [1.,1.,1.]
262 | dataset_info['img_dims'] = (3,224,224)
263 | dataset_info['num_labels'] = 1000
264 | elif 'largewidth_imagenet' in dataset:
265 | transform_train = transforms.Compose([
266 | transforms.RandomResizedCrop((224, 244), scale=(0.08, 1.0), interpolation=Image.BICUBIC),
267 | transforms.RandomHorizontalFlip(),
268 | transforms.ColorJitter(
269 | brightness=0.4,
270 | contrast=0.4,
271 | saturation=0.4,
272 | ),
273 | transforms.ToTensor(),
274 | Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']),
275 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
276 | ])
277 |
278 | transform_test = transforms.Compose([
279 | transforms.Resize(256, interpolation=Image.BICUBIC),
280 | transforms.CenterCrop((224, 244)),
281 | transforms.ToTensor(),
282 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
283 | ])
284 | dataset_info['mean'] = [0.485, 0.456, 0.406]
285 | dataset_info['std'] = [0.229, 0.224, 0.225]
286 | dataset_info['img_dims'] = (3, 224, 244)
287 | dataset_info['num_labels'] = 1000
288 | else:
289 | raise ValueError('dataset=%s' % dataset)
290 |
291 | logger.debug('augmentation: %s' % C.get()['aug'])
292 | if C.get()['aug'] == 'randaugment':
293 | assert not C.get()['randaug'].get('corrected_sample_space') and not C.get()['randaug'].get('google_augmentations')
294 | transform_train.transforms.insert(0, get_randaugment(n=C.get()['randaug']['N'], m=C.get()['randaug']['M'],
295 | weights=C.get()['randaug'].get('weights',None), bs=C.get()['batch']))
296 | elif C.get()['aug'] in ['default', 'inception', 'inception320']:
297 | pass
298 | elif C.get()['aug'] in ['DeepAA']:
299 | transform_train.transforms.insert(0, Augmentation_DeepAA(EXP = C.get()['deepaa']['EXP'],
300 | use_crop = ('imagenet' in dataset) and C.get()['aug'] == 'DeepAA'
301 | ))
302 | else:
303 | raise ValueError('not found augmentations. %s' % C.get()['aug'])
304 |
305 | transform_train.transforms.insert(0, pre_transform_train)
306 |
307 | if C.get()['cutout'] > 0:
308 | transform_train.transforms.append(CutoutDefault(C.get()['cutout']))
309 |
310 | if 'preprocessor' in C.get():
311 | if 'imagenet' in dataset:
312 | print("Only using cropping/centering transforms on dataset, since preprocessor active.")
313 | transform_train = transforms.Compose([
314 | transforms.RandomResizedCrop(224, scale=(0.08, 1.0), interpolation=Image.BICUBIC),
315 | PILImageToHWCByteTensor(),
316 | ])
317 |
318 | transform_test = transforms.Compose([
319 | transforms.Resize(256, interpolation=Image.BICUBIC),
320 | transforms.CenterCrop(224),
321 | PILImageToHWCByteTensor(),
322 | ])
323 | else:
324 | print("Not using any transforms in dataset, since preprocessor is active.")
325 | transform_train = PILImageToHWCByteTensor()
326 | transform_test = PILImageToHWCByteTensor()
327 |
328 | if dataset in ('cifar10', 'pre_transform_cifar10'):
329 | total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train)
330 | testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test)
331 | elif dataset in ('cifar100', 'pre_transform_cifar100'):
332 | total_trainset = torchvision.datasets.CIFAR100(root=dataroot, train=True, download=True, transform=transform_train)
333 | testset = torchvision.datasets.CIFAR100(root=dataroot, train=False, download=True, transform=transform_test)
334 | elif dataset == 'svhncore':
335 | total_trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True,
336 | transform=transform_train)
337 | testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test)
338 | elif dataset == 'svhn':
339 | trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train)
340 | extraset = torchvision.datasets.SVHN(root=dataroot, split='extra', download=True, transform=transform_train)
341 | total_trainset = ConcatDataset([trainset, extraset])
342 | testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test)
343 | elif dataset in ('imagenet', 'ohl_pipeline_imagenet', 'smallwidth_imagenet'):
344 | # Ignore archive only means to not to try to extract the files again, because they already are and the zip files
345 | # are not there no more
346 | total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train, ignore_archive=True)
347 | testset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform_test, ignore_archive=True)
348 |
349 | # compatibility
350 | total_trainset.targets = [lb for _, lb in total_trainset.samples]
351 | else:
352 | raise ValueError('invalid dataset name=%s' % dataset)
353 |
354 | if 'throwaway_share_of_ds' in C.get():
355 | assert 'val_step_trainloader_val_share' not in C.get()
356 | share = C.get()['throwaway_share_of_ds']['throwaway_share']
357 | train_subset_inds, rest_inds = stratified_split(total_trainset.targets if hasattr(total_trainset, 'targets') else list(total_trainset.labels),share)
358 | if C.get()['throwaway_share_of_ds']['use_throwaway_as_val']:
359 | testset = copy_and_replace_transform(Subset(total_trainset, rest_inds), transform_test)
360 | total_trainset = Subset(total_trainset, train_subset_inds)
361 |
362 | train_sampler = None
363 | if split > 0.0:
364 | sss = StratifiedShuffleSplit(n_splits=5, test_size=split, random_state=0)
365 | sss = sss.split(list(range(len(total_trainset))), total_trainset.targets)
366 | for _ in range(split_idx + 1):
367 | train_idx, valid_idx = next(sss)
368 |
369 | train_sampler = SubsetRandomSampler(train_idx)
370 | valid_sampler = SubsetSampler(valid_idx)
371 | else:
372 | valid_sampler = SubsetSampler([])
373 |
374 | if distributed:
375 | assert split == 0.0, "Split not supported for distributed training."
376 | if C.get().get('all_workers_use_the_same_batches', False):
377 | train_sampler = DistributedSampler(total_trainset, num_replicas=1, rank=0)
378 | else:
379 | train_sampler = DistributedSampler(total_trainset)
380 | test_sampler = None
381 | test_train_sampler = None # if these are specified, acc/loss computation is wrong for results.
382 | # while one has to say, that this setting leads to the test sets being computed seperately on each gpu which
383 | # might be considered not-very-climate-friendly
384 | else:
385 | test_sampler = None
386 | test_train_sampler = None
387 |
388 | trainloader = torch.utils.data.DataLoader(
389 | total_trainset, batch_size=batch, shuffle=train_sampler is None, num_workers= os.cpu_count()//8 if distributed else 32, # fix the data laoder
390 | pin_memory=True,
391 | sampler=train_sampler, drop_last=True, persistent_workers=True)
392 | validloader = torch.utils.data.DataLoader(
393 | total_trainset, batch_size=batch, shuffle=False, num_workers=0 if started_with_spawn else 8, pin_memory=True,
394 | sampler=valid_sampler, drop_last=False)
395 |
396 | testloader = torch.utils.data.DataLoader(
397 | testset, batch_size=batch, shuffle=False, num_workers=16 if started_with_spawn else 8, pin_memory=True,
398 | drop_last=False, sampler=test_sampler, persistent_workers=True
399 | )
400 | # We use this 'hacky' solution s.t. we do not need to keep the dataset twice in memory.
401 | test_total_trainset = copy_and_replace_transform(total_trainset, transform_test)
402 | test_trainloader = torch.utils.data.DataLoader(
403 | test_total_trainset, batch_size=batch, shuffle=False, num_workers=0 if started_with_spawn else 8, pin_memory=True,
404 | drop_last=False, sampler=test_train_sampler
405 | )
406 | test_trainloader.denorm = lambda x: denormalize(x, dataset_info['mean'], dataset_info['std'])
407 |
408 | return train_sampler, trainloader, validloader, testloader, test_trainloader, dataset_info
409 | # trainloader_prefetch = PrefetchedWrapper(trainloader, start_epoch=0, num_classes=None, one_hot=False)
410 | # testloader_prefetch = PrefetchedWrapper(testloader, start_epoch=0, num_classes=None, one_hot=False)
411 | # return train_sampler, trainloader_prefetch, validloader, testloader_prefetch, test_trainloader, dataset_info
412 |
413 |
414 | class SubsetSampler(Sampler):
415 | r"""Samples elements from a given list of indices, without replacement.
416 |
417 | Arguments:
418 | indices (sequence): a sequence of indices
419 | """
420 |
421 | def __init__(self, indices):
422 | self.indices = indices
423 |
424 | def __iter__(self):
425 | return (i for i in self.indices)
426 |
427 | def __len__(self):
428 | return len(self.indices)
--------------------------------------------------------------------------------
/DeepAA_evaluate/deep_autoaugment.py:
--------------------------------------------------------------------------------
1 | # code in this file is adpated from rpmcruz/autoaugment
2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
3 | import random
4 | import math
5 |
6 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
7 | import numpy as np
8 | import torch
9 | import os
10 | import json
11 | import hashlib
12 | import requests
13 | import scipy
14 | from torchvision.transforms.transforms import Compose
15 |
16 | random_mirror = True
17 |
18 | ##########################################################################
19 | CIFAR_MEANS = np.array([0.49139968, 0.48215841, 0.44653091], dtype=np.float32)
20 | # CIFAR10_STDS = np.array([0.24703223, 0.24348513, 0.26158784], dtype=np.float32)
21 | CIFAR_STDS = np.array([0.2023, 0.1994, 0.2010], dtype=np.float32)
22 |
23 | SVHN_MEANS = np.array([0.4379, 0.4440, 0.4729], dtype=np.float32)
24 | SVHN_STDS = np.array([0.1980, 0.2010, 0.1970], dtype=np.float32)
25 |
26 | IMAGENET_MEANS = np.array([0.485, 0.456, 0.406], dtype=np.float32)
27 | IMAGENET_STDS = np.array([0.229, 0.224, 0.225], dtype=np.float32)
28 |
29 | def ShearX(img, v): # [-0.3, 0.3]
30 | assert -0.3 <= v <= 0.3
31 | if random_mirror and random.random() > 0.5:
32 | v = -v
33 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
34 |
35 |
36 | def ShearY(img, v): # [-0.3, 0.3]
37 | assert -0.3 <= v <= 0.3
38 | if random_mirror and random.random() > 0.5:
39 | v = -v
40 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
41 |
42 |
43 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
44 | assert -0.45 <= v <= 0.45
45 | if random_mirror and random.random() > 0.5:
46 | v = -v
47 | v = v * img.size[0]
48 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
49 |
50 |
51 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
52 | assert -0.45 <= v <= 0.45
53 | if random_mirror and random.random() > 0.5:
54 | v = -v
55 | v = v * img.size[1]
56 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
57 |
58 |
59 | def TranslateXAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
60 | assert 0 <= v <= 10
61 | if random_mirror and random.random() > 0.5:
62 | v = -v
63 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
64 |
65 |
66 | def TranslateYAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
67 | assert 0 <= v <= 10
68 | if random_mirror and random.random() > 0.5:
69 | v = -v
70 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
71 |
72 |
73 | def Rotate(img, v): # [-30, 30]
74 | assert -30 <= v <= 30
75 | if random_mirror and random.random() > 0.5:
76 | v = -v
77 | return img.rotate(v)
78 |
79 |
80 | def AutoContrast(img, _):
81 | return PIL.ImageOps.autocontrast(img)
82 |
83 |
84 | def Invert(img, _):
85 | return PIL.ImageOps.invert(img)
86 |
87 |
88 | def Equalize(img, _):
89 | return PIL.ImageOps.equalize(img)
90 |
91 |
92 | def Flip(img, _): # not from the paper
93 | return PIL.ImageOps.mirror(img)
94 |
95 |
96 | def Solarize(img, v): # [0, 256]
97 | assert 0 <= v <= 256
98 | return PIL.ImageOps.solarize(img, v)
99 |
100 |
101 | def Posterize(img, v): # [4, 8]
102 | assert 4 <= v <= 8
103 | v = int(v)
104 | v = max(1, v)
105 | return PIL.ImageOps.posterize(img, v)
106 |
107 |
108 | def Posterize2(img, v): # [0, 4]
109 | assert 0 <= v <= 4
110 | v = int(v)
111 | return PIL.ImageOps.posterize(img, v)
112 |
113 |
114 | def Contrast(img, v): # [0.1,1.9]
115 | assert 0.1 <= v <= 1.9
116 | return PIL.ImageEnhance.Contrast(img).enhance(v)
117 |
118 |
119 | def Color(img, v): # [0.1,1.9]
120 | assert 0.1 <= v <= 1.9
121 | return PIL.ImageEnhance.Color(img).enhance(v)
122 |
123 |
124 | def Brightness(img, v): # [0.1,1.9]
125 | assert 0.1 <= v <= 1.9
126 | return PIL.ImageEnhance.Brightness(img).enhance(v)
127 |
128 |
129 | def Sharpness(img, v): # [0.1,1.9]
130 | assert 0.1 <= v <= 1.9
131 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
132 |
133 |
134 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
135 | assert 0.0 <= v <= 0.2
136 | if v <= 0.:
137 | return img
138 |
139 | v = v * img.size[0]
140 | return Cutout_default(img, v)
141 |
142 |
143 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
144 | # assert 0 <= v <= 20
145 | if v < 0:
146 | return img
147 | w, h = img.size
148 | # x0 = np.random.uniform(w)
149 | # y0 = np.random.uniform(h)
150 | x0 = random.uniform(0, w)
151 | y0 = random.uniform(0, h)
152 |
153 | x0 = int(max(0, x0 - v / 2.))
154 | y0 = int(max(0, y0 - v / 2.))
155 | x1 = min(w, x0 + v)
156 | y1 = min(h, y0 + v)
157 |
158 | xy = (x0, y0, x1, y1)
159 | # color = (125, 123, 114)
160 | color = (0, 0, 0)
161 | img = img.copy()
162 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
163 | return img
164 |
165 |
166 | def SamplePairing(imgs): # [0, 0.4]
167 | def f(img1, v):
168 | i = np.random.choice(len(imgs))
169 | img2 = PIL.Image.fromarray(imgs[i])
170 | return PIL.Image.blend(img1, img2, v)
171 |
172 | return f
173 |
174 | # =============== OPS for DeepAA ==============:
175 | def mean_pad_randcrop(img, v):
176 | # v: Pad with mean value=[125, 123, 114] by v pixels on each side and then take random crop
177 | assert v <= 10, 'The maximum shift should be less then 10'
178 | padded_size = (img.size[0] + 2*v, img.size[1] + 2*v)
179 | new_img = PIL.Image.new('RGB', padded_size, color=(125, 123, 114))
180 | # new_img = PIL.Image.new('RGB', padded_size, color=(0, 0, 0))
181 | new_img.paste(img, (v, v))
182 | top = random.randint(0, v*2)
183 | left = random.randint(0, v*2)
184 | new_img = new_img.crop((left, top, left + img.size[0], top + img.size[1]))
185 | return new_img
186 |
187 |
188 |
189 | def Cutout_default(img, v): # Used in FastAA, different from CutoutABS, the actual cutout size can be smaller than v on the boundary
190 | # Passed random number generation test
191 | # assert 0 <= v <= 20
192 | if v < 0:
193 | return img
194 | w, h = img.size
195 | # x = np.random.uniform(w)
196 | # y = np.random.uniform(h)
197 | if v <= 16: # for cutout of cifar and SVHN
198 | assert w == h == 32
199 | x = random.uniform(0, w)
200 | y = random.uniform(0, h)
201 |
202 | x0 = int(min(w, max(0, x - v // 2))) # clip to the range (0, w)
203 | x1 = int(min(w, max(0, x + v // 2)))
204 | y0 = int(min(h, max(0, y - v // 2)))
205 | y1 = int(min(h, max(0, y + v // 2)))
206 |
207 | xy = (x0, y0, x1, y1)
208 | color = (125, 123, 114)
209 | # color = (0, 0, 0)
210 | img = img.copy()
211 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
212 | # img = CutoutAbs(img, v)
213 | return img
214 | else:
215 | raise NotImplementedError
216 |
217 | def RandCrop(img, _):
218 | v = 4
219 | return mean_pad_randcrop(img, v)
220 |
221 | def RandCutout(img, _):
222 | v = 16 # Cutout 0.5 means 0.5*32=16 pixels as in the FastAA paper
223 | return Cutout_default(img, v)
224 |
225 | def RandCutout60(img, _):
226 | v = 60 # Cutout 0.5 means 0.5*32=16 pixels as in the FastAA paper
227 | return Cutout_default(img, v)
228 |
229 | def RandFlip(img, _):
230 | if random.random() > 0.5:
231 | img = Flip(img, None)
232 | return img
233 |
234 | def Identity(img, _):
235 | return img
236 |
237 | # ===================== ops for imagenet =============
238 | def RandResizeCrop_imagenet(img, _):
239 | # ported from torchvision
240 | # for ImageNet use only
241 | scale = (0.08, 1.0)
242 | ratio = (3. / 4., 4. / 3.)
243 | size = IMAGENET_SIZE # (224, 224)
244 |
245 | def get_params(img, scale, ratio):
246 | width, height = img.size
247 | area = float(width * height)
248 | log_ratio = [math.log(r) for r in ratio]
249 |
250 | for _ in range(10):
251 | target_area = area * random.uniform(scale[0], scale[1])
252 | aspect_ratio = math.exp(random.uniform(log_ratio[0], log_ratio[1]))
253 |
254 | w = round(math.sqrt(target_area * aspect_ratio))
255 | h = round(math.sqrt(target_area / aspect_ratio))
256 | if 0 < w <= width and 0 < h <= height:
257 | top = random.randint(0, height - h)
258 | left = random.randint(0, width - w)
259 | return left, top, w, h
260 |
261 | # fallback to central crop
262 | in_ratio = float(width) / float(height)
263 | if in_ratio < min(ratio):
264 | w = width
265 | h = round(w / min(ratio))
266 | elif in_ratio > max(ratio):
267 | h = height
268 | w = round(h * max(ratio))
269 | else:
270 | w = width
271 | h = height
272 | top = (height - h) // 2
273 | left = (width - w) // 2
274 | return left, top, w, h
275 |
276 | left, top, w_box, h_box = get_params(img, scale, ratio)
277 | box = (left, top, left + w_box, top + h_box)
278 | img = img.resize(size=size, resample=PIL.Image.CUBIC, box=box)
279 | return img
280 |
281 |
282 | def Resize_imagenet(img, size):
283 | w, h = img.size
284 | if isinstance(size, int):
285 | short, long = (w, h) if w <= h else (h, w)
286 | if short == size:
287 | return img
288 | new_short, new_long = size, int(size * long / short)
289 | new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
290 | return img.resize((new_w, new_h), PIL.Image.BICUBIC)
291 | elif isinstance(size, tuple) or isinstance(size, list):
292 | assert len(size) == 2, 'Check the size {}'.format(size)
293 | return img.resize(size, PIL.Image.BICUBIC)
294 | else:
295 | raise Exception
296 |
297 |
298 | def centerCrop_imagenet(img, _):
299 | # for ImageNet only
300 | # https://github.com/pytorch/vision/blob/master/torchvision/transforms/functional.py
301 | crop_width, crop_height = IMAGENET_SIZE # (224,224)
302 | image_width, image_height = img.size
303 |
304 | if crop_width > image_width or crop_height > image_height:
305 | padding_ltrb = [
306 | (crop_width - image_width) // 2 if crop_width > image_width else 0,
307 | (crop_height - image_height) // 2 if crop_height > image_height else 0,
308 | (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
309 | (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
310 | ]
311 | img = pad(img, padding_ltrb, fill=0)
312 | image_width, image_height = img.size
313 | if crop_width == image_width and crop_height == image_height:
314 | return img
315 |
316 | crop_top = int(round((image_height - crop_height) / 2.))
317 | crop_left = int(round((image_width - crop_width) / 2.))
318 | return img.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height))
319 |
320 |
321 | def _parse_fill(fill, img, name="fillcolor"):
322 | # Process fill color for affine transforms
323 | num_bands = len(img.getbands())
324 | if fill is None:
325 | fill = 0
326 | if isinstance(fill, (int, float)) and num_bands > 1:
327 | fill = tuple([fill] * num_bands)
328 | if isinstance(fill, (list, tuple)):
329 | if len(fill) != num_bands:
330 | msg = ("The number of elements in 'fill' does not match the number of "
331 | "bands of the image ({} != {})")
332 | raise ValueError(msg.format(len(fill), num_bands))
333 |
334 | fill = tuple(fill)
335 |
336 | return {name: fill}
337 |
338 |
339 | def pad(img, padding_ltrb, fill=0, padding_mode='constant'):
340 | if isinstance(padding_ltrb, list):
341 | padding_ltrb = tuple(padding_ltrb)
342 | if padding_mode == 'constant':
343 | opts = _parse_fill(fill, img, name='fill')
344 | if img.mode == 'P':
345 | palette = img.getpalette()
346 | image = PIL.ImageOps.expand(img, border=padding_ltrb, **opts)
347 | image.putpalette(palette)
348 | return image
349 | return PIL.ImageOps.expand(img, border=padding_ltrb, **opts)
350 | elif len(padding_ltrb) == 4:
351 | image_width, image_height = img.size
352 | cropping = -np.minimum(padding_ltrb, 0)
353 | if cropping.any():
354 | crop_left, crop_top, crop_right, crop_bottom = cropping
355 | img = img.crop((crop_left, crop_top, image_width - crop_right, image_height - crop_bottom))
356 | pad_left, pad_top, pad_right, pad_bottom = np.maximum(padding_ltrb, 0)
357 |
358 | if img.mode == 'P':
359 | palette = img.getpalette()
360 | img = np.asarray(img)
361 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
362 | img = PIL.Image.fromarray(img)
363 | img.putpalette(palette)
364 | return img
365 |
366 | img = np.asarray(img)
367 | # RGB image
368 | if len(img.shape) == 3:
369 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
370 | # Grayscale image
371 | if len(img.shape) == 2:
372 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
373 |
374 | return PIL.Image.fromarray(img)
375 | else:
376 | raise Exception
377 |
378 | def augment_list(for_autoaug=True, for_DeepAA_cifar=True, for_DeepAA_imagenet=True): # 16 oeprations and their ranges
379 | l = [
380 | (ShearX, -0.3, 0.3), # 0
381 | (ShearY, -0.3, 0.3), # 1
382 | (TranslateX, -0.45, 0.45), # 2
383 | (TranslateY, -0.45, 0.45), # 3
384 | (Rotate, -30, 30), # 4
385 | (AutoContrast, 0, 1), # 5
386 | (Invert, 0, 1), # 6
387 | (Equalize, 0, 1), # 7
388 | (Solarize, 0, 256), # 8
389 | (Posterize, 4, 8), # 9
390 | (Contrast, 0.1, 1.9), # 10
391 | (Color, 0.1, 1.9), # 11
392 | (Brightness, 0.1, 1.9), # 12
393 | (Sharpness, 0.1, 1.9), # 13
394 | (Cutout, 0, 0.2), # 14
395 | # (SamplePairing(imgs), 0, 0.4), # 15
396 | ]
397 | if for_autoaug:
398 | l += [
399 | (CutoutAbs, 0, 20), # compatible with auto-augment
400 | (Posterize2, 0, 4), # 9
401 | (TranslateXAbs, 0, 10), # 9
402 | (TranslateYAbs, 0, 10), # 9
403 | ]
404 | if for_DeepAA_cifar:
405 | l += [
406 | (Identity, 0., 1.0),
407 | (RandFlip, 0., 1.0), # Additional 15
408 | (RandCutout, 0., 1.0), # 16
409 | (RandCrop, 0., 1.0), # 17
410 | ]
411 | if for_DeepAA_imagenet:
412 | l += [
413 | (RandResizeCrop_imagenet, 0., 1.0),
414 | (RandCutout60, 0., 1.0)
415 | ]
416 |
417 | return l
418 |
419 |
420 | augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()}
421 |
422 | def Cutout16(img, _):
423 | # return CutoutAbs(img, 16)
424 | return Cutout_default(img, 16)
425 |
426 | augmentation_TA_list = [
427 | (Identity, 0., 1.0),
428 | (ShearX, -0.3, 0.3), # 0
429 | (ShearY, -0.3, 0.3), # 1
430 | (TranslateX, -0.45, 0.45), # 2
431 | (TranslateY, -0.45, 0.45), # 3
432 | (Rotate, -30, 30), # 4
433 | (AutoContrast, 0, 1), # 5
434 | # (Invert, 0, 1), # 6
435 | (Equalize, 0, 1), # 7
436 | (Solarize, 0, 256), # 8
437 | (Posterize, 4, 8), # 9
438 | (Contrast, 0.1, 1.9), # 10
439 | (Color, 0.1, 1.9), # 11
440 | (Brightness, 0.1, 1.9), # 12
441 | (Sharpness, 0.1, 1.9), # 13
442 | (Flip, 0., 1.0), # Additional 15
443 | (Cutout16, 0, 20), # (RandCutout, 0, 20), # compatible with auto-augment
444 | (RandCrop, 0., 1.0), # 17
445 | ]
446 |
447 |
448 | def get_augment(name):
449 | return augment_dict[name]
450 |
451 |
452 | def apply_augment(img, name, level):
453 | augment_fn, low, high = get_augment(name)
454 | return augment_fn(img.copy(), level * (high - low) + low)
455 |
456 |
457 | class Lighting(object):
458 | """Lighting noise(AlexNet - style PCA - based noise)"""
459 |
460 | def __init__(self, alphastd, eigval, eigvec):
461 | self.alphastd = alphastd
462 | self.eigval = torch.Tensor(eigval)
463 | self.eigvec = torch.Tensor(eigvec)
464 |
465 | def __call__(self, img):
466 | if self.alphastd == 0:
467 | return img
468 |
469 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
470 | rgb = self.eigvec.type_as(img).clone() \
471 | .mul(alpha.view(1, 3).expand(3, 3)) \
472 | .mul(self.eigval.view(1, 3).expand(3, 3)) \
473 | .sum(1).squeeze()
474 |
475 | return img.add(rgb.view(3, 1, 1).expand_as(img))
476 |
477 |
478 | class Augmentation_DeepAA(object):
479 | def __init__(self, EXP='cifar', use_crop=False):
480 | self.use_crop = use_crop
481 | policy_data = np.load('./policy_port/policy_DeepAA_{}.npz'.format(EXP))
482 | self.policy_probs = policy_data['policy_probs']
483 |
484 | self.l_ops = policy_data['l_ops']
485 | self.l_mags = policy_data['l_mags']
486 | self.ops = policy_data['ops']
487 | self.mags = policy_data['mags']
488 | self.op_names = policy_data['op_names']
489 |
490 | def __call__(self, img):
491 | for k_policy in self.policy_probs:
492 | k_samp = random.choices(range(len(k_policy)), weights=k_policy, k=1)[0]
493 | op, mag = np.squeeze(self.ops[k_samp]), np.squeeze(self.mags[k_samp]).astype(np.float32)/float(self.l_mags-1)
494 | op_name = self.op_names[op].split(':')[0]
495 | img = apply_augment(img, op_name, mag)
496 | if self.use_crop:
497 | w, h = img.size
498 | if w==IMAGENET_SIZE[0] and h==IMAGENET_SIZE[1]:
499 | return img
500 | # return centerCrop_imagenet(Resize_imagenet(img, 256), None)
501 | return centerCrop_imagenet(img, None)
502 | return img
503 |
504 |
505 | IMAGENET_SIZE = (224, 224)
--------------------------------------------------------------------------------
/DeepAA_evaluate/imagenet.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets.imagenet import *
2 |
3 | class ImageNet(ImageFolder):
4 | """`ImageNet `_ 2012 Classification Dataset.
5 | Copied from torchvision, besides warning below.
6 |
7 | Args:
8 | root (string): Root directory of the ImageNet Dataset.
9 | split (string, optional): The dataset split, supports ``train``, or ``val``.
10 | transform (callable, optional): A function/transform that takes in an PIL image
11 | and returns a transformed version. E.g, ``transforms.RandomCrop``
12 | target_transform (callable, optional): A function/transform that takes in the
13 | target and transforms it.
14 | loader (callable, optional): A function to load an image given its path.
15 |
16 | Attributes:
17 | classes (list): List of the class name tuples.
18 | class_to_idx (dict): Dict with items (class_name, class_index).
19 | wnids (list): List of the WordNet IDs.
20 | wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
21 | imgs (list): List of (image path, class_index) tuples
22 | targets (list): The class_index value for each image in the dataset
23 |
24 | WARN::
25 | This is the same ImageNet class as in torchvision.datasets.imagenet, but it has the `ignore_archive` argument.
26 | This allows us to only copy the unzipped files before training.
27 | """
28 |
29 | def __init__(self, root, split='train', download=None, ignore_archive=False, **kwargs):
30 | if download is True:
31 | msg = ("The dataset is no longer publicly accessible. You need to "
32 | "download the archives externally and place them in the root "
33 | "directory.")
34 | raise RuntimeError(msg)
35 | elif download is False:
36 | msg = ("The use of the download flag is deprecated, since the dataset "
37 | "is no longer publicly accessible.")
38 | warnings.warn(msg, RuntimeWarning)
39 |
40 | root = self.root = os.path.expanduser(root)
41 | self.split = verify_str_arg(split, "split", ("train", "val"))
42 |
43 | if not ignore_archive:
44 | self.parse_archives()
45 | wnid_to_classes = load_meta_file(self.root)[0]
46 |
47 | super(ImageNet, self).__init__(self.split_folder, **kwargs)
48 | self.root = root
49 |
50 | self.wnids = self.classes
51 | self.wnid_to_idx = self.class_to_idx
52 | self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
53 | self.class_to_idx = {cls: idx
54 | for idx, clss in enumerate(self.classes)
55 | for cls in clss}
56 |
57 | def parse_archives(self):
58 | if not check_integrity(os.path.join(self.root, META_FILE)):
59 | parse_devkit_archive(self.root)
60 |
61 | if not os.path.isdir(self.split_folder):
62 | if self.split == 'train':
63 | parse_train_archive(self.root)
64 | elif self.split == 'val':
65 | parse_val_archive(self.root)
66 |
67 | @property
68 | def split_folder(self):
69 | return os.path.join(self.root, self.split)
70 |
71 | def extra_repr(self):
72 | return "Split: {split}".format(**self.__dict__)
--------------------------------------------------------------------------------
/DeepAA_evaluate/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from theconf import Config as C
4 |
5 |
6 | def adjust_learning_rate_resnet(optimizer):
7 | """
8 | Sets the learning rate to the initial LR decayed by 10 on every predefined epochs
9 | Ref: AutoAugment
10 | """
11 |
12 | if C.get()['epoch'] == 90:
13 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 80])
14 | elif C.get()['epoch'] == 180:
15 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 120, 160])
16 | elif C.get()['epoch'] == 270:
17 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, [90, 180, 240])
18 | else:
19 | raise ValueError('invalid epoch=%d for resnet scheduler' % C.get()['epoch'])
20 |
--------------------------------------------------------------------------------
/DeepAA_evaluate/metrics.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 | from collections import defaultdict
5 |
6 | from torch import nn
7 |
8 |
9 | def accuracy(output, target, topk=(1,)):
10 | """Computes the precision@k for the specified values of k"""
11 | maxk = max(topk)
12 | batch_size = target.size(0)
13 |
14 | _, pred = output.topk(maxk, 1, True, True)
15 | pred = pred.t()
16 | correct = pred.eq(target.view(1, -1).expand_as(pred))
17 |
18 | res = []
19 | for k in topk:
20 | correct_k = correct[:k].flatten().float().sum(0)
21 | res.append(correct_k.mul_(1. / batch_size))
22 | return res
23 |
24 |
25 | def cross_entropy_smooth(input, target, size_average=True, label_smoothing=0.1):
26 | y = torch.eye(10).cuda()
27 | lb_oh = y[target]
28 |
29 | target = lb_oh * (1 - label_smoothing) + 0.5 * label_smoothing
30 |
31 | logsoftmax = nn.LogSoftmax()
32 | if size_average:
33 | return torch.mean(torch.sum(-target * logsoftmax(input), dim=1))
34 | else:
35 | return torch.sum(torch.sum(-target * logsoftmax(input), dim=1))
36 |
37 |
38 | class Accumulator:
39 | def __init__(self):
40 | self.metrics = defaultdict(lambda: 0.)
41 |
42 | def add(self, key, value):
43 | self.metrics[key] += value
44 |
45 | def add_dict(self, dict):
46 | for key, value in dict.items():
47 | self.add(key, value)
48 |
49 | def __getitem__(self, item):
50 | return self.metrics[item]
51 |
52 | def __setitem__(self, key, value):
53 | self.metrics[key] = value
54 |
55 | def __contains__(self, item):
56 | return self.metrics.__contains__(item)
57 |
58 | def get_dict(self):
59 | return copy.deepcopy(dict(self.metrics))
60 |
61 | def items(self):
62 | return self.metrics.items()
63 |
64 | def __str__(self):
65 | return str(dict(self.metrics))
66 |
67 | def __truediv__(self, other):
68 | newone = Accumulator()
69 | for key, value in self.items():
70 | newone[key] = value / other
71 | return newone
72 |
73 | def divide(self, divisor, **special_divisors):
74 | newone = Accumulator()
75 | for key, value in self.items():
76 | if key in special_divisors:
77 | newone[key] = value/special_divisors[key]
78 | else:
79 | newone[key] = value/divisor
80 | return newone
81 |
82 |
83 | class SummaryWriterDummy:
84 | def __init__(self, log_dir):
85 | pass
86 |
87 | def add_scalar(self, *args, **kwargs):
88 | pass
89 |
90 | def add_image(self, *args, **kwargs):
91 | pass
--------------------------------------------------------------------------------
/DeepAA_evaluate/networks/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch import nn
4 | from torch.nn import DataParallel
5 | import torch.backends.cudnn as cudnn
6 | # from torchvision import models
7 |
8 | from DeepAA_evaluate.networks.resnet import ResNet
9 | from DeepAA_evaluate.networks.shakeshake.shake_resnet import ShakeResNet
10 | from DeepAA_evaluate.networks.wideresnet import WideResNet
11 | from DeepAA_evaluate.networks.shakeshake.shake_resnext import ShakeResNeXt
12 | from DeepAA_evaluate.networks.convnet import SeqConvNet
13 | from DeepAA_evaluate.networks.mlp import MLP
14 | from DeepAA_evaluate.common import apply_weightnorm
15 |
16 |
17 |
18 | # example usage get_model(
19 | def get_model(conf, bs, num_class=10, writer=None):
20 | name = conf['type']
21 | ad_creators = (None,None)
22 |
23 |
24 | if name == 'resnet50':
25 | model = ResNet(dataset='imagenet', depth=50, num_classes=num_class, bottleneck=True)
26 | elif name == 'resnet200':
27 | model = ResNet(dataset='imagenet', depth=200, num_classes=num_class, bottleneck=True)
28 | elif name == 'resnet18':
29 | model = ResNet(dataset='imagenet', depth=18, num_classes=num_class, bottleneck=False)
30 | elif name == 'wresnet40_2':
31 | model = WideResNet(40, 2, dropout_rate=conf.get('dropout',0.0), num_classes=num_class, adaptive_dropouter_creator=ad_creators[0],adaptive_conv_dropouter_creator=ad_creators[1], groupnorm=conf.get('groupnorm', False), examplewise_bn=conf.get('examplewise_bn', False), virtual_bn=conf.get('virtual_bn', False))
32 | elif name == 'wresnet28_10':
33 | model = WideResNet(28, 10, dropout_rate=conf.get('dropout',0.0), num_classes=num_class, adaptive_dropouter_creator=ad_creators[0],adaptive_conv_dropouter_creator=ad_creators[1], groupnorm=conf.get('groupnorm',False), examplewise_bn=conf.get('examplewise_bn', False), virtual_bn=conf.get('virtual_bn', False))
34 | elif name == 'wresnet28_2':
35 | model = WideResNet(28, 2, dropout_rate=conf.get('dropout', 0.0), num_classes=num_class,
36 | adaptive_dropouter_creator=ad_creators[0], adaptive_conv_dropouter_creator=ad_creators[1],
37 | groupnorm=conf.get('groupnorm', False), examplewise_bn=conf.get('examplewise_bn', False),
38 | virtual_bn=conf.get('virtual_bn', False))
39 | elif name == 'miniconvnet':
40 | model = SeqConvNet(num_class,adaptive_dropout_creator=ad_creators[0],batch_norm=False)
41 | elif name == 'mlp':
42 | model = MLP(num_class, (3,32,32), adaptive_dropouter_creator=ad_creators[0])
43 | elif name == 'shakeshake26_2x96d':
44 | model = ShakeResNet(26, 96, num_class)
45 | elif name == 'shakeshake26_2x112d':
46 | model = ShakeResNet(26, 112, num_class)
47 | elif name == 'shakeshake26_2x96d_next':
48 | model = ShakeResNeXt(26, 96, 4, num_class)
49 | else:
50 | raise NameError('no model named, %s' % name)
51 |
52 | if conf.get('weight_norm', False):
53 | print('Using weight norm.')
54 | apply_weightnorm(model)
55 |
56 | #model = model.cuda()
57 | #model = DataParallel(model)
58 | cudnn.benchmark = True
59 | return model
60 |
61 |
62 | def num_class(dataset):
63 | return {
64 | 'cifar10': 10,
65 | 'noised_cifar10': 10,
66 | 'targetnoised_cifar10': 10,
67 | 'reduced_cifar10': 10,
68 | 'cifar10.1': 10,
69 | 'pre_transform_cifar10': 10,
70 | 'cifar100': 100,
71 | 'pre_transform_cifar100': 100,
72 | 'fiftyexample_cifar100': 100,
73 | 'tenclass_cifar100': 10,
74 | 'svhn': 10,
75 | 'svhncore': 10,
76 | 'reduced_svhn': 10,
77 | 'imagenet': 1000,
78 | 'smallwidth_imagenet': 1000,
79 | 'ohl_pipeline_imagenet': 1000,
80 | 'reduced_imagenet': 120,
81 | }[dataset]
82 |
--------------------------------------------------------------------------------
/DeepAA_evaluate/networks/convnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class SeqConvNet(nn.Module):
5 | def __init__(self,D_out,fixed_dropout=None,in_channels=3,channels=(64,64),h_dims=(200,100),adaptive_dropout_creator=None,batch_norm=False):
6 | super().__init__()
7 | print("Using SeqConvNet")
8 | assert len(channels) == 2 == len(h_dims)
9 | pool = lambda: nn.MaxPool2d(2,2)
10 | dropout = lambda: torch.nn.Dropout(p=fixed_dropout)
11 | dropout_li = lambda: ([] if fixed_dropout is None else [dropout()])
12 | relu = lambda: torch.nn.ReLU(inplace=False)
13 | flatten = lambda l: [item for sublist in l for item in sublist]
14 | convs = [nn.Conv2d(in_channels, channels[0], 5),nn.Conv2d(channels[0], channels[1], 5)]
15 | fcs = [nn.Linear(channels[1] * 5 * 5, h_dims[0]),nn.Linear(h_dims[0], h_dims[1])]
16 | self.final_fc = nn.Linear(h_dims[1], D_out)
17 | self.conv_blocks = nn.Sequential(*flatten([[conv,relu(),pool()] + dropout_li() for conv in convs]))
18 | self.bn = nn.BatchNorm1d(h_dims[1], momentum=.9) if batch_norm else nn.Identity()
19 | self.fc_blocks = nn.Sequential(*flatten([[fc,relu()] + dropout_li() for fc in fcs]))
20 | self.adaptive_dropouters = [adaptive_dropout_creator(h_dims[1])] if adaptive_dropout_creator is not None else []
21 |
22 | def forward(self, x):
23 | x = self.conv_blocks(x)
24 | x = torch.nn.Flatten()(x)
25 | x = self.fc_blocks(x)
26 | if self.adaptive_dropouters:
27 | x = self.adaptive_dropouters[0](x)
28 | x = self.bn(x)
29 | x = self.final_fc(x)
30 | return x
31 |
32 |
--------------------------------------------------------------------------------
/DeepAA_evaluate/networks/mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | def MLP(D_out,in_dims,adaptive_dropouter_creator):
6 | print('adaptive dropouter', adaptive_dropouter_creator)
7 | in_dim = 1
8 | for d in in_dims: in_dim *= d
9 | ada_dropper = adaptive_dropouter_creator(100) if adaptive_dropouter_creator is not None else None
10 | model = nn.Sequential(
11 | nn.Flatten(),
12 | nn.Linear(in_dim, 300),
13 | nn.Tanh(),
14 | nn.Linear(300,100),
15 | ada_dropper or nn.Identity(),
16 | nn.Tanh(),
17 | nn.Linear(100,D_out)
18 | )
19 | model.adaptive_dropouters = [ada_dropper] if ada_dropper is not None else []
20 | return model
21 |
--------------------------------------------------------------------------------
/DeepAA_evaluate/networks/resnet.py:
--------------------------------------------------------------------------------
1 | # Original code: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
2 | # gamma is initialized ot 0 in the last BN of each residual block
3 |
4 | import torch.nn as nn
5 | import math
6 |
7 |
8 | def conv3x3(in_planes, out_planes, stride=1):
9 | "3x3 convolution with padding"
10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
11 | padding=1, bias=False)
12 |
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, inplanes, planes, stride=1, downsample=None):
18 | super(BasicBlock, self).__init__()
19 | self.conv1 = conv3x3(inplanes, planes, stride)
20 | self.bn1 = nn.BatchNorm2d(planes)
21 | self.conv2 = conv3x3(planes, planes)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | nn.init.zeros_(self.bn2.weight)
24 | self.relu = nn.ReLU(inplace=True)
25 |
26 | self.downsample = downsample
27 | self.stride = stride
28 |
29 | def forward(self, x):
30 | residual = x
31 |
32 | out = self.conv1(x)
33 | out = self.bn1(out)
34 | out = self.relu(out)
35 |
36 | out = self.conv2(out)
37 | out = self.bn2(out)
38 |
39 | if self.downsample is not None:
40 | residual = self.downsample(x)
41 |
42 | out += residual
43 | out = self.relu(out)
44 |
45 | return out
46 |
47 |
48 | class Bottleneck(nn.Module):
49 | expansion = 4
50 |
51 | def __init__(self, inplanes, planes, stride=1, downsample=None):
52 | super(Bottleneck, self).__init__()
53 |
54 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
55 | self.bn1 = nn.BatchNorm2d(planes)
56 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
57 | self.bn2 = nn.BatchNorm2d(planes)
58 | self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False)
59 | self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion)
60 | nn.init.zeros_(self.bn3.weight)
61 | self.relu = nn.ReLU(inplace=True)
62 |
63 | self.downsample = downsample
64 | self.stride = stride
65 |
66 | def forward(self, x):
67 | residual = x
68 |
69 | out = self.conv1(x)
70 | out = self.bn1(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv2(out)
74 | out = self.bn2(out)
75 | out = self.relu(out)
76 |
77 | out = self.conv3(out)
78 | out = self.bn3(out)
79 | if self.downsample is not None:
80 | residual = self.downsample(x)
81 |
82 | out += residual
83 | out = self.relu(out)
84 |
85 | return out
86 |
87 | class ResNet(nn.Module):
88 | def __init__(self, dataset, depth, num_classes, bottleneck=False):
89 | super(ResNet, self).__init__()
90 | self.dataset = dataset
91 | if self.dataset.startswith('cifar'):
92 | self.inplanes = 16
93 | print(bottleneck)
94 | if bottleneck == True:
95 | n = int((depth - 2) / 9)
96 | block = Bottleneck
97 | else:
98 | n = int((depth - 2) / 6)
99 | block = BasicBlock
100 |
101 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
102 | self.bn1 = nn.BatchNorm2d(self.inplanes)
103 | self.relu = nn.ReLU(inplace=True)
104 | self.layer1 = self._make_layer(block, 16, n)
105 | self.layer2 = self._make_layer(block, 32, n, stride=2)
106 | self.layer3 = self._make_layer(block, 64, n, stride=2)
107 | # self.avgpool = nn.AvgPool2d(8)
108 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
109 | self.fc = nn.Linear(64 * block.expansion, num_classes)
110 |
111 | elif dataset == 'imagenet':
112 | blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck}
113 | layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]}
114 | assert layers[depth], 'invalid detph for ResNet (depth should be one of 18, 34, 50, 101, 152, and 200)'
115 |
116 | self.inplanes = 64
117 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
118 | self.bn1 = nn.BatchNorm2d(64)
119 | self.relu = nn.ReLU(inplace=True)
120 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
121 | self.layer1 = self._make_layer(blocks[depth], 64, layers[depth][0])
122 | self.layer2 = self._make_layer(blocks[depth], 128, layers[depth][1], stride=2)
123 | self.layer3 = self._make_layer(blocks[depth], 256, layers[depth][2], stride=2)
124 | self.layer4 = self._make_layer(blocks[depth], 512, layers[depth][3], stride=2)
125 | # self.avgpool = nn.AvgPool2d(7)
126 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
127 | self.fc = nn.Linear(512 * blocks[depth].expansion, num_classes)
128 |
129 | for m in self.modules():
130 | if isinstance(m, nn.Conv2d):
131 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
132 | m.weight.data.normal_(0, math.sqrt(2. / n))
133 | elif isinstance(m, nn.BatchNorm2d):
134 | m.weight.data.fill_(1)
135 | m.bias.data.zero_()
136 |
137 | def _make_layer(self, block, planes, blocks, stride=1):
138 | downsample = None
139 | if stride != 1 or self.inplanes != planes * block.expansion:
140 | downsample = nn.Sequential(
141 | nn.Conv2d(self.inplanes, planes * block.expansion,
142 | kernel_size=1, stride=stride, bias=False),
143 | nn.BatchNorm2d(planes * block.expansion),
144 | )
145 |
146 | layers = []
147 | layers.append(block(self.inplanes, planes, stride, downsample))
148 | self.inplanes = planes * block.expansion
149 | for i in range(1, blocks):
150 | layers.append(block(self.inplanes, planes))
151 |
152 | return nn.Sequential(*layers)
153 |
154 | def forward(self, x):
155 | if self.dataset == 'cifar10' or self.dataset == 'cifar100':
156 | x = self.conv1(x)
157 | x = self.bn1(x)
158 | x = self.relu(x)
159 |
160 | x = self.layer1(x)
161 | x = self.layer2(x)
162 | x = self.layer3(x)
163 |
164 | x = self.avgpool(x)
165 | x = x.view(x.size(0), -1)
166 | x = self.fc(x)
167 |
168 | elif self.dataset == 'imagenet':
169 | x = self.conv1(x)
170 | x = self.bn1(x)
171 | x = self.relu(x)
172 | x = self.maxpool(x)
173 |
174 | x = self.layer1(x)
175 | x = self.layer2(x)
176 | x = self.layer3(x)
177 | x = self.layer4(x)
178 |
179 | x = self.avgpool(x)
180 | x = x.view(x.size(0), -1)
181 | x = self.fc(x)
182 |
183 | return x
184 |
--------------------------------------------------------------------------------
/DeepAA_evaluate/networks/shakeshake/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/DeepAA_evaluate/networks/shakeshake/__init__.py
--------------------------------------------------------------------------------
/DeepAA_evaluate/networks/shakeshake/shake_resnet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import math
4 |
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from DeepAA_evaluate.networks.shakeshake.shakeshake import ShakeShake
9 | from DeepAA_evaluate.networks.shakeshake.shakeshake import Shortcut
10 |
11 |
12 | class ShakeBlock(nn.Module):
13 |
14 | def __init__(self, in_ch, out_ch, stride=1):
15 | super(ShakeBlock, self).__init__()
16 | self.equal_io = in_ch == out_ch
17 | if self.equal_io:
18 | self.shortcut = lambda x: x
19 | else:
20 | self.shortcut = Shortcut(in_ch, out_ch, stride=stride)
21 | #self.shortcut = self.equal_io and None or Shortcut(in_ch, out_ch, stride=stride)
22 |
23 | self.branch1 = self._make_branch(in_ch, out_ch, stride)
24 | self.branch2 = self._make_branch(in_ch, out_ch, stride)
25 |
26 | def forward(self, x):
27 | h1 = self.branch1(x)
28 | h2 = self.branch2(x)
29 | h = ShakeShake.apply(h1, h2, self.training)
30 | #h0 = x if self.equal_io else self.shortcut(x)
31 | h0 = self.shortcut(x)
32 | return h + h0
33 |
34 | def _make_branch(self, in_ch, out_ch, stride=1):
35 | return nn.Sequential(
36 | nn.ReLU(inplace=False),
37 | nn.Conv2d(in_ch, out_ch, 3, padding=1, stride=stride, bias=False),
38 | nn.BatchNorm2d(out_ch),
39 | nn.ReLU(inplace=False),
40 | nn.Conv2d(out_ch, out_ch, 3, padding=1, stride=1, bias=False),
41 | nn.BatchNorm2d(out_ch))
42 |
43 |
44 | class ShakeResNet(nn.Module):
45 |
46 | def __init__(self, depth, w_base, label):
47 | super(ShakeResNet, self).__init__()
48 | n_units = (depth - 2) / 6
49 |
50 | in_chs = [16, w_base, w_base * 2, w_base * 4]
51 | self.in_chs = in_chs
52 |
53 | self.c_in = nn.Conv2d(3, in_chs[0], 3, padding=1)
54 | self.layer1 = self._make_layer(n_units, in_chs[0], in_chs[1])
55 | self.layer2 = self._make_layer(n_units, in_chs[1], in_chs[2], 2)
56 | self.layer3 = self._make_layer(n_units, in_chs[2], in_chs[3], 2)
57 | self.fc_out = nn.Linear(in_chs[3], label)
58 |
59 | # Initialize paramters
60 | for m in self.modules():
61 | if isinstance(m, nn.Conv2d):
62 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
63 | m.weight.data.normal_(0, math.sqrt(2. / n))
64 | elif isinstance(m, nn.BatchNorm2d):
65 | m.weight.data.fill_(1)
66 | m.bias.data.zero_()
67 | elif isinstance(m, nn.Linear):
68 | m.bias.data.zero_()
69 |
70 | def forward(self, x):
71 | h = self.c_in(x)
72 | h = self.layer1(h)
73 | h = self.layer2(h)
74 | h = self.layer3(h)
75 | h = F.relu(h)
76 | h = F.avg_pool2d(h, 8)
77 | h = h.view(-1, self.in_chs[3])
78 | h = self.fc_out(h)
79 | return h
80 |
81 | def _make_layer(self, n_units, in_ch, out_ch, stride=1):
82 | layers = []
83 | for i in range(int(n_units)):
84 | layers.append(ShakeBlock(in_ch, out_ch, stride=stride))
85 | in_ch, stride = out_ch, 1
86 | return nn.Sequential(*layers)
87 |
--------------------------------------------------------------------------------
/DeepAA_evaluate/networks/shakeshake/shake_resnext.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import math
4 |
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from DeepAA_evaluate.networks.shakeshake.shakeshake import ShakeShake
9 | from DeepAA_evaluate.networks.shakeshake.shakeshake import Shortcut
10 |
11 |
12 | class ShakeBottleNeck(nn.Module):
13 |
14 | def __init__(self, in_ch, mid_ch, out_ch, cardinary, stride=1):
15 | super(ShakeBottleNeck, self).__init__()
16 | self.equal_io = in_ch == out_ch
17 | self.shortcut = None if self.equal_io else Shortcut(in_ch, out_ch, stride=stride)
18 |
19 | self.branch1 = self._make_branch(in_ch, mid_ch, out_ch, cardinary, stride)
20 | self.branch2 = self._make_branch(in_ch, mid_ch, out_ch, cardinary, stride)
21 |
22 | def forward(self, x):
23 | h1 = self.branch1(x)
24 | h2 = self.branch2(x)
25 | h = ShakeShake.apply(h1, h2, self.training)
26 | h0 = x if self.equal_io else self.shortcut(x)
27 | return h + h0
28 |
29 | def _make_branch(self, in_ch, mid_ch, out_ch, cardinary, stride=1):
30 | return nn.Sequential(
31 | nn.Conv2d(in_ch, mid_ch, 1, padding=0, bias=False),
32 | nn.BatchNorm2d(mid_ch),
33 | nn.ReLU(inplace=False),
34 | nn.Conv2d(mid_ch, mid_ch, 3, padding=1, stride=stride, groups=cardinary, bias=False),
35 | nn.BatchNorm2d(mid_ch),
36 | nn.ReLU(inplace=False),
37 | nn.Conv2d(mid_ch, out_ch, 1, padding=0, bias=False),
38 | nn.BatchNorm2d(out_ch))
39 |
40 |
41 | class ShakeResNeXt(nn.Module):
42 |
43 | def __init__(self, depth, w_base, cardinary, label):
44 | super(ShakeResNeXt, self).__init__()
45 | n_units = (depth - 2) // 9
46 | n_chs = [64, 128, 256, 1024]
47 | self.n_chs = n_chs
48 | self.in_ch = n_chs[0]
49 |
50 | self.c_in = nn.Conv2d(3, n_chs[0], 3, padding=1)
51 | self.layer1 = self._make_layer(n_units, n_chs[0], w_base, cardinary)
52 | self.layer2 = self._make_layer(n_units, n_chs[1], w_base, cardinary, 2)
53 | self.layer3 = self._make_layer(n_units, n_chs[2], w_base, cardinary, 2)
54 | self.fc_out = nn.Linear(n_chs[3], label)
55 |
56 | # Initialize paramters
57 | for m in self.modules():
58 | if isinstance(m, nn.Conv2d):
59 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
60 | m.weight.data.normal_(0, math.sqrt(2. / n))
61 | elif isinstance(m, nn.BatchNorm2d):
62 | m.weight.data.fill_(1)
63 | m.bias.data.zero_()
64 | elif isinstance(m, nn.Linear):
65 | m.bias.data.zero_()
66 |
67 | def forward(self, x):
68 | h = self.c_in(x)
69 | h = self.layer1(h)
70 | h = self.layer2(h)
71 | h = self.layer3(h)
72 | h = F.relu(h)
73 | h = F.avg_pool2d(h, 8)
74 | h = h.view(-1, self.n_chs[3])
75 | h = self.fc_out(h)
76 | return h
77 |
78 | def _make_layer(self, n_units, n_ch, w_base, cardinary, stride=1):
79 | layers = []
80 | mid_ch, out_ch = n_ch * (w_base // 64) * cardinary, n_ch * 4
81 | for i in range(n_units):
82 | layers.append(ShakeBottleNeck(self.in_ch, mid_ch, out_ch, cardinary, stride=stride))
83 | self.in_ch, stride = out_ch, 1
84 | return nn.Sequential(*layers)
85 |
--------------------------------------------------------------------------------
/DeepAA_evaluate/networks/shakeshake/shakeshake.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.autograd import Variable
7 |
8 |
9 | class ShakeShake(torch.autograd.Function):
10 |
11 | @staticmethod
12 | def forward(ctx, x1, x2, training=True):
13 | if training:
14 | alpha = torch.cuda.FloatTensor(x1.size(0)).uniform_()
15 | alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x1)
16 | else:
17 | alpha = 0.5
18 | return alpha * x1 + (1 - alpha) * x2
19 |
20 | @staticmethod
21 | def backward(ctx, grad_output):
22 | beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_()
23 | beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output)
24 | beta = Variable(beta)
25 |
26 | return beta * grad_output, (1 - beta) * grad_output, None
27 |
28 |
29 | class Shortcut(nn.Module):
30 |
31 | def __init__(self, in_ch, out_ch, stride):
32 | super(Shortcut, self).__init__()
33 | self.stride = stride
34 | self.conv1 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False)
35 | self.conv2 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False)
36 | self.bn = nn.BatchNorm2d(out_ch)
37 |
38 | def forward(self, x):
39 | h = F.relu(x)
40 |
41 | h1 = F.avg_pool2d(h, 1, self.stride)
42 | h1 = self.conv1(h1)
43 |
44 | h2 = F.avg_pool2d(F.pad(h, (-1, 1, -1, 1)), 1, self.stride)
45 | h2 = self.conv2(h2)
46 |
47 | h = torch.cat((h1, h2), 1)
48 | return self.bn(h)
49 |
--------------------------------------------------------------------------------
/DeepAA_evaluate/networks/wideresnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 |
8 | _bn_momentum = 0.1
9 | CpG = 8
10 |
11 |
12 | class ExampleWiseBatchNorm2d(nn.BatchNorm2d):
13 | def __init__(self, num_features, eps=1e-5, momentum=0.1,
14 | affine=True, track_running_stats=True):
15 | super().__init__(num_features, eps, momentum, affine, track_running_stats)
16 |
17 | def forward(self, input):
18 | self._check_input_dim(input)
19 |
20 | exponential_average_factor = 0.0
21 |
22 | if self.training and self.track_running_stats:
23 | if self.num_batches_tracked is not None:
24 | self.num_batches_tracked += 1
25 | if self.momentum is None: # use cumulative moving average
26 | exponential_average_factor = 1.0 / float(self.num_batches_tracked)
27 | else: # use exponential moving average
28 | exponential_average_factor = self.momentum
29 |
30 | # calculate running estimates
31 | if self.training:
32 | mean = input.mean([0, 2, 3])
33 | # use biased var in train
34 | var = input.var([0, 2, 3], unbiased=False)
35 | n = input.numel() / input.size(1)
36 | with torch.no_grad():
37 | self.running_mean = exponential_average_factor * mean\
38 | + (1 - exponential_average_factor) * self.running_mean
39 | # update running_var with unbiased var
40 | self.running_var = exponential_average_factor * var * n / (n - 1)\
41 | + (1 - exponential_average_factor) * self.running_var
42 | local_means = input.mean([2, 3])
43 | local_global_means = local_means + (mean.unsqueeze(0) - local_means).detach()
44 | local_vars = input.var([2, 3], unbiased=False)
45 | local_global_vars = local_vars + (var.unsqueeze(0) - local_vars).detach()
46 | input = (input - local_global_means[:,:,None,None]) / (torch.sqrt(local_global_vars[:,:,None,None] + self.eps))
47 | else:
48 | mean = self.running_mean
49 | var = self.running_var
50 | input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
51 |
52 | if self.affine:
53 | input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]
54 |
55 | return input
56 |
57 |
58 | class VirtualBatchNorm2d(nn.BatchNorm2d):
59 | def __init__(self, num_features, eps=1e-5, momentum=0.1,
60 | affine=True, track_running_stats=True):
61 | super().__init__(num_features, eps, momentum, affine, track_running_stats)
62 |
63 | def forward(self, input):
64 | self._check_input_dim(input)
65 |
66 | exponential_average_factor = 0.0
67 |
68 | if self.training and self.track_running_stats:
69 | if self.num_batches_tracked is not None:
70 | self.num_batches_tracked += 1
71 | if self.momentum is None: # use cumulative moving average
72 | exponential_average_factor = 1.0 / float(self.num_batches_tracked)
73 | else: # use exponential moving average
74 | exponential_average_factor = self.momentum
75 |
76 | # calculate running estimates
77 | if self.training:
78 | mean = input.mean([0, 2, 3])
79 | # use biased var in train
80 | var = input.var([0, 2, 3], unbiased=False)
81 | n = input.numel() / input.size(1)
82 | with torch.no_grad():
83 | self.running_mean = exponential_average_factor * mean \
84 | + (1 - exponential_average_factor) * self.running_mean
85 | # update running_var with unbiased var
86 | self.running_var = exponential_average_factor * var * n / (n - 1) \
87 | + (1 - exponential_average_factor) * self.running_var
88 | input = (input - mean.detach()[None, :, None, None]) / (torch.sqrt(var.detach()[None, :, None, None] + self.eps))
89 | else:
90 | mean = self.running_mean
91 | var = self.running_var
92 | input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
93 |
94 | if self.affine:
95 | input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]
96 |
97 | return input
98 |
99 |
100 | def conv3x3(in_planes, out_planes, stride=1):
101 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
102 |
103 |
104 | def conv_init(m):
105 | classname = m.__class__.__name__
106 | if classname.find('Conv') != -1:
107 | init.xavier_uniform_(m.weight, gain=np.sqrt(2))
108 | init.constant_(m.bias, 0)
109 | elif classname.find('BatchNorm') != -1:
110 | init.constant_(m.weight, 1)
111 | init.constant_(m.bias, 0)
112 |
113 |
114 | class WideBasic(nn.Module):
115 | def __init__(self, in_planes, planes, dropout_rate, norm_creator, stride=1, adaptive_dropouter_creator=None):
116 | super(WideBasic, self).__init__()
117 | self.bn1 = norm_creator(in_planes)
118 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
119 | if adaptive_dropouter_creator is None:
120 | self.dropout = nn.Dropout(p=dropout_rate)
121 | else:
122 | self.dropout = adaptive_dropouter_creator(planes, 3, stride, 1)
123 | self.bn2 = norm_creator(planes)
124 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
125 |
126 | self.shortcut = nn.Sequential()
127 | if stride != 1 or in_planes != planes:
128 | self.shortcut = nn.Sequential(
129 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
130 | )
131 |
132 | def forward(self, x):
133 | out = self.dropout(self.conv1(F.relu(self.bn1(x))))
134 | out = self.conv2(F.relu(self.bn2(out)))
135 | out += self.shortcut(x)
136 |
137 | return out
138 |
139 |
140 | class WideResNet(nn.Module):
141 | def __init__(self, depth, widen_factor, dropout_rate, num_classes, adaptive_dropouter_creator, adaptive_conv_dropouter_creator, groupnorm, examplewise_bn, virtual_bn):
142 | super(WideResNet, self).__init__()
143 | self.in_planes = 16
144 | self.adaptive_conv_dropouter_creator = adaptive_conv_dropouter_creator
145 |
146 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
147 | assert sum([groupnorm,examplewise_bn,virtual_bn]) <= 1
148 | n = int((depth - 4) / 6)
149 | k = widen_factor
150 |
151 | nStages = [16, 16*k, 32*k, 64*k]
152 |
153 | self.adaptive_dropouters = [] #nn.ModuleList()
154 |
155 | if groupnorm:
156 | print('Uses group norm.')
157 | self.norm_creator = lambda c: nn.GroupNorm(max(c//CpG, 1), c)
158 | elif examplewise_bn:
159 | print("Uses Example Wise BN")
160 | self.norm_creator = lambda c: ExampleWiseBatchNorm2d(c, momentum=_bn_momentum)
161 | elif virtual_bn:
162 | print("Uses Virtual BN")
163 | self.norm_creator = lambda c: VirtualBatchNorm2d(c, momentum=_bn_momentum)
164 | else:
165 | self.norm_creator = lambda c: nn.BatchNorm2d(c, momentum=_bn_momentum)
166 |
167 | self.conv1 = conv3x3(3, nStages[0])
168 | self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1)
169 | self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2)
170 | self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2)
171 | self.bn1 = self.norm_creator(nStages[3])
172 | self.linear = nn.Linear(nStages[3], num_classes)
173 | if adaptive_dropouter_creator is not None:
174 | last_dropout = adaptive_dropouter_creator(nStages[3])
175 | else:
176 | last_dropout = lambda x: x
177 | self.adaptive_dropouters.append(last_dropout)
178 |
179 | # self.apply(conv_init)
180 |
181 | def to(self, *args, **kwargs):
182 | super().to(*args,**kwargs)
183 | print(*args)
184 | for ad in self.adaptive_dropouters:
185 | if hasattr(ad,'to'):
186 | ad.to(*args,**kwargs)
187 | return self
188 |
189 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
190 | strides = [stride] + [1]*(num_blocks-1)
191 | layers = []
192 |
193 | for i,stride in enumerate(strides):
194 | ada_conv_drop_c = self.adaptive_conv_dropouter_creator if i == 0 else None
195 | new_block = block(self.in_planes, planes, dropout_rate, self.norm_creator, stride, adaptive_dropouter_creator=ada_conv_drop_c)
196 | layers.append(new_block)
197 | if ada_conv_drop_c is not None:
198 | self.adaptive_dropouters.append(new_block.dropout)
199 |
200 | self.in_planes = planes
201 |
202 | return nn.Sequential(*layers)
203 |
204 | def forward(self, x):
205 | out = self.conv1(x)
206 | out = self.layer1(out)
207 | out = self.layer2(out)
208 | out = self.layer3(out)
209 | out = F.relu(self.bn1(out))
210 | # out = F.avg_pool2d(out, 8)
211 | out = F.adaptive_avg_pool2d(out, (1, 1))
212 | out = out.view(out.size(0), -1)
213 | out = self.adaptive_dropouters[-1](out)
214 | out = self.linear(out)
215 |
216 | return out
217 |
--------------------------------------------------------------------------------
/DeepAA_evaluate/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import matplotlib
4 | matplotlib.use('TkAgg')
5 | import matplotlib.pyplot as plt
6 |
7 | import torchvision.transforms.functional as F
8 |
9 |
10 | plt.rcParams["savefig.bbox"] = 'tight'
11 |
12 |
13 | def save_images(imgs, dir):
14 | if not isinstance(imgs, list):
15 | imgs = [imgs]
16 | fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
17 | for i, img in enumerate(imgs):
18 | img = img.detach()
19 | img = F.to_pil_image(img)
20 | axs[0, i].imshow(np.asarray(img))
21 | axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
22 | fix.savefig(dir)
23 | return fix
--------------------------------------------------------------------------------
/DeepAA_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import numpy as np
4 | import copy
5 | import random
6 | import datetime
7 |
8 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
9 | import tensorflow as tf
10 | tf.get_logger().setLevel(logging.ERROR)
11 |
12 |
13 | from data_generator import DataGenerator, DataAugmentation
14 | from utils import CTLHistory
15 | from lr_scheduler import GradualWarmup_Cosine_Scheduler
16 | import resnet
17 | from resnet_imagenet import imagenet_resnet50
18 |
19 | from data_generator import get_cifar10_data, get_cifar100_data
20 |
21 | from augmentation import AutoContrast, Invert, Equalize, Solarize, Posterize, Contrast, Brightness, Sharpness, \
22 | Identity, Color, ShearX, ShearY, TranslateX, TranslateY, Rotate
23 | from augmentation import RandCrop, RandCutout, RandFlip, RandCutout60
24 | from augmentation import RandResizeCrop_imagenet, centerCrop_imagenet
25 |
26 |
27 | from policy import DA_Policy_logits
28 | from augmentation import IMAGENET_SIZE
29 |
30 | import torch
31 | import threading
32 | import queue
33 | from imagenet_data_utils import get_imagenet_split
34 |
35 | def aug_op_cifar_list(): # oeprators and their ranges
36 | l = [
37 | (Identity, 0., 1.0), # 0
38 | (ShearX, -0.3, 0.3), # 1
39 | (ShearY, -0.3, 0.3), # 2
40 | (TranslateX, -0.45, 0.45), # 3
41 | (TranslateY, -0.45, 0.45), # 4
42 | (Rotate, -30., 30.), # 5
43 | (AutoContrast, 0., 1.), # 6
44 | (Invert, 0., 1.), # 7
45 | (Equalize, 0., 1.), # 8
46 | (Solarize, 0., 256.), # 9
47 | (Posterize, 4., 8.), # 10,
48 | (Contrast, 0.1, 1.9), # 11
49 | (Color, 0.1, 1.9), # 12
50 | (Brightness, 0.1, 1.9), # 13
51 | (Sharpness, 0.1, 1.9), # 14
52 | (RandFlip, 0., 1.0), # 15
53 | (RandCutout, 0., 1.0), # 16
54 | (RandCrop, 0., 1.0), # 17
55 | ]
56 | names = []
57 | for op in l:
58 | info = op.__str__().split(' ')
59 | name = '{}:({},{}'.format(info[1], info[-2], info[-1])
60 | names.append(name)
61 |
62 | return l, names
63 |
64 | def aug_op_imagenet_list(): # 16 oeprations and their ranges
65 | l = [
66 | (Identity, 0., 1.0), # 0
67 | (ShearX, -0.3, 0.3), # 1
68 | (ShearY, -0.3, 0.3), # 2
69 | (TranslateX, -0.45, 0.45), # 3
70 | (TranslateY, -0.45, 0.45), # 4
71 | (Rotate, -30., 30.), # 5
72 | (AutoContrast, 0., 1.), # 6
73 | (Invert, 0., 1.), # 7
74 | (Equalize, 0., 1.), # 8
75 | (Solarize, 0., 256.), # 9
76 | (Posterize, 4., 8.), # 10
77 | (Contrast, 0.1, 1.9), # 11
78 | (Color, 0.1, 1.9), # 12
79 | (Brightness, 0.1, 1.9), # 13
80 | (Sharpness, 0.1, 1.9), # 14
81 | (RandFlip, 0., 1.0), # 15
82 | (RandCutout60, 0., 1.0), # 16
83 | (RandResizeCrop_imagenet, 0., 1.),
84 | ]
85 | names = []
86 | for op in l:
87 | info = op.__str__().split(' ')
88 | name = '{}:({},{}'.format(info[1], info[-2], info[-1])
89 | names.append(name)
90 |
91 | return l, names
92 |
93 |
94 | # Get the model
95 | def get_model(args, model, n_classes):
96 | if model == 'WRN_28_10':
97 | model = resnet.cifar_WRN_28_10(dropout=0, l2_reg=0.00025,
98 | preact_shortcuts=False, n_classes=n_classes, input_shape=args.img_size)
99 | elif model == 'WRN_40_2':
100 | model = resnet.cifar_WRN_40_2(dropout=0, l2_reg=0.00025,
101 | preact_shortcuts=False, n_classes=n_classes, input_shape=args.img_size)
102 | elif model == 'resnet50':
103 | model = imagenet_resnet50()
104 | else:
105 | raise Exception('Unrecognized model')
106 | return model
107 |
108 | # metric to keep track of
109 | train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
110 | test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
111 | train_loss = tf.keras.metrics.Mean()
112 | test_loss = tf.keras.metrics.Mean()
113 |
114 | def get_img_size(args):
115 | if 'cifar' in args.dataset:
116 | return (32, 32, 3)
117 | elif 'imagenet' in args.dataset:
118 | return (*IMAGENET_SIZE, 3)
119 | else:
120 | raise Exception
121 |
122 | # get the data
123 | def get_dataset(args):
124 | print('Loading train and retrain dataset.')
125 | if args.dataset in ['cifar10', 'cifar100']:
126 | if args.dataset == 'cifar10':
127 | assert args.n_classes == 10
128 | x_train_, y_train_, x_val, y_val, x_test, y_test = get_cifar10_data(val_size=10000)
129 | x_train, y_train = x_train_[:args.pretrain_size], y_train_[:args.pretrain_size]
130 | x_search, y_search = x_train_[args.pretrain_size:], y_train_[args.pretrain_size:]
131 | elif args.dataset == 'cifar100':
132 | assert args.n_classes == 100
133 | x_train_, y_train_, x_val, y_val, x_test, y_test = get_cifar100_data(val_size=10000)
134 | x_train, y_train = x_train_[:args.pretrain_size], y_train_[:args.pretrain_size]
135 | x_search, y_search = x_train_[args.pretrain_size:], y_train_[args.pretrain_size:]
136 | train_ds = DataGenerator(x_train, y_train, batch_size=args.batch_size, drop_last=True)
137 | search_ds = DataGenerator(x_search, y_search, batch_size=args.batch_size, drop_last=True)
138 | val_ds = DataGenerator(x_val, y_val, batch_size=args.val_batch_size, drop_last=True)
139 | test_ds = DataGenerator(x_test, y_test, batch_size=args.test_batch_size, drop_last=False, shuffle=False) # setting shuffle=False for parallel evaluation
140 | elif args.dataset == 'imagenet':
141 | assert args.n_classes == 1000
142 | def collate_fn_imagenet_list(l): # return a list
143 | images, labels = zip(*l)
144 | assert images[0].dtype == np.uint8
145 | return list(images), np.array(labels, dtype=np.int32)
146 | if args.dataset == 'imagenet':
147 | train_ds_total, val_ds, search_ds, train_ds, test_ds = get_imagenet_split(n_GPU=1, seed=300)
148 | assert len(train_ds) == 1 and isinstance(train_ds, list), 'Train_ds should be a length=1 list'
149 | train_ds = train_ds[0]
150 | test_ds = torch.utils.data.DataLoader(
151 | test_ds, batch_size=256, shuffle=False, num_workers=64,
152 | pin_memory=False,
153 | drop_last=False, sampler=None,
154 | collate_fn=collate_fn_imagenet_list,
155 | )
156 | else:
157 | raise Exception('Unrecognized dataset')
158 |
159 | return train_ds, val_ds, test_ds, search_ds
160 |
161 | def get_augmentation(args):
162 | if 'cifar' in args.dataset:
163 | augmentation_default = DataAugmentation(num_classes=args.n_classes, dataset=args.dataset, image_shape=args.img_size,
164 | ops_list=(None, None),
165 | default_pre_aug=None,
166 | default_post_aug=[RandCrop,
167 | RandFlip,
168 | RandCutout])
169 |
170 | augmentation_search = DataAugmentation(num_classes=args.n_classes, dataset=args.dataset, image_shape=args.img_size,
171 | ops_list=aug_op_cifar_list(),
172 | default_pre_aug=None,
173 | default_post_aug=None)
174 |
175 | augmentation_test = DataAugmentation(num_classes=args.n_classes, dataset=args.dataset, image_shape=args.img_size,
176 | ops_list=(None, None),
177 | default_pre_aug=None,
178 | default_post_aug=None)
179 | elif 'imagenet' in args.dataset:
180 | augmentation_default = DataAugmentation(num_classes=args.n_classes, dataset=args.dataset,
181 | image_shape=args.img_size,
182 | ops_list=(None, None),
183 | default_pre_aug=None,
184 | default_post_aug=[RandResizeCrop_imagenet, #
185 | RandFlip])
186 |
187 | augmentation_search = DataAugmentation(num_classes=args.n_classes, dataset=args.dataset, image_shape=args.img_size,
188 | ops_list=aug_op_imagenet_list(),
189 | default_pre_aug=None,
190 | default_post_aug=None)
191 |
192 |
193 | augmentation_test = DataAugmentation(num_classes=args.n_classes, dataset=args.dataset,
194 | image_shape=args.img_size,
195 | ops_list=(None, None),
196 | default_pre_aug=None,
197 | default_post_aug=[
198 | centerCrop_imagenet,
199 | ])
200 | return augmentation_default, augmentation_search, augmentation_test
201 |
202 | def get_optim_net(args, nb_train_steps):
203 | scheduler_lr = GradualWarmup_Cosine_Scheduler(starting_lr=0., initial_lr=args.pretrain_lr,
204 | ending_lr=1e-7,
205 | warmup_steps= 0,
206 | total_steps=nb_train_steps * args.nb_epochs)
207 |
208 | optim_net = tf.optimizers.SGD(learning_rate=scheduler_lr, momentum=0.9, nesterov=True)
209 | return optim_net
210 |
211 |
212 |
213 |
214 | def get_policy(args, op_names, ops_mid_magnitude, available_policies):
215 | policy = DA_Policy_logits(args.l_ops, args.l_mags, args.l_uniq,
216 | op_names=op_names,
217 | ops_mid_magnitude=ops_mid_magnitude, N_repeat_random=args.N_repeat_random,
218 | available_policies=available_policies)
219 | return policy
220 |
221 | def get_optim_policy(policy_lr):
222 | optim_policy = tf.optimizers.Adam(learning_rate=policy_lr, beta_1=0.9, beta_2=0.999)
223 | return optim_policy
224 |
225 |
226 | # get the loss
227 | def get_loss_fun():
228 | train_loss_fun = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
229 | reduction=tf.keras.losses.Reduction.NONE)
230 | test_loss_fun = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
231 | reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
232 | val_loss_fun = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
233 | reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)
234 | return train_loss_fun, test_loss_fun, val_loss_fun
235 |
236 |
237 | def get_lops_luniq(args, ops_mid_magnitude):
238 | if 'cifar' in args.dataset:
239 | _, op_names = aug_op_cifar_list()
240 | elif 'imagenet' in args.dataset:
241 | _, op_names = aug_op_imagenet_list()
242 | else:
243 | raise Exception('Unknown dataset ={}'.format(args.dataset))
244 |
245 | names_modified = [op_name.split(':')[0] for op_name in op_names]
246 | l_ops = len(op_names)
247 | l_uniq = 0
248 | for k_name, name in enumerate(names_modified):
249 | mid_mag = ops_mid_magnitude[name]
250 | if mid_mag == 'random':
251 | l_uniq += 1 # The op is a random op, however we only sample one op
252 | elif mid_mag is not None and mid_mag >=0 and mid_mag <= args.l_mags-1:
253 | l_uniq += args.l_mags-1
254 | elif mid_mag is not None and mid_mag == -1: # magnitude==-1 means all l_mags are independnt policies; or mid_mag > args.l_mags-1)
255 | l_uniq += args.l_mags
256 | elif mid_mag is None:
257 | l_uniq += 1
258 | else:
259 | raise Exception('mid_mag = {} is invalid'.format(mid_mag))
260 | return l_ops, l_uniq
261 |
262 | def get_all_policy(policy_train):
263 | l_ops, l_mags = policy_train.l_ops, policy_train.l_mags
264 | ops, mags = np.meshgrid(np.arange(l_ops), np.arange(l_mags), indexing='ij')
265 | ops = np.reshape(ops, [l_ops*l_mags,1])
266 | mags = np.reshape(mags, [l_ops*l_mags,1])
267 | return ops.astype(np.int32), mags.astype(np.int32)
268 |
269 | class PrefetchGenerator(threading.Thread):
270 | def __init__(self, search_ds, val_ds, n_classes, search_bs=8, val_bs=64):
271 | threading.Thread.__init__(self)
272 | self.queue = queue.Queue(1)
273 | self.search_ds = search_ds
274 | self.val_ds = val_ds
275 | self.n_classes = n_classes
276 | self.search_bs = search_bs
277 | self.val_bs = val_bs
278 | self.daemon = True
279 | self.start()
280 |
281 | @staticmethod
282 | def sample_label_and_batch(dataset, bs, n_classes, MAX_iterations=100):
283 | for k in range(MAX_iterations):
284 | try:
285 | lab = random.randint(0, n_classes-1)
286 | imgs, labs = dataset.sample_labeled_data_batch(lab, bs)
287 | except:
288 | print('Insufficient data in a single class, try {}/{}'.format(k, MAX_iterations))
289 | continue
290 | return lab, imgs, labs
291 | raise Exception('Maximum number of iteration {} reached'.format(MAX_iterations))
292 |
293 | def run(self):
294 | while True:
295 | images_val, labels_val, images_train, labels_train = [], [], [], []
296 | for _ in range(self.search_bs):
297 | lab, imgs_val, labs_val = PrefetchGenerator.sample_label_and_batch(self.val_ds, self.val_bs, self.n_classes)
298 | imgs_train, labs_train = self.search_ds.sample_labeled_data_batch(lab, 1)
299 | images_val.append(imgs_val)
300 | labels_val.append(labs_val)
301 | images_train.append(imgs_train)
302 | labels_train.append(labs_train)
303 | self.queue.put( (images_val, labels_val, images_train, labels_train) )
304 |
305 | def next(self):
306 | next_item = self.queue.get()
307 | return next_item
308 |
309 |
310 | def save_policy(args, all_using_policies, augmentation_search):
311 | ops, mags = all_using_policies[0].unique_policy
312 | op_names = augmentation_search.op_names
313 | policy_probs = []
314 | for k_policy, policy in enumerate(all_using_policies):
315 | policy_probs.append(tf.nn.softmax(policy.logits).numpy())
316 | policy_probs = np.stack(policy_probs, axis=0)
317 |
318 | np.savez('./policy_port/policy_DeepAA_{}.npz'.format(datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")),
319 | policy_probs=policy_probs, l_ops=args.l_ops, l_mags=args.l_mags,
320 | ops=ops, mags=mags, op_names=op_names)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep AutoAugment
2 |
3 | This is the official implementation of Deep AutoAugment ([DeepAA](https://openreview.net/forum?id=St-53J9ZARf)), a fully automated data augmentation policy search method. Leaderboard is here: https://paperswithcode.com/paper/deep-autoaugment-1
4 |
5 |
6 |
7 |
8 |
9 | ## 5-Minute Explanation Video
10 | Click the figure to watch this short video explaining our work.
11 |
12 | [](https://recorder-v3.slideslive.com/#/share?share=64177&s=6d93977f-2a40-436d-a404-8808aee650fa)
13 |
14 | ## Requirements
15 | DeepAA is implemented using TensorFlow.
16 | To be consistent with previous work, we run the policy evaluation based on [TrivialAugment](https://github.com/automl/trivialaugment), which is implemented using PyTorch.
17 |
18 | ### Install required packages
19 | a. Create a conda virtual environment.
20 | ```shell
21 | conda create -n deepaa python=3.7
22 | conda activate deepaa
23 | ```
24 |
25 | b. Install Tensorflow and PyTorch.
26 | ```shell
27 | conda install tensorflow-gpu=2.5 cudnn=8.1 cudatoolkit=11.2 -c conda-forge
28 | pip3 install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
29 | ```
30 |
31 | c. Install other dependencies.
32 | ```shell
33 | pip install -r requirements.txt
34 | ```
35 |
36 |
37 | ## Experiments
38 |
39 | ### Run augmentation policy search on CIFAR-10/100.
40 | ```shell
41 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
42 | python DeepAA_search.py --dataset cifar10 --n_classes 10 --use_model WRN_40_2 --n_policies 6 --search_bno 1024 --pretrain_lr 0.1 --seed 1 --batch_size 128 --test_batch_size 512 --policy_lr 0.025 --l_mags 13 --use_pool --pretrain_size 5000 --nb_epochs 45 --EXP_G 16 --EXP_gT_factor=4 --train_same_labels 16
43 | ```
44 |
45 | ### Run augmentation policy search on ImageNet.
46 | ```shell
47 | mkdir pretrained_imagenet
48 | ```
49 | Download the [files](https://drive.google.com/drive/folders/1QmqWfF_dzyZPDIuvkiLHp0X6JiUNbIZI?usp=sharing) and copy them to the `./pretrained_imagenet` folder.
50 | ```shell
51 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
52 | python DeepAA_search.py --dataset imagenet --n_classes 1000 --use_model resnet50 --n_policies 6 --search_bno 1024 --seed 1 --batch_size 128 --test_batch_size 512 --policy_lr 0.025 --l_mags 13 --use_pool --EXP_G 16 --EXP_gT_factor=4 --train_same_labels 16
53 | ```
54 |
55 | ### Evaluate the policy found on CIFAR-10/100 and ImageNet.
56 | ```shell
57 | mkdir ckpt
58 | python -m DeepAA_evaluate.train -c confs/wresnet28x10_cifar10_DeepAA_1.yaml --dataroot ./data --save ckpt/DeepAA_cifar10.pth --tag Exp_DeepAA_cifar10
59 | python -m DeepAA_evaluate.train -c confs/wresnet28x10_cifar100_DeepAA_1.yaml --dataroot ./data --save ckpt/DeepAA_cifar100.pth --tag Exp_DeepAA_cifar100
60 | python -m DeepAA_evaluate.train -c confs/resnet50_imagenet_DeepAA_8x256_1.yaml --dataroot ./data --save ckpt/DeepAA_imagenet.pth --tag Exp_DeepAA_imagenet
61 | ```
62 |
63 | ### Evaluate the policy found on CIFAR-10/100 with Batch Augmentation.
64 | ```shell
65 | mkdir ckpt
66 | python -m DeepAA_evaluate.train -c confs/wresnet28x10_cifar10_DeepAA_BatchAug8x_1.yaml --dataroot ./data --save ckpt/DeepAA_cifar10.pth --tag Exp_DeepAA_cifar10
67 | python -m DeepAA_evaluate.train -c confs/wresnet28x10_cifar100_DeepAA_BatchAug8x_1.yaml --dataroot ./data --save ckpt/DeepAA_cifar100.pth --tag Exp_DeepAA_cifar100
68 | ```
69 |
70 | ## Visualization
71 |
72 | The policies found on CIFAR-10/100 and ImageNet are visualized as follows.
73 |
74 |
75 |
76 |
77 |
78 | The distribution of operations at each layer of the policy for (a) CIFAR-10/100 and (b) ImageNet. The probability of each operation is summed up over all 12 discrete intensity levels of the corresponding transformation.
79 |
80 |
81 |
82 |
83 |
84 | The distribution of discrete magnitudes of each augmentation transformation in each layer of the policy for CIFAR-10/100. The x-axis represents the discrete magnitudes and the y-axis represents the probability. The magnitude is discretized to 12 levels with each transformation having its own range. A large absolute value of the magnitude corresponds to high transformation intensity. Note that we do not show identity, autoContrast, invert, equalize, flips, Cutout and crop because they do not have intensity parameters.
85 |
86 |
87 |
88 |
89 |
90 | The distribution of discrete magnitudes of each augmentation transformation in each layer of the policy for ImageNet. The x-axis represents the discrete magnitudes and the y-axis represents the probability. The magnitude is discretized to 12 levels with each transformation having its own range. A large absolute value of the magnitude corresponds to high transformation intensity. Note that we do not show identity, autoContrast, invert, equalize, flips, Cutout and crop because they do not have intensity parameters.
91 |
92 | ## Citation
93 | If you find this useful for your work, please consider citing:
94 | ```
95 | @inproceedings{
96 | zheng2022deep,
97 | title={Deep AutoAugment},
98 | author={Yu Zheng and Zhi Zhang and Shen Yan and Mi Zhang},
99 | booktitle={International Conference on Learning Representations},
100 | year={2022},
101 | url={https://openreview.net/forum?id=St-53J9ZARf}
102 | }
103 | ```
104 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/__init__.py
--------------------------------------------------------------------------------
/augmentation.py:
--------------------------------------------------------------------------------
1 | # code in this file is adpated from rpmcruz/autoaugment
2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
3 | # https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py
4 |
5 | import random
6 |
7 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
8 | import numpy as np
9 | from PIL import Image
10 | import math
11 |
12 | IMAGENET_SIZE = (224, 224) # (width, height) may set to (244, 224)
13 |
14 | _IMAGENET_PCA = {
15 | 'eigval': [0.2175, 0.0188, 0.0045],
16 | 'eigvec': [
17 | [-0.5675, 0.7192, 0.4009],
18 | [-0.5808, -0.0045, -0.8140],
19 | [-0.5836, -0.6948, 0.4203],
20 | ]
21 | }
22 | _CIFAR_MEAN, _CIFAR_STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
23 |
24 | def ShearX(img, v): # [-0.3, 0.3]
25 | assert -0.3 <= v <= 0.3
26 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
27 |
28 | def ShearY(img, v): # [-0.3, 0.3]
29 | assert -0.3 <= v <= 0.3
30 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
31 |
32 |
33 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
34 | assert -0.45 <= v <= 0.45
35 | v = v * img.size[0]
36 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
37 |
38 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
39 | assert -0.45 <= v <= 0.45
40 | v = v * img.size[1]
41 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
42 |
43 |
44 | def Rotate(img, v): # [-30, 30]
45 | assert -30 <= v <= 30
46 | return img.rotate(v)
47 |
48 |
49 | def AutoContrast(img, _):
50 | return PIL.ImageOps.autocontrast(img)
51 |
52 |
53 | def Invert(img, _):
54 | return PIL.ImageOps.invert(img)
55 |
56 |
57 | def Equalize(img, _):
58 | return PIL.ImageOps.equalize(img)
59 |
60 |
61 | def Flip(img, _): # not from the paper
62 | return PIL.ImageOps.mirror(img)
63 |
64 |
65 | def Solarize(img, v): # [0, 256]
66 | assert 0 <= v <= 256
67 | return PIL.ImageOps.solarize(img, v)
68 |
69 |
70 | def SolarizeAdd(img, addition=0, threshold=128):
71 | img_np = np.array(img).astype(np.int)
72 | img_np = img_np + addition
73 | img_np = np.clip(img_np, 0, 255)
74 | img_np = img_np.astype(np.uint8)
75 | img = Image.fromarray(img_np)
76 | return PIL.ImageOps.solarize(img, threshold)
77 |
78 |
79 | def Posterize(img, v): # [4, 8]
80 | assert 4 <= v <= 8 # FastAA
81 | v = int(v)
82 | return PIL.ImageOps.posterize(img, v)
83 |
84 |
85 | def Contrast(img, v): # [0.1,1.9]
86 | assert 0.1 <= v <= 1.9
87 | return PIL.ImageEnhance.Contrast(img).enhance(v)
88 |
89 |
90 | def Color(img, v): # [0.1,1.9]
91 | assert 0.1 <= v <= 1.9
92 | return PIL.ImageEnhance.Color(img).enhance(v)
93 |
94 |
95 | def Brightness(img, v): # [0.1,1.9]
96 | assert 0.1 <= v <= 1.9
97 | return PIL.ImageEnhance.Brightness(img).enhance(v)
98 |
99 |
100 | def Sharpness(img, v): # [0.1,1.9]
101 | assert 0.1 <= v <= 1.9
102 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
103 |
104 |
105 | def RandCrop(img, _):
106 | v = 4
107 | return mean_pad_randcrop(img, v)
108 |
109 | def RandCutout(img, _):
110 | v = 16
111 | w, h = img.size
112 | x = random.uniform(0, w)
113 | y = random.uniform(0, h)
114 |
115 | x0 = int(min(w, max(0, x - v // 2))) # clip to the range (0, w)
116 | x1 = int(min(w, max(0, x + v // 2)))
117 | y0 = int(min(h, max(0, y - v // 2)))
118 | y1 = int(min(h, max(0, y + v // 2)))
119 |
120 | xy = (x0, y0, x1, y1)
121 | color = (125, 123, 114)
122 | # color = (0, 0, 0)
123 | img = img.copy()
124 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
125 | return img
126 |
127 |
128 | def RandCutout60(img, _):
129 | v = 60
130 | w, h = img.size
131 | x_left = max(0, w // 2 - 256 // 2)
132 | x_right = min(w, w // 2 + 256 // 2)
133 | y_bottom = max(0, h // 2 - 256 // 2)
134 | y_top = min(h, h // 2 + 256 // 2)
135 |
136 | x = random.uniform(x_left, x_right)
137 | y = random.uniform(y_bottom, y_top)
138 |
139 | x0 = int(min(w, max(0, x - v // 2)))
140 | x1 = int(min(w, max(0, x + v // 2)))
141 | y0 = int(min(h, max(0, y - v // 2)))
142 | y1 = int(min(h, max(0, y + v // 2)))
143 |
144 | xy = (x0, y0, x1, y1)
145 | color = (125, 123, 114)
146 | # color = (0, 0, 0)
147 | img = img.copy()
148 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
149 | return img
150 |
151 |
152 | def RandFlip(img, _):
153 | if random.random() > 0.5:
154 | img = Flip(img, None)
155 | return img
156 |
157 |
158 |
159 | def mean_pad_randcrop(img, v):
160 | # v: Pad with mean value=[125, 123, 114] by v pixels on each side and then take random crop
161 | assert v <= 10, 'The maximum shift should be less then 10'
162 | padded_size = (img.size[0] + 2*v, img.size[1] + 2*v)
163 | new_img = PIL.Image.new('RGB', padded_size, color=(125, 123, 114))
164 | new_img.paste(img, (v, v))
165 | top = random.randint(0, v*2)
166 | left = random.randint(0, v*2)
167 | new_img = new_img.crop((left, top, left + img.size[0], top + img.size[1]))
168 | return new_img
169 |
170 | def Identity(img, v):
171 | return img
172 |
173 |
174 | def RandResizeCrop_imagenet(img, _):
175 | # ported from torchvision
176 | # for ImageNet use only
177 | scale = (0.08, 1.0)
178 | ratio = (3. / 4., 4. / 3.)
179 | size = IMAGENET_SIZE # (224, 224)
180 |
181 | def get_params(img, scale, ratio):
182 | width, height = img.size
183 | area = float(width * height)
184 | log_ratio = [math.log(r) for r in ratio]
185 |
186 | for _ in range(10):
187 | target_area = area * random.uniform(scale[0], scale[1])
188 | aspect_ratio = math.exp(random.uniform(log_ratio[0], log_ratio[1]))
189 |
190 | w = round(math.sqrt(target_area * aspect_ratio))
191 | h = round(math.sqrt(target_area / aspect_ratio))
192 | if 0 < w <= width and 0 < h <= height:
193 | top = random.randint(0, height - h)
194 | left = random.randint(0, width - w)
195 | return left, top, w, h
196 |
197 | # fallback to central crop
198 | in_ratio = float(width) / float(height)
199 | if in_ratio < min(ratio):
200 | w = width
201 | h = round(w / min(ratio))
202 | elif in_ratio > max(ratio):
203 | h = height
204 | w = round(h * max(ratio))
205 | else:
206 | w = width
207 | h = height
208 | top = (height - h) // 2
209 | left = (width - w) // 2
210 | return left, top, w, h
211 |
212 | left, top, w_box, h_box = get_params(img, scale, ratio)
213 | box = (left, top, left + w_box, top + h_box)
214 | img = img.resize(size=size, resample=PIL.Image.CUBIC, box=box)
215 | return img
216 |
217 |
218 | def Resize_imagenet(img, size):
219 | w, h = img.size
220 | if isinstance(size, int):
221 | short, long = (w, h) if w <= h else (h, w)
222 | if short == size:
223 | return img
224 | new_short, new_long = size, int(size * long / short)
225 | new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
226 | return img.resize((new_w, new_h), PIL.Image.BICUBIC)
227 | elif isinstance(size, tuple) or isinstance(size, list):
228 | assert len(size) == 2, 'Check the size {}'.format(size)
229 | return img.resize(size, PIL.Image.BICUBIC)
230 | else:
231 | raise Exception
232 |
233 |
234 | def centerCrop_imagenet(img, _):
235 | # for ImageNet only
236 | # https://github.com/pytorch/vision/blob/master/torchvision/transforms/functional.py
237 | crop_width, crop_height = IMAGENET_SIZE # (224,224)
238 | image_width, image_height = img.size
239 |
240 | if crop_width > image_width or crop_height > image_height:
241 | padding_ltrb = [
242 | (crop_width - image_width) // 2 if crop_width > image_width else 0,
243 | (crop_height - image_height) // 2 if crop_height > image_height else 0,
244 | (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
245 | (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
246 | ]
247 | img = pad(img, padding_ltrb, fill=0)
248 | image_width, image_height = img.size
249 | if crop_width == image_width and crop_height == image_height:
250 | return img
251 |
252 | crop_top = int(round((image_height - crop_height) / 2.))
253 | crop_left = int(round((image_width - crop_width) / 2.))
254 | return img.crop((crop_left, crop_top, crop_left + crop_width, crop_top + crop_height))
255 |
256 | # def centerCrop_imagenet_default(img):
257 | # return centerCrop_imagenet(img, None)
258 |
259 | def _parse_fill(fill, img, name="fillcolor"):
260 | # Process fill color for affine transforms
261 | num_bands = len(img.getbands())
262 | if fill is None:
263 | fill = 0
264 | if isinstance(fill, (int, float)) and num_bands > 1:
265 | fill = tuple([fill] * num_bands)
266 | if isinstance(fill, (list, tuple)):
267 | if len(fill) != num_bands:
268 | msg = ("The number of elements in 'fill' does not match the number of "
269 | "bands of the image ({} != {})")
270 | raise ValueError(msg.format(len(fill), num_bands))
271 |
272 | fill = tuple(fill)
273 |
274 | return {name: fill}
275 |
276 |
277 | def pad(img, padding_ltrb, fill=0, padding_mode='constant'):
278 | if isinstance(padding_ltrb, list):
279 | padding_ltrb = tuple(padding_ltrb)
280 | if padding_mode == 'constant':
281 | opts = _parse_fill(fill, img, name='fill')
282 | if img.mode == 'P':
283 | palette = img.getpalette()
284 | image = PIL.ImageOps.expand(img, border=padding_ltrb, **opts)
285 | image.putpalette(palette)
286 | return image
287 | return PIL.ImageOps.expand(img, border=padding_ltrb, **opts)
288 | elif len(padding_ltrb) == 4:
289 | image_width, image_height = img.size
290 | cropping = -np.minimum(padding_ltrb, 0)
291 | if cropping.any():
292 | crop_left, crop_top, crop_right, crop_bottom = cropping
293 | img = img.crop((crop_left, crop_top, image_width - crop_right, image_height - crop_bottom))
294 | pad_left, pad_top, pad_right, pad_bottom = np.maximum(padding_ltrb, 0)
295 |
296 | if img.mode == 'P':
297 | palette = img.getpalette()
298 | img = np.asarray(img)
299 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
300 | img = Image.fromarray(img)
301 | img.putpalette(palette)
302 | return img
303 |
304 | img = np.asarray(img)
305 | # RGB image
306 | if len(img.shape) == 3:
307 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
308 | # Grayscale image
309 | if len(img.shape) == 2:
310 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
311 |
312 | return Image.fromarray(img)
313 | else:
314 | raise Exception
315 |
316 | def get_mid_magnitude(l_mags):
317 | ops_mid_magnitude = {'Identity': None,
318 | 'ShearX': (l_mags - 1) // 2,
319 | 'ShearY': (l_mags - 1) // 2,
320 | 'TranslateX': (l_mags - 1) // 2,
321 | 'TranslateY': (l_mags - 1) // 2,
322 | 'Rotate': (l_mags - 1) // 2,
323 | 'AutoContrast': None,
324 | 'Invert': None,
325 | 'Equalize': None,
326 | 'Solarize': l_mags - 1,
327 | 'Posterize': l_mags - 1,
328 | 'Contrast': (l_mags - 1) // 2,
329 | 'Color': (l_mags - 1) // 2,
330 | 'Brightness': (l_mags - 1) // 2,
331 | 'Sharpness': (l_mags - 1) // 2,
332 | 'RandFlip': 'random',
333 | 'RandCutout': 'random',
334 | 'RandCutout60': 'random',
335 | 'RandCrop': 'random',
336 | 'RandResizeCrop_imagenet': 'random',
337 | }
338 | return ops_mid_magnitude
--------------------------------------------------------------------------------
/confs/resnet50_imagenet_DeepAA_8x256_1.yaml:
--------------------------------------------------------------------------------
1 | #load_main_model: true
2 | save_model: true
3 | model:
4 | type: resnet50
5 | dataset: imagenet
6 | aug: DeepAA
7 | deepaa:
8 | EXP: imagenet_1
9 | augmentation_search_space: Not_used
10 | cutout: -1
11 | batch: 256
12 | gpus: 8
13 | epoch: 270
14 | lr: .1
15 | lr_schedule:
16 | type: 'resnet'
17 | warmup:
18 | multiplier: 8.0
19 | epoch: 5
20 | optimizer:
21 | type: sgd
22 | nesterov: True
23 | decay: 0.0001
24 | clip: 0
25 | test_interval: 20
26 |
27 |
--------------------------------------------------------------------------------
/confs/resnet50_imagenet_DeepAA_8x256_2.yaml:
--------------------------------------------------------------------------------
1 | #load_main_model: true
2 | save_model: true
3 | model:
4 | type: resnet50
5 | dataset: imagenet
6 | aug: DeepAA
7 | deepaa:
8 | EXP: imagenet_2
9 | augmentation_search_space: Not_used
10 | cutout: -1
11 | batch: 256
12 | gpus: 8
13 | epoch: 270
14 | lr: .1
15 | lr_schedule:
16 | type: 'resnet'
17 | warmup:
18 | multiplier: 8.0
19 | epoch: 5
20 | optimizer:
21 | type: sgd
22 | nesterov: True
23 | decay: 0.0001
24 | clip: 0
25 | test_interval: 20
26 |
27 |
--------------------------------------------------------------------------------
/confs/wresnet28x10_cifar100_DeepAA_1.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: wresnet28_10
3 | dataset: cifar100
4 | aug: DeepAA
5 | deepaa:
6 | EXP: cifar_1
7 | cutout: -1
8 | batch: 128
9 | gpus: 1
10 | augmentation_search_space: Not_used # fixed_standard
11 | epoch: 200
12 | lr: 0.1
13 | lr_schedule:
14 | type: 'cosine'
15 | warmup:
16 | multiplier: 1
17 | epoch: 5
18 | optimizer:
19 | type: sgd
20 | nesterov: True
21 | decay: 0.0005
22 |
23 |
--------------------------------------------------------------------------------
/confs/wresnet28x10_cifar100_DeepAA_1_wd1e-3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: wresnet28_10
3 | dataset: cifar100
4 | aug: DeepAA
5 | deepaa:
6 | EXP: cifar_1
7 | cutout: -1
8 | batch: 128
9 | gpus: 1
10 | augmentation_search_space: Not_used # fixed_standard
11 | epoch: 200
12 | lr: 0.1
13 | lr_schedule:
14 | type: 'cosine'
15 | warmup:
16 | multiplier: 1
17 | epoch: 5
18 | optimizer:
19 | type: sgd
20 | nesterov: True
21 | decay: 0.001
22 |
23 |
--------------------------------------------------------------------------------
/confs/wresnet28x10_cifar100_DeepAA_2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: wresnet28_10
3 | dataset: cifar100
4 | aug: DeepAA
5 | deepaa:
6 | EXP: cifar_2
7 | cutout: -1
8 | batch: 128
9 | gpus: 1
10 | augmentation_search_space: Not_used # fixed_standard
11 | epoch: 200
12 | lr: 0.1
13 | lr_schedule:
14 | type: 'cosine'
15 | warmup:
16 | multiplier: 1
17 | epoch: 5
18 | optimizer:
19 | type: sgd
20 | nesterov: True
21 | decay: 0.0005
22 |
23 |
--------------------------------------------------------------------------------
/confs/wresnet28x10_cifar100_DeepAA_2_wd1e-3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: wresnet28_10
3 | dataset: cifar100
4 | aug: DeepAA
5 | deepaa:
6 | EXP: cifar_2
7 | cutout: -1
8 | batch: 128
9 | gpus: 1
10 | augmentation_search_space: Not_used # fixed_standard
11 | epoch: 200
12 | lr: 0.1
13 | lr_schedule:
14 | type: 'cosine'
15 | warmup:
16 | multiplier: 1
17 | epoch: 5
18 | optimizer:
19 | type: sgd
20 | nesterov: True
21 | decay: 0.001
22 |
23 |
--------------------------------------------------------------------------------
/confs/wresnet28x10_cifar100_DeepAA_BatchAug8x_1.yaml:
--------------------------------------------------------------------------------
1 | all_workers_use_the_same_batches: true
2 | model:
3 | type: wresnet28_10
4 | dataset: cifar100
5 | aug: DeepAA
6 | deepaa:
7 | EXP: cifar_1
8 | cutout: -1
9 | batch: 128
10 | gpus: 8
11 | augmentation_search_space: Not_used
12 | epoch: 35
13 | lr: 0.4
14 | lr_schedule:
15 | type: 'cosine'
16 | warmup:
17 | multiplier: 1
18 | epoch: 5
19 | optimizer:
20 | type: sgd
21 | nesterov: True
22 | decay: 0.0005
23 |
24 |
--------------------------------------------------------------------------------
/confs/wresnet28x10_cifar100_DeepAA_BatchAug8x_2.yaml:
--------------------------------------------------------------------------------
1 | all_workers_use_the_same_batches: true
2 | model:
3 | type: wresnet28_10
4 | dataset: cifar100
5 | aug: DeepAA
6 | deepaa:
7 | EXP: cifar_2
8 | cutout: -1
9 | batch: 128
10 | gpus: 8
11 | augmentation_search_space: Not_used
12 | epoch: 35
13 | lr: 0.4
14 | lr_schedule:
15 | type: 'cosine'
16 | warmup:
17 | multiplier: 1
18 | epoch: 5
19 | optimizer:
20 | type: sgd
21 | nesterov: True
22 | decay: 0.0005
23 |
24 |
--------------------------------------------------------------------------------
/confs/wresnet28x10_cifar10_DeepAA_1.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: wresnet28_10
3 | dataset: cifar10
4 | aug: DeepAA
5 | deepaa:
6 | EXP: cifar_1
7 | cutout: -1
8 | batch: 128
9 | gpus: 1
10 | augmentation_search_space: Not_used # fixed_standard
11 | epoch: 200
12 | lr: 0.1
13 | lr_schedule:
14 | type: 'cosine'
15 | warmup:
16 | multiplier: 1
17 | epoch: 5
18 | optimizer:
19 | type: sgd
20 | nesterov: True
21 | decay: 0.0005
22 |
23 |
--------------------------------------------------------------------------------
/confs/wresnet28x10_cifar10_DeepAA_1_wd1e-3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: wresnet28_10
3 | dataset: cifar10
4 | aug: DeepAA
5 | deepaa:
6 | EXP: cifar_1
7 | cutout: -1
8 | batch: 128
9 | gpus: 1
10 | augmentation_search_space: Not_used # fixed_standard
11 | epoch: 200
12 | lr: 0.1
13 | lr_schedule:
14 | type: 'cosine'
15 | warmup:
16 | multiplier: 1
17 | epoch: 5
18 | optimizer:
19 | type: sgd
20 | nesterov: True
21 | decay: 0.001
22 |
23 |
--------------------------------------------------------------------------------
/confs/wresnet28x10_cifar10_DeepAA_2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: wresnet28_10
3 | dataset: cifar10
4 | aug: DeepAA
5 | deepaa:
6 | EXP: cifar_2
7 | cutout: -1
8 | batch: 128
9 | gpus: 1
10 | augmentation_search_space: Not_used # fixed_standard
11 | epoch: 200
12 | lr: 0.1
13 | lr_schedule:
14 | type: 'cosine'
15 | warmup:
16 | multiplier: 1
17 | epoch: 5
18 | optimizer:
19 | type: sgd
20 | nesterov: True
21 | decay: 0.0005
22 |
23 |
--------------------------------------------------------------------------------
/confs/wresnet28x10_cifar10_DeepAA_2_wd1e-3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | type: wresnet28_10
3 | dataset: cifar10
4 | aug: DeepAA
5 | deepaa:
6 | EXP: cifar_2
7 | cutout: -1
8 | batch: 128
9 | gpus: 1
10 | augmentation_search_space: Not_used # fixed_standard
11 | epoch: 200
12 | lr: 0.1
13 | lr_schedule:
14 | type: 'cosine'
15 | warmup:
16 | multiplier: 1
17 | epoch: 5
18 | optimizer:
19 | type: sgd
20 | nesterov: True
21 | decay: 0.001
22 |
23 |
--------------------------------------------------------------------------------
/confs/wresnet28x10_cifar10_DeepAA_BatchAug8x_1.yaml:
--------------------------------------------------------------------------------
1 | all_workers_use_the_same_batches: true
2 | model:
3 | type: wresnet28_10
4 | dataset: cifar10
5 | aug: DeepAA
6 | deepaa:
7 | EXP: cifar_1
8 | cutout: -1
9 | batch: 128
10 | gpus: 8
11 | augmentation_search_space: Not_used
12 | epoch: 100
13 | lr: 0.2
14 | lr_schedule:
15 | type: 'cosine'
16 | warmup:
17 | multiplier: 1
18 | epoch: 5
19 | optimizer:
20 | type: sgd
21 | nesterov: True
22 | decay: 0.0005
23 |
24 |
--------------------------------------------------------------------------------
/confs/wresnet28x10_cifar10_DeepAA_BatchAug8x_2.yaml:
--------------------------------------------------------------------------------
1 | all_workers_use_the_same_batches: true
2 | model:
3 | type: wresnet28_10
4 | dataset: cifar10
5 | aug: DeepAA
6 | deepaa:
7 | EXP: cifar_2
8 | cutout: -1
9 | batch: 128
10 | gpus: 8
11 | augmentation_search_space: Not_used
12 | epoch: 100
13 | lr: 0.2
14 | lr_schedule:
15 | type: 'cosine'
16 | warmup:
17 | multiplier: 1
18 | epoch: 5
19 | optimizer:
20 | type: sgd
21 | nesterov: True
22 | decay: 0.0005
23 |
24 |
--------------------------------------------------------------------------------
/data_generator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import logging
4 | import numpy as np
5 | import math
6 | from PIL import Image
7 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
8 | import tensorflow as tf
9 | tf.get_logger().setLevel(logging.ERROR)
10 | from tensorflow.keras.utils import Sequence
11 | from augmentation import IMAGENET_SIZE, centerCrop_imagenet
12 |
13 |
14 | CIFAR_MEANS = np.array([0.49139968, 0.48215841, 0.44653091], dtype=np.float32)
15 | CIFAR_STDS = np.array([0.2023, 0.1994, 0.2010], dtype=np.float32)
16 |
17 | IMAGENET_MEANS = np.array([0.485, 0.456, 0.406], dtype=np.float32)
18 | IMAGENET_STDS = np.array([0.229, 0.224, 0.225], dtype=np.float32)
19 |
20 | def split_train_validation(x, y, val_size):
21 | indices = np.arange(len(x))
22 | np.random.shuffle(indices)
23 | x_train, x_val, y_train, y_val = x[:-val_size], x[-val_size:], y[:-val_size], y[-val_size:]
24 | return x_train, y_train, x_val, y_val
25 |
26 | def get_cifar100_data(num_classes=100, val_size=10000):
27 | (x_train_val, y_train_val), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
28 | y_train_val = y_train_val.squeeze()
29 | y_test = y_test.squeeze()
30 | if val_size > 0:
31 | x_train, y_train, x_val, y_val = split_train_validation(x_train_val, y_train_val, val_size=val_size)
32 | else:
33 | x_train, y_train = x_train_val, y_train_val
34 | x_val, y_val = None, None
35 | return x_train, y_train, x_val, y_val, x_test, y_test
36 |
37 | def get_cifar10_data(num_classes=10, val_size=10000):
38 | (x_train_val, y_train_val), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
39 | y_train_val = y_train_val.squeeze()
40 | y_test = y_test.squeeze()
41 | if val_size > 0:
42 | x_train, y_train, x_val, y_val = split_train_validation(x_train_val, y_train_val, val_size=val_size)
43 | else:
44 | x_train, y_train = x_train_val, y_train_val
45 | x_val, y_val = None, None
46 | return x_train, y_train, x_val, y_val, x_test, y_test
47 |
48 |
49 | class DataGenerator(Sequence):
50 | def __init__(self,
51 | data,
52 | labels,
53 | img_dim=None,
54 | batch_size=32,
55 | num_classes=10,
56 | shuffle=True,
57 | drop_last=True,
58 | ):
59 |
60 | self._data = data
61 | self.data = self._data # initially without calling augment, the output data is not augmented
62 | self.labels = labels
63 | self.img_dim = img_dim
64 | self.batch_size = batch_size
65 | self.num_classes = num_classes
66 | self.shuffle = shuffle
67 | self.drop_last = drop_last
68 | self.on_epoch_end()
69 |
70 | def reset_augment(self):
71 | self.data = self._data
72 |
73 | def on_epoch_end(self):
74 | self.indices = np.arange(len(self._data))
75 | if self.shuffle:
76 | np.random.shuffle(self.indices)
77 |
78 | def sample_labeled_data_batch(self, label, bs):
79 | # suffle indices every time
80 | indices = np.arange(len(self._data))
81 | np.random.shuffle(indices)
82 | if isinstance(self.labels, list):
83 | labels = [self.labels[k] for k in indices]
84 | else:
85 | labels = self.labels[indices]
86 | matched_labels = np.array(labels) == int(label)
87 | matched_indices = [id for id, isMatched in enumerate(matched_labels) if isMatched]
88 | if len(matched_indices) - bs >=0:
89 | start_idx = np.random.randint(0, len(matched_indices)-bs)
90 | batch_indices = matched_indices[start_idx:start_idx + bs]
91 | else:
92 | print('Not enough matched data, required {}, but got {} instead'.format(bs, len(matched_indices)))
93 | batch_indices = matched_indices
94 | data_indices = indices[batch_indices]
95 | return [self.data[k] for k in data_indices], np.array([self.labels[k] for k in data_indices], dtype=self.labels[0].dtype)
96 |
97 | def __len__(self):
98 | if self.drop_last:
99 | return int(np.floor(len(self.data) / self.batch_size)) # drop the last batch
100 | else:
101 | return int(np.ceil(len(self.data) / self.batch_size)) # drop the last batch
102 |
103 | def __getitem__(self, idx):
104 | curr_batch = self.indices[idx*self.batch_size:(idx+1)*self.batch_size]
105 | batch_len = len(curr_batch)
106 | if isinstance(self.data, list) and isinstance(self.labels, list):
107 | return [self.data[k] for k in curr_batch], np.array([self.labels[k] for k in curr_batch], np.int32)
108 | else:
109 | return self.data[curr_batch], self.labels[curr_batch]
110 |
111 | class DataAugmentation(object):
112 | def __init__(self, num_classes, dataset, image_shape, ops_list=None, default_pre_aug=None, default_post_aug=None):
113 | self.ops, self.op_names = ops_list
114 | self.default_pre_aug = default_pre_aug
115 | self.default_post_aug = default_post_aug
116 | self.num_classes = num_classes
117 | self.dataset = dataset
118 | self.image_shape = image_shape
119 | if 'imagenet' in self.dataset:
120 | assert self.image_shape == (*IMAGENET_SIZE, 3)
121 | elif 'cifar' in self.dataset:
122 | assert self.image_shape == (32, 32, 3)
123 | else:
124 | raise Exception('Unrecognized dataset')
125 |
126 | def sequantially_augment(self, args):
127 | idx, img_, op_idxs, mags, aug_finish = args
128 | assert img_.dtype == np.uint8, 'Input images should be unporocessed, should stay in np.uint8'
129 | img = copy.deepcopy(img_)
130 | pil_img = Image.fromarray(img) # Convert to PIL.Image
131 | if self.default_pre_aug is not None:
132 | for op in self.default_pre_aug:
133 | pil_img = op(pil_img)
134 | if self.ops is not None:
135 | for op_idx, mag in zip(op_idxs, mags):
136 | op, minval, maxval = self.ops[op_idx]
137 | assert mag > -1e-5 and mag < 1. + 1e-5, 'magnitudes should be in the range of (0., 1.)'
138 | mag = mag * (maxval - minval) + minval
139 | pil_img = op(pil_img, mag)
140 | if self.default_post_aug is not None and self.use_post_aug:
141 | for op in self.default_post_aug:
142 | pil_img = op(pil_img, None)
143 | if 'cifar' in self.dataset:
144 | img = np.asarray(pil_img, dtype=np.uint8)
145 | return idx, img
146 | elif 'imagenet' in self.dataset:
147 | if aug_finish:
148 | pil_img = self.crop_IMAGENET(pil_img)
149 | img = np.asarray(pil_img, dtype=np.uint8)
150 | return idx, img
151 | else:
152 | raise Exception
153 |
154 | def postprocessing_standardization(self, pil_img):
155 | x = np.asarray(pil_img, dtype=np.float32) / 255.
156 | if 'cifar' in self.dataset:
157 | x = (x - CIFAR_MEANS) / CIFAR_STDS
158 | elif 'imagenet' in self.dataset:
159 | x = (x - IMAGENET_MEANS) / IMAGENET_STDS
160 | else:
161 | raise Exception('Unrecoginized dataset')
162 | return x
163 |
164 | def crop_IMAGENET(self, img):
165 | # cropping imagenet dataset to the same size
166 | if isinstance(img, np.ndarray):
167 | assert img.shape == (IMAGENET_SIZE[1], IMAGENET_SIZE[0], 3) and img.dtype==np.uint8, 'numpy array should be {}, but got {}. crop_IMAGENET does not apply to numpy array, but got {}'.format(IMAGENET_SIZE, img.size, img.dtype)
168 | return img
169 | w, h = img.size
170 | if w == IMAGENET_SIZE[0] and h == IMAGENET_SIZE[1]:
171 | return img
172 | return centerCrop_imagenet(img, None)
173 |
174 | def check_data_type(self, images, labels):
175 | assert images[0].dtype == np.uint8
176 | if 'imagenet' in self.dataset:
177 | assert type(labels[0]) == np.int32
178 | elif 'cifar' in self.dataset:
179 | assert type(labels[0]) == np.uint8
180 | else:
181 | raise Exception('Unrecognized dataset')
182 |
183 | def __call__(self, images, labels, samples_op, samples_mag, use_post_aug, pool=None, chunksize=None, aug_finish=True):
184 | self.check_data_type(images, labels)
185 |
186 | self.use_post_aug = use_post_aug
187 | self.batch_len = len(labels)
188 | if aug_finish:
189 | aug_imgs = np.empty([self.batch_len, *self.image_shape], dtype=np.float32)
190 | else:
191 | aug_imgs = [None]*self.batch_len
192 | aug_results = pool.imap_unordered(self.sequantially_augment,
193 | zip(range(self.batch_len), images, samples_op, samples_mag, [aug_finish]*self.batch_len),
194 | chunksize=math.ceil(float(self.batch_len) / float(pool._processes)) if chunksize is None else chunksize)
195 | for idx, img in aug_results:
196 | aug_imgs[idx] = img
197 |
198 | if aug_finish:
199 | aug_imgs = self.postprocessing_standardization(aug_imgs)
200 |
201 | return aug_imgs, labels
--------------------------------------------------------------------------------
/imagenet_data_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | from torchvision.datasets.imagenet import *
4 | from torch import randperm, default_generator
5 | from torch._utils import _accumulate
6 | from torch.utils.data.dataset import Subset
7 |
8 |
9 | _DATA_TYPE = tf.float32
10 |
11 | CMYK_IMAGES = [
12 | 'n01739381_1309.JPEG',
13 | 'n02077923_14822.JPEG',
14 | 'n02447366_23489.JPEG',
15 | 'n02492035_15739.JPEG',
16 | 'n02747177_10752.JPEG',
17 | 'n03018349_4028.JPEG',
18 | 'n03062245_4620.JPEG',
19 | 'n03347037_9675.JPEG',
20 | 'n03467068_12171.JPEG',
21 | 'n03529860_11437.JPEG',
22 | 'n03544143_17228.JPEG',
23 | 'n03633091_5218.JPEG',
24 | 'n03710637_5125.JPEG',
25 | 'n03961711_5286.JPEG',
26 | 'n04033995_2932.JPEG',
27 | 'n04258138_17003.JPEG',
28 | 'n04264628_27969.JPEG',
29 | 'n04336792_7448.JPEG',
30 | 'n04371774_5854.JPEG',
31 | 'n04596742_4225.JPEG',
32 | 'n07583066_647.JPEG',
33 | 'n13037406_4650.JPEG',
34 | ]
35 |
36 | PNG_IMAGES = ['n02105855_2933.JPEG']
37 |
38 | class ImageNet(ImageFolder):
39 | """`ImageNet `_ 2012 Classification Dataset.
40 | Copied from torchvision, besides warning below.
41 |
42 | Args:
43 | root (string): Root directory of the ImageNet Dataset.
44 | split (string, optional): The dataset split, supports ``train``, or ``val``.
45 | transform (callable, optional): A function/transform that takes in an PIL image
46 | and returns a transformed version. E.g, ``transforms.RandomCrop``
47 | target_transform (callable, optional): A function/transform that takes in the
48 | target and transforms it.
49 | loader (callable, optional): A function to load an image given its path.
50 |
51 | Attributes:
52 | classes (list): List of the class name tuples.
53 | class_to_idx (dict): Dict with items (class_name, class_index).
54 | wnids (list): List of the WordNet IDs.
55 | wnid_to_idx (dict): Dict with items (wordnet_id, class_index).
56 | imgs (list): List of (image path, class_index) tuples
57 | targets (list): The class_index value for each image in the dataset
58 |
59 | WARN::
60 | This is the same ImageNet class as in torchvision.datasets.imagenet, but it has the `ignore_archive` argument.
61 | This allows us to only copy the unzipped files before training.
62 | """
63 |
64 | def __init__(self, root, split='train', download=None, ignore_archive=False, **kwargs):
65 | if download is True:
66 | msg = ("The dataset is no longer publicly accessible. You need to "
67 | "download the archives externally and place them in the root "
68 | "directory.")
69 | raise RuntimeError(msg)
70 | elif download is False:
71 | msg = ("The use of the download flag is deprecated, since the dataset "
72 | "is no longer publicly accessible.")
73 | warnings.warn(msg, RuntimeWarning)
74 |
75 | root = self.root = os.path.expanduser(root)
76 | self.split = verify_str_arg(split, "split", ("train", "val"))
77 |
78 | if not ignore_archive:
79 | self.parse_archives()
80 | wnid_to_classes = load_meta_file(self.root)[0]
81 |
82 | super(ImageNet, self).__init__(self.split_folder, **kwargs)
83 | self.root = root
84 |
85 | self.wnids = self.classes
86 | self.wnid_to_idx = self.class_to_idx
87 | self.classes = [wnid_to_classes[wnid] for wnid in self.wnids]
88 | self.class_to_idx = {cls: idx
89 | for idx, clss in enumerate(self.classes)
90 | for cls in clss}
91 |
92 | def parse_archives(self):
93 | if not check_integrity(os.path.join(self.root, META_FILE)):
94 | parse_devkit_archive(self.root)
95 |
96 | if not os.path.isdir(self.split_folder):
97 | if self.split == 'train':
98 | parse_train_archive(self.root)
99 | elif self.split == 'val':
100 | parse_val_archive(self.root)
101 |
102 | @property
103 | def split_folder(self):
104 | return os.path.join(self.root, self.split)
105 |
106 | def extra_repr(self):
107 | return "Split: {split}".format(**self.__dict__)
108 |
109 | class ImageNet_DeepAA(ImageNet):
110 | def __init__(self, root, split='train', download=None, **kwargs):
111 | super(ImageNet_DeepAA, self).__init__(root, split=split, download=download, ignore_archive=True, **kwargs)
112 | _, self.labels_ = zip(*self.samples)
113 |
114 | def on_epoch_end(self):
115 | print('Dummy one_epoch_end for ImageNet dataset using torchvision')
116 | pass
117 |
118 | def sample_labeled_data_batch(self, label, val_bs): # generate val and train batch at the same time
119 | matched_indices = [id for id, lab in enumerate(self.labels_) if lab==label]
120 | matched_indices = np.array(matched_indices)
121 | assert len(matched_indices) > val_bs, 'Make sure the have enough data'
122 | np.random.shuffle(matched_indices)
123 | val_indices = matched_indices[:val_bs]
124 |
125 | val_samples, val_labels = zip(*[self[id] for id in val_indices])
126 | val_samples = list(val_samples)
127 | val_labels = np.array(val_labels, dtype=np.int32)
128 |
129 | return val_samples, val_labels
130 |
131 | class Subset_ImageNet(Subset):
132 | def __init__(self, dataset, indices):
133 | super(Subset_ImageNet, self).__init__(dataset, indices)
134 | self.subset_labels_ = [self.dataset.labels_[k] for k in indices]
135 |
136 |
137 | def on_epoch_end(self):
138 | pass
139 |
140 | def sample_labeled_data_batch(self, label, val_bs):
141 | matched_indices = [self.indices[id] for id, lab in enumerate(self.subset_labels_) if lab == label]
142 | matched_indices = np.array(matched_indices)
143 | assert len(matched_indices) > val_bs, 'Make sure the have enough data'
144 | np.random.shuffle(matched_indices)
145 | val_indices = matched_indices[:val_bs]
146 |
147 | val_samples, val_labels = zip(*[self.dataset[id] for id in val_indices]) # applies transforms
148 | val_samples = list(val_samples)
149 | val_labels = np.array(val_labels, dtype=np.int32)
150 |
151 | return val_samples, val_labels
152 |
153 | def random_split_ImageNet(dataset, lengths, generator=default_generator):
154 | if sum(lengths) != len(dataset):
155 | raise ValueError('Sum of input lengths does not equal the length of the input dataset')
156 | indices = randperm(sum(lengths), generator=generator).tolist()
157 | return [Subset_ImageNet(dataset, indices[offset - length : offset]) for offset, length in zip(_accumulate(lengths), lengths)]
158 |
159 | def get_imagenet_split(val_size=400000, train_sep_size=100000, dataroot='./data', n_GPU=None, seed=300):
160 | transform = lambda img: np.array(img, dtype=np.uint8)
161 | total_trainset = ImageNet_DeepAA(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform)
162 | testset = ImageNet_DeepAA(root=os.path.join(dataroot, 'imagenet-pytorch'), split='val', transform=transform)
163 |
164 | N_per_shard = (len(total_trainset) - val_size - train_sep_size)//n_GPU
165 | remaining_data = len(total_trainset) - val_size - train_sep_size - n_GPU * N_per_shard
166 | if remaining_data > 0:
167 | splits = [val_size, train_sep_size, *[N_per_shard]*n_GPU, remaining_data]
168 | else:
169 | splits = [val_size, train_sep_size, *[N_per_shard]*n_GPU]
170 | all_ds = random_split_ImageNet(total_trainset,
171 | lengths=splits,
172 | generator=torch.Generator().manual_seed(seed))
173 | val_ds = all_ds[0]
174 | train_ds_sep = all_ds[1]
175 | pretrain_ds_splits = all_ds[2:2+n_GPU]
176 | return total_trainset, val_ds, train_ds_sep, pretrain_ds_splits, testset
--------------------------------------------------------------------------------
/images/DeepAA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/images/DeepAA.png
--------------------------------------------------------------------------------
/images/DeepAA_slideslive.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/images/DeepAA_slideslive.png
--------------------------------------------------------------------------------
/images/magnitude_distribution_cifar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/images/magnitude_distribution_cifar.png
--------------------------------------------------------------------------------
/images/magnitude_distribution_imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/images/magnitude_distribution_imagenet.png
--------------------------------------------------------------------------------
/images/operation_distribution.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/images/operation_distribution.png
--------------------------------------------------------------------------------
/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras.optimizers.schedules import LearningRateSchedule
3 | from tensorflow.python.framework import ops
4 | from tensorflow.python.ops import math_ops, control_flow_ops
5 |
6 | class GradualWarmup_Cosine_Scheduler(LearningRateSchedule):
7 | def __init__(self, starting_lr, initial_lr, ending_lr, warmup_steps, total_steps, name=None):
8 | super(GradualWarmup_Cosine_Scheduler, self).__init__()
9 |
10 | self.starting_lr = starting_lr
11 | self.initial_lr = initial_lr
12 | self.ending_lr = ending_lr
13 | self.warmup_steps = warmup_steps
14 | self.total_steps = total_steps
15 | self.name = name
16 |
17 | def __call__(self, step):
18 | with ops.name_scope_v2(self.name or 'GradualWarmup_Cosine') as name:
19 | initial_lr = ops.convert_to_tensor_v2(self.initial_lr, name='initial_learning_rate')
20 | dtype = initial_lr.dtype
21 | starting_lr = math_ops.cast(self.starting_lr, dtype)
22 | ending_lr = math_ops.cast(self.ending_lr, dtype)
23 | warmup_steps = math_ops.cast(self.warmup_steps, dtype)
24 | total_steps = math_ops.cast(self.total_steps, dtype)
25 | one = math_ops.cast(1.0, dtype)
26 | point5 = math_ops.cast(0.5, dtype)
27 | pi = math_ops.cast(3.1415926536, dtype)
28 | step = math_ops.cast(step, dtype)
29 |
30 | lr = tf.cond(step < warmup_steps,
31 | true_fn=lambda: self._warmup_schedule(starting_lr, initial_lr, step, warmup_steps),
32 | false_fn=lambda: self._cosine_annealing_schedule(initial_lr, ending_lr, step, warmup_steps, total_steps, pi,
33 | point5, one))
34 | return lr
35 |
36 | def _warmup_schedule(self, starting_lr, initial_lr, step, warmup_steps):
37 | ratio = math_ops.divide(step, warmup_steps)
38 | lr = math_ops.add(starting_lr,
39 | math_ops.multiply(initial_lr - starting_lr, ratio))
40 | return lr
41 |
42 | def _cosine_annealing_schedule(self, initial_lr, ending_lr, step, warmup_steps, total_steps, pi, point5, one):
43 | ratio = math_ops.divide(step - warmup_steps, total_steps - warmup_steps)
44 | cosine_ratio_pi = math_ops.cos(math_ops.multiply(ratio, pi))
45 | second_part = math_ops.multiply(point5,
46 | math_ops.multiply(initial_lr - ending_lr,
47 | one + cosine_ratio_pi))
48 | lr = math_ops.add(ending_lr, second_part)
49 | return lr
50 |
51 |
52 | def get_config(self):
53 | return {
54 | 'starting_lr': self.starting_lr,
55 | 'initial_lr': self.initial_lr,
56 | 'ending_lr': self.ending_lr,
57 | 'warmup_steps': self.warmup_steps,
58 | 'total_steps': self.total_steps,
59 | 'name': self.name
60 | }
--------------------------------------------------------------------------------
/policy.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import math
4 | import json
5 |
6 | from tensorflow_probability import distributions as tfd
7 |
8 | from resnet import Resnet
9 |
10 | CIFAR_MEANS = np.array([0.49139968, 0.48215841, 0.44653091], dtype=np.float32)
11 | CIFAR_STDS = np.array([0.2023, 0.1994, 0.2010], dtype=np.float32)
12 |
13 | SVHN_MEANS = np.array([0.4379, 0.4440, 0.4729], dtype=np.float32)
14 | SVHN_STDS = np.array([0.1980, 0.2010, 0.1970], dtype=np.float32)
15 |
16 | IMAGENET_MEANS = np.array([0.485, 0.456, 0.406], dtype=np.float32)
17 | IMAGENET_STDS = np.array([0.229, 0.224, 0.225], dtype=np.float32)
18 |
19 | class DA_Policy_logits(tf.keras.Model):
20 | def __init__(self, l_ops, l_mags, l_uniq, op_names, ops_mid_magnitude,
21 | N_repeat_random, available_policies, policy_init='identity'):
22 | super().__init__()
23 | self.l_uniq = l_uniq
24 | self.l_ops = l_ops
25 | self.l_mags = l_mags
26 | self.N_repeat_random = N_repeat_random
27 | self.available_policies = available_policies
28 |
29 | if policy_init == 'uniform':
30 | init_value = tf.constant([0.0]*len(available_policies), dtype=tf.float32)
31 | elif policy_init == 'identity':
32 | init_value = tf.constant([8.0] + [0.0]*(len(available_policies)-1), dtype=tf.float32)
33 | init_value = init_value - tf.reduce_mean(init_value)
34 | else:
35 | raise Exception
36 | self.logits = tf.Variable(initial_value=init_value, trainable=True)
37 |
38 | self.ops_mid_magnitude = ops_mid_magnitude
39 | self.unique_policy = self._get_unique_policy(op_names, l_ops, l_mags)
40 | self.N_random, self.repeat_cfg, self.reduce_random_mat = self._get_repeat_random(op_names, l_ops, l_mags,
41 | l_uniq, N_repeat_random)
42 | self.act = tf.nn.softmax
43 |
44 | def sample(self, images_orig, images, onehot_ops_mags, augNum):
45 | bs = len(images_orig)
46 | probs = self.act(self.logits, axis=-1)
47 | dist = tfd.Categorical(probs=probs)
48 | samples_om = dist.sample(augNum*bs).numpy() # (augNum, bs)
49 |
50 | ops_dense, mags_dense, reduce_random_mat, ops_mags_idx, probs, probs_exp = self.get_dense_aug(images, repeat_random_ops=False)
51 | ops = ops_dense[samples_om]
52 | mags = mags_dense[samples_om]
53 | ops_mags_idx_sample = ops_mags_idx[samples_om]
54 | probs_sample = probs.numpy()[samples_om]
55 |
56 | return ops, mags, ops_mags_idx_sample, probs_sample
57 |
58 | def probs(self, images_orig, images, onehot_ops_mags, training):
59 | bs = len(images_orig)
60 | probs = self.act(self.logits, axis=-1)
61 | probs = tf.repeat(probs[tf.newaxis], bs, axis=0)
62 | return probs
63 |
64 | def get_dense_aug(self, images, repeat_random_ops):
65 | ops_uniq, mags_uniq = self.unique_policy
66 | ops_dense = np.squeeze(ops_uniq)[self.available_policies]
67 | mags_dense = np.squeeze(mags_uniq)[self.available_policies]
68 | ops_mags_idx = self.available_policies
69 | if repeat_random_ops:
70 | isRepeat = [np.any(np.array(ops_dense == repeat_op_idx), axis=1) for repeat_op_idx in self.repeat_ops_idx]
71 | isRepeat = np.stack(isRepeat, axis=1)
72 | isRepeat = np.any(isRepeat, axis=1)
73 | nRepeat = [self.N_repeat_random if isrepeat else 1 for isrepeat in isRepeat]
74 |
75 | ops_dense = np.repeat(ops_dense, nRepeat, axis=0)
76 | mags_dense = np.repeat(mags_dense, nRepeat, axis=0)
77 | reduce_random_mat = np.eye(len(self.available_policies)) / np.array(nRepeat, dtype=np.float32)
78 | reduce_random_mat = np.repeat(reduce_random_mat, nRepeat, axis=1)
79 | else:
80 | nRepeat = [1] * len(self.available_policies)
81 | reduce_random_mat = np.eye(len(self.available_policies))
82 |
83 | probs = self.act(self.logits)
84 | probs_exp = np.repeat(probs/np.array(nRepeat, dtype=np.float32), nRepeat, axis=0)
85 | return ops_dense, mags_dense, reduce_random_mat, ops_mags_idx, probs, probs_exp
86 |
87 | def _get_unique_policy(self, op_names, l_ops, l_mags):
88 | names_modified = [op_name.split(':')[0] for op_name in op_names]
89 | ops_list, mags_list = [], []
90 | repeat_ops_idx = []
91 | for k_name, name in enumerate(names_modified):
92 | if self.ops_mid_magnitude[name] == 'random':
93 | repeat_ops_idx.append(k_name)
94 | ops_sub, mags_sub = np.array([[k_name]], dtype=np.int32), np.array([[(l_mags - 1) // 2]], dtype=np.int32)
95 | elif self.ops_mid_magnitude[name] is not None and self.ops_mid_magnitude[name]>=0 and self.ops_mid_magnitude[name]<=l_mags-1:
96 | ops_sub = k_name * np.ones([l_mags - 1, 1], dtype=np.int32)
97 | mags_sub = np.array([l for l in range(l_mags) if l != self.ops_mid_magnitude[name]], dtype=np.int32)[:, np.newaxis]
98 | elif self.ops_mid_magnitude[name] is not None and self.ops_mid_magnitude[name]<0: #or self.ops_mid_magnitude[name]>l_mags-1):
99 | ops_sub = k_name * np.ones([l_mags, 1], dtype=np.int32)
100 | mags_sub = np.arange(l_mags, dtype=np.int32)[:, np.newaxis]
101 | elif self.ops_mid_magnitude[name] is None:
102 | ops_sub, mags_sub = np.array([[k_name]], dtype=np.int32), np.array([[(l_mags - 1) // 2]], dtype=np.int32)
103 | else:
104 | raise Exception('Unrecognized middle magnitude')
105 | ops_list.append(ops_sub)
106 | mags_list.append(mags_sub)
107 | ops = np.concatenate(ops_list, axis=0)
108 | mags = np.concatenate(mags_list, axis=0)
109 | self.repeat_ops_idx = repeat_ops_idx
110 | return ops.astype(np.int32), mags.astype(np.int32)
111 |
112 | def _get_repeat_random(self, op_names, l_ops, l_mags, l_uniq, N_repeat_random):
113 | names_modified = [op_name.split(':')[0] for op_name in op_names]
114 | N_random = sum([1 for name in names_modified if self.ops_mid_magnitude[name]=='random'])
115 | repeat_cfg = []
116 | for k_name, name in enumerate(names_modified):
117 | if self.ops_mid_magnitude[name] == 'random':
118 | repeat_cfg.append(N_repeat_random) # we may repeat random operations for N_repeat_random times
119 | elif self.ops_mid_magnitude[name] is not None and self.ops_mid_magnitude[name] == -1:
120 | repeat_cfg.append([1]*l_mags)
121 | elif self.ops_mid_magnitude[name] is not None and self.ops_mid_magnitude[name] >= 0 and self.ops_mid_magnitude[name]<=l_mags-1:
122 | repeat_cfg.extend([1]*(l_mags-1))
123 | elif self.ops_mid_magnitude[name] is None:
124 | repeat_cfg.append(1)
125 | else:
126 | raise Exception
127 | repeat_cfg = np.array(repeat_cfg, dtype=np.int32)
128 |
129 | reduce_mat = np.eye(l_uniq)/repeat_cfg[np.newaxis].astype(np.float)
130 | reduce_mat = np.repeat(reduce_mat, repeat_cfg, axis=1)
131 | return N_random, repeat_cfg, reduce_mat
132 |
133 | @property
134 | def idx_removed_redundant(self):
135 | idx_removed_redundant = np.concatenate([[1] if rep == 1 else [1]+[0]*(rep-1) for rep in self.repeat_cfg ]).nonzero()[0]
136 | assert len(idx_removed_redundant) == self.l_uniq, 'removing the repeated random operations'
137 | return idx_removed_redundant
--------------------------------------------------------------------------------
/policy_port/policy_DeepAA_cifar_1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/policy_port/policy_DeepAA_cifar_1.npz
--------------------------------------------------------------------------------
/policy_port/policy_DeepAA_cifar_2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/policy_port/policy_DeepAA_cifar_2.npz
--------------------------------------------------------------------------------
/policy_port/policy_DeepAA_imagenet_1.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/policy_port/policy_DeepAA_imagenet_1.npz
--------------------------------------------------------------------------------
/policy_port/policy_DeepAA_imagenet_2.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/DeepAA/7a1b94fa930b392bddff17c8d5f6a9b8c8e44a7b/policy_port/policy_DeepAA_imagenet_2.npz
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/wbaek/theconf
2 | git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git
3 | git+https://github.com/ildoonet/pystopwatch2.git
4 |
5 | keras==2.4.0
6 | tensorflow-datasets==4.3.0
7 | tensorflow-probability==0.13.0
8 | matplotlib
9 | seaborn
10 | pandas
11 | packaging
12 |
13 | colored
14 | pretrainedmodels
15 | tqdm
16 | tensorboardx
17 | sklearn
18 | matplotlib
19 | psutil
20 | requests
21 | Pillow
--------------------------------------------------------------------------------
/resnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tensorflow as tf
3 | # ref: https://github.com/gahaalt/resnets-in-tensorflow2/blob/master/Models/Resnets.py
4 | _bn_momentum = 0.9
5 |
6 | def regularized_padded_conv(*args, **kwargs):
7 | return tf.keras.layers.Conv2D(*args, **kwargs, padding='same', kernel_regularizer=_regularizer, bias_regularizer=_regularizer,
8 | kernel_initializer='he_normal', use_bias=True)
9 |
10 |
11 | def bn_relu(x):
12 | x = tf.keras.layers.experimental.SyncBatchNormalization(momentum=_bn_momentum)(x)
13 | return tf.keras.layers.ReLU()(x)
14 |
15 |
16 | def shortcut(x, filters, stride, mode):
17 | if x.shape[-1] == filters: # maybe and stride==1
18 | return x
19 | elif mode == 'B':
20 | return regularized_padded_conv(filters, 1, strides=stride)(x)
21 | elif mode == 'B_original':
22 | x = regularized_padded_conv(filters, 1, strides=stride)(x)
23 | return tf.keras.layers.experimental.SyncBatchNormalization(momentum=_bn_momentum)(x)
24 | elif mode == 'A':
25 | return tf.pad(tf.keras.layers.MaxPool2D(1, stride)(x) if stride > 1 else x,
26 | paddings=[(0, 0), (0, 0), (0, 0), (0, filters - x.shape[-1])])
27 | else:
28 | raise KeyError("Parameter shortcut_type not recognized!")
29 |
30 |
31 | def original_block(x, filters, stride=1, **kwargs):
32 | c1 = regularized_padded_conv(filters, 3, strides=stride)(x)
33 | c2 = regularized_padded_conv(filters, 3)(bn_relu(c1))
34 | c2 = tf.keras.layers.experimental.SyncBatchNormalization(momentum=_bn_momentum)(c2)
35 |
36 | mode = 'B_original' if _shortcut_type == 'B' else _shortcut_type
37 | x = shortcut(x, filters, stride, mode=mode)
38 | return tf.keras.layers.ReLU()(x + c2)
39 |
40 |
41 | def preactivation_block(x, filters, stride=1, preact_block=False):
42 | flow = bn_relu(x)
43 |
44 | c1 = regularized_padded_conv(filters, 3)(flow)
45 | if _dropout:
46 | c1 = tf.keras.layers.Dropout(_dropout)(c1)
47 |
48 | c2 = regularized_padded_conv(filters, 3, strides=stride)(bn_relu(c1))
49 | x = shortcut(x, filters, stride, mode=_shortcut_type)
50 | return x + c2
51 |
52 |
53 | def bootleneck_block(x, filters, stride=1, preact_block=False):
54 | flow = bn_relu(x)
55 | if preact_block:
56 | x = flow
57 |
58 | c1 = regularized_padded_conv(filters // _bootleneck_width, 1)(flow)
59 | c2 = regularized_padded_conv(filters // _bootleneck_width, 3, strides=stride)(bn_relu(c1))
60 | c3 = regularized_padded_conv(filters, 1)(bn_relu(c2))
61 | x = shortcut(x, filters, stride, mode=_shortcut_type)
62 | return x + c3
63 |
64 |
65 | def group_of_blocks(x, block_type, num_blocks, filters, stride, block_idx=0):
66 | global _preact_shortcuts
67 | preact_block = True if _preact_shortcuts or block_idx == 0 else False
68 |
69 | x = block_type(x, filters, stride, preact_block=preact_block)
70 | for i in range(num_blocks - 1):
71 | x = block_type(x, filters)
72 | return x
73 |
74 |
75 | def Resnet(input_shape, n_classes, l2_reg=1e-4, group_sizes=(2, 2, 2), features=(16, 32, 64), strides=(1, 2, 2),
76 | shortcut_type='B', block_type='preactivated', first_conv={"filters": 16, "kernel_size": 3, "strides": 1},
77 | dropout=0, cardinality=1, bootleneck_width=4, preact_shortcuts=True,
78 | final_dense_kernel_initializer=None, final_dense_bias_initializer=None):
79 | global _regularizer, _shortcut_type, _preact_projection, _dropout, _cardinality, _bootleneck_width, _preact_shortcuts
80 | _bootleneck_width = bootleneck_width # used in ResNeXts and bootleneck blocks
81 | _regularizer = tf.keras.regularizers.l2(l2_reg)
82 | _shortcut_type = shortcut_type # used in blocks
83 | _cardinality = cardinality # used in ResNeXts
84 | _dropout = dropout # used in Wide ResNets
85 | _preact_shortcuts = preact_shortcuts
86 |
87 | block_types = {'preactivated': preactivation_block,
88 | 'bootleneck': bootleneck_block,
89 | 'original': original_block}
90 |
91 | selected_block = block_types[block_type]
92 | inputs = tf.keras.layers.Input(shape=input_shape)
93 | flow = regularized_padded_conv(**first_conv)(inputs)
94 |
95 | if block_type == 'original':
96 | flow = bn_relu(flow)
97 |
98 | for block_idx, (group_size, feature, stride) in enumerate(zip(group_sizes, features, strides)):
99 | flow = group_of_blocks(flow,
100 | block_type=selected_block,
101 | num_blocks=group_size,
102 | block_idx=block_idx,
103 | filters=feature,
104 | stride=stride)
105 |
106 | if block_type != 'original':
107 | flow = bn_relu(flow)
108 |
109 | flow = tf.keras.layers.GlobalAveragePooling2D()(flow)
110 |
111 | if final_dense_kernel_initializer is not None:
112 | assert final_dense_bias_initializer is not None, 'make sure kernel and bias initializer is not None at the same time'
113 | outputs = tf.keras.layers.Dense(n_classes, kernel_regularizer=_regularizer,
114 | kernel_initializer=final_dense_kernel_initializer,
115 | bias_initializer=final_dense_bias_initializer)(flow)
116 | else:
117 | outputs = tf.keras.layers.Dense(n_classes, kernel_regularizer=_regularizer)(flow)
118 | model = tf.keras.Model(inputs=inputs, outputs=outputs)
119 | return model
120 |
121 |
122 | def load_weights_func(model, model_name):
123 | try:
124 | model.load_weights(os.path.join('saved_models', model_name + '.tf'))
125 | except tf.errors.NotFoundError:
126 | print("No weights found for this model!")
127 | return model
128 |
129 |
130 | def cifar_resnet20(block_type='original', shortcut_type='A', l2_reg=1e-4, load_weights=False, input_shape=None, n_classes=None):
131 | model = Resnet(input_shape=input_shape, n_classes=n_classes, l2_reg=l2_reg, group_sizes=(3, 3, 3), features=(16, 32, 64),
132 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1},
133 | shortcut_type=shortcut_type,
134 | block_type=block_type, preact_shortcuts=False)
135 | if load_weights: model = load_weights_func(model, 'cifar_resnet20')
136 | return model
137 |
138 |
139 | def cifar_resnet32(block_type='original', shortcut_type='A', l2_reg=1e-4, load_weights=False, input_shape=None):
140 | model = Resnet(input_shape=input_shape, n_classes=10, l2_reg=l2_reg, group_sizes=(5, 5, 5), features=(16, 32, 64),
141 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1},
142 | shortcut_type=shortcut_type,
143 | block_type=block_type, preact_shortcuts=False)
144 | if load_weights: model = load_weights_func(model, 'cifar_resnet32')
145 | return model
146 |
147 |
148 | def cifar_resnet44(block_type='original', shortcut_type='A', l2_reg=1e-4, load_weights=False, input_shape=None):
149 | model = Resnet(input_shape=input_shape, n_classes=10, l2_reg=l2_reg, group_sizes=(7, 7, 7), features=(16, 32, 64),
150 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1},
151 | shortcut_type=shortcut_type,
152 | block_type=block_type, preact_shortcuts=False)
153 | if load_weights: model = load_weights_func(model, 'cifar_resnet44')
154 | return model
155 |
156 |
157 | def cifar_resnet56(block_type='original', shortcut_type='A', l2_reg=1e-4, load_weights=False, input_shape=None):
158 | model = Resnet(input_shape=input_shape, n_classes=10, l2_reg=l2_reg, group_sizes=(9, 9, 9), features=(16, 32, 64),
159 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1},
160 | shortcut_type=shortcut_type,
161 | block_type=block_type, preact_shortcuts=False)
162 | if load_weights: model = load_weights_func(model, 'cifar_resnet56')
163 | return model
164 |
165 |
166 | def cifar_resnet110(block_type='preactivated', shortcut_type='B', l2_reg=1e-4, load_weights=False, input_shape=None):
167 | model = Resnet(input_shape=input_shape, n_classes=10, l2_reg=l2_reg, group_sizes=(18, 18, 18),
168 | features=(16, 32, 64),
169 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1},
170 | shortcut_type=shortcut_type,
171 | block_type=block_type, preact_shortcuts=False)
172 | if load_weights: model = load_weights_func(model, 'cifar_resnet110')
173 | return model
174 |
175 |
176 | def cifar_resnet164(shortcut_type='B', load_weights=False, l2_reg=1e-4, input_shape=None):
177 | model = Resnet(input_shape=input_shape, n_classes=10, l2_reg=l2_reg, group_sizes=(18, 18, 18),
178 | features=(64, 128, 256),
179 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1},
180 | shortcut_type=shortcut_type,
181 | block_type='bootleneck', preact_shortcuts=True)
182 | if load_weights: model = load_weights_func(model, 'cifar_resnet164')
183 | return model
184 |
185 |
186 | def cifar_resnet1001(shortcut_type='B', load_weights=False, l2_reg=1e-4, input_shape=None):
187 | model = Resnet(input_shape=input_shape, n_classes=10, l2_reg=l2_reg, group_sizes=(111, 111, 111),
188 | features=(64, 128, 256),
189 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1},
190 | shortcut_type=shortcut_type,
191 | block_type='bootleneck', preact_shortcuts=True)
192 | if load_weights: model = load_weights_func(model, 'cifar_resnet1001')
193 | return model
194 |
195 |
196 | def cifar_wide_resnet(N, K, block_type='preactivated', shortcut_type='B', dropout=0, l2_reg=2.5e-4, n_classes=None, preact_shortcuts=False, input_shape=None):
197 | assert (N - 4) % 6 == 0, "N-4 has to be divisible by 6"
198 | lpb = (N - 4) // 6 # layers per block - since N is total number of convolutional layers in Wide ResNet
199 | model = Resnet(input_shape=input_shape, n_classes=n_classes, l2_reg=l2_reg, group_sizes=(lpb, lpb, lpb),
200 | features=(16 * K, 32 * K, 64 * K),
201 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1},
202 | shortcut_type=shortcut_type,
203 | block_type=block_type, dropout=dropout, preact_shortcuts=preact_shortcuts)
204 | return model
205 |
206 |
207 | def cifar_WRN_16_4(shortcut_type='B', load_weights=False, dropout=0, l2_reg=2.5e-4, input_shape=None):
208 | model = cifar_wide_resnet(16, 4, 'preactivated', shortcut_type, dropout=dropout, l2_reg=l2_reg, input_shape=input_shape)
209 | if load_weights: model = load_weights_func(model, 'cifar_WRN_16_4')
210 | return model
211 |
212 |
213 | def cifar_WRN_40_4(shortcut_type='B', load_weights=False, dropout=0, l2_reg=2.5e-4, input_shape=None):
214 | model = cifar_wide_resnet(40, 4, 'preactivated', shortcut_type, dropout=dropout, l2_reg=l2_reg, input_shape=input_shape)
215 | if load_weights: model = load_weights_func(model, 'cifar_WRN_40_4')
216 | return model
217 |
218 |
219 | def cifar_WRN_16_8(shortcut_type='B', load_weights=False, dropout=0, l2_reg=2.5e-4, input_shape=None):
220 | model = cifar_wide_resnet(16, 8, 'preactivated', shortcut_type, dropout=dropout, l2_reg=l2_reg, input_shape=input_shape)
221 | if load_weights: model = load_weights_func(model, 'cifar_WRN_16_8')
222 | return model
223 |
224 |
225 | def cifar_WRN_28_10(shortcut_type='B', load_weights=False, dropout=0, l2_reg=2.5e-4, n_classes=None, preact_shortcuts=False, input_shape=None):
226 | model = cifar_wide_resnet(28, 10, 'preactivated', shortcut_type, dropout=dropout, l2_reg=l2_reg, n_classes = n_classes, preact_shortcuts=preact_shortcuts, input_shape=input_shape)
227 | return model
228 |
229 | def cifar_WRN_28_2(shortcut_type='B', load_weights=False, dropout=0, l2_reg=2.5e-4, n_classes=None, preact_shortcuts=False, input_shape=None):
230 | model = cifar_wide_resnet(28, 2, 'preactivated', shortcut_type, dropout=dropout, l2_reg=l2_reg, n_classes = n_classes, preact_shortcuts=preact_shortcuts, input_shape=input_shape)
231 | return model
232 |
233 |
234 | def cifar_WRN_40_2(shortcut_type='B', load_weights=False, dropout=0, l2_reg=2.5e-4, n_classes=None, preact_shortcuts=False, input_shape=None):
235 | model = cifar_wide_resnet(40, 2, 'preactivated', shortcut_type, dropout=dropout, l2_reg=l2_reg, n_classes = n_classes, preact_shortcuts=preact_shortcuts, input_shape=input_shape)
236 | return model
237 |
238 | def cifar_resnext(N, cardinality, width, shortcut_type='B', ):
239 | assert (N - 3) % 9 == 0, "N-4 has to be divisible by 6"
240 | lpb = (N - 3) // 9 # layers per block - since N is total number of convolutional layers in Wide ResNet
241 | model = Resnet(input_shape=(32, 32, 3), n_classes=10, l2_reg=1e-4, group_sizes=(lpb, lpb, lpb),
242 | features=(16 * width, 32 * width, 64 * width),
243 | strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1},
244 | shortcut_type=shortcut_type,
245 | block_type='resnext', cardinality=cardinality, width=width)
246 | return model
247 |
248 |
249 | if __name__ == '__main__':
250 | model = cifar_WRN_28_10(dropout=0, l2_reg=5e-4/2., preact_shortcuts=False, n_classes=10)
--------------------------------------------------------------------------------
/resnet_imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tensorflow as tf
3 | # ref: https://github.com/gahaalt/resnets-in-tensorflow2/blob/master/Models/Resnets.py
4 | _bn_momentum = 0.9
5 |
6 | def regularized_padded_conv(*args, **kwargs):
7 | return tf.keras.layers.Conv2D(*args, **kwargs, padding='same', kernel_regularizer=_regularizer, bias_regularizer=_regularizer,
8 | kernel_initializer='he_normal', use_bias=False)
9 |
10 |
11 | def bn_relu(x, gamma_initializer='ones'):
12 | x = tf.keras.layers.experimental.SyncBatchNormalization(momentum=_bn_momentum, gamma_initializer=gamma_initializer)(x)
13 | return tf.keras.layers.ReLU()(x)
14 |
15 |
16 | def shortcut(x, filters, stride, mode):
17 | if x.shape[-1] == filters: # maybe and stride==1
18 | return x
19 | elif mode == 'B':
20 | return regularized_padded_conv(filters, 1, strides=stride)(x)
21 | elif mode == 'B_original':
22 | x = regularized_padded_conv(filters, 1, strides=stride)(x)
23 | return tf.keras.layers.experimental.SyncBatchNormalization(momentum=_bn_momentum)(x)
24 | elif mode == 'A':
25 | return tf.pad(tf.keras.layers.MaxPool2D(1, stride)(x) if stride > 1 else x,
26 | paddings=[(0, 0), (0, 0), (0, 0), (0, filters - x.shape[-1])])
27 | else:
28 | raise KeyError("Parameter shortcut_type not recognized!")
29 |
30 |
31 | def original_block(x, filters, stride=1, **kwargs):
32 | c1 = regularized_padded_conv(filters, 3, strides=stride)(x)
33 | c2 = regularized_padded_conv(filters, 3)(bn_relu(c1))
34 | c2 = tf.keras.layers.experimental.SyncBatchNormalization(momentum=_bn_momentum)(c2)
35 |
36 | mode = 'B_original' if _shortcut_type == 'B' else _shortcut_type
37 | x = shortcut(x, filters, stride, mode=mode)
38 | return tf.keras.layers.ReLU()(x + c2)
39 |
40 |
41 | def bootleneck_block(x, filters, stride=1, preact_block=False): # preact_block==False
42 | # flow = bn_relu(x)
43 | # if preact_block:
44 | # x = flow
45 | residual = x
46 | c1 = regularized_padded_conv(filters // _bootleneck_width, 1)(bn_relu(x))
47 | c2 = regularized_padded_conv(filters // _bootleneck_width, 3, strides=stride)(bn_relu(c1))
48 | c3 = regularized_padded_conv(filters, 1)(bn_relu(c2))
49 | if x.shape[-1] != filters or stride != 1:
50 | residual = shortcut(x, filters, stride, mode=_shortcut_type)
51 | return tf.keras.layers.ReLU()(residual + tf.keras.layers.experimental.SyncBatchNormalization(momentum=_bn_momentum, gamma_initializer='zeros')(c3))
52 |
53 |
54 | def group_of_blocks(x, block_type, num_blocks, filters, stride, block_idx=0):
55 | global _preact_shortcuts
56 | preact_block = False
57 |
58 | x = block_type(x, filters, stride, preact_block=preact_block)
59 | for i in range(num_blocks - 1):
60 | x = block_type(x, filters)
61 | return x
62 |
63 |
64 | def Resnet(input_shape, n_classes, l2_reg=1e-4, group_sizes=(2, 2, 2), features=(16, 32, 64), strides=(1, 2, 2),
65 | shortcut_type='B', block_type='preactivated', first_conv={"filters": 16, "kernel_size": 3, "strides": 1},
66 | dropout=0, cardinality=1, bootleneck_width=4, preact_shortcuts=True):
67 | global _regularizer, _shortcut_type, _preact_projection, _dropout, _cardinality, _bootleneck_width, _preact_shortcuts
68 | _bootleneck_width = bootleneck_width # used in ResNeXts and bootleneck blocks
69 | _regularizer = tf.keras.regularizers.l2(l2_reg)
70 | _shortcut_type = shortcut_type # used in blocks
71 | _cardinality = cardinality # used in ResNeXts
72 | _dropout = dropout # used in Wide ResNets
73 | _preact_shortcuts = preact_shortcuts
74 |
75 | block_types = {
76 | # 'preactivated': preactivation_block,
77 | 'bootleneck': bootleneck_block,
78 | 'original': original_block
79 | }
80 |
81 | selected_block = block_types[block_type]
82 | inputs = tf.keras.layers.Input(shape=input_shape)
83 | flow = regularized_padded_conv(**first_conv)(inputs)
84 |
85 | # if block_type == 'original':
86 | flow = bn_relu(flow)
87 | flow = tf.keras.layers.MaxPool2D(pool_size=(3,3), strides=2, padding='same')(flow)
88 |
89 | for block_idx, (group_size, feature, stride) in enumerate(zip(group_sizes, features, strides)):
90 | flow = group_of_blocks(flow,
91 | block_type=selected_block,
92 | num_blocks=group_size,
93 | block_idx=block_idx,
94 | filters=feature,
95 | stride=stride)
96 |
97 | # if block_type != 'original':
98 | # flow = bn_relu(flow)
99 |
100 | flow = tf.keras.layers.GlobalAveragePooling2D()(flow)
101 |
102 | outputs = tf.keras.layers.Dense(n_classes, kernel_regularizer=_regularizer, bias_regularizer=_regularizer, use_bias=True)(flow)
103 | model = tf.keras.Model(inputs=inputs, outputs=outputs)
104 | return model
105 |
106 | def imagenet_resnet50(block_type='bootleneck', shortcut_type='B_original', l2_reg=0.5e-4, load_weights=False, input_shape=(224,224,3), n_classes=1000):
107 | bootleneck_width = 4
108 | model = Resnet(input_shape=input_shape, n_classes=n_classes, l2_reg=l2_reg, group_sizes=(3,4,6,3),
109 | features=(64*bootleneck_width, 128*bootleneck_width, 256*bootleneck_width, 512*bootleneck_width),
110 | strides=(1, 2, 2, 2), first_conv={"filters": 64, "kernel_size": 7, "strides": 2},
111 | shortcut_type=shortcut_type,
112 | block_type=block_type, preact_shortcuts=False,
113 | bootleneck_width=bootleneck_width)
114 | return model
115 |
116 | def imagenet_resnet50_pretrained(input_shape, n_classes, l2_reg):
117 | _regularizer = tf.keras.regularizers.l2(l2_reg)
118 | inputs = tf.keras.layers.Input(shape=input_shape)
119 | base_model = tf.keras.applications.resnet50.ResNet50(include_top=False, input_shape=input_shape,
120 | pooling='avg', weights='imagenet')
121 | base_model.trainable = False
122 | x = base_model(inputs, training=False) # do not update batch augmentation
123 | outputs = tf.keras.layers.Dense(n_classes, kernel_regularizer=_regularizer, bias_regularizer=_regularizer, use_bias=True)(x)
124 | model = tf.keras.Model(inputs=inputs, outputs=outputs)
125 | return model
126 |
127 | def imagenet_resnet18(block_type='original', shortcut_type='B_original', l2_reg=0.5e-4, load_weights=False, input_shape=(224,224,3), n_classes=1000):
128 | model = Resnet(input_shape=input_shape, n_classes=n_classes, l2_reg=l2_reg, group_sizes=(2,2,2,2),
129 | features=(64, 128, 256, 512),
130 | strides=(1, 2, 2, 2), first_conv={"filters": 64, "kernel_size": 7, "strides": 2},
131 | shortcut_type=shortcut_type,
132 | block_type=block_type, preact_shortcuts=False,
133 | bootleneck_width=None)
134 | return model
135 |
136 | def load_weights_func(model, model_name):
137 | try:
138 | model.load_weights(os.path.join('saved_models', model_name + '.tf'))
139 | except tf.errors.NotFoundError:
140 | print("No weights found for this model!")
141 | return model
142 |
143 |
144 | if __name__ == '__main__':
145 | model = imagenet_resnet50()
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import numpy as np
4 | import matplotlib
5 | # configure backend here
6 | matplotlib.use('Agg')
7 | # matplotlib.use('tkagg')
8 | import matplotlib.pyplot as plt
9 | import matplotlib.patheffects as PathEffects
10 | from mpl_toolkits.axes_grid1 import ImageGrid
11 | import tensorflow as tf
12 | import math
13 | import sys
14 | from data_generator import CIFAR_MEANS, CIFAR_STDS
15 |
16 | gfile = tf.io.gfile
17 |
18 | class Logger(object):
19 | """Prints to both STDOUT and a file."""
20 |
21 | def __init__(self, filepath):
22 | self.terminal = sys.stdout
23 | self.log = gfile.GFile(filepath, 'a+')
24 |
25 | def write(self, message):
26 | self.terminal.write(message)
27 | self.terminal.flush()
28 | self.log.write(message)
29 | self.log.flush()
30 |
31 | def flush(self):
32 | self.terminal.flush()
33 | self.log.flush()
34 |
35 | class CTLEarlyStopping:
36 | def __init__(self,
37 | monitor='val_loss',
38 | min_delta=0,
39 | patience=0,
40 | mode='auto',
41 | ):
42 | self.monitor = monitor
43 | self.patience = patience
44 | self.min_delta = abs(min_delta)
45 | self.wait = 0
46 | self.stop_training = False
47 | self.improvement = False
48 |
49 | if mode not in ['auto', 'min', 'max']:
50 | logging.warning('EarlyStopping mode %s is unknown, '
51 | 'fallback to auto mode.', mode)
52 | mode = 'auto'
53 |
54 | if mode == 'min':
55 | self.monitor_op = np.less
56 | elif mode == 'max':
57 | self.monitor_op = np.greater
58 | else:
59 | if 'acc' in self.monitor:
60 | self.monitor_op = np.greater
61 | else:
62 | self.monitor_op = np.less
63 |
64 | if self.monitor_op == np.greater:
65 | self.min_delta *= 1
66 | else:
67 | self.min_delta *= -1
68 |
69 | self.best = np.Inf if self.monitor_op == np.less else -np.Inf
70 |
71 |
72 | def check_progress(self, current):
73 | if self.monitor_op(current - self.min_delta, self.best):
74 | print(f"{self.monitor} improved from {self.best:.4f} to {current:.4f}.", end=" ")
75 | self.best = current
76 | self.wait = 0
77 | self.improvement = True
78 | else:
79 | self.wait += 1
80 | self.improvement = False
81 | print(f"{self.monitor} didn't improve")
82 | if self.wait >= self.patience:
83 | print("Early stopping")
84 | self.stop_training = True
85 |
86 | return self.improvement, self.stop_training
87 |
88 |
89 | ##########################################################################################
90 |
91 |
92 | class CTLHistory:
93 | def __init__(self,
94 | filename=None,
95 | save_dir='plots'):
96 |
97 | self.history = {'train_loss':[],
98 | "train_acc":[],
99 | "val_loss":[],
100 | "val_acc":[],
101 | "lr":[],
102 | "wd":[]}
103 |
104 | self.save_dir = save_dir
105 | if not os.path.exists(self.save_dir):
106 | os.mkdir(self.save_dir)
107 |
108 | try:
109 | filename = 'history_cuda.png'
110 | except:
111 | filename = 'history.png' if filename is None else filename
112 |
113 | self.plot_name = os.path.join(self.save_dir, filename)
114 |
115 |
116 |
117 | def update(self, train_stats, val_stats, record_lr_wd):
118 | train_loss, train_acc = train_stats
119 | val_loss, val_acc = val_stats
120 | lr_history, wd_history = record_lr_wd
121 |
122 | self.history['train_loss'].append(train_loss)
123 | self.history['train_acc'].append(np.round(train_acc*100))
124 | self.history['val_loss'].append(val_loss)
125 | self.history['val_acc'].append(np.round(val_acc*100))
126 | self.history['lr'].extend(lr_history)
127 | self.history['wd'].extend(wd_history)
128 |
129 |
130 | def plot_and_save(self, initial_epoch=0):
131 | train_loss = self.history['train_loss']
132 | train_acc = self.history['train_acc']
133 | val_loss = self.history['val_loss']
134 | val_acc = self.history['val_acc']
135 |
136 | epochs = [(i+initial_epoch) for i in range(len(train_loss))]
137 |
138 | f, ax = plt.subplots(3, 1, figsize=(15,8))
139 | ax[0].plot(epochs, train_loss)
140 | ax[0].plot(epochs, val_loss)
141 | ax[0].set_title('loss progression')
142 | ax[0].set_xlabel('Epochs')
143 | ax[0].set_ylabel('loss values')
144 | ax[0].legend(['train', 'test'])
145 |
146 | ax[1].plot(epochs, train_acc)
147 | ax[1].plot(epochs, val_acc)
148 | ax[1].set_title('accuracy progression')
149 | ax[1].set_xlabel('Epochs')
150 | ax[1].set_ylabel('Accuracy')
151 | ax[1].legend(['train', 'test'])
152 |
153 | steps = len(self.history['lr'])
154 | bs = steps/len(train_loss)
155 | ax[2].plot([s/bs for s in range(steps)], self.history['lr'])
156 | ax[2].plot([s/bs for s in range(steps)], self.history['wd'])
157 | ax[2].set_title('learning rate and weight decay')
158 | ax[2].set_xlabel('Epochs')
159 | ax[2].set_ylabel('lr and wd')
160 | ax[2].legend(['lr', 'wd'])
161 |
162 | plt.savefig(self.plot_name)
163 | plt.close()
164 |
165 | def repeat(x, n, axis):
166 | if isinstance(x, np.ndarray):
167 | return np.repeat(x, n, axis=axis)
168 | elif isinstance(x, list):
169 | return repeat_list(x, n, axis)
170 | else:
171 | raise Exception('Unsupport data type {}'.format(type(x)))
172 |
173 | def repeat_list(x, n, axis):
174 | assert isinstance(x, list), 'Can only consume list type'
175 | if axis == 0:
176 | x_new = sum([[x_] * n for x_ in x], [])
177 | elif axis > 1:
178 | x_new = [repeat(x_, n, axis=axis - 1) for x_ in x]
179 | else:
180 | raise Exception
181 | return x_new
182 |
183 | def tile(x):
184 | return None
--------------------------------------------------------------------------------