├── CONTRIBUTING-ARCHIVED.md ├── README.txt ├── data.py ├── eval.py ├── existing_methods.py ├── gen_color_mnist.py ├── main.py ├── models.py └── utils.py /CONTRIBUTING-ARCHIVED.md: -------------------------------------------------------------------------------- 1 | # ARCHIVED 2 | 3 | This project is `Archived` and is no longer actively maintained; 4 | We are not accepting contributions or Pull Requests. 5 | 6 | -------------------------------------------------------------------------------- /README.txt: -------------------------------------------------------------------------------- 1 | ReadMe for the paper "Predicting with High Correlation Features": 2 | 3 | Version of softwares used: 4 | 5 | 1. Python 3.6.8 6 | 2. PyTorch 1.0.0 7 | 8 | 9 | Commands: 10 | 11 | 1. Generate Colored MNIST: 12 | python gen_color_mnist.py 13 | 14 | 2. Sample command to run Correlation based regularization: 15 | python main.py --dataset fgbg_cmnist_cpr0.5-0.5 --seed 0 --root_dir cmnist --save_dir corr --beta 0.1 16 | 17 | 3. Sample commands to run existing regularization/robustness methods: 18 | 19 | - Maximum Likelihood Estimate (MLE): 20 | 21 | python existing_methods.py --dataset fgbg_cmnist_cpr0.5-0.5 --seed 0 --root_dir cmnist --lr 0.0001 --bs 128 --save_dir mle 22 | 23 | - Adaptive Batch Normalization (AdaBN): 24 | 25 | python existing_methods.py --dataset fgbg_cmnist_cpr0.5-0.5 --seed 0 --root_dir cmnist --lr 0.0001 --bs 32 --save_dir adabn --bn --bn_eval 26 | 27 | - Adversarial Logit Pairing (ALP): 28 | 29 | python existing_methods.py --dataset fgbg_cmnist_cpr0.5-0.5 --seed 0 --root_dir cmnist --lr 0.0001 --save_dir alp --alp --nsteps 20 --stepsz 2 --epsilon 8 --beta 0.1 30 | 31 | - Clean Logit Pairing (CLP): 32 | 33 | python existing_methods.py --dataset fgbg_cmnist_cpr0.5-0.5 --seed 0 --root_dir cmnist --lr 0.0001 --save_dir clp --clp --beta 0.5 34 | 35 | - Projected Gradient Descent (PGD) based adversarial training: 36 | 37 | python existing_methods.py --dataset fgbg_cmnist_cpr0.5-0.5 --seed 0 --root_dir cmnist --lr 0.0001 --save_dir pgd --pgd --nsteps 20 --stepsz 2 --epsilon 8 38 | 39 | - Variational Information Bottleneck (VIB): 40 | 41 | python existing_methods.py --dataset fgbg_cmnist_cpr0.5-0.5 --seed 0 --root_dir cmnist --lr 0.001 --save_dir inp --inp_noise 0.2 42 | 43 | - Input Noise: 44 | 45 | python existing_methods.py --dataset fgbg_cmnist_cpr0.5-0.5 --seed 0 --root_dir cmnist --lr 0.001 --save_dir inp_noise --inp_noise 0.2 46 | 47 | 48 | 4. Evaluate a trained model on another dataset (here [root_dir] and [save_dir] should be the directories in which the model to be used is saved): 49 | python eval.py --root_dir cmnist --save_dir corr --dataset mnistm 50 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | from utils import _split_train_val 10 | import torchvision.datasets as datasets 11 | import torch.utils.data as utils 12 | import errno 13 | from PIL import Image 14 | 15 | torch.manual_seed(0) 16 | 17 | NUM_WORKERS = 0 18 | 19 | def get_dataset(args): 20 | if args.dataset=='mnist': 21 | trans = ([ transforms.ToTensor()]) 22 | trans = transforms.Compose(trans) 23 | fulltrainset = torchvision.datasets.MNIST(root=args.data, train=True, transform=trans, download=True) 24 | 25 | train_set, valset = _split_train_val(fulltrainset, val_fraction=0.1) 26 | 27 | 28 | trainloader = torch.utils.data.DataLoader(train_set, batch_size=args.bs, shuffle=True, 29 | num_workers=NUM_WORKERS, pin_memory=True) 30 | validloader = torch.utils.data.DataLoader(valset, batch_size=args.bs, shuffle=False, 31 | num_workers=NUM_WORKERS, pin_memory=True) 32 | 33 | 34 | test_set = torchvision.datasets.MNIST(root=args.data, train=False, transform=trans) 35 | testloader = torch.utils.data.DataLoader(test_set, batch_size=args.bs, shuffle=False, num_workers=NUM_WORKERS) 36 | 37 | nb_classes = 10 38 | dim_inp=28*28 # np.prod(train_set.data.size()[1:]) 39 | elif 'cmnist' in args.dataset: 40 | data_dir_cmnist = args.data + 'cmnist/' + args.dataset + '/' 41 | data_x = np.load(data_dir_cmnist+'train_x.npy') 42 | data_y = np.load(data_dir_cmnist+'train_y.npy') 43 | 44 | data_x = torch.from_numpy(data_x).type('torch.FloatTensor') 45 | data_y = torch.from_numpy(data_y).type('torch.LongTensor') 46 | 47 | my_dataset = utils.TensorDataset(data_x,data_y) 48 | 49 | train_set, valset = _split_train_val(my_dataset, val_fraction=0.1) 50 | 51 | trainloader = torch.utils.data.DataLoader(train_set, batch_size=args.bs, shuffle=True, num_workers=NUM_WORKERS) 52 | validloader = torch.utils.data.DataLoader(valset, batch_size=args.bs, shuffle=False, 53 | num_workers=NUM_WORKERS, pin_memory=True) 54 | 55 | 56 | data_x = np.load(data_dir_cmnist+'test_x.npy') 57 | data_y = np.load(data_dir_cmnist+'test_y.npy') 58 | data_x = torch.from_numpy(data_x).type('torch.FloatTensor') 59 | data_y = torch.from_numpy(data_y).type('torch.LongTensor') 60 | my_dataset = utils.TensorDataset(data_x,data_y) 61 | testloader = torch.utils.data.DataLoader(my_dataset, batch_size=args.bs, shuffle=False, num_workers=NUM_WORKERS) 62 | 63 | nb_classes = 10 64 | dim_inp=28*28* 3 65 | elif args.dataset=='mnistm': 66 | trans = ([transforms.ToTensor()]) 67 | trans = transforms.Compose(trans) 68 | fulltrainset = MNISTM(root=args.data, train=True, transform=trans, download=True) 69 | 70 | train_set, valset = _split_train_val(fulltrainset, val_fraction=0.1) 71 | 72 | 73 | trainloader = torch.utils.data.DataLoader(train_set, batch_size=args.bs, shuffle=True, 74 | num_workers=2, pin_memory=True) 75 | validloader = torch.utils.data.DataLoader(valset, batch_size=args.bs, shuffle=False, 76 | num_workers=2, pin_memory=True) 77 | 78 | 79 | test_set = MNISTM(root=args.data, train=False, transform=trans) 80 | testloader = torch.utils.data.DataLoader(test_set, batch_size=args.bs, shuffle=False, num_workers=2) 81 | 82 | nb_classes = 10 83 | dim_inp=3*28*28 # np.prod(train_set.data.size()[1:]) 84 | elif args.dataset=='svhn': 85 | trans = ([torchvision.transforms.Resize((28,28), interpolation=2), transforms.ToTensor()]) 86 | trans = transforms.Compose(trans) 87 | fulltrainset = torchvision.datasets.SVHN(args.data, split='train', transform=trans, target_transform=None, download=True) 88 | 89 | train_set, valset = _split_train_val(fulltrainset, val_fraction=0.1) 90 | 91 | 92 | trainloader = torch.utils.data.DataLoader(train_set, batch_size=args.bs, shuffle=True, 93 | num_workers=NUM_WORKERS, pin_memory=True) 94 | validloader = torch.utils.data.DataLoader(valset, batch_size=args.bs, shuffle=False, 95 | num_workers=NUM_WORKERS, pin_memory=True) 96 | 97 | 98 | test_set = torchvision.datasets.SVHN(args.data, split='test', transform=trans, target_transform=None, download=True) 99 | testloader = torch.utils.data.DataLoader(test_set, batch_size=args.bs, shuffle=False, num_workers=NUM_WORKERS) 100 | 101 | nb_classes = 10 102 | dim_inp=3*28*28 103 | return trainloader, validloader, testloader, nb_classes, dim_inp 104 | 105 | 106 | class MNISTM(torch.utils.data.Dataset): 107 | """`MNIST-M Dataset.""" 108 | 109 | url = "https://github.com/VanushVaswani/keras_mnistm/releases/download/1.0/keras_mnistm.pkl.gz" 110 | 111 | raw_folder = 'raw' 112 | processed_folder = 'processed' 113 | training_file = 'mnist_m_train.pt' 114 | test_file = 'mnist_m_test.pt' 115 | 116 | def __init__(self, 117 | root, mnist_root="data", 118 | train=True, 119 | transform=None, target_transform=None, 120 | download=False): 121 | """Init MNIST-M dataset.""" 122 | super(MNISTM, self).__init__() 123 | self.root = os.path.expanduser(root) 124 | self.mnist_root = os.path.expanduser(mnist_root) 125 | self.transform = transform 126 | self.target_transform = target_transform 127 | self.train = train # training set or test set 128 | 129 | if download: 130 | self.download() 131 | 132 | if not self._check_exists(): 133 | raise RuntimeError('Dataset not found.' + 134 | ' You can use download=True to download it') 135 | 136 | if self.train: 137 | self.train_data, self.train_labels = \ 138 | torch.load(os.path.join(self.root, 139 | self.processed_folder, 140 | self.training_file)) 141 | else: 142 | self.test_data, self.test_labels = \ 143 | torch.load(os.path.join(self.root, 144 | self.processed_folder, 145 | self.test_file)) 146 | 147 | def __getitem__(self, index): 148 | """Get images and target for data loader. 149 | Args: 150 | index (int): Index 151 | Returns: 152 | tuple: (image, target) where target is index of the target class. 153 | """ 154 | if self.train: 155 | img, target = self.train_data[index], self.train_labels[index] 156 | else: 157 | img, target = self.test_data[index], self.test_labels[index] 158 | 159 | # doing this so that it is consistent with all other datasets 160 | # to return a PIL Image 161 | img = Image.fromarray(img.squeeze().numpy(), mode='RGB') 162 | 163 | if self.transform is not None: 164 | img = self.transform(img) 165 | 166 | if self.target_transform is not None: 167 | target = self.target_transform(target) 168 | 169 | return img, target 170 | 171 | def __len__(self): 172 | """Return size of dataset.""" 173 | if self.train: 174 | return len(self.train_data) 175 | else: 176 | return len(self.test_data) 177 | 178 | def _check_exists(self): 179 | return os.path.exists(os.path.join(self.root, 180 | self.processed_folder, 181 | self.training_file)) and \ 182 | os.path.exists(os.path.join(self.root, 183 | self.processed_folder, 184 | self.test_file)) 185 | 186 | def download(self): 187 | """Download the MNIST data.""" 188 | # import essential packages 189 | from six.moves import urllib 190 | import gzip 191 | import pickle 192 | from torchvision import datasets 193 | 194 | # check if dataset already exists 195 | if self._check_exists(): 196 | return 197 | 198 | # make data dirs 199 | try: 200 | os.makedirs(os.path.join(self.root, self.raw_folder)) 201 | os.makedirs(os.path.join(self.root, self.processed_folder)) 202 | except OSError as e: 203 | if e.errno == errno.EEXIST: 204 | pass 205 | else: 206 | raise 207 | 208 | # download pkl files 209 | print('Downloading ' + self.url) 210 | filename = self.url.rpartition('/')[2] 211 | file_path = os.path.join(self.root, self.raw_folder, filename) 212 | if not os.path.exists(file_path.replace('.gz', '')): 213 | data = urllib.request.urlopen(self.url) 214 | with open(file_path, 'wb') as f: 215 | f.write(data.read()) 216 | with open(file_path.replace('.gz', ''), 'wb') as out_f, \ 217 | gzip.GzipFile(file_path) as zip_f: 218 | out_f.write(zip_f.read()) 219 | os.unlink(file_path) 220 | 221 | # process and save as torch files 222 | print('Processing...') 223 | 224 | # load MNIST-M images from pkl file 225 | with open(file_path.replace('.gz', ''), "rb") as f: 226 | mnist_m_data = pickle.load(f, encoding='bytes') 227 | mnist_m_train_data = torch.ByteTensor(mnist_m_data[b'train']) 228 | mnist_m_test_data = torch.ByteTensor(mnist_m_data[b'test']) 229 | 230 | # get MNIST labels 231 | mnist_train_labels = datasets.MNIST(root=self.mnist_root, 232 | train=True, 233 | download=True).train_labels 234 | mnist_test_labels = datasets.MNIST(root=self.mnist_root, 235 | train=False, 236 | download=True).test_labels 237 | 238 | # save MNIST-M dataset 239 | training_set = (mnist_m_train_data, mnist_train_labels) 240 | test_set = (mnist_m_test_data, mnist_test_labels) 241 | with open(os.path.join(self.root, 242 | self.processed_folder, 243 | self.training_file), 'wb') as f: 244 | torch.save(training_set, f) 245 | with open(os.path.join(self.root, 246 | self.processed_folder, 247 | self.test_file), 'wb') as f: 248 | torch.save(test_set, f) 249 | 250 | print('Done!') 251 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | 2 | from argparse import Namespace 3 | import argparse 4 | import numpy as np 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | import tqdm 10 | from data import get_dataset 11 | 12 | 13 | 14 | parser = argparse.ArgumentParser(description='Predicting with high correlation features') 15 | 16 | # Directories 17 | parser.add_argument('--data', type=str, default='datasets/', 18 | help='location of the data corpus') 19 | parser.add_argument('--root_dir', type=str, default='default/', 20 | help='root dir path to save the log and the final model') 21 | parser.add_argument('--save_dir', type=str, default='0/', 22 | help='dir path (inside root_dir) to save the log and the final model') 23 | 24 | # dataset 25 | parser.add_argument('--dataset', type=str, default='mnistm', 26 | help='dataset name') 27 | 28 | # Adaptive BN 29 | parser.add_argument('--bn_eval', action='store_true', 30 | help='adapt BN stats during eval') 31 | 32 | # hyperparameters 33 | parser.add_argument('--seed', type=int, default=1111, 34 | help='random seed') 35 | parser.add_argument('--bs', type=int, default=128, metavar='N', 36 | help='batch size') 37 | 38 | # meta specifications 39 | parser.add_argument('--cuda', action='store_false', 40 | help='use CUDA') 41 | parser.add_argument('--gpu', nargs='+', type=int, default=[0]) 42 | 43 | 44 | args = parser.parse_args() 45 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(i) for i in args.gpu) 46 | 47 | 48 | args.root_dir = os.path.join('runs/', args.root_dir) 49 | args.save_dir = os.path.join(args.root_dir, args.save_dir) 50 | 51 | use_cuda = torch.cuda.is_available() 52 | torch.manual_seed(args.seed) 53 | if use_cuda: 54 | torch.cuda.manual_seed(args.seed) 55 | 56 | ############################################################################### 57 | # Load data 58 | ############################################################################### 59 | 60 | print('==> Preparing data..') 61 | trainloader, validloader, testloader, nb_classes, dim_inp = get_dataset(args) 62 | 63 | 64 | def test(loader, model): 65 | global best_acc, args 66 | 67 | if args.bn_eval: # forward prop data twice to update BN running averages 68 | model.train() 69 | for _ in range(2): 70 | for batch_idx, (inputs, targets) in enumerate(loader): 71 | if use_cuda: 72 | inputs, targets = inputs.cuda(), targets.cuda() 73 | inputs, targets = Variable(inputs), Variable(targets) 74 | _ = (model(inputs, train=False)) 75 | 76 | model.eval() 77 | test_loss, correct, total = 0,0,0 78 | tot_iters = len(loader) 79 | for batch_idx in tqdm.tqdm(range(tot_iters), total=tot_iters): 80 | inputs, targets = next(iter(loader)) 81 | if use_cuda: 82 | inputs, targets = inputs.cuda(), targets.cuda() 83 | with torch.no_grad(): 84 | inputs, targets = Variable(inputs), Variable(targets) 85 | outputs = (model(inputs, train=False)) 86 | _, predicted = torch.max(nn.Softmax(dim=1)(outputs).data, 1) 87 | total += targets.size(0) 88 | correct += predicted.eq(targets.data).cpu().sum() 89 | 90 | 91 | # Save checkpoint. 92 | acc = 100.*float(correct)/float(total) 93 | return acc 94 | 95 | with open(args.save_dir + '/best_model.pt', 'rb') as f: 96 | best_state = torch.load(f) 97 | model = best_state['model'] 98 | if use_cuda: 99 | model.cuda() 100 | # Run on test data. 101 | test_acc = test(testloader, model=model) 102 | best_val_acc = test(validloader, model=model) 103 | print('=' * 89) 104 | status = 'Test acc {:3.4f} at best val acc {:3.4f}'.format(test_acc, best_val_acc) 105 | print(status) 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /existing_methods.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | Existing regularization/robustness methods 4 | 5 | ''' 6 | 7 | from argparse import Namespace 8 | import sys 9 | import argparse 10 | import math 11 | import numpy as np 12 | import os 13 | import torch 14 | import torch.nn as nn 15 | from torch.autograd import Variable 16 | from torch import autograd 17 | import pickle as pkl 18 | from models import ResNet_model, CNN 19 | import torch.nn.functional as F 20 | import glob 21 | import tqdm 22 | import torch.utils.data as utils 23 | import json 24 | from data import get_dataset 25 | from utils import AttackPGD, add_gaussian_noise, pairing_loss 26 | 27 | parser = argparse.ArgumentParser(description='Predicting with high correlation features') 28 | 29 | # Directories 30 | parser.add_argument('--data', type=str, default='datasets/', 31 | help='location of the data corpus') 32 | parser.add_argument('--root_dir', type=str, default='default/', 33 | help='root dir path to save the log and the final model') 34 | parser.add_argument('--save_dir', type=str, default='0/', 35 | help='dir path (inside root_dir) to save the log and the final model') 36 | 37 | parser.add_argument('--load_dir', type=str, default='', 38 | help='dir path (inside root_dir) to load model from') 39 | 40 | 41 | ######################## 42 | ### Baseline methods ### 43 | 44 | # Vanilla MLE (simply run without any baseline method argument below) 45 | 46 | # Projected gradient descent (PGD) based adversarial training 47 | parser.add_argument('--pgd', action='store_true', help='PGD') 48 | parser.add_argument('--nsteps', type=int, default=20, metavar='N', 49 | help='num of steps for PGD') 50 | parser.add_argument('--stepsz', type=int, default=2, metavar='N', 51 | help='step size for 1st order adv training') 52 | parser.add_argument('--epsilon', type=float, default=8, 53 | help='number of pixel values (0-255) allowed for PGD which is normalized by 255 in the code') 54 | 55 | # Input Gaussian noise 56 | parser.add_argument('--inp_noise', type=float, default=0., help='Gaussian input noise with standard deviation specified here') 57 | 58 | # Adversarial logit pairing (ALP/CLP) 59 | parser.add_argument('--alp', action='store_true', 60 | help='clean logit pairing') 61 | parser.add_argument('--clp', action='store_true', 62 | help='clean logit pairing') 63 | 64 | parser.add_argument('--beta', type=float, default=0, 65 | help='coefficient used for regularization term in ALP/CLP/VIB') 66 | parser.add_argument('--anneal_beta', action='store_true', help='anneal beta from 0.0001 to specified value gradually') 67 | 68 | # Variational Information Bottleneck (VIB) 69 | parser.add_argument('--vib', action='store_true', 70 | help='use Variational Information Bottleneck') 71 | 72 | # Adaptive batch norm 73 | parser.add_argument('--bn_eval', action='store_true', 74 | help='adapt BN stats during eval') 75 | 76 | 77 | ### Baseline methods ### 78 | ######################## 79 | 80 | 81 | 82 | 83 | # dataset and architecture 84 | parser.add_argument('--dataset', type=str, default='fgbg_cmnist_cpr0.5-0.5', 85 | help='dataset name') 86 | parser.add_argument('--arch', type=str, default='resnet', 87 | help='arch name (resnet,cnn)') 88 | parser.add_argument('--depth', type=int, default=56, 89 | help='number of resblocks if using resnet architecture') 90 | parser.add_argument('--k', type=int, default=1, 91 | help='widening factor for wide resnet architecture') 92 | 93 | # Optimization hyper-parameters 94 | parser.add_argument('--seed', type=int, default=1111, 95 | help='random seed') 96 | parser.add_argument('--bs', type=int, default=128, metavar='N', 97 | help='batch size') 98 | parser.add_argument('--bn', action='store_true', 99 | help='Use Batch norm') 100 | parser.add_argument('--noaffine', action='store_true', 101 | help='no affine transformations') 102 | parser.add_argument('--lr', type=float, default=0.001, 103 | help='learning rate ') 104 | parser.add_argument('--epochs', type=int, default=200, 105 | help='upper epoch limit') 106 | parser.add_argument('--init', type=str, default="he") 107 | parser.add_argument('--wdecay', type=float, default=0.0001, 108 | help='weight decay applied to all weights') 109 | 110 | 111 | # meta specifications 112 | parser.add_argument('--validation', action='store_true', 113 | help='Compute accuracy on validation set at each epoch') 114 | parser.add_argument('--cuda', action='store_false', 115 | help='use CUDA') 116 | parser.add_argument('--gpu', nargs='+', type=int, default=[0]) 117 | 118 | 119 | args = parser.parse_args() 120 | args.root_dir = os.path.join('runs/', args.root_dir) 121 | args.save_dir = os.path.join(args.root_dir, args.save_dir) 122 | if not os.path.exists(args.save_dir): 123 | os.makedirs(args.save_dir) 124 | log_dir = args.save_dir + '/' 125 | 126 | with open(args.save_dir + '/config.txt', 'w') as f: 127 | json.dump(args.__dict__, f, indent=2) 128 | with open(args.save_dir + '/log.txt', 'w') as f: 129 | f.write('python ' + ' '.join(s for s in sys.argv) + '\n') 130 | 131 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(i) for i in args.gpu) 132 | 133 | 134 | 135 | 136 | 137 | # Set the random seed manually for reproducibility. 138 | use_cuda = torch.cuda.is_available() 139 | torch.manual_seed(args.seed) 140 | if use_cuda: 141 | torch.cuda.manual_seed(args.seed) 142 | 143 | 144 | ############################################################################### 145 | # Load data 146 | ############################################################################### 147 | print('==> Preparing data..') 148 | trainloader, validloader, testloader, nb_classes, dim_inp = get_dataset(args) 149 | 150 | 151 | ############################################################################### 152 | # Build the model 153 | ############################################################################### 154 | epoch = 0 155 | if args.load_dir=='': 156 | inp_channels=3 157 | print('==> Building model..') 158 | if args.arch == 'resnet': 159 | model0 = ResNet_model(bn= args.bn, num_classes=nb_classes, depth=args.depth,\ 160 | inp_channels=inp_channels, k=args.k, affine=not args.noaffine, inp_noise=args.inp_noise, VIB=args.vib) 161 | elif args.arch == 'cnn': 162 | model0 = CNN(bn= args.bn, affine=not args.noaffine, num_classes=nb_classes, inp_noise=args.inp_noise, VIB=args.vib) 163 | else: 164 | with open(args.root_dir + '/' + args.load_dir + '/best_model.pt', 'rb') as f: 165 | best_state = torch.load(f) 166 | model0 = best_state['model'] 167 | epoch = best_state['epoch'] 168 | print('==> Loading model from epoch ', epoch) 169 | 170 | params = list(model0.parameters()) 171 | model = torch.nn.DataParallel(model0, device_ids=range(len(args.gpu))) 172 | 173 | 174 | adv_PGD_config = config = { 175 | 'epsilon': args.epsilon / (255), 176 | 'num_steps': args.nsteps, 177 | 'step_size': args.stepsz / (255.), 178 | 'random_start': True 179 | } 180 | AttackPGD_ = AttackPGD(adv_PGD_config) 181 | 182 | nb = 0 183 | if args.init == 'he': 184 | for m in model.modules(): 185 | if isinstance(m, nn.Conv2d): 186 | nb += 1 187 | # print ('Update init of ', m) 188 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 189 | m.weight.data.normal_(0, math.sqrt(2. / n)) 190 | elif isinstance(m, nn.BatchNorm2d) and not args.noaffine: 191 | # print ('Update init of ', m) 192 | m.weight.data.fill_(1) 193 | m.bias.data.zero_() 194 | print( 'Number of Conv layers: ', (nb)) 195 | 196 | 197 | 198 | if use_cuda: 199 | model.cuda() 200 | total_params = sum(np.prod(x.size()) if len(x.size()) > 1 else x.size()[0] for x in model.parameters()) 201 | print('Args:', args) 202 | print( 'Model total parameters:', total_params) 203 | with open(args.save_dir + '/log.txt', 'a') as f: 204 | f.write(str(args) + ',total_params=' + str(total_params) + '\n') 205 | 206 | criterion = nn.CrossEntropyLoss() 207 | 208 | 209 | ############################################################################### 210 | # Training/Testing code 211 | ############################################################################### 212 | 213 | 214 | def test(loader, model, save=False, epoch=0): 215 | global best_acc, args 216 | 217 | if args.bn_eval: # forward prop data twice to update BN running averages 218 | model.train() 219 | for _ in range(2): 220 | for batch_idx, (inputs, targets) in enumerate(loader): 221 | if use_cuda: 222 | inputs, targets = inputs.cuda(), targets.cuda() 223 | inputs, targets = Variable(inputs), Variable(targets) 224 | _ = (model(inputs, train=False)) 225 | 226 | model.eval() 227 | correct, total = 0,0 228 | tot_iters = len(loader) 229 | for batch_idx in tqdm.tqdm(range(tot_iters), total=tot_iters): 230 | inputs, targets = next(iter(loader)) 231 | if use_cuda: 232 | inputs, targets = inputs.cuda(), targets.cuda() 233 | with torch.no_grad(): 234 | inputs, targets = Variable(inputs), Variable(targets) 235 | outputs = (model(inputs, train=False)) 236 | 237 | _, predicted = torch.max(nn.Softmax(dim=1)(outputs).data, 1) 238 | total += targets.size(0) 239 | correct += predicted.eq(targets.data).cpu().sum() 240 | 241 | 242 | # Save checkpoint. 243 | acc = 100.*float(correct)/float(total) 244 | 245 | if save and acc > best_acc: 246 | best_acc = acc 247 | print('Saving best model..') 248 | state = { 249 | 'model': model0, 250 | 'epoch': epoch 251 | } 252 | with open(args.save_dir + '/best_model.pt', 'wb') as f: 253 | torch.save(state, f) 254 | return acc 255 | 256 | 257 | def train(epoch): 258 | global trainloader, optimizer, args, model, best_loss 259 | model.train() 260 | correct = 0 261 | total = 0 262 | total_loss, reg_loss, tot_regularization_loss = 0, 0, 0 263 | 264 | 265 | optimizer.zero_grad() 266 | tot_iters = len(trainloader) 267 | for batch_idx in tqdm.tqdm(range(tot_iters), total=tot_iters): 268 | inputs, targets = next(iter(trainloader)) 269 | if use_cuda: 270 | inputs, targets = inputs.cuda(), targets.cuda() 271 | 272 | inputs = Variable(inputs) 273 | 274 | if args.pgd: 275 | outputs = AttackPGD_(inputs, targets, model) 276 | loss = criterion(outputs, targets) 277 | elif args.alp: 278 | outputs_adv = AttackPGD_(inputs, targets, model) 279 | loss_adv = criterion(outputs_adv, targets) 280 | outputs = (model(inputs)) 281 | loss_clean = criterion(outputs, targets) 282 | loss = loss_clean + loss_adv 283 | reg_loss = pairing_loss(outputs_adv, outputs) 284 | elif args.clp: 285 | outputs = (model(inputs)) 286 | loss_clean = criterion(outputs, targets) 287 | loss = loss_clean 288 | reg_loss = pairing_loss(outputs, outputs, stochastic_pairing = True) 289 | elif args.vib: 290 | outputs, mn, logvar = model(inputs) 291 | loss = criterion(outputs, targets) 292 | reg_loss = -0.5 * torch.sum(1 + logvar - mn.pow(2) - logvar.exp())/inputs.size(0) 293 | else: 294 | outputs = (model(inputs)) 295 | loss = criterion(outputs, targets) 296 | 297 | tot_regularization_loss += reg_loss 298 | 299 | 300 | total_loss_ = loss + args.beta* reg_loss 301 | total_loss_.backward() # retain_graph=True 302 | 303 | total_loss += loss.data.cpu() 304 | _, predicted = torch.max(nn.Softmax(dim=1)(outputs).data, 1) 305 | total += targets.size(0) 306 | correct += predicted.eq(targets.data).cpu().sum() 307 | 308 | 309 | # nn.utils.clip_grad_norm_(model.parameters(), 0.1) 310 | optimizer.step() 311 | optimizer.zero_grad() 312 | 313 | 314 | 315 | acc = 100.*correct/total 316 | return total_loss/(batch_idx+1), acc, tot_regularization_loss/(batch_idx+1) 317 | 318 | 319 | optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.wdecay) 320 | 321 | best_acc, best_loss =0, np.inf 322 | train_loss_list, train_acc_list, valid_acc_list, test_acc_list, reg_loss_list = [], [], [], [], [] 323 | 324 | 325 | if args.anneal_beta: 326 | beta_ = args.beta 327 | args.beta = 0.0001 328 | def train_fn(): 329 | global epoch, args, best_loss, best_acc 330 | while epoch150)*255).type('torch.FloatTensor') 53 | x_rgb = torch.ones(x.size(0),3, x.size()[2], x.size()[3]).type('torch.FloatTensor') 54 | x_rgb = x_rgb* x 55 | x_rgb_fg = 1.*x_rgb 56 | 57 | color_choice = np.argmax(np.random.multinomial(1, cpr, targets.shape[0]), axis=1) if cpr is not None else 0 58 | c = Cfg[color_choice,targets] if cpr is not None else Cfg[color_choice,np.random.randint(nb_classes, size=targets.shape[0])] 59 | c = c.reshape(-1, 3, 1, 1) 60 | c= torch.from_numpy(c).type('torch.FloatTensor') 61 | x_rgb_fg[:,0] = x_rgb_fg[:,0]* c[:,0] 62 | x_rgb_fg[:,1] = x_rgb_fg[:,1]* c[:,1] 63 | x_rgb_fg[:,2] = x_rgb_fg[:,2]* c[:,2] 64 | 65 | bg = (255-x_rgb) 66 | # c = C[targets] if np.random.rand()>cpr else C[np.random.randint(C.shape[0], size=targets.shape[0])] 67 | color_choice = np.argmax(np.random.multinomial(1, cpr, targets.shape[0]), axis=1) if cpr is not None else 0 68 | c = Cbg[color_choice,targets] if cpr is not None else Cbg[color_choice,np.random.randint(nb_classes, size=targets.shape[0])] 69 | c = c.reshape(-1, 3, 1, 1) 70 | c= torch.from_numpy(c).type('torch.FloatTensor') 71 | bg[:,0] = bg[:,0]* c[:,0] 72 | bg[:,1] = bg[:,1]* c[:,1] 73 | bg[:,2] = bg[:,2]* c[:,2] 74 | x_rgb = x_rgb_fg + bg 75 | x_rgb = x_rgb + torch.tensor((noise)* np.random.randn(*x_rgb.size())).type('torch.FloatTensor') 76 | x_rgb = torch.clamp(x_rgb, 0.,255.) 77 | if i==0: 78 | color_data_x = np.zeros((bs* tot_iters, *img_size)) 79 | color_data_y = np.zeros((bs* tot_iters,)) 80 | color_data_x[i*bs: (i+1)*bs] = x_rgb/255. 81 | color_data_y[i*bs: (i+1)*bs] = targets 82 | return color_data_x, color_data_y 83 | 84 | dir_name = data_path + 'cmnist/' + 'fgbg_cmnist_cpr' + '-'.join(str(p) for p in args.cpr) + '/' 85 | print(dir_name) 86 | if not os.path.exists(data_path + 'cmnist/'): 87 | os.mkdir(data_path + 'cmnist/') 88 | if not os.path.exists(dir_name): 89 | os.mkdir(dir_name) 90 | 91 | 92 | color_data_x, color_data_y = gen_fgbgcolor_data(trainloader, img_size=(3,28,28), cpr=args.cpr, noise=10.) 93 | np.save(dir_name+ '/train_x.npy', color_data_x) 94 | np.save(dir_name+ '/train_y.npy', color_data_y) 95 | 96 | 97 | color_data_x, color_data_y = gen_fgbgcolor_data(testloader, img_size=(3,28,28), cpr=None, noise=10.) 98 | np.save(dir_name + 'test_x.npy', color_data_x) 99 | np.save(dir_name + 'test_y.npy', color_data_y) 100 | 101 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | Predicting with high correlation features 4 | 5 | ''' 6 | 7 | from argparse import Namespace 8 | import sys 9 | import argparse 10 | import math 11 | import numpy as np 12 | import os 13 | import torch 14 | import torch.nn as nn 15 | from torch.autograd import Variable 16 | from torch import autograd 17 | import pickle as pkl 18 | from models import ResNet_model, CNN 19 | import torch.nn.functional as F 20 | from utils import correlation_reg 21 | import glob 22 | import tqdm 23 | import torch.utils.data as utils 24 | import json 25 | from data import get_dataset 26 | 27 | parser = argparse.ArgumentParser(description='Predicting with high correlation features') 28 | 29 | # Directories 30 | parser.add_argument('--data', type=str, default='datasets/', 31 | help='location of the data corpus') 32 | parser.add_argument('--root_dir', type=str, default='default/', 33 | help='root dir path to save the log and the final model') 34 | parser.add_argument('--save_dir', type=str, default='0/', 35 | help='dir path (inside root_dir) to save the log and the final model') 36 | 37 | parser.add_argument('--load_dir', type=str, default='', 38 | help='dir path (inside root_dir) to load model from') 39 | 40 | 41 | # Baseline (correlation based) method 42 | parser.add_argument('--beta', type=float, default=1, 43 | help='coefficient for correlation based penalty') 44 | 45 | # adaptive batch norm 46 | parser.add_argument('--bn_eval', action='store_true', 47 | help='adapt BN stats during eval') 48 | 49 | # dataset and architecture 50 | parser.add_argument('--dataset', type=str, default='fgbg_cmnist_cpr0.5-0.5', 51 | help='dataset name') 52 | parser.add_argument('--arch', type=str, default='resnet', 53 | help='arch name (resnet,cnn)') 54 | parser.add_argument('--depth', type=int, default=56, 55 | help='number of resblocks if using resnet architecture') 56 | parser.add_argument('--k', type=int, default=1, 57 | help='widening factor for wide resnet architecture') 58 | 59 | # Optimization hyper-parameters 60 | parser.add_argument('--seed', type=int, default=1111, 61 | help='random seed') 62 | parser.add_argument('--bs', type=int, default=128, metavar='N', 63 | help='batch size') 64 | parser.add_argument('--bn', action='store_true', 65 | help='Use Batch norm') 66 | parser.add_argument('--noaffine', action='store_true', 67 | help='no affine transformations') 68 | parser.add_argument('--lr', type=float, default=0.001, 69 | help='learning rate ') 70 | parser.add_argument('--epochs', type=int, default=100, 71 | help='upper epoch limit') 72 | parser.add_argument('--init', type=str, default="he") 73 | parser.add_argument('--wdecay', type=float, default=0.0001, 74 | help='weight decay applied to all weights') 75 | 76 | 77 | # meta specifications 78 | parser.add_argument('--validation', action='store_true', 79 | help='Compute accuracy on validation set at each epoch') 80 | parser.add_argument('--cuda', action='store_false', 81 | help='use CUDA') 82 | parser.add_argument('--gpu', nargs='+', type=int, default=[0]) 83 | 84 | 85 | args = parser.parse_args() 86 | args.root_dir = os.path.join('runs/', args.root_dir) 87 | args.save_dir = os.path.join(args.root_dir, args.save_dir) 88 | if not os.path.exists(args.save_dir): 89 | os.makedirs(args.save_dir) 90 | log_dir = args.save_dir + '/' 91 | 92 | with open(args.save_dir + '/config.txt', 'w') as f: 93 | json.dump(args.__dict__, f, indent=2) 94 | with open(args.save_dir + '/log.txt', 'w') as f: 95 | f.write('python ' + ' '.join(s for s in sys.argv) + '\n') 96 | 97 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(i) for i in args.gpu) 98 | 99 | 100 | 101 | 102 | 103 | # Set the random seed manually for reproducibility. 104 | use_cuda = torch.cuda.is_available() 105 | torch.manual_seed(args.seed) 106 | if use_cuda: 107 | torch.cuda.manual_seed(args.seed) 108 | 109 | 110 | ############################################################################### 111 | # Load data 112 | ############################################################################### 113 | print('==> Preparing data..') 114 | trainloader, validloader, testloader, nb_classes, dim_inp = get_dataset(args) 115 | 116 | 117 | ############################################################################### 118 | # Build the model 119 | ############################################################################### 120 | if args.load_dir=='': 121 | inp_channels=3 122 | print('==> Building model..') 123 | if args.arch == 'resnet': 124 | model0 = ResNet_model(bn= args.bn, num_classes=nb_classes, depth=args.depth,\ 125 | inp_channels=inp_channels, k=args.k, affine=not args.noaffine) 126 | elif args.arch == 'cnn': 127 | model0 = CNN(bn= args.bn, affine=not args.noaffine, num_classes=nb_classes) 128 | else: 129 | with open(args.root_dir + '/' + args.load_dir + '/best_model.pt', 'rb') as f: 130 | best_state = torch.load(f) 131 | model0 = best_state['model'] 132 | 133 | params = list(model0.parameters()) 134 | model = torch.nn.DataParallel(model0, device_ids=range(len(args.gpu))) 135 | 136 | 137 | nb = 0 138 | if args.init == 'he': 139 | for m in model.modules(): 140 | if isinstance(m, nn.Conv2d): 141 | nb += 1 142 | # print ('Update init of ', m) 143 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 144 | m.weight.data.normal_(0, math.sqrt(2. / n)) 145 | elif isinstance(m, nn.BatchNorm2d) and not args.noaffine: 146 | # print ('Update init of ', m) 147 | m.weight.data.fill_(1) 148 | m.bias.data.zero_() 149 | print( 'Number of Conv layers: ', (nb)) 150 | 151 | 152 | 153 | if use_cuda: 154 | model.cuda() 155 | total_params = sum(np.prod(x.size()) if len(x.size()) > 1 else x.size()[0] for x in model.parameters()) 156 | print('Args:', args) 157 | print( 'Model total parameters:', total_params) 158 | with open(args.save_dir + '/log.txt', 'a') as f: 159 | f.write(str(args) + ',total_params=' + str(total_params) + '\n') 160 | 161 | criterion = nn.CrossEntropyLoss() 162 | 163 | 164 | ############################################################################### 165 | # Training/Testing code 166 | ############################################################################### 167 | 168 | 169 | def test(loader, model, save=False, epoch=0): 170 | global best_acc, args 171 | 172 | if args.bn_eval: # forward prop data twice to update BN running averages 173 | model.train() 174 | for _ in range(2): 175 | for batch_idx, (inputs, targets) in enumerate(loader): 176 | if use_cuda: 177 | inputs, targets = inputs.cuda(), targets.cuda() 178 | inputs, targets = Variable(inputs), Variable(targets) 179 | _ = (model(inputs, train=False)) 180 | 181 | model.eval() 182 | correct, total = 0,0 183 | tot_iters = len(loader) 184 | for batch_idx in tqdm.tqdm(range(tot_iters), total=tot_iters): 185 | inputs, targets = next(iter(loader)) 186 | if use_cuda: 187 | inputs, targets = inputs.cuda(), targets.cuda() 188 | with torch.no_grad(): 189 | inputs, targets = Variable(inputs), Variable(targets) 190 | outputs = (model(inputs, train=False)) 191 | 192 | _, predicted = torch.max(nn.Softmax(dim=1)(outputs).data, 1) 193 | total += targets.size(0) 194 | correct += predicted.eq(targets.data).cpu().sum() 195 | 196 | 197 | # Save checkpoint. 198 | acc = 100.*float(correct)/float(total) 199 | 200 | if save and acc > best_acc: 201 | best_acc = acc 202 | print('Saving best model..') 203 | state = { 204 | 'model': model0, 205 | 'epoch': epoch 206 | } 207 | with open(args.save_dir + '/best_model.pt', 'wb') as f: 208 | torch.save(state, f) 209 | return acc 210 | 211 | 212 | def train(epoch): 213 | global trainloader, optimizer, args, model, best_loss 214 | model.train() 215 | correct = 0 216 | total = 0 217 | total_loss, tot_regularization_loss = 0, 0 218 | 219 | 220 | optimizer.zero_grad() 221 | tot_iters = len(trainloader) 222 | for batch_idx in tqdm.tqdm(range(tot_iters), total=tot_iters): 223 | inputs, targets = next(iter(trainloader)) 224 | if use_cuda: 225 | inputs, targets = inputs.cuda(), targets.cuda() 226 | 227 | inputs = Variable(inputs) 228 | 229 | outputs, hid_repr = (model(inputs, ret_hid=True)) 230 | 231 | loss = criterion(outputs, targets) 232 | 233 | regularization_loss = 0 234 | regularization_loss = correlation_reg(hid_repr, targets.cpu().numpy()) 235 | tot_regularization_loss = tot_regularization_loss + regularization_loss.data 236 | 237 | total_loss_ = loss + args.beta* regularization_loss 238 | total_loss_.backward() 239 | 240 | total_loss += loss.data.cpu() 241 | _, predicted = torch.max(nn.Softmax(dim=1)(outputs).data, 1) 242 | total += targets.size(0) 243 | correct += predicted.eq(targets.data).cpu().sum() 244 | 245 | 246 | # nn.utils.clip_grad_norm_(model.parameters(), 0.1) 247 | optimizer.step() 248 | optimizer.zero_grad() 249 | 250 | 251 | 252 | acc = 100.*correct/total 253 | return total_loss/(batch_idx+1), acc, tot_regularization_loss/(batch_idx+1) 254 | 255 | 256 | optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.wdecay) 257 | 258 | best_acc, best_loss =0, np.inf 259 | train_loss_list, train_acc_list, valid_acc_list, test_acc_list, reg_loss_list = [], [], [], [], [] 260 | epoch = 0 261 | 262 | 263 | def train_fn(): 264 | global epoch, args, best_loss, best_acc 265 | while epoch0 and train: 67 | x = x + self.inp_noise*torch.randn_like(x) 68 | h=self.conv1(x) 69 | x = F.relu(self.bn1(h)) 70 | x = F.max_pool2d(x, 2, 2) 71 | 72 | x=self.conv2(x) 73 | x = F.relu(self.bn2(x)) 74 | 75 | x=self.conv3(x) 76 | x = F.relu(self.bn3(x)) 77 | x = F.max_pool2d(x, 2, 2) 78 | 79 | x=self.conv4(x) 80 | x = F.relu(self.bn4(x)) 81 | 82 | x = nn.AvgPool2d(*[x.size()[2]*2])(x) 83 | x = x.view(x.size()[0], -1) 84 | 85 | if self.VIB: 86 | mn = self.mn(x) 87 | logvar = self.logvar(x) 88 | x = reparameterize(mn,logvar) 89 | 90 | 91 | x = self.fc(x) 92 | if ret_hid: 93 | return x, h 94 | elif self.VIB and train: 95 | return out, mn, logvar 96 | else: 97 | return x 98 | 99 | 100 | class resblock(nn.Module): 101 | 102 | def __init__(self, depth, channels, stride=1, bn='', nresblocks=1.,affine=True, kernel_size=3, bias=True): 103 | self.depth = depth 104 | self. channels = channels 105 | 106 | super(resblock, self).__init__() 107 | self.bn1 = nn.BatchNorm2d(depth,affine=affine) if bn else nn.Sequential() 108 | self.conv2 = (nn.Conv2d(depth, channels, kernel_size=kernel_size, stride=stride, padding=1, bias=bias)) 109 | self.bn2 = nn.BatchNorm2d(channels, affine=affine) if bn else nn.Sequential() 110 | 111 | self.conv3 = nn.Conv2d(channels, channels, kernel_size=kernel_size, stride=1, padding=1, bias=bias) 112 | 113 | self.shortcut = nn.Sequential() 114 | if stride > 1 or depth!=channels: 115 | layers = [] 116 | conv_layer = nn.Conv2d(depth, channels, kernel_size=1, stride=stride, padding=0, bias=bias) 117 | layers += [conv_layer, nn.BatchNorm2d(channels,affine=affine) if bn else nn.Sequential()] 118 | self.shortcut = nn.Sequential(*layers) 119 | 120 | def forward(self, x): 121 | out = ACT(self.bn1(x)) 122 | out = ACT(self.bn2(self.conv2(out))) 123 | out = (self.conv3(out)) 124 | short = self.shortcut(x) 125 | out += 1.*short 126 | return out 127 | 128 | 129 | 130 | class ResNet(nn.Module): 131 | def __init__(self, depth=56, nb_filters=16, num_classes=10, bn=False, affine=True, kernel_size=3, inp_channels=3, k=1, pad_conv1=0, bias=False, inp_noise=0, VIB=False): # n=9->Resnet-56 132 | super(ResNet, self).__init__() 133 | self.inp_noise = inp_noise 134 | self.VIB = VIB 135 | nstage = 3 136 | 137 | self.pre_clf=[] 138 | 139 | assert ((depth-2)%6 ==0), 'resnet depth should be 6n+2' 140 | n = int((depth-2)/6) 141 | 142 | nfilters = [nb_filters, nb_filters*k, 2* nb_filters*k, 4* nb_filters*k, num_classes] 143 | self.nfilters = nfilters 144 | self.num_classes = num_classes 145 | self.conv1 = (nn.Conv2d(inp_channels, nfilters[0], kernel_size=kernel_size, stride=1, padding=pad_conv1, bias=bias)) 146 | self.bn1 = nn.BatchNorm2d(nfilters[0], affine=affine) if bn else nn.Sequential() 147 | 148 | 149 | nb_filters_prev = nb_filters_cur = nfilters[0] 150 | for stage in range(nstage): 151 | nb_filters_cur = nfilters[stage+1] 152 | for i in range(n): 153 | subsample = 1 if (i > 0 or stage == 0) else 2 154 | layer = resblock(nb_filters_prev, nb_filters_cur, subsample, bn=bn, nresblocks = nstage*n, affine=affine, kernel_size=3, bias=bias) 155 | self.pre_clf.append(layer) 156 | nb_filters_prev = nb_filters_cur 157 | 158 | self.pre_clf = nn.Sequential(*self.pre_clf) 159 | 160 | if self.VIB: 161 | self.mn = MLPLayer(nb_filters_cur, 256, 'none', act=False, bias=bias) 162 | self.logvar = MLPLayer(nb_filters_cur, 256, 'none', act=False, bias=bias) 163 | nb_filters_cur = 256 164 | 165 | self.fc = MLPLayer(nb_filters_cur, nfilters[-1], 'none', act=False, bias=bias) 166 | 167 | def forward(self, x, ret_hid=False, train=True): 168 | if x.size()[1]==1: # if MNIST is given, replicate 1 channel to make input have 3 channel 169 | out = torch.ones(x.size(0), 3, x.size(2), x.size(3)).type('torch.cuda.FloatTensor') 170 | out = out*x 171 | else: 172 | out = x 173 | 174 | if self.inp_noise>0 and train: 175 | out = out + self.inp_noise*torch.randn_like(out) 176 | hid = self.conv1(out) 177 | 178 | out = self.bn1(hid) 179 | 180 | out = self.pre_clf(out) 181 | 182 | fc = torch.mean(out.view(out.size(0), out.size(1), -1), dim=2) 183 | fc = fc.view(fc.size()[0], -1) 184 | 185 | if self.VIB: 186 | mn = self.mn(fc) 187 | logvar = self.logvar(fc) 188 | fc = reparameterize(mn,logvar) 189 | 190 | out = self.fc((fc)) 191 | 192 | 193 | if ret_hid: 194 | return out, hid 195 | elif self.VIB and train: 196 | return out, mn, logvar 197 | else: 198 | return out 199 | 200 | 201 | # Resnet nomenclature: 6n+2 = 3x2xn + 2; 3 stages, each with n number of resblocks containing 2 conv layers each, and finally 2 non-res conv layers 202 | def ResNet_model(bn=False, num_classes=10, depth=56, nb_filters=16, kernel_size=3, inp_channels=3, k=1, pad_conv1=0, affine=True, inp_noise=0, VIB=False): 203 | return ResNet(depth=depth, nb_filters=nb_filters, num_classes=num_classes, bn=bn, kernel_size=kernel_size, \ 204 | inp_channels=inp_channels, k=k, pad_conv1=pad_conv1, affine=affine, inp_noise=inp_noise, VIB=VIB) 205 | 206 | 207 | def reparameterize(mu, logvar): 208 | std = torch.exp(0.5*logvar) 209 | eps = torch.randn_like(std) 210 | return mu + eps*std 211 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch 5 | from torch.autograd import Variable 6 | import numpy as np 7 | import torch.autograd as autograd 8 | import torch.nn.functional as F 9 | from torch.nn.parameter import Parameter 10 | from torch.nn.init import zeros_, ones_ 11 | import tqdm 12 | 13 | class AttackPGD(nn.Module): 14 | def __init__(self, config): 15 | super(AttackPGD, self).__init__() 16 | self.rand = config['random_start'] 17 | self.step_size = config['step_size'] 18 | self.epsilon = config['epsilon'] 19 | self.num_steps = config['num_steps'] 20 | 21 | def forward(self, inputs, targets, basic_net): 22 | 23 | x = inputs.detach() 24 | if self.rand: 25 | x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon) 26 | for i in range(self.num_steps): 27 | x.requires_grad_() 28 | with torch.enable_grad(): 29 | logits = basic_net(x) 30 | loss = F.cross_entropy(logits, targets, reduction='sum') 31 | grad = torch.autograd.grad(loss, [x])[0] 32 | x = x.detach() + self.step_size*torch.sign(grad.detach()) 33 | x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon) 34 | x = torch.clamp(x, 0, 1) 35 | 36 | return basic_net(x) 37 | 38 | def pairing_loss(logit1, logit2, stochastic_pairing=False): 39 | if stochastic_pairing: 40 | exchanged_idx = np.random.permutation(logit1.shape[0]) 41 | stoc_target_logit2 = logit2[exchanged_idx] 42 | loss = torch.sum( (stoc_target_logit2-logit1)**2 )/logit1.size()[0] 43 | else: 44 | loss = torch.sum( (logit2-logit1)**2 )/logit1.size()[0] 45 | return loss 46 | 47 | def dim_permute(h): 48 | if len(h.size())>2: 49 | h=h.permute(1,0,2,3).contiguous() 50 | h = h.view(h.size(0), -1) 51 | else: 52 | h=h.permute(1,0).contiguous() 53 | h = h.view(h.size(0),-1) 54 | return h 55 | 56 | 57 | def compute_l2_norm(h, subtract_mean=False): 58 | h = dim_permute(h) 59 | N = (h.size(1)) 60 | if subtract_mean: 61 | mn = (h).mean(dim=1, keepdim=True) 62 | h = h-mn 63 | 64 | l2_norm = (h**2).sum() 65 | return torch.sqrt(l2_norm) 66 | 67 | def correlation_reg(hid, targets, within_class=True, subtract_mean=True): 68 | norm_fn = compute_l2_norm 69 | if within_class: 70 | uniq = np.unique(targets) 71 | reg_=0 72 | for u in uniq: 73 | idx = np.where(targets==u)[0] 74 | 75 | norm = norm_fn(hid[idx], subtract_mean=subtract_mean) 76 | reg_ += (norm)**2 77 | else: 78 | norm = norm_fn(hid, subtract_mean=subtract_mean) 79 | reg_ = (norm)**2 80 | return reg_ 81 | 82 | 83 | 84 | def idx2onehot(idx, n, h=1, w=1): 85 | 86 | assert torch.max(idx).item() < n 87 | if idx.dim() == 1: 88 | idx = idx.unsqueeze(1) 89 | 90 | onehot = torch.zeros(idx.size(0), n).cuda() 91 | onehot.scatter_(1, idx, 1) 92 | if h*w>1: 93 | onehot = onehot.view(idx.size(0), n, 1, 1) 94 | onehot_tensor = torch.ones(idx.size(0), n, h, w).cuda() 95 | onehot = onehot_tensor* onehot 96 | return onehot 97 | 98 | 99 | def _split_train_val(trainset, val_fraction=0, nsamples=-1): 100 | if nsamples>-1: 101 | n_train, n_val = int(nsamples), len(trainset)-int(nsamples) 102 | else: 103 | n_train = int((1. - val_fraction) * len(trainset)) 104 | n_val = len(trainset) - n_train 105 | train_subset, val_subset = torch.utils.data.random_split(trainset, (n_train, n_val)) 106 | return train_subset, val_subset 107 | 108 | 109 | class add_gaussian_noise(): 110 | def __init__(self, std): 111 | self.std = std 112 | def __call__(self,x): 113 | noise = self.std*torch.randn_like(x) 114 | return x + noise 115 | 116 | --------------------------------------------------------------------------------