├── .gitignore ├── README.md ├── TinyImagesExclusionIdcs ├── 80mn_cifar100_test_idxs.txt ├── 80mn_cifar100_train_idxs.txt ├── 80mn_cifar101_idxs.txt ├── 80mn_cifar10_test_idxs.txt ├── 80mn_cifar10_train_idxs.txt └── 80mn_cifar_idxs.txt ├── cifar10_robustness_test.py ├── cifar10_visualize_failure.py ├── counterfactual.png ├── environment.yml ├── run_training_cifar10.py └── utils ├── 80mn_cifar100_test_idxs.txt ├── 80mn_cifar101_idxs.txt ├── 80mn_cifar10_test_idxs.txt ├── 80mn_cifar_idxs.txt ├── __init__.py ├── adversarial_attacks ├── __init__.py ├── adversarialattack.py ├── apgd.py ├── argmin_pgd.py ├── cutout_pgd.py ├── dummy_attack.py ├── fab.py ├── fgm.py ├── l1_projection.py ├── l1_reg_pgd.py ├── monotone_pgd.py ├── noise.py ├── pgd.py ├── restartattack.py ├── tanh_attack.py └── utils.py ├── adversarial_test.py ├── average_model.py ├── bit_downstream_schedule.py ├── compute_auc.py ├── datasets ├── __init__.py ├── augmentations │ ├── __init__.py │ ├── autoaugment.py │ ├── cifar_augmentation.py │ ├── cutout.py │ ├── imagenet_augmentation.py │ ├── svhn_augmentation.py │ └── utils.py ├── celebA.py ├── cifar.py ├── cifar_corrupted.py ├── cinic_10.py ├── combo_dataset.py ├── fgvc_aircraft.py ├── flowers.py ├── food_101.py ├── food_101N.py ├── imagenet.py ├── imagenet_natural_adversarials.py ├── imagenet_subsets.py ├── lsun.py ├── mnist.py ├── noise_datasets.py ├── openimages.py ├── paths.py ├── pets.py ├── preproc.py ├── semisupervised_dataset.py ├── stanford_cars.py ├── svhn.py ├── tinyImages.py ├── tiny_image_net.py ├── utils.py └── various.py ├── distances.py ├── eval.py ├── find_nearest_neighbours.py ├── generate_all_classes.py ├── id_radius_confidence.py ├── load_trained_model.py ├── model_normalization.py ├── models.py ├── models ├── __init__.py ├── big_transfer │ ├── __init__.py │ └── models.py ├── big_transfer_factory.py ├── ebm_wrn.py ├── model_factory_224.py ├── model_factory_32.py └── models_32x32 │ ├── __init__.py │ ├── fixup_resnet.py │ ├── pyramid.py │ ├── resnet.py │ ├── shake_pyramidnet.py │ ├── shakedrop.py │ ├── wide_resnet.py │ └── wideresnet_carmon.py ├── od_radius_confidence.py ├── plotting.py ├── resize_right ├── __init__.py ├── interp_methods.py └── resize_right.py ├── run_file_helpers.py ├── temperature_wrapper.py ├── train_types ├── ACET_training.py ├── AdversarialACET.py ├── Adversarial_training.py ├── BCEACET_training.py ├── BCEAdversarial_training.py ├── BCECEDA_training.py ├── CEDA_training.py ├── TRADESACET_training.py ├── TRADESCEDA_training.py ├── TRADES_training.py ├── __init__.py ├── helpers.py ├── in_distribution_training.py ├── in_out_distribution_training.py ├── msda │ ├── __init__.py │ ├── config_creators.py │ ├── dummy_msda.py │ ├── factory.py │ ├── fmix.py │ ├── fmix_utils.py │ ├── mixed_sample_data_augmentation.py │ └── mixup.py ├── optimizers │ ├── __init__.py │ ├── config_creators.py │ ├── factory.py │ └── sam.py ├── out_distribution_training.py ├── output_backend.py ├── plain_training.py ├── randomized_smoothing_training.py ├── schedulers │ ├── __init__.py │ ├── config_creators.py │ ├── cosine_lr.py │ ├── factory.py │ ├── scheduler.py │ ├── scheduler_factory.py │ └── step_lr.py ├── train_loss.py └── train_type.py └── visual_counterfactual_generation.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | #pycharm 132 | .idea/ 133 | -------------------------------------------------------------------------------- /cifar10_robustness_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | import pathlib 8 | import matplotlib as mpl 9 | mpl.use('Agg') 10 | import matplotlib.pyplot as plt 11 | import argparse 12 | import utils.adversarial_attacks as aa 13 | from autoattack import AutoAttack 14 | 15 | from utils.load_trained_model import load_model 16 | import utils.datasets as dl 17 | 18 | model_descriptions = [ 19 | ('WideResNet34x10', 'cifar10_500k_apgd_asam', 'best_avg', None, False), 20 | ('WideResNet34x10', 'cifar10_pgd', 'best_avg', None, False), 21 | ('WideResNet34x10', 'cifar10_apgd', 'best_avg', None, False), 22 | ('WideResNet34x10', 'cifar10_500k_pgd', 'best_avg', None, False), 23 | ('WideResNet34x10', 'cifar10_500k_apgd', 'best_avg', None, False), 24 | ] 25 | 26 | 27 | parser = argparse.ArgumentParser(description='Parse arguments.', prefix_chars='-') 28 | 29 | parser.add_argument('--gpu','--list', nargs='+', default=[0], 30 | help='GPU indices, if more than 1 parallel modules will be called') 31 | 32 | hps = parser.parse_args() 33 | 34 | if len(hps.gpu)==0: 35 | device = torch.device('cpu') 36 | print('Warning! Computing on CPU') 37 | num_devices = 1 38 | elif len(hps.gpu)==1: 39 | device_ids = [int(hps.gpu[0])] 40 | device = torch.device('cuda:' + str(hps.gpu[0])) 41 | num_devices = 1 42 | else: 43 | device_ids = [int(i) for i in hps.gpu] 44 | device = torch.device('cuda:' + str(min(device_ids))) 45 | num_devices = len(device_ids) 46 | 47 | L2 = True 48 | LINF = False 49 | 50 | ROBUSTNESS_DATAPOINTS = 10_000 51 | dataset = 'cifar10' 52 | 53 | bs = 500 * num_devices 54 | 55 | print(f'Testing on {ROBUSTNESS_DATAPOINTS} points') 56 | 57 | for model_idx, (type, folder, checkpoint, temperature, temp) in enumerate(model_descriptions): 58 | model = load_model(type, folder, checkpoint, 59 | temperature, device, load_temp=temp, dataset=dataset) 60 | model.to(device) 61 | 62 | if len(hps.gpu) > 1: 63 | model = nn.DataParallel(model, device_ids=device_ids) 64 | 65 | model.eval() 66 | print(f'\n\n{folder} {checkpoint}\n ') 67 | 68 | if dataset == 'cifar10': 69 | dataloader = dl.get_CIFAR10(False, batch_size=bs, augm_type='none') 70 | elif dataset == 'cifar100': 71 | dataloader = dl.get_CIFAR100(False, batch_size=bs, augm_type='none') 72 | else: 73 | raise NotImplementedError() 74 | 75 | acc = 0.0 76 | with torch.no_grad(): 77 | for data, target in dataloader: 78 | data = data.to(device) 79 | target = target.to(device) 80 | out = model(data) 81 | _, pred = torch.max(out, dim=1) 82 | acc += torch.sum(pred == target).item() / len(dataloader.dataset) 83 | 84 | print(f'Clean accuracy {acc}') 85 | 86 | if dataset == 'cifar10': 87 | dataloader = dl.get_CIFAR10(False, batch_size=ROBUSTNESS_DATAPOINTS, augm_type='none') 88 | elif dataset == 'cifar100': 89 | dataloader = dl.get_CIFAR100(False, batch_size=ROBUSTNESS_DATAPOINTS, augm_type='none') 90 | else: 91 | raise NotImplementedError() 92 | 93 | data_iterator = iter(dataloader) 94 | ref_data, target = next(data_iterator) 95 | 96 | if L2: 97 | print('Eps: 0.5') 98 | 99 | attack = AutoAttack(model, device=device, norm='L2', eps=0.5, verbose=True) 100 | attack.run_standard_evaluation(ref_data, target, bs=bs) 101 | 102 | # print('Eps: 1.0') 103 | # attack = AutoAttack(model, device=device, norm='L2', eps=1.0, attacks_to_run=attacks_to_run,verbose=True) 104 | # attack.run_standard_evaluation(ref_data, target, bs=bs) 105 | if LINF: 106 | print('Eps: 8/255') 107 | attack = AutoAttack(model, device=device, norm='Linf', eps=8./255.,verbose=True) 108 | attack.run_standard_evaluation(ref_data, target, bs=bs) 109 | 110 | -------------------------------------------------------------------------------- /cifar10_visualize_failure.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | import pathlib 8 | import matplotlib as mpl 9 | mpl.use('Agg') 10 | import matplotlib.pyplot as plt 11 | import argparse 12 | 13 | import utils.datasets as dl 14 | from utils.visual_counterfactual_generation import visual_counterfactuals 15 | 16 | parser = argparse.ArgumentParser(description='Parse arguments.', prefix_chars='-') 17 | 18 | parser.add_argument('--gpu','--list', nargs='+', default=[0], 19 | help='GPU indices, if more than 1 parallel modules will be called') 20 | 21 | hps = parser.parse_args() 22 | 23 | 24 | bs = 128 25 | big_model_bs = 20 26 | 27 | if len(hps.gpu)==0: 28 | device = torch.device('cpu') 29 | print('Warning! Computing on CPU') 30 | elif len(hps.gpu)==1: 31 | device_ids = None 32 | device = torch.device('cuda:' + str(hps.gpu[0])) 33 | bs = bs 34 | big_model_bs = big_model_bs 35 | else: 36 | device_ids = [int(i) for i in hps.gpu] 37 | device = torch.device('cuda:' + str(min(device_ids))) 38 | bs = bs * len(device_ids) 39 | big_model_bs = big_model_bs * len(device_ids) 40 | 41 | model_descriptions = [ 42 | ('WideResNet34x10', 'cifar10_pgd', 'best_avg', None, False), 43 | ('WideResNet34x10', 'cifar10_apgd', 'best_avg', None, False), 44 | ('WideResNet34x10', 'cifar10_500k_pgd', 'best_avg', None, False), 45 | ('WideResNet34x10', 'cifar10_500k_apgd', 'best_avg', None, False), 46 | ('WideResNet34x10', 'cifar10_500k_apgd_asam', 'best_avg', None, False), 47 | ] 48 | 49 | model_batchsize = bs * np.ones(len(model_descriptions), dtype=np.int) 50 | num_examples = 16 51 | 52 | dataloader = dl.get_CIFAR10(False, bs, augm_type='none') 53 | num_datapoints = len(dataloader.dataset) 54 | 55 | class_labels = dl.cifar.get_CIFAR10_labels() 56 | eval_dir = 'Cifar10Eval/' 57 | 58 | norm = 'l2' 59 | 60 | if norm == 'l1': 61 | radii = np.linspace(15, 90, 6) 62 | visual_counterfactuals(model_descriptions, radii, dataloader, model_batchsize, num_examples, class_labels, device, 63 | eval_dir, 'cifar10', norm='l1', stepsize=5, device_ids=device_ids) 64 | else: 65 | radii = np.linspace(0.5, 3, 6) 66 | visual_counterfactuals(model_descriptions, radii, dataloader, model_batchsize, num_examples, class_labels, device, eval_dir, 'cifar10', device_ids=device_ids) 67 | -------------------------------------------------------------------------------- /counterfactual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/M4xim4l/InNOutRobustness/d81d1d26e5ebc9193009e3d92bd67b5e01d6cfd6/counterfactual.png -------------------------------------------------------------------------------- /run_training_cifar10.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | mpl.use('Agg') 3 | 4 | import os 5 | import torch 6 | import torch.nn as nn 7 | 8 | from utils.model_normalization import Cifar10Wrapper 9 | import utils.datasets as dl 10 | import utils.models.model_factory_32 as factory 11 | import utils.run_file_helpers as rh 12 | from distutils.util import strtobool 13 | 14 | import argparse 15 | 16 | parser = argparse.ArgumentParser(description='Define hyperparameters.', prefix_chars='-') 17 | parser.add_argument('--net', type=str, default='ResNet18', help='Resnet18, 34 or 50, WideResNet28') 18 | parser.add_argument('--model_params', nargs='+', default=[]) 19 | parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10 or semi-cifar10') 20 | parser.add_argument('--od_dataset', type=str, default='tinyImages', 21 | help=('tinyImages or cifar100')) 22 | parser.add_argument('--exclude_cifar', dest='exclude_cifar', type=lambda x: bool(strtobool(x)), 23 | default=True, help='whether to exclude cifar10 from tiny images') 24 | 25 | rh.parser_add_commons(parser) 26 | rh.parser_add_adversarial_commons(parser) 27 | rh.parser_add_adversarial_norms(parser, 'cifar10') 28 | 29 | hps = parser.parse_args() 30 | # 31 | device_ids = None 32 | if len(hps.gpu)==0: 33 | device = torch.device('cpu') 34 | print('Warning! Computing on CPU') 35 | elif len(hps.gpu)==1: 36 | device = torch.device('cuda:' + str(hps.gpu[0])) 37 | else: 38 | device_ids = [int(i) for i in hps.gpu] 39 | device = torch.device('cuda:' + str(min(device_ids))) 40 | 41 | #Load model 42 | model_root_dir = 'Cifar10Models' 43 | logs_root_dir = 'Cifar10Logs' 44 | num_classes = 10 45 | 46 | model, model_name, model_config, img_size = factory.build_model(hps.net, num_classes, model_params=hps.model_params) 47 | model_dir = os.path.join(model_root_dir, model_name) 48 | log_dir = os.path.join(logs_root_dir, model_name) 49 | 50 | start_epoch, optim_state_dict = rh.load_model_checkpoint(model, model_dir, device, hps) 51 | model = Cifar10Wrapper(model).to(device) 52 | 53 | msda_config = rh.create_msda_config(hps) 54 | 55 | #load dataset 56 | od_bs = int(hps.od_bs_factor * hps.bs) 57 | 58 | id_config = {} 59 | if hps.dataset == 'cifar10': 60 | train_loader = dl.get_CIFAR10(train=True, batch_size=hps.bs, augm_type=hps.augm, size=img_size, 61 | config_dict=id_config) 62 | elif hps.dataset == 'semi-cifar10': 63 | train_loader = dl.get_CIFAR10_ti_500k(train=True, batch_size=hps.bs, augm_type=hps.augm, fraction=0.7, 64 | size=img_size, 65 | config_dict=id_config) 66 | else: 67 | raise ValueError(f'Dataset {hps.datset} not supported') 68 | 69 | if hps.train_type.lower() in ['ceda', 'acet', 'advacet', 'tradesacet', 'tradesceda']: 70 | od_config = {} 71 | loader_config = {'ID config': id_config, 'OD config': od_config} 72 | 73 | if hps.od_dataset == 'tinyImages': 74 | tiny_train = dl.get_80MTinyImages(batch_size=od_bs, augm_type=hps.augm, num_workers=1, size=img_size, 75 | exclude_cifar=hps.exclude_cifar, exclude_cifar10_1=hps.exclude_cifar, config_dict=od_config) 76 | elif hps.od_dataset == 'cifar100': 77 | tiny_train = dl.get_CIFAR100(train=True, batch_size=od_bs, shuffle=True, augm_type=hps.augm, 78 | size=img_size, config_dict=od_config) 79 | elif hps.od_dataset == 'openImages': 80 | tiny_train = dl.get_openImages('train', batch_size=od_bs, shuffle=True, augm_type=hps.augm, size=img_size, exclude_dataset=None, config_dict=od_config) 81 | else: 82 | loader_config = {'ID config': id_config} 83 | 84 | test_loader = dl.get_CIFAR10(train=False, batch_size=hps.bs, augm_type='none', size=img_size) 85 | 86 | scheduler_config, optimizer_config = rh.create_optim_scheduler_swa_configs(hps) 87 | id_attack_config, od_attack_config = rh.create_attack_config(hps, 'cifar10') 88 | trainer = rh.create_trainer(hps, model, optimizer_config, scheduler_config, device, num_classes, 89 | model_dir, log_dir, msda_config=msda_config, model_config=model_config, 90 | id_attack_config=id_attack_config, od_attack_config=od_attack_config) 91 | ##DEBUG: 92 | # torch.autograd.set_detect_anomaly(True) 93 | torch.backends.cudnn.benchmark = True 94 | 95 | # run training 96 | if trainer.requires_out_distribution(): 97 | train_loaders, test_loaders = trainer.create_loaders_dict(train_loader, test_loader=test_loader, 98 | out_distribution_loader=tiny_train) 99 | trainer.train(train_loaders, test_loaders, loader_config=loader_config, start_epoch=start_epoch, 100 | optim_state_dict=optim_state_dict, device_ids=device_ids) 101 | else: 102 | train_loaders, test_loaders = trainer.create_loaders_dict(train_loader, test_loader=test_loader) 103 | trainer.train(train_loaders, test_loaders, loader_config=loader_config, start_epoch=start_epoch, 104 | optim_state_dict=optim_state_dict, device_ids=device_ids) 105 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/M4xim4l/InNOutRobustness/d81d1d26e5ebc9193009e3d92bd67b5e01d6cfd6/utils/__init__.py -------------------------------------------------------------------------------- /utils/adversarial_attacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .apgd import APGDAttack 2 | from .fab import L1FABAttack, L2FABAttack, LinfFABAttack 3 | from .argmin_pgd import ArgminPGD 4 | from .fgm import FGM 5 | from .cutout_pgd import CutoutPGD 6 | from .monotone_pgd import MonotonePGD 7 | from .tanh_attack import TanhIterativeAttack 8 | from .dummy_attack import DummyAttack 9 | from .pgd import PGD 10 | from .noise import UniformNoiseGenerator, NormalNoiseGenerator -------------------------------------------------------------------------------- /utils/adversarial_attacks/adversarialattack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | from .utils import logits_diff_loss, conf_diff_loss, confidence_loss, reduce 7 | 8 | class AdversarialAttack(): 9 | def __init__(self, loss, num_classes, model=None, save_trajectory=False): 10 | #loss should either be a string specifying one of the predefined loss functions 11 | #OR 12 | #a custom loss function taking 4 arguments as train_loss class 13 | self.loss = loss 14 | self.save_trajectory = save_trajectory 15 | self.last_trajectory = None 16 | self.num_classes = num_classes 17 | if model is not None: 18 | self.model = model 19 | else: 20 | self.model = None 21 | 22 | def __call__(self, *args, **kwargs): 23 | return self.perturb(*args,**kwargs) 24 | 25 | def set_loss(self, loss): 26 | self.loss = loss 27 | 28 | def set_model(self, model): 29 | self.model = model 30 | 31 | def _get_loss_f(self, x, y, targeted, reduction): 32 | #x, y original ref_data / target 33 | #targeted whether to use a targeted attack or not 34 | #reduction: reduction to use: 'sum', 'mean', 'none' 35 | if isinstance(self.loss, str): 36 | if self.loss.lower() in ['crossentropy', 'ce']: 37 | if not targeted: 38 | l_f = lambda data, data_out: -F.cross_entropy(data_out, y, reduction=reduction) 39 | else: 40 | l_f = lambda data, data_out: F.cross_entropy(data_out, y, reduction=reduction ) 41 | elif self.loss.lower() =='kl': 42 | if not targeted: 43 | l_f = lambda data, data_out: -reduce(F.kl_div(torch.log_softmax(data_out,dim=1), y, reduction='none').sum(dim=1), reduction) 44 | else: 45 | l_f = lambda data, data_out: reduce(F.kl_div(torch.log_softmax(data_out,dim=1), y, reduction='none').sum(dim=1), reduction) 46 | elif self.loss.lower() == 'logitsdiff': 47 | if not targeted: 48 | y_oh = F.one_hot(y, self.num_classes) 49 | y_oh = y_oh.float() 50 | l_f = lambda data, data_out: -logits_diff_loss(data_out, y_oh, reduction=reduction) 51 | else: 52 | y_oh = F.one_hot(y, self.num_classes) 53 | y_oh = y_oh.float() 54 | l_f = lambda data, data_out: logits_diff_loss(data_out, y_oh, reduction=reduction) 55 | elif self.loss.lower() == 'conf': 56 | if not targeted: 57 | l_f = lambda data, data_out: confidence_loss(data_out, y, reduction=reduction) 58 | else: 59 | l_f = lambda data, data_out: -confidence_loss(data_out, y, reduction=reduction) 60 | elif self.loss.lower() == 'confdiff': 61 | if not targeted: 62 | y_oh = F.one_hot(y, self.num_classes) 63 | y_oh = y_oh.float() 64 | l_f = lambda data, data_out: -conf_diff_loss(data_out, y_oh, reduction=reduction) 65 | else: 66 | y_oh = F.one_hot(y, self.num_classes) 67 | y_oh = y_oh.float() 68 | l_f = lambda data, data_out: conf_diff_loss(data_out, y_oh, reduction=reduction) 69 | else: 70 | raise ValueError(f'Loss {self.loss} not supported') 71 | else: 72 | #custom 5 argument loss 73 | #(x_adv, x_adv_out, x, y, reduction) 74 | l_f = lambda data, data_out: self.loss(data, data_out, x, y, reduction=reduction) 75 | 76 | return l_f 77 | 78 | def get_config_dict(self): 79 | raise NotImplementedError() 80 | 81 | def get_last_trajectory(self): 82 | #output dimension: (iterations, batch_size, img_dimension) 83 | if not self.save_trajectory or self.last_trajectory is None: 84 | raise AssertionError() 85 | else: 86 | return self.last_trajectory 87 | 88 | def _get_trajectory_depth(self): 89 | raise NotImplementedError() 90 | 91 | def _check_model(self): 92 | if self.model is None: 93 | raise RuntimeError('Attack density_model not set') 94 | 95 | def perturb(self, x, y, targeted=False, x_init=None): 96 | #force child class implementation 97 | raise NotImplementedError() -------------------------------------------------------------------------------- /utils/adversarial_attacks/argmin_pgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from .restartattack import RestartAttack 6 | from .utils import project_perturbation, normalize_perturbation, create_early_stopping_mask, initialize_perturbation 7 | 8 | 9 | ################################################################################################### 10 | class ArgminPGD(RestartAttack): 11 | def __init__(self, eps, iterations, stepsize, num_classes, momentum=0.9, norm='inf', loss='CrossEntropy', 12 | normalize_grad=True, early_stopping=0, restarts=0, init_noise_generator=None, model=None, 13 | save_trajectory=False): 14 | super().__init__(loss, restarts, num_classes, model=model, save_trajectory=save_trajectory) 15 | self.eps = eps 16 | self.iterations = iterations 17 | self.stepsize = stepsize 18 | self.momentum = momentum 19 | self.norm = norm 20 | self.loss = loss 21 | self.normalize_grad = normalize_grad 22 | self.early_stopping = early_stopping 23 | self.init_noise_generator = init_noise_generator 24 | 25 | def _get_trajectory_depth(self): 26 | return self.iterations + 1 27 | 28 | def get_config_dict(self): 29 | dict = {} 30 | dict['type'] = 'ArgminPGD' 31 | dict['eps'] = self.eps 32 | dict['iterations'] = self.iterations 33 | dict['stepsize'] = self.stepsize 34 | dict['momentum'] = self.momentum 35 | dict['norm'] = self.norm 36 | if isinstance(self.loss, str): 37 | dict['loss'] = self.loss 38 | else: 39 | dict['loss'] = 'custom' 40 | dict['normalize_grad'] = self.normalize_grad 41 | dict['early_stopping'] = self.early_stopping 42 | dict['restarts'] = self.restarts 43 | return dict 44 | 45 | 46 | def perturb_inner(self, x, y, targeted=False, x_init=None): 47 | l_f = self._get_loss_f(x, y, targeted, 'none') 48 | 49 | best_perts = x.new_empty(x.shape) 50 | best_losses = 1e13 * x.new_ones(x.shape[0]) 51 | 52 | velocity = torch.zeros_like(x) 53 | 54 | #initialize perturbation 55 | pert = initialize_perturbation(x, self.eps, self.norm, x_init, self.init_noise_generator) 56 | 57 | #trajectory container 58 | if self.save_trajectory: 59 | trajectory = torch.zeros((self.iterations + 1,) + x.shape, device=x.device) 60 | trajectory[0, :] = x 61 | else: 62 | trajectory = None 63 | 64 | for i in range(self.iterations + 1): 65 | pert.requires_grad_(True) 66 | with torch.enable_grad(): 67 | p_data = x + pert 68 | out = self.model(p_data) 69 | loss_expanded = l_f(p_data, out) 70 | 71 | new_best = loss_expanded < best_losses 72 | best_losses[new_best] = loss_expanded[new_best].clone().detach() 73 | best_perts[new_best, :] = pert[new_best, :].clone().detach() 74 | 75 | if i == self.iterations: 76 | break 77 | 78 | if self.early_stopping > 0: 79 | finished, mask = create_early_stopping_mask(out, y, self.early_stopping, targeted) 80 | if finished: 81 | break 82 | else: 83 | mask = 1. 84 | 85 | loss = torch.mean(loss_expanded) 86 | grad = torch.autograd.grad(loss, pert)[0] 87 | 88 | with torch.no_grad(): 89 | # pgd on given loss 90 | if self.normalize_grad: 91 | # https://arxiv.org/pdf/1710.06081.pdf the l1 normalization follows the momentum iterative method 92 | l1_norm_gradient = 1e-10 + torch.sum(grad.abs().reshape(x.shape[0], -1), dim=1).view(-1,1,1,1) 93 | velocity = self.momentum * velocity + grad / l1_norm_gradient 94 | norm_velocity = normalize_perturbation(velocity, self.norm) 95 | else: 96 | # velocity update as in pytorch https://pytorch.org/docs/stable/optim.html 97 | velocity = self.momentum * velocity + grad 98 | norm_velocity = velocity 99 | 100 | pert = pert - self.stepsize * mask * norm_velocity 101 | pert = project_perturbation(pert, self.eps, self.norm) 102 | pert = torch.clamp(x + pert, 0, 1) - x #box constraint 103 | 104 | if self.save_trajectory: 105 | trajectory[i + 1] = x + pert 106 | 107 | final_loss = best_losses 108 | p_data = (x + best_perts).detach() 109 | return p_data, final_loss, trajectory 110 | -------------------------------------------------------------------------------- /utils/adversarial_attacks/dummy_attack.py: -------------------------------------------------------------------------------- 1 | from .adversarialattack import AdversarialAttack 2 | 3 | ################################## 4 | class DummyAttack(AdversarialAttack): 5 | def __init__(self): 6 | super().__init__(None, 0, model=None) 7 | 8 | def perturb(self, x, y, targeted=False, x_init=None): 9 | return x 10 | -------------------------------------------------------------------------------- /utils/adversarial_attacks/fgm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from .restartattack import RestartAttack 6 | from .utils import project_perturbation, normalize_perturbation, initialize_perturbation 7 | 8 | class FGM(RestartAttack): 9 | #one step attack with l2 or inf norm constraint 10 | def __init__(self, eps, num_classes, norm='inf', loss='CrossEntropy', restarts=0, init_noise_generator=None, 11 | model=None, save_trajectory=False): 12 | super().__init__(loss, restarts, num_classes, model=model, save_trajectory=save_trajectory) 13 | self.eps = eps 14 | self.norm = norm 15 | self.init_noise_generator = init_noise_generator 16 | 17 | def _get_trajectory_depth(self): 18 | return 2 19 | 20 | def get_config_dict(self): 21 | dict = {} 22 | dict['type'] = 'FGM' 23 | dict['eps'] = self.eps 24 | dict['norm'] = self.norm 25 | if isinstance(self.loss, str): 26 | dict['loss'] = self.loss 27 | else: 28 | dict['loss'] = 'custom' 29 | dict['restarts'] = self.restarts 30 | return dict 31 | 32 | 33 | def perturb_inner(self, x, y, targeted=False, x_init=None): 34 | l_f = self._get_loss_f(x, y, targeted, 'none') 35 | 36 | pert = initialize_perturbation(x, self.eps, self.norm, x_init, self.init_noise_generator) 37 | 38 | pert.requires_grad_(True) 39 | 40 | with torch.enable_grad(): 41 | p_data = x + pert 42 | out = self.model(p_data) 43 | loss_expanded = l_f(p_data, out) 44 | loss = loss_expanded.mean() 45 | grad = torch.autograd.grad(loss, pert)[0] 46 | 47 | with torch.no_grad(): 48 | pert = project_perturbation(pert - self.eps * normalize_perturbation(grad, self.norm), self.eps, self.norm) 49 | p_data = x + pert 50 | p_data = torch.clamp(p_data, 0, 1) 51 | final_loss = l_f(p_data, self.model(p_data)) 52 | 53 | if self.save_trajectory: 54 | trajectory = torch.zeros((2,) + x.shape, device=x.device) 55 | trajectory[0, :] = x 56 | trajectory[1, :] = p_data 57 | else: 58 | trajectory = None 59 | 60 | return p_data, final_loss, trajectory 61 | -------------------------------------------------------------------------------- /utils/adversarial_attacks/l1_projection.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def project_onto_l1_ball(x, eps): 4 | """ 5 | Compute Euclidean projection onto the L1 ball for a batch. 6 | 7 | min ||x - u||_2 s.t. ||u||_1 <= eps 8 | 9 | Inspired by the corresponding numpy version by Adrien Gaidon. 10 | 11 | Parameters 12 | ---------- 13 | x: (batch_size, *) torch array 14 | batch of arbitrary-out_size tensors to project, possibly on GPU 15 | 16 | eps: float 17 | radius of l-1 ball to project onto 18 | 19 | Returns 20 | ------- 21 | u: (batch_size, *) torch array 22 | batch of projected tensors, reshaped to match the original 23 | 24 | Notes 25 | ----- 26 | The complexity of this algorithm is in O(dlogd) as it involves sorting x. 27 | 28 | References 29 | ---------- 30 | [1] Efficient Projections onto the l1-Ball for Learning in High Dimensions 31 | John Duchi, Shai Shalev-Shwartz, Yoram Singer, and Tushar Chandra. 32 | International Conference on Machine Learning (ICML 2008) 33 | """ 34 | original_shape = x.shape 35 | x = x.view(x.shape[0], -1) 36 | mask = (torch.norm(x, p=1, dim=1) < eps).float().unsqueeze(1) 37 | mu, _ = torch.sort(torch.abs(x), dim=1, descending=True) 38 | cumsum = torch.cumsum(mu, dim=1) 39 | arange = torch.arange(1, x.shape[1] + 1, device=x.device) 40 | rho, _ = torch.max((mu * arange > (cumsum - eps)) * arange, dim=1) 41 | theta = (cumsum[torch.arange(x.shape[0]), rho.cpu() - 1] - eps) / rho 42 | proj = (torch.abs(x) - theta.unsqueeze(1)).clamp(min=0) 43 | x = mask * x + (1 - mask) * proj * torch.sign(x) 44 | return x.view(original_shape) -------------------------------------------------------------------------------- /utils/adversarial_attacks/l1_reg_pgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from .restartattack import RestartAttack 6 | from .utils import project_perturbation, normalize_perturbation, create_early_stopping_mask, initialize_perturbation 7 | 8 | 9 | ################################################################################################### 10 | class L1RegularizedPGD(RestartAttack): 11 | def __init__(self, eps, iterations, stepsize, num_classes, reg_weight=1.0, momentum=0.9, norm='l2', loss='CrossEntropy', 12 | normalize_grad=True, early_stopping=0, restarts=0, init_noise_generator=None, model=None, 13 | save_trajectory=False): 14 | super().__init__(loss, restarts, num_classes, model=model, save_trajectory=save_trajectory) 15 | self.eps = eps 16 | self.iterations = iterations 17 | self.stepsize = stepsize 18 | self.reg_weight = reg_weight 19 | self.momentum = momentum 20 | self.norm = norm 21 | self.loss = loss 22 | self.normalize_grad = normalize_grad 23 | self.early_stopping = early_stopping 24 | self.init_noise_generator = init_noise_generator 25 | 26 | def _get_trajectory_depth(self): 27 | return self.iterations + 1 28 | 29 | def get_config_dict(self): 30 | dict = {} 31 | dict['type'] = 'ArgminPGD' 32 | dict['eps'] = self.eps 33 | dict['iterations'] = self.iterations 34 | dict['stepsize'] = self.stepsize 35 | dict['momentum'] = self.momentum 36 | dict['norm'] = self.norm 37 | if isinstance(self.loss, str): 38 | dict['loss'] = self.loss 39 | else: 40 | dict['loss'] = 'custom' 41 | dict['normalize_grad'] = self.normalize_grad 42 | dict['early_stopping'] = self.early_stopping 43 | dict['restarts'] = self.restarts 44 | return dict 45 | 46 | 47 | def perturb_inner(self, x, y, targeted=False, x_init=None): 48 | l_f = self._get_loss_f(x, y, targeted, 'none') 49 | 50 | best_perts = x.new_empty(x.shape) 51 | best_losses = 1e13 * x.new_ones(x.shape[0]) 52 | 53 | velocity = torch.zeros_like(x) 54 | 55 | #initialize perturbation 56 | pert = initialize_perturbation(x, self.eps, self.norm, x_init, self.init_noise_generator) 57 | pert_plus = torch.zeros_like(pert) 58 | pert_minus = torch.zeros_like(pert) 59 | 60 | pert_plus[pert > 0] = pert[pert > 0] 61 | pert_minus[pert < 0] = pert[pert < 0].abs() 62 | 63 | #trajectory container 64 | if self.save_trajectory: 65 | trajectory = torch.zeros((self.iterations + 1,) + x.shape, device=x.device) 66 | trajectory[0, :] = x 67 | else: 68 | trajectory = None 69 | 70 | for i in range(self.iterations + 1): 71 | pert_plus.requires_grad_(True) 72 | pert_minus.requires_grad_(True) 73 | pert = pert_plus - pert_minus 74 | with torch.enable_grad(): 75 | p_data = x + pert 76 | out = self.model(p_data) 77 | main_loss_expanded = l_f(p_data, out) 78 | l1_reg = torch.sum(pert_plus + pert_minus) 79 | 80 | new_best = loss_expanded < best_losses 81 | best_losses[new_best] = loss_expanded[new_best].clone().detach() 82 | best_perts[new_best, :] = pert[new_best, :].clone().detach() 83 | 84 | if i == self.iterations: 85 | break 86 | 87 | if self.early_stopping > 0: 88 | finished, mask = create_early_stopping_mask(out, y, self.early_stopping, targeted) 89 | if finished: 90 | break 91 | else: 92 | mask = 1. 93 | 94 | loss = torch.mean(loss_expanded) 95 | pert_plus_grad, pert_minus_grad = torch.autograd.grad(loss, [pert_plus, pert_minus])[0] 96 | 97 | with torch.no_grad(): 98 | # pgd on given loss 99 | if self.normalize_grad: 100 | # https://arxiv.org/pdf/1710.06081.pdf the l1 normalization follows the momentum iterative method 101 | l1_norm_gradient = 1e-10 + torch.sum(grad.abs().reshape(x.shape[0], -1), dim=1).view(-1,1,1,1) 102 | velocity = self.momentum * velocity + grad / l1_norm_gradient 103 | norm_velocity = normalize_perturbation(velocity, self.norm) 104 | else: 105 | # velocity update as in pytorch https://pytorch.org/docs/stable/optim.html 106 | velocity = self.momentum * velocity + grad 107 | norm_velocity = velocity 108 | 109 | pert = pert - self.stepsize * mask * norm_velocity 110 | pert = project_perturbation(pert, self.eps, self.norm) 111 | pert = torch.clamp(x + pert, 0, 1) - x #box constraint 112 | 113 | if self.save_trajectory: 114 | trajectory[i + 1] = x + pert 115 | 116 | final_loss = best_losses 117 | p_data = (x + best_perts).detach() 118 | return p_data, final_loss, trajectory 119 | -------------------------------------------------------------------------------- /utils/adversarial_attacks/noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributions as distributions 5 | import math 6 | import torch.optim as optim 7 | 8 | ################################################### 9 | class AdversarialNoiseGenerator(torch.nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | return 13 | 14 | def forward(self, x): 15 | #generate nosie matching the out_size of x 16 | raise NotImplementedError() 17 | 18 | class UniformNoiseGenerator(AdversarialNoiseGenerator): 19 | def __init__(self, min=0.0, max=1.0): 20 | super().__init__() 21 | self.min = min 22 | self.max = max 23 | 24 | def forward(self, x): 25 | return (self.max - self.min) * torch.rand_like(x) + self.min 26 | 27 | class NormalNoiseGenerator(AdversarialNoiseGenerator): 28 | def __init__(self, sigma=1.0, mu=0): 29 | super().__init__() 30 | self.sigma = sigma 31 | self.mu = mu 32 | 33 | def forward(self, x): 34 | return self.sigma * torch.randn_like(x) + self.mu 35 | 36 | class CALNoiseGenerator(AdversarialNoiseGenerator): 37 | def __init__(self, rho=1, lambda_scheme='normal'): 38 | super().__init__() 39 | self.rho = rho 40 | self.lambda_scheme = lambda_scheme 41 | 42 | def forward(self, x): 43 | if self.lambda_scheme == 'normal': 44 | lambda_targets = x.new_zeros(x.shape[0]) 45 | reject_idcs = lambda_targets < 1 46 | #rejection sample from truncated normal 47 | while sum(reject_idcs > 0): 48 | lambda_targets[reject_idcs] = math.sqrt(self.rho) * torch.randn(sum(reject_idcs), device=x.device).abs() + 1e-8 49 | reject_idcs = lambda_targets > 1 50 | elif self.lambda_scheme == 'uniform': 51 | lambda_targets = torch.rand(x.shape[0], device=x.device) 52 | 53 | target_dists_sqr = -torch.log( lambda_targets) * self.rho 54 | dirs = torch.randn_like(x) 55 | dirs_lengths = torch.norm( dirs.view( x.shape[0], -1) , dim=1) 56 | dirs_normalized = dirs / dirs_lengths.view(x.shape[0], 1, 1, 1) 57 | perts = target_dists_sqr.sqrt().view(x.shape[0], 1, 1, 1) * dirs_normalized 58 | return perts 59 | 60 | -------------------------------------------------------------------------------- /utils/adversarial_attacks/pgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributions as distributions 5 | import math 6 | import torch.optim as optim 7 | from .restartattack import RestartAttack 8 | from .utils import project_perturbation, normalize_perturbation, create_early_stopping_mask, initialize_perturbation 9 | 10 | class PGD(RestartAttack): 11 | def __init__(self, eps, iterations, stepsize, num_classes, momentum=0.9, norm='inf', loss='CrossEntropy', 12 | normalize_grad=True, early_stopping=0, restarts=0, init_noise_generator=None, model=None, 13 | save_trajectory=False): 14 | super().__init__(loss, restarts, num_classes, model=model, save_trajectory=save_trajectory) 15 | #loss either pass 'CrossEntropy' or 'LogitsDiff' or custom loss function 16 | self.eps = eps 17 | self.iterations = iterations 18 | self.stepsize = stepsize 19 | self.momentum = momentum 20 | self.norm = norm 21 | self.loss = loss 22 | self.normalize_grad = normalize_grad 23 | self.early_stopping = early_stopping 24 | self.init_noise_generator = init_noise_generator 25 | 26 | def _get_trajectory_depth(self): 27 | return self.iterations + 1 28 | 29 | def get_config_dict(self): 30 | dict = {} 31 | dict['type'] = 'PGD' 32 | dict['eps'] = self.eps 33 | dict['iterations'] = self.iterations 34 | dict['stepsize'] = self.stepsize 35 | dict['momentum'] = self.momentum 36 | dict['norm'] = self.norm 37 | if isinstance(self.loss, str): 38 | dict['loss'] = self.loss 39 | else: 40 | dict['loss'] = 'custom' 41 | dict['normalize_grad'] = self.normalize_grad 42 | dict['early_stopping'] = self.early_stopping 43 | dict['restarts'] = self.restarts 44 | return dict 45 | 46 | 47 | def perturb_inner(self, x, y, targeted=False, x_init=None): 48 | l_f = self._get_loss_f(x, y, targeted, 'none') 49 | 50 | velocity = torch.zeros_like(x) 51 | 52 | #initialize perturbation 53 | pert = initialize_perturbation(x, self.eps, self.norm, x_init, self.init_noise_generator) 54 | 55 | #trajectory container 56 | if self.save_trajectory: 57 | trajectory = torch.zeros((self.iterations + 1,) + x.shape, device=x.device) 58 | trajectory[0, :] = x 59 | else: 60 | trajectory = None 61 | 62 | for i in range(self.iterations): 63 | pert.requires_grad_(True) 64 | with torch.enable_grad(): 65 | p_data = x + pert 66 | out = self.model(p_data) 67 | 68 | if self.early_stopping > 0: 69 | finished, mask = create_early_stopping_mask(out, y, self.early_stopping, targeted) 70 | if finished: 71 | break 72 | else: 73 | mask = 1. 74 | 75 | loss_expanded = l_f(p_data, out) 76 | loss = loss_expanded.mean() 77 | grad = torch.autograd.grad(loss, pert)[0] 78 | 79 | with torch.no_grad(): 80 | # pgd on given loss 81 | if self.normalize_grad: 82 | # https://arxiv.org/pdf/1710.06081.pdf the l1 normalization follows the momentum iterative method 83 | l1_norm_gradient = 1e-10 + torch.sum(grad.abs().view(x.shape[0], -1), dim=1).view(-1,1,1,1) 84 | velocity = self.momentum * velocity + grad / l1_norm_gradient 85 | norm_velocity = normalize_perturbation(velocity, self.norm) 86 | else: 87 | # velocity update as in pytorch https://pytorch.org/docs/stable/optim.html 88 | velocity = self.momentum * velocity + grad 89 | norm_velocity = velocity 90 | 91 | pert = pert - self.stepsize * mask * norm_velocity 92 | pert = project_perturbation(pert, self.eps, self.norm) 93 | pert = torch.clamp(x + pert, 0, 1) - x #box constraint 94 | 95 | if self.save_trajectory: 96 | trajectory[i + 1] = x + pert 97 | 98 | p_data = x + pert 99 | return p_data, l_f(p_data, self.model(p_data)), trajectory 100 | -------------------------------------------------------------------------------- /utils/adversarial_attacks/restartattack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from .adversarialattack import AdversarialAttack 6 | 7 | class RestartAttack(AdversarialAttack): 8 | #Base class for attacks that start from different initial values 9 | #Make sure that they MINIMIZE the given loss function 10 | def __init__(self, loss, restarts, num_classes, model=None, save_trajectory=False): 11 | super().__init__(loss, num_classes, model=model, save_trajectory=save_trajectory) 12 | self.restarts = restarts 13 | 14 | def perturb_inner(self, x, y, targeted=False, x_init=None): 15 | #force child class implementation 16 | raise NotImplementedError() 17 | 18 | def perturb(self, x, y, targeted=False, x_init=None): 19 | #base class method that handles various restarts 20 | self._check_model() 21 | 22 | is_train = self.model.training 23 | self.model.eval() 24 | 25 | restarts_data = x.new_empty((1 + self.restarts,) + x.shape) 26 | restarts_objs = x.new_empty((1 + self.restarts, x.shape[0])) 27 | 28 | if self.save_trajectory: 29 | self.last_trajectory = None 30 | trajectories_shape = (1 + self.restarts, self._get_trajectory_depth(),) + x.shape 31 | restart_trajectories = x.new_empty(trajectories_shape, device=torch.device('cpu')) 32 | 33 | for k in range(1 + self.restarts): 34 | k_data, k_obj, k_trajectory = self.perturb_inner(x, y, targeted=targeted, x_init=x_init) 35 | restarts_data[k, :] = k_data 36 | restarts_objs[k, :] = k_obj 37 | if self.save_trajectory: 38 | restart_trajectories[k, :] = k_trajectory.cpu() 39 | 40 | bs = x.shape[0] 41 | best_idx = torch.argmin(restarts_objs, 0) 42 | best_data = restarts_data[best_idx, range(bs), :] 43 | 44 | if self.save_trajectory: 45 | self.last_trajectory = restart_trajectories[best_idx, :, range(bs), :] 46 | 47 | #reset density_model status 48 | if is_train: 49 | self.model.train() 50 | else: 51 | self.model.eval() 52 | 53 | return best_data 54 | -------------------------------------------------------------------------------- /utils/adversarial_attacks/tanh_attack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributions as distributions 5 | import math 6 | import torch.optim as optim 7 | from .restartattack import RestartAttack 8 | from .utils import project_perturbation, normalize_perturbation, create_early_stopping_mask 9 | 10 | #################################################################################################### 11 | def atanh(x): 12 | return 0.5*torch.log((1+x)/(1-x)) 13 | 14 | def img_to_tanh(x, boxmin=0, boxmax=1): 15 | boxmul = 0.5 * (boxmax - boxmin) 16 | boxplus = 0.5 * (boxmin + boxmax) 17 | return atanh((x - boxplus) / boxmul) 18 | 19 | def tanh_to_img(w, boxmin=0, boxmax=1): 20 | boxmul = 0.5 * (boxmax - boxmin) 21 | boxplus = 0.5 * (boxmin + boxmax) 22 | #transform tanh space image to normal image space with bound [low_b, up_b] 23 | return torch.tanh(w) * boxmul + boxplus 24 | 25 | def CW_loss(pert_img, orig_img, out, y_oh, targeted, reg_weight,confidence=0, reduce=True): 26 | bs = pert_img.shape[0] 27 | loss_1 = torch.sum(((pert_img - orig_img) ** 2).view(bs, -1), dim=1) 28 | 29 | # logits of gt class 30 | out_real = torch.sum((out * y_oh), 1) 31 | # logits of other highest scoring 32 | out_other = torch.max(out * (1. - y_oh) - y_oh * 100000000., 1)[0] 33 | 34 | if targeted: 35 | # maximize target class and minimize second highest 36 | loss_2 = torch.clamp_min(out_other - out_real, -confidence) 37 | else: 38 | # minimize target and max second highest 39 | loss_2 = torch.clamp_min(out_real - out_other, -confidence) 40 | 41 | if reduce: 42 | loss = torch.mean(reg_weight * loss_1 + loss_2) 43 | 44 | return loss 45 | 46 | #https://arxiv.org/pdf/1608.04644.pdf 47 | #Tensorflow: https://github.com/carlini/nn_robust_attacks/blob/master/l2_attack.py 48 | class TanhIterativeAttack(RestartAttack): 49 | def __init__(self, iterations, stepsize, num_classes, loss='CW', restarts=0, init_noise_generator=None, confidence=0.0, early_stopping=0, model=None, reg_weight=1): 50 | super().__init__(restarts=restarts, model=model, save_trajectory=False) 51 | self.iterations = iterations 52 | self.stepsize = stepsize 53 | self.num_classes = num_classes 54 | self.loss = loss 55 | self.init_noise_generator = init_noise_generator 56 | self.confidence = confidence 57 | self.reg_weight = reg_weight 58 | self.early_stopping=early_stopping 59 | 60 | def get_config_dict(self): 61 | dict = {} 62 | dict['type'] = 'TanhIterative' 63 | dict['iterations'] = self.iterations 64 | dict['stepsize'] = self.stepsize 65 | dict['reg weight'] = self.reg_weight 66 | dict['confidence'] = self.confidence 67 | #config_dict['init_sigma'] = self.init_sigma 68 | return dict 69 | 70 | def perturb_inner(self, x, y, targeted=False, x_init=None): 71 | bs = y.shape[0] 72 | 73 | if self.loss == 'CW': 74 | y_oh = torch.nn.functional.one_hot(y, self.num_classes) 75 | y_oh = y_oh.float() 76 | l_f = lambda data, data_out: CW_loss(data, x, data_out, y_oh, targeted, self.reg_weight, confidence=self.confidence) 77 | else: 78 | l_f = lambda data, data_out: self.loss(data, data_out, x, y) 79 | 80 | 81 | data = x.clone().detach() 82 | 83 | 84 | data_tanh = img_to_tanh(data) 85 | 86 | if self.init_noise_generator is None: 87 | pert_tanh = torch.zeros_like(data) 88 | else: 89 | raise NotImplementedError() 90 | 91 | pert_tanh.requires_grad_(True) 92 | 93 | optimizer = optim.Adam([pert_tanh], self.stepsize) 94 | 95 | for i in range(self.iterations): 96 | optimizer.zero_grad() 97 | 98 | with torch.enable_grad(): 99 | #distance to original image 100 | pert_img = tanh_to_img(data_tanh + pert_tanh) 101 | out = self.model(pert_img) 102 | loss = l_f(pert_img, out) 103 | 104 | if self.early_stopping > 0: 105 | conf, pred = torch.max(torch.nn.functional.softmax(out, dim=1), 1) 106 | conf_mask = conf > self.early_stopping 107 | if targeted: 108 | correct_mask = torch.eq(y, pred) 109 | else: 110 | correct_mask = (~torch.eq(y, pred)) 111 | mask = (conf_mask & correct_mask).detach() 112 | saved_perts = pert_tanh[mask,:].detach() 113 | 114 | if sum(mask.float()) == x.shape[0]: 115 | break 116 | 117 | loss.backward() 118 | optimizer.step() 119 | 120 | if self.early_stopping > 0: 121 | pert_tanh[mask] = saved_perts 122 | 123 | 124 | 125 | pert_img = tanh_to_img(data_tanh + pert_tanh) 126 | loss = l_f(pert_img, self.model(pert_img)) 127 | return pert_img, loss, None 128 | 129 | -------------------------------------------------------------------------------- /utils/average_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch.nn import Module 4 | from copy import deepcopy 5 | from torch.nn.modules.batchnorm import _BatchNorm 6 | 7 | class AveragedModel(Module): 8 | """ 9 | Modified AveragedModel from torch swa_utils that supports EMA and SWA updates and batch norm averaging 10 | """ 11 | def __init__(self, model, avg_type='ema', ema_decay=0.990, avg_batchnorm=False, device=None): 12 | super(AveragedModel, self).__init__() 13 | self.module = deepcopy(model) 14 | if device is not None: 15 | self.module = self.module.to(device) 16 | self.register_buffer('n_averaged', 17 | torch.tensor(0, dtype=torch.long, device=device)) 18 | 19 | assert avg_type in ['ema', 'swa'] 20 | self.avg_type = avg_type 21 | self.ema_decay = ema_decay 22 | self.avg_batchnorm = avg_batchnorm 23 | 24 | def forward(self, *args, **kwargs): 25 | return self.module(*args, **kwargs) 26 | 27 | def update_parameters(self, model): 28 | n = self.n_averaged.item() 29 | if self.avg_type == 'ema': 30 | decay = min( 31 | self.ema_decay, 32 | (1 + n) / (10 + n) 33 | ) 34 | avg_fn = lambda averaged_model_parameter, model_parameter: \ 35 | decay * averaged_model_parameter + (1.0 - decay) * model_parameter 36 | elif self.avg_type == 'swa': 37 | avg_fn = lambda averaged_model_parameter, model_parameter: \ 38 | (model_parameter - averaged_model_parameter) / (n + 1) 39 | else: 40 | raise NotImplementedError() 41 | 42 | for p_swa, p_model in zip(self.parameters(), model.parameters()): 43 | device = p_swa.device 44 | p_model_ = p_model.detach().to(device) 45 | if n == 0: 46 | p_swa.detach().copy_(p_model_) 47 | else: 48 | p_swa.detach().copy_(avg_fn(p_swa.detach(), p_model_,)) 49 | 50 | if self.avg_batchnorm: 51 | for avg_mod, model_mod in zip(self.module.modules(), model.modules()): 52 | if issubclass(type(model_mod), _BatchNorm): 53 | device = avg_mod.running_mean.device 54 | mean_model_ = model_mod.running_mean.detach().to(device) 55 | var_model_ = model_mod.running_var.detach().to(device) 56 | if n == 0: 57 | avg_mod.running_mean.detach().copy_(mean_model_) 58 | avg_mod.running_var.detach().copy_(var_model_) 59 | else: 60 | avg_mod.running_mean.detach().copy_( 61 | avg_fn(avg_mod.running_mean.detach(), mean_model_)) 62 | avg_mod.running_var.detach().copy_( 63 | avg_fn(avg_mod.running_var.detach(), var_model_)) 64 | 65 | self.n_averaged += 1 -------------------------------------------------------------------------------- /utils/bit_downstream_schedule.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import utils.train_types.schedulers as schedulers 3 | import utils.train_types.optimizers as optimizers 4 | 5 | def get_bit_scheduler_optim_configs(dataset_size, dataloader_length, lr=0.001, decay=0, nesterov=False): 6 | # Bit hyperrule 7 | # https://github.com/google-research/big_transfer/blob/0bb237d6e34ab770b56502c90424d262e565a7f3/bit_hyperrule.py#L30 8 | if dataset_size < 20_000: 9 | decay_steps = [200, 300, 400] 10 | total_updates = 500 11 | warmup_steps = 100 12 | elif dataset_size < 250_000: 13 | decay_steps = [3000, 6000, 9000] 14 | total_updates = 10000 15 | warmup_steps = 500 16 | # elif dataset_size < 500_000: 17 | # decay_steps = [4500, 9000, 13500] 18 | # total_updates = 15000 19 | # warmup_steps = 500 20 | else: 21 | decay_steps = [6000, 12_000, 18_000] 22 | total_updates = 20000 23 | warmup_steps = 500 24 | 25 | decay_rate = 0.1 26 | epochs = int(np.ceil(total_updates / dataloader_length)) 27 | 28 | # convert from batch to epoch 29 | decay_epochs = [decay_step / dataloader_length for decay_step in decay_steps] 30 | warmup_length_epochs = warmup_steps / dataloader_length 31 | 32 | scheduler_config = schedulers.create_piecewise_consant_scheduler_config(epochs, decay_epochs, decay_rate, 33 | warmup_length=warmup_length_epochs) 34 | 35 | #########OPTIMIZER 36 | optimizer_config = optimizers.create_optimizer_config('SGD', lr, momentum=0.9, 37 | weight_decay=decay, nesterov=nesterov) 38 | 39 | return scheduler_config, optimizer_config, epochs 40 | 41 | def get_ssL_bit_scheduler_optim_configs(dataset_size, dataloader_length, lr=0.001, decay=0, nesterov=False): 42 | # Bit hyperrule 43 | # https://github.com/google-research/big_transfer/blob/0bb237d6e34ab770b56502c90424d262e565a7f3/bit_hyperrule.py#L30 44 | if dataset_size < 10_000: 45 | decay_steps = [200, 300, 400] 46 | total_updates = 500 47 | warmup_steps = 100 48 | if dataset_size < 20_000: 49 | decay_steps = [200, 300, 400] 50 | total_updates = 500 51 | warmup_steps = 100 52 | elif dataset_size < 250_000: 53 | decay_steps = [3000, 6000, 9000] 54 | total_updates = 10000 55 | warmup_steps = 500 56 | # elif dataset_size < 500_000: 57 | # decay_steps = [4500, 9000, 13500] 58 | # total_updates = 15000 59 | # warmup_steps = 500 60 | else: 61 | decay_steps = [6000, 12_000, 18_000] 62 | total_updates = 20000 63 | warmup_steps = 500 64 | 65 | decay_rate = 0.1 66 | epochs = int(np.ceil(total_updates / dataloader_length)) 67 | 68 | # convert from batch to epoch 69 | decay_epochs = [decay_step / dataloader_length for decay_step in decay_steps] 70 | warmup_length_epochs = warmup_steps / dataloader_length 71 | 72 | scheduler_config = schedulers.create_piecewise_consant_scheduler_config(epochs, decay_epochs, decay_rate, 73 | warmup_length=warmup_length_epochs) 74 | 75 | #########OPTIMIZER 76 | optimizer_config = optimizers.create_optimizer_config('SGD', lr, momentum=0.9, 77 | weight_decay=decay, nesterov=nesterov) 78 | 79 | return scheduler_config, optimizer_config, epochs 80 | -------------------------------------------------------------------------------- /utils/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar import get_CIFAR10, get_CIFAR100, get_CIFAR10_1 2 | from .semisupervised_dataset import get_CIFAR10_ti_500k 3 | from .svhn import get_SVHN 4 | from .celebA import celebA_feature_set, celebA_ImageNetOD 5 | from .imagenet import get_ImageNet 6 | from .imagenet_subsets import get_restrictedImageNet, get_restrictedImageNetOD, get_ImageNet100, get_ImageNetCloseToCifar 7 | from .fgvc_aircraft import get_fgvc_aircraft 8 | from .food_101N import get_food_101N 9 | from .food_101 import get_food_101 10 | from .flowers import get_flowers 11 | from .pets import get_pets 12 | from .stanford_cars import get_stanford_cars 13 | from .tinyImages import get_80MTinyImages, TinyImagesDataset, TINY_LENGTH 14 | from .tiny_image_net import get_TinyImageNet 15 | from .lsun import get_LSUN_CR, get_LSUN_scenes 16 | from .openimages import get_openImages 17 | from .cifar_corrupted import get_CIFAR10_C, get_CIFAR100_C 18 | from .cinic_10 import get_CINIC10 19 | from .noise_datasets import get_noise_dataset -------------------------------------------------------------------------------- /utils/datasets/augmentations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/M4xim4l/InNOutRobustness/d81d1d26e5ebc9193009e3d92bd67b5e01d6cfd6/utils/datasets/augmentations/__init__.py -------------------------------------------------------------------------------- /utils/datasets/augmentations/cutout.py: -------------------------------------------------------------------------------- 1 | #https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class Cutout(object): 8 | """Randomly mask out one or more patches from an image. 9 | Args: 10 | n_holes (int): Number of patches to cut out of each image. 11 | length (int): The length (in pixels) of each square patch. 12 | """ 13 | def __init__(self, n_holes, length, fill_color=torch.tensor([0,0,0])): 14 | self.n_holes = n_holes 15 | self.length = length 16 | self.fill_color = fill_color 17 | 18 | def __call__(self, img): 19 | """ 20 | Args: 21 | img (Tensor): Tensor image of out_size (C, H, W). 22 | Returns: 23 | Tensor: Image with n_holes of dimension length x length cut out of it. 24 | """ 25 | h = img.shape[1] 26 | w = img.shape[2] 27 | 28 | mask = np.ones((h, w), np.float32) 29 | 30 | for n in range(self.n_holes): 31 | y = np.random.randint(h) 32 | x = np.random.randint(w) 33 | 34 | y1 = np.clip(y - self.length // 2, 0, h) 35 | y2 = np.clip(y + self.length // 2, 0, h) 36 | x1 = np.clip(x - self.length // 2, 0, w) 37 | x2 = np.clip(x + self.length // 2, 0, w) 38 | 39 | mask[y1: y2, x1: x2] = 0. 40 | 41 | mask = torch.from_numpy(mask) 42 | mask = mask.expand_as(img) 43 | 44 | img = img * mask + (1 - mask) * self.fill_color[:, None, None] 45 | 46 | return img 47 | -------------------------------------------------------------------------------- /utils/datasets/augmentations/svhn_augmentation.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | import torch 3 | from utils.datasets.augmentations.autoaugment import SVHNPolicy, CIFAR10Policy 4 | from utils.datasets.augmentations.cutout import Cutout 5 | from .utils import INTERPOLATION_STRING_TO_TYPE 6 | 7 | SVHN_mean = (0.4377, 0.4438, 0.4728) 8 | 9 | DEFAULT_SVHN_PARAMETERS = { 10 | 'interpolation': 'bilinear', 11 | 'mean': SVHN_mean, 12 | 'crop_pct': 0.875 13 | } 14 | 15 | def get_SVHN_augmentation(augm_type='none', in_size=32, out_size=32, augm_parameters=None, config_dict=None): 16 | if augm_parameters is None: 17 | augm_parameters = DEFAULT_SVHN_PARAMETERS 18 | 19 | mean_int = tuple(int(255. * v) for v in augm_parameters['mean']) 20 | mean_tensor = torch.FloatTensor(augm_parameters['mean']) 21 | padding_size = int((1. - augm_parameters['crop_pct']) * in_size) 22 | interpolation_mode = INTERPOLATION_STRING_TO_TYPE[augm_parameters['interpolation']] 23 | 24 | if augm_type == 'none': 25 | transform_list = [] 26 | elif augm_type == 'default' or augm_type == 'default_cutout': 27 | transform_list = [ 28 | transforms.RandomCrop(in_size, padding=padding_size, fill=mean_int), 29 | ] 30 | elif augm_type == 'autoaugment' or augm_type == 'autoaugment_cutout': 31 | transform_list = [ 32 | transforms.RandomCrop(in_size, padding=padding_size, fill=mean_int), 33 | SVHNPolicy(fillcolor=mean_int), 34 | ] 35 | elif augm_type == 'cifar_autoaugment' or augm_type == 'cifar_autoaugment_cutout': 36 | transform_list = [ 37 | transforms.RandomCrop(in_size, padding=padding_size, fill=mean_int), 38 | CIFAR10Policy(fillcolor=mean_int), 39 | ] 40 | else: 41 | raise ValueError() 42 | 43 | cutout_window = 16 44 | cutout_color = mean_tensor 45 | cutout_size = 0 46 | 47 | if out_size != in_size: 48 | if 'cutout' in augm_type: 49 | transform_list.append(transforms.Resize(out_size, interpolation=interpolation_mode)) 50 | transform_list.append(transforms.ToTensor()) 51 | cutout_size = int(out_size / in_size * cutout_window) 52 | print(f'Relative Cutout window {cutout_window / in_size} - Absolute Cutout window: {cutout_size}') 53 | transform_list.append(Cutout(n_holes=1, length=cutout_size, fill_color=cutout_color)) 54 | else: 55 | transform_list.append(transforms.Resize(out_size, interpolation=interpolation_mode)) 56 | transform_list.append(transforms.ToTensor()) 57 | elif 'cutout' in augm_type: 58 | cutout_size = cutout_window 59 | print(f'Relative Cutout window {cutout_size / in_size} - Absolute Cutout window: {cutout_size}') 60 | transform_list.append(transforms.ToTensor()) 61 | transform_list.append(Cutout(n_holes=1, length=cutout_size, fill_color=cutout_color)) 62 | else: 63 | transform_list.append(transforms.ToTensor()) 64 | 65 | transform = transforms.Compose(transform_list) 66 | 67 | if config_dict is not None: 68 | config_dict['type'] = type 69 | config_dict['Input size'] = in_size 70 | config_dict['Output size'] = out_size 71 | if 'cutout' in augm_type: 72 | config_dict['Cutout out_size'] = cutout_size 73 | for key, value in augm_parameters.items(): 74 | config_dict[key] = value 75 | 76 | return transform -------------------------------------------------------------------------------- /utils/datasets/augmentations/utils.py: -------------------------------------------------------------------------------- 1 | import PIL.Image as Img 2 | from torchvision.transforms.functional import InterpolationMode 3 | 4 | INTERPOLATION_STRING_TO_TYPE = { 5 | 'nearest': InterpolationMode.NEAREST, 6 | 'bilinear': InterpolationMode.BILINEAR, 7 | 'bicubic': InterpolationMode.BICUBIC, 8 | 'lanczos': InterpolationMode.LANCZOS 9 | } -------------------------------------------------------------------------------- /utils/datasets/cinic_10.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions 3 | from torch.utils.data import DataLoader 4 | from torchvision.datasets import ImageFolder 5 | import os 6 | 7 | from .paths import get_CINIC10_path 8 | from utils.datasets.augmentations.cifar_augmentation import get_cifar10_augmentation 9 | 10 | DEFAULT_TRAIN_BATCHSIZE = 128 11 | DEFAULT_TEST_BATCHSIZE = 128 12 | 13 | 14 | def get_CINIC10(split='train', batch_size=None, shuffle=False, 15 | augm_type='none', cutout_window=16, num_workers=2, size=32, config_dict=None): 16 | if batch_size == None: 17 | batch_size = DEFAULT_TEST_BATCHSIZE 18 | 19 | augm_config = {} 20 | transform = get_cifar10_augmentation(type=augm_type, cutout_window=cutout_window, out_size=size, config_dict=augm_config) 21 | 22 | path = get_CINIC10_path() 23 | if split == 'train': 24 | cinic_subdir = 'train' 25 | elif split == 'val': 26 | cinic_subdir = 'valid' 27 | elif split == 'test': 28 | cinic_subdir = 'test' 29 | else: 30 | raise ValueError() 31 | 32 | cinic_directory = os.path.join(path, cinic_subdir) 33 | cinic_dataset = ImageFolder(cinic_directory,transform=transform) 34 | 35 | loader = torch.utils.data.DataLoader(cinic_dataset, batch_size=batch_size, 36 | shuffle=shuffle, num_workers=num_workers) 37 | 38 | if config_dict is not None: 39 | config_dict['Dataset'] = 'CINIC-10' 40 | config_dict['Batch out_size'] = batch_size 41 | config_dict['Augmentation'] = augm_config 42 | 43 | return loader 44 | -------------------------------------------------------------------------------- /utils/datasets/combo_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | class ComboDataset(Dataset): 5 | def __init__(self, datasets): 6 | num_datasets = len(datasets) 7 | self.dataset_lengths = torch.zeros(num_datasets, dtype=torch.long) 8 | 9 | for i, ds in enumerate(datasets): 10 | self.dataset_lengths[i] = len(ds) 11 | 12 | self.cum_lengths = torch.cumsum(self.dataset_lengths, dim=0) 13 | self.length = torch.sum(self.dataset_lengths) 14 | self.datasets = datasets 15 | 16 | def __len__(self): 17 | return self.length 18 | 19 | def __getitem__(self, index): 20 | ds_idx = torch.nonzero(self.cum_lengths > index, as_tuple=False)[0] 21 | if ds_idx > 0: 22 | item_idx = index - self.cum_lengths[ds_idx - 1] 23 | else: 24 | item_idx = index 25 | 26 | return self.datasets[ds_idx][item_idx] 27 | -------------------------------------------------------------------------------- /utils/datasets/flowers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions 3 | from torch.utils.data import DataLoader, Dataset 4 | 5 | import numpy as np 6 | import os 7 | from torchvision.datasets.folder import default_loader 8 | from scipy.io import loadmat 9 | 10 | from .paths import get_flowers_path 11 | from utils.datasets.augmentations.imagenet_augmentation import get_imageNet_augmentation 12 | 13 | FLOWERS_LABELS = [ 14 | "pink primrose", "hard-leaved pocket orchid", "canterbury bells", 15 | "sweet pea", "english marigold", "tiger lily", "moon orchid", 16 | "bird of paradise", "monkshood", "globe thistle", "snapdragon", 17 | "colt's foot", "king protea", "spear thistle", "yellow iris", 18 | "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", 19 | "giant white arum lily", "fire lily", "pincushion flower", "fritillary", 20 | "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", 21 | "stemless gentian", "artichoke", "sweet william", "carnation", 22 | "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", 23 | "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", 24 | "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", 25 | "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", 26 | "common dandelion", "petunia", "wild pansy", "primula", "sunflower", 27 | "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", 28 | "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", 29 | "black-eyed susan", "silverbush", "californian poppy", "osteospermum", 30 | "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", 31 | "azalea", "water lily", "rose", "thorn apple", "morning glory", 32 | "passion flower", "lotus", "toad lily", "anthurium", "frangipani", 33 | "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", 34 | "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", 35 | "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", 36 | "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", 37 | "blackberry lily" 38 | ] 39 | 40 | def get_flowers_labels(): 41 | return FLOWERS_LABELS 42 | 43 | def get_flowers(split='train', batch_size=128, shuffle=True, augm_type='none', 44 | size=224, num_workers=8, config_dict=None): 45 | 46 | augm_config = {} 47 | transform = get_imageNet_augmentation(augm_type, out_size=size, config_dict=augm_config) 48 | path = get_flowers_path() 49 | dataset = Flowers(path, split, transform=transform) 50 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 51 | shuffle=shuffle, num_workers=num_workers) 52 | 53 | if config_dict is not None: 54 | config_dict['Dataset'] = 'Flowers' 55 | config_dict['Batch out_size'] = batch_size 56 | config_dict['Augmentation'] = augm_config 57 | 58 | return loader 59 | 60 | 61 | class Flowers(Dataset): 62 | def __init__(self, root, split, transform=None): 63 | self.root = root 64 | self.labels = loadmat(os.path.join(root, 'imagelabels.mat'))['labels'][0].astype(np.long) 65 | self.transform = transform 66 | setids = loadmat(os.path.join(root, 'setid.mat')) 67 | 68 | if split == 'train': 69 | self.indices = setids['trnid'][0] 70 | elif split =='val': 71 | self.indices = setids['valid'][0] 72 | elif split =='train_val': 73 | trn_idcs = setids['trnid'][0] 74 | val_idcs = setids['valid'][0] 75 | self.indices = np.concatenate([trn_idcs, val_idcs]) 76 | elif split == 'test': 77 | self.indices = setids['tstid'][0] 78 | else: 79 | raise ValueError() 80 | 81 | self.indices = self.indices 82 | self.loader = default_loader 83 | self.length = len(self.indices) 84 | 85 | def __getitem__(self, index): 86 | img_idx = self.indices[index] 87 | #matlab starts with 1, so decrease both index and target idx by 1 88 | target = self.labels[img_idx - 1] - 1 89 | path = os.path.join(self.root, 'jpg', f'image_{img_idx:05d}.jpg') 90 | sample = self.loader(path) 91 | if self.transform is not None: 92 | sample = self.transform(sample) 93 | 94 | return sample, target 95 | 96 | def __len__(self): 97 | return self.length 98 | -------------------------------------------------------------------------------- /utils/datasets/food_101.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions 3 | from torch.utils.data import DataLoader, Dataset 4 | import os 5 | 6 | from .paths import get_food_101_path 7 | from .augmentations.imagenet_augmentation import get_imageNet_augmentation 8 | from torchvision.datasets.folder import default_loader 9 | 10 | 11 | def get_food_101_labels(): 12 | path = get_food_101_path() 13 | class_list = [] 14 | classes_file = os.path.join(path, 'meta', 'meta', 'classes.txt') 15 | with open(classes_file) as classestxt: 16 | for line_number, line in enumerate(classestxt): 17 | class_list.append(line.rstrip()) 18 | return class_list 19 | 20 | 21 | def get_food_101(split='train', batch_size=128, shuffle=True, augm_type='none', 22 | size=224, num_workers=8, config_dict=None): 23 | 24 | augm_config = {} 25 | transform = get_imageNet_augmentation(augm_type, out_size=size, config_dict=augm_config) 26 | 27 | path = get_food_101_path() 28 | dataset = Food101(path, split, transform=transform) 29 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 30 | shuffle=shuffle, num_workers=num_workers) 31 | 32 | if config_dict is not None: 33 | config_dict['Dataset'] = 'Food-101' 34 | config_dict['Batch out_size'] = batch_size 35 | config_dict['Augmentation'] = augm_config 36 | 37 | return loader 38 | 39 | #Adapted to Kaggle Food101 download 40 | #https://www.kaggle.com/kmader/food41 41 | class Food101(Dataset): 42 | def __init__(self, root, split, transform=None): 43 | class_labels = get_food_101_labels() 44 | self.root = root 45 | label_to_target = {label : target for target, label in enumerate(class_labels)} 46 | self.transform = transform 47 | 48 | if split == 'train': 49 | meta_txt = os.path.join(self.root, 'meta', 'meta', 'train.txt') 50 | elif split == 'val': 51 | meta_txt = os.path.join(self.root, 'meta', 'meta', 'test.txt') 52 | else: 53 | raise ValueError() 54 | 55 | self.img_label_list = [] 56 | with open(meta_txt)as fileID: 57 | for row in fileID: 58 | img = row.rstrip() 59 | target = label_to_target[img.split('/')[0]] 60 | self.img_label_list.append((img,target)) 61 | 62 | print(f'Food 101 {split} - {len(self.img_label_list)} Images') 63 | 64 | self.loader = default_loader 65 | self.length = len(self.img_label_list) 66 | 67 | def __getitem__(self, index): 68 | sub_path, target = self.img_label_list[index] 69 | path = os.path.join(self.root, 'images', sub_path + '.jpg') 70 | sample = self.loader(path) 71 | if self.transform is not None: 72 | sample = self.transform(sample) 73 | 74 | return sample, target 75 | 76 | def __len__(self): 77 | return self.length 78 | -------------------------------------------------------------------------------- /utils/datasets/food_101N.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions 3 | from torch.utils.data import DataLoader 4 | import os 5 | 6 | from .paths import get_food_101N_path 7 | from utils.datasets.augmentations.imagenet_augmentation import get_imageNet_augmentation 8 | from torchvision.datasets.folder import default_loader 9 | import csv 10 | 11 | 12 | def get_food_101N_labels(): 13 | path = get_food_101N_path() 14 | class_list = [] 15 | classes_file = os.path.join(path, 'meta', 'classes.txt') 16 | with open(classes_file) as classestxt: 17 | for line_number, line in enumerate(classestxt): 18 | if line_number > 0: #skip the first line as it's no class 19 | class_list.append(line.rstrip()) 20 | return class_list 21 | 22 | 23 | def get_food_101N(split='train', batch_size=128, shuffle=True, augm_type='none', 24 | size=224, num_workers=8): 25 | transform = get_imageNet_augmentation(augm_type, out_size=size) 26 | path = get_food_101N_path() 27 | dataset = Food101N(path, split, transform=transform) 28 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 29 | shuffle=shuffle, num_workers=num_workers) 30 | return loader 31 | 32 | 33 | class Food101N(torch.utils.data.Dataset): 34 | def __init__(self, root, split, verified_only=True, transform=None): 35 | class_labels = get_food_101N_labels() 36 | self.root = root 37 | label_to_target = {label:target for target, label in enumerate(class_labels)} 38 | self.transform = transform 39 | 40 | if split == 'train': 41 | meta_tsv = os.path.join(self.root, 'meta', 'verified_train.tsv') 42 | elif split == 'val': 43 | meta_tsv = os.path.join(self.root, 'meta', 'verified_val.tsv') 44 | else: 45 | raise ValueError() 46 | 47 | self.img_label_list = [] 48 | with open(meta_tsv)as tsvfile: 49 | reader = csv.DictReader(tsvfile, dialect='excel-tab') 50 | for row in reader: 51 | if verified_only and row['verification_label']: 52 | img = row['class_name/key'] 53 | target = label_to_target[img.split('/')[0]] 54 | self.img_label_list.append((img,target)) 55 | 56 | print(f'Food 101N {split} - Verified only {verified_only} - {len(self.img_label_list)} Images') 57 | 58 | self.loader = default_loader 59 | self.length = len(self.img_label_list) 60 | 61 | def __getitem__(self, index): 62 | sub_path, target = self.img_label_list[index] 63 | path = os.path.join(self.root, 'images', sub_path) 64 | sample = self.loader(path) 65 | if self.transform is not None: 66 | sample = self.transform(sample) 67 | 68 | return sample, target 69 | 70 | def __len__(self): 71 | return self.length 72 | -------------------------------------------------------------------------------- /utils/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions 3 | from torchvision import datasets 4 | from torch.utils.data import DataLoader 5 | from utils.datasets.augmentations.imagenet_augmentation import get_imageNet_augmentation 6 | from .paths import get_imagenet_path 7 | 8 | DEFAULT_TRAIN_BATCHSIZE = 128 9 | DEFAULT_TEST_BATCHSIZE = 128 10 | 11 | 12 | def get_imagenet_labels(): 13 | path = get_imagenet_path() 14 | dataset = datasets.ImageNet(path, split='val', transform='none') 15 | classes_extended = dataset.classes 16 | labels = [] 17 | for a in classes_extended: 18 | labels.append(a[0]) 19 | return labels 20 | 21 | def get_imagenet_label_wid_pairs(): 22 | path = get_imagenet_path() 23 | dataset = datasets.ImageNet(path, split='val', transform='none') 24 | classes_extended = dataset.classes 25 | wids = dataset.wnids 26 | 27 | label_wid_pairs = [] 28 | for a, b in zip(classes_extended, wids) : 29 | label_wid_pairs.append((a[0], b)) 30 | return label_wid_pairs 31 | 32 | def get_ImageNet(train=True, batch_size=None, shuffle=None, augm_type='test', num_workers=8, size=224, config_dict=None): 33 | if batch_size == None: 34 | if train: 35 | batch_size = DEFAULT_TRAIN_BATCHSIZE 36 | else: 37 | batch_size = DEFAULT_TEST_BATCHSIZE 38 | 39 | augm_config = {} 40 | transform = get_imageNet_augmentation(type=augm_type, out_size=size, config_dict=augm_config) 41 | if not train and augm_type != 'none': 42 | print('Warning: ImageNet test set with ref_data augmentation') 43 | 44 | if shuffle is None: 45 | shuffle = train 46 | 47 | path = get_imagenet_path() 48 | 49 | if train == True: 50 | dataset = datasets.ImageNet(path, split='train', transform=transform) 51 | else: 52 | dataset = datasets.ImageNet(path, split='val', transform=transform) 53 | 54 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 55 | shuffle=shuffle, num_workers=num_workers) 56 | 57 | if config_dict is not None: 58 | config_dict['Dataset'] = 'ImageNet' 59 | config_dict['Batch out_size'] = batch_size 60 | config_dict['Augmentation'] = augm_config 61 | 62 | return loader 63 | 64 | 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /utils/datasets/imagenet_natural_adversarials.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions 3 | from torchvision import datasets 4 | from torch.utils.data import DataLoader 5 | 6 | from .paths import get_imagenet_o_path 7 | from utils.datasets.augmentations.imagenet_augmentation import get_imageNet_augmentation 8 | 9 | DEFAULT_TRAIN_BATCHSIZE = 128 10 | DEFAULT_TEST_BATCHSIZE = 128 11 | 12 | def get_imagenet_o(batch_size=None, shuffle=False, augm_type='none', 13 | num_workers=8, size=224): 14 | if batch_size == None: 15 | batch_size = DEFAULT_TEST_BATCHSIZE 16 | 17 | transform = get_imageNet_augmentation(type=augm_type, out_size=size) 18 | 19 | path = get_imagenet_o_path() 20 | 21 | dataset = datasets.ImageFolder(path, transform=transform) 22 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 23 | shuffle=shuffle, num_workers=num_workers) 24 | return loader 25 | -------------------------------------------------------------------------------- /utils/datasets/lsun.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions 3 | from torchvision import datasets, transforms 4 | 5 | from .paths import get_base_data_dir, get_LSUN_scenes_path 6 | from torch.utils.data import DataLoader, SubsetRandomSampler 7 | 8 | DEFAULT_TRAIN_BATCHSIZE = 128 9 | DEFAULT_TEST_BATCHSIZE = 128 10 | 11 | from utils.datasets.augmentations.imagenet_augmentation import get_imageNet_augmentation 12 | from utils.datasets.augmentations.cifar_augmentation import get_cifar10_augmentation 13 | 14 | # LSUN classroom 15 | def get_LSUN_CR(train=False, batch_size=None, size=32): 16 | if train: 17 | ValueError('Warning: Training set for LSUN not available') 18 | if batch_size is None: 19 | batch_size=DEFAULT_TEST_BATCHSIZE 20 | 21 | transform = transforms.Compose([ 22 | transforms.Resize(size=(size, size)), 23 | transforms.ToTensor() 24 | ]) 25 | path = get_base_data_dir() 26 | data_dir = path + '/LSUN' 27 | dataset = datasets.LSUN(data_dir, classes=['classroom_val'], transform=transform) 28 | loader = DataLoader(dataset, batch_size=batch_size, 29 | shuffle=False, num_workers=4) 30 | return loader 31 | 32 | def get_LSUN_scenes(split='train', samples_per_class=None, batch_size=None, shuffle=None, augm_type='none', 33 | augm_class='imagenet', num_workers=8, size=224, config_dict=None): 34 | if batch_size is None: 35 | batch_size=DEFAULT_TEST_BATCHSIZE 36 | 37 | augm_config = {} 38 | 39 | if augm_class == 'imagenet': 40 | transform = get_imageNet_augmentation(type=augm_type, out_size=size, config_dict=augm_config) 41 | elif augm_class == 'cifar': 42 | raise NotImplementedError() 43 | transform = get_cifar10_augmentation(type=augm_type, out_size=size, in_size=224, config_dict=augm_config) 44 | else: 45 | raise NotImplementedError() 46 | path = get_LSUN_scenes_path() 47 | dataset = datasets.LSUN(path, classes=split, transform=transform) 48 | 49 | if samples_per_class is None: 50 | loader = DataLoader(dataset, batch_size=batch_size, 51 | shuffle=shuffle, num_workers=num_workers) 52 | 53 | else: 54 | num_classes = len(dataset.dbs) 55 | idcs = torch.zeros(num_classes, samples_per_class, dtype=torch.long) 56 | start_idx = 0 57 | for i in range(num_classes): 58 | idcs[i, :] = torch.arange(start_idx,start_idx + samples_per_class) 59 | start_idx = dataset.indices[i] 60 | idcs = idcs.view(-1).numpy() 61 | sampler = SubsetRandomSampler(idcs) 62 | loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, 63 | num_workers=num_workers) 64 | 65 | return loader 66 | 67 | 68 | def get_LSUN_scenes_labels(): 69 | return ['bedroom', 'bridge', 'church_outdoor', 'classroom', 70 | 'conference_room', 'dining_room', 'kitchen', 71 | 'living_room', 'restaurant', 'tower'] 72 | -------------------------------------------------------------------------------- /utils/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions 3 | from torchvision import datasets, transforms 4 | from torchvision.datasets.vision import VisionDataset 5 | 6 | DEFAULT_TRAIN_BATCHSIZE = 128 7 | DEFAULT_TEST_BATCHSIZE = 512 8 | 9 | def MNIST(train=True, batch_size=None, augm_flag=True, shuffle=None): 10 | if batch_size==None: 11 | if train: 12 | batch_size=DEFAULT_TRAIN_BATCHSIZE 13 | else: 14 | batch_size=DEFAULT_TEST_BATCHSIZE 15 | 16 | if shuffle is None: 17 | shuffle = train 18 | 19 | transform_base = [transforms.ToTensor()] 20 | transform_train = transforms.Compose([ 21 | transforms.RandomCrop(28, padding=2), 22 | ] + transform_base) 23 | transform_test = transforms.Compose(transform_base) 24 | 25 | transform_train = transforms.RandomChoice([transform_train, transform_test]) 26 | 27 | transform = transform_train if (augm_flag and train) else transform_test 28 | 29 | dataset = datasets.MNIST(path, train=train, transform=transform) 30 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 31 | shuffle=shuffle, num_workers=4) 32 | return loader 33 | 34 | def EMNIST(train=False, batch_size=None, augm_flag=False, shuffle=None): 35 | if batch_size==None: 36 | if train: 37 | batch_size=DEFAULT_TRAIN_BATCHSIZE 38 | else: 39 | batch_size=DEFAULT_TEST_BATCHSIZE 40 | 41 | if shuffle is None: 42 | shuffle = train 43 | 44 | transform_base = [transforms.ToTensor(), pre.Transpose()] #EMNIST is rotated 90 degrees from MNIST 45 | transform_train = transforms.Compose([ 46 | transforms.RandomCrop(28, padding=4), 47 | ] + transform_base) 48 | transform_test = transforms.Compose(transform_base) 49 | 50 | transform_train = transforms.RandomChoice([transform_train, transform_test]) 51 | 52 | transform = transform_train if (augm_flag and train) else transform_test 53 | 54 | dataset = datasets.EMNIST(path, split='letters', 55 | train=train, transform=transform) 56 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 57 | shuffle=shuffle, num_workers=1) 58 | return loader 59 | 60 | 61 | def FMNIST(train=False, batch_size=None, augm_flag=False, shuffle=None): 62 | if batch_size==None: 63 | if train: 64 | batch_size=DEFAULT_TRAIN_BATCHSIZE 65 | else: 66 | batch_size=DEFAULT_TEST_BATCHSIZE 67 | 68 | if shuffle is None: 69 | shuffle = train 70 | 71 | transform_base = [transforms.ToTensor()] 72 | transform_train = transforms.Compose([ 73 | transforms.RandomCrop(28, padding=2), 74 | ] + transform_base) 75 | transform_test = transforms.Compose(transform_base) 76 | 77 | transform_train = transforms.RandomChoice([transform_train, transform_test]) 78 | 79 | transform = transform_train if (augm_flag and train) else transform_test 80 | 81 | dataset = datasets.FashionMNIST(path, train=train, transform=transform) 82 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 83 | shuffle=shuffle, num_workers=1) 84 | return loader 85 | -------------------------------------------------------------------------------- /utils/datasets/noise_datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions 3 | from torchvision import transforms 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from utils.datasets.augmentations.cifar_augmentation import get_cifar10_augmentation 7 | 8 | DEFAULT_TRAIN_BATCHSIZE = 128 9 | DEFAULT_TEST_BATCHSIZE = 128 10 | 11 | 12 | def get_noise_dataset(length, type='normal', batch_size=128, augm_type='none', cutout_window=32, 13 | num_workers=8, size=32, config_dict=None): 14 | augm_config = {} 15 | transform = get_cifar10_augmentation(type=augm_type, cutout_window=cutout_window, out_size=size, 16 | in_size=size, config_dict=augm_config) 17 | 18 | dataset = NoiseDataset(length, type, size, transform) 19 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 20 | shuffle=False, num_workers=num_workers) 21 | 22 | if config_dict is not None: 23 | config_dict['Dataset'] = 'NoiseData' 24 | config_dict['Length'] = length 25 | config_dict['Noise Type'] = type 26 | config_dict['Batch size'] = batch_size 27 | config_dict['Augmentation'] = augm_config 28 | 29 | return loader 30 | 31 | class NoiseDataset(Dataset): 32 | def __init__(self, length, type, size, transform): 33 | assert type in ['uniform', 'normal'] 34 | 35 | self.type = type 36 | self.size = size 37 | self.length = length 38 | 39 | np.random.seed(123) 40 | if self.type == 'uniform': 41 | data_np = np.random.rand(length, 3, self.size, self.size).astype(np.float32) 42 | self.data = torch.from_numpy(data_np) 43 | elif self.type == 'normal': 44 | data_np = 0.5 + np.random.randn(length, 3, self.size, self.size).astype(np.float32) 45 | self.data = torch.clamp(torch.from_numpy(data_np), min=0, max=1) 46 | else: 47 | raise NotImplementedError() 48 | 49 | 50 | transform = transforms.Compose([ 51 | transforms.ToPILImage(), 52 | transform]) 53 | 54 | self.transform = transform 55 | 56 | def __getitem__(self, index): 57 | target = 0 58 | img = self.data[index].squeeze(dim=0) 59 | img = self.transform(img) 60 | 61 | return img, target 62 | 63 | def __len__(self): 64 | return self.length 65 | -------------------------------------------------------------------------------- /utils/datasets/openimages.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions 3 | from torch.utils.data import DataLoader 4 | from torchvision.datasets import ImageFolder 5 | import numpy as np 6 | 7 | from .paths import get_openimages_path 8 | from utils.datasets.augmentations.imagenet_augmentation import get_imageNet_augmentation 9 | import os 10 | 11 | DEFAULT_TRAIN_BATCHSIZE = 128 12 | DEFAULT_TEST_BATCHSIZE = 128 13 | 14 | class OpenImages(ImageFolder): 15 | def __init__(self, root, split, transform=None, target_transform=None, exclude_dataset=None): 16 | if split == 'train': 17 | path = os.path.join(root, 'train') 18 | elif split == 'val': 19 | path = os.path.join(root, 'val') 20 | elif split == 'test': 21 | raise NotImplementedError() 22 | path = os.path.join(root, 'test') 23 | else: 24 | raise ValueError() 25 | 26 | super().__init__(path, transform=transform, target_transform=target_transform) 27 | exclude_idcs = [] 28 | 29 | if exclude_dataset is not None and split == 'train': 30 | if exclude_dataset == 'imageNet100': 31 | duplicate_file = 'openImages_imageNet100_duplicates.txt' 32 | elif exclude_dataset == 'flowers': 33 | duplicate_file = 'utils/openImages_flowers_idxs.txt' 34 | elif exclude_dataset == 'pets': 35 | duplicate_file = 'utils/openImages_pets_idxs.txt' 36 | elif exclude_dataset == 'cars': 37 | duplicate_file = 'utils/openImages_cars_idxs.txt' 38 | elif exclude_dataset == 'food-101': 39 | duplicate_file = 'utils/openImages_food-101_idxs.txt' 40 | elif exclude_dataset == 'cifar10': 41 | print('Warning; CIFAR10 duplicates not checked') 42 | duplicate_file = None 43 | else: 44 | raise ValueError(f'Exclusion dataset {exclude_dataset} not supported') 45 | 46 | if duplicate_file is not None: 47 | with open(duplicate_file, 'r') as idxs: 48 | for idx in idxs: 49 | exclude_idcs.append(int(idx)) 50 | 51 | self.exclude_idcs = set(exclude_idcs) 52 | print(f'OpenImages {split} - Length: {len(self)} - Exclude images: {len(self.exclude_idcs)}') 53 | 54 | def __getitem__(self, index): 55 | while index in self.exclude_idcs: 56 | index = np.random.randint(len(self)) 57 | 58 | return super().__getitem__(index) 59 | 60 | def get_openImages(split='train', batch_size=128, shuffle=None, augm_type='none', num_workers=8, size=224, 61 | exclude_dataset=None, config_dict=None): 62 | 63 | augm_config = {} 64 | transform = get_imageNet_augmentation(type=augm_type, out_size=size, config_dict=augm_config) 65 | 66 | if shuffle is None: 67 | shuffle = True if split == 'train' else False 68 | 69 | path = get_openimages_path() 70 | 71 | dataset = OpenImages(path, split, transform=transform, exclude_dataset=exclude_dataset) 72 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 73 | shuffle=shuffle, num_workers=num_workers) 74 | 75 | if config_dict is not None: 76 | config_dict['Dataset'] = 'OpenImages' 77 | config_dict['Exclude Dataset'] = exclude_dataset 78 | config_dict['Length'] = len(dataset) 79 | config_dict['Batch out_size'] = batch_size 80 | config_dict['Augmentation'] = augm_config 81 | 82 | return loader 83 | -------------------------------------------------------------------------------- /utils/datasets/paths.py: -------------------------------------------------------------------------------- 1 | import socket 2 | print() 3 | import os 4 | 5 | def get_base_data_dir(): 6 | path = '/scratch/datasets/' 7 | return path 8 | 9 | def get_svhn_path(): 10 | return os.path.join(get_base_data_dir(), 'SVHN') 11 | 12 | def get_CIFAR10_path(): 13 | return os.path.join(get_base_data_dir(), 'CIFAR10') 14 | 15 | def get_CIFAR100_path(): 16 | return os.path.join(get_base_data_dir(), 'CIFAR100') 17 | 18 | def get_CIFAR10_C_path(): 19 | return os.path.join(get_base_data_dir(), 'CIFAR-10-C') 20 | 21 | def get_CIFAR100_C_path(): 22 | return os.path.join(get_base_data_dir(), 'CIFAR-100-C') 23 | 24 | def get_CINIC10_path(): 25 | return os.path.join(get_base_data_dir(), 'cinic_10') 26 | 27 | def get_celebA_path(): 28 | return get_base_data_dir() 29 | 30 | def get_stanford_cars_path(): 31 | return os.path.join(get_base_data_dir(), 'stanford_cars') 32 | 33 | def get_flowers_path(): 34 | return os.path.join(get_base_data_dir(), 'flowers') 35 | 36 | def get_pets_path(): 37 | return os.path.join(get_base_data_dir(), 'pets') 38 | 39 | def get_food_101N_path(): 40 | return os.path.join(get_base_data_dir(), 'Food-101N', 'Food-101N_release') 41 | 42 | def get_food_101_path(): 43 | return os.path.join(get_base_data_dir(), 'Food-101') 44 | 45 | def get_fgvc_aircraft_path(): 46 | return os.path.join(get_base_data_dir(), 'FGVC/fgvc-aircraft-2013b') 47 | 48 | def get_cub_path(): 49 | return os.path.join(get_base_data_dir(), 'CUB') 50 | 51 | def get_LSUN_scenes_path(): 52 | return os.path.join(get_base_data_dir(), 'LSUN_scenes') 53 | 54 | 55 | def get_tiny_images_files(shuffled=True): 56 | if shuffled == True: 57 | raise NotImplementedError() 58 | else: 59 | return os.path.join(get_base_data_dir(), '80M Tiny Images/tiny_images.bin') 60 | 61 | def get_tiny_images_lmdb(): 62 | raise NotImplementedError() 63 | 64 | def get_imagenet_path(): 65 | path = os.path.join(get_base_data_dir(), 'imagenet/') 66 | return path 67 | 68 | def get_imagenet_o_path(): 69 | return os.path.join(get_base_data_dir(), 'imagenet-o/') 70 | 71 | def get_openimages_path(): 72 | path = os.path.join(get_base_data_dir(), 'openimages/') 73 | return path 74 | 75 | def get_tiny_imagenet_path(): 76 | return os.path.join( get_base_data_dir(), 'TinyImageNet/tiny-imagenet-200/') -------------------------------------------------------------------------------- /utils/datasets/pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.distributions 5 | from torch.utils.data import DataLoader, Dataset 6 | from torchvision.datasets.folder import default_loader 7 | 8 | from utils.datasets.augmentations.imagenet_augmentation import get_imageNet_augmentation 9 | from .paths import get_pets_path 10 | 11 | class_labels = ['Abyssinian', 'american bulldog', 'american pit bull terrier', 'basset hound', 'beagle', 'Bengal', 12 | 'Birman', 'Bombay', 'boxer', 'British Shorthair', 'chihuahua', 'Egyptian Mau', 'english cocker spaniel', 13 | 'english setter', 'german shorthaired', 'great pyrenees', 'havanese', 'japanese chin', 'keeshond', 14 | 'leonberger', 'Maine Coon', 'miniature pinscher', 'newfoundland', 'Persian', 'pomeranian', 'pug', 15 | 'Ragdoll', 'Russian Blue', 'saint bernard', 'samoyed', 'scottish terrier', 'shiba inu', 'Siamese', 16 | 'Sphynx', 'staffordshire bull terrier', 'wheaten terrier', 'yorkshire terrier'] 17 | 18 | def get_pets_labels(): 19 | return class_labels 20 | 21 | def get_pets(split='train', batch_size=128, shuffle=True, augm_type='none', 22 | size=224, num_workers=8, config_dict=None): 23 | 24 | augm_config = {} 25 | transform = get_imageNet_augmentation(augm_type, out_size=size, config_dict=augm_config) 26 | path = get_pets_path() 27 | dataset = Pets(path, split, transform=transform) 28 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 29 | shuffle=shuffle, num_workers=num_workers) 30 | 31 | if config_dict is not None: 32 | config_dict['Dataset'] = 'Flowers' 33 | config_dict['Batch out_size'] = batch_size 34 | config_dict['Augmentation'] = augm_config 35 | 36 | return loader 37 | 38 | 39 | class Pets(Dataset): 40 | def __init__(self, root, split, transform=None): 41 | if split == 'train': 42 | annotations_file = os.path.join(root, 'annotations/trainval.txt') 43 | elif split == 'test': 44 | annotations_file = os.path.join(root, 'annotations/test.txt') 45 | else: 46 | raise ValueError(f'Split {split} not supported') 47 | 48 | self.img_root = os.path.join(root, 'images') 49 | 50 | self.labels = [] 51 | self.imgs = [] 52 | 53 | with open(annotations_file, 'r') as fileID: 54 | for line in fileID: 55 | line_parts = line.rstrip().split(' ') 56 | img = line_parts[0] 57 | label = int(line_parts[1]) - 1 #labels are in range 1:37, transform to 0:36 58 | 59 | self.imgs.append(img) 60 | self.labels.append(label) 61 | 62 | self.transform = transform 63 | self.loader = default_loader 64 | self.length = len(self.imgs) 65 | 66 | def __getitem__(self, index): 67 | img = self.imgs[index] 68 | target = self.labels[index] 69 | path = os.path.join(self.img_root, f'{img}.jpg') 70 | sample = self.loader(path) 71 | if self.transform is not None: 72 | sample = self.transform(sample) 73 | 74 | return sample, target 75 | 76 | def __len__(self): 77 | return self.length 78 | -------------------------------------------------------------------------------- /utils/datasets/preproc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | 4 | import numpy as np 5 | import scipy.ndimage.filters as filters 6 | 7 | 8 | class Transpose(object): 9 | def __init__(self): 10 | pass 11 | def __call__(self, data): 12 | return data.transpose(-1,-2) 13 | 14 | 15 | class Gray(object): 16 | def __init__(self): 17 | pass 18 | def __call__(self, data): 19 | return data.mean(-3, keepdim=True) 20 | 21 | 22 | class PermutationNoise(object): 23 | def __init__(self): 24 | pass 25 | def __call__(self, data): 26 | shape = data.shape 27 | new_data = 0*data 28 | idx = [torch.tensor(np.random.permutation(np.prod(shape[-2:])))] 29 | for i, x in enumerate(data): 30 | new_data[i] = (x.view(np.prod(shape[-2:]))[idx]).view(shape[-2:]) 31 | return new_data 32 | 33 | 34 | class GaussianFilter(object): 35 | def __init__(self): 36 | pass 37 | def __call__(self, data): 38 | sigma = 1.+1.5*torch.rand(1).item() 39 | return torch.tensor(filters.gaussian_filter(data, sigma, mode='reflect')) 40 | 41 | 42 | class ContrastRescaling(object): 43 | def __init__(self): 44 | pass 45 | def __call__(self, data): 46 | gamma = 5+ 25.*torch.rand(1).item() 47 | return torch.sigmoid(gamma*(data-.5)) 48 | 49 | 50 | class AdversarialNoise(object): 51 | def __init__(self, model, device, epsilon=0.3): 52 | self.model = model 53 | self.pretransform = dl.noise_transform 54 | self.device = device 55 | self.epsilon = epsilon 56 | 57 | def __call__(self, data): 58 | perturbed = tt.generate_adv_noise(self.model, self.epsilon, 59 | device=self.device, batch_size=1, 60 | norm=20, num_of_it=40, 61 | alpha=0.01, seed_images=data.unsqueeze(0)) 62 | return perturbed.squeeze(0) -------------------------------------------------------------------------------- /utils/datasets/svhn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions 3 | from torchvision import datasets 4 | from torch.utils.data import Dataset 5 | 6 | from .combo_dataset import ComboDataset 7 | from .paths import get_svhn_path 8 | from utils.datasets.augmentations.svhn_augmentation import get_SVHN_augmentation 9 | 10 | DEFAULT_TRAIN_BATCHSIZE = 128 11 | DEFAULT_TEST_BATCHSIZE = 128 12 | 13 | 14 | def get_SVHN_labels(): 15 | class_labels = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 16 | return class_labels 17 | 18 | class SVHNTrainExtraCombo(ComboDataset): 19 | def __init__(self, transform=None): 20 | path = get_svhn_path() 21 | train = datasets.SVHN(path, split='train', transform=transform, download=True) 22 | extra = datasets.SVHN(path, split='extra', transform=transform, download=True) 23 | 24 | super().__init__([train, extra]) 25 | print(f'SVHN Train + Extra - Train: {len(train)} - Extra {len(extra)} - Total {self.length}') 26 | 27 | def get_SVHN(split='train', shuffle = None, batch_size=None, augm_type='none', size=32, num_workers=4, config_dict=None): 28 | if batch_size==None: 29 | if split in ['train', 'extra']: 30 | batch_size=DEFAULT_TRAIN_BATCHSIZE 31 | else: 32 | batch_size=DEFAULT_TEST_BATCHSIZE 33 | 34 | if shuffle is None: 35 | if split in ['train', 'extra']: 36 | shuffle = True 37 | else: 38 | shuffle = False 39 | 40 | augm_config = {} 41 | transform = get_SVHN_augmentation(augm_type, out_size=size, config_dict=augm_config) 42 | 43 | path = get_svhn_path() 44 | if split=='svhn_train_extra': 45 | dataset = SVHNTrainExtraCombo(transform) 46 | else: 47 | dataset = datasets.SVHN(path, split=split, transform=transform, download=True) 48 | 49 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 50 | shuffle=shuffle, num_workers=num_workers) 51 | 52 | if config_dict is not None: 53 | config_dict['Dataset'] = 'SVHN' 54 | config_dict['SVHN Split'] = split 55 | config_dict['Batch out_size'] = batch_size 56 | config_dict['Augmentation'] = augm_config 57 | 58 | return loader 59 | -------------------------------------------------------------------------------- /utils/datasets/tinyImages.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions 3 | from torch.utils.data import Dataset 4 | from torchvision import transforms 5 | import numpy as np 6 | import os 7 | 8 | from utils.datasets.augmentations.cifar_augmentation import get_cifar10_augmentation 9 | from .paths import get_tiny_images_files 10 | 11 | DEFAULT_TRAIN_BATCHSIZE = 128 12 | DEFAULT_TEST_BATCHSIZE = 128 13 | 14 | def get_80MTinyImages(batch_size=100, augm_type='default', shuffle=True, cutout_window=16, num_workers=1, 15 | size=32, exclude_cifar=False, exclude_cifar10_1=False, config_dict=None): 16 | #dataset is the dataset that will be excluded, eg CIFAR10 17 | if num_workers > 1: 18 | pass 19 | #raise ValueError('Bug in the current multithreaded tinyimages implementation') 20 | 21 | augm_config = {} 22 | transform = get_cifar10_augmentation(augm_type, cutout_window=cutout_window, out_size=size, config_dict=augm_config) 23 | 24 | dataset_out = TinyImagesDataset(transform, 25 | exclude_cifar=exclude_cifar, exclude_cifar10_1=exclude_cifar10_1) 26 | 27 | loader = torch.utils.data.DataLoader(dataset_out, batch_size=batch_size, 28 | shuffle=shuffle, num_workers=num_workers) 29 | 30 | if config_dict is not None: 31 | if config_dict is not None: 32 | config_dict['Dataset'] = '80M Tiny Images' 33 | config_dict['Shuffle'] = shuffle 34 | config_dict['Batch out_size'] = batch_size 35 | config_dict['Exclude CIFAR'] = exclude_cifar 36 | config_dict['Exclude CIFAR10.1'] = exclude_cifar10_1 37 | config_dict['Augmentation'] = augm_config 38 | 39 | return loader 40 | 41 | def _preload_tiny_images(idcs, file_id): 42 | imgs = np.zeros((len(idcs), 32, 32, 3), dtype='uint8') 43 | for lin_idx, idx in enumerate(idcs): 44 | imgs[lin_idx,:] = _load_tiny_image(idx, file_id) 45 | return imgs 46 | 47 | def _load_tiny_image(idx, file_id): 48 | try: 49 | file_id.seek(idx * 3072) 50 | data = file_id.read(3072) 51 | finally: 52 | pass 53 | 54 | data_np = np.fromstring(data, dtype='uint8').reshape(32, 32, 3, order="F") 55 | return data_np 56 | 57 | 58 | def _load_cifar_exclusion_idcs(exclude_cifar, exclude_cifar10_1): 59 | cifar_idxs = [] 60 | main_idcs_dir = 'TinyImagesExclusionIdcs/' 61 | 62 | our_exclusion_files = [ 63 | '80mn_cifar10_test_idxs.txt', 64 | '80mn_cifar100_test_idxs.txt', 65 | '80mn_cifar10_train_idxs.txt', 66 | '80mn_cifar100_train_idxs.txt', 67 | ] 68 | if exclude_cifar: 69 | with open(os.path.join(main_idcs_dir, '80mn_cifar_idxs.txt'), 'r') as idxs: 70 | for idx in idxs: 71 | # indices in file take the 80mn database to start at 1, hence "- 1" 72 | cifar_idxs.append(int(idx) - 1) 73 | 74 | for file in our_exclusion_files: 75 | with open(os.path.join(main_idcs_dir, file), 'r') as idxs: 76 | for idx in idxs: 77 | cifar_idxs.append(int(idx)) 78 | 79 | if exclude_cifar10_1: 80 | with open(os.path.join(main_idcs_dir, '80mn_cifar101_idxs.txt'), 'r') as idxs: 81 | for idx in idxs: 82 | cifar_idxs.append(int(idx)) 83 | 84 | cifar_idxs = torch.unique(torch.LongTensor(cifar_idxs)) 85 | return cifar_idxs 86 | 87 | TINY_LENGTH = 79302017 88 | 89 | # Code from https://github.com/hendrycks/outlier-exposure 90 | class TinyImagesDataset(Dataset): 91 | def __init__(self, transform_base, exclude_cifar=False, exclude_cifar10_1=False): 92 | self.data_location = get_tiny_images_files(False) 93 | self.memap = np.memmap(self.data_location, mode='r', dtype='uint8', order='C').reshape(TINY_LENGTH, -1) 94 | 95 | if transform_base is not None: 96 | transform = transforms.Compose([ 97 | transforms.ToPILImage(), 98 | transform_base]) 99 | else: 100 | transform = transforms.Compose([ 101 | transforms.ToPILImage(), 102 | transforms.ToTensor()]) 103 | 104 | self.transform = transform 105 | self.exclude_cifar = exclude_cifar 106 | 107 | exclusion_idcs = _load_cifar_exclusion_idcs(exclude_cifar, exclude_cifar10_1) 108 | 109 | self.included_indices = torch.ones(TINY_LENGTH, dtype=torch.long) 110 | self.included_indices[exclusion_idcs] = 0 111 | self.included_indices = torch.nonzero(self.included_indices, as_tuple=False).squeeze() 112 | self.length = len(self.included_indices) 113 | print(f'80M Tiny Images - Length {self.length} - Excluding {len(exclusion_idcs)} images') 114 | 115 | def __getitem__(self, ii): 116 | index = self.included_indices[ii] 117 | img = self.memap[index].reshape(32, 32, 3, order="F") 118 | 119 | if self.transform is not None: 120 | img = self.transform(img) 121 | 122 | return img, 0 # 0 is the class 123 | 124 | def __len__(self): 125 | return self.length 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /utils/datasets/tiny_image_net.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributions 4 | from .paths import get_tiny_imagenet_path 5 | from torch.utils.data import Dataset 6 | from torchvision.datasets import ImageFolder 7 | from utils.datasets.augmentations.cifar_augmentation import get_cifar10_augmentation 8 | 9 | DEFAULT_TRAIN_BATCHSIZE = 128 10 | DEFAULT_TEST_BATCHSIZE = 128 11 | 12 | 13 | def get_TinyImageNetClassNames(cleaned=True): 14 | class_labels = [] 15 | path = get_tiny_imagenet_path() 16 | with open(f'{path}label_clearnames.txt', 'r') as fileID: 17 | for line_idx, line in enumerate(fileID.readlines()): 18 | line_elements = str(line).split("\t") 19 | class_labels.append(line_elements[1].rstrip()) 20 | 21 | 22 | if cleaned: 23 | class_labels_cleaned = [] 24 | for label in class_labels: 25 | class_labels_cleaned.append(label.split(',')[0]) 26 | else: 27 | class_labels_cleaned = class_labels 28 | 29 | return class_labels_cleaned 30 | 31 | 32 | def get_TinyImageNet(split, batch_size=None, shuffle=None, augm_type='none', cutout_window=32, num_workers=8, size=64, config_dict=None): 33 | if batch_size == None: 34 | if split == 'train': 35 | batch_size = DEFAULT_TRAIN_BATCHSIZE 36 | else: 37 | batch_size = DEFAULT_TEST_BATCHSIZE 38 | 39 | augm_config = {} 40 | transform = get_cifar10_augmentation(type=augm_type, cutout_window=cutout_window, out_size=size, 41 | in_size=64, config_dict=augm_config) 42 | 43 | if shuffle is None: 44 | shuffle = True if split == 'train' else False 45 | 46 | 47 | path = get_tiny_imagenet_path() 48 | dataset = TinyImageNet(path, split, transform_base=transform) 49 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 50 | shuffle=shuffle, num_workers=num_workers) 51 | 52 | if config_dict is not None: 53 | config_dict['Dataset'] = 'TinyImageNet' 54 | config_dict['Batch size'] = batch_size 55 | config_dict['Augmentation'] = augm_config 56 | 57 | return loader 58 | 59 | class TinyImageNet(ImageFolder): 60 | def __init__(self, path, split, transform_base): 61 | assert split in ['train', 'test', 'val'] 62 | 63 | root = os.path.join(path, split) 64 | super().__init__(root, transform=transform_base) 65 | 66 | print(f'TinyImageNet {split} - Length {len(self)}') 67 | -------------------------------------------------------------------------------- /utils/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numbers 5 | import math 6 | 7 | 8 | class GaussianSmoothing(nn.Module): 9 | """ 10 | Apply gaussian smoothing on a 11 | 1d, 2d or 3d tensor. Filtering is performed seperately for each channel 12 | in the input using a depthwise convolution. 13 | Arguments: 14 | channels (int, sequence): Number of channels of the input tensors. Output will 15 | have this number of channels as well. 16 | kernel_size (int, sequence): Size of the gaussian kernel. 17 | sigma (float, sequence): Standard deviation of the gaussian kernel. 18 | dim (int, optional): The number of dimensions of the ref_data. 19 | Default value is 2 (spatial). 20 | """ 21 | def __init__(self, channels, kernel_size, sigma, dim=2): 22 | super(GaussianSmoothing, self).__init__() 23 | if isinstance(kernel_size, numbers.Number): 24 | kernel_size = [kernel_size] * dim 25 | if isinstance(sigma, numbers.Number): 26 | sigma = [sigma] * dim 27 | 28 | # The gaussian kernel is the product of the 29 | # gaussian function of each dimension. 30 | kernel = 1 31 | meshgrids = torch.meshgrid( 32 | [ 33 | torch.arange(size, dtype=torch.float32) 34 | for size in kernel_size 35 | ] 36 | ) 37 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 38 | mean = (size - 1) / 2 39 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ 40 | torch.exp(-((mgrid - mean) / std) ** 2 / 2) 41 | 42 | # Make sure sum of values in gaussian kernel equals 1. 43 | kernel = kernel / torch.sum(kernel) 44 | 45 | # Reshape to depthwise convolutional weight 46 | kernel = kernel.view(1, 1, *kernel.size()) 47 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 48 | 49 | self.register_buffer('weight', kernel) 50 | self.groups = channels 51 | 52 | if dim == 1: 53 | self.conv = F.conv1d 54 | elif dim == 2: 55 | self.conv = F.conv2d 56 | elif dim == 3: 57 | self.conv = F.conv3d 58 | else: 59 | raise RuntimeError( 60 | 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) 61 | ) 62 | 63 | def forward(self, input): 64 | """ 65 | Apply gaussian filter to input. 66 | Arguments: 67 | input (torch.Tensor): Input to apply gaussian filter on. 68 | Returns: 69 | filtered (torch.Tensor): Filtered output. 70 | """ 71 | return self.conv(input, weight=self.weight, groups=self.groups) 72 | 73 | -------------------------------------------------------------------------------- /utils/datasets/various.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions 3 | from torchvision import datasets, transforms 4 | from torchvision.datasets.vision import VisionDataset 5 | from .preproc import PermutationNoise, GaussianFilter, ContrastRescaling 6 | from .paths import get_base_data_dir, get_svhn_path, get_CIFAR100_path, get_CIFAR10_path 7 | import numpy as np 8 | 9 | 10 | DEFAULT_TRAIN_BATCHSIZE = 128 11 | DEFAULT_TEST_BATCHSIZE = 128 12 | 13 | def get_permutationNoise(dataset, train=True, batch_size=None): 14 | if batch_size==None: 15 | if train: 16 | batch_size=DEFAULT_TRAIN_BATCHSIZE 17 | else: 18 | batch_size=DEFAULT_TEST_BATCHSIZE 19 | transform = transforms.Compose([ 20 | transforms.ToTensor(), 21 | PermutationNoise(), 22 | GaussianFilter(), 23 | ContrastRescaling() 24 | ]) 25 | 26 | path = get_base_data_dir() 27 | if dataset=='MNIST': 28 | dataset = datasets.MNIST(path, train=train, transform=transform) 29 | elif dataset=='FMNIST': 30 | dataset = datasets.FashionMNIST(path, train=train, transform=transform) 31 | elif dataset=='SVHN': 32 | dataset = datasets.SVHN(get_svhn_path(), split='train' if train else 'test', transform=transform) 33 | elif dataset=='CIFAR10': 34 | dataset = datasets.CIFAR10(get_CIFAR10_path(), train=train, transform=transform) 35 | elif dataset=='CIFAR100': 36 | dataset = datasets.CIFAR100(get_CIFAR100_path(), train=train, transform=transform) 37 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 38 | shuffle=False, num_workers=4) 39 | #cifar_loader = PrecomputeLoader(cifar_loader, batch_size=batch_size, shuffle=True) 40 | return loader 41 | 42 | 43 | def get_UniformNoise(dataset, train=False, batch_size=None): 44 | if batch_size==None: 45 | if train: 46 | batch_size=DEFAULT_TRAIN_BATCHSIZE 47 | else: 48 | batch_size=DEFAULT_TEST_BATCHSIZE 49 | import torch.utils.data as data_utils 50 | 51 | if dataset in ['MNIST', 'FMNIST']: 52 | shape = (1, 28, 28) 53 | elif dataset in ['SVHN', 'CIFAR10', 'CIFAR100']: 54 | shape = (3, 32, 32) 55 | elif dataset in ['imageNet', 'restrictedImageNet']: 56 | shape = (3, 224, 224) 57 | 58 | data = torch.rand((100*batch_size,) + shape) 59 | train = data_utils.TensorDataset(data, torch.zeros(data.shape[0], device=data.device)) 60 | loader = torch.utils.data.DataLoader(train, batch_size=batch_size, 61 | shuffle=False, num_workers=1) 62 | return loader 63 | 64 | class UniormNoiseDataset(torch.utils.data.Dataset): 65 | def __init__(self, dim, length=100000000): 66 | 67 | def load_image(idx): 68 | return torch.rand(dim) 69 | 70 | self.load_image = load_image 71 | 72 | self.length = length 73 | 74 | transform = None 75 | 76 | self.transform = transform 77 | 78 | def __getitem__(self, index): 79 | 80 | img = self.load_image(index) 81 | if self.transform is not None: 82 | img = self.transform(img) 83 | 84 | return img, 0 # 0 is the class 85 | 86 | def __len__(self): 87 | return self.length 88 | 89 | 90 | def ImageNetMinusCifar10(train=False, batch_size=None, augm_flag=False): 91 | if train: 92 | print('Warning: Training set for ImageNet not available') 93 | if batch_size is None: 94 | batch_size = DEFAULT_TEST_BATCHSIZE 95 | 96 | path = get_base_data_dir() 97 | dir_imagenet = path + '/imagenet/val/' 98 | n_test_imagenet = 30000 99 | 100 | transform = transforms.ToTensor() 101 | 102 | dataset = torch.utils.data.Subset(datasets.ImageFolder(dir_imagenet, transform=transform), 103 | np.random.permutation(range(n_test_imagenet))[:10000]) 104 | loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 105 | shuffle=False, num_workers=1) 106 | return loader 107 | -------------------------------------------------------------------------------- /utils/find_nearest_neighbours.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pathlib 4 | import os 5 | import numpy as np 6 | import matplotlib as mpl 7 | mpl.use('Agg') 8 | import matplotlib.pyplot as plt 9 | import lpips 10 | 11 | class L2(torch.nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, ref_point, batch): 16 | batch = batch.view(batch.shape[0], -1) 17 | ref_point = ref_point.view(1, -1) 18 | 19 | l2 = torch.sqrt(torch.sum(((batch - ref_point) ** 2), dim=1)) 20 | return l2 21 | 22 | class FeatureDist(torch.nn.Module): 23 | def __init__(self, model): 24 | super().__init__() 25 | self.model = model 26 | 27 | def forward(self, ref_point, batch): 28 | all_data = torch.cat([ref_point, batch]) 29 | out = self.model(all_data) 30 | ref_feature = out[0,:].view(1, -1) 31 | batch_features = out[1:, :].view(batch.shape[0], -1) 32 | l2 = torch.sqrt(torch.sum(((batch_features - ref_feature) ** 2), dim=1)) 33 | return l2 34 | 35 | 36 | class LPIPS(torch.nn.Module): 37 | def __init__(self): 38 | super().__init__() 39 | self.loss_fn = lpips.LPIPS(net='alex') 40 | 41 | def forward(self, ref_point, batch): 42 | ref_batch = ref_point.expand_as(batch) 43 | sim = self.loss_fn(ref_batch, batch).squeeze() 44 | return sim 45 | 46 | def find_nearest_neighbours(dist_function, ref_batch, data_loader, data_set, device, num_neighbours, out_dir, out_prefix, is_similarity=False): 47 | distances = torch.zeros(ref_batch.shape[0], len(data_set)) 48 | ref_batch = ref_batch.to(device) 49 | 50 | data_idx = 0 51 | with torch.no_grad(): 52 | for data, _ in data_loader: 53 | data = data.to(device) 54 | 55 | for ref_i in range(ref_batch.shape[0]): 56 | ref_point = ref_batch[ref_i].unsqueeze(0) 57 | d_data = dist_function(ref_point, data) 58 | 59 | distances[ref_i, data_idx:(data_idx+data.shape[0])] = d_data.detach().cpu() 60 | 61 | data_idx += data.shape[0] 62 | 63 | pathlib.Path(out_dir).mkdir(parents=True, exist_ok=True) 64 | for ref_i in range(ref_batch.shape[0]): 65 | 66 | num_cols = 1 + num_neighbours 67 | scale_factor = 2 68 | fig, ax = plt.subplots(1, num_cols, figsize=(scale_factor * num_cols, 1.3 * scale_factor)) 69 | ax = np.expand_dims(ax, axis=0) 70 | # plot original: 71 | ax[0, 0].axis('off') 72 | ax[0, 0].title.set_text(f'Target') 73 | target_img = ref_batch[ref_i, :].permute(1, 2, 0).cpu().detach() 74 | ax[0, 0].imshow(target_img, interpolation='lanczos') 75 | 76 | d_i = distances[ref_i, :] 77 | if is_similarity: 78 | d_i_sort_idcs = torch.argsort(d_i, descending=True) 79 | else: 80 | d_i_sort_idcs = torch.argsort(d_i, descending=False) 81 | 82 | for j in range(num_neighbours): 83 | j_idx = d_i_sort_idcs[j] 84 | ref_img = data_set[j_idx][0].permute(1, 2, 0).cpu().detach() 85 | ax[0, j + 1].axis('off') 86 | ax[0, j + 1].imshow(ref_img, interpolation='lanczos') 87 | d_j = d_i[j_idx] 88 | ax[0, j + 1].title.set_text(f'{d_j:.3f}') 89 | 90 | plt.tight_layout() 91 | 92 | fig.savefig(os.path.join(out_dir, f'{out_prefix}_{ref_i}.png')) 93 | fig.savefig(os.path.join(out_dir, f'{out_prefix}_{ref_i}.pdf')) 94 | plt.close(fig) 95 | -------------------------------------------------------------------------------- /utils/model_normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class NormalizationWrapper(torch.nn.Module): 4 | def __init__(self, model, mean, std): 5 | super().__init__() 6 | 7 | mean = mean[..., None, None] 8 | std = std[..., None, None] 9 | 10 | self.train(model.training) 11 | 12 | self.model = model 13 | self.register_buffer("mean", mean) 14 | self.register_buffer("std", std) 15 | 16 | def forward(self, x, *args, **kwargs): 17 | x_normalized = (x - self.mean)/self.std 18 | return self.model(x_normalized, *args, **kwargs) 19 | 20 | def state_dict(self, destination=None, prefix='', keep_vars=False): 21 | return self.model.state_dict() 22 | 23 | def IdentityWrapper(model): 24 | mean = torch.tensor([0., 0., 0.]) 25 | std = torch.tensor([1., 1., 1.]) 26 | return NormalizationWrapper(model, mean, std) 27 | 28 | def Cifar10Wrapper(model): 29 | mean = torch.tensor([0.4913997551666284, 0.48215855929893703, 0.4465309133731618]) 30 | std = torch.tensor([0.24703225141799082, 0.24348516474564, 0.26158783926049628]) 31 | return NormalizationWrapper(model, mean, std) 32 | 33 | def Cifar100Wrapper(model): 34 | mean = torch.tensor([0.4913997551666284, 0.48215855929893703, 0.4465309133731618]) 35 | std = torch.tensor([0.24703225141799082, 0.24348516474564, 0.26158783926049628]) 36 | return NormalizationWrapper(model, mean, std) 37 | 38 | def SVHNWrapper(model): 39 | mean = torch.tensor([0.4377, 0.4438, 0.4728]) 40 | std = torch.tensor([0.1201, 0.1231, 0.1052]) 41 | return NormalizationWrapper(model, mean, std) 42 | 43 | def CelebAWrapper(model): 44 | mean = torch.tensor([0.5063, 0.4258, 0.3832]) 45 | std = torch.tensor([0.2632, 0.2424, 0.2385]) 46 | return NormalizationWrapper(model, mean, std) 47 | 48 | def TinyImageNetWrapper(model): 49 | mean = torch.tensor([0.4802, 0.4481, 0.3975]) 50 | std = torch.tensor([0.2302, 0.2265, 0.2262]) 51 | return NormalizationWrapper(model, mean, std) 52 | 53 | def ImageNetWrapper(model): 54 | mean = torch.tensor([0.485, 0.456, 0.406]) 55 | std = torch.tensor([0.229, 0.224, 0.225]) 56 | return NormalizationWrapper(model, mean, std) 57 | 58 | def RestrictedImageNetWrapper(model): 59 | mean = torch.tensor([0.4717, 0.4499, 0.3837]) 60 | std = torch.tensor([0.2600, 0.2516, 0.2575]) 61 | return NormalizationWrapper(model, mean, std) 62 | 63 | def BigTransferWrapper(model): 64 | mean = torch.tensor([0.5, 0.5, 0.5]) 65 | std = torch.tensor([0.5, 0.5, 0.5]) 66 | return NormalizationWrapper(model, mean, std) 67 | 68 | def LSUNScenesWrapper(model): 69 | #imagenet 70 | mean = torch.tensor([0.485, 0.456, 0.406]) 71 | std = torch.tensor([0.229, 0.224, 0.225]) 72 | return NormalizationWrapper(model, mean, std) 73 | -------------------------------------------------------------------------------- /utils/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/M4xim4l/InNOutRobustness/d81d1d26e5ebc9193009e3d92bd67b5e01d6cfd6/utils/models/__init__.py -------------------------------------------------------------------------------- /utils/models/big_transfer/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import KNOWN_MODELS -------------------------------------------------------------------------------- /utils/models/big_transfer_factory.py: -------------------------------------------------------------------------------- 1 | from .big_transfer import KNOWN_MODELS 2 | import numpy as np 3 | import os 4 | 5 | BIG_TRANSFER_MODEL_DIR = 'BigTransfer/' 6 | 7 | def build_model_big_transfer(model_name, num_classes, pretrained=True): 8 | model = KNOWN_MODELS[model_name](head_size=num_classes, zero_head=True) 9 | if pretrained: 10 | model.load_from(np.load(os.path.join(BIG_TRANSFER_MODEL_DIR, f"{model_name}.npz"))) 11 | 12 | 13 | return model, model_name -------------------------------------------------------------------------------- /utils/models/model_factory_224.py: -------------------------------------------------------------------------------- 1 | from timm.models.factory import create_model 2 | 3 | def build_model(model_name, num_classes, **kwargs): 4 | model_name = model_name.lower() 5 | if model_name == 'sslresnext50': 6 | model_name = 'SSLResNext50' 7 | model = create_model('ssl_resnext50_32x4d', num_classes=num_classes, **kwargs) 8 | config = dict(name=model_name, **kwargs) 9 | elif model_name == 'resnet50': 10 | model_name = 'ResNet50' 11 | model = create_model('resnet50', num_classes=num_classes, **kwargs) 12 | config = dict(name=model_name, **kwargs) 13 | elif model_name == 'tresnetm': 14 | model_name = 'TResNet-M' 15 | model = create_model('tresnet_m', num_classes=num_classes, **kwargs) 16 | config = dict(name=model_name, **kwargs) 17 | elif model_name == 'seresnext26t': 18 | model_name = 'SE-ResNeXt-26-T' 19 | model = create_model('seresnext26t_32x4d', num_classes=num_classes, **kwargs) 20 | config = dict(name=model_name, **kwargs) 21 | elif model_name == 'seresnext50': 22 | model_name = 'SE-ResNeXt-50' 23 | model = create_model('seresnext50_32x4d', num_classes=num_classes, **kwargs) 24 | config = dict(name=model_name, **kwargs) 25 | else: 26 | print(f'Net {model_name} not supported') 27 | raise NotImplemented() 28 | 29 | return model, model_name, config -------------------------------------------------------------------------------- /utils/models/model_factory_32.py: -------------------------------------------------------------------------------- 1 | from utils.models.models_32x32.resnet import ResNet50, ResNet18, ResNet34 2 | from utils.models.models_32x32.fixup_resnet import fixup_resnet20, fixup_resnet56 3 | from utils.models.models_32x32.wide_resnet import WideResNet28x2, WideResNet28x10, WideResNet28x20, WideResNet34x20, WideResNet40x10, WideResNet70x16, WideResNet34x10 4 | from timm.models.factory import create_model 5 | from utils.models.models_32x32.pyramid import aa_PyramidNet 6 | 7 | def try_number_conversion(s): 8 | try: 9 | value = float(s) 10 | return value 11 | except ValueError: 12 | return s 13 | 14 | def parse_params(params_list): 15 | params = {} 16 | 17 | if params_list is not None: 18 | assert len(params_list) % 2 == 0 19 | for i in range(len(params_list) // 2 ): 20 | key = params_list[2*i] 21 | value = params_list[2*i + 1] 22 | value = try_number_conversion(value) 23 | params[key] = value 24 | 25 | print(params) 26 | 27 | return params 28 | 29 | def build_model(model_name, num_classes, model_params=None): 30 | model_name = model_name.lower() 31 | model_config = parse_params(model_params) 32 | 33 | img_size = 32 34 | if model_name == 'resnet18': 35 | model = ResNet18(num_classes=num_classes) 36 | model_name = 'ResNet18' 37 | elif model_name == 'resnet34': 38 | model = ResNet34(num_classes=num_classes) 39 | model_name = 'ResNet34' 40 | elif model_name == 'resnet50': 41 | model = ResNet50(num_classes=num_classes) 42 | model_name = 'ResNet50' 43 | elif model_name == 'fixup_resnet20': 44 | model = fixup_resnet20(num_classes=num_classes) 45 | model_name = 'FixupResNet20' 46 | elif model_name == 'fixup_resnet56': 47 | model = fixup_resnet56(num_classes=num_classes) 48 | model_name = 'FixupResNet56' 49 | elif model_name == 'shakedrop_pyramid': 50 | model = aa_PyramidNet(depth=110, alpha=270, num_classes=num_classes) 51 | model_name = 'ShakedropPyramid' 52 | elif model_name == 'shakedrop_pyramid272': 53 | model = aa_PyramidNet(depth=272, alpha=200, num_classes=num_classes) 54 | model_name = 'ShakedropPyramid272' 55 | elif model_name == 'wideresnet28x2': 56 | model = WideResNet28x2(num_classes=num_classes, **model_config) 57 | model_name = 'WideResNet28x2' 58 | elif model_name == 'wideresnet28x10': 59 | model = WideResNet28x10(num_classes=num_classes) 60 | model_name = 'WideResNet28x10' 61 | elif model_name == 'wideresnet28x20': 62 | model = WideResNet28x20(num_classes=num_classes) 63 | model_name = 'WideResNet28x20' 64 | elif model_name == 'wideresnet34x10': 65 | model = WideResNet34x10(num_classes=num_classes, **model_config) 66 | model_name = 'WideResNet34x10' 67 | elif model_name == 'wideresnet34x20': 68 | model = WideResNet34x20(num_classes=num_classes) 69 | model_name = 'WideResNet34x20' 70 | elif model_name == 'wideresnet40x10': 71 | model = WideResNet40x10(num_classes=num_classes) 72 | model_name = 'WideResNet40x10' 73 | elif model_name == 'wideresnet70x16': 74 | model = WideResNet70x16(num_classes=num_classes) 75 | model_name = 'WideResNet70x16' 76 | elif model_name == 'vit-b16': 77 | model = create_model('vit_base_patch16_224_in21k', num_classes=num_classes, pretrained=True) 78 | model_name = 'ViT-B16' 79 | img_size = 224 80 | elif model_name == 'vit-b32': 81 | model = create_model('vit_base_patch32_224_in21k', num_classes=num_classes, pretrained=True) 82 | model_name = 'ViT-B32' 83 | img_size = 224 84 | else: 85 | print(f'Net {model_name} not supported') 86 | raise NotImplemented() 87 | 88 | return model, model_name, model_config, img_size 89 | 90 | -------------------------------------------------------------------------------- /utils/models/models_32x32/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/M4xim4l/InNOutRobustness/d81d1d26e5ebc9193009e3d92bd67b5e01d6cfd6/utils/models/models_32x32/__init__.py -------------------------------------------------------------------------------- /utils/models/models_32x32/fixup_resnet.py: -------------------------------------------------------------------------------- 1 | #https://github.com/hongyi-zhang/Fixup/blob/master/cifar/models/fixup_resnet_cifar.py 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | 7 | __all__ = ['FixupResNet', 'fixup_resnet20', 'fixup_resnet32', 'fixup_resnet44', 'fixup_resnet56', 'fixup_resnet110', 'fixup_resnet1202'] 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | 15 | 16 | class FixupBasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None): 20 | super(FixupBasicBlock, self).__init__() 21 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 22 | self.bias1a = nn.Parameter(torch.zeros(1)) 23 | self.conv1 = conv3x3(inplanes, planes, stride) 24 | self.bias1b = nn.Parameter(torch.zeros(1)) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.bias2a = nn.Parameter(torch.zeros(1)) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.scale = nn.Parameter(torch.ones(1)) 29 | self.bias2b = nn.Parameter(torch.zeros(1)) 30 | self.downsample = downsample 31 | 32 | def forward(self, x): 33 | identity = x 34 | 35 | out = self.conv1(x + self.bias1a) 36 | out = self.relu(out + self.bias1b) 37 | 38 | out = self.conv2(out + self.bias2a) 39 | out = out * self.scale + self.bias2b 40 | 41 | if self.downsample is not None: 42 | identity = self.downsample(x + self.bias1a) 43 | identity = torch.cat((identity, torch.zeros_like(identity)), 1) 44 | 45 | out += identity 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | 51 | class FixupResNet(nn.Module): 52 | 53 | def __init__(self, block, layers, num_classes=10): 54 | super(FixupResNet, self).__init__() 55 | self.num_layers = sum(layers) 56 | self.inplanes = 16 57 | self.conv1 = conv3x3(3, 16) 58 | self.bias1 = nn.Parameter(torch.zeros(1)) 59 | self.relu = nn.ReLU(inplace=True) 60 | self.layer1 = self._make_layer(block, 16, layers[0]) 61 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 62 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 63 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 64 | self.bias2 = nn.Parameter(torch.zeros(1)) 65 | self.fc = nn.Linear(64, num_classes) 66 | 67 | for m in self.modules(): 68 | if isinstance(m, FixupBasicBlock): 69 | nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5)) 70 | nn.init.constant_(m.conv2.weight, 0) 71 | elif isinstance(m, nn.Linear): 72 | nn.init.constant_(m.weight, 0) 73 | nn.init.constant_(m.bias, 0) 74 | 75 | def _make_layer(self, block, planes, blocks, stride=1): 76 | downsample = None 77 | if stride != 1: 78 | downsample = nn.AvgPool2d(1, stride=stride) 79 | 80 | layers = [] 81 | layers.append(block(self.inplanes, planes, stride, downsample)) 82 | self.inplanes = planes 83 | for _ in range(1, blocks): 84 | layers.append(block(planes, planes)) 85 | 86 | return nn.Sequential(*layers) 87 | 88 | def forward(self, x): 89 | x = self.conv1(x) 90 | x = self.relu(x + self.bias1) 91 | 92 | x = self.layer1(x) 93 | x = self.layer2(x) 94 | x = self.layer3(x) 95 | 96 | x = self.avgpool(x) 97 | x = x.view(x.size(0), -1) 98 | x = self.fc(x + self.bias2) 99 | 100 | return x 101 | 102 | 103 | def fixup_resnet20(**kwargs): 104 | """Constructs a Fixup-ResNet-20 density_model. 105 | """ 106 | model = FixupResNet(FixupBasicBlock, [3, 3, 3], **kwargs) 107 | return model 108 | 109 | 110 | def fixup_resnet32(**kwargs): 111 | """Constructs a Fixup-ResNet-32 density_model. 112 | """ 113 | model = FixupResNet(FixupBasicBlock, [5, 5, 5], **kwargs) 114 | return model 115 | 116 | 117 | def fixup_resnet44(**kwargs): 118 | """Constructs a Fixup-ResNet-44 density_model. 119 | """ 120 | model = FixupResNet(FixupBasicBlock, [7, 7, 7], **kwargs) 121 | return model 122 | 123 | 124 | def fixup_resnet56(**kwargs): 125 | """Constructs a Fixup-ResNet-56 density_model. 126 | """ 127 | model = FixupResNet(FixupBasicBlock, [9, 9, 9], **kwargs) 128 | return model 129 | 130 | 131 | def fixup_resnet110(**kwargs): 132 | """Constructs a Fixup-ResNet-110 density_model. 133 | """ 134 | model = FixupResNet(FixupBasicBlock, [18, 18, 18], **kwargs) 135 | return model 136 | 137 | 138 | def fixup_resnet1202(**kwargs): 139 | """Constructs a Fixup-ResNet-1202 density_model. 140 | """ 141 | model = FixupResNet(FixupBasicBlock, [200, 200, 200], **kwargs) 142 | return model -------------------------------------------------------------------------------- /utils/models/models_32x32/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, in_planes, planes, stride=1): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | 21 | self.shortcut = nn.Sequential() 22 | if stride != 1 or in_planes != self.expansion*planes: 23 | self.shortcut = nn.Sequential( 24 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 25 | nn.BatchNorm2d(self.expansion*planes) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(self.conv1(x))) 30 | out = self.bn2(self.conv2(out)) 31 | out += self.shortcut(x) 32 | out = F.relu(out) 33 | return out 34 | 35 | 36 | class Bottleneck(nn.Module): 37 | expansion = 4 38 | 39 | def __init__(self, in_planes, planes, stride=1): 40 | super(Bottleneck, self).__init__() 41 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 42 | self.bn1 = nn.BatchNorm2d(planes) 43 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 46 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 47 | 48 | self.shortcut = nn.Sequential() 49 | if stride != 1 or in_planes != self.expansion*planes: 50 | self.shortcut = nn.Sequential( 51 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 52 | nn.BatchNorm2d(self.expansion*planes) 53 | ) 54 | 55 | def forward(self, x): 56 | out = F.relu(self.bn1(self.conv1(x))) 57 | out = F.relu(self.bn2(self.conv2(out))) 58 | out = self.bn3(self.conv3(out)) 59 | out += self.shortcut(x) 60 | out = F.relu(out) 61 | return out 62 | 63 | 64 | class ResNet(nn.Module): 65 | def __init__(self, block, num_blocks, num_classes=10): 66 | super(ResNet, self).__init__() 67 | self.in_planes = 64 68 | 69 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 70 | self.bn1 = nn.BatchNorm2d(64) 71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 75 | self.linear = nn.Linear(512*block.expansion, num_classes) 76 | 77 | def _make_layer(self, block, planes, num_blocks, stride): 78 | strides = [stride] + [1]*(num_blocks-1) 79 | layers = [] 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, stride)) 82 | self.in_planes = planes * block.expansion 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x): 86 | out = F.relu(self.bn1(self.conv1(x))) 87 | out = self.layer1(out) 88 | out = self.layer2(out) 89 | out = self.layer3(out) 90 | out = self.layer4(out) 91 | out = F.avg_pool2d(out, 4) 92 | out = out.view(out.size(0), -1) 93 | out = self.linear(out) 94 | return out 95 | 96 | 97 | def ResNet18(num_classes=10): 98 | return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes) 99 | 100 | def ResNet34(num_classes=10): 101 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes) 102 | 103 | def ResNet50(num_classes=10): 104 | return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes) 105 | 106 | def ResNet101(num_classes=10): 107 | return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes) 108 | 109 | def ResNet152(num_classes=10): 110 | return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes) 111 | 112 | 113 | def test(): 114 | net = ResNet18() 115 | y = net(torch.randn(1,3,32,32)) 116 | print(y.size()) 117 | -------------------------------------------------------------------------------- /utils/models/models_32x32/shake_pyramidnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from utils.models.models_32x32.shakedrop import ShakeDrop 10 | 11 | 12 | class ShakeBasicBlock(nn.Module): 13 | 14 | def __init__(self, in_ch, out_ch, stride=1, p_shakedrop=1.0): 15 | super(ShakeBasicBlock, self).__init__() 16 | self.downsampled = stride == 2 17 | self.branch = self._make_branch(in_ch, out_ch, stride=stride) 18 | self.shortcut = not self.downsampled and None or nn.AvgPool2d(2) 19 | self.shake_drop = ShakeDrop(p_shakedrop) 20 | 21 | def forward(self, x): 22 | h = self.branch(x) 23 | h = self.shake_drop(h) 24 | h0 = x if not self.downsampled else self.shortcut(x) 25 | pad_zero = torch.zeros((h0.size(0), h.size(1) - h0.size(1), h0.size(2), h0.size(3)), dtype=torch.float, device=x.device) 26 | h0 = torch.cat([h0, pad_zero], dim=1) 27 | 28 | return h + h0 29 | 30 | def _make_branch(self, in_ch, out_ch, stride=1): 31 | return nn.Sequential( 32 | nn.BatchNorm2d(in_ch), 33 | nn.Conv2d(in_ch, out_ch, 3, padding=1, stride=stride, bias=False), 34 | nn.BatchNorm2d(out_ch), 35 | nn.ReLU(inplace=True), 36 | nn.Conv2d(out_ch, out_ch, 3, padding=1, stride=1, bias=False), 37 | nn.BatchNorm2d(out_ch)) 38 | 39 | 40 | class ShakePyramidNet(nn.Module): 41 | def __init__(self, depth=110, alpha=270, label=10): 42 | super(ShakePyramidNet, self).__init__() 43 | in_ch = 16 44 | # for BasicBlock 45 | n_units = (depth - 2) // 6 46 | in_chs = [in_ch] + [in_ch + math.ceil((alpha / (3 * n_units)) * (i + 1)) for i in range(3 * n_units)] 47 | block = ShakeBasicBlock 48 | 49 | self.in_chs, self.u_idx = in_chs, 0 50 | self.ps_shakedrop = [1 - (1.0 - (0.5 / (3 * n_units)) * (i + 1)) for i in range(3 * n_units)] 51 | 52 | self.c_in = nn.Conv2d(3, in_chs[0], 3, padding=1) 53 | self.bn_in = nn.BatchNorm2d(in_chs[0]) 54 | self.layer1 = self._make_layer(n_units, block, 1) 55 | self.layer2 = self._make_layer(n_units, block, 2) 56 | self.layer3 = self._make_layer(n_units, block, 2) 57 | self.bn_out = nn.BatchNorm2d(in_chs[-1]) 58 | self.fc_out = nn.Linear(in_chs[-1], label) 59 | 60 | # Initialize paramters 61 | for m in self.modules(): 62 | if isinstance(m, nn.Conv2d): 63 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 64 | m.weight.data.normal_(0, math.sqrt(2. / n)) 65 | elif isinstance(m, nn.BatchNorm2d): 66 | m.weight.data.fill_(1) 67 | m.bias.data.zero_() 68 | elif isinstance(m, nn.Linear): 69 | m.bias.data.zero_() 70 | 71 | def forward(self, x): 72 | h = self.bn_in(self.c_in(x)) 73 | h = self.layer1(h) 74 | h = self.layer2(h) 75 | h = self.layer3(h) 76 | h = F.relu(self.bn_out(h)) 77 | h = F.avg_pool2d(h, 8) 78 | h = h.view(h.size(0), -1) 79 | h = self.fc_out(h) 80 | return h 81 | 82 | def _make_layer(self, n_units, block, stride=1): 83 | layers = [] 84 | for i in range(int(n_units)): 85 | layers.append(block(self.in_chs[self.u_idx], self.in_chs[self.u_idx+1], 86 | stride, self.ps_shakedrop[self.u_idx])) 87 | self.u_idx, stride = self.u_idx + 1, 1 88 | return nn.Sequential(*layers) 89 | -------------------------------------------------------------------------------- /utils/models/models_32x32/shakedrop.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | class ShakeDropFunction(torch.autograd.Function): 10 | @staticmethod 11 | def forward(ctx, x, training=True, p_drop=0.5, alpha_range=[-1, 1]): 12 | ctx.training = training 13 | ctx.p_drop = p_drop 14 | if training: 15 | gate = torch.empty(1, device=x.device).bernoulli_(1 - p_drop) 16 | ctx.save_for_backward(gate) 17 | if gate.item() == 0: 18 | alpha = torch.empty(x.size(0), device=x.device).uniform_(*alpha_range) 19 | alpha = alpha.view(alpha.size(0), 1, 1, 1).expand_as(x) 20 | return alpha * x 21 | else: 22 | return x 23 | else: 24 | return (1 - p_drop) * x 25 | 26 | @staticmethod 27 | def backward(ctx, grad_output): 28 | training = ctx.training 29 | p_drop = ctx.p_drop 30 | if training: 31 | gate = ctx.saved_tensors[0] 32 | if gate.item() == 0: 33 | beta = torch.empty(grad_output.size(0), device=grad_output.device).uniform_(0, 1) 34 | beta = beta.view(beta.size(0), 1, 1, 1).expand_as(grad_output) 35 | beta = Variable(beta) 36 | return beta * grad_output, None, None, None 37 | else: 38 | return grad_output, None, None, None 39 | else: 40 | return (1 - p_drop) * grad_output, None, None, None 41 | 42 | 43 | class ShakeDrop(nn.Module): 44 | 45 | def __init__(self, p_drop=0.5, alpha_range=[-1, 1]): 46 | super(ShakeDrop, self).__init__() 47 | self.p_drop = p_drop 48 | self.alpha_range = alpha_range 49 | 50 | def forward(self, x): 51 | return ShakeDropFunction.apply(x, self.training, self.p_drop, self.alpha_range) 52 | -------------------------------------------------------------------------------- /utils/models/models_32x32/wideresnet_carmon.py: -------------------------------------------------------------------------------- 1 | """Based on code from https://github.com/yaodongyu/TRADES""" 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class BasicBlock(nn.Module): 10 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 11 | super(BasicBlock, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.relu1 = nn.ReLU(inplace=True) 14 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(out_planes) 17 | self.relu2 = nn.ReLU(inplace=True) 18 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 19 | padding=1, bias=False) 20 | self.droprate = dropRate 21 | self.equalInOut = (in_planes == out_planes) 22 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 23 | padding=0, bias=False) or None 24 | 25 | def forward(self, x): 26 | if not self.equalInOut: 27 | x = self.relu1(self.bn1(x)) 28 | else: 29 | out = self.relu1(self.bn1(x)) 30 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 31 | if self.droprate > 0: 32 | out = F.dropout(out, p=self.droprate, training=self.training) 33 | out = self.conv2(out) 34 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 35 | 36 | 37 | class NetworkBlock(nn.Module): 38 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 39 | super(NetworkBlock, self).__init__() 40 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 41 | 42 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 43 | layers = [] 44 | for i in range(int(nb_layers)): 45 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 46 | return nn.Sequential(*layers) 47 | 48 | def forward(self, x): 49 | return self.layer(x) 50 | 51 | 52 | class WideResNet(nn.Module): 53 | def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0.0): 54 | super(WideResNet, self).__init__() 55 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 56 | assert ((depth - 4) % 6 == 0) 57 | n = (depth - 4) / 6 58 | block = BasicBlock 59 | # 1st conv before any network block 60 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 61 | padding=1, bias=False) 62 | # 1st block 63 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 64 | # 1st sub-block 65 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 66 | # 2nd block 67 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 68 | # 3rd block 69 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 70 | # global average pooling and classifier 71 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.fc = nn.Linear(nChannels[3], num_classes) 74 | self.nChannels = nChannels[3] 75 | 76 | for m in self.modules(): 77 | if isinstance(m, nn.Conv2d): 78 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 79 | m.weight.data.normal_(0, math.sqrt(2. / n)) 80 | elif isinstance(m, nn.BatchNorm2d): 81 | m.weight.data.fill_(1) 82 | m.bias.data.zero_() 83 | elif isinstance(m, nn.Linear): 84 | m.bias.data.zero_() 85 | 86 | def forward(self, x, return_prelogit=False): 87 | out = self.conv1(x) 88 | out = self.block1(out) 89 | out = self.block2(out) 90 | out = self.block3(out) 91 | out = self.relu(self.bn1(out)) 92 | out = F.avg_pool2d(out, 8) 93 | out = out.view(-1, self.nChannels) 94 | if return_prelogit: 95 | return self.fc(out), out 96 | else: 97 | return self.fc(out) 98 | 99 | -------------------------------------------------------------------------------- /utils/plotting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | def plot_samples(Y, data, dataset='MNIST'): 5 | for i in range(10): 6 | plt.subplot(2,5,i+1) 7 | string = '' 8 | for y in Y: 9 | string += ('\n' 10 | + classes_dict[dataset][y.argmax(1)[i].item()] 11 | + ": %.3f" % y[i].max().exp().item() ) 12 | 13 | plt.title(string) 14 | if dataset in ['MNIST', 'FMNIST']: 15 | plt.imshow(data[i].squeeze().detach().cpu(), cmap='gray', interpolation='none') 16 | elif dataset in ['CIFAR10', 'SVHN', 'CIFAR100']: 17 | plt.imshow(data[i].transpose(0,2).transpose(0,1).detach().cpu(), interpolation='none') 18 | plt.xticks([]) 19 | plt.yticks([]) 20 | plt.show() 21 | print('\n') 22 | 23 | 24 | classes_FMNIST = ( 25 | 'shirt', 26 | 'trousers', 27 | 'pullover', 28 | 'dress', 29 | 'coat', 30 | 'sandal', 31 | 'shirt', 32 | 'sneaker', 33 | 'bag', 34 | 'boot') 35 | 36 | 37 | classes_CIFAR10 = ('plane', 38 | 'car', 39 | 'bird', 40 | 'cat', 41 | 'deer', 42 | 'dog', 43 | 'frog', 44 | 'horse', 45 | 'ship', 46 | 'truck') 47 | 48 | 49 | classes_CIFAR100 = ['apple', 50 | 'aquarium_fish', 51 | 'baby', 52 | 'bear', 53 | 'beaver', 54 | 'bed', 55 | 'bee', 56 | 'beetle', 57 | 'bicycle', 58 | 'bottle', 59 | 'bowl', 60 | 'boy', 61 | 'bridge', 62 | 'bus', 63 | 'butterfly', 64 | 'camel', 65 | 'can', 66 | 'castle', 67 | 'caterpillar', 68 | 'cattle', 69 | 'chair', 70 | 'chimpanzee', 71 | 'clock', 72 | 'cloud', 73 | 'cockroach', 74 | 'couch', 75 | 'crab', 76 | 'crocodile', 77 | 'cup', 78 | 'dinosaur', 79 | 'dolphin', 80 | 'elephant', 81 | 'flatfish', 82 | 'forest', 83 | 'fox', 84 | 'girl', 85 | 'hamster', 86 | 'house', 87 | 'kangaroo', 88 | 'keyboard', 89 | 'lamp', 90 | 'lawn_mower', 91 | 'leopard', 92 | 'lion', 93 | 'lizard', 94 | 'lobster', 95 | 'man', 96 | 'maple_tree', 97 | 'motorcycle', 98 | 'mountain', 99 | 'mouse', 100 | 'mushroom', 101 | 'oak_tree', 102 | 'orange', 103 | 'orchid', 104 | 'otter', 105 | 'palm_tree', 106 | 'pear', 107 | 'pickup_truck', 108 | 'pine_tree', 109 | 'plain', 110 | 'plate', 111 | 'poppy', 112 | 'porcupine', 113 | 'possum', 114 | 'rabbit', 115 | 'raccoon', 116 | 'ray', 117 | 'road', 118 | 'rocket', 119 | 'rose', 120 | 'sea', 121 | 'seal', 122 | 'shark', 123 | 'shrew', 124 | 'skunk', 125 | 'skyscraper', 126 | 'snail', 127 | 'snake', 128 | 'spider', 129 | 'squirrel', 130 | 'streetcar', 131 | 'sunflower', 132 | 'sweet_pepper', 133 | 'table', 134 | 'tank', 135 | 'telephone', 136 | 'television', 137 | 'tiger', 138 | 'tractor', 139 | 'train', 140 | 'trout', 141 | 'tulip', 142 | 'turtle', 143 | 'wardrobe', 144 | 'whale', 145 | 'willow_tree', 146 | 'wolf', 147 | 'woman', 148 | 'worm'] 149 | 150 | 151 | classes_dict = {'MNIST': list(range(10)), 152 | 'FMNIST': classes_FMNIST, 153 | 'SVHN': list(range(10)), 154 | 'CIFAR10': classes_CIFAR10, 155 | 'CIFAR100': classes_CIFAR100, 156 | } 157 | -------------------------------------------------------------------------------- /utils/resize_right/__init__.py: -------------------------------------------------------------------------------- 1 | from .resize_right import resize, ResizeLayer -------------------------------------------------------------------------------- /utils/resize_right/interp_methods.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | torch = None 7 | 8 | try: 9 | import numpy 10 | except ImportError: 11 | numpy = None 12 | 13 | if numpy is None and torch is None: 14 | raise ImportError("Must have either Numpy or PyTorch but both not found") 15 | 16 | 17 | def set_framework_dependencies(x): 18 | if type(x) is numpy.ndarray: 19 | to_dtype = lambda a: a 20 | fw = numpy 21 | else: 22 | to_dtype = lambda a: a.to(x.dtype) 23 | fw = torch 24 | eps = fw.finfo(fw.float32).eps 25 | return fw, to_dtype, eps 26 | 27 | 28 | def cubic(x): 29 | fw, to_dtype, eps = set_framework_dependencies(x) 30 | absx = fw.abs(x) 31 | absx2 = absx ** 2 32 | absx3 = absx ** 3 33 | return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) + 34 | (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) * 35 | to_dtype((1. < absx) & (absx <= 2.))) 36 | 37 | 38 | def lanczos2(x): 39 | fw, to_dtype, eps = set_framework_dependencies(x) 40 | return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) / 41 | ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2)) 42 | 43 | 44 | def lanczos3(x, fw): 45 | fw, to_dtype, eps = set_framework_dependencies(x) 46 | return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) / 47 | ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3)) 48 | 49 | 50 | def linear(x, fw): 51 | fw, to_dtype, eps = set_framework_dependencies(x) 52 | return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) * 53 | to_dtype((0 <= x) & (x <= 1))) 54 | 55 | 56 | def box(x, fw): 57 | fw, to_dtype, eps = set_framework_dependencies(x) 58 | return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1)) 59 | -------------------------------------------------------------------------------- /utils/temperature_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | 6 | import numpy as np 7 | 8 | class TemperatureWrapper(nn.Module): 9 | def __init__(self, model, T=1.): 10 | super().__init__() 11 | 12 | self.train(model.training) 13 | 14 | self.model = model 15 | self.T = T 16 | 17 | def forward(self, x): 18 | logits = self.model(x) 19 | return logits / self.T 20 | 21 | @staticmethod 22 | def compute_temperature(model, loader, device): 23 | model.eval() 24 | logits = [] 25 | labels = [] 26 | with torch.no_grad(): 27 | for data, target in loader: 28 | data = data.to(device) 29 | 30 | logits.append(model(data).detach().cpu()) 31 | labels.append(target) 32 | 33 | logits = torch.cat(logits, 0) 34 | labels = torch.cat(labels, 0) 35 | 36 | ca = [] 37 | log_T = torch.linspace(-3., 1., 2000) 38 | 39 | for t in log_T: 40 | ca.append(TemperatureWrapper._get_ece_inner(logits / np.exp(t), labels)[0]) 41 | ece, idx = torch.stack(ca, 0).min(0) 42 | 43 | T = float(np.exp(log_T[idx])) 44 | return T 45 | 46 | @staticmethod 47 | def compute_ece(model, loader, device): 48 | model.eval() 49 | logits = [] 50 | labels = [] 51 | with torch.no_grad(): 52 | for data, target in loader: 53 | data = data.to(device) 54 | 55 | logits.append(model(data).detach().cpu()) 56 | labels.append(target) 57 | 58 | logits = torch.cat(logits, 0) 59 | labels = torch.cat(labels, 0) 60 | ece = TemperatureWrapper._get_ece_inner(logits, labels)[0] 61 | return ece 62 | 63 | @staticmethod 64 | def _get_ece_inner(logits, labels, n_bins=20): 65 | bin_boundaries = torch.linspace(0, 1, n_bins + 1) 66 | bin_lowers = bin_boundaries[:-1] 67 | bin_uppers = bin_boundaries[1:] 68 | 69 | softmaxes = F.softmax(logits, dim=1) 70 | confidences, predictions = torch.max(softmaxes, 1) 71 | accuracies = predictions.eq(labels) 72 | 73 | ece = torch.zeros(1, device=logits.device) 74 | for bin_lower, bin_upper in zip(bin_lowers, bin_uppers): 75 | # Calculated |confidence - accuracy| in each bin 76 | in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item()) 77 | prop_in_bin = in_bin.float().mean() 78 | if prop_in_bin.item() > 0: 79 | accuracy_in_bin = accuracies[in_bin].float().mean() 80 | avg_confidence_in_bin = confidences[in_bin].mean() 81 | ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin 82 | 83 | return ece 84 | 85 | -------------------------------------------------------------------------------- /utils/train_types/AdversarialACET.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | import utils.distances as d 7 | 8 | from .ACET_training import ACETObjective, ACETTargetedObjective 9 | from .Adversarial_training import AdversarialTraining, AdversarialLoss 10 | from .in_out_distribution_training import InOutDistributionTraining 11 | from .helpers import get_adversarial_attack, create_attack_config, get_distance 12 | from .train_loss import CrossEntropyProxy 13 | 14 | 15 | ###################################################### 16 | class AdversarialACET(InOutDistributionTraining): 17 | def __init__(self, model, id_attack_config, od_attack_config, optimizer_config, epochs, device, num_classes, 18 | train_clean=True, 19 | attack_loss='LogitsDiff', lr_scheduler_config=None, model_config=None, 20 | target_confidences=False, 21 | attack_obj='log_conf', train_obj='log_conf', od_weight=1., test_epochs=1, verbose=100, 22 | saved_model_dir='SavedModels', saved_log_dir='Logs'): 23 | 24 | id_distance = get_distance(id_attack_config['norm']) 25 | od_distance = get_distance(od_attack_config['norm']) 26 | 27 | if train_clean: 28 | id_clean_weight = 1.0 29 | id_adv_weight = 1.0 30 | else: 31 | id_clean_weight = 0.0 32 | id_adv_weight = 1.0 33 | 34 | super().__init__('AdvACET', model, id_distance, od_distance, optimizer_config, epochs, device, num_classes, 35 | train_clean=train_clean, id_weight=0.5, id_adv_weight=id_adv_weight, clean_weight=id_clean_weight, 36 | od_weight=0.5 * od_weight, od_clean_weight=0.0, od_adv_weight=1.0, 37 | lr_scheduler_config=lr_scheduler_config, 38 | model_config=model_config, test_epochs=test_epochs, verbose=verbose, 39 | saved_model_dir=saved_model_dir, saved_log_dir=saved_log_dir) 40 | 41 | # Adversarial specific 42 | self.id_attack_config = id_attack_config 43 | self.attack_loss = attack_loss 44 | 45 | # ACET specifics 46 | self.target_confidences = target_confidences 47 | self.od_attack_config = od_attack_config 48 | self.od_attack_obj = attack_obj 49 | self.od_train_obj = train_obj 50 | 51 | 52 | def _get_adversarialacet_config(self): 53 | config_dict = {} 54 | config_dict['Train Clean'] = self.train_clean 55 | config_dict['Adversarial Loss'] = self.attack_loss 56 | config_dict['ID Weight'] = self.id_weight 57 | config_dict['Clean Weight'] = self.clean_weight 58 | config_dict['Adversarial Weight'] = self.id_adv_weight 59 | 60 | config_dict['OD Targeted Confidences'] = self.target_confidences 61 | config_dict['OD Train Objective'] = self.od_train_obj 62 | config_dict['OD Attack_obj'] = self.od_attack_obj 63 | config_dict['OD Weight'] = self.od_weight 64 | config_dict['OD Adversarial Weight'] = self.od_adv_weight 65 | 66 | return config_dict 67 | 68 | def _get_train_type_config(self, loader_config=None): 69 | base_config = self._get_base_config() 70 | 71 | configs = {} 72 | configs['Base'] = base_config 73 | configs['ID Attack'] = self.id_attack_config 74 | configs['AdversarialACET'] = self._get_adversarialacet_config() 75 | configs['OD Attack'] = self.od_attack_config 76 | configs['Optimizer'] = self.optimizer_config 77 | configs['Scheduler'] = self.lr_scheduler_config 78 | 79 | configs['Data Loader'] = loader_config 80 | configs['MSDA'] = self.msda_config 81 | configs['Model'] = self.model_config 82 | 83 | return configs 84 | 85 | def _get_id_criterion(self, epoch, model, name_prefix='ID'): 86 | id_train_criterion = AdversarialLoss(model, epoch, self.id_attack_config, self.classes, 87 | inner_objective=self.attack_loss, 88 | log_stats=True, name_prefix=name_prefix) 89 | return id_train_criterion 90 | 91 | def _get_od_clean_criterion(self, epoch, model, name_prefix='OD'): 92 | return None 93 | 94 | def _get_od_criterion(self, epoch, model, name_prefix='OD'): 95 | if self.target_confidences: 96 | train_criterion = ACETTargetedObjective(model, epoch, self.od_attack_config, self.od_train_obj, 97 | self.od_attack_obj, self.classes, 98 | log_stats=True, name_prefix=name_prefix) 99 | else: 100 | train_criterion = ACETObjective(model, epoch, self.od_attack_config, self.od_train_obj, 101 | self.od_attack_obj, self.classes, 102 | log_stats=True, name_prefix=name_prefix) 103 | return train_criterion 104 | 105 | def _get_od_attack(self, epoch, att_criterion): 106 | return get_adversarial_attack(self.od_attack_config, self.model, att_criterion, num_classes=self.classes, 107 | epoch=epoch) 108 | -------------------------------------------------------------------------------- /utils/train_types/Adversarial_training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | from .in_distribution_training import InDistributionTraining 7 | from .train_loss import MinMaxLoss, TrainLoss 8 | from .helpers import get_adversarial_attack, create_attack_config, get_distance 9 | 10 | from utils.adversarial_attacks import * 11 | from utils.distances import LPDistance 12 | 13 | class AdversarialLoss(MinMaxLoss): 14 | def __init__(self, model, epoch, attack_config, num_classes, inner_objective='crossentropy', log_stats=False, name_prefix=None): 15 | super().__init__('AdversarialLoss', 'log_probabilities', log_stats=log_stats, name_prefix=name_prefix) 16 | self.attack = get_adversarial_attack(attack_config, model, inner_objective, num_classes=num_classes, epoch=epoch) 17 | 18 | def inner_max(self, data, target): 19 | adv_samples = self.attack(data, target, targeted=False) 20 | return adv_samples 21 | 22 | def forward(self, data, model_out, orig_data, y, reduction='mean'): 23 | prep_out = self._prepare_input(model_out) 24 | loss_expanded = F.cross_entropy(prep_out, y, reduction='none' ) 25 | self._log_stats(loss_expanded) 26 | return TrainLoss.reduce(loss_expanded, reduction) 27 | 28 | class AdversarialTraining(InDistributionTraining): 29 | def __init__(self, model, id_attack_config, optimizer_config, epochs, device, num_classes, train_clean=True, 30 | attack_loss='logits_diff', lr_scheduler_config=None, model_config=None, 31 | test_epochs=1, verbose=100, saved_model_dir='SavedModels', 32 | saved_log_dir='Logs'): 33 | 34 | distance = get_distance(id_attack_config['norm']) 35 | 36 | if train_clean: 37 | clean_weight = 0.5 38 | adv_weight = 0.5 39 | else: 40 | clean_weight = 0.0 41 | adv_weight = 1.0 42 | 43 | super().__init__('Adversarial Training', model, distance, optimizer_config, epochs, device, num_classes, 44 | train_clean=train_clean, clean_weight=clean_weight, id_adv_weight=adv_weight, 45 | lr_scheduler_config=lr_scheduler_config, model_config=model_config, 46 | test_epochs=test_epochs, verbose=verbose, saved_model_dir=saved_model_dir, 47 | saved_log_dir=saved_log_dir) 48 | self.id_attack_config = id_attack_config 49 | self.attack_loss = attack_loss 50 | 51 | 52 | def _get_id_criterion(self, epoch, model, name_prefix='ID'): 53 | id_train_criterion = AdversarialLoss(model, epoch, self.id_attack_config, self.classes, inner_objective=self.attack_loss, 54 | log_stats=True, name_prefix=name_prefix) 55 | return id_train_criterion 56 | 57 | def _get_train_type_config(self, loader_config=None): 58 | base_config = self._get_base_config() 59 | adv_config = self._get_adversarial_training_config() 60 | 61 | configs = {} 62 | configs['Base'] = base_config 63 | configs['Adversarial Training'] = adv_config 64 | configs['ID Attack'] = self.id_attack_config 65 | configs['Optimizer'] = self.optimizer_config 66 | configs['Scheduler'] = self.lr_scheduler_config 67 | configs['MSDA'] = self.msda_config 68 | configs['Data Loader'] = loader_config 69 | configs['Model'] = self.model_config 70 | 71 | 72 | return configs 73 | 74 | def _get_adversarial_training_config(self): 75 | config_dict = {} 76 | config_dict['Train Clean'] = self.train_clean 77 | config_dict['Adversarial Loss'] = self.attack_loss 78 | config_dict['Clean Weight'] = self.clean_weight 79 | config_dict['Adversarial Weight'] = self.id_adv_weight 80 | return config_dict 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /utils/train_types/BCEACET_training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import utils.distances as d 6 | 7 | from .out_distribution_training import OutDistributionTraining 8 | from .train_loss import TrainLoss, MinMaxLoss, NegativeWrapper, BCELogitsProxy 9 | from .helpers import create_attack_config, get_adversarial_attack, get_distance 10 | 11 | 12 | class BCEACETObjective(MinMaxLoss): 13 | def __init__(self, model, epoch, attack_config, mask_features, min_features_mask, max_features_mask, num_features, 14 | log_stats=False, number_of_batches=None, name_prefix=None): 15 | super().__init__('BCEACET', expected_format='logits', log_stats=log_stats, name_prefix=name_prefix) 16 | self.attack_config = attack_config 17 | self.model = model 18 | self.epoch = epoch 19 | self.num_features = num_features 20 | self.mask_features = mask_features 21 | self.min_features_mask = min_features_mask 22 | self.max_features_mask = max_features_mask 23 | 24 | 25 | def _generate_mask(self, target): 26 | if self.mask_features: 27 | mask = torch.zeros_like(target, dtype=torch.float) 28 | for idx in range(mask.shape[0]): 29 | num_non_masked = torch.randint(self.min_features_mask, self.max_features_mask + 1, (1,)).item() 30 | non_masked_idcs = torch.randperm(mask.shape[1], device=mask.device)[:num_non_masked] 31 | mask[idx, non_masked_idcs] = 1 32 | 33 | self.mask = mask 34 | else: 35 | self.mask = None 36 | 37 | def inner_max(self, data, target): 38 | uniform_target = 0.5 * data.new_ones((data.shape[0],self.num_features), dtype=torch.float) 39 | self._generate_mask(uniform_target) 40 | self.obj = BCELogitsProxy(mask=self.mask, log_stats=False) 41 | neg_obj = NegativeWrapper(self.obj) 42 | self.attack = get_adversarial_attack(self.attack_config, self.model, neg_obj, 43 | num_classes=self.num_classes, epoch=self.epoch) 44 | adv_samples = self.attack(data, uniform_target, targeted=False) 45 | return adv_samples 46 | 47 | 48 | def forward(self, data, model_out, orig_data, y, reduction='mean'): 49 | prep_out = self._prepare_input(model_out) 50 | assert self.obj is not None, 'Inner Max has to be called first' 51 | uniform_target = 0.5 * data.new_ones((data.shape[0],self.num_features), dtype=torch.float) 52 | loss_expanded = self.obj(None, prep_out, None, uniform_target, reduction='none' ) 53 | self.obj = None 54 | self._log_stats(loss_expanded) 55 | return TrainLoss.reduce(loss_expanded, reduction) 56 | 57 | 58 | class BCEACETTraining(OutDistributionTraining): 59 | def __init__(self, model, od_attack_config, optimizer_config, epochs, device, num_classes, 60 | mask_features, min_features_mask, max_features_mask, lr_scheduler_config=None, lam=1., 61 | test_epochs=5, verbose=100, 62 | saved_model_dir='SavedModels', saved_log_dir='Logs'): 63 | 64 | distance = get_distance(od_attack_config['norm']) 65 | 66 | super().__init__('BCEACET', model, distance, optimizer_config, epochs, device, num_classes, 67 | clean_criterion='bce', lr_scheduler_config=lr_scheduler_config, od_weight=lam, 68 | test_epochs=test_epochs, verbose=verbose, saved_model_dir=saved_model_dir, 69 | saved_log_dir=saved_log_dir) 70 | 71 | self.od_attack_config = od_attack_config 72 | self.mask_featrues = mask_features 73 | self.min_features_mask = min_features_mask 74 | self.max_features_mask = max_features_mask 75 | 76 | 77 | def _get_od_criterion(self, epoch, model): 78 | train_criterion = BCEACETObjective(model, epoch, self.od_attack_config, self.mask_featrues, 79 | self.min_features_mask, self.max_features_mask, self.classes, 80 | log_stats=True, name_prefix='OD') 81 | return train_criterion 82 | 83 | def _get_od_attack(self, epoch, att_criterion): 84 | return get_adversarial_attack(self.od_attack_config, self.model, att_criterion, num_classes=self.classes, epoch=epoch) 85 | 86 | def _get_ACET_config(self): 87 | ACET_config = {'lambda': self.od_weight} 88 | return ACET_config 89 | 90 | def _get_train_type_config(self, loader_config=None): 91 | base_config = self._get_base_config() 92 | ACET_config = self._get_ACET_config() 93 | configs = {} 94 | configs['Base'] = base_config 95 | configs['ACET'] = ACET_config 96 | configs['OD Attack'] = self.od_attack_config 97 | configs['Optimizer'] = self.optimizer_config 98 | configs['Scheduler'] = self.lr_scheduler_config 99 | 100 | configs['Data Loader'] = loader_config 101 | configs['MSDA'] = self.msda_config 102 | configs['Model'] = model_config 103 | 104 | return configs -------------------------------------------------------------------------------- /utils/train_types/BCEAdversarial_training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .helpers import get_adversarial_attack, create_attack_config, get_distance 4 | from .in_distribution_training import InDistributionTraining 5 | from .train_loss import TrainLoss, MinMaxLoss, BCELogitsProxy, NegativeWrapper, BCAccuracyConfidenceLogger 6 | 7 | 8 | class BCEAdversarialLoss(MinMaxLoss): 9 | def __init__(self, model, epoch, attack_config, mask_features, min_features_mask, max_features_mask, num_classes, 10 | log_stats=False,name_prefix=None): 11 | super().__init__('AdversarialLoss', 'logits', log_stats=log_stats, name_prefix=name_prefix) 12 | self.attack_config = attack_config 13 | self.model = model 14 | self.epoch = epoch 15 | self.mask_features = mask_features 16 | self.min_features_mask = min_features_mask 17 | self.max_features_mask = max_features_mask 18 | self.num_classes = num_classes 19 | self.obj = None 20 | 21 | def _generate_mask(self, target): 22 | if self.mask_features: 23 | mask = torch.zeros_like(target, dtype=torch.float) 24 | for idx in range(mask.shape[0]): 25 | num_non_masked = torch.randint(self.min_features_mask, self.max_features_mask + 1, (1,)).item() 26 | non_masked_idcs = torch.randperm(mask.shape[1], device=mask.device)[:num_non_masked] 27 | mask[idx, non_masked_idcs] = 1 28 | 29 | self.mask = mask 30 | else: 31 | self.mask = None 32 | 33 | def inner_max(self, data, target): 34 | self._generate_mask(target) 35 | self.obj = BCELogitsProxy(mask=self.mask, log_stats=False) 36 | neg_obj = NegativeWrapper(self.obj) 37 | self.attack = get_adversarial_attack(self.attack_config, self.model, neg_obj, 38 | num_classes=self.num_classes, epoch=self.epoch) 39 | adv_samples = self.attack(data, target, targeted=False) 40 | 41 | return adv_samples 42 | 43 | def forward(self, data, model_out, orig_data, y, reduction='mean'): 44 | prep_out = self._prepare_input(model_out) 45 | assert self.obj is not None, 'Inner Max has to be called first' 46 | loss_expanded = self.obj(data, prep_out, orig_data, y, reduction='none') 47 | self.obj = None 48 | self._log_stats(loss_expanded) 49 | return TrainLoss.reduce(loss_expanded, reduction) 50 | 51 | 52 | class BCEAdversarialTraining(InDistributionTraining): 53 | def __init__(self, model, id_attack_config, optimizer_config, epochs, device, num_classes, train_clean=True, 54 | lr_scheduler_config=None, model_config=None, 55 | test_epochs=1, verbose=100, saved_model_dir='SavedModels', 56 | saved_log_dir='Logs'): 57 | distance = get_distance(id_attack_config['norm']) 58 | 59 | super().__init__('BCEAdversarial Training', model, distance, optimizer_config, 60 | epochs, device, num_classes, 61 | clean_criterion='bce', train_clean=train_clean, 62 | lr_scheduler_config=lr_scheduler_config, model_config=model_config, 63 | test_epochs=test_epochs, verbose=verbose, saved_model_dir=saved_model_dir, 64 | saved_log_dir=saved_log_dir) 65 | self.id_attack_config = id_attack_config 66 | 67 | def _get_id_criterion(self, epoch, model, name_prefix='ID'): 68 | id_train_criterion = BCEAdversarialLoss(model, epoch, self.id_attack_config, True, 1, 1, self.classes, 69 | log_stats=True, name_prefix=name_prefix) 70 | return id_train_criterion 71 | 72 | def _get_id_accuracy_conf_logger(self, name_prefix): 73 | return BCAccuracyConfidenceLogger(self.classes, name_prefix=name_prefix) 74 | 75 | def _get_train_type_config(self, loader_config=None): 76 | base_config = self._get_base_config() 77 | adv_config = self._get_adversarial_training_config() 78 | 79 | configs = {} 80 | configs['Base'] = base_config 81 | configs['Adversarial Training'] = adv_config 82 | configs['ID Attack'] = self.id_attack_config 83 | configs['Optimizer'] = self.optimizer_config 84 | configs['Scheduler'] = self.lr_scheduler_config 85 | 86 | configs['Data Loader'] = loader_config 87 | configs['MSDA'] = self.msda_config 88 | configs['Model'] = self.model_config 89 | 90 | return configs 91 | 92 | def _get_adversarial_training_config(self): 93 | config_dict = {} 94 | config_dict['train_clean'] = self.train_clean 95 | return config_dict 96 | -------------------------------------------------------------------------------- /utils/train_types/BCECEDA_training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import utils.distances as d 6 | 7 | from .out_distribution_training import OutDistributionTraining 8 | from .train_loss import TrainLoss, MinMaxLoss, BCELogitsProxy 9 | 10 | class BCEACETObjective(MinMaxLoss): 11 | def __init__(self, epoch, mask_features, min_features_mask, max_features_mask, num_features, 12 | log_stats=False, number_of_batches=None, name_prefix=None): 13 | super().__init__('BCEACET', expected_format='logits', log_stats=log_stats, name_prefix=name_prefix) 14 | self.epoch = epoch 15 | self.num_features = num_features 16 | self.mask_features = mask_features 17 | self.min_features_mask = min_features_mask 18 | self.max_features_mask = max_features_mask 19 | 20 | 21 | def _generate_mask(self, target): 22 | if self.mask_features: 23 | mask = torch.zeros_like(target, dtype=torch.float) 24 | for idx in range(mask.shape[0]): 25 | num_non_masked = torch.randint(self.min_features_mask, self.max_features_mask + 1, (1,)).item() 26 | non_masked_idcs = torch.randperm(mask.shape[1], device=mask.device)[:num_non_masked] 27 | mask[idx, non_masked_idcs] = 1 28 | 29 | self.mask = mask 30 | else: 31 | self.mask = None 32 | 33 | def inner_max(self, data, target): 34 | return data 35 | 36 | def forward(self, data, model_out, orig_data, y, reduction='mean'): 37 | prep_out = self._prepare_input(model_out) 38 | uniform_target = 0.5 * data.new_ones((data.shape[0],self.num_features), dtype=torch.float) 39 | self._generate_mask(uniform_target) 40 | obj = BCELogitsProxy(mask=self.mask, log_stats=False) 41 | loss_expanded = obj(None, prep_out, None, uniform_target, reduction='none' ) 42 | self._log_stats(loss_expanded) 43 | return TrainLoss.reduce(loss_expanded, reduction) 44 | 45 | class BCECEDATraining(OutDistributionTraining): 46 | def __init__(self, model, optimizer_config, epochs, device, num_classes, 47 | mask_features, min_features_mask, max_features_mask, 48 | lr_scheduler_config=None, 49 | lam=1., test_epochs=1, verbose=100, saved_model_dir= 'SavedModels', saved_log_dir= 'Logs'): 50 | 51 | distance = d.LPDistance(p=2) 52 | super().__init__('BCECEDA', model, distance, optimizer_config, epochs, device, num_classes, 53 | clean_criterion='bce', lr_scheduler_config=lr_scheduler_config, od_weight=lam, 54 | test_epochs=test_epochs, verbose=verbose, saved_model_dir=saved_model_dir, 55 | saved_log_dir=saved_log_dir) 56 | 57 | self.mask_features = mask_features 58 | self.min_features_mask = min_features_mask 59 | self.max_features_mask = max_features_mask 60 | 61 | def _get_od_criterion(self, epoch, model): 62 | train_criterion = BCEACETObjective(epoch, self.mask_features, self.min_features_mask, self.max_features_mask, self.classes, 63 | log_stats=True, name_prefix='OD') 64 | return None, train_criterion 65 | 66 | def _get_CEDA_config(self): 67 | CEDA_config = {'lambda': self.od_weight} 68 | return CEDA_config 69 | 70 | def _get_train_type_config(self, loader_config=None): 71 | base_config = self._get_base_config() 72 | ceda_config = self._get_CEDA_config() 73 | 74 | configs = {} 75 | configs['Base'] = base_config 76 | configs['CEDA'] = ceda_config 77 | configs['Optimizer'] = self.optimizer_config 78 | configs['Scheduler'] = self.lr_scheduler_config 79 | 80 | configs['Data Loader'] = loader_config 81 | configs['MSDA'] = self.msda_config 82 | configs['Model'] = model_config 83 | 84 | return configs 85 | 86 | 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /utils/train_types/TRADESACET_training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | import utils.distances as d 7 | 8 | from .ACET_training import ACETObjective 9 | from .CEDA_training import CEDAObjective 10 | from .TRADES_training import TRADESLoss 11 | from .in_out_distribution_training import InOutDistributionTraining 12 | from .train_loss import AccuracyConfidenceLogger, DistanceLogger, ConfidenceLogger, SingleValueLogger, NegativeWrapper 13 | from .helpers import interleave_forward, get_distance 14 | import torch.cuda.amp as amp 15 | 16 | import math 17 | ###################################################### 18 | class TRADESACETTraining(InOutDistributionTraining): 19 | def __init__(self, model, id_attack_config, od_attack_config, optimizer_config, epochs, device, num_classes, 20 | trades_weight=1, lr_scheduler_config=None, 21 | acet_obj='kl', od_weight=1., model_config=None, 22 | test_epochs=1, verbose=100, saved_model_dir= 'SavedModels', saved_log_dir= 'Logs'): 23 | 24 | id_distance = get_distance(id_attack_config['norm']) 25 | od_distance = get_distance(od_attack_config['norm']) 26 | 27 | 28 | super().__init__('TRADESACET', model, id_distance, od_distance, optimizer_config, epochs, device, num_classes, 29 | train_clean=False, id_trades=True, id_weight=0.5, clean_weight=1.0, id_adv_weight=trades_weight, 30 | od_trades=False, od_weight=0.5 * od_weight, od_clean_weight=0.0, od_adv_weight=1.0, 31 | lr_scheduler_config=lr_scheduler_config, 32 | model_config=model_config, test_epochs=test_epochs, 33 | verbose=verbose, saved_model_dir=saved_model_dir, saved_log_dir=saved_log_dir) 34 | 35 | #Trades 36 | self.id_attack_config = id_attack_config 37 | 38 | #od 39 | self.od_attack_config = od_attack_config 40 | self.acet_obj = acet_obj 41 | 42 | def requires_out_distribution(self): 43 | return True 44 | 45 | def create_loaders_dict(self, train_loader, test_loader=None, out_distribution_loader=None, out_distribution_test_loader=None, *args, **kwargs): 46 | train_loaders = { 47 | 'train_loader': train_loader, 48 | 'out_distribution_loader': out_distribution_loader 49 | } 50 | 51 | test_loaders = {} 52 | if test_loader is not None: 53 | test_loaders['test_loader'] = test_loader 54 | if out_distribution_test_loader is not None: 55 | test_loaders['out_distribution_test_loader'] = out_distribution_test_loader 56 | 57 | return train_loaders, test_loaders 58 | 59 | def _validate_loaders(self, train_loaders, test_loaders): 60 | if not 'train_loader' in train_loaders: 61 | raise ValueError('Train loader not given') 62 | if not 'out_distribution_loader' in train_loaders: 63 | raise ValueError('Out distribution loader is required for out distribution training') 64 | 65 | def _get_id_criterion(self, epoch, model, name_prefix='ID'): 66 | trades_reg = TRADESLoss(model, epoch, self.id_attack_config, self.classes, log_stats=True, name_prefix=name_prefix) 67 | return trades_reg 68 | 69 | def _get_od_clean_criterion(self, epoch, model, name_prefix='OD'): 70 | od_clean_criterion = None 71 | return od_clean_criterion 72 | 73 | def _get_od_criterion(self, epoch, model, name_prefix='OD'): 74 | od_criterion = ACETObjective(model, epoch, self.od_attack_config, self.acet_obj, self.acet_obj, 75 | self.classes, log_stats=True, name_prefix=name_prefix) 76 | return od_criterion 77 | 78 | def _get_TRADESACET_config(self): 79 | config_dict = {} 80 | config_dict['ID Weight'] = self.id_weight 81 | config_dict['Clean Weight'] = self.clean_weight 82 | config_dict['Trades Weight'] = self.id_adv_weight 83 | 84 | config_dict['OD Weight'] = self.od_weight 85 | config_dict['OD ACET Weight'] = self.od_adv_weight 86 | config_dict['OD ACET Objective'] = self.acet_obj 87 | return config_dict 88 | 89 | def _get_train_type_config(self, loader_config=None): 90 | base_config = self._get_base_config() 91 | tradesacet_config = self._get_TRADESACET_config() 92 | configs = {} 93 | configs['Base'] = base_config 94 | configs['TRADESACET'] = tradesacet_config 95 | configs['ID Attack'] = self.id_attack_config 96 | configs['OD Attack'] = self.od_attack_config 97 | configs['Optimizer'] = self.optimizer_config 98 | configs['Scheduler'] = self.lr_scheduler_config 99 | 100 | configs['Data Loader'] = loader_config 101 | configs['MSDA'] = self.msda_config 102 | configs['Model'] = self.model_config 103 | 104 | return configs 105 | 106 | -------------------------------------------------------------------------------- /utils/train_types/TRADESCEDA_training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | import utils.distances as d 7 | 8 | from .ACET_training import ACETObjective 9 | from .CEDA_training import CEDAObjective 10 | from .TRADES_training import TRADESLoss 11 | from .in_out_distribution_training import InOutDistributionTraining 12 | from .train_loss import AccuracyConfidenceLogger, DistanceLogger, ConfidenceLogger, SingleValueLogger, NegativeWrapper 13 | from .helpers import interleave_forward, get_distance 14 | import torch.cuda.amp as amp 15 | 16 | import math 17 | ###################################################### 18 | class TRADESCEDATraining(InOutDistributionTraining): 19 | def __init__(self, model, id_attack_config, od_attack_config, optimizer_config, epochs, device, num_classes, 20 | id_trades_weight=1., od_trades_weight=1., lr_scheduler_config=None, 21 | ceda_obj='kl', od_weight=1., model_config=None, 22 | test_epochs=1, verbose=100, saved_model_dir= 'SavedModels', saved_log_dir= 'Logs'): 23 | 24 | id_distance = get_distance(id_attack_config['norm']) 25 | od_distance = get_distance(od_attack_config['norm']) 26 | 27 | super().__init__('TRADESCEDA', model, id_distance, od_distance, optimizer_config, epochs, device, num_classes, 28 | train_clean=False, id_trades=True, id_weight=0.5, clean_weight=1.0, id_adv_weight=id_trades_weight, 29 | od_trades=True, od_weight=0.5 * od_weight, od_clean_weight=1.0, od_adv_weight=od_trades_weight, 30 | lr_scheduler_config=lr_scheduler_config, 31 | model_config=model_config, test_epochs=test_epochs, 32 | verbose=verbose, saved_model_dir=saved_model_dir, saved_log_dir=saved_log_dir) 33 | 34 | #Trades 35 | self.id_attack_config = id_attack_config 36 | 37 | #od 38 | self.od_attack_config = od_attack_config 39 | self.ceda_obj = ceda_obj 40 | 41 | def requires_out_distribution(self): 42 | return True 43 | 44 | def create_loaders_dict(self, train_loader, test_loader=None, out_distribution_loader=None, out_distribution_test_loader=None, *args, **kwargs): 45 | train_loaders = { 46 | 'train_loader': train_loader, 47 | 'out_distribution_loader': out_distribution_loader 48 | } 49 | 50 | test_loaders = {} 51 | if test_loader is not None: 52 | test_loaders['test_loader'] = test_loader 53 | if out_distribution_test_loader is not None: 54 | test_loaders['out_distribution_test_loader'] = out_distribution_test_loader 55 | 56 | return train_loaders, test_loaders 57 | 58 | def _validate_loaders(self, train_loaders, test_loaders): 59 | if not 'train_loader' in train_loaders: 60 | raise ValueError('Train loader not given') 61 | if not 'out_distribution_loader' in train_loaders: 62 | raise ValueError('Out distribution loader is required for out distribution training') 63 | 64 | def _get_id_criterion(self, epoch, model, name_prefix='ID'): 65 | trades_reg = TRADESLoss(model, epoch, self.id_attack_config, self.classes, log_stats=True, name_prefix=name_prefix) 66 | return trades_reg 67 | 68 | def _get_od_clean_criterion(self, epoch, model, name_prefix='OD'): 69 | od_clean_criterion = CEDAObjective(self.ceda_obj, self.classes, log_stats=True, name_prefix=name_prefix) 70 | return od_clean_criterion 71 | 72 | def _get_od_criterion(self, epoch, model, name_prefix='OD'): 73 | od_criterion = TRADESLoss(model, epoch, self.od_attack_config, self.classes, 74 | log_stats=True, name_prefix=name_prefix) 75 | return od_criterion 76 | 77 | def _get_TRADESACET_config(self): 78 | config_dict = {} 79 | config_dict['ID Weight'] = self.id_weight 80 | config_dict['Clean Weight'] = self.clean_weight 81 | config_dict['Trades Weight'] = self.id_adv_weight 82 | 83 | config_dict['OD Weight'] = self.od_weight 84 | config_dict['OD Clean Weight'] = self.od_clean_weight 85 | config_dict['OD Trades Weight'] = self.od_adv_weight 86 | config_dict['OD CEDA Objective'] = self.ceda_obj 87 | return config_dict 88 | 89 | def _get_train_type_config(self, loader_config=None): 90 | base_config = self._get_base_config() 91 | tradesacet_config = self._get_TRADESACET_config() 92 | configs = {} 93 | configs['Base'] = base_config 94 | configs['TRADESACET'] = tradesacet_config 95 | configs['ID Attack'] = self.id_attack_config 96 | configs['OD Attack'] = self.od_attack_config 97 | configs['Optimizer'] = self.optimizer_config 98 | configs['Scheduler'] = self.lr_scheduler_config 99 | 100 | configs['Data Loader'] = loader_config 101 | configs['MSDA'] = self.msda_config 102 | configs['Model'] = self.model_config 103 | 104 | return configs 105 | 106 | -------------------------------------------------------------------------------- /utils/train_types/TRADES_training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | from .in_distribution_training import InDistributionTraining 7 | from .train_loss import LoggingLoss, TrainLoss, CrossEntropyProxy, AccuracyConfidenceLogger, DistanceLogger,\ 8 | SingleValueLogger, KLDivergenceProxy, NegativeWrapper, MinMaxLoss 9 | from .helpers import interleave_forward, get_adversarial_attack, create_attack_config, get_distance 10 | 11 | class TRADESLoss(MinMaxLoss): 12 | def __init__(self, model, epoch, attack_config, num_classes, log_stats=False, name_prefix=None): 13 | super().__init__('TRADES', 'logits', log_stats=log_stats, name_prefix=name_prefix) 14 | self.model = model 15 | self.epoch = epoch 16 | self.attack_config = attack_config 17 | 18 | self.div = KLDivergenceProxy(log_stats=False) 19 | self.adv_attack = get_adversarial_attack(self.attack_config, self.model, 'kl', num_classes, epoch=self.epoch) 20 | 21 | def inner_max(self, data, target): 22 | is_train = self.model.training 23 | #attack is run in test mode so target distribution should also be estimated in test not train 24 | self.model.eval() 25 | target_distribution = F.softmax(self.model(data), dim=1).detach() 26 | x_adv = self.adv_attack(data, target_distribution) 27 | 28 | if is_train: 29 | self.model.train() 30 | else: 31 | self.model.eval() 32 | 33 | return x_adv.detach() 34 | 35 | #model out will be model out at adversarial samples 36 | #y will be the softmax distribution at original datapoint 37 | def forward(self, data, model_out, orig_data, y, reduction='mean'): 38 | prep_out = self._prepare_input(model_out) 39 | loss_expanded = self.div(data, prep_out, orig_data, y, reduction='none') 40 | self._log_stats(loss_expanded) 41 | return TrainLoss.reduce(loss_expanded, reduction) 42 | 43 | #Base class for train types that use custom losses/attacks on the in distribution such as adversarial training 44 | class TRADESTraining(InDistributionTraining): 45 | def __init__(self, model, id_attack_config, optimizer_config, epochs, device, num_classes, trades_weight=1., 46 | lr_scheduler_config=None, model_config=None, test_epochs=1, verbose=100, 47 | saved_model_dir= 'SavedModels', saved_log_dir= 'Logs'): 48 | 49 | distance = get_distance(id_attack_config['norm']) 50 | 51 | super().__init__('TRADES', model, distance, optimizer_config, epochs, device, num_classes, 52 | train_clean=False, id_trades=True, clean_weight=1.0, id_adv_weight=trades_weight, 53 | lr_scheduler_config=lr_scheduler_config, model_config=model_config, test_epochs=test_epochs, 54 | verbose=verbose, saved_model_dir= saved_model_dir, saved_log_dir=saved_log_dir) 55 | 56 | self.id_attack_config = id_attack_config 57 | 58 | def _get_id_criterion(self, epoch, model, name_prefix='ID'): 59 | trades_reg = TRADESLoss(model, epoch, self.id_attack_config, self.classes, log_stats=True, name_prefix=name_prefix) 60 | return trades_reg 61 | 62 | def _get_TRADES_config(self): 63 | config_dict = {} 64 | config_dict['Clean Weight'] = self.clean_weight 65 | config_dict['Trades Weight'] = self.id_adv_weight 66 | return config_dict 67 | 68 | def _get_train_type_config(self, loader_config=None): 69 | base_config = self._get_base_config() 70 | adv_config = self._get_TRADES_config() 71 | 72 | configs = {} 73 | configs['Base'] = base_config 74 | configs['TRADES'] = adv_config 75 | configs['ID Attack'] = self.id_attack_config 76 | configs['Optimizer'] = self.optimizer_config 77 | configs['Scheduler'] = self.lr_scheduler_config 78 | 79 | configs['Data Loader'] = loader_config 80 | configs['MSDA'] = self.msda_config 81 | configs['Model'] = self.model_config 82 | 83 | return configs 84 | -------------------------------------------------------------------------------- /utils/train_types/__init__.py: -------------------------------------------------------------------------------- 1 | from .ACET_training import ACETTraining 2 | from .BCEACET_training import BCEACETTraining 3 | from .Adversarial_training import AdversarialTraining 4 | from .BCEAdversarial_training import BCEAdversarialTraining 5 | from .AdversarialACET import AdversarialACET 6 | from .CEDA_training import CEDATraining 7 | from .BCECEDA_training import BCECEDATraining 8 | from .plain_training import PlainTraining 9 | from .train_type import TrainType 10 | from .train_loss import TrainLoss 11 | from .TRADES_training import TRADESTraining 12 | from .TRADESACET_training import TRADESACETTraining 13 | from .TRADESCEDA_training import TRADESCEDATraining 14 | from .randomized_smoothing_training import RandomizedSmoothingTraining 15 | from .helpers import create_attack_config 16 | -------------------------------------------------------------------------------- /utils/train_types/msda/__init__.py: -------------------------------------------------------------------------------- 1 | from .config_creators import create_fmix_config, create_mixup_config 2 | from .factory import get_msda -------------------------------------------------------------------------------- /utils/train_types/msda/config_creators.py: -------------------------------------------------------------------------------- 1 | def create_fmix_config(decay_power=3.0, alpha=1.0): 2 | fmix_config = {'type': 'FMix', 'decay_power': decay_power, 'alpha': alpha} 3 | return fmix_config 4 | 5 | def create_mixup_config(alpha=1.0): 6 | fmix_config = {'type': 'Mixup', 'alpha': alpha} 7 | return fmix_config 8 | -------------------------------------------------------------------------------- /utils/train_types/msda/dummy_msda.py: -------------------------------------------------------------------------------- 1 | from .mixed_sample_data_augmentation import MixedSampleDataAugmentation 2 | 3 | class DummyMSDA(MixedSampleDataAugmentation): 4 | def __init__(self): 5 | super().__init__(None) 6 | 7 | def apply_mix(self, x): 8 | return x -------------------------------------------------------------------------------- /utils/train_types/msda/factory.py: -------------------------------------------------------------------------------- 1 | from .fmix import FMix 2 | from .mixup import Mixup 3 | from .dummy_msda import DummyMSDA 4 | 5 | def get_msda(loss, msda_config, log_stats=True, name_prefix=None): 6 | if msda_config is None: 7 | return loss, DummyMSDA() 8 | elif msda_config['type'] == 'FMix': 9 | fmix = FMix(loss, decay_power=msda_config['decay_power'], alpha=msda_config['alpha'], 10 | log_stats=log_stats, name_prefix=name_prefix) 11 | fmix_loss = fmix.loss 12 | return fmix_loss, fmix 13 | elif msda_config['type'] == 'Mixup': 14 | mixup = Mixup(loss, alpha=msda_config['alpha'], 15 | log_stats=log_stats, name_prefix=name_prefix) 16 | mixup_loss = mixup.loss 17 | return mixup_loss, mixup 18 | else: 19 | raise NotImplementedError() 20 | 21 | -------------------------------------------------------------------------------- /utils/train_types/msda/fmix.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from .fmix_utils import sample_mask 3 | import torch 4 | from .mixed_sample_data_augmentation import MixedSampleDataAugmentation 5 | from ..train_loss import MinMaxLoss, TrainLoss 6 | 7 | class FMixLoss(MinMaxLoss): 8 | def __init__(self, base_loss, lam=None, index=None, reformulate=False, log_stats=False, name_prefix=None): 9 | name = 'FMix_' + base_loss.name 10 | super().__init__(name, expected_format='logits', log_stats=log_stats, name_prefix=name_prefix) 11 | self.index = index 12 | self.lam = lam 13 | self.reformulate = reformulate 14 | self.base_loss = base_loss 15 | 16 | def inner_max(self, data, target): 17 | return data 18 | 19 | def forward(self, data, model_out, orig_data, y, reduction='mean'): 20 | assert self.index is not None 21 | assert self.lam is not None 22 | 23 | if not self.reformulate: 24 | y2 = y[self.index] 25 | loss_expanded = self.base_loss(data, model_out, orig_data, y, reduction='none') * self.lam\ 26 | + self.base_loss(data, model_out, orig_data, y2, reduction='none') * (1 - self.lam) 27 | else: 28 | loss_expanded = self.base_loss(data, model_out, orig_data, y, reduction='none') 29 | 30 | self._log_stats(loss_expanded) 31 | return TrainLoss.reduce(loss_expanded, reduction) 32 | 33 | class FMix(MixedSampleDataAugmentation): 34 | r""" FMix augmentation 35 | Args: 36 | decay_power (float): Decay power for frequency decay prop 1/f**d 37 | alpha (float): Alpha value for beta distribution from which to sample mean of mask 38 | size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims. -1 computes on the fly 39 | max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask. 40 | reformulate (bool): If True, uses the reformulation of [1]. 41 | 42 | """ 43 | def __init__(self, base_loss, decay_power=3, alpha=1, size=(-1, -1), max_soft=0.0, reformulate=False, 44 | log_stats=False, name_prefix=None): 45 | self.decay_power = decay_power 46 | self.reformulate = reformulate 47 | self.size = size 48 | self.alpha = alpha 49 | self.max_soft = max_soft 50 | 51 | loss = FMixLoss(base_loss, reformulate=reformulate, log_stats=log_stats, name_prefix=name_prefix) 52 | super().__init__(loss) 53 | 54 | def apply_mix(self, x): 55 | size = [] 56 | for i, s in enumerate(self.size): 57 | if s != -1: 58 | size.append(s) 59 | else: 60 | size.append(x.shape[i+2]) 61 | 62 | lam, mask = sample_mask(self.alpha, self.decay_power, size, self.max_soft, self.reformulate) 63 | index = torch.randperm(x.size(0)).to(x.device) 64 | mask = torch.from_numpy(mask).float().to(x.device) 65 | 66 | if len(self.size) == 1 and x.ndim == 3: 67 | mask = mask.unsqueeze(2) 68 | 69 | # Mix the images 70 | x_mix = mask * x + (1 - mask) * x[index] 71 | 72 | self.loss.od_weight = lam 73 | self.loss.index = index 74 | 75 | return x_mix 76 | 77 | 78 | -------------------------------------------------------------------------------- /utils/train_types/msda/mixed_sample_data_augmentation.py: -------------------------------------------------------------------------------- 1 | class MixedSampleDataAugmentation(): 2 | def __init__(self, loss): 3 | self.loss = loss 4 | 5 | def apply_mix(self, x): 6 | raise NotImplementedError 7 | 8 | def __call__(self, *args, **kwargs): 9 | return self.apply_mix(*args, **kwargs) 10 | 11 | -------------------------------------------------------------------------------- /utils/train_types/msda/mixup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from .mixed_sample_data_augmentation import MixedSampleDataAugmentation 4 | from ..train_loss import MinMaxLoss, TrainLoss 5 | 6 | class MixupLoss(MinMaxLoss): 7 | def __init__(self, base_loss, lam=None, index=None, log_stats=False, name_prefix=None): 8 | name = 'Mixup_' + base_loss.name 9 | super().__init__(name, expected_format='logits', log_stats=log_stats, name_prefix=name_prefix) 10 | self.index = index 11 | self.lam = lam 12 | self.base_loss = base_loss 13 | 14 | def inner_max(self, data, target): 15 | return data 16 | 17 | def forward(self, data, model_out, orig_data, y, reduction='mean'): 18 | assert self.index is not None 19 | assert self.lam is not None 20 | 21 | y2 = y[self.index] 22 | loss_expanded = self.lam * self.base_loss(data, model_out, orig_data, y, reduction='none')\ 23 | + (1. - self.lam) * self.base_loss(data, model_out, orig_data, y2, reduction='none') 24 | 25 | self._log_stats(loss_expanded) 26 | return TrainLoss.reduce(loss_expanded, reduction) 27 | 28 | class Mixup(MixedSampleDataAugmentation): 29 | def __init__(self, base_loss, alpha=1, log_stats=False, name_prefix=None): 30 | self.alpha = alpha 31 | loss = MixupLoss(base_loss, log_stats=log_stats, name_prefix=name_prefix) 32 | super().__init__(loss) 33 | 34 | def apply_mix(self, x): 35 | if self.alpha > 0: 36 | lam = np.random.beta(self.alpha, self.alpha) 37 | else: 38 | lam = 1 39 | 40 | index = torch.randperm(x.size(0)).to(x.device) 41 | x_mix = lam * x + (1 - lam) * x[index, :] 42 | 43 | self.loss.od_weight = lam 44 | self.loss.index = index 45 | 46 | return x_mix 47 | 48 | 49 | -------------------------------------------------------------------------------- /utils/train_types/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .config_creators import create_optimizer_config, create_sam_optimizer_config -------------------------------------------------------------------------------- /utils/train_types/optimizers/config_creators.py: -------------------------------------------------------------------------------- 1 | def create_optimizer_config(optimizer_type, lr, 2 | weight_decay=0, momentum=0, nesterov=False, 3 | mixed_precision=False, ema=False, ema_decay=0.990): 4 | optimizer_config = {'optimizer_type': optimizer_type, 'lr': lr, 'weight_decay': weight_decay, 5 | 'mixed_precision': mixed_precision, 'ema': ema, 'ema_decay': ema_decay, } 6 | if optimizer_type.lower() == 'sgd': 7 | optimizer_config['momentum'] = momentum 8 | optimizer_config['nesterov'] = nesterov 9 | return optimizer_config 10 | 11 | def create_sam_optimizer_config(lr, 12 | weight_decay=0, momentum=0, nesterov=False, 13 | sam_rho=0.05, sam_adaptive=False, 14 | ema=False, ema_decay=0.990): 15 | optimizer_config = {'optimizer_type': 'SAM', 'lr': lr, 'weight_decay': weight_decay, 16 | 'momentum': momentum, 'nesterov': nesterov, 17 | 'sam_rho': sam_rho, 'sam_adaptive': sam_adaptive, 18 | 'ema': ema, 'ema_decay': ema_decay, } 19 | return optimizer_config 20 | 21 | def add_cosine_swa_to_optimizer_config(epochs, cycle_length, update_frequency, 22 | virtual_schedule_length, virtual_schedule_swa_end, 23 | virtual_schedule_lr, scheduler_config): 24 | swa_config = {'epochs': epochs, 'cycle_length': cycle_length, 25 | 'update_frequency': update_frequency, 26 | 'swa_schedule_type': 'cosine', 27 | 'virtual_schedule_length': virtual_schedule_length, 28 | 'virtual_schedule_swa_end': virtual_schedule_swa_end, 29 | 'virtual_schedule_lr': virtual_schedule_lr} 30 | scheduler_config['swa_config'] = swa_config 31 | 32 | 33 | def add_constant_swa_to_optimizer_config(epochs, update_frequency, 34 | virtual_schedule_lr, scheduler_config): 35 | swa_config = {'epochs': epochs, 36 | 'update_frequency': update_frequency, 37 | 'swa_schedule_type': 'constant', 38 | 'virtual_schedule_lr': virtual_schedule_lr} 39 | scheduler_config['swa_config'] = swa_config -------------------------------------------------------------------------------- /utils/train_types/optimizers/factory.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/M4xim4l/InNOutRobustness/d81d1d26e5ebc9193009e3d92bd67b5e01d6cfd6/utils/train_types/optimizers/factory.py -------------------------------------------------------------------------------- /utils/train_types/optimizers/sam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class SAM(torch.optim.Optimizer): 5 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs): 6 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 7 | 8 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 9 | super(SAM, self).__init__(params, defaults) 10 | 11 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 12 | self.param_groups = self.base_optimizer.param_groups 13 | 14 | @torch.no_grad() 15 | def first_step(self, zero_grad=False): 16 | grad_norm = self._grad_norm() 17 | for group in self.param_groups: 18 | scale = group["rho"] / (grad_norm + 1e-12) 19 | 20 | for p in group["params"]: 21 | if p.grad is None: continue 22 | self.state[p]["old_p"] = p.data.clone() 23 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) 24 | p.add_(e_w) # climb to the local maximum "w + e(w)" 25 | 26 | if zero_grad: self.zero_grad() 27 | 28 | @torch.no_grad() 29 | def second_step(self, zero_grad=False): 30 | for group in self.param_groups: 31 | for p in group["params"]: 32 | if p.grad is None: continue 33 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" 34 | 35 | self.base_optimizer.step() # do the actual "sharpness-aware" update 36 | 37 | if zero_grad: self.zero_grad() 38 | 39 | @torch.no_grad() 40 | def step(self, closure=None): 41 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" 42 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass 43 | 44 | self.first_step(zero_grad=True) 45 | closure() 46 | self.second_step() 47 | 48 | def _grad_norm(self): 49 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism 50 | norm = torch.norm( 51 | torch.stack([ 52 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) 53 | for group in self.param_groups for p in group["params"] 54 | if p.grad is not None 55 | ]), 56 | p=2 57 | ) 58 | return norm 59 | 60 | def load_state_dict(self, state_dict): 61 | super().load_state_dict(state_dict) 62 | self.base_optimizer.param_groups = self.param_groups 63 | -------------------------------------------------------------------------------- /utils/train_types/plain_training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .in_distribution_training import InDistributionTraining 3 | from .helpers import get_distance 4 | from .train_loss import ZeroLoss 5 | 6 | class PlainTraining(InDistributionTraining): 7 | def __init__(self, model, optimizer_config, epochs, device, num_classes, clean_criterion='ce', 8 | lr_scheduler_config=None, msda_config=None, model_config=None, test_epochs=1, verbose=100, saved_model_dir='SavedModels', 9 | saved_log_dir='Logs'): 10 | distance = get_distance('l2') 11 | super().__init__('plain', model, distance, optimizer_config, epochs, device, num_classes, 12 | train_clean=False, id_trades=True, clean_weight=1., id_adv_weight=0., 13 | lr_scheduler_config=lr_scheduler_config, msda_config=msda_config, model_config=model_config, 14 | test_epochs=test_epochs, clean_criterion=clean_criterion, 15 | verbose=verbose, saved_model_dir=saved_model_dir, saved_log_dir=saved_log_dir) 16 | 17 | def _get_id_criterion(self, epoch, model, name_prefix='ID'): 18 | id_train_criterion = ZeroLoss() 19 | return id_train_criterion 20 | 21 | def _get_train_type_config(self, loader_config=None): 22 | base_config = self._get_base_config() 23 | plain_config = self._get_plain_config() 24 | 25 | configs = {} 26 | configs['Base'] = base_config 27 | configs['Plain Training'] = plain_config 28 | configs['Optimizer'] = self.optimizer_config 29 | configs['Scheduler'] = self.lr_scheduler_config 30 | configs['MSDA'] = self.msda_config 31 | configs['Data Loader'] = loader_config 32 | configs['Model'] = self.model_config 33 | 34 | 35 | return configs 36 | 37 | def _get_plain_config(self): 38 | config_dict = {} 39 | config_dict['Clean Weight'] = self.clean_weight 40 | return config_dict 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /utils/train_types/randomized_smoothing_training.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | from .in_distribution_training import InDistributionTraining 7 | from .train_loss import MinMaxLoss, TrainLoss 8 | from .helpers import get_distance 9 | 10 | 11 | class RandomizedSmoothingLoss(MinMaxLoss): 12 | def __init__(self, noise_scales, log_stats=False, name_prefix=None): 13 | super().__init__('RandomizedSmoothingLoss', 'log_probabilities', log_stats=log_stats, name_prefix=name_prefix) 14 | self.noise_scales = torch.FloatTensor(noise_scales) 15 | 16 | def inner_max(self, data, target): 17 | chosen_noise_scale = torch.randint(len(self.noise_scales), (data.shape[0],)) 18 | noise_eps = self.noise_scales[chosen_noise_scale].view(data.shape[0], *([1] * len(data.shape[1:]))).to(data.device) 19 | adv_samples = (data + noise_eps * torch.randn_like(data)).clamp(0.0, 1.0) 20 | return adv_samples 21 | 22 | def forward(self, data, model_out, orig_data, y, reduction='mean'): 23 | prep_out = self._prepare_input(model_out) 24 | loss_expanded = F.cross_entropy(prep_out, y, reduction='none' ) 25 | self._log_stats(loss_expanded) 26 | return TrainLoss.reduce(loss_expanded, reduction) 27 | 28 | class RandomizedSmoothingTraining(InDistributionTraining): 29 | def __init__(self, model, optimizer_config, epochs, device, num_classes, noise_scales, train_clean=True, 30 | lr_scheduler_config=None, model_config=None, 31 | test_epochs=1, verbose=100, saved_model_dir='SavedModels', 32 | saved_log_dir='Logs'): 33 | 34 | distance = get_distance('l2') 35 | self.noise_scales = noise_scales 36 | 37 | super().__init__('RandomizedSmoothing', model, distance, optimizer_config, epochs, device, num_classes, 38 | train_clean=train_clean, lr_scheduler_config=lr_scheduler_config, model_config=model_config, 39 | test_epochs=test_epochs, verbose=verbose, saved_model_dir=saved_model_dir, 40 | saved_log_dir=saved_log_dir) 41 | 42 | 43 | def _get_id_criterion(self, epoch, model, name_prefix='ID'): 44 | id_train_criterion = RandomizedSmoothingLoss(self.noise_scales, log_stats=True, name_prefix=name_prefix) 45 | return id_train_criterion 46 | 47 | def _get_train_type_config(self, loader_config=None): 48 | base_config = self._get_base_config() 49 | adv_config = self._get_randomized_smoothing_training_config() 50 | 51 | configs = {} 52 | configs['Base'] = base_config 53 | configs['Randomized Smoothing Training'] = adv_config 54 | configs['Optimizer'] = self.optimizer_config 55 | configs['Scheduler'] = self.lr_scheduler_config 56 | configs['MSDA'] = self.msda_config 57 | 58 | configs['Data Loader'] = loader_config 59 | configs['Model'] = self.model_config 60 | 61 | return configs 62 | 63 | def _get_randomized_smoothing_training_config(self): 64 | config_dict = {} 65 | config_dict['Num Noise scales'] = len(self.noise_scales) 66 | config_dict['Min Noise scales'] = torch.min(self.noise_scales).item() 67 | config_dict['Max Noise scales'] = torch.max(self.noise_scales).item() 68 | return config_dict 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /utils/train_types/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .config_creators import * 2 | from .scheduler_factory import create_scheduler -------------------------------------------------------------------------------- /utils/train_types/schedulers/config_creators.py: -------------------------------------------------------------------------------- 1 | 2 | def create_cosine_annealing_scheduler_config(cycle_length, lr_min, cycle_multiplier=1, warmup_length=0): 3 | scheduler_config = {'cycle_length': cycle_length, 'scheduler_type': 'CosineAnnealing', 4 | 'lr_min': lr_min, 'cycle_multiplier': cycle_multiplier, 5 | 'warmup_length': warmup_length} 6 | return scheduler_config 7 | 8 | def create_piecewise_consant_scheduler_config(epochs, decay_epochs, decay_rate, warmup_length=0): 9 | scheduler_config = {'cycle_length': epochs, 'scheduler_type': 'StepLR', 10 | 'decay_epochs': decay_epochs, 'decay_rate': decay_rate, 11 | 'warmup_length': warmup_length} 12 | return scheduler_config 13 | 14 | -------------------------------------------------------------------------------- /utils/train_types/schedulers/cosine_lr.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import numpy as np 10 | import torch 11 | 12 | from .scheduler import Scheduler 13 | 14 | 15 | _logger = logging.getLogger(__name__) 16 | 17 | 18 | class CosineLRScheduler(Scheduler): 19 | """ 20 | Cosine decay with restarts. 21 | This is described in the paper https://arxiv.org/abs/1608.03983. 22 | 23 | Inspiration from 24 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 25 | """ 26 | 27 | def __init__(self, 28 | optimizer: torch.optim.Optimizer, 29 | t_initial: int, 30 | t_mul: float = 1., 31 | lr_min: float = 0., 32 | decay_rate: float = 1., 33 | warmup_t=0, 34 | warmup_lr_init=0, 35 | warmup_prefix=False, 36 | cycle_limit=0, 37 | noise_range_t=None, 38 | noise_pct=0.67, 39 | noise_std=1.0, 40 | noise_seed=42, 41 | initialize=True) -> None: 42 | super().__init__( 43 | optimizer, param_group_field="lr", 44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 45 | initialize=initialize) 46 | 47 | assert t_initial >= 0 48 | assert lr_min >= 0 49 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 50 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 51 | "rate since t_initial = t_mul = eta_mul = 1.") 52 | self.t_initial = t_initial 53 | self.t_mul = t_mul 54 | self.lr_min = lr_min 55 | self.decay_rate = decay_rate 56 | self.cycle_limit = cycle_limit 57 | self.warmup_t = warmup_t 58 | self.warmup_lr_init = warmup_lr_init 59 | self.warmup_prefix = warmup_prefix 60 | if self.warmup_t: 61 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 62 | super().update_groups(self.warmup_lr_init) 63 | else: 64 | self.warmup_steps = [1 for _ in self.base_values] 65 | 66 | def _get_lr(self, t): 67 | if t < self.warmup_t: 68 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 69 | else: 70 | if self.warmup_prefix: 71 | t = t - self.warmup_t 72 | 73 | if self.t_mul != 1: 74 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 75 | t_i = self.t_mul ** i * self.t_initial 76 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 77 | else: 78 | i = t // self.t_initial 79 | t_i = self.t_initial 80 | t_curr = t - (self.t_initial * i) 81 | 82 | gamma = self.decay_rate ** i 83 | lr_min = self.lr_min * gamma 84 | lr_max_values = [v * gamma for v in self.base_values] 85 | 86 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 87 | lrs = [ 88 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 89 | ] 90 | else: 91 | lrs = [self.lr_min for _ in self.base_values] 92 | 93 | return lrs 94 | 95 | def get_epoch_values(self, epoch: float): 96 | return self._get_lr(epoch) 97 | 98 | def get_cycle_length(self, cycles=0): 99 | if not cycles: 100 | cycles = self.cycle_limit 101 | cycles = max(1, cycles) 102 | if self.t_mul == 1.0: 103 | return self.t_initial * cycles 104 | else: 105 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 106 | -------------------------------------------------------------------------------- /utils/train_types/schedulers/factory.py: -------------------------------------------------------------------------------- 1 | # import torch.optim as optim 2 | # import math 3 | # import numpy as np 4 | # 5 | # def get_scheduler(optimizer, lr_scheduler_config): 6 | # if lr_scheduler_config['scheduler_type'] == 'StepLR': 7 | # batchwise_scheduler = False 8 | # scheduler = optim.lr_scheduler.StepLR(optimizer, lr_scheduler_config['step_size'], 9 | # lr_scheduler_config['gamma'], 10 | # lr_scheduler_config['last_epoch']) 11 | # elif lr_scheduler_config['scheduler_type'] == 'ExponentialLR': 12 | # batchwise_scheduler = False 13 | # scheduler = optim.lr_scheduler.ExponentialLR(optimizer, lr_scheduler_config['gamma'], 14 | # lr_scheduler_config['last_epoch']) 15 | # elif lr_scheduler_config['scheduler_type'] == 'PiecewiseConstant': 16 | # batchwise_scheduler = False 17 | # 18 | # def piecewise(epoch): 19 | # for stage_end, stage_factor in zip(lr_scheduler_config['epoch_stages'], 20 | # lr_scheduler_config['stages_factors']): 21 | # if epoch < stage_end: 22 | # return stage_factor 23 | # print(f'Warning: Epoch {epoch} not in epoch stages') 24 | # return lr_scheduler_config['stages_factors'][-1] 25 | # 26 | # scheduler = optim.lr_scheduler.LambdaLR(optimizer, piecewise) 27 | # elif lr_scheduler_config['scheduler_type'] == 'CosineAnnealing': 28 | # batchwise_scheduler = True 29 | # period_length_batches = lr_scheduler_config['period_length_batches'] 30 | # 31 | # def cosine_annealing(step, period_length_batches, lr_min, lr_max): 32 | # return lr_min + (lr_max - lr_min) * 0.5 * ( 33 | # 1 + math.cos((step % period_length_batches) / period_length_batches * math.pi)) 34 | # 35 | # cosine_lambda = lambda x: (lr_scheduler_config['period_falloff'] ** np.floor( 36 | # x / period_length_batches)) * cosine_annealing(x, period_length_batches, lr_scheduler_config['lr_min'], 37 | # lr_scheduler_config['lr_max']) 38 | # 39 | # if 'warmup_length_batches' in lr_scheduler_config: 40 | # warmup_length = lr_scheduler_config['warmup_length_batches'] 41 | # lr_lambda = lambda x: min(cosine_lambda(x), lr_scheduler_config['lr_max'] * x / warmup_length ) 42 | # else: 43 | # lr_lambda = cosine_lambda 44 | # 45 | # scheduler = optim.lr_scheduler.LambdaLR( 46 | # optimizer, lr_lambda=lr_lambda) 47 | # elif lr_scheduler_config['scheduler_type'] == 'CyclicalLR': 48 | # batchwise_scheduler = True 49 | # # Scaler: we can adapt this if we do not want the triangular CLR 50 | # period_length_batches = lr_scheduler_config['period_length_batches'] 51 | # midpoint = lr_scheduler_config['midpoint'] * period_length_batches 52 | # period_falloff = lr_scheduler_config['period_falloff'] 53 | # xp = np.array([0, midpoint, period_length_batches]) 54 | # yp = np.array([lr_scheduler_config['lr_start'], lr_scheduler_config['lr_mid'], 55 | # lr_scheduler_config['lr_end']]) 56 | # 57 | # def cylic_lr(x): 58 | # period_factor = (period_falloff ** np.floor(x / period_length_batches)) 59 | # interp = np.interp(x % period_length_batches, xp, yp) 60 | # return period_factor * interp 61 | # 62 | # scheduler = optim.lr_scheduler.LambdaLR(optimizer, [cylic_lr]) 63 | # elif lr_scheduler_config['scheduler_type'] == 'LogarithmicFindLRScheduler': 64 | # batchwise_scheduler = True 65 | # q = math.pow(lr_scheduler_config['lr_end'] / lr_scheduler_config['lr_start'], 66 | # 1 / (lr_scheduler_config['period_length_batches'] - 1)) 67 | # 68 | # def log_scheduler(step, lr_start, q): 69 | # return lr_start * q ** step 70 | # 71 | # lr_lambda = lambda x: log_scheduler(x, lr_scheduler_config['lr_start'], q) 72 | # scheduler = optim.lr_scheduler.LambdaLR(optimizer, [lr_lambda]) 73 | # else: 74 | # raise ValueError('Scheduler not supported {}'.format(lr_scheduler_config['scheduler_type'])) 75 | # 76 | # return scheduler, batchwise_scheduler 77 | -------------------------------------------------------------------------------- /utils/train_types/schedulers/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .cosine_lr import CosineLRScheduler 5 | from .step_lr import StepLRScheduler 6 | 7 | def create_scheduler(args, optimizer): 8 | num_epochs = args['cycle_length'] 9 | noise_range = None 10 | 11 | if args['scheduler_type'] == 'CosineAnnealing': 12 | lr_scheduler = CosineLRScheduler( 13 | optimizer, 14 | t_initial=num_epochs, 15 | lr_min=args['lr_min'], 16 | warmup_t=args['warmup_length'], 17 | t_mul=args['cycle_multiplier'], 18 | noise_range_t=noise_range, 19 | ) 20 | elif args['scheduler_type'] == 'StepLR': 21 | lr_scheduler = StepLRScheduler( 22 | optimizer, 23 | decay_epochs=args['decay_epochs'], 24 | decay_rate=args['decay_rate'], 25 | warmup_t=args['warmup_length'], 26 | noise_range_t=noise_range, 27 | ) 28 | else: 29 | raise NotImplementedError(f'Scheduler {args.sched} not implemented') 30 | 31 | return lr_scheduler, num_epochs -------------------------------------------------------------------------------- /utils/train_types/schedulers/step_lr.py: -------------------------------------------------------------------------------- 1 | """ Step Scheduler 2 | 3 | Basic step LR schedule with warmup, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import math 8 | import torch 9 | 10 | from .scheduler import Scheduler 11 | 12 | 13 | class StepLRScheduler(Scheduler): 14 | """ 15 | """ 16 | 17 | def __init__(self, 18 | optimizer: torch.optim.Optimizer, 19 | decay_epochs: list, 20 | decay_rate: float = 1., 21 | warmup_t=0, 22 | warmup_lr_init=0, 23 | noise_range_t=None, 24 | noise_pct=0.67, 25 | noise_std=1.0, 26 | noise_seed=42, 27 | initialize=True, 28 | ) -> None: 29 | super().__init__( 30 | optimizer, param_group_field="lr", 31 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 32 | initialize=initialize) 33 | 34 | self.decay_epochs = decay_epochs 35 | self.decay_rate = decay_rate 36 | self.warmup_t = warmup_t 37 | self.warmup_lr_init = warmup_lr_init 38 | if self.warmup_t: 39 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 40 | super().update_groups(self.warmup_lr_init) 41 | else: 42 | self.warmup_steps = [1 for _ in self.base_values] 43 | 44 | def _get_lr(self, t): 45 | if t < self.warmup_t: 46 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 47 | else: 48 | decay_factor = 1. 49 | for epoch in self.decay_epochs: 50 | if t >= epoch: 51 | decay_factor *= self.decay_rate 52 | lrs = [v * decay_factor for v in self.base_values] 53 | return lrs 54 | 55 | def get_epoch_values(self, epoch: float): 56 | return self._get_lr(epoch) 57 | --------------------------------------------------------------------------------