├── Block_Diag.png ├── Mapping_Function.png ├── Models_PCL └── CIFAR10_PCL.pth.tar ├── Models_Softmax └── CIFAR10_Softmax.pth.tar ├── README.md ├── contrastive_proximity.py ├── pcl_training.py ├── pcl_training_adversarial_fgsm.py ├── pcl_training_adversarial_pgd.py ├── proximity.py ├── resnet_model.py ├── robust_ml.py ├── robust_model.pth.tar ├── robustness.py ├── softmax_training.py └── utils.py /Block_Diag.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aamir-mustafa/pcl-adversarial-defense/37242580937c267efcac6679ec050aa438b47796/Block_Diag.png -------------------------------------------------------------------------------- /Mapping_Function.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aamir-mustafa/pcl-adversarial-defense/37242580937c267efcac6679ec050aa438b47796/Mapping_Function.png -------------------------------------------------------------------------------- /Models_PCL/CIFAR10_PCL.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aamir-mustafa/pcl-adversarial-defense/37242580937c267efcac6679ec050aa438b47796/Models_PCL/CIFAR10_PCL.pth.tar -------------------------------------------------------------------------------- /Models_Softmax/CIFAR10_Softmax.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aamir-mustafa/pcl-adversarial-defense/37242580937c267efcac6679ec050aa438b47796/Models_Softmax/CIFAR10_Softmax.pth.tar -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Adversarial Defense by Restricting the Hidden Space of Deep Neural Networks (ICCV'19) 3 | 4 | ![Figure 1](Mapping_Function.png) 5 | 6 | This repository is an PyTorch implementation of the ICCV'19 paper [Adversarial Defense by Restricting the Hidden Space of Deep Neural Networks](https://arxiv.org/abs/1904.00887). 7 | 8 | To counter adversarial attacks, we propose Prototype Conformity Loss to class-wise disentangle intermediate features of a deep network. From the figure, it can be observed that the main reason for the existence of such adversarial samples is the close proximity of learnt features in the latent feature space. 9 | 10 | We provide scripts for reproducing the results from our paper. 11 | 12 | 13 | ## Clone the repository 14 | Clone this repository into any place you want. 15 | ```bash 16 | git clone https://github.com/aamir-mustafa/pcl-adversarial-defense 17 | cd pcl-adversarial-defense 18 | ``` 19 | ## Softmax (Cross-Entropy) Training 20 | To expedite the process of forming clusters for our proposed loss, we initially train the model using cross-entropy loss. 21 | 22 | ``softmax_training.py`` -- ( For initial softmax training). 23 | 24 | * The trained checkpoints will be saved in ``Models_Softmax`` folder. 25 | 26 | 27 | ## Prototype Conformity Loss 28 | The deep features for the prototype conformity loss are extracted from different intermediate layers using auxiliary branches, which map the features to a lower dimensional output as shown in the following figure. 29 | 30 | ![](Block_Diag.png) 31 | 32 | 33 | 34 | ``pcl_training.py`` -- ( Joint supervision with cross-entropy and our loss). 35 | 36 | * The trained checkpoints will be saved in ``Models_PCL`` folder. 37 | 38 | ## Adversarial Training 39 | ``pcl_training_adversarial_fgsm.py`` -- ( Adversarial Training using FGSM Attack). 40 | 41 | ``pcl_training_adversarial_pgd.py`` -- ( Adversarial Training using PGD Attack). 42 | 43 | 44 | 45 | ## Testing Model's Robustness against White-Box Attacks 46 | 47 | ``robustness.py`` -- (Evaluate trained model's robustness against various types of attacks). 48 | 49 | ## Comparison of Softmax Trained Model and Our Model 50 | Retained classification accuracy of the model's under various types of adversarial attacks: 51 | 52 | | Training Scheme | No Attack | FGSM | BIM | MIM | PGD | 53 | | :------- | :---------- | :----- |:------ |:------ |:------ | 54 | | Softmax | 92.15 | 21.48 | 0.01 | 0.02 | 0.00 | 55 | | Ours | 89.55 | 55.76 | 39.75 | 36.44 | 31.10 | 56 | 57 | 58 | ## Citation 59 | ``` 60 | @InProceedings{Mustafa_2019_ICCV, 61 | author = {Mustafa, Aamir and Khan, Salman and Hayat, Munawar and Goecke, Roland and Shen, Jianbing and Shao, Ling}, 62 | title = {Adversarial Defense by Restricting the Hidden Space of Deep Neural Networks}, 63 | booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, 64 | month = {October}, 65 | year = {2019} 66 | } 67 | ``` 68 | 69 | -------------------------------------------------------------------------------- /contrastive_proximity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Con_Proximity(nn.Module): 5 | 6 | def __init__(self, num_classes=100, feat_dim=1024, use_gpu=True): 7 | super(Con_Proximity, self).__init__() 8 | self.num_classes = num_classes 9 | self.feat_dim = feat_dim 10 | self.use_gpu = use_gpu 11 | 12 | if self.use_gpu: 13 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())#100 x feats- for 100 centers 14 | else: 15 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 16 | 17 | def forward(self, x, labels): 18 | 19 | batch_size = x.size(0) 20 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 21 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 22 | distmat.addmm_(1, -2, x, self.centers.t()) 23 | 24 | classes = torch.arange(self.num_classes).long() 25 | if self.use_gpu: classes = classes.cuda() 26 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 27 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 28 | 29 | dist = [] 30 | for i in range(batch_size): 31 | 32 | k= mask[i].clone().to(dtype=torch.int8) 33 | 34 | k= -1* k +1 35 | 36 | kk= k.clone().to(dtype=torch.uint8) 37 | 38 | value = distmat[i][kk] 39 | 40 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 41 | 42 | dist.append(value) 43 | dist = torch.cat(dist) 44 | loss = dist.mean() 45 | 46 | return loss 47 | -------------------------------------------------------------------------------- /pcl_training.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Created on Wed Jan 23 10:15:27 2019 4 | 5 | @author: aamir-mustafa 6 | Implementation Part 2 of Paper: 7 | "Adversarial Defense by Restricting the Hidden Space of Deep Neural Networks" 8 | 9 | Here it is not necessary to save the best performing model (in terms of accuracy). The model with high robustness 10 | against adversarial attacks is chosen. 11 | 12 | """ 13 | 14 | #Essential Imports 15 | import os 16 | import sys 17 | import argparse 18 | import datetime 19 | import time 20 | import os.path as osp 21 | import numpy as np 22 | import torch 23 | import torch.nn as nn 24 | import torch.backends.cudnn as cudnn 25 | import torchvision 26 | import torchvision.transforms as transforms 27 | from utils import AverageMeter, Logger 28 | from proximity import Proximity 29 | from contrastive_proximity import Con_Proximity 30 | from resnet_model import * # Imports the ResNet Model 31 | 32 | 33 | parser = argparse.ArgumentParser("Prototype Conformity Loss Implementation") 34 | parser.add_argument('-j', '--workers', default=4, type=int, 35 | help="number of data loading workers (default: 4)") 36 | parser.add_argument('--train-batch', default=128, type=int, metavar='N', 37 | help='train batchsize') 38 | parser.add_argument('--test-batch', default=100, type=int, metavar='N', 39 | help='test batchsize') 40 | parser.add_argument('--schedule', type=int, nargs='+', default=[142, 230, 360], 41 | help='Decrease learning rate at these epochs.') 42 | parser.add_argument('--lr_model', type=float, default=0.01, help="learning rate for CE Loss") 43 | parser.add_argument('--lr_prox', type=float, default=0.5, help="learning rate for Proximity Loss") # as per paper 44 | parser.add_argument('--weight-prox', type=float, default=1, help="weight for Proximity Loss") # as per paper 45 | parser.add_argument('--lr_conprox', type=float, default=0.0001, help="learning rate for Con-Proximity Loss") # as per paper 46 | parser.add_argument('--weight-conprox', type=float, default=0.0001, help="weight for Con-Proximity Loss") # as per paper 47 | parser.add_argument('--max-epoch', type=int, default=400) 48 | parser.add_argument('--gamma', type=float, default=0.1, help="learning rate decay") 49 | parser.add_argument('--eval-freq', type=int, default=10) 50 | parser.add_argument('--print-freq', type=int, default=50) 51 | parser.add_argument('--gpu', type=str, default='0') 52 | parser.add_argument('--seed', type=int, default=1) 53 | parser.add_argument('--use-cpu', action='store_true') 54 | parser.add_argument('--save-dir', type=str, default='log') 55 | 56 | args = parser.parse_args() 57 | state = {k: v for k, v in args._get_kwargs()} 58 | 59 | 60 | def main(): 61 | torch.manual_seed(args.seed) 62 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 63 | use_gpu = torch.cuda.is_available() 64 | if args.use_cpu: use_gpu = False 65 | 66 | sys.stdout = Logger(osp.join(args.save_dir, 'log_' + 'CIFAR-10_PC_Loss' + '.txt')) 67 | 68 | if use_gpu: 69 | print("Currently using GPU: {}".format(args.gpu)) 70 | cudnn.benchmark = True 71 | torch.cuda.manual_seed_all(args.seed) 72 | else: 73 | print("Currently using CPU") 74 | 75 | # Data Load 76 | num_classes=10 77 | print('==> Preparing dataset') 78 | transform_train = transforms.Compose([ 79 | transforms.RandomCrop(32, padding=4), 80 | transforms.RandomHorizontalFlip(), 81 | transforms.ToTensor(), 82 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]) 83 | 84 | transform_test = transforms.Compose([ 85 | transforms.ToTensor(), 86 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]) 87 | 88 | trainset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=True, 89 | download=True, transform=transform_train) 90 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch, pin_memory=True, 91 | shuffle=True, num_workers=args.workers) 92 | 93 | testset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=False, 94 | download=True, transform=transform_test) 95 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch, pin_memory=True, 96 | shuffle=False, num_workers=args.workers) 97 | 98 | # Loading the Model 99 | model = resnet(num_classes=num_classes,depth=110) 100 | 101 | if True: 102 | model = nn.DataParallel(model).cuda() 103 | 104 | criterion_xent = nn.CrossEntropyLoss() 105 | criterion_prox_1024 = Proximity(num_classes=num_classes, feat_dim=1024, use_gpu=use_gpu) 106 | criterion_prox_256 = Proximity(num_classes=num_classes, feat_dim=256, use_gpu=use_gpu) 107 | 108 | criterion_conprox_1024 = Con_Proximity(num_classes=num_classes, feat_dim=1024, use_gpu=use_gpu) 109 | criterion_conprox_256 = Con_Proximity(num_classes=num_classes, feat_dim=256, use_gpu=use_gpu) 110 | 111 | optimizer_model = torch.optim.SGD(model.parameters(), lr=args.lr_model, weight_decay=1e-04, momentum=0.9) 112 | 113 | optimizer_prox_1024 = torch.optim.SGD(criterion_prox_1024.parameters(), lr=args.lr_prox) 114 | optimizer_prox_256 = torch.optim.SGD(criterion_prox_256.parameters(), lr=args.lr_prox) 115 | 116 | optimizer_conprox_1024 = torch.optim.SGD(criterion_conprox_1024.parameters(), lr=args.lr_conprox) 117 | optimizer_conprox_256 = torch.optim.SGD(criterion_conprox_256.parameters(), lr=args.lr_conprox) 118 | 119 | 120 | filename= 'Models_Softmax/CIFAR10_Softmax.pth.tar' 121 | checkpoint = torch.load(filename) 122 | 123 | model.load_state_dict(checkpoint['state_dict']) 124 | optimizer_model.load_state_dict= checkpoint['optimizer_model'] 125 | 126 | start_time = time.time() 127 | 128 | for epoch in range(args.max_epoch): 129 | 130 | adjust_learning_rate(optimizer_model, epoch) 131 | adjust_learning_rate_prox(optimizer_prox_1024, epoch) 132 | adjust_learning_rate_prox(optimizer_prox_256, epoch) 133 | 134 | adjust_learning_rate_conprox(optimizer_conprox_1024, epoch) 135 | adjust_learning_rate_conprox(optimizer_conprox_256, epoch) 136 | 137 | print("==> Epoch {}/{}".format(epoch+1, args.max_epoch)) 138 | train(model, criterion_xent, criterion_prox_1024, criterion_prox_256, 139 | criterion_conprox_1024, criterion_conprox_256, 140 | optimizer_model, optimizer_prox_1024, optimizer_prox_256, 141 | optimizer_conprox_1024, optimizer_conprox_256, 142 | trainloader, use_gpu, num_classes, epoch) 143 | 144 | if args.eval_freq > 0 and (epoch+1) % args.eval_freq == 0 or (epoch+1) == args.max_epoch: 145 | print("==> Test") #Tests after every 10 epochs 146 | acc, err = test(model, testloader, use_gpu, num_classes, epoch) 147 | print("Accuracy (%): {}\t Error rate (%): {}".format(acc, err)) 148 | 149 | state_ = {'epoch': epoch + 1, 'state_dict': model.state_dict(), 150 | 'optimizer_model': optimizer_model.state_dict(), 'optimizer_prox_1024': optimizer_prox_1024.state_dict(), 151 | 'optimizer_prox_256': optimizer_prox_256.state_dict(), 'optimizer_conprox_1024': optimizer_conprox_1024.state_dict(), 152 | 'optimizer_conprox_256': optimizer_conprox_256.state_dict(),} 153 | 154 | torch.save(state_, 'Models_PCL/CIFAR10_PCL.pth.tar') 155 | 156 | elapsed = round(time.time() - start_time) 157 | elapsed = str(datetime.timedelta(seconds=elapsed)) 158 | print("Finished. Total elapsed time (h:m:s): {}".format(elapsed)) 159 | 160 | def train(model, criterion_xent, criterion_prox_1024, criterion_prox_256, 161 | criterion_conprox_1024, criterion_conprox_256, 162 | optimizer_model, optimizer_prox_1024, optimizer_prox_256, 163 | optimizer_conprox_1024, optimizer_conprox_256, 164 | trainloader, use_gpu, num_classes, epoch): 165 | 166 | model.train() 167 | xent_losses = AverageMeter() #Computes and stores the average and current value 168 | prox_losses_1024 = AverageMeter() 169 | prox_losses_256= AverageMeter() 170 | 171 | conprox_losses_1024 = AverageMeter() 172 | conprox_losses_256= AverageMeter() 173 | losses = AverageMeter() 174 | 175 | #Batchwise training 176 | for batch_idx, (data, labels) in enumerate(trainloader): 177 | if use_gpu: 178 | data, labels = data.cuda(), labels.cuda() 179 | feats128, feats256, feats1024, outputs = model(data) 180 | loss_xent = criterion_xent(outputs, labels) 181 | 182 | loss_prox_1024 = criterion_prox_1024(feats1024, labels) 183 | loss_prox_256= criterion_prox_256(feats256, labels) 184 | 185 | loss_conprox_1024 = criterion_conprox_1024(feats1024, labels) 186 | loss_conprox_256= criterion_conprox_256(feats256, labels) 187 | 188 | loss_prox_1024 *= args.weight_prox 189 | loss_prox_256 *= args.weight_prox 190 | 191 | loss_conprox_1024 *= args.weight_conprox 192 | loss_conprox_256 *= args.weight_conprox 193 | 194 | loss = loss_xent + loss_prox_1024 + loss_prox_256 - loss_conprox_1024 - loss_conprox_256 # total loss 195 | optimizer_model.zero_grad() 196 | 197 | optimizer_prox_1024.zero_grad() 198 | optimizer_prox_256.zero_grad() 199 | 200 | optimizer_conprox_1024.zero_grad() 201 | optimizer_conprox_256.zero_grad() 202 | 203 | loss.backward() 204 | optimizer_model.step() 205 | 206 | for param in criterion_prox_1024.parameters(): 207 | param.grad.data *= (1. / args.weight_prox) 208 | optimizer_prox_1024.step() 209 | 210 | for param in criterion_prox_256.parameters(): 211 | param.grad.data *= (1. / args.weight_prox) 212 | optimizer_prox_256.step() 213 | 214 | 215 | for param in criterion_conprox_1024.parameters(): 216 | param.grad.data *= (1. / args.weight_conprox) 217 | optimizer_conprox_1024.step() 218 | 219 | for param in criterion_conprox_256.parameters(): 220 | param.grad.data *= (1. / args.weight_conprox) 221 | optimizer_conprox_256.step() 222 | 223 | losses.update(loss.item(), labels.size(0)) 224 | xent_losses.update(loss_xent.item(), labels.size(0)) 225 | prox_losses_1024.update(loss_prox_1024.item(), labels.size(0)) 226 | prox_losses_256.update(loss_prox_256.item(), labels.size(0)) 227 | 228 | conprox_losses_1024.update(loss_conprox_1024.item(), labels.size(0)) 229 | conprox_losses_256.update(loss_conprox_256.item(), labels.size(0)) 230 | 231 | if (batch_idx+1) % args.print_freq == 0: 232 | print("Batch {}/{}\t Loss {:.6f} ({:.6f}) XentLoss {:.6f} ({:.6f}) ProxLoss_1024 {:.6f} ({:.6f}) ProxLoss_256 {:.6f} ({:.6f}) \n ConProxLoss_1024 {:.6f} ({:.6f}) ConProxLoss_256 {:.6f} ({:.6f}) " \ 233 | .format(batch_idx+1, len(trainloader), losses.val, losses.avg, xent_losses.val, xent_losses.avg, 234 | prox_losses_1024.val, prox_losses_1024.avg, prox_losses_256.val, prox_losses_256.avg , 235 | conprox_losses_1024.val, conprox_losses_1024.avg, conprox_losses_256.val, 236 | conprox_losses_256.avg )) 237 | 238 | 239 | def test(model, testloader, use_gpu, num_classes, epoch): 240 | model.eval() 241 | correct, total = 0, 0 242 | 243 | with torch.no_grad(): 244 | for data, labels in testloader: 245 | if True: 246 | data, labels = data.cuda(), labels.cuda() 247 | feats128, feats256, feats1024, outputs = model(data) 248 | predictions = outputs.data.max(1)[1] 249 | total += labels.size(0) 250 | correct += (predictions == labels.data).sum() 251 | 252 | 253 | acc = correct * 100. / total 254 | err = 100. - acc 255 | return acc, err 256 | 257 | def adjust_learning_rate(optimizer, epoch): 258 | global state 259 | if epoch in args.schedule: 260 | state['lr_model'] *= args.gamma 261 | for param_group in optimizer.param_groups: 262 | param_group['lr_model'] = state['lr_model'] 263 | 264 | def adjust_learning_rate_prox(optimizer, epoch): 265 | global state 266 | if epoch in args.schedule: 267 | state['lr_prox'] *= args.gamma 268 | for param_group in optimizer.param_groups: 269 | param_group['lr_prox'] = state['lr_prox'] 270 | 271 | def adjust_learning_rate_conprox(optimizer, epoch): 272 | global state 273 | if epoch in args.schedule: 274 | state['lr_conprox'] *= args.gamma 275 | for param_group in optimizer.param_groups: 276 | param_group['lr_conprox'] = state['lr_conprox'] 277 | if __name__ == '__main__': 278 | main() 279 | 280 | 281 | 282 | 283 | 284 | -------------------------------------------------------------------------------- /pcl_training_adversarial_fgsm.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Created on Wed Jan 23 10:15:27 2019 4 | 5 | @author: aamir-mustafa 6 | Implementation Part 2 of Paper: 7 | "Adversarial Defense by Restricting the Hidden Space of Deep Neural Networks" 8 | 9 | Here it is not necessary to save the best performing model (in terms of accuracy). The model with high robustness 10 | against adversarial attacks is chosen. 11 | This coe implements Adversarial Training using FGSM Attack. 12 | """ 13 | 14 | #Essential Imports 15 | import os 16 | import sys 17 | import argparse 18 | import datetime 19 | import time 20 | import os.path as osp 21 | import numpy as np 22 | import torch 23 | import torch.nn as nn 24 | import torch.backends.cudnn as cudnn 25 | import torchvision 26 | import torchvision.transforms as transforms 27 | from utils import AverageMeter, Logger 28 | from proximity import Proximity 29 | from contrastive_proximity import Con_Proximity 30 | from resnet_model import * # Imports the ResNet Model 31 | 32 | 33 | parser = argparse.ArgumentParser("Prototype Conformity Loss Implementation") 34 | parser.add_argument('-j', '--workers', default=4, type=int, 35 | help="number of data loading workers (default: 4)") 36 | parser.add_argument('--train-batch', default=64, type=int, metavar='N', 37 | help='train batchsize') 38 | parser.add_argument('--test-batch', default=100, type=int, metavar='N', 39 | help='test batchsize') 40 | parser.add_argument('--schedule', type=int, nargs='+', default=[142, 230, 360], 41 | help='Decrease learning rate at these epochs.') 42 | parser.add_argument('--lr_model', type=float, default=0.01, help="learning rate for model") 43 | parser.add_argument('--lr_prox', type=float, default=0.5, help="learning rate for Proximity Loss") # as per paper 44 | parser.add_argument('--weight-prox', type=float, default=1, help="weight for Proximity Loss") # as per paper 45 | parser.add_argument('--lr_conprox', type=float, default=0.00001, help="learning rate for Con-Proximity Loss") # as per paper 46 | parser.add_argument('--weight-conprox', type=float, default=0.00001, help="weight for Con-Proximity Loss") # as per paper 47 | parser.add_argument('--max-epoch', type=int, default=500) 48 | parser.add_argument('--gamma', type=float, default=0.1, help="learning rate decay") 49 | parser.add_argument('--eval-freq', type=int, default=10) 50 | parser.add_argument('--print-freq', type=int, default=50) 51 | parser.add_argument('--gpu', type=str, default='0') 52 | parser.add_argument('--seed', type=int, default=1) 53 | parser.add_argument('--use-cpu', action='store_true') 54 | parser.add_argument('--save-dir', type=str, default='log') 55 | 56 | args = parser.parse_args() 57 | state = {k: v for k, v in args._get_kwargs()} 58 | 59 | mean = [0.4914, 0.4822, 0.4465] 60 | std = [0.2023, 0.1994, 0.2010] 61 | def normalize(t): 62 | t[:, 0, :, :] = (t[:, 0, :, :] - mean[0])/std[0] 63 | t[:, 1, :, :] = (t[:, 1, :, :] - mean[1])/std[1] 64 | t[:, 2, :, :] = (t[:, 2, :, :] - mean[2])/std[2] 65 | 66 | return t 67 | 68 | def un_normalize(t): 69 | t[:, 0, :, :] = (t[:, 0, :, :] * std[0]) + mean[0] 70 | t[:, 1, :, :] = (t[:, 1, :, :] * std[1]) + mean[1] 71 | t[:, 2, :, :] = (t[:, 2, :, :] * std[2]) + mean[2] 72 | 73 | return t 74 | 75 | def FGSM(model, criterion, img, label, eps): 76 | adv = img.clone() 77 | adv.requires_grad = True 78 | _,_,_, out= model(adv) 79 | loss = criterion(out, label) 80 | loss.backward() 81 | adv.data = un_normalize(adv.data) + eps * adv.grad.sign() 82 | adv.data.clamp_(0.0, 1.0) 83 | adv.grad.data.zero_() 84 | return adv.detach() 85 | 86 | def main(): 87 | torch.manual_seed(args.seed) 88 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 89 | use_gpu = torch.cuda.is_available() 90 | if args.use_cpu: use_gpu = False 91 | 92 | sys.stdout = Logger(osp.join(args.save_dir, 'log_' + 'CIFAR-10_PC_Loss_FGSM_AdvTrain' + '.txt')) 93 | 94 | if use_gpu: 95 | print("Currently using GPU: {}".format(args.gpu)) 96 | cudnn.benchmark = True 97 | torch.cuda.manual_seed_all(args.seed) 98 | else: 99 | print("Currently using CPU") 100 | 101 | # Data Load 102 | num_classes=10 103 | print('==> Preparing dataset') 104 | transform_train = transforms.Compose([ 105 | transforms.RandomCrop(32, padding=4), 106 | transforms.RandomHorizontalFlip(), 107 | transforms.ToTensor(), 108 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]) 109 | 110 | transform_test = transforms.Compose([ 111 | transforms.ToTensor(), 112 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]) 113 | 114 | trainset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=True, 115 | download=True, transform=transform_train) 116 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch, pin_memory=True, 117 | shuffle=True, num_workers=args.workers) 118 | 119 | testset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=False, 120 | download=True, transform=transform_test) 121 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch, pin_memory=True, 122 | shuffle=False, num_workers=args.workers) 123 | 124 | # Loading the Model 125 | model = resnet(num_classes=num_classes,depth=110) 126 | 127 | if True: 128 | model = nn.DataParallel(model).cuda() 129 | 130 | criterion_xent = nn.CrossEntropyLoss() 131 | criterion_prox_1024 = Proximity(num_classes=num_classes, feat_dim=1024, use_gpu=use_gpu) 132 | criterion_prox_256 = Proximity(num_classes=num_classes, feat_dim=256, use_gpu=use_gpu) 133 | 134 | criterion_conprox_1024 = Con_Proximity(num_classes=num_classes, feat_dim=1024, use_gpu=use_gpu) 135 | criterion_conprox_256 = Con_Proximity(num_classes=num_classes, feat_dim=256, use_gpu=use_gpu) 136 | 137 | optimizer_model = torch.optim.SGD(model.parameters(), lr=args.lr_model, weight_decay=1e-04, momentum=0.9) 138 | 139 | optimizer_prox_1024 = torch.optim.SGD(criterion_prox_1024.parameters(), lr=args.lr_prox) 140 | optimizer_prox_256 = torch.optim.SGD(criterion_prox_256.parameters(), lr=args.lr_prox) 141 | 142 | optimizer_conprox_1024 = torch.optim.SGD(criterion_conprox_1024.parameters(), lr=args.lr_conprox) 143 | optimizer_conprox_256 = torch.optim.SGD(criterion_conprox_256.parameters(), lr=args.lr_conprox) 144 | 145 | 146 | filename= 'Models_Softmax/CIFAR10_Softmax.pth.tar' 147 | checkpoint = torch.load(filename) 148 | 149 | model.load_state_dict(checkpoint['state_dict']) 150 | optimizer_model.load_state_dict= checkpoint['optimizer_model'] 151 | 152 | start_time = time.time() 153 | 154 | for epoch in range(args.max_epoch): 155 | 156 | adjust_learning_rate(optimizer_model, epoch) 157 | adjust_learning_rate_prox(optimizer_prox_1024, epoch) 158 | adjust_learning_rate_prox(optimizer_prox_256, epoch) 159 | 160 | adjust_learning_rate_conprox(optimizer_conprox_1024, epoch) 161 | adjust_learning_rate_conprox(optimizer_conprox_256, epoch) 162 | 163 | print("==> Epoch {}/{}".format(epoch+1, args.max_epoch)) 164 | train(model, criterion_xent, criterion_prox_1024, criterion_prox_256, 165 | criterion_conprox_1024, criterion_conprox_256, 166 | optimizer_model, optimizer_prox_1024, optimizer_prox_256, 167 | optimizer_conprox_1024, optimizer_conprox_256, 168 | trainloader, use_gpu, num_classes, epoch) 169 | 170 | if args.eval_freq > 0 and (epoch+1) % args.eval_freq == 0 or (epoch+1) == args.max_epoch: 171 | print("==> Test") #Tests after every 10 epochs 172 | acc, err = test(model, testloader, use_gpu, num_classes, epoch) 173 | print("Accuracy (%): {}\t Error rate (%): {}".format(acc, err)) 174 | 175 | state_ = {'epoch': epoch + 1, 'state_dict': model.state_dict(), 176 | 'optimizer_model': optimizer_model.state_dict(), 'optimizer_prox_1024': optimizer_prox_1024.state_dict(), 177 | 'optimizer_prox_256': optimizer_prox_256.state_dict(), 'optimizer_conprox_1024': optimizer_conprox_1024.state_dict(), 178 | 'optimizer_conprox_256': optimizer_conprox_256.state_dict(),} 179 | 180 | torch.save(state_, 'Models_PCL_AdvTrain_FGSM/CIFAR10_PCL_AdvTrain_FGSM.pth.tar') 181 | 182 | elapsed = round(time.time() - start_time) 183 | elapsed = str(datetime.timedelta(seconds=elapsed)) 184 | print("Finished. Total elapsed time (h:m:s): {}".format(elapsed)) 185 | 186 | def train(model, criterion_xent, criterion_prox_1024, criterion_prox_256, 187 | criterion_conprox_1024, criterion_conprox_256, 188 | optimizer_model, optimizer_prox_1024, optimizer_prox_256, 189 | optimizer_conprox_1024, optimizer_conprox_256, 190 | trainloader, use_gpu, num_classes, epoch): 191 | 192 | # model.train() 193 | xent_losses = AverageMeter() #Computes and stores the average and current value 194 | prox_losses_1024 = AverageMeter() 195 | prox_losses_256= AverageMeter() 196 | 197 | conprox_losses_1024 = AverageMeter() 198 | conprox_losses_256= AverageMeter() 199 | losses = AverageMeter() 200 | 201 | #Batchwise training 202 | for batch_idx, (data, labels) in enumerate(trainloader): 203 | if use_gpu: 204 | data, labels = data.cuda(), labels.cuda() 205 | model.eval() 206 | eps= np.random.uniform(0.02,0.05) 207 | adv = FGSM(model, criterion_xent, data, labels, eps=eps) # Generates Batch-wise Adv Images 208 | adv.requires_grad= False 209 | 210 | adv= normalize(adv) 211 | adv= adv.cuda() 212 | true_labels_adv= labels 213 | data= torch.cat((data, adv),0) 214 | labels= torch.cat((labels, true_labels_adv)) 215 | model.train() 216 | 217 | feats128, feats256, feats1024, outputs = model(data) 218 | loss_xent = criterion_xent(outputs, labels) 219 | 220 | loss_prox_1024 = criterion_prox_1024(feats1024, labels) 221 | loss_prox_256= criterion_prox_256(feats256, labels) 222 | 223 | loss_conprox_1024 = criterion_conprox_1024(feats1024, labels) 224 | loss_conprox_256= criterion_conprox_256(feats256, labels) 225 | 226 | loss_prox_1024 *= args.weight_prox 227 | loss_prox_256 *= args.weight_prox 228 | 229 | loss_conprox_1024 *= args.weight_conprox 230 | loss_conprox_256 *= args.weight_conprox 231 | 232 | loss = loss_xent + loss_prox_1024 + loss_prox_256 - loss_conprox_1024 - loss_conprox_256 # total loss 233 | optimizer_model.zero_grad() 234 | 235 | optimizer_prox_1024.zero_grad() 236 | optimizer_prox_256.zero_grad() 237 | 238 | optimizer_conprox_1024.zero_grad() 239 | optimizer_conprox_256.zero_grad() 240 | 241 | loss.backward() 242 | optimizer_model.step() 243 | 244 | for param in criterion_prox_1024.parameters(): 245 | param.grad.data *= (1. / args.weight_prox) 246 | optimizer_prox_1024.step() 247 | 248 | for param in criterion_prox_256.parameters(): 249 | param.grad.data *= (1. / args.weight_prox) 250 | optimizer_prox_256.step() 251 | 252 | 253 | for param in criterion_conprox_1024.parameters(): 254 | param.grad.data *= (1. / args.weight_conprox) 255 | optimizer_conprox_1024.step() 256 | 257 | for param in criterion_conprox_256.parameters(): 258 | param.grad.data *= (1. / args.weight_conprox) 259 | optimizer_conprox_256.step() 260 | 261 | losses.update(loss.item(), labels.size(0)) 262 | xent_losses.update(loss_xent.item(), labels.size(0)) 263 | prox_losses_1024.update(loss_prox_1024.item(), labels.size(0)) 264 | prox_losses_256.update(loss_prox_256.item(), labels.size(0)) 265 | 266 | conprox_losses_1024.update(loss_conprox_1024.item(), labels.size(0)) 267 | conprox_losses_256.update(loss_conprox_256.item(), labels.size(0)) 268 | 269 | if (batch_idx+1) % args.print_freq == 0: 270 | print("Batch {}/{}\t Loss {:.6f} ({:.6f}) XentLoss {:.6f} ({:.6f}) ProxLoss_1024 {:.6f} ({:.6f}) ProxLoss_256 {:.6f} ({:.6f}) \n ConProxLoss_1024 {:.6f} ({:.6f}) ConProxLoss_256 {:.6f} ({:.6f}) " \ 271 | .format(batch_idx+1, len(trainloader), losses.val, losses.avg, xent_losses.val, xent_losses.avg, 272 | prox_losses_1024.val, prox_losses_1024.avg, prox_losses_256.val, prox_losses_256.avg , 273 | conprox_losses_1024.val, conprox_losses_1024.avg, conprox_losses_256.val, 274 | conprox_losses_256.avg )) 275 | 276 | 277 | def test(model, testloader, use_gpu, num_classes, epoch): 278 | model.eval() 279 | correct, total = 0, 0 280 | 281 | with torch.no_grad(): 282 | for data, labels in testloader: 283 | if True: 284 | data, labels = data.cuda(), labels.cuda() 285 | feats128, feats256, feats1024, outputs = model(data) 286 | predictions = outputs.data.max(1)[1] 287 | total += labels.size(0) 288 | correct += (predictions == labels.data).sum() 289 | 290 | 291 | acc = correct * 100. / total 292 | err = 100. - acc 293 | return acc, err 294 | 295 | def adjust_learning_rate(optimizer, epoch): 296 | global state 297 | if epoch in args.schedule: 298 | state['lr_model'] *= args.gamma 299 | for param_group in optimizer.param_groups: 300 | param_group['lr_model'] = state['lr_model'] 301 | 302 | def adjust_learning_rate_prox(optimizer, epoch): 303 | global state 304 | if epoch in args.schedule: 305 | state['lr_prox'] *= args.gamma 306 | for param_group in optimizer.param_groups: 307 | param_group['lr_prox'] = state['lr_prox'] 308 | 309 | def adjust_learning_rate_conprox(optimizer, epoch): 310 | global state 311 | if epoch in args.schedule: 312 | state['lr_conprox'] *= args.gamma 313 | for param_group in optimizer.param_groups: 314 | param_group['lr_conprox'] = state['lr_conprox'] 315 | if __name__ == '__main__': 316 | main() 317 | 318 | 319 | 320 | 321 | 322 | -------------------------------------------------------------------------------- /pcl_training_adversarial_pgd.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Created on Wed Jan 23 10:15:27 2019 4 | 5 | @author: aamir-mustafa 6 | Implementation Part 2 of Paper: 7 | "Adversarial Defense by Restricting the Hidden Space of Deep Neural Networks" 8 | 9 | Here it is not necessary to save the best performing model (in terms of accuracy). The model with high robustness 10 | against adversarial attacks is chosen. 11 | This coe implements Adversarial Training using PGD Attack. 12 | """ 13 | 14 | #Essential Imports 15 | import os 16 | import sys 17 | import argparse 18 | import datetime 19 | import time 20 | import os.path as osp 21 | import numpy as np 22 | import torch 23 | import torch.nn as nn 24 | import torch.backends.cudnn as cudnn 25 | import torchvision 26 | import torchvision.transforms as transforms 27 | from utils import AverageMeter, Logger 28 | from proximity import Proximity 29 | from contrastive_proximity import Con_Proximity 30 | from resnet_model import * # Imports the ResNet Model 31 | 32 | 33 | parser = argparse.ArgumentParser("Prototype Conformity Loss Implementation") 34 | parser.add_argument('-j', '--workers', default=4, type=int, 35 | help="number of data loading workers (default: 4)") 36 | parser.add_argument('--train-batch', default=64, type=int, metavar='N', 37 | help='train batchsize') 38 | parser.add_argument('--test-batch', default=100, type=int, metavar='N', 39 | help='test batchsize') 40 | parser.add_argument('--schedule', type=int, nargs='+', default=[142, 230, 360], 41 | help='Decrease learning rate at these epochs.') 42 | parser.add_argument('--lr_model', type=float, default=0.01, help="learning rate for model") 43 | parser.add_argument('--lr_prox', type=float, default=0.5, help="learning rate for Proximity Loss") # as per paper 44 | parser.add_argument('--weight-prox', type=float, default=1, help="weight for Proximity Loss") # as per paper 45 | parser.add_argument('--lr_conprox', type=float, default=0.00001, help="learning rate for Con-Proximity Loss") # as per paper 46 | parser.add_argument('--weight-conprox', type=float, default=0.00001, help="weight for Con-Proximity Loss") # as per paper 47 | parser.add_argument('--max-epoch', type=int, default=500) 48 | parser.add_argument('--gamma', type=float, default=0.1, help="learning rate decay") 49 | parser.add_argument('--eval-freq', type=int, default=10) 50 | parser.add_argument('--print-freq', type=int, default=50) 51 | parser.add_argument('--gpu', type=str, default='0') 52 | parser.add_argument('--seed', type=int, default=1) 53 | parser.add_argument('--use-cpu', action='store_true') 54 | parser.add_argument('--save-dir', type=str, default='log') 55 | 56 | args = parser.parse_args() 57 | state = {k: v for k, v in args._get_kwargs()} 58 | 59 | mean = [0.4914, 0.4822, 0.4465] 60 | std = [0.2023, 0.1994, 0.2010] 61 | def normalize(t): 62 | t[:, 0, :, :] = (t[:, 0, :, :] - mean[0])/std[0] 63 | t[:, 1, :, :] = (t[:, 1, :, :] - mean[1])/std[1] 64 | t[:, 2, :, :] = (t[:, 2, :, :] - mean[2])/std[2] 65 | 66 | return t 67 | 68 | def un_normalize(t): 69 | t[:, 0, :, :] = (t[:, 0, :, :] * std[0]) + mean[0] 70 | t[:, 1, :, :] = (t[:, 1, :, :] * std[1]) + mean[1] 71 | t[:, 2, :, :] = (t[:, 2, :, :] * std[2]) + mean[2] 72 | 73 | return t 74 | 75 | def attack(model, criterion, img, label, eps, attack_type, iters): 76 | adv = img.detach() 77 | adv.requires_grad = True 78 | 79 | if attack_type == 'fgsm': 80 | iterations = 1 81 | else: 82 | iterations = iters 83 | 84 | if attack_type == 'pgd': 85 | step = 2 / 255 86 | else: 87 | step = eps / iterations 88 | 89 | noise = 0 90 | 91 | for j in range(iterations): 92 | _,_,_,out_adv = model(adv.clone()) 93 | loss = criterion(out_adv, label) 94 | loss.backward() 95 | 96 | if attack_type == 'mim': 97 | adv_mean= torch.mean(torch.abs(adv.grad), dim=1, keepdim=True) 98 | adv_mean= torch.mean(torch.abs(adv_mean), dim=2, keepdim=True) 99 | adv_mean= torch.mean(torch.abs(adv_mean), dim=3, keepdim=True) 100 | adv.grad = adv.grad / adv_mean 101 | noise = noise + adv.grad 102 | else: 103 | noise = adv.grad 104 | 105 | # Optimization step 106 | adv.data = un_normalize(adv.data) + step * noise.sign() 107 | # adv.data = adv.data + step * adv.grad.sign() 108 | 109 | if attack_type == 'pgd': 110 | adv.data = torch.where(adv.data > img.data + eps, img.data + eps, adv.data) 111 | adv.data = torch.where(adv.data < img.data - eps, img.data - eps, adv.data) 112 | adv.data.clamp_(0.0, 1.0) 113 | 114 | adv.grad.data.zero_() 115 | 116 | return adv.detach() 117 | 118 | def main(): 119 | torch.manual_seed(args.seed) 120 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 121 | use_gpu = torch.cuda.is_available() 122 | if args.use_cpu: use_gpu = False 123 | 124 | sys.stdout = Logger(osp.join(args.save_dir, 'log_' + 'CIFAR-10_PC_Loss_PGD_AdvTrain' + '.txt')) 125 | 126 | if use_gpu: 127 | print("Currently using GPU: {}".format(args.gpu)) 128 | cudnn.benchmark = True 129 | torch.cuda.manual_seed_all(args.seed) 130 | else: 131 | print("Currently using CPU") 132 | 133 | # Data Load 134 | num_classes=10 135 | print('==> Preparing dataset') 136 | transform_train = transforms.Compose([ 137 | transforms.RandomCrop(32, padding=4), 138 | transforms.RandomHorizontalFlip(), 139 | transforms.ToTensor(), 140 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]) 141 | 142 | transform_test = transforms.Compose([ 143 | transforms.ToTensor(), 144 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]) 145 | 146 | trainset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=True, 147 | download=True, transform=transform_train) 148 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch, pin_memory=True, 149 | shuffle=True, num_workers=args.workers) 150 | 151 | testset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=False, 152 | download=True, transform=transform_test) 153 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch, pin_memory=True, 154 | shuffle=False, num_workers=args.workers) 155 | 156 | # Loading the Model 157 | model = resnet(num_classes=num_classes,depth=110) 158 | 159 | if True: 160 | model = nn.DataParallel(model).cuda() 161 | 162 | criterion_xent = nn.CrossEntropyLoss() 163 | criterion_prox_1024 = Proximity(num_classes=num_classes, feat_dim=1024, use_gpu=use_gpu) 164 | criterion_prox_256 = Proximity(num_classes=num_classes, feat_dim=256, use_gpu=use_gpu) 165 | 166 | criterion_conprox_1024 = Con_Proximity(num_classes=num_classes, feat_dim=1024, use_gpu=use_gpu) 167 | criterion_conprox_256 = Con_Proximity(num_classes=num_classes, feat_dim=256, use_gpu=use_gpu) 168 | 169 | optimizer_model = torch.optim.SGD(model.parameters(), lr=args.lr_model, weight_decay=1e-04, momentum=0.9) 170 | 171 | optimizer_prox_1024 = torch.optim.SGD(criterion_prox_1024.parameters(), lr=args.lr_prox) 172 | optimizer_prox_256 = torch.optim.SGD(criterion_prox_256.parameters(), lr=args.lr_prox) 173 | 174 | optimizer_conprox_1024 = torch.optim.SGD(criterion_conprox_1024.parameters(), lr=args.lr_conprox) 175 | optimizer_conprox_256 = torch.optim.SGD(criterion_conprox_256.parameters(), lr=args.lr_conprox) 176 | 177 | 178 | filename= 'Models_Softmax/CIFAR10_Softmax.pth.tar' 179 | checkpoint = torch.load(filename) 180 | 181 | model.load_state_dict(checkpoint['state_dict']) 182 | optimizer_model.load_state_dict= checkpoint['optimizer_model'] 183 | 184 | start_time = time.time() 185 | 186 | for epoch in range(args.max_epoch): 187 | 188 | adjust_learning_rate(optimizer_model, epoch) 189 | adjust_learning_rate_prox(optimizer_prox_1024, epoch) 190 | adjust_learning_rate_prox(optimizer_prox_256, epoch) 191 | 192 | adjust_learning_rate_conprox(optimizer_conprox_1024, epoch) 193 | adjust_learning_rate_conprox(optimizer_conprox_256, epoch) 194 | 195 | print("==> Epoch {}/{}".format(epoch+1, args.max_epoch)) 196 | train(model, criterion_xent, criterion_prox_1024, criterion_prox_256, 197 | criterion_conprox_1024, criterion_conprox_256, 198 | optimizer_model, optimizer_prox_1024, optimizer_prox_256, 199 | optimizer_conprox_1024, optimizer_conprox_256, 200 | trainloader, use_gpu, num_classes, epoch) 201 | 202 | if args.eval_freq > 0 and (epoch+1) % args.eval_freq == 0 or (epoch+1) == args.max_epoch: 203 | print("==> Test") #Tests after every 10 epochs 204 | acc, err = test(model, testloader, use_gpu, num_classes, epoch) 205 | print("Accuracy (%): {}\t Error rate (%): {}".format(acc, err)) 206 | 207 | state_ = {'epoch': epoch + 1, 'state_dict': model.state_dict(), 208 | 'optimizer_model': optimizer_model.state_dict(), 'optimizer_prox_1024': optimizer_prox_1024.state_dict(), 209 | 'optimizer_prox_256': optimizer_prox_256.state_dict(), 'optimizer_conprox_1024': optimizer_conprox_1024.state_dict(), 210 | 'optimizer_conprox_256': optimizer_conprox_256.state_dict(),} 211 | 212 | torch.save(state_, 'Models_PCL_AdvTrain_PGD/CIFAR10_PCL_AdvTrain_PGD.pth.tar') 213 | 214 | elapsed = round(time.time() - start_time) 215 | elapsed = str(datetime.timedelta(seconds=elapsed)) 216 | print("Finished. Total elapsed time (h:m:s): {}".format(elapsed)) 217 | 218 | def train(model, criterion_xent, criterion_prox_1024, criterion_prox_256, 219 | criterion_conprox_1024, criterion_conprox_256, 220 | optimizer_model, optimizer_prox_1024, optimizer_prox_256, 221 | optimizer_conprox_1024, optimizer_conprox_256, 222 | trainloader, use_gpu, num_classes, epoch): 223 | 224 | # model.train() 225 | xent_losses = AverageMeter() #Computes and stores the average and current value 226 | prox_losses_1024 = AverageMeter() 227 | prox_losses_256= AverageMeter() 228 | 229 | conprox_losses_1024 = AverageMeter() 230 | conprox_losses_256= AverageMeter() 231 | losses = AverageMeter() 232 | 233 | #Batchwise training 234 | for batch_idx, (data, labels) in enumerate(trainloader): 235 | if use_gpu: 236 | data, labels = data.cuda(), labels.cuda() 237 | model.eval() 238 | eps= np.random.uniform(0.02,0.05) 239 | adv = attack(model, criterion_xent, data, labels, eps=eps, attack_type='pgd', iters= 10) # Generates Batch-wise Adv Images 240 | adv.requires_grad= False 241 | 242 | adv= normalize(adv) 243 | adv= adv.cuda() 244 | true_labels_adv= labels 245 | data= torch.cat((data, adv),0) 246 | labels= torch.cat((labels, true_labels_adv)) 247 | model.train() 248 | 249 | feats128, feats256, feats1024, outputs = model(data) 250 | loss_xent = criterion_xent(outputs, labels) 251 | 252 | loss_prox_1024 = criterion_prox_1024(feats1024, labels) 253 | loss_prox_256= criterion_prox_256(feats256, labels) 254 | 255 | loss_conprox_1024 = criterion_conprox_1024(feats1024, labels) 256 | loss_conprox_256= criterion_conprox_256(feats256, labels) 257 | 258 | loss_prox_1024 *= args.weight_prox 259 | loss_prox_256 *= args.weight_prox 260 | 261 | loss_conprox_1024 *= args.weight_conprox 262 | loss_conprox_256 *= args.weight_conprox 263 | 264 | loss = loss_xent + loss_prox_1024 + loss_prox_256 - loss_conprox_1024 - loss_conprox_256 # total loss 265 | optimizer_model.zero_grad() 266 | 267 | optimizer_prox_1024.zero_grad() 268 | optimizer_prox_256.zero_grad() 269 | 270 | optimizer_conprox_1024.zero_grad() 271 | optimizer_conprox_256.zero_grad() 272 | 273 | loss.backward() 274 | optimizer_model.step() 275 | 276 | for param in criterion_prox_1024.parameters(): 277 | param.grad.data *= (1. / args.weight_prox) 278 | optimizer_prox_1024.step() 279 | 280 | for param in criterion_prox_256.parameters(): 281 | param.grad.data *= (1. / args.weight_prox) 282 | optimizer_prox_256.step() 283 | 284 | 285 | for param in criterion_conprox_1024.parameters(): 286 | param.grad.data *= (1. / args.weight_conprox) 287 | optimizer_conprox_1024.step() 288 | 289 | for param in criterion_conprox_256.parameters(): 290 | param.grad.data *= (1. / args.weight_conprox) 291 | optimizer_conprox_256.step() 292 | 293 | losses.update(loss.item(), labels.size(0)) 294 | xent_losses.update(loss_xent.item(), labels.size(0)) 295 | prox_losses_1024.update(loss_prox_1024.item(), labels.size(0)) 296 | prox_losses_256.update(loss_prox_256.item(), labels.size(0)) 297 | 298 | conprox_losses_1024.update(loss_conprox_1024.item(), labels.size(0)) 299 | conprox_losses_256.update(loss_conprox_256.item(), labels.size(0)) 300 | 301 | if (batch_idx+1) % args.print_freq == 0: 302 | print("Batch {}/{}\t Loss {:.6f} ({:.6f}) XentLoss {:.6f} ({:.6f}) ProxLoss_1024 {:.6f} ({:.6f}) ProxLoss_256 {:.6f} ({:.6f}) \n ConProxLoss_1024 {:.6f} ({:.6f}) ConProxLoss_256 {:.6f} ({:.6f}) " \ 303 | .format(batch_idx+1, len(trainloader), losses.val, losses.avg, xent_losses.val, xent_losses.avg, 304 | prox_losses_1024.val, prox_losses_1024.avg, prox_losses_256.val, prox_losses_256.avg , 305 | conprox_losses_1024.val, conprox_losses_1024.avg, conprox_losses_256.val, 306 | conprox_losses_256.avg )) 307 | 308 | 309 | def test(model, testloader, use_gpu, num_classes, epoch): 310 | model.eval() 311 | correct, total = 0, 0 312 | 313 | with torch.no_grad(): 314 | for data, labels in testloader: 315 | if True: 316 | data, labels = data.cuda(), labels.cuda() 317 | feats128, feats256, feats1024, outputs = model(data) 318 | predictions = outputs.data.max(1)[1] 319 | total += labels.size(0) 320 | correct += (predictions == labels.data).sum() 321 | 322 | 323 | acc = correct * 100. / total 324 | err = 100. - acc 325 | return acc, err 326 | 327 | def adjust_learning_rate(optimizer, epoch): 328 | global state 329 | if epoch in args.schedule: 330 | state['lr_model'] *= args.gamma 331 | for param_group in optimizer.param_groups: 332 | param_group['lr_model'] = state['lr_model'] 333 | 334 | def adjust_learning_rate_prox(optimizer, epoch): 335 | global state 336 | if epoch in args.schedule: 337 | state['lr_prox'] *= args.gamma 338 | for param_group in optimizer.param_groups: 339 | param_group['lr_prox'] = state['lr_prox'] 340 | 341 | def adjust_learning_rate_conprox(optimizer, epoch): 342 | global state 343 | if epoch in args.schedule: 344 | state['lr_conprox'] *= args.gamma 345 | for param_group in optimizer.param_groups: 346 | param_group['lr_conprox'] = state['lr_conprox'] 347 | if __name__ == '__main__': 348 | main() 349 | 350 | 351 | 352 | 353 | 354 | -------------------------------------------------------------------------------- /proximity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Proximity(nn.Module): 5 | 6 | def __init__(self, num_classes=100, feat_dim=1024, use_gpu=True): 7 | super(Proximity, self).__init__() 8 | self.num_classes = num_classes 9 | self.feat_dim = feat_dim 10 | self.use_gpu = use_gpu 11 | 12 | if self.use_gpu: 13 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 14 | else: 15 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 16 | 17 | def forward(self, x, labels): 18 | 19 | batch_size = x.size(0) 20 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 21 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 22 | distmat.addmm_(1, -2, x, self.centers.t()) 23 | 24 | classes = torch.arange(self.num_classes).long() 25 | if self.use_gpu: classes = classes.cuda() 26 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 27 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 28 | 29 | dist = [] 30 | for i in range(batch_size): 31 | value = distmat[i][mask[i]] 32 | value = value.clamp(min=1e-12, max=1e+12) 33 | dist.append(value) 34 | dist = torch.cat(dist) 35 | loss = dist.mean() 36 | 37 | return loss 38 | -------------------------------------------------------------------------------- /resnet_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Tue Apr 2 14:21:30 2019 3 | 4 | @author: aamir-mustafa 5 | """ 6 | 7 | import torch.nn as nn 8 | import math 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | "3x3 convolution with padding" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = conv3x3(inplanes, planes, stride) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.conv2 = conv3x3(planes, planes) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | self.downsample = downsample 27 | self.stride = stride 28 | 29 | def forward(self, x): 30 | residual = x 31 | 32 | out = self.conv1(x) 33 | out = self.bn1(out) 34 | out = self.relu(out) 35 | out = self.conv2(out) 36 | out = self.bn2(out) 37 | 38 | if self.downsample is not None: 39 | residual = self.downsample(x) 40 | 41 | out += residual 42 | out = self.relu(out) 43 | 44 | return out 45 | 46 | 47 | class Bottleneck(nn.Module): 48 | expansion = 4 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None): 51 | super(Bottleneck, self).__init__() 52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 55 | padding=1, bias=False) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 58 | self.bn3 = nn.BatchNorm2d(planes * 4) 59 | self.relu = nn.ReLU(inplace=True) 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): # (conv-bn-relu) x 3 times 64 | residual = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | out = self.relu(out) 72 | out = self.conv3(out) 73 | out = self.bn3(out) 74 | 75 | if self.downsample is not None: 76 | residual = self.downsample(x) 77 | 78 | out += residual # in our case is none 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | 84 | class ResNet(nn.Module): 85 | 86 | def __init__(self, depth, num_classes=10): 87 | super(ResNet, self).__init__() 88 | # Model type specifies number of layers for CIFAR-10 model 89 | assert (depth - 2) % 6 == 0, 'depth should be 6n+2' 90 | n = (depth - 2) // 6 91 | 92 | block = Bottleneck if depth >=44 else BasicBlock 93 | 94 | self.inplanes = 16 95 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 96 | bias=False) 97 | self.bn1 = nn.BatchNorm2d(16) 98 | self.relu = nn.ReLU(inplace=True) 99 | self.layer1 = self._make_layer(block, 16, n) 100 | 101 | self.layer2 = self._make_layer(block, 32, n, stride=2) 102 | self.layer3 = self._make_layer(block, 64, n, stride=2) 103 | self.avgpool = nn.AvgPool2d(8) 104 | 105 | self.maxpool2= nn.MaxPool2d(16) 106 | self.fc = nn.Linear(64 * block.expansion, 1024) 107 | self.fcf = nn.Linear(1024,num_classes) 108 | 109 | for m in self.modules(): 110 | if isinstance(m, nn.Conv2d): 111 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 112 | m.weight.data.normal_(0, math.sqrt(2. / n)) 113 | elif isinstance(m, nn.BatchNorm2d): 114 | m.weight.data.fill_(1) 115 | m.bias.data.zero_() 116 | 117 | def _make_layer(self, block, planes, blocks, stride=1): 118 | downsample = None 119 | if stride != 1 or self.inplanes != planes * block.expansion: 120 | downsample = nn.Sequential( 121 | nn.Conv2d(self.inplanes, planes * block.expansion, 122 | kernel_size=1, stride=stride, bias=False), 123 | nn.BatchNorm2d(planes * block.expansion), 124 | ) 125 | 126 | layers = [] 127 | layers.append(block(self.inplanes, planes, stride, downsample)) 128 | self.inplanes = planes * block.expansion 129 | for i in range(1, blocks): 130 | layers.append(block(self.inplanes, planes)) 131 | 132 | return nn.Sequential(*layers) 133 | 134 | def forward(self, x): 135 | x = self.conv1(x) 136 | 137 | 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | 141 | x = self.layer1(x) 142 | x = self.layer2(x) 143 | 144 | m = self.maxpool2(x) 145 | m = m.view(m.size(0), -1) # 128 dimensional 146 | 147 | x = self.layer3(x) 148 | 149 | 150 | x = self.avgpool(x) 151 | z = x.view(x.size(0), -1) # 256 dimensional 152 | x = self.fc(z) # 1024 dimensional 153 | y = self.fcf(x) # num_classes dimensional 154 | 155 | return m, z, x, y 156 | 157 | 158 | def resnet(**kwargs): 159 | """ 160 | Constructs a ResNet model. 161 | """ 162 | return ResNet(**kwargs) -------------------------------------------------------------------------------- /robust_ml.py: -------------------------------------------------------------------------------- 1 | """ 2 | author: aamir-mustafa 3 | This is RobustML interface implementation for the paper: 4 | Adversarial Defense by Restricting the Hidden Space of Deep Neural Networks 5 | """ 6 | 7 | import robustml 8 | import torch 9 | import torch.nn as nn 10 | import numpy as np 11 | from resnet_model import * # Imports the ResNet Model 12 | 13 | num_classes=10 14 | model = resnet(num_classes=num_classes,depth=110) 15 | model = nn.DataParallel(model).cuda() 16 | filename= 'robust_model.pth.tar' 17 | checkpoint = torch.load(filename) 18 | model.load_state_dict(checkpoint['state_dict']) 19 | 20 | #Normalize the data as per CIFAR-10 mean and std 21 | mean = [0.4914, 0.4822, 0.4465] 22 | std = [0.2023, 0.1994, 0.2010] 23 | def normalize(t): 24 | t[:, 0, :, :] = (t[:, 0, :, :] - mean[0])/std[0] 25 | t[:, 1, :, :] = (t[:, 1, :, :] - mean[1])/std[1] 26 | t[:, 2, :, :] = (t[:, 2, :, :] - mean[2])/std[2] 27 | return t 28 | 29 | 30 | class Model(robustml.model.Model): 31 | def __init__(self): 32 | 33 | self._dataset = robustml.dataset.CIFAR10() 34 | self._threat_model = robustml.threat_model.Linf(epsilon=8/255) 35 | 36 | @property 37 | def dataset(self): 38 | return self._dataset 39 | 40 | @property 41 | def threat_model(self): 42 | return self._threat_model 43 | 44 | def classify(self, x): 45 | X = torch.Tensor([x]).cuda() 46 | model.eval() 47 | out=model(normalize(X.clone().detach()))[-1].argmax(dim=-1) # Our Model outputs intermediate feats as well 48 | return out 49 | 50 | if __name__ == '__main__': 51 | robust_model = Model() 52 | #Design a Random Input 53 | x = np.zeros((3, 32, 32), dtype=np.float32) 54 | x[1:2 ,5:-5, 12:-12] = 1 55 | 56 | print('Predicted Class', robust_model.classify(x)) 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /robust_model.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aamir-mustafa/pcl-adversarial-defense/37242580937c267efcac6679ec050aa438b47796/robust_model.pth.tar -------------------------------------------------------------------------------- /robustness.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Sun Mar 24 17:51:08 2019 3 | 4 | @author: aamir-mustafa 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | from resnet_model import * # Imports the ResNet Model 11 | """ 12 | Adversarial Attack Options: fgsm, bim, mim, pgd 13 | """ 14 | num_classes=10 15 | 16 | model = resnet(num_classes=num_classes,depth=110) 17 | if True: 18 | model = nn.DataParallel(model).cuda() 19 | 20 | #Loading Trained Model 21 | softmax_filename= 'Models_Softmax/CIFAR10_Softmax.pth.tar' 22 | filename= 'Models_PCL/CIFAR10_PCL.pth.tar' 23 | robust_model= 'robust_model.pth.tar' 24 | checkpoint = torch.load(robust_model) 25 | model.load_state_dict(checkpoint['state_dict']) 26 | model.eval() 27 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 28 | 29 | # Loading Test Data (Un-normalized) 30 | transform_test = transforms.Compose([transforms.ToTensor(),]) 31 | 32 | testset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=False, 33 | download=True, transform=transform_test) 34 | test_loader = torch.utils.data.DataLoader(testset, batch_size=256, pin_memory=True, 35 | shuffle=False, num_workers=4) 36 | 37 | # Mean and Standard Deiation of the Dataset 38 | mean = [0.4914, 0.4822, 0.4465] 39 | std = [0.2023, 0.1994, 0.2010] 40 | def normalize(t): 41 | t[:, 0, :, :] = (t[:, 0, :, :] - mean[0])/std[0] 42 | t[:, 1, :, :] = (t[:, 1, :, :] - mean[1])/std[1] 43 | t[:, 2, :, :] = (t[:, 2, :, :] - mean[2])/std[2] 44 | 45 | return t 46 | def un_normalize(t): 47 | t[:, 0, :, :] = (t[:, 0, :, :] * std[0]) + mean[0] 48 | t[:, 1, :, :] = (t[:, 1, :, :] * std[1]) + mean[1] 49 | t[:, 2, :, :] = (t[:, 2, :, :] * std[2]) + mean[2] 50 | 51 | return t 52 | 53 | # Attacking Images batch-wise 54 | def attack(model, criterion, img, label, eps, attack_type, iters): 55 | adv = img.detach() 56 | adv.requires_grad = True 57 | 58 | if attack_type == 'fgsm': 59 | iterations = 1 60 | else: 61 | iterations = iters 62 | 63 | if attack_type == 'pgd': 64 | step = 2 / 255 65 | else: 66 | step = eps / iterations 67 | 68 | noise = 0 69 | 70 | for j in range(iterations): 71 | _,_,_,out_adv = model(normalize(adv.clone())) 72 | loss = criterion(out_adv, label) 73 | loss.backward() 74 | 75 | if attack_type == 'mim': 76 | adv_mean= torch.mean(torch.abs(adv.grad), dim=1, keepdim=True) 77 | adv_mean= torch.mean(torch.abs(adv_mean), dim=2, keepdim=True) 78 | adv_mean= torch.mean(torch.abs(adv_mean), dim=3, keepdim=True) 79 | adv.grad = adv.grad / adv_mean 80 | noise = noise + adv.grad 81 | else: 82 | noise = adv.grad 83 | 84 | # Optimization step 85 | adv.data = adv.data + step * noise.sign() 86 | # adv.data = adv.data + step * adv.grad.sign() 87 | 88 | if attack_type == 'pgd': 89 | adv.data = torch.where(adv.data > img.data + eps, img.data + eps, adv.data) 90 | adv.data = torch.where(adv.data < img.data - eps, img.data - eps, adv.data) 91 | adv.data.clamp_(0.0, 1.0) 92 | 93 | adv.grad.data.zero_() 94 | 95 | return adv.detach() 96 | 97 | # Loss Criteria 98 | criterion = nn.CrossEntropyLoss() 99 | adv_acc = 0 100 | clean_acc = 0 101 | eps =8/255 # Epsilon for Adversarial Attack 102 | 103 | for i, (img, label) in enumerate(test_loader): 104 | img, label = img.to(device), label.to(device) 105 | 106 | clean_acc += torch.sum(model(normalize(img.clone().detach()))[3].argmax(dim=-1) == label).item() 107 | adv= attack(model, criterion, img, label, eps=eps, attack_type= 'bim', iters= 10 ) 108 | adv_acc += torch.sum(model(normalize(adv.clone().detach()))[3].argmax(dim=-1) == label).item() 109 | print('Batch: {0}'.format(i)) 110 | print('Clean accuracy:{0:.3%}\t Adversarial accuracy:{1:.3%}'.format(clean_acc / len(testset), adv_acc / len(testset))) 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /softmax_training.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Wed Jan 23 10:15:27 2019 3 | 4 | @author: aamir-mustafa 5 | This is Part 1 file for replicating the results for Paper: 6 | "Adversarial Defense by Restricting the Hidden Space of Deep Neural Networks" 7 | Here a ResNet model is trained with Softmax Loss for 164 epochs. 8 | """ 9 | 10 | # Essential Imports 11 | import os 12 | import sys 13 | import argparse 14 | import datetime 15 | import time 16 | import os.path as osp 17 | import numpy as np 18 | import torch 19 | import torch.nn as nn 20 | import torch.backends.cudnn as cudnn 21 | import torchvision 22 | import torchvision.transforms as transforms 23 | from utils import AverageMeter, Logger 24 | from resnet_model import * # Imports the ResNet Model 25 | 26 | parser = argparse.ArgumentParser("Softmax Training for CIFAR-10 Dataset") 27 | parser.add_argument('-j', '--workers', default=4, type=int, 28 | help="number of data loading workers (default: 4)") 29 | parser.add_argument('--train-batch', default=128, type=int, metavar='N', 30 | help='train batchsize') 31 | parser.add_argument('--test-batch', default=100, type=int, metavar='N', 32 | help='test batchsize') 33 | parser.add_argument('--lr', type=float, default=0.1, help="learning rate for model") 34 | parser.add_argument('--schedule', type=int, nargs='+', default=[81, 122, 140], 35 | help='Decrease learning rate at these epochs.') 36 | parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.') 37 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 38 | help='momentum') 39 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 40 | metavar='W', help='weight decay (default: 1e-4)') 41 | parser.add_argument('--max-epoch', type=int, default=164) 42 | parser.add_argument('--eval-freq', type=int, default=10) 43 | parser.add_argument('--print-freq', type=int, default=50) 44 | parser.add_argument('--gpu', type=str, default='0') #gpu to be used 45 | parser.add_argument('--seed', type=int, default=1) 46 | parser.add_argument('--use-cpu', action='store_true') 47 | parser.add_argument('--save-dir', type=str, default='log') 48 | 49 | args = parser.parse_args() 50 | state = {k: v for k, v in args._get_kwargs()} 51 | 52 | #%% 53 | 54 | def main(): 55 | torch.manual_seed(args.seed) 56 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 57 | use_gpu = torch.cuda.is_available() 58 | if args.use_cpu: use_gpu = False 59 | 60 | sys.stdout = Logger(osp.join(args.save_dir, 'log_' + 'CIFAR-10_OnlySoftmax' + '.txt')) 61 | 62 | if use_gpu: 63 | print("Currently using GPU: {}".format(args.gpu)) 64 | cudnn.benchmark = True 65 | torch.cuda.manual_seed_all(args.seed) 66 | else: 67 | print("Currently using CPU") 68 | 69 | # Data Loading 70 | num_classes=10 71 | print('==> Preparing dataset ') 72 | transform_train = transforms.Compose([ 73 | transforms.RandomCrop(32, padding=4), 74 | transforms.RandomHorizontalFlip(), 75 | transforms.ToTensor(), 76 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]) 77 | 78 | transform_test = transforms.Compose([ 79 | transforms.ToTensor(), 80 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),]) 81 | 82 | 83 | trainset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=True, 84 | download=True, transform=transform_train) 85 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch, pin_memory=True, 86 | shuffle=True, num_workers=args.workers) 87 | 88 | testset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=False, 89 | download=True, transform=transform_test) 90 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch, pin_memory=True, 91 | shuffle=False, num_workers=args.workers) 92 | 93 | # Loading the Model 94 | 95 | model = resnet(num_classes=num_classes,depth=110) 96 | 97 | if use_gpu: 98 | model = nn.DataParallel(model).cuda() 99 | 100 | criterion = nn.CrossEntropyLoss() 101 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, 102 | weight_decay=args.weight_decay) 103 | 104 | start_time = time.time() 105 | 106 | for epoch in range(args.max_epoch): 107 | adjust_learning_rate(optimizer, epoch) 108 | 109 | print("==> Epoch {}/{}".format(epoch+1, args.max_epoch)) 110 | print('LR: %f' % (state['lr'])) 111 | 112 | train(trainloader, model, criterion, optimizer, epoch, use_gpu, num_classes) 113 | 114 | if args.eval_freq > 0 and (epoch+1) % args.eval_freq == 0 or (epoch+1) == args.max_epoch: 115 | print("==> Test") #Tests after every 10 epochs 116 | acc, err = test(model, testloader, use_gpu, num_classes, epoch) 117 | print("Accuracy (%): {}\t Error rate (%): {}".format(acc, err)) 118 | 119 | checkpoint = {'epoch': epoch + 1, 'state_dict': model.state_dict(), 120 | 'optimizer_model': optimizer.state_dict(), } 121 | torch.save(checkpoint, 'Models_Softmax/CIFAR10_Softmax.pth.tar') 122 | 123 | elapsed = round(time.time() - start_time) 124 | elapsed = str(datetime.timedelta(seconds=elapsed)) 125 | print("Finished. Total elapsed time (h:m:s): {}".format(elapsed)) 126 | 127 | 128 | def train(trainloader, model, criterion, optimizer, epoch, use_gpu, num_classes): 129 | 130 | model.train() 131 | losses = AverageMeter() 132 | 133 | #Batch-wise Training 134 | for batch_idx, (data, labels) in enumerate(trainloader): 135 | if use_gpu: 136 | data, labels = data.cuda(), labels.cuda() 137 | feats_128, feats_256, feats_1024, outputs = model(data) 138 | loss_xent = criterion(outputs, labels) # cross-entropy loss calculation 139 | 140 | optimizer.zero_grad() 141 | loss_xent.backward() 142 | optimizer.step() 143 | 144 | losses.update(loss_xent.item(), labels.size(0)) # AverageMeter() has this param 145 | 146 | if (batch_idx+1) % args.print_freq == 0: 147 | print("Batch {}/{}\t Loss {:.6f} ({:.6f}) " \ 148 | .format(batch_idx+1, len(trainloader), losses.val, losses.avg)) 149 | 150 | def test(model, testloader, use_gpu, num_classes, epoch): 151 | model.eval() 152 | correct, total = 0, 0 153 | 154 | with torch.no_grad(): 155 | for data, labels in testloader: 156 | if use_gpu: 157 | data, labels = data.cuda(), labels.cuda() 158 | feats_128, feats_256, feats_1024, outputs = model(data) 159 | predictions = outputs.data.max(1)[1] 160 | total += labels.size(0) 161 | correct += (predictions == labels.data).sum() 162 | 163 | acc = correct * 100. / total 164 | err = 100. - acc 165 | return acc, err 166 | 167 | def adjust_learning_rate(optimizer, epoch): 168 | global state 169 | if epoch in args.schedule: 170 | state['lr'] *= args.gamma 171 | for param_group in optimizer.param_groups: 172 | param_group['lr'] = state['lr'] 173 | 174 | if __name__ == '__main__': 175 | main() 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import errno 4 | import shutil 5 | import os.path as osp 6 | 7 | import torch 8 | 9 | def mkdir_if_missing(directory): 10 | if not osp.exists(directory): 11 | try: 12 | os.makedirs(directory) 13 | except OSError as e: 14 | if e.errno != errno.EEXIST: 15 | raise 16 | 17 | class AverageMeter(object): 18 | """Computes and stores the average and current value. 19 | 20 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 21 | """ 22 | def __init__(self): 23 | self.reset() 24 | 25 | def reset(self): 26 | self.val = 0 27 | self.avg = 0 28 | self.sum = 0 29 | self.count = 0 30 | 31 | def update(self, val, n=1): 32 | self.val = val 33 | self.sum += val * n 34 | self.count += n 35 | self.avg = self.sum / self.count 36 | 37 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 38 | mkdir_if_missing(osp.dirname(fpath)) 39 | torch.save(state, fpath) 40 | if is_best: 41 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar')) 42 | 43 | class Logger(object): 44 | """ 45 | Write console output to external text file. 46 | 47 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 48 | """ 49 | def __init__(self, fpath=None): 50 | self.console = sys.stdout 51 | self.file = None 52 | if fpath is not None: 53 | mkdir_if_missing(os.path.dirname(fpath)) 54 | self.file = open(fpath, 'w') 55 | 56 | def __del__(self): 57 | self.close() 58 | 59 | def __enter__(self): 60 | pass 61 | 62 | def __exit__(self, *args): 63 | self.close() 64 | 65 | def write(self, msg): 66 | self.console.write(msg) 67 | if self.file is not None: 68 | self.file.write(msg) 69 | 70 | def flush(self): 71 | self.console.flush() 72 | if self.file is not None: 73 | self.file.flush() 74 | os.fsync(self.file.fileno()) 75 | 76 | def close(self): 77 | self.console.close() 78 | if self.file is not None: 79 | self.file.close() --------------------------------------------------------------------------------