├── FAT.py ├── FAT_for_MART.py ├── FAT_for_TRADES.py ├── README.md ├── attack_generator.py ├── attack_test.py ├── attack_test.sh ├── earlystop.py ├── image ├── adv_train.png ├── cross_over_mixture_problem.png ├── early_stopped_pgd.png ├── min-min_vs_minmax.png ├── min_min_formulation.png └── minimax_formulation.png ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── densenet.cpython-36.pyc │ ├── dpn.cpython-36.pyc │ ├── googlenet.cpython-36.pyc │ ├── lenet.cpython-36.pyc │ ├── mobilenet.cpython-36.pyc │ ├── preact_resnet.cpython-36.pyc │ ├── resnet.cpython-36.pyc │ ├── resnext.cpython-36.pyc │ ├── senet.cpython-36.pyc │ ├── shufflenet.cpython-36.pyc │ ├── small_cnn.cpython-36.pyc │ ├── vgg.cpython-36.pyc │ ├── wide_resnet.cpython-36.pyc │ └── wrn_madry.cpython-36.pyc ├── densenet.py ├── dpn.py ├── googlenet.py ├── lenet.py ├── mobilenet.py ├── preact_resnet.py ├── resnet.py ├── resnext.py ├── senet.py ├── shufflenet.py ├── small_cnn.py ├── vgg.py ├── wide_resnet.py └── wrn_madry.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-35.pyc ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── eval.cpython-35.pyc ├── eval.cpython-36.pyc ├── eval.cpython-37.pyc ├── logger.cpython-35.pyc ├── logger.cpython-36.pyc ├── logger.cpython-37.pyc ├── misc.cpython-35.pyc ├── misc.cpython-36.pyc └── misc.cpython-37.pyc └── logger.py /FAT.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torchvision 4 | import torch.optim as optim 5 | from torchvision import transforms 6 | import datetime 7 | from models import * 8 | from earlystop import earlystop 9 | import numpy as np 10 | from utils import Logger 11 | import attack_generator as attack 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch Friendly Adversarial Training') 14 | parser.add_argument('--epochs', type=int, default=120, metavar='N', help='number of epochs to train') 15 | parser.add_argument('--weight_decay', '--wd', default=2e-4, type=float, metavar='W') 16 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate') 17 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum') 18 | parser.add_argument('--epsilon', type=float, default=0.031, help='perturbation bound') 19 | parser.add_argument('--num_steps', type=int, default=10, help='maximum perturbation step K') 20 | parser.add_argument('--step_size', type=float, default=0.007, help='step size') 21 | parser.add_argument('--seed', type=int, default=7, metavar='S', help='random seed') 22 | parser.add_argument('--net', type=str, default="WRN_madry", 23 | help="decide which network to use,choose from smallcnn,resnet18,WRN") 24 | parser.add_argument('--tau', type=int, default=0, help='step tau') 25 | parser.add_argument('--dataset', type=str, default="cifar10", help="choose from cifar10,svhn") 26 | parser.add_argument('--rand_init', type=bool, default=True, help="whether to initialize adversarial sample with random noise") 27 | parser.add_argument('--omega', type=float, default=0.001, help="random sample parameter for adv data generation") 28 | parser.add_argument('--dynamictau', type=bool, default=True, help='whether to use dynamic tau') 29 | parser.add_argument('--depth', type=int, default=32, help='WRN depth') 30 | parser.add_argument('--width_factor', type=int, default=10, help='WRN width factor') 31 | parser.add_argument('--drop_rate', type=float, default=0.0, help='WRN drop rate') 32 | parser.add_argument('--out_dir', type=str, default='./FAT_results', help='dir of output') 33 | parser.add_argument('--resume', type=str, default='', help='whether to resume training, default: None') 34 | 35 | args = parser.parse_args() 36 | 37 | # training settings 38 | torch.manual_seed(args.seed) 39 | np.random.seed(args.seed) 40 | torch.cuda.manual_seed_all(args.seed) 41 | torch.backends.cudnn.deterministic = False 42 | torch.backends.cudnn.benchmark = True 43 | 44 | out_dir = args.out_dir 45 | if not os.path.exists(out_dir): 46 | os.makedirs(out_dir) 47 | 48 | def train(model, train_loader, optimizer, tau): 49 | starttime = datetime.datetime.now() 50 | loss_sum = 0 51 | bp_count = 0 52 | for batch_idx, (data, target) in enumerate(train_loader): 53 | data, target = data.cuda(), target.cuda() 54 | 55 | # Get friendly adversarial training data via early-stopped PGD 56 | output_adv, output_target, output_natural, count = earlystop(model, data, target, step_size=args.step_size, 57 | epsilon=args.epsilon, perturb_steps=args.num_steps, tau=tau, 58 | randominit_type="uniform_randominit", loss_fn='cent', rand_init=args.rand_init, omega=args.omega) 59 | bp_count += count 60 | model.train() 61 | optimizer.zero_grad() 62 | output = model(output_adv) 63 | 64 | # calculate standard adversarial training loss 65 | loss = nn.CrossEntropyLoss(reduction='mean')(output, output_target) 66 | 67 | loss_sum += loss.item() 68 | loss.backward() 69 | optimizer.step() 70 | 71 | bp_count_avg = bp_count / len(train_loader.dataset) 72 | endtime = datetime.datetime.now() 73 | time = (endtime - starttime).seconds 74 | 75 | return time, loss_sum, bp_count_avg 76 | 77 | def adjust_tau(epoch, dynamictau): 78 | tau = args.tau 79 | if dynamictau: 80 | if epoch <= 50: 81 | tau = 0 82 | elif epoch <= 90: 83 | tau = 1 84 | else: 85 | tau = 2 86 | return tau 87 | 88 | 89 | def adjust_learning_rate(optimizer, epoch): 90 | """decrease the learning rate""" 91 | lr = args.lr 92 | if epoch >= 60: 93 | lr = args.lr * 0.1 94 | if epoch >= 90: 95 | lr = args.lr * 0.01 96 | if epoch >= 110: 97 | lr = args.lr * 0.005 98 | for param_group in optimizer.param_groups: 99 | param_group['lr'] = lr 100 | 101 | 102 | def save_checkpoint(state, checkpoint=out_dir, filename='checkpoint.pth.tar'): 103 | filepath = os.path.join(checkpoint, filename) 104 | torch.save(state, filepath) 105 | 106 | # setup data loader 107 | transform_train = transforms.Compose([ 108 | transforms.RandomCrop(32, padding=4), 109 | transforms.RandomHorizontalFlip(), 110 | transforms.ToTensor(), 111 | ]) 112 | transform_test = transforms.Compose([ 113 | transforms.ToTensor(), 114 | ]) 115 | 116 | print('==> Load Test Data') 117 | if args.dataset == "cifar10": 118 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 119 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 120 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 121 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 122 | if args.dataset == "svhn": 123 | trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform_train) 124 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 125 | testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform_test) 126 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 127 | 128 | print('==> Load Model') 129 | if args.net == "smallcnn": 130 | model = SmallCNN().cuda() 131 | net = "smallcnn" 132 | if args.net == "resnet18": 133 | model = ResNet18().cuda() 134 | net = "resnet18" 135 | if args.net == "WRN": 136 | # e.g., WRN-34-10 137 | model = Wide_ResNet(depth=args.depth, num_classes=10, widen_factor=args.width_factor, dropRate=args.drop_rate).cuda() 138 | net = "WRN{}-{}-dropout{}".format(args.depth, args.width_factor, args.drop_rate) 139 | if args.net == 'WRN_madry': 140 | # e.g., WRN-32-10 141 | model = Wide_ResNet_Madry(depth=args.depth, num_classes=10, widen_factor=args.width_factor, dropRate=args.drop_rate).cuda() 142 | net = "WRN_madry{}-{}-dropout{}".format(args.depth, args.width_factor, args.drop_rate) 143 | print(net) 144 | 145 | model = torch.nn.DataParallel(model) 146 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 147 | 148 | start_epoch = 0 149 | # Resume 150 | title = 'FAT train' 151 | if args.resume: 152 | # resume directly point to checkpoint.pth.tar e.g., --resume='./out-dir/checkpoint.pth.tar' 153 | print('==> Friendly Adversarial Training Resuming from checkpoint ..') 154 | print(args.resume) 155 | assert os.path.isfile(args.resume) 156 | out_dir = os.path.dirname(args.resume) 157 | checkpoint = torch.load(args.resume) 158 | start_epoch = checkpoint['epoch'] 159 | model.load_state_dict(checkpoint['state_dict']) 160 | optimizer.load_state_dict(checkpoint['optimizer']) 161 | logger_test = Logger(os.path.join(out_dir, 'log_results.txt'), title=title, resume=True) 162 | else: 163 | print('==> Friendly Adversarial Training') 164 | logger_test = Logger(os.path.join(args.out_dir, 'log_results.txt'), title=title) 165 | logger_test.set_names(['Epoch', 'Natural Test Acc', 'FGSM Acc', 'PGD20 Acc', 'CW Acc']) 166 | 167 | test_nat_acc = 0 168 | fgsm_acc = 0 169 | test_pgd20_acc = 0 170 | cw_acc = 0 171 | best_epoch = 0 172 | for epoch in range(start_epoch, args.epochs): 173 | adjust_learning_rate(optimizer, epoch + 1) 174 | train_time, train_loss, bp_count_avg = train(model, train_loader, optimizer, adjust_tau(epoch + 1, args.dynamictau)) 175 | 176 | ## Evalutions the same as DAT. 177 | loss, test_nat_acc = attack.eval_clean(model, test_loader) 178 | loss, fgsm_acc = attack.eval_robust(model, test_loader, perturb_steps=1, epsilon=0.031, step_size=0.031,loss_fn="cent", category="Madry",rand_init=True) 179 | loss, test_pgd20_acc = attack.eval_robust(model, test_loader, perturb_steps=20, epsilon=0.031, step_size=0.031 / 4,loss_fn="cent", category="Madry", rand_init=True) 180 | loss, cw_acc = attack.eval_robust(model, test_loader, perturb_steps=30, epsilon=0.031, step_size=0.031 / 4,loss_fn="cw", category="Madry", rand_init=True) 181 | 182 | print( 183 | 'Epoch: [%d | %d] | Train Time: %.2f s | BP Average: %.2f | Natural Test Acc %.2f | FGSM Test Acc %.2f | PGD20 Test Acc %.2f | CW Test Acc %.2f |\n' % ( 184 | epoch + 1, 185 | args.epochs, 186 | train_time, 187 | bp_count_avg, 188 | test_nat_acc, 189 | fgsm_acc, 190 | test_pgd20_acc, 191 | cw_acc) 192 | ) 193 | 194 | logger_test.append([epoch + 1, test_nat_acc, fgsm_acc, test_pgd20_acc, cw_acc]) 195 | 196 | save_checkpoint({ 197 | 'epoch': epoch + 1, 198 | 'state_dict': model.state_dict(), 199 | 'bp_avg': bp_count_avg, 200 | 'test_nat_acc': test_nat_acc, 201 | 'test_pgd20_acc': test_pgd20_acc, 202 | 'optimizer': optimizer.state_dict(), 203 | }) 204 | -------------------------------------------------------------------------------- /FAT_for_MART.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torchvision 4 | import torch.optim as optim 5 | from torchvision import transforms 6 | import datetime 7 | from models import * 8 | from earlystop import earlystop 9 | import numpy as np 10 | import attack_generator as attack 11 | from utils import Logger 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch Friendly Adversarial Training for MART') 14 | parser.add_argument('--epochs', type=int, default=90, metavar='N', help='number of epochs to train') 15 | parser.add_argument('--weight_decay', '--wd', default=2e-4, type=float, metavar='W') 16 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate') 17 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum') 18 | parser.add_argument('--epsilon', type=float, default=0.031, help='perturbation bound') 19 | parser.add_argument('--num_steps', type=int, default=10, help='maximum perturbation step K') 20 | parser.add_argument('--step_size', type=float, default=0.007, help='step size') 21 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed') 22 | parser.add_argument('--net', type=str, default="WRN",help="decide which network to use,choose from smallcnn,resnet18,WRN") 23 | parser.add_argument('--tau', type=int, default=0, help='step tau') 24 | parser.add_argument('--beta',type=float,default=6.0,help='regularization parameter') 25 | parser.add_argument('--dataset', type=str, default="cifar10", help="choose from cifar10,svhn") 26 | parser.add_argument('--rand_init', type=bool, default=True, help="whether to initialize adversarial sample with random noise") 27 | parser.add_argument('--omega', type=float, default=0.0, help="random sample parameter") 28 | parser.add_argument('--dynamictau', type=bool, default=True, help='whether to use dynamic tau') 29 | parser.add_argument('--depth', type=int, default=34, help='WRN depth') 30 | parser.add_argument('--width_factor', type=int, default=10, help='WRN width factor') 31 | parser.add_argument('--drop_rate', type=float, default=0.0, help='WRN drop rate') 32 | parser.add_argument('--out_dir',type=str,default='./FAT_for_MART_results',help='dir of output') 33 | parser.add_argument('--resume', type=str, default='', help='whether to resume training, default: None') 34 | 35 | args = parser.parse_args() 36 | 37 | # settings 38 | torch.manual_seed(args.seed) 39 | np.random.seed(args.seed) 40 | torch.cuda.manual_seed_all(args.seed) 41 | torch.backends.cudnn.deterministic = False 42 | torch.backends.cudnn.benchmark = True 43 | 44 | out_dir = args.out_dir 45 | if not os.path.exists(out_dir): 46 | os.makedirs(out_dir) 47 | 48 | def MART_loss(adv_logits, natural_logits, target, beta): 49 | # Based on the repo MART https://github.com/YisenWang/MART 50 | kl = nn.KLDivLoss(reduction='none') 51 | batch_size = len(target) 52 | adv_probs = F.softmax(adv_logits, dim=1) 53 | tmp1 = torch.argsort(adv_probs, dim=1)[:, -2:] 54 | new_y = torch.where(tmp1[:, -1] == target, tmp1[:, -2], tmp1[:, -1]) 55 | loss_adv = F.cross_entropy(adv_logits, target) + F.nll_loss(torch.log(1.0001 - adv_probs + 1e-12), new_y) 56 | nat_probs = F.softmax(natural_logits, dim=1) 57 | true_probs = torch.gather(nat_probs, 1, (target.unsqueeze(1)).long()).squeeze() 58 | loss_robust = (1.0 / batch_size) * torch.sum( 59 | torch.sum(kl(torch.log(adv_probs + 1e-12), nat_probs), dim=1) * (1.0000001 - true_probs)) 60 | loss = loss_adv + float(beta) * loss_robust 61 | return loss 62 | 63 | def train(model, train_loader, optimizer, tau): 64 | starttime = datetime.datetime.now() 65 | loss_sum = 0 66 | bp_count = 0 67 | for batch_idx, (data, target) in enumerate(train_loader): 68 | data, target = data.cuda(), target.cuda() 69 | 70 | # Get friendly adversarial training data via early-stopped PGD 71 | output_adv, output_target, output_natural, count = earlystop(model, data, target, step_size=args.step_size, 72 | epsilon=args.epsilon, perturb_steps=args.num_steps, 73 | tau=tau, randominit_type="normal_distribution_randominit", loss_fn='cent', rand_init=args.rand_init, 74 | omega=args.omega) 75 | bp_count += count 76 | model.train() 77 | optimizer.zero_grad() 78 | 79 | adv_logits = model(output_adv) 80 | natural_logits = model(output_natural) 81 | 82 | # calculate MART adversarial training loss 83 | loss = MART_loss(adv_logits, natural_logits, output_target, args.beta) 84 | 85 | loss_sum += loss.item() 86 | loss.backward() 87 | optimizer.step() 88 | 89 | bp_count_avg = bp_count / len(train_loader.dataset) 90 | endtime = datetime.datetime.now() 91 | time = (endtime - starttime).seconds 92 | 93 | return time, loss_sum, bp_count_avg 94 | 95 | def adjust_tau(epoch, dynamictau): 96 | tau = args.tau 97 | if dynamictau: 98 | if epoch <= 20: 99 | tau = 0 100 | elif epoch <= 40: 101 | tau = 1 102 | elif epoch <= 60: 103 | tau = 2 104 | elif epoch <= 80: 105 | tau = 3 106 | else: 107 | tau = 4 108 | return tau 109 | 110 | def adjust_learning_rate(optimizer, epoch): 111 | """decrease the learning rate""" 112 | lr = args.lr 113 | if epoch >= 60: 114 | lr = args.lr * 0.1 115 | if epoch >= 90: 116 | lr = args.lr * 0.01 117 | for param_group in optimizer.param_groups: 118 | param_group['lr'] = lr 119 | 120 | def save_checkpoint(state, checkpoint=out_dir, filename='checkpoint.pth.tar'): 121 | filepath = os.path.join(checkpoint, filename) 122 | torch.save(state, filepath) 123 | 124 | # setup data loader 125 | transform_train = transforms.Compose([ 126 | transforms.RandomCrop(32, padding=4), 127 | transforms.RandomHorizontalFlip(), 128 | transforms.ToTensor(), 129 | ]) 130 | transform_test = transforms.Compose([ 131 | transforms.ToTensor(), 132 | ]) 133 | 134 | print('==> Load Test Data') 135 | if args.dataset == "cifar10": 136 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 137 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 138 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 139 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 140 | if args.dataset == "svhn": 141 | trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform_train) 142 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 143 | testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform_test) 144 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 145 | 146 | print('==> Load Model') 147 | if args.net == "smallcnn": 148 | model = SmallCNN().cuda() 149 | net = "smallcnn" 150 | if args.net == "resnet18": 151 | model = ResNet18().cuda() 152 | net = "resnet18" 153 | if args.net == "WRN": 154 | model = Wide_ResNet(depth=args.depth, num_classes=10, widen_factor=args.width_factor, dropRate=args.drop_rate).cuda() 155 | net = "WRN{}-{}-dropout{}".format(args.depth, args.width_factor, args.drop_rate) 156 | model = torch.nn.DataParallel(model) 157 | print(net) 158 | 159 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 160 | 161 | if not os.path.exists(out_dir): 162 | os.makedirs(out_dir) 163 | 164 | start_epoch = 0 165 | # Resume 166 | title = 'FAT for MART train' 167 | if args.resume: 168 | # resume directly point to checkpoint.pth.tar e.g., --resume='./out-dir/checkpoint.pth.tar' 169 | print ('==> Friendly Adversarial Training for MART Resuming from checkpoint ..') 170 | print(args.resume) 171 | assert os.path.isfile(args.resume) 172 | out_dir = os.path.dirname(args.resume) 173 | checkpoint = torch.load(args.resume) 174 | start_epoch = checkpoint['epoch'] 175 | model.load_state_dict(checkpoint['state_dict']) 176 | optimizer.load_state_dict(checkpoint['optimizer']) 177 | logger_test = Logger(os.path.join(out_dir, 'log_results.txt'), title=title, resume=True) 178 | else: 179 | print('==> Friendly Adversarial Training for MART') 180 | logger_test = Logger(os.path.join(out_dir, 'log_results.txt'), title=title) 181 | logger_test.set_names(['Epoch', 'Natural Test Acc', 'FGSM Acc', 'PGD20 Acc', 'CW Acc']) 182 | 183 | 184 | test_nat_acc = 0 185 | fgsm_acc = 0 186 | test_pgd20_acc = 0 187 | cw_acc = 0 188 | for epoch in range(start_epoch, args.epochs): 189 | adjust_learning_rate(optimizer, epoch + 1) 190 | train_time, train_loss, bp_count_avg = train(model, train_loader, optimizer, adjust_tau(epoch + 1, args.dynamictau)) 191 | 192 | ## Evalutions the same as TRADES. 193 | loss, test_nat_acc = attack.eval_clean(model, test_loader) 194 | loss, fgsm_acc = attack.eval_robust(model, test_loader, perturb_steps=1, epsilon=0.031, step_size=0.031,loss_fn="cent", category="Madry",rand_init=True) 195 | loss, test_pgd20_acc = attack.eval_robust(model,test_loader, perturb_steps=20, epsilon=0.031, step_size=0.003,loss_fn="cent",category="Madry",rand_init=True) 196 | loss, cw_acc = attack.eval_robust(model,test_loader, perturb_steps=30, epsilon=0.031, step_size=0.003,loss_fn="cw",category="Madry",rand_init=True) 197 | 198 | print( 199 | 'Epoch: [%d | %d] | Train Time: %.2f s | BP Average: %.2f | Natural Test Acc %.2f | FGSM Test Acc %.2f | PGD20 Test Acc %.2f | CW Test Acc %.2f |\n' % ( 200 | epoch + 1, 201 | args.epochs, 202 | train_time, 203 | bp_count_avg, 204 | test_nat_acc, 205 | fgsm_acc, 206 | test_pgd20_acc, 207 | cw_acc) 208 | ) 209 | 210 | logger_test.append([epoch + 1, test_nat_acc, fgsm_acc, test_pgd20_acc, cw_acc]) 211 | 212 | save_checkpoint({ 213 | 'epoch': epoch + 1, 214 | 'state_dict': model.state_dict(), 215 | 'bp_avg': bp_count_avg, 216 | 'test_nat_acc': test_nat_acc, 217 | 'test_pgd20_acc': test_pgd20_acc, 218 | 'optimizer': optimizer.state_dict(), 219 | }) 220 | -------------------------------------------------------------------------------- /FAT_for_TRADES.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torchvision 4 | import torch.optim as optim 5 | from torchvision import transforms 6 | import datetime 7 | from models import * 8 | from earlystop import earlystop 9 | import numpy as np 10 | import attack_generator as attack 11 | from utils import Logger 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch Friendly Adversarial Training for TRADES') 14 | parser.add_argument('--epochs', type=int, default=85, metavar='N', help='number of epochs to train') 15 | parser.add_argument('--weight_decay', '--wd', default=2e-4, type=float, metavar='W') 16 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate') 17 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum') 18 | parser.add_argument('--epsilon', type=float, default=0.031, help='perturbation bound') 19 | parser.add_argument('--num_steps', type=int, default=10, help='maximum perturbation step K') 20 | parser.add_argument('--step_size', type=float, default=0.007, help='step size') 21 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed') 22 | parser.add_argument('--net', type=str, default="WRN",help="decide which network to use,choose from smallcnn,resnet18,WRN") 23 | parser.add_argument('--tau', type=int, default=0, help='step tau') 24 | parser.add_argument('--beta',type=float,default=6.0,help='regularization parameter') 25 | parser.add_argument('--dataset', type=str, default="cifar10", help="choose from cifar10,svhn") 26 | parser.add_argument('--rand_init', type=bool, default=True, help="whether to initialize adversarial sample with random noise") 27 | parser.add_argument('--omega', type=float, default=0.0, help="random sample parameter") 28 | parser.add_argument('--dynamictau', type=bool, default=True, help='whether to use dynamic tau') 29 | parser.add_argument('--depth', type=int, default=34, help='WRN depth') 30 | parser.add_argument('--width_factor', type=int, default=10, help='WRN width factor') 31 | parser.add_argument('--drop_rate', type=float, default=0.0, help='WRN drop rate') 32 | parser.add_argument('--out_dir',type=str,default='./FAT_for_TRADES_results',help='dir of output') 33 | parser.add_argument('--resume', type=str, default='', help='whether to resume training, default: None') 34 | 35 | args = parser.parse_args() 36 | 37 | # settings 38 | torch.manual_seed(args.seed) 39 | np.random.seed(args.seed) 40 | torch.cuda.manual_seed_all(args.seed) 41 | torch.backends.cudnn.deterministic = False 42 | torch.backends.cudnn.benchmark = True 43 | 44 | out_dir = args.out_dir 45 | if not os.path.exists(out_dir): 46 | os.makedirs(out_dir) 47 | 48 | def TRADES_loss(adv_logits, natural_logits, target, beta): 49 | # Based on the repo TREADES: https://github.com/yaodongyu/TRADES 50 | batch_size = len(target) 51 | criterion_kl = nn.KLDivLoss(size_average=False).cuda() 52 | loss_natural = nn.CrossEntropyLoss(reduction='mean')(natural_logits, target) 53 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(adv_logits, dim=1), 54 | F.softmax(natural_logits, dim=1)) 55 | loss = loss_natural + beta * loss_robust 56 | return loss 57 | 58 | def train(model, train_loader, optimizer, tau): 59 | starttime = datetime.datetime.now() 60 | loss_sum = 0 61 | bp_count = 0 62 | for batch_idx, (data, target) in enumerate(train_loader): 63 | data, target = data.cuda(), target.cuda() 64 | 65 | # Get friendly adversarial training data via early-stopped PGD 66 | output_adv, output_target, output_natural, count = earlystop(model, data, target, step_size=args.step_size, 67 | epsilon=args.epsilon, perturb_steps=args.num_steps, 68 | tau=tau, randominit_type="normal_distribution_randominit", loss_fn='kl', rand_init=args.rand_init, 69 | omega=args.omega) 70 | bp_count += count 71 | model.train() 72 | optimizer.zero_grad() 73 | 74 | natural_logits = model(output_natural) 75 | adv_logits = model(output_adv) 76 | 77 | # calculate TRADES adversarial training loss 78 | loss = TRADES_loss(adv_logits,natural_logits,output_target,args.beta) 79 | 80 | loss_sum += loss.item() 81 | loss.backward() 82 | optimizer.step() 83 | 84 | bp_count_avg = bp_count / len(train_loader.dataset) 85 | endtime = datetime.datetime.now() 86 | time = (endtime - starttime).seconds 87 | 88 | return time, loss_sum, bp_count_avg 89 | 90 | def adjust_tau(epoch, dynamictau): 91 | tau = args.tau 92 | if dynamictau: 93 | if epoch <= 30: 94 | tau = 0 95 | elif epoch <= 50: 96 | tau = 1 97 | elif epoch <= 70: 98 | tau = 2 99 | else: 100 | tau = 3 101 | return tau 102 | 103 | def adjust_learning_rate(optimizer, epoch): 104 | """decrease the learning rate""" 105 | lr = args.lr 106 | if epoch >= 75: 107 | lr = args.lr * 0.1 108 | if epoch >= 90: 109 | lr = args.lr * 0.01 110 | for param_group in optimizer.param_groups: 111 | param_group['lr'] = lr 112 | 113 | def save_checkpoint(state, checkpoint=out_dir, filename='checkpoint.pth.tar'): 114 | filepath = os.path.join(checkpoint, filename) 115 | torch.save(state, filepath) 116 | 117 | # setup data loader 118 | transform_train = transforms.Compose([ 119 | transforms.RandomCrop(32, padding=4), 120 | transforms.RandomHorizontalFlip(), 121 | transforms.ToTensor(), 122 | ]) 123 | transform_test = transforms.Compose([ 124 | transforms.ToTensor(), 125 | ]) 126 | 127 | print('==> Load Test Data') 128 | if args.dataset == "cifar10": 129 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 130 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 131 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 132 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 133 | if args.dataset == "svhn": 134 | trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform_train) 135 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) 136 | testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform_test) 137 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 138 | 139 | print('==> Load Model') 140 | if args.net == "smallcnn": 141 | model = SmallCNN().cuda() 142 | net = "smallcnn" 143 | if args.net == "resnet18": 144 | model = ResNet18().cuda() 145 | net = "resnet18" 146 | if args.net == "WRN": 147 | model = Wide_ResNet(depth=args.depth, num_classes=10, widen_factor=args.width_factor, dropRate=args.drop_rate).cuda() 148 | net = "WRN{}-{}-dropout{}".format(args.depth, args.width_factor, args.drop_rate) 149 | model = torch.nn.DataParallel(model) 150 | print(net) 151 | 152 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 153 | 154 | if not os.path.exists(out_dir): 155 | os.makedirs(out_dir) 156 | 157 | start_epoch = 0 158 | # Resume 159 | title = 'FAT for TRADES train' 160 | if args.resume: 161 | # resume directly point to checkpoint.pth.tar e.g., --resume='./out-dir/checkpoint.pth.tar' 162 | print ('==> Adversarial Training Resuming from checkpoint ..') 163 | print(args.resume) 164 | assert os.path.isfile(args.resume) 165 | out_dir = os.path.dirname(args.resume) 166 | checkpoint = torch.load(args.resume) 167 | start_epoch = checkpoint['epoch'] 168 | model.load_state_dict(checkpoint['state_dict']) 169 | optimizer.load_state_dict(checkpoint['optimizer']) 170 | logger_test = Logger(os.path.join(out_dir, 'log_results.txt'), title=title, resume=True) 171 | else: 172 | print('==> Friendly Adversarial Training for TRADES') 173 | logger_test = Logger(os.path.join(out_dir, 'log_results.txt'), title=title) 174 | logger_test.set_names(['Epoch', 'Natural Test Acc', 'FGSM Acc', 'PGD20 Acc', 'CW Acc']) 175 | 176 | test_nat_acc = 0 177 | fgsm_acc = 0 178 | test_pgd20_acc = 0 179 | cw_acc = 0 180 | for epoch in range(start_epoch, args.epochs): 181 | adjust_learning_rate(optimizer, epoch + 1) 182 | train_time, train_loss, bp_count_avg = train(model, train_loader, optimizer, adjust_tau(epoch + 1, args.dynamictau)) 183 | 184 | ## Evalutions the same as TRADES. 185 | loss, test_nat_acc = attack.eval_clean(model, test_loader) 186 | loss, fgsm_acc = attack.eval_robust(model, test_loader, perturb_steps=1, epsilon=0.031, step_size=0.031,loss_fn="cent", category="Madry",rand_init=True) 187 | loss, test_pgd20_acc = attack.eval_robust(model,test_loader, perturb_steps=20, epsilon=0.031, step_size=0.003,loss_fn="cent",category="Madry",rand_init=True) 188 | loss, cw_acc = attack.eval_robust(model,test_loader, perturb_steps=30, epsilon=0.031, step_size=0.003,loss_fn="cw",category="Madry",rand_init=True) 189 | 190 | print( 191 | 'Epoch: [%d | %d] | Train Time: %.2f s | BP Average: %.2f | Natural Test Acc %.2f | FGSM Test Acc %.2f | PGD20 Test Acc %.2f | CW Test Acc %.2f |\n' % ( 192 | epoch + 1, 193 | args.epochs, 194 | train_time, 195 | bp_count_avg, 196 | test_nat_acc, 197 | fgsm_acc, 198 | test_pgd20_acc, 199 | cw_acc) 200 | ) 201 | 202 | logger_test.append([epoch + 1, test_nat_acc, fgsm_acc, test_pgd20_acc, cw_acc]) 203 | 204 | save_checkpoint({ 205 | 'epoch': epoch + 1, 206 | 'state_dict': model.state_dict(), 207 | 'bp_avg' : bp_count_avg, 208 | 'test_nat_acc': test_nat_acc, 209 | 'test_pgd20_acc': test_pgd20_acc, 210 | 'optimizer': optimizer.state_dict(), 211 | }) 212 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Friendly Adversarial Training Code 2 | 3 | This repository provides codes for friendly adversarial training (FAT). 4 | 5 | ICML 2020 Paper: **Attacks Which Do Not Kill Training Make Adversarial Learning Stronger** (https://arxiv.org/abs/2002.11242) 6 | *Jingfeng Zhang\*, Xilie Xu\*, Bo Han, Gang Niu, Lizhen Cui, Masashi Sugiyama and Mohan Kankanhalli* 7 | 8 | ## What is the nature of the adversarial training? 9 | Adversarial data can easily fool the standard trained classifier. 10 | Adversarial training employs the adversarial data into the training process. 11 | Adversarial training aims to achieve two purposes (a) correctly classify the data, and (b) make the decision boundary thick so that no data fall inside the decision boundary. 12 |

13 | 14 |

15 |

16 | The purposes of the adversarial training 17 |

18 | 19 | 20 | ## Conventional formulation of the adversarial training 21 | 22 | Conventional adversarial training is based on the minimax formulation: 23 | 24 | ![](http://latex.codecogs.com/gif.latex?\min_{f\in\mathcal{F}}\frac{1}{n}\sum_{i=1}^n\ell(f(\tilde{x}_i),y_i),) 25 | 26 | where 27 | 28 | ![](http://latex.codecogs.com/gif.latex?\tilde{x}_i=\mathrm{arg\max}_{\tilde{x}\in\mathcal{B}_\epsilon[x_i]}\ell(f(\tilde{x}),y_i).) 29 | 30 | Inside, there is maximization where we find **the most adversarial data**. Outside, there is minimization where we find a classifier to fit those generated adversarial data. 31 | 32 | ### The minimax formulation is pessimistic. 33 | 34 | The minimax-based adversarial training causes the severe degradation of the natural generalization. Why? 35 | The minimax-based adversarial training has a severe cross-over mixture problem: the adversarial data of different classes overshoot into the peer areas. Learning from those adversarial data is very difficult. 36 |

37 | 38 |

39 |

40 | Cross-over mixture problem of the minimax-based adversarial training 41 |

42 | 43 | ## Our **min-min formulation** for the adversarial training. 44 | 45 | The outer minimization keeps the same. Instead of generating adversarial data via the inner maximization, we generate **the friendly adversarial data** minimizing the loss value. There are two constraints (a) the adversarial data is misclassified, and (b) the wrong prediction of the adversarial data is better than the desired prediction by at least a margin ![](http://latex.codecogs.com/gif.latex?\rho.) 46 | 47 | ![](http://latex.codecogs.com/gif.latex?\tilde{x}_i=\mathrm{arg\min}_{\tilde{x}\in\mathcal{B}_\epsilon[x_i]}\ell(f(\tilde{x}),y_i)\quad\mathrm{s.t.}\quad\ell(f(\tilde{x}),y_i)-\min_{y\in\mathcal{Y}}\ell(f(\tilde{x}),y)\ge\rho) 48 | 49 | 50 | Let us look at comparisons between minimax formulation and min-min formulation. 51 |

52 | 53 |

54 |

55 | Comparisons between minimax formulation and min-min formulation 56 |

57 | 58 | ## A Realization of the Min-min Formulation --- Friendly Adversarial Training (FAT) 59 | 60 | Friendly adversarial training (FAT) employs the friendly adversarial data generated by **early stopped PGD** to update the model. 61 | The early stopped PGD stop the PGD interations once the adversarial data is misclassified. (Controlled by the hyperparameter ```tau``` in the code. Noted that when ```tau``` equal to maximum perturbation step ```num_steps```, our FAT makes the conventional adversarial training e.g., [AT](https://arxiv.org/abs/1706.06083), [TRADES](https://arxiv.org/abs/1901.08573), and [MART](https://openreview.net/forum?id=rklOg6EFwS) as our special cases.) 62 |

63 | 64 |

65 |

66 | Conventional adversarial training employs PGD for searching most adversarial data. Friendly adversarial training employs early stopped PGD for searching friendly adversarial data. 67 |

68 | 69 | ## Preferred Prerequisites 70 | 71 | * Python (3.6) 72 | * Pytorch (1.2.0) 73 | * CUDA 74 | * numpy 75 | 76 | 77 | ## Running FAT, FAT for TRADES, FAT for MART on benchmark datasets (CIFAR-10 and SVHN) 78 | 79 | Here are examples: 80 | * Train WRN-32-10 model on CIFAR-10 and compare our results with [AT](https://arxiv.org/abs/1706.06083), [CAT](https://arxiv.org/abs/1805.04807) and [DAT](http://proceedings.mlr.press/v97/wang19i/wang19i.pdf): 81 | ```bash 82 | CUDA_VISIBLE_DEVICES='0' python FAT.py --epsilon 0.031 83 | CUDA_VISIBLE_DEVICES='0' python FAT.py --epsilon 0.062 84 | ``` 85 | ### White-box evaluations on WRN-32-10 86 | 87 | | Defense | Natural Acc. | FGSM Acc. | PGD-20 Acc. | C&W Acc. | 88 | |-----------------------|-----------------------|------------------|-----------------|-----------------| 89 | |[AT(Madry)](https://arxiv.org/abs/1706.06083) | 87.30% | 56.10% | 45.80% | 46.80% 90 | | [CAT](https://arxiv.org/abs/1805.04807) | 77.43% | 57.17% | 46.06% | 42.28% 91 | | [DAT](http://proceedings.mlr.press/v97/wang19i/wang19i.pdf) | 85.03% | 63.53% | 48.70% | 47.27% 92 | | FAT (![](http://latex.codecogs.com/gif.latex?\epsilon=8/255)) | **89.34**![](http://latex.codecogs.com/gif.latex?\pm)0.221% |65.52![](http://latex.codecogs.com/gif.latex?\pm)0.355%| 46.13![](http://latex.codecogs.com/gif.latex?\pm)0.049%| 46.82![](http://latex.codecogs.com/gif.latex?\pm)0.517% 93 | | FAT (![](http://latex.codecogs.com/gif.latex?\epsilon=16/255)) | 87.00![](http://latex.codecogs.com/gif.latex?\pm)0.203%| **65.94**![](http://latex.codecogs.com/gif.latex?\pm)0.244%|**49.86**![](http://latex.codecogs.com/gif.latex?\pm)0.328%|**48.65**![](http://latex.codecogs.com/gif.latex?\pm)0.176% 94 | 95 | Results of AT(Madry), CAT and DAT are reported in [DAT](http://proceedings.mlr.press/v97/wang19i/wang19i.pdf). FAT has the same evaluations. 96 | 97 | * Train WRN-34-10 model on CIFAR-10 and compare our results with [TRADES](https://arxiv.org/abs/1901.08573), and [MART](https://openreview.net/forum?id=rklOg6EFwS). 98 | ```bash 99 | CUDA_VISIBLE_DEVICES='0' python FAT_for_TRADES.py --epsilon 0.031 100 | CUDA_VISIBLE_DEVICES='0' python FAT_for_TRADES.py --epsilon 0.062 101 | CUDA_VISIBLE_DEVICES='0' python FAT_for_MART.py --epsilon 0.031 102 | CUDA_VISIBLE_DEVICES='0' python FAT_for_MART.py --epsilon 0.062 103 | ``` 104 | 105 | ### White-box evaluations on WRN-34-10 106 | 107 | | Defense | Natural Acc. | FGSM Acc. | PGD-20 Acc. | C&W Acc. | 108 | |-----------------------|-----------------------|------------------|-----------------|-----------------| 109 | |[TRADES](https://arxiv.org/abs/1901.08573)(![](http://latex.codecogs.com/gif.latex?\beta=1.0))| 88.64% | 56.38% | 49.14% | - 110 | |FAT for TRADES(![](http://latex.codecogs.com/gif.latex?\beta=1.0,\epsilon=8/255))| **89.94**![](http://latex.codecogs.com/gif.latex?\pm)0.303% |61.00![](http://latex.codecogs.com/gif.latex?\pm)0.418% |49.70![](http://latex.codecogs.com/gif.latex?\pm)0.653%|49.35![](http://latex.codecogs.com/gif.latex?\pm)0.363% 111 | |[TRADES](https://arxiv.org/abs/1901.08573)(![](http://latex.codecogs.com/gif.latex?\beta=6.0))|84.92%|61.06%|56.61%|**54.47**% 112 | |FAT for TRADES(![](http://latex.codecogs.com/gif.latex?\beta=6.0,\epsilon=8/255))| 86.60![](http://latex.codecogs.com/gif.latex?\pm)0.548% |**61.79**![](http://latex.codecogs.com/gif.latex?\pm)0.570% |55.98![](http://latex.codecogs.com/gif.latex?\pm)0.209%|54.29![](http://latex.codecogs.com/gif.latex?\pm)0.173% 113 | |FAT for TRADES(![](http://latex.codecogs.com/gif.latex?\beta=6.0,\epsilon=16/255))| 84.39![](http://latex.codecogs.com/gif.latex?\pm)0.030% |61.73![](http://latex.codecogs.com/gif.latex?\pm)0.131% |**57.12**![](http://latex.codecogs.com/gif.latex?\pm)0.233%|54.36![](http://latex.codecogs.com/gif.latex?\pm)0.177% 114 | 115 | Results of TRADES (![](http://latex.codecogs.com/gif.latex?\beta=1.0) and ![](http://latex.codecogs.com/gif.latex?\beta=6.0)) are reported in [TRADES](https://arxiv.org/abs/1901.08573). FAT for TRADES has the same evaluations. Noted that our evaluations of the above are the same as the description in the TRADES's paper, i.e., adversarial data are generated without random start ```rand_init=False```. 116 | However, in [TRADES’s GitHub](https://github.com/yaodongyu/TRADES), they use random start ```rand_init=True``` before PGD perturbation that is deviated from the statements in their paper. For the fair evaluations of FAT with random start, please refer to the Table 3 in [our paper](https://arxiv.org/pdf/2002.11242.pdf). 117 | 118 | ### How to recover original AT, TRADES, or MART? 119 | Just set ```tau=10```, i.e., 120 | ``` 121 | python FAT.py --epsilon 0.031 --tau 10 --dynamictau False 122 | python FAT_for_TRADES --epsilon 0.031 --tau 10 --dynamictau False 123 | python FAT_for_MART.py --epsilon 0.031 --tau 10 --dynamictau False 124 | ``` 125 | 126 | 127 | ## Want to attack FAT? Sure! 128 | 129 | We welcome various attack methods to attack our defense models. For cifar-10 dataset, we normalize all images into ```[0,1]```. 130 | 131 | Download our pretrained models into the folder ```FAT_models``` through this [Google Drive link](https://drive.google.com/drive/folders/1lV3qob_zR-YpFVGuKiiE5hNu74NID-ZS?usp=sharing) or [Baidu Drive link](https://pan.baidu.com/s/17XBd02FoGFqgYCVy2Fm_SQ)(extraction code: ww7f). 132 | ```bash 133 | cd Friendly-Adversarial-Training 134 | mkdir FAT_models 135 | ``` 136 | Run robustness evaluations. 137 | ```bash 138 | chmod +x attack_test.sh 139 | ./attack_test.sh 140 | ``` 141 | 142 | ## Reference 143 | 144 | ``` 145 | @inproceedings{zhang2020fat, 146 | title={Attacks Which Do Not Kill Training Make Adversarial Learning Stronger}, 147 | author={Zhang, Jingfeng and Xu, Xilie and Han, Bo and Niu, Gang and Cui, Lizhen and Sugiyama, Masashi and Kankanhalli, Mohan}, 148 | booktitle = {ICML}, 149 | year={2020} 150 | } 151 | ``` 152 | 153 | ## Contact 154 | 155 | Please contact jingfeng.zhang@auckland.ac.nz (preferred) OR jingfeng.zhang9660@gmail.com and xuxilie@comp.nus.edu.sg if you have any question on the codes. 156 | -------------------------------------------------------------------------------- /attack_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from models import * 3 | 4 | def cwloss(output, target,confidence=50, num_classes=10): 5 | # Compute the probability of the label class versus the maximum other 6 | # The same implementation as in repo CAT https://github.com/sunblaze-ucb/curriculum-adversarial-training-CAT 7 | target = target.data 8 | target_onehot = torch.zeros(target.size() + (num_classes,)) 9 | target_onehot = target_onehot.cuda() 10 | target_onehot.scatter_(1, target.unsqueeze(1), 1.) 11 | target_var = Variable(target_onehot, requires_grad=False) 12 | real = (target_var * output).sum(1) 13 | other = ((1. - target_var) * output - target_var * 10000.).max(1)[0] 14 | loss = -torch.clamp(real - other + confidence, min=0.) # equiv to max(..., 0.) 15 | loss = torch.sum(loss) 16 | return loss 17 | 18 | def pgd(model, data, target, epsilon, step_size, num_steps,loss_fn,category,rand_init): 19 | model.eval() 20 | if category == "trades": 21 | x_adv = data.detach() + 0.001 * torch.randn(data.shape).cuda().detach() if rand_init else data.detach() 22 | if category == "Madry": 23 | x_adv = data.detach() + torch.from_numpy(np.random.uniform(-epsilon, epsilon, data.shape)).float().cuda() if rand_init else data.detach() 24 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 25 | for k in range(num_steps): 26 | x_adv.requires_grad_() 27 | output = model(x_adv) 28 | model.zero_grad() 29 | with torch.enable_grad(): 30 | if loss_fn == "cent": 31 | loss_adv = nn.CrossEntropyLoss(reduction="mean")(output, target) 32 | if loss_fn == "cw": 33 | loss_adv = cwloss(output,target) 34 | loss_adv.backward() 35 | eta = step_size * x_adv.grad.sign() 36 | x_adv = x_adv.detach() + eta 37 | x_adv = torch.min(torch.max(x_adv, data - epsilon), data + epsilon) 38 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 39 | return x_adv 40 | 41 | def eval_clean(model, test_loader): 42 | model.eval() 43 | test_loss = 0 44 | correct = 0 45 | with torch.no_grad(): 46 | for data, target in test_loader: 47 | data, target = data.cuda(), target.cuda() 48 | output = model(data) 49 | test_loss += nn.CrossEntropyLoss(reduction='mean')(output, target).item() 50 | pred = output.max(1, keepdim=True)[1] 51 | correct += pred.eq(target.view_as(pred)).sum().item() 52 | test_loss /= len(test_loader.dataset) 53 | log = 'Natrual Test Result ==> Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format( 54 | test_loss, correct, len(test_loader.dataset), 55 | 100. * correct / len(test_loader.dataset)) 56 | # print(log) 57 | test_accuracy = correct / len(test_loader.dataset) 58 | return test_loss, test_accuracy 59 | 60 | def eval_robust(model, test_loader, perturb_steps, epsilon, step_size, loss_fn, category, rand_init): 61 | model.eval() 62 | test_loss = 0 63 | correct = 0 64 | with torch.enable_grad(): 65 | for data, target in test_loader: 66 | data, target = data.cuda(), target.cuda() 67 | x_adv = pgd(model,data,target,epsilon,step_size,perturb_steps,loss_fn,category,rand_init=rand_init) 68 | output = model(x_adv) 69 | test_loss += nn.CrossEntropyLoss(reduction='mean')(output, target).item() 70 | pred = output.max(1, keepdim=True)[1] 71 | correct += pred.eq(target.view_as(pred)).sum().item() 72 | test_loss /= len(test_loader.dataset) 73 | log = 'Attack Setting ==> Loss_fn:{}, Perturb steps:{}, Epsilon:{}, Step dize:{} \n Test Result ==> Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(loss_fn,perturb_steps,epsilon,step_size, 74 | test_loss, correct, len(test_loader.dataset), 75 | 100. * correct / len(test_loader.dataset)) 76 | # print(log) 77 | test_accuracy = correct / len(test_loader.dataset) 78 | return test_loss, test_accuracy 79 | 80 | -------------------------------------------------------------------------------- /attack_test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.nn as nn 3 | import torchvision 4 | from torchvision import transforms 5 | from models import * 6 | import attack_generator as attack 7 | 8 | parser = argparse.ArgumentParser(description='PyTorch White-box Adversarial Attack Test') 9 | parser.add_argument('--net', type=str, default="WRN", help="decide which network to use,choose from smallcnn,resnet18,WRN") 10 | parser.add_argument('--dataset', type=str, default="cifar10", help="choose from cifar10,svhn") 11 | parser.add_argument('--depth', type=int, default=34, help='WRN depth') 12 | parser.add_argument('--width_factor', type=int, default=10,help='WRN width factor') 13 | parser.add_argument('--drop_rate', type=float,default=0.0, help='WRN drop rate') 14 | parser.add_argument('--attack_method', type=str,default="dat", help = "choose form: dat and trades") 15 | parser.add_argument('--model_path', default='./FAT_models/fat_for_trades_wrn34-10_eps0.031_beta1.0.pth.tar', help='model for white-box attack evaluation') 16 | parser.add_argument('--method',type=str,default='dat',help='select attack setting following DAT or TRADES') 17 | 18 | args = parser.parse_args() 19 | 20 | transform_test = transforms.Compose([ 21 | transforms.ToTensor(), 22 | ]) 23 | 24 | print('==> Load Test Data') 25 | if args.dataset == "cifar10": 26 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 27 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 28 | if args.dataset == "svhn": 29 | testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform_test) 30 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2) 31 | 32 | print('==> Load Model') 33 | if args.net == "smallcnn": 34 | model = SmallCNN().cuda() 35 | net = "smallcnn" 36 | if args.net == "resnet18": 37 | model = ResNet18().cuda() 38 | net = "resnet18" 39 | if args.net == "WRN": 40 | ## WRN-34-10 41 | model = Wide_ResNet(depth=args.depth, num_classes=10, widen_factor=args.width_factor, dropRate=args.drop_rate).cuda() 42 | net = "WRN{}-{}-dropout{}".format(args.depth,args.width_factor,args.drop_rate) 43 | if args.net == 'WRN_madry': 44 | ## WRN-32-10 45 | model = Wide_ResNet_Madry(depth=args.depth, num_classes=10, widen_factor=args.width_factor, dropRate=args.drop_rate).cuda() 46 | net = "WRN_madry{}-{}-dropout{}".format(args.depth, args.width_factor, args.drop_rate) 47 | model = torch.nn.DataParallel(model) 48 | print(net) 49 | 50 | model.load_state_dict(torch.load(args.model_path)['state_dict']) 51 | 52 | print('==> Evaluating Performance under White-box Adversarial Attack') 53 | 54 | loss, test_nat_acc = attack.eval_clean(model, test_loader) 55 | print('Natural Test Accuracy: {:.2f}%'.format(100. * test_nat_acc)) 56 | if args.method == "dat": 57 | # Evalutions the same as DAT. 58 | loss, fgsm_acc = attack.eval_robust(model, test_loader, perturb_steps=1, epsilon=0.031, step_size=0.031,loss_fn="cent", category="Madry",rand_init=True) 59 | print('FGSM Test Accuracy: {:.2f}%'.format(100. * fgsm_acc)) 60 | loss, pgd20_acc = attack.eval_robust(model, test_loader, perturb_steps=20, epsilon=0.031, step_size=0.031 / 4,loss_fn="cent", category="Madry", rand_init=True) 61 | print('PGD20 Test Accuracy: {:.2f}%'.format(100. * pgd20_acc)) 62 | loss, cw_acc = attack.eval_robust(model, test_loader, perturb_steps=30, epsilon=0.031, step_size=0.031 / 4,loss_fn="cw", category="Madry", rand_init=True) 63 | print('CW Test Accuracy: {:.2f}%'.format(100. * cw_acc)) 64 | if args.method == 'trades': 65 | # Evalutions the same as TRADES. 66 | # wri : with random init, wori : without random init 67 | loss, fgsm_wori_acc = attack.eval_robust(model, test_loader, perturb_steps=1, epsilon=0.031, step_size=0.031,loss_fn="cent", category="Madry",rand_init=False) 68 | print('FGSM without Random Start Test Accuracy: {:.2f}%'.format(100. * fgsm_wori_acc)) 69 | loss, pgd20_wori_acc = attack.eval_robust(model,test_loader, perturb_steps=20, epsilon=0.031, step_size=0.003,loss_fn="cent",category="Madry",rand_init=False) 70 | print('PGD20 without Random Start Test Accuracy: {:.2f}%'.format(100. * pgd20_wori_acc)) 71 | loss, cw_wori_acc = attack.eval_robust(model,test_loader, perturb_steps=30, epsilon=0.031, step_size=0.003,loss_fn="cw",category="Madry",rand_init=False) 72 | print('CW without Random Start Test Accuracy: {:.2f}%'.format(100. * cw_wori_acc)) 73 | loss, fgsm_wri_acc = attack.eval_robust(model, test_loader, perturb_steps=1, epsilon=0.031, step_size=0.031,loss_fn="cent", category="Madry",rand_init=True) 74 | print('FGSM with Random Start Test Accuracy: {:.2f}%'.format(100. * fgsm_wri_acc)) 75 | loss, pgd20_wri_acc = attack.eval_robust(model,test_loader, perturb_steps=20, epsilon=0.031, step_size=0.003,loss_fn="cent",category="Madry",rand_init=True) 76 | print('PGD20 with Random Start Test Accuracy: {:.2f}%'.format(100. * pgd20_wri_acc)) 77 | loss, cw_wri_acc = attack.eval_robust(model,test_loader, perturb_steps=30, epsilon=0.031, step_size=0.003,loss_fn="cw",category="Madry",rand_init=True) 78 | print('CW with Random Start Test Accuracy: {:.2f}%'.format(100. * cw_wri_acc)) 79 | -------------------------------------------------------------------------------- /attack_test.sh: -------------------------------------------------------------------------------- 1 | python attack_test.py --net 'WRN_madry' --depth 32 --model_path './FAT_models/fat_wrn32-10_eps0.031.pth.tar' --method 'dat' 2 | python attack_test.py --net 'WRN_madry' --depth 32 --model_path './FAT_models/fat_wrn32-10_eps0.062.pth.tar' --method 'dat' 3 | 4 | python attack_test.py --net 'WRN' --depth 34 --model_path './FAT_models/fat_for_trades_wrn34-10_eps0.031_beta1.0.pth.tar' --method 'trades' 5 | python attack_test.py --net 'WRN' --depth 34 --model_path './FAT_models/fat_for_trades_wrn34-10_eps0.031_beta6.0.pth.tar' --method 'trades' 6 | python attack_test.py --net 'WRN' --depth 34 --model_path './FAT_models/fat_for_trades_wrn34-10_eps0.062_beta6.0.pth.tar' --method 'trades' 7 | 8 | python attack_test.py --net 'WRN' --depth 58 --model_path './FAT_models/fat_for_trades_wrn58-10_eps0.031_beta6.0.pth.tar' --method 'trades' 9 | python attack_test.py --net 'WRN' --depth 58 --model_path './FAT_models/fat_for_trades_wrn58-10_eps0.062_beta6.0.pth.tar' --method 'trades' 10 | 11 | python attack_test.py --net 'WRN' --depth 34 --model_path './FAT_models/fat_for_mart_wrn34-10_eps0.031.pth.tar' --method 'trades' 12 | python attack_test.py --net 'WRN' --depth 34 --model_path './FAT_models/fat_for_mart_wrn34-10_eps0.062.pth.tar' --method 'trades' 13 | -------------------------------------------------------------------------------- /earlystop.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | import torch 3 | import numpy as np 4 | 5 | def earlystop(model, data, target, step_size, epsilon, perturb_steps,tau,randominit_type,loss_fn,rand_init=True,omega=0): 6 | ''' 7 | The implematation of early-stopped PGD 8 | Following the Alg.1 in our FAT paper 9 | :param step_size: the PGD step size 10 | :param epsilon: the perturbation bound 11 | :param perturb_steps: the maximum PGD step 12 | :param tau: the step controlling how early we should stop interations when wrong adv data is found 13 | :param randominit_type: To decide the type of random inirialization (random start for searching adv data) 14 | :param rand_init: To decide whether to initialize adversarial sample with random noise (random start for searching adv data) 15 | :param omega: random sample parameter for adv data generation (this is for escaping the local minimum.) 16 | :return: output_adv (friendly adversarial data) output_target (targets), output_natural (the corresponding natrual data), count (average backword propagations count) 17 | ''' 18 | model.eval() 19 | 20 | K = perturb_steps 21 | count = 0 22 | output_target = [] 23 | output_adv = [] 24 | output_natural = [] 25 | 26 | control = (torch.ones(len(target)) * tau).cuda() 27 | 28 | # Initialize the adversarial data with random noise 29 | if rand_init: 30 | if randominit_type == "normal_distribution_randominit": 31 | iter_adv = data.detach() + 0.001 * torch.randn(data.shape).cuda().detach() 32 | iter_adv = torch.clamp(iter_adv, 0.0, 1.0) 33 | if randominit_type == "uniform_randominit": 34 | iter_adv = data.detach() + torch.from_numpy(np.random.uniform(-epsilon, epsilon, data.shape)).float().cuda() 35 | iter_adv = torch.clamp(iter_adv, 0.0, 1.0) 36 | else: 37 | iter_adv = data.cuda().detach() 38 | 39 | iter_clean_data = data.cuda().detach() 40 | iter_target = target.cuda().detach() 41 | output_iter_clean_data = model(data) 42 | 43 | while K>0: 44 | iter_adv.requires_grad_() 45 | output = model(iter_adv) 46 | pred = output.max(1, keepdim=True)[1] 47 | output_index = [] 48 | iter_index = [] 49 | 50 | # Calculate the indexes of adversarial data those still needs to be iterated 51 | for idx in range(len(pred)): 52 | if pred[idx] != iter_target[idx]: 53 | if control[idx] == 0: 54 | output_index.append(idx) 55 | else: 56 | control[idx] -= 1 57 | iter_index.append(idx) 58 | else: 59 | iter_index.append(idx) 60 | 61 | # Add adversarial data those do not need any more iteration into set output_adv 62 | if len(output_index) != 0: 63 | if len(output_target) == 0: 64 | # incorrect adv data should not keep iterated 65 | output_adv = iter_adv[output_index].reshape(-1, 3, 32, 32).cuda() 66 | output_natural = iter_clean_data[output_index].reshape(-1, 3, 32, 32).cuda() 67 | output_target = iter_target[output_index].reshape(-1).cuda() 68 | else: 69 | # incorrect adv data should not keep iterated 70 | output_adv = torch.cat((output_adv, iter_adv[output_index].reshape(-1, 3, 32, 32).cuda()), dim=0) 71 | output_natural = torch.cat((output_natural, iter_clean_data[output_index].reshape(-1, 3, 32, 32).cuda()), dim=0) 72 | output_target = torch.cat((output_target, iter_target[output_index].reshape(-1).cuda()), dim=0) 73 | 74 | # calculate gradient 75 | model.zero_grad() 76 | with torch.enable_grad(): 77 | if loss_fn == "cent": 78 | loss_adv = nn.CrossEntropyLoss(reduction='mean')(output, iter_target) 79 | if loss_fn == "kl": 80 | criterion_kl = nn.KLDivLoss(size_average=False).cuda() 81 | loss_adv = criterion_kl(F.log_softmax(output, dim=1),F.softmax(output_iter_clean_data, dim=1)) 82 | loss_adv.backward(retain_graph=True) 83 | grad = iter_adv.grad 84 | 85 | # update iter adv 86 | if len(iter_index) != 0: 87 | control = control[iter_index] 88 | iter_adv = iter_adv[iter_index] 89 | iter_clean_data = iter_clean_data[iter_index] 90 | iter_target = iter_target[iter_index] 91 | output_iter_clean_data = output_iter_clean_data[iter_index] 92 | grad = grad[iter_index] 93 | eta = step_size * grad.sign() 94 | 95 | iter_adv = iter_adv.detach() + eta + omega * torch.randn(iter_adv.shape).detach().cuda() 96 | iter_adv = torch.min(torch.max(iter_adv, iter_clean_data - epsilon), iter_clean_data + epsilon) 97 | iter_adv = torch.clamp(iter_adv, 0, 1) 98 | count += len(iter_target) 99 | else: 100 | output_adv = output_adv.detach() 101 | return output_adv, output_target, output_natural, count 102 | K = K-1 103 | 104 | if len(output_target) == 0: 105 | output_target = iter_target.reshape(-1).squeeze().cuda() 106 | output_adv = iter_adv.reshape(-1, 3, 32, 32).cuda() 107 | output_natural = iter_clean_data.reshape(-1, 3, 32, 32).cuda() 108 | else: 109 | output_adv = torch.cat((output_adv, iter_adv.reshape(-1, 3, 32, 32)), dim=0).cuda() 110 | output_target = torch.cat((output_target, iter_target.reshape(-1)), dim=0).squeeze().cuda() 111 | output_natural = torch.cat((output_natural, iter_clean_data.reshape(-1, 3, 32, 32).cuda()),dim=0).cuda() 112 | output_adv = output_adv.detach() 113 | return output_adv, output_target, output_natural, count 114 | -------------------------------------------------------------------------------- /image/adv_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/image/adv_train.png -------------------------------------------------------------------------------- /image/cross_over_mixture_problem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/image/cross_over_mixture_problem.png -------------------------------------------------------------------------------- /image/early_stopped_pgd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/image/early_stopped_pgd.png -------------------------------------------------------------------------------- /image/min-min_vs_minmax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/image/min-min_vs_minmax.png -------------------------------------------------------------------------------- /image/min_min_formulation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/image/min_min_formulation.png -------------------------------------------------------------------------------- /image/minimax_formulation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/image/minimax_formulation.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .dpn import * 3 | from .lenet import * 4 | from .senet import * 5 | from .resnet import * 6 | from .resnext import * 7 | from .densenet import * 8 | from .googlenet import * 9 | from .mobilenet import * 10 | from .shufflenet import * 11 | from .preact_resnet import * 12 | from .wide_resnet import * 13 | from .small_cnn import * 14 | from .wrn_madry import * 15 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/densenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/densenet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/dpn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/dpn.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/googlenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/googlenet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/lenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/lenet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/mobilenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/mobilenet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/preact_resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/preact_resnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnext.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/resnext.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/senet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/senet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/shufflenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/shufflenet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/small_cnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/small_cnn.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/vgg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/vgg.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/wide_resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/wide_resnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/wrn_madry.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/wrn_madry.cpython-36.pyc -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | '''DenseNet in PyTorch.''' 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from torch.autograd import Variable 9 | 10 | 11 | class Bottleneck(nn.Module): 12 | def __init__(self, in_planes, growth_rate): 13 | super(Bottleneck, self).__init__() 14 | self.bn1 = nn.BatchNorm2d(in_planes,momentum=0.2) 15 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(4*growth_rate,momentum=0.2) 17 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 18 | 19 | def forward(self, x): 20 | out = self.conv1(F.relu(self.bn1(x))) 21 | out = self.conv2(F.relu(self.bn2(out))) 22 | out = torch.cat([out,x], 1) 23 | return out 24 | 25 | 26 | class Transition(nn.Module): 27 | def __init__(self, in_planes, out_planes): 28 | super(Transition, self).__init__() 29 | self.bn = nn.BatchNorm2d(in_planes,momentum=0.2) 30 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 31 | 32 | def forward(self, x): 33 | out = self.conv(F.relu(self.bn(x))) 34 | out = F.avg_pool2d(out, 2) 35 | return out 36 | 37 | 38 | class DenseNet(nn.Module): 39 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.2, num_classes=10): 40 | super(DenseNet, self).__init__() 41 | self.growth_rate = growth_rate 42 | 43 | num_planes = 2*growth_rate 44 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 45 | 46 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 47 | num_planes += nblocks[0]*growth_rate 48 | out_planes = int(math.floor(num_planes*reduction)) 49 | self.trans1 = Transition(num_planes, out_planes) 50 | num_planes = out_planes 51 | 52 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 53 | num_planes += nblocks[1]*growth_rate 54 | out_planes = int(math.floor(num_planes*reduction)) 55 | self.trans2 = Transition(num_planes, out_planes) 56 | num_planes = out_planes 57 | 58 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 59 | num_planes += nblocks[2]*growth_rate 60 | out_planes = int(math.floor(num_planes*reduction)) 61 | self.trans3 = Transition(num_planes, out_planes) 62 | num_planes = out_planes 63 | 64 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 65 | num_planes += nblocks[3]*growth_rate 66 | 67 | self.bn = nn.BatchNorm2d(num_planes,momentum=0.2) 68 | self.linear = nn.Linear(num_planes, num_classes) 69 | 70 | def _make_dense_layers(self, block, in_planes, nblock): 71 | layers = [] 72 | for i in range(nblock): 73 | layers.append(block(in_planes, self.growth_rate)) 74 | in_planes += self.growth_rate 75 | return nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | out = self.conv1(x) 79 | out = self.trans1(self.dense1(out)) 80 | out = self.trans2(self.dense2(out)) 81 | out = self.trans3(self.dense3(out)) 82 | out = self.dense4(out) 83 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 84 | out = out.view(out.size(0), -1) 85 | out = self.linear(out) 86 | return out 87 | 88 | def DenseNet121(): 89 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32) 90 | 91 | def DenseNet169(): 92 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32) 93 | 94 | def DenseNet201(): 95 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=48) 96 | 97 | def DenseNet161(): 98 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48) 99 | 100 | def densenet_cifar(): 101 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12) 102 | 103 | def test_densenet(): 104 | net = densenet_cifar() 105 | x = torch.randn(1,3,32,32) 106 | y = net(Variable(x)) 107 | print(y) 108 | print(net) 109 | #test_densenet() 110 | -------------------------------------------------------------------------------- /models/dpn.py: -------------------------------------------------------------------------------- 1 | '''Dual Path Networks in PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import Variable 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer): 11 | super(Bottleneck, self).__init__() 12 | self.out_planes = out_planes 13 | self.dense_depth = dense_depth 14 | 15 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False) 18 | self.bn2 = nn.BatchNorm2d(in_planes) 19 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False) 20 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth) 21 | 22 | self.shortcut = nn.Sequential() 23 | if first_layer: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False), 26 | nn.BatchNorm2d(out_planes+dense_depth) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = F.relu(self.bn2(self.conv2(out))) 32 | out = self.bn3(self.conv3(out)) 33 | x = self.shortcut(x) 34 | d = self.out_planes 35 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class DPN(nn.Module): 41 | def __init__(self, cfg): 42 | super(DPN, self).__init__() 43 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes'] 44 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth'] 45 | 46 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 47 | self.bn1 = nn.BatchNorm2d(64) 48 | self.last_planes = 64 49 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1) 50 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2) 51 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2) 52 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2) 53 | self.linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], 10) 54 | 55 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride): 56 | strides = [stride] + [1]*(num_blocks-1) 57 | layers = [] 58 | for i,stride in enumerate(strides): 59 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0)) 60 | self.last_planes = out_planes + (i+2) * dense_depth 61 | return nn.Sequential(*layers) 62 | 63 | def forward(self, x): 64 | out = F.relu(self.bn1(self.conv1(x))) 65 | out = self.layer1(out) 66 | out = self.layer2(out) 67 | out = self.layer3(out) 68 | out = self.layer4(out) 69 | out = F.avg_pool2d(out, 4) 70 | out = out.view(out.size(0), -1) 71 | out = self.linear(out) 72 | return out 73 | 74 | 75 | def DPN26(): 76 | cfg = { 77 | 'in_planes': (96,192,384,768), 78 | 'out_planes': (256,512,1024,2048), 79 | 'num_blocks': (2,2,2,2), 80 | 'dense_depth': (16,32,24,128) 81 | } 82 | return DPN(cfg) 83 | 84 | def DPN92(): 85 | cfg = { 86 | 'in_planes': (96,192,384,768), 87 | 'out_planes': (256,512,1024,2048), 88 | 'num_blocks': (3,4,20,3), 89 | 'dense_depth': (16,32,24,128) 90 | } 91 | return DPN(cfg) 92 | 93 | 94 | def test(): 95 | net = DPN92() 96 | x = Variable(torch.randn(1,3,32,32)) 97 | y = net(x) 98 | print(y) 99 | 100 | # test() 101 | -------------------------------------------------------------------------------- /models/googlenet.py: -------------------------------------------------------------------------------- 1 | '''GoogLeNet with PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import Variable 7 | 8 | 9 | class Inception(nn.Module): 10 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): 11 | super(Inception, self).__init__() 12 | # 1x1 conv branch 13 | self.b1 = nn.Sequential( 14 | nn.Conv2d(in_planes, n1x1, kernel_size=1), 15 | nn.BatchNorm2d(n1x1), 16 | nn.ReLU(True), 17 | ) 18 | 19 | # 1x1 conv -> 3x3 conv branch 20 | self.b2 = nn.Sequential( 21 | nn.Conv2d(in_planes, n3x3red, kernel_size=1), 22 | nn.BatchNorm2d(n3x3red), 23 | nn.ReLU(True), 24 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1), 25 | nn.BatchNorm2d(n3x3), 26 | nn.ReLU(True), 27 | ) 28 | 29 | # 1x1 conv -> 5x5 conv branch 30 | self.b3 = nn.Sequential( 31 | nn.Conv2d(in_planes, n5x5red, kernel_size=1), 32 | nn.BatchNorm2d(n5x5red), 33 | nn.ReLU(True), 34 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(n5x5), 36 | nn.ReLU(True), 37 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 38 | nn.BatchNorm2d(n5x5), 39 | nn.ReLU(True), 40 | ) 41 | 42 | # 3x3 pool -> 1x1 conv branch 43 | self.b4 = nn.Sequential( 44 | nn.MaxPool2d(3, stride=1, padding=1), 45 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 46 | nn.BatchNorm2d(pool_planes), 47 | nn.ReLU(True), 48 | ) 49 | 50 | def forward(self, x): 51 | y1 = self.b1(x) 52 | y2 = self.b2(x) 53 | y3 = self.b3(x) 54 | y4 = self.b4(x) 55 | return torch.cat([y1,y2,y3,y4], 1) 56 | 57 | 58 | class GoogLeNet(nn.Module): 59 | def __init__(self): 60 | super(GoogLeNet, self).__init__() 61 | self.pre_layers = nn.Sequential( 62 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 63 | nn.BatchNorm2d(192), 64 | nn.ReLU(True), 65 | ) 66 | 67 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 68 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 69 | 70 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 71 | 72 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 73 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 74 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 75 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 76 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 77 | 78 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 79 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 80 | 81 | self.avgpool = nn.AvgPool2d(8, stride=1) 82 | self.linear = nn.Linear(1024, 10) 83 | 84 | def forward(self, x): 85 | out = self.pre_layers(x) 86 | out = self.a3(out) 87 | out = self.b3(out) 88 | out = self.maxpool(out) 89 | out = self.a4(out) 90 | out = self.b4(out) 91 | out = self.c4(out) 92 | out = self.d4(out) 93 | out = self.e4(out) 94 | out = self.maxpool(out) 95 | out = self.a5(out) 96 | out = self.b5(out) 97 | out = self.avgpool(out) 98 | out = out.view(out.size(0), -1) 99 | out = self.linear(out) 100 | return out 101 | 102 | # net = GoogLeNet() 103 | # x = torch.randn(1,3,32,32) 104 | # y = net(Variable(x)) 105 | # print(y.size()) 106 | -------------------------------------------------------------------------------- /models/lenet.py: -------------------------------------------------------------------------------- 1 | '''LeNet in PyTorch.''' 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class LeNet(nn.Module): 6 | def __init__(self): 7 | super(LeNet, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 6, 5) 9 | self.conv2 = nn.Conv2d(6, 16, 5) 10 | self.fc1 = nn.Linear(16*5*5, 120) 11 | self.fc2 = nn.Linear(120, 84) 12 | self.fc3 = nn.Linear(84, 10) 13 | 14 | def forward(self, x): 15 | out = F.relu(self.conv1(x)) 16 | out = F.max_pool2d(out, 2) 17 | out = F.relu(self.conv2(out)) 18 | out = F.max_pool2d(out, 2) 19 | out = out.view(out.size(0), -1) 20 | out = F.relu(self.fc1(out)) 21 | out = F.relu(self.fc2(out)) 22 | out = self.fc3(out) 23 | return out 24 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | '''MobileNet in PyTorch. 2 | 3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 4 | for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from torch.autograd import Variable 11 | 12 | 13 | class Block(nn.Module): 14 | '''Depthwise conv + Pointwise conv''' 15 | def __init__(self, in_planes, out_planes, stride=1): 16 | super(Block, self).__init__() 17 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 20 | self.bn2 = nn.BatchNorm2d(out_planes) 21 | 22 | def forward(self, x): 23 | out = F.relu(self.bn1(self.conv1(x))) 24 | out = F.relu(self.bn2(self.conv2(out))) 25 | return out 26 | 27 | 28 | class MobileNet(nn.Module): 29 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 30 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 31 | 32 | def __init__(self, num_classes=10): 33 | super(MobileNet, self).__init__() 34 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 35 | self.bn1 = nn.BatchNorm2d(32) 36 | self.layers = self._make_layers(in_planes=32) 37 | self.linear = nn.Linear(1024, num_classes) 38 | 39 | def _make_layers(self, in_planes): 40 | layers = [] 41 | for x in self.cfg: 42 | out_planes = x if isinstance(x, int) else x[0] 43 | stride = 1 if isinstance(x, int) else x[1] 44 | layers.append(Block(in_planes, out_planes, stride)) 45 | in_planes = out_planes 46 | return nn.Sequential(*layers) 47 | 48 | def forward(self, x): 49 | out = F.relu(self.bn1(self.conv1(x))) 50 | out = self.layers(out) 51 | out = F.avg_pool2d(out, 2) 52 | out = out.view(out.size(0), -1) 53 | out = self.linear(out) 54 | return out 55 | 56 | 57 | def test(): 58 | net = MobileNet() 59 | x = torch.randn(1,3,32,32) 60 | y = net(Variable(x)) 61 | print(y.size()) 62 | 63 | # test() 64 | -------------------------------------------------------------------------------- /models/preact_resnet.py: -------------------------------------------------------------------------------- 1 | '''Pre-activation ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from torch.autograd import Variable 12 | 13 | 14 | class PreActBlock(nn.Module): 15 | '''Pre-activation version of the BasicBlock.''' 16 | expansion = 1 17 | 18 | def __init__(self, in_planes, planes, stride=1): 19 | super(PreActBlock, self).__init__() 20 | self.bn1 = nn.BatchNorm2d(in_planes) 21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 24 | 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 28 | ) 29 | 30 | def forward(self, x): 31 | out = F.relu(self.bn1(x)) 32 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 33 | out = self.conv1(out) 34 | out = self.conv2(F.relu(self.bn2(out))) 35 | out += shortcut 36 | return out 37 | 38 | 39 | class PreActBottleneck(nn.Module): 40 | '''Pre-activation version of the original Bottleneck module.''' 41 | expansion = 4 42 | 43 | def __init__(self, in_planes, planes, stride=1): 44 | super(PreActBottleneck, self).__init__() 45 | self.bn1 = nn.BatchNorm2d(in_planes) 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 51 | 52 | if stride != 1 or in_planes != self.expansion*planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 55 | ) 56 | 57 | def forward(self, x): 58 | out = F.relu(self.bn1(x)) 59 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 60 | out = self.conv1(out) 61 | out = self.conv2(F.relu(self.bn2(out))) 62 | out = self.conv3(F.relu(self.bn3(out))) 63 | out += shortcut 64 | return out 65 | 66 | 67 | class PreActResNet(nn.Module): 68 | def __init__(self, block, num_blocks, num_classes=10): 69 | super(PreActResNet, self).__init__() 70 | self.in_planes = 64 71 | 72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 74 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 75 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 76 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 77 | self.linear = nn.Linear(512*block.expansion, num_classes) 78 | 79 | def _make_layer(self, block, planes, num_blocks, stride): 80 | strides = [stride] + [1]*(num_blocks-1) 81 | layers = [] 82 | for stride in strides: 83 | layers.append(block(self.in_planes, planes, stride)) 84 | self.in_planes = planes * block.expansion 85 | return nn.Sequential(*layers) 86 | 87 | def forward(self, x): 88 | out = self.conv1(x) 89 | out = self.layer1(out) 90 | out = self.layer2(out) 91 | out = self.layer3(out) 92 | out = self.layer4(out) 93 | out = F.avg_pool2d(out, 4) 94 | out = out.view(out.size(0), -1) 95 | out = self.linear(out) 96 | return out 97 | 98 | 99 | def PreActResNet18(): 100 | return PreActResNet(PreActBlock, [2,2,2,2]) 101 | 102 | def PreActResNet34(): 103 | return PreActResNet(PreActBlock, [3,4,6,3]) 104 | 105 | def PreActResNet50(): 106 | return PreActResNet(PreActBottleneck, [3,4,6,3]) 107 | 108 | def PreActResNet101(): 109 | return PreActResNet(PreActBottleneck, [3,4,23,3]) 110 | 111 | def PreActResNet152(): 112 | return PreActResNet(PreActBottleneck, [3,8,36,3]) 113 | 114 | 115 | def test(): 116 | net = PreActResNet18() 117 | y = net(Variable(torch.randn(1,3,32,32))) 118 | print(y.size()) 119 | 120 | # test() 121 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | For Pre-activation ResNet, see 'preact_resnet.py'. 4 | 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from torch.autograd import Variable 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride=1): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*planes: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(self.expansion*planes) 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = self.bn2(self.conv2(out)) 36 | out += self.shortcut(x) 37 | out = F.relu(out) 38 | return out 39 | 40 | 41 | class Bottleneck(nn.Module): 42 | expansion = 4 43 | 44 | def __init__(self, in_planes, planes, stride=1): 45 | super(Bottleneck, self).__init__() 46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 47 | self.bn1 = nn.BatchNorm2d(planes) 48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | self.bn2 = nn.BatchNorm2d(planes) 50 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 51 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 52 | 53 | self.shortcut = nn.Sequential() 54 | if stride != 1 or in_planes != self.expansion*planes: 55 | self.shortcut = nn.Sequential( 56 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 57 | nn.BatchNorm2d(self.expansion*planes) 58 | ) 59 | 60 | def forward(self, x): 61 | out = F.relu(self.bn1(self.conv1(x))) 62 | out = F.relu(self.bn2(self.conv2(out))) 63 | out = self.bn3(self.conv3(out)) 64 | out += self.shortcut(x) 65 | out = F.relu(out) 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | def __init__(self, block, num_blocks, num_classes=10): 71 | super(ResNet, self).__init__() 72 | self.in_planes = 64 73 | 74 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 75 | self.bn1 = nn.BatchNorm2d(64) 76 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 77 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 78 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 79 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 80 | self.linear = nn.Linear(512*block.expansion, num_classes) 81 | 82 | def _make_layer(self, block, planes, num_blocks, stride): 83 | strides = [stride] + [1]*(num_blocks-1) 84 | layers = [] 85 | for stride in strides: 86 | layers.append(block(self.in_planes, planes, stride)) 87 | self.in_planes = planes * block.expansion 88 | return nn.Sequential(*layers) 89 | 90 | def forward(self, x): 91 | out = F.relu(self.bn1(self.conv1(x))) 92 | out = self.layer1(out) 93 | out = self.layer2(out) 94 | out = self.layer3(out) 95 | out = self.layer4(out) 96 | out = F.avg_pool2d(out, 4) 97 | out = out.view(out.size(0), -1) 98 | out = self.linear(out) 99 | return out 100 | 101 | 102 | def ResNet18(): 103 | return ResNet(BasicBlock, [2,2,2,2]) 104 | 105 | def ResNet34(): 106 | return ResNet(BasicBlock, [3,4,6,3]) 107 | 108 | def ResNet50(): 109 | return ResNet(Bottleneck, [3,4,6,3]) 110 | 111 | def ResNet101(): 112 | return ResNet(Bottleneck, [3,4,23,3]) 113 | 114 | def ResNet152(): 115 | return ResNet(Bottleneck, [3,8,36,3]) 116 | 117 | 118 | def test(): 119 | net = ResNet18() 120 | y = net(Variable(torch.randn(1,3,32,32))) 121 | print(y.size()) 122 | print(net) 123 | # test() -------------------------------------------------------------------------------- /models/resnext.py: -------------------------------------------------------------------------------- 1 | '''ResNeXt in PyTorch. 2 | 3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from torch.autograd import Variable 10 | 11 | 12 | class Block(nn.Module): 13 | '''Grouped convolution block.''' 14 | expansion = 2 15 | 16 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 17 | super(Block, self).__init__() 18 | group_width = cardinality * bottleneck_width 19 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(group_width) 21 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 22 | self.bn2 = nn.BatchNorm2d(group_width) 23 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 24 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*group_width: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(self.expansion*group_width) 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = F.relu(self.bn2(self.conv2(out))) 36 | out = self.bn3(self.conv3(out)) 37 | out += self.shortcut(x) 38 | out = F.relu(out) 39 | return out 40 | 41 | 42 | class ResNeXt(nn.Module): 43 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10): 44 | super(ResNeXt, self).__init__() 45 | self.cardinality = cardinality 46 | self.bottleneck_width = bottleneck_width 47 | self.in_planes = 64 48 | 49 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 50 | self.bn1 = nn.BatchNorm2d(64) 51 | self.layer1 = self._make_layer(num_blocks[0], 1) 52 | self.layer2 = self._make_layer(num_blocks[1], 2) 53 | self.layer3 = self._make_layer(num_blocks[2], 2) 54 | # self.layer4 = self._make_layer(num_blocks[3], 2) 55 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes) 56 | 57 | def _make_layer(self, num_blocks, stride): 58 | strides = [stride] + [1]*(num_blocks-1) 59 | layers = [] 60 | for stride in strides: 61 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 62 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width 63 | # Increase bottleneck_width by 2 after each stage. 64 | self.bottleneck_width *= 2 65 | return nn.Sequential(*layers) 66 | 67 | def forward(self, x): 68 | out = F.relu(self.bn1(self.conv1(x))) 69 | out = self.layer1(out) 70 | out = self.layer2(out) 71 | out = self.layer3(out) 72 | # out = self.layer4(out) 73 | out = F.avg_pool2d(out, 8) 74 | out = out.view(out.size(0), -1) 75 | out = self.linear(out) 76 | return out 77 | 78 | 79 | def ResNeXt29_2x64d(): 80 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64) 81 | 82 | def ResNeXt29_4x64d(): 83 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64) 84 | 85 | def ResNeXt29_8x64d(): 86 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64) 87 | 88 | def ResNeXt29_32x4d(): 89 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4) 90 | 91 | def test_resnext(): 92 | net = ResNeXt29_2x64d() 93 | x = torch.randn(1,3,32,32) 94 | y = net(Variable(x)) 95 | print(y.size()) 96 | 97 | # test_resnext() 98 | -------------------------------------------------------------------------------- /models/senet.py: -------------------------------------------------------------------------------- 1 | '''SENet in PyTorch. 2 | 3 | SENet is the winner of ImageNet-2017. The paper is not released yet. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from torch.autograd import Variable 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | def __init__(self, in_planes, planes, stride=1): 14 | super(BasicBlock, self).__init__() 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 16 | self.bn1 = nn.BatchNorm2d(planes) 17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | 20 | self.shortcut = nn.Sequential() 21 | if stride != 1 or in_planes != planes: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(planes) 25 | ) 26 | 27 | # SE layers 28 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear 29 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | 35 | # Squeeze 36 | w = F.avg_pool2d(out, out.size(2)) 37 | w = F.relu(self.fc1(w)) 38 | w = F.sigmoid(self.fc2(w)) 39 | # Excitation 40 | out = out * w # New broadcasting feature from v0.2! 41 | 42 | out += self.shortcut(x) 43 | out = F.relu(out) 44 | return out 45 | 46 | 47 | class PreActBlock(nn.Module): 48 | def __init__(self, in_planes, planes, stride=1): 49 | super(PreActBlock, self).__init__() 50 | self.bn1 = nn.BatchNorm2d(in_planes) 51 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 54 | 55 | if stride != 1 or in_planes != planes: 56 | self.shortcut = nn.Sequential( 57 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False) 58 | ) 59 | 60 | # SE layers 61 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) 62 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(x)) 66 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 67 | out = self.conv1(out) 68 | out = self.conv2(F.relu(self.bn2(out))) 69 | 70 | # Squeeze 71 | w = F.avg_pool2d(out, out.size(2)) 72 | w = F.relu(self.fc1(w)) 73 | w = F.sigmoid(self.fc2(w)) 74 | # Excitation 75 | out = out * w 76 | 77 | out += shortcut 78 | return out 79 | 80 | 81 | class SENet(nn.Module): 82 | def __init__(self, block, num_blocks, num_classes=10): 83 | super(SENet, self).__init__() 84 | self.in_planes = 64 85 | 86 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 87 | self.bn1 = nn.BatchNorm2d(64) 88 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 89 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 90 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 91 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 92 | self.linear = nn.Linear(512, num_classes) 93 | 94 | def _make_layer(self, block, planes, num_blocks, stride): 95 | strides = [stride] + [1]*(num_blocks-1) 96 | layers = [] 97 | for stride in strides: 98 | layers.append(block(self.in_planes, planes, stride)) 99 | self.in_planes = planes 100 | return nn.Sequential(*layers) 101 | 102 | def forward(self, x): 103 | out = F.relu(self.bn1(self.conv1(x))) 104 | out = self.layer1(out) 105 | out = self.layer2(out) 106 | out = self.layer3(out) 107 | out = self.layer4(out) 108 | out = F.avg_pool2d(out, 4) 109 | out = out.view(out.size(0), -1) 110 | out = self.linear(out) 111 | return out 112 | 113 | 114 | def SENet18(): 115 | return SENet(PreActBlock, [2,2,2,2]) 116 | 117 | 118 | def test(): 119 | net = SENet18() 120 | y = net(Variable(torch.randn(1,3,32,32))) 121 | print(y.size()) 122 | print(net) 123 | 124 | #test() 125 | -------------------------------------------------------------------------------- /models/shufflenet.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | 3 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from torch.autograd import Variable 10 | 11 | 12 | class ShuffleBlock(nn.Module): 13 | def __init__(self, groups): 14 | super(ShuffleBlock, self).__init__() 15 | self.groups = groups 16 | 17 | def forward(self, x): 18 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 19 | N,C,H,W = x.size() 20 | g = self.groups 21 | return x.view(N,g,C/g,H,W).permute(0,2,1,3,4).contiguous().view(N,C,H,W) 22 | 23 | 24 | class Bottleneck(nn.Module): 25 | def __init__(self, in_planes, out_planes, stride, groups): 26 | super(Bottleneck, self).__init__() 27 | self.stride = stride 28 | 29 | mid_planes = out_planes/4 30 | g = 1 if in_planes==24 else groups 31 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 32 | self.bn1 = nn.BatchNorm2d(mid_planes) 33 | self.shuffle1 = ShuffleBlock(groups=g) 34 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 35 | self.bn2 = nn.BatchNorm2d(mid_planes) 36 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 37 | self.bn3 = nn.BatchNorm2d(out_planes) 38 | 39 | self.shortcut = nn.Sequential() 40 | if stride == 2: 41 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 42 | 43 | def forward(self, x): 44 | out = F.relu(self.bn1(self.conv1(x))) 45 | out = self.shuffle1(out) 46 | out = F.relu(self.bn2(self.conv2(out))) 47 | out = self.bn3(self.conv3(out)) 48 | res = self.shortcut(x) 49 | out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res) 50 | return out 51 | 52 | 53 | class ShuffleNet(nn.Module): 54 | def __init__(self, cfg): 55 | super(ShuffleNet, self).__init__() 56 | out_planes = cfg['out_planes'] 57 | num_blocks = cfg['num_blocks'] 58 | groups = cfg['groups'] 59 | 60 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(24) 62 | self.in_planes = 24 63 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 64 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 65 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 66 | self.linear = nn.Linear(out_planes[2], 10) 67 | 68 | def _make_layer(self, out_planes, num_blocks, groups): 69 | layers = [] 70 | for i in range(num_blocks): 71 | stride = 2 if i == 0 else 1 72 | cat_planes = self.in_planes if i == 0 else 0 73 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups)) 74 | self.in_planes = out_planes 75 | return nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | out = F.relu(self.bn1(self.conv1(x))) 79 | out = self.layer1(out) 80 | out = self.layer2(out) 81 | out = self.layer3(out) 82 | out = F.avg_pool2d(out, 4) 83 | out = out.view(out.size(0), -1) 84 | out = self.linear(out) 85 | return out 86 | 87 | 88 | def ShuffleNetG2(): 89 | cfg = { 90 | 'out_planes': [200,400,800], 91 | 'num_blocks': [4,8,4], 92 | 'groups': 2 93 | } 94 | return ShuffleNet(cfg) 95 | 96 | def ShuffleNetG3(): 97 | cfg = { 98 | 'out_planes': [240,480,960], 99 | 'num_blocks': [4,8,4], 100 | 'groups': 3 101 | } 102 | return ShuffleNet(cfg) 103 | 104 | 105 | def test(): 106 | net = ShuffleNetG2() 107 | x = Variable(torch.randn(1,3,32,32)) 108 | y = net(x) 109 | print(y) 110 | 111 | # test() 112 | -------------------------------------------------------------------------------- /models/small_cnn.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch.nn as nn 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | class SmallCNN(nn.Module): 7 | def __init__(self): 8 | super(SmallCNN, self).__init__() 9 | 10 | self.block1_conv1 = nn.Conv2d(3, 64, 3, padding=1) 11 | self.block1_conv2 = nn.Conv2d(64, 64, 3, padding=1) 12 | self.block1_pool1 = nn.MaxPool2d(2, 2) 13 | self.batchnorm1_1 = nn.BatchNorm2d(64) 14 | self.batchnorm1_2 = nn.BatchNorm2d(64) 15 | 16 | self.block2_conv1 = nn.Conv2d(64, 128, 3, padding=1) 17 | self.block2_conv2 = nn.Conv2d(128, 128, 3, padding=1) 18 | self.block2_pool1 = nn.MaxPool2d(2, 2) 19 | self.batchnorm2_1 = nn.BatchNorm2d(128) 20 | self.batchnorm2_2 = nn.BatchNorm2d(128) 21 | 22 | self.block3_conv1 = nn.Conv2d(128, 196, 3, padding=1) 23 | self.block3_conv2 = nn.Conv2d(196, 196, 3, padding=1) 24 | self.block3_pool1 = nn.MaxPool2d(2, 2) 25 | self.batchnorm3_1 = nn.BatchNorm2d(196) 26 | self.batchnorm3_2 = nn.BatchNorm2d(196) 27 | 28 | self.activ = nn.ReLU() 29 | 30 | self.fc1 = nn.Linear(196*4*4,256) 31 | self.fc2 = nn.Linear(256,10) 32 | 33 | def forward(self, x): 34 | #block1 35 | x = self.block1_conv1(x) 36 | x = self.batchnorm1_1(x) 37 | x = self.activ(x) 38 | x = self.block1_conv2(x) 39 | x = self.batchnorm1_2(x) 40 | x = self.activ(x) 41 | x = self.block1_pool1(x) 42 | 43 | #block2 44 | x = self.block2_conv1(x) 45 | x = self.batchnorm2_1(x) 46 | x = self.activ(x) 47 | x = self.block2_conv2(x) 48 | x = self.batchnorm2_2(x) 49 | x = self.activ(x) 50 | x = self.block2_pool1(x) 51 | #block3 52 | x = self.block3_conv1(x) 53 | x = self.batchnorm3_1(x) 54 | x = self.activ(x) 55 | x = self.block3_conv2(x) 56 | x = self.batchnorm3_2(x) 57 | x = self.activ(x) 58 | x = self.block3_pool1(x) 59 | 60 | x = x.view(-1,196*4*4) 61 | x = self.fc1(x) 62 | x = self.activ(x) 63 | x = self.fc2(x) 64 | 65 | return x 66 | 67 | def small_cnn(): 68 | return SmallCNN() 69 | def test(): 70 | net = small_cnn() 71 | y = net(Variable(torch.randn(1,3,32,32))) 72 | print(y.size()) 73 | print(net) -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | cfg = { 8 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 10 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 11 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 12 | } 13 | 14 | 15 | class VGG(nn.Module): 16 | def __init__(self, vgg_name): 17 | super(VGG, self).__init__() 18 | self.features = self._make_layers(cfg[vgg_name]) 19 | self.classifier = nn.Linear(512, 10) 20 | 21 | def forward(self, x): 22 | out = self.features(x) 23 | out = out.view(out.size(0), -1) 24 | out = self.classifier(out) 25 | return out 26 | 27 | def _make_layers(self, cfg): 28 | layers = [] 29 | in_channels = 3 30 | for x in cfg: 31 | if x == 'M': 32 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 33 | else: 34 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(x), 36 | nn.ReLU(inplace=True)] 37 | in_channels = x 38 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 39 | return nn.Sequential(*layers) 40 | 41 | # net = VGG('VGG11') 42 | # x = torch.randn(2,3,32,32) 43 | # print(net(Variable(x)).size()) 44 | -------------------------------------------------------------------------------- /models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | 35 | class NetworkBlock(nn.Module): 36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 37 | super(NetworkBlock, self).__init__() 38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 39 | 40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 41 | layers = [] 42 | for i in range(int(nb_layers)): 43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | return self.layer(x) 48 | 49 | 50 | class Wide_ResNet(nn.Module): 51 | def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0): 52 | super(Wide_ResNet, self).__init__() 53 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 54 | assert ((depth - 4) % 6 == 0) 55 | n = (depth - 4) / 6 56 | block = BasicBlock 57 | # 1st conv before any network block 58 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 59 | padding=1, bias=False) 60 | # 1st block 61 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 62 | # 1st sub-block 63 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 64 | # 2nd block 65 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 66 | # 3rd block 67 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 68 | # global average pooling and classifier 69 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.fc = nn.Linear(nChannels[3], num_classes) 72 | self.nChannels = nChannels[3] 73 | 74 | for m in self.modules(): 75 | if isinstance(m, nn.Conv2d): 76 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 77 | m.weight.data.normal_(0, math.sqrt(2. / n)) 78 | elif isinstance(m, nn.BatchNorm2d): 79 | m.weight.data.fill_(1) 80 | m.bias.data.zero_() 81 | elif isinstance(m, nn.Linear): 82 | m.bias.data.zero_() 83 | 84 | def forward(self, x): 85 | out = self.conv1(x) 86 | out = self.block1(out) 87 | out = self.block2(out) 88 | out = self.block3(out) 89 | out = self.relu(self.bn1(out)) 90 | out = F.avg_pool2d(out, 8) 91 | out = out.view(-1, self.nChannels) 92 | return self.fc(out) 93 | def test(): 94 | net = Wide_ResNet() 95 | y = net(Variable(torch.randn(1, 3, 32, 32))) 96 | #print(y.size()) 97 | print(net) 98 | # test() -------------------------------------------------------------------------------- /models/wrn_madry.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | 35 | class NetworkBlock(nn.Module): 36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 37 | super(NetworkBlock, self).__init__() 38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 39 | 40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 41 | layers = [] 42 | for i in range(int(nb_layers)): 43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 44 | return nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | return self.layer(x) 48 | 49 | 50 | class Wide_ResNet_Madry(nn.Module): 51 | def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0): 52 | super(Wide_ResNet_Madry, self).__init__() 53 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 54 | assert ((depth - 2) % 6 == 0) 55 | n = (depth - 2) / 6 56 | block = BasicBlock 57 | # 1st conv before any network block 58 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 59 | padding=1, bias=False) 60 | # 1st block 61 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 62 | # 1st sub-block 63 | # self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 64 | # 2nd block 65 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 66 | # 3rd block 67 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 68 | # global average pooling and classifier 69 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.fc = nn.Linear(nChannels[3], num_classes) 72 | self.nChannels = nChannels[3] 73 | 74 | for m in self.modules(): 75 | if isinstance(m, nn.Conv2d): 76 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 77 | m.weight.data.normal_(0, math.sqrt(2. / n)) 78 | elif isinstance(m, nn.BatchNorm2d): 79 | m.weight.data.fill_(1) 80 | m.bias.data.zero_() 81 | elif isinstance(m, nn.Linear): 82 | m.bias.data.zero_() 83 | 84 | def forward(self, x): 85 | out = self.conv1(x) 86 | out = self.block1(out) 87 | out = self.block2(out) 88 | out = self.block3(out) 89 | out = self.relu(self.bn1(out)) 90 | out = F.avg_pool2d(out, 8) 91 | out = out.view(-1, self.nChannels) 92 | return self.fc(out) 93 | def test(): 94 | net = Wide_ResNet_Madry() 95 | y = net(Variable(torch.randn(1, 3, 32, 32))) 96 | #print(y.size()) 97 | print(net) 98 | # test() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .logger import * 4 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/eval.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/eval.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/eval.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/logger.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/misc.cpython-35.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib.pyplot as plt 5 | import os 6 | import sys 7 | import numpy as np 8 | 9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 10 | 11 | def savefig(fname, dpi=None): 12 | dpi = 150 if dpi == None else dpi 13 | plt.savefig(fname, dpi=dpi) 14 | 15 | def plot_overlap(logger, names=None): 16 | names = logger.names if names == None else names 17 | numbers = logger.numbers 18 | for _, name in enumerate(names): 19 | x = np.arange(len(numbers[name])) 20 | plt.plot(x, np.asarray(numbers[name])) 21 | return [logger.title + '(' + name + ')' for name in names] 22 | 23 | class Logger(object): 24 | '''Save training process to log file with simple plot function.''' 25 | def __init__(self, fpath, title=None, resume=False): 26 | self.file = None 27 | self.resume = resume 28 | self.title = '' if title == None else title 29 | if fpath is not None: 30 | if resume: 31 | self.file = open(fpath, 'r') 32 | name = self.file.readline() 33 | self.names = name.rstrip().split('\t') 34 | self.numbers = {} 35 | for _, name in enumerate(self.names): 36 | self.numbers[name] = [] 37 | 38 | for numbers in self.file: 39 | numbers = numbers.rstrip().split('\t') 40 | for i in range(0, len(numbers)): 41 | self.numbers[self.names[i]].append(numbers[i]) 42 | self.file.close() 43 | self.file = open(fpath, 'a') 44 | else: 45 | self.file = open(fpath, 'w') 46 | 47 | def set_names(self, names): 48 | if self.resume: 49 | pass 50 | # initialize numbers as empty list 51 | self.numbers = {} 52 | self.names = names 53 | for _, name in enumerate(self.names): 54 | self.file.write(name) 55 | self.file.write('\t') 56 | self.numbers[name] = [] 57 | self.file.write('\n') 58 | self.file.flush() 59 | 60 | 61 | def append(self, numbers): 62 | assert len(self.names) == len(numbers), 'Numbers do not match names' 63 | for index, num in enumerate(numbers): 64 | self.file.write("{0:.6f}".format(num)) 65 | self.file.write('\t') 66 | self.numbers[self.names[index]].append(num) 67 | self.file.write('\n') 68 | self.file.flush() 69 | 70 | def plot(self, names=None): 71 | names = self.names if names == None else names 72 | numbers = self.numbers 73 | for _, name in enumerate(names): 74 | x = np.arange(len(numbers[name])) 75 | plt.plot(x, np.asarray(numbers[name])) 76 | plt.legend([self.title + '(' + name + ')' for name in names]) 77 | plt.grid(True) 78 | 79 | def close(self): 80 | if self.file is not None: 81 | self.file.close() 82 | 83 | class LoggerMonitor(object): 84 | '''Load and visualize multiple logs.''' 85 | def __init__ (self, paths): 86 | '''paths is a distionary with {name:filepath} pair''' 87 | self.loggers = [] 88 | for title, path in paths.items(): 89 | logger = Logger(path, title=title, resume=True) 90 | self.loggers.append(logger) 91 | 92 | def plot(self, names=None): 93 | plt.figure() 94 | plt.subplot(121) 95 | legend_text = [] 96 | for logger in self.loggers: 97 | legend_text += plot_overlap(logger, names) 98 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 99 | plt.grid(True) 100 | 101 | if __name__ == '__main__': 102 | # # Example 103 | # logger = Logger('test.txt') 104 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 105 | 106 | # length = 100 107 | # t = np.arange(length) 108 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 109 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 110 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 111 | 112 | # for i in range(0, length): 113 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 114 | # logger.plot() 115 | 116 | # Example: logger monitor 117 | paths = { 118 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 119 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 120 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 121 | } 122 | 123 | field = ['Valid Acc.'] 124 | 125 | monitor = LoggerMonitor(paths) 126 | monitor.plot(names=field) 127 | savefig('test.eps') --------------------------------------------------------------------------------