├── .gitignore ├── LICENSE ├── README.md ├── ada_aug ├── adaptive_augmentor.py ├── config.py ├── dataset.py ├── networks │ ├── __init__.py │ ├── projection.py │ ├── resnet.py │ └── wideresnet.py ├── operation.py ├── search.py ├── train.py └── utils.py ├── requirements.txt └── scripts ├── search.sh └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | #data 2 | data/ 3 | ada_aug/data/ 4 | search/ 5 | eval/ 6 | ada_aug/__pycache__/ 7 | debug/ 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | *.DS_Store 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 jamestszhim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### AdaAug 2 | AdaAug: Learning class- and instance-adaptive augmentation policies. 3 | 4 | ### Table of Contents 5 | 6 | 1. [Introduction](#introduction) 7 | 2. [Getting Started](#getting-started) 8 | 3. [Run Search](#run-adaaug-search) 9 | 4. [Run Training](#run-adaaug-training) 10 | 5. [Citation](#citation) 11 | 6. [References & Opensources](#references-&-opensources) 12 | 13 | ### Introduction 14 | 15 | AdaAug is a framework that finds class- and instance-adaptive data augmentation policies to augment a given dataset. 16 | 17 | This repository contains code for the work "AdaAug: Learning class- and instance-adaptive data augmentation policies" (https://openreview.net/forum?id=rWXfFogxRJN) implemented using the PyTorch library. 18 | 19 | ### Getting Started 20 | Code supports Python 3. 21 | 22 | #### Install requirements 23 | 24 | ```shell 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ### Run AdaAug search 29 | Script to search for the augmentation policy for is located in `scripts/search.sh`. Pass the dataset name as the argument to call the script. 30 | 31 | For example, to search for the augmentation policy for reduced_svhn dataset: 32 | 33 | ```shell 34 | bash scripts/search.sh reduced_svhn 35 | ``` 36 | 37 | The training log and candidate policies of the search will be output to the `./search` directory. 38 | 39 | ### Run AdaAug training 40 | To use the searched policy, paste the path of the g_model and h_model as the G and H variables respectively in `scripts/train.sh`. The path should look like this (./search/...). Then, pass the dataset name as the argument to call the script located in `scripts/train.sh`. The results will be output to the `./eval` directory 41 | 42 | ```shell 43 | bash scripts/train.sh reduced_svhn 44 | ``` 45 | 46 | ### Citation 47 | If you use this code in your research, please cite our paper. 48 | ``` 49 | @inproceedings{cheung2022adaaug, 50 | title = {AdaAug: Learning class and instance-adaptive data augmentation policies}, 51 | author = {Tsz-Him Cheung and Dit-Yan Yeung}, 52 | booktitle = {International Conference on Learning Representations}, 53 | year = {2022}, 54 | url = {https://openreview.net/forum?id=rWXfFogxRJN} 55 | } 56 | ``` 57 | 58 | ### References & Opensources 59 | Part of our implementation is adopted from the Fast AutoAugment and DADA repositories. 60 | - Fast AutoAugment (https://github.com/kakaobrain/fast-autoaugment) 61 | - DADA (https://github.com/VDIGPKU/DADA) 62 | -------------------------------------------------------------------------------- /ada_aug/adaptive_augmentor.py: -------------------------------------------------------------------------------- 1 | import random 2 | from cv2 import magnitude 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision import transforms 8 | 9 | 10 | from operation import apply_augment 11 | from networks import get_model 12 | from utils import PolicyHistory 13 | from config import OPS_NAMES 14 | 15 | default_config = {'sampling': 'prob', 16 | 'k_ops': 1, 17 | 'delta': 0, 18 | 'temp': 1.0, 19 | 'search_d': 32, 20 | 'target_d': 32} 21 | 22 | def perturb_param(param, delta): 23 | if delta <= 0: 24 | return param 25 | 26 | amt = random.uniform(0, delta) 27 | if random.random() < 0.5: 28 | return max(0, param-amt) 29 | else: 30 | return min(1, param+amt) 31 | 32 | def stop_gradient(trans_image, magnitude): 33 | images = trans_image 34 | adds = 0 35 | 36 | images = images - magnitude 37 | adds = adds + magnitude 38 | images = images.detach() + adds 39 | return images 40 | 41 | class AdaAug(nn.Module): 42 | def __init__(self, after_transforms, n_class, gf_model, h_model, save_dir=None, 43 | config=default_config): 44 | super(AdaAug, self).__init__() 45 | self.ops_names = OPS_NAMES 46 | self.n_ops = len(self.ops_names) 47 | self.after_transforms = after_transforms 48 | self.save_dir = save_dir 49 | self.gf_model = gf_model 50 | self.h_model = h_model 51 | self.n_class = n_class 52 | self.resize = config['search_d'] != config['target_d'] 53 | self.search_d = config['search_d'] 54 | self.k_ops = config['k_ops'] 55 | self.sampling = config['sampling'] 56 | self.temp = config['temp'] 57 | self.delta = config['delta'] 58 | self.history = PolicyHistory(self.ops_names, self.save_dir, self.n_class) 59 | 60 | def save_history(self, class2label=None): 61 | self.history.save(class2label) 62 | 63 | def plot_history(self): 64 | return self.history.plot() 65 | 66 | def predict_aug_params(self, images, mode): 67 | self.gf_model.eval() 68 | if mode == 'exploit': 69 | self.h_model.eval() 70 | T = self.temp 71 | elif mode == 'explore': 72 | self.h_model.train() 73 | T = 1.0 74 | a_params = self.h_model(self.gf_model.f(images.cuda())) 75 | magnitudes, weights = torch.split(a_params, self.n_ops, dim=1) 76 | magnitudes = torch.sigmoid(magnitudes) 77 | weights = torch.nn.functional.softmax(weights/T, dim=-1) 78 | return magnitudes, weights 79 | 80 | def add_history(self, images, targets): 81 | magnitudes, weights = self.predict_aug_params(images, 'exploit') 82 | for k in range(self.n_class): 83 | idxs = (targets == k).nonzero().squeeze() 84 | mean_lambda = magnitudes[idxs].mean(0).detach().cpu().tolist() 85 | mean_p = weights[idxs].mean(0).detach().cpu().tolist() 86 | std_lambda = magnitudes[idxs].std(0).detach().cpu().tolist() 87 | std_p = weights[idxs].std(0).detach().cpu().tolist() 88 | self.history.add(k, mean_lambda, mean_p, std_lambda, std_p) 89 | 90 | def get_aug_valid_imgs(self, images, magnitudes): 91 | """Return the mixed latent feature 92 | 93 | Args: 94 | images ([Tensor]): [description] 95 | magnitudes ([Tensor]): [description] 96 | Returns: 97 | [Tensor]: a set of augmented validation images 98 | """ 99 | trans_image_list = [] 100 | for i, image in enumerate(images): 101 | pil_img = transforms.ToPILImage()(image) 102 | # Prepare transformed image for mixing 103 | for k, ops_name in enumerate(self.ops_names): 104 | trans_image = apply_augment(pil_img, ops_name, magnitudes[i][k]) 105 | trans_image = self.after_transforms(trans_image) 106 | trans_image = stop_gradient(trans_image.cuda(), magnitudes[i][k]) 107 | trans_image_list.append(trans_image) 108 | return torch.stack(trans_image_list, dim=0) 109 | 110 | def explore(self, images): 111 | """Return the mixed latent feature 112 | 113 | Args: 114 | images ([Tensor]): [description] 115 | Returns: 116 | [Tensor]: return a batch of mixed features 117 | """ 118 | magnitudes, weights = self.predict_aug_params(images, 'explore') 119 | a_imgs = self.get_aug_valid_imgs(images, magnitudes) 120 | a_features = self.gf_model.f(a_imgs) 121 | ba_features = a_features.reshape(len(images), self.n_ops, -1) 122 | 123 | mixed_features = [w.matmul(feat) for w, feat in zip(weights, ba_features)] 124 | mixed_features = torch.stack(mixed_features, dim=0) 125 | return mixed_features 126 | 127 | def get_training_aug_images(self, images, magnitudes, weights): 128 | # visualization 129 | if self.k_ops > 0: 130 | trans_images = [] 131 | if self.sampling == 'prob': 132 | idx_matrix = torch.multinomial(weights, self.k_ops) 133 | elif self.sampling == 'max': 134 | idx_matrix = torch.topk(weights, self.k_ops, dim=1)[1] 135 | 136 | for i, image in enumerate(images): 137 | pil_image = transforms.ToPILImage()(image) 138 | for idx in idx_matrix[i]: 139 | m_pi = perturb_param(magnitudes[i][idx], self.delta) 140 | pil_image = apply_augment(pil_image, self.ops_names[idx], m_pi) 141 | trans_images.append(self.after_transforms(pil_image)) 142 | else: 143 | trans_images = [] 144 | for i, image in enumerate(images): 145 | pil_image = transforms.ToPILImage()(image) 146 | trans_image = self.after_transforms(pil_image) 147 | trans_images.append(trans_image) 148 | 149 | aug_imgs = torch.stack(trans_images, dim=0).cuda() 150 | return aug_imgs 151 | 152 | def exploit(self, images): 153 | resize_imgs = F.interpolate(images, size=self.search_d) if self.resize else images 154 | magnitudes, weights = self.predict_aug_params(resize_imgs, 'exploit') 155 | aug_imgs = self.get_training_aug_images(images, magnitudes, weights) 156 | return aug_imgs 157 | 158 | def forward(self, images, mode): 159 | if mode == 'explore': 160 | # return a set of mixed augmented features 161 | return self.explore(images) 162 | elif mode == 'exploit': 163 | # return a set of augmented images 164 | return self.exploit(images) 165 | elif mode == 'inference': 166 | return images 167 | -------------------------------------------------------------------------------- /ada_aug/config.py: -------------------------------------------------------------------------------- 1 | OPS_NAMES = ['ShearX', 2 | 'ShearY', 3 | 'TranslateX', 4 | 'TranslateY', 5 | 'Rotate', 6 | 'AutoContrast', 7 | 'Invert', 8 | 'Equalize', 9 | 'Solarize', 10 | 'Posterize', 11 | 'Contrast', 12 | 'Color', 13 | 'Brightness', 14 | 'Sharpness', 15 | 'Cutout', 16 | 'Flip', 17 | 'Identity'] 18 | 19 | 20 | def get_warmup_config(dset): 21 | # multiplier, epoch 22 | config = {'svhn': (2, 2), 23 | 'cifar10': (2, 5), 24 | 'cifar100': (4, 5), 25 | 'mnist': (1, 1), 26 | 'imagenet': (2, 3)} 27 | if 'svhn' in dset: 28 | return config['svhn'] 29 | elif 'cifar100' in dset: 30 | return config['cifar100'] 31 | elif 'cifar10' in dset: 32 | return config['cifar10'] 33 | elif 'mnist' in dset: 34 | return config['mnist'] 35 | elif 'imagenet' in dset: 36 | return config['imagenet'] 37 | else: 38 | return config['imagenet'] 39 | 40 | 41 | def get_search_divider(model_name): 42 | # batch size is too large if the search model is large 43 | # the divider split the update to multiple updates 44 | if model_name == 'wresnet40_2': 45 | return 32 46 | elif model_name == 'resnet50': 47 | return 128 48 | else: 49 | return 16 50 | -------------------------------------------------------------------------------- /ada_aug/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | from sklearn.model_selection import StratifiedShuffleSplit 5 | from torch.utils.data import Sampler, Subset, SubsetRandomSampler 6 | from torchvision import transforms 7 | 8 | 9 | _CIFAR_MEAN, _CIFAR_STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 10 | _SVHN_MEAN, _SVHN_STD = (0.43090966, 0.4302428, 0.44634357), (0.19652855, 0.19832038, 0.19942076) 11 | 12 | 13 | class CutoutDefault(object): 14 | """ 15 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 16 | """ 17 | def __init__(self, length): 18 | self.length = length 19 | 20 | def __call__(self, img): 21 | h, w = img.size(1), img.size(2) 22 | mask = np.ones((h, w), np.float32) 23 | y = np.random.randint(h) 24 | x = np.random.randint(w) 25 | 26 | y1 = np.clip(y - self.length // 2, 0, h) 27 | y2 = np.clip(y + self.length // 2, 0, h) 28 | x1 = np.clip(x - self.length // 2, 0, w) 29 | x2 = np.clip(x + self.length // 2, 0, w) 30 | 31 | mask[y1: y2, x1: x2] = 0. 32 | mask = torch.from_numpy(mask) 33 | mask = mask.expand_as(img) 34 | img *= mask 35 | return img 36 | 37 | 38 | class SubsetSampler(Sampler): 39 | """Samples elements from a given list of indices, without replacement. 40 | 41 | Arguments: 42 | indices (sequence): a sequence of indices 43 | """ 44 | 45 | def __init__(self, indices): 46 | self.indices = indices 47 | 48 | def __iter__(self): 49 | return (i for i in self.indices) 50 | 51 | def __len__(self): 52 | return len(self.indices) 53 | 54 | 55 | class AugmentDataset(torch.utils.data.Dataset): 56 | def __init__(self, dataset, pre_transforms, after_transforms, valid_transforms, search, train): 57 | super(AugmentDataset, self).__init__() 58 | self.dataset = dataset 59 | self.pre_transforms = pre_transforms 60 | self.after_transforms = after_transforms 61 | self.valid_transforms = valid_transforms 62 | self.search = search 63 | self.train = train 64 | 65 | def __getitem__(self, index): 66 | if self.search: 67 | raw_image, target = self.dataset.__getitem__(index) 68 | raw_image = self.pre_transforms(raw_image) 69 | image = transforms.ToTensor()(raw_image) 70 | return image, target 71 | else: 72 | img, target = self.dataset.__getitem__(index) 73 | if self.train: 74 | img = self.pre_transforms(img) 75 | img = self.after_transforms(img) 76 | else: 77 | if self.valid_transforms is not None: 78 | img = self.valid_transforms(img) 79 | return img, target 80 | 81 | def __len__(self): 82 | return self.dataset.__len__() 83 | 84 | 85 | def get_num_class(dataset): 86 | return { 87 | 'cifar10': 10, 88 | 'reduced_cifar10': 10, 89 | 'svhn': 10, 90 | 'reduced_svhn': 10, 91 | }[dataset] 92 | 93 | 94 | def get_num_channel(dataset): 95 | return { 96 | 'cifar10': 3, 97 | 'reduced_cifar10': 3, 98 | 'svhn': 3, 99 | 'reduced_svhn': 3, 100 | }[dataset] 101 | 102 | 103 | def get_dataloaders(dataset, batch, num_workers, dataroot, cutout, 104 | cutout_length, split=0.5, split_idx=0, target_lb=-1, 105 | search=True, search_divider=1): 106 | ''' 107 | If search is True, dataloader will give batches of image without after_transforms, 108 | the transform will be done by augment agent 109 | If search is False, used in benchmark training 110 | ''' 111 | if 'cifar10' in dataset: 112 | transform_train_pre = transforms.Compose([ 113 | transforms.RandomCrop(32, padding=4), 114 | transforms.RandomHorizontalFlip(), 115 | ]) 116 | transform_train_after = transforms.Compose([ 117 | transforms.ToTensor(), 118 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), 119 | ]) 120 | transform_test = transforms.Compose([ 121 | transforms.ToTensor(), 122 | transforms.Normalize(_CIFAR_MEAN, _CIFAR_STD), 123 | ]) 124 | elif 'svhn' in dataset: 125 | transform_train_pre = transforms.Compose([ 126 | transforms.RandomCrop(32, padding=4), 127 | ]) 128 | transform_train_after = transforms.Compose([ 129 | transforms.ToTensor(), 130 | transforms.Normalize(_SVHN_MEAN, _SVHN_STD), 131 | ]) 132 | transform_test = transforms.Compose([ 133 | transforms.ToTensor(), 134 | transforms.Normalize(_SVHN_MEAN, _SVHN_STD), 135 | ]) 136 | else: 137 | raise ValueError('dataset=%s' % dataset) 138 | 139 | if cutout and cutout_length != 0: 140 | transform_train_after.transforms.append(CutoutDefault(cutout_length)) 141 | 142 | if dataset == 'cifar10': 143 | total_trainset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=None) 144 | search_dataset = None 145 | testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=None) 146 | elif dataset == 'reduced_cifar10': 147 | search_dataset = torchvision.datasets.CIFAR10(root=dataroot, train=True, download=True, transform=None) 148 | sss = StratifiedShuffleSplit(n_splits=1, test_size=45744, random_state=0) 149 | sss = sss.split(list(range(len(search_dataset))), search_dataset.targets) 150 | train_idx, valid_idx = next(sss) 151 | targets = [search_dataset.targets[idx] for idx in train_idx] 152 | total_trainset = Subset(search_dataset, train_idx) 153 | total_trainset.targets = targets 154 | targets = [search_dataset.targets[idx] for idx in valid_idx] 155 | search_dataset = Subset(search_dataset, valid_idx) 156 | search_dataset.targets = targets 157 | testset = torchvision.datasets.CIFAR10(root=dataroot, train=False, download=True, transform=None) 158 | elif dataset == 'svhn': 159 | total_trainset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=None) 160 | testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=True, transform=None) 161 | search_dataset = None 162 | elif dataset == 'reduced_svhn': 163 | search_dataset = torchvision.datasets.SVHN(root=dataroot, split='train', download=True, transform=None) 164 | sss = StratifiedShuffleSplit(n_splits=1, test_size=73257-1000, random_state=0) # 1000 + 1000 trainset 165 | sss = sss.split(list(range(len(search_dataset))), search_dataset.labels) 166 | train_idx, search_idx = next(sss) 167 | targets = [search_dataset.labels[idx] for idx in train_idx] 168 | total_trainset = Subset(search_dataset, train_idx) 169 | total_trainset.labels = targets 170 | total_trainset.targets = targets 171 | targets = [search_dataset.labels[idx] for idx in search_idx] 172 | search_dataset = Subset(search_dataset, search_idx) 173 | search_dataset.labels = targets 174 | search_dataset.targets = targets 175 | testset = torchvision.datasets.SVHN(root=dataroot, split='test', download=False, transform=None) 176 | else: 177 | raise ValueError('invalid dataset name=%s' % dataset) 178 | 179 | train_sampler = None 180 | if split < 1.0: 181 | sss = StratifiedShuffleSplit(n_splits=5, test_size=1-split, random_state=0) 182 | sss = sss.split(list(range(len(total_trainset))), total_trainset.targets) 183 | for _ in range(split_idx + 1): 184 | train_idx, valid_idx = next(sss) 185 | 186 | print(len(valid_idx)) 187 | 188 | if target_lb >= 0: 189 | train_idx = [i for i in train_idx if total_trainset.targets[i] == target_lb] 190 | valid_idx = [i for i in valid_idx if total_trainset.targets[i] == target_lb] 191 | 192 | train_sampler = SubsetRandomSampler(train_idx) 193 | valid_sampler = SubsetSampler(valid_idx) 194 | else: 195 | valid_sampler = SubsetSampler([]) 196 | 197 | test_sampler = None 198 | 199 | train_data = AugmentDataset(total_trainset, transform_train_pre, transform_train_after, transform_test, search=search, train=True) 200 | if search and search_dataset is not None: 201 | search_data = AugmentDataset(search_dataset, transform_train_pre, transform_train_after, transform_test, search=True, train=False) 202 | valid_data = AugmentDataset(total_trainset, transform_train_pre, transform_train_after, transform_test, search=False, train=False) 203 | test_data = AugmentDataset(testset, transform_train_pre, transform_train_after, transform_test, search=False, train=False) 204 | 205 | if train_sampler is None: 206 | trainloader = torch.utils.data.DataLoader( 207 | train_data, batch_size=batch, shuffle=True, 208 | drop_last=True, pin_memory=True, 209 | num_workers=num_workers) 210 | else: 211 | trainloader = torch.utils.data.DataLoader( 212 | train_data, batch_size=batch, shuffle=False, 213 | sampler=train_sampler, drop_last=False, 214 | pin_memory=True, num_workers=num_workers) 215 | 216 | validloader = torch.utils.data.DataLoader( 217 | valid_data, batch_size=batch, 218 | sampler=valid_sampler, drop_last=False, 219 | pin_memory=True, num_workers=num_workers) 220 | 221 | if search and search_dataset is not None: 222 | searchloader = torch.utils.data.DataLoader( 223 | search_data, batch_size=search_divider, 224 | shuffle=True, drop_last=True, pin_memory=True, 225 | num_workers=num_workers) 226 | else: 227 | searchloader = None 228 | 229 | testloader = torch.utils.data.DataLoader( 230 | test_data, batch_size=batch, 231 | sampler=test_sampler, drop_last=False, 232 | pin_memory=True, num_workers=num_workers) 233 | 234 | print(f'Dataset: {dataset}') 235 | print(f' |total: {len(train_data)}') 236 | print(f' |train: {len(trainloader)*batch}') 237 | print(f' |valid: {len(validloader)*batch}') 238 | print(f' |test: {len(testloader)*batch}') 239 | if search and search_dataset is not None: 240 | print(f' |search: {len(searchloader)*search_divider}') 241 | return trainloader, validloader, searchloader, testloader 242 | 243 | 244 | def unpickle(file): 245 | import pickle 246 | with open(file, 'rb') as fo: 247 | dict = pickle.load(fo, encoding='bytes') 248 | return dict 249 | 250 | 251 | def get_dataset_dimension(dset): 252 | return {'cifar10': 32, 253 | 'reduced_cifar10': 32, 254 | 'svhn': 32, 255 | 'reduced_svhn': 32}[dset] 256 | 257 | 258 | def get_label_name(dset, dataroot): 259 | if 'cifar10' in dset: 260 | meta = unpickle(f'{dataroot}/cifar-10-batches-py/batches.meta') 261 | classes = [t.decode('utf8') for t in meta[b'label_names']] 262 | elif 'svhn' in dset: 263 | classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 264 | else: 265 | class_idxs = np.arange(0, get_num_class(dset)) 266 | classes = [str(i) for i in class_idxs] 267 | return classes 268 | -------------------------------------------------------------------------------- /ada_aug/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | # from __future__ import absolute_import 3 | import torch 4 | 5 | from torch import nn 6 | from torch.nn import DataParallel 7 | 8 | from .resnet import ResNet 9 | from .wideresnet import WideResNet 10 | 11 | 12 | def get_model(model_name='wresnet40_2', num_class=10, n_channel=3, use_cuda=True, data_parallel=False): 13 | name = model_name 14 | 15 | if name == 'resnet50': 16 | model = ResNet(dataset='imagenet', n_channel=n_channel, depth=50, num_classes=num_class, bottleneck=True) 17 | elif name == 'resnet200': 18 | model = ResNet(dataset='imagenet', n_channel=n_channel, depth=200, num_classes=num_class, bottleneck=True) 19 | elif name == 'wresnet40_2': 20 | model = WideResNet(40, 2, dropout_rate=0.0, num_classes=num_class) 21 | elif name == 'wresnet28_10': 22 | model = WideResNet(28, 10, dropout_rate=0.0, num_classes=num_class) 23 | else: 24 | raise NameError('no model named, %s' % name) 25 | 26 | if data_parallel: 27 | model = model.cuda() 28 | model = DataParallel(model) 29 | else: 30 | if use_cuda: 31 | model = model.cuda() 32 | return model 33 | 34 | -------------------------------------------------------------------------------- /ada_aug/networks/projection.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from config import OPS_NAMES 3 | 4 | class Projection(nn.Module): 5 | def __init__(self, in_features, n_layers, n_hidden=128): 6 | super(Projection, self).__init__() 7 | self.n_layers = n_layers 8 | if self.n_layers > 0: 9 | layers = [nn.Linear(in_features, n_hidden), nn.ReLU()] 10 | for _ in range(self.n_layers-1): 11 | layers.append(nn.Linear(n_hidden, n_hidden)) 12 | layers.append(nn.ReLU()) 13 | layers.append(nn.Linear(n_hidden, 2*len(OPS_NAMES))) 14 | else: 15 | layers = [nn.Linear(in_features, 2*len(OPS_NAMES))] 16 | self.projection = nn.Sequential(*layers) 17 | 18 | def forward(self, x): 19 | return self.projection(x) -------------------------------------------------------------------------------- /ada_aug/networks/resnet.py: -------------------------------------------------------------------------------- 1 | # Original code: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 2 | 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = conv3x3(planes, planes) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.relu = nn.ReLU(inplace=True) 23 | 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | residual = x 29 | 30 | out = self.conv1(x) 31 | out = self.bn1(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv2(out) 35 | out = self.bn2(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | 46 | class Bottleneck(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, inplanes, planes, stride=1, downsample=None): 50 | super(Bottleneck, self).__init__() 51 | 52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False) 57 | self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion) 58 | self.relu = nn.ReLU(inplace=True) 59 | 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | residual = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv3(out) 75 | out = self.bn3(out) 76 | if self.downsample is not None: 77 | residual = self.downsample(x) 78 | 79 | out += residual 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | 85 | class ResNet(nn.Module): 86 | def __init__(self, dataset, depth, n_channel, num_classes, bottleneck=False): 87 | super(ResNet, self).__init__() 88 | self.dataset = dataset 89 | if self.dataset.startswith('cifar'): 90 | self.inplanes = 16 91 | print(bottleneck) 92 | if bottleneck == True: 93 | n = int((depth - 2) / 9) 94 | block = Bottleneck 95 | else: 96 | n = int((depth - 2) / 6) 97 | block = BasicBlock 98 | 99 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 100 | self.bn1 = nn.BatchNorm2d(self.inplanes) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.layer1 = self._make_layer(block, 16, n) 103 | self.layer2 = self._make_layer(block, 32, n, stride=2) 104 | self.layer3 = self._make_layer(block, 64, n, stride=2) 105 | # self.avgpool = nn.AvgPool2d(8) 106 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 107 | self.fc = nn.Linear(64 * block.expansion, num_classes) 108 | 109 | elif dataset == 'imagenet': 110 | blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} 111 | layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]} 112 | assert layers[depth], 'invalid detph for ResNet (depth should be one of 18, 34, 50, 101, 152, and 200)' 113 | 114 | self.inplanes = 64 115 | self.conv1 = nn.Conv2d(n_channel, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 116 | self.bn1 = nn.BatchNorm2d(64) 117 | self.relu = nn.ReLU(inplace=True) 118 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 119 | self.layer1 = self._make_layer(blocks[depth], 64, layers[depth][0]) 120 | self.layer2 = self._make_layer(blocks[depth], 128, layers[depth][1], stride=2) 121 | self.layer3 = self._make_layer(blocks[depth], 256, layers[depth][2], stride=2) 122 | self.layer4 = self._make_layer(blocks[depth], 512, layers[depth][3], stride=2) 123 | # self.avgpool = nn.AvgPool2d(7) 124 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 125 | self.fc = nn.Linear(512 * blocks[depth].expansion, num_classes) 126 | 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 130 | m.weight.data.normal_(0, math.sqrt(2. / n)) 131 | elif isinstance(m, nn.BatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | 135 | def _make_layer(self, block, planes, blocks, stride=1): 136 | downsample = None 137 | if stride != 1 or self.inplanes != planes * block.expansion: 138 | downsample = nn.Sequential( 139 | nn.Conv2d(self.inplanes, planes * block.expansion, 140 | kernel_size=1, stride=stride, bias=False), 141 | nn.BatchNorm2d(planes * block.expansion), 142 | ) 143 | 144 | layers = [] 145 | layers.append(block(self.inplanes, planes, stride, downsample)) 146 | self.inplanes = planes * block.expansion 147 | for i in range(1, blocks): 148 | layers.append(block(self.inplanes, planes)) 149 | 150 | return nn.Sequential(*layers) 151 | 152 | def f(self, x): 153 | if self.dataset == 'cifar10' or self.dataset == 'cifar100': 154 | x = self.conv1(x) 155 | x = self.bn1(x) 156 | x = self.relu(x) 157 | 158 | x = self.layer1(x) 159 | x = self.layer2(x) 160 | x = self.layer3(x) 161 | 162 | x = self.avgpool(x) 163 | x = x.view(x.size(0), -1) 164 | 165 | elif self.dataset == 'imagenet': 166 | x = self.conv1(x) 167 | x = self.bn1(x) 168 | x = self.relu(x) 169 | x = self.maxpool(x) 170 | 171 | x = self.layer1(x) 172 | x = self.layer2(x) 173 | x = self.layer3(x) 174 | x = self.layer4(x) 175 | 176 | x = self.avgpool(x) 177 | x = x.view(x.size(0), -1) 178 | 179 | return x 180 | 181 | def g(self, x): 182 | return self.fc(x) 183 | 184 | def forward(self, x): 185 | x = self.f(x) 186 | x = self.g(x) 187 | 188 | return x 189 | -------------------------------------------------------------------------------- /ada_aug/networks/wideresnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.init as init 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 9 | 10 | 11 | def conv_init(m): 12 | classname = m.__class__.__name__ 13 | if classname.find('Conv') != -1: 14 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 15 | init.constant_(m.bias, 0) 16 | elif classname.find('BatchNorm') != -1: 17 | init.constant_(m.weight, 1) 18 | init.constant_(m.bias, 0) 19 | 20 | 21 | class WideBasic(nn.Module): 22 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 23 | super(WideBasic, self).__init__() 24 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.1) 25 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 26 | self.dropout = nn.Dropout(p=dropout_rate) 27 | self.bn2 = nn.BatchNorm2d(planes, momentum=0.1) 28 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 29 | 30 | self.shortcut = nn.Sequential() 31 | if stride != 1 or in_planes != planes: 32 | self.shortcut = nn.Sequential( 33 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 34 | ) 35 | 36 | def forward(self, x): 37 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 38 | out = self.conv2(F.relu(self.bn2(out))) 39 | out += self.shortcut(x) 40 | 41 | return out 42 | 43 | 44 | class WideResNet(nn.Module): 45 | def __init__(self, depth, widen_factor, dropout_rate, num_classes): 46 | super(WideResNet, self).__init__() 47 | self.in_planes = 16 48 | 49 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 50 | n = int((depth - 4) / 6) 51 | k = widen_factor 52 | 53 | nStages = [16, 16*k, 32*k, 64*k] 54 | 55 | self.conv1 = conv3x3(3, nStages[0]) 56 | self.layer1 = self._wide_layer(WideBasic, nStages[1], n, dropout_rate, stride=1) 57 | self.layer2 = self._wide_layer(WideBasic, nStages[2], n, dropout_rate, stride=2) 58 | self.layer3 = self._wide_layer(WideBasic, nStages[3], n, dropout_rate, stride=2) 59 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.1) 60 | self.fc = nn.Linear(nStages[3], num_classes) 61 | 62 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 63 | strides = [stride] + [1]*(num_blocks-1) 64 | layers = [] 65 | 66 | for stride in strides: 67 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 68 | self.in_planes = planes 69 | 70 | return nn.Sequential(*layers) 71 | 72 | def extract_multi_layer_feature(self, x): 73 | out = self.layer1(x) 74 | out = self.layer2(out) 75 | out2 = F.adaptive_avg_pool2d(out, (1, 1)) 76 | out2 = out2.view(out2.size(0), -1) 77 | 78 | out = self.layer3(out) 79 | out = F.relu(self.bn1(out)) 80 | out = F.adaptive_avg_pool2d(out, (1, 1)) 81 | out = out.view(out.size(0), -1) 82 | 83 | def f(self, x): 84 | out = self.conv1(x) 85 | out = self.layer1(out) 86 | out = self.layer2(out) 87 | out = self.layer3(out) 88 | out = F.relu(self.bn1(out)) 89 | # out = F.avg_pool2d(out, 8) 90 | out = F.adaptive_avg_pool2d(out, (1, 1)) 91 | out = out.view(out.size(0), -1) 92 | 93 | return out 94 | 95 | def g(self, x): 96 | return self.fc(x) 97 | 98 | def forward(self, x): 99 | x = self.f(x) 100 | x = self.g(x) 101 | 102 | return x 103 | -------------------------------------------------------------------------------- /ada_aug/operation.py: -------------------------------------------------------------------------------- 1 | # code is adapted from Fast AutoAugment and AutoAugment 2 | 3 | import random 4 | import numpy as np 5 | import PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 6 | 7 | random_mirror = True 8 | 9 | 10 | def Identity(img, _): 11 | return img 12 | 13 | 14 | def ShearX(img, v): # [-0.3, 0.3] 15 | assert -0.3 <= v <= 0.3 16 | if random_mirror and random.random() > 0.5: 17 | v = -v 18 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 19 | 20 | 21 | def ShearY(img, v): # [-0.3, 0.3] 22 | assert -0.3 <= v <= 0.3 23 | if random_mirror and random.random() > 0.5: 24 | v = -v 25 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 26 | 27 | 28 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 29 | assert -0.45 <= v <= 0.45 30 | if random_mirror and random.random() > 0.5: 31 | v = -v 32 | v = v * img.size[0] 33 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 34 | 35 | 36 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 37 | assert -0.45 <= v <= 0.45 38 | if random_mirror and random.random() > 0.5: 39 | v = -v 40 | v = v * img.size[1] 41 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 42 | 43 | 44 | def Rotate(img, v): # [-30, 30] 45 | assert -30 <= v <= 30 46 | if random_mirror and random.random() > 0.5: 47 | v = -v 48 | return img.rotate(v) 49 | 50 | 51 | def AutoContrast(img, _): 52 | return PIL.ImageOps.autocontrast(img) 53 | 54 | 55 | def Invert(img, _): 56 | return PIL.ImageOps.invert(img) 57 | 58 | 59 | def Equalize(img, _): 60 | return PIL.ImageOps.equalize(img) 61 | 62 | 63 | def Flip(img, _): # not from the paper 64 | return PIL.ImageOps.mirror(img) 65 | 66 | 67 | def Solarize(img, v): # [0, 256] 68 | assert 0 <= v <= 256 69 | return PIL.ImageOps.solarize(img, v) 70 | 71 | 72 | def Posterize(img, v): # [4, 8] 73 | assert 4 <= v <= 8 74 | v = int(v) 75 | return PIL.ImageOps.posterize(img, v) 76 | 77 | 78 | def Posterize2(img, v): # [0, 4] 79 | assert 0 <= v <= 4 80 | v = int(v) 81 | return PIL.ImageOps.posterize(img, v) 82 | 83 | 84 | def Contrast(img, v): # [0.1,1.9] 85 | assert 0.1 <= v <= 1.9 86 | return PIL.ImageEnhance.Contrast(img).enhance(v) 87 | 88 | 89 | def Color(img, v): # [0.1,1.9] 90 | assert 0.1 <= v <= 1.9 91 | return PIL.ImageEnhance.Color(img).enhance(v) 92 | 93 | 94 | def Brightness(img, v): # [0.1,1.9] 95 | assert 0.1 <= v <= 1.9 96 | return PIL.ImageEnhance.Brightness(img).enhance(v) 97 | 98 | 99 | def Sharpness(img, v): # [0.1,1.9] 100 | assert 0.1 <= v <= 1.9 101 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 102 | 103 | 104 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 105 | assert 0.0 <= v <= 0.2 106 | if v <= 0.: 107 | return img 108 | 109 | v = v * img.size[0] 110 | return CutoutAbs(img, v) 111 | 112 | 113 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 114 | # assert 0 <= v <= 20 115 | if v < 0: 116 | return img 117 | w, h = img.size 118 | x0 = np.random.uniform(w) 119 | y0 = np.random.uniform(h) 120 | 121 | x0 = int(max(0, x0 - v / 2.)) 122 | y0 = int(max(0, y0 - v / 2.)) 123 | x1 = min(w, x0 + v) 124 | y1 = min(h, y0 + v) 125 | 126 | xy = (x0, y0, x1, y1) 127 | color = (125, 123, 114) 128 | # color = (0, 0, 0) 129 | img = img.copy() 130 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 131 | return img 132 | 133 | 134 | AUGMENT_LIST = [ 135 | (ShearX, -0.3, 0.3), # 0 136 | (ShearY, -0.3, 0.3), # 1 137 | (TranslateX, -0.45, 0.45), # 2 138 | (TranslateY, -0.45, 0.45), # 3 139 | (Rotate, -30, 30), # 4 140 | (AutoContrast, 0, 1), # 5 141 | (Invert, 0, 1), # 6 142 | (Equalize, 0, 1), # 7 143 | (Solarize, 0, 256), # 8 144 | (Posterize, 4, 8), # 9 145 | (Contrast, 0.1, 1.9), # 10 146 | (Color, 0.1, 1.9), # 11 147 | (Brightness, 0.1, 1.9), # 12 148 | (Sharpness, 0.1, 1.9), # 13 149 | (Cutout, 0, 0.2), # 14 150 | (Flip, 0, 1), # 15 151 | (Identity, 0, 1)] # 16 152 | 153 | 154 | def get_augment(name): 155 | augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in AUGMENT_LIST} 156 | return augment_dict[name] 157 | 158 | 159 | def apply_augment(img, name, level): 160 | augment_fn, low, high = get_augment(name) 161 | return augment_fn(img.copy(), level * (high - low) + low) 162 | 163 | -------------------------------------------------------------------------------- /ada_aug/search.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | 4 | import os 5 | import sys 6 | import time 7 | import torch 8 | import utils 9 | import logging 10 | import argparse 11 | import torch.nn as nn 12 | import torch.utils 13 | 14 | from adaptive_augmentor import AdaAug 15 | from networks import get_model 16 | from networks.projection import Projection 17 | from config import get_search_divider 18 | from dataset import get_dataloaders, get_num_class, get_label_name, get_dataset_dimension 19 | 20 | parser = argparse.ArgumentParser("ada_aug") 21 | parser.add_argument('--dataroot', type=str, default='./', help='location of the data corpus') 22 | parser.add_argument('--dataset', type=str, default='cifar10', help='name of dataset') 23 | parser.add_argument('--batch_size', type=int, default=512, help='batch size') 24 | parser.add_argument('--learning_rate', type=float, default=0.400, help='init learning rate') 25 | parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate') 26 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 27 | parser.add_argument('--weight_decay', type=float, default=2e-4, help='weight decay') 28 | parser.add_argument('--report_freq', type=float, default=1, help='report frequency') 29 | parser.add_argument('--gpu', type=int, default=0, help='gpu device id') 30 | parser.add_argument('--epochs', type=int, default=20, help='number of training epochs') 31 | parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model') 32 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 33 | parser.add_argument('--seed', type=int, default=2, help='seed') 34 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') 35 | parser.add_argument('--train_portion', type=float, default=1, help='portion of training data') 36 | parser.add_argument('--proj_learning_rate', type=float, default=1e-2, help='learning rate for h') 37 | parser.add_argument('--proj_weight_decay', type=float, default=1e-3, help='weight decay for h]') 38 | parser.add_argument('--cutout', action='store_true', default=False, help='use cutout') 39 | parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') 40 | parser.add_argument('--use_cuda', type=bool, default=True, help="use cuda default True") 41 | parser.add_argument('--use_parallel', type=bool, default=False, help="use data parallel default False") 42 | parser.add_argument('--model_name', type=str, default='wresnet40_2', help="mode _name") 43 | parser.add_argument('--num_workers', type=int, default=0, help="num_workers") 44 | parser.add_argument('--k_ops', type=int, default=1, help="number of augmentation applied during training") 45 | parser.add_argument('--temperature', type=float, default=1.0, help="temperature") 46 | parser.add_argument('--search_freq', type=float, default=1, help='exploration frequency') 47 | parser.add_argument('--n_proj_layer', type=int, default=0, help="number of hidden layer in augmentation policy projection") 48 | 49 | args = parser.parse_args() 50 | debug = True if args.save == "debug" else False 51 | args.save = '{}-{}'.format(time.strftime("%Y%m%d-%H%M%S"), args.save) 52 | if debug: 53 | args.save = os.path.join('debug', args.save) 54 | else: 55 | args.save = os.path.join('search', args.dataset, args.save) 56 | utils.create_exp_dir(args.save) 57 | log_format = '%(asctime)s %(message)s' 58 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 59 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 60 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 61 | fh.setFormatter(logging.Formatter(log_format)) 62 | logging.getLogger().addHandler(fh) 63 | 64 | 65 | def main(): 66 | if not torch.cuda.is_available(): 67 | logging.info('no gpu device available') 68 | sys.exit(1) 69 | 70 | torch.cuda.set_device(args.gpu) 71 | utils.reproducibility(args.seed) 72 | logging.info('gpu device = %d' % args.gpu) 73 | logging.info("args = %s", args) 74 | 75 | # dataset settings 76 | n_class = get_num_class(args.dataset) 77 | sdiv = get_search_divider(args.model_name) 78 | class2label = get_label_name(args.dataset, args.dataroot) 79 | 80 | train_queue, valid_queue, search_queue, test_queue = get_dataloaders( 81 | args.dataset, args.batch_size, args.num_workers, 82 | args.dataroot, args.cutout, args.cutout_length, 83 | split=args.train_portion, split_idx=0, target_lb=-1, 84 | search=True, search_divider=sdiv) 85 | 86 | logging.info(f'Dataset: {args.dataset}') 87 | logging.info(f' |total: {len(train_queue.dataset)}') 88 | logging.info(f' |train: {len(train_queue)*args.batch_size}') 89 | logging.info(f' |valid: {len(valid_queue)*args.batch_size}') 90 | logging.info(f' |search: {len(search_queue)*sdiv}') 91 | 92 | # model settings 93 | gf_model = get_model(model_name=args.model_name, num_class=n_class, 94 | use_cuda=True, data_parallel=False) 95 | logging.info("param size = %fMB", utils.count_parameters_in_MB(gf_model)) 96 | 97 | h_model = Projection(in_features=gf_model.fc.in_features, 98 | n_layers=args.n_proj_layer, n_hidden=128).cuda() 99 | 100 | # training settings 101 | gf_optimizer = torch.optim.SGD( 102 | gf_model.parameters(), 103 | args.learning_rate, 104 | momentum=args.momentum, 105 | weight_decay=args.weight_decay) 106 | 107 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(gf_optimizer, 108 | float(args.epochs), eta_min=args.learning_rate_min) 109 | 110 | h_optimizer = torch.optim.Adam( 111 | h_model.parameters(), 112 | lr=args.proj_learning_rate, 113 | betas=(0.9, 0.999), 114 | weight_decay=args.proj_weight_decay) 115 | 116 | criterion = nn.CrossEntropyLoss() 117 | criterion = criterion.cuda() 118 | 119 | # AdaAug settings 120 | after_transforms = train_queue.dataset.after_transforms 121 | adaaug_config = {'sampling': 'prob', 122 | 'k_ops': 1, 123 | 'delta': 0.0, 124 | 'temp': 1.0, 125 | 'search_d': get_dataset_dimension(args.dataset), 126 | 'target_d': get_dataset_dimension(args.dataset)} 127 | 128 | adaaug = AdaAug(after_transforms=after_transforms, 129 | n_class=n_class, 130 | gf_model=gf_model, 131 | h_model=h_model, 132 | save_dir=args.save, 133 | config=adaaug_config) 134 | 135 | # Start training 136 | start_time = time.time() 137 | for epoch in range(args.epochs): 138 | lr = scheduler.get_last_lr()[0] 139 | logging.info('epoch %d lr %e', epoch, lr) 140 | 141 | # searching 142 | train_acc, train_obj = train(train_queue, search_queue, gf_model, adaaug, 143 | criterion, gf_optimizer, args.grad_clip, h_optimizer, epoch, args.search_freq) 144 | 145 | # validation 146 | valid_acc, valid_obj = infer(valid_queue, gf_model, criterion) 147 | 148 | logging.info(f'train_acc {train_acc} valid_acc {valid_acc}') 149 | scheduler.step() 150 | 151 | utils.save_model(gf_model, os.path.join(args.save, 'gf_weights.pt')) 152 | utils.save_model(h_model, os.path.join(args.save, 'h_weights.pt')) 153 | 154 | end_time = time.time() 155 | elapsed = end_time - start_time 156 | 157 | test_acc, test_obj = infer(test_queue, gf_model, criterion) 158 | utils.save_model(gf_model, os.path.join(args.save, 'gf_weights.pt')) 159 | utils.save_model(h_model, os.path.join(args.save, 'h_weights.pt')) 160 | adaaug.save_history(class2label) 161 | figure = adaaug.plot_history() 162 | 163 | logging.info(f'test_acc {test_acc}') 164 | logging.info('elapsed time: %.3f Hours' % (elapsed / 3600.)) 165 | logging.info(f'saved to: {args.save}') 166 | 167 | def train(train_queue, search_queue, gf_model, adaaug, criterion, gf_optimizer, 168 | grad_clip, h_optimizer, epoch, search_freq): 169 | objs = utils.AvgrageMeter() 170 | top1 = utils.AvgrageMeter() 171 | top5 = utils.AvgrageMeter() 172 | 173 | for step, (input, target) in enumerate(train_queue): 174 | target = target.cuda(non_blocking=True) 175 | 176 | # exploitation 177 | timer = time.time() 178 | aug_images = adaaug(input, mode='exploit') 179 | gf_model.train() 180 | gf_optimizer.zero_grad() 181 | logits = gf_model(aug_images) 182 | loss = criterion(logits, target) 183 | loss.backward() 184 | nn.utils.clip_grad_norm_(gf_model.parameters(), grad_clip) 185 | gf_optimizer.step() 186 | 187 | # stats 188 | prec1, prec5 = utils.accuracy(logits.detach(), target.detach(), topk=(1, 5)) 189 | n = target.size(0) 190 | objs.update(loss.detach().item(), n) 191 | top1.update(prec1.detach().item(), n) 192 | top5.update(prec5.detach().item(), n) 193 | exploitation_time = time.time() - timer 194 | 195 | # exploration 196 | timer = time.time() 197 | if step % search_freq == 0: 198 | input_search, target_search = next(iter(search_queue)) 199 | target_search = target_search.cuda(non_blocking=True) 200 | 201 | h_optimizer.zero_grad() 202 | mixed_features = adaaug(input_search, mode='explore') 203 | logits = gf_model.g(mixed_features) 204 | loss = criterion(logits, target_search) 205 | loss.backward() 206 | h_optimizer.step() 207 | exploration_time = time.time() - timer 208 | 209 | # log policy 210 | adaaug.add_history(input_search, target_search) 211 | 212 | global_step = epoch * len(train_queue) + step 213 | if global_step % args.report_freq == 0: 214 | logging.info(' |train %03d %e %f %f | %.3f + %.3f s', global_step, 215 | objs.avg, top1.avg, top5.avg, exploitation_time, exploration_time) 216 | 217 | return top1.avg, objs.avg 218 | 219 | 220 | def infer(valid_queue, gf_model, criterion): 221 | objs = utils.AvgrageMeter() 222 | top1 = utils.AvgrageMeter() 223 | top5 = utils.AvgrageMeter() 224 | gf_model.eval() 225 | 226 | with torch.no_grad(): 227 | for input, target in valid_queue: 228 | input = input.cuda() 229 | target = target.cuda(non_blocking=True) 230 | 231 | logits = gf_model(input) 232 | loss = criterion(logits, target) 233 | 234 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 235 | n = input.size(0) 236 | objs.update(loss.detach().item(), n) 237 | top1.update(prec1.detach().item(), n) 238 | top5.update(prec5.detach().item(), n) 239 | 240 | return top1.avg, objs.avg 241 | 242 | 243 | if __name__ == '__main__': 244 | main() 245 | -------------------------------------------------------------------------------- /ada_aug/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import torch 5 | import utils 6 | import logging 7 | import argparse 8 | import torch.nn as nn 9 | import torch.utils 10 | 11 | from adaptive_augmentor import AdaAug 12 | from networks import get_model 13 | from networks.projection import Projection 14 | from dataset import get_num_class, get_dataloaders, get_label_name, get_dataset_dimension 15 | from config import get_warmup_config 16 | from warmup_scheduler import GradualWarmupScheduler 17 | 18 | parser = argparse.ArgumentParser("ada_aug") 19 | parser.add_argument('--dataroot', type=str, default='./', help='location of the data corpus') 20 | parser.add_argument('--dataset', type=str, default='cifar10', help='name of dataset') 21 | parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data') 22 | parser.add_argument('--batch_size', type=int, default=96, help='batch size') 23 | parser.add_argument('--num_workers', type=int, default=0, help="num_workers") 24 | parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate') 25 | parser.add_argument('--learning_rate_min', type=float, default=0.0001, help='min learning rate') 26 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 27 | parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay') 28 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') 29 | parser.add_argument('--use_cuda', type=bool, default=True, help="use cuda default True") 30 | parser.add_argument('--gpu', type=int, default=0, help='gpu device id') 31 | parser.add_argument('--use_parallel', action='store_true', default=False, help="use data parallel default False") 32 | parser.add_argument('--model_name', type=str, default='wresnet40_2', help="model name") 33 | parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model') 34 | parser.add_argument('--cutout', action='store_true', default=False, help='use cutout') 35 | parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') 36 | parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability') 37 | parser.add_argument('--epochs', type=int, default=600, help='number of training epochs') 38 | parser.add_argument('--report_freq', type=float, default=50, help='report frequency') 39 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 40 | parser.add_argument('--seed', type=int, default=0, help='seed') 41 | parser.add_argument('--search_dataset', type=str, default='./', help='search dataset name') 42 | parser.add_argument('--gf_model_name', type=str, default='./', help='gf_model name') 43 | parser.add_argument('--gf_model_path', type=str, default='./', help='gf_model path') 44 | parser.add_argument('--h_model_path', type=str, default='./', help='h_model path') 45 | parser.add_argument('--k_ops', type=int, default=1, help="number of augmentation applied during training") 46 | parser.add_argument('--delta', type=float, default=0.3, help="degree of perturbation in magnitude") 47 | parser.add_argument('--temperature', type=float, default=1.0, help="temperature") 48 | parser.add_argument('--n_proj_layer', type=int, default=0, help="number of additional hidden layer in augmentation policy projection") 49 | parser.add_argument('--n_proj_hidden', type=int, default=128, help="number of hidden units in augmentation policy projection layers") 50 | parser.add_argument('--restore_path', type=str, default='./', help='restore model path') 51 | parser.add_argument('--restore', action='store_true', default=False, help='restore model default False') 52 | 53 | args = parser.parse_args() 54 | debug = True if args.save == "debug" else False 55 | args.save = '{}-{}'.format(time.strftime("%Y%m%d-%H%M%S"), args.save) 56 | if debug: 57 | args.save = os.path.join('debug', args.save) 58 | else: 59 | args.save = os.path.join('eval', args.dataset, args.save) 60 | utils.create_exp_dir(args.save) 61 | log_format = '%(asctime)s %(message)s' 62 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 63 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 64 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 65 | fh.setFormatter(logging.Formatter(log_format)) 66 | logging.getLogger().addHandler(fh) 67 | 68 | 69 | def main(): 70 | if not torch.cuda.is_available(): 71 | logging.info('no gpu device available') 72 | sys.exit(1) 73 | 74 | torch.cuda.set_device(args.gpu) 75 | utils.reproducibility(args.seed) 76 | logging.info('gpu device = %d' % args.gpu) 77 | logging.info("args = %s", args) 78 | 79 | # dataset settings 80 | n_class = get_num_class(args.dataset) 81 | class2label = get_label_name(args.dataset, args.dataroot) 82 | train_queue, valid_queue, _, test_queue = get_dataloaders( 83 | args.dataset, args.batch_size, args.num_workers, 84 | args.dataroot, args.cutout, args.cutout_length, 85 | split=args.train_portion, split_idx=0, target_lb=-1, 86 | search=True) 87 | 88 | logging.info(f'Dataset: {args.dataset}') 89 | logging.info(f' |total: {len(train_queue.dataset)}') 90 | logging.info(f' |train: {len(train_queue)*args.batch_size}') 91 | logging.info(f' |valid: {len(valid_queue)*args.batch_size}') 92 | 93 | # task model settings 94 | task_model = get_model(model_name=args.model_name, 95 | num_class=n_class, 96 | use_cuda=True, data_parallel=False) 97 | logging.info("param size = %fMB", utils.count_parameters_in_MB(task_model)) 98 | 99 | # task optimization settings 100 | optimizer = torch.optim.SGD( 101 | task_model.parameters(), 102 | args.learning_rate, 103 | momentum=args.momentum, 104 | weight_decay=args.weight_decay, 105 | nesterov=True 106 | ) 107 | 108 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 109 | optimizer, float(args.epochs), eta_min=args.learning_rate_min) 110 | 111 | m, e = get_warmup_config(args.dataset) 112 | scheduler = GradualWarmupScheduler( 113 | optimizer, 114 | multiplier=m, 115 | total_epoch=e, 116 | after_scheduler=scheduler) 117 | logging.info(f'Optimizer: SGD, scheduler: CosineAnnealing, warmup: {m}/{e}') 118 | criterion = nn.CrossEntropyLoss() 119 | criterion = criterion.cuda() 120 | 121 | # restore setting 122 | if args.restore: 123 | trained_epoch = utils.restore_ckpt(task_model, optimizer, scheduler, args.restore_path, location=args.gpu) + 1 124 | n_epoch = args.epochs - trained_epoch 125 | logging.info(f'Restoring model from {args.restore_path}, starting from epoch {trained_epoch}') 126 | else: 127 | trained_epoch = 0 128 | n_epoch = args.epochs 129 | 130 | # load trained adaaug sub models 131 | search_n_class = get_num_class(args.search_dataset) 132 | gf_model = get_model(model_name=args.gf_model_name, 133 | num_class=search_n_class, 134 | use_cuda=True, data_parallel=False) 135 | 136 | h_model = Projection(in_features=gf_model.fc.in_features, 137 | n_layers=args.n_proj_layer, 138 | n_hidden=args.n_proj_hidden).cuda() 139 | 140 | utils.load_model(gf_model, f'{args.gf_model_path}/gf_weights.pt', location=args.gpu) 141 | utils.load_model(h_model, f'{args.h_model_path}/h_weights.pt', location=args.gpu) 142 | 143 | for param in gf_model.parameters(): 144 | param.requires_grad = False 145 | 146 | for param in h_model.parameters(): 147 | param.requires_grad = False 148 | 149 | after_transforms = train_queue.dataset.after_transforms 150 | adaaug_config = {'sampling': 'prob', 151 | 'k_ops': args.k_ops, 152 | 'delta': args.delta, 153 | 'temp': args.temperature, 154 | 'search_d': get_dataset_dimension(args.search_dataset), 155 | 'target_d': get_dataset_dimension(args.dataset)} 156 | 157 | adaaug = AdaAug(after_transforms=after_transforms, 158 | n_class=search_n_class, 159 | gf_model=gf_model, 160 | h_model=h_model, 161 | save_dir=args.save, 162 | config=adaaug_config) 163 | 164 | # start training 165 | for i_epoch in range(n_epoch): 166 | epoch = trained_epoch + i_epoch 167 | lr = scheduler.get_last_lr()[0] 168 | logging.info('epoch %d lr %e', epoch, lr) 169 | 170 | train_acc, train_obj = train( 171 | train_queue, task_model, criterion, optimizer, epoch, args.grad_clip, adaaug) 172 | logging.info('train_acc %f', train_acc) 173 | 174 | valid_acc, valid_obj, _, _ = infer(valid_queue, task_model, criterion) 175 | logging.info('valid_acc %f', valid_acc) 176 | 177 | scheduler.step() 178 | 179 | if epoch % args.report_freq == 0: 180 | test_acc, test_obj, test_acc5, _ = infer(test_queue, task_model, criterion) 181 | logging.info('test_acc %f %f', test_acc, test_acc5) 182 | 183 | utils.save_ckpt(task_model, optimizer, scheduler, epoch, 184 | os.path.join(args.save, 'weights.pt')) 185 | 186 | adaaug.save_history(class2label) 187 | figure = adaaug.plot_history() 188 | test_acc, test_obj, test_acc5, _ = infer(test_queue, task_model, criterion) 189 | 190 | logging.info('test_acc %f %f', test_acc, test_acc5) 191 | logging.info(f'save to {args.save}') 192 | 193 | 194 | def train(train_queue, model, criterion, optimizer, epoch, grad_clip, adaaug): 195 | objs = utils.AvgrageMeter() 196 | top1 = utils.AvgrageMeter() 197 | top5 = utils.AvgrageMeter() 198 | 199 | for step, (input, target) in enumerate(train_queue): 200 | target = target.cuda(non_blocking=True) 201 | 202 | # get augmented training data from adaaug 203 | aug_images = adaaug(input, mode='exploit') 204 | model.train() 205 | optimizer.zero_grad() 206 | logits = model(aug_images) 207 | loss = criterion(logits, target) 208 | loss.backward() 209 | nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 210 | optimizer.step() 211 | 212 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 213 | n = input.size(0) 214 | objs.update(loss.detach().item(), n) 215 | top1.update(prec1.detach().item(), n) 216 | top5.update(prec5.detach().item(), n) 217 | 218 | global_step = step + epoch * len(train_queue) 219 | if global_step % args.report_freq == 0: 220 | logging.info('train %03d %e %f %f', global_step, objs.avg, top1.avg, top5.avg) 221 | 222 | # log the policy 223 | if step == 0: 224 | adaaug.add_history(input, target) 225 | 226 | return top1.avg, objs.avg 227 | 228 | 229 | def infer(valid_queue, model, criterion): 230 | objs = utils.AvgrageMeter() 231 | top1 = utils.AvgrageMeter() 232 | top5 = utils.AvgrageMeter() 233 | model.eval() 234 | with torch.no_grad(): 235 | for input, target in valid_queue: 236 | input = input.cuda() 237 | target = target.cuda(non_blocking=True) 238 | 239 | logits = model(input) 240 | loss = criterion(logits, target) 241 | 242 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 243 | n = input.size(0) 244 | objs.update(loss.detach().item(), n) 245 | top1.update(prec1.detach().item(), n) 246 | top5.update(prec5.detach().item(), n) 247 | 248 | return top1.avg, objs.avg, top5.avg, objs.avg 249 | 250 | 251 | if __name__ == '__main__': 252 | main() 253 | -------------------------------------------------------------------------------- /ada_aug/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | import shutil 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pandas as pd 9 | import seaborn as sns 10 | import torch 11 | import torchvision.transforms as transforms 12 | from torch.autograd import Variable 13 | sns.set() 14 | 15 | 16 | def save_ckpt(model, optimizer, scheduler, epoch, model_path): 17 | torch.save({'model':model.state_dict(), 18 | 'epoch': epoch, 19 | 'optimizer': optimizer.state_dict(), 20 | 'scheduler': scheduler.state_dict()}, model_path) 21 | 22 | 23 | def restore_ckpt(model, optimizer, scheduler, model_path, location): 24 | state = torch.load(model_path, map_location=f'cuda:{location}') 25 | model.load_state_dict(state['model'], strict=True) 26 | optimizer.load_state_dict(state['optimizer']) 27 | scheduler.load_state_dict(state['scheduler']) 28 | epoch = state['epoch'] 29 | return epoch 30 | 31 | 32 | def save_model(model, model_path): 33 | torch.save(model.state_dict(), model_path) 34 | 35 | 36 | def load_model(model, model_path, location): 37 | model.load_state_dict(torch.load(model_path, map_location=f'cuda:{location}'), strict=True) 38 | 39 | 40 | class PolicyHistory(object): 41 | 42 | def __init__(self, op_names, save_dir, n_class): 43 | self.op_names = op_names 44 | self.save_dir = save_dir 45 | self._initialize(n_class) 46 | 47 | def _initialize(self, n_class): 48 | self.history = [] 49 | # [{m:[], w:[]}, {}] 50 | for i in range(n_class): 51 | self.history.append({'magnitudes': [], 52 | 'weights': [], 53 | 'var_magnitudes': [], 54 | 'var_weights': []}) 55 | 56 | def add(self, class_idx, m_mu, w_mu, m_std, w_std): 57 | if not isinstance(m_mu, list): # ugly way to bypass batch with single element 58 | return 59 | self.history[class_idx]['magnitudes'].append(m_mu) 60 | self.history[class_idx]['weights'].append(w_mu) 61 | self.history[class_idx]['var_magnitudes'].append(m_std) 62 | self.history[class_idx]['var_weights'].append(w_std) 63 | 64 | def save(self, class2label=None): 65 | path = os.path.join(self.save_dir, 'policy') 66 | vis_path = os.path.join(self.save_dir, 'vis_policy') 67 | os.makedirs(path, exist_ok=True) 68 | os.makedirs(vis_path, exist_ok=True) 69 | header = ','.join(self.op_names) 70 | for i, history in enumerate(self.history): 71 | k = i if class2label is None else class2label[i] 72 | np.savetxt(f'{path}/policy{i}({k})_magnitude.csv', 73 | history['magnitudes'], delimiter=',', header=header, comments='') 74 | np.savetxt(f'{path}/policy{i}({k})_weights.csv', 75 | history['weights'], delimiter=',', header=header, comments='') 76 | np.savetxt(f'{vis_path}/policy{i}({k})_var_magnitude.csv', 77 | history['var_magnitudes'], delimiter=',', header=header, comments='') 78 | np.savetxt(f'{vis_path}/policy{i}({k})_var_weights.csv', 79 | history['var_weights'], delimiter=',', header=header, comments='') 80 | 81 | def plot(self): 82 | PATH = self.save_dir 83 | mag_file_list = glob.glob(f'{PATH}/policy/*_magnitude.csv') 84 | weights_file_list = glob.glob(f'{PATH}/policy/*_weights.csv') 85 | n_class = len(mag_file_list) 86 | 87 | f, axes = plt.subplots(n_class, 2, figsize=(15, 5*n_class)) 88 | 89 | for i, file in enumerate(mag_file_list): 90 | df = pd.read_csv(file).dropna() 91 | x = range(0, len(df)) 92 | y = df.to_numpy().T 93 | axes[i][0].stackplot(x, y, labels=df.columns, edgecolor='none') 94 | axes[i][0].set_title(file.split('/')[-1][:-4]) 95 | 96 | for i, file in enumerate(weights_file_list): 97 | df = pd.read_csv(file).dropna() 98 | x = range(0, len(df)) 99 | y = df.to_numpy().T 100 | axes[i][1].stackplot(x, y, labels=df.columns, edgecolor='none') 101 | axes[i][1].set_title(file.split('/')[-1][:-4]) 102 | 103 | axes[-1][-1].legend(loc='upper center', bbox_to_anchor=(-0.1, -0.2), fancybox=True, shadow=True, ncol=10) 104 | plt.savefig(f'{PATH}/policy/schedule.png') 105 | 106 | f, axes = plt.subplots(1, 1, figsize=(7,5)) 107 | 108 | frames = [] 109 | for i, file in enumerate(mag_file_list): 110 | df = pd.read_csv(file).dropna() 111 | df['class'] = file.split('/')[-1][:-4].split('_')[0] 112 | frames.append(df.tail(1)) 113 | 114 | df = pd.concat(frames) 115 | df.set_index('class').plot(ax=axes, kind='bar', stacked=True, legend=False, rot=90, fontsize=8) 116 | axes.set_ylabel("magnitude") 117 | plt.savefig(f'{PATH}/policy/magnitude_by_class.png') 118 | 119 | f, axes = plt.subplots(1, 1, figsize=(7,5)) 120 | frames = [] 121 | for i, file in enumerate(weights_file_list): 122 | df = pd.read_csv(file).dropna() 123 | df['class'] = file.split('/')[-1][:-4].split('_')[0].split('(')[1][:-1] 124 | frames.append(df.tail(1)) 125 | 126 | df = pd.concat(frames) 127 | df.set_index('class').plot(ax=axes, kind='bar', stacked=True, legend=False, rot=90, fontsize=8) 128 | axes.set_ylabel("probability") 129 | axes.set_xlabel("") 130 | plt.savefig(f'{PATH}/policy/probability_by_class.png') 131 | 132 | return f 133 | 134 | 135 | class AvgrageMeter(object): 136 | 137 | def __init__(self): 138 | self.reset() 139 | 140 | def reset(self): 141 | self.avg = 0 142 | self.sum = 0 143 | self.cnt = 0 144 | 145 | def update(self, val, n=1): 146 | self.sum += val * n 147 | self.cnt += n 148 | self.avg = self.sum / self.cnt 149 | 150 | 151 | def accuracy(output, target, topk=(1,)): 152 | maxk = max(topk) 153 | batch_size = target.size(0) 154 | 155 | _, pred = output.topk(maxk, 1, True, True) 156 | pred = pred.t() 157 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 158 | 159 | res = [] 160 | 161 | for k in topk: 162 | correct_k = correct[:k].reshape(-1).float().sum(0) 163 | res.append(correct_k.mul_(100.0/batch_size)) 164 | return res 165 | 166 | 167 | class Cutout(object): 168 | def __init__(self, length): 169 | self.length = length 170 | 171 | def __call__(self, img): 172 | h, w = img.size(1), img.size(2) 173 | mask = np.ones((h, w), np.float32) 174 | y = np.random.randint(h) 175 | x = np.random.randint(w) 176 | 177 | y1 = np.clip(y - self.length // 2, 0, h) 178 | y2 = np.clip(y + self.length // 2, 0, h) 179 | x1 = np.clip(x - self.length // 2, 0, w) 180 | x2 = np.clip(x + self.length // 2, 0, w) 181 | 182 | mask[y1: y2, x1: x2] = 0. 183 | mask = torch.from_numpy(mask) 184 | mask = mask.expand_as(img) 185 | img *= mask 186 | return img 187 | 188 | 189 | def count_parameters_in_MB(model): 190 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6 191 | 192 | 193 | def drop_path(x, drop_prob): 194 | if drop_prob > 0.: 195 | keep_prob = 1.-drop_prob 196 | mask = Variable(torch.cuda.FloatTensor( 197 | x.size(0), 1, 1, 1).bernoulli_(keep_prob)) 198 | x.div_(keep_prob) 199 | x.mul_(mask) 200 | return x 201 | 202 | 203 | def create_exp_dir(path, scripts_to_save=None): 204 | if not os.path.exists(path): 205 | os.mkdir(path) 206 | print('Experiment dir : {}'.format(path)) 207 | 208 | if scripts_to_save is not None: 209 | os.mkdir(os.path.join(path, 'scripts')) 210 | for script in scripts_to_save: 211 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 212 | shutil.copyfile(script, dst_file) 213 | 214 | 215 | def reproducibility(seed): 216 | random.seed(seed) 217 | np.random.seed(seed) 218 | torch.manual_seed(seed) 219 | torch.cuda.manual_seed(seed) 220 | torch.cuda.manual_seed_all(seed) 221 | torch.backends.cudnn.deterministic = True 222 | torch.backends.cudnn.benchmark = False 223 | torch.autograd.set_detect_anomaly(True) 224 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | aiohttp==3.6.2 3 | astroid==2.3.3 4 | async-timeout==3.0.1 5 | attrs==19.3.0 6 | autopep8==1.5 7 | backcall==0.1.0 8 | beautifulsoup4==4.8.2 9 | bleach==3.1.1 10 | blessings==1.7 11 | cachetools==4.2.0 12 | calmsize==0.1.3 13 | certifi==2019.11.28 14 | chardet==3.0.4 15 | Click==7.0 16 | cloudpickle==1.3.0 17 | colorama==0.4.3 18 | contextvars==2.4 19 | cycler==0.10.0 20 | decorator==4.4.1 21 | defusedxml==0.6.0 22 | dill==0.3.3 23 | entrypoints==0.3 24 | filelock==3.0.12 25 | flake8==3.8.4 26 | funcsigs==1.0.2 27 | future==0.18.2 28 | google==2.0.3 29 | google-auth==1.24.0 30 | google-auth-oauthlib==0.4.2 31 | googletrans==3.0.0 32 | gpustat==0.6.0 33 | grpcio==1.27.2 34 | h11==0.9.0 35 | h2==3.2.0 36 | hpack==3.0.0 37 | hstspreload==2020.10.20 38 | httpcore==0.9.1 39 | httpx==0.13.3 40 | hyperframe==5.2.0 41 | idna==2.9 42 | idna-ssl==1.1.0 43 | immutables==0.14 44 | importlib-metadata==1.5.0 45 | ipykernel==5.1.4 46 | ipython==7.12.0 47 | ipython-genutils==0.2.0 48 | ipywidgets==7.5.1 49 | isort==4.3.21 50 | jedi==0.16.0 51 | Jinja2==2.11.1 52 | joblib==0.14.1 53 | jsonschema==3.2.0 54 | jupyter==1.0.0 55 | jupyter-client==6.0.0 56 | jupyter-console==6.1.0 57 | jupyter-contrib-core==0.3.3 58 | jupyter-contrib-nbextensions==0.5.1 59 | jupyter-core==4.6.3 60 | jupyter-highlight-selected-word==0.2.0 61 | jupyter-latex-envs==1.4.6 62 | jupyter-nbextensions-configurator==0.4.1 63 | jupyterthemes==0.20.0 64 | kiwisolver==1.1.0 65 | lazy-object-proxy==1.4.3 66 | lesscpy==0.14.0 67 | lxml==4.5.0 68 | Markdown==3.3.3 69 | MarkupSafe==1.1.1 70 | matplotlib==3.1.3 71 | mccabe==0.6.1 72 | mistune==0.8.4 73 | more-itertools==8.2.0 74 | multidict==4.7.5 75 | nbconvert==5.6.1 76 | nbformat==5.0.4 77 | nltk==3.5 78 | notebook==5.7.8 79 | numpy==1.18.1 80 | nvidia-ml-py3==7.352.0 81 | oauthlib==3.1.0 82 | packaging==20.1 83 | pandas==1.0.1 84 | pandocfilters==1.4.2 85 | parso==0.6.2 86 | pexpect==4.8.0 87 | pickleshare==0.7.5 88 | Pillow==7.1.1 89 | pluggy==0.13.1 90 | ply==3.11 91 | prometheus-client==0.7.1 92 | prompt-toolkit==3.0.3 93 | protobuf==3.11.3 94 | psutil==5.8.0 95 | ptyprocess==0.6.0 96 | py==1.8.1 97 | py-spy==0.3.3 98 | pyasn1==0.4.8 99 | pyasn1-modules==0.2.8 100 | pycodestyle==2.6.0 101 | pyflakes==2.2.0 102 | Pygments==2.5.2 103 | pylint==2.4.4 104 | pyparsing==2.4.6 105 | pyrsistent==0.15.7 106 | pytest==5.3.5 107 | python-dateutil==2.8.1 108 | pytorch-memlab==0.2.3 109 | pytz==2019.3 110 | PyYAML==5.3 111 | pyzmq==19.0.0 112 | qtconsole==4.6.0 113 | ray==0.8.2 114 | redis==3.4.1 115 | regex==2020.4.4 116 | requests==2.23.0 117 | requests-oauthlib==1.3.0 118 | rfc3986==1.4.0 119 | rope==0.16.0 120 | rsa==4.6 121 | scikit-learn==0.22.1 122 | scipy==1.4.1 123 | seaborn==0.10.0 124 | Send2Trash==1.5.0 125 | sentencepiece==0.1.86 126 | six==1.14.0 127 | sklearn==0.0 128 | sniffio==1.2.0 129 | soupsieve==2.0 130 | tabulate==0.8.6 131 | tensorboard==2.4.0 132 | tensorboard-plugin-wit==1.7.0 133 | tensorboardX==2.0 134 | terminado==0.8.3 135 | testpath==0.4.4 136 | torch==1.5.1 137 | torchtext==0.6.0 138 | torchvision==0.6.1 139 | tornado==6.0.3 140 | tqdm==4.45.0 141 | traitlets==4.3.3 142 | typed-ast==1.4.1 143 | typing-extensions==3.7.4.1 144 | urllib3==1.25.8 145 | warmup-scheduler @ git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git@6b5e8953a80aef5b324104dc0c2e9b8c34d622bd 146 | wcwidth==0.1.8 147 | webencodings==0.5.1 148 | Werkzeug==1.0.1 149 | widgetsnbextension==3.5.1 150 | wrapt==1.11.2 151 | yarl==1.4.2 152 | zipp==3.0.0 153 | -------------------------------------------------------------------------------- /scripts/search.sh: -------------------------------------------------------------------------------- 1 | GPU=0 2 | 3 | function run_reduced_svhn { 4 | DATASET=reduced_svhn 5 | MODEL=wresnet40_2 6 | EPOCH=160 7 | BATCH=128 8 | LR=0.05 9 | WD=0.01 10 | SLR=0.001 11 | CUTOUT=0 12 | SF=1 13 | } 14 | 15 | # cifar10 16 | function run_reduced_cifar10 { 17 | DATASET=reduced_cifar10 18 | MODEL=wresnet40_2 19 | EPOCH=200 20 | BATCH=128 21 | LR=0.1 22 | WD=0.0005 23 | SLR=0.001 24 | CUTOUT=16 25 | SF=3 26 | } 27 | 28 | if [ $1 = "reduced_cifar10" ]; then 29 | run_reduced_cifar10 30 | elif [ $1 = "reduced_svhn" ]; then 31 | run_reduced_svhn 32 | fi 33 | 34 | SAVE=${DATASET}_${MODEL}_${BATCH}_${EPOCH}_SLR${SLR}_SF${SF}_cutout_${CUTOUT}_lr${LR}_wd${WD} 35 | python ada_aug/search.py --k_ops ${KOPS} --report_freq 10 --num_workers 4 --epochs ${EPOCH} --batch_size ${BATCH} --learning_rate ${LR} --dataset ${DATASET} --model_name ${MODEL} --save ${SAVE} --gpu ${GPU} --weight_decay ${WD} --proj_learning_rate ${SLR} --search_freq ${SF} --cutout --cutout_length ${CUTOUT} --temperature ${TEMP} 36 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | GPU=0 2 | 3 | function run_reduced_svhn { 4 | DATASET=reduced_svhn 5 | MODEL=wresnet28_10 6 | EPOCH=180 7 | BATCH=128 8 | LR=0.05 9 | WD=0.01 10 | CUTOUT=0 11 | TRAIN_PORTION=1 12 | GF= 13 | H= 14 | SDN=reduced_svhn 15 | GFN=wresnet40_2 16 | DELTA=0.4 17 | TEMP=2 18 | KOPS=2 19 | } 20 | 21 | function run_reduced_cifar10 { 22 | DATASET=reduced_cifar10 23 | MODEL=wresnet28_10 24 | EPOCH=240 25 | BATCH=128 26 | LR=0.1 27 | WD=0.0005 28 | CUTOUT=16 29 | TRAIN_PORTION=1 30 | GF= 31 | H= 32 | SDN=reduced_cifar10 33 | GFN=wresnet40_2 34 | DELTA=0.3 35 | TEMP=3 36 | KOPS=2 37 | } 38 | 39 | if [ $1 = "reduced_cifar10" ]; then 40 | run_reduced_cifar10 41 | elif [ $1 = "reduced_svhn" ]; then 42 | run_reduced_svhn 43 | fi 44 | 45 | SAVE=${DATASET}_${MODEL}_${BATCH}_${EPOCH}_cutout_${CUTOUT}_lr${LR}_wd${WD}_kops_${KOPS}_TEMP_${TEMP}_${DELTA} 46 | python ada_aug/train.py --temperature ${TEMP} --delta ${DELTA} --search_dataset ${SDN} --gf_model_path ${GF} --h_model_path ${H} --gf_model_name ${GFN} --k_ops ${KOPS} --report_freq 10 --num_workers 8 --epochs ${EPOCH} --batch_size ${BATCH} --learning_rate ${LR} --dataset ${DATASET} --model_name ${MODEL} --save ${SAVE} --gpu ${GPU} --weight_decay ${WD} --train_portion 1 --use_parallel 47 | --------------------------------------------------------------------------------