├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── attacker.py
├── imagenet.py
├── net.py
├── run.sh
└── utils
├── __init__.py
├── eval.py
├── fastaug
├── augmentations.py
└── fastaug.py
├── images
├── advresult.png
├── cifar.png
└── imagenet.png
├── logger.py
├── misc.py
├── mix_dataloader.py
├── progress
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.rst
├── demo.gif
├── progress
│ ├── __init__.py
│ ├── bar.py
│ ├── counter.py
│ ├── helpers.py
│ └── spinner.py
├── setup.py
└── test_progress.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # tmp dirs and files
2 | checkpoint
3 | checkpoints
4 | data
5 | cifar-debug.py
6 | test.eps
7 | dev
8 | monitor.py
9 | exp
10 |
11 | # Byte-compiled / optimized / DLL files
12 | __pycache__/
13 | *.py[cod]
14 | *$py.class
15 |
16 | # C extensions
17 | *.so
18 |
19 | # Distribution / packaging
20 | .Python
21 | env/
22 | build/
23 | develop-eggs/
24 | dist/
25 | downloads/
26 | eggs/
27 | .eggs/
28 | lib/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *,cover
56 | .hypothesis/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # IPython Notebook
80 | .ipynb_checkpoints
81 |
82 | # pyenv
83 | .python-version
84 |
85 | # celery beat schedule file
86 | celerybeat-schedule
87 |
88 | # dotenv
89 | .env
90 |
91 | # virtualenv
92 | venv/
93 | ENV/
94 |
95 | # Spyder project settings
96 | .spyderproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # ide and editor settings
102 | .idea/
103 | .vscode/
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "utils/progress"]
2 | path = utils/progress
3 | url = https://github.com/verigak/progress.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Wei Yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-classification-advprop
2 |
3 | **In [this repository](https://github.com/meijieru/fast_advprop), there will be an implementation of Fast AdvProp by Jieru Mei, which is accepted by ICLR 2022.**
4 |
5 | A PyTorch implementation of CVPR2020 paper Adversarial examples improve image recognition by Xie C, Tan M, Gong B, et al.
6 |
7 | Thanks for guidance from Cihang Xie and Yingwei Li. The code is adapted from https://github.com/bearpaw/pytorch-classification.
8 |
9 |
10 | ## Features
11 | * Multi-GPU support
12 | * Training progress bar with rich info
13 | * Training log and training curve visualization code (see `./utils/logger.py`)
14 | * Training log using tensorboardX
15 |
16 | ## Environments
17 | This project is developed and tested in the following environments.
18 | * Ubuntu 16.04
19 | * CUDA 10.0.130
20 | * TITAN Xp
21 | * Python 3.8.1
22 |
23 | ## Requirements
24 | * Install [matplotlib]
25 | * Install [numpy]
26 | * Install [PyTorch]
27 | * Install [torchvision]
28 | * Install [tensorboardX]
29 |
30 | ## Training
31 | This is the example command for training:
32 | ```
33 | python imagenet.py --attack-iter 1 --attack-epsilon 1 --attack-step-size 1 -a resnet50 --train-batch 256 --num_classes 1000 --data /path/of/ImageNet --epochs 105 --schedule 30 60 90 100 --gamma 0.1 -c checkpoints/imagenet/advresnet-resnet50-smoothing --gpu-id 0,1,2,3,4,5,6,7 --lr_schedule step --mixbn
34 | ```
35 | For all the commonly used parameters:
36 | * --attack-iter: the number of PGD attacker iterations. If you want to train model without AdvProp, you should set attack-iter as zero.
37 | * --attack-epsilon: the maximum range of the noise after accumulation during all the iterations.
38 | * --attack-step-size: the step size of PGD attacker, which means the range of the noise during each iteration.
39 | * -a: architecture of the model, we can choose the model in [torchvision.models.resnet](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py).
40 | * --train-batch: total batches on all the GPUs.
41 | * --num_classes: number of classes in the dataset, for ImageNet this value should be set as 1000.
42 | * --data: path to your ImageNet dataset.
43 | * --epochs: number of epochs.
44 | * --schedule: if lr_schedule is set as 'step', then the learning rate will multiply the parameter gamma at the specific epoch as listed in schedule.
45 | * --gpu-id: ids for the GPUs you could use.
46 | * --lr_schedule: the training sheme. I recommand 'step' settings.
47 |
48 | More options could be seen in the code.
49 | ## Accuracy curve
50 |
51 | The figure below is the comparasion between ResNet-50 with standard training and advprop (using PGD-1). The red curve is the test accuracy with advprop and the orange curve is the accuracy with standard training.
52 |
53 | It is obivous that the test accuracy under the advprop setting consistenly exceeds that under standard trainig. Another notable feature is that the accuracy using main batch normalization consistenly exceeds that using auxiliary batch normalization.
54 |
55 | 
56 |
57 | The settings are the same as in run.sh. Finally, the ResNet-50 top1 test accuracy using standard training is 76.67%, and that using advprop is 77.42%.
58 |
59 | If possible, we will provide more results in the future.
60 |
61 |
62 | ## Contribute
63 | Feel free to create a pull request if you find any bugs or you want to contribute.
64 |
--------------------------------------------------------------------------------
/attacker.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | IMAGE_SCALE = 2.0/255
5 |
6 |
7 | def get_kernel(size, nsig, mode='gaussian', device='cuda:0'):
8 | if mode == 'gaussian':
9 | # since we have to normlize all the numbers
10 | # there is no need to calculate the const number like \pi and \sigma.
11 | vec = torch.linspace(-nsig, nsig, steps=size).to(device)
12 | vec = torch.exp(-vec*vec/2)
13 | res = vec.view(-1, 1) @ vec.view(1, -1)
14 | res = res / torch.sum(res)
15 | elif mode == 'linear':
16 | # originally, res[i][j] = (1-|i|/(k+1)) * (1-|j|/(k+1))
17 | # since we have to normalize it
18 | # calculate res[i][j] = (k+1-|i|)*(k+1-|j|)
19 | vec = (size+1)/2 - torch.abs(torch.arange(-(size+1)/2, (size+1)/2+1, step=1)).to(device)
20 | res = vec.view(-1, 1) @ vec.view(1, -1)
21 | res = res / torch.sum(res)
22 | else:
23 | raise ValueError("no such mode in get_kernel.")
24 | return res
25 |
26 |
27 | class NoOpAttacker():
28 |
29 | def attack(self, image, label, model):
30 | return image, -torch.ones_like(label)
31 |
32 |
33 | class PGDAttacker():
34 | def __init__(self, num_iter, epsilon, step_size, kernel_size=15, prob_start_from_clean=0.0, translation=False, device='cuda:0'):
35 | step_size = max(step_size, epsilon / num_iter)
36 | self.num_iter = num_iter
37 | self.epsilon = epsilon * IMAGE_SCALE
38 | self.step_size = step_size*IMAGE_SCALE
39 | self.prob_start_from_clean = prob_start_from_clean
40 | self.device=device
41 | self.translation = translation
42 | if translation:
43 | # this is equivalent to deepth wise convolution
44 | # details can be found in the docs of Conv2d.
45 | # "When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, this operation is also termed in literature as depthwise convolution."
46 | self.conv = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=kernel_size, stride=(kernel_size-1)//2, bias=False, groups=3).to(self.device)
47 | self.gkernel = get_kernel(kernel_size, nsig=3, device=self.device).to(self.device)
48 | self.conv.weight = self.gkernel
49 |
50 | def _create_random_target(self, label):
51 | label_offset = torch.randint_like(label, low=0, high=1000)
52 | return (label + label_offset) % 1000
53 |
54 | def attack(self, image_clean, label, model, original=False):
55 | if original:
56 | target_label = label
57 | else:
58 | target_label = self._create_random_target(label)
59 | lower_bound = torch.clamp(image_clean - self.epsilon, min=-1., max=1.)
60 | upper_bound = torch.clamp(image_clean + self.epsilon, min=-1., max=1.)
61 |
62 | ori_images = image_clean.clone().detach()
63 |
64 | init_start = torch.empty_like(image_clean).uniform_(-self.epsilon, self.epsilon)
65 |
66 | start_from_noise_index = (torch.randn([])>self.prob_start_from_clean).float()
67 | start_adv = image_clean + start_from_noise_index * init_start
68 |
69 | adv = start_adv
70 | for i in range(self.num_iter):
71 | adv.requires_grad = True
72 | logits = model(adv)
73 | losses = F.cross_entropy(logits, target_label)
74 | g = torch.autograd.grad(losses, adv,
75 | retain_graph=False, create_graph=False)[0]
76 | if self.translation:
77 | g = self.conv(g)
78 | if original:
79 | adv = adv + torch.sign(g)*self.step_size
80 | else:
81 | adv = adv - torch.sign(g) * self.step_size
82 | adv = torch.where(adv > lower_bound, adv, lower_bound)
83 | adv = torch.where(adv < upper_bound, adv, upper_bound).detach()
84 |
85 | return adv, target_label
--------------------------------------------------------------------------------
/imagenet.py:
--------------------------------------------------------------------------------
1 | '''
2 | Training script for ImageNet
3 | Copyright (c) Wei YANG, 2017
4 | '''
5 | from __future__ import print_function
6 |
7 | import numpy as np
8 | from PIL import ImageFile
9 |
10 |
11 |
12 | ImageFile.LOAD_TRUNCATED_IMAGES = True
13 |
14 | import argparse
15 | import math
16 | import os
17 | import shutil
18 | import time
19 | import random
20 | from functools import partial
21 |
22 | from tensorboardX import SummaryWriter
23 | import torch
24 | import torch.nn as nn
25 | import torch.nn.functional as F
26 | import torch.nn.parallel
27 | import torch.backends.cudnn as cudnn
28 | import torch.optim as optim
29 | import torch.utils.data as data
30 | import torchvision.transforms as transforms
31 | import torchvision.datasets as datasets
32 | from torch.optim.lr_scheduler import _LRScheduler
33 |
34 | from attacker import NoOpAttacker, PGDAttacker
35 | import net
36 | from utils import Bar, Logger, AverageMeter, accuracy, mkdir_p, savefig
37 | from utils.fastaug.fastaug import FastAugmentation
38 | from utils.fastaug.augmentations import Lighting
39 |
40 |
41 | def to_status(m, status):
42 | if hasattr(m, 'batch_type'):
43 | m.batch_type = status
44 |
45 |
46 | to_clean_status = partial(to_status, status='clean')
47 | to_adv_status = partial(to_status, status='adv')
48 | to_mix_status = partial(to_status, status='mix')
49 |
50 | # Models
51 | default_model_names = sorted(name for name in net.__dict__ if name.islower() and not name.startswith('__') and callable(net.__dict__[name]) and not name.startswith("to_") and not name.startswith("partial"))
52 |
53 | model_names = default_model_names
54 |
55 | # Parse arguments
56 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
57 |
58 | # Datasets
59 | parser.add_argument('-d', '--data', default='path to dataset', type=str)
60 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
61 | help='number of data loading workers (default: 4)')
62 | # Optimization options
63 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
64 | help='number of total epochs to run')
65 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
66 | help='manual epoch number (useful on restarts)')
67 | parser.add_argument('--train-batch', default=256, type=int, metavar='N',
68 | help='train batchsize (default: 256)')
69 | parser.add_argument('--test-batch', default=200, type=int, metavar='N',
70 | help='test batchsize (default: 200)')
71 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
72 | metavar='LR', help='initial learning rate')
73 | parser.add_argument('--drop', '--dropout', default=0, type=float,
74 | metavar='Dropout', help='Dropout ratio')
75 | parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225],
76 | help='Decrease learning rate at these epochs.')
77 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
78 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
79 | help='momentum')
80 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
81 | metavar='W', help='weight decay (default: 1e-4)')
82 | # Checkpoints
83 | parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
84 | help='path to save checkpoint (default: checkpoint)')
85 |
86 | # commented by HYC
87 | # the learning rate of the setting 'step' cannot be handled automatically,
88 | # so you should change --lr as you wanted,
89 | # but you don't need to change other settings.
90 | # more information can be referred in the function adjust_learning_rate
91 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
92 | help='path to latest checkpoint (default: none)')
93 | parser.add_argument('--load', default='', type=str)
94 | # Architecture
95 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
96 | choices=model_names,
97 | help='model architecture: ' +
98 | ' | '.join(model_names) +
99 | ' (default: resnet18)')
100 | # Miscs
101 | parser.add_argument('--manualSeed', type=int, help='manual seed')
102 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
103 | help='evaluate model on validation set')
104 | #Device options
105 | parser.add_argument('--gpu-id', default='0', type=str,
106 | help='id(s) for CUDA_VISIBLE_DEVICES')
107 |
108 | #Add by YW
109 | parser.add_argument('--warm', default=5, type=int, help='warm up epochs')
110 | parser.add_argument('--warm_lr', default=0.1, type=float, help='warm up start lr')
111 | parser.add_argument('--num_classes', default=200, type=int, help='number of classes')
112 | parser.add_argument('--mixbn', action='store_true', help='use mixbn')
113 | parser.add_argument('--lr_schedule', type=str, default='step', choices=['step', 'cos'])
114 | parser.add_argument('--fastaug', action='store_true')
115 | parser.add_argument('--already224', action='store_true')
116 | # added by HYC, training options, you'd better set smoothing to improve the accuracy.
117 | # but nesterov and lighting make the training too slow and don't have much improvement.
118 | parser.add_argument('--nesterov', action='store_true')
119 | parser.add_argument('--lighting', action='store_true')
120 | parser.add_argument('--smoothing', type=float, default=0)
121 | # added by HYC, attacker options
122 | parser.add_argument('--attack-iter', help='Adversarial attack iteration', type=int, default=0)
123 | parser.add_argument('--attack-epsilon', help='Adversarial attack maximal perturbation', type=float, default=1.0)
124 | parser.add_argument('--attack-step-size', help='Adversarial attack step size', type=float, default=1.0)
125 |
126 | args = parser.parse_args()
127 | state = {k: v for k, v in args._get_kwargs()}
128 |
129 | # Use CUDA
130 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
131 | use_cuda = torch.cuda.is_available()
132 |
133 | # Random seed
134 | if args.manualSeed is None:
135 | args.manualSeed = random.randint(1, 10000)
136 | random.seed(args.manualSeed)
137 | torch.manual_seed(args.manualSeed)
138 | if use_cuda:
139 | torch.cuda.manual_seed_all(args.manualSeed)
140 |
141 | best_acc = 0 # best test accuracy
142 |
143 | def main():
144 | global best_acc, state
145 | start_epoch = args.start_epoch # start from epoch 0 or last checkpoint epoch
146 |
147 | if args.attack_iter == 0:
148 | attacker = NoOpAttacker()
149 | else:
150 | attacker = PGDAttacker(args.attack_iter, args.attack_epsilon, args.attack_step_size, prob_start_from_clean=0.2 if not args.evaluate else 0.0)
151 |
152 | if not os.path.isdir(args.checkpoint):
153 | mkdir_p(args.checkpoint)
154 |
155 | # Data loading code
156 | traindir = os.path.join(args.data, 'train')
157 | valdir = os.path.join(args.data, 'val')
158 |
159 | # the mean and variant don't have too much influence
160 | # (pic - 0.5) / 0.5 just make it easier for attacker.
161 |
162 | # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
163 | # std=[0.229, 0.224, 0.225])
164 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
165 | std=[0.5, 0.5, 0.5])
166 |
167 | transform_train = transforms.Compose([
168 | transforms.RandomSizedCrop(224),
169 | transforms.RandomHorizontalFlip(),
170 | transforms.ToTensor(),
171 | normalize,
172 | ])
173 | if args.fastaug:
174 | transform_train.transforms.insert(0, FastAugmentation())
175 | if args.lighting:
176 | __imagenet_pca = {
177 | 'eigval': np.array([0.2175, 0.0188, 0.0045]),
178 | 'eigvec': np.array([
179 | [-0.5675, 0.7192, 0.4009],
180 | [-0.5808, -0.0045, -0.8140],
181 | [-0.5836, -0.6948, 0.4203],
182 | ])
183 | }
184 | transform_train = transforms.Compose([
185 | transforms.RandomSizedCrop(224),
186 | transforms.RandomHorizontalFlip(),
187 | transforms.ToTensor(),
188 | Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']),
189 | normalize
190 | ])
191 | train_dataset = datasets.ImageFolder(traindir, transform_train)
192 | train_loader = torch.utils.data.DataLoader((train_dataset),
193 | batch_size=args.train_batch, shuffle=True,
194 | num_workers=args.workers, pin_memory=True)
195 |
196 | val_transforms = [
197 | transforms.ToTensor(),
198 | normalize,
199 | ]
200 | if not args.already224:
201 | val_transforms = [transforms.Scale(256), transforms.CenterCrop(224)] + val_transforms
202 | val_loader = torch.utils.data.DataLoader(
203 | datasets.ImageFolder(valdir, transforms.Compose(val_transforms)),
204 | batch_size=args.test_batch, shuffle=False,
205 | num_workers=args.workers, pin_memory=True)
206 |
207 | # create model
208 | print("=> creating model '{}'".format(args.arch))
209 | if args.mixbn:
210 | norm_layer = MixBatchNorm2d
211 | else:
212 | norm_layer = None
213 | model = net.__dict__[args.arch](num_classes=args.num_classes, norm_layer=norm_layer)
214 | model.set_attacker(attacker)
215 | model.set_mixbn(args.mixbn)
216 |
217 | model = torch.nn.DataParallel(model).cuda()
218 |
219 | cudnn.benchmark = True
220 | print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
221 |
222 | # define loss function (criterion) and optimizer
223 | if args.smoothing == 0:
224 | criterion = nn.CrossEntropyLoss(reduction='none').cuda()
225 | # implement a cross_entropy with label smoothing.
226 | # First, perform a log_softmax; then fill the selected classes with 1-smoothing
227 | # At last, use kl_div, which means:
228 | # KL(p||q) = -\int p(x)ln q(x) dx - (-\int p(x)ln p(x) dx)
229 | # kl_div is different from Crossentropy with a const number (\int p(x) ln p(x))
230 | else:
231 | criterion = partial(label_smoothing_cross_entropy, classes=args.num_classes, dim=-1)
232 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)
233 |
234 | # Resume
235 | title = 'ImageNet-' + args.arch
236 | if args.resume:
237 | # Load checkpoint.
238 | print('==> Resuming from checkpoint..')
239 | assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
240 | args.checkpoint = os.path.dirname(args.resume)
241 | checkpoint = torch.load(args.resume)
242 | best_acc = checkpoint['best_acc']
243 | start_epoch = checkpoint['epoch']
244 | model.load_state_dict(checkpoint['state_dict'])
245 | optimizer.load_state_dict(checkpoint['optimizer'])
246 | for param_group in optimizer.param_groups:
247 | param_group['lr'] = state['lr']
248 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
249 | else:
250 | if args.load:
251 | checkpoint = torch.load(args.load)
252 | if args.mixbn:
253 | to_merge = {}
254 | for key in checkpoint['state_dict']:
255 | if 'bn' in key:
256 | tmp = key.split("bn")
257 | aux_key = tmp[0] + 'bn' + tmp[1][0] + '.aux_bn' + tmp[1][1:]
258 | to_merge[aux_key] = checkpoint['state_dict'][key]
259 | elif 'downsample.1' in key:
260 | tmp = key.split("downsample.1")
261 | aux_key = tmp[0] + 'downsample.1.aux_bn' + tmp[1]
262 | to_merge[aux_key] = checkpoint['state_dict'][key]
263 | checkpoint['state_dict'].update(to_merge)
264 |
265 | model.load_state_dict(checkpoint['state_dict'])
266 | logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
267 | logger.set_names(['Learning Rate', 'Train Loss', 'Valid Loss', 'Train Acc.', 'Valid Acc.'])
268 |
269 | if args.evaluate:
270 | print('\nEvaluation only')
271 | test_loss, test_acc = test(val_loader, model, criterion, start_epoch, use_cuda)
272 | print(' Test Loss: %.8f, Test Acc: %.2f' % (test_loss, test_acc))
273 | return
274 |
275 |
276 | # Train and val
277 | writer = SummaryWriter(log_dir=args.checkpoint)
278 | warmup_scheduler = WarmUpLR(optimizer, len(train_loader) * args.warm, start_lr=args.warm_lr) if args.warm > 0 else None
279 | for epoch in range(start_epoch, args.epochs):
280 | if epoch >= args.warm and args.lr_schedule == 'step':
281 | adjust_learning_rate(optimizer, epoch, args)
282 |
283 | print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, optimizer.param_groups[-1]['lr']))
284 |
285 | train_func = partial(train, train_loader=train_loader, model=model, criterion=criterion,
286 | optimizer=optimizer, epoch=epoch, use_cuda=use_cuda,
287 | warmup_scheduler=warmup_scheduler, mixbn=args.mixbn,
288 | writer=writer, attacker=attacker)
289 | if args.mixbn:
290 | model.apply(to_mix_status)
291 | train_loss, train_acc, loss_main, loss_aux, top1_main, top1_aux = train_func()
292 | else:
293 | train_loss, train_acc = train_func()
294 | writer.add_scalar('Train/loss', train_loss, epoch)
295 | writer.add_scalar('Train/acc', train_acc, epoch)
296 |
297 | if args.mixbn:
298 | writer.add_scalar('Train/loss_main', loss_main, epoch)
299 | writer.add_scalar('Train/loss_aux', loss_aux, epoch)
300 | writer.add_scalar('Train/acc_main', top1_main, epoch)
301 | writer.add_scalar('Train/acc_aux', top1_aux, epoch)
302 | model.apply(to_clean_status)
303 | test_loss, test_acc = test(val_loader, model, criterion, epoch, use_cuda)
304 | writer.add_scalar('Test/loss', test_loss, epoch)
305 | writer.add_scalar('Test/acc', test_acc, epoch)
306 |
307 | # append logger file
308 | logger.append([state['lr'], train_loss, test_loss, train_acc, test_acc])
309 |
310 | # save model
311 | is_best = test_acc > best_acc
312 | best_acc = max(test_acc, best_acc)
313 | save_checkpoint({
314 | 'epoch': epoch + 1,
315 | 'state_dict': model.state_dict(),
316 | 'acc': test_acc,
317 | 'best_acc': best_acc,
318 | 'optimizer' : optimizer.state_dict(),
319 | }, is_best, checkpoint=args.checkpoint)
320 |
321 | print('Best acc:')
322 | print(best_acc)
323 | writer.close()
324 | logger.close()
325 | logger.plot()
326 | savefig(os.path.join(args.checkpoint, 'log.eps'))
327 |
328 |
329 | def train(train_loader, model, criterion, optimizer, epoch, use_cuda, warmup_scheduler, mixbn=False,
330 | writer=None, attacker=NoOpAttacker()):
331 | # switch to train mode
332 | model.train()
333 |
334 | batch_time = AverageMeter()
335 | data_time = AverageMeter()
336 | losses = AverageMeter()
337 | top1 = AverageMeter()
338 | top5 = AverageMeter()
339 | if mixbn:
340 | losses_main = AverageMeter()
341 | losses_aux = AverageMeter()
342 | top1_main = AverageMeter()
343 | top1_aux = AverageMeter()
344 | end = time.time()
345 |
346 | bar = Bar('Processing', max=len(train_loader))
347 | for batch_idx, (inputs, targets) in enumerate(train_loader):
348 | if epoch < args.warm:
349 | warmup_scheduler.step()
350 | elif args.lr_schedule == 'cos':
351 | adjust_learning_rate(optimizer, epoch, args, batch=batch_idx, nBatch=len(train_loader))
352 |
353 | # measure data loading time
354 | data_time.update(time.time() - end)
355 |
356 | if use_cuda:
357 | inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
358 |
359 | # you'd better see the code in net.py to understand what it does when attacker is PGD attacker.
360 | # the advprop part is done inside forward function.
361 | # if the advprop part is set outside the forward function, the way to concatenate the batches costs
362 | # more time. (around 10 minutes per epoch)
363 | outputs, targets = model(inputs, targets)
364 | if args.mixbn:
365 | outputs = outputs.transpose(1, 0).contiguous().view(-1, args.num_classes)
366 | targets = targets.transpose(1, 0).contiguous().view(-1)
367 | loss = criterion(outputs, targets).mean()
368 |
369 | # measure accuracy and record loss
370 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
371 | losses.update(loss.item(), outputs.size(0))
372 | top1.update(prec1.item(), outputs.size(0))
373 | top5.update(prec5.item(), outputs.size(0))
374 |
375 | if mixbn:
376 | with torch.no_grad():
377 | batch_size = outputs.size(0)
378 | loss_main = criterion(outputs[:batch_size // 2], targets[:batch_size // 2]).mean()
379 | loss_aux = criterion(outputs[batch_size // 2:], targets[batch_size // 2:]).mean()
380 | prec1_main = accuracy(outputs.data[:batch_size // 2],
381 | targets.data[:batch_size // 2], topk=(1,))[0]
382 | prec1_aux = accuracy(outputs.data[batch_size // 2:],
383 | targets.data[batch_size // 2:], topk=(1,))[0]
384 | losses_main.update(loss_main.item(), batch_size // 2)
385 | losses_aux.update(loss_aux.item(), batch_size // 2)
386 | top1_main.update(prec1_main.item(), batch_size // 2)
387 | top1_aux.update(prec1_aux.item(), batch_size // 2)
388 |
389 | # compute gradient and do SGD step
390 | optimizer.zero_grad()
391 | loss.backward()
392 | optimizer.step()
393 |
394 | # measure elapsed time
395 | batch_time.update(time.time() - end)
396 | end = time.time()
397 |
398 | # plot progress
399 | if not mixbn:
400 | loss_str = "{:.4f}".format(losses.avg)
401 | top1_str = "{:.4f}".format(top1.avg)
402 | else:
403 | loss_str = "{:.2f}/{:.2f}/{:.2f}".format(losses.avg, losses_main.avg, losses_aux.avg)
404 | top1_str = "{:.2f}/{:.2f}/{:.2f}".format(top1.avg, top1_main.avg, top1_aux.avg)
405 | bar.suffix = '({batch}/{size}) Data: {data:.2f}s | Batch: {bt:.2f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:s} | top1: {top1:s} | top5: {top5: .1f}'.format(
406 | batch=batch_idx + 1,
407 | size=len(train_loader),
408 | data=data_time.val,
409 | bt=batch_time.val,
410 | total=bar.elapsed_td,
411 | eta=bar.eta_td,
412 | loss=loss_str,
413 | top1=top1_str,
414 | top5=top5.avg,
415 | )
416 | bar.next()
417 | bar.finish()
418 | if mixbn:
419 | return losses.avg, top1.avg, losses_main.avg, losses_aux.avg, top1_main.avg, top1_aux.avg
420 | else:
421 | return (losses.avg, top1.avg)
422 |
423 | def test(val_loader, model, criterion, epoch, use_cuda):
424 | global best_acc
425 |
426 | batch_time = AverageMeter()
427 | data_time = AverageMeter()
428 | losses = AverageMeter()
429 | top1 = AverageMeter()
430 | top5 = AverageMeter()
431 |
432 | # switch to evaluate mode
433 | model.eval()
434 |
435 | end = time.time()
436 | bar = Bar('Processing', max=len(val_loader))
437 | for batch_idx, (inputs, targets) in enumerate(val_loader):
438 | # measure data loading time
439 | data_time.update(time.time() - end)
440 |
441 | if use_cuda:
442 | inputs, targets = inputs.cuda(), targets.cuda()
443 | inputs, targets = torch.autograd.Variable(inputs, volatile=True), torch.autograd.Variable(targets)
444 |
445 | # compute output
446 | with torch.no_grad():
447 | outputs, targets = model(inputs, targets)
448 | loss = criterion(outputs, targets).mean()
449 |
450 | # measure accuracy and record loss
451 | prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
452 | losses.update(loss.item(), inputs.size(0))
453 | top1.update(prec1.item(), inputs.size(0))
454 | top5.update(prec5.item(), inputs.size(0))
455 |
456 | # measure elapsed time
457 | batch_time.update(time.time() - end)
458 | end = time.time()
459 |
460 | # plot progress
461 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
462 | batch=batch_idx + 1,
463 | size=len(val_loader),
464 | data=data_time.avg,
465 | bt=batch_time.avg,
466 | total=bar.elapsed_td,
467 | eta=bar.eta_td,
468 | loss=losses.avg,
469 | top1=top1.avg,
470 | top5=top5.avg,
471 | )
472 | bar.next()
473 | bar.finish()
474 | return (losses.avg, top1.avg)
475 |
476 |
477 | def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='checkpoint.pth.tar'):
478 | filepath = os.path.join(checkpoint, filename)
479 | torch.save(state, filepath)
480 | if is_best:
481 | shutil.copyfile(filepath, os.path.join(checkpoint, 'model_best.pth.tar'))
482 |
483 |
484 | def adjust_learning_rate(optimizer, epoch, args, batch=None, nBatch=None):
485 | global state
486 | if args.lr_schedule == 'cos':
487 | T_total = args.epochs * nBatch
488 | T_cur = (epoch % args.epochs) * nBatch + batch
489 | state['lr'] = 0.5 * args.lr * (1 + math.cos(math.pi * T_cur / T_total))
490 | elif args.lr_schedule == 'step':
491 | if epoch in args.schedule:
492 | state['lr'] *= args.gamma
493 | else:
494 | raise NotImplementedError
495 | for param_group in optimizer.param_groups:
496 | param_group['lr'] = state['lr']
497 |
498 |
499 | def label_smoothing_cross_entropy(pred, target, classes, dim, reduction='batchmean', smoothing=0.1):
500 | '''
501 | adopted from https://github.com/OpenNMT/OpenNMT-py/blob/e8622eb5c6117269bb3accd8eb6f66282b5e67d9/onmt/utils/loss.py#L186
502 | and https://github.com/pytorch/pytorch/issues/7455
503 | '''
504 | confidence = 1.0-smoothing
505 | pred = pred.log_softmax(dim=dim)
506 | with torch.no_grad():
507 | true_dist = torch.zeros_like(pred)
508 | true_dist.fill_(smoothing / (classes -1))
509 | true_dist.scatter_(1, target.data.unsqueeze(1), confidence)
510 | return F.kl_div(pred, true_dist, reduction=reduction)
511 |
512 |
513 | class WarmUpLR(_LRScheduler):
514 | """warmup_training learning rate scheduler
515 | Args:
516 | optimizer: optimzier(e.g. SGD)
517 | total_iters: totoal_iters of warmup phase
518 | """
519 |
520 | def __init__(self, optimizer, total_iters, last_epoch=-1, start_lr=0.1):
521 | self.total_iters = total_iters
522 | self.start_lr = start_lr
523 | super().__init__(optimizer, last_epoch)
524 |
525 | def get_lr(self):
526 | """we will use the first m batches, and set the learning
527 | rate to base_lr * m / total_iters
528 | """
529 | ret = [self.start_lr + (base_lr - self.start_lr) * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]
530 | return ret
531 |
532 |
533 | class MixBatchNorm2d(nn.BatchNorm2d):
534 | '''
535 | if the dimensions of the tensors from dataloader is [N, 3, 224, 224]
536 | that of the inputs of the MixBatchNorm2d should be [2*N, 3, 224, 224].
537 |
538 | If you set batch_type as 'mix', this network will using one batchnorm (main bn) to calculate the features corresponding to[:N, 3, 224, 224],
539 | while using another batch normalization (auxiliary bn) for the features of [N:, 3, 224, 224].
540 | During training, the batch_type should be set as 'mix'.
541 |
542 | During validation, we only need the results of the features using some specific batchnormalization.
543 | if you set batch_type as 'clean', the features are calculated using main bn; if you set it as 'adv', the features are calculated using auxiliary bn.
544 |
545 | Usually, we use to_clean_status, to_adv_status, and to_mix_status to set the batch_type recursively. It should be noticed that the batch_type should be set as 'adv' while attacking.
546 | '''
547 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
548 | track_running_stats=True):
549 | super(MixBatchNorm2d, self).__init__(
550 | num_features, eps, momentum, affine, track_running_stats)
551 | self.aux_bn = nn.BatchNorm2d(num_features, eps=eps, momentum=momentum, affine=affine,
552 | track_running_stats=track_running_stats)
553 | self.batch_type = 'clean'
554 |
555 | def forward(self, input):
556 | if self.batch_type == 'adv':
557 | input = self.aux_bn(input)
558 | elif self.batch_type == 'clean':
559 | input = super(MixBatchNorm2d, self).forward(input)
560 | else:
561 | assert self.batch_type == 'mix'
562 | batch_size = input.shape[0]
563 | # input0 = self.aux_bn(input[: batch_size // 2])
564 | # input1 = super(MixBatchNorm2d, self).forward(input[batch_size // 2:])
565 | input0 = super(MixBatchNorm2d, self).forward(input[:batch_size // 2])
566 | input1 = self.aux_bn(input[batch_size // 2:])
567 | input = torch.cat((input0, input1), 0)
568 | return input
569 |
570 |
571 |
572 |
573 | if __name__ == '__main__':
574 | main()
575 |
--------------------------------------------------------------------------------
/net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | # from torchvision.models.utils import load_state_dict_from_url
4 | from functools import partial
5 |
6 | from attacker import PGDAttacker, NoOpAttacker
7 |
8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
9 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
10 | 'wide_resnet50_2', 'wide_resnet101_2']
11 |
12 |
13 | model_urls = {
14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
19 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
20 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
21 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
22 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
23 | }
24 |
25 |
26 | def to_status(m, status):
27 | if hasattr(m, 'batch_type'):
28 | m.batch_type = status
29 |
30 |
31 | to_clean_status = partial(to_status, status='clean')
32 | to_adv_status = partial(to_status, status='adv')
33 | to_mix_status = partial(to_status, status='mix')
34 |
35 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
36 | """3x3 convolution with padding"""
37 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
38 | padding=dilation, groups=groups, bias=False, dilation=dilation)
39 |
40 |
41 | def conv1x1(in_planes, out_planes, stride=1):
42 | """1x1 convolution"""
43 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
44 |
45 |
46 | class BasicBlock(nn.Module):
47 | expansion = 1
48 |
49 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
50 | base_width=64, dilation=1, norm_layer=None):
51 | super(BasicBlock, self).__init__()
52 | if norm_layer is None:
53 | norm_layer = nn.BatchNorm2d
54 | if groups != 1 or base_width != 64:
55 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
56 | if dilation > 1:
57 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
58 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
59 | self.conv1 = conv3x3(inplanes, planes, stride)
60 | self.bn1 = norm_layer(planes)
61 | self.relu = nn.ReLU(inplace=True)
62 | self.conv2 = conv3x3(planes, planes)
63 | self.bn2 = norm_layer(planes)
64 | self.downsample = downsample
65 | self.stride = stride
66 |
67 | def forward(self, x):
68 | identity = x
69 |
70 | out = self.conv1(x)
71 | out = self.bn1(out)
72 | out = self.relu(out)
73 |
74 | out = self.conv2(out)
75 | out = self.bn2(out)
76 |
77 | if self.downsample is not None:
78 | identity = self.downsample(x)
79 |
80 | out += identity
81 | out = self.relu(out)
82 |
83 | return out
84 |
85 |
86 | class Bottleneck(nn.Module):
87 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
88 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
89 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
90 | # This variant is also known as ResNet V1.5 and improves accuracy according to
91 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
92 |
93 | expansion = 4
94 |
95 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
96 | base_width=64, dilation=1, norm_layer=None):
97 | super(Bottleneck, self).__init__()
98 | if norm_layer is None:
99 | norm_layer = nn.BatchNorm2d
100 | width = int(planes * (base_width / 64.)) * groups
101 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
102 | self.conv1 = conv1x1(inplanes, width)
103 | self.bn1 = norm_layer(width)
104 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
105 | self.bn2 = norm_layer(width)
106 | self.conv3 = conv1x1(width, planes * self.expansion)
107 | self.bn3 = norm_layer(planes * self.expansion)
108 | self.relu = nn.ReLU(inplace=True)
109 | self.downsample = downsample
110 | self.stride = stride
111 |
112 | def forward(self, x):
113 | identity = x
114 |
115 | out = self.conv1(x)
116 | out = self.bn1(out)
117 | out = self.relu(out)
118 |
119 | out = self.conv2(out)
120 | out = self.bn2(out)
121 | out = self.relu(out)
122 |
123 | out = self.conv3(out)
124 | out = self.bn3(out)
125 |
126 | if self.downsample is not None:
127 | identity = self.downsample(x)
128 |
129 | out += identity
130 | out = self.relu(out)
131 |
132 | return out
133 |
134 |
135 | class ResNet(nn.Module):
136 |
137 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
138 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
139 | norm_layer=None):
140 | super(ResNet, self).__init__()
141 | if norm_layer is None:
142 | norm_layer = nn.BatchNorm2d
143 | self._norm_layer = norm_layer
144 |
145 | self.inplanes = 64
146 | self.dilation = 1
147 | if replace_stride_with_dilation is None:
148 | # each element in the tuple indicates if we should replace
149 | # the 2x2 stride with a dilated convolution instead
150 | replace_stride_with_dilation = [False, False, False]
151 | if len(replace_stride_with_dilation) != 3:
152 | raise ValueError("replace_stride_with_dilation should be None "
153 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
154 | self.groups = groups
155 | self.base_width = width_per_group
156 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
157 | bias=False)
158 | self.bn1 = norm_layer(self.inplanes)
159 | self.relu = nn.ReLU(inplace=True)
160 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
161 | self.layer1 = self._make_layer(block, 64, layers[0])
162 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
163 | dilate=replace_stride_with_dilation[0])
164 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
165 | dilate=replace_stride_with_dilation[1])
166 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
167 | dilate=replace_stride_with_dilation[2])
168 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
169 | self.fc = nn.Linear(512 * block.expansion, num_classes)
170 |
171 | for m in self.modules():
172 | if isinstance(m, nn.Conv2d):
173 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
174 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
175 | nn.init.constant_(m.weight, 1)
176 | nn.init.constant_(m.bias, 0)
177 |
178 | # Zero-initialize the last BN in each residual branch,
179 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
180 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
181 | if zero_init_residual:
182 | for m in self.modules():
183 | if isinstance(m, Bottleneck):
184 | nn.init.constant_(m.bn3.weight, 0)
185 | elif isinstance(m, BasicBlock):
186 | nn.init.constant_(m.bn2.weight, 0)
187 |
188 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
189 | norm_layer = self._norm_layer
190 | downsample = None
191 | previous_dilation = self.dilation
192 | if dilate:
193 | self.dilation *= stride
194 | stride = 1
195 | if stride != 1 or self.inplanes != planes * block.expansion:
196 | downsample = nn.Sequential(
197 | conv1x1(self.inplanes, planes * block.expansion, stride),
198 | norm_layer(planes * block.expansion),
199 | )
200 |
201 | layers = []
202 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
203 | self.base_width, previous_dilation, norm_layer))
204 | self.inplanes = planes * block.expansion
205 | for _ in range(1, blocks):
206 | layers.append(block(self.inplanes, planes, groups=self.groups,
207 | base_width=self.base_width, dilation=self.dilation,
208 | norm_layer=norm_layer))
209 |
210 | return nn.Sequential(*layers)
211 |
212 | def _forward_impl(self, x):
213 | # See note [TorchScript super()]
214 | x = self.conv1(x)
215 | x = self.bn1(x)
216 | x = self.relu(x)
217 | x = self.maxpool(x)
218 |
219 | x = self.layer1(x)
220 | x = self.layer2(x)
221 | x = self.layer3(x)
222 | x = self.layer4(x)
223 |
224 | x = self.avgpool(x)
225 | x = torch.flatten(x, 1)
226 | x = self.fc(x)
227 |
228 | return x
229 |
230 | def forward(self, x):
231 | return self._forward_impl(x)
232 |
233 |
234 |
235 | class AdvResNet(ResNet):
236 | '''
237 | The modified model using ResNet in torchvision.models.resnet.
238 | Usually we using DataParallel to wrap this model,
239 | so you'd better set the attacker and mixbn before using DataParallel.
240 | '''
241 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
242 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
243 | norm_layer=None, attacker=NoOpAttacker()):
244 | super().__init__(block, layers, num_classes=num_classes, zero_init_residual=zero_init_residual,
245 | groups=groups, width_per_group=width_per_group, replace_stride_with_dilation=replace_stride_with_dilation,
246 | norm_layer=norm_layer)
247 | self.attacker = attacker
248 | self.mixbn = False
249 |
250 | def set_attacker(self, attacker):
251 | self.attacker = attacker
252 |
253 | def set_mixbn(self, mixbn):
254 | self.mixbn = mixbn
255 |
256 | def forward(self, x, labels):
257 | training = self.training
258 | input_len = len(x)
259 | # only during training do we need to attack, and cat the clean and auxiliary pics
260 | if training:
261 | self.eval()
262 | self.apply(to_adv_status)
263 | if isinstance(self.attacker, NoOpAttacker):
264 | images = x
265 | targets = labels
266 | else:
267 | aux_images, _ = self.attacker.attack(x, labels, self._forward_impl)
268 | images = torch.cat([x, aux_images], dim=0)
269 | targets = torch.cat([labels, labels], dim=0)
270 | self.train()
271 | if self.mixbn:
272 | # the DataParallel usually cat the outputs along the first dimension simply,
273 | # so if we don't change the dimensions, the outputs will be something like
274 | # [clean_batches_gpu1, adv_batches_gpu1, clean_batches_gpu2, adv_batches_gpu2...]
275 | # Then it will be hard to distinguish clean batches and adversarial batches.
276 | self.apply(to_mix_status)
277 | return self._forward_impl(images).view(2, input_len, -1).transpose(1, 0), targets.view(2, input_len).transpose(1, 0)
278 | else:
279 | self.apply(to_clean_status)
280 | return self._forward_impl(images), targets
281 | else:
282 | images = x
283 | targets = labels
284 | return self._forward_impl(images), targets
285 |
286 |
287 |
288 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
289 | model = AdvResNet(block, layers, **kwargs)
290 | if pretrained:
291 | raise ValueError('do not set pretrained as True, since we aim at training from scratch')
292 | # state_dict = load_state_dict_from_url(model_urls[arch],
293 | # progress=progress)
294 | # model.load_state_dict(state_dict)
295 | return model
296 |
297 |
298 | def resnet18(pretrained=False, progress=True, **kwargs):
299 | r"""ResNet-18 model from
300 | `"Deep Residual Learning for Image Recognition" `_
301 |
302 | Args:
303 | pretrained (bool): If True, returns a model pre-trained on ImageNet
304 | progress (bool): If True, displays a progress bar of the download to stderr
305 | """
306 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
307 | **kwargs)
308 |
309 |
310 | def resnet34(pretrained=False, progress=True, **kwargs):
311 | r"""ResNet-34 model from
312 | `"Deep Residual Learning for Image Recognition" `_
313 |
314 | Args:
315 | pretrained (bool): If True, returns a model pre-trained on ImageNet
316 | progress (bool): If True, displays a progress bar of the download to stderr
317 | """
318 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
319 | **kwargs)
320 |
321 |
322 | def resnet50(pretrained=False, progress=True, **kwargs):
323 | r"""ResNet-50 model from
324 | `"Deep Residual Learning for Image Recognition" `_
325 |
326 | Args:
327 | pretrained (bool): If True, returns a model pre-trained on ImageNet
328 | progress (bool): If True, displays a progress bar of the download to stderr
329 | """
330 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
331 | **kwargs)
332 |
333 |
334 | def resnet101(pretrained=False, progress=True, **kwargs):
335 | r"""ResNet-101 model from
336 | `"Deep Residual Learning for Image Recognition" `_
337 |
338 | Args:
339 | pretrained (bool): If True, returns a model pre-trained on ImageNet
340 | progress (bool): If True, displays a progress bar of the download to stderr
341 | """
342 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
343 | **kwargs)
344 |
345 |
346 | def resnet152(pretrained=False, progress=True, **kwargs):
347 | r"""ResNet-152 model from
348 | `"Deep Residual Learning for Image Recognition" `_
349 |
350 | Args:
351 | pretrained (bool): If True, returns a model pre-trained on ImageNet
352 | progress (bool): If True, displays a progress bar of the download to stderr
353 | """
354 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
355 | **kwargs)
356 |
357 |
358 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
359 | r"""ResNeXt-50 32x4d model from
360 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
361 |
362 | Args:
363 | pretrained (bool): If True, returns a model pre-trained on ImageNet
364 | progress (bool): If True, displays a progress bar of the download to stderr
365 | """
366 | kwargs['groups'] = 32
367 | kwargs['width_per_group'] = 4
368 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
369 | pretrained, progress, **kwargs)
370 |
371 |
372 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
373 | r"""ResNeXt-101 32x8d model from
374 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
375 |
376 | Args:
377 | pretrained (bool): If True, returns a model pre-trained on ImageNet
378 | progress (bool): If True, displays a progress bar of the download to stderr
379 | """
380 | kwargs['groups'] = 32
381 | kwargs['width_per_group'] = 8
382 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
383 | pretrained, progress, **kwargs)
384 |
385 |
386 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
387 | r"""Wide ResNet-50-2 model from
388 | `"Wide Residual Networks" `_
389 |
390 | The model is the same as ResNet except for the bottleneck number of channels
391 | which is twice larger in every block. The number of channels in outer 1x1
392 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
393 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
394 |
395 | Args:
396 | pretrained (bool): If True, returns a model pre-trained on ImageNet
397 | progress (bool): If True, displays a progress bar of the download to stderr
398 | """
399 | kwargs['width_per_group'] = 64 * 2
400 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
401 | pretrained, progress, **kwargs)
402 |
403 |
404 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
405 | r"""Wide ResNet-101-2 model from
406 | `"Wide Residual Networks" `_
407 |
408 | The model is the same as ResNet except for the bottleneck number of channels
409 | which is twice larger in every block. The number of channels in outer 1x1
410 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
411 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
412 |
413 | Args:
414 | pretrained (bool): If True, returns a model pre-trained on ImageNet
415 | progress (bool): If True, displays a progress bar of the download to stderr
416 | """
417 | kwargs['width_per_group'] = 64 * 2
418 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
419 | pretrained, progress, **kwargs)
420 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python imagenet.py --attack-iter 1 --attack-epsilon 1 --attack-step-size 1 -a resnet50 --train-batch 256 --num_classes 1000 --data /path/of/ImageNet --epochs 105 --schedule 30 60 90 100 --gamma 0.1 -c checkpoints/imagenet/advresnet-resnet50-smoothing --gpu-id 0,1,2,3,4,5,6,7 --lr_schedule step --mixbn
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Useful utils
2 | """
3 | from .misc import *
4 | from .logger import *
5 | from .visualize import *
6 | from .eval import *
7 |
8 | # progress bar
9 | import os, sys
10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress"))
11 | from progress.bar import Bar as Bar
--------------------------------------------------------------------------------
/utils/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | __all__ = ['accuracy']
4 |
5 | def accuracy(output, target, topk=(1,)):
6 | """Computes the precision@k for the specified values of k"""
7 | maxk = max(topk)
8 | batch_size = target.size(0)
9 |
10 | _, pred = output.topk(maxk, 1, True, True)
11 | pred = pred.t()
12 | correct = pred.eq(target.view(1, -1).expand_as(pred))
13 |
14 | res = []
15 | for k in topk:
16 | correct_k = correct[:k].view(-1).float().sum(0)
17 | res.append(correct_k.mul_(100.0 / batch_size))
18 | return res
--------------------------------------------------------------------------------
/utils/fastaug/augmentations.py:
--------------------------------------------------------------------------------
1 | # code in this file is adpated from rpmcruz/autoaugment
2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
3 | import random
4 |
5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
6 | import numpy as np
7 | import torch
8 | from torchvision.transforms.transforms import Compose
9 |
10 | random_mirror = True
11 |
12 |
13 | def ShearX(img, v): # [-0.3, 0.3]
14 | assert -0.3 <= v <= 0.3
15 | if random_mirror and random.random() > 0.5:
16 | v = -v
17 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
18 |
19 |
20 | def ShearY(img, v): # [-0.3, 0.3]
21 | assert -0.3 <= v <= 0.3
22 | if random_mirror and random.random() > 0.5:
23 | v = -v
24 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
25 |
26 |
27 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
28 | assert -0.45 <= v <= 0.45
29 | if random_mirror and random.random() > 0.5:
30 | v = -v
31 | v = v * img.size[0]
32 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
33 |
34 |
35 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
36 | assert -0.45 <= v <= 0.45
37 | if random_mirror and random.random() > 0.5:
38 | v = -v
39 | v = v * img.size[1]
40 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
41 |
42 |
43 | def TranslateXAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
44 | assert 0 <= v <= 10
45 | if random.random() > 0.5:
46 | v = -v
47 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
48 |
49 |
50 | def TranslateYAbs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
51 | assert 0 <= v <= 10
52 | if random.random() > 0.5:
53 | v = -v
54 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
55 |
56 |
57 | def Rotate(img, v): # [-30, 30]
58 | assert -30 <= v <= 30
59 | if random_mirror and random.random() > 0.5:
60 | v = -v
61 | return img.rotate(v)
62 |
63 |
64 | def AutoContrast(img, _):
65 | return PIL.ImageOps.autocontrast(img)
66 |
67 |
68 | def Invert(img, _):
69 | return PIL.ImageOps.invert(img)
70 |
71 |
72 | def Equalize(img, _):
73 | return PIL.ImageOps.equalize(img)
74 |
75 |
76 | def Flip(img, _): # not from the paper
77 | return PIL.ImageOps.mirror(img)
78 |
79 |
80 | def Solarize(img, v): # [0, 256]
81 | assert 0 <= v <= 256
82 | return PIL.ImageOps.solarize(img, v)
83 |
84 |
85 | def Posterize(img, v): # [4, 8]
86 | assert 4 <= v <= 8
87 | v = int(v)
88 | return PIL.ImageOps.posterize(img, v)
89 |
90 |
91 | def Posterize2(img, v): # [0, 4]
92 | assert 0 <= v <= 4
93 | v = int(v)
94 | return PIL.ImageOps.posterize(img, v)
95 |
96 |
97 | def Contrast(img, v): # [0.1,1.9]
98 | assert 0.1 <= v <= 1.9
99 | return PIL.ImageEnhance.Contrast(img).enhance(v)
100 |
101 |
102 | def Color(img, v): # [0.1,1.9]
103 | assert 0.1 <= v <= 1.9
104 | return PIL.ImageEnhance.Color(img).enhance(v)
105 |
106 |
107 | def Brightness(img, v): # [0.1,1.9]
108 | assert 0.1 <= v <= 1.9
109 | return PIL.ImageEnhance.Brightness(img).enhance(v)
110 |
111 |
112 | def Sharpness(img, v): # [0.1,1.9]
113 | assert 0.1 <= v <= 1.9
114 | return PIL.ImageEnhance.Sharpness(img).enhance(v)
115 |
116 |
117 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
118 | assert 0.0 <= v <= 0.2
119 | if v <= 0.:
120 | return img
121 |
122 | v = v * img.size[0]
123 | return CutoutAbs(img, v)
124 |
125 |
126 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
127 | # assert 0 <= v <= 20
128 | if v < 0:
129 | return img
130 | w, h = img.size
131 | x0 = np.random.uniform(w)
132 | y0 = np.random.uniform(h)
133 |
134 | x0 = int(max(0, x0 - v / 2.))
135 | y0 = int(max(0, y0 - v / 2.))
136 | x1 = min(w, x0 + v)
137 | y1 = min(h, y0 + v)
138 |
139 | xy = (x0, y0, x1, y1)
140 | color = (125, 123, 114)
141 | # color = (0, 0, 0)
142 | img = img.copy()
143 | PIL.ImageDraw.Draw(img).rectangle(xy, color)
144 | return img
145 |
146 |
147 | def SamplePairing(imgs): # [0, 0.4]
148 | def f(img1, v):
149 | i = np.random.choice(len(imgs))
150 | img2 = PIL.Image.fromarray(imgs[i])
151 | return PIL.Image.blend(img1, img2, v)
152 |
153 | return f
154 |
155 |
156 | def augment_list(for_autoaug=True): # 16 oeprations and their ranges
157 | l = [
158 | (ShearX, -0.3, 0.3), # 0
159 | (ShearY, -0.3, 0.3), # 1
160 | (TranslateX, -0.45, 0.45), # 2
161 | (TranslateY, -0.45, 0.45), # 3
162 | (Rotate, -30, 30), # 4
163 | (AutoContrast, 0, 1), # 5
164 | (Invert, 0, 1), # 6
165 | (Equalize, 0, 1), # 7
166 | (Solarize, 0, 256), # 8
167 | (Posterize, 4, 8), # 9
168 | (Contrast, 0.1, 1.9), # 10
169 | (Color, 0.1, 1.9), # 11
170 | (Brightness, 0.1, 1.9), # 12
171 | (Sharpness, 0.1, 1.9), # 13
172 | (Cutout, 0, 0.2), # 14
173 | # (SamplePairing(imgs), 0, 0.4), # 15
174 | ]
175 | if for_autoaug:
176 | l += [
177 | (CutoutAbs, 0, 20), # compatible with auto-augment
178 | (Posterize2, 0, 4), # 9
179 | (TranslateXAbs, 0, 10), # 9
180 | (TranslateYAbs, 0, 10), # 9
181 | ]
182 | return l
183 |
184 |
185 | augment_dict = {fn.__name__: (fn, v1, v2) for fn, v1, v2 in augment_list()}
186 |
187 |
188 | def get_augment(name):
189 | return augment_dict[name]
190 |
191 |
192 | def apply_augment(img, name, level):
193 | augment_fn, low, high = get_augment(name)
194 | return augment_fn(img.copy(), level * (high - low) + low)
195 |
196 |
197 | class Lighting(object):
198 | """Lighting noise(AlexNet - style PCA - based noise)"""
199 |
200 | def __init__(self, alphastd, eigval, eigvec):
201 | self.alphastd = alphastd
202 | self.eigval = torch.Tensor(eigval)
203 | self.eigvec = torch.Tensor(eigvec)
204 |
205 | def __call__(self, img):
206 | if self.alphastd == 0:
207 | return img
208 |
209 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
210 | rgb = self.eigvec.type_as(img).clone() \
211 | .mul(alpha.view(1, 3).expand(3, 3)) \
212 | .mul(self.eigval.view(1, 3).expand(3, 3)) \
213 | .sum(1).squeeze()
214 |
215 | return img.add(rgb.view(3, 1, 1).expand_as(img))
216 |
--------------------------------------------------------------------------------
/utils/fastaug/fastaug.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from utils.fastaug.augmentations import apply_augment
4 |
5 |
6 | def fa_resnet50_rimagenet():
7 | p = [[["ShearY", 0.14143816458479197, 0.513124791615952], ["Sharpness", 0.9290316227291179, 0.9788406212603302]],
8 | [["Color", 0.21502874228385338, 0.3698477943880306], ["TranslateY", 0.49865058747734736, 0.4352676987103321]],
9 | [["Brightness", 0.6603452126485386, 0.6990174510500261], ["Cutout", 0.7742953773992511, 0.8362550883640804]],
10 | [["Posterize", 0.5188375788270497, 0.9863648925446865],
11 | ["TranslateY", 0.8365230108655313, 0.6000972236440252]],
12 | [["ShearY", 0.9714994964711299, 0.2563663552809896], ["Equalize", 0.8987567223581153, 0.1181761775609772]],
13 | [["Sharpness", 0.14346409304565366, 0.5342189791746006],
14 | ["Sharpness", 0.1219714162835897, 0.44746801278319975]],
15 | [["TranslateX", 0.08089260772173967, 0.028011721602479833],
16 | ["TranslateX", 0.34767877352421406, 0.45131294688688794]],
17 | [["Brightness", 0.9191164585327378, 0.5143232242627864], ["Color", 0.9235247849934283, 0.30604586249462173]],
18 | [["Contrast", 0.4584173187505879, 0.40314219914942756], ["Rotate", 0.550289356406774, 0.38419022293237126]],
19 | [["Posterize", 0.37046156420799325, 0.052693291117634544], ["Cutout", 0.7597581409366909, 0.7535799791937421]],
20 | [["Color", 0.42583964114658746, 0.6776641859552079], ["ShearY", 0.2864805671096011, 0.07580175477739545]],
21 | [["Brightness", 0.5065952125552232, 0.5508640233704984],
22 | ["Brightness", 0.4760021616081475, 0.3544313318097987]],
23 | [["Posterize", 0.5169630851995185, 0.9466018906715961], ["Posterize", 0.5390336503396841, 0.1171015788193209]],
24 | [["Posterize", 0.41153170909576176, 0.7213063942615204], ["Rotate", 0.6232230424824348, 0.7291984098675746]],
25 | [["Color", 0.06704687234714028, 0.5278429246040438], ["Sharpness", 0.9146652195810183, 0.4581415618941407]],
26 | [["ShearX", 0.22404644446773492, 0.6508620171913467],
27 | ["Brightness", 0.06421961538672451, 0.06859528721039095]],
28 | [["Rotate", 0.29864103693134797, 0.5244313199644495], ["Sharpness", 0.4006161706584276, 0.5203708477368657]],
29 | [["AutoContrast", 0.5748186910788027, 0.8185482599354216],
30 | ["Posterize", 0.9571441684265188, 0.1921474117448481]],
31 | [["ShearY", 0.5214786760436251, 0.8375629059785009], ["Invert", 0.6872393349333636, 0.9307694335024579]],
32 | [["Contrast", 0.47219838080793364, 0.8228524484275648],
33 | ["TranslateY", 0.7435518856840543, 0.5888865560614439]],
34 | [["Posterize", 0.10773482839638836, 0.6597021018893648], ["Contrast", 0.5218466423129691, 0.562985661685268]],
35 | [["Rotate", 0.4401753067886466, 0.055198255925702475], ["Rotate", 0.3702153509335602, 0.5821574425474759]],
36 | [["TranslateY", 0.6714729117832363, 0.7145542887432927],
37 | ["Equalize", 0.0023263758097700205, 0.25837341854887885]],
38 | [["Cutout", 0.3159707561240235, 0.19539664199170742], ["TranslateY", 0.8702824829864558, 0.5832348977243467]],
39 | [["AutoContrast", 0.24800812729140026, 0.08017301277245716],
40 | ["Brightness", 0.5775505849482201, 0.4905904775616114]],
41 | [["Color", 0.4143517886294533, 0.8445937742921498], ["ShearY", 0.28688910858536587, 0.17539366839474402]],
42 | [["Brightness", 0.6341134194059947, 0.43683815933640435],
43 | ["Brightness", 0.3362277685899835, 0.4612826163288225]],
44 | [["Sharpness", 0.4504035748829761, 0.6698294470467474],
45 | ["Posterize", 0.9610055612671645, 0.21070714173174876]],
46 | [["Posterize", 0.19490421920029832, 0.7235798208354267], ["Rotate", 0.8675551331308305, 0.46335565746433094]],
47 | [["Color", 0.35097958351003306, 0.42199181561523186], ["Invert", 0.914112788087429, 0.44775583211984815]],
48 | [["Cutout", 0.223575616055454, 0.6328591417299063], ["TranslateY", 0.09269465212259387, 0.5101073959070608]],
49 | [["Rotate", 0.3315734525975911, 0.9983593458299167], ["Sharpness", 0.12245416662856974, 0.6258689139914664]],
50 | [["ShearY", 0.696116760180471, 0.6317805202283014], ["Color", 0.847501151593963, 0.4440116609830195]],
51 | [["Solarize", 0.24945891607225948, 0.7651150206105561], ["Cutout", 0.7229677092930331, 0.12674657348602494]],
52 | [["TranslateX", 0.43461945065713675, 0.06476571036747841], ["Color", 0.6139316940180952, 0.7376264330632316]],
53 | [["Invert", 0.1933003530637138, 0.4497819016184308], ["Invert", 0.18391634069983653, 0.3199769100951113]],
54 | [["Color", 0.20418296626476137, 0.36785101882029814], ["Posterize", 0.624658293920083, 0.8390081535735991]],
55 | [["Sharpness", 0.5864963540530814, 0.586672446690273], ["Posterize", 0.1980280647652339, 0.222114611452575]],
56 | [["Invert", 0.3543654961628104, 0.5146369635250309], ["Equalize", 0.40751271919434434, 0.4325310837291978]],
57 | [["ShearY", 0.22602859359451877, 0.13137880879778158], ["Posterize", 0.7475029061591305, 0.803900538461099]],
58 | [["Sharpness", 0.12426276165599924, 0.5965912716602046], ["Invert", 0.22603903038966913, 0.4346802001255868]],
59 | [["TranslateY", 0.010307035630661765, 0.16577665156754046],
60 | ["Posterize", 0.4114319141395257, 0.829872913683949]],
61 | [["TranslateY", 0.9353069865746215, 0.5327821671247214], ["Color", 0.16990443486261103, 0.38794866007484197]],
62 | [["Cutout", 0.1028174322829021, 0.3955952903458266], ["ShearY", 0.4311995281335693, 0.48024695395374734]],
63 | [["Posterize", 0.1800334334284686, 0.0548749478418862],
64 | ["Brightness", 0.7545808536793187, 0.7699080551646432]],
65 | [["Color", 0.48695305373084197, 0.6674269768464615], ["ShearY", 0.4306032279086781, 0.06057690550239343]],
66 | [["Brightness", 0.4919399683825053, 0.677338905806407],
67 | ["Brightness", 0.24112708387760828, 0.42761103121157656]],
68 | [["Posterize", 0.4434818644882532, 0.9489450593207714],
69 | ["Posterize", 0.40957675116385955, 0.015664946759584186]],
70 | [["Posterize", 0.41307949855153797, 0.6843276552020272], ["Rotate", 0.8003545094091291, 0.7002300783416026]],
71 | [["Color", 0.7038570031770905, 0.4697612983649519], ["Sharpness", 0.9700016496081002, 0.25185103545948884]],
72 | [["AutoContrast", 0.714641656154856, 0.7962423001719023],
73 | ["Sharpness", 0.2410097684093468, 0.5919171048019731]],
74 | [["TranslateX", 0.8101567644494714, 0.7156447005337443], ["Solarize", 0.5634727831229329, 0.8875158446846]],
75 | [["Sharpness", 0.5335258857303261, 0.364743126378182], ["Color", 0.453280875871377, 0.5621962714743068]],
76 | [["Cutout", 0.7423678127672542, 0.7726370777867049], ["Invert", 0.2806161382641934, 0.6021111986900146]],
77 | [["TranslateY", 0.15190341320343761, 0.3860373175487939], ["Cutout", 0.9980805818665679, 0.05332384819400854]],
78 | [["Posterize", 0.36518675678786605, 0.2935819027397963],
79 | ["TranslateX", 0.26586180351840005, 0.303641300745208]],
80 | [["Brightness", 0.19994509744377761, 0.90813953707639], ["Equalize", 0.8447217761297836, 0.3449396603478335]],
81 | [["Sharpness", 0.9294773669936768, 0.999713346583839], ["Brightness", 0.1359744825665662, 0.1658489221872924]],
82 | [["TranslateX", 0.11456529257659381, 0.9063795878367734],
83 | ["Equalize", 0.017438134319894553, 0.15776887259743755]],
84 | [["ShearX", 0.9833726383270114, 0.5688194948373335], ["Equalize", 0.04975615490994345, 0.8078130016227757]],
85 | [["Brightness", 0.2654654830488695, 0.8989789725280538],
86 | ["TranslateX", 0.3681535065952329, 0.36433345713161036]],
87 | [["Rotate", 0.04956524209892327, 0.5371942433238247], ["ShearY", 0.0005527499145153714, 0.56082571605602]],
88 | [["Rotate", 0.7918337108932019, 0.5906896260060501], ["Posterize", 0.8223967034091191, 0.450216998388943]],
89 | [["Color", 0.43595106766978337, 0.5253013785221605], ["Sharpness", 0.9169421073531799, 0.8439997639348893]],
90 | [["TranslateY", 0.20052300197155504, 0.8202662448307549],
91 | ["Sharpness", 0.2875792108435686, 0.6997181624527842]],
92 | [["Color", 0.10568089980973616, 0.3349467065132249], ["Brightness", 0.13070947282207768, 0.5757725013960775]],
93 | [["AutoContrast", 0.3749999712869779, 0.6665578760607657],
94 | ["Brightness", 0.8101178402610292, 0.23271946112218125]],
95 | [["Color", 0.6473605933679651, 0.7903409763232029], ["ShearX", 0.588080941572581, 0.27223524148254086]],
96 | [["Cutout", 0.46293361616697304, 0.7107761001833921],
97 | ["AutoContrast", 0.3063766931658412, 0.8026114219854579]],
98 | [["Brightness", 0.7884854981520251, 0.5503669863113797],
99 | ["Brightness", 0.5832456158675261, 0.5840349298921661]],
100 | [["Solarize", 0.4157539625058916, 0.9161905834309929], ["Sharpness", 0.30628197221802017, 0.5386291658995193]],
101 | [["Sharpness", 0.03329610069672856, 0.17066672983670506], ["Invert", 0.9900547302690527, 0.6276238841220477]],
102 | [["Solarize", 0.551015648982762, 0.6937104775938737], ["Color", 0.8838491591064375, 0.31596634380795385]],
103 | [["AutoContrast", 0.16224182418148447, 0.6068227969351896],
104 | ["Sharpness", 0.9599468096118623, 0.4885289719905087]],
105 | [["TranslateY", 0.06576432526133724, 0.6899544605400214],
106 | ["Posterize", 0.2177096480169678, 0.9949164789616582]], [["Solarize", 0.529820544480292, 0.7576047224165541],
107 | ["Sharpness", 0.027047878909321643,
108 | 0.45425231553970685]],
109 | [["Sharpness", 0.9102526010473146, 0.8311987141993857], ["Invert", 0.5191838751826638, 0.6906136644742229]],
110 | [["Solarize", 0.4762773516008588, 0.7703654263842423], ["Color", 0.8048437792602289, 0.4741523094238038]],
111 | [["Sharpness", 0.7095055508594206, 0.7047344238075169], ["Sharpness", 0.5059623654132546, 0.6127255499234886]],
112 | [["TranslateY", 0.02150725921966186, 0.3515764519224378],
113 | ["Posterize", 0.12482170119714735, 0.7829851754051393]],
114 | [["Color", 0.7983830079184816, 0.6964694521670339], ["Brightness", 0.3666527856286296, 0.16093151636495978]],
115 | [["AutoContrast", 0.6724982375829505, 0.536777706678488],
116 | ["Sharpness", 0.43091754837597646, 0.7363240924241439]],
117 | [["Brightness", 0.2889770401966227, 0.4556557902380539],
118 | ["Sharpness", 0.8805303296690755, 0.6262218017754902]],
119 | [["Sharpness", 0.5341939854581068, 0.6697109101429343], ["Rotate", 0.6806606655137529, 0.4896914517968317]],
120 | [["Sharpness", 0.5690509737059344, 0.32790632371915096],
121 | ["Posterize", 0.7951894258661069, 0.08377850335209162]],
122 | [["Color", 0.6124132978216081, 0.5756485920709012], ["Brightness", 0.33053544654445344, 0.23321841707002083]],
123 | [["TranslateX", 0.0654795026615917, 0.5227246924310244], ["ShearX", 0.2932320531132063, 0.6732066478183716]],
124 | [["Cutout", 0.6226071187083615, 0.01009274433736012], ["ShearX", 0.7176799968189801, 0.3758780240463811]],
125 | [["Rotate", 0.18172339508029314, 0.18099184896819184], ["ShearY", 0.7862658331645667, 0.295658135767252]],
126 | [["Contrast", 0.4156099177015862, 0.7015784500878446], ["Sharpness", 0.6454135310009, 0.32335858947955287]],
127 | [["Color", 0.6215885089922037, 0.6882673235388836], ["Brightness", 0.3539881732605379, 0.39486736455795496]],
128 | [["Invert", 0.8164816716866418, 0.7238192000817796], ["Sharpness", 0.3876355847343607, 0.9870077619731956]],
129 | [["Brightness", 0.1875628712629315, 0.5068115936257], ["Sharpness", 0.8732419122060423, 0.5028019258530066]],
130 | [["Sharpness", 0.6140734993408259, 0.6458239834366959], ["Rotate", 0.5250107862824867, 0.533419456933602]],
131 | [["Sharpness", 0.5710893143725344, 0.15551651073007305], ["ShearY", 0.6548487860151722, 0.021365083044319146]],
132 | [["Color", 0.7610250354649954, 0.9084452893074055], ["Brightness", 0.6934611792619156, 0.4108071412071374]],
133 | [["ShearY", 0.07512550098923898, 0.32923768385754293], ["ShearY", 0.2559588911696498, 0.7082337365398496]],
134 | [["Cutout", 0.5401319018926146, 0.004750568603408445], ["ShearX", 0.7473354415031975, 0.34472481968368773]],
135 | [["Rotate", 0.02284154583679092, 0.1353450082435801], ["ShearY", 0.8192458031684238, 0.2811653613473772]],
136 | [["Contrast", 0.21142896718139154, 0.7230739568811746],
137 | ["Sharpness", 0.6902690582665707, 0.13488436112901683]],
138 | [["Posterize", 0.21701219600958138, 0.5900695769640687], ["Rotate", 0.7541095031505971, 0.5341162375286219]],
139 | [["Posterize", 0.5772853064792737, 0.45808311743269936],
140 | ["Brightness", 0.14366050177823675, 0.4644871239446629]],
141 | [["Cutout", 0.8951718842805059, 0.4970074074310499], ["Equalize", 0.3863835903119882, 0.9986531042150006]],
142 | [["Equalize", 0.039411354473938925, 0.7475477254908457],
143 | ["Sharpness", 0.8741966378291861, 0.7304822679596362]],
144 | [["Solarize", 0.4908704265218634, 0.5160677350249471], ["Color", 0.24961813832742435, 0.09362352627360726]],
145 | [["Rotate", 7.870457075154214e-05, 0.8086950025500952],
146 | ["Solarize", 0.10200484521793163, 0.12312889222989265]],
147 | [["Contrast", 0.8052564975559727, 0.3403813036543645], ["Solarize", 0.7690158533600184, 0.8234626822018851]],
148 | [["AutoContrast", 0.680362728854513, 0.9415320040873628],
149 | ["TranslateY", 0.5305871824686941, 0.8030609611614028]],
150 | [["Cutout", 0.1748050257378294, 0.06565343731910589], ["TranslateX", 0.1812738872339903, 0.6254461448344308]],
151 | [["Brightness", 0.4230502644722749, 0.3346463682905031], ["ShearX", 0.19107198973659312, 0.6715789128604919]],
152 | [["ShearX", 0.1706528684548394, 0.7816570201200446], ["TranslateX", 0.494545185948171, 0.4710810058360291]],
153 | [["TranslateX", 0.42356251508933324, 0.23865307292867322],
154 | ["TranslateX", 0.24407503619326745, 0.6013778508137331]],
155 | [["AutoContrast", 0.7719512185744232, 0.3107905373009763],
156 | ["ShearY", 0.49448082925617176, 0.5777951230577671]],
157 | [["Cutout", 0.13026983827940525, 0.30120438757485657], ["Brightness", 0.8857896834516185, 0.7731541459513939]],
158 | [["AutoContrast", 0.6422800349197934, 0.38637401090264556],
159 | ["TranslateX", 0.25085431400995084, 0.3170642592664873]],
160 | [["Sharpness", 0.22336654455367122, 0.4137774852324138], ["ShearY", 0.22446851054920894, 0.518341735882535]],
161 | [["Color", 0.2597579403253848, 0.7289643913060193], ["Sharpness", 0.5227416670468619, 0.9239943674030637]],
162 | [["Cutout", 0.6835337711563527, 0.24777620448593812],
163 | ["AutoContrast", 0.37260245353051846, 0.4840361183247263]],
164 | [["Posterize", 0.32756602788628375, 0.21185124493743707],
165 | ["ShearX", 0.25431504951763967, 0.19585996561416225]],
166 | [["AutoContrast", 0.07930627591849979, 0.5719381348340309],
167 | ["AutoContrast", 0.335512380071304, 0.4208050118308541]],
168 | [["Rotate", 0.2924360268257798, 0.5317629242879337], ["Sharpness", 0.4531050021499891, 0.4102650087199528]],
169 | [["Equalize", 0.5908862210984079, 0.468742362277498], ["Brightness", 0.08571766548550425, 0.5629320703375056]],
170 | [["Cutout", 0.52751122383816, 0.7287774744737556], ["Equalize", 0.28721628275296274, 0.8075179887475786]],
171 | [["AutoContrast", 0.24208377391366226, 0.34616549409607644],
172 | ["TranslateX", 0.17454707403766834, 0.5278055700078459]],
173 | [["Brightness", 0.5511881924749478, 0.999638675514418], ["Equalize", 0.14076197797220913, 0.2573030693317552]],
174 | [["ShearX", 0.668731433926434, 0.7564253049646743], ["Color", 0.63235486543845, 0.43954436063340785]],
175 | [["ShearX", 0.40511960873276237, 0.5710419512142979], ["Contrast", 0.9256769948746423, 0.7461350716211649]],
176 | [["Cutout", 0.9995917204023061, 0.22908419326246265], ["TranslateX", 0.5440902956629469, 0.9965570051216295]],
177 | [["Color", 0.22552987172228894, 0.4514558960849747], ["Sharpness", 0.638058150559443, 0.9987829481002615]],
178 | [["Contrast", 0.5362775837534763, 0.7052133185951871], ["ShearY", 0.220369845547023, 0.7593922994775721]],
179 | [["ShearX", 0.0317785822935219, 0.775536785253455], ["TranslateX", 0.7939510227015061, 0.5355620618496535]],
180 | [["Cutout", 0.46027969917602196, 0.31561199122527517], ["Color", 0.06154066467629451, 0.5384660000729091]],
181 | [["Sharpness", 0.7205483743301113, 0.552222392539886], ["Posterize", 0.5146496404711752, 0.9224333144307473]],
182 | [["ShearX", 0.00014547730356910538, 0.3553954298642108], ["TranslateY", 0.9625736029090676, 0.57403418640424]],
183 | [["Posterize", 0.9199917903297341, 0.6690259107633706],
184 | ["Posterize", 0.0932558110217602, 0.22279303372106138]],
185 | [["Invert", 0.25401453476874863, 0.3354329544078385], ["Posterize", 0.1832673201325652, 0.4304718799821412]],
186 | [["TranslateY", 0.02084122674367607, 0.12826181437197323], ["ShearY", 0.655862534043703, 0.3838330909470975]],
187 | [["Contrast", 0.35231797644104523, 0.3379356652070079], ["Cutout", 0.19685599014304822, 0.1254328595280942]],
188 | [["Sharpness", 0.18795594984191433, 0.09488678946484895], ["ShearX", 0.33332876790679306, 0.633523782574133]],
189 | [["Cutout", 0.28267175940290246, 0.7901991550267817], ["Contrast", 0.021200195312951198, 0.4733128702798515]],
190 | [["ShearX", 0.966231043411256, 0.7700673327786812], ["TranslateX", 0.7102390777763321, 0.12161245817120675]],
191 | [["Cutout", 0.5183324259533826, 0.30766086003013055], ["Color", 0.48399078150128927, 0.4967477809069189]],
192 | [["Sharpness", 0.8160855187385873, 0.47937658961644], ["Posterize", 0.46360395447862535, 0.7685454058155061]],
193 | [["ShearX", 0.10173571421694395, 0.3987290690178754], ["TranslateY", 0.8939980277379345, 0.5669994143735713]],
194 | [["Posterize", 0.6768089584801844, 0.7113149244621721],
195 | ["Posterize", 0.054896856043358935, 0.3660837250743921]],
196 | [["AutoContrast", 0.5915576211896306, 0.33607718177676493],
197 | ["Contrast", 0.3809408206617828, 0.5712201773913784]],
198 | [["AutoContrast", 0.012321347472748323, 0.06379072432796573],
199 | ["Rotate", 0.0017964439160045656, 0.7598026295973337]],
200 | [["Contrast", 0.6007100085192627, 0.36171972473370206], ["Invert", 0.09553573684975913, 0.12218510774295901]],
201 | [["AutoContrast", 0.32848604643836266, 0.2619457656206414],
202 | ["Invert", 0.27082113532501784, 0.9967965642293485]],
203 | [["AutoContrast", 0.6156282120903395, 0.9422706516080884],
204 | ["Sharpness", 0.4215509247379262, 0.4063347716503587]],
205 | [["Solarize", 0.25059210436331264, 0.7215305521159305], ["Invert", 0.1654465185253614, 0.9605851884186778]],
206 | [["AutoContrast", 0.4464438610980994, 0.685334175815482], ["Cutout", 0.24358625461158645, 0.4699066834058694]],
207 | [["Rotate", 0.5931657741857909, 0.6813978655574067], ["AutoContrast", 0.9259100547738681, 0.4903201223870492]],
208 | [["Color", 0.8203976071280751, 0.9777824466585101], ["Posterize", 0.4620669369254169, 0.2738895968716055]],
209 | [["Contrast", 0.13754352055786848, 0.3369433962088463],
210 | ["Posterize", 0.48371187792441916, 0.025718004361451302]],
211 | [["Rotate", 0.5208233630704999, 0.1760188899913535], ["TranslateX", 0.49753461392937226, 0.4142935276250922]],
212 | [["Cutout", 0.5967418240931212, 0.8028675552639539], ["Cutout", 0.20021854152659121, 0.19426330549590076]],
213 | [["ShearY", 0.549583567386676, 0.6601326640171705], ["Cutout", 0.6111813470383047, 0.4141935587984994]],
214 | [["Brightness", 0.6354891977535064, 0.31591459747846745],
215 | ["AutoContrast", 0.7853952208711621, 0.6555861906702081]],
216 | [["AutoContrast", 0.7333725370546154, 0.9919410576081586], ["Cutout", 0.9984177877923588, 0.2938253683694291]],
217 | [["Color", 0.33219296307742263, 0.6378995578424113],
218 | ["AutoContrast", 0.15432820754183288, 0.7897899838932103]],
219 | [["Contrast", 0.5905289460222578, 0.8158577207653422], ["Cutout", 0.3980284381203051, 0.43030531250317217]],
220 | [["TranslateX", 0.452093693346745, 0.5251475931559115], ["Rotate", 0.991422504871258, 0.4556503729269001]],
221 | [["Color", 0.04560406292983776, 0.061574671308480766],
222 | ["Brightness", 0.05161079440128734, 0.6718398142425688]],
223 | [["Contrast", 0.02913302416506853, 0.14402056093217708], ["Rotate", 0.7306930378774588, 0.47088249057922094]],
224 | [["Solarize", 0.3283072384190169, 0.82680847744367], ["Invert", 0.21632614168418854, 0.8792241691482687]],
225 | [["Equalize", 0.4860808352478527, 0.9440534949023064], ["Cutout", 0.31395897639184694, 0.41805859306017523]],
226 | [["Rotate", 0.2816043232522335, 0.5451282807926706], ["Color", 0.7388520447173302, 0.7706503658143311]],
227 | [["Color", 0.9342776719536201, 0.9039981381514299], ["Rotate", 0.6646389177840164, 0.5147917008383647]],
228 | [["Cutout", 0.08929430082050335, 0.22416445996932374], ["Posterize", 0.454485751267457, 0.500958345348237]],
229 | [["TranslateX", 0.14674201106374488, 0.7018633472428202],
230 | ["Sharpness", 0.6128796723832848, 0.743535235614809]],
231 | [["TranslateX", 0.5189900164469432, 0.6491132403587601],
232 | ["Contrast", 0.26309555778227806, 0.5976857969656114]],
233 | [["Solarize", 0.23569808291972655, 0.3315781686591778], ["ShearY", 0.07292078937544964, 0.7460326987587573]],
234 | [["ShearY", 0.7090542757477153, 0.5246437008439621], ["Sharpness", 0.9666919148538443, 0.4841687888767071]],
235 | [["Solarize", 0.3486952615189488, 0.7012877201721799], ["Invert", 0.1933387967311534, 0.9535472742828175]],
236 | [["AutoContrast", 0.5393460721514914, 0.6924005011697713],
237 | ["Cutout", 0.16988156769247176, 0.3667207571712882]],
238 | [["Rotate", 0.5815329514554719, 0.5390406879316949], ["AutoContrast", 0.7370538341589625, 0.7708822194197815]],
239 | [["Color", 0.8463701017918459, 0.9893491045831084], ["Invert", 0.06537367901579016, 0.5238468509941635]],
240 | [["Contrast", 0.8099771812443645, 0.39371603893945184],
241 | ["Posterize", 0.38273629875646487, 0.46493786058573966]],
242 | [["Color", 0.11164686537114032, 0.6771450570033168], ["Posterize", 0.27921361289661406, 0.7214300893597819]],
243 | [["Contrast", 0.5958265906571906, 0.5963959447666958], ["Sharpness", 0.2640889223630885, 0.3365870842641453]],
244 | [["Color", 0.255634146724125, 0.5610029792926452], ["ShearY", 0.7476893976084721, 0.36613194760395557]],
245 | [["ShearX", 0.2167581882130063, 0.022978065071245002], ["TranslateX", 0.1686864409720319, 0.4919575435512007]],
246 | [["Solarize", 0.10702753776284957, 0.3954707963684698], ["Contrast", 0.7256100635368403, 0.48845259655719686]],
247 | [["Sharpness", 0.6165615058519549, 0.2624079463213861], ["ShearX", 0.3804820351860919, 0.4738994677544202]],
248 | [["TranslateX", 0.18066394808448177, 0.8174509422318228],
249 | ["Solarize", 0.07964569396290502, 0.45495935736800974]],
250 | [["Sharpness", 0.2741884021129658, 0.9311045302358317], ["Cutout", 0.0009101326429323388, 0.5932102256756948]],
251 | [["Rotate", 0.8501796375826188, 0.5092564038282137], ["Brightness", 0.6520146983999912, 0.724091283316938]],
252 | [["Brightness", 0.10079744898900078, 0.7644088017429471],
253 | ["AutoContrast", 0.33540215138213575, 0.1487538541758792]],
254 | [["ShearY", 0.10632545944757177, 0.9565164562996977], ["Rotate", 0.275833816849538, 0.6200731548023757]],
255 | [["Color", 0.6749819274397422, 0.41042188598168844],
256 | ["AutoContrast", 0.22396590966461932, 0.5048018491863738]],
257 | [["Equalize", 0.5044277111650255, 0.2649182381110667],
258 | ["Brightness", 0.35715133289571355, 0.8653260893016869]],
259 | [["Cutout", 0.49083594426355326, 0.5602781291093129], ["Posterize", 0.721795488514384, 0.5525847430754974]],
260 | [["Sharpness", 0.5081835448947317, 0.7453323423804428],
261 | ["TranslateX", 0.11511932212234266, 0.4337766796030984]],
262 | [["Solarize", 0.3817050641766593, 0.6879004573473403], ["Invert", 0.0015041436267447528, 0.9793134066888262]],
263 | [["AutoContrast", 0.5107410439697935, 0.8276720355454423],
264 | ["Cutout", 0.2786270701864015, 0.43993387208414564]],
265 | [["Rotate", 0.6711202569428987, 0.6342930903972932], ["Posterize", 0.802820231163559, 0.42770002619222053]],
266 | [["Color", 0.9426854321337312, 0.9055431782458764], ["AutoContrast", 0.3556422423506799, 0.2773922428787449]],
267 | [["Contrast", 0.10318991257659992, 0.30841372533347416],
268 | ["Posterize", 0.4202264962677853, 0.05060395018085634]],
269 | [["Invert", 0.549305630337048, 0.886056156681853], ["Cutout", 0.9314157033373055, 0.3485836940307909]],
270 | [["ShearX", 0.5642891775895684, 0.16427372934801418], ["Invert", 0.228741164726475, 0.5066345406806475]],
271 | [["ShearY", 0.5813123201003086, 0.33474363490586106], ["Equalize", 0.11803439432255824, 0.8583936440614798]],
272 | [["Sharpness", 0.1642809706111211, 0.6958675237301609], ["ShearY", 0.5989560762277414, 0.6194018060415276]],
273 | [["Rotate", 0.05092104774529638, 0.9358045394527796], ["Cutout", 0.6443254331615441, 0.28548414658857657]],
274 | [["Brightness", 0.6986036769232594, 0.9618046340942727],
275 | ["Sharpness", 0.5564490243465492, 0.6295231286085622]],
276 | [["Brightness", 0.42725649792574105, 0.17628028916784244],
277 | ["Equalize", 0.4425109360966546, 0.6392872650036018]],
278 | [["ShearY", 0.5758622795525444, 0.8773349286588288], ["ShearX", 0.038525646435423666, 0.8755366512394268]],
279 | [["Sharpness", 0.3704459924265827, 0.9236361456197351], ["Color", 0.6379842432311235, 0.4548767717224531]],
280 | [["Contrast", 0.1619523824549347, 0.4506528800882731],
281 | ["AutoContrast", 0.34513874426188385, 0.3580290330996726]],
282 | [["Contrast", 0.728699731513527, 0.6932238009822878], ["Brightness", 0.8602917375630352, 0.5341445123280423]],
283 | [["Equalize", 0.3574552353044203, 0.16814745124536548], ["Rotate", 0.24191717169379262, 0.3279497108179034]],
284 | [["ShearY", 0.8567478695576244, 0.37746117240238164], ["ShearX", 0.9654125389830487, 0.9283047610798827]],
285 | [["ShearY", 0.4339052480582405, 0.5394548246617406], ["Cutout", 0.5070570647967001, 0.7846286976687882]],
286 | [["AutoContrast", 0.021620100406875065, 0.44425839772845227],
287 | ["AutoContrast", 0.33978157614075183, 0.47716564815092244]],
288 | [["Contrast", 0.9727600659025666, 0.6651758819229426],
289 | ["Brightness", 0.9893133904996626, 0.39176397622636105]],
290 | [["Equalize", 0.283428620586305, 0.18727922861893637], ["Rotate", 0.3556063466797136, 0.3722839913107821]],
291 | [["ShearY", 0.7276172841941864, 0.4834188516302227], ["ShearX", 0.010783217950465884, 0.9756458772142235]],
292 | [["ShearY", 0.2901753295101581, 0.5684700238749064], ["Cutout", 0.655585564610337, 0.9490071307790201]],
293 | [["AutoContrast", 0.008507193981450278, 0.4881150103902877],
294 | ["AutoContrast", 0.6561989723231185, 0.3715071329838596]],
295 | [["Contrast", 0.7702505530948414, 0.6961371266519999], ["Brightness", 0.9953051630261895, 0.3861962467326121]],
296 | [["Equalize", 0.2805270012472756, 0.17715406116880994], ["Rotate", 0.3111256593947474, 0.15824352183820073]],
297 | [["Brightness", 0.9888680802094193, 0.4856236485253163], ["ShearX", 0.022370252047332284, 0.9284975906226682]],
298 | [["ShearY", 0.4065719044318099, 0.7468528006921563],
299 | ["AutoContrast", 0.19494427109708126, 0.8613186475174786]],
300 | [["AutoContrast", 0.023296727279367765, 0.9170949567425306],
301 | ["AutoContrast", 0.11663051100921168, 0.7908646792175343]],
302 | [["AutoContrast", 0.7335191671571732, 0.4958357308292425], ["Color", 0.7964964008349845, 0.4977687544324929]],
303 | [["ShearX", 0.19905221600021472, 0.3033081933150046], ["Equalize", 0.9383410219319321, 0.3224669877230161]],
304 | [["ShearX", 0.8265450331466404, 0.6509091423603757], ["Sharpness", 0.7134181178748723, 0.6472835976443643]],
305 | [["ShearY", 0.46962439525486044, 0.223433110541722], ["Rotate", 0.7749806946212373, 0.5337060376916906]],
306 | [["Posterize", 0.1652499695106796, 0.04860659068586126],
307 | ["Brightness", 0.6644577712782511, 0.4144528269429337]],
308 | [["TranslateY", 0.6220449565731829, 0.4917495676722932],
309 | ["Posterize", 0.6255000355409635, 0.8374266890984867]],
310 | [["AutoContrast", 0.4887160797052227, 0.7106426020530529],
311 | ["Sharpness", 0.7684218571497236, 0.43678474722954763]],
312 | [["Invert", 0.13178101535845366, 0.8301141976359813], ["Color", 0.002820877424219378, 0.49444413062487075]],
313 | [["TranslateX", 0.9920683666478188, 0.5862245842588877],
314 | ["Posterize", 0.5536357075855376, 0.5454300367281468]],
315 | [["Brightness", 0.8150181219663427, 0.1411060258870707], ["Sharpness", 0.8548823004164599, 0.77008691072314]],
316 | [["Brightness", 0.9580478020413399, 0.7198667636628974], ["ShearY", 0.8431585033377366, 0.38750016565010803]],
317 | [["Solarize", 0.2331505347152334, 0.25754361489084787], ["TranslateY", 0.447431373734262, 0.5782399531772253]],
318 | [["TranslateY", 0.8904927998691309, 0.25872872455072315],
319 | ["AutoContrast", 0.7129888139716263, 0.7161603231650524]],
320 | [["ShearY", 0.6336216800247362, 0.5247508616674911], ["Cutout", 0.9167315119726633, 0.2060557387978919]],
321 | [["ShearX", 0.001661782345968199, 0.3682225725445044], ["Solarize", 0.12303352043754572, 0.5014989548584458]],
322 | [["Brightness", 0.9723625105116246, 0.6555444729681099], ["Contrast", 0.5539208721135375, 0.7819973409318487]],
323 | [["Equalize", 0.3262607499912611, 0.0006745572802121513],
324 | ["Contrast", 0.35341551623767103, 0.36814689398886347]],
325 | [["ShearY", 0.7478539900243613, 0.37322078030129185], ["TranslateX", 0.41558847793529247, 0.7394615158544118]],
326 | [["Invert", 0.13735541232529067, 0.5536403864332143], ["Cutout", 0.5109718190377135, 0.0447509485253679]],
327 | [["AutoContrast", 0.09403602327274725, 0.5909250807862687], ["ShearY", 0.53234060616395, 0.5316981359469398]],
328 | [["ShearX", 0.5651922367876323, 0.6794110241313183], ["Posterize", 0.7431624856363638, 0.7896861463783287]],
329 | [["Brightness", 0.30949179379286806, 0.7650569096019195],
330 | ["Sharpness", 0.5461629122105034, 0.6814369444005866]],
331 | [["Sharpness", 0.28459340191768434, 0.7802208350806028], ["Rotate", 0.15097973114238117, 0.5259683294104645]],
332 | [["ShearX", 0.6430803693700531, 0.9333735880102375], ["Contrast", 0.7522209520030653, 0.18831747966185058]],
333 | [["Contrast", 0.4219455937915647, 0.29949769435499646], ["Color", 0.6925322933509542, 0.8095523885795443]],
334 | [["ShearX", 0.23553236193043048, 0.17966207900468323],
335 | ["AutoContrast", 0.9039700567886262, 0.21983629944639108]],
336 | [["ShearX", 0.19256223146671514, 0.31200739880443584], ["Sharpness", 0.31962196883294713, 0.6828107668550425]],
337 | [["Cutout", 0.5947690279080912, 0.21728220253899178], ["Rotate", 0.6757188879871141, 0.489460599679474]],
338 | [["ShearY", 0.18365897125470526, 0.3988571115918058], ["Brightness", 0.7727489489504, 0.4790369956329955]],
339 | [["Contrast", 0.7090301084131432, 0.5178303607560537], ["ShearX", 0.16749258277688506, 0.33061773301592356]],
340 | [["ShearX", 0.3706690885419934, 0.38510677124319415],
341 | ["AutoContrast", 0.8288356276501032, 0.16556487668770264]],
342 | [["TranslateY", 0.16758043046445614, 0.30127092823893986],
343 | ["Brightness", 0.5194636577132354, 0.6225165310621702]],
344 | [["Cutout", 0.6087289363049726, 0.10439287037803044], ["Rotate", 0.7503452083033819, 0.7425316019981433]],
345 | [["ShearY", 0.24347189588329932, 0.5554979486672325], ["Brightness", 0.9468115239174161, 0.6132449358023568]],
346 | [["Brightness", 0.7144508395807994, 0.4610594769966929], ["ShearX", 0.16466683833092968, 0.3382903812375781]],
347 | [["Sharpness", 0.27743648684265465, 0.17200038071656915], ["Color", 0.47404262107546236, 0.7868991675614725]],
348 | [["Sharpness", 0.8603993513633618, 0.324604728411791], ["TranslateX", 0.3331597130403763, 0.9369586812977804]],
349 | [["Color", 0.1535813630595832, 0.4700116846558207], ["Color", 0.5435647971896318, 0.7639291483525243]],
350 | [["Brightness", 0.21486188101947656, 0.039347277341450576],
351 | ["Cutout", 0.7069526940684954, 0.39273934115015696]],
352 | [["ShearY", 0.7267130888840517, 0.6310800726389485], ["AutoContrast", 0.662163190824139, 0.31948540372237766]],
353 | [["ShearX", 0.5123132117185981, 0.1981015909438834],
354 | ["AutoContrast", 0.9009347363863067, 0.26790399126924036]],
355 | [["Brightness", 0.24245061453231648, 0.2673478678291436], ["ShearX", 0.31707976089283946, 0.6800582845544948]],
356 | [["Cutout", 0.9257780138367764, 0.03972673526848819], ["Rotate", 0.6807858944518548, 0.46974332280612097]],
357 | [["ShearY", 0.1543443071262312, 0.6051682587030671], ["Brightness", 0.9758203119828304, 0.4941406868162414]],
358 | [["Contrast", 0.07578049236491124, 0.38953819133407647], ["ShearX", 0.20194918288164293, 0.4141510791947318]],
359 | [["Color", 0.27826402243792286, 0.43517491081531157],
360 | ["AutoContrast", 0.6159269026143263, 0.2021846783488046]],
361 | [["AutoContrast", 0.5039377966534692, 0.19241507605941105],
362 | ["Invert", 0.5563931144385394, 0.7069728937319112]],
363 | [["Sharpness", 0.19031632433810566, 0.26310171056096743], ["Color", 0.4724537593175573, 0.6715201448387876]],
364 | [["ShearY", 0.2280910467786642, 0.33340559088059313], ["ShearY", 0.8858560034869303, 0.2598627441471076]],
365 | [["ShearY", 0.07291814128021593, 0.5819462692986321], ["Cutout", 0.27605696060512147, 0.9693427371868695]],
366 | [["Posterize", 0.4249871586563321, 0.8256952014328607],
367 | ["Posterize", 0.005907466926447169, 0.8081353382152597]],
368 | [["Brightness", 0.9071305290601128, 0.4781196213717954],
369 | ["Posterize", 0.8996214311439275, 0.5540717376630279]],
370 | [["Brightness", 0.06560728936236392, 0.9920627849065685],
371 | ["TranslateX", 0.04530789794044952, 0.5318568944702607]],
372 | [["TranslateX", 0.6800263601084814, 0.4611536772507228], ["Rotate", 0.7245888375283157, 0.0914772551375381]],
373 | [["Sharpness", 0.879556061897963, 0.42272481462067535],
374 | ["TranslateX", 0.4600350422524085, 0.5742175429334919]],
375 | [["AutoContrast", 0.5005776243176145, 0.22597121331684505],
376 | ["Invert", 0.10763286370369299, 0.6841782704962373]], [["Sharpness", 0.7422908472000116, 0.6850324203882405],
377 | ["TranslateX", 0.3832914614128403,
378 | 0.34798646673324896]],
379 | [["ShearY", 0.31939465302679326, 0.8792088167639516], ["Brightness", 0.4093604352811235, 0.21055483197261338]],
380 | [["AutoContrast", 0.7447595860998638, 0.19280222555998586],
381 | ["TranslateY", 0.317754779431227, 0.9983454520593591]],
382 | [["Equalize", 0.27706973689750847, 0.6447455020660622], ["Contrast", 0.5626579126863761, 0.7920049962776781]],
383 | [["Rotate", 0.13064369451773816, 0.1495367590684905], ["Sharpness", 0.24893941981801215, 0.6295943894521504]],
384 | [["ShearX", 0.6856269993063254, 0.5167938584189854], ["Sharpness", 0.24835352574609537, 0.9990550493102627]],
385 | [["AutoContrast", 0.461654115871693, 0.43097388896245004], ["Cutout", 0.366359682416437, 0.08011826474215511]],
386 | [["AutoContrast", 0.993892672935951, 0.2403608711236933], ["ShearX", 0.6620817870694181, 0.1744814077869482]],
387 | [["ShearY", 0.6396747719986443, 0.15031017143644265], ["Brightness", 0.9451954879495629, 0.26490678840264714]],
388 | [["Color", 0.19311480787397262, 0.15712300697448575], ["Posterize", 0.05391448762015258, 0.6943963643155474]],
389 | [["Sharpness", 0.6199669674684085, 0.5412492335319072], ["Invert", 0.14086213450149815, 0.2611850277919339]],
390 | [["Posterize", 0.5533129268803405, 0.5332478159319912], ["ShearX", 0.48956244029096635, 0.09223930853562916]],
391 | [["ShearY", 0.05871590849449765, 0.19549715278943228],
392 | ["TranslateY", 0.7208521362741379, 0.36414003004659434]],
393 | [["ShearY", 0.7316263417917531, 0.0629747985768501], ["Contrast", 0.036359793501448245, 0.48658745414898386]],
394 | [["Rotate", 0.3301497610942963, 0.5686622043085637], ["ShearX", 0.40581487555676843, 0.5866127743850192]],
395 | [["ShearX", 0.6679039628249283, 0.5292270693200821], ["Sharpness", 0.25901391739310703, 0.9778360586541461]],
396 | [["AutoContrast", 0.27373222012596854, 0.14456771405730712],
397 | ["Contrast", 0.3877220783523938, 0.7965158941894336]],
398 | [["Solarize", 0.29440905483979096, 0.06071633809388455],
399 | ["Equalize", 0.5246736285116214, 0.37575084834661976]],
400 | [["TranslateY", 0.2191269464520395, 0.7444942293988484],
401 | ["Posterize", 0.3840878524812771, 0.31812671711741247]],
402 | [["Solarize", 0.25159267140731356, 0.5833264622559661],
403 | ["Brightness", 0.07552262572348738, 0.33210648549288435]],
404 | [["AutoContrast", 0.9770099298399954, 0.46421915310428197],
405 | ["AutoContrast", 0.04707358934642503, 0.24922048012183493]],
406 | [["Cutout", 0.5379685806621965, 0.02038212605928355], ["Brightness", 0.5900728303717965, 0.28807872931416956]],
407 | [["Sharpness", 0.11596624872886108, 0.6086947716949325],
408 | ["AutoContrast", 0.34876470059667525, 0.22707897759730578]],
409 | [["Contrast", 0.276545513135698, 0.8822580384226156], ["Rotate", 0.04874027684061846, 0.6722214281612163]],
410 | [["ShearY", 0.595839851757025, 0.4389866852785822], ["Equalize", 0.5225492356128832, 0.2735290854063459]],
411 | [["Sharpness", 0.9918029636732927, 0.9919926583216121],
412 | ["Sharpness", 0.03672376137997366, 0.5563865980047012]],
413 | [["AutoContrast", 0.34169589759999847, 0.16419911552645738],
414 | ["Invert", 0.32995953043129234, 0.15073174739720568]],
415 | [["Posterize", 0.04600255098477292, 0.2632612790075844],
416 | ["TranslateY", 0.7852153329831825, 0.6990722310191976]],
417 | [["AutoContrast", 0.4414653815356372, 0.2657468780017082],
418 | ["Posterize", 0.30647061536763337, 0.3688222724948656]],
419 | [["Contrast", 0.4239361091421837, 0.6076562806342001], ["Cutout", 0.5780707784165284, 0.05361325256745192]],
420 | [["Sharpness", 0.7657895907855394, 0.9842407321667671], ["Sharpness", 0.5416352696151596, 0.6773681575200902]],
421 | [["AutoContrast", 0.13967381098331305, 0.10787258006315015],
422 | ["Posterize", 0.5019536507897069, 0.9881978222469807]],
423 | [["Brightness", 0.030528346448984903, 0.31562058762552847],
424 | ["TranslateY", 0.0843808140595676, 0.21019213305350526]],
425 | [["AutoContrast", 0.6934579165006736, 0.2530484168209199],
426 | ["Rotate", 0.0005751408130693636, 0.43790043943210005]],
427 | [["TranslateX", 0.611258547664328, 0.25465240215894935],
428 | ["Sharpness", 0.5001446909868196, 0.36102204109889413]],
429 | [["Contrast", 0.8995127327150193, 0.5493190695343996], ["Brightness", 0.242708780669213, 0.5461116653329015]],
430 | [["AutoContrast", 0.3751825351022747, 0.16845985803896962],
431 | ["Cutout", 0.25201103287363663, 0.0005893331783358435]],
432 | [["ShearX", 0.1518985779435941, 0.14768180777304504], ["Color", 0.85133530274324, 0.4006641163378305]],
433 | [["TranslateX", 0.5489668255504668, 0.4694591826554948], ["Rotate", 0.1917354490155893, 0.39993269385802177]],
434 | [["ShearY", 0.6689267479532809, 0.34304285013663577], ["Equalize", 0.24133154048883143, 0.279324043138247]],
435 | [["Contrast", 0.3412544002099494, 0.20217358823930232], ["Color", 0.8606984790510235, 0.14305503544676373]],
436 | [["Cutout", 0.21656155695311988, 0.5240101349572595], ["Brightness", 0.14109877717636352, 0.2016827341210295]],
437 | [["Sharpness", 0.24764371218833872, 0.19655480259925423],
438 | ["Posterize", 0.19460398862039913, 0.4975414350200679]],
439 | [["Brightness", 0.6071850094982323, 0.7270716448607151], ["Solarize", 0.111786402398499, 0.6325641684614275]],
440 | [["Contrast", 0.44772949532200856, 0.44267502710695955],
441 | ["AutoContrast", 0.360117506402693, 0.2623958228760273]],
442 | [["Sharpness", 0.8888131688583053, 0.936897400764746], ["Sharpness", 0.16080674198274894, 0.5681119841445879]],
443 | [["AutoContrast", 0.8004456226590612, 0.1788600469525269],
444 | ["Brightness", 0.24832285390647374, 0.02755350284841604]],
445 | [["ShearY", 0.06910320102646594, 0.26076407321544054], ["Contrast", 0.8633703022354964, 0.38968514704043056]],
446 | [["AutoContrast", 0.42306251382780613, 0.6883260271268138],
447 | ["Rotate", 0.3938724346852023, 0.16740881249086037]],
448 | [["Contrast", 0.2725343884286728, 0.6468194318074759], ["Sharpness", 0.32238942646494745, 0.6721149242783824]],
449 | [["AutoContrast", 0.942093919956842, 0.14675331481712853],
450 | ["Posterize", 0.5406276708262192, 0.683901182218153]],
451 | [["Cutout", 0.5386811894643584, 0.04498833938429728], ["Posterize", 0.17007257321724775, 0.45761177118620633]],
452 | [["Contrast", 0.13599408935104654, 0.53282738083886], ["Solarize", 0.26941667995081114, 0.20958261079465895]],
453 | [["Color", 0.6600788518606634, 0.9522228302165842], ["Invert", 0.0542722262516899, 0.5152431169321683]],
454 | [["Contrast", 0.5328934819727553, 0.2376220512388278], ["Posterize", 0.04890422575781711, 0.3182233123739474]],
455 | [["AutoContrast", 0.9289628064340965, 0.2976678437448435], ["Color", 0.20936893798507963, 0.9649612821434217]],
456 | [["Cutout", 0.9019423698575457, 0.24002036989728096],
457 | ["Brightness", 0.48734445615892974, 0.047660899809176316]],
458 | [["Sharpness", 0.09347824275711591, 0.01358686275590612],
459 | ["Posterize", 0.9248539660538934, 0.4064232632650468]],
460 | [["Brightness", 0.46575675383704634, 0.6280194775484345],
461 | ["Invert", 0.17276207634499413, 0.21263495428839635]],
462 | [["Brightness", 0.7238014711679732, 0.6178946027258592],
463 | ["Equalize", 0.3815496086340364, 0.07301281068847276]],
464 | [["Contrast", 0.754557393588416, 0.895332753570098], ["Color", 0.32709957750707447, 0.8425486003491515]],
465 | [["Rotate", 0.43406698081696576, 0.28628263254953723],
466 | ["TranslateY", 0.43949548709125374, 0.15927082198238685]],
467 | [["Brightness", 0.0015838339831640708, 0.09341692553352654],
468 | ["AutoContrast", 0.9113966907329718, 0.8345900469751112]],
469 | [["ShearY", 0.46698796308585017, 0.6150701348176804], ["Invert", 0.14894062704815722, 0.2778388046184728]],
470 | [["Color", 0.30360499169455957, 0.995713092016834], ["Contrast", 0.2597016288524961, 0.8654420870658932]],
471 | [["Brightness", 0.9661642031891435, 0.7322006407169436],
472 | ["TranslateY", 0.4393502786333408, 0.33934762664274265]],
473 | [["Color", 0.9323638351992302, 0.912776309755293], ["Brightness", 0.1618274755371618, 0.23485741708056307]],
474 | [["Color", 0.2216470771158821, 0.3359240197334976], ["Sharpness", 0.6328691811471494, 0.6298393874452548]],
475 | [["Solarize", 0.4772769142265505, 0.7073470698713035], ["ShearY", 0.2656114148206966, 0.31343097010487253]],
476 | [["Solarize", 0.3839017339304234, 0.5985505779429036],
477 | ["Equalize", 0.002412059429196589, 0.06637506181196245]],
478 | [["Contrast", 0.12751196553017863, 0.46980311434237976],
479 | ["Sharpness", 0.3467487455865491, 0.4054907610444406]],
480 | [["AutoContrast", 0.9321813669127206, 0.31328471589533274],
481 | ["Rotate", 0.05801738717432747, 0.36035756254444273]],
482 | [["TranslateX", 0.52092390458353, 0.5261722561643886], ["Contrast", 0.17836804476171306, 0.39354333443158535]],
483 | [["Posterize", 0.5458100909925713, 0.49447244994482603],
484 | ["Brightness", 0.7372536822363605, 0.5303409097463796]],
485 | [["Solarize", 0.1913974941725724, 0.5582966653986761], ["Equalize", 0.020733669175727026, 0.9377467166472878]],
486 | [["Equalize", 0.16265732137763889, 0.5206282340874929], ["Sharpness", 0.2421533133595281, 0.506389065871883]],
487 | [["AutoContrast", 0.9787324801448523, 0.24815051941486466],
488 | ["Rotate", 0.2423487151245957, 0.6456493129745148]], [["TranslateX", 0.6809867726670327, 0.6949687002397612],
489 | ["Contrast", 0.16125673359747458, 0.7582679978218987]],
490 | [["Posterize", 0.8212000950994955, 0.5225012157831872],
491 | ["Brightness", 0.8824891856626245, 0.4499216779709508]],
492 | [["Solarize", 0.12061313332505218, 0.5319371283368052], ["Equalize", 0.04120865969945108, 0.8179402157299602]],
493 | [["Rotate", 0.11278256686005855, 0.4022686554165438], ["ShearX", 0.2983451019112792, 0.42782525461812604]],
494 | [["ShearY", 0.8847385513289983, 0.5429227024179573], ["Rotate", 0.21316428726607445, 0.6712120087528564]],
495 | [["TranslateX", 0.46448081241068717, 0.4746090648963252],
496 | ["Brightness", 0.19973580961271142, 0.49252862676553605]],
497 | [["Posterize", 0.49664100539481526, 0.4460713166484651],
498 | ["Brightness", 0.6629559985581529, 0.35192346529003693]],
499 | [["Color", 0.22710733249173676, 0.37943185764616194], ["ShearX", 0.015809774971472595, 0.8472080190835669]],
500 | [["Contrast", 0.4187366322381491, 0.21621979869256666],
501 | ["AutoContrast", 0.7631045030367304, 0.44965231251615134]],
502 | [["Sharpness", 0.47240637876720515, 0.8080091811749525], ["Cutout", 0.2853425420104144, 0.6669811510150936]],
503 | [["Posterize", 0.7830320527127324, 0.2727062685529881], ["Solarize", 0.527834000867504, 0.20098218845222998]],
504 | [["Contrast", 0.366380535288225, 0.39766001659663075], ["Cutout", 0.8708808878088891, 0.20669525734273086]],
505 | [["ShearX", 0.6815427281122932, 0.6146858582671569], ["AutoContrast", 0.28330622372053493, 0.931352024154997]],
506 | [["AutoContrast", 0.8668174463154519, 0.39961453880632863],
507 | ["AutoContrast", 0.5718557712359253, 0.6337062930797239]],
508 | [["ShearY", 0.8923152519411871, 0.02480062504737446], ["Cutout", 0.14954159341231515, 0.1422219808492364]],
509 | [["Rotate", 0.3733718175355636, 0.3861928572224287], ["Sharpness", 0.5651126520194574, 0.6091103847442831]],
510 | [["Posterize", 0.8891714191922857, 0.29600154265251016],
511 | ["TranslateY", 0.7865351723963945, 0.5664998548985523]],
512 | [["TranslateX", 0.9298214806998273, 0.729856565052017],
513 | ["AutoContrast", 0.26349082482341846, 0.9638882609038888]],
514 | [["Sharpness", 0.8387378377527128, 0.42146721129032494],
515 | ["AutoContrast", 0.9860522000876452, 0.4200699464169384]],
516 | [["ShearY", 0.019609159303115145, 0.37197835936879514], ["Cutout", 0.22199340461754258, 0.015932573201085848]],
517 | [["Rotate", 0.43871085583928443, 0.3283504258860078], ["Sharpness", 0.6077702068037776, 0.6830305349618742]],
518 | [["Contrast", 0.6160211756538094, 0.32029451083389626], ["Cutout", 0.8037631428427006, 0.4025688837399259]],
519 | [["TranslateY", 0.051637820936985435, 0.6908417834391846],
520 | ["Sharpness", 0.7602756948473368, 0.4927111506643095]],
521 | [["Rotate", 0.4973618638052235, 0.45931479729281227], ["TranslateY", 0.04701789716427618, 0.9408779705948676]],
522 | [["Rotate", 0.5214194592768602, 0.8371249272013652], ["Solarize", 0.17734812472813338, 0.045020798970228315]],
523 | [["ShearX", 0.7457999920079351, 0.19025612553075893], ["Sharpness", 0.5994846101703786, 0.5665094068864229]],
524 | [["Contrast", 0.6172655452900769, 0.7811432139704904], ["Cutout", 0.09915620454670282, 0.3963692287596121]],
525 | [["TranslateX", 0.2650112299235817, 0.7377261946165307],
526 | ["AutoContrast", 0.5019539734059677, 0.26905046992024506]],
527 | [["Contrast", 0.6646299821370135, 0.41667784809592945], ["Cutout", 0.9698457154992128, 0.15429001887703997]],
528 | [["Sharpness", 0.9467079029475773, 0.44906457469098204], ["Cutout", 0.30036908747917396, 0.4766149689663106]],
529 | [["Equalize", 0.6667517691051055, 0.5014839828447363], ["Solarize", 0.4127890336820831, 0.9578274770236529]],
530 | [["Cutout", 0.6447384874120834, 0.2868806107728985], ["Cutout", 0.4800990488106021, 0.4757538246206956]],
531 | [["Solarize", 0.12560195032363236, 0.5557473475801568],
532 | ["Equalize", 0.019957161871490228, 0.5556797187823773]],
533 | [["Contrast", 0.12607637375759484, 0.4300633627435161],
534 | ["Sharpness", 0.3437273670109087, 0.40493203127714417]],
535 | [["AutoContrast", 0.884353334807183, 0.5880138314357569], ["Rotate", 0.9846032404597116, 0.3591877296622974]],
536 | [["TranslateX", 0.6862295865975581, 0.5307482119690076],
537 | ["Contrast", 0.19439251187251982, 0.3999195825722808]],
538 | [["Posterize", 0.4187641835025246, 0.5008988942651585],
539 | ["Brightness", 0.6665805605402482, 0.3853288204214253]],
540 | [["Posterize", 0.4507470690013903, 0.4232437206624681],
541 | ["TranslateX", 0.6054107416317659, 0.38123828040922203]],
542 | [["AutoContrast", 0.29562338573283276, 0.35608605102687474],
543 | ["TranslateX", 0.909954785390274, 0.20098894888066549]],
544 | [["Contrast", 0.6015278411777212, 0.6049140992035096], ["Cutout", 0.47178713636517855, 0.5333747244651914]],
545 | [["TranslateX", 0.490851976691112, 0.3829593925141144], ["Sharpness", 0.2716675173824095, 0.5131696240367152]],
546 | [["Posterize", 0.4190558294646337, 0.39316689077269873], ["Rotate", 0.5018526072725914, 0.295712490156129]],
547 | [["AutoContrast", 0.29624715560691617, 0.10937329832409388],
548 | ["Posterize", 0.8770505275992637, 0.43117765012206943]],
549 | [["Rotate", 0.6649970092751698, 0.47767131373391974], ["ShearX", 0.6257923540490786, 0.6643337040198358]],
550 | [["Sharpness", 0.5553620705849509, 0.8467799429696928], ["Cutout", 0.9006185811918932, 0.3537270716262]],
551 | [["ShearY", 0.0007619678283789788, 0.9494591850536303], ["Invert", 0.24267733654007673, 0.7851608409575828]],
552 | [["Contrast", 0.9730916198112872, 0.404670123321921], ["Sharpness", 0.5923587793251186, 0.7405792404430281]],
553 | [["Cutout", 0.07393909593373034, 0.44569630026328344], ["TranslateX", 0.2460593252211425, 0.4817527814541055]],
554 | [["Brightness", 0.31058654119340867, 0.7043749950260936], ["ShearX", 0.7632161538947713, 0.8043681264908555]],
555 | [["AutoContrast", 0.4352334371415373, 0.6377550087204297],
556 | ["Rotate", 0.2892714673415678, 0.49521052050510556]],
557 | [["Equalize", 0.509071051375276, 0.7352913414974414], ["ShearX", 0.5099959429711828, 0.7071566714593619]],
558 | [["Posterize", 0.9540506532512889, 0.8498853304461906], ["ShearY", 0.28199061357155397, 0.3161715627214629]],
559 | [["Posterize", 0.6740855359097433, 0.684004694936616], ["Posterize", 0.6816720350737863, 0.9654766942980918]],
560 | [["Solarize", 0.7149344531717328, 0.42212789795181643], ["Brightness", 0.686601460864528, 0.4263050070610551]],
561 | [["Cutout", 0.49577164991501, 0.08394890892056037], ["Rotate", 0.5810369852730606, 0.3320732965776973]],
562 | [["TranslateY", 0.1793755480490623, 0.6006520265468684],
563 | ["Brightness", 0.3769016576438939, 0.7190746300828186]],
564 | [["TranslateX", 0.7226363597757153, 0.3847027238123509],
565 | ["Brightness", 0.7641713191794035, 0.36234003077512544]],
566 | [["TranslateY", 0.1211227055347106, 0.6693523474608023],
567 | ["Brightness", 0.13011180247738063, 0.5126647617294864]],
568 | [["Equalize", 0.1501070550869129, 0.0038548909451806557],
569 | ["Posterize", 0.8266535939653881, 0.5502199643499207]], [["Sharpness", 0.550624117428359, 0.2023044586648523],
570 | ["Brightness", 0.06291556314780017,
571 | 0.7832635398703937]],
572 | [["Color", 0.3701578205508141, 0.9051537973590863], ["Contrast", 0.5763972727739397, 0.4905511239739898]],
573 | [["Rotate", 0.7678527224046323, 0.6723066265307555], ["Solarize", 0.31458533097383207, 0.38329324335154524]],
574 | [["Brightness", 0.292050127929522, 0.7047582807953063], ["ShearX", 0.040541891910333805, 0.06639328601282746]],
575 | [["TranslateY", 0.4293891393238555, 0.6608516902234284],
576 | ["Sharpness", 0.7794685477624004, 0.5168044063408147]],
577 | [["Color", 0.3682450402286552, 0.17274523597220048], ["ShearY", 0.3936056470397763, 0.5702597289866161]],
578 | [["Equalize", 0.43436990310624657, 0.9207072627823626], ["Contrast", 0.7608688260846083, 0.4759023148841439]],
579 | [["Brightness", 0.7926088966143935, 0.8270093925674497], ["ShearY", 0.4924174064969461, 0.47424347505831244]],
580 | [["Contrast", 0.043917555279430476, 0.15861903591675125], ["ShearX", 0.30439480405505853, 0.1682659341098064]],
581 | [["TranslateY", 0.5598255583454538, 0.721352536005039], ["Posterize", 0.9700921973303752, 0.6882015184440126]],
582 | [["AutoContrast", 0.3620887415037668, 0.5958176322317132],
583 | ["TranslateX", 0.14213781552733287, 0.6230799786459947]],
584 | [["Color", 0.490366889723972, 0.9863152892045195], ["Color", 0.817792262022319, 0.6755656429452775]],
585 | [["Brightness", 0.7030707021937771, 0.254633187122679], ["Color", 0.13977318232688843, 0.16378180123959793]],
586 | [["AutoContrast", 0.2933247831326118, 0.6283663376211102],
587 | ["Sharpness", 0.85430478154147, 0.9753613184208796]],
588 | [["Rotate", 0.6674299955457268, 0.48571208708018976], ["Contrast", 0.47491370175907016, 0.6401079552479657]],
589 | [["Sharpness", 0.37589579644127863, 0.8475131989077025],
590 | ["TranslateY", 0.9985149867598191, 0.057815729375099975]],
591 | [["Equalize", 0.0017194373841596389, 0.7888361311461602], ["Contrast", 0.6779293670669408, 0.796851411454113]],
592 | [["TranslateY", 0.3296782119072306, 0.39765117357271834],
593 | ["Sharpness", 0.5890554357001884, 0.6318339473765834]],
594 | [["Posterize", 0.25423810893163856, 0.5400430289894207],
595 | ["Sharpness", 0.9273643918988342, 0.6480913470982622]],
596 | [["Cutout", 0.850219975768305, 0.4169812455601289], ["Solarize", 0.5418755745870089, 0.5679666650495466]],
597 | [["Brightness", 0.008881361977310959, 0.9282562314720516],
598 | ["TranslateY", 0.7736066471553994, 0.20041167606029642]],
599 | [["Brightness", 0.05382537581401925, 0.6405265501035952],
600 | ["Contrast", 0.30484329473639593, 0.5449338155734242]],
601 | [["Color", 0.613257119787967, 0.4541503912724138], ["Brightness", 0.9061572524724674, 0.4030159294447347]],
602 | [["Brightness", 0.02739111568942537, 0.006028056532326534],
603 | ["ShearX", 0.17276751958646486, 0.05967365780621859]],
604 | [["TranslateY", 0.4376298213047888, 0.7691816164456199],
605 | ["Sharpness", 0.8162292718857824, 0.6054926462265117]],
606 | [["Color", 0.37963069679121214, 0.5946919433483344], ["Posterize", 0.08485417284005387, 0.5663580913231766]],
607 | [["Equalize", 0.49785780226818316, 0.9999137109183761], ["Sharpness", 0.7685879484682496, 0.6260846154212211]],
608 | [["AutoContrast", 0.4190931409670763, 0.2374852525139795],
609 | ["Posterize", 0.8797422264608563, 0.3184738541692057]],
610 | [["Rotate", 0.7307269024632872, 0.41523609600701106], ["ShearX", 0.6166685870692289, 0.647133807748274]],
611 | [["Sharpness", 0.5633713231039904, 0.8276694754755876], ["Cutout", 0.8329340776895764, 0.42656043027424073]],
612 | [["ShearY", 0.14934828370884312, 0.8622510773680372], ["Invert", 0.25925989086863277, 0.8813283584888576]],
613 | [["Contrast", 0.9457071292265932, 0.43228655518614034], ["Sharpness", 0.8485316947644338, 0.7590298998732413]],
614 | [["AutoContrast", 0.8386103589399184, 0.5859583131318076],
615 | ["Solarize", 0.466758711343543, 0.9956215363818983]],
616 | [["Rotate", 0.9387133710926467, 0.19180564509396503], ["Rotate", 0.5558247609706255, 0.04321698692007105]],
617 | [["ShearX", 0.3608716600695567, 0.15206159451532864], ["TranslateX", 0.47295292905710146, 0.5290760596129888]],
618 | [["TranslateX", 0.8357685981547495, 0.5991305115727084],
619 | ["Posterize", 0.5362929404188211, 0.34398525441943373]],
620 | [["ShearY", 0.6751984031632811, 0.6066293622133011], ["Contrast", 0.4122723990263818, 0.4062467515095566]],
621 | [["Color", 0.7515349936021702, 0.5122124665429213], ["Contrast", 0.03190514292904123, 0.22903520154660545]],
622 | [["Contrast", 0.5448962625054385, 0.38655673938910545],
623 | ["AutoContrast", 0.4867400684894492, 0.3433111101096984]],
624 | [["Rotate", 0.0008372434310827959, 0.28599951781141714],
625 | ["Equalize", 0.37113686925530087, 0.5243929348114981]],
626 | [["Color", 0.720054993488857, 0.2010177651701808], ["TranslateX", 0.23036196506059398, 0.11152764304368781]],
627 | [["Cutout", 0.859134208332423, 0.6727345740185254], ["ShearY", 0.02159833505865088, 0.46390076266538544]],
628 | [["Sharpness", 0.3428232157391428, 0.4067874527486514],
629 | ["Brightness", 0.5409415136577347, 0.3698432231874003]],
630 | [["Solarize", 0.27303978936454776, 0.9832186173589548], ["ShearY", 0.08831127213044043, 0.4681870331149774]],
631 | [["TranslateY", 0.2909309268736869, 0.4059460811623174],
632 | ["Sharpness", 0.6425125139803729, 0.20275737203293587]],
633 | [["Contrast", 0.32167626214661627, 0.28636162794046977], ["Invert", 0.4712405253509603, 0.7934644799163176]],
634 | [["Color", 0.867993060896951, 0.96574321666213], ["Color", 0.02233897320328512, 0.44478933557303063]],
635 | [["AutoContrast", 0.1841254751814967, 0.2779992148017741], ["Color", 0.3586283093530607, 0.3696246850445087]],
636 | [["Posterize", 0.2052935984046965, 0.16796913860308244], ["ShearX", 0.4807226832843722, 0.11296747254563266]],
637 | [["Cutout", 0.2016411266364791, 0.2765295444084803], ["Brightness", 0.3054112810424313, 0.695924264931216]],
638 | [["Rotate", 0.8405872184910479, 0.5434142541450815], ["Cutout", 0.4493615138203356, 0.893453735250007]],
639 | [["Contrast", 0.8433310507685494, 0.4915423577963278], ["ShearX", 0.22567799557913246, 0.20129892537008834]],
640 | [["Contrast", 0.045954277103674224, 0.5043900167190442], ["Cutout", 0.5552992473054611, 0.14436447810888237]],
641 | [["AutoContrast", 0.7719296115130478, 0.4440417544621306],
642 | ["Sharpness", 0.13992809206158283, 0.7988278670709781]],
643 | [["Color", 0.7838574233513952, 0.5971351401625151], ["TranslateY", 0.13562290583925385, 0.2253039635819158]],
644 | [["Cutout", 0.24870301109385806, 0.6937886690381568], ["TranslateY", 0.4033400068952813, 0.06253378991880915]],
645 | [["TranslateX", 0.0036059390486775644, 0.5234723884081843],
646 | ["Solarize", 0.42724862530733526, 0.8697702564187633]],
647 | [["Equalize", 0.5446026737834311, 0.9367992979112202], ["ShearY", 0.5943478903735789, 0.42345889214100046]],
648 | [["ShearX", 0.18611885697957506, 0.7320849092947314], ["ShearX", 0.3796416430900566, 0.03817761920009881]],
649 | [["Posterize", 0.37636778506979124, 0.26807924785236537],
650 | ["Brightness", 0.4317372554383255, 0.5473346211870932]],
651 | [["Brightness", 0.8100436240916665, 0.3817612088285007],
652 | ["Brightness", 0.4193974619003253, 0.9685902764026623]],
653 | [["Contrast", 0.701776402197012, 0.6612786008858009], ["Color", 0.19882787177960912, 0.17275597188875483]],
654 | [["Color", 0.9538303302832989, 0.48362384535228686], ["ShearY", 0.2179980837345602, 0.37027290936457313]],
655 | [["TranslateY", 0.6068028691503798, 0.3919346523454841], ["Cutout", 0.8228303342563138, 0.18372280287814613]],
656 | [["Equalize", 0.016416758802906828, 0.642838949194916], ["Cutout", 0.5761717838655257, 0.7600661153497648]],
657 | [["Color", 0.9417761826818639, 0.9916074035986558], ["Equalize", 0.2524209308597042, 0.6373703468715077]],
658 | [["Brightness", 0.75512589439513, 0.6155072321007569], ["Contrast", 0.32413476940254515, 0.4194739830159837]],
659 | [["Sharpness", 0.3339450765586968, 0.9973297539194967],
660 | ["AutoContrast", 0.6523930242124429, 0.1053482471037186]],
661 | [["ShearX", 0.2961391955838801, 0.9870036064904368], ["ShearY", 0.18705025965909403, 0.4550895821154484]],
662 | [["TranslateY", 0.36956447983807883, 0.36371471767143543],
663 | ["Sharpness", 0.6860051967688487, 0.2850190720087796]],
664 | [["Cutout", 0.13017742151902967, 0.47316674150067195], ["Invert", 0.28923829959551883, 0.9295585654924601]],
665 | [["Contrast", 0.7302368472279086, 0.7178974949876642],
666 | ["TranslateY", 0.12589674152030433, 0.7485392909494947]],
667 | [["Color", 0.6474693117772619, 0.5518269515590674], ["Contrast", 0.24643004970708016, 0.3435581358079418]],
668 | [["Contrast", 0.5650327855750835, 0.4843031798040887], ["Brightness", 0.3526684005761239, 0.3005305004600969]],
669 | [["Rotate", 0.09822284968122225, 0.13172798244520356], ["Equalize", 0.38135066977857157, 0.5135129123554154]],
670 | [["Contrast", 0.5902590645585712, 0.2196062383730596], ["ShearY", 0.14188379126120954, 0.1582612142182743]],
671 | [["Cutout", 0.8529913814417812, 0.89734031211874], ["Color", 0.07293767043078672, 0.32577659205278897]],
672 | [["Equalize", 0.21401668971453247, 0.040015259500028266], ["ShearY", 0.5126400895338797, 0.4726484828276388]],
673 | [["Brightness", 0.8269430025954498, 0.9678362841865166], ["ShearY", 0.17142069814830432, 0.4726727848289514]],
674 | [["Brightness", 0.699707089334018, 0.2795501395789335], ["ShearX", 0.5308818178242845, 0.10581814221896294]],
675 | [["Equalize", 0.32519644258946145, 0.15763390340309183],
676 | ["TranslateX", 0.6149090364414208, 0.7454832565718259]],
677 | [["AutoContrast", 0.5404508567155423, 0.7472387762067986],
678 | ["Equalize", 0.05649876539221024, 0.5628180219887216]]]
679 | return p
680 |
681 |
682 | class FastAugmentation(object):
683 | def __init__(self, policies=fa_resnet50_rimagenet()):
684 | self.policies = policies
685 |
686 | def __call__(self, img):
687 | for _ in range(1):
688 | policy = random.choice(self.policies)
689 | for name, pr, level in policy:
690 | if random.random() > pr:
691 | continue
692 | img = apply_augment(img, name, level)
693 | return img
694 |
--------------------------------------------------------------------------------
/utils/images/advresult.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tingxueronghua/pytorch-classification-advprop/d0f6b045c8666d747cfefdfff0e0d7bfe869b60d/utils/images/advresult.png
--------------------------------------------------------------------------------
/utils/images/cifar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tingxueronghua/pytorch-classification-advprop/d0f6b045c8666d747cfefdfff0e0d7bfe869b60d/utils/images/cifar.png
--------------------------------------------------------------------------------
/utils/images/imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tingxueronghua/pytorch-classification-advprop/d0f6b045c8666d747cfefdfff0e0d7bfe869b60d/utils/images/imagenet.png
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | # A simple torch style logger
2 | # (C) Wei YANG 2017
3 | from __future__ import absolute_import
4 | import matplotlib.pyplot as plt
5 | import os
6 | import sys
7 | import numpy as np
8 |
9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig']
10 |
11 | def savefig(fname, dpi=None):
12 | dpi = 150 if dpi == None else dpi
13 | plt.savefig(fname, dpi=dpi)
14 |
15 | def plot_overlap(logger, names=None):
16 | names = logger.names if names == None else names
17 | numbers = logger.numbers
18 | for _, name in enumerate(names):
19 | x = np.arange(len(numbers[name]))
20 | plt.plot(x, np.asarray(numbers[name]))
21 | return [logger.title + '(' + name + ')' for name in names]
22 |
23 | class Logger(object):
24 | '''Save training process to log file with simple plot function.'''
25 | def __init__(self, fpath, title=None, resume=False):
26 | self.file = None
27 | self.resume = resume
28 | self.title = '' if title == None else title
29 | if fpath is not None:
30 | if resume:
31 | self.file = open(fpath, 'r')
32 | name = self.file.readline()
33 | self.names = name.rstrip().split('\t')
34 | self.numbers = {}
35 | for _, name in enumerate(self.names):
36 | self.numbers[name] = []
37 |
38 | for numbers in self.file:
39 | numbers = numbers.rstrip().split('\t')
40 | for i in range(0, len(numbers)):
41 | self.numbers[self.names[i]].append(numbers[i])
42 | self.file.close()
43 | self.file = open(fpath, 'a')
44 | else:
45 | self.file = open(fpath, 'w')
46 |
47 | def set_names(self, names):
48 | if self.resume:
49 | pass
50 | # initialize numbers as empty list
51 | self.numbers = {}
52 | self.names = names
53 | for _, name in enumerate(self.names):
54 | self.file.write(name)
55 | self.file.write('\t')
56 | self.numbers[name] = []
57 | self.file.write('\n')
58 | self.file.flush()
59 |
60 |
61 | def append(self, numbers):
62 | assert len(self.names) == len(numbers), 'Numbers do not match names'
63 | for index, num in enumerate(numbers):
64 | self.file.write("{0:.6f}".format(num))
65 | self.file.write('\t')
66 | self.numbers[self.names[index]].append(num)
67 | self.file.write('\n')
68 | self.file.flush()
69 |
70 | def plot(self, names=None):
71 | names = self.names if names == None else names
72 | numbers = self.numbers
73 | for _, name in enumerate(names):
74 | x = np.arange(len(numbers[name]))
75 | plt.plot(x, np.asarray(numbers[name]))
76 | plt.legend([self.title + '(' + name + ')' for name in names])
77 | plt.grid(True)
78 |
79 | def close(self):
80 | if self.file is not None:
81 | self.file.close()
82 |
83 | class LoggerMonitor(object):
84 | '''Load and visualize multiple logs.'''
85 | def __init__ (self, paths):
86 | '''paths is a distionary with {name:filepath} pair'''
87 | self.loggers = []
88 | for title, path in paths.items():
89 | logger = Logger(path, title=title, resume=True)
90 | self.loggers.append(logger)
91 |
92 | def plot(self, names=None):
93 | plt.figure()
94 | plt.subplot(121)
95 | legend_text = []
96 | for logger in self.loggers:
97 | legend_text += plot_overlap(logger, names)
98 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
99 | plt.grid(True)
100 |
101 | if __name__ == '__main__':
102 | # # Example
103 | # logger = Logger('test.txt')
104 | # logger.set_names(['Train loss', 'Valid loss','Test loss'])
105 |
106 | # length = 100
107 | # t = np.arange(length)
108 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
109 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
110 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
111 |
112 | # for i in range(0, length):
113 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
114 | # logger.plot()
115 |
116 | # Example: logger monitor
117 | paths = {
118 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
119 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
120 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
121 | }
122 |
123 | field = ['Valid Acc.']
124 |
125 | monitor = LoggerMonitor(paths)
126 | monitor.plot(names=field)
127 | savefig('test.eps')
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | '''Some helper functions for PyTorch, including:
2 | - get_mean_and_std: calculate the mean and std value of dataset.
3 | - msr_init: net parameter initialization.
4 | - progress_bar: progress bar mimic xlua.progress.
5 | '''
6 | import errno
7 | import os
8 | import sys
9 | import time
10 | import math
11 |
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 | from torch.autograd import Variable
15 |
16 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter']
17 |
18 |
19 | def get_mean_and_std(dataset):
20 | '''Compute the mean and std value of dataset.'''
21 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
22 |
23 | mean = torch.zeros(3)
24 | std = torch.zeros(3)
25 | print('==> Computing mean and std..')
26 | for inputs, targets in dataloader:
27 | for i in range(3):
28 | mean[i] += inputs[:,i,:,:].mean()
29 | std[i] += inputs[:,i,:,:].std()
30 | mean.div_(len(dataset))
31 | std.div_(len(dataset))
32 | return mean, std
33 |
34 | def init_params(net):
35 | '''Init layer parameters.'''
36 | for m in net.modules():
37 | if isinstance(m, nn.Conv2d):
38 | init.kaiming_normal(m.weight, mode='fan_out')
39 | if m.bias:
40 | init.constant(m.bias, 0)
41 | elif isinstance(m, nn.BatchNorm2d):
42 | init.constant(m.weight, 1)
43 | init.constant(m.bias, 0)
44 | elif isinstance(m, nn.Linear):
45 | init.normal(m.weight, std=1e-3)
46 | if m.bias:
47 | init.constant(m.bias, 0)
48 |
49 | def mkdir_p(path):
50 | '''make dir if not exist'''
51 | try:
52 | os.makedirs(path)
53 | except OSError as exc: # Python >2.5
54 | if exc.errno == errno.EEXIST and os.path.isdir(path):
55 | pass
56 | else:
57 | raise
58 |
59 | class AverageMeter(object):
60 | """Computes and stores the average and current value
61 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
62 | """
63 | def __init__(self):
64 | self.reset()
65 |
66 | def reset(self):
67 | self.val = 0
68 | self.avg = 0
69 | self.sum = 0
70 | self.count = 0
71 |
72 | def update(self, val, n=1):
73 | self.val = val
74 | self.sum += val * n
75 | self.count += n
76 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/utils/mix_dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class MixDataLoader(object):
5 | r"""An abstract class representing a :class:`Dataset`.
6 |
7 | All datasets that represent a map from keys to data samples should subclass
8 | it. All subclasses should overrite :meth:`__getitem__`, supporting fetching a
9 | data sample for a given key. Subclasses could also optionally overwrite
10 | :meth:`__len__`, which is expected to return the size of the dataset by many
11 | :class:`~torch.utils.data.Sampler` implementations and the default options
12 | of :class:`~torch.utils.data.DataLoader`.
13 |
14 | .. note::
15 | :class:`~torch.utils.data.DataLoader` by default constructs a index
16 | sampler that yields integral indices. To make it work with a map-style
17 | dataset with non-integral indices/keys, a custom sampler must be provided.
18 | """
19 |
20 | def __init__(self, dataloader_main, dataloader_aux):
21 | self.dataloader_main = dataloader_main
22 | self.dataloader_aux = dataloader_aux
23 | assert len(dataloader_main) == len(dataloader_aux)
24 | self.len = len(dataloader_main)
25 |
26 | def __len__(self):
27 | return self.len
28 |
29 | def __iter__(self):
30 | self.dataloader_main_iter = self.dataloader_main.__iter__()
31 | self.dataloader_aux_iter = self.dataloader_aux.__iter__()
32 | return self
33 |
34 | def __next__(self):
35 | inputs_main, target_main = next(self.dataloader_main_iter)
36 | inputs_aux, target_aux = next(self.dataloader_aux_iter)
37 | return (torch.cat([inputs_main, inputs_aux]), torch.cat([target_main, target_aux]))
38 |
--------------------------------------------------------------------------------
/utils/progress/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.egg-info
3 | build/
4 | dist/
5 |
--------------------------------------------------------------------------------
/utils/progress/LICENSE:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2012 Giorgos Verigakis
2 | #
3 | # Permission to use, copy, modify, and distribute this software for any
4 | # purpose with or without fee is hereby granted, provided that the above
5 | # copyright notice and this permission notice appear in all copies.
6 | #
7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14 |
--------------------------------------------------------------------------------
/utils/progress/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.rst LICENSE
2 |
--------------------------------------------------------------------------------
/utils/progress/README.rst:
--------------------------------------------------------------------------------
1 | Easy progress reporting for Python
2 | ==================================
3 |
4 | |pypi|
5 |
6 | |demo|
7 |
8 | .. |pypi| image:: https://img.shields.io/pypi/v/progress.svg
9 | .. |demo| image:: https://raw.github.com/verigak/progress/master/demo.gif
10 | :alt: Demo
11 |
12 | Bars
13 | ----
14 |
15 | There are 7 progress bars to choose from:
16 |
17 | - ``Bar``
18 | - ``ChargingBar``
19 | - ``FillingSquaresBar``
20 | - ``FillingCirclesBar``
21 | - ``IncrementalBar``
22 | - ``PixelBar``
23 | - ``ShadyBar``
24 |
25 | To use them, just call ``next`` to advance and ``finish`` to finish:
26 |
27 | .. code-block:: python
28 |
29 | from progress.bar import Bar
30 |
31 | bar = Bar('Processing', max=20)
32 | for i in range(20):
33 | # Do some work
34 | bar.next()
35 | bar.finish()
36 |
37 | The result will be a bar like the following: ::
38 |
39 | Processing |############# | 42/100
40 |
41 | To simplify the common case where the work is done in an iterator, you can
42 | use the ``iter`` method:
43 |
44 | .. code-block:: python
45 |
46 | for i in Bar('Processing').iter(it):
47 | # Do some work
48 |
49 | Progress bars are very customizable, you can change their width, their fill
50 | character, their suffix and more:
51 |
52 | .. code-block:: python
53 |
54 | bar = Bar('Loading', fill='@', suffix='%(percent)d%%')
55 |
56 | This will produce a bar like the following: ::
57 |
58 | Loading |@@@@@@@@@@@@@ | 42%
59 |
60 | You can use a number of template arguments in ``message`` and ``suffix``:
61 |
62 | ========== ================================
63 | Name Value
64 | ========== ================================
65 | index current value
66 | max maximum value
67 | remaining max - index
68 | progress index / max
69 | percent progress * 100
70 | avg simple moving average time per item (in seconds)
71 | elapsed elapsed time in seconds
72 | elapsed_td elapsed as a timedelta (useful for printing as a string)
73 | eta avg * remaining
74 | eta_td eta as a timedelta (useful for printing as a string)
75 | ========== ================================
76 |
77 | Instead of passing all configuration options on instatiation, you can create
78 | your custom subclass:
79 |
80 | .. code-block:: python
81 |
82 | class FancyBar(Bar):
83 | message = 'Loading'
84 | fill = '*'
85 | suffix = '%(percent).1f%% - %(eta)ds'
86 |
87 | You can also override any of the arguments or create your own:
88 |
89 | .. code-block:: python
90 |
91 | class SlowBar(Bar):
92 | suffix = '%(remaining_hours)d hours remaining'
93 | @property
94 | def remaining_hours(self):
95 | return self.eta // 3600
96 |
97 |
98 | Spinners
99 | ========
100 |
101 | For actions with an unknown number of steps you can use a spinner:
102 |
103 | .. code-block:: python
104 |
105 | from progress.spinner import Spinner
106 |
107 | spinner = Spinner('Loading ')
108 | while state != 'FINISHED':
109 | # Do some work
110 | spinner.next()
111 |
112 | There are 5 predefined spinners:
113 |
114 | - ``Spinner``
115 | - ``PieSpinner``
116 | - ``MoonSpinner``
117 | - ``LineSpinner``
118 | - ``PixelSpinner``
119 |
120 |
121 | Other
122 | =====
123 |
124 | There are a number of other classes available too, please check the source or
125 | subclass one of them to create your own.
126 |
127 |
128 | License
129 | =======
130 |
131 | progress is licensed under ISC
132 |
--------------------------------------------------------------------------------
/utils/progress/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tingxueronghua/pytorch-classification-advprop/d0f6b045c8666d747cfefdfff0e0d7bfe869b60d/utils/progress/demo.gif
--------------------------------------------------------------------------------
/utils/progress/progress/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2012 Giorgos Verigakis
2 | #
3 | # Permission to use, copy, modify, and distribute this software for any
4 | # purpose with or without fee is hereby granted, provided that the above
5 | # copyright notice and this permission notice appear in all copies.
6 | #
7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14 |
15 | from __future__ import division
16 |
17 | from collections import deque
18 | from datetime import timedelta
19 | from math import ceil
20 | from sys import stderr
21 | from time import time
22 |
23 |
24 | __version__ = '1.3'
25 |
26 |
27 | class Infinite(object):
28 | file = stderr
29 | sma_window = 10 # Simple Moving Average window
30 |
31 | def __init__(self, *args, **kwargs):
32 | self.index = 0
33 | self.start_ts = time()
34 | self.avg = 0
35 | self._ts = self.start_ts
36 | self._xput = deque(maxlen=self.sma_window)
37 | for key, val in kwargs.items():
38 | setattr(self, key, val)
39 |
40 | def __getitem__(self, key):
41 | if key.startswith('_'):
42 | return None
43 | return getattr(self, key, None)
44 |
45 | @property
46 | def elapsed(self):
47 | return int(time() - self.start_ts)
48 |
49 | @property
50 | def elapsed_td(self):
51 | return timedelta(seconds=self.elapsed)
52 |
53 | def update_avg(self, n, dt):
54 | if n > 0:
55 | self._xput.append(dt / n)
56 | self.avg = sum(self._xput) / len(self._xput)
57 |
58 | def update(self):
59 | pass
60 |
61 | def start(self):
62 | pass
63 |
64 | def finish(self):
65 | pass
66 |
67 | def next(self, n=1):
68 | now = time()
69 | dt = now - self._ts
70 | self.update_avg(n, dt)
71 | self._ts = now
72 | self.index = self.index + n
73 | self.update()
74 |
75 | def iter(self, it):
76 | try:
77 | for x in it:
78 | yield x
79 | self.next()
80 | finally:
81 | self.finish()
82 |
83 |
84 | class Progress(Infinite):
85 | def __init__(self, *args, **kwargs):
86 | super(Progress, self).__init__(*args, **kwargs)
87 | self.max = kwargs.get('max', 100)
88 |
89 | @property
90 | def eta(self):
91 | return int(ceil(self.avg * self.remaining))
92 |
93 | @property
94 | def eta_td(self):
95 | return timedelta(seconds=self.eta)
96 |
97 | @property
98 | def percent(self):
99 | return self.progress * 100
100 |
101 | @property
102 | def progress(self):
103 | return min(1, self.index / self.max)
104 |
105 | @property
106 | def remaining(self):
107 | return max(self.max - self.index, 0)
108 |
109 | def start(self):
110 | self.update()
111 |
112 | def goto(self, index):
113 | incr = index - self.index
114 | self.next(incr)
115 |
116 | def iter(self, it):
117 | try:
118 | self.max = len(it)
119 | except TypeError:
120 | pass
121 |
122 | try:
123 | for x in it:
124 | yield x
125 | self.next()
126 | finally:
127 | self.finish()
128 |
--------------------------------------------------------------------------------
/utils/progress/progress/bar.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2012 Giorgos Verigakis
4 | #
5 | # Permission to use, copy, modify, and distribute this software for any
6 | # purpose with or without fee is hereby granted, provided that the above
7 | # copyright notice and this permission notice appear in all copies.
8 | #
9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 |
17 | from __future__ import unicode_literals
18 | from . import Progress
19 | from .helpers import WritelnMixin
20 |
21 |
22 | class Bar(WritelnMixin, Progress):
23 | width = 32
24 | message = ''
25 | suffix = '%(index)d/%(max)d'
26 | bar_prefix = ' |'
27 | bar_suffix = '| '
28 | empty_fill = ' '
29 | fill = '#'
30 | hide_cursor = True
31 |
32 | def update(self):
33 | filled_length = int(self.width * self.progress)
34 | empty_length = self.width - filled_length
35 |
36 | message = self.message % self
37 | bar = self.fill * filled_length
38 | empty = self.empty_fill * empty_length
39 | suffix = self.suffix % self
40 | line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix,
41 | suffix])
42 | self.writeln(line)
43 |
44 |
45 | class ChargingBar(Bar):
46 | suffix = '%(percent)d%%'
47 | bar_prefix = ' '
48 | bar_suffix = ' '
49 | empty_fill = '∙'
50 | fill = '█'
51 |
52 |
53 | class FillingSquaresBar(ChargingBar):
54 | empty_fill = '▢'
55 | fill = '▣'
56 |
57 |
58 | class FillingCirclesBar(ChargingBar):
59 | empty_fill = '◯'
60 | fill = '◉'
61 |
62 |
63 | class IncrementalBar(Bar):
64 | phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█')
65 |
66 | def update(self):
67 | nphases = len(self.phases)
68 | filled_len = self.width * self.progress
69 | nfull = int(filled_len) # Number of full chars
70 | phase = int((filled_len - nfull) * nphases) # Phase of last char
71 | nempty = self.width - nfull # Number of empty chars
72 |
73 | message = self.message % self
74 | bar = self.phases[-1] * nfull
75 | current = self.phases[phase] if phase > 0 else ''
76 | empty = self.empty_fill * max(0, nempty - len(current))
77 | suffix = self.suffix % self
78 | line = ''.join([message, self.bar_prefix, bar, current, empty,
79 | self.bar_suffix, suffix])
80 | self.writeln(line)
81 |
82 |
83 | class PixelBar(IncrementalBar):
84 | phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿')
85 |
86 |
87 | class ShadyBar(IncrementalBar):
88 | phases = (' ', '░', '▒', '▓', '█')
89 |
--------------------------------------------------------------------------------
/utils/progress/progress/counter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2012 Giorgos Verigakis
4 | #
5 | # Permission to use, copy, modify, and distribute this software for any
6 | # purpose with or without fee is hereby granted, provided that the above
7 | # copyright notice and this permission notice appear in all copies.
8 | #
9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 |
17 | from __future__ import unicode_literals
18 | from . import Infinite, Progress
19 | from .helpers import WriteMixin
20 |
21 |
22 | class Counter(WriteMixin, Infinite):
23 | message = ''
24 | hide_cursor = True
25 |
26 | def update(self):
27 | self.write(str(self.index))
28 |
29 |
30 | class Countdown(WriteMixin, Progress):
31 | hide_cursor = True
32 |
33 | def update(self):
34 | self.write(str(self.remaining))
35 |
36 |
37 | class Stack(WriteMixin, Progress):
38 | phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█')
39 | hide_cursor = True
40 |
41 | def update(self):
42 | nphases = len(self.phases)
43 | i = min(nphases - 1, int(self.progress * nphases))
44 | self.write(self.phases[i])
45 |
46 |
47 | class Pie(Stack):
48 | phases = ('○', '◔', '◑', '◕', '●')
49 |
--------------------------------------------------------------------------------
/utils/progress/progress/helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2012 Giorgos Verigakis
2 | #
3 | # Permission to use, copy, modify, and distribute this software for any
4 | # purpose with or without fee is hereby granted, provided that the above
5 | # copyright notice and this permission notice appear in all copies.
6 | #
7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14 |
15 | from __future__ import print_function
16 |
17 |
18 | HIDE_CURSOR = '\x1b[?25l'
19 | SHOW_CURSOR = '\x1b[?25h'
20 |
21 |
22 | class WriteMixin(object):
23 | hide_cursor = False
24 |
25 | def __init__(self, message=None, **kwargs):
26 | super(WriteMixin, self).__init__(**kwargs)
27 | self._width = 0
28 | if message:
29 | self.message = message
30 |
31 | if self.file.isatty():
32 | if self.hide_cursor:
33 | print(HIDE_CURSOR, end='', file=self.file)
34 | print(self.message, end='', file=self.file)
35 | self.file.flush()
36 |
37 | def write(self, s):
38 | if self.file.isatty():
39 | b = '\b' * self._width
40 | c = s.ljust(self._width)
41 | print(b + c, end='', file=self.file)
42 | self._width = max(self._width, len(s))
43 | self.file.flush()
44 |
45 | def finish(self):
46 | if self.file.isatty() and self.hide_cursor:
47 | print(SHOW_CURSOR, end='', file=self.file)
48 |
49 |
50 | class WritelnMixin(object):
51 | hide_cursor = False
52 |
53 | def __init__(self, message=None, **kwargs):
54 | super(WritelnMixin, self).__init__(**kwargs)
55 | if message:
56 | self.message = message
57 |
58 | if self.file.isatty() and self.hide_cursor:
59 | print(HIDE_CURSOR, end='', file=self.file)
60 |
61 | def clearln(self):
62 | if self.file.isatty():
63 | print('\r\x1b[K', end='', file=self.file)
64 |
65 | def writeln(self, line):
66 | if self.file.isatty():
67 | self.clearln()
68 | print(line, end='', file=self.file)
69 | self.file.flush()
70 |
71 | def finish(self):
72 | if self.file.isatty():
73 | print(file=self.file)
74 | if self.hide_cursor:
75 | print(SHOW_CURSOR, end='', file=self.file)
76 |
77 |
78 | from signal import signal, SIGINT
79 | from sys import exit
80 |
81 |
82 | class SigIntMixin(object):
83 | """Registers a signal handler that calls finish on SIGINT"""
84 |
85 | def __init__(self, *args, **kwargs):
86 | super(SigIntMixin, self).__init__(*args, **kwargs)
87 | signal(SIGINT, self._sigint_handler)
88 |
89 | def _sigint_handler(self, signum, frame):
90 | self.finish()
91 | exit(0)
92 |
--------------------------------------------------------------------------------
/utils/progress/progress/spinner.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2012 Giorgos Verigakis
4 | #
5 | # Permission to use, copy, modify, and distribute this software for any
6 | # purpose with or without fee is hereby granted, provided that the above
7 | # copyright notice and this permission notice appear in all copies.
8 | #
9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 |
17 | from __future__ import unicode_literals
18 | from . import Infinite
19 | from .helpers import WriteMixin
20 |
21 |
22 | class Spinner(WriteMixin, Infinite):
23 | message = ''
24 | phases = ('-', '\\', '|', '/')
25 | hide_cursor = True
26 |
27 | def update(self):
28 | i = self.index % len(self.phases)
29 | self.write(self.phases[i])
30 |
31 |
32 | class PieSpinner(Spinner):
33 | phases = ['◷', '◶', '◵', '◴']
34 |
35 |
36 | class MoonSpinner(Spinner):
37 | phases = ['◑', '◒', '◐', '◓']
38 |
39 |
40 | class LineSpinner(Spinner):
41 | phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻']
42 |
43 | class PixelSpinner(Spinner):
44 | phases = ['⣾','⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽']
45 |
--------------------------------------------------------------------------------
/utils/progress/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import setup
4 |
5 | import progress
6 |
7 |
8 | setup(
9 | name='progress',
10 | version=progress.__version__,
11 | description='Easy to use progress bars',
12 | long_description=open('README.rst').read(),
13 | author='Giorgos Verigakis',
14 | author_email='verigak@gmail.com',
15 | url='http://github.com/verigak/progress/',
16 | license='ISC',
17 | packages=['progress'],
18 | classifiers=[
19 | 'Environment :: Console',
20 | 'Intended Audience :: Developers',
21 | 'License :: OSI Approved :: ISC License (ISCL)',
22 | 'Programming Language :: Python :: 2.6',
23 | 'Programming Language :: Python :: 2.7',
24 | 'Programming Language :: Python :: 3.3',
25 | 'Programming Language :: Python :: 3.4',
26 | 'Programming Language :: Python :: 3.5',
27 | 'Programming Language :: Python :: 3.6',
28 | ]
29 | )
30 |
--------------------------------------------------------------------------------
/utils/progress/test_progress.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from __future__ import print_function
4 |
5 | import random
6 | import time
7 |
8 | from progress.bar import (Bar, ChargingBar, FillingSquaresBar,
9 | FillingCirclesBar, IncrementalBar, PixelBar,
10 | ShadyBar)
11 | from progress.spinner import (Spinner, PieSpinner, MoonSpinner, LineSpinner,
12 | PixelSpinner)
13 | from progress.counter import Counter, Countdown, Stack, Pie
14 |
15 |
16 | def sleep():
17 | t = 0.01
18 | t += t * random.uniform(-0.1, 0.1) # Add some variance
19 | time.sleep(t)
20 |
21 |
22 | for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar):
23 | suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]'
24 | bar = bar_cls(bar_cls.__name__, suffix=suffix)
25 | for i in bar.iter(range(200)):
26 | sleep()
27 |
28 | for bar_cls in (IncrementalBar, PixelBar, ShadyBar):
29 | suffix = '%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]'
30 | bar = bar_cls(bar_cls.__name__, suffix=suffix)
31 | for i in bar.iter(range(200)):
32 | sleep()
33 |
34 | for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner):
35 | for i in spin(spin.__name__ + ' ').iter(range(100)):
36 | sleep()
37 | print()
38 |
39 | for singleton in (Counter, Countdown, Stack, Pie):
40 | for i in singleton(singleton.__name__ + ' ').iter(range(100)):
41 | sleep()
42 | print()
43 |
44 | bar = IncrementalBar('Random', suffix='%(index)d')
45 | for i in range(100):
46 | bar.goto(random.randint(0, 100))
47 | sleep()
48 | bar.finish()
49 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import torch
3 | import torch.nn as nn
4 | import torchvision
5 | import torchvision.transforms as transforms
6 | import numpy as np
7 | from .misc import *
8 |
9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single']
10 |
11 | # functions to show an image
12 | def make_image(img, mean=(0,0,0), std=(1,1,1)):
13 | for i in range(0, 3):
14 | img[i] = img[i] * std[i] + mean[i] # unnormalize
15 | npimg = img.numpy()
16 | return np.transpose(npimg, (1, 2, 0))
17 |
18 | def gauss(x,a,b,c):
19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a)
20 |
21 | def colorize(x):
22 | ''' Converts a one-channel grayscale image to a color heatmap image '''
23 | if x.dim() == 2:
24 | torch.unsqueeze(x, 0, out=x)
25 | if x.dim() == 3:
26 | cl = torch.zeros([3, x.size(1), x.size(2)])
27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
28 | cl[1] = gauss(x,1,.5,.3)
29 | cl[2] = gauss(x,1,.2,.3)
30 | cl[cl.gt(1)] = 1
31 | elif x.dim() == 4:
32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)])
33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
34 | cl[:,1,:,:] = gauss(x,1,.5,.3)
35 | cl[:,2,:,:] = gauss(x,1,.2,.3)
36 | return cl
37 |
38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
40 | plt.imshow(images)
41 | plt.show()
42 |
43 |
44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
45 | im_size = images.size(2)
46 |
47 | # save for adding mask
48 | im_data = images.clone()
49 | for i in range(0, 3):
50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
51 |
52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
53 | plt.subplot(2, 1, 1)
54 | plt.imshow(images)
55 | plt.axis('off')
56 |
57 | # for b in range(mask.size(0)):
58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
59 | mask_size = mask.size(2)
60 | # print('Max %f Min %f' % (mask.max(), mask.min()))
61 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
63 | # for c in range(3):
64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
65 |
66 | # print(mask.size())
67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
69 | plt.subplot(2, 1, 2)
70 | plt.imshow(mask)
71 | plt.axis('off')
72 |
73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
74 | im_size = images.size(2)
75 |
76 | # save for adding mask
77 | im_data = images.clone()
78 | for i in range(0, 3):
79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
80 |
81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
82 | plt.subplot(1+len(masklist), 1, 1)
83 | plt.imshow(images)
84 | plt.axis('off')
85 |
86 | for i in range(len(masklist)):
87 | mask = masklist[i].data.cpu()
88 | # for b in range(mask.size(0)):
89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
90 | mask_size = mask.size(2)
91 | # print('Max %f Min %f' % (mask.max(), mask.min()))
92 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
94 | # for c in range(3):
95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
96 |
97 | # print(mask.size())
98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
100 | plt.subplot(1+len(masklist), 1, i+2)
101 | plt.imshow(mask)
102 | plt.axis('off')
103 |
104 |
105 |
106 | # x = torch.zeros(1, 3, 3)
107 | # out = colorize(x)
108 | # out_im = make_image(out)
109 | # plt.imshow(out_im)
110 | # plt.show()
--------------------------------------------------------------------------------