├── FAT.py
├── FAT_for_MART.py
├── FAT_for_TRADES.py
├── README.md
├── attack_generator.py
├── attack_test.py
├── attack_test.sh
├── earlystop.py
├── image
├── adv_train.png
├── cross_over_mixture_problem.png
├── early_stopped_pgd.png
├── min-min_vs_minmax.png
├── min_min_formulation.png
└── minimax_formulation.png
├── models
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── densenet.cpython-36.pyc
│ ├── dpn.cpython-36.pyc
│ ├── googlenet.cpython-36.pyc
│ ├── lenet.cpython-36.pyc
│ ├── mobilenet.cpython-36.pyc
│ ├── preact_resnet.cpython-36.pyc
│ ├── resnet.cpython-36.pyc
│ ├── resnext.cpython-36.pyc
│ ├── senet.cpython-36.pyc
│ ├── shufflenet.cpython-36.pyc
│ ├── small_cnn.cpython-36.pyc
│ ├── vgg.cpython-36.pyc
│ ├── wide_resnet.cpython-36.pyc
│ └── wrn_madry.cpython-36.pyc
├── densenet.py
├── dpn.py
├── googlenet.py
├── lenet.py
├── mobilenet.py
├── preact_resnet.py
├── resnet.py
├── resnext.py
├── senet.py
├── shufflenet.py
├── small_cnn.py
├── vgg.py
├── wide_resnet.py
└── wrn_madry.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-35.pyc
├── __init__.cpython-36.pyc
├── __init__.cpython-37.pyc
├── eval.cpython-35.pyc
├── eval.cpython-36.pyc
├── eval.cpython-37.pyc
├── logger.cpython-35.pyc
├── logger.cpython-36.pyc
├── logger.cpython-37.pyc
├── misc.cpython-35.pyc
├── misc.cpython-36.pyc
└── misc.cpython-37.pyc
└── logger.py
/FAT.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torchvision
4 | import torch.optim as optim
5 | from torchvision import transforms
6 | import datetime
7 | from models import *
8 | from earlystop import earlystop
9 | import numpy as np
10 | from utils import Logger
11 | import attack_generator as attack
12 |
13 | parser = argparse.ArgumentParser(description='PyTorch Friendly Adversarial Training')
14 | parser.add_argument('--epochs', type=int, default=120, metavar='N', help='number of epochs to train')
15 | parser.add_argument('--weight_decay', '--wd', default=2e-4, type=float, metavar='W')
16 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate')
17 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum')
18 | parser.add_argument('--epsilon', type=float, default=0.031, help='perturbation bound')
19 | parser.add_argument('--num_steps', type=int, default=10, help='maximum perturbation step K')
20 | parser.add_argument('--step_size', type=float, default=0.007, help='step size')
21 | parser.add_argument('--seed', type=int, default=7, metavar='S', help='random seed')
22 | parser.add_argument('--net', type=str, default="WRN_madry",
23 | help="decide which network to use,choose from smallcnn,resnet18,WRN")
24 | parser.add_argument('--tau', type=int, default=0, help='step tau')
25 | parser.add_argument('--dataset', type=str, default="cifar10", help="choose from cifar10,svhn")
26 | parser.add_argument('--rand_init', type=bool, default=True, help="whether to initialize adversarial sample with random noise")
27 | parser.add_argument('--omega', type=float, default=0.001, help="random sample parameter for adv data generation")
28 | parser.add_argument('--dynamictau', type=bool, default=True, help='whether to use dynamic tau')
29 | parser.add_argument('--depth', type=int, default=32, help='WRN depth')
30 | parser.add_argument('--width_factor', type=int, default=10, help='WRN width factor')
31 | parser.add_argument('--drop_rate', type=float, default=0.0, help='WRN drop rate')
32 | parser.add_argument('--out_dir', type=str, default='./FAT_results', help='dir of output')
33 | parser.add_argument('--resume', type=str, default='', help='whether to resume training, default: None')
34 |
35 | args = parser.parse_args()
36 |
37 | # training settings
38 | torch.manual_seed(args.seed)
39 | np.random.seed(args.seed)
40 | torch.cuda.manual_seed_all(args.seed)
41 | torch.backends.cudnn.deterministic = False
42 | torch.backends.cudnn.benchmark = True
43 |
44 | out_dir = args.out_dir
45 | if not os.path.exists(out_dir):
46 | os.makedirs(out_dir)
47 |
48 | def train(model, train_loader, optimizer, tau):
49 | starttime = datetime.datetime.now()
50 | loss_sum = 0
51 | bp_count = 0
52 | for batch_idx, (data, target) in enumerate(train_loader):
53 | data, target = data.cuda(), target.cuda()
54 |
55 | # Get friendly adversarial training data via early-stopped PGD
56 | output_adv, output_target, output_natural, count = earlystop(model, data, target, step_size=args.step_size,
57 | epsilon=args.epsilon, perturb_steps=args.num_steps, tau=tau,
58 | randominit_type="uniform_randominit", loss_fn='cent', rand_init=args.rand_init, omega=args.omega)
59 | bp_count += count
60 | model.train()
61 | optimizer.zero_grad()
62 | output = model(output_adv)
63 |
64 | # calculate standard adversarial training loss
65 | loss = nn.CrossEntropyLoss(reduction='mean')(output, output_target)
66 |
67 | loss_sum += loss.item()
68 | loss.backward()
69 | optimizer.step()
70 |
71 | bp_count_avg = bp_count / len(train_loader.dataset)
72 | endtime = datetime.datetime.now()
73 | time = (endtime - starttime).seconds
74 |
75 | return time, loss_sum, bp_count_avg
76 |
77 | def adjust_tau(epoch, dynamictau):
78 | tau = args.tau
79 | if dynamictau:
80 | if epoch <= 50:
81 | tau = 0
82 | elif epoch <= 90:
83 | tau = 1
84 | else:
85 | tau = 2
86 | return tau
87 |
88 |
89 | def adjust_learning_rate(optimizer, epoch):
90 | """decrease the learning rate"""
91 | lr = args.lr
92 | if epoch >= 60:
93 | lr = args.lr * 0.1
94 | if epoch >= 90:
95 | lr = args.lr * 0.01
96 | if epoch >= 110:
97 | lr = args.lr * 0.005
98 | for param_group in optimizer.param_groups:
99 | param_group['lr'] = lr
100 |
101 |
102 | def save_checkpoint(state, checkpoint=out_dir, filename='checkpoint.pth.tar'):
103 | filepath = os.path.join(checkpoint, filename)
104 | torch.save(state, filepath)
105 |
106 | # setup data loader
107 | transform_train = transforms.Compose([
108 | transforms.RandomCrop(32, padding=4),
109 | transforms.RandomHorizontalFlip(),
110 | transforms.ToTensor(),
111 | ])
112 | transform_test = transforms.Compose([
113 | transforms.ToTensor(),
114 | ])
115 |
116 | print('==> Load Test Data')
117 | if args.dataset == "cifar10":
118 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
119 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
120 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
121 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
122 | if args.dataset == "svhn":
123 | trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform_train)
124 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
125 | testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform_test)
126 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
127 |
128 | print('==> Load Model')
129 | if args.net == "smallcnn":
130 | model = SmallCNN().cuda()
131 | net = "smallcnn"
132 | if args.net == "resnet18":
133 | model = ResNet18().cuda()
134 | net = "resnet18"
135 | if args.net == "WRN":
136 | # e.g., WRN-34-10
137 | model = Wide_ResNet(depth=args.depth, num_classes=10, widen_factor=args.width_factor, dropRate=args.drop_rate).cuda()
138 | net = "WRN{}-{}-dropout{}".format(args.depth, args.width_factor, args.drop_rate)
139 | if args.net == 'WRN_madry':
140 | # e.g., WRN-32-10
141 | model = Wide_ResNet_Madry(depth=args.depth, num_classes=10, widen_factor=args.width_factor, dropRate=args.drop_rate).cuda()
142 | net = "WRN_madry{}-{}-dropout{}".format(args.depth, args.width_factor, args.drop_rate)
143 | print(net)
144 |
145 | model = torch.nn.DataParallel(model)
146 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
147 |
148 | start_epoch = 0
149 | # Resume
150 | title = 'FAT train'
151 | if args.resume:
152 | # resume directly point to checkpoint.pth.tar e.g., --resume='./out-dir/checkpoint.pth.tar'
153 | print('==> Friendly Adversarial Training Resuming from checkpoint ..')
154 | print(args.resume)
155 | assert os.path.isfile(args.resume)
156 | out_dir = os.path.dirname(args.resume)
157 | checkpoint = torch.load(args.resume)
158 | start_epoch = checkpoint['epoch']
159 | model.load_state_dict(checkpoint['state_dict'])
160 | optimizer.load_state_dict(checkpoint['optimizer'])
161 | logger_test = Logger(os.path.join(out_dir, 'log_results.txt'), title=title, resume=True)
162 | else:
163 | print('==> Friendly Adversarial Training')
164 | logger_test = Logger(os.path.join(args.out_dir, 'log_results.txt'), title=title)
165 | logger_test.set_names(['Epoch', 'Natural Test Acc', 'FGSM Acc', 'PGD20 Acc', 'CW Acc'])
166 |
167 | test_nat_acc = 0
168 | fgsm_acc = 0
169 | test_pgd20_acc = 0
170 | cw_acc = 0
171 | best_epoch = 0
172 | for epoch in range(start_epoch, args.epochs):
173 | adjust_learning_rate(optimizer, epoch + 1)
174 | train_time, train_loss, bp_count_avg = train(model, train_loader, optimizer, adjust_tau(epoch + 1, args.dynamictau))
175 |
176 | ## Evalutions the same as DAT.
177 | loss, test_nat_acc = attack.eval_clean(model, test_loader)
178 | loss, fgsm_acc = attack.eval_robust(model, test_loader, perturb_steps=1, epsilon=0.031, step_size=0.031,loss_fn="cent", category="Madry",rand_init=True)
179 | loss, test_pgd20_acc = attack.eval_robust(model, test_loader, perturb_steps=20, epsilon=0.031, step_size=0.031 / 4,loss_fn="cent", category="Madry", rand_init=True)
180 | loss, cw_acc = attack.eval_robust(model, test_loader, perturb_steps=30, epsilon=0.031, step_size=0.031 / 4,loss_fn="cw", category="Madry", rand_init=True)
181 |
182 | print(
183 | 'Epoch: [%d | %d] | Train Time: %.2f s | BP Average: %.2f | Natural Test Acc %.2f | FGSM Test Acc %.2f | PGD20 Test Acc %.2f | CW Test Acc %.2f |\n' % (
184 | epoch + 1,
185 | args.epochs,
186 | train_time,
187 | bp_count_avg,
188 | test_nat_acc,
189 | fgsm_acc,
190 | test_pgd20_acc,
191 | cw_acc)
192 | )
193 |
194 | logger_test.append([epoch + 1, test_nat_acc, fgsm_acc, test_pgd20_acc, cw_acc])
195 |
196 | save_checkpoint({
197 | 'epoch': epoch + 1,
198 | 'state_dict': model.state_dict(),
199 | 'bp_avg': bp_count_avg,
200 | 'test_nat_acc': test_nat_acc,
201 | 'test_pgd20_acc': test_pgd20_acc,
202 | 'optimizer': optimizer.state_dict(),
203 | })
204 |
--------------------------------------------------------------------------------
/FAT_for_MART.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torchvision
4 | import torch.optim as optim
5 | from torchvision import transforms
6 | import datetime
7 | from models import *
8 | from earlystop import earlystop
9 | import numpy as np
10 | import attack_generator as attack
11 | from utils import Logger
12 |
13 | parser = argparse.ArgumentParser(description='PyTorch Friendly Adversarial Training for MART')
14 | parser.add_argument('--epochs', type=int, default=90, metavar='N', help='number of epochs to train')
15 | parser.add_argument('--weight_decay', '--wd', default=2e-4, type=float, metavar='W')
16 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate')
17 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum')
18 | parser.add_argument('--epsilon', type=float, default=0.031, help='perturbation bound')
19 | parser.add_argument('--num_steps', type=int, default=10, help='maximum perturbation step K')
20 | parser.add_argument('--step_size', type=float, default=0.007, help='step size')
21 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed')
22 | parser.add_argument('--net', type=str, default="WRN",help="decide which network to use,choose from smallcnn,resnet18,WRN")
23 | parser.add_argument('--tau', type=int, default=0, help='step tau')
24 | parser.add_argument('--beta',type=float,default=6.0,help='regularization parameter')
25 | parser.add_argument('--dataset', type=str, default="cifar10", help="choose from cifar10,svhn")
26 | parser.add_argument('--rand_init', type=bool, default=True, help="whether to initialize adversarial sample with random noise")
27 | parser.add_argument('--omega', type=float, default=0.0, help="random sample parameter")
28 | parser.add_argument('--dynamictau', type=bool, default=True, help='whether to use dynamic tau')
29 | parser.add_argument('--depth', type=int, default=34, help='WRN depth')
30 | parser.add_argument('--width_factor', type=int, default=10, help='WRN width factor')
31 | parser.add_argument('--drop_rate', type=float, default=0.0, help='WRN drop rate')
32 | parser.add_argument('--out_dir',type=str,default='./FAT_for_MART_results',help='dir of output')
33 | parser.add_argument('--resume', type=str, default='', help='whether to resume training, default: None')
34 |
35 | args = parser.parse_args()
36 |
37 | # settings
38 | torch.manual_seed(args.seed)
39 | np.random.seed(args.seed)
40 | torch.cuda.manual_seed_all(args.seed)
41 | torch.backends.cudnn.deterministic = False
42 | torch.backends.cudnn.benchmark = True
43 |
44 | out_dir = args.out_dir
45 | if not os.path.exists(out_dir):
46 | os.makedirs(out_dir)
47 |
48 | def MART_loss(adv_logits, natural_logits, target, beta):
49 | # Based on the repo MART https://github.com/YisenWang/MART
50 | kl = nn.KLDivLoss(reduction='none')
51 | batch_size = len(target)
52 | adv_probs = F.softmax(adv_logits, dim=1)
53 | tmp1 = torch.argsort(adv_probs, dim=1)[:, -2:]
54 | new_y = torch.where(tmp1[:, -1] == target, tmp1[:, -2], tmp1[:, -1])
55 | loss_adv = F.cross_entropy(adv_logits, target) + F.nll_loss(torch.log(1.0001 - adv_probs + 1e-12), new_y)
56 | nat_probs = F.softmax(natural_logits, dim=1)
57 | true_probs = torch.gather(nat_probs, 1, (target.unsqueeze(1)).long()).squeeze()
58 | loss_robust = (1.0 / batch_size) * torch.sum(
59 | torch.sum(kl(torch.log(adv_probs + 1e-12), nat_probs), dim=1) * (1.0000001 - true_probs))
60 | loss = loss_adv + float(beta) * loss_robust
61 | return loss
62 |
63 | def train(model, train_loader, optimizer, tau):
64 | starttime = datetime.datetime.now()
65 | loss_sum = 0
66 | bp_count = 0
67 | for batch_idx, (data, target) in enumerate(train_loader):
68 | data, target = data.cuda(), target.cuda()
69 |
70 | # Get friendly adversarial training data via early-stopped PGD
71 | output_adv, output_target, output_natural, count = earlystop(model, data, target, step_size=args.step_size,
72 | epsilon=args.epsilon, perturb_steps=args.num_steps,
73 | tau=tau, randominit_type="normal_distribution_randominit", loss_fn='cent', rand_init=args.rand_init,
74 | omega=args.omega)
75 | bp_count += count
76 | model.train()
77 | optimizer.zero_grad()
78 |
79 | adv_logits = model(output_adv)
80 | natural_logits = model(output_natural)
81 |
82 | # calculate MART adversarial training loss
83 | loss = MART_loss(adv_logits, natural_logits, output_target, args.beta)
84 |
85 | loss_sum += loss.item()
86 | loss.backward()
87 | optimizer.step()
88 |
89 | bp_count_avg = bp_count / len(train_loader.dataset)
90 | endtime = datetime.datetime.now()
91 | time = (endtime - starttime).seconds
92 |
93 | return time, loss_sum, bp_count_avg
94 |
95 | def adjust_tau(epoch, dynamictau):
96 | tau = args.tau
97 | if dynamictau:
98 | if epoch <= 20:
99 | tau = 0
100 | elif epoch <= 40:
101 | tau = 1
102 | elif epoch <= 60:
103 | tau = 2
104 | elif epoch <= 80:
105 | tau = 3
106 | else:
107 | tau = 4
108 | return tau
109 |
110 | def adjust_learning_rate(optimizer, epoch):
111 | """decrease the learning rate"""
112 | lr = args.lr
113 | if epoch >= 60:
114 | lr = args.lr * 0.1
115 | if epoch >= 90:
116 | lr = args.lr * 0.01
117 | for param_group in optimizer.param_groups:
118 | param_group['lr'] = lr
119 |
120 | def save_checkpoint(state, checkpoint=out_dir, filename='checkpoint.pth.tar'):
121 | filepath = os.path.join(checkpoint, filename)
122 | torch.save(state, filepath)
123 |
124 | # setup data loader
125 | transform_train = transforms.Compose([
126 | transforms.RandomCrop(32, padding=4),
127 | transforms.RandomHorizontalFlip(),
128 | transforms.ToTensor(),
129 | ])
130 | transform_test = transforms.Compose([
131 | transforms.ToTensor(),
132 | ])
133 |
134 | print('==> Load Test Data')
135 | if args.dataset == "cifar10":
136 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
137 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
138 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
139 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
140 | if args.dataset == "svhn":
141 | trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform_train)
142 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
143 | testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform_test)
144 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
145 |
146 | print('==> Load Model')
147 | if args.net == "smallcnn":
148 | model = SmallCNN().cuda()
149 | net = "smallcnn"
150 | if args.net == "resnet18":
151 | model = ResNet18().cuda()
152 | net = "resnet18"
153 | if args.net == "WRN":
154 | model = Wide_ResNet(depth=args.depth, num_classes=10, widen_factor=args.width_factor, dropRate=args.drop_rate).cuda()
155 | net = "WRN{}-{}-dropout{}".format(args.depth, args.width_factor, args.drop_rate)
156 | model = torch.nn.DataParallel(model)
157 | print(net)
158 |
159 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
160 |
161 | if not os.path.exists(out_dir):
162 | os.makedirs(out_dir)
163 |
164 | start_epoch = 0
165 | # Resume
166 | title = 'FAT for MART train'
167 | if args.resume:
168 | # resume directly point to checkpoint.pth.tar e.g., --resume='./out-dir/checkpoint.pth.tar'
169 | print ('==> Friendly Adversarial Training for MART Resuming from checkpoint ..')
170 | print(args.resume)
171 | assert os.path.isfile(args.resume)
172 | out_dir = os.path.dirname(args.resume)
173 | checkpoint = torch.load(args.resume)
174 | start_epoch = checkpoint['epoch']
175 | model.load_state_dict(checkpoint['state_dict'])
176 | optimizer.load_state_dict(checkpoint['optimizer'])
177 | logger_test = Logger(os.path.join(out_dir, 'log_results.txt'), title=title, resume=True)
178 | else:
179 | print('==> Friendly Adversarial Training for MART')
180 | logger_test = Logger(os.path.join(out_dir, 'log_results.txt'), title=title)
181 | logger_test.set_names(['Epoch', 'Natural Test Acc', 'FGSM Acc', 'PGD20 Acc', 'CW Acc'])
182 |
183 |
184 | test_nat_acc = 0
185 | fgsm_acc = 0
186 | test_pgd20_acc = 0
187 | cw_acc = 0
188 | for epoch in range(start_epoch, args.epochs):
189 | adjust_learning_rate(optimizer, epoch + 1)
190 | train_time, train_loss, bp_count_avg = train(model, train_loader, optimizer, adjust_tau(epoch + 1, args.dynamictau))
191 |
192 | ## Evalutions the same as TRADES.
193 | loss, test_nat_acc = attack.eval_clean(model, test_loader)
194 | loss, fgsm_acc = attack.eval_robust(model, test_loader, perturb_steps=1, epsilon=0.031, step_size=0.031,loss_fn="cent", category="Madry",rand_init=True)
195 | loss, test_pgd20_acc = attack.eval_robust(model,test_loader, perturb_steps=20, epsilon=0.031, step_size=0.003,loss_fn="cent",category="Madry",rand_init=True)
196 | loss, cw_acc = attack.eval_robust(model,test_loader, perturb_steps=30, epsilon=0.031, step_size=0.003,loss_fn="cw",category="Madry",rand_init=True)
197 |
198 | print(
199 | 'Epoch: [%d | %d] | Train Time: %.2f s | BP Average: %.2f | Natural Test Acc %.2f | FGSM Test Acc %.2f | PGD20 Test Acc %.2f | CW Test Acc %.2f |\n' % (
200 | epoch + 1,
201 | args.epochs,
202 | train_time,
203 | bp_count_avg,
204 | test_nat_acc,
205 | fgsm_acc,
206 | test_pgd20_acc,
207 | cw_acc)
208 | )
209 |
210 | logger_test.append([epoch + 1, test_nat_acc, fgsm_acc, test_pgd20_acc, cw_acc])
211 |
212 | save_checkpoint({
213 | 'epoch': epoch + 1,
214 | 'state_dict': model.state_dict(),
215 | 'bp_avg': bp_count_avg,
216 | 'test_nat_acc': test_nat_acc,
217 | 'test_pgd20_acc': test_pgd20_acc,
218 | 'optimizer': optimizer.state_dict(),
219 | })
220 |
--------------------------------------------------------------------------------
/FAT_for_TRADES.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torchvision
4 | import torch.optim as optim
5 | from torchvision import transforms
6 | import datetime
7 | from models import *
8 | from earlystop import earlystop
9 | import numpy as np
10 | import attack_generator as attack
11 | from utils import Logger
12 |
13 | parser = argparse.ArgumentParser(description='PyTorch Friendly Adversarial Training for TRADES')
14 | parser.add_argument('--epochs', type=int, default=85, metavar='N', help='number of epochs to train')
15 | parser.add_argument('--weight_decay', '--wd', default=2e-4, type=float, metavar='W')
16 | parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate')
17 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum')
18 | parser.add_argument('--epsilon', type=float, default=0.031, help='perturbation bound')
19 | parser.add_argument('--num_steps', type=int, default=10, help='maximum perturbation step K')
20 | parser.add_argument('--step_size', type=float, default=0.007, help='step size')
21 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed')
22 | parser.add_argument('--net', type=str, default="WRN",help="decide which network to use,choose from smallcnn,resnet18,WRN")
23 | parser.add_argument('--tau', type=int, default=0, help='step tau')
24 | parser.add_argument('--beta',type=float,default=6.0,help='regularization parameter')
25 | parser.add_argument('--dataset', type=str, default="cifar10", help="choose from cifar10,svhn")
26 | parser.add_argument('--rand_init', type=bool, default=True, help="whether to initialize adversarial sample with random noise")
27 | parser.add_argument('--omega', type=float, default=0.0, help="random sample parameter")
28 | parser.add_argument('--dynamictau', type=bool, default=True, help='whether to use dynamic tau')
29 | parser.add_argument('--depth', type=int, default=34, help='WRN depth')
30 | parser.add_argument('--width_factor', type=int, default=10, help='WRN width factor')
31 | parser.add_argument('--drop_rate', type=float, default=0.0, help='WRN drop rate')
32 | parser.add_argument('--out_dir',type=str,default='./FAT_for_TRADES_results',help='dir of output')
33 | parser.add_argument('--resume', type=str, default='', help='whether to resume training, default: None')
34 |
35 | args = parser.parse_args()
36 |
37 | # settings
38 | torch.manual_seed(args.seed)
39 | np.random.seed(args.seed)
40 | torch.cuda.manual_seed_all(args.seed)
41 | torch.backends.cudnn.deterministic = False
42 | torch.backends.cudnn.benchmark = True
43 |
44 | out_dir = args.out_dir
45 | if not os.path.exists(out_dir):
46 | os.makedirs(out_dir)
47 |
48 | def TRADES_loss(adv_logits, natural_logits, target, beta):
49 | # Based on the repo TREADES: https://github.com/yaodongyu/TRADES
50 | batch_size = len(target)
51 | criterion_kl = nn.KLDivLoss(size_average=False).cuda()
52 | loss_natural = nn.CrossEntropyLoss(reduction='mean')(natural_logits, target)
53 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(adv_logits, dim=1),
54 | F.softmax(natural_logits, dim=1))
55 | loss = loss_natural + beta * loss_robust
56 | return loss
57 |
58 | def train(model, train_loader, optimizer, tau):
59 | starttime = datetime.datetime.now()
60 | loss_sum = 0
61 | bp_count = 0
62 | for batch_idx, (data, target) in enumerate(train_loader):
63 | data, target = data.cuda(), target.cuda()
64 |
65 | # Get friendly adversarial training data via early-stopped PGD
66 | output_adv, output_target, output_natural, count = earlystop(model, data, target, step_size=args.step_size,
67 | epsilon=args.epsilon, perturb_steps=args.num_steps,
68 | tau=tau, randominit_type="normal_distribution_randominit", loss_fn='kl', rand_init=args.rand_init,
69 | omega=args.omega)
70 | bp_count += count
71 | model.train()
72 | optimizer.zero_grad()
73 |
74 | natural_logits = model(output_natural)
75 | adv_logits = model(output_adv)
76 |
77 | # calculate TRADES adversarial training loss
78 | loss = TRADES_loss(adv_logits,natural_logits,output_target,args.beta)
79 |
80 | loss_sum += loss.item()
81 | loss.backward()
82 | optimizer.step()
83 |
84 | bp_count_avg = bp_count / len(train_loader.dataset)
85 | endtime = datetime.datetime.now()
86 | time = (endtime - starttime).seconds
87 |
88 | return time, loss_sum, bp_count_avg
89 |
90 | def adjust_tau(epoch, dynamictau):
91 | tau = args.tau
92 | if dynamictau:
93 | if epoch <= 30:
94 | tau = 0
95 | elif epoch <= 50:
96 | tau = 1
97 | elif epoch <= 70:
98 | tau = 2
99 | else:
100 | tau = 3
101 | return tau
102 |
103 | def adjust_learning_rate(optimizer, epoch):
104 | """decrease the learning rate"""
105 | lr = args.lr
106 | if epoch >= 75:
107 | lr = args.lr * 0.1
108 | if epoch >= 90:
109 | lr = args.lr * 0.01
110 | for param_group in optimizer.param_groups:
111 | param_group['lr'] = lr
112 |
113 | def save_checkpoint(state, checkpoint=out_dir, filename='checkpoint.pth.tar'):
114 | filepath = os.path.join(checkpoint, filename)
115 | torch.save(state, filepath)
116 |
117 | # setup data loader
118 | transform_train = transforms.Compose([
119 | transforms.RandomCrop(32, padding=4),
120 | transforms.RandomHorizontalFlip(),
121 | transforms.ToTensor(),
122 | ])
123 | transform_test = transforms.Compose([
124 | transforms.ToTensor(),
125 | ])
126 |
127 | print('==> Load Test Data')
128 | if args.dataset == "cifar10":
129 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
130 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
131 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
132 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
133 | if args.dataset == "svhn":
134 | trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform_train)
135 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
136 | testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform_test)
137 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
138 |
139 | print('==> Load Model')
140 | if args.net == "smallcnn":
141 | model = SmallCNN().cuda()
142 | net = "smallcnn"
143 | if args.net == "resnet18":
144 | model = ResNet18().cuda()
145 | net = "resnet18"
146 | if args.net == "WRN":
147 | model = Wide_ResNet(depth=args.depth, num_classes=10, widen_factor=args.width_factor, dropRate=args.drop_rate).cuda()
148 | net = "WRN{}-{}-dropout{}".format(args.depth, args.width_factor, args.drop_rate)
149 | model = torch.nn.DataParallel(model)
150 | print(net)
151 |
152 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
153 |
154 | if not os.path.exists(out_dir):
155 | os.makedirs(out_dir)
156 |
157 | start_epoch = 0
158 | # Resume
159 | title = 'FAT for TRADES train'
160 | if args.resume:
161 | # resume directly point to checkpoint.pth.tar e.g., --resume='./out-dir/checkpoint.pth.tar'
162 | print ('==> Adversarial Training Resuming from checkpoint ..')
163 | print(args.resume)
164 | assert os.path.isfile(args.resume)
165 | out_dir = os.path.dirname(args.resume)
166 | checkpoint = torch.load(args.resume)
167 | start_epoch = checkpoint['epoch']
168 | model.load_state_dict(checkpoint['state_dict'])
169 | optimizer.load_state_dict(checkpoint['optimizer'])
170 | logger_test = Logger(os.path.join(out_dir, 'log_results.txt'), title=title, resume=True)
171 | else:
172 | print('==> Friendly Adversarial Training for TRADES')
173 | logger_test = Logger(os.path.join(out_dir, 'log_results.txt'), title=title)
174 | logger_test.set_names(['Epoch', 'Natural Test Acc', 'FGSM Acc', 'PGD20 Acc', 'CW Acc'])
175 |
176 | test_nat_acc = 0
177 | fgsm_acc = 0
178 | test_pgd20_acc = 0
179 | cw_acc = 0
180 | for epoch in range(start_epoch, args.epochs):
181 | adjust_learning_rate(optimizer, epoch + 1)
182 | train_time, train_loss, bp_count_avg = train(model, train_loader, optimizer, adjust_tau(epoch + 1, args.dynamictau))
183 |
184 | ## Evalutions the same as TRADES.
185 | loss, test_nat_acc = attack.eval_clean(model, test_loader)
186 | loss, fgsm_acc = attack.eval_robust(model, test_loader, perturb_steps=1, epsilon=0.031, step_size=0.031,loss_fn="cent", category="Madry",rand_init=True)
187 | loss, test_pgd20_acc = attack.eval_robust(model,test_loader, perturb_steps=20, epsilon=0.031, step_size=0.003,loss_fn="cent",category="Madry",rand_init=True)
188 | loss, cw_acc = attack.eval_robust(model,test_loader, perturb_steps=30, epsilon=0.031, step_size=0.003,loss_fn="cw",category="Madry",rand_init=True)
189 |
190 | print(
191 | 'Epoch: [%d | %d] | Train Time: %.2f s | BP Average: %.2f | Natural Test Acc %.2f | FGSM Test Acc %.2f | PGD20 Test Acc %.2f | CW Test Acc %.2f |\n' % (
192 | epoch + 1,
193 | args.epochs,
194 | train_time,
195 | bp_count_avg,
196 | test_nat_acc,
197 | fgsm_acc,
198 | test_pgd20_acc,
199 | cw_acc)
200 | )
201 |
202 | logger_test.append([epoch + 1, test_nat_acc, fgsm_acc, test_pgd20_acc, cw_acc])
203 |
204 | save_checkpoint({
205 | 'epoch': epoch + 1,
206 | 'state_dict': model.state_dict(),
207 | 'bp_avg' : bp_count_avg,
208 | 'test_nat_acc': test_nat_acc,
209 | 'test_pgd20_acc': test_pgd20_acc,
210 | 'optimizer': optimizer.state_dict(),
211 | })
212 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Friendly Adversarial Training Code
2 |
3 | This repository provides codes for friendly adversarial training (FAT).
4 |
5 | ICML 2020 Paper: **Attacks Which Do Not Kill Training Make Adversarial Learning Stronger** (https://arxiv.org/abs/2002.11242)
6 | *Jingfeng Zhang\*, Xilie Xu\*, Bo Han, Gang Niu, Lizhen Cui, Masashi Sugiyama and Mohan Kankanhalli*
7 |
8 | ## What is the nature of the adversarial training?
9 | Adversarial data can easily fool the standard trained classifier.
10 | Adversarial training employs the adversarial data into the training process.
11 | Adversarial training aims to achieve two purposes (a) correctly classify the data, and (b) make the decision boundary thick so that no data fall inside the decision boundary.
12 |
13 |
14 |
15 |
16 | The purposes of the adversarial training
17 |
18 |
19 |
20 | ## Conventional formulation of the adversarial training
21 |
22 | Conventional adversarial training is based on the minimax formulation:
23 |
24 | ,y_i),)
25 |
26 | where
27 |
28 | ,y_i).)
29 |
30 | Inside, there is maximization where we find **the most adversarial data**. Outside, there is minimization where we find a classifier to fit those generated adversarial data.
31 |
32 | ### The minimax formulation is pessimistic.
33 |
34 | The minimax-based adversarial training causes the severe degradation of the natural generalization. Why?
35 | The minimax-based adversarial training has a severe cross-over mixture problem: the adversarial data of different classes overshoot into the peer areas. Learning from those adversarial data is very difficult.
36 |
37 |
38 |
39 |
40 | Cross-over mixture problem of the minimax-based adversarial training
41 |
42 |
43 | ## Our **min-min formulation** for the adversarial training.
44 |
45 | The outer minimization keeps the same. Instead of generating adversarial data via the inner maximization, we generate **the friendly adversarial data** minimizing the loss value. There are two constraints (a) the adversarial data is misclassified, and (b) the wrong prediction of the adversarial data is better than the desired prediction by at least a margin 
46 |
47 | ,y_i)\quad\mathrm{s.t.}\quad\ell(f(\tilde{x}),y_i)-\min_{y\in\mathcal{Y}}\ell(f(\tilde{x}),y)\ge\rho)
48 |
49 |
50 | Let us look at comparisons between minimax formulation and min-min formulation.
51 |
52 |
53 |
54 |
55 | Comparisons between minimax formulation and min-min formulation
56 |
57 |
58 | ## A Realization of the Min-min Formulation --- Friendly Adversarial Training (FAT)
59 |
60 | Friendly adversarial training (FAT) employs the friendly adversarial data generated by **early stopped PGD** to update the model.
61 | The early stopped PGD stop the PGD interations once the adversarial data is misclassified. (Controlled by the hyperparameter ```tau``` in the code. Noted that when ```tau``` equal to maximum perturbation step ```num_steps```, our FAT makes the conventional adversarial training e.g., [AT](https://arxiv.org/abs/1706.06083), [TRADES](https://arxiv.org/abs/1901.08573), and [MART](https://openreview.net/forum?id=rklOg6EFwS) as our special cases.)
62 |
63 |
64 |
65 |
66 | Conventional adversarial training employs PGD for searching most adversarial data. Friendly adversarial training employs early stopped PGD for searching friendly adversarial data.
67 |
68 |
69 | ## Preferred Prerequisites
70 |
71 | * Python (3.6)
72 | * Pytorch (1.2.0)
73 | * CUDA
74 | * numpy
75 |
76 |
77 | ## Running FAT, FAT for TRADES, FAT for MART on benchmark datasets (CIFAR-10 and SVHN)
78 |
79 | Here are examples:
80 | * Train WRN-32-10 model on CIFAR-10 and compare our results with [AT](https://arxiv.org/abs/1706.06083), [CAT](https://arxiv.org/abs/1805.04807) and [DAT](http://proceedings.mlr.press/v97/wang19i/wang19i.pdf):
81 | ```bash
82 | CUDA_VISIBLE_DEVICES='0' python FAT.py --epsilon 0.031
83 | CUDA_VISIBLE_DEVICES='0' python FAT.py --epsilon 0.062
84 | ```
85 | ### White-box evaluations on WRN-32-10
86 |
87 | | Defense | Natural Acc. | FGSM Acc. | PGD-20 Acc. | C&W Acc. |
88 | |-----------------------|-----------------------|------------------|-----------------|-----------------|
89 | |[AT(Madry)](https://arxiv.org/abs/1706.06083) | 87.30% | 56.10% | 45.80% | 46.80%
90 | | [CAT](https://arxiv.org/abs/1805.04807) | 77.43% | 57.17% | 46.06% | 42.28%
91 | | [DAT](http://proceedings.mlr.press/v97/wang19i/wang19i.pdf) | 85.03% | 63.53% | 48.70% | 47.27%
92 | | FAT () | **89.34**0.221% |65.520.355%| 46.130.049%| 46.820.517%
93 | | FAT () | 87.000.203%| **65.94**0.244%|**49.86**0.328%|**48.65**0.176%
94 |
95 | Results of AT(Madry), CAT and DAT are reported in [DAT](http://proceedings.mlr.press/v97/wang19i/wang19i.pdf). FAT has the same evaluations.
96 |
97 | * Train WRN-34-10 model on CIFAR-10 and compare our results with [TRADES](https://arxiv.org/abs/1901.08573), and [MART](https://openreview.net/forum?id=rklOg6EFwS).
98 | ```bash
99 | CUDA_VISIBLE_DEVICES='0' python FAT_for_TRADES.py --epsilon 0.031
100 | CUDA_VISIBLE_DEVICES='0' python FAT_for_TRADES.py --epsilon 0.062
101 | CUDA_VISIBLE_DEVICES='0' python FAT_for_MART.py --epsilon 0.031
102 | CUDA_VISIBLE_DEVICES='0' python FAT_for_MART.py --epsilon 0.062
103 | ```
104 |
105 | ### White-box evaluations on WRN-34-10
106 |
107 | | Defense | Natural Acc. | FGSM Acc. | PGD-20 Acc. | C&W Acc. |
108 | |-----------------------|-----------------------|------------------|-----------------|-----------------|
109 | |[TRADES](https://arxiv.org/abs/1901.08573)()| 88.64% | 56.38% | 49.14% | -
110 | |FAT for TRADES()| **89.94**0.303% |61.000.418% |49.700.653%|49.350.363%
111 | |[TRADES](https://arxiv.org/abs/1901.08573)()|84.92%|61.06%|56.61%|**54.47**%
112 | |FAT for TRADES()| 86.600.548% |**61.79**0.570% |55.980.209%|54.290.173%
113 | |FAT for TRADES()| 84.390.030% |61.730.131% |**57.12**0.233%|54.360.177%
114 |
115 | Results of TRADES ( and ) are reported in [TRADES](https://arxiv.org/abs/1901.08573). FAT for TRADES has the same evaluations. Noted that our evaluations of the above are the same as the description in the TRADES's paper, i.e., adversarial data are generated without random start ```rand_init=False```.
116 | However, in [TRADES’s GitHub](https://github.com/yaodongyu/TRADES), they use random start ```rand_init=True``` before PGD perturbation that is deviated from the statements in their paper. For the fair evaluations of FAT with random start, please refer to the Table 3 in [our paper](https://arxiv.org/pdf/2002.11242.pdf).
117 |
118 | ### How to recover original AT, TRADES, or MART?
119 | Just set ```tau=10```, i.e.,
120 | ```
121 | python FAT.py --epsilon 0.031 --tau 10 --dynamictau False
122 | python FAT_for_TRADES --epsilon 0.031 --tau 10 --dynamictau False
123 | python FAT_for_MART.py --epsilon 0.031 --tau 10 --dynamictau False
124 | ```
125 |
126 |
127 | ## Want to attack FAT? Sure!
128 |
129 | We welcome various attack methods to attack our defense models. For cifar-10 dataset, we normalize all images into ```[0,1]```.
130 |
131 | Download our pretrained models into the folder ```FAT_models``` through this [Google Drive link](https://drive.google.com/drive/folders/1lV3qob_zR-YpFVGuKiiE5hNu74NID-ZS?usp=sharing) or [Baidu Drive link](https://pan.baidu.com/s/17XBd02FoGFqgYCVy2Fm_SQ)(extraction code: ww7f).
132 | ```bash
133 | cd Friendly-Adversarial-Training
134 | mkdir FAT_models
135 | ```
136 | Run robustness evaluations.
137 | ```bash
138 | chmod +x attack_test.sh
139 | ./attack_test.sh
140 | ```
141 |
142 | ## Reference
143 |
144 | ```
145 | @inproceedings{zhang2020fat,
146 | title={Attacks Which Do Not Kill Training Make Adversarial Learning Stronger},
147 | author={Zhang, Jingfeng and Xu, Xilie and Han, Bo and Niu, Gang and Cui, Lizhen and Sugiyama, Masashi and Kankanhalli, Mohan},
148 | booktitle = {ICML},
149 | year={2020}
150 | }
151 | ```
152 |
153 | ## Contact
154 |
155 | Please contact jingfeng.zhang@auckland.ac.nz (preferred) OR jingfeng.zhang9660@gmail.com and xuxilie@comp.nus.edu.sg if you have any question on the codes.
156 |
--------------------------------------------------------------------------------
/attack_generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from models import *
3 |
4 | def cwloss(output, target,confidence=50, num_classes=10):
5 | # Compute the probability of the label class versus the maximum other
6 | # The same implementation as in repo CAT https://github.com/sunblaze-ucb/curriculum-adversarial-training-CAT
7 | target = target.data
8 | target_onehot = torch.zeros(target.size() + (num_classes,))
9 | target_onehot = target_onehot.cuda()
10 | target_onehot.scatter_(1, target.unsqueeze(1), 1.)
11 | target_var = Variable(target_onehot, requires_grad=False)
12 | real = (target_var * output).sum(1)
13 | other = ((1. - target_var) * output - target_var * 10000.).max(1)[0]
14 | loss = -torch.clamp(real - other + confidence, min=0.) # equiv to max(..., 0.)
15 | loss = torch.sum(loss)
16 | return loss
17 |
18 | def pgd(model, data, target, epsilon, step_size, num_steps,loss_fn,category,rand_init):
19 | model.eval()
20 | if category == "trades":
21 | x_adv = data.detach() + 0.001 * torch.randn(data.shape).cuda().detach() if rand_init else data.detach()
22 | if category == "Madry":
23 | x_adv = data.detach() + torch.from_numpy(np.random.uniform(-epsilon, epsilon, data.shape)).float().cuda() if rand_init else data.detach()
24 | x_adv = torch.clamp(x_adv, 0.0, 1.0)
25 | for k in range(num_steps):
26 | x_adv.requires_grad_()
27 | output = model(x_adv)
28 | model.zero_grad()
29 | with torch.enable_grad():
30 | if loss_fn == "cent":
31 | loss_adv = nn.CrossEntropyLoss(reduction="mean")(output, target)
32 | if loss_fn == "cw":
33 | loss_adv = cwloss(output,target)
34 | loss_adv.backward()
35 | eta = step_size * x_adv.grad.sign()
36 | x_adv = x_adv.detach() + eta
37 | x_adv = torch.min(torch.max(x_adv, data - epsilon), data + epsilon)
38 | x_adv = torch.clamp(x_adv, 0.0, 1.0)
39 | return x_adv
40 |
41 | def eval_clean(model, test_loader):
42 | model.eval()
43 | test_loss = 0
44 | correct = 0
45 | with torch.no_grad():
46 | for data, target in test_loader:
47 | data, target = data.cuda(), target.cuda()
48 | output = model(data)
49 | test_loss += nn.CrossEntropyLoss(reduction='mean')(output, target).item()
50 | pred = output.max(1, keepdim=True)[1]
51 | correct += pred.eq(target.view_as(pred)).sum().item()
52 | test_loss /= len(test_loader.dataset)
53 | log = 'Natrual Test Result ==> Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
54 | test_loss, correct, len(test_loader.dataset),
55 | 100. * correct / len(test_loader.dataset))
56 | # print(log)
57 | test_accuracy = correct / len(test_loader.dataset)
58 | return test_loss, test_accuracy
59 |
60 | def eval_robust(model, test_loader, perturb_steps, epsilon, step_size, loss_fn, category, rand_init):
61 | model.eval()
62 | test_loss = 0
63 | correct = 0
64 | with torch.enable_grad():
65 | for data, target in test_loader:
66 | data, target = data.cuda(), target.cuda()
67 | x_adv = pgd(model,data,target,epsilon,step_size,perturb_steps,loss_fn,category,rand_init=rand_init)
68 | output = model(x_adv)
69 | test_loss += nn.CrossEntropyLoss(reduction='mean')(output, target).item()
70 | pred = output.max(1, keepdim=True)[1]
71 | correct += pred.eq(target.view_as(pred)).sum().item()
72 | test_loss /= len(test_loader.dataset)
73 | log = 'Attack Setting ==> Loss_fn:{}, Perturb steps:{}, Epsilon:{}, Step dize:{} \n Test Result ==> Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(loss_fn,perturb_steps,epsilon,step_size,
74 | test_loss, correct, len(test_loader.dataset),
75 | 100. * correct / len(test_loader.dataset))
76 | # print(log)
77 | test_accuracy = correct / len(test_loader.dataset)
78 | return test_loss, test_accuracy
79 |
80 |
--------------------------------------------------------------------------------
/attack_test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch.nn as nn
3 | import torchvision
4 | from torchvision import transforms
5 | from models import *
6 | import attack_generator as attack
7 |
8 | parser = argparse.ArgumentParser(description='PyTorch White-box Adversarial Attack Test')
9 | parser.add_argument('--net', type=str, default="WRN", help="decide which network to use,choose from smallcnn,resnet18,WRN")
10 | parser.add_argument('--dataset', type=str, default="cifar10", help="choose from cifar10,svhn")
11 | parser.add_argument('--depth', type=int, default=34, help='WRN depth')
12 | parser.add_argument('--width_factor', type=int, default=10,help='WRN width factor')
13 | parser.add_argument('--drop_rate', type=float,default=0.0, help='WRN drop rate')
14 | parser.add_argument('--attack_method', type=str,default="dat", help = "choose form: dat and trades")
15 | parser.add_argument('--model_path', default='./FAT_models/fat_for_trades_wrn34-10_eps0.031_beta1.0.pth.tar', help='model for white-box attack evaluation')
16 | parser.add_argument('--method',type=str,default='dat',help='select attack setting following DAT or TRADES')
17 |
18 | args = parser.parse_args()
19 |
20 | transform_test = transforms.Compose([
21 | transforms.ToTensor(),
22 | ])
23 |
24 | print('==> Load Test Data')
25 | if args.dataset == "cifar10":
26 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
27 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
28 | if args.dataset == "svhn":
29 | testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform_test)
30 | test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
31 |
32 | print('==> Load Model')
33 | if args.net == "smallcnn":
34 | model = SmallCNN().cuda()
35 | net = "smallcnn"
36 | if args.net == "resnet18":
37 | model = ResNet18().cuda()
38 | net = "resnet18"
39 | if args.net == "WRN":
40 | ## WRN-34-10
41 | model = Wide_ResNet(depth=args.depth, num_classes=10, widen_factor=args.width_factor, dropRate=args.drop_rate).cuda()
42 | net = "WRN{}-{}-dropout{}".format(args.depth,args.width_factor,args.drop_rate)
43 | if args.net == 'WRN_madry':
44 | ## WRN-32-10
45 | model = Wide_ResNet_Madry(depth=args.depth, num_classes=10, widen_factor=args.width_factor, dropRate=args.drop_rate).cuda()
46 | net = "WRN_madry{}-{}-dropout{}".format(args.depth, args.width_factor, args.drop_rate)
47 | model = torch.nn.DataParallel(model)
48 | print(net)
49 |
50 | model.load_state_dict(torch.load(args.model_path)['state_dict'])
51 |
52 | print('==> Evaluating Performance under White-box Adversarial Attack')
53 |
54 | loss, test_nat_acc = attack.eval_clean(model, test_loader)
55 | print('Natural Test Accuracy: {:.2f}%'.format(100. * test_nat_acc))
56 | if args.method == "dat":
57 | # Evalutions the same as DAT.
58 | loss, fgsm_acc = attack.eval_robust(model, test_loader, perturb_steps=1, epsilon=0.031, step_size=0.031,loss_fn="cent", category="Madry",rand_init=True)
59 | print('FGSM Test Accuracy: {:.2f}%'.format(100. * fgsm_acc))
60 | loss, pgd20_acc = attack.eval_robust(model, test_loader, perturb_steps=20, epsilon=0.031, step_size=0.031 / 4,loss_fn="cent", category="Madry", rand_init=True)
61 | print('PGD20 Test Accuracy: {:.2f}%'.format(100. * pgd20_acc))
62 | loss, cw_acc = attack.eval_robust(model, test_loader, perturb_steps=30, epsilon=0.031, step_size=0.031 / 4,loss_fn="cw", category="Madry", rand_init=True)
63 | print('CW Test Accuracy: {:.2f}%'.format(100. * cw_acc))
64 | if args.method == 'trades':
65 | # Evalutions the same as TRADES.
66 | # wri : with random init, wori : without random init
67 | loss, fgsm_wori_acc = attack.eval_robust(model, test_loader, perturb_steps=1, epsilon=0.031, step_size=0.031,loss_fn="cent", category="Madry",rand_init=False)
68 | print('FGSM without Random Start Test Accuracy: {:.2f}%'.format(100. * fgsm_wori_acc))
69 | loss, pgd20_wori_acc = attack.eval_robust(model,test_loader, perturb_steps=20, epsilon=0.031, step_size=0.003,loss_fn="cent",category="Madry",rand_init=False)
70 | print('PGD20 without Random Start Test Accuracy: {:.2f}%'.format(100. * pgd20_wori_acc))
71 | loss, cw_wori_acc = attack.eval_robust(model,test_loader, perturb_steps=30, epsilon=0.031, step_size=0.003,loss_fn="cw",category="Madry",rand_init=False)
72 | print('CW without Random Start Test Accuracy: {:.2f}%'.format(100. * cw_wori_acc))
73 | loss, fgsm_wri_acc = attack.eval_robust(model, test_loader, perturb_steps=1, epsilon=0.031, step_size=0.031,loss_fn="cent", category="Madry",rand_init=True)
74 | print('FGSM with Random Start Test Accuracy: {:.2f}%'.format(100. * fgsm_wri_acc))
75 | loss, pgd20_wri_acc = attack.eval_robust(model,test_loader, perturb_steps=20, epsilon=0.031, step_size=0.003,loss_fn="cent",category="Madry",rand_init=True)
76 | print('PGD20 with Random Start Test Accuracy: {:.2f}%'.format(100. * pgd20_wri_acc))
77 | loss, cw_wri_acc = attack.eval_robust(model,test_loader, perturb_steps=30, epsilon=0.031, step_size=0.003,loss_fn="cw",category="Madry",rand_init=True)
78 | print('CW with Random Start Test Accuracy: {:.2f}%'.format(100. * cw_wri_acc))
79 |
--------------------------------------------------------------------------------
/attack_test.sh:
--------------------------------------------------------------------------------
1 | python attack_test.py --net 'WRN_madry' --depth 32 --model_path './FAT_models/fat_wrn32-10_eps0.031.pth.tar' --method 'dat'
2 | python attack_test.py --net 'WRN_madry' --depth 32 --model_path './FAT_models/fat_wrn32-10_eps0.062.pth.tar' --method 'dat'
3 |
4 | python attack_test.py --net 'WRN' --depth 34 --model_path './FAT_models/fat_for_trades_wrn34-10_eps0.031_beta1.0.pth.tar' --method 'trades'
5 | python attack_test.py --net 'WRN' --depth 34 --model_path './FAT_models/fat_for_trades_wrn34-10_eps0.031_beta6.0.pth.tar' --method 'trades'
6 | python attack_test.py --net 'WRN' --depth 34 --model_path './FAT_models/fat_for_trades_wrn34-10_eps0.062_beta6.0.pth.tar' --method 'trades'
7 |
8 | python attack_test.py --net 'WRN' --depth 58 --model_path './FAT_models/fat_for_trades_wrn58-10_eps0.031_beta6.0.pth.tar' --method 'trades'
9 | python attack_test.py --net 'WRN' --depth 58 --model_path './FAT_models/fat_for_trades_wrn58-10_eps0.062_beta6.0.pth.tar' --method 'trades'
10 |
11 | python attack_test.py --net 'WRN' --depth 34 --model_path './FAT_models/fat_for_mart_wrn34-10_eps0.031.pth.tar' --method 'trades'
12 | python attack_test.py --net 'WRN' --depth 34 --model_path './FAT_models/fat_for_mart_wrn34-10_eps0.062.pth.tar' --method 'trades'
13 |
--------------------------------------------------------------------------------
/earlystop.py:
--------------------------------------------------------------------------------
1 | from models import *
2 | import torch
3 | import numpy as np
4 |
5 | def earlystop(model, data, target, step_size, epsilon, perturb_steps,tau,randominit_type,loss_fn,rand_init=True,omega=0):
6 | '''
7 | The implematation of early-stopped PGD
8 | Following the Alg.1 in our FAT paper
9 | :param step_size: the PGD step size
10 | :param epsilon: the perturbation bound
11 | :param perturb_steps: the maximum PGD step
12 | :param tau: the step controlling how early we should stop interations when wrong adv data is found
13 | :param randominit_type: To decide the type of random inirialization (random start for searching adv data)
14 | :param rand_init: To decide whether to initialize adversarial sample with random noise (random start for searching adv data)
15 | :param omega: random sample parameter for adv data generation (this is for escaping the local minimum.)
16 | :return: output_adv (friendly adversarial data) output_target (targets), output_natural (the corresponding natrual data), count (average backword propagations count)
17 | '''
18 | model.eval()
19 |
20 | K = perturb_steps
21 | count = 0
22 | output_target = []
23 | output_adv = []
24 | output_natural = []
25 |
26 | control = (torch.ones(len(target)) * tau).cuda()
27 |
28 | # Initialize the adversarial data with random noise
29 | if rand_init:
30 | if randominit_type == "normal_distribution_randominit":
31 | iter_adv = data.detach() + 0.001 * torch.randn(data.shape).cuda().detach()
32 | iter_adv = torch.clamp(iter_adv, 0.0, 1.0)
33 | if randominit_type == "uniform_randominit":
34 | iter_adv = data.detach() + torch.from_numpy(np.random.uniform(-epsilon, epsilon, data.shape)).float().cuda()
35 | iter_adv = torch.clamp(iter_adv, 0.0, 1.0)
36 | else:
37 | iter_adv = data.cuda().detach()
38 |
39 | iter_clean_data = data.cuda().detach()
40 | iter_target = target.cuda().detach()
41 | output_iter_clean_data = model(data)
42 |
43 | while K>0:
44 | iter_adv.requires_grad_()
45 | output = model(iter_adv)
46 | pred = output.max(1, keepdim=True)[1]
47 | output_index = []
48 | iter_index = []
49 |
50 | # Calculate the indexes of adversarial data those still needs to be iterated
51 | for idx in range(len(pred)):
52 | if pred[idx] != iter_target[idx]:
53 | if control[idx] == 0:
54 | output_index.append(idx)
55 | else:
56 | control[idx] -= 1
57 | iter_index.append(idx)
58 | else:
59 | iter_index.append(idx)
60 |
61 | # Add adversarial data those do not need any more iteration into set output_adv
62 | if len(output_index) != 0:
63 | if len(output_target) == 0:
64 | # incorrect adv data should not keep iterated
65 | output_adv = iter_adv[output_index].reshape(-1, 3, 32, 32).cuda()
66 | output_natural = iter_clean_data[output_index].reshape(-1, 3, 32, 32).cuda()
67 | output_target = iter_target[output_index].reshape(-1).cuda()
68 | else:
69 | # incorrect adv data should not keep iterated
70 | output_adv = torch.cat((output_adv, iter_adv[output_index].reshape(-1, 3, 32, 32).cuda()), dim=0)
71 | output_natural = torch.cat((output_natural, iter_clean_data[output_index].reshape(-1, 3, 32, 32).cuda()), dim=0)
72 | output_target = torch.cat((output_target, iter_target[output_index].reshape(-1).cuda()), dim=0)
73 |
74 | # calculate gradient
75 | model.zero_grad()
76 | with torch.enable_grad():
77 | if loss_fn == "cent":
78 | loss_adv = nn.CrossEntropyLoss(reduction='mean')(output, iter_target)
79 | if loss_fn == "kl":
80 | criterion_kl = nn.KLDivLoss(size_average=False).cuda()
81 | loss_adv = criterion_kl(F.log_softmax(output, dim=1),F.softmax(output_iter_clean_data, dim=1))
82 | loss_adv.backward(retain_graph=True)
83 | grad = iter_adv.grad
84 |
85 | # update iter adv
86 | if len(iter_index) != 0:
87 | control = control[iter_index]
88 | iter_adv = iter_adv[iter_index]
89 | iter_clean_data = iter_clean_data[iter_index]
90 | iter_target = iter_target[iter_index]
91 | output_iter_clean_data = output_iter_clean_data[iter_index]
92 | grad = grad[iter_index]
93 | eta = step_size * grad.sign()
94 |
95 | iter_adv = iter_adv.detach() + eta + omega * torch.randn(iter_adv.shape).detach().cuda()
96 | iter_adv = torch.min(torch.max(iter_adv, iter_clean_data - epsilon), iter_clean_data + epsilon)
97 | iter_adv = torch.clamp(iter_adv, 0, 1)
98 | count += len(iter_target)
99 | else:
100 | output_adv = output_adv.detach()
101 | return output_adv, output_target, output_natural, count
102 | K = K-1
103 |
104 | if len(output_target) == 0:
105 | output_target = iter_target.reshape(-1).squeeze().cuda()
106 | output_adv = iter_adv.reshape(-1, 3, 32, 32).cuda()
107 | output_natural = iter_clean_data.reshape(-1, 3, 32, 32).cuda()
108 | else:
109 | output_adv = torch.cat((output_adv, iter_adv.reshape(-1, 3, 32, 32)), dim=0).cuda()
110 | output_target = torch.cat((output_target, iter_target.reshape(-1)), dim=0).squeeze().cuda()
111 | output_natural = torch.cat((output_natural, iter_clean_data.reshape(-1, 3, 32, 32).cuda()),dim=0).cuda()
112 | output_adv = output_adv.detach()
113 | return output_adv, output_target, output_natural, count
114 |
--------------------------------------------------------------------------------
/image/adv_train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/image/adv_train.png
--------------------------------------------------------------------------------
/image/cross_over_mixture_problem.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/image/cross_over_mixture_problem.png
--------------------------------------------------------------------------------
/image/early_stopped_pgd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/image/early_stopped_pgd.png
--------------------------------------------------------------------------------
/image/min-min_vs_minmax.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/image/min-min_vs_minmax.png
--------------------------------------------------------------------------------
/image/min_min_formulation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/image/min_min_formulation.png
--------------------------------------------------------------------------------
/image/minimax_formulation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/image/minimax_formulation.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .vgg import *
2 | from .dpn import *
3 | from .lenet import *
4 | from .senet import *
5 | from .resnet import *
6 | from .resnext import *
7 | from .densenet import *
8 | from .googlenet import *
9 | from .mobilenet import *
10 | from .shufflenet import *
11 | from .preact_resnet import *
12 | from .wide_resnet import *
13 | from .small_cnn import *
14 | from .wrn_madry import *
15 |
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/densenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/densenet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/dpn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/dpn.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/googlenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/googlenet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/lenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/lenet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/mobilenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/mobilenet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/preact_resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/preact_resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnext.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/resnext.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/senet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/senet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/shufflenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/shufflenet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/small_cnn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/small_cnn.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/vgg.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/vgg.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/wide_resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/wide_resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/wrn_madry.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/models/__pycache__/wrn_madry.cpython-36.pyc
--------------------------------------------------------------------------------
/models/densenet.py:
--------------------------------------------------------------------------------
1 | '''DenseNet in PyTorch.'''
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from torch.autograd import Variable
9 |
10 |
11 | class Bottleneck(nn.Module):
12 | def __init__(self, in_planes, growth_rate):
13 | super(Bottleneck, self).__init__()
14 | self.bn1 = nn.BatchNorm2d(in_planes,momentum=0.2)
15 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False)
16 | self.bn2 = nn.BatchNorm2d(4*growth_rate,momentum=0.2)
17 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
18 |
19 | def forward(self, x):
20 | out = self.conv1(F.relu(self.bn1(x)))
21 | out = self.conv2(F.relu(self.bn2(out)))
22 | out = torch.cat([out,x], 1)
23 | return out
24 |
25 |
26 | class Transition(nn.Module):
27 | def __init__(self, in_planes, out_planes):
28 | super(Transition, self).__init__()
29 | self.bn = nn.BatchNorm2d(in_planes,momentum=0.2)
30 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)
31 |
32 | def forward(self, x):
33 | out = self.conv(F.relu(self.bn(x)))
34 | out = F.avg_pool2d(out, 2)
35 | return out
36 |
37 |
38 | class DenseNet(nn.Module):
39 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.2, num_classes=10):
40 | super(DenseNet, self).__init__()
41 | self.growth_rate = growth_rate
42 |
43 | num_planes = 2*growth_rate
44 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)
45 |
46 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0])
47 | num_planes += nblocks[0]*growth_rate
48 | out_planes = int(math.floor(num_planes*reduction))
49 | self.trans1 = Transition(num_planes, out_planes)
50 | num_planes = out_planes
51 |
52 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1])
53 | num_planes += nblocks[1]*growth_rate
54 | out_planes = int(math.floor(num_planes*reduction))
55 | self.trans2 = Transition(num_planes, out_planes)
56 | num_planes = out_planes
57 |
58 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2])
59 | num_planes += nblocks[2]*growth_rate
60 | out_planes = int(math.floor(num_planes*reduction))
61 | self.trans3 = Transition(num_planes, out_planes)
62 | num_planes = out_planes
63 |
64 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3])
65 | num_planes += nblocks[3]*growth_rate
66 |
67 | self.bn = nn.BatchNorm2d(num_planes,momentum=0.2)
68 | self.linear = nn.Linear(num_planes, num_classes)
69 |
70 | def _make_dense_layers(self, block, in_planes, nblock):
71 | layers = []
72 | for i in range(nblock):
73 | layers.append(block(in_planes, self.growth_rate))
74 | in_planes += self.growth_rate
75 | return nn.Sequential(*layers)
76 |
77 | def forward(self, x):
78 | out = self.conv1(x)
79 | out = self.trans1(self.dense1(out))
80 | out = self.trans2(self.dense2(out))
81 | out = self.trans3(self.dense3(out))
82 | out = self.dense4(out)
83 | out = F.avg_pool2d(F.relu(self.bn(out)), 4)
84 | out = out.view(out.size(0), -1)
85 | out = self.linear(out)
86 | return out
87 |
88 | def DenseNet121():
89 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32)
90 |
91 | def DenseNet169():
92 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32)
93 |
94 | def DenseNet201():
95 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=48)
96 |
97 | def DenseNet161():
98 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48)
99 |
100 | def densenet_cifar():
101 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12)
102 |
103 | def test_densenet():
104 | net = densenet_cifar()
105 | x = torch.randn(1,3,32,32)
106 | y = net(Variable(x))
107 | print(y)
108 | print(net)
109 | #test_densenet()
110 |
--------------------------------------------------------------------------------
/models/dpn.py:
--------------------------------------------------------------------------------
1 | '''Dual Path Networks in PyTorch.'''
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torch.autograd import Variable
7 |
8 |
9 | class Bottleneck(nn.Module):
10 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer):
11 | super(Bottleneck, self).__init__()
12 | self.out_planes = out_planes
13 | self.dense_depth = dense_depth
14 |
15 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False)
16 | self.bn1 = nn.BatchNorm2d(in_planes)
17 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False)
18 | self.bn2 = nn.BatchNorm2d(in_planes)
19 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False)
20 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth)
21 |
22 | self.shortcut = nn.Sequential()
23 | if first_layer:
24 | self.shortcut = nn.Sequential(
25 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False),
26 | nn.BatchNorm2d(out_planes+dense_depth)
27 | )
28 |
29 | def forward(self, x):
30 | out = F.relu(self.bn1(self.conv1(x)))
31 | out = F.relu(self.bn2(self.conv2(out)))
32 | out = self.bn3(self.conv3(out))
33 | x = self.shortcut(x)
34 | d = self.out_planes
35 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1)
36 | out = F.relu(out)
37 | return out
38 |
39 |
40 | class DPN(nn.Module):
41 | def __init__(self, cfg):
42 | super(DPN, self).__init__()
43 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes']
44 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth']
45 |
46 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
47 | self.bn1 = nn.BatchNorm2d(64)
48 | self.last_planes = 64
49 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1)
50 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2)
51 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2)
52 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2)
53 | self.linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], 10)
54 |
55 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride):
56 | strides = [stride] + [1]*(num_blocks-1)
57 | layers = []
58 | for i,stride in enumerate(strides):
59 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0))
60 | self.last_planes = out_planes + (i+2) * dense_depth
61 | return nn.Sequential(*layers)
62 |
63 | def forward(self, x):
64 | out = F.relu(self.bn1(self.conv1(x)))
65 | out = self.layer1(out)
66 | out = self.layer2(out)
67 | out = self.layer3(out)
68 | out = self.layer4(out)
69 | out = F.avg_pool2d(out, 4)
70 | out = out.view(out.size(0), -1)
71 | out = self.linear(out)
72 | return out
73 |
74 |
75 | def DPN26():
76 | cfg = {
77 | 'in_planes': (96,192,384,768),
78 | 'out_planes': (256,512,1024,2048),
79 | 'num_blocks': (2,2,2,2),
80 | 'dense_depth': (16,32,24,128)
81 | }
82 | return DPN(cfg)
83 |
84 | def DPN92():
85 | cfg = {
86 | 'in_planes': (96,192,384,768),
87 | 'out_planes': (256,512,1024,2048),
88 | 'num_blocks': (3,4,20,3),
89 | 'dense_depth': (16,32,24,128)
90 | }
91 | return DPN(cfg)
92 |
93 |
94 | def test():
95 | net = DPN92()
96 | x = Variable(torch.randn(1,3,32,32))
97 | y = net(x)
98 | print(y)
99 |
100 | # test()
101 |
--------------------------------------------------------------------------------
/models/googlenet.py:
--------------------------------------------------------------------------------
1 | '''GoogLeNet with PyTorch.'''
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torch.autograd import Variable
7 |
8 |
9 | class Inception(nn.Module):
10 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes):
11 | super(Inception, self).__init__()
12 | # 1x1 conv branch
13 | self.b1 = nn.Sequential(
14 | nn.Conv2d(in_planes, n1x1, kernel_size=1),
15 | nn.BatchNorm2d(n1x1),
16 | nn.ReLU(True),
17 | )
18 |
19 | # 1x1 conv -> 3x3 conv branch
20 | self.b2 = nn.Sequential(
21 | nn.Conv2d(in_planes, n3x3red, kernel_size=1),
22 | nn.BatchNorm2d(n3x3red),
23 | nn.ReLU(True),
24 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1),
25 | nn.BatchNorm2d(n3x3),
26 | nn.ReLU(True),
27 | )
28 |
29 | # 1x1 conv -> 5x5 conv branch
30 | self.b3 = nn.Sequential(
31 | nn.Conv2d(in_planes, n5x5red, kernel_size=1),
32 | nn.BatchNorm2d(n5x5red),
33 | nn.ReLU(True),
34 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1),
35 | nn.BatchNorm2d(n5x5),
36 | nn.ReLU(True),
37 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1),
38 | nn.BatchNorm2d(n5x5),
39 | nn.ReLU(True),
40 | )
41 |
42 | # 3x3 pool -> 1x1 conv branch
43 | self.b4 = nn.Sequential(
44 | nn.MaxPool2d(3, stride=1, padding=1),
45 | nn.Conv2d(in_planes, pool_planes, kernel_size=1),
46 | nn.BatchNorm2d(pool_planes),
47 | nn.ReLU(True),
48 | )
49 |
50 | def forward(self, x):
51 | y1 = self.b1(x)
52 | y2 = self.b2(x)
53 | y3 = self.b3(x)
54 | y4 = self.b4(x)
55 | return torch.cat([y1,y2,y3,y4], 1)
56 |
57 |
58 | class GoogLeNet(nn.Module):
59 | def __init__(self):
60 | super(GoogLeNet, self).__init__()
61 | self.pre_layers = nn.Sequential(
62 | nn.Conv2d(3, 192, kernel_size=3, padding=1),
63 | nn.BatchNorm2d(192),
64 | nn.ReLU(True),
65 | )
66 |
67 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32)
68 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)
69 |
70 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
71 |
72 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64)
73 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64)
74 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64)
75 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64)
76 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)
77 |
78 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
79 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)
80 |
81 | self.avgpool = nn.AvgPool2d(8, stride=1)
82 | self.linear = nn.Linear(1024, 10)
83 |
84 | def forward(self, x):
85 | out = self.pre_layers(x)
86 | out = self.a3(out)
87 | out = self.b3(out)
88 | out = self.maxpool(out)
89 | out = self.a4(out)
90 | out = self.b4(out)
91 | out = self.c4(out)
92 | out = self.d4(out)
93 | out = self.e4(out)
94 | out = self.maxpool(out)
95 | out = self.a5(out)
96 | out = self.b5(out)
97 | out = self.avgpool(out)
98 | out = out.view(out.size(0), -1)
99 | out = self.linear(out)
100 | return out
101 |
102 | # net = GoogLeNet()
103 | # x = torch.randn(1,3,32,32)
104 | # y = net(Variable(x))
105 | # print(y.size())
106 |
--------------------------------------------------------------------------------
/models/lenet.py:
--------------------------------------------------------------------------------
1 | '''LeNet in PyTorch.'''
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class LeNet(nn.Module):
6 | def __init__(self):
7 | super(LeNet, self).__init__()
8 | self.conv1 = nn.Conv2d(3, 6, 5)
9 | self.conv2 = nn.Conv2d(6, 16, 5)
10 | self.fc1 = nn.Linear(16*5*5, 120)
11 | self.fc2 = nn.Linear(120, 84)
12 | self.fc3 = nn.Linear(84, 10)
13 |
14 | def forward(self, x):
15 | out = F.relu(self.conv1(x))
16 | out = F.max_pool2d(out, 2)
17 | out = F.relu(self.conv2(out))
18 | out = F.max_pool2d(out, 2)
19 | out = out.view(out.size(0), -1)
20 | out = F.relu(self.fc1(out))
21 | out = F.relu(self.fc2(out))
22 | out = self.fc3(out)
23 | return out
24 |
--------------------------------------------------------------------------------
/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | '''MobileNet in PyTorch.
2 |
3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications"
4 | for more details.
5 | '''
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | from torch.autograd import Variable
11 |
12 |
13 | class Block(nn.Module):
14 | '''Depthwise conv + Pointwise conv'''
15 | def __init__(self, in_planes, out_planes, stride=1):
16 | super(Block, self).__init__()
17 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False)
18 | self.bn1 = nn.BatchNorm2d(in_planes)
19 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
20 | self.bn2 = nn.BatchNorm2d(out_planes)
21 |
22 | def forward(self, x):
23 | out = F.relu(self.bn1(self.conv1(x)))
24 | out = F.relu(self.bn2(self.conv2(out)))
25 | return out
26 |
27 |
28 | class MobileNet(nn.Module):
29 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1
30 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024]
31 |
32 | def __init__(self, num_classes=10):
33 | super(MobileNet, self).__init__()
34 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
35 | self.bn1 = nn.BatchNorm2d(32)
36 | self.layers = self._make_layers(in_planes=32)
37 | self.linear = nn.Linear(1024, num_classes)
38 |
39 | def _make_layers(self, in_planes):
40 | layers = []
41 | for x in self.cfg:
42 | out_planes = x if isinstance(x, int) else x[0]
43 | stride = 1 if isinstance(x, int) else x[1]
44 | layers.append(Block(in_planes, out_planes, stride))
45 | in_planes = out_planes
46 | return nn.Sequential(*layers)
47 |
48 | def forward(self, x):
49 | out = F.relu(self.bn1(self.conv1(x)))
50 | out = self.layers(out)
51 | out = F.avg_pool2d(out, 2)
52 | out = out.view(out.size(0), -1)
53 | out = self.linear(out)
54 | return out
55 |
56 |
57 | def test():
58 | net = MobileNet()
59 | x = torch.randn(1,3,32,32)
60 | y = net(Variable(x))
61 | print(y.size())
62 |
63 | # test()
64 |
--------------------------------------------------------------------------------
/models/preact_resnet.py:
--------------------------------------------------------------------------------
1 | '''Pre-activation ResNet in PyTorch.
2 |
3 | Reference:
4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from torch.autograd import Variable
12 |
13 |
14 | class PreActBlock(nn.Module):
15 | '''Pre-activation version of the BasicBlock.'''
16 | expansion = 1
17 |
18 | def __init__(self, in_planes, planes, stride=1):
19 | super(PreActBlock, self).__init__()
20 | self.bn1 = nn.BatchNorm2d(in_planes)
21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
24 |
25 | if stride != 1 or in_planes != self.expansion*planes:
26 | self.shortcut = nn.Sequential(
27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
28 | )
29 |
30 | def forward(self, x):
31 | out = F.relu(self.bn1(x))
32 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
33 | out = self.conv1(out)
34 | out = self.conv2(F.relu(self.bn2(out)))
35 | out += shortcut
36 | return out
37 |
38 |
39 | class PreActBottleneck(nn.Module):
40 | '''Pre-activation version of the original Bottleneck module.'''
41 | expansion = 4
42 |
43 | def __init__(self, in_planes, planes, stride=1):
44 | super(PreActBottleneck, self).__init__()
45 | self.bn1 = nn.BatchNorm2d(in_planes)
46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
47 | self.bn2 = nn.BatchNorm2d(planes)
48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
49 | self.bn3 = nn.BatchNorm2d(planes)
50 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
51 |
52 | if stride != 1 or in_planes != self.expansion*planes:
53 | self.shortcut = nn.Sequential(
54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
55 | )
56 |
57 | def forward(self, x):
58 | out = F.relu(self.bn1(x))
59 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
60 | out = self.conv1(out)
61 | out = self.conv2(F.relu(self.bn2(out)))
62 | out = self.conv3(F.relu(self.bn3(out)))
63 | out += shortcut
64 | return out
65 |
66 |
67 | class PreActResNet(nn.Module):
68 | def __init__(self, block, num_blocks, num_classes=10):
69 | super(PreActResNet, self).__init__()
70 | self.in_planes = 64
71 |
72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
73 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
74 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
75 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
76 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
77 | self.linear = nn.Linear(512*block.expansion, num_classes)
78 |
79 | def _make_layer(self, block, planes, num_blocks, stride):
80 | strides = [stride] + [1]*(num_blocks-1)
81 | layers = []
82 | for stride in strides:
83 | layers.append(block(self.in_planes, planes, stride))
84 | self.in_planes = planes * block.expansion
85 | return nn.Sequential(*layers)
86 |
87 | def forward(self, x):
88 | out = self.conv1(x)
89 | out = self.layer1(out)
90 | out = self.layer2(out)
91 | out = self.layer3(out)
92 | out = self.layer4(out)
93 | out = F.avg_pool2d(out, 4)
94 | out = out.view(out.size(0), -1)
95 | out = self.linear(out)
96 | return out
97 |
98 |
99 | def PreActResNet18():
100 | return PreActResNet(PreActBlock, [2,2,2,2])
101 |
102 | def PreActResNet34():
103 | return PreActResNet(PreActBlock, [3,4,6,3])
104 |
105 | def PreActResNet50():
106 | return PreActResNet(PreActBottleneck, [3,4,6,3])
107 |
108 | def PreActResNet101():
109 | return PreActResNet(PreActBottleneck, [3,4,23,3])
110 |
111 | def PreActResNet152():
112 | return PreActResNet(PreActBottleneck, [3,8,36,3])
113 |
114 |
115 | def test():
116 | net = PreActResNet18()
117 | y = net(Variable(torch.randn(1,3,32,32)))
118 | print(y.size())
119 |
120 | # test()
121 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 |
3 | For Pre-activation ResNet, see 'preact_resnet.py'.
4 |
5 | Reference:
6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
8 | '''
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from torch.autograd import Variable
14 |
15 |
16 | class BasicBlock(nn.Module):
17 | expansion = 1
18 |
19 | def __init__(self, in_planes, planes, stride=1):
20 | super(BasicBlock, self).__init__()
21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
22 | self.bn1 = nn.BatchNorm2d(planes)
23 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
24 | self.bn2 = nn.BatchNorm2d(planes)
25 |
26 | self.shortcut = nn.Sequential()
27 | if stride != 1 or in_planes != self.expansion*planes:
28 | self.shortcut = nn.Sequential(
29 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
30 | nn.BatchNorm2d(self.expansion*planes)
31 | )
32 |
33 | def forward(self, x):
34 | out = F.relu(self.bn1(self.conv1(x)))
35 | out = self.bn2(self.conv2(out))
36 | out += self.shortcut(x)
37 | out = F.relu(out)
38 | return out
39 |
40 |
41 | class Bottleneck(nn.Module):
42 | expansion = 4
43 |
44 | def __init__(self, in_planes, planes, stride=1):
45 | super(Bottleneck, self).__init__()
46 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
47 | self.bn1 = nn.BatchNorm2d(planes)
48 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
49 | self.bn2 = nn.BatchNorm2d(planes)
50 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
51 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
52 |
53 | self.shortcut = nn.Sequential()
54 | if stride != 1 or in_planes != self.expansion*planes:
55 | self.shortcut = nn.Sequential(
56 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
57 | nn.BatchNorm2d(self.expansion*planes)
58 | )
59 |
60 | def forward(self, x):
61 | out = F.relu(self.bn1(self.conv1(x)))
62 | out = F.relu(self.bn2(self.conv2(out)))
63 | out = self.bn3(self.conv3(out))
64 | out += self.shortcut(x)
65 | out = F.relu(out)
66 | return out
67 |
68 |
69 | class ResNet(nn.Module):
70 | def __init__(self, block, num_blocks, num_classes=10):
71 | super(ResNet, self).__init__()
72 | self.in_planes = 64
73 |
74 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
75 | self.bn1 = nn.BatchNorm2d(64)
76 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
77 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
78 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
79 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
80 | self.linear = nn.Linear(512*block.expansion, num_classes)
81 |
82 | def _make_layer(self, block, planes, num_blocks, stride):
83 | strides = [stride] + [1]*(num_blocks-1)
84 | layers = []
85 | for stride in strides:
86 | layers.append(block(self.in_planes, planes, stride))
87 | self.in_planes = planes * block.expansion
88 | return nn.Sequential(*layers)
89 |
90 | def forward(self, x):
91 | out = F.relu(self.bn1(self.conv1(x)))
92 | out = self.layer1(out)
93 | out = self.layer2(out)
94 | out = self.layer3(out)
95 | out = self.layer4(out)
96 | out = F.avg_pool2d(out, 4)
97 | out = out.view(out.size(0), -1)
98 | out = self.linear(out)
99 | return out
100 |
101 |
102 | def ResNet18():
103 | return ResNet(BasicBlock, [2,2,2,2])
104 |
105 | def ResNet34():
106 | return ResNet(BasicBlock, [3,4,6,3])
107 |
108 | def ResNet50():
109 | return ResNet(Bottleneck, [3,4,6,3])
110 |
111 | def ResNet101():
112 | return ResNet(Bottleneck, [3,4,23,3])
113 |
114 | def ResNet152():
115 | return ResNet(Bottleneck, [3,8,36,3])
116 |
117 |
118 | def test():
119 | net = ResNet18()
120 | y = net(Variable(torch.randn(1,3,32,32)))
121 | print(y.size())
122 | print(net)
123 | # test()
--------------------------------------------------------------------------------
/models/resnext.py:
--------------------------------------------------------------------------------
1 | '''ResNeXt in PyTorch.
2 |
3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details.
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from torch.autograd import Variable
10 |
11 |
12 | class Block(nn.Module):
13 | '''Grouped convolution block.'''
14 | expansion = 2
15 |
16 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1):
17 | super(Block, self).__init__()
18 | group_width = cardinality * bottleneck_width
19 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False)
20 | self.bn1 = nn.BatchNorm2d(group_width)
21 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
22 | self.bn2 = nn.BatchNorm2d(group_width)
23 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False)
24 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width)
25 |
26 | self.shortcut = nn.Sequential()
27 | if stride != 1 or in_planes != self.expansion*group_width:
28 | self.shortcut = nn.Sequential(
29 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False),
30 | nn.BatchNorm2d(self.expansion*group_width)
31 | )
32 |
33 | def forward(self, x):
34 | out = F.relu(self.bn1(self.conv1(x)))
35 | out = F.relu(self.bn2(self.conv2(out)))
36 | out = self.bn3(self.conv3(out))
37 | out += self.shortcut(x)
38 | out = F.relu(out)
39 | return out
40 |
41 |
42 | class ResNeXt(nn.Module):
43 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10):
44 | super(ResNeXt, self).__init__()
45 | self.cardinality = cardinality
46 | self.bottleneck_width = bottleneck_width
47 | self.in_planes = 64
48 |
49 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False)
50 | self.bn1 = nn.BatchNorm2d(64)
51 | self.layer1 = self._make_layer(num_blocks[0], 1)
52 | self.layer2 = self._make_layer(num_blocks[1], 2)
53 | self.layer3 = self._make_layer(num_blocks[2], 2)
54 | # self.layer4 = self._make_layer(num_blocks[3], 2)
55 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes)
56 |
57 | def _make_layer(self, num_blocks, stride):
58 | strides = [stride] + [1]*(num_blocks-1)
59 | layers = []
60 | for stride in strides:
61 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride))
62 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width
63 | # Increase bottleneck_width by 2 after each stage.
64 | self.bottleneck_width *= 2
65 | return nn.Sequential(*layers)
66 |
67 | def forward(self, x):
68 | out = F.relu(self.bn1(self.conv1(x)))
69 | out = self.layer1(out)
70 | out = self.layer2(out)
71 | out = self.layer3(out)
72 | # out = self.layer4(out)
73 | out = F.avg_pool2d(out, 8)
74 | out = out.view(out.size(0), -1)
75 | out = self.linear(out)
76 | return out
77 |
78 |
79 | def ResNeXt29_2x64d():
80 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64)
81 |
82 | def ResNeXt29_4x64d():
83 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64)
84 |
85 | def ResNeXt29_8x64d():
86 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64)
87 |
88 | def ResNeXt29_32x4d():
89 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4)
90 |
91 | def test_resnext():
92 | net = ResNeXt29_2x64d()
93 | x = torch.randn(1,3,32,32)
94 | y = net(Variable(x))
95 | print(y.size())
96 |
97 | # test_resnext()
98 |
--------------------------------------------------------------------------------
/models/senet.py:
--------------------------------------------------------------------------------
1 | '''SENet in PyTorch.
2 |
3 | SENet is the winner of ImageNet-2017. The paper is not released yet.
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from torch.autograd import Variable
10 |
11 |
12 | class BasicBlock(nn.Module):
13 | def __init__(self, in_planes, planes, stride=1):
14 | super(BasicBlock, self).__init__()
15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
16 | self.bn1 = nn.BatchNorm2d(planes)
17 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
18 | self.bn2 = nn.BatchNorm2d(planes)
19 |
20 | self.shortcut = nn.Sequential()
21 | if stride != 1 or in_planes != planes:
22 | self.shortcut = nn.Sequential(
23 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
24 | nn.BatchNorm2d(planes)
25 | )
26 |
27 | # SE layers
28 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1) # Use nn.Conv2d instead of nn.Linear
29 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1)
30 |
31 | def forward(self, x):
32 | out = F.relu(self.bn1(self.conv1(x)))
33 | out = self.bn2(self.conv2(out))
34 |
35 | # Squeeze
36 | w = F.avg_pool2d(out, out.size(2))
37 | w = F.relu(self.fc1(w))
38 | w = F.sigmoid(self.fc2(w))
39 | # Excitation
40 | out = out * w # New broadcasting feature from v0.2!
41 |
42 | out += self.shortcut(x)
43 | out = F.relu(out)
44 | return out
45 |
46 |
47 | class PreActBlock(nn.Module):
48 | def __init__(self, in_planes, planes, stride=1):
49 | super(PreActBlock, self).__init__()
50 | self.bn1 = nn.BatchNorm2d(in_planes)
51 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
52 | self.bn2 = nn.BatchNorm2d(planes)
53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
54 |
55 | if stride != 1 or in_planes != planes:
56 | self.shortcut = nn.Sequential(
57 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False)
58 | )
59 |
60 | # SE layers
61 | self.fc1 = nn.Conv2d(planes, planes//16, kernel_size=1)
62 | self.fc2 = nn.Conv2d(planes//16, planes, kernel_size=1)
63 |
64 | def forward(self, x):
65 | out = F.relu(self.bn1(x))
66 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
67 | out = self.conv1(out)
68 | out = self.conv2(F.relu(self.bn2(out)))
69 |
70 | # Squeeze
71 | w = F.avg_pool2d(out, out.size(2))
72 | w = F.relu(self.fc1(w))
73 | w = F.sigmoid(self.fc2(w))
74 | # Excitation
75 | out = out * w
76 |
77 | out += shortcut
78 | return out
79 |
80 |
81 | class SENet(nn.Module):
82 | def __init__(self, block, num_blocks, num_classes=10):
83 | super(SENet, self).__init__()
84 | self.in_planes = 64
85 |
86 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
87 | self.bn1 = nn.BatchNorm2d(64)
88 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
89 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
90 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
91 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
92 | self.linear = nn.Linear(512, num_classes)
93 |
94 | def _make_layer(self, block, planes, num_blocks, stride):
95 | strides = [stride] + [1]*(num_blocks-1)
96 | layers = []
97 | for stride in strides:
98 | layers.append(block(self.in_planes, planes, stride))
99 | self.in_planes = planes
100 | return nn.Sequential(*layers)
101 |
102 | def forward(self, x):
103 | out = F.relu(self.bn1(self.conv1(x)))
104 | out = self.layer1(out)
105 | out = self.layer2(out)
106 | out = self.layer3(out)
107 | out = self.layer4(out)
108 | out = F.avg_pool2d(out, 4)
109 | out = out.view(out.size(0), -1)
110 | out = self.linear(out)
111 | return out
112 |
113 |
114 | def SENet18():
115 | return SENet(PreActBlock, [2,2,2,2])
116 |
117 |
118 | def test():
119 | net = SENet18()
120 | y = net(Variable(torch.randn(1,3,32,32)))
121 | print(y.size())
122 | print(net)
123 |
124 | #test()
125 |
--------------------------------------------------------------------------------
/models/shufflenet.py:
--------------------------------------------------------------------------------
1 | '''ShuffleNet in PyTorch.
2 |
3 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details.
4 | '''
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from torch.autograd import Variable
10 |
11 |
12 | class ShuffleBlock(nn.Module):
13 | def __init__(self, groups):
14 | super(ShuffleBlock, self).__init__()
15 | self.groups = groups
16 |
17 | def forward(self, x):
18 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
19 | N,C,H,W = x.size()
20 | g = self.groups
21 | return x.view(N,g,C/g,H,W).permute(0,2,1,3,4).contiguous().view(N,C,H,W)
22 |
23 |
24 | class Bottleneck(nn.Module):
25 | def __init__(self, in_planes, out_planes, stride, groups):
26 | super(Bottleneck, self).__init__()
27 | self.stride = stride
28 |
29 | mid_planes = out_planes/4
30 | g = 1 if in_planes==24 else groups
31 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False)
32 | self.bn1 = nn.BatchNorm2d(mid_planes)
33 | self.shuffle1 = ShuffleBlock(groups=g)
34 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False)
35 | self.bn2 = nn.BatchNorm2d(mid_planes)
36 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False)
37 | self.bn3 = nn.BatchNorm2d(out_planes)
38 |
39 | self.shortcut = nn.Sequential()
40 | if stride == 2:
41 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1))
42 |
43 | def forward(self, x):
44 | out = F.relu(self.bn1(self.conv1(x)))
45 | out = self.shuffle1(out)
46 | out = F.relu(self.bn2(self.conv2(out)))
47 | out = self.bn3(self.conv3(out))
48 | res = self.shortcut(x)
49 | out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res)
50 | return out
51 |
52 |
53 | class ShuffleNet(nn.Module):
54 | def __init__(self, cfg):
55 | super(ShuffleNet, self).__init__()
56 | out_planes = cfg['out_planes']
57 | num_blocks = cfg['num_blocks']
58 | groups = cfg['groups']
59 |
60 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False)
61 | self.bn1 = nn.BatchNorm2d(24)
62 | self.in_planes = 24
63 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups)
64 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups)
65 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups)
66 | self.linear = nn.Linear(out_planes[2], 10)
67 |
68 | def _make_layer(self, out_planes, num_blocks, groups):
69 | layers = []
70 | for i in range(num_blocks):
71 | stride = 2 if i == 0 else 1
72 | cat_planes = self.in_planes if i == 0 else 0
73 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups))
74 | self.in_planes = out_planes
75 | return nn.Sequential(*layers)
76 |
77 | def forward(self, x):
78 | out = F.relu(self.bn1(self.conv1(x)))
79 | out = self.layer1(out)
80 | out = self.layer2(out)
81 | out = self.layer3(out)
82 | out = F.avg_pool2d(out, 4)
83 | out = out.view(out.size(0), -1)
84 | out = self.linear(out)
85 | return out
86 |
87 |
88 | def ShuffleNetG2():
89 | cfg = {
90 | 'out_planes': [200,400,800],
91 | 'num_blocks': [4,8,4],
92 | 'groups': 2
93 | }
94 | return ShuffleNet(cfg)
95 |
96 | def ShuffleNetG3():
97 | cfg = {
98 | 'out_planes': [240,480,960],
99 | 'num_blocks': [4,8,4],
100 | 'groups': 3
101 | }
102 | return ShuffleNet(cfg)
103 |
104 |
105 | def test():
106 | net = ShuffleNetG2()
107 | x = Variable(torch.randn(1,3,32,32))
108 | y = net(x)
109 | print(y)
110 |
111 | # test()
112 |
--------------------------------------------------------------------------------
/models/small_cnn.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import torch.nn as nn
3 | import torch
4 | from torch.autograd import Variable
5 |
6 | class SmallCNN(nn.Module):
7 | def __init__(self):
8 | super(SmallCNN, self).__init__()
9 |
10 | self.block1_conv1 = nn.Conv2d(3, 64, 3, padding=1)
11 | self.block1_conv2 = nn.Conv2d(64, 64, 3, padding=1)
12 | self.block1_pool1 = nn.MaxPool2d(2, 2)
13 | self.batchnorm1_1 = nn.BatchNorm2d(64)
14 | self.batchnorm1_2 = nn.BatchNorm2d(64)
15 |
16 | self.block2_conv1 = nn.Conv2d(64, 128, 3, padding=1)
17 | self.block2_conv2 = nn.Conv2d(128, 128, 3, padding=1)
18 | self.block2_pool1 = nn.MaxPool2d(2, 2)
19 | self.batchnorm2_1 = nn.BatchNorm2d(128)
20 | self.batchnorm2_2 = nn.BatchNorm2d(128)
21 |
22 | self.block3_conv1 = nn.Conv2d(128, 196, 3, padding=1)
23 | self.block3_conv2 = nn.Conv2d(196, 196, 3, padding=1)
24 | self.block3_pool1 = nn.MaxPool2d(2, 2)
25 | self.batchnorm3_1 = nn.BatchNorm2d(196)
26 | self.batchnorm3_2 = nn.BatchNorm2d(196)
27 |
28 | self.activ = nn.ReLU()
29 |
30 | self.fc1 = nn.Linear(196*4*4,256)
31 | self.fc2 = nn.Linear(256,10)
32 |
33 | def forward(self, x):
34 | #block1
35 | x = self.block1_conv1(x)
36 | x = self.batchnorm1_1(x)
37 | x = self.activ(x)
38 | x = self.block1_conv2(x)
39 | x = self.batchnorm1_2(x)
40 | x = self.activ(x)
41 | x = self.block1_pool1(x)
42 |
43 | #block2
44 | x = self.block2_conv1(x)
45 | x = self.batchnorm2_1(x)
46 | x = self.activ(x)
47 | x = self.block2_conv2(x)
48 | x = self.batchnorm2_2(x)
49 | x = self.activ(x)
50 | x = self.block2_pool1(x)
51 | #block3
52 | x = self.block3_conv1(x)
53 | x = self.batchnorm3_1(x)
54 | x = self.activ(x)
55 | x = self.block3_conv2(x)
56 | x = self.batchnorm3_2(x)
57 | x = self.activ(x)
58 | x = self.block3_pool1(x)
59 |
60 | x = x.view(-1,196*4*4)
61 | x = self.fc1(x)
62 | x = self.activ(x)
63 | x = self.fc2(x)
64 |
65 | return x
66 |
67 | def small_cnn():
68 | return SmallCNN()
69 | def test():
70 | net = small_cnn()
71 | y = net(Variable(torch.randn(1,3,32,32)))
72 | print(y.size())
73 | print(net)
--------------------------------------------------------------------------------
/models/vgg.py:
--------------------------------------------------------------------------------
1 | '''VGG11/13/16/19 in Pytorch.'''
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 |
6 |
7 | cfg = {
8 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
9 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
10 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
11 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
12 | }
13 |
14 |
15 | class VGG(nn.Module):
16 | def __init__(self, vgg_name):
17 | super(VGG, self).__init__()
18 | self.features = self._make_layers(cfg[vgg_name])
19 | self.classifier = nn.Linear(512, 10)
20 |
21 | def forward(self, x):
22 | out = self.features(x)
23 | out = out.view(out.size(0), -1)
24 | out = self.classifier(out)
25 | return out
26 |
27 | def _make_layers(self, cfg):
28 | layers = []
29 | in_channels = 3
30 | for x in cfg:
31 | if x == 'M':
32 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
33 | else:
34 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
35 | nn.BatchNorm2d(x),
36 | nn.ReLU(inplace=True)]
37 | in_channels = x
38 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
39 | return nn.Sequential(*layers)
40 |
41 | # net = VGG('VGG11')
42 | # x = torch.randn(2,3,32,32)
43 | # print(net(Variable(x)).size())
44 |
--------------------------------------------------------------------------------
/models/wide_resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 | class BasicBlock(nn.Module):
8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
9 | super(BasicBlock, self).__init__()
10 | self.bn1 = nn.BatchNorm2d(in_planes)
11 | self.relu1 = nn.ReLU(inplace=True)
12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(out_planes)
15 | self.relu2 = nn.ReLU(inplace=True)
16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
17 | padding=1, bias=False)
18 | self.droprate = dropRate
19 | self.equalInOut = (in_planes == out_planes)
20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
21 | padding=0, bias=False) or None
22 |
23 | def forward(self, x):
24 | if not self.equalInOut:
25 | x = self.relu1(self.bn1(x))
26 | else:
27 | out = self.relu1(self.bn1(x))
28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
29 | if self.droprate > 0:
30 | out = F.dropout(out, p=self.droprate, training=self.training)
31 | out = self.conv2(out)
32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
33 |
34 |
35 | class NetworkBlock(nn.Module):
36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
37 | super(NetworkBlock, self).__init__()
38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
39 |
40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
41 | layers = []
42 | for i in range(int(nb_layers)):
43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
44 | return nn.Sequential(*layers)
45 |
46 | def forward(self, x):
47 | return self.layer(x)
48 |
49 |
50 | class Wide_ResNet(nn.Module):
51 | def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0):
52 | super(Wide_ResNet, self).__init__()
53 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
54 | assert ((depth - 4) % 6 == 0)
55 | n = (depth - 4) / 6
56 | block = BasicBlock
57 | # 1st conv before any network block
58 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
59 | padding=1, bias=False)
60 | # 1st block
61 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
62 | # 1st sub-block
63 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
64 | # 2nd block
65 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
66 | # 3rd block
67 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
68 | # global average pooling and classifier
69 | self.bn1 = nn.BatchNorm2d(nChannels[3])
70 | self.relu = nn.ReLU(inplace=True)
71 | self.fc = nn.Linear(nChannels[3], num_classes)
72 | self.nChannels = nChannels[3]
73 |
74 | for m in self.modules():
75 | if isinstance(m, nn.Conv2d):
76 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
77 | m.weight.data.normal_(0, math.sqrt(2. / n))
78 | elif isinstance(m, nn.BatchNorm2d):
79 | m.weight.data.fill_(1)
80 | m.bias.data.zero_()
81 | elif isinstance(m, nn.Linear):
82 | m.bias.data.zero_()
83 |
84 | def forward(self, x):
85 | out = self.conv1(x)
86 | out = self.block1(out)
87 | out = self.block2(out)
88 | out = self.block3(out)
89 | out = self.relu(self.bn1(out))
90 | out = F.avg_pool2d(out, 8)
91 | out = out.view(-1, self.nChannels)
92 | return self.fc(out)
93 | def test():
94 | net = Wide_ResNet()
95 | y = net(Variable(torch.randn(1, 3, 32, 32)))
96 | #print(y.size())
97 | print(net)
98 | # test()
--------------------------------------------------------------------------------
/models/wrn_madry.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 | class BasicBlock(nn.Module):
8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
9 | super(BasicBlock, self).__init__()
10 | self.bn1 = nn.BatchNorm2d(in_planes)
11 | self.relu1 = nn.ReLU(inplace=True)
12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
13 | padding=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(out_planes)
15 | self.relu2 = nn.ReLU(inplace=True)
16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
17 | padding=1, bias=False)
18 | self.droprate = dropRate
19 | self.equalInOut = (in_planes == out_planes)
20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
21 | padding=0, bias=False) or None
22 |
23 | def forward(self, x):
24 | if not self.equalInOut:
25 | x = self.relu1(self.bn1(x))
26 | else:
27 | out = self.relu1(self.bn1(x))
28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
29 | if self.droprate > 0:
30 | out = F.dropout(out, p=self.droprate, training=self.training)
31 | out = self.conv2(out)
32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
33 |
34 |
35 | class NetworkBlock(nn.Module):
36 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
37 | super(NetworkBlock, self).__init__()
38 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
39 |
40 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
41 | layers = []
42 | for i in range(int(nb_layers)):
43 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
44 | return nn.Sequential(*layers)
45 |
46 | def forward(self, x):
47 | return self.layer(x)
48 |
49 |
50 | class Wide_ResNet_Madry(nn.Module):
51 | def __init__(self, depth=34, num_classes=10, widen_factor=10, dropRate=0):
52 | super(Wide_ResNet_Madry, self).__init__()
53 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
54 | assert ((depth - 2) % 6 == 0)
55 | n = (depth - 2) / 6
56 | block = BasicBlock
57 | # 1st conv before any network block
58 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
59 | padding=1, bias=False)
60 | # 1st block
61 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
62 | # 1st sub-block
63 | # self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
64 | # 2nd block
65 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
66 | # 3rd block
67 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
68 | # global average pooling and classifier
69 | self.bn1 = nn.BatchNorm2d(nChannels[3])
70 | self.relu = nn.ReLU(inplace=True)
71 | self.fc = nn.Linear(nChannels[3], num_classes)
72 | self.nChannels = nChannels[3]
73 |
74 | for m in self.modules():
75 | if isinstance(m, nn.Conv2d):
76 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
77 | m.weight.data.normal_(0, math.sqrt(2. / n))
78 | elif isinstance(m, nn.BatchNorm2d):
79 | m.weight.data.fill_(1)
80 | m.bias.data.zero_()
81 | elif isinstance(m, nn.Linear):
82 | m.bias.data.zero_()
83 |
84 | def forward(self, x):
85 | out = self.conv1(x)
86 | out = self.block1(out)
87 | out = self.block2(out)
88 | out = self.block3(out)
89 | out = self.relu(self.bn1(out))
90 | out = F.avg_pool2d(out, 8)
91 | out = out.view(-1, self.nChannels)
92 | return self.fc(out)
93 | def test():
94 | net = Wide_ResNet_Madry()
95 | y = net(Variable(torch.randn(1, 3, 32, 32)))
96 | #print(y.size())
97 | print(net)
98 | # test()
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Useful utils
2 | """
3 | from .logger import *
4 |
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/__init__.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/eval.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/eval.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/eval.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/eval.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/eval.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/eval.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logger.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/logger.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logger.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/logger.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logger.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/logger.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/misc.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/misc.cpython-35.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/misc.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/misc.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/misc.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zjfheart/Friendly-Adversarial-Training/80a064074891775146bca8c4d8a6b99c1499285a/utils/__pycache__/misc.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | # A simple torch style logger
2 | # (C) Wei YANG 2017
3 | from __future__ import absolute_import
4 | import matplotlib.pyplot as plt
5 | import os
6 | import sys
7 | import numpy as np
8 |
9 | __all__ = ['Logger', 'LoggerMonitor', 'savefig']
10 |
11 | def savefig(fname, dpi=None):
12 | dpi = 150 if dpi == None else dpi
13 | plt.savefig(fname, dpi=dpi)
14 |
15 | def plot_overlap(logger, names=None):
16 | names = logger.names if names == None else names
17 | numbers = logger.numbers
18 | for _, name in enumerate(names):
19 | x = np.arange(len(numbers[name]))
20 | plt.plot(x, np.asarray(numbers[name]))
21 | return [logger.title + '(' + name + ')' for name in names]
22 |
23 | class Logger(object):
24 | '''Save training process to log file with simple plot function.'''
25 | def __init__(self, fpath, title=None, resume=False):
26 | self.file = None
27 | self.resume = resume
28 | self.title = '' if title == None else title
29 | if fpath is not None:
30 | if resume:
31 | self.file = open(fpath, 'r')
32 | name = self.file.readline()
33 | self.names = name.rstrip().split('\t')
34 | self.numbers = {}
35 | for _, name in enumerate(self.names):
36 | self.numbers[name] = []
37 |
38 | for numbers in self.file:
39 | numbers = numbers.rstrip().split('\t')
40 | for i in range(0, len(numbers)):
41 | self.numbers[self.names[i]].append(numbers[i])
42 | self.file.close()
43 | self.file = open(fpath, 'a')
44 | else:
45 | self.file = open(fpath, 'w')
46 |
47 | def set_names(self, names):
48 | if self.resume:
49 | pass
50 | # initialize numbers as empty list
51 | self.numbers = {}
52 | self.names = names
53 | for _, name in enumerate(self.names):
54 | self.file.write(name)
55 | self.file.write('\t')
56 | self.numbers[name] = []
57 | self.file.write('\n')
58 | self.file.flush()
59 |
60 |
61 | def append(self, numbers):
62 | assert len(self.names) == len(numbers), 'Numbers do not match names'
63 | for index, num in enumerate(numbers):
64 | self.file.write("{0:.6f}".format(num))
65 | self.file.write('\t')
66 | self.numbers[self.names[index]].append(num)
67 | self.file.write('\n')
68 | self.file.flush()
69 |
70 | def plot(self, names=None):
71 | names = self.names if names == None else names
72 | numbers = self.numbers
73 | for _, name in enumerate(names):
74 | x = np.arange(len(numbers[name]))
75 | plt.plot(x, np.asarray(numbers[name]))
76 | plt.legend([self.title + '(' + name + ')' for name in names])
77 | plt.grid(True)
78 |
79 | def close(self):
80 | if self.file is not None:
81 | self.file.close()
82 |
83 | class LoggerMonitor(object):
84 | '''Load and visualize multiple logs.'''
85 | def __init__ (self, paths):
86 | '''paths is a distionary with {name:filepath} pair'''
87 | self.loggers = []
88 | for title, path in paths.items():
89 | logger = Logger(path, title=title, resume=True)
90 | self.loggers.append(logger)
91 |
92 | def plot(self, names=None):
93 | plt.figure()
94 | plt.subplot(121)
95 | legend_text = []
96 | for logger in self.loggers:
97 | legend_text += plot_overlap(logger, names)
98 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
99 | plt.grid(True)
100 |
101 | if __name__ == '__main__':
102 | # # Example
103 | # logger = Logger('test.txt')
104 | # logger.set_names(['Train loss', 'Valid loss','Test loss'])
105 |
106 | # length = 100
107 | # t = np.arange(length)
108 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
109 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
110 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
111 |
112 | # for i in range(0, length):
113 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
114 | # logger.plot()
115 |
116 | # Example: logger monitor
117 | paths = {
118 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
119 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
120 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
121 | }
122 |
123 | field = ['Valid Acc.']
124 |
125 | monitor = LoggerMonitor(paths)
126 | monitor.plot(names=field)
127 | savefig('test.eps')
--------------------------------------------------------------------------------