├── README.md ├── augmentations.py ├── common.py ├── confs ├── pyramid272_cifar100_b128_rwt.yaml ├── pyramid272_cifar100_b152_rwt.yaml ├── pyramid272_cifar100_b256_rwt.yaml ├── pyramid272_cifar100_b64_rwt.yaml ├── pyramid272_cifar100_rwaug_b64.yaml ├── pyramid272_cifar10_b128_rwt.yaml ├── pyramid272_cifar10_b152_ra.yaml ├── pyramid272_cifar10_b152_rwt.yaml ├── pyramid272_cifar10_b64.yaml ├── pyramid272_cifar10_b64_rwt.yaml ├── res18_cifar100_b32_rwaug_search.yaml ├── shake26_2x96d_cifar100_b128_ra.yaml ├── shake26_2x96d_cifar100_b128_rwaug.yaml ├── shake26_2x96d_cifar100_b128_rwaug_300e.yaml ├── shake26_2x96d_cifar10_b128_noaug_300e.yaml ├── shake26_2x96d_cifar10_b128_ra.yaml ├── shake26_2x96d_cifar10_b128_rwaug.yaml ├── shake26_2x96d_cifar10_b128_rwaug_300e.yaml ├── shake26_2x96d_cifar10_b256_ra.yaml ├── shake26_2x96d_cifar10_b256_rwaug.yaml ├── shake26_2x96d_cifar10_b512.yaml ├── shake26_2x96d_cifar10_b64_rwaug_search.yaml ├── shake26_2x96d_svhn_b128_rwaug.yaml ├── wresnet28x10_cifar100_b128_noaug_100e.yaml ├── wresnet28x10_cifar100_b128_noaug_50e.yaml ├── wresnet28x10_cifar100_b128_ra_100e.yaml ├── wresnet28x10_cifar100_b128_ra_300e.yaml ├── wresnet28x10_cifar100_b128_ra_50e.yaml ├── wresnet28x10_cifar100_b128_ra_th.yaml ├── wresnet28x10_cifar100_b128_rwaug_100e.yaml ├── wresnet28x10_cifar100_b128_rwaug_50e.yaml ├── wresnet28x10_cifar100_b128_rwaug_train.yaml ├── wresnet28x10_cifar100_b128_rwaug_train_300e.yaml ├── wresnet28x10_cifar100_b16_fast_search.yaml ├── wresnet28x10_cifar100_b16_ra_refine.yaml ├── wresnet28x10_cifar100_b256_rwaug_train.yaml ├── wresnet28x10_cifar100_b256_total_ohl_rand.yaml ├── wresnet28x10_cifar100_b32_rwaug_search.yaml ├── wresnet28x10_cifar100_b64_mix_search.yaml ├── wresnet28x10_cifar100_b64_mix_search_resdec.yaml ├── wresnet28x10_cifar100_b64_rwaug_search.yaml ├── wresnet28x10_cifar10_b128_noaug_100e.yaml ├── wresnet28x10_cifar10_b128_noaug_50e.yaml ├── wresnet28x10_cifar10_b128_ra.yaml ├── wresnet28x10_cifar10_b128_ra_100e.yaml ├── wresnet28x10_cifar10_b128_ra_300e.yaml ├── wresnet28x10_cifar10_b128_ra_50e.yaml ├── wresnet28x10_cifar10_b128_ra_th.yaml ├── wresnet28x10_cifar10_b128_ra_th_100e.yaml ├── wresnet28x10_cifar10_b128_ra_th_300e.yaml ├── wresnet28x10_cifar10_b128_rwaug.yaml ├── wresnet28x10_cifar10_b128_rwaug_100e.yaml ├── wresnet28x10_cifar10_b128_rwaug_50e.yaml ├── wresnet28x10_cifar10_b128_rwaug_cutmix.yaml ├── wresnet28x10_cifar10_b128_rwaug_mixup.yaml ├── wresnet28x10_cifar10_b128_rwaug_warmup.yaml ├── wresnet28x10_cifar10_b16_fast_search.yaml ├── wresnet28x10_cifar10_b16_ra_refine.yaml ├── wresnet28x10_cifar10_b16_search.yaml ├── wresnet28x10_cifar10_b256_cutout.yaml ├── wresnet28x10_cifar10_b256_default.yaml ├── wresnet28x10_cifar10_b256_rwaug.yaml ├── wresnet28x10_cifar10_b32_fast_search.yaml ├── wresnet28x10_cifar10_b32_search.yaml ├── wresnet28x10_cifar10_b64_search.yaml ├── wresnet28x10_cifar10_reduce_b16_rwaug_search.yaml ├── wresnet28x10_cifar10_reduce_b64_mix_search.yaml ├── wresnet28x10_cifar10_reduce_b64_rwaug_search.yaml ├── wresnet28x10_svhn_b128.yaml ├── wresnet28x10_svhn_b128_rwaug.yaml ├── wresnet28x10_svhn_b256.yaml ├── wresnet28x10_svhn_b256_rwaug.yaml ├── wresnet28x10_svhn_b32_search.yaml ├── wresnet40x2_cifar100_b32_fast_search.yaml ├── wresnet40x2_cifar10_b32_fast_search.yaml └── wresnet40x2_svhn_b32_fast_search.yaml ├── data.py ├── imagenet.py ├── lr_scheduler.py ├── metrics.py ├── networks.py ├── process_npy.py ├── pyramidnet.py ├── resnet.py ├── search.py ├── shake_resnet.py ├── shake_resnext.py ├── shakedrop.py ├── shakeshake.py ├── smooth_ce.py ├── train.py ├── wideresnet.py ├── wresnet40x2_cifar100_new_search_smoothed.npy └── wresnet40x2_cifar10_new_search_smoothed.npy /README.md: -------------------------------------------------------------------------------- 1 | # DDAS_code 2 | Direct Differentiable Augmentation Search 3 | 4 | **Commnad for search on CIFAR-10:** 5 | ``` 6 | python search.py -c confs/wresnet40x2_cifar10_b32_fast_search.yaml --dataroot ../data/ --tag 1 --save wresnet_cifar10_new_search_bs_final.pth --explore_ratio 0.99999 --cv-ratio 0.96 --param_lr 0.005 --tp_lr 0.001 --init_tp 0.35 7 | ``` 8 | **Commnad for training on CIFAR-10:** 9 | ``` 10 | python process_npy.py --file_name wresnet_cifar10_new_search_bs_final_save_dict.npy --out_name wresnet_cifar10_new_search_smoothed.npy 11 | python train.py -c confs/wresnet28x10_cifar10_b128_rwaug.yaml --dataroot ../data/ --save wresnet_cifar10_new_search2_bs1.pth --load_tp wresnet_cifar10_new_search_smoothed.npy --tag 1 12 | ``` 13 | **Commnad for search on CIFAR-100:** 14 | ``` 15 | python search.py -c confs/wresnet40x2_cifar100_b32_fast_search.yaml --dataroot ../data/ --save wresnet_cifar_new_search_bs_final.pth --tag 1 --explore_ratio 0.9999 --cv-ratio 0.96 --param_lr 0.005 --tp_lr 0.001 --init_tp 0.35 16 | ``` 17 | **Commnad for training on CIFAR-100:** 18 | ``` 19 | python process_npy.py --file_name wresnet_cifar_new_search_bs_final_save_dict.npy --out_name wresnet_cifar100_smoothed.npy 20 | python train.py -c confs/wresnet28x10_cifar100_b128_rwaug_train.yaml --dataroot ../data/ --save wresnet_cifar100_new_bs1.pth --load_tp wresnet_cifar100_smoothed.npy --tag 1 21 | ``` 22 | -------------------------------------------------------------------------------- /augmentations.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import random 6 | import numpy as np 7 | import torch 8 | # pylint:disable=g-multiple-import 9 | from PIL import ImageOps, ImageEnhance, ImageFilter, Image 10 | from torchvision.transforms import transforms 11 | # pylint:enable=g-multiple-import 12 | 13 | 14 | IMAGE_SIZE = 32 15 | # What is the dataset mean and std of the images on the training set 16 | MEANS = [0.49139968, 0.48215841, 0.44653091] 17 | STDS = [0.24703223, 0.24348513, 0.26158784] 18 | PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted 19 | 20 | 21 | def random_flip(x): 22 | """Flip the input x horizontally with 50% probability.""" 23 | if np.random.rand(1)[0] > 0.5: 24 | return np.fliplr(x) 25 | return x 26 | 27 | 28 | def zero_pad_and_crop(img, amount=4): 29 | """Zero pad by `amount` zero pixels on each side then take a random crop. 30 | Args: 31 | img: numpy image that will be zero padded and cropped. 32 | amount: amount of zeros to pad `img` with horizontally and verically. 33 | Returns: 34 | The cropped zero padded img. The returned numpy array will be of the same 35 | shape as `img`. 36 | """ 37 | padded_img = np.zeros((img.shape[0] + amount * 2, img.shape[1] + amount * 2, 38 | img.shape[2])) 39 | padded_img[amount:img.shape[0] + amount, amount: 40 | img.shape[1] + amount, :] = img 41 | top = np.random.randint(low=0, high=2 * amount) 42 | left = np.random.randint(low=0, high=2 * amount) 43 | new_img = padded_img[top:top + img.shape[0], left:left + img.shape[1], :] 44 | return new_img 45 | 46 | 47 | def create_cutout_mask(img_height, img_width, num_channels, size): 48 | """Creates a zero mask used for cutout of shape `img_height` x `img_width`. 49 | Args: 50 | img_height: Height of image cutout mask will be applied to. 51 | img_width: Width of image cutout mask will be applied to. 52 | num_channels: Number of channels in the image. 53 | size: Size of the zeros mask. 54 | Returns: 55 | A mask of shape `img_height` x `img_width` with all ones except for a 56 | square of zeros of shape `size` x `size`. This mask is meant to be 57 | elementwise multiplied with the original image. Additionally returns 58 | the `upper_coord` and `lower_coord` which specify where the cutout mask 59 | will be applied. 60 | """ 61 | assert img_height == img_width 62 | 63 | # Sample center where cutout mask will be applied 64 | height_loc = np.random.randint(low=0, high=img_height) 65 | width_loc = np.random.randint(low=0, high=img_width) 66 | 67 | # Determine upper right and lower left corners of patch 68 | upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2)) 69 | lower_coord = (min(img_height, height_loc + size // 2), 70 | min(img_width, width_loc + size // 2)) 71 | mask_height = lower_coord[0] - upper_coord[0] 72 | mask_width = lower_coord[1] - upper_coord[1] 73 | assert mask_height > 0 74 | assert mask_width > 0 75 | 76 | mask = np.ones((img_height, img_width, num_channels)) 77 | zeros = np.zeros((mask_height, mask_width, num_channels)) 78 | mask[upper_coord[0]:lower_coord[0], upper_coord[1]:lower_coord[1], :] = ( 79 | zeros) 80 | return mask, upper_coord, lower_coord 81 | 82 | 83 | def cutout_numpy(img, size=16): 84 | """Apply cutout with mask of shape `size` x `size` to `img`. 85 | The cutout operation is from the paper https://arxiv.org/abs/1708.04552. 86 | This operation applies a `size`x`size` mask of zeros to a random location 87 | within `img`. 88 | Args: 89 | img: Numpy image that cutout will be applied to. 90 | size: Height/width of the cutout mask that will be 91 | Returns: 92 | A numpy tensor that is the result of applying the cutout mask to `img`. 93 | """ 94 | img_height, img_width, num_channels = (img.shape[0], img.shape[1], 95 | img.shape[2]) 96 | assert len(img.shape) == 3 97 | mask, _, _ = create_cutout_mask(img_height, img_width, num_channels, size) 98 | return img * mask 99 | 100 | 101 | def float_parameter(level, maxval): 102 | """Helper function to scale `val` between 0 and maxval . 103 | Args: 104 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 105 | maxval: Maximum value that the operation can have. This will be scaled 106 | to level/PARAMETER_MAX. 107 | Returns: 108 | A float that results from scaling `maxval` according to `level`. 109 | """ 110 | return float(level) * maxval / PARAMETER_MAX 111 | 112 | 113 | def int_parameter(level, maxval): 114 | """Helper function to scale `val` between 0 and maxval . 115 | Args: 116 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 117 | maxval: Maximum value that the operation can have. This will be scaled 118 | to level/PARAMETER_MAX. 119 | Returns: 120 | An int that results from scaling `maxval` according to `level`. 121 | """ 122 | return int(level * maxval / PARAMETER_MAX) 123 | 124 | 125 | def pil_wrap(img): 126 | """Convert the `img` numpy tensor to a PIL Image.""" 127 | return Image.fromarray( 128 | np.uint8((img * STDS + MEANS) * 255.0)).convert('RGBA') 129 | 130 | 131 | def pil_unwrap(pil_img): 132 | """Converts the PIL img to a numpy array.""" 133 | pic_array = (np.array(pil_img.getdata()).reshape((32, 32, 4)) / 255.0) 134 | i1, i2 = np.where(pic_array[:, :, 3] == 0) 135 | pic_array = (pic_array[:, :, :3] - MEANS) / STDS 136 | pic_array[i1, i2] = [0, 0, 0] 137 | return pic_array 138 | 139 | 140 | def apply_policy(policy, img): 141 | """Apply the `policy` to the numpy `img`. 142 | Args: 143 | policy: A list of tuples with the form (name, probability, level) where 144 | `name` is the name of the augmentation operation to apply, `probability` 145 | is the probability of applying the operation and `level` is what strength 146 | the operation to apply. 147 | img: Numpy image that will have `policy` applied to it. 148 | Returns: 149 | The result of applying `policy` to `img`. 150 | """ 151 | pil_img = pil_wrap(img) 152 | for xform in policy: 153 | assert len(xform) == 3 154 | name, probability, level = xform 155 | xform_fn = NAME_TO_TRANSFORM[name].pil_transformer(probability, level) 156 | pil_img = xform_fn(pil_img) 157 | return pil_unwrap(pil_img) 158 | 159 | 160 | class TransformFunction(object): 161 | """Wraps the Transform function for pretty printing options.""" 162 | 163 | def __init__(self, func, name): 164 | self.f = func 165 | self.name = name 166 | 167 | def __repr__(self): 168 | return '<' + self.name + '>' 169 | 170 | def __call__(self, pil_img): 171 | return self.f(pil_img) 172 | 173 | 174 | class TransformT(object): 175 | """Each instance of this class represents a specific transform.""" 176 | 177 | def __init__(self, name, xform_fn): 178 | self.name = name 179 | self.xform = xform_fn 180 | 181 | def pil_transformer(self, probability, level): 182 | 183 | def return_function(im): 184 | if random.random() < probability: 185 | im = self.xform(im, level) 186 | return im 187 | 188 | name = self.name + '({:.1f},{})'.format(probability, level) 189 | return TransformFunction(return_function, name) 190 | 191 | def do_transform(self, image, level): 192 | f = self.pil_transformer(PARAMETER_MAX, level) 193 | return f(image) 194 | 195 | 196 | ################## Transform Functions ################## 197 | aug_ohl_list = [] 198 | aug_name_ls = [] 199 | identity = TransformT('identity', lambda pil_img, level: pil_img) 200 | # identity_ohl = TransformT('identity_ohl', lambda pil_img: pil_img) 201 | identity_ohl = lambda pil_img: identity.do_transform(pil_img, 0) 202 | aug_ohl_list.append(identity_ohl) 203 | aug_name_ls.append('identity.') 204 | flip_lr = TransformT( 205 | 'FlipLR', 206 | lambda pil_img, level: pil_img.transpose(Image.FLIP_LEFT_RIGHT)) 207 | 208 | flip_lr_ohl = lambda pil_img: flip_lr.do_transform(pil_img, 0) 209 | aug_ohl_list.append(flip_lr_ohl) 210 | aug_name_ls.append('FlipLR.') 211 | flip_ud = TransformT( 212 | 'FlipUD', 213 | lambda pil_img, level: pil_img.transpose(Image.FLIP_TOP_BOTTOM)) 214 | aug_name_ls.append('FlipUD.') 215 | flip_ud_ohl = lambda pil_img: flip_ud.do_transform(pil_img, 0) 216 | aug_ohl_list.append(flip_ud_ohl) 217 | 218 | # pylint:disable=g-long-lambda 219 | auto_contrast = TransformT( 220 | 'AutoContrast', 221 | lambda pil_img, level: ImageOps.autocontrast( 222 | pil_img.convert('RGB')).convert('RGBA')) 223 | auto_contrast_ohl = lambda pil_img: auto_contrast.do_transform(pil_img, 0) 224 | aug_ohl_list.append(auto_contrast_ohl) 225 | aug_name_ls.append('AutoContrast.') 226 | equalize = TransformT( 227 | 'Equalize', 228 | lambda pil_img, level: ImageOps.equalize( 229 | pil_img.convert('RGB')).convert('RGBA')) 230 | equalize_ohl = lambda pil_img: equalize.do_transform(pil_img, 0) 231 | aug_ohl_list.append(equalize_ohl) 232 | aug_name_ls.append('Equalize.') 233 | invert = TransformT( 234 | 'Invert', 235 | lambda pil_img, level: ImageOps.invert( 236 | pil_img.convert('RGB')).convert('RGBA')) 237 | invert_ohl = lambda pil_img: invert.do_transform(pil_img, 0) 238 | aug_ohl_list.append(invert_ohl) 239 | aug_name_ls.append('Invert.') 240 | # pylint:enable=g-long-lambda 241 | blur = TransformT( 242 | 'Blur', lambda pil_img, level: pil_img.filter(ImageFilter.BLUR)) 243 | blur_ohl = lambda pil_img: blur.do_transform(pil_img, 0) 244 | aug_ohl_list.append(blur_ohl) 245 | aug_name_ls.append('Blur.') 246 | smooth = TransformT( 247 | 'Smooth', 248 | lambda pil_img, level: pil_img.filter(ImageFilter.SMOOTH)) 249 | smooth_ohl = lambda pil_img: smooth.do_transform(pil_img, 0) 250 | aug_ohl_list.append(smooth_ohl) 251 | aug_name_ls.append('Smooth.') 252 | aug_ohl_list_rotate=[] 253 | def _rotate_impl(pil_img, level): 254 | """Rotates `pil_img` from -30 to 30 degrees depending on `level`.""" 255 | degrees = int_parameter(level, 30) 256 | if random.random() > 0.5: 257 | degrees = -degrees 258 | return pil_img.rotate(degrees) 259 | 260 | M_list = [0, 2, 10, 14] 261 | #M_list = [5] 262 | rotate = TransformT('Rotate', _rotate_impl) 263 | for m in M_list: 264 | mop = lambda pil_img: rotate.do_transform(pil_img, m) 265 | aug_ohl_list.append(mop) 266 | aug_ohl_list_rotate.append(mop) 267 | aug_name_ls.append('Rotate.'+str(m)) 268 | 269 | def _posterize_impl(pil_img, level): 270 | """Applies PIL Posterize to `pil_img`.""" 271 | level = int_parameter(level, 4) 272 | return ImageOps.posterize(pil_img.convert('RGB'), 4 - level).convert('RGBA') 273 | 274 | posterize = TransformT('Posterize', _posterize_impl) 275 | for m in M_list: 276 | mop = lambda pil_img: posterize.do_transform(pil_img, m) 277 | aug_ohl_list.append(mop) 278 | aug_name_ls.append('Posterize.'+str(m)) 279 | 280 | def _shear_x_impl(pil_img, level): 281 | """Applies PIL ShearX to `pil_img`. 282 | The ShearX operation shears the image along the horizontal axis with `level` 283 | magnitude. 284 | Args: 285 | pil_img: Image in PIL object. 286 | level: Strength of the operation specified as an Integer from 287 | [0, `PARAMETER_MAX`]. 288 | Returns: 289 | A PIL Image that has had ShearX applied to it. 290 | """ 291 | level = float_parameter(level, 0.3) 292 | if random.random() > 0.5: 293 | level = -level 294 | return pil_img.transform((32, 32), Image.AFFINE, (1, level, 0, 0, 1, 0)) 295 | 296 | 297 | shear_x = TransformT('ShearX', _shear_x_impl) 298 | for m in M_list: 299 | mop = lambda pil_img:shear_x.do_transform(pil_img, m) 300 | aug_ohl_list.append(mop) 301 | aug_name_ls.append('ShearX.'+str(m)) 302 | 303 | 304 | def _shear_y_impl(pil_img, level): 305 | """Applies PIL ShearY to `pil_img`. 306 | The ShearY operation shears the image along the vertical axis with `level` 307 | magnitude. 308 | Args: 309 | pil_img: Image in PIL object. 310 | level: Strength of the operation specified as an Integer from 311 | [0, `PARAMETER_MAX`]. 312 | Returns: 313 | A PIL Image that has had ShearX applied to it. 314 | """ 315 | level = float_parameter(level, 0.3) 316 | if random.random() > 0.5: 317 | level = -level 318 | return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, level, 1, 0)) 319 | 320 | 321 | shear_y = TransformT('ShearY', _shear_y_impl) 322 | for m in M_list: 323 | mop = lambda pil_img:shear_y.do_transform(pil_img, m) 324 | aug_ohl_list.append(mop) 325 | aug_name_ls.append('ShearY.'+str(m)) 326 | 327 | 328 | def _translate_x_impl(pil_img, level): 329 | """Applies PIL TranslateX to `pil_img`. 330 | Translate the image in the horizontal direction by `level` 331 | number of pixels. 332 | Args: 333 | pil_img: Image in PIL object. 334 | level: Strength of the operation specified as an Integer from 335 | [0, `PARAMETER_MAX`]. 336 | Returns: 337 | A PIL Image that has had TranslateX applied to it. 338 | """ 339 | level = int_parameter(level, 10) 340 | if random.random() > 0.5: 341 | level = -level 342 | return pil_img.transform((32, 32), Image.AFFINE, (1, 0, level, 0, 1, 0)) 343 | 344 | 345 | translate_x = TransformT('TranslateX', _translate_x_impl) 346 | for m in M_list: 347 | mop = lambda pil_img:translate_x.do_transform(pil_img, m) 348 | aug_ohl_list.append(mop) 349 | aug_name_ls.append('TranslateX.'+str(m)) 350 | 351 | 352 | def _translate_y_impl(pil_img, level): 353 | """Applies PIL TranslateY to `pil_img`. 354 | Translate the image in the vertical direction by `level` 355 | number of pixels. 356 | Args: 357 | pil_img: Image in PIL object. 358 | level: Strength of the operation specified as an Integer from 359 | [0, `PARAMETER_MAX`]. 360 | Returns: 361 | A PIL Image that has had TranslateY applied to it. 362 | """ 363 | level = int_parameter(level, 10) 364 | if random.random() > 0.5: 365 | level = -level 366 | return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, 0, 1, level)) 367 | 368 | 369 | translate_y = TransformT('TranslateY', _translate_y_impl) 370 | for m in M_list: 371 | mop = lambda pil_img:translate_y.do_transform(pil_img, m) 372 | aug_ohl_list.append(mop) 373 | aug_name_ls.append('TranslateY.'+str(m)) 374 | 375 | 376 | def _crop_impl(pil_img, level, interpolation=Image.BILINEAR): 377 | """Applies a crop to `pil_img` with the size depending on the `level`.""" 378 | cropped = pil_img.crop((level, level, IMAGE_SIZE - level, IMAGE_SIZE - level)) 379 | resized = cropped.resize((IMAGE_SIZE, IMAGE_SIZE), interpolation) 380 | return resized 381 | 382 | 383 | crop_bilinear = TransformT('CropBilinear', _crop_impl) 384 | 385 | 386 | def _solarize_impl(pil_img, level): 387 | """Applies PIL Solarize to `pil_img`. 388 | Translate the image in the vertical direction by `level` 389 | number of pixels. 390 | Args: 391 | pil_img: Image in PIL object. 392 | level: Strength of the operation specified as an Integer from 393 | [0, `PARAMETER_MAX`]. 394 | Returns: 395 | A PIL Image that has had Solarize applied to it. 396 | """ 397 | level = int_parameter(level, 256) 398 | return ImageOps.solarize(pil_img.convert('RGB'), 256 - level).convert('RGBA') 399 | 400 | 401 | solarize = TransformT('Solarize', _solarize_impl) 402 | for m in M_list: 403 | mop = lambda pil_img:solarize.do_transform(pil_img, m) 404 | aug_ohl_list.append(mop) 405 | aug_name_ls.append('Solarize.'+str(m)) 406 | 407 | 408 | def _cutout_pil_impl(pil_img, level): 409 | """Apply cutout to pil_img at the specified level.""" 410 | size = int_parameter(level, 20) 411 | if size <= 0: 412 | return pil_img 413 | img_height, img_width, num_channels = (32, 32, 3) 414 | _, upper_coord, lower_coord = ( 415 | create_cutout_mask(img_height, img_width, num_channels, size)) 416 | pixels = pil_img.load() # create the pixel map 417 | for i in range(upper_coord[0], lower_coord[0]): # for every col: 418 | for j in range(upper_coord[1], lower_coord[1]): # For every row 419 | pixels[i, j] = (125, 122, 113, 0) # set the colour accordingly 420 | return pil_img 421 | 422 | cutout = TransformT('Cutout', _cutout_pil_impl) 423 | 424 | 425 | def _enhancer_impl(enhancer): 426 | """Sets level to be between 0.1 and 1.8 for ImageEnhance transforms of PIL.""" 427 | def impl(pil_img, level): 428 | v = float_parameter(level, 1.8) + .1 # going to 0 just destroys it 429 | return enhancer(pil_img).enhance(v) 430 | return impl 431 | 432 | 433 | color = TransformT('Color', _enhancer_impl(ImageEnhance.Color)) 434 | 435 | for m in M_list: 436 | mop = lambda pil_img:color.do_transform(pil_img, m) 437 | aug_ohl_list.append(mop) 438 | aug_name_ls.append('Color.'+str(m)) 439 | 440 | contrast = TransformT('Contrast', _enhancer_impl(ImageEnhance.Contrast)) 441 | 442 | for m in M_list: 443 | mop = lambda pil_img:contrast.do_transform(pil_img, m) 444 | aug_ohl_list.append(mop) 445 | aug_name_ls.append('Contrast.'+str(m)) 446 | 447 | brightness = TransformT('Brightness', _enhancer_impl( 448 | ImageEnhance.Brightness)) 449 | 450 | for m in M_list: 451 | mop = lambda pil_img:brightness.do_transform(pil_img, m) 452 | aug_ohl_list.append(mop) 453 | aug_name_ls.append('Brightness.'+str(m)) 454 | 455 | sharpness = TransformT('Sharpness', _enhancer_impl(ImageEnhance.Sharpness)) 456 | 457 | for m in M_list: 458 | mop = lambda pil_img:sharpness.do_transform(pil_img, m) 459 | aug_ohl_list.append(mop) 460 | aug_name_ls.append('Sharpness.'+str(m)) 461 | 462 | ALL_TRANSFORMS = [ 463 | identity, 464 | auto_contrast, 465 | equalize, 466 | rotate, 467 | posterize, 468 | solarize, 469 | color, 470 | contrast, 471 | brightness, 472 | sharpness, 473 | shear_x, 474 | shear_y, 475 | translate_x, 476 | translate_y, 477 | ] 478 | 479 | random_policy_ops = [ 480 | 'Identity', 'AutoContrast', 'Equalize', 'Rotate', 481 | 'Solarize', 'Color', 'Contrast', 'Brightness', 482 | 'Sharpness', 'ShearX', 'TranslateX', 'TranslateY', 483 | 'Posterize', 'ShearY' 484 | ] 485 | 486 | def augment_list(): 487 | l = [ 488 | identity, 489 | auto_contrast, 490 | equalize, 491 | rotate, 492 | posterize, 493 | solarize, 494 | color, 495 | contrast, 496 | brightness, 497 | sharpness, 498 | shear_x, 499 | shear_y, 500 | translate_x, 501 | translate_y] 502 | return l 503 | 504 | def augment_mag_stage_list(): 505 | l = [ 506 | identity, 507 | auto_contrast, 508 | equalize, 509 | rotate, 510 | posterize, 511 | solarize, 512 | color, 513 | contrast, 514 | brightness, 515 | sharpness, 516 | shear_x, 517 | shear_y, 518 | translate_x, 519 | translate_y] 520 | amsl = [] 521 | for m in M_list: 522 | tmp = [] 523 | for op in l: 524 | tmp.append(lambda pil_img:op.do_transform(pil_img, m)) 525 | amsl.append(tmp) 526 | return amsl 527 | class Curriculum_Aug: 528 | def __init__(self, n, th): 529 | self.n = n 530 | self.aug_ohl_list = augment_mag_stage_list() 531 | self.sl = len(self.aug_ohl_list[0]) 532 | self.th=th 533 | self.stage = 1 534 | 535 | def __call__(self, img): 536 | ss = np.random.choice(self.stage, self.n) 537 | ids = np.random.choice(self.sl, self.n) 538 | # print(idxs) 539 | if random.random() 0: 91 | transform_train.transforms.append(CutoutDefault(C.get()['cutout'])) 92 | 93 | if dataset == 'cifar10': 94 | total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train) 95 | testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) 96 | elif dataset == 'cifar100': 97 | total_trainset = torchvision.datasets.CIFAR100(root=dataroot, train=True, download=True, transform=transform_train) 98 | testset = torchvision.datasets.CIFAR100(root=dataroot, train=False, download=True, transform=transform_test) 99 | elif dataset == 'svhn': 100 | trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train) 101 | extraset = torchvision.datasets.SVHN(root=dataroot, split='extra', download=True, transform=transform_train) 102 | total_trainset = ConcatDataset([trainset, extraset]) 103 | testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test) 104 | elif dataset == 'imagenet': 105 | total_trainset = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'train'), transform=transform_train) 106 | testset = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'val'), transform=transform_test) 107 | # total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet'), transform=transform_train) 108 | # testset = ImageNet(root=os.path.join(dataroot, 'imagenet'), split='val', transform=transform_test) 109 | 110 | # compatibility 111 | total_trainset.targets = [lb for _, lb in total_trainset.samples] 112 | else: 113 | raise ValueError('invalid dataset name=%s' % dataset) 114 | 115 | train_sampler = None 116 | if split > 0.0: 117 | sss = StratifiedShuffleSplit(n_splits=5, test_size=split, random_state=0) 118 | sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) 119 | for _ in range(split_idx + 1): 120 | train_idx, valid_idx = next(sss) 121 | 122 | train_sampler = SubsetRandomSampler(train_idx) 123 | valid_sampler = SubsetSampler(valid_idx) 124 | else: 125 | valid_sampler = SubsetSampler([]) 126 | 127 | trainloader = torch.utils.data.DataLoader( 128 | total_trainset, batch_size=batch, shuffle=True if train_sampler is None else False, num_workers = 16, pin_memory=True, 129 | sampler=train_sampler, drop_last=True) 130 | validloader = torch.utils.data.DataLoader( 131 | total_trainset, batch_size=batch, shuffle=False, num_workers = 16, pin_memory=True, 132 | sampler=valid_sampler, drop_last=False) 133 | 134 | testloader = torch.utils.data.DataLoader( 135 | testset, batch_size=batch, shuffle=False, num_workers = 16, pin_memory=True, 136 | drop_last=False 137 | ) 138 | return train_sampler, trainloader, validloader, testloader 139 | 140 | class SubsetSampler(Sampler): 141 | r"""Samples elements from a given list of indices, without replacement. 142 | 143 | Arguments: 144 | indices (sequence): a sequence of indices 145 | """ 146 | 147 | def __init__(self, indices): 148 | self.indices = indices 149 | 150 | def __iter__(self): 151 | return (i for i in self.indices) 152 | 153 | def __len__(self): 154 | return len(self.indices) 155 | def get_val_test_dataloader(dataset, batch, dataroot, split = 0.1): 156 | if 'cifar' in dataset or 'svhn' in dataset: 157 | transform_test = transforms.Compose([ 158 | transforms.ToTensor(), 159 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), 160 | ]) 161 | elif 'imagenet' in dataset: 162 | transform_test = transforms.Compose([ 163 | transforms.Resize(256, interpolation=Image.BICUBIC), 164 | transforms.CenterCrop(224), 165 | transforms.ToTensor(), 166 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 167 | ]) 168 | else: 169 | raise ValueError('dataset=%s' % dataset) 170 | 171 | if dataset == 'cifar10' or dataset == 'reduced_cifar10': 172 | total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_test) 173 | testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=transform_test) 174 | elif dataset == 'cifar100': 175 | total_trainset = torchvision.datasets.CIFAR100(root=dataroot, train=True, download=True, transform=transform_test) 176 | testset = torchvision.datasets.CIFAR100(root=dataroot, train=False, download=True, transform=transform_test) 177 | elif dataset == 'svhn': 178 | trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_test) 179 | extraset = torchvision.datasets.SVHN(root=dataroot, split='extra', download=True, transform=transform_test) 180 | total_trainset = ConcatDataset([trainset, extraset]) 181 | testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test) 182 | elif dataset == 'svhn' or dataset == 'svhn_core': 183 | 184 | total_trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_test) 185 | total_trainset.targets = [lb for lb in total_trainset.labels] 186 | testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=transform_test) 187 | elif dataset == 'imagenet': 188 | total_trainset = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'train'), transform=transform_test) 189 | testset = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'val'), transform=transform_test) 190 | 191 | total_trainset.targets = [lb for _, lb in total_trainset.samples] 192 | else: 193 | raise ValueError('invalid dataset name=%s' % dataset) 194 | train_sampler = None 195 | sss = StratifiedShuffleSplit(n_splits=1, test_size=split, random_state=0) 196 | if dataset == 'reduced_cifar10': 197 | sss = StratifiedShuffleSplit(n_splits=1, test_size=split/10, random_state=0) 198 | sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) 199 | train_idx, valid_idx = next(sss) 200 | #assuming that train idx is smaller than train index! 201 | valid_idx = valid_idx[0:len(train_idx)] 202 | train_sampler = SubsetSampler(train_idx) 203 | if dataset == 'reduced_cifar10': 204 | train_idx = train_idx[0:4000] 205 | train_sampler = SubsetSampler(train_idx) 206 | train_target = [total_trainset.targets[idx] for idx in train_idx] 207 | 208 | train_target = [total_trainset.targets[idx] for idx in train_idx] 209 | label_freq = {} 210 | for lb in set(train_target): 211 | label_freq[lb] = train_target.count(lb) 212 | print(label_freq) 213 | print("length of train idx") 214 | print(len(train_idx)) 215 | print(len(valid_idx)) 216 | print(len(train_sampler.indices)) 217 | valid_sampler = SubsetSampler(valid_idx) 218 | 219 | validloader = torch.utils.data.DataLoader( 220 | total_trainset, batch_size=256, shuffle=False, num_workers=2, pin_memory=True, 221 | sampler=valid_sampler, drop_last=False) 222 | 223 | testloader = torch.utils.data.DataLoader( 224 | testset, batch_size=256, shuffle=False, num_workers=2, pin_memory=True, 225 | drop_last=False 226 | ) 227 | return train_sampler, validloader, testloader 228 | 229 | #fixed batch size for each data loader! 230 | def Get_DataLoaders_Epoch_s(dataset, batch, dataroot, random_sampler, AugTypes, loader_num = 4): 231 | loaders = [] 232 | idx_epoch = [] 233 | assert len(AugTypes) == loader_num 234 | for idx in random_sampler: 235 | idx_epoch.append(idx) 236 | #turn random sample to fixed id sampler! 237 | SubsetSampler_epoch = SubsetSampler(idx_epoch) 238 | for i in range(loader_num): 239 | loaders.append(get_dataloader_epoch(dataset, batch, dataroot, sampler = SubsetSampler_epoch, AugType = AugTypes[i])) 240 | #here to delet the augmentation in the 1st loader 241 | print(loaders[0].dataset.transform.transforms.pop(0)) 242 | return loaders 243 | 244 | def get_dataloader_epoch(dataset, batch, dataroot, sampler=None, AugType = (2,5)): 245 | if 'cifar' in dataset or 'svhn' in dataset: 246 | transform_train = transforms.Compose([ 247 | transforms.RandomCrop(32, padding=4), 248 | transforms.RandomHorizontalFlip(), 249 | transforms.ToTensor(), 250 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), 251 | ]) 252 | transform_test = transforms.Compose([ 253 | transforms.ToTensor(), 254 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), 255 | ]) 256 | elif 'imagenet' in dataset: 257 | transform_train = transforms.Compose([ 258 | transforms.RandomResizedCrop(224, scale=(0.08, 1.0), interpolation=Image.BICUBIC), 259 | transforms.RandomHorizontalFlip(), 260 | transforms.ColorJitter( 261 | brightness=0.4, 262 | contrast=0.4, 263 | saturation=0.4, 264 | ), 265 | transforms.ToTensor(), 266 | Lighting(0.1, _IMAGENET_PCA['eigval'], _IMAGENET_PCA['eigvec']), 267 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 268 | ]) 269 | 270 | transform_test = transforms.Compose([ 271 | transforms.Resize(256, interpolation=Image.BICUBIC), 272 | transforms.CenterCrop(224), 273 | transforms.ToTensor(), 274 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 275 | ]) 276 | else: 277 | raise ValueError('dataset=%s' % dataset) 278 | 279 | #logger.debug('augmentation: %s' % C.get()['aug']) 280 | if C.get()['aug'] == 'randaugment': 281 | transform_train.transforms.insert(0, RandAugment(C.get()['randaug']['N'], C.get()['randaug']['M'])) 282 | elif C.get()['aug'] == 'rwaug_s': 283 | transform_train.transforms.insert(0, RWAug_Search(C.get()['rwaug']['n'],AugType[1])) 284 | elif C.get()['aug'] == 'randaugment_G': 285 | transform_train.transforms.insert(0, RandAugment_G(AugType[0], AugType[1])) 286 | elif C.get()['aug'] == 'randaugment_C': 287 | transform_train.transforms.insert(0, RandAugment_C(AugType[0], AugType[1])) 288 | elif C.get()['aug'] in ['default', 'inception', 'inception320','mix']: 289 | pass 290 | else: 291 | raise ValueError('not found augmentations. %s' % C.get()['aug']) 292 | 293 | if C.get()['cutout'] > 0: 294 | transform_train.transforms.append(CutoutDefault(C.get()['cutout'])) 295 | 296 | if dataset == 'cifar10' or dataset == 'reduced_cifar10': 297 | total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=transform_train) 298 | elif dataset == 'cifar100': 299 | total_trainset = torchvision.datasets.CIFAR100(root=dataroot, train=True, download=True, transform=transform_train) 300 | elif dataset == 'svhn': 301 | trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train) 302 | extraset = torchvision.datasets.SVHN(root=dataroot, split='extra', download=True, transform=transform_train) 303 | total_trainset = ConcatDataset([trainset, extraset]) 304 | elif dataset == 'svhn_core': 305 | total_trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=transform_train) 306 | elif dataset == 'imagenet': 307 | total_trainset = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'train'), transform=transform_train) 308 | # testset = torchvision.datasets.ImageFolder(root=os.path.join(dataroot, 'val'), transform=transform_test) 309 | # total_trainset = ImageNet(root=os.path.join(dataroot, 'imagenet-pytorch'), transform=transform_train) 310 | total_trainset.targets = [lb for _, lb in total_trainset.samples] 311 | else: 312 | raise ValueError('invalid dataset name=%s' % dataset) 313 | 314 | train_sampler = sampler 315 | 316 | trainloader = torch.utils.data.DataLoader( 317 | total_trainset, batch_size=batch, shuffle=True if train_sampler is None else False, num_workers=1, pin_memory=True, 318 | sampler=train_sampler, drop_last=True) 319 | return trainloader 320 | 321 | if __name__ == '__main__': 322 | a=[1,2,3,4,5,6,7,8] 323 | sb=SubsetSampler(a) 324 | for i in sb: 325 | print(i) 326 | -------------------------------------------------------------------------------- /imagenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import shutil 4 | import torch 5 | 6 | ARCHIVE_DICT = { 7 | 'train': { 8 | 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar', 9 | 'md5': '1d675b47d978889d74fa0da5fadfb00e', 10 | }, 11 | 'val': { 12 | 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar', 13 | 'md5': '29b22e2961454d5413ddabcf34fc5622', 14 | }, 15 | 'devkit': { 16 | 'url': 'http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar.gz', 17 | 'md5': 'fa75699e90414af021442c21a62c3abf', 18 | } 19 | } 20 | 21 | 22 | import torchvision 23 | from torchvision.datasets.utils import check_integrity, download_url 24 | 25 | 26 | # copy ILSVRC/ImageSets/CLS-LOC/train_cls.txt to ./root/ 27 | # to skip os walk (it's too slow) using ILSVRC/ImageSets/CLS-LOC/train_cls.txt file 28 | class ImageNet(torchvision.datasets.ImageFolder): 29 | """`ImageNet `_ 2012 Classification Dataset. 30 | 31 | Args: 32 | root (string): Root directory of the ImageNet Dataset. 33 | split (string, optional): The dataset split, supports ``train``, or ``val``. 34 | download (bool, optional): If true, downloads the dataset from the internet and 35 | puts it in root directory. If dataset is already downloaded, it is not 36 | downloaded again. 37 | transform (callable, optional): A function/transform that takes in an PIL image 38 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 39 | target_transform (callable, optional): A function/transform that takes in the 40 | target and transforms it. 41 | loader (callable, optional): A function to load an image given its path. 42 | 43 | Attributes: 44 | classes (list): List of the class names. 45 | class_to_idx (dict): Dict with items (class_name, class_index). 46 | wnids (list): List of the WordNet IDs. 47 | wnid_to_idx (dict): Dict with items (wordnet_id, class_index). 48 | imgs (list): List of (image path, class_index) tuples 49 | targets (list): The class_index value for each image in the dataset 50 | """ 51 | 52 | def __init__(self, root, split='train', download=False, **kwargs): 53 | root = self.root = os.path.expanduser(root) 54 | self.split = self._verify_split(split) 55 | 56 | if download: 57 | self.download() 58 | wnid_to_classes = self._load_meta_file()[0] 59 | 60 | # to skip os walk (it's too slow) using ILSVRC/ImageSets/CLS-LOC/train_cls.txt file 61 | listfile = os.path.join(root, 'train_cls.txt') 62 | if split == 'train' and os.path.exists(listfile): 63 | torchvision.datasets.VisionDataset.__init__(self, root, **kwargs) 64 | with open(listfile, 'r') as f: 65 | datalist = [ 66 | line.strip().split(' ')[0] 67 | for line in f.readlines() 68 | if line.strip() 69 | ] 70 | 71 | classes = list(set([line.split('/')[0] for line in datalist])) 72 | classes.sort() 73 | class_to_idx = {classes[i]: i for i in range(len(classes))} 74 | 75 | samples = [ 76 | (os.path.join(self.split_folder, line + '.JPEG'), class_to_idx[line.split('/')[0]]) 77 | for line in datalist 78 | ] 79 | 80 | self.loader = torchvision.datasets.folder.default_loader 81 | self.extensions = torchvision.datasets.folder.IMG_EXTENSIONS 82 | 83 | self.classes = classes 84 | self.class_to_idx = class_to_idx 85 | self.samples = samples 86 | self.targets = [s[1] for s in samples] 87 | 88 | self.imgs = self.samples 89 | else: 90 | super(ImageNet, self).__init__(self.split_folder, **kwargs) 91 | 92 | self.root = root 93 | 94 | idcs = [idx for _, idx in self.imgs] 95 | self.wnids = self.classes 96 | self.wnid_to_idx = {wnid: idx for idx, wnid in zip(idcs, self.wnids)} 97 | self.classes = [wnid_to_classes[wnid] for wnid in self.wnids] 98 | self.class_to_idx = {cls: idx 99 | for clss, idx in zip(self.classes, idcs) 100 | for cls in clss} 101 | 102 | def download(self): 103 | if not check_integrity(self.meta_file): 104 | tmpdir = os.path.join(self.root, 'tmp') 105 | 106 | archive_dict = ARCHIVE_DICT['devkit'] 107 | download_and_extract_tar(archive_dict['url'], self.root, 108 | extract_root=tmpdir, 109 | md5=archive_dict['md5']) 110 | devkit_folder = _splitexts(os.path.basename(archive_dict['url']))[0] 111 | meta = parse_devkit(os.path.join(tmpdir, devkit_folder)) 112 | self._save_meta_file(*meta) 113 | 114 | shutil.rmtree(tmpdir) 115 | 116 | if not os.path.isdir(self.split_folder): 117 | archive_dict = ARCHIVE_DICT[self.split] 118 | download_and_extract_tar(archive_dict['url'], self.root, 119 | extract_root=self.split_folder, 120 | md5=archive_dict['md5']) 121 | 122 | if self.split == 'train': 123 | prepare_train_folder(self.split_folder) 124 | elif self.split == 'val': 125 | val_wnids = self._load_meta_file()[1] 126 | prepare_val_folder(self.split_folder, val_wnids) 127 | else: 128 | msg = ("You set download=True, but a folder '{}' already exist in " 129 | "the root directory. If you want to re-download or re-extract the " 130 | "archive, delete the folder.") 131 | print(msg.format(self.split)) 132 | 133 | @property 134 | def meta_file(self): 135 | return os.path.join(self.root, 'meta.bin') 136 | 137 | def _load_meta_file(self): 138 | if check_integrity(self.meta_file): 139 | return torch.load(self.meta_file) 140 | raise RuntimeError("Meta file not found or corrupted.", 141 | "You can use download=True to create it.") 142 | 143 | def _save_meta_file(self, wnid_to_class, val_wnids): 144 | torch.save((wnid_to_class, val_wnids), self.meta_file) 145 | 146 | def _verify_split(self, split): 147 | if split not in self.valid_splits: 148 | msg = "Unknown split {} .".format(split) 149 | msg += "Valid splits are {{}}.".format(", ".join(self.valid_splits)) 150 | raise ValueError(msg) 151 | return split 152 | 153 | @property 154 | def valid_splits(self): 155 | return 'train', 'val' 156 | 157 | @property 158 | def split_folder(self): 159 | return os.path.join(self.root, self.split) 160 | 161 | def extra_repr(self): 162 | return "Split: {split}".format(**self.__dict__) 163 | 164 | 165 | def extract_tar(src, dest=None, gzip=None, delete=False): 166 | import tarfile 167 | 168 | if dest is None: 169 | dest = os.path.dirname(src) 170 | if gzip is None: 171 | gzip = src.lower().endswith('.gz') 172 | 173 | mode = 'r:gz' if gzip else 'r' 174 | with tarfile.open(src, mode) as tarfh: 175 | tarfh.extractall(path=dest) 176 | 177 | if delete: 178 | os.remove(src) 179 | 180 | 181 | def download_and_extract_tar(url, download_root, extract_root=None, filename=None, 182 | md5=None, **kwargs): 183 | download_root = os.path.expanduser(download_root) 184 | if extract_root is None: 185 | extract_root = download_root 186 | if filename is None: 187 | filename = os.path.basename(url) 188 | 189 | if not check_integrity(os.path.join(download_root, filename), md5): 190 | download_url(url, download_root, filename=filename, md5=md5) 191 | 192 | extract_tar(os.path.join(download_root, filename), extract_root, **kwargs) 193 | 194 | 195 | def parse_devkit(root): 196 | idx_to_wnid, wnid_to_classes = parse_meta(root) 197 | val_idcs = parse_val_groundtruth(root) 198 | val_wnids = [idx_to_wnid[idx] for idx in val_idcs] 199 | return wnid_to_classes, val_wnids 200 | 201 | 202 | def parse_meta(devkit_root, path='data', filename='meta.mat'): 203 | import scipy.io as sio 204 | 205 | metafile = os.path.join(devkit_root, path, filename) 206 | meta = sio.loadmat(metafile, squeeze_me=True)['synsets'] 207 | nums_children = list(zip(*meta))[4] 208 | meta = [meta[idx] for idx, num_children in enumerate(nums_children) 209 | if num_children == 0] 210 | idcs, wnids, classes = list(zip(*meta))[:3] 211 | classes = [tuple(clss.split(', ')) for clss in classes] 212 | idx_to_wnid = {idx: wnid for idx, wnid in zip(idcs, wnids)} 213 | wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} 214 | return idx_to_wnid, wnid_to_classes 215 | 216 | 217 | def parse_val_groundtruth(devkit_root, path='data', 218 | filename='ILSVRC2012_validation_ground_truth.txt'): 219 | with open(os.path.join(devkit_root, path, filename), 'r') as txtfh: 220 | val_idcs = txtfh.readlines() 221 | return [int(val_idx) for val_idx in val_idcs] 222 | 223 | 224 | def prepare_train_folder(folder): 225 | for archive in [os.path.join(folder, archive) for archive in os.listdir(folder)]: 226 | extract_tar(archive, os.path.splitext(archive)[0], delete=True) 227 | 228 | 229 | def prepare_val_folder(folder, wnids): 230 | img_files = sorted([os.path.join(folder, file) for file in os.listdir(folder)]) 231 | 232 | for wnid in set(wnids): 233 | os.mkdir(os.path.join(folder, wnid)) 234 | 235 | for wnid, img_file in zip(wnids, img_files): 236 | shutil.move(img_file, os.path.join(folder, wnid, os.path.basename(img_file))) 237 | 238 | 239 | def _splitexts(root): 240 | exts = [] 241 | ext = '.' 242 | while ext: 243 | root, ext = os.path.splitext(root) 244 | exts.append(ext) 245 | return root, ''.join(reversed(exts)) 246 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from theconf import Config as C 4 | 5 | 6 | def adjust_learning_rate_resnet(optimizer): 7 | """ 8 | Sets the learning rate to the initial LR decayed by 10 on every predefined epochs 9 | Ref: AutoAugment 10 | """ 11 | 12 | if C.get()['epoch'] == 90: 13 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 80]) 14 | elif C.get()['epoch'] == 180: 15 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, [60, 120, 160]) 16 | elif C.get()['epoch'] == 270: # autoaugment 17 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, [90, 180, 240]) 18 | else: 19 | raise ValueError('invalid epoch=%d for resnet scheduler' % C.get()['epoch']) 20 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | from collections import defaultdict 5 | 6 | from torch import nn 7 | 8 | 9 | def accuracy(output, target, topk=(1,)): 10 | """Computes the precision@k for the specified values of k""" 11 | maxk = max(topk) 12 | batch_size = target.size(0) 13 | 14 | _, pred = output.topk(maxk, 1, True, True) 15 | pred = pred.t() 16 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 17 | 18 | res = [] 19 | for k in topk: 20 | correct_k = correct[:k].contiguous().view(-1).float().sum(0) 21 | res.append(correct_k.mul_(1. / batch_size)) 22 | return res 23 | 24 | 25 | class Accumulator: 26 | def __init__(self): 27 | self.metrics = defaultdict(lambda: 0.) 28 | 29 | def add(self, key, value): 30 | self.metrics[key] += value 31 | 32 | def add_dict(self, dict): 33 | for key, value in dict.items(): 34 | self.add(key, value) 35 | 36 | def __getitem__(self, item): 37 | return self.metrics[item] 38 | 39 | def __setitem__(self, key, value): 40 | self.metrics[key] = value 41 | 42 | def get_dict(self): 43 | return copy.deepcopy(dict(self.metrics)) 44 | 45 | def items(self): 46 | return self.metrics.items() 47 | 48 | def __str__(self): 49 | return str(dict(self.metrics)) 50 | 51 | def __truediv__(self, other): 52 | newone = Accumulator() 53 | for key, value in self.items(): 54 | if isinstance(other, str): 55 | if other != key: 56 | newone[key] = value / self[other] 57 | else: 58 | newone[key] = value 59 | else: 60 | newone[key] = value / other 61 | return newone 62 | 63 | 64 | class SummaryWriterDummy: 65 | def __init__(self, log_dir): 66 | pass 67 | 68 | def add_scalar(self, *args, **kwargs): 69 | pass 70 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | from torch.nn import DataParallel 5 | import torch.backends.cudnn as cudnn 6 | # from torchvision import models 7 | 8 | from resnet import ResNet 9 | from pyramidnet import PyramidNet 10 | from shake_resnet import ShakeResNet 11 | from wideresnet import WideResNet 12 | from shake_resnext import ShakeResNeXt 13 | 14 | 15 | def get_model(conf, num_class=10): 16 | name = conf['type'] 17 | 18 | if name == 'resnet50': 19 | model = ResNet(dataset='imagenet', depth=50, num_classes=num_class, bottleneck=True) 20 | elif name == 'resnet18': 21 | model = ResNet(dataset='imagenet', depth=18, num_classes=num_class, bottleneck=True) 22 | elif name == 'resnet200': 23 | model = ResNet(dataset='imagenet', depth=200, num_classes=num_class, bottleneck=True) 24 | elif name == 'wresnet40_2': 25 | model = WideResNet(40, 2, dropout_rate=0.0, num_classes=num_class) 26 | elif name == 'wresnet28_10': 27 | model = WideResNet(28, 10, dropout_rate=0.0, num_classes=num_class) 28 | 29 | elif name == 'shakeshake26_2x32d': 30 | model = ShakeResNet(26, 32, num_class) 31 | elif name == 'shakeshake26_2x64d': 32 | model = ShakeResNet(26, 64, num_class) 33 | elif name == 'shakeshake26_2x96d': 34 | model = ShakeResNet(26, 96, num_class) 35 | elif name == 'shakeshake26_2x112d': 36 | model = ShakeResNet(26, 112, num_class) 37 | 38 | elif name == 'shakeshake26_2x96d_next': 39 | model = ShakeResNeXt(26, 96, 4, num_class) 40 | 41 | elif name == 'pyramid': 42 | model = PyramidNet('cifar10', depth=conf['depth'], alpha=conf['alpha'], num_classes=num_class, bottleneck=conf['bottleneck']) 43 | else: 44 | raise NameError('no model named, %s' % name) 45 | 46 | model = model.cuda() 47 | model = DataParallel(model) 48 | cudnn.benchmark = True 49 | return model 50 | 51 | def get_model_np(conf, num_class=10): 52 | name = conf['type'] 53 | 54 | if name == 'resnet50': 55 | model = ResNet(dataset='imagenet', depth=50, num_classes=num_class, bottleneck=True) 56 | elif name == 'resnet200': 57 | model = ResNet(dataset='imagenet', depth=200, num_classes=num_class, bottleneck=True) 58 | elif name == 'wresnet40_2': 59 | model = WideResNet(40, 2, dropout_rate=0.0, num_classes=num_class) 60 | elif name == 'wresnet28_10': 61 | model = WideResNet(28, 10, dropout_rate=0.0, num_classes=num_class) 62 | 63 | elif name == 'shakeshake26_2x32d': 64 | model = ShakeResNet(26, 32, num_class) 65 | elif name == 'shakeshake26_2x64d': 66 | model = ShakeResNet(26, 64, num_class) 67 | elif name == 'shakeshake26_2x96d': 68 | model = ShakeResNet(26, 96, num_class) 69 | elif name == 'shakeshake26_2x112d': 70 | model = ShakeResNet(26, 112, num_class) 71 | 72 | elif name == 'shakeshake26_2x96d_next': 73 | model = ShakeResNeXt(26, 96, 4, num_class) 74 | 75 | elif name == 'pyramid': 76 | model = PyramidNet('cifar10', depth=conf['depth'], alpha=conf['alpha'], num_classes=num_class, bottleneck=conf['bottleneck']) 77 | else: 78 | raise NameError('no model named, %s' % name) 79 | 80 | model = model.cuda() 81 | #model = DataParallel(model) 82 | cudnn.benchmark = True 83 | return model 84 | 85 | 86 | def num_class(dataset): 87 | return { 88 | 'cifar10': 10, 89 | 'reduced_cifar10': 10, 90 | 'cifar10.1': 10, 91 | 'cifar100': 100, 92 | 'svhn': 10, 93 | 'reduced_svhn': 10, 94 | 'svhn_core':10, 95 | 'imagenet': 1000, 96 | 'reduced_imagenet': 120, 97 | }[dataset] 98 | -------------------------------------------------------------------------------- /process_npy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import argparse 4 | l = 1 5 | 6 | def avg_filter(y, al): 7 | return [sum(y[i : i + al])/al for i in range(0,len(y)-al+1)] 8 | 9 | parser = argparse.ArgumentParser(conflict_handler='resolve') 10 | parser.add_argument('--file_name', type=str, default='wresnet_cifar_new_ab_02_r_save_dict.npy') 11 | parser.add_argument('--out_name', type=str, default='wresnet_cifar_new_ab_02_r_save_dict_smooth_1.npy') 12 | args = parser.parse_args() 13 | #wresnet40x2_cifar10_fast_itp035_slr_sa_b1_save_dict.npy 14 | name = args.file_name 15 | pdict = np.load(name,allow_pickle=True).item() 16 | 17 | prob_dis = pdict['dis_ps'] 18 | 19 | pdict_save = pdict.copy() 20 | y1 = pdict['tps'] 21 | y_save_smooth = avg_filter(y1, l) + [avg_filter(y1, l)[-1] for _ in range(l)] 22 | pdict_save['tps'] = y_save_smooth 23 | pdict_save['w0s_mt'] = y_save_smooth[:len(pdict_save['dis_ps'])] 24 | 25 | print(pdict['tps']) 26 | print(pdict_save['tps']) 27 | print(len(pdict_save['tps'])) 28 | print(len(pdict_save['dis_ps'])) 29 | out_name = args.out_name 30 | np.save(out_name, pdict_save) 31 | -------------------------------------------------------------------------------- /pyramidnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | from shakedrop import ShakeDrop 6 | 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """ 10 | 3x3 convolution with padding 11 | """ 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | outchannel_ratio = 1 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None, p_shakedrop=1.0): 19 | super(BasicBlock, self).__init__() 20 | self.bn1 = nn.BatchNorm2d(inplanes) 21 | self.conv1 = conv3x3(inplanes, planes, stride) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.conv2 = conv3x3(planes, planes) 24 | self.bn3 = nn.BatchNorm2d(planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.downsample = downsample 27 | self.stride = stride 28 | self.shake_drop = ShakeDrop(p_shakedrop) 29 | 30 | def forward(self, x): 31 | 32 | out = self.bn1(x) 33 | out = self.conv1(out) 34 | out = self.bn2(out) 35 | out = self.relu(out) 36 | out = self.conv2(out) 37 | out = self.bn3(out) 38 | 39 | out = self.shake_drop(out) 40 | 41 | if self.downsample is not None: 42 | shortcut = self.downsample(x) 43 | featuremap_size = shortcut.size()[2:4] 44 | else: 45 | shortcut = x 46 | featuremap_size = out.size()[2:4] 47 | 48 | batch_size = out.size()[0] 49 | residual_channel = out.size()[1] 50 | shortcut_channel = shortcut.size()[1] 51 | 52 | if residual_channel != shortcut_channel: 53 | padding = torch.autograd.Variable( 54 | torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], 55 | featuremap_size[1]).fill_(0)) 56 | out += torch.cat((shortcut, padding), 1) 57 | else: 58 | out += shortcut 59 | 60 | return out 61 | 62 | 63 | class Bottleneck(nn.Module): 64 | outchannel_ratio = 4 65 | 66 | def __init__(self, inplanes, planes, stride=1, downsample=None, p_shakedrop=1.0): 67 | super(Bottleneck, self).__init__() 68 | self.bn1 = nn.BatchNorm2d(inplanes) 69 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 70 | self.bn2 = nn.BatchNorm2d(planes) 71 | self.conv2 = nn.Conv2d(planes, (planes * 1), kernel_size=3, stride=stride, 72 | padding=1, bias=False) 73 | self.bn3 = nn.BatchNorm2d((planes * 1)) 74 | self.conv3 = nn.Conv2d((planes * 1), planes * Bottleneck.outchannel_ratio, kernel_size=1, bias=False) 75 | self.bn4 = nn.BatchNorm2d(planes * Bottleneck.outchannel_ratio) 76 | self.relu = nn.ReLU(inplace=True) 77 | self.downsample = downsample 78 | self.stride = stride 79 | self.shake_drop = ShakeDrop(p_shakedrop) 80 | 81 | def forward(self, x): 82 | 83 | out = self.bn1(x) 84 | out = self.conv1(out) 85 | 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | out = self.conv2(out) 89 | 90 | out = self.bn3(out) 91 | out = self.relu(out) 92 | out = self.conv3(out) 93 | 94 | out = self.bn4(out) 95 | 96 | out = self.shake_drop(out) 97 | 98 | if self.downsample is not None: 99 | shortcut = self.downsample(x) 100 | featuremap_size = shortcut.size()[2:4] 101 | else: 102 | shortcut = x 103 | featuremap_size = out.size()[2:4] 104 | 105 | batch_size = out.size()[0] 106 | residual_channel = out.size()[1] 107 | shortcut_channel = shortcut.size()[1] 108 | 109 | if residual_channel != shortcut_channel: 110 | padding = torch.autograd.Variable( 111 | torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], 112 | featuremap_size[1]).fill_(0)) 113 | out += torch.cat((shortcut, padding), 1) 114 | else: 115 | out += shortcut 116 | 117 | return out 118 | 119 | 120 | class PyramidNet(nn.Module): 121 | 122 | def __init__(self, dataset, depth, alpha, num_classes, bottleneck=True): 123 | super(PyramidNet, self).__init__() 124 | self.dataset = dataset 125 | if self.dataset.startswith('cifar'): 126 | self.inplanes = 16 127 | if bottleneck: 128 | n = int((depth - 2) / 9) 129 | block = Bottleneck 130 | else: 131 | n = int((depth - 2) / 6) 132 | block = BasicBlock 133 | 134 | self.addrate = alpha / (3 * n * 1.0) 135 | self.ps_shakedrop = [1. - (1.0 - (0.5 / (3 * n)) * (i + 1)) for i in range(3 * n)] 136 | 137 | self.input_featuremap_dim = self.inplanes 138 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=3, stride=1, padding=1, bias=False) 139 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 140 | 141 | self.featuremap_dim = self.input_featuremap_dim 142 | self.layer1 = self.pyramidal_make_layer(block, n) 143 | self.layer2 = self.pyramidal_make_layer(block, n, stride=2) 144 | self.layer3 = self.pyramidal_make_layer(block, n, stride=2) 145 | 146 | self.final_featuremap_dim = self.input_featuremap_dim 147 | self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim) 148 | self.relu_final = nn.ReLU(inplace=True) 149 | self.avgpool = nn.AvgPool2d(8) 150 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes) 151 | 152 | elif dataset == 'imagenet': 153 | blocks = {18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} 154 | layers = {18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 155 | 200: [3, 24, 36, 3]} 156 | 157 | if layers.get(depth) is None: 158 | if bottleneck == True: 159 | blocks[depth] = Bottleneck 160 | temp_cfg = int((depth - 2) / 12) 161 | else: 162 | blocks[depth] = BasicBlock 163 | temp_cfg = int((depth - 2) / 8) 164 | 165 | layers[depth] = [temp_cfg, temp_cfg, temp_cfg, temp_cfg] 166 | print('=> the layer configuration for each stage is set to', layers[depth]) 167 | 168 | self.inplanes = 64 169 | self.addrate = alpha / (sum(layers[depth]) * 1.0) 170 | 171 | self.input_featuremap_dim = self.inplanes 172 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=7, stride=2, padding=3, bias=False) 173 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 174 | self.relu = nn.ReLU(inplace=True) 175 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 176 | 177 | self.featuremap_dim = self.input_featuremap_dim 178 | self.layer1 = self.pyramidal_make_layer(blocks[depth], layers[depth][0]) 179 | self.layer2 = self.pyramidal_make_layer(blocks[depth], layers[depth][1], stride=2) 180 | self.layer3 = self.pyramidal_make_layer(blocks[depth], layers[depth][2], stride=2) 181 | self.layer4 = self.pyramidal_make_layer(blocks[depth], layers[depth][3], stride=2) 182 | 183 | self.final_featuremap_dim = self.input_featuremap_dim 184 | self.bn_final = nn.BatchNorm2d(self.final_featuremap_dim) 185 | self.relu_final = nn.ReLU(inplace=True) 186 | self.avgpool = nn.AvgPool2d(7) 187 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes) 188 | 189 | for m in self.modules(): 190 | if isinstance(m, nn.Conv2d): 191 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 192 | m.weight.data.normal_(0, math.sqrt(2. / n)) 193 | elif isinstance(m, nn.BatchNorm2d): 194 | m.weight.data.fill_(1) 195 | m.bias.data.zero_() 196 | 197 | assert len(self.ps_shakedrop) == 0, self.ps_shakedrop 198 | 199 | def pyramidal_make_layer(self, block, block_depth, stride=1): 200 | downsample = None 201 | if stride != 1: # or self.inplanes != int(round(featuremap_dim_1st)) * block.outchannel_ratio: 202 | downsample = nn.AvgPool2d((2, 2), stride=(2, 2), ceil_mode=True) 203 | 204 | layers = [] 205 | self.featuremap_dim = self.featuremap_dim + self.addrate 206 | layers.append(block(self.input_featuremap_dim, int(round(self.featuremap_dim)), stride, downsample, p_shakedrop=self.ps_shakedrop.pop(0))) 207 | for i in range(1, block_depth): 208 | temp_featuremap_dim = self.featuremap_dim + self.addrate 209 | layers.append( 210 | block(int(round(self.featuremap_dim)) * block.outchannel_ratio, int(round(temp_featuremap_dim)), 1, p_shakedrop=self.ps_shakedrop.pop(0))) 211 | self.featuremap_dim = temp_featuremap_dim 212 | self.input_featuremap_dim = int(round(self.featuremap_dim)) * block.outchannel_ratio 213 | 214 | return nn.Sequential(*layers) 215 | 216 | def forward(self, x): 217 | if self.dataset == 'cifar10' or self.dataset == 'cifar100': 218 | x = self.conv1(x) 219 | x = self.bn1(x) 220 | 221 | x = self.layer1(x) 222 | x = self.layer2(x) 223 | x = self.layer3(x) 224 | 225 | x = self.bn_final(x) 226 | x = self.relu_final(x) 227 | x = self.avgpool(x) 228 | x = x.view(x.size(0), -1) 229 | x = self.fc(x) 230 | 231 | elif self.dataset == 'imagenet': 232 | x = self.conv1(x) 233 | x = self.bn1(x) 234 | x = self.relu(x) 235 | x = self.maxpool(x) 236 | 237 | x = self.layer1(x) 238 | x = self.layer2(x) 239 | x = self.layer3(x) 240 | x = self.layer4(x) 241 | 242 | x = self.bn_final(x) 243 | x = self.relu_final(x) 244 | x = self.avgpool(x) 245 | x = x.view(x.size(0), -1) 246 | x = self.fc(x) 247 | 248 | return x 249 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | # Original code: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 2 | 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = conv3x3(planes, planes) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.relu = nn.ReLU(inplace=True) 23 | 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | residual = x 29 | 30 | out = self.conv1(x) 31 | out = self.bn1(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv2(out) 35 | out = self.bn2(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | 46 | class Bottleneck(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, inplanes, planes, stride=1, downsample=None): 50 | super(Bottleneck, self).__init__() 51 | 52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False) 57 | self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion) 58 | self.relu = nn.ReLU(inplace=True) 59 | 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | residual = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv3(out) 75 | out = self.bn3(out) 76 | if self.downsample is not None: 77 | residual = self.downsample(x) 78 | 79 | out += residual 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | class ResNet(nn.Module): 85 | def __init__(self, dataset, depth, num_classes, bottleneck=False): 86 | super(ResNet, self).__init__() 87 | self.dataset = dataset 88 | if self.dataset.startswith('cifar'): 89 | self.inplanes = 16 90 | print(bottleneck) 91 | if bottleneck == True: 92 | n = int((depth - 2) / 9) 93 | block = Bottleneck 94 | else: 95 | n = int((depth - 2) / 6) 96 | block = BasicBlock 97 | 98 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 99 | self.bn1 = nn.BatchNorm2d(self.inplanes) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.layer1 = self._make_layer(block, 16, n) 102 | self.layer2 = self._make_layer(block, 32, n, stride=2) 103 | self.layer3 = self._make_layer(block, 64, n, stride=2) 104 | # self.avgpool = nn.AvgPool2d(8) 105 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 106 | self.fc = nn.Linear(64 * block.expansion, num_classes) 107 | 108 | elif dataset == 'imagenet': 109 | blocks = {18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} 110 | layers = {18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]} 111 | assert layers[depth], 'invalid detph for ResNet (depth should be one of 18, 34, 50, 101, 152, and 200)' 112 | 113 | self.inplanes = 64 114 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 115 | self.bn1 = nn.BatchNorm2d(64) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 118 | self.layer1 = self._make_layer(blocks[depth], 64, layers[depth][0]) 119 | self.layer2 = self._make_layer(blocks[depth], 128, layers[depth][1], stride=2) 120 | self.layer3 = self._make_layer(blocks[depth], 256, layers[depth][2], stride=2) 121 | self.layer4 = self._make_layer(blocks[depth], 512, layers[depth][3], stride=2) 122 | # self.avgpool = nn.AvgPool2d(7) 123 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 124 | self.fc = nn.Linear(512 * blocks[depth].expansion, num_classes) 125 | 126 | for m in self.modules(): 127 | if isinstance(m, nn.Conv2d): 128 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 129 | m.weight.data.normal_(0, math.sqrt(2. / n)) 130 | elif isinstance(m, nn.BatchNorm2d): 131 | m.weight.data.fill_(1) 132 | m.bias.data.zero_() 133 | 134 | def _make_layer(self, block, planes, blocks, stride=1): 135 | downsample = None 136 | if stride != 1 or self.inplanes != planes * block.expansion: 137 | downsample = nn.Sequential( 138 | nn.Conv2d(self.inplanes, planes * block.expansion, 139 | kernel_size=1, stride=stride, bias=False), 140 | nn.BatchNorm2d(planes * block.expansion), 141 | ) 142 | 143 | layers = [] 144 | layers.append(block(self.inplanes, planes, stride, downsample)) 145 | self.inplanes = planes * block.expansion 146 | for i in range(1, blocks): 147 | layers.append(block(self.inplanes, planes)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x): 152 | if self.dataset == 'cifar10' or self.dataset == 'cifar100': 153 | x = self.conv1(x) 154 | x = self.bn1(x) 155 | x = self.relu(x) 156 | 157 | x = self.layer1(x) 158 | x = self.layer2(x) 159 | x = self.layer3(x) 160 | 161 | x = self.avgpool(x) 162 | x = x.view(x.size(0), -1) 163 | x = self.fc(x) 164 | 165 | elif self.dataset == 'imagenet': 166 | x = self.conv1(x) 167 | x = self.bn1(x) 168 | x = self.relu(x) 169 | x = self.maxpool(x) 170 | 171 | x = self.layer1(x) 172 | x = self.layer2(x) 173 | x = self.layer3(x) 174 | x = self.layer4(x) 175 | 176 | x = self.avgpool(x) 177 | x = x.view(x.size(0), -1) 178 | x = self.fc(x) 179 | 180 | return x 181 | -------------------------------------------------------------------------------- /search.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import logging 4 | import math 5 | import os 6 | from collections import OrderedDict 7 | import numpy as np 8 | import copy 9 | from torchvision.transforms import transforms 10 | import torch 11 | import random 12 | from torch import nn, optim 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | from tqdm import tqdm 16 | from theconf import Config as C, ConfigArgumentParser 17 | from tensorboardX import SummaryWriter 18 | 19 | from common import get_logger 20 | from data import get_dataloaders, Get_DataLoaders_Epoch_s, get_val_test_dataloader 21 | from lr_scheduler import adjust_learning_rate_resnet 22 | from metrics import accuracy, Accumulator 23 | from networks import get_model, num_class 24 | from warmup_scheduler import GradualWarmupScheduler 25 | from augmentations import aug_ohl_list, RWAug_Search 26 | 27 | from common import add_filehandler 28 | from smooth_ce import SmoothCrossEntropyLoss 29 | 30 | from itertools import cycle 31 | 32 | logger = get_logger('RandAugment') 33 | logger.setLevel(logging.INFO) 34 | 35 | dis_ps = [] 36 | tps = [] 37 | #Function to run normal training! 38 | def run_epoch(model, loader, loss_fn, optimizer, desc_default='', epoch=0, writer=None, verbose=1, scheduler=None): 39 | tqdm_disable = bool(os.environ.get('TASK_NAME', '')) 40 | if verbose: 41 | loader = tqdm(loader, disable=tqdm_disable) 42 | loader.set_description('[%s %04d/%04d]' % (desc_default, epoch, C.get()['epoch'])) 43 | 44 | metrics = Accumulator() 45 | cnt = 0 46 | total_steps = len(loader) 47 | steps = 0 48 | for data, label in loader: 49 | steps += 1 50 | data, label = data.cuda(), label.cuda() 51 | 52 | if optimizer: 53 | optimizer.zero_grad() 54 | 55 | preds = model(data) 56 | loss = loss_fn(preds, label) 57 | 58 | if optimizer: 59 | loss.backward() 60 | if C.get()['optimizer'].get('clip', 5) > 0: 61 | nn.utils.clip_grad_norm_(model.parameters(), C.get()['optimizer'].get('clip', 5)) 62 | optimizer.step() 63 | 64 | top1, top5 = accuracy(preds, label, (1, 5)) 65 | metrics.add_dict({ 66 | 'loss': loss.item() * len(data), 67 | 'top1': top1.item() * len(data), 68 | 'top5': top5.item() * len(data), 69 | }) 70 | cnt += len(data) 71 | if verbose: 72 | postfix = metrics / cnt 73 | if optimizer: 74 | postfix['lr'] = optimizer.param_groups[0]['lr'] 75 | loader.set_postfix(postfix) 76 | 77 | if scheduler is not None: 78 | scheduler.step(epoch - 1 + float(steps) / total_steps) 79 | 80 | del preds, loss, top1, top5, data, label 81 | 82 | if tqdm_disable: 83 | if optimizer: 84 | logger.info('[%s %03d/%03d] %s lr=%.6f', desc_default, epoch, C.get()['epoch'], metrics / cnt, optimizer.param_groups[0]['lr']) 85 | else: 86 | logger.info('[%s %03d/%03d] %s', desc_default, epoch, C.get()['epoch'], metrics / cnt) 87 | logger.info('[%s %03d/%03d] %s', desc_default, epoch, C.get()['epoch'], metrics / cnt) 88 | metrics /= cnt 89 | if optimizer: 90 | metrics.metrics['lr'] = optimizer.param_groups[0]['lr'] 91 | if verbose: 92 | for key, value in metrics.items(): 93 | writer.add_scalar(key, value, epoch) 94 | return metrics 95 | #Function to run search 96 | def run_epoch_search(model, loaders, val_loader, loss_fn, optimizer,optimizer_aug, optimizer_tp, 97 | aug_param, tp_param, desc_default='',explore_ratio = 0, w0s_at=[],w0s_mt=[], 98 | ops_num = 2, epoch=0, writer=None, verbose=1, scheduler=None, 99 | dict_reward={}): 100 | tqdm_disable = bool(os.environ.get('TASK_NAME', '')) 101 | 102 | transform_or = copy.deepcopy(loaders[0].dataset.transform) 103 | loaders[0].dataset.transform = transforms.ToTensor() 104 | 105 | if verbose: 106 | loader_t = tqdm(loaders[0], disable=tqdm_disable) 107 | loader_t.set_description('[%s %04d/%04d]' % (desc_default, epoch, C.get()['epoch'])) 108 | 109 | rw_search = RWAug_Search(ops_num,[0,0]) 110 | metrics = Accumulator() 111 | cnt = 0 112 | total_steps = len(loaders[0]) 113 | steps = 0 114 | val_loader = cycle(val_loader) 115 | 116 | dis_ps.append(torch.nn.Softmax()(aug_param).data.numpy()) 117 | tps.append(torch.sigmoid(tp_param).data.item()) 118 | print(dis_ps[-1]) 119 | print(torch.sigmoid(tp_param)) 120 | 121 | #Save the probability 122 | save_dict = {} 123 | save_dict['dis_ps'] = dis_ps 124 | save_dict['w0s_mt'] = w0s_mt 125 | save_dict['tps'] = np.array(tps) 126 | np.save(args.save[:-4]+'_save_dict'+'.npy',save_dict) 127 | 128 | #Select the augmentation operation 129 | for data in loader_t: 130 | aug_types = [] 131 | for idl in range(1,len(loaders)): 132 | if random.random() > explore_ratio: 133 | tmp_type = (ops_num, select_op(aug_param, ops_num)) 134 | else: 135 | tmp_type = (ops_num, select_op(torch.zeros(len(aug_ohl_list)), ops_num)) 136 | aug_types.append(tmp_type) 137 | aug_probs = [] 138 | 139 | #Calculate the probability to select this augmentation operation! 140 | for aug_type in aug_types: 141 | idxs = aug_type[1] 142 | aug_probs.append(trace_prob(aug_param, idxs).item()) 143 | Z = sum(aug_probs) 144 | 145 | data_bacth = [copy.deepcopy(data) for _ in range(len(loaders))] 146 | grad_ls = [] 147 | gip_ls =torch.zeros(len(loaders)) 148 | 149 | if optimizer: 150 | optimizer.zero_grad() 151 | if optimizer_aug: 152 | optimizer_aug.zero_grad() 153 | if optimizer_tp: 154 | optimizer_tp.zero_grad() 155 | 156 | for idl in range(len(data_bacth)): 157 | data, label = data_bacth[idl] 158 | print(label) 159 | pil_imgs = [] 160 | 161 | #Transform the data to PIL forms! 162 | for nb in range(len(data)): 163 | pil_imgs.append(transforms.ToPILImage()(data[nb])) 164 | if idl > 0: 165 | rw_search.n = aug_types[idl - 1][0] 166 | rw_search.idxs = aug_types[idl - 1][1] 167 | 168 | #Do the selected augmentation 169 | for idp in range(len(pil_imgs)): 170 | if idl > 0: 171 | pil_imgs[idp] = rw_search(pil_imgs[idp]) 172 | pil_imgs[idp] = transform_or(pil_imgs[idp]).unsqueeze(0) 173 | data = torch.cat(pil_imgs) 174 | data_train, label_train = data.cuda(), label.cuda() 175 | preds_train = model(data_train) 176 | loss_train = loss_fn(preds_train, label_train) 177 | loss_T = torch.sum(loss_train) 178 | if idl == 0: 179 | preds_train0, label_train0, loss_T0 = preds_train, label_train, loss_T 180 | grads_T = torch.autograd.grad(loss_T, (model.parameters())) 181 | grad_ls.append(grads_T) 182 | del data_train, label_train, loss_train, preds_train,loss_T 183 | 184 | #Update the model parameters! 185 | grad_T = grad_ls[0] 186 | print("tp") 187 | print(torch.sigmoid(tp_param).item()) 188 | for gt, p in zip(grad_T, model.parameters()): 189 | p.grad = (1 - torch.sigmoid(tp_param).item()) * gt.data 190 | for idl in range(1,len(grad_ls)): 191 | for gt, p in zip(grad_ls[idl], model.parameters()): 192 | p.grad = p.grad + torch.sigmoid(tp_param).item() * aug_probs[idl - 1]/Z * gt.data 193 | if optimizer: 194 | optimizer.step() 195 | 196 | #Calculate the validation gradient 197 | data_val, label_val = next(val_loader) 198 | data_val, label_val = data_val.cuda(), label_val.cuda() 199 | preds_val = model(data_val) 200 | 201 | loss_V = loss_fn(preds_val, label_val).sum() 202 | grads_V = torch.autograd.grad(loss_V, (model.parameters())) 203 | del data_val, label_val, preds_val, loss_V 204 | 205 | #Calculate the inner product of gradients! 206 | for idl in range(len(data_bacth)): 207 | gip_ls[idl] = sum([torch.sum(gt*gv) for gt, gv in zip(grad_ls[idl], grads_V)]).data 208 | if idl == 0: 209 | gip0 = gip_ls[idl].data 210 | gip_ls[idl] = gip_ls[idl] - gip0 211 | 212 | gd_norm = torch.norm(gip_ls,p=1) 213 | print("gip_norm") 214 | print(gd_norm) 215 | 216 | #Update the augmentation parameters! 217 | for idl in range(1,len(loaders)): 218 | idxs = aug_types[idl - 1][1] 219 | trace_loss = -1 * gip_ls[idl].data.item() * torch.sigmoid(tp_param) * trace_prob(aug_param, idxs)/Z/gd_norm 220 | trace_loss.backward() 221 | print("current pop value!!!") 222 | print(torch.nn.Softmax()(aug_param).data.numpy()) 223 | optimizer_aug.step() 224 | optimizer_tp.step() 225 | 226 | del grads_V, grads_T 227 | 228 | top1, top5 = accuracy(preds_train0, label_train0, (1, 5)) 229 | metrics.add_dict({ 230 | 'loss': loss_T0.item() * len(data), 231 | 'top1': top1.item() * len(data), 232 | 'top5': top5.item() * len(data), 233 | }) 234 | cnt += len(data) 235 | if verbose: 236 | postfix = metrics / cnt 237 | if optimizer: 238 | postfix['lr'] = optimizer.param_groups[0]['lr'] 239 | loader_t.set_postfix(postfix) 240 | 241 | if scheduler is not None: 242 | scheduler.step(epoch - 1 + float(steps) / total_steps) 243 | 244 | del top1, top5 245 | del gip_ls, grad_ls 246 | 247 | steps += 1 248 | 249 | if tqdm_disable: 250 | if optimizer: 251 | logger.info('[%s %03d/%03d] %s lr=%.6f', desc_default, epoch, C.get()['epoch'], metrics / cnt, optimizer.param_groups[0]['lr']) 252 | else: 253 | logger.info('[%s %03d/%03d] %s', desc_default, epoch, C.get()['epoch'], metrics / cnt) 254 | logger.info('[%s %03d/%03d] %s', desc_default, epoch, C.get()['epoch'], metrics / cnt) 255 | metrics /= cnt 256 | if optimizer: 257 | metrics.metrics['lr'] = optimizer.param_groups[0]['lr'] 258 | if optimizer_aug: 259 | print("param learning rate") 260 | print(optimizer_aug.param_groups[0]['lr']) 261 | if verbose: 262 | for key, value in metrics.items(): 263 | writer.add_scalar(key, value, epoch) 264 | return metrics 265 | softmax = torch.nn.Softmax() 266 | 267 | def select_op(op_params, num_ops): 268 | prob = softmax(op_params) 269 | op_ids = torch.multinomial(prob, 2, replacement=True).tolist() 270 | return op_ids 271 | 272 | def trace_prob(op_params, op_ids): 273 | probs = softmax(op_params) 274 | tp = 1 275 | for idx in op_ids: 276 | tp = tp * probs[idx] 277 | return tp 278 | 279 | def train_and_eval(tag, dataroot, loader_num = 6, test_ratio=0.1, ops_num = 2, explore_ratio = 0, param_lr = 0.05, reporter=None, metric='last', save_path=None, only_eval=False,args = None): 280 | if not reporter: 281 | reporter = lambda **kwargs: 0 282 | 283 | max_epoch = C.get()['epoch'] 284 | aug_length = len(aug_ohl_list) 285 | 286 | #Initialize the augmentation parameters! 287 | tp_alpha = np.log(args.init_tp/(1-args.init_tp)) 288 | tp_param = torch.nn.Parameter(torch.ones(1,requires_grad=True) * tp_alpha,requires_grad=True) 289 | aug_param = torch.nn.Parameter(torch.zeros(aug_length,requires_grad=True),requires_grad=True) 290 | optimizer_aug = torch.optim.Adam((aug_param,),lr=param_lr, betas=(0.5, 0.999)) 291 | optimizer_tp = torch.optim.Adam((tp_param,),lr=args.tp_lr, betas=(0.5, 0.999)) 292 | 293 | trainsampler, validloader, testloader_ = get_val_test_dataloader(C.get()['dataset'], C.get()['batch'], dataroot, test_ratio) 294 | 295 | # create a model & an optimizer 296 | model = get_model(C.get()['model'], num_class(C.get()['dataset'])) 297 | 298 | lb_smooth = C.get()['optimizer'].get('label_smoothing', 0.0) 299 | if lb_smooth > 0.0: 300 | criterion = SmoothCrossEntropyLoss(lb_smooth,reduction='none') 301 | criterion_val = SmoothCrossEntropyLoss(lb_smooth) 302 | else: 303 | criterion = nn.CrossEntropyLoss(reduction='none') 304 | criterion_val = nn.CrossEntropyLoss() 305 | if C.get()['optimizer']['type'] == 'sgd': 306 | optimizer = optim.SGD( 307 | model.parameters(), 308 | lr=C.get()['lr'], 309 | momentum=0, 310 | weight_decay=0, 311 | #nesterov=C.get()['optimizer']['nesterov'] 312 | ) 313 | else: 314 | raise ValueError('invalid optimizer type=%s' % C.get()['optimizer']['type']) 315 | 316 | if C.get()['optimizer'].get('lars', False): 317 | from torchlars import LARS 318 | optimizer = LARS(optimizer) 319 | logger.info('*** LARS Enabled.') 320 | 321 | lr_scheduler_type = C.get()['lr_schedule'].get('type', 'cosine') 322 | if lr_scheduler_type == 'cosine': 323 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=C.get()['epoch'], eta_min=0.) 324 | elif lr_scheduler_type == 'resnet': 325 | scheduler = adjust_learning_rate_resnet(optimizer) 326 | else: 327 | raise ValueError('invalid lr_schduler=%s' % lr_scheduler_type) 328 | 329 | if C.get()['lr_schedule'].get('warmup', None): 330 | scheduler = GradualWarmupScheduler( 331 | optimizer, 332 | multiplier=C.get()['lr_schedule']['warmup']['multiplier'], 333 | total_epoch=C.get()['lr_schedule']['warmup']['epoch'], 334 | after_scheduler=scheduler 335 | ) 336 | 337 | # if not tag: 338 | # from RandAugment.metrics import SummaryWriterDummy as SummaryWriter 339 | # logger.warning('tag not provided, no tensorboard log.') 340 | # else: 341 | 342 | writers = [SummaryWriter(log_dir='./logs/%s/%s' % (tag, x)) for x in ['train', 'valid', 'test']] 343 | 344 | result = OrderedDict() 345 | epoch_start = 1 346 | if save_path and os.path.exists(save_path): 347 | logger.info('%s file found. loading...' % save_path) 348 | data = torch.load(save_path) 349 | if 'model' in data or 'state_dict' in data: 350 | key = 'model' if 'model' in data else 'state_dict' 351 | logger.info('checkpoint epoch@%d' % data['epoch']) 352 | if not isinstance(model, DataParallel): 353 | model.load_state_dict({k.replace('module.', ''): v for k, v in data[key].items()}) 354 | else: 355 | model.load_state_dict({k if 'module.' in k else 'module.'+k: v for k, v in data[key].items()}) 356 | optimizer.load_state_dict(data['optimizer']) 357 | if data['epoch'] < C.get()['epoch']: 358 | epoch_start = data['epoch'] 359 | else: 360 | only_eval = True 361 | else: 362 | model.load_state_dict({k: v for k, v in data.items()}) 363 | del data 364 | else: 365 | logger.info('"%s" file not found. skip to pretrain weights...' % save_path) 366 | if only_eval: 367 | logger.warning('model checkpoint not found. only-evaluation mode is off.') 368 | only_eval = False 369 | 370 | if only_eval: 371 | logger.info('evaluation only+') 372 | model.eval() 373 | rs = dict() 374 | rs['train'] = run_epoch(model, trainloader, criterion_val, None, desc_default='train', epoch=0, writer=writers[0]) 375 | rs['valid'] = run_epoch(model, validloader, criterion_val, None, desc_default='valid', epoch=0, writer=writers[1]) 376 | rs['test'] = run_epoch(model, testloader_, criterion_val, None, desc_default='*test', epoch=0, writer=writers[2]) 377 | for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'valid', 'test']): 378 | if setname not in rs: 379 | continue 380 | result['%s_%s' % (key, setname)] = rs[setname][key] 381 | result['epoch'] = 0 382 | return result 383 | 384 | # search loop 385 | best_top1 = 0 386 | dict_reward = {} 387 | w0s_at=[] 388 | w0s_mt=[] 389 | for epoch in range(epoch_start, max_epoch + 1): 390 | 391 | AugTypes=[(ops_num, select_op(aug_param, ops_num)) for _ in range(loader_num)] 392 | random.shuffle(trainsampler.indices) 393 | print(AugTypes) 394 | loaders = Get_DataLoaders_Epoch_s( 395 | C.get()['dataset'], C.get()['batch'], dataroot, trainsampler, AugTypes, loader_num = len(AugTypes)) 396 | for loader in loaders[1:]: 397 | print((loader.dataset.transform.transforms[0].n,loader.dataset.transform.transforms[0].idxs)) 398 | model.train() 399 | rs = dict() 400 | 401 | rs['train'] = run_epoch_search( 402 | model, loaders, validloader, criterion, optimizer,optimizer_aug, optimizer_tp, aug_param, tp_param, 403 | explore_ratio = explore_ratio, ops_num = ops_num, desc_default='train', epoch=epoch, 404 | writer=writers[0], verbose=True, scheduler=scheduler, dict_reward=dict_reward, w0s_at=w0s_at, w0s_mt = w0s_mt) 405 | model.eval() 406 | 407 | if math.isnan(rs['train']['loss']): 408 | raise Exception('train loss is NaN.') 409 | 410 | if epoch % 1 == 0 or epoch == max_epoch: 411 | rs['valid'] = run_epoch(model, validloader, criterion_val, None, desc_default='valid', epoch=epoch, writer=writers[1], verbose=True) 412 | rs['test'] = run_epoch(model, testloader_, criterion_val, None, desc_default='*test', epoch=epoch, writer=writers[2], verbose=True) 413 | 414 | if metric == 'last' or rs[metric]['top1'] > best_top1: 415 | if metric != 'last': 416 | best_top1 = rs[metric]['top1'] 417 | for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'valid', 'test']): 418 | result['%s_%s' % (key, setname)] = rs[setname][key] 419 | result['epoch'] = epoch 420 | 421 | writers[1].add_scalar('valid_top1/best', rs['valid']['top1'], epoch) 422 | writers[2].add_scalar('test_top1/best', rs['test']['top1'], epoch) 423 | 424 | reporter( 425 | loss_valid=rs['valid']['loss'], top1_valid=rs['valid']['top1'], 426 | loss_test=rs['test']['loss'], top1_test=rs['test']['top1'] 427 | ) 428 | 429 | # save checkpoint 430 | if save_path: 431 | logger.info('save model@%d to %s' % (epoch, save_path)) 432 | torch.save({ 433 | 'epoch': epoch, 434 | 'log': { 435 | 'train': rs['train'].get_dict(), 436 | 'valid': rs['valid'].get_dict(), 437 | 'test': rs['test'].get_dict(), 438 | }, 439 | 'optimizer': optimizer.state_dict(), 440 | 'model': model.state_dict() 441 | }, save_path) 442 | torch.save({ 443 | 'epoch': epoch, 444 | 'log': { 445 | 'train': rs['train'].get_dict(), 446 | 'valid': rs['valid'].get_dict(), 447 | 'test': rs['test'].get_dict(), 448 | }, 449 | 'optimizer': optimizer.state_dict(), 450 | 'model': model.state_dict() 451 | }, save_path.replace('.pth', '_e%d_top1_%.3f_%.3f' % (epoch, rs['train']['top1'], rs['test']['top1']) + '.pth')) 452 | 453 | del model 454 | 455 | result['top1_test'] = best_top1 456 | return result 457 | 458 | def setup_seed(seed): 459 | torch.manual_seed(seed) 460 | torch.cuda.manual_seed_all(seed) 461 | np.random.seed(seed) 462 | random.seed(seed) 463 | torch.backends.cudnn.deterministic = True 464 | # 设置随机数种子 465 | 466 | if __name__ == '__main__': 467 | parser = ConfigArgumentParser(conflict_handler='resolve') 468 | parser.add_argument('--tag', type=str, default='') 469 | parser.add_argument('--dataroot', type=str, default='/data/private/pretrainedmodels', help='torchvision data folder') 470 | parser.add_argument('--save', type=str, default='') 471 | parser.add_argument('--cv-ratio', type=float, default=0.1) 472 | parser.add_argument('--explore_ratio', type=float, default=0.2) 473 | parser.add_argument('--param_lr', type=float, default=0.005) 474 | parser.add_argument('--tp_lr', type=float, default=0.001) 475 | parser.add_argument('--init_tp', type=float, default=0.3) 476 | parser.add_argument('--cv', type=int, default=0) 477 | parser.add_argument('--loader_num', type=int, default=4) 478 | parser.add_argument('--ops_num', type=int, default=2) 479 | parser.add_argument('--only-eval', action='store_true') 480 | parser.add_argument('--rand-seed', type=int, default=20) 481 | args = parser.parse_args() 482 | 483 | assert (args.only_eval and args.save) or not args.only_eval, 'checkpoint path not provided in evaluation mode.' 484 | 485 | setup_seed(args.rand_seed) 486 | 487 | if not args.only_eval: 488 | if args.save: 489 | logger.info('checkpoint will be saved at %s' % args.save) 490 | else: 491 | logger.warning('Provide --save argument to save the checkpoint. Without it, training result will not be saved!') 492 | 493 | if args.save: 494 | add_filehandler(logger, args.save.replace('.pth', '') + '.log') 495 | 496 | logger.info(json.dumps(C.get().conf, indent=4)) 497 | 498 | import time 499 | t = time.time() 500 | result = train_and_eval( 501 | args.tag, args.dataroot, loader_num = args.loader_num, param_lr = args.param_lr, 502 | ops_num = args.ops_num, test_ratio=args.cv_ratio, save_path=args.save, 503 | only_eval=args.only_eval, explore_ratio = args.explore_ratio, metric='test',args = args) 504 | elapsed = time.time() - t 505 | 506 | logger.info('done.') 507 | logger.info('model: %s' % C.get()['model']) 508 | logger.info('augmentation: %s' % C.get()['aug']) 509 | logger.info('\n' + json.dumps(result, indent=4)) 510 | logger.info('elapsed time: %.3f Hours' % (elapsed / 3600.)) 511 | logger.info('top1 error in testset: %.4f' % (1. - result['top1_test'])) 512 | logger.info(args.save) 513 | -------------------------------------------------------------------------------- /shake_resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from shakeshake import ShakeShake 9 | from shakeshake import Shortcut 10 | 11 | 12 | class ShakeBlock(nn.Module): 13 | 14 | def __init__(self, in_ch, out_ch, stride=1): 15 | super(ShakeBlock, self).__init__() 16 | self.equal_io = in_ch == out_ch 17 | self.shortcut = None if self.equal_io else Shortcut(in_ch, out_ch, stride=stride) 18 | 19 | self.branch1 = self._make_branch(in_ch, out_ch, stride) 20 | self.branch2 = self._make_branch(in_ch, out_ch, stride) 21 | 22 | def forward(self, x): 23 | h1 = self.branch1(x) 24 | h2 = self.branch2(x) 25 | h = ShakeShake.apply(h1, h2, self.training) 26 | h0 = x if self.equal_io else self.shortcut(x) 27 | return h + h0 28 | 29 | def _make_branch(self, in_ch, out_ch, stride=1): 30 | return nn.Sequential( 31 | nn.ReLU(inplace=False), 32 | nn.Conv2d(in_ch, out_ch, 3, padding=1, stride=stride, bias=False), 33 | nn.BatchNorm2d(out_ch), 34 | nn.ReLU(inplace=False), 35 | nn.Conv2d(out_ch, out_ch, 3, padding=1, stride=1, bias=False), 36 | nn.BatchNorm2d(out_ch)) 37 | 38 | 39 | class ShakeResNet(nn.Module): 40 | 41 | def __init__(self, depth, w_base, label): 42 | super(ShakeResNet, self).__init__() 43 | n_units = (depth - 2) / 6 44 | 45 | in_chs = [16, w_base, w_base * 2, w_base * 4] 46 | self.in_chs = in_chs 47 | 48 | self.c_in = nn.Conv2d(3, in_chs[0], 3, padding=1) 49 | self.layer1 = self._make_layer(n_units, in_chs[0], in_chs[1]) 50 | self.layer2 = self._make_layer(n_units, in_chs[1], in_chs[2], 2) 51 | self.layer3 = self._make_layer(n_units, in_chs[2], in_chs[3], 2) 52 | self.fc_out = nn.Linear(in_chs[3], label) 53 | 54 | # Initialize paramters 55 | for m in self.modules(): 56 | if isinstance(m, nn.Conv2d): 57 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 58 | m.weight.data.normal_(0, math.sqrt(2. / n)) 59 | elif isinstance(m, nn.BatchNorm2d): 60 | m.weight.data.fill_(1) 61 | m.bias.data.zero_() 62 | elif isinstance(m, nn.Linear): 63 | m.bias.data.zero_() 64 | 65 | def forward(self, x): 66 | h = self.c_in(x) 67 | h = self.layer1(h) 68 | h = self.layer2(h) 69 | h = self.layer3(h) 70 | h = F.relu(h) 71 | h = F.avg_pool2d(h, 8) 72 | h = h.view(-1, self.in_chs[3]) 73 | h = self.fc_out(h) 74 | return h 75 | 76 | def _make_layer(self, n_units, in_ch, out_ch, stride=1): 77 | layers = [] 78 | for i in range(int(n_units)): 79 | layers.append(ShakeBlock(in_ch, out_ch, stride=stride)) 80 | in_ch, stride = out_ch, 1 81 | return nn.Sequential(*layers) 82 | -------------------------------------------------------------------------------- /shake_resnext.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from shakeshake import ShakeShake 9 | from shakeshake import Shortcut 10 | 11 | 12 | class ShakeBottleNeck(nn.Module): 13 | 14 | def __init__(self, in_ch, mid_ch, out_ch, cardinary, stride=1): 15 | super(ShakeBottleNeck, self).__init__() 16 | self.equal_io = in_ch == out_ch 17 | self.shortcut = None if self.equal_io else Shortcut(in_ch, out_ch, stride=stride) 18 | 19 | self.branch1 = self._make_branch(in_ch, mid_ch, out_ch, cardinary, stride) 20 | self.branch2 = self._make_branch(in_ch, mid_ch, out_ch, cardinary, stride) 21 | 22 | def forward(self, x): 23 | h1 = self.branch1(x) 24 | h2 = self.branch2(x) 25 | h = ShakeShake.apply(h1, h2, self.training) 26 | h0 = x if self.equal_io else self.shortcut(x) 27 | return h + h0 28 | 29 | def _make_branch(self, in_ch, mid_ch, out_ch, cardinary, stride=1): 30 | return nn.Sequential( 31 | nn.Conv2d(in_ch, mid_ch, 1, padding=0, bias=False), 32 | nn.BatchNorm2d(mid_ch), 33 | nn.ReLU(inplace=False), 34 | nn.Conv2d(mid_ch, mid_ch, 3, padding=1, stride=stride, groups=cardinary, bias=False), 35 | nn.BatchNorm2d(mid_ch), 36 | nn.ReLU(inplace=False), 37 | nn.Conv2d(mid_ch, out_ch, 1, padding=0, bias=False), 38 | nn.BatchNorm2d(out_ch)) 39 | 40 | 41 | class ShakeResNeXt(nn.Module): 42 | 43 | def __init__(self, depth, w_base, cardinary, label): 44 | super(ShakeResNeXt, self).__init__() 45 | n_units = (depth - 2) // 9 46 | n_chs = [64, 128, 256, 1024] 47 | self.n_chs = n_chs 48 | self.in_ch = n_chs[0] 49 | 50 | self.c_in = nn.Conv2d(3, n_chs[0], 3, padding=1) 51 | self.layer1 = self._make_layer(n_units, n_chs[0], w_base, cardinary) 52 | self.layer2 = self._make_layer(n_units, n_chs[1], w_base, cardinary, 2) 53 | self.layer3 = self._make_layer(n_units, n_chs[2], w_base, cardinary, 2) 54 | self.fc_out = nn.Linear(n_chs[3], label) 55 | 56 | # Initialize paramters 57 | for m in self.modules(): 58 | if isinstance(m, nn.Conv2d): 59 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 60 | m.weight.data.normal_(0, math.sqrt(2. / n)) 61 | elif isinstance(m, nn.BatchNorm2d): 62 | m.weight.data.fill_(1) 63 | m.bias.data.zero_() 64 | elif isinstance(m, nn.Linear): 65 | m.bias.data.zero_() 66 | 67 | def forward(self, x): 68 | h = self.c_in(x) 69 | h = self.layer1(h) 70 | h = self.layer2(h) 71 | h = self.layer3(h) 72 | h = F.relu(h) 73 | h = F.avg_pool2d(h, 8) 74 | h = h.view(-1, self.n_chs[3]) 75 | h = self.fc_out(h) 76 | return h 77 | 78 | def _make_layer(self, n_units, n_ch, w_base, cardinary, stride=1): 79 | layers = [] 80 | mid_ch, out_ch = n_ch * (w_base // 64) * cardinary, n_ch * 4 81 | for i in range(n_units): 82 | layers.append(ShakeBottleNeck(self.in_ch, mid_ch, out_ch, cardinary, stride=stride)) 83 | self.in_ch, stride = out_ch, 1 84 | return nn.Sequential(*layers) 85 | -------------------------------------------------------------------------------- /shakedrop.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | class ShakeDropFunction(torch.autograd.Function): 10 | 11 | @staticmethod 12 | def forward(ctx, x, training=True, p_drop=0.5, alpha_range=[-1, 1]): 13 | if training: 14 | gate = torch.cuda.FloatTensor([0]).bernoulli_(1 - p_drop) 15 | ctx.save_for_backward(gate) 16 | if gate.item() == 0: 17 | alpha = torch.cuda.FloatTensor(x.size(0)).uniform_(*alpha_range) 18 | alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x) 19 | return alpha * x 20 | else: 21 | return x 22 | else: 23 | return (1 - p_drop) * x 24 | 25 | @staticmethod 26 | def backward(ctx, grad_output): 27 | gate = ctx.saved_tensors[0] 28 | if gate.item() == 0: 29 | beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_(0, 1) 30 | beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output) 31 | beta = Variable(beta) 32 | return beta * grad_output, None, None, None 33 | else: 34 | return grad_output, None, None, None 35 | 36 | 37 | class ShakeDrop(nn.Module): 38 | 39 | def __init__(self, p_drop=0.5, alpha_range=[-1, 1]): 40 | super(ShakeDrop, self).__init__() 41 | self.p_drop = p_drop 42 | self.alpha_range = alpha_range 43 | 44 | def forward(self, x): 45 | return ShakeDropFunction.apply(x, self.training, self.p_drop, self.alpha_range) 46 | -------------------------------------------------------------------------------- /shakeshake.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | class ShakeShake(torch.autograd.Function): 10 | 11 | @staticmethod 12 | def forward(ctx, x1, x2, training=True): 13 | if training: 14 | alpha = torch.cuda.FloatTensor(x1.size(0)).uniform_() 15 | alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x1) 16 | else: 17 | alpha = 0.5 18 | return alpha * x1 + (1 - alpha) * x2 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output): 22 | beta = torch.cuda.FloatTensor(grad_output.size(0)).uniform_() 23 | beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output) 24 | beta = Variable(beta) 25 | 26 | return beta * grad_output, (1 - beta) * grad_output, None 27 | 28 | 29 | class Shortcut(nn.Module): 30 | 31 | def __init__(self, in_ch, out_ch, stride): 32 | super(Shortcut, self).__init__() 33 | self.stride = stride 34 | self.conv1 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False) 35 | self.conv2 = nn.Conv2d(in_ch, out_ch // 2, 1, stride=1, padding=0, bias=False) 36 | self.bn = nn.BatchNorm2d(out_ch) 37 | 38 | def forward(self, x): 39 | h = F.relu(x) 40 | 41 | h1 = F.avg_pool2d(h, 1, self.stride) 42 | h1 = self.conv1(h1) 43 | 44 | h2 = F.avg_pool2d(F.pad(h, (-1, 1, -1, 1)), 1, self.stride) 45 | h2 = self.conv2(h2) 46 | 47 | h = torch.cat((h1, h2), 1) 48 | return self.bn(h) 49 | -------------------------------------------------------------------------------- /smooth_ce.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn.modules.module import Module 4 | 5 | 6 | class SmoothCrossEntropyLoss(Module): 7 | def __init__(self, label_smoothing=0.0, size_average=True, reduction='mean'): 8 | super().__init__() 9 | self.label_smoothing = label_smoothing 10 | self.size_average = size_average 11 | self.reduction = reduction 12 | 13 | def forward(self, input, target): 14 | if len(target.size()) == 1: 15 | target = torch.nn.functional.one_hot(target, num_classes=input.size(-1)) 16 | target = target.float().cuda() 17 | if self.label_smoothing > 0.0: 18 | s_by_c = self.label_smoothing / len(input[0]) 19 | smooth = torch.zeros_like(target) 20 | smooth = smooth + s_by_c 21 | target = target * (1. - s_by_c) + smooth 22 | 23 | return cross_entropy(input, target, self.size_average, self.reduction) 24 | 25 | 26 | def cross_entropy(input, target, size_average=True, reduction='mean'): 27 | """ Cross entropy that accepts soft targets 28 | Args: 29 | pred: predictions for neural network 30 | targets: targets, can be soft 31 | size_average: if false, sum is returned instead of mean 32 | Examples:: 33 | input = torch.FloatTensor([[1.1, 2.8, 1.3], [1.1, 2.1, 4.8]]) 34 | input = torch.autograd.Variable(out, requires_grad=True) 35 | target = torch.FloatTensor([[0.05, 0.9, 0.05], [0.05, 0.05, 0.9]]) 36 | target = torch.autograd.Variable(y1) 37 | loss = cross_entropy(input, target) 38 | loss.backward() 39 | """ 40 | logsoftmax = torch.nn.LogSoftmax(dim=1) 41 | if reduction == 'none': 42 | return torch.sum(-target * logsoftmax(input), dim=1) 43 | if size_average: 44 | return torch.mean(torch.sum(-target * logsoftmax(input), dim=1)) 45 | else: 46 | return torch.sum(torch.sum(-target * logsoftmax(input), dim=1)) 47 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import logging 4 | import math 5 | import os 6 | from collections import OrderedDict 7 | import numpy as np 8 | 9 | import torch 10 | from torch import nn, optim 11 | from torch.nn.parallel.data_parallel import DataParallel 12 | 13 | from tqdm import tqdm 14 | from theconf import Config as C, ConfigArgumentParser 15 | 16 | from common import get_logger 17 | from data import get_dataloaders 18 | from lr_scheduler import adjust_learning_rate_resnet 19 | from metrics import accuracy, Accumulator 20 | from networks import get_model, num_class 21 | from warmup_scheduler import GradualWarmupScheduler 22 | 23 | from common import add_filehandler 24 | from smooth_ce import SmoothCrossEntropyLoss 25 | 26 | logger = get_logger('RandAugment') 27 | logger.setLevel(logging.INFO) 28 | 29 | 30 | def run_epoch(model, loader, loss_fn, optimizer, desc_default='', epoch=0, writer=None, verbose=1, scheduler=None): 31 | tqdm_disable = bool(os.environ.get('TASK_NAME', '')) # KakaoBrain Environment 32 | if verbose: 33 | loader = tqdm(loader, disable=tqdm_disable) 34 | loader.set_description('[%s %04d/%04d]' % (desc_default, epoch, C.get()['epoch'])) 35 | 36 | metrics = Accumulator() 37 | cnt = 0 38 | total_steps = len(loader) 39 | steps = 0 40 | for data, label in loader: 41 | steps += 1 42 | data, label = data.cuda(), label.cuda() 43 | 44 | if optimizer: 45 | optimizer.zero_grad() 46 | 47 | preds = model(data) 48 | loss = loss_fn(preds, label) 49 | 50 | if optimizer: 51 | loss.backward() 52 | if C.get()['optimizer'].get('clip', 5) > 0: 53 | nn.utils.clip_grad_norm_(model.parameters(), C.get()['optimizer'].get('clip', 5)) 54 | optimizer.step() 55 | 56 | top1, top5 = accuracy(preds, label, (1, 5)) 57 | metrics.add_dict({ 58 | 'loss': loss.item() * len(data), 59 | 'top1': top1.item() * len(data), 60 | 'top5': top5.item() * len(data), 61 | }) 62 | cnt += len(data) 63 | if verbose: 64 | postfix = metrics / cnt 65 | if optimizer: 66 | postfix['lr'] = optimizer.param_groups[0]['lr'] 67 | loader.set_postfix(postfix) 68 | 69 | if scheduler is not None: 70 | scheduler.step(epoch - 1 + float(steps) / total_steps) 71 | 72 | del preds, loss, top1, top5, data, label 73 | 74 | if tqdm_disable: 75 | if optimizer: 76 | logger.info('[%s %03d/%03d] %s lr=%.6f', desc_default, epoch, C.get()['epoch'], metrics / cnt, optimizer.param_groups[0]['lr']) 77 | else: 78 | logger.info('[%s %03d/%03d] %s', desc_default, epoch, C.get()['epoch'], metrics / cnt) 79 | logger.info('[%s %03d/%03d] %s', desc_default, epoch, C.get()['epoch'], metrics / cnt) 80 | metrics /= cnt 81 | if optimizer: 82 | metrics.metrics['lr'] = optimizer.param_groups[0]['lr'] 83 | if verbose: 84 | for key, value in metrics.items(): 85 | writer.add_scalar(key, value, epoch) 86 | return metrics 87 | 88 | 89 | def train_and_eval(tag, dataroot, test_ratio=0.0, cv_fold=0, reporter=None, metric='last', save_path=None, only_eval=False, reduct_factor=1.0, args = None): 90 | if not reporter: 91 | reporter = lambda **kwargs: 0 92 | 93 | max_epoch = C.get()['epoch'] 94 | trainsampler, trainloader, validloader, testloader_ = get_dataloaders(C.get()['dataset'], C.get()['batch'], dataroot, test_ratio, split_idx=cv_fold) 95 | 96 | # create a model & an optimizer 97 | model = get_model(C.get()['model'], num_class(C.get()['dataset'])) 98 | 99 | lb_smooth = C.get()['optimizer'].get('label_smoothing', 0.0) 100 | if lb_smooth > 0.0: 101 | criterion = SmoothCrossEntropyLoss(lb_smooth) 102 | else: 103 | criterion = nn.CrossEntropyLoss() 104 | if C.get()['optimizer']['type'] == 'sgd': 105 | optimizer = optim.SGD( 106 | model.parameters(), 107 | lr=C.get()['lr'], 108 | momentum=C.get()['optimizer'].get('momentum', 0.9), 109 | weight_decay=C.get()['optimizer']['decay'], 110 | nesterov=C.get()['optimizer']['nesterov'] 111 | ) 112 | else: 113 | raise ValueError('invalid optimizer type=%s' % C.get()['optimizer']['type']) 114 | 115 | if C.get()['optimizer'].get('lars', False): 116 | from torchlars import LARS 117 | optimizer = LARS(optimizer) 118 | logger.info('*** LARS Enabled.') 119 | 120 | lr_scheduler_type = C.get()['lr_schedule'].get('type', 'cosine') 121 | if lr_scheduler_type == 'cosine': 122 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=C.get()['epoch'], eta_min=0.) 123 | elif lr_scheduler_type == 'resnet': 124 | scheduler = adjust_learning_rate_resnet(optimizer) 125 | else: 126 | raise ValueError('invalid lr_schduler=%s' % lr_scheduler_type) 127 | 128 | if C.get()['lr_schedule'].get('warmup', None): 129 | scheduler = GradualWarmupScheduler( 130 | optimizer, 131 | multiplier=C.get()['lr_schedule']['warmup']['multiplier'], 132 | total_epoch=C.get()['lr_schedule']['warmup']['epoch'], 133 | after_scheduler=scheduler 134 | ) 135 | if not tag: 136 | from RandAugment.metrics import SummaryWriterDummy as SummaryWriter 137 | logger.warning('tag not provided, no tensorboard log.') 138 | else: 139 | from tensorboardX import SummaryWriter 140 | writers = [SummaryWriter(log_dir='./logs/%s/%s' % (tag, x)) for x in ['train', 'valid', 'test']] 141 | 142 | result = OrderedDict() 143 | epoch_start = 1 144 | if save_path and os.path.exists(save_path): 145 | logger.info('%s file found. loading...' % save_path) 146 | data = torch.load(save_path) 147 | if 'model' in data or 'state_dict' in data: 148 | key = 'model' if 'model' in data else 'state_dict' 149 | logger.info('checkpoint epoch@%d' % data['epoch']) 150 | if not isinstance(model, DataParallel): 151 | model.load_state_dict({k.replace('module.', ''): v for k, v in data[key].items()}) 152 | else: 153 | model.load_state_dict({k if 'module.' in k else 'module.'+k: v for k, v in data[key].items()}) 154 | optimizer.load_state_dict(data['optimizer']) 155 | if data['epoch'] < C.get()['epoch']: 156 | epoch_start = data['epoch'] 157 | else: 158 | only_eval = True 159 | else: 160 | model.load_state_dict({k: v for k, v in data.items()}) 161 | del data 162 | else: 163 | logger.info('"%s" file not found. skip to pretrain weights...' % save_path) 164 | if only_eval: 165 | logger.warning('model checkpoint not found. only-evaluation mode is off.') 166 | only_eval = False 167 | 168 | if only_eval: 169 | logger.info('evaluation only+') 170 | model.eval() 171 | rs = dict() 172 | rs['train'] = run_epoch(model, trainloader, criterion, None, desc_default='train', epoch=0, writer=writers[0]) 173 | rs['valid'] = run_epoch(model, validloader, criterion, None, desc_default='valid', epoch=0, writer=writers[1]) 174 | rs['test'] = run_epoch(model, testloader_, criterion, None, desc_default='*test', epoch=0, writer=writers[2]) 175 | for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'valid', 'test']): 176 | if setname not in rs: 177 | continue 178 | result['%s_%s' % (key, setname)] = rs[setname][key] 179 | result['epoch'] = 0 180 | return result 181 | 182 | # train loop 183 | best_top1 = 0 184 | flag_load = 1 185 | # print(th_ls) 186 | for epoch in range(epoch_start, max_epoch + 1): 187 | if args.load_tp == 'none': 188 | break 189 | else: 190 | if flag_load == 1: 191 | prob_dict = np.load(args.load_tp,allow_pickle=True).item() 192 | dis_ps = prob_dict['dis_ps'] 193 | max_probs = prob_dict['w0s_mt'] 194 | print((len(dis_ps), len(max_probs))) 195 | th_ls = max_probs 196 | flag_load = 0 197 | th_epoch = args.mul * th_ls[int(epoch/((max_epoch+0.1)/len(th_ls)))] 198 | trainloader.dataset.transform.transforms[0].p = dis_ps[int(epoch/((max_epoch+0.1)/len(th_ls)))] 199 | print(trainloader.dataset.transform.transforms[0].p) 200 | trainloader.dataset.transform.transforms[0].th = th_epoch 201 | print(trainloader.dataset.transform.transforms[0].th) 202 | model.train() 203 | rs = dict() 204 | rs['train'] = run_epoch(model, trainloader, criterion, optimizer, desc_default='train', epoch=epoch, writer=writers[0], verbose=True, scheduler=scheduler) 205 | model.eval() 206 | 207 | if math.isnan(rs['train']['loss']): 208 | raise Exception('train loss is NaN.') 209 | 210 | if epoch % 1 == 0 or epoch == max_epoch: 211 | rs['valid'] = run_epoch(model, validloader, criterion, None, desc_default='valid', epoch=epoch, writer=writers[1], verbose=True) 212 | rs['test'] = run_epoch(model, testloader_, criterion, None, desc_default='*test', epoch=epoch, writer=writers[2], verbose=True) 213 | 214 | if metric == 'last' or rs[metric]['top1'] > best_top1: 215 | if metric != 'last': 216 | best_top1 = rs[metric]['top1'] 217 | for key, setname in itertools.product(['loss', 'top1', 'top5'], ['train', 'valid', 'test']): 218 | result['%s_%s' % (key, setname)] = rs[setname][key] 219 | result['epoch'] = epoch 220 | 221 | writers[1].add_scalar('valid_top1/best', rs['valid']['top1'], epoch) 222 | writers[2].add_scalar('test_top1/best', rs['test']['top1'], epoch) 223 | 224 | reporter( 225 | loss_valid=rs['valid']['loss'], top1_valid=rs['valid']['top1'], 226 | loss_test=rs['test']['loss'], top1_test=rs['test']['top1'] 227 | ) 228 | 229 | # save checkpoint 230 | if save_path: 231 | logger.info('save model@%d to %s' % (epoch, save_path)) 232 | torch.save({ 233 | 'epoch': epoch, 234 | 'log': { 235 | 'train': rs['train'].get_dict(), 236 | 'valid': rs['valid'].get_dict(), 237 | 'test': rs['test'].get_dict(), 238 | }, 239 | 'optimizer': optimizer.state_dict(), 240 | 'model': model.state_dict() 241 | }, save_path) 242 | #torch.save({ 243 | # 'epoch': epoch, 244 | # 'log': { 245 | # 'train': rs['train'].get_dict(), 246 | # 'valid': rs['valid'].get_dict(), 247 | # 'test': rs['test'].get_dict(), 248 | # }, 249 | # 'optimizer': optimizer.state_dict(), 250 | # 'model': model.state_dict() 251 | #}, save_path.replace('.pth', '_e%d_top1_%.3f_%.3f' % (epoch, rs['train']['top1'], rs['test']['top1']) + '.pth')) 252 | 253 | del model 254 | 255 | result['top1_test'] = best_top1 256 | return result 257 | 258 | 259 | if __name__ == '__main__': 260 | parser = ConfigArgumentParser(conflict_handler='resolve') 261 | parser.add_argument('--tag', type=str, default='') 262 | parser.add_argument('--dataroot', type=str, default='/data/private/pretrainedmodels', help='torchvision data folder') 263 | parser.add_argument('--save', type=str, default='') 264 | parser.add_argument('--rf', type=float, default=2.0) 265 | parser.add_argument('--cv-ratio', type=float, default=0.0) 266 | parser.add_argument('--mul', type=float, default=1) 267 | parser.add_argument('--sqrt', type=float, default=1) 268 | parser.add_argument('--cv', type=int, default=0) 269 | parser.add_argument('--load_tp', type=str, default='none') 270 | parser.add_argument('--only-eval', action='store_true') 271 | args = parser.parse_args() 272 | 273 | assert (args.only_eval and args.save) or not args.only_eval, 'checkpoint path not provided in evaluation mode.' 274 | 275 | if not args.only_eval: 276 | if args.save: 277 | logger.info('checkpoint will be saved at %s' % args.save) 278 | else: 279 | logger.warning('Provide --save argument to save the checkpoint. Without it, training result will not be saved!') 280 | 281 | if args.save: 282 | add_filehandler(logger, args.save.replace('.pth', '') + '.log') 283 | 284 | logger.info(json.dumps(C.get().conf, indent=4)) 285 | 286 | import time 287 | t = time.time() 288 | result = train_and_eval(args.tag, args.dataroot, test_ratio=args.cv_ratio, cv_fold=args.cv, save_path=args.save, only_eval=args.only_eval, metric='test',reduct_factor = args.rf, args = args) 289 | elapsed = time.time() - t 290 | 291 | logger.info('done.') 292 | logger.info('model: %s' % C.get()['model']) 293 | logger.info('augmentation: %s' % C.get()['aug']) 294 | logger.info('\n' + json.dumps(result, indent=4)) 295 | logger.info('elapsed time: %.3f Hours' % (elapsed / 3600.)) 296 | logger.info('top1 error in testset: %.4f' % (1. - result['top1_test'])) 297 | logger.info(args.save) 298 | -------------------------------------------------------------------------------- /wideresnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.init as init 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | _bn_momentum = 0.1 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 12 | 13 | 14 | def conv_init(m): 15 | classname = m.__class__.__name__ 16 | if classname.find('Conv') != -1: 17 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 18 | init.constant_(m.bias, 0) 19 | elif classname.find('BatchNorm') != -1: 20 | init.constant_(m.weight, 1) 21 | init.constant_(m.bias, 0) 22 | 23 | 24 | class WideBasic(nn.Module): 25 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 26 | super(WideBasic, self).__init__() 27 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=_bn_momentum) 28 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 29 | self.dropout = nn.Dropout(p=dropout_rate) 30 | self.bn2 = nn.BatchNorm2d(planes, momentum=_bn_momentum) 31 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 32 | 33 | self.shortcut = nn.Sequential() 34 | if stride != 1 or in_planes != planes: 35 | self.shortcut = nn.Sequential( 36 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 37 | ) 38 | 39 | def forward(self, x): 40 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 41 | out = self.conv2(F.relu(self.bn2(out))) 42 | out += self.shortcut(x) 43 | 44 | return out 45 | 46 | 47 | class WideResNet(nn.Module): 48 | def __init__(self, depth, widen_factor, dropout_rate, num_classes): 49 | super(WideResNet, self).__init__() 50 | self.in_planes = 16 51 | 52 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 53 | n = int((depth - 4) / 6) 54 | k = widen_factor 55 | 56 | nStages = [16, 16*k, 32*k, 64*k] 57 | 58 | self.conv1 = conv3x3(3, nStages[0]) 59 | self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1) 60 | self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2) 61 | self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2) 62 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=_bn_momentum) 63 | self.linear = nn.Linear(nStages[3], num_classes) 64 | 65 | # self.apply(conv_init) 66 | 67 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 68 | strides = [stride] + [1]*(num_blocks-1) 69 | layers = [] 70 | 71 | for stride in strides: 72 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 73 | self.in_planes = planes 74 | 75 | return nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | out = self.conv1(x) 79 | out = self.layer1(out) 80 | out = self.layer2(out) 81 | out = self.layer3(out) 82 | out = F.relu(self.bn1(out)) 83 | # out = F.avg_pool2d(out, 8) 84 | out = F.adaptive_avg_pool2d(out, (1, 1)) 85 | out = out.view(out.size(0), -1) 86 | out = self.linear(out) 87 | 88 | return out 89 | -------------------------------------------------------------------------------- /wresnet40x2_cifar100_new_search_smoothed.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxcvfd13502/DDAS_code/b538eeaf57e803a39a8b1dbc8fd768253064434f/wresnet40x2_cifar100_new_search_smoothed.npy -------------------------------------------------------------------------------- /wresnet40x2_cifar10_new_search_smoothed.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxcvfd13502/DDAS_code/b538eeaf57e803a39a8b1dbc8fd768253064434f/wresnet40x2_cifar10_new_search_smoothed.npy --------------------------------------------------------------------------------