├── AKD2.py ├── ARD.py ├── IAD-I.py ├── IAD-II.py ├── LICENSE ├── README.md ├── attack_generator.py ├── basic_eval.py ├── models ├── __init__.py ├── mobilenetv2.py ├── preresnet.py ├── resnet.py └── wideresnet.py ├── pic └── overview.png ├── pre_train ├── AT.py ├── ST.py ├── attack_generator.py ├── models │ ├── __init__.py │ ├── preact_resnet.py │ ├── resnet.py │ ├── wide_resnet.py │ ├── wideresnet.py │ └── wrn_madry.py └── utils │ ├── __init__.py │ ├── eval.py │ ├── logger.py │ └── misc.py └── utils ├── __init__.py ├── eval.py ├── logger.py └── misc.py /AKD2.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import torch.backends.cudnn as cudnn 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import numpy as np 10 | import os 11 | import argparse 12 | from tqdm import tqdm 13 | from utils import Logger 14 | from models import * 15 | 16 | parser = argparse.ArgumentParser(description='AKD2') 17 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 18 | parser.add_argument('--lr_schedule', type=int, nargs='+', default=[100, 150], help='Decrease learning rate at these epochs.') 19 | parser.add_argument('--lr_factor', default=0.1, type=float, help='factor by which to decrease lr') 20 | parser.add_argument('--epochs', default=200, type=int, help='number of epochs for training') 21 | parser.add_argument('--output', default = '', type=str, help='output subdirectory') 22 | parser.add_argument('--model', default = 'ResNet18', type = str, help = 'student model name') 23 | parser.add_argument('--teacher_model', default = 'ResNet18', type = str, help = 'teacher network model') 24 | parser.add_argument('--teacher_path', default = './pre_train/AT_teacher_cifar10/bestpoint.pth.tar', type=str, help='path of AT teacher net being distilled') 25 | parser.add_argument('--teacher_st_path', default = './pre_train/ST_teacher_cifar10/bestpoint.pth.tar', type=str, help='path of ST teacher net being distilled') 26 | parser.add_argument('--temp', default=1.0, type=float, help='temperature for distillation') 27 | parser.add_argument('--val_period', default=1, type=int, help='print every __ epoch') 28 | parser.add_argument('--save_period', default=1, type=int, help='save every __ epoch') 29 | parser.add_argument('--alpha', default=0.5, type=float, help='weight for sum of losses') 30 | parser.add_argument('--dataset', default = 'CIFAR10', type=str, help='name of dataset') 31 | parser.add_argument('--out-dir',type=str,default='./AKD2_CIFAR10',help='dir of output') 32 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed') 33 | parser.add_argument('--resume',type=str,default=None,help='whether to resume training') 34 | parser.add_argument('--beta',type=float,default=0.0) 35 | parser.add_argument('--alpha1', default=0.5, type=float, help='weight for sum of losses') 36 | parser.add_argument('--alpha2', default=0.25, type=float, help='weight for sum of losses') 37 | 38 | args = parser.parse_args() 39 | 40 | seed = args.seed 41 | out_dir = args.out_dir 42 | torch.manual_seed(seed) 43 | np.random.seed(seed) 44 | torch.cuda.manual_seed_all(seed) 45 | torch.backends.cudnn.benchmark = True 46 | torch.backends.cudnn.deterministic = True 47 | 48 | 49 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 50 | 51 | def adjust_learning_rate(optimizer, epoch, lr): 52 | if epoch in args.lr_schedule: 53 | lr *= args.lr_factor 54 | for param_group in optimizer.param_groups: 55 | param_group['lr'] = lr 56 | 57 | # Store path 58 | if not os.path.exists(out_dir): 59 | os.makedirs(out_dir) 60 | 61 | # Save checkpoint 62 | def save_checkpoint(state, checkpoint=out_dir, filename='checkpoint.pth.tar'): 63 | filepath = os.path.join(checkpoint, filename) 64 | torch.save(state, filepath) 65 | 66 | # prepare the dataset 67 | print('==> Preparing data..') 68 | transform_train = transforms.Compose([ 69 | transforms.RandomCrop(32, padding=4), 70 | transforms.RandomHorizontalFlip(), 71 | transforms.ToTensor(), 72 | ]) 73 | transform_test = transforms.Compose([ 74 | transforms.ToTensor(), 75 | ]) 76 | if args.dataset == 'CIFAR10': 77 | trainset = torchvision.datasets.CIFAR10(root='~/data/cifar-10', train=True, download=True, transform=transform_train) 78 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 79 | testset = torchvision.datasets.CIFAR10(root='~/data/cifar-10', train=False, download=True, transform=transform_test) 80 | testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 81 | num_classes = 10 82 | elif args.dataset == 'CIFAR100': 83 | trainset = torchvision.datasets.CIFAR100(root='~/data/cifar-100', train=True, download=True, transform=transform_train) 84 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 85 | testset = torchvision.datasets.CIFAR100(root='~/data/cifar-100', train=False, download=True, transform=transform_test) 86 | testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 87 | num_classes = 100 88 | 89 | class AttackPGD(nn.Module): 90 | def __init__(self, basic_net, config): 91 | super(AttackPGD, self).__init__() 92 | self.basic_net = basic_net 93 | self.step_size = config['step_size'] 94 | self.epsilon = config['epsilon'] 95 | self.num_steps = config['num_steps'] 96 | 97 | def forward(self, inputs, targets): 98 | x = inputs.detach() 99 | x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon) 100 | for i in range(self.num_steps): 101 | x.requires_grad_() 102 | with torch.enable_grad(): 103 | loss = F.cross_entropy(self.basic_net(x), targets, size_average=False) 104 | grad = torch.autograd.grad(loss, [x])[0] 105 | x = x.detach() + self.step_size*torch.sign(grad.detach()) 106 | x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon) 107 | x = torch.clamp(x, 0.0, 1.0) 108 | return self.basic_net(x), x 109 | 110 | 111 | # build teacher and student models 112 | # dataparalella 113 | 114 | print('==> Building model..'+args.model) 115 | # student 116 | if args.model == 'MobileNetV2': 117 | basic_net = MobileNetV2(num_classes=num_classes) 118 | elif args.model == 'WideResNet': 119 | basic_net = WideResNet(num_classes=num_classes) 120 | elif args.model == 'ResNet18': 121 | basic_net = ResNet18(num_classes=num_classes) 122 | basic_net = basic_net.to(device) 123 | basic_net = torch.nn.DataParallel(basic_net) 124 | 125 | # teacher 126 | if args.teacher_path != '': 127 | if args.teacher_model == 'MobileNetV2': 128 | teacher_net = MobileNetV2(num_classes=num_classes) 129 | elif args.teacher_model == 'WideResNet': 130 | teacher_net = WideResNet(num_classes=num_classes) 131 | elif args.teacher_model == 'ResNet18': 132 | teacher_net = ResNet18(num_classes=num_classes) 133 | teacher_net = teacher_net.to(device) 134 | for param in teacher_net.parameters(): 135 | param.requires_grad = False 136 | 137 | teacher_st_net = ResNet18(num_classes=num_classes) 138 | teacher_st_net = teacher_st_net.to(device) 139 | for param in teacher_st_net.parameters(): 140 | param.requires_grad = False 141 | 142 | config_train = { 143 | 'epsilon': 8 / 255, 144 | 'num_steps': 10, 145 | 'step_size': 2 / 255, 146 | } 147 | 148 | net = AttackPGD(basic_net, config_train) 149 | 150 | if device == 'cuda': 151 | cudnn.benchmark = True 152 | 153 | print('==> Loading at teacher..') 154 | teacher_net = torch.nn.DataParallel(teacher_net) 155 | teacher_net.load_state_dict(torch.load(args.teacher_path)['state_dict']) 156 | teacher_net.eval() 157 | 158 | print('==> Loading st teacher..') 159 | teacher_st_net = torch.nn.DataParallel(teacher_st_net) 160 | teacher_st_net.load_state_dict(torch.load(args.teacher_st_path)['state_dict']) 161 | teacher_st_net.eval() 162 | 163 | 164 | KL_loss = nn.KLDivLoss(reduce=False) 165 | XENT_loss = nn.CrossEntropyLoss() 166 | lr=args.lr 167 | 168 | def train(epoch, optimizer, net, basic_net, teacher_net): 169 | net.train() 170 | train_loss = 0 171 | iterator = tqdm(trainloader, ncols=0, leave=False) 172 | for batch_idx, (inputs, targets) in enumerate(iterator): 173 | inputs, targets = inputs.to(device), targets.to(device) 174 | optimizer.zero_grad() 175 | outputs, pert_inputs = net.forward(inputs, targets) 176 | teacher_outputs = teacher_net(pert_inputs) 177 | st_outputs = teacher_st_net(pert_inputs) 178 | 179 | loss = (1-args.alpha1-args.alpha2)*XENT_loss(outputs, targets)+args.alpha1*args.temp*args.temp*(1/len(outputs))*torch.sum(KL_loss(F.log_softmax(outputs/args.temp, dim=1),F.softmax(teacher_outputs/args.temp, dim=1)).sum(dim=1))+args.alpha2*args.temp*args.temp*(1/len(outputs))*torch.sum(KL_loss(F.log_softmax(outputs/args.temp, dim=1),F.softmax(st_outputs/args.temp, dim=1)).sum(dim=1)) 180 | 181 | loss.backward() 182 | optimizer.step() 183 | train_loss += loss.item() 184 | iterator.set_description(str(loss.item())) 185 | 186 | print('Mean Training Loss:', train_loss/len(iterator)) 187 | return train_loss 188 | 189 | 190 | def test(epoch, optimizer, net, basic_net, teacher_net): 191 | net.eval() 192 | adv_correct = 0 193 | natural_correct = 0 194 | total = 0 195 | with torch.no_grad(): 196 | iterator = tqdm(testloader, ncols=0, leave=False) 197 | for batch_idx, (inputs, targets) in enumerate(iterator): 198 | inputs, targets = inputs.to(device), targets.to(device) 199 | adv_outputs, pert_inputs = net(inputs, targets) 200 | natural_outputs = basic_net(inputs) 201 | _, adv_predicted = adv_outputs.max(1) 202 | _, natural_predicted = natural_outputs.max(1) 203 | natural_correct += natural_predicted.eq(targets).sum().item() 204 | total += targets.size(0) 205 | adv_correct += adv_predicted.eq(targets).sum().item() 206 | iterator.set_description(str(adv_predicted.eq(targets).sum().item()/targets.size(0))) 207 | robust_acc = 100.*adv_correct/total 208 | natural_acc = 100.*natural_correct/total 209 | print('Natural acc:', natural_acc) 210 | print('Robust acc:', robust_acc) 211 | return natural_acc, robust_acc 212 | 213 | def main(): 214 | lr = args.lr 215 | best_acc = 0 216 | test_robust = 0 217 | optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=2e-4) 218 | logger_test = Logger(os.path.join(out_dir, 'student_results.txt'), title='student') 219 | logger_test.set_names(['Epoch', 'Natural Test Acc', 'PGD10 Acc']) 220 | for epoch in range(args.epochs): 221 | adjust_learning_rate(optimizer, epoch, lr) 222 | 223 | print("teacher >>>> student ") 224 | train_loss = train(epoch, optimizer, net, basic_net, teacher_net) 225 | 226 | if (epoch+1)%args.val_period == 0: 227 | natural_val, robust_val = test(epoch, optimizer, net, basic_net, teacher_net) 228 | logger_test.append([epoch + 1, natural_val, robust_val]) 229 | save_checkpoint({ 230 | 'epoch': epoch + 1, 231 | 'test_nat_acc': natural_val, 232 | 'test_pgd10_acc': robust_val, 233 | 'state_dict': basic_net.state_dict(), 234 | 'optimizer' : optimizer.state_dict(), 235 | }) 236 | 237 | if robust_val > best_acc: 238 | best_acc = robust_val 239 | save_checkpoint({ 240 | 'epoch': epoch + 1, 241 | 'state_dict': basic_net.state_dict(), 242 | 'test_nat_acc': natural_val, 243 | 'test_pgd10_acc': robust_val, 244 | 'optimizer' : optimizer.state_dict(), 245 | },filename='bestpoint.pth.tar') 246 | 247 | 248 | if __name__ == '__main__': 249 | main() 250 | -------------------------------------------------------------------------------- /ARD.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import torch.backends.cudnn as cudnn 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import numpy as np 10 | import os 11 | import argparse 12 | from tqdm import tqdm 13 | from utils import Logger 14 | from models import * 15 | 16 | parser = argparse.ArgumentParser(description='ARD') 17 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 18 | parser.add_argument('--lr_schedule', type=int, nargs='+', default=[100, 150], help='Decrease learning rate at these epochs.') 19 | parser.add_argument('--lr_factor', default=0.1, type=float, help='factor by which to decrease lr') 20 | parser.add_argument('--epochs', default=200, type=int, help='number of epochs for training') 21 | parser.add_argument('--output', default = '', type=str, help='output subdirectory') 22 | parser.add_argument('--model', default = 'ResNet18', type = str, help = 'student model name') 23 | parser.add_argument('--teacher_model', default = 'ResNet18', type = str, help = 'teacher network model') 24 | parser.add_argument('--teacher_path', default = './pre_train/AT_teacher_cifar10/bestpoint.pth.tar', type=str, help='path of teacher net being distilled') 25 | parser.add_argument('--temp', default=1.0, type=float, help='temperature for distillation') 26 | parser.add_argument('--val_period', default=1, type=int, help='print every __ epoch') 27 | parser.add_argument('--save_period', default=1, type=int, help='save every __ epoch') 28 | parser.add_argument('--alpha', default=1.0, type=float, help='weight for sum of losses') 29 | parser.add_argument('--dataset', default = 'CIFAR10', type=str, help='name of dataset') 30 | parser.add_argument('--out-dir',type=str,default='./ARD_CIFAR10',help='dir of output') 31 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed') 32 | parser.add_argument('--resume',type=str,default=None,help='whether to resume training') 33 | args = parser.parse_args() 34 | 35 | seed = args.seed 36 | out_dir = args.out_dir 37 | torch.manual_seed(seed) 38 | np.random.seed(seed) 39 | torch.cuda.manual_seed_all(seed) 40 | torch.backends.cudnn.benchmark = True 41 | torch.backends.cudnn.deterministic = True 42 | 43 | 44 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 45 | 46 | def adjust_learning_rate(optimizer, epoch, lr): 47 | if epoch in args.lr_schedule: 48 | lr *= args.lr_factor 49 | for param_group in optimizer.param_groups: 50 | param_group['lr'] = lr 51 | 52 | # Store path 53 | if not os.path.exists(out_dir): 54 | os.makedirs(out_dir) 55 | 56 | # Save checkpoint 57 | def save_checkpoint(state, checkpoint=out_dir, filename='checkpoint.pth.tar'): 58 | filepath = os.path.join(checkpoint, filename) 59 | torch.save(state, filepath) 60 | 61 | # prepare the dataset 62 | print('==> Preparing data..') 63 | transform_train = transforms.Compose([ 64 | transforms.RandomCrop(32, padding=4), 65 | transforms.RandomHorizontalFlip(), 66 | transforms.ToTensor(), 67 | ]) 68 | transform_test = transforms.Compose([ 69 | transforms.ToTensor(), 70 | ]) 71 | if args.dataset == 'CIFAR10': 72 | trainset = torchvision.datasets.CIFAR10(root='~/data/cifar-10', train=True, download=True, transform=transform_train) 73 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 74 | testset = torchvision.datasets.CIFAR10(root='~/data/cifar-10', train=False, download=True, transform=transform_test) 75 | testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 76 | num_classes = 10 77 | elif args.dataset == 'CIFAR100': 78 | trainset = torchvision.datasets.CIFAR100(root='~/data/cifar-100', train=True, download=True, transform=transform_train) 79 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 80 | testset = torchvision.datasets.CIFAR100(root='~/data/cifar-100', train=False, download=True, transform=transform_test) 81 | testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 82 | num_classes = 100 83 | 84 | class AttackPGD(nn.Module): 85 | def __init__(self, basic_net, config): 86 | super(AttackPGD, self).__init__() 87 | self.basic_net = basic_net 88 | self.step_size = config['step_size'] 89 | self.epsilon = config['epsilon'] 90 | self.num_steps = config['num_steps'] 91 | 92 | def forward(self, inputs, targets): 93 | x = inputs.detach() 94 | x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon) 95 | for i in range(self.num_steps): 96 | x.requires_grad_() 97 | with torch.enable_grad(): 98 | loss = F.cross_entropy(self.basic_net(x), targets, size_average=False) 99 | grad = torch.autograd.grad(loss, [x])[0] 100 | x = x.detach() + self.step_size*torch.sign(grad.detach()) 101 | x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon) 102 | x = torch.clamp(x, 0.0, 1.0) 103 | return self.basic_net(x), x 104 | 105 | 106 | # build teacher and student models 107 | 108 | print('==> Building model..'+args.model) 109 | # student 110 | if args.model == 'MobileNetV2': 111 | basic_net = MobileNetV2(num_classes=num_classes) 112 | elif args.model == 'WideResNet': 113 | basic_net = WideResNet(num_classes=num_classes) 114 | elif args.model == 'ResNet18': 115 | basic_net = ResNet18(num_classes=num_classes) 116 | basic_net = basic_net.to(device) 117 | basic_net = torch.nn.DataParallel(basic_net) 118 | 119 | # teacher 120 | if args.teacher_path != '': 121 | if args.teacher_model == 'MobileNetV2': 122 | teacher_net = MobileNetV2(num_classes=num_classes) 123 | elif args.teacher_model == 'WideResNet': 124 | teacher_net = WideResNet(num_classes=num_classes) 125 | elif args.teacher_model == 'ResNet18': 126 | teacher_net = ResNet18(num_classes=num_classes) 127 | teacher_net = teacher_net.to(device) 128 | for param in teacher_net.parameters(): 129 | param.requires_grad = False 130 | 131 | config_train = { 132 | 'epsilon': 8 / 255, 133 | 'num_steps': 10, 134 | 'step_size': 2 / 255, 135 | } 136 | 137 | net = AttackPGD(basic_net, config_train) 138 | 139 | if device == 'cuda': 140 | cudnn.benchmark = True 141 | 142 | print('==> Loading teacher..') 143 | teacher_net = torch.nn.DataParallel(teacher_net) 144 | teacher_net.load_state_dict(torch.load(args.teacher_path)['state_dict']) 145 | teacher_net.eval() 146 | 147 | net_t = AttackPGD(teacher_net, config_train) 148 | 149 | KL_loss = nn.KLDivLoss(reduce=False) 150 | XENT_loss = nn.CrossEntropyLoss() 151 | lr=args.lr 152 | 153 | def train(epoch, optimizer, net, basic_net, teacher_net): 154 | net.train() 155 | train_loss = 0 156 | iterator = tqdm(trainloader, ncols=0, leave=False) 157 | for batch_idx, (inputs, targets) in enumerate(iterator): 158 | inputs, targets = inputs.to(device), targets.to(device) 159 | optimizer.zero_grad() 160 | 161 | outputs, pert_inputs = net(inputs, targets) 162 | teacher_outputs = teacher_net(inputs) 163 | basic_outputs = basic_net(inputs) 164 | loss = args.alpha*args.temp*args.temp*(1/len(outputs))*torch.sum(KL_loss(F.log_softmax(outputs/args.temp, dim=1),F.softmax(teacher_outputs/args.temp, dim=1)).sum(dim=1))+(1.0-args.alpha)*XENT_loss(basic_outputs, targets) 165 | loss.backward() 166 | optimizer.step() 167 | train_loss += loss.item() 168 | iterator.set_description(str(loss.item())) 169 | 170 | print('Mean Training Loss:', train_loss/len(iterator)) 171 | return train_loss 172 | 173 | 174 | def test(epoch, optimizer, net, basic_net, teacher_net): 175 | net.eval() 176 | adv_correct = 0 177 | natural_correct = 0 178 | total = 0 179 | with torch.no_grad(): 180 | iterator = tqdm(testloader, ncols=0, leave=False) 181 | for batch_idx, (inputs, targets) in enumerate(iterator): 182 | inputs, targets = inputs.to(device), targets.to(device) 183 | adv_outputs, pert_inputs = net(inputs, targets) 184 | natural_outputs = basic_net(inputs) 185 | _, adv_predicted = adv_outputs.max(1) 186 | _, natural_predicted = natural_outputs.max(1) 187 | natural_correct += natural_predicted.eq(targets).sum().item() 188 | total += targets.size(0) 189 | adv_correct += adv_predicted.eq(targets).sum().item() 190 | iterator.set_description(str(adv_predicted.eq(targets).sum().item()/targets.size(0))) 191 | robust_acc = 100.*adv_correct/total 192 | natural_acc = 100.*natural_correct/total 193 | print('Natural acc:', natural_acc) 194 | print('Robust acc:', robust_acc) 195 | return natural_acc, robust_acc 196 | 197 | def main(): 198 | lr = args.lr 199 | best_acc = 0 200 | test_robust = 0 201 | optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=2e-4) 202 | logger_test = Logger(os.path.join(out_dir, 'student_results.txt'), title='student') 203 | logger_test.set_names(['Epoch', 'Natural Test Acc', 'PGD10 Acc']) 204 | for epoch in range(args.epochs): 205 | adjust_learning_rate(optimizer, epoch, lr) 206 | 207 | print("teacher >>>> student ") 208 | train_loss = train(epoch, optimizer, net, basic_net, teacher_net) 209 | 210 | if (epoch+1)%args.val_period == 0: 211 | natural_val, robust_val = test(epoch, optimizer, net, basic_net, teacher_net) 212 | logger_test.append([epoch + 1, natural_val, robust_val]) 213 | save_checkpoint({ 214 | 'epoch': epoch + 1, 215 | 'test_nat_acc': natural_val, 216 | 'test_pgd10_acc': robust_val, 217 | 'state_dict': basic_net.state_dict(), 218 | 'optimizer' : optimizer.state_dict(), 219 | }) 220 | 221 | if robust_val > best_acc: 222 | best_acc = robust_val 223 | save_checkpoint({ 224 | 'epoch': epoch + 1, 225 | 'state_dict': basic_net.state_dict(), 226 | 'test_nat_acc': natural_val, 227 | 'test_pgd10_acc': robust_val, 228 | 'optimizer' : optimizer.state_dict(), 229 | },filename='bestpoint.pth.tar') 230 | 231 | 232 | if __name__ == '__main__': 233 | main() 234 | -------------------------------------------------------------------------------- /IAD-I.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import torch.backends.cudnn as cudnn 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import numpy as np 10 | import os 11 | import argparse 12 | import time 13 | from tqdm import tqdm 14 | from utils import Logger 15 | from models import * 16 | 17 | parser = argparse.ArgumentParser(description='IAD-I') 18 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 19 | parser.add_argument('--lr_schedule', type=int, nargs='+', default=[100, 150], help='Decrease learning rate at these epochs.') 20 | parser.add_argument('--lr_factor', default=0.1, type=float, help='factor by which to decrease lr') 21 | parser.add_argument('--epochs', default=200, type=int, help='number of epochs for training') 22 | parser.add_argument('--output', default = '', type=str, help='output subdirectory') 23 | parser.add_argument('--model', default = 'ResNet18', type = str, help = 'student model name') 24 | parser.add_argument('--teacher_model', default = 'ResNet18', type = str, help = 'teacher network model') 25 | parser.add_argument('--teacher_path', default = './pre_train/AT_teacher_cifar10/bestpoint.pth.tar', type=str, help='path of teacher net being distilled') 26 | parser.add_argument('--temp', default=1.0, type=float, help='temperature for distillation') 27 | parser.add_argument('--val_period', default=1, type=int, help='print every __ epoch') 28 | parser.add_argument('--save_period', default=1, type=int, help='save every __ epoch') 29 | parser.add_argument('--alpha', default=1.0, type=float, help='weight for sum of losses') 30 | parser.add_argument('--dataset', default = 'CIFAR10', type=str, help='name of dataset') 31 | parser.add_argument('--out-dir',type=str,default='./IAD_I_CIFAR10',help='dir of output') 32 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed') 33 | parser.add_argument('--resume',type=str,default=None,help='whether to resume training') 34 | parser.add_argument('--beta',type=float, default=0.1) 35 | parser.add_argument('--begin',type=int, default=60) 36 | args = parser.parse_args() 37 | 38 | seed = args.seed 39 | out_dir = args.out_dir 40 | torch.manual_seed(seed) 41 | np.random.seed(seed) 42 | torch.cuda.manual_seed_all(seed) 43 | torch.backends.cudnn.benchmark = True 44 | torch.backends.cudnn.deterministic = True 45 | 46 | 47 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 48 | 49 | def adjust_learning_rate(optimizer, epoch, lr): 50 | if epoch in args.lr_schedule: 51 | lr *= args.lr_factor 52 | for param_group in optimizer.param_groups: 53 | param_group['lr'] = lr 54 | 55 | # Store path 56 | if not os.path.exists(out_dir): 57 | os.makedirs(out_dir) 58 | 59 | # Save checkpoint 60 | def save_checkpoint(state, checkpoint=out_dir, filename='checkpoint.pth.tar'): 61 | filepath = os.path.join(checkpoint, filename) 62 | torch.save(state, filepath) 63 | 64 | # prepare the dataset 65 | print('==> Preparing data..') 66 | transform_train = transforms.Compose([ 67 | transforms.RandomCrop(32, padding=4), 68 | transforms.RandomHorizontalFlip(), 69 | transforms.ToTensor(), 70 | ]) 71 | transform_test = transforms.Compose([ 72 | transforms.ToTensor(), 73 | ]) 74 | if args.dataset == 'CIFAR10': 75 | trainset = torchvision.datasets.CIFAR10(root='~/data/cifar-10', train=True, download=True, transform=transform_train) 76 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 77 | testset = torchvision.datasets.CIFAR10(root='~/data/cifar-10', train=False, download=True, transform=transform_test) 78 | testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 79 | num_classes = 10 80 | elif args.dataset == 'CIFAR100': 81 | trainset = torchvision.datasets.CIFAR100(root='~/data/cifar-100', train=True, download=True, transform=transform_train) 82 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 83 | testset = torchvision.datasets.CIFAR100(root='~/data/cifar-100', train=False, download=True, transform=transform_test) 84 | testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 85 | num_classes = 100 86 | 87 | class AttackPGD(nn.Module): 88 | def __init__(self, basic_net, config): 89 | super(AttackPGD, self).__init__() 90 | self.basic_net = basic_net 91 | self.step_size = config['step_size'] 92 | self.epsilon = config['epsilon'] 93 | self.num_steps = config['num_steps'] 94 | 95 | def forward(self, inputs, targets): 96 | x = inputs.detach() 97 | x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon) 98 | for i in range(self.num_steps): 99 | x.requires_grad_() 100 | with torch.enable_grad(): 101 | loss = F.cross_entropy(self.basic_net(x), targets, size_average=False) 102 | grad = torch.autograd.grad(loss, [x])[0] 103 | x = x.detach() + self.step_size*torch.sign(grad.detach()) 104 | x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon) 105 | x = torch.clamp(x, 0.0, 1.0) 106 | return self.basic_net(x), x 107 | 108 | 109 | # build teacher and student models 110 | # dataparalella 111 | 112 | print('==> Building model..'+args.model) 113 | # student 114 | if args.model == 'MobileNetV2': 115 | basic_net = MobileNetV2(num_classes=num_classes) 116 | elif args.model == 'WideResNet': 117 | basic_net = WideResNet(num_classes=num_classes) 118 | elif args.model == 'ResNet18': 119 | basic_net = ResNet18(num_classes=num_classes) 120 | basic_net = basic_net.to(device) 121 | basic_net = torch.nn.DataParallel(basic_net) 122 | 123 | # teacher 124 | if args.teacher_path != '': 125 | if args.teacher_model == 'MobileNetV2': 126 | teacher_net = MobileNetV2(num_classes=num_classes) 127 | elif args.teacher_model == 'WideResNet': 128 | teacher_net = WideResNet(num_classes=num_classes) 129 | elif args.teacher_model == 'ResNet18': 130 | teacher_net = ResNet18(num_classes=num_classes) 131 | teacher_net = teacher_net.to(device) 132 | for param in teacher_net.parameters(): 133 | param.requires_grad = False 134 | 135 | config_train = { 136 | 'epsilon': 8 / 255, 137 | 'num_steps': 10, 138 | 'step_size': 2 / 255, 139 | } 140 | 141 | net = AttackPGD(basic_net, config_train) 142 | 143 | if device == 'cuda': 144 | cudnn.benchmark = True 145 | 146 | print('==> Loading teacher..') 147 | teacher_net = torch.nn.DataParallel(teacher_net) 148 | teacher_net.load_state_dict(torch.load(args.teacher_path)['state_dict']) 149 | teacher_net.eval() 150 | 151 | 152 | KL_loss = nn.KLDivLoss(reduce=False) 153 | XENT_loss = nn.CrossEntropyLoss() 154 | lr=args.lr 155 | 156 | def train(epoch, optimizer, net, basic_net, teacher_net): 157 | torch.cuda.synchronize() 158 | start = time.time() 159 | net.train() 160 | train_loss = 0 161 | iterator = tqdm(trainloader, ncols=0, leave=False) 162 | for batch_idx, (inputs, targets) in enumerate(iterator): 163 | inputs, targets = inputs.to(device), targets.to(device) 164 | optimizer.zero_grad() 165 | teacher_outputs = teacher_net(inputs) 166 | outputs, pert_inputs = net(inputs, targets) 167 | Alpha = torch.ones(len(inputs)).cuda() 168 | 169 | basicop = basic_net(pert_inputs).detach() 170 | guide = teacher_net(pert_inputs) 171 | 172 | 173 | if epoch >= args.begin: 174 | for pp in range(len(outputs)): 175 | 176 | L = F.softmax(guide, dim=1)[pp][targets[pp].item()] 177 | L = L.pow(args.beta).item() 178 | Alpha[pp] = L 179 | loss = args.alpha*args.temp*args.temp*(1/len(outputs))*torch.sum(KL_loss(F.log_softmax(outputs/args.temp, dim=1),F.softmax(teacher_outputs/args.temp, dim=1)).sum(dim=1)) + args.alpha*(1/len(outputs))*torch.sum(KL_loss(F.log_softmax(outputs, dim=1),F.softmax(basic_net(inputs), dim=1)).sum(dim=1).mul(1-Alpha))+(1.0-args.alpha)*XENT_loss(basic_net(inputs), targets) 180 | else: 181 | loss = args.alpha*args.temp*args.temp*(1/len(outputs))*torch.sum(KL_loss(F.log_softmax(outputs/args.temp, dim=1),F.softmax(teacher_outputs/args.temp, dim=1)).sum(dim=1))+(1.0-args.alpha)*XENT_loss(basic_net(inputs), targets) 182 | 183 | loss.backward() 184 | optimizer.step() 185 | train_loss += loss.item() 186 | iterator.set_description(str(loss.item())) 187 | torch.cuda.synchronize() 188 | end = time.time() 189 | print(end-start) 190 | print('Mean Training Loss:', train_loss/len(iterator)) 191 | return train_loss 192 | 193 | 194 | def test(epoch, optimizer, net, basic_net, teacher_net): 195 | net.eval() 196 | adv_correct = 0 197 | natural_correct = 0 198 | total = 0 199 | with torch.no_grad(): 200 | iterator = tqdm(testloader, ncols=0, leave=False) 201 | for batch_idx, (inputs, targets) in enumerate(iterator): 202 | inputs, targets = inputs.to(device), targets.to(device) 203 | adv_outputs, pert_inputs = net(inputs, targets) 204 | natural_outputs = basic_net(inputs) 205 | _, adv_predicted = adv_outputs.max(1) 206 | _, natural_predicted = natural_outputs.max(1) 207 | natural_correct += natural_predicted.eq(targets).sum().item() 208 | total += targets.size(0) 209 | adv_correct += adv_predicted.eq(targets).sum().item() 210 | iterator.set_description(str(adv_predicted.eq(targets).sum().item()/targets.size(0))) 211 | robust_acc = 100.*adv_correct/total 212 | natural_acc = 100.*natural_correct/total 213 | print('Natural acc:', natural_acc) 214 | print('Robust acc:', robust_acc) 215 | return natural_acc, robust_acc 216 | 217 | def main(): 218 | lr = args.lr 219 | best_acc = 0 220 | test_robust = 0 221 | stu_r = 0 222 | tea_r = 0 223 | mark = 1 224 | optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=2e-4) 225 | logger_test = Logger(os.path.join(out_dir, 'student_results.txt'), title='student') 226 | logger_test_teacher = Logger(os.path.join(out_dir, 'teacher_results.txt'), title='teacher') 227 | logger_test.set_names(['Epoch', 'Natural Test Acc', 'PGD10 Acc', 'T or S']) 228 | logger_test_teacher.set_names(['Epoch', 'Natural Test Acc', 'PGD10 Acc', 'T or S']) 229 | for epoch in range(args.epochs): 230 | adjust_learning_rate(optimizer, epoch, lr) 231 | 232 | print("teacher >>>> student ") 233 | mark = 1 234 | train_loss = train(epoch, optimizer, net, basic_net, teacher_net) 235 | 236 | if (epoch+1)%args.val_period == 0: 237 | natural_val, robust_val = test(epoch, optimizer, net, basic_net, teacher_net) 238 | natural_val_t, robust_val_t = 0, 0 239 | logger_test.append([epoch + 1, natural_val, robust_val, mark]) 240 | logger_test_teacher.append([epoch + 1, natural_val_t, robust_val_t, mark]) 241 | stu_r = robust_val 242 | tea_r = robust_val_t 243 | save_checkpoint({ 244 | 'epoch': epoch + 1, 245 | 'test_nat_acc': natural_val, 246 | 'test_pgd10_acc': robust_val, 247 | 'state_dict': basic_net.state_dict(), 248 | 'optimizer' : optimizer.state_dict(), 249 | }) 250 | 251 | if robust_val > best_acc: 252 | best_acc = robust_val 253 | save_checkpoint({ 254 | 'epoch': epoch + 1, 255 | 'state_dict': basic_net.state_dict(), 256 | 'test_nat_acc': natural_val, 257 | 'test_pgd10_acc': robust_val, 258 | 'optimizer' : optimizer.state_dict(), 259 | },filename='bestpoint.pth.tar') 260 | 261 | 262 | if __name__ == '__main__': 263 | main() 264 | -------------------------------------------------------------------------------- /IAD-II.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import torch.backends.cudnn as cudnn 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import numpy as np 10 | import os 11 | import argparse 12 | from tqdm import tqdm 13 | from utils import Logger 14 | from models import * 15 | 16 | parser = argparse.ArgumentParser(description='IAD-II') 17 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 18 | parser.add_argument('--lr_schedule', type=int, nargs='+', default=[100, 150], help='Decrease learning rate at these epochs.') 19 | parser.add_argument('--lr_factor', default=0.1, type=float, help='factor by which to decrease lr') 20 | parser.add_argument('--epochs', default=200, type=int, help='number of epochs for training') 21 | parser.add_argument('--output', default = '', type=str, help='output subdirectory') 22 | parser.add_argument('--model', default = 'ResNet18', type = str, help = 'student model name') 23 | parser.add_argument('--teacher_model', default = 'ResNet18', type = str, help = 'teacher network model') 24 | parser.add_argument('--teacher_path', default = './pre_train/AT_teacher_cifar10/bestpoint.pth.tar', type=str, help='path of AT teacher net being distilled') 25 | parser.add_argument('--teacher_st_path', default = './pre_train/ST_teacher_cifar10/bestpoint.pth.tar', type=str, help='path of ST teacher net being distilled') 26 | parser.add_argument('--temp', default=1.0, type=float, help='temperature for distillation') 27 | parser.add_argument('--val_period', default=1, type=int, help='print every __ epoch') 28 | parser.add_argument('--save_period', default=1, type=int, help='save every __ epoch') 29 | parser.add_argument('--alpha', default=0.5, type=float, help='weight for sum of losses') 30 | parser.add_argument('--dataset', default = 'CIFAR10', type=str, help='name of dataset') 31 | parser.add_argument('--out-dir',type=str,default='./IAD_II_CIFAR10',help='dir of output') 32 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed') 33 | parser.add_argument('--resume',type=str,default=None,help='whether to resume training') 34 | parser.add_argument('--alpha1', default=0.5, type=float, help='weight for sum of losses') 35 | parser.add_argument('--alpha2', default=0.25, type=float, help='weight for sum of losses') 36 | parser.add_argument('--beta',type=float, default=0.1) 37 | parser.add_argument('--begin',type=int, default=40) 38 | 39 | args = parser.parse_args() 40 | 41 | seed = args.seed 42 | out_dir = args.out_dir 43 | torch.manual_seed(seed) 44 | np.random.seed(seed) 45 | torch.cuda.manual_seed_all(seed) 46 | torch.backends.cudnn.benchmark = True 47 | torch.backends.cudnn.deterministic = True 48 | 49 | 50 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 51 | 52 | def adjust_learning_rate(optimizer, epoch, lr): 53 | if epoch in args.lr_schedule: 54 | lr *= args.lr_factor 55 | for param_group in optimizer.param_groups: 56 | param_group['lr'] = lr 57 | 58 | # Store path 59 | if not os.path.exists(out_dir): 60 | os.makedirs(out_dir) 61 | 62 | # Save checkpoint 63 | def save_checkpoint(state, checkpoint=out_dir, filename='checkpoint.pth.tar'): 64 | filepath = os.path.join(checkpoint, filename) 65 | torch.save(state, filepath) 66 | 67 | # prepare the dataset 68 | print('==> Preparing data..') 69 | transform_train = transforms.Compose([ 70 | transforms.RandomCrop(32, padding=4), 71 | transforms.RandomHorizontalFlip(), 72 | transforms.ToTensor(), 73 | ]) 74 | transform_test = transforms.Compose([ 75 | transforms.ToTensor(), 76 | ]) 77 | if args.dataset == 'CIFAR10': 78 | trainset = torchvision.datasets.CIFAR10(root='~/data/cifar-10', train=True, download=True, transform=transform_train) 79 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 80 | testset = torchvision.datasets.CIFAR10(root='~/data/cifar-10', train=False, download=True, transform=transform_test) 81 | testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 82 | num_classes = 10 83 | elif args.dataset == 'CIFAR100': 84 | trainset = torchvision.datasets.CIFAR100(root='~/data/cifar-100', train=True, download=True, transform=transform_train) 85 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 86 | testset = torchvision.datasets.CIFAR100(root='~/data/cifar-100', train=False, download=True, transform=transform_test) 87 | testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 88 | num_classes = 100 89 | 90 | 91 | 92 | class AttackPGD(nn.Module): 93 | def __init__(self, basic_net, config): 94 | super(AttackPGD, self).__init__() 95 | self.basic_net = basic_net 96 | self.step_size = config['step_size'] 97 | self.epsilon = config['epsilon'] 98 | self.num_steps = config['num_steps'] 99 | 100 | def forward(self, inputs, targets): 101 | x = inputs.detach() 102 | x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon) 103 | for i in range(self.num_steps): 104 | x.requires_grad_() 105 | with torch.enable_grad(): 106 | loss = F.cross_entropy(self.basic_net(x), targets, size_average=False) 107 | grad = torch.autograd.grad(loss, [x])[0] 108 | x = x.detach() + self.step_size*torch.sign(grad.detach()) 109 | x = torch.min(torch.max(x, inputs - self.epsilon), inputs + self.epsilon) 110 | x = torch.clamp(x, 0.0, 1.0) 111 | return self.basic_net(x), x 112 | 113 | 114 | # build teacher and student models 115 | # dataparalella 116 | 117 | print('==> Building model..'+args.model) 118 | # student 119 | if args.model == 'MobileNetV2': 120 | basic_net = MobileNetV2(num_classes=num_classes) 121 | elif args.model == 'WideResNet': 122 | basic_net = WideResNet(num_classes=num_classes) 123 | elif args.model == 'ResNet18': 124 | basic_net = ResNet18(num_classes=num_classes) 125 | basic_net = basic_net.to(device) 126 | basic_net = torch.nn.DataParallel(basic_net) 127 | 128 | # teacher 129 | if args.teacher_path != '': 130 | if args.teacher_model == 'MobileNetV2': 131 | teacher_net = MobileNetV2(num_classes=num_classes) 132 | elif args.teacher_model == 'WideResNet': 133 | teacher_net = WideResNet(num_classes=num_classes) 134 | elif args.teacher_model == 'ResNet18': 135 | teacher_net = ResNet18(num_classes=num_classes) 136 | teacher_net = teacher_net.to(device) 137 | for param in teacher_net.parameters(): 138 | param.requires_grad = False 139 | 140 | teacher_st_net = ResNet18(num_classes=num_classes) 141 | teacher_st_net = teacher_st_net.to(device) 142 | for param in teacher_st_net.parameters(): 143 | param.requires_grad = False 144 | 145 | config_train = { 146 | 'epsilon': 8 / 255, 147 | 'num_steps': 10, 148 | 'step_size': 2 / 255, 149 | } 150 | 151 | net = AttackPGD(basic_net, config_train) 152 | 153 | if device == 'cuda': 154 | cudnn.benchmark = True 155 | 156 | print('==> Loading at teacher..') 157 | teacher_net = torch.nn.DataParallel(teacher_net) 158 | teacher_net.load_state_dict(torch.load(args.teacher_path)['state_dict']) 159 | teacher_net.eval() 160 | 161 | print('==> Loading st teacher..') 162 | teacher_st_net = torch.nn.DataParallel(teacher_st_net) 163 | teacher_st_net.load_state_dict(torch.load(args.teacher_st_path)['state_dict']) 164 | teacher_st_net.eval() 165 | 166 | 167 | KL_loss = nn.KLDivLoss(reduce=False) 168 | XENT_loss = nn.CrossEntropyLoss() 169 | lr=args.lr 170 | 171 | def train(epoch, optimizer, net, basic_net, teacher_net): 172 | net.train() 173 | train_loss = 0 174 | iterator = tqdm(trainloader, ncols=0, leave=False) 175 | for batch_idx, (inputs, targets) in enumerate(iterator): 176 | inputs, targets = inputs.to(device), targets.to(device) 177 | optimizer.zero_grad() 178 | outputs, pert_inputs = net.forward(inputs, targets) 179 | teacher_outputs = teacher_net(pert_inputs) 180 | st_outputs = teacher_st_net(pert_inputs) 181 | Alpha = torch.ones(len(inputs)).cuda() 182 | 183 | guide = teacher_net(pert_inputs) 184 | 185 | if epoch >= args.begin: 186 | for pp in range(len(outputs)): 187 | L = F.softmax(guide, dim=1)[pp][targets[pp].item()] 188 | L = L.pow(args.beta).item() 189 | Alpha[pp] = L 190 | loss = (1-args.alpha1-args.alpha2)*XENT_loss(outputs, targets)+args.alpha1*args.temp*args.temp*((1/len(outputs))*torch.sum(KL_loss(F.log_softmax(outputs/args.temp, dim=1),F.softmax(teacher_outputs/args.temp, dim=1)).sum(dim=1).mul(Alpha)) + (1/len(outputs))*torch.sum(KL_loss(F.log_softmax(outputs, dim=1),F.softmax(basic_net(inputs), dim=1)).sum(dim=1).mul(1-Alpha)))+args.alpha2*args.temp*args.temp*(1/len(outputs))*torch.sum(KL_loss(F.log_softmax(outputs/args.temp, dim=1),F.softmax(st_outputs/args.temp, dim=1)).sum(dim=1)) 191 | else: 192 | loss = (1-args.alpha1-args.alpha2)*XENT_loss(outputs, targets)+args.alpha1*args.temp*args.temp*(1/len(outputs))*torch.sum(KL_loss(F.log_softmax(outputs/args.temp, dim=1),F.softmax(teacher_outputs/args.temp, dim=1)).sum(dim=1))+args.alpha2*args.temp*args.temp*(1/len(outputs))*torch.sum(KL_loss(F.log_softmax(outputs/args.temp, dim=1),F.softmax(st_outputs/args.temp, dim=1)).sum(dim=1)) 193 | 194 | loss.backward() 195 | optimizer.step() 196 | train_loss += loss.item() 197 | iterator.set_description(str(loss.item())) 198 | 199 | print('Mean Training Loss:', train_loss/len(iterator)) 200 | return train_loss 201 | 202 | 203 | def test(epoch, optimizer, net, basic_net, teacher_net): 204 | net.eval() 205 | adv_correct = 0 206 | natural_correct = 0 207 | total = 0 208 | with torch.no_grad(): 209 | iterator = tqdm(testloader, ncols=0, leave=False) 210 | for batch_idx, (inputs, targets) in enumerate(iterator): 211 | inputs, targets = inputs.to(device), targets.to(device) 212 | adv_outputs, pert_inputs = net(inputs, targets) 213 | natural_outputs = basic_net(inputs) 214 | _, adv_predicted = adv_outputs.max(1) 215 | _, natural_predicted = natural_outputs.max(1) 216 | natural_correct += natural_predicted.eq(targets).sum().item() 217 | total += targets.size(0) 218 | adv_correct += adv_predicted.eq(targets).sum().item() 219 | iterator.set_description(str(adv_predicted.eq(targets).sum().item()/targets.size(0))) 220 | robust_acc = 100.*adv_correct/total 221 | natural_acc = 100.*natural_correct/total 222 | print('Natural acc:', natural_acc) 223 | print('Robust acc:', robust_acc) 224 | return natural_acc, robust_acc 225 | 226 | def main(): 227 | lr = args.lr 228 | best_acc = 0 229 | test_robust = 0 230 | optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=2e-4) 231 | logger_test = Logger(os.path.join(out_dir, 'student_results.txt'), title='student') 232 | logger_test.set_names(['Epoch', 'Natural Test Acc', 'PGD10 Acc']) 233 | for epoch in range(args.epochs): 234 | adjust_learning_rate(optimizer, epoch, lr) 235 | 236 | print("teacher >>>> student ") 237 | train_loss = train(epoch, optimizer, net, basic_net, teacher_net) 238 | 239 | if (epoch+1)%args.val_period == 0: 240 | natural_val, robust_val = test(epoch, optimizer, net, basic_net, teacher_net) 241 | logger_test.append([epoch + 1, natural_val, robust_val]) 242 | save_checkpoint({ 243 | 'epoch': epoch + 1, 244 | 'test_nat_acc': natural_val, 245 | 'test_pgd10_acc': robust_val, 246 | 'state_dict': basic_net.state_dict(), 247 | 'optimizer' : optimizer.state_dict(), 248 | }) 249 | 250 | if robust_val > best_acc: 251 | best_acc = robust_val 252 | save_checkpoint({ 253 | 'epoch': epoch + 1, 254 | 'state_dict': basic_net.state_dict(), 255 | 'test_nat_acc': natural_val, 256 | 'test_pgd10_acc': robust_val, 257 | 'optimizer' : optimizer.state_dict(), 258 | },filename='bestpoint.pth.tar') 259 | 260 | 261 | if __name__ == '__main__': 262 | main() 263 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Jianing Zhu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Reliable Adversarial Distillation with Unreliable Teachers 2 | 3 | Code for ICLR 2022 "[Reliable Adversarial Distillation with Unreliable Teachers](https://openreview.net/forum?id=u6TRGdzhfip&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3DICLR.cc%2F2022%2FConference%2FAuthors%23your-submissions))" 4 | 5 | by *Jianing Zhu, Jiangchao Yao, Bo Han, Jingfeng Zhang, Tongliang Liu, Gang Niu, Jingren Zhou, Jianliang Xu, Hongxia Yang*. 6 | 7 | Full code and instructions will be completed soon. 8 | 9 | ## Introduction 10 | 11 | In this work, we found the soft-labels provided by the teacher model gradually becomes ***less and less reliable*** during the adversarial training of student model. Based on that, we propose to ***partially trust*** the soft labels provided by the teacher model in adversarial distillation. 12 | 13 | 14 | 15 | ## Environment 16 | 17 | - Python (3.7.10) 18 | - Pytorch (1.7.1) 19 | - torchvision (0.8.2) 20 | - CUDA 21 | - Numpy 22 | - advtorch 23 | 24 | ## Content 25 | 26 | - ```./models```: models used for pre-train and distillation. 27 | - ```./pre_train```: code for AT and ST. 28 | - ```IAD-I.py```: Introspective Adversarial Distillation based on ARD. 29 | - ```IAD-II.py```: Introspective Adversarial Distillation based on AKD2. 30 | 31 | ## Usage 32 | 33 | **Pre-train** 34 | 35 | - AT 36 | ```bash 37 | cd ./pre_train 38 | CUDA_VISIBLE_DEVICES='0' python AT.py --out-dir INSERT-YOUR-OUTPUT-PATH 39 | ``` 40 | 41 | - ST 42 | ```bash 43 | cd ./pre_train 44 | CUDA_VISIBLE_DEVICES='0' python ST.py --out-dir INSERT-YOUR-OUTPUT-PATH 45 | ``` 46 | 47 | **Distillation** 48 | 49 | - IAD-I 50 | ```bash 51 | CUDA_VISIBLE_DEVICES='0' python IAD-I.py --teacher_path INSERT-YOUR-TEACHER-PATH --out-dir INSERT-YOUR-OUTPUT-PATH 52 | ``` 53 | 54 | - IAD-II 55 | ```bash 56 | CUDA_VISIBLE_DEVICES='0' python IAD-II.py --teacher_path INSERT-YOUR-TEACHER-PATH --out-dir INSERT-YOUR-OUTPUT-PATH 57 | ``` 58 | 59 | **Evaluation** 60 | - basic eval 61 | ```bash 62 | CUDA_VISIBLE_DEVICES='0' python basic_eval.py --model_path INSERT-YOUR-MODEL-PATH 63 | ``` 64 | 65 | ## Citation 66 | 67 | ```bib 68 | @inproceedings{zhu2022reliable, 69 | title={Reliable Adversarial Distillation with Unreliable Teachers}, 70 | author={Jianing Zhu and Jiangchao Yao and Bo Han and Jingfeng Zhang and Tongliang Liu and Gang Niu and Jingren Zhou and Jianliang Xu and Hongxia Yang}, 71 | booktitle={International Conference on Learning Representations}, 72 | year={2022}, 73 | url={https://openreview.net/forum?id=u6TRGdzhfip} 74 | } 75 | ``` 76 | 77 | ## Reference Code 78 | 79 | [1] AT: https://github.com/locuslab/robust_overfitting 80 | 81 | [2] TRADES: https://github.com/yaodongyu/TRADES/ 82 | 83 | [3] ARD: https://github.com/goldblum/AdversariallyRobustDistillation 84 | 85 | [4] AKD2: https://github.com/VITA-Group/Alleviate-Robust-Overfitting 86 | 87 | [5] GAIRAT: https://github.com/zjfheart/Geometry-aware-Instance-reweighted-Adversarial-Training 88 | 89 | ## Contact 90 | 91 | Please contact csjnzhu@comp.hkbu.edu.hk if you have any question on the codes. 92 | -------------------------------------------------------------------------------- /attack_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from models import * 3 | from torch.autograd import Variable 4 | 5 | def cwloss(output, target,confidence=50, num_classes=10): 6 | # Compute the probability of the label class versus the maximum other 7 | target = target.data 8 | target_onehot = torch.zeros(target.size() + (num_classes,)) 9 | target_onehot = target_onehot.cuda() 10 | target_onehot.scatter_(1, target.unsqueeze(1), 1.) 11 | target_var = Variable(target_onehot, requires_grad=False) 12 | real = (target_var * output).sum(1) 13 | other = ((1. - target_var) * output - target_var * 10000.).max(1)[0] 14 | loss = -torch.clamp(real - other + confidence, min=0.) # equiv to max(..., 0.) 15 | loss = torch.sum(loss) 16 | return loss 17 | 18 | def PGD(model, data, target, epsilon, step_size, num_steps,loss_fn,category,rand_init): 19 | model.eval() 20 | Kappa = torch.zeros(len(data)) 21 | if category == "trades": 22 | x_adv = data.detach() + 0.001 * torch.randn(data.shape).cuda().detach() if rand_init else data.detach() 23 | nat_output = model(data) 24 | if category == "Madry": 25 | x_adv = data.detach() + torch.from_numpy(np.random.uniform(-epsilon, epsilon, data.shape)).float().cuda() if rand_init else data.detach() 26 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 27 | for k in range(num_steps): 28 | x_adv.requires_grad_() 29 | output = model(x_adv) 30 | predict = output.max(1, keepdim=True)[1] 31 | # Update Kappa 32 | for p in range(len(x_adv)): 33 | if predict[p] == target[p]: 34 | Kappa[p] += 1 35 | model.zero_grad() 36 | with torch.enable_grad(): 37 | if loss_fn == "cent": 38 | loss_adv = nn.CrossEntropyLoss(reduction="mean")(output, target) 39 | if loss_fn == "cw": 40 | loss_adv = cwloss(output,target) 41 | if loss_fn == "kl": 42 | criterion_kl = nn.KLDivLoss(size_average=False).cuda() 43 | loss_adv = criterion_kl(F.log_softmax(output, dim=1),F.softmax(nat_output, dim=1)) 44 | loss_adv.backward() 45 | eta = step_size * x_adv.grad.sign() 46 | # Update adversarial data 47 | x_adv = x_adv.detach() + eta 48 | x_adv = torch.min(torch.max(x_adv, data - epsilon), data + epsilon) 49 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 50 | x_adv = Variable(x_adv, requires_grad=False) 51 | return x_adv, Kappa 52 | 53 | def eval_clean(model, test_loader): 54 | model.eval() 55 | test_loss = 0 56 | correct = 0 57 | with torch.no_grad(): 58 | for data, target in test_loader: 59 | data, target = data.cuda(), target.cuda() 60 | output = model(data) 61 | test_loss += F.cross_entropy(output, target, size_average=False).item() 62 | pred = output.max(1, keepdim=True)[1] 63 | correct += pred.eq(target.view_as(pred)).sum().item() 64 | test_loss /= len(test_loader.dataset) 65 | test_accuracy = correct / len(test_loader.dataset) 66 | return test_loss, test_accuracy 67 | 68 | def eval_robust(model, test_loader, perturb_steps, epsilon, step_size, loss_fn, category, random): 69 | model.eval() 70 | test_loss = 0 71 | correct = 0 72 | with torch.enable_grad(): 73 | for data, target in test_loader: 74 | data, target = data.cuda(), target.cuda() 75 | x_adv, _ = PGD(model,data,target,epsilon,step_size,perturb_steps,loss_fn,category,rand_init=random) 76 | output = model(x_adv) 77 | test_loss += F.cross_entropy(output, target, size_average=False).item() 78 | pred = output.max(1, keepdim=True)[1] 79 | correct += pred.eq(target.view_as(pred)).sum().item() 80 | test_loss /= len(test_loader.dataset) 81 | test_accuracy = correct / len(test_loader.dataset) 82 | return test_loss, test_accuracy 83 | 84 | -------------------------------------------------------------------------------- /basic_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.nn as nn 3 | import torchvision 4 | from torchvision import transforms 5 | from models import * 6 | import attack_generator as attack 7 | 8 | parser = argparse.ArgumentParser(description='PyTorch White-box Adversarial Attack Test') 9 | parser.add_argument('--net', type=str, default="resnet18", help="decide which network to use,choose from smallcnn,resnet18,WRN") 10 | parser.add_argument('--dataset', type=str, default="cifar10", help="choose from cifar10,svhn") 11 | parser.add_argument('--depth', type=int, default=34, help='WRN depth') 12 | parser.add_argument('--width_factor', type=int, default=10,help='WRN width factor') 13 | parser.add_argument('--drop_rate', type=float,default=0.0, help='WRN drop rate') 14 | parser.add_argument('--model_path', default="./bestpoint.pth.tar", help='model for white-box attack evaluation') 15 | 16 | args = parser.parse_args() 17 | 18 | transform_test = transforms.Compose([ 19 | transforms.ToTensor(), 20 | ]) 21 | print('==> Load Test Data') 22 | if args.dataset == "cifar10": 23 | testset = torchvision.datasets.CIFAR10(root='~/data/cifar-10', train=False, download=True, transform=transform_test) 24 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 25 | num_classes = 10 26 | if args.dataset == 'cifar100': 27 | testset = torchvision.datasets.CIFAR100(root='~/data/cifar-100', train=False, download=True, transform=transform_test) 28 | testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 29 | num_classes = 100 30 | 31 | print('==> Load Model') 32 | if args.net == "resnet18": 33 | model = ResNet18(num_classes=num_classes).cuda() 34 | net = "resnet18" 35 | if args.net == "WRN": 36 | model = WideResNet(depth=args.depth, num_classes=num_classes, widen_factor=args.width_factor, dropRate=args.drop_rate).cuda() 37 | net = "WRN{}-{}-dropout{}".format(args.depth,args.width_factor,args.drop_rate) 38 | model = torch.nn.DataParallel(model) 39 | 40 | print(net) 41 | print(args.model_path) 42 | model.load_state_dict(torch.load(args.model_path)['state_dict']) 43 | 44 | print('==> Evaluating Performance under White-box Adversarial Attack') 45 | 46 | loss, test_nat_acc = attack.eval_clean(model, test_loader) 47 | print('Natural Test Accuracy: {:.2f}%'.format(100. * test_nat_acc)) 48 | # Evalutions the same as DAT. 49 | loss, fgsm_wori_acc = attack.eval_robust(model, test_loader, perturb_steps=1, epsilon=8/255, step_size=8/255,loss_fn="cent", category="Madry",random=True) 50 | print('FGSM without Random Start Test Accuracy: {:.2f}%'.format(100. * fgsm_wori_acc)) 51 | loss, pgd20_acc = attack.eval_robust(model, test_loader, perturb_steps=20, epsilon=8/255, step_size=2/255,loss_fn="cent", category="Madry", random=True) 52 | print('PGD20 Test Accuracy: {:.2f}%'.format(100. * pgd20_acc)) 53 | loss, cw_wori_acc = attack.eval_robust(model,test_loader, perturb_steps=30, epsilon=8/255, step_size=2/255,loss_fn="cw",category="Madry",random=True) 54 | print('CW Test Accuracy: {:.2f}%'.format(100. * cw_wori_acc)) 55 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mobilenetv2 import * 2 | from .wideresnet import * 3 | from .resnet import * 4 | from .preresnet import * -------------------------------------------------------------------------------- /models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Block(nn.Module): 6 | def __init__(self, in_planes, out_planes, expansion, stride): 7 | super(Block, self).__init__() 8 | self.stride = stride 9 | 10 | planes = expansion * in_planes 11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False) 12 | self.bn1 = nn.BatchNorm2d(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 14 | self.bn2 = nn.BatchNorm2d(planes) 15 | self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 16 | self.bn3 = nn.BatchNorm2d(out_planes) 17 | 18 | self.shortcut = nn.Sequential() 19 | if stride == 1 and in_planes != out_planes: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False), 22 | nn.BatchNorm2d(out_planes), 23 | ) 24 | 25 | def forward(self, x): 26 | out = F.relu(self.bn1(self.conv1(x))) 27 | out = F.relu(self.bn2(self.conv2(out))) 28 | out = self.bn3(self.conv3(out)) 29 | out = out + self.shortcut(x) if self.stride==1 else out 30 | return out 31 | 32 | class MobileNetV2(nn.Module): 33 | #(expansion, out_planes, num_blocks, stride) 34 | cfg = [(1, 16, 1, 1), 35 | (6, 24, 2, 1), 36 | (6, 32, 3, 2), 37 | (6, 64, 4, 2), 38 | (6, 96, 3, 1), 39 | (6, 160, 3, 2), 40 | (6, 320, 1, 1)] 41 | 42 | def __init__(self, num_classes=10): 43 | super(MobileNetV2, self).__init__() 44 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(32) 46 | self.layers = self._make_layers(in_planes=32) 47 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 48 | self.bn2 = nn.BatchNorm2d(1280) 49 | self.linear = nn.Linear(1280, num_classes) 50 | 51 | def _make_layers(self, in_planes): 52 | layers = [] 53 | for expansion, out_planes, num_blocks, stride in self.cfg: 54 | strides = [stride] + [1]*(num_blocks-1) 55 | for stride in strides: 56 | layers.append(Block(in_planes, out_planes, expansion, stride)) 57 | in_planes = out_planes 58 | return nn.Sequential(*layers) 59 | 60 | def forward(self, x): 61 | out = F.relu(self.bn1(self.conv1(x))) 62 | out = self.layers(out) 63 | out = F.relu(self.bn2(self.conv2(out))) 64 | out = F.avg_pool2d(out, 4) 65 | out = out.view(out.size(0), -1) 66 | out = self.linear(out) 67 | return out 68 | -------------------------------------------------------------------------------- /models/preresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from advertorch.utils import NormalizeByChannelMeanStd 5 | 6 | __all__ = ['pResNet18', 'pResNet34', 'pResNet50', 'pResNet101', 'pResNet152'] 7 | 8 | class PreActBlock(nn.Module): 9 | '''Pre-activation version of the BasicBlock.''' 10 | expansion = 1 11 | 12 | def __init__(self, in_planes, planes, stride=1): 13 | super(PreActBlock, self).__init__() 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 18 | 19 | if stride != 1 or in_planes != self.expansion*planes: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 22 | ) 23 | 24 | def forward(self, x): 25 | out = F.relu(self.bn1(x)) 26 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 27 | out = self.conv1(out) 28 | out = self.conv2(F.relu(self.bn2(out))) 29 | out += shortcut 30 | return out 31 | 32 | class PreActBottleneck(nn.Module): 33 | '''Pre-activation version of the original Bottleneck module.''' 34 | expansion = 4 35 | 36 | def __init__(self, in_planes, planes, stride=1): 37 | super(PreActBottleneck, self).__init__() 38 | self.bn1 = nn.BatchNorm2d(in_planes) 39 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 40 | self.bn2 = nn.BatchNorm2d(planes) 41 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 42 | self.bn3 = nn.BatchNorm2d(planes) 43 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 44 | 45 | if stride != 1 or in_planes != self.expansion*planes: 46 | self.shortcut = nn.Sequential( 47 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 48 | ) 49 | 50 | def forward(self, x): 51 | out = F.relu(self.bn1(x)) 52 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 53 | out = self.conv1(out) 54 | out = self.conv2(F.relu(self.bn2(out))) 55 | out = self.conv3(F.relu(self.bn3(out))) 56 | out += shortcut 57 | return out 58 | 59 | class PreActResNet(nn.Module): 60 | def __init__(self, block, num_blocks, num_classes=100): 61 | super(PreActResNet, self).__init__() 62 | self.in_planes = 64 63 | 64 | # default normalization is for Tiny-ImageNet 65 | self.normalize = NormalizeByChannelMeanStd( 66 | mean=[0.4802, 0.4481, 0.3975], std=[0.2302, 0.2265, 0.2262]) 67 | 68 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 69 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 70 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 71 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 72 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 73 | self.bn = nn.BatchNorm2d(512 * block.expansion) 74 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 75 | self.linear = nn.Linear(512*block.expansion, num_classes) 76 | 77 | def _make_layer(self, block, planes, num_blocks, stride): 78 | strides = [stride] + [1]*(num_blocks-1) 79 | layers = [] 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, stride)) 82 | self.in_planes = planes * block.expansion 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x): 86 | x = self.normalize(x) 87 | out = self.conv1(x) 88 | out = self.layer1(out) 89 | out = self.layer2(out) 90 | out = self.layer3(out) 91 | out = self.layer4(out) 92 | out = F.relu(self.bn(out)) 93 | out = self.avgpool(out) 94 | out = out.view(out.size(0), -1) 95 | out = self.linear(out) 96 | return out 97 | 98 | def pResNet18(num_classes = 10): 99 | return PreActResNet(PreActBlock, [2,2,2,2], num_classes) 100 | 101 | def pResNet34(num_classes = 10): 102 | return PreActResNet(PreActBlock, [3,4,6,3], num_classes) 103 | 104 | def pResNet50(num_classes = 10): 105 | return PreActResNet(PreActBottleneck, [3,4,6,3], num_classes) 106 | 107 | def pResNet101(num_classes = 10): 108 | return PreActResNet(PreActBottleneck, [3,4,23,3], num_classes) 109 | 110 | def pResNet152(num_classes = 10): 111 | return PreActResNet(PreActBottleneck, [3,8,36,3], num_classes) -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class BasicBlock(nn.Module): 6 | expansion = 1 7 | 8 | def __init__(self, in_planes, planes, stride=1): 9 | super(BasicBlock, self).__init__() 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 11 | self.bn1 = nn.BatchNorm2d(planes) 12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 13 | self.bn2 = nn.BatchNorm2d(planes) 14 | self.shortcut = nn.Sequential() 15 | if stride != 1 or in_planes != self.expansion*planes: 16 | self.shortcut = nn.Sequential( 17 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 18 | nn.BatchNorm2d(self.expansion*planes) 19 | ) 20 | 21 | def forward(self, x): 22 | out = F.relu(self.bn1(self.conv1(x))) 23 | out = self.bn2(self.conv2(out)) 24 | out += self.shortcut(x) 25 | out = F.relu(out) 26 | return out 27 | 28 | class Bottleneck(nn.Module): 29 | expansion = 4 30 | 31 | def __init__(self, in_planes, planes, stride=1): 32 | super(Bottleneck, self).__init__() 33 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 34 | self.bn1 = nn.BatchNorm2d(planes) 35 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 38 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 39 | 40 | self.shortcut = nn.Sequential() 41 | if stride != 1 or in_planes != self.expansion*planes: 42 | self.shortcut = nn.Sequential( 43 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 44 | nn.BatchNorm2d(self.expansion*planes) 45 | ) 46 | 47 | def forward(self, x): 48 | out = F.relu(self.bn1(self.conv1(x))) 49 | out = F.relu(self.bn2(self.conv2(out))) 50 | out = self.bn3(self.conv3(out)) 51 | out += self.shortcut(x) 52 | out = F.relu(out) 53 | return out 54 | 55 | class ResNet(nn.Module): 56 | def __init__(self, block, num_blocks, num_classes=10): 57 | super(ResNet, self).__init__() 58 | self.in_planes = 64 59 | 60 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(64) 62 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 63 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 64 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 65 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 66 | self.linear = nn.Linear(512*block.expansion, num_classes) 67 | 68 | def _make_layer(self, block, planes, num_blocks, stride): 69 | strides = [stride] + [1]*(num_blocks-1) 70 | layers = [] 71 | for stride in strides: 72 | layers.append(block(self.in_planes, planes, stride)) 73 | self.in_planes = planes * block.expansion 74 | return nn.Sequential(*layers) 75 | 76 | def forward(self, x): 77 | out = F.relu(self.bn1(self.conv1(x))) 78 | out = self.layer1(out) 79 | out = self.layer2(out) 80 | out = self.layer3(out) 81 | out = self.layer4(out) 82 | out = F.avg_pool2d(out, 4) 83 | out = out.view(out.size(0), -1) 84 | out = self.linear(out) 85 | return out 86 | 87 | def ResNet18(num_classes=10): 88 | return ResNet(BasicBlock, [2,2,2,2], num_classes) 89 | 90 | def ResNet34(num_classes=10): 91 | return ResNet(BasicBlock, [3,4,6,3], num_classes) 92 | 93 | def ResNet50(num_classes=10): 94 | return ResNet(Bottleneck, [3,4,6,3], num_classes) 95 | 96 | def ResNet101(num_classes=10): 97 | return ResNet(Bottleneck, [3,4,23,3], num_classes) 98 | 99 | def ResNet152(num_classes=10): 100 | return ResNet(Bottleneck, [3,8,36,3], num_classes) 101 | -------------------------------------------------------------------------------- /models/wideresnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | 35 | class NetworkBlock(nn.Module): 36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 37 | super(NetworkBlock, self).__init__() 38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 39 | 40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 41 | layers = [] 42 | for i in range(int(nb_layers)): 43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | return self.layer(x) 48 | 49 | 50 | class WideResNet(nn.Module): 51 | def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0.0): 52 | super(WideResNet, self).__init__() 53 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 54 | assert ((depth - 4) % 6 == 0) 55 | n = (depth - 4) / 6 56 | block = BasicBlock 57 | # 1st conv before any network block 58 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 59 | padding=1, bias=False) 60 | # 1st block 61 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 62 | # 1st sub-block 63 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 64 | # 2nd block 65 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 66 | # 3rd block 67 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 68 | # global average pooling and classifier 69 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.fc = nn.Linear(nChannels[3], num_classes) 72 | self.nChannels = nChannels[3] 73 | 74 | for m in self.modules(): 75 | if isinstance(m, nn.Conv2d): 76 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 77 | m.weight.data.normal_(0, math.sqrt(2. / n)) 78 | elif isinstance(m, nn.BatchNorm2d): 79 | m.weight.data.fill_(1) 80 | m.bias.data.zero_() 81 | elif isinstance(m, nn.Linear): 82 | m.bias.data.zero_() 83 | 84 | def forward(self, x): 85 | out = self.conv1(x) 86 | out = self.block1(out) 87 | out = self.block2(out) 88 | out = self.block3(out) 89 | out = self.relu(self.bn1(out)) 90 | out = F.avg_pool2d(out, 8) 91 | out = out.view(-1, self.nChannels) 92 | return self.fc(out) 93 | 94 | -------------------------------------------------------------------------------- /pic/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZFancy/IAD/a8091a8c7552cef43d8f3f28085426cb786ce9d3/pic/overview.png -------------------------------------------------------------------------------- /pre_train/AT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torchvision 4 | import torch.optim as optim 5 | from torchvision import transforms 6 | from models import * 7 | import numpy as np 8 | import attack_generator as attack 9 | from utils import Logger 10 | 11 | parser = argparse.ArgumentParser(description='AT') 12 | parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train') 13 | parser.add_argument('--weight-decay', '--wd', default=2e-4, type=float, metavar='W') 14 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum') 15 | parser.add_argument('--epsilon', type=float, default=8/255, help='perturbation bound') 16 | parser.add_argument('--num-steps', type=int, default=10, help='maximum perturbation step K') 17 | parser.add_argument('--step-size', type=float, default=2/255, help='step size') 18 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed') 19 | parser.add_argument('--net', type=str, default="resnet18",help="decide which network to use,choose from smallcnn,resnet18,WRN") 20 | parser.add_argument('--dataset', type=str, default="cifar10", help="choose from cifar10,svhn,cifar100,mnist") 21 | parser.add_argument('--random',type=bool,default=True,help="whether to initiat adversarial sample with random noise") 22 | parser.add_argument('--depth',type=int,default=34,help='WRN depth') 23 | parser.add_argument('--width-factor',type=int,default=10,help='WRN width factor') 24 | parser.add_argument('--drop-rate',type=float,default=0.0, help='WRN drop rate') 25 | parser.add_argument('--resume',type=str,default=None,help='whether to resume training') 26 | parser.add_argument('--out-dir',type=str,default='AT_teacher_cifar10',help='dir of output') 27 | parser.add_argument('--lr-schedule', default='piecewise', choices=['superconverge', 'piecewise', 'linear', 'onedrop', 'multipledecay', 'cosine']) 28 | parser.add_argument('--lr-max', default=0.1, type=float) 29 | parser.add_argument('--lr-one-drop', default=0.01, type=float) 30 | parser.add_argument('--lr-drop-epoch', default=100, type=int) 31 | args = parser.parse_args() 32 | 33 | # Training settings 34 | seed = args.seed 35 | momentum = args.momentum 36 | weight_decay = args.weight_decay 37 | depth = args.depth 38 | width_factor = args.width_factor 39 | drop_rate = args.drop_rate 40 | resume = args.resume 41 | out_dir = args.out_dir 42 | 43 | torch.manual_seed(seed) 44 | np.random.seed(seed) 45 | torch.cuda.manual_seed_all(seed) 46 | torch.backends.cudnn.benchmark = True 47 | torch.backends.cudnn.deterministic = True 48 | 49 | # Models and optimizer 50 | if args.net == "resnet18": 51 | model = ResNet18().cuda() 52 | net = "resnet18" 53 | if args.net == "preactresnet18": 54 | model = PreActResNet18().cuda() 55 | net = "preactresnet18" 56 | if args.net == "WRN": 57 | model = Wide_ResNet(depth=depth, num_classes=10, widen_factor=width_factor, dropRate=drop_rate).cuda() 58 | net = "WRN{}-{}-dropout{}".format(depth,width_factor,drop_rate) 59 | 60 | model = torch.nn.DataParallel(model) 61 | optimizer = optim.SGD(model.parameters(), lr=args.lr_max, momentum=momentum, weight_decay=weight_decay) 62 | 63 | # Learning schedules 64 | if args.lr_schedule == 'superconverge': 65 | lr_schedule = lambda t: np.interp([t], [0, args.epochs * 2 // 5, args.epochs], [0, args.lr_max, 0])[0] 66 | elif args.lr_schedule == 'piecewise': 67 | def lr_schedule(t): 68 | if args.epochs >= 110: 69 | # Train ResNet 70 | if t / args.epochs < 0.5: 71 | return args.lr_max 72 | elif t / args.epochs < 0.75: 73 | return args.lr_max / 10. 74 | else: 75 | return args.lr_max / 100. 76 | else: 77 | # Train Wide-ResNet 78 | if t / args.epochs < 0.3: 79 | return args.lr_max 80 | elif t / args.epochs < 0.6: 81 | return args.lr_max / 10. 82 | else: 83 | return args.lr_max / 100. 84 | elif args.lr_schedule == 'linear': 85 | lr_schedule = lambda t: np.interp([t], [0, args.epochs // 3, args.epochs * 2 // 3, args.epochs], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0] 86 | elif args.lr_schedule == 'onedrop': 87 | def lr_schedule(t): 88 | if t < args.lr_drop_epoch: 89 | return args.lr_max 90 | else: 91 | return args.lr_one_drop 92 | elif args.lr_schedule == 'multipledecay': 93 | def lr_schedule(t): 94 | return args.lr_max - (t//(args.epochs//10))*(args.lr_max/10) 95 | elif args.lr_schedule == 'cosine': 96 | def lr_schedule(t): 97 | return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi)) 98 | 99 | # Store path 100 | if not os.path.exists(out_dir): 101 | os.makedirs(out_dir) 102 | 103 | # Save checkpoint 104 | def save_checkpoint(state, checkpoint=out_dir, filename='checkpoint.pth.tar'): 105 | filepath = os.path.join(checkpoint, filename) 106 | torch.save(state, filepath) 107 | 108 | # Get adversarially robust network 109 | def train(epoch, model, train_loader, optimizer): 110 | 111 | lr = 0 112 | num_data = 0 113 | train_robust_loss = 0 114 | 115 | for batch_idx, (data, target) in enumerate(train_loader): 116 | 117 | loss = 0 118 | data, target = data.cuda(), target.cuda() 119 | 120 | x_adv, _ = attack.PGD(model,data,target,args.epsilon,args.step_size,args.num_steps,loss_fn="cent",category="Madry",rand_init=True) 121 | 122 | model.train() 123 | lr = lr_schedule(epoch + 1) 124 | optimizer.param_groups[0].update(lr=lr) 125 | optimizer.zero_grad() 126 | 127 | logit = model(x_adv) 128 | 129 | loss = nn.CrossEntropyLoss(reduce="mean")(logit, target) 130 | 131 | train_robust_loss += loss.item() * len(x_adv) 132 | 133 | loss.backward() 134 | optimizer.step() 135 | 136 | num_data += len(data) 137 | 138 | train_robust_loss = train_robust_loss / num_data 139 | 140 | return train_robust_loss, lr 141 | 142 | 143 | 144 | # Setup data loader 145 | transform_train = transforms.Compose([ 146 | transforms.RandomCrop(32, padding=4), 147 | transforms.RandomHorizontalFlip(), 148 | transforms.ToTensor(), 149 | ]) 150 | transform_test = transforms.Compose([ 151 | transforms.ToTensor(), 152 | ]) 153 | 154 | if args.dataset == "cifar10": 155 | trainset = torchvision.datasets.CIFAR10(root='~/data/cifar-10', train=True, download=True, transform=transform_train) 156 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 157 | testset = torchvision.datasets.CIFAR10(root='~/data/cifar-10', train=False, download=True, transform=transform_test) 158 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 159 | if args.dataset == "cifar100": 160 | trainset = torchvision.datasets.CIFAR100(root='~/data/cifar-100', train=True, download=True, transform=transform_train) 161 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 162 | testset = torchvision.datasets.CIFAR100(root='~/data/cifar-100', train=False, download=True, transform=transform_test) 163 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 164 | if args.dataset == "svhn": 165 | trainset = torchvision.datasets.SVHN(root='~/data/SVHN', split='train', download=True, transform=transform_train) 166 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 167 | testset = torchvision.datasets.SVHN(root='~/data/SVHN', split='test', download=True, transform=transform_test) 168 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 169 | if args.dataset == "mnist": 170 | trainset = torchvision.datasets.MNIST(root='~/data/MNIST', train=True, download=True, transform=transforms.ToTensor()) 171 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=1,pin_memory=True) 172 | testset = torchvision.datasets.MNIST(root='~/data/MNIST', train=False, download=True, transform=transforms.ToTensor()) 173 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=1,pin_memory=True) 174 | 175 | # Resume 176 | title = 'AT' 177 | best_acc = 0 178 | start_epoch = 0 179 | if resume: 180 | # Resume directly point to checkpoint.pth.tar 181 | print(resume) 182 | assert os.path.isfile(resume) 183 | out_dir = os.path.dirname(resume) 184 | checkpoint = torch.load(resume) 185 | start_epoch = checkpoint['epoch'] 186 | best_acc = checkpoint['test_pgd10_acc'] 187 | model.load_state_dict(checkpoint['state_dict']) 188 | optimizer.load_state_dict(checkpoint['optimizer']) 189 | logger_test = Logger(os.path.join(out_dir, 'log_results.txt'), title=title, resume=True) 190 | else: 191 | logger_test = Logger(os.path.join(out_dir, 'log_results.txt'), title=title) 192 | logger_test.set_names(['Epoch', 'Natural Test Acc', 'PGD10 Acc']) 193 | 194 | ## Training get started 195 | test_nat_acc = 0 196 | test_pgd10_acc = 0 197 | 198 | for epoch in range(start_epoch, args.epochs): 199 | 200 | # Adversarial training 201 | train_robust_loss, lr = train(epoch, model, train_loader, optimizer) 202 | 203 | # Evalutions similar to DAT. 204 | _, test_nat_acc = attack.eval_clean(model, test_loader) 205 | _, test_pgd10_acc = attack.eval_robust(model, test_loader, perturb_steps=10, epsilon=8/255, step_size=2/255,loss_fn="cent", category="Madry", random=True) 206 | 207 | 208 | print( 209 | 'Epoch: [%d | %d] | Learning Rate: %f | Natural Test Acc %.2f | PGD20 Test Acc %.2f |\n' % ( 210 | epoch, 211 | args.epochs, 212 | lr, 213 | test_nat_acc, 214 | test_pgd10_acc) 215 | ) 216 | 217 | logger_test.append([epoch + 1, test_nat_acc, test_pgd10_acc]) 218 | 219 | # Save the best checkpoint 220 | if test_pgd10_acc > best_acc: 221 | best_acc = test_pgd10_acc 222 | save_checkpoint({ 223 | 'epoch': epoch + 1, 224 | 'state_dict': model.state_dict(), 225 | 'test_nat_acc': test_nat_acc, 226 | 'test_pgd10_acc': test_pgd10_acc, 227 | 'optimizer' : optimizer.state_dict(), 228 | },filename='bestpoint.pth.tar') 229 | 230 | # Save the last checkpoint 231 | save_checkpoint({ 232 | 'epoch': epoch + 1, 233 | 'state_dict': model.state_dict(), 234 | 'test_nat_acc': test_nat_acc, 235 | 'test_pgd10_acc': test_pgd10_acc, 236 | 'optimizer' : optimizer.state_dict(), 237 | }) 238 | if (epoch+1)%10 == 0: 239 | 240 | save_checkpoint({ 241 | 'epoch': epoch + 1, 242 | 'state_dict': model.state_dict(), 243 | 'test_nat_acc': test_nat_acc, 244 | 'test_pgd10_acc': test_pgd10_acc, 245 | 'optimizer' : optimizer.state_dict(), 246 | },filename='check'+str(epoch+1)+'.pth.tar') 247 | 248 | logger_test.close() -------------------------------------------------------------------------------- /pre_train/ST.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torchvision 4 | import torch.optim as optim 5 | from torchvision import transforms 6 | from models import * 7 | import numpy as np 8 | import attack_generator as attack 9 | from utils import Logger 10 | 11 | parser = argparse.ArgumentParser(description='ST') 12 | parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train') 13 | parser.add_argument('--weight-decay', '--wd', default=2e-4, type=float, metavar='W') 14 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum') 15 | parser.add_argument('--epsilon', type=float, default=8/255, help='perturbation bound') 16 | parser.add_argument('--num-steps', type=int, default=10, help='maximum perturbation step K') 17 | parser.add_argument('--step-size', type=float, default=2/255, help='step size') 18 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed') 19 | parser.add_argument('--net', type=str, default="resnet18",help="decide which network to use,choose from smallcnn,resnet18,WRN") 20 | parser.add_argument('--dataset', type=str, default="cifar10", help="choose from cifar10,svhn,cifar100,mnist") 21 | parser.add_argument('--random',type=bool,default=True,help="whether to initiat adversarial sample with random noise") 22 | parser.add_argument('--depth',type=int,default=34,help='WRN depth') 23 | parser.add_argument('--width-factor',type=int,default=10,help='WRN width factor') 24 | parser.add_argument('--drop-rate',type=float,default=0.0, help='WRN drop rate') 25 | parser.add_argument('--resume',type=str,default=None,help='whether to resume training') 26 | parser.add_argument('--out-dir',type=str,default='ST_teacher_cifar10',help='dir of output') 27 | parser.add_argument('--lr-schedule', default='piecewise', choices=['superconverge', 'piecewise', 'linear', 'onedrop', 'multipledecay', 'cosine']) 28 | parser.add_argument('--lr-max', default=0.1, type=float) 29 | parser.add_argument('--lr-one-drop', default=0.01, type=float) 30 | parser.add_argument('--lr-drop-epoch', default=100, type=int) 31 | args = parser.parse_args() 32 | 33 | # Training settings 34 | seed = args.seed 35 | momentum = args.momentum 36 | weight_decay = args.weight_decay 37 | depth = args.depth 38 | width_factor = args.width_factor 39 | drop_rate = args.drop_rate 40 | resume = args.resume 41 | out_dir = args.out_dir 42 | 43 | torch.manual_seed(seed) 44 | np.random.seed(seed) 45 | torch.cuda.manual_seed_all(seed) 46 | torch.backends.cudnn.benchmark = True 47 | torch.backends.cudnn.deterministic = True 48 | 49 | # Models and optimizer 50 | if args.net == "resnet18": 51 | model = ResNet18().cuda() 52 | net = "resnet18" 53 | if args.net == "preactresnet18": 54 | model = PreActResNet18().cuda() 55 | net = "preactresnet18" 56 | if args.net == "WRN": 57 | model = Wide_ResNet(depth=depth, num_classes=10, widen_factor=width_factor, dropRate=drop_rate).cuda() 58 | net = "WRN{}-{}-dropout{}".format(depth,width_factor,drop_rate) 59 | 60 | model = torch.nn.DataParallel(model) 61 | optimizer = optim.SGD(model.parameters(), lr=args.lr_max, momentum=momentum, weight_decay=weight_decay) 62 | 63 | # Learning schedules 64 | if args.lr_schedule == 'superconverge': 65 | lr_schedule = lambda t: np.interp([t], [0, args.epochs * 2 // 5, args.epochs], [0, args.lr_max, 0])[0] 66 | elif args.lr_schedule == 'piecewise': 67 | def lr_schedule(t): 68 | if args.epochs >= 110: 69 | # Train ResNet 70 | if t / args.epochs < 0.5: 71 | return args.lr_max 72 | elif t / args.epochs < 0.75: 73 | return args.lr_max / 10. 74 | else: 75 | return args.lr_max / 100. 76 | else: 77 | # Train Wide-ResNet 78 | if t / args.epochs < 0.3: 79 | return args.lr_max 80 | elif t / args.epochs < 0.6: 81 | return args.lr_max / 10. 82 | else: 83 | return args.lr_max / 100. 84 | elif args.lr_schedule == 'linear': 85 | lr_schedule = lambda t: np.interp([t], [0, args.epochs // 3, args.epochs * 2 // 3, args.epochs], [args.lr_max, args.lr_max, args.lr_max / 10, args.lr_max / 100])[0] 86 | elif args.lr_schedule == 'onedrop': 87 | def lr_schedule(t): 88 | if t < args.lr_drop_epoch: 89 | return args.lr_max 90 | else: 91 | return args.lr_one_drop 92 | elif args.lr_schedule == 'multipledecay': 93 | def lr_schedule(t): 94 | return args.lr_max - (t//(args.epochs//10))*(args.lr_max/10) 95 | elif args.lr_schedule == 'cosine': 96 | def lr_schedule(t): 97 | return args.lr_max * 0.5 * (1 + np.cos(t / args.epochs * np.pi)) 98 | 99 | # Store path 100 | if not os.path.exists(out_dir): 101 | os.makedirs(out_dir) 102 | 103 | # Save checkpoint 104 | def save_checkpoint(state, checkpoint=out_dir, filename='checkpoint.pth.tar'): 105 | filepath = os.path.join(checkpoint, filename) 106 | torch.save(state, filepath) 107 | 108 | def train(epoch, model, train_loader, optimizer): 109 | 110 | lr = 0 111 | num_data = 0 112 | train_robust_loss = 0 113 | 114 | for batch_idx, (data, target) in enumerate(train_loader): 115 | 116 | loss = 0 117 | data, target = data.cuda(), target.cuda() 118 | 119 | model.train() 120 | lr = lr_schedule(epoch + 1) 121 | optimizer.param_groups[0].update(lr=lr) 122 | optimizer.zero_grad() 123 | 124 | logit = model(data) 125 | 126 | loss = nn.CrossEntropyLoss(reduce="mean")(logit, target) 127 | 128 | train_robust_loss += loss.item() * len(data) 129 | 130 | loss.backward() 131 | optimizer.step() 132 | 133 | num_data += len(data) 134 | 135 | train_robust_loss = train_robust_loss / num_data 136 | 137 | return train_robust_loss, lr 138 | 139 | # Setup data loader 140 | transform_train = transforms.Compose([ 141 | transforms.RandomCrop(32, padding=4), 142 | transforms.RandomHorizontalFlip(), 143 | transforms.ToTensor(), 144 | ]) 145 | transform_test = transforms.Compose([ 146 | transforms.ToTensor(), 147 | ]) 148 | 149 | if args.dataset == "cifar10": 150 | trainset = torchvision.datasets.CIFAR10(root='~/data/cifar-10', train=True, download=True, transform=transform_train) 151 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 152 | testset = torchvision.datasets.CIFAR10(root='~/data/cifar-10', train=False, download=True, transform=transform_test) 153 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 154 | if args.dataset == "cifar100": 155 | trainset = torchvision.datasets.CIFAR100(root='~/data/cifar-100', train=True, download=True, transform=transform_train) 156 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 157 | testset = torchvision.datasets.CIFAR100(root='~/data/cifar-100', train=False, download=True, transform=transform_test) 158 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 159 | if args.dataset == "svhn": 160 | trainset = torchvision.datasets.SVHN(root='~/data/SVHN', split='train', download=True, transform=transform_train) 161 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 162 | testset = torchvision.datasets.SVHN(root='~/data/SVHN', split='test', download=True, transform=transform_test) 163 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 164 | if args.dataset == "mnist": 165 | trainset = torchvision.datasets.MNIST(root='~/data/MNIST', train=True, download=True, transform=transforms.ToTensor()) 166 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=1,pin_memory=True) 167 | testset = torchvision.datasets.MNIST(root='~/data/MNIST', train=False, download=True, transform=transforms.ToTensor()) 168 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=1,pin_memory=True) 169 | 170 | # Resume 171 | title = 'ST' 172 | best_acc = 0 173 | start_epoch = 0 174 | if resume: 175 | # Resume directly point to checkpoint.pth.tar 176 | print(resume) 177 | assert os.path.isfile(resume) 178 | out_dir = os.path.dirname(resume) 179 | checkpoint = torch.load(resume) 180 | start_epoch = checkpoint['epoch'] 181 | best_acc = checkpoint['test_pgd10_acc'] 182 | model.load_state_dict(checkpoint['state_dict']) 183 | optimizer.load_state_dict(checkpoint['optimizer']) 184 | logger_test = Logger(os.path.join(out_dir, 'log_results.txt'), title=title, resume=True) 185 | else: 186 | logger_test = Logger(os.path.join(out_dir, 'log_results.txt'), title=title) 187 | logger_test.set_names(['Epoch', 'Natural Test Acc', 'PGD20 Acc']) 188 | 189 | ## Training get started 190 | test_nat_acc = 0 191 | test_pgd10_acc = 0 192 | 193 | for epoch in range(start_epoch, args.epochs): 194 | 195 | 196 | # standard training 197 | train_robust_loss, lr = train(epoch, model, train_loader, optimizer) 198 | 199 | # Evalutions similar to DAT. 200 | _, test_nat_acc = attack.eval_clean(model, test_loader) 201 | _, test_pgd10_acc = 0,0 202 | 203 | 204 | print( 205 | 'Epoch: [%d | %d] | Learning Rate: %f | Natural Test Acc %.2f | PGD20 Test Acc %.2f |\n' % ( 206 | epoch, 207 | args.epochs, 208 | lr, 209 | test_nat_acc, 210 | test_pgd10_acc) 211 | ) 212 | 213 | logger_test.append([epoch + 1, test_nat_acc, test_pgd10_acc]) 214 | 215 | # Save the best checkpoint 216 | if test_nat_acc > best_acc: 217 | best_acc = test_nat_acc 218 | save_checkpoint({ 219 | 'epoch': epoch + 1, 220 | 'state_dict': model.state_dict(), 221 | 'test_nat_acc': test_nat_acc, 222 | 'test_pgd10_acc': test_pgd10_acc, 223 | 'optimizer' : optimizer.state_dict(), 224 | },filename='bestpoint.pth.tar') 225 | 226 | # Save the last checkpoint 227 | save_checkpoint({ 228 | 'epoch': epoch + 1, 229 | 'state_dict': model.state_dict(), 230 | 'test_nat_acc': test_nat_acc, 231 | 'test_pgd10_acc': test_pgd10_acc, 232 | 'optimizer' : optimizer.state_dict(), 233 | }) 234 | if (epoch+1)%10 == 0: 235 | 236 | save_checkpoint({ 237 | 'epoch': epoch + 1, 238 | 'state_dict': model.state_dict(), 239 | 'test_nat_acc': test_nat_acc, 240 | 'test_pgd10_acc': test_pgd10_acc, 241 | 'optimizer' : optimizer.state_dict(), 242 | },filename='check'+str(epoch+1)+'.pth.tar') 243 | 244 | logger_test.close() -------------------------------------------------------------------------------- /pre_train/attack_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from models import * 3 | from torch.autograd import Variable 4 | 5 | def cwloss(output, target,confidence=50, num_classes=10): 6 | # Compute the probability of the label class versus the maximum other 7 | target = target.data 8 | target_onehot = torch.zeros(target.size() + (num_classes,)) 9 | target_onehot = target_onehot.cuda() 10 | target_onehot.scatter_(1, target.unsqueeze(1), 1.) 11 | target_var = Variable(target_onehot, requires_grad=False) 12 | real = (target_var * output).sum(1) 13 | other = ((1. - target_var) * output - target_var * 10000.).max(1)[0] 14 | loss = -torch.clamp(real - other + confidence, min=0.) # equiv to max(..., 0.) 15 | loss = torch.sum(loss) 16 | return loss 17 | 18 | def PGD(model, data, target, epsilon, step_size, num_steps,loss_fn,category,rand_init): 19 | model.eval() 20 | Kappa = torch.zeros(len(data)) 21 | if category == "trades": 22 | x_adv = data.detach() + 0.001 * torch.randn(data.shape).cuda().detach() if rand_init else data.detach() 23 | nat_output = model(data) 24 | if category == "Madry": 25 | x_adv = data.detach() + torch.from_numpy(np.random.uniform(-epsilon, epsilon, data.shape)).float().cuda() if rand_init else data.detach() 26 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 27 | for k in range(num_steps): 28 | x_adv.requires_grad_() 29 | output = model(x_adv) 30 | predict = output.max(1, keepdim=True)[1] 31 | # Update Kappa 32 | for p in range(len(x_adv)): 33 | if predict[p] == target[p]: 34 | Kappa[p] += 1 35 | model.zero_grad() 36 | with torch.enable_grad(): 37 | if loss_fn == "cent": 38 | loss_adv = nn.CrossEntropyLoss(reduction="mean")(output, target) 39 | if loss_fn == "cw": 40 | loss_adv = cwloss(output,target) 41 | if loss_fn == "kl": 42 | criterion_kl = nn.KLDivLoss(size_average=False).cuda() 43 | loss_adv = criterion_kl(F.log_softmax(output, dim=1),F.softmax(nat_output, dim=1)) 44 | loss_adv.backward() 45 | eta = step_size * x_adv.grad.sign() 46 | # Update adversarial data 47 | x_adv = x_adv.detach() + eta 48 | x_adv = torch.min(torch.max(x_adv, data - epsilon), data + epsilon) 49 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 50 | x_adv = Variable(x_adv, requires_grad=False) 51 | return x_adv, Kappa 52 | 53 | def eval_clean(model, test_loader): 54 | model.eval() 55 | test_loss = 0 56 | correct = 0 57 | with torch.no_grad(): 58 | for data, target in test_loader: 59 | data, target = data.cuda(), target.cuda() 60 | output = model(data) 61 | test_loss += F.cross_entropy(output, target, size_average=False).item() 62 | pred = output.max(1, keepdim=True)[1] 63 | correct += pred.eq(target.view_as(pred)).sum().item() 64 | test_loss /= len(test_loader.dataset) 65 | test_accuracy = correct / len(test_loader.dataset) 66 | return test_loss, test_accuracy 67 | 68 | def eval_robust(model, test_loader, perturb_steps, epsilon, step_size, loss_fn, category, random): 69 | model.eval() 70 | test_loss = 0 71 | correct = 0 72 | with torch.enable_grad(): 73 | for data, target in test_loader: 74 | data, target = data.cuda(), target.cuda() 75 | x_adv, _ = PGD(model,data,target,epsilon,step_size,perturb_steps,loss_fn,category,rand_init=random) 76 | output = model(x_adv) 77 | test_loss += F.cross_entropy(output, target, size_average=False).item() 78 | pred = output.max(1, keepdim=True)[1] 79 | correct += pred.eq(target.view_as(pred)).sum().item() 80 | test_loss /= len(test_loader.dataset) 81 | test_accuracy = correct / len(test_loader.dataset) 82 | return test_loss, test_accuracy 83 | 84 | -------------------------------------------------------------------------------- /pre_train/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .preact_resnet import * 3 | from .wide_resnet import * 4 | from .wrn_madry import * 5 | from .wideresnet import * 6 | -------------------------------------------------------------------------------- /pre_train/models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | 7 | 8 | class PreActBlock(nn.Module): 9 | '''Pre-activation version of the BasicBlock.''' 10 | expansion = 1 11 | 12 | def __init__(self, in_planes, planes, stride=1): 13 | super(PreActBlock, self).__init__() 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 18 | 19 | if stride != 1 or in_planes != self.expansion*planes: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 22 | ) 23 | 24 | def forward(self, x): 25 | out = F.relu(self.bn1(x)) 26 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 27 | out = self.conv1(out) 28 | out = self.conv2(F.relu(self.bn2(out))) 29 | out += shortcut 30 | return out 31 | 32 | 33 | class PreActBottleneck(nn.Module): 34 | '''Pre-activation version of the original Bottleneck module.''' 35 | expansion = 4 36 | 37 | def __init__(self, in_planes, planes, stride=1): 38 | super(PreActBottleneck, self).__init__() 39 | self.bn1 = nn.BatchNorm2d(in_planes) 40 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 41 | self.bn2 = nn.BatchNorm2d(planes) 42 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 43 | self.bn3 = nn.BatchNorm2d(planes) 44 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 45 | 46 | if stride != 1 or in_planes != self.expansion*planes: 47 | self.shortcut = nn.Sequential( 48 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 49 | ) 50 | 51 | def forward(self, x): 52 | out = F.relu(self.bn1(x)) 53 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 54 | out = self.conv1(out) 55 | out = self.conv2(F.relu(self.bn2(out))) 56 | out = self.conv3(F.relu(self.bn3(out))) 57 | out += shortcut 58 | return out 59 | 60 | 61 | class PreActResNet(nn.Module): 62 | def __init__(self, block, num_blocks, num_classes=10): 63 | super(PreActResNet, self).__init__() 64 | self.in_planes = 64 65 | 66 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 67 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 68 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 69 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 70 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 71 | self.linear = nn.Linear(512*block.expansion, num_classes) 72 | 73 | def _make_layer(self, block, planes, num_blocks, stride): 74 | strides = [stride] + [1]*(num_blocks-1) 75 | layers = [] 76 | for stride in strides: 77 | layers.append(block(self.in_planes, planes, stride)) 78 | self.in_planes = planes * block.expansion 79 | return nn.Sequential(*layers) 80 | 81 | def forward(self, x): 82 | out = self.conv1(x) 83 | out = self.layer1(out) 84 | out = self.layer2(out) 85 | out = self.layer3(out) 86 | out = self.layer4(out) 87 | out = F.avg_pool2d(out, 4) 88 | out = out.view(out.size(0), -1) 89 | out = self.linear(out) 90 | return out 91 | 92 | 93 | def PreActResNet18(): 94 | return PreActResNet(PreActBlock, [2,2,2,2]) 95 | 96 | def PreActResNet34(): 97 | return PreActResNet(PreActBlock, [3,4,6,3]) 98 | 99 | def PreActResNet50(): 100 | return PreActResNet(PreActBottleneck, [3,4,6,3]) 101 | 102 | def PreActResNet101(): 103 | return PreActResNet(PreActBottleneck, [3,4,23,3]) 104 | 105 | def PreActResNet152(): 106 | return PreActResNet(PreActBottleneck, [3,8,36,3]) 107 | 108 | 109 | def test(): 110 | net = PreActResNet18() 111 | y = net(Variable(torch.randn(1,3,32,32))) 112 | print(y.size()) 113 | 114 | # test() 115 | -------------------------------------------------------------------------------- /pre_train/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | 7 | 8 | class BasicBlock(nn.Module): 9 | expansion = 1 10 | 11 | def __init__(self, in_planes, planes, stride=1): 12 | super(BasicBlock, self).__init__() 13 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(planes) 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | 18 | self.shortcut = nn.Sequential() 19 | if stride != 1 or in_planes != self.expansion*planes: 20 | self.shortcut = nn.Sequential( 21 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 22 | nn.BatchNorm2d(self.expansion*planes) 23 | ) 24 | 25 | def forward(self, x): 26 | out = F.relu(self.bn1(self.conv1(x))) 27 | out = self.bn2(self.conv2(out)) 28 | out += self.shortcut(x) 29 | out = F.relu(out) 30 | return out 31 | 32 | 33 | class Bottleneck(nn.Module): 34 | expansion = 4 35 | 36 | def __init__(self, in_planes, planes, stride=1): 37 | super(Bottleneck, self).__init__() 38 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 39 | self.bn1 = nn.BatchNorm2d(planes) 40 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 41 | self.bn2 = nn.BatchNorm2d(planes) 42 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 43 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 44 | 45 | self.shortcut = nn.Sequential() 46 | if stride != 1 or in_planes != self.expansion*planes: 47 | self.shortcut = nn.Sequential( 48 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 49 | nn.BatchNorm2d(self.expansion*planes) 50 | ) 51 | 52 | def forward(self, x): 53 | out = F.relu(self.bn1(self.conv1(x))) 54 | out = F.relu(self.bn2(self.conv2(out))) 55 | out = self.bn3(self.conv3(out)) 56 | out += self.shortcut(x) 57 | out = F.relu(out) 58 | return out 59 | 60 | 61 | class ResNet(nn.Module): 62 | def __init__(self, block, num_blocks, num_classes=10): 63 | super(ResNet, self).__init__() 64 | self.in_planes = 64 65 | 66 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 67 | self.bn1 = nn.BatchNorm2d(64) 68 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 69 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 70 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 71 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 72 | self.linear = nn.Linear(512*block.expansion, num_classes) 73 | 74 | def _make_layer(self, block, planes, num_blocks, stride): 75 | strides = [stride] + [1]*(num_blocks-1) 76 | layers = [] 77 | for stride in strides: 78 | layers.append(block(self.in_planes, planes, stride)) 79 | self.in_planes = planes * block.expansion 80 | return nn.Sequential(*layers) 81 | 82 | def forward(self, x): 83 | out = F.relu(self.bn1(self.conv1(x))) 84 | out = self.layer1(out) 85 | out = self.layer2(out) 86 | out = self.layer3(out) 87 | out = self.layer4(out) 88 | out = F.avg_pool2d(out, 4) 89 | out = out.view(out.size(0), -1) 90 | out = self.linear(out) 91 | return out 92 | 93 | 94 | def ResNet18(num_classes=10): 95 | return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes) 96 | 97 | def ResNet34(): 98 | return ResNet(BasicBlock, [3,4,6,3]) 99 | 100 | def ResNet50(): 101 | return ResNet(Bottleneck, [3,4,6,3]) 102 | 103 | def ResNet101(): 104 | return ResNet(Bottleneck, [3,4,23,3]) 105 | 106 | def ResNet152(): 107 | return ResNet(Bottleneck, [3,8,36,3]) 108 | 109 | 110 | def test(): 111 | net = ResNet18() 112 | y = net(Variable(torch.randn(1,3,32,32))) 113 | print(y.size()) 114 | print(net) 115 | # test() -------------------------------------------------------------------------------- /pre_train/models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | 35 | class NetworkBlock(nn.Module): 36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 37 | super(NetworkBlock, self).__init__() 38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 39 | 40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 41 | layers = [] 42 | for i in range(int(nb_layers)): 43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | return self.layer(x) 48 | 49 | 50 | class Wide_ResNet(nn.Module): 51 | def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0): 52 | super(Wide_ResNet, self).__init__() 53 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 54 | assert ((depth - 4) % 6 == 0) 55 | n = (depth - 4) / 6 56 | block = BasicBlock 57 | # 1st conv before any network block 58 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 59 | padding=1, bias=False) 60 | # 1st block 61 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 62 | # 1st sub-block 63 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 64 | # 2nd block 65 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 66 | # 3rd block 67 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 68 | # global average pooling and classifier 69 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.fc = nn.Linear(nChannels[3], num_classes) 72 | self.nChannels = nChannels[3] 73 | 74 | for m in self.modules(): 75 | if isinstance(m, nn.Conv2d): 76 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 77 | m.weight.data.normal_(0, math.sqrt(2. / n)) 78 | elif isinstance(m, nn.BatchNorm2d): 79 | m.weight.data.fill_(1) 80 | m.bias.data.zero_() 81 | elif isinstance(m, nn.Linear): 82 | m.bias.data.zero_() 83 | 84 | def forward(self, x): 85 | out = self.conv1(x) 86 | out = self.block1(out) 87 | out = self.block2(out) 88 | out = self.block3(out) 89 | out = self.relu(self.bn1(out)) 90 | out = F.avg_pool2d(out, 8) 91 | out = out.view(-1, self.nChannels) 92 | return self.fc(out) 93 | def test(): 94 | net = Wide_ResNet() 95 | y = net(Variable(torch.randn(1, 3, 32, 32))) 96 | #print(y.size()) 97 | print(net) 98 | # test() -------------------------------------------------------------------------------- /pre_train/models/wideresnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | def forward(self, x): 23 | if not self.equalInOut: 24 | x = self.relu1(self.bn1(x)) 25 | else: 26 | out = self.relu1(self.bn1(x)) 27 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 28 | if self.droprate > 0: 29 | out = F.dropout(out, p=self.droprate, training=self.training) 30 | out = self.conv2(out) 31 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 32 | 33 | class NetworkBlock(nn.Module): 34 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 35 | super(NetworkBlock, self).__init__() 36 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 37 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 38 | layers = [] 39 | for i in range(int(nb_layers)): 40 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 41 | return nn.Sequential(*layers) 42 | def forward(self, x): 43 | return self.layer(x) 44 | 45 | class WideResNet(nn.Module): 46 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 47 | super(WideResNet, self).__init__() 48 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 49 | assert((depth - 4) % 6 == 0) 50 | n = (depth - 4) / 6 51 | block = BasicBlock 52 | # 1st conv before any network block 53 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 54 | padding=1, bias=False) 55 | # 1st block 56 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 57 | # 2nd block 58 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 59 | # 3rd block 60 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 61 | # global average pooling and classifier 62 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.fc = nn.Linear(nChannels[3], num_classes) 65 | self.nChannels = nChannels[3] 66 | 67 | for m in self.modules(): 68 | if isinstance(m, nn.Conv2d): 69 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 70 | elif isinstance(m, nn.BatchNorm2d): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | elif isinstance(m, nn.Linear): 74 | m.bias.data.zero_() 75 | def forward(self, x): 76 | out = self.conv1(x) 77 | out = self.block1(out) 78 | out = self.block2(out) 79 | out = self.block3(out) 80 | out = self.relu(self.bn1(out)) 81 | out = F.avg_pool2d(out, 8) 82 | out = out.view(-1, self.nChannels) 83 | return self.fc(out) 84 | -------------------------------------------------------------------------------- /pre_train/models/wrn_madry.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | 35 | class NetworkBlock(nn.Module): 36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 37 | super(NetworkBlock, self).__init__() 38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 39 | 40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 41 | layers = [] 42 | for i in range(int(nb_layers)): 43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | return self.layer(x) 48 | 49 | 50 | class Wide_ResNet_Madry(nn.Module): 51 | def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0): 52 | super(Wide_ResNet_Madry, self).__init__() 53 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 54 | assert ((depth - 2) % 6 == 0) 55 | n = (depth - 2) / 6 56 | block = BasicBlock 57 | # 1st conv before any network block 58 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 59 | padding=1, bias=False) 60 | # 1st block 61 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 62 | # 1st sub-block 63 | # self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 64 | # 2nd block 65 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 66 | # 3rd block 67 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 68 | # global average pooling and classifier 69 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.fc = nn.Linear(nChannels[3], num_classes) 72 | self.nChannels = nChannels[3] 73 | 74 | for m in self.modules(): 75 | if isinstance(m, nn.Conv2d): 76 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 77 | m.weight.data.normal_(0, math.sqrt(2. / n)) 78 | elif isinstance(m, nn.BatchNorm2d): 79 | m.weight.data.fill_(1) 80 | m.bias.data.zero_() 81 | elif isinstance(m, nn.Linear): 82 | m.bias.data.zero_() 83 | 84 | def forward(self, x): 85 | out = self.conv1(x) 86 | out = self.block1(out) 87 | out = self.block2(out) 88 | out = self.block3(out) 89 | out = self.relu(self.bn1(out)) 90 | out = F.avg_pool2d(out, 8) 91 | out = out.view(-1, self.nChannels) 92 | return self.fc(out) 93 | def test(): 94 | net = Wide_ResNet_Madry() 95 | y = net(Variable(torch.randn(1, 3, 32, 32))) 96 | #print(y.size()) 97 | print(net) 98 | # test() -------------------------------------------------------------------------------- /pre_train/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .eval import * 6 | 7 | # progress bar 8 | import os, sys 9 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 10 | #from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /pre_train/utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /pre_train/utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | from __future__ import absolute_import 3 | import matplotlib.pyplot as plt 4 | import os 5 | import sys 6 | import numpy as np 7 | 8 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 9 | 10 | def savefig(fname, dpi=None): 11 | dpi = 150 if dpi == None else dpi 12 | plt.savefig(fname, dpi=dpi) 13 | 14 | def plot_overlap(logger, names=None): 15 | names = logger.names if names == None else names 16 | numbers = logger.numbers 17 | for _, name in enumerate(names): 18 | x = np.arange(len(numbers[name])) 19 | plt.plot(x, np.asarray(numbers[name])) 20 | return [logger.title + '(' + name + ')' for name in names] 21 | 22 | class Logger(object): 23 | '''Save training process to log file with simple plot function.''' 24 | def __init__(self, fpath, title=None, resume=False): 25 | self.file = None 26 | self.resume = resume 27 | self.title = '' if title == None else title 28 | if fpath is not None: 29 | if resume: 30 | self.file = open(fpath, 'r') 31 | name = self.file.readline() 32 | self.names = name.rstrip().split('\t') 33 | self.numbers = {} 34 | for _, name in enumerate(self.names): 35 | self.numbers[name] = [] 36 | 37 | for numbers in self.file: 38 | numbers = numbers.rstrip().split('\t') 39 | for i in range(0, len(numbers)): 40 | self.numbers[self.names[i]].append(numbers[i]) 41 | self.file.close() 42 | self.file = open(fpath, 'a') 43 | else: 44 | self.file = open(fpath, 'w') 45 | 46 | def set_names(self, names): 47 | if self.resume: 48 | pass 49 | # initialize numbers as empty list 50 | self.numbers = {} 51 | self.names = names 52 | for _, name in enumerate(self.names): 53 | self.file.write(name) 54 | self.file.write('\t') 55 | self.numbers[name] = [] 56 | self.file.write('\n') 57 | self.file.flush() 58 | 59 | 60 | def append(self, numbers): 61 | assert len(self.names) == len(numbers), 'Numbers do not match names' 62 | for index, num in enumerate(numbers): 63 | self.file.write("{0:.6f}".format(num)) 64 | self.file.write('\t') 65 | self.numbers[self.names[index]].append(num) 66 | self.file.write('\n') 67 | self.file.flush() 68 | 69 | def plot(self, names=None): 70 | names = self.names if names == None else names 71 | numbers = self.numbers 72 | for _, name in enumerate(names): 73 | x = np.arange(len(numbers[name])) 74 | plt.plot(x, np.asarray(numbers[name])) 75 | plt.legend([self.title + '(' + name + ')' for name in names]) 76 | plt.grid(True) 77 | 78 | def close(self): 79 | if self.file is not None: 80 | self.file.close() 81 | 82 | class LoggerMonitor(object): 83 | '''Load and visualize multiple logs.''' 84 | def __init__ (self, paths): 85 | '''paths is a distionary with {name:filepath} pair''' 86 | self.loggers = [] 87 | for title, path in paths.items(): 88 | logger = Logger(path, title=title, resume=True) 89 | self.loggers.append(logger) 90 | 91 | def plot(self, names=None): 92 | plt.figure() 93 | plt.subplot(121) 94 | legend_text = [] 95 | for logger in self.loggers: 96 | legend_text += plot_overlap(logger, names) 97 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 98 | plt.grid(True) 99 | 100 | if __name__ == '__main__': 101 | # # Example 102 | # logger = Logger('test.txt') 103 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 104 | 105 | # length = 100 106 | # t = np.arange(length) 107 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 108 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 109 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 110 | 111 | # for i in range(0, length): 112 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 113 | # logger.plot() 114 | 115 | # Example: logger monitor 116 | paths = { 117 | 'resadvnet20':'~/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 118 | 'resadvnet32':'~/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 119 | 'resadvnet44':'~/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 120 | } 121 | 122 | field = ['Valid Acc.'] 123 | 124 | monitor = LoggerMonitor(paths) 125 | monitor.plot(names=field) 126 | savefig('test.eps') -------------------------------------------------------------------------------- /pre_train/utils/misc.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import os 3 | import sys 4 | import time 5 | import math 6 | 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | from torch.autograd import Variable 10 | 11 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 12 | 13 | 14 | def get_mean_and_std(dataset): 15 | '''Compute the mean and std value of dataset.''' 16 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 17 | 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | print('==> Computing mean and std..') 21 | for inputs, targets in dataloader: 22 | for i in range(3): 23 | mean[i] += inputs[:,i,:,:].mean() 24 | std[i] += inputs[:,i,:,:].std() 25 | mean.div_(len(dataset)) 26 | std.div_(len(dataset)) 27 | return mean, std 28 | 29 | def init_params(net): 30 | '''Init layer parameters.''' 31 | for m in net.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal(m.weight, mode='fan_out') 34 | if m.bias: 35 | init.constant(m.bias, 0) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | init.constant(m.weight, 1) 38 | init.constant(m.bias, 0) 39 | elif isinstance(m, nn.Linear): 40 | init.normal(m.weight, std=1e-3) 41 | if m.bias: 42 | init.constant(m.bias, 0) 43 | 44 | def mkdir_p(path): 45 | '''make dir if not exist''' 46 | try: 47 | os.makedirs(path) 48 | except OSError as exc: # Python >2.5 49 | if exc.errno == errno.EEXIST and os.path.isdir(path): 50 | pass 51 | else: 52 | raise 53 | 54 | class AverageMeter(object): 55 | def __init__(self): 56 | self.reset() 57 | 58 | def reset(self): 59 | self.val = 0 60 | self.avg = 0 61 | self.sum = 0 62 | self.count = 0 63 | 64 | def update(self, val, n=1): 65 | self.val = val 66 | self.sum += val * n 67 | self.count += n 68 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .eval import * 6 | 7 | # progress bar 8 | import os, sys 9 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 10 | #from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | from __future__ import absolute_import 3 | import os 4 | import sys 5 | import numpy as np 6 | 7 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 8 | 9 | def savefig(fname, dpi=None): 10 | dpi = 150 if dpi == None else dpi 11 | plt.savefig(fname, dpi=dpi) 12 | 13 | class Logger(object): 14 | '''Save training process to log file with simple plot function.''' 15 | def __init__(self, fpath, title=None, resume=False): 16 | self.file = None 17 | self.resume = resume 18 | self.title = '' if title == None else title 19 | if fpath is not None: 20 | if resume: 21 | self.file = open(fpath, 'r') 22 | name = self.file.readline() 23 | self.names = name.rstrip().split('\t') 24 | self.numbers = {} 25 | for _, name in enumerate(self.names): 26 | self.numbers[name] = [] 27 | 28 | for numbers in self.file: 29 | numbers = numbers.rstrip().split('\t') 30 | for i in range(0, len(numbers)): 31 | self.numbers[self.names[i]].append(numbers[i]) 32 | self.file.close() 33 | self.file = open(fpath, 'a') 34 | else: 35 | self.file = open(fpath, 'w') 36 | 37 | def set_names(self, names): 38 | if self.resume: 39 | pass 40 | # initialize numbers as empty list 41 | self.numbers = {} 42 | self.names = names 43 | for _, name in enumerate(self.names): 44 | self.file.write(name) 45 | self.file.write('\t') 46 | self.numbers[name] = [] 47 | self.file.write('\n') 48 | self.file.flush() 49 | 50 | 51 | def append(self, numbers): 52 | assert len(self.names) == len(numbers), 'Numbers do not match names' 53 | for index, num in enumerate(numbers): 54 | self.file.write("{0:.6f}".format(num)) 55 | self.file.write('\t') 56 | self.numbers[self.names[index]].append(num) 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | def close(self): 61 | if self.file is not None: 62 | self.file.close() 63 | 64 | class LoggerMonitor(object): 65 | '''Load and visualize multiple logs.''' 66 | def __init__ (self, paths): 67 | '''paths is a distionary with {name:filepath} pair''' 68 | self.loggers = [] 69 | for title, path in paths.items(): 70 | logger = Logger(path, title=title, resume=True) 71 | self.loggers.append(logger) 72 | 73 | 74 | if __name__ == '__main__': 75 | # # Example 76 | # logger = Logger('test.txt') 77 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 78 | 79 | # length = 100 80 | # t = np.arange(length) 81 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 82 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 83 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 84 | 85 | # for i in range(0, length): 86 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 87 | # logger.plot() 88 | 89 | # Example: logger monitor 90 | paths = { 91 | 'resadvnet20':'~/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 92 | 'resadvnet32':'~/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 93 | 'resadvnet44':'~/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 94 | } 95 | 96 | field = ['Valid Acc.'] 97 | 98 | monitor = LoggerMonitor(paths) -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import os 3 | import sys 4 | import time 5 | import math 6 | 7 | import torch.nn as nn 8 | import torch.nn.init as init 9 | from torch.autograd import Variable 10 | 11 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 12 | 13 | 14 | def get_mean_and_std(dataset): 15 | '''Compute the mean and std value of dataset.''' 16 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 17 | 18 | mean = torch.zeros(3) 19 | std = torch.zeros(3) 20 | print('==> Computing mean and std..') 21 | for inputs, targets in dataloader: 22 | for i in range(3): 23 | mean[i] += inputs[:,i,:,:].mean() 24 | std[i] += inputs[:,i,:,:].std() 25 | mean.div_(len(dataset)) 26 | std.div_(len(dataset)) 27 | return mean, std 28 | 29 | def init_params(net): 30 | '''Init layer parameters.''' 31 | for m in net.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal(m.weight, mode='fan_out') 34 | if m.bias: 35 | init.constant(m.bias, 0) 36 | elif isinstance(m, nn.BatchNorm2d): 37 | init.constant(m.weight, 1) 38 | init.constant(m.bias, 0) 39 | elif isinstance(m, nn.Linear): 40 | init.normal(m.weight, std=1e-3) 41 | if m.bias: 42 | init.constant(m.bias, 0) 43 | 44 | def mkdir_p(path): 45 | '''make dir if not exist''' 46 | try: 47 | os.makedirs(path) 48 | except OSError as exc: # Python >2.5 49 | if exc.errno == errno.EEXIST and os.path.isdir(path): 50 | pass 51 | else: 52 | raise 53 | 54 | class AverageMeter(object): 55 | def __init__(self): 56 | self.reset() 57 | 58 | def reset(self): 59 | self.val = 0 60 | self.avg = 0 61 | self.sum = 0 62 | self.count = 0 63 | 64 | def update(self, val, n=1): 65 | self.val = val 66 | self.sum += val * n 67 | self.count += n 68 | self.avg = self.sum / self.count --------------------------------------------------------------------------------