├── .gitignore ├── README.md ├── datasets ├── __init__.py ├── cifar10.py └── mnist.py ├── deepdefense.py ├── deepdefense.sh └── models ├── __init__.py ├── cifar10.py └── mnist.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | 93 | # Rope project settings 94 | .ropeproject 95 | 96 | # data 97 | data/ 98 | data 99 | 100 | # PyCharm 101 | .idea/ 102 | 103 | # macOS 104 | .DS_Store 105 | 106 | # output directory 107 | output/ 108 | 109 | # python import cache 110 | *.pyc 111 | 112 | # nfs temp file 113 | .nfs* 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # deepdefense.pytorch 2 | Code for NeurIPS 2018 paper [Deep Defense: Training DNNs with Improved Adversarial Robustness](https://papers.nips.cc/paper/7324-deep-defense-training-dnns-with-improved-adversarial-robustness). 3 | 4 | Deep Defense is recipe to improve the robustness of DNNs to adversarial perturbations. We integrate an adversarial perturbation-based regularizer into the training objective, such that the obtained models learn to resist potential attacks in a principled way. 5 | 6 | ## Environments 7 | * Python 3.5 8 | * PyTorch 0.4.1 9 | * glog 0.3.1 10 | 11 | ## Datasets and Reference Models 12 | For fair comparison with DeepFool, we follow it to use [matconvnet](https://github.com/vlfeat/matconvnet/releases/tag/v1.0-beta24) to pre-process data and train reference models for MNIST and CIFAR-10. 13 | 14 | Please download processed datasets and reference models (including MNIST and CIFAR-10) at [Google Drive](https://drive.google.com/open?id=15xoZ-LUbc9GZpTlxmCJmvL_DR2qYEu2J) or [Baidu Pan](https://pan.baidu.com/s/1-TSXR8kVcat7IXtuE74nJg). 15 | 16 | For the MLP with batch normalization example [(issue 2)](https://github.com/ZiangYan/deepdefense.pytorch/issues/2), please download the reference model at [Google Drive](https://drive.google.com/open?id=1Vy4xWeXhOX_QluYH33SHVE3q_KDOOBeF) or [Baidu Pan](https://pan.baidu.com/s/1cIgGX6b-1AQ4ybSyX2xDew). 17 | 18 | ## Usage 19 | To train a Deep Defense LeNet model using default parameters on MNIST: 20 | 21 | ``` 22 | python3 deepdefense.py --pretest --dataset mnist --arch LeNet 23 | ``` 24 | 25 | Argument ```--pretest``` indicates evaluating performance before fine-tuning, thus we can check the performance of reference model. 26 | 27 | Currently we've implemented ```MLP``` and ```LeNet``` for mnist, and ```ConvNet``` for CIFAR-10. 28 | 29 | ## Citation 30 | Please cite our work in your publications if it helps your research: 31 | 32 | ``` 33 | @inproceedings{yan2018deep, 34 | title={Deep Defense: Training DNNs with Improved Adversarial Robustness}, 35 | author={Yan, Ziang and Guo, Yiwen and Zhang, Changshui}, 36 | booktitle={Advances in Neural Information Processing Systems}, 37 | pages={417--426}, 38 | year={2018} 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZiangYan/deepdefense.pytorch/48621f7d40c5c7f3470b59a77cc42a0e18e2f0bd/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | import numpy as np 3 | 4 | 5 | class CIFAR10Dataset(torch.utils.data.Dataset): 6 | def __init__(self, phase='train', num_val=5000): 7 | import scipy.io as sio 8 | imdb = sio.loadmat('data/cifar10-data-ce5d97dd.mat') 9 | images = imdb['images'][0][0][0].transpose() 10 | sets = imdb['images'][0][0][2].flatten() 11 | labels = (imdb['images'][0][0][1].flatten() - 1).astype(np.int64) 12 | train_idx = np.where(sets == 1)[0][num_val:] 13 | val_idx = np.where(sets == 1)[0][:num_val] 14 | trainval_idx = np.where(sets == 1)[0] 15 | test_idx = np.where(sets == 3)[0] 16 | assert phase in ['train', 'val', 'trainval', 'test'] 17 | self.images = eval('images[%s_idx]' % phase) 18 | self.labels = eval('labels[%s_idx]' % phase) 19 | self.perm = np.arange(self.labels.size) 20 | 21 | def __getitem__(self, index): 22 | if np.random.rand() > 0.5: 23 | images = np.fliplr(self.images[self.perm[index]]).copy() 24 | else: 25 | images = self.images[self.perm[index]] 26 | return images, self.labels[self.perm[index]] 27 | 28 | def __len__(self): 29 | return self.labels.size 30 | 31 | def shuffle(self, perm): 32 | self.perm = perm 33 | -------------------------------------------------------------------------------- /datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | import numpy as np 3 | import scipy.io as sio 4 | 5 | 6 | num_val = 10000 # first num_val examples in training set is used as validation set 7 | 8 | 9 | class MNISTDataset(torch.utils.data.Dataset): 10 | def __init__(self, phase='train'): 11 | imdb = sio.loadmat('data/mnist-data-0208ce21.mat') 12 | images = imdb['images'][0][0][0].transpose() 13 | sets = imdb['images'][0][0][3].flatten() 14 | labels = imdb['images'][0][0][2].flatten() - 1 15 | train_idx = np.where(sets == 1)[0][num_val:] 16 | val_idx = np.where(sets == 1)[0][:num_val] 17 | trainval_idx = np.where(sets == 1)[0] 18 | test_idx = np.where(sets == 3)[0] 19 | mean = imdb['images'][0][0][1].transpose() 20 | assert phase in ['train', 'val', 'trainval', 'test'] 21 | self.images = eval('images[%s_idx]' % phase) 22 | self.labels = eval('labels[%s_idx]' % phase) 23 | self.mean = mean 24 | self.perm = np.arange(self.labels.size) 25 | 26 | def shuffle(self, perm): 27 | self.perm = perm 28 | 29 | def __getitem__(self, index): 30 | return self.images[self.perm[index]], self.labels[self.perm[index]] 31 | 32 | def __len__(self): 33 | return self.labels.size 34 | -------------------------------------------------------------------------------- /deepdefense.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | import os 4 | import os.path as osp 5 | import glog as log 6 | import argparse 7 | import json 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | 15 | from datasets.mnist import MNISTDataset 16 | from datasets.cifar10 import CIFAR10Dataset 17 | from models.mnist import LeNet, InverseLeNet, MLP, InverseMLP, MLPBN, InverseMLPBN, BNTranspose 18 | from models.cifar10 import ConvNet, InverseConvNet, NIN, InverseNIN 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description='Use DeepDefense to improve robustness') 23 | parser.add_argument('--lr', default=0.0005, type=float, 24 | help='learning rate') 25 | parser.add_argument('--epochs', default=5, type=int, 26 | help='number of epochs to train') 27 | parser.add_argument('--max-iter', default=5, type=int, 28 | help='max iteration in deepfool attack') 29 | parser.add_argument('--lmbd', default=15, type=float, 30 | help='lmbd in regularization term') 31 | parser.add_argument('--c', default=25, type=float, 32 | help='c in regularization term') 33 | parser.add_argument('--d', default=5, type=float, 34 | help='d in regularization term') 35 | parser.add_argument('--decay', default=0.0005, type=float, 36 | help='weight decay') 37 | parser.add_argument('--batch', default=100, type=int, 38 | help='actual batch size in each iteration during training. ' 39 | 'we use gradient accumulation if args.batch < args.train_batch') 40 | parser.add_argument('--train-batch', default=100, type=int, 41 | help='training batch size. we always collect args.train_batch samples for one update') 42 | parser.add_argument('--test-batch', default=100, type=int, 43 | help='test batch size') 44 | parser.add_argument('--exp-dir', default='output/debug', type=str, 45 | help='directory to save models and logs for current experiment') 46 | parser.add_argument('--pretest', action='store_true', 47 | help='evaluate model before training') 48 | parser.add_argument('--seed', default=1234, type=int, 49 | help='random seed') 50 | parser.add_argument('--dataset', default='mnist', type=str, 51 | help='which dataset to use, e.g., mnist or cifar10') 52 | parser.add_argument('--arch', default='LeNet', type=str, 53 | help='network architecture, e.g., LeNet or MLP') 54 | 55 | if len(sys.argv) == 1: 56 | parser.print_help() 57 | sys.exit(1) 58 | 59 | args = parser.parse_args() 60 | return args 61 | 62 | 63 | class DeepFool(nn.Module): 64 | def __init__(self): 65 | super(DeepFool, self).__init__() 66 | 67 | self.num_labels = 10 68 | self.overshot = 0.02 69 | self.max_iter = args.max_iter 70 | 71 | # initialize net 72 | if args.dataset == 'mnist': 73 | assert args.arch in ['MLP', 'MLPBN', 'LeNet'] 74 | elif args.dataset == 'cifar10': 75 | assert args.arch in ['ConvNet', 'NIN'] 76 | else: 77 | raise NotImplementedError 78 | self.net = eval('%s()' % args.arch) 79 | self.net.load_weights() 80 | log.info(self.net) 81 | 82 | # initialize inversenet 83 | self.inverse_net = eval('Inverse%s()' % args.arch) 84 | log.info(self.inverse_net) 85 | self.inverse_net.copy_from(self.net) 86 | 87 | self.net.cuda() 88 | self.inverse_net.cuda() 89 | 90 | self.eps = 5e-6 if args.dataset == 'mnist' else 1e-5 # protect norm againse nan 91 | 92 | def net_forward(self, input_image): 93 | return self.net.forward(input_image.cuda()) 94 | 95 | def inversenet_backward(self, input_image, idx): 96 | return self.inverse_net.forward_from_net(self.net, input_image, idx) 97 | 98 | def project_boundary_polyhedron(self, input_grad_, output_): 99 | batch_size = input_grad_.size()[0] # e.g., 100 for mnist 100 | image_dim = input_grad_.size()[1] # e.g., 784 for mnist 101 | # project under l_2 norm 102 | res_ = torch.abs(output_) / torch.norm(input_grad_ + self.eps, p=2, dim=1).view(output_.size()) 103 | _, ii = torch.min(res_, 1) 104 | 105 | # dir_ = res_[np.arange(batch_size), ii.data].view(batch_size, 1) 106 | # advanced indexing seems to be buggy in pytorch 0.3.x, we use gather instead 107 | dir_ = res_.gather(1, ii.view(batch_size, 1)) 108 | 109 | w = input_grad_.gather( 110 | 2, ii.view(batch_size, 1, 1).expand(batch_size, image_dim, 1)).view(batch_size, image_dim) 111 | dir_ = dir_ * w / torch.norm(w + self.eps, p=2, dim=1).view(batch_size, 1) 112 | return dir_ 113 | 114 | def forward_correct(self, input_image, label=None, pred=None, check=True): 115 | # this function is called when an image is correctly classified 116 | # label should be true label during training, and None during test 117 | 118 | num_image = input_image.size()[0] 119 | image_shape = input_image.size() 120 | self.label = pred.copy() 121 | if check: 122 | if self.training: 123 | # label should be true label 124 | assert label is not None 125 | assert np.all(self.label == label) 126 | else: 127 | # label should be None 128 | assert label is None 129 | outputt = self.net_forward(input_image) 130 | idx = torch.from_numpy(self.label).cuda().view(num_image, 1) 131 | output = outputt - outputt.gather(1, idx).expand_as(outputt) 132 | 133 | _, target_labels = torch.sort(-output, dim=1) 134 | target_labels = target_labels.data[:, :self.num_labels] 135 | 136 | ww = self.inversenet_backward(input_image, target_labels) 137 | w = ww - ww[:, :, 0].contiguous().view(ww.size()[0], ww.size()[1], 1).expand_as(ww) 138 | 139 | self.noises = dict() 140 | self.inputs_perturbed = dict() 141 | self.inputs_perturbed['step_0'] = input_image 142 | self.label_perturbed = self.label.copy() 143 | self.iteration = 0 144 | self.fooled = np.zeros(num_image).astype(np.bool) 145 | 146 | while True: 147 | self.iteration += 1 148 | noise_this_step = \ 149 | self.project_boundary_polyhedron(w[:, :, 1:], output.gather(1, target_labels[:, 1:].cuda())) 150 | 151 | # if an image is already successfully fooled, no more perturbation should be applied to it 152 | t = torch.from_numpy(np.logical_not(self.fooled).astype(np.float32).copy()).cuda() 153 | t = t.view(num_image, 1).expand(num_image, noise_this_step.size()[1]) 154 | self.noise_this_step = noise_this_step * t 155 | 156 | self.inputs_perturbed['step_%d' % self.iteration] = \ 157 | self.inputs_perturbed['step_%d' % (self.iteration - 1)] + self.noise_this_step.view(image_shape) 158 | if len(self.noises) == 0: 159 | self.noises['step_%d' % self.iteration] = self.noise_this_step 160 | else: 161 | self.noises['step_%d' % self.iteration] = \ 162 | self.noises['step_%d' % (self.iteration - 1)] + self.noise_this_step 163 | 164 | # test whether we have successfully fooled these images 165 | _, t = torch.max(self.net_forward( 166 | input_image + (1 + self.overshot) * self.noises['step_%d' % self.iteration].view(image_shape)), 1) 167 | t = t.data.cpu().numpy().flatten() 168 | for i in range(num_image): 169 | # iterate over all images 170 | if not self.fooled[i]: 171 | self.label_perturbed[i] = t[i] 172 | if t[i] != self.label[i]: 173 | self.fooled[i] = True 174 | 175 | if np.all(self.fooled): 176 | # quit if already fooled all images 177 | break 178 | if self.iteration == self.max_iter: 179 | # quit if max iteration 180 | break 181 | # if not quit, prepare the next fooling iteration 182 | 183 | outputt = self.net_forward(self.inputs_perturbed['step_%d' % self.iteration]) 184 | idx = torch.from_numpy(self.label).cuda().view(num_image, 1) 185 | output = outputt - outputt.gather(1, idx).expand_as(outputt) 186 | 187 | ww = self.inversenet_backward(self.inputs_perturbed['step_%d' % self.iteration], target_labels) 188 | w = ww - ww[:, :, 0].contiguous().view(ww.size()[0], ww.size()[1], 1).expand_as(ww) 189 | 190 | return (1 + self.overshot) * self.noises['step_%d' % self.iteration] 191 | 192 | def forward_wrong(self, input_image, label, pred, check=True): 193 | # this function is called when an image is incorrectly classified 194 | # this function is only called during test, and label is true label 195 | 196 | num_image = input_image.size()[0] 197 | image_shape = input_image.size() 198 | self.label = pred.copy() 199 | if check: 200 | assert self.training 201 | assert label is not None 202 | assert np.all(self.label != label) 203 | 204 | idx = torch.from_numpy(self.label).cuda().view(num_image, 1) 205 | outputt = self.net_forward(input_image) 206 | output = outputt - outputt.gather(1, idx).expand_as(outputt) 207 | 208 | target_labels = torch.from_numpy(np.vstack((self.label, label)).T).cuda() 209 | 210 | ww = self.inversenet_backward(input_image, target_labels) 211 | w = ww - ww[:, :, 0].contiguous().view(ww.size()[0], ww.size()[1], 1).expand_as(ww) 212 | 213 | self.noises = dict() 214 | self.inputs_perturbed = dict() 215 | self.inputs_perturbed['step_0'] = input_image 216 | self.label_perturbed = self.label.copy() 217 | self.iteration = 0 218 | self.fooled = np.zeros(num_image).astype(np.bool) 219 | 220 | while True: 221 | self.iteration += 1 222 | noise_this_step = \ 223 | self.project_boundary_polyhedron(w[:, :, 1:], output.gather(1, target_labels[:, 1:].cuda())) 224 | 225 | t = torch.from_numpy(np.logical_not(self.fooled).astype(np.float32)).cuda() 226 | t = t.view(num_image, 1).expand(num_image, noise_this_step.size()[1]) 227 | self.noise_this_step = noise_this_step * t 228 | 229 | self.inputs_perturbed['step_%d' % self.iteration] = \ 230 | self.inputs_perturbed['step_%d' % (self.iteration - 1)] + self.noise_this_step.view(image_shape) 231 | if len(self.noises) == 0: 232 | self.noises['step_%d' % self.iteration] = self.noise_this_step 233 | else: 234 | self.noises['step_%d' % self.iteration] = \ 235 | self.noises['step_%d' % (self.iteration - 1)] + self.noise_this_step 236 | 237 | _, t = torch.max(self.net_forward( 238 | input_image + (1 + self.overshot) * self.noises['step_%d' % self.iteration].view(image_shape)), 1) 239 | t = t.data.cpu().numpy().flatten() 240 | for i in range(num_image): 241 | if not self.fooled[i]: 242 | self.label_perturbed[i] = t[i] 243 | if t[i] == label[i]: 244 | self.fooled[i] = True 245 | 246 | if np.all(self.fooled): 247 | break 248 | if self.iteration == self.max_iter: 249 | break 250 | 251 | outputt = self.net_forward(self.inputs_perturbed['step_%d' % self.iteration]) 252 | idx = torch.from_numpy(self.label).cuda().view(num_image, 1) 253 | output = outputt - outputt.gather(1, idx).expand_as(outputt) 254 | 255 | # target will change as fooling process goes on 256 | # this is different from forward_correct 257 | self.label = outputt.data.cpu().numpy().argmax(axis=1) 258 | target_labels = torch.from_numpy(np.vstack((self.label, label)).T).cuda() 259 | 260 | ww = self.inversenet_backward(self.inputs_perturbed['step_%d' % self.iteration], target_labels) 261 | w = ww - ww[:, :, 0].contiguous().view(ww.size()[0], ww.size()[1], 1).expand_as(ww) 262 | 263 | return (1 + self.overshot) * self.noises['step_%d' % self.iteration] 264 | 265 | def forward(self, input_image): 266 | # this function should only be used during test 267 | # in training, use forward_correct and forward_wrong instead 268 | assert not self.training 269 | return self.forward_correct(input_image, check=False) 270 | 271 | 272 | def test(model, phases='test'): 273 | model.eval() 274 | result = dict() 275 | if isinstance(phases, str): 276 | phases = [phases] 277 | for phase in phases: 278 | log.info('Evaluating deepfool robustness, phase=%s' % phase) 279 | loader = eval('%s_loader' % phase) 280 | 281 | num_image = len(loader.dataset) 282 | assert num_image % len(loader) == 0 283 | log.info('Found %d images' % num_image) 284 | 285 | accuracy = np.zeros(num_image) 286 | ce_loss = np.zeros(num_image) 287 | noise_norm = np.zeros(num_image) 288 | ratio = np.zeros(num_image) 289 | iteration = np.zeros(num_image) 290 | 291 | for index, (image, label) in enumerate(loader): 292 | # get one batch 293 | image_var = image.cuda() 294 | image_var.requires_grad = True 295 | label_var = label.long().cuda() 296 | selected = np.arange(index * args.test_batch, (index + 1) * args.test_batch) 297 | 298 | # calculate cross entropy 299 | forward_result_var = model.net(image_var) 300 | ce_loss_var = F.cross_entropy(forward_result_var, label_var) 301 | ce_loss[selected] = ce_loss_var.data.cpu().numpy() 302 | pred = forward_result_var.data.cpu().numpy().argmax(axis=1) 303 | 304 | # calculate accuracy 305 | accuracy[selected] = pred == label 306 | 307 | # calculate perturbation norm 308 | noise_var = model.forward_correct(image_var, label=label.cpu().numpy(), pred=pred, check=False) 309 | noise_loss_var = torch.norm(noise_var, dim=1) 310 | noise_norm[selected] = noise_loss_var.data.cpu().numpy().flatten() 311 | 312 | # calculate ratio 313 | # l_2 norm 314 | t = torch.norm(image_var.view(args.test_batch, -1), dim=1).data.cpu().numpy().flatten() 315 | ratio[selected] = noise_norm[selected] / t 316 | 317 | # save number of iteration 318 | iteration[selected] = model.iteration 319 | 320 | n = (index + 1) * args.test_batch 321 | if n % 1000 == 0: 322 | log.info('Evaluating %s set %d / %d,' % (phase, n, num_image)) 323 | log.info('\tnoise_norm\t: %f' % (noise_norm.sum() / n)) 324 | log.info('\tratio\t\t: %f' % (ratio.sum() / n)) 325 | log.info('\tce_loss\t\t: %f' % (ce_loss.sum() / n)) 326 | log.info('\taccuracy\t: %f' % (accuracy.sum() / n)) 327 | log.info('\titeartion\t: %f' % (iteration.sum() / n)) 328 | 329 | log.info('Performance on %s set is:' % phase) 330 | log.info('\tnoise_norm\t: %f' % noise_norm.mean()) 331 | log.info('\tratio\t\t: %f' % ratio.mean()) 332 | log.info('\tce_loss\t\t: %f' % ce_loss.mean()) 333 | log.info('\taccuracy\t: %f' % accuracy.mean()) 334 | 335 | result['%s_accuracy' % phase] = accuracy.mean() 336 | result['%s_ratio' % phase] = ratio.mean() 337 | 338 | log.info('Performance of current model is:') 339 | for phase in ['train', 'val', 'test']: 340 | if '%s_accuracy' % phase in result: 341 | log.info('\t%s accuracy\t: %f' % (phase, result['%s_accuracy' % phase])) 342 | log.info('\t%s ratio\t: %f' % (phase, result['%s_ratio' % phase])) 343 | 344 | 345 | def train(model): 346 | num_epoch = args.epochs 347 | def trainable(name): 348 | if 'bn' in name: 349 | return False 350 | return True 351 | trainable_parameters = list(p[1] for p in model.named_parameters() if trainable(p[0])) 352 | optimizer = optim.SGD(trainable_parameters, lr=args.lr, weight_decay=args.decay, momentum=0.9) 353 | log.info('Train {} params among all {} params'.format(len(trainable_parameters), len(list(model.parameters())))) 354 | log.info('Trainable param list: {}'.format(list(p[0] for p in model.named_parameters() if trainable(p[0])))) 355 | num_image = len(train_loader.dataset) 356 | log.info('Found %d images' % num_image) 357 | 358 | assert (args.train_batch % args.batch == 0) and (args.train_batch >= args.batch) 359 | 360 | for epoch_idx in range(num_epoch): 361 | log.info('Training for %d epoch' % epoch_idx) 362 | model.zero_grad() 363 | 364 | # reduce learning 365 | if epoch_idx == (0.8 * args.epochs): 366 | for param_group in optimizer.param_groups: 367 | lr = param_group['lr'] 368 | new_lr = lr * 0.5 369 | param_group['lr'] = new_lr 370 | log.info('epoch %d, cut learning rate from %f to %f' % (epoch_idx, lr, new_lr)) 371 | 372 | perm = np.random.permutation(num_image) 373 | train_loader.dataset.shuffle(perm) 374 | model.train() 375 | for m in model.modules(): 376 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, BNTranspose)): 377 | m.eval() 378 | for index, (image, label) in enumerate(train_loader): 379 | batch_in_train_batch = index % (args.train_batch // args.batch) 380 | if batch_in_train_batch == 0: 381 | noise_norm = np.zeros(args.train_batch) 382 | ratio = np.zeros(args.train_batch) 383 | ce_loss = np.zeros(args.train_batch) 384 | loss = np.zeros(args.train_batch) 385 | accuracy = np.zeros(args.train_batch) 386 | grad_norm = np.zeros(args.train_batch) 387 | optimizer.zero_grad() 388 | 389 | # get one batch data 390 | if (args.dataset == 'cifar10') and (args.arch == 'NIN') and (np.random.rand() < 0.5): 391 | # flip with a probability of 50% 392 | inv = torch.arange(image.size(2) - 1, -1, -1).long() 393 | image = image.index_select(2, inv) 394 | image_var = image.cuda() 395 | image_var.requires_grad = True 396 | label_var = label.long().cuda() 397 | # selected index in train batch, used to store ce_loss and loss 398 | selected_in_train_batch = np.arange(batch_in_train_batch * args.batch, 399 | (batch_in_train_batch + 1) * args.batch).astype(np.int) 400 | 401 | # split pos and neg 402 | forward_result_var = model.net(image_var) 403 | _, pred = torch.max(forward_result_var, 1) 404 | pred = pred.data.cpu().numpy().flatten() 405 | pos_idx = np.where(pred == label)[0] 406 | neg_idx = np.where(pred != label)[0] 407 | accuracy[selected_in_train_batch] = pred == label 408 | 409 | # adversarial training 410 | ce_loss_var = F.cross_entropy(forward_result_var, label_var) 411 | ce_loss_var = ce_loss_var * args.batch / args.train_batch 412 | ce_loss_var.backward(retain_graph=True) 413 | ce_loss[selected_in_train_batch] = ce_loss_var.data.cpu().numpy() 414 | 415 | if (args.lmbd > 0) and (pos_idx.size > 0): 416 | pos_idx_var = torch.from_numpy(pos_idx).cuda() 417 | pos_image = image_var.index_select(0, pos_idx_var) 418 | noise_var = model.forward_correct(input_image=pos_image, 419 | label=label[pos_idx], 420 | pred=pred[pos_idx], 421 | check=True) 422 | noise_norm[batch_in_train_batch * args.batch + pos_idx] = torch.norm(noise_var, 423 | dim=1).data.cpu().numpy().flatten() 424 | # l_2 norm 425 | ratio[batch_in_train_batch * args.batch + pos_idx] = \ 426 | noise_norm[batch_in_train_batch * args.batch + pos_idx] / \ 427 | torch.norm(pos_image.view(pos_idx.size, -1), dim=1).data.cpu().numpy().flatten() 428 | 429 | # calculate perturbation norm 430 | noise_loss_var = torch.norm(noise_var, dim=1) 431 | t = pos_image.view(pos_idx.size, -1) 432 | noise_loss_var = noise_loss_var / torch.norm(t, dim=1) 433 | 434 | loss_var = args.lmbd * torch.exp(-args.c * noise_loss_var) 435 | loss_var = loss_var.sum() 436 | loss[batch_in_train_batch * args.batch + pos_idx] = loss_var.data.cpu().numpy() / args.batch 437 | 438 | # BP 439 | loss_var = loss_var / args.train_batch 440 | loss_var.backward() 441 | 442 | if (args.lmbd > 0) and (neg_idx.size > 0): 443 | neg_idx_var = torch.from_numpy(neg_idx).cuda() 444 | neg_image = image_var.index_select(0, neg_idx_var) 445 | noise_var = model.forward_wrong(input_image=neg_image, 446 | label=label[neg_idx], 447 | pred=pred[neg_idx], 448 | check=True) 449 | noise_norm[batch_in_train_batch * args.batch + neg_idx] = \ 450 | torch.norm(noise_var, dim=1).data.cpu().numpy().flatten() 451 | 452 | # l_2 norm 453 | ratio[batch_in_train_batch * args.batch + neg_idx] = \ 454 | noise_norm[batch_in_train_batch * args.batch + neg_idx] /\ 455 | torch.norm(neg_image.view(neg_idx.size, -1), dim=1).data.cpu().numpy().flatten() 456 | 457 | # calculate perturbation norm 458 | noise_loss_var = torch.norm(noise_var, dim=1) 459 | t = neg_image.view(neg_idx.size, -1) 460 | noise_loss_var = noise_loss_var / torch.norm(t, dim=1) 461 | 462 | loss_var = args.lmbd * torch.exp(args.d * noise_loss_var) 463 | loss_var = loss_var.sum() 464 | loss[batch_in_train_batch * args.batch + neg_idx] = loss_var.data.cpu().numpy() / args.batch 465 | 466 | # BP 467 | loss_var = loss_var / args.train_batch 468 | loss_var.backward() 469 | 470 | # calculate grad norm 471 | for p in model.parameters(): 472 | if p.grad is not None: 473 | grad_norm[selected_in_train_batch] += p.grad.data.norm(2) ** 2 474 | grad_norm[selected_in_train_batch] = np.sqrt(grad_norm[selected_in_train_batch]) 475 | 476 | # update weights 477 | if batch_in_train_batch == (args.train_batch / args.batch - 1): 478 | optimizer.step() 479 | optimizer.zero_grad() 480 | model.zero_grad() 481 | 482 | log.info('Processing %d - %d / %d' % ((index + 1) * args.batch - args.train_batch, 483 | (index + 1) * args.batch, num_image)) 484 | log.info('\tnoise_norm\t: %f' % noise_norm.mean()) 485 | log.info('\tgrad_norm\t: %f' % grad_norm.mean()) 486 | log.info('\tratio\t\t: %f' % ratio.mean()) 487 | log.info('\tce_loss\t\t: %f' % ce_loss.mean()) 488 | log.info('\tloss\t\t: %f' % loss.mean()) 489 | log.info('\taccuracy\t: %f' % accuracy.mean()) 490 | 491 | # evaluate and save model after each epoch 492 | log.info('Evaluating model after epoch %d' % epoch_idx) 493 | test(model, phases='test') 494 | 495 | # save model 496 | fname = osp.join(args.exp_dir, 'epoch_%d.model' % epoch_idx) 497 | if not osp.exists(osp.dirname(fname)): 498 | os.makedirs(osp.dirname(fname)) 499 | torch.save(model.state_dict(), fname) 500 | log.info('Model of epoch %d saved to %s' % (epoch_idx, fname)) 501 | 502 | 503 | def main(): 504 | model = DeepFool() 505 | 506 | if args.pretest: 507 | log.info('Evaluating performance before fine-tune') 508 | test(model, phases='test') 509 | 510 | log.info('Fine-tuning network') 511 | train(model) 512 | 513 | log.info('Saving model') 514 | fname = osp.join(args.exp_dir, 'final.model') 515 | if not osp.exists(osp.dirname(fname)): 516 | os.makedirs(osp.dirname(fname)) 517 | torch.save(model.cpu().state_dict(), fname) 518 | 519 | log.info('Final model saved to %s' % fname) 520 | 521 | 522 | if __name__ == '__main__': 523 | args = parse_args() 524 | 525 | log.info('Called with args:') 526 | log.info(args) 527 | 528 | np.random.seed(args.seed) 529 | torch.manual_seed(args.seed) 530 | torch.backends.cudnn.deterministic = True 531 | 532 | if args.dataset == 'mnist': 533 | train_loader = torch.utils.data.DataLoader(MNISTDataset(phase='trainval'), 534 | batch_size=args.batch, shuffle=False, num_workers=4, 535 | pin_memory=False, drop_last=False) 536 | test_loader = torch.utils.data.DataLoader(MNISTDataset(phase='test'), 537 | batch_size=args.test_batch, shuffle=False, num_workers=4, 538 | pin_memory=False, drop_last=False) 539 | elif args.dataset == 'cifar10': 540 | train_loader = torch.utils.data.DataLoader(CIFAR10Dataset(phase='trainval'), 541 | batch_size=args.batch, shuffle=False, num_workers=4, 542 | pin_memory=False, drop_last=False) 543 | test_loader = torch.utils.data.DataLoader(CIFAR10Dataset(phase='test'), 544 | batch_size=args.test_batch, shuffle=False, num_workers=4, 545 | pin_memory=False, drop_last=False) 546 | else: 547 | raise NotImplementedError 548 | 549 | # print this script to log 550 | fname = __file__ 551 | if fname.endswith('pyc'): 552 | fname = fname[:-1] 553 | with open(fname, 'r') as f: 554 | log.info(f.read()) 555 | 556 | # make experiment directory 557 | if not osp.exists(args.exp_dir): 558 | os.makedirs(args.exp_dir) 559 | 560 | # dump config 561 | with open(osp.join(args.exp_dir, 'config.json'), 'w') as f: 562 | json.dump(vars(args), f, sort_keys=True, indent=4) 563 | 564 | # backup scripts 565 | os.system('cp %s %s' % (fname, args.exp_dir)) 566 | os.system('cp -r datasets models %s' % args.exp_dir) 567 | 568 | # do the business 569 | main() 570 | -------------------------------------------------------------------------------- /deepdefense.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -x 4 | set -e 5 | 6 | export PYTHONUNBUFFERED="True" 7 | 8 | DIR="output/debug/`date +'%Y-%m-%d_%H:%M:%S'`" 9 | R=$(head -c 500 /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 8 | head -n 1) 10 | DIR=${DIR}"-"${R} 11 | mkdir -p $DIR 12 | LOG=${DIR}"/train.log" 13 | exec &> >(tee -a "$LOG") 14 | echo Logging output to "$LOG" 15 | 16 | time python3 deepdefense.py --exp-dir $DIR $@ 17 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import mnist 2 | from . import cifar10 3 | -------------------------------------------------------------------------------- /models/cifar10.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | import scipy.io as sio 7 | import glog as log 8 | 9 | 10 | class ConvNet(nn.Module): 11 | def __init__(self): 12 | super(ConvNet, self).__init__() 13 | self.conv1 = nn.Conv2d(3, 32, 5, padding=2) 14 | self.conv2 = nn.Conv2d(32, 32, 5, padding=2) 15 | self.conv3 = nn.Conv2d(32, 64, 5, padding=2) 16 | self.conv4 = nn.Conv2d(64, 64, 4) 17 | self.conv5 = nn.Conv2d(64, 10, 1) 18 | 19 | for k in ['conv1', 'conv2', 'conv3', 'conv4', 'conv5']: 20 | w = self.__getattr__(k) 21 | torch.nn.init.kaiming_normal(w.weight.data) 22 | w.bias.data.fill_(0) 23 | 24 | self.out = dict() 25 | 26 | def save(self, x, name): 27 | self.out[name] = x 28 | 29 | def forward(self, x): 30 | self.save(x, 'x') 31 | x = self.conv1(x) 32 | self.save(x, 'conv1_out') 33 | x, pool1_ind = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True) 34 | self.save(x, 'pool1_out') 35 | self.save(pool1_ind, 'pool1_ind') 36 | x = F.relu(x) 37 | self.save(x, 'relu1_out') 38 | 39 | x = self.conv2(x) 40 | self.save(x, 'conv2_out') 41 | x = F.relu(x) 42 | self.save(x, 'relu2_out') 43 | x = F.avg_pool2d(x, kernel_size=2, stride=2) 44 | self.save(x, 'pool2_out') 45 | 46 | x = self.conv3(x) 47 | self.save(x, 'conv3_out') 48 | 49 | x = F.relu(x) 50 | self.save(x, 'relu3_out') 51 | x = F.avg_pool2d(x, kernel_size=2, stride=2) 52 | self.save(x, 'pool3_out') 53 | 54 | x = self.conv4(x) 55 | self.save(x, 'conv4_out') 56 | x = F.relu(x) 57 | self.save(x, 'relu4_out') 58 | 59 | x = self.conv5(x) 60 | self.save(x, 'conv5_out') 61 | 62 | x = x.view(-1, 10) 63 | self.save(x, 'flat_out') 64 | 65 | return x 66 | 67 | def load_weights(self, source=None): 68 | if source is None: 69 | source = 'data/cifar10-convnet-15742544.mat' 70 | if source.endswith('mat'): 71 | log.info('Load cifar10 weights from matlab model %s' % source) 72 | mcn = sio.loadmat(source) 73 | mcn_weights = dict() 74 | 75 | mcn_weights['conv1.weights'] = mcn['net'][0][0][0][0][0][0][0][1][0][0].transpose() 76 | mcn_weights['conv1.bias'] = mcn['net'][0][0][0][0][0][0][0][1][0][1].flatten() 77 | 78 | mcn_weights['conv2.weights'] = mcn['net'][0][0][0][0][3][0][0][1][0][0].transpose() 79 | mcn_weights['conv2.bias'] = mcn['net'][0][0][0][0][3][0][0][1][0][1].flatten() 80 | 81 | mcn_weights['conv3.weights'] = mcn['net'][0][0][0][0][6][0][0][1][0][0].transpose() 82 | mcn_weights['conv3.bias'] = mcn['net'][0][0][0][0][6][0][0][1][0][1].flatten() 83 | 84 | mcn_weights['conv4.weights'] = mcn['net'][0][0][0][0][9][0][0][1][0][0].transpose() 85 | mcn_weights['conv4.bias'] = mcn['net'][0][0][0][0][9][0][0][1][0][1].flatten() 86 | 87 | mcn_weights['conv5.weights'] = mcn['net'][0][0][0][0][11][0][0][1][0][0].transpose() 88 | mcn_weights['conv5.bias'] = mcn['net'][0][0][0][0][11][0][0][1][0][1].flatten() 89 | 90 | for k in ['conv1', 'conv2', 'conv3', 'conv4', 'conv5']: 91 | t = self.__getattr__(k) 92 | assert t.weight.data.size() == mcn_weights['%s.weights' % k].shape 93 | t.weight.data[:] = torch.from_numpy(mcn_weights['%s.weights' % k]) 94 | assert t.bias.data.size() == mcn_weights['%s.bias' % k].shape 95 | t.bias.data[:] = torch.from_numpy(mcn_weights['%s.bias' % k]) 96 | elif source.endswith('pth'): 97 | log.info('Load cifar10 weights from PyTorch model %s' % source) 98 | pth_weights = torch.load(source) 99 | for k in ['conv1', 'conv2', 'conv3', 'conv4', 'conv5']: 100 | t = self.__getattr__(k) 101 | assert t.weight.data.size() == pth_weights['net.%s.weight' % k].shape 102 | t.weight.data[:] = pth_weights['net.%s.weight' % k] 103 | assert t.bias.data.size() == pth_weights['net.%s.bias' % k].shape 104 | t.bias.data[:] = pth_weights['net.%s.bias' % k] 105 | 106 | 107 | class InverseConvNet(nn.Module): 108 | def __init__(self): 109 | super(InverseConvNet, self).__init__() 110 | self.transposeconv5 = nn.ConvTranspose2d(10, 64, 1, bias=False) 111 | self.transposeconv4 = nn.ConvTranspose2d(64, 64, 4, bias=False) 112 | self.transposeconv3 = nn.ConvTranspose2d(64, 32, 5, padding=2, bias=False) 113 | self.transposeconv2 = nn.ConvTranspose2d(32, 32, 5, padding=2, bias=False) 114 | self.transposeconv1 = nn.ConvTranspose2d(32, 3, 5, padding=2, bias=False) 115 | 116 | # inverse pool2 (average pooling) 117 | self.w2 = torch.zeros(32, 32, 2, 2).cuda() 118 | for i in range(32): 119 | self.w2[i, i, :, :] = 0.25 120 | 121 | # inverse pool3 (average pooling) 122 | self.w3 = torch.zeros(64, 64, 2, 2).cuda() 123 | for i in range(64): 124 | self.w3[i, i, :, :] = 0.25 125 | 126 | self.out = dict() 127 | 128 | def save(self, x, name): 129 | self.out[name] = x 130 | 131 | def forward(self, x, pool1_ind, relu1_mask, relu2_mask, relu3_mask, relu4_mask): 132 | x = x.view(-1, 10, 1, 1) 133 | self.save(x, 'conv5_out') 134 | 135 | x = self.transposeconv5(x) 136 | self.save(x, 'relu4_out') 137 | x = x * relu4_mask 138 | self.save(x, 'conv4_out') 139 | 140 | x = self.transposeconv4(x) 141 | self.save(x, 'pool3_out') 142 | x = F.conv_transpose2d(x, self.w3, stride=2) 143 | self.save(x, 'relu3_out') 144 | x = x * relu3_mask 145 | self.save(x, 'conv3_out') 146 | 147 | x = self.transposeconv3(x) 148 | self.save(x, 'pool2_out') 149 | x = F.conv_transpose2d(x, self.w2, stride=2) 150 | self.save(x, 'relu2_out') 151 | x = x* relu2_mask 152 | self.save(x, 'conv2_out') 153 | 154 | x = self.transposeconv2(x) 155 | self.save(x, 'relu1_out') 156 | x = x * relu1_mask 157 | self.save(x, 'pool1_out') 158 | x = F.max_unpool2d(x, pool1_ind, kernel_size=2, stride=2) 159 | self.save(x, 'conv1_out') 160 | x = self.transposeconv1(x) 161 | self.save(x, 'input_out') 162 | 163 | return x 164 | 165 | def copy_from(self, net): 166 | for k in ['conv1', 'conv2', 'conv3', 'conv4', 'conv5']: 167 | t = net.__getattr__(k) 168 | tt = self.__getattr__('transpose%s' % k) 169 | assert t.weight.size() == tt.weight.size() 170 | tt.weight = t.weight 171 | 172 | def forward_from_net(self, net, input_image, idx): 173 | num_target_label = idx.size()[1] 174 | batch_size = input_image.size()[0] 175 | image_shape = input_image.size()[1:] 176 | 177 | output_var = net(input_image.cuda()) 178 | 179 | dzdy = np.zeros((idx.numel(), output_var.size()[1]), dtype=np.float32) 180 | dzdy[np.arange(idx.numel()), idx.view(idx.numel()).cpu().numpy()] = 1. 181 | 182 | inverse_input_var = torch.from_numpy(dzdy).cuda() 183 | inverse_input_var.requires_grad = True 184 | inverse_output_var = self.forward( 185 | inverse_input_var, 186 | net.out['pool1_ind'].repeat(1, num_target_label, 1, 1).view(idx.numel(), 32, 16, 16), 187 | (net.out['pool1_out'] > 0).float().repeat(1, num_target_label, 1, 1).view(idx.numel(), 32, 16, 16), 188 | (net.out['conv2_out'] > 0).float().repeat(1, num_target_label, 1, 1).view(idx.numel(), 32, 16, 16), 189 | (net.out['conv3_out'] > 0).float().repeat(1, num_target_label, 1, 1).view(idx.numel(), 64, 8, 8), 190 | (net.out['conv4_out'] > 0).float().repeat(1, num_target_label, 1, 1).view(idx.numel(), 64, 1, 1), 191 | ) 192 | 193 | dzdx = inverse_output_var.view(input_image.size()[0], idx.size()[1], -1).transpose(1, 2) 194 | return dzdx 195 | 196 | 197 | class NIN(nn.Module): 198 | def __init__(self): 199 | super(NIN, self).__init__() 200 | self.conv1 = nn.Conv2d(3, 192, 5, padding=2) 201 | self.cccp1 = nn.Conv2d(192, 160, 1) 202 | self.cccp2 = nn.Conv2d(160, 96, 1) 203 | 204 | self.conv2 = nn.Conv2d(96, 192, 5, padding=2) 205 | self.cccp3 = nn.Conv2d(192, 192, 1) 206 | self.cccp4 = nn.Conv2d(192, 192, 1) 207 | 208 | self.conv3 = nn.Conv2d(192, 192, 3, padding=1) 209 | self.cccp5 = nn.Conv2d(192, 192, 1) 210 | self.cccp6 = nn.Conv2d(192, 10, 1) 211 | 212 | for k in ['conv1', 'cccp1', 'cccp2', 'conv2', 'cccp3', 'cccp4', 'conv3', 'cccp5', 'cccp6']: 213 | w = self.__getattr__(k) 214 | torch.nn.init.kaiming_normal(w.weight.data) 215 | w.bias.data.fill_(0) 216 | 217 | # self.cccp6.weight.data[:] = 0.1 * self.cccp6.weight.data[:] 218 | 219 | self.p = 0.5 # dropout probability 220 | raise NotImplementedError('Code for NIN will be released soon since we need to ' 221 | 'clean up our codebase for dropout support') 222 | 223 | def forward(self, x, drop1_mask=None, drop2_mask=None): 224 | self.x = x 225 | self.conv1_out = self.conv1(self.x) 226 | self.relu1_out = F.relu(self.conv1_out) 227 | self.cccp1_out = self.cccp1(self.relu1_out) 228 | self.relu_cccp1_out = F.relu(self.cccp1_out) 229 | self.cccp2_out = self.cccp2(self.relu_cccp1_out) 230 | self.relu_cccp2_out = F.relu(self.cccp2_out) 231 | self.pool1_out, self.pool1_ind = F.max_pool2d(self.relu_cccp2_out, kernel_size=2, stride=2, return_indices=True) 232 | if self.training: 233 | # when dropout mask passed from outside is None, we need to generate a new dropout mask in this round 234 | if drop1_mask is None: 235 | # check if we can re-use previous mask Variable 236 | # if yes, we simply fill it with bernoulli noise without cloning it, which may save some running time 237 | if hasattr(self, 'drop1_mask') \ 238 | and self.drop1_mask is not None \ 239 | and self.drop1_mask.size() == self.pool1_out.size(): 240 | drop1_mask = self.drop1_mask 241 | else: 242 | drop1_mask = self.pool1_out.clone().detach() 243 | drop1_mask.data.bernoulli_(self.p).div_(1. - self.p) 244 | self.drop1_out = self.pool1_out * drop1_mask 245 | else: 246 | self.drop1_out = self.pool1_out 247 | self.drop1_mask = drop1_mask 248 | 249 | self.conv2_out = self.conv2(self.drop1_out) 250 | self.relu2_out = F.relu(self.conv2_out) 251 | self.cccp3_out = self.cccp3(self.relu2_out) 252 | self.relu_cccp3_out = F.relu(self.cccp3_out) 253 | self.cccp4_out = self.cccp4(self.relu_cccp3_out) 254 | self.relu_cccp4_out = F.relu(self.cccp4_out) 255 | self.pool2_out = F.avg_pool2d(self.relu_cccp4_out, kernel_size=2, stride=2) 256 | if self.training: 257 | if drop2_mask is None: 258 | if hasattr(self, 'drop2_mask') \ 259 | and self.drop2_mask is not None \ 260 | and self.drop2_mask.size() == self.pool2_out.size(): 261 | drop2_mask = self.drop2_mask 262 | else: 263 | drop2_mask = self.pool2_out.clone().detach() 264 | drop2_mask.data.bernoulli_(self.p).div_(1. - self.p) 265 | self.drop2_out = self.pool2_out * drop2_mask 266 | else: 267 | self.drop2_out = self.pool2_out 268 | self.drop2_mask = drop2_mask 269 | 270 | self.conv3_out = self.conv3(self.drop2_out) 271 | self.relu3_out = F.relu(self.conv3_out) 272 | self.cccp5_out = self.cccp5(self.relu3_out) 273 | self.relu_cccp5_out = F.relu(self.cccp5_out) 274 | self.cccp6_out = self.cccp6(self.relu_cccp5_out) 275 | self.pool3_out = F.avg_pool2d(self.cccp6_out, kernel_size=8) 276 | self.flat_out = self.pool3_out.view(-1, 10) 277 | return self.flat_out 278 | 279 | def load_weights(self, source=None): 280 | if source is None: 281 | source = 'data/cifar10-nin-62053fa9.mat' 282 | mcn = sio.loadmat(source) 283 | mcn_weights = dict() 284 | 285 | mcn_weights['conv1.weights'] = mcn['net'][0][0][0][0][0][0][0][2][0][0].transpose() 286 | mcn_weights['conv1.bias'] = mcn['net'][0][0][0][0][0][0][0][2][0][1].flatten() 287 | 288 | mcn_weights['cccp1.weights'] = mcn['net'][0][0][0][0][2][0][0][2][0][0].transpose() 289 | mcn_weights['cccp1.bias'] = mcn['net'][0][0][0][0][2][0][0][2][0][1].flatten() 290 | 291 | mcn_weights['cccp2.weights'] = mcn['net'][0][0][0][0][4][0][0][2][0][0].transpose() 292 | mcn_weights['cccp2.bias'] = mcn['net'][0][0][0][0][4][0][0][2][0][1].flatten() 293 | 294 | mcn_weights['conv2.weights'] = mcn['net'][0][0][0][0][7][0][0][2][0][0].transpose() 295 | mcn_weights['conv2.bias'] = mcn['net'][0][0][0][0][7][0][0][2][0][1].flatten() 296 | 297 | mcn_weights['cccp3.weights'] = mcn['net'][0][0][0][0][9][0][0][2][0][0].transpose() 298 | mcn_weights['cccp3.bias'] = mcn['net'][0][0][0][0][9][0][0][2][0][1].flatten() 299 | 300 | mcn_weights['cccp4.weights'] = mcn['net'][0][0][0][0][11][0][0][2][0][0].transpose() 301 | mcn_weights['cccp4.bias'] = mcn['net'][0][0][0][0][11][0][0][2][0][1].flatten() 302 | 303 | mcn_weights['conv3.weights'] = mcn['net'][0][0][0][0][14][0][0][2][0][0].transpose() 304 | mcn_weights['conv3.bias'] = mcn['net'][0][0][0][0][14][0][0][2][0][1].flatten() 305 | 306 | mcn_weights['cccp5.weights'] = mcn['net'][0][0][0][0][16][0][0][2][0][0].transpose() 307 | mcn_weights['cccp5.bias'] = mcn['net'][0][0][0][0][16][0][0][2][0][1].flatten() 308 | 309 | mcn_weights['cccp6.weights'] = mcn['net'][0][0][0][0][18][0][0][2][0][0].transpose() 310 | mcn_weights['cccp6.bias'] = mcn['net'][0][0][0][0][18][0][0][2][0][1].flatten() 311 | 312 | for k in ['conv1', 'cccp1', 'cccp2', 'conv2', 'cccp3', 'cccp4', 'conv3', 'cccp5', 'cccp6']: 313 | t = self.__getattr__(k) 314 | assert t.weight.data.size() == mcn_weights['%s.weights' % k].shape 315 | t.weight.data[:] = torch.from_numpy(mcn_weights['%s.weights' % k]) 316 | assert t.bias.data.size() == mcn_weights['%s.bias' % k].shape 317 | t.bias.data[:] = torch.from_numpy(mcn_weights['%s.bias' % k]) 318 | 319 | 320 | class InverseNIN(nn.Module): 321 | def __init__(self): 322 | super(InverseNIN, self).__init__() 323 | 324 | self.transposecccp6 = nn.ConvTranspose2d(10, 192, 1, bias=False) 325 | self.transposecccp5 = nn.ConvTranspose2d(192, 192, 1, bias=False) 326 | self.transposeconv3 = nn.ConvTranspose2d(192, 192, 3, padding=1, bias=False) 327 | 328 | self.transposecccp4 = nn.ConvTranspose2d(192, 192, 1, bias=False) 329 | self.transposecccp3 = nn.ConvTranspose2d(192, 192, 1, bias=False) 330 | self.transposeconv2 = nn.ConvTranspose2d(192, 96, 5, padding=2, bias=False) 331 | 332 | self.transposecccp2 = nn.ConvTranspose2d(96, 160, 1, bias=False) 333 | self.transposecccp1 = nn.ConvTranspose2d(160, 192, 1, bias=False) 334 | self.transposeconv1 = nn.ConvTranspose2d(192, 3, 5, padding=2, bias=False) 335 | 336 | # inverse pool2 (average pooling) 337 | self.w = torch.zeros(192, 192, 2, 2).cuda() 338 | for i in range(192): 339 | self.w[i, i, :, :] = 1. / 4 340 | 341 | def forward(self, x, relu1_mask, relu_cccp1_mask, relu_cccp2_mask, pool1_ind, drop1_mask, relu2_mask, 342 | relu_cccp3_mask, relu_cccp4_mask, drop2_mask, relu3_mask, relu_cccp5_mask): 343 | batch_size = x.size()[0] 344 | self.pool3_out = x.view(-1, 10, 1, 1) 345 | self.cccp6_out = self.pool3_out.expand(batch_size, 10, 8, 8) / 64. 346 | self.relu_cccp5_out = self.transposecccp6(self.cccp6_out) 347 | self.cccp5_out = self.relu_cccp5_out * relu_cccp5_mask 348 | self.relu3_out = self.transposecccp5(self.cccp5_out) 349 | self.conv3_out = self.relu3_out * relu3_mask 350 | self.drop2_out = self.transposeconv3(self.conv3_out) 351 | if self.training: 352 | self.pool2_out = self.drop2_out * drop2_mask 353 | else: 354 | self.pool2_out = self.drop2_out 355 | 356 | self.relu_cccp4_out = F.conv_transpose2d(self.pool2_out, self.w, stride=2) # inverse pool2 (average pooling) 357 | self.cccp4_out = self.relu_cccp4_out * relu_cccp4_mask 358 | self.relu_cccp3_out = self.transposecccp4(self.cccp4_out) 359 | self.cccp3_out = self.relu_cccp3_out * relu_cccp3_mask 360 | self.relu2_out = self.transposecccp3(self.cccp3_out) 361 | self.conv2_out = self.relu2_out * relu2_mask 362 | self.drop1_out = self.transposeconv2(self.conv2_out) 363 | if self.training: 364 | self.pool1_out = self.drop1_out * drop1_mask 365 | else: 366 | self.pool1_out = self.drop1_out 367 | 368 | self.relu_cccp2_out = F.max_unpool2d(self.pool1_out, pool1_ind, kernel_size=2, stride=2) 369 | self.cccp2_out = self.relu_cccp2_out * relu_cccp2_mask 370 | self.relu_cccp1_out = self.transposecccp2(self.cccp2_out) 371 | self.cccp1_out = self.relu_cccp1_out * relu_cccp1_mask 372 | self.relu1_out = self.transposecccp1(self.cccp1_out) 373 | self.conv1_out = self.relu1_out * relu1_mask 374 | self.input_out = self.transposeconv1(self.conv1_out) 375 | return self.input_out 376 | 377 | def copy_from(self, net): 378 | for k in ['conv1', 'cccp1', 'cccp2', 'conv2', 'cccp3', 'cccp4', 'conv3', 'cccp5', 'cccp6']: 379 | t = net.__getattr__(k) 380 | tt = self.__getattr__('transpose%s' % k) 381 | assert t.weight.size() == tt.weight.size() 382 | tt.weight = t.weight 383 | 384 | def forward_from_net(self, net, input_image, idx, drop1_mask=None, drop2_mask=None): 385 | idx = idx.contiguous() 386 | num_target_label = idx.size()[1] 387 | batch_size = input_image.size()[0] 388 | image_shape = input_image.size()[1:] 389 | 390 | output_var = net(input_image, drop1_mask=drop1_mask, drop2_mask=drop2_mask) 391 | 392 | dzdy = np.zeros((idx.numel(), output_var.size()[1]), dtype=np.float32) 393 | dzdy[np.arange(idx.numel()), idx.view(idx.numel()).cpu().numpy()] = 1. 394 | 395 | inverse_input_var = torch.from_numpy(dzdy).cuda() 396 | inverse_input_var.requires_grad = True 397 | inverse_output_var = self.forward( 398 | inverse_input_var, 399 | (net.conv1_out > 0.).float().repeat(1, num_target_label, 1, 1).view(idx.numel(), 192, 32, 32), 400 | (net.cccp1_out > 0.).float().repeat(1, num_target_label, 1, 1).view(idx.numel(), 160, 32, 32), 401 | (net.cccp2_out > 0.).float().repeat(1, num_target_label, 1, 1).view(idx.numel(), 96, 32, 32), 402 | net.pool1_ind.repeat(1, num_target_label, 1, 1).view(idx.numel(), 96, 16, 16), 403 | net.drop1_mask.repeat(1, num_target_label, 1, 1).view(idx.numel(), 96, 16, 16) 404 | if net.drop1_mask is not None else None, 405 | (net.conv2_out > 0.).float().repeat(1, num_target_label, 1, 1).view(idx.numel(), 192, 16, 16), 406 | (net.cccp3_out > 0.).float().repeat(1, num_target_label, 1, 1).view(idx.numel(), 192, 16, 16), 407 | (net.cccp4_out > 0.).float().repeat(1, num_target_label, 1, 1).view(idx.numel(), 192, 16, 16), 408 | net.drop2_mask.repeat(1, num_target_label, 1, 1).view(idx.numel(), 192, 8, 8) 409 | if net.drop2_mask is not None else None, 410 | (net.conv3_out > 0.).float().repeat(1, num_target_label, 1, 1).view(idx.numel(), 192, 8, 8), 411 | (net.cccp5_out > 0.).float().repeat(1, num_target_label, 1, 1).view(idx.numel(), 192, 8, 8), 412 | ) 413 | 414 | dzdx = inverse_output_var.view(input_image.size()[0], idx.size()[1], -1).transpose(1, 2) 415 | return dzdx 416 | -------------------------------------------------------------------------------- /models/mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | import math 7 | import scipy.io as sio 8 | 9 | 10 | class LinearTranspose(nn.Module): 11 | def __init__(self, in_features, out_features, bias=True): 12 | super(LinearTranspose, self).__init__() 13 | self.in_features = in_features 14 | self.out_features = out_features 15 | self.weight = nn.Parameter(torch.Tensor(in_features, out_features)) 16 | if bias: 17 | self.bias = nn.Parameter(torch.Tensor(out_features)) 18 | else: 19 | self.register_parameter('bias', None) 20 | self.reset_parameters() 21 | 22 | def reset_parameters(self): 23 | stdv = 1. / math.sqrt(self.weight.size(1)) 24 | self.weight.data.uniform_(-stdv, stdv) 25 | if self.bias is not None: 26 | self.bias.data.uniform_(-stdv, stdv) 27 | 28 | def forward(self, input): 29 | return F.linear(input, self.weight.transpose(0, 1), self.bias) 30 | 31 | def __repr__(self): 32 | return self.__class__.__name__ + ' (' \ 33 | + str(self.in_features) + ' -> ' \ 34 | + str(self.out_features) + ')' 35 | 36 | 37 | class BNTranspose(nn.Module): 38 | def __init__(self, num_features): 39 | super(BNTranspose, self).__init__() 40 | self.num_features = num_features 41 | self.weight = nn.Parameter(torch.Tensor(num_features)) 42 | # self.bias = nn.Parameter(torch.Tensor(num_features)) 43 | self.reset_parameters() 44 | 45 | def reset_parameters(self): 46 | pass 47 | 48 | def forward(self, input): 49 | return input * self.weight 50 | 51 | def __repr__(self): 52 | return self.__class__.__name__ + ' (' \ 53 | + str(self.num_features) + ')' 54 | 55 | 56 | class LeNet(nn.Module): 57 | def __init__(self): 58 | super(LeNet, self).__init__() 59 | # 1 input image channel, 6 output channels, 5x5 square convolution 60 | # kernel 61 | self.conv1 = nn.Conv2d(1, 20, 5) 62 | self.conv2 = nn.Conv2d(20, 50, 5) 63 | # an affine operation: y = Wx + b 64 | self.fc1 = nn.Linear(50 * 4 * 4, 500) 65 | self.fc2 = nn.Linear(500, 10) 66 | 67 | def forward(self, x): 68 | # Max pooling over a (2, 2) window 69 | self.x = x 70 | self.conv1_out = self.conv1(self.x) 71 | self.pool1_out, self.pool1_ind = F.max_pool2d(self.conv1_out, (2, 2), return_indices=True) 72 | self.conv2_out = self.conv2(self.pool1_out) 73 | self.pool2_out, self.pool2_ind = F.max_pool2d(self.conv2_out, (2, 2), return_indices=True) 74 | self.flat_out = self.pool2_out.view(-1, self.num_flat_features(self.pool2_out)) 75 | self.fc1_out = self.fc1(self.flat_out) 76 | self.relu1_out = F.relu(self.fc1_out) 77 | self.fc2_out = self.fc2(self.relu1_out) 78 | 79 | return self.fc2_out 80 | 81 | def load_weights(self, source=None): 82 | if source is None: 83 | source = 'data/mnist-lenet-92dd205e.mat' 84 | mcn = sio.loadmat(source) 85 | mcn_weights = dict() 86 | 87 | mcn_weights['conv1.weights'] = mcn['net'][0][0][0][0][0][0][0][1][0][0].transpose() 88 | mcn_weights['conv1.bias'] = mcn['net'][0][0][0][0][0][0][0][1][0][1].flatten() 89 | 90 | mcn_weights['conv2.weights'] = mcn['net'][0][0][0][0][2][0][0][1][0][0].transpose() 91 | mcn_weights['conv2.bias'] = mcn['net'][0][0][0][0][2][0][0][1][0][1].flatten() 92 | 93 | mcn_weights['fc1.weights'] = mcn['net'][0][0][0][0][4][0][0][1][0][0].transpose().reshape(500, -1) 94 | mcn_weights['fc1.bias'] = mcn['net'][0][0][0][0][4][0][0][1][0][1].flatten() 95 | 96 | mcn_weights['fc2.weights'] = mcn['net'][0][0][0][0][6][0][0][1][0][0].transpose().reshape(10, -1) 97 | mcn_weights['fc2.bias'] = mcn['net'][0][0][0][0][6][0][0][1][0][1].flatten() 98 | 99 | for k in ['conv1', 'conv2', 'fc1', 'fc2']: 100 | t = self.__getattr__(k) 101 | assert t.weight.data.size() == mcn_weights['%s.weights' % k].shape 102 | t.weight.data[:] = torch.from_numpy(mcn_weights['%s.weights' % k]).cuda() 103 | assert t.bias.data.size() == mcn_weights['%s.bias' % k].shape 104 | t.bias.data[:] = torch.from_numpy(mcn_weights['%s.bias' % k]).cuda() 105 | 106 | @staticmethod 107 | def num_flat_features(x): 108 | size = x.size()[1:] # all dimensions except the batch dimension 109 | num_features = 1 110 | for s in size: 111 | num_features *= s 112 | return num_features 113 | 114 | 115 | class InverseLeNet(nn.Module): 116 | def __init__(self): 117 | super(InverseLeNet, self).__init__() 118 | self.transposefc2 = LinearTranspose(10, 500, bias=False) 119 | self.transposefc1 = LinearTranspose(500, 50 * 4 * 4, bias=False) 120 | self.unpool2 = nn.MaxUnpool2d(2, 2) 121 | self.transposeconv2 = nn.ConvTranspose2d(50, 20, 5, bias=False) 122 | self.unpool1 = nn.MaxUnpool2d(2, 2) 123 | self.transposeconv1 = nn.ConvTranspose2d(20, 1, 5, bias=False) 124 | 125 | def forward(self, x, relu1_mask, pool2_ind, pool1_ind): 126 | self.relu1_out = self.transposefc2(x) 127 | self.fc1_out = self.relu1_out * relu1_mask 128 | self.flat_out = self.transposefc1(self.fc1_out) 129 | self.pool2_out = self.flat_out.view(-1, 50, 4, 4) 130 | self.conv2_out = self.unpool2(self.pool2_out, pool2_ind) 131 | self.pool1_out = self.transposeconv2(self.conv2_out) 132 | self.conv1_out = self.unpool2(self.pool1_out, pool1_ind) 133 | self.input_out = self.transposeconv1(self.conv1_out) 134 | return self.input_out 135 | 136 | def copy_from(self, net): 137 | assert self.transposefc2.weight.data.size() == net.fc2.weight.data.size() 138 | self.transposefc2.weight = net.fc2.weight 139 | 140 | assert self.transposefc1.weight.data.size() == net.fc1.weight.data.size() 141 | self.transposefc1.weight = net.fc1.weight 142 | 143 | assert self.transposeconv2.weight.data.size() == net.conv2.weight.data.size() 144 | self.transposeconv2.weight = net.conv2.weight 145 | 146 | assert self.transposeconv1.weight.data.size() == net.conv1.weight.data.size() 147 | self.transposeconv1.weight = net.conv1.weight 148 | 149 | def forward_from_net(self, net, input_image, idx): 150 | num_target_label = idx.size()[1] 151 | batch_size = input_image.size()[0] 152 | image_shape = input_image.size()[1:] 153 | 154 | # use inversenet to calculate gradient 155 | output_var = net(input_image.cuda()) 156 | 157 | dzdy = np.zeros((idx.numel(), output_var.size()[1]), dtype=np.float32) 158 | dzdy[np.arange(idx.numel()), idx.view(idx.numel()).cpu().numpy()] = 1. 159 | 160 | inverse_input_var = torch.from_numpy(dzdy).cuda() 161 | inverse_input_var.requires_grad = True 162 | inverse_output_var = self.forward( 163 | inverse_input_var, 164 | (net.fc1_out > 0).float().repeat(1, num_target_label).view(idx.numel(), 500), 165 | net.pool2_ind.repeat(1, num_target_label, 1, 1).view(idx.numel(), 50, 4, 4), 166 | net.pool1_ind.repeat(1, num_target_label, 1, 1).view(idx.numel(), 20, 12, 12)) 167 | 168 | dzdx = inverse_output_var.view(input_image.size()[0], idx.size()[1], -1).transpose(1, 2) 169 | return dzdx 170 | 171 | 172 | class MLP(nn.Module): 173 | def __init__(self): 174 | super(MLP, self).__init__() 175 | # an affine operation: y = Wx + b 176 | self.fc1 = nn.Linear(784, 500) 177 | self.fc2 = nn.Linear(500, 150) 178 | self.fc3 = nn.Linear(150, 10) 179 | 180 | def forward(self, x): 181 | self.x = x 182 | self.flat_out = self.x.view(-1, 784) 183 | self.fc1_out = self.fc1(self.flat_out) 184 | self.relu1_out = F.relu(self.fc1_out) 185 | self.fc2_out = self.fc2(self.relu1_out) 186 | self.relu2_out = F.relu(self.fc2_out) 187 | self.fc3_out = self.fc3(self.relu2_out) 188 | return self.fc3_out 189 | 190 | def load_weights(self, source=None): 191 | if source is None: 192 | source = 'data/mnist-mlp-d072f4c8.mat' 193 | mcn = sio.loadmat(source) 194 | mcn_weights = dict() 195 | 196 | mcn_weights['fc1.weights'] = mcn['net'][0][0][0][0][0][0][0][1][0][0].transpose().reshape(500, -1) 197 | mcn_weights['fc1.bias'] = mcn['net'][0][0][0][0][0][0][0][1][0][1].flatten() 198 | 199 | mcn_weights['fc2.weights'] = mcn['net'][0][0][0][0][2][0][0][1][0][0].transpose().reshape(150, -1) 200 | mcn_weights['fc2.bias'] = mcn['net'][0][0][0][0][2][0][0][1][0][1].flatten() 201 | 202 | mcn_weights['fc3.weights'] = mcn['net'][0][0][0][0][4][0][0][1][0][0].transpose().reshape(10, -1) 203 | mcn_weights['fc3.bias'] = mcn['net'][0][0][0][0][4][0][0][1][0][1].flatten() 204 | 205 | for k in ['fc1', 'fc2', 'fc3']: 206 | t = self.__getattr__(k) 207 | assert t.weight.data.size() == mcn_weights['%s.weights' % k].shape 208 | t.weight.data[:] = torch.from_numpy(mcn_weights['%s.weights' % k]).cuda() 209 | assert t.bias.data.size() == mcn_weights['%s.bias' % k].shape 210 | t.bias.data[:] = torch.from_numpy(mcn_weights['%s.bias' % k]).cuda() 211 | 212 | 213 | class InverseMLP(nn.Module): 214 | def __init__(self): 215 | super(InverseMLP, self).__init__() 216 | self.transposefc3 = LinearTranspose(10, 150, bias=False) 217 | self.transposefc2 = LinearTranspose(150, 500, bias=False) 218 | self.transposefc1 = LinearTranspose(500, 784, bias=False) 219 | 220 | def forward(self, x, relu1_mask, relu2_mask): 221 | self.relu2_out = self.transposefc3(x) 222 | self.fc2_out = self.relu2_out * relu2_mask 223 | self.relu1_out = self.transposefc2(self.fc2_out) 224 | self.fc1_out = self.relu1_out * relu1_mask 225 | self.flat_out = self.transposefc1(self.fc1_out) 226 | self.input_out = self.flat_out.view(-1, 1, 28, 28) 227 | return self.input_out 228 | 229 | def copy_from(self, net): 230 | for k in ['fc1', 'fc2', 'fc3']: 231 | t = net.__getattr__(k) 232 | tt = self.__getattr__('transpose%s' % k) 233 | assert t.weight.data.size() == tt.weight.data.size() 234 | tt.weight = t.weight 235 | 236 | def forward_from_net(self, net, input_image, idx): 237 | num_target_label = idx.size()[1] 238 | batch_size = input_image.size()[0] 239 | image_shape = input_image.size()[1:] 240 | 241 | output_var = net(input_image.cuda()) 242 | 243 | dzdy = np.zeros((idx.numel(), output_var.size()[1]), dtype=np.float32) 244 | dzdy[np.arange(idx.numel()), idx.view(idx.numel()).cpu().numpy()] = 1. 245 | 246 | inverse_input_var = torch.from_numpy(dzdy).cuda() 247 | inverse_input_var.requires_grad = True 248 | inverse_output_var = self.forward( 249 | inverse_input_var, 250 | (net.fc1_out > 0).float().repeat(1, num_target_label).view(idx.numel(), 500), 251 | (net.fc2_out > 0).float().repeat(1, num_target_label).view(idx.numel(), 150), 252 | ) 253 | 254 | dzdx = inverse_output_var.view(input_image.size()[0], idx.size()[1], -1).transpose(1, 2) 255 | return dzdx 256 | 257 | 258 | class MLPBN(nn.Module): 259 | def __init__(self): 260 | super(MLPBN, self).__init__() 261 | # an affine operation: y = Wx + b 262 | self.fc1 = nn.Linear(784, 500) 263 | self.bn1 = nn.BatchNorm1d(500) 264 | self.fc2 = nn.Linear(500, 150) 265 | self.bn2 = nn.BatchNorm1d(150) 266 | self.fc3 = nn.Linear(150, 10) 267 | 268 | def forward(self, x): 269 | self.x = x 270 | self.flat_out = self.x.view(-1, 784) 271 | self.fc1_out = self.fc1(self.flat_out) 272 | self.bn1_out = self.bn1(self.fc1_out) 273 | self.relu1_out = F.relu(self.bn1_out) 274 | self.fc2_out = self.fc2(self.relu1_out) 275 | self.bn2_out = self.bn2(self.fc2_out) 276 | self.relu2_out = F.relu(self.bn2_out) 277 | self.fc3_out = self.fc3(self.relu2_out) 278 | return self.fc3_out 279 | 280 | def load_weights(self, source=None): 281 | if source is None: 282 | source = 'data/mnist-mlpbn-25b43980.pth' 283 | self.load_state_dict(torch.load(source)) 284 | 285 | 286 | class InverseMLPBN(nn.Module): 287 | def __init__(self): 288 | super(InverseMLPBN, self).__init__() 289 | self.transposefc3 = LinearTranspose(10, 150, bias=False) 290 | self.transposebn2 = BNTranspose(150) 291 | self.transposefc2 = LinearTranspose(150, 500, bias=False) 292 | self.transposebn1 = BNTranspose(500) 293 | self.transposefc1 = LinearTranspose(500, 784, bias=False) 294 | 295 | def forward(self, x, relu1_mask, relu2_mask): 296 | self.relu2_out = self.transposefc3(x) 297 | self.bn2_out = self.relu2_out * relu2_mask 298 | self.fc2_out = self.transposebn2(self.bn2_out) 299 | self.relu1_out = self.transposefc2(self.fc2_out) 300 | self.bn1_out = self.relu1_out * relu1_mask 301 | self.fc1_out = self.transposebn1(self.bn1_out) 302 | self.flat_out = self.transposefc1(self.fc1_out) 303 | self.input_out = self.flat_out.view(-1, 1, 28, 28) 304 | return self.input_out 305 | 306 | def copy_from(self, net): 307 | for k in ['fc1', 'fc2', 'fc3']: 308 | t = net.__getattr__(k) 309 | tt = self.__getattr__('transpose%s' % k) 310 | assert t.weight.data.size() == tt.weight.data.size() 311 | tt.weight = t.weight 312 | for k in ['bn1', 'bn2']: 313 | t = net.__getattr__(k) 314 | tt = self.__getattr__('transpose%s' % k) 315 | tt.weight.data[:] = (t.weight / torch.sqrt(t.running_var + t.eps)).data[:] 316 | # tt.bias.data[:] = (-t.running_mean * tt.weight + t.bias).data[:] 317 | 318 | def forward_from_net(self, net, input_image, idx): 319 | num_target_label = idx.size()[1] 320 | batch_size = input_image.size()[0] 321 | image_shape = input_image.size()[1:] 322 | 323 | output_var = net(input_image.cuda()) 324 | 325 | dzdy = np.zeros((idx.numel(), output_var.size()[1]), dtype=np.float32) 326 | dzdy[np.arange(idx.numel()), idx.view(idx.numel()).cpu().numpy()] = 1. 327 | 328 | inverse_input_var = torch.from_numpy(dzdy).cuda() 329 | inverse_input_var.requires_grad = True 330 | inverse_output_var = self.forward( 331 | inverse_input_var, 332 | (net.bn1_out > 0).float().repeat(1, num_target_label).view(idx.numel(), 500), 333 | (net.bn2_out > 0).float().repeat(1, num_target_label).view(idx.numel(), 150), 334 | ) 335 | 336 | dzdx = inverse_output_var.view(input_image.size()[0], idx.size()[1], -1).transpose(1, 2) 337 | return dzdx 338 | --------------------------------------------------------------------------------