├── .gitignore ├── CIFAR10.png ├── MNIST.png ├── README.md ├── checkpoint └── .gitkeep ├── generate.py ├── models.py ├── test_model.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.tar 3 | checkpoint/* 4 | -------------------------------------------------------------------------------- /CIFAR10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owruby/APE-GAN/1095ac2f6c0cf85fc75fcd0802c8cce4762dbacf/CIFAR10.png -------------------------------------------------------------------------------- /MNIST.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owruby/APE-GAN/1095ac2f6c0cf85fc75fcd0802c8cce4762dbacf/MNIST.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # APE-GAN 2 | 3 | Implementation APE-GAN (https://arxiv.org/pdf/1707.05474.pdf) 4 | 5 | ## MNIST 6 | 7 | ![MNIST](https://github.com/owruby/APE-GAN/blob/master/MNIST.png) 8 | 9 | ### 1. Train CNN and Generate Adversarial Examples(FGSM) 10 | ``` 11 | python generate.py --eps 0.15 12 | ``` 13 | 14 | ### 2. Train APE-GAN 15 | ``` 16 | python train.py --checkpoint ./checkpoint/mnist 17 | ``` 18 | 19 | ### 3. Test 20 | ``` 21 | python test_model.py --eps 0.15 --gan_path ./checkpoint/mnist/3.tar 22 | ``` 23 | 24 | ## CIFAR-10 25 | 26 | ![CIFAR10](https://github.com/owruby/APE-GAN/blob/master/CIFAR10.png) 27 | 28 | ### 1. Train CNN and Generate Adversarial Examples(FGSM) 29 | ``` 30 | python generate.py --data cifar --eps 0.01 31 | ``` 32 | 33 | ### 2. Train APE-GAN 34 | ``` 35 | python train.py --data cifar --epochs 30 --checkpoint ./checkpoint/cifar 36 | ``` 37 | 38 | ### 3. Test 39 | ``` 40 | python test_model.py --data cifar --eps 0.01 --gan_path ./checkpoint/cifar/10.tar 41 | ``` 42 | -------------------------------------------------------------------------------- /checkpoint/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/owruby/APE-GAN/1095ac2f6c0cf85fc75fcd0802c8cce4762dbacf/checkpoint/.gitkeep -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import argparse 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.optim.lr_scheduler as lr_scheduler 10 | import torch.backends.cudnn as cudnn 11 | from torch.autograd import Variable 12 | 13 | from torchvision import datasets 14 | from torchvision import transforms 15 | 16 | from tqdm import tqdm 17 | 18 | from models import MnistCNN, CifarCNN 19 | from utils import accuracy, fgsm 20 | 21 | 22 | def load_dataset(args): 23 | if args.data == "mnist": 24 | train_loader = torch.utils.data.DataLoader( 25 | datasets.MNIST(os.path.expanduser("~/.torch/data/mnist"), train=True, download=True, 26 | transform=transforms.Compose([ 27 | transforms.ToTensor()])), 28 | batch_size=128, shuffle=True) 29 | test_loader = torch.utils.data.DataLoader( 30 | datasets.MNIST(os.path.expanduser("~/.torch/data/mnist"), train=False, download=False, 31 | transform=transforms.Compose([ 32 | transforms.ToTensor()])), 33 | batch_size=128, shuffle=False) 34 | elif args.data == "cifar": 35 | train_loader = torch.utils.data.DataLoader( 36 | datasets.CIFAR10(os.path.expanduser("~/.torch/data/cifar10"), train=True, download=True, 37 | transform=transforms.Compose([ 38 | transforms.ToTensor()])), 39 | batch_size=128, shuffle=True) 40 | test_loader = torch.utils.data.DataLoader( 41 | datasets.CIFAR10(os.path.expanduser("~/.torch/data/cifar10"), train=False, download=False, 42 | transform=transforms.Compose([ 43 | transforms.ToTensor()])), 44 | batch_size=128, shuffle=False) 45 | return train_loader, test_loader 46 | 47 | 48 | def load_cnn(args): 49 | if args.data == "mnist": 50 | return MnistCNN 51 | elif args.data == "cifar": 52 | return CifarCNN 53 | 54 | 55 | def main(args): 56 | print("Generating Model ...") 57 | print("-" * 30) 58 | 59 | train_loader, test_loader = load_dataset(args) 60 | CNN = load_cnn(args) 61 | model = CNN().cuda() 62 | cudnn.benchmark = True 63 | 64 | opt = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.001) 65 | scheduler = lr_scheduler.MultiStepLR(opt, milestones=args.milestones, gamma=args.gamma) 66 | loss_func = nn.CrossEntropyLoss().cuda() 67 | 68 | epochs = args.epochs 69 | print_str = "\t".join(["{}"] + ["{:.6f}"] * 4) 70 | print("\t".join(["{:}"] * 5).format("Epoch", "TrainLoss", "TestLoss", "TrainAcc.", "TestAcc.")) 71 | for e in range(epochs): 72 | train_loss, train_acc, train_n = 0, 0, 0 73 | test_loss, test_acc, test_n = 0, 0, 0 74 | 75 | model.train() 76 | for x, t in tqdm(train_loader, total=len(train_loader), leave=False): 77 | x, t = Variable(x.cuda()), Variable(t.cuda()) 78 | y = model(x) 79 | loss = loss_func(y, t) 80 | opt.zero_grad() 81 | loss.backward() 82 | opt.step() 83 | train_loss += loss.data[0] * t.size(0) 84 | train_acc += accuracy(y, t) 85 | train_n += t.size(0) 86 | 87 | model.eval() 88 | for x, t in tqdm(test_loader, total=len(test_loader), leave=False): 89 | x, t = Variable(x.cuda()), Variable(t.cuda()) 90 | y = model(x) 91 | loss = loss_func(y, t) 92 | 93 | test_loss += loss.data[0] * t.size(0) 94 | test_acc += accuracy(y, t) 95 | test_n += t.size(0) 96 | scheduler.step() 97 | print(print_str.format(e, train_loss / train_n, test_loss / test_n, 98 | train_acc / train_n * 100, test_acc / test_n * 100)) 99 | 100 | # Generate Adversarial Examples 101 | print("-" * 30) 102 | print("Genrating Adversarial Examples ...") 103 | eps = args.eps 104 | train_acc, adv_acc, train_n = 0, 0, 0 105 | normal_data, adv_data = None, None 106 | for x, t in tqdm(train_loader, total=len(train_loader), leave=False): 107 | x, t = Variable(x.cuda()), Variable(t.cuda()) 108 | y = model(x) 109 | train_acc += accuracy(y, t) 110 | 111 | x_adv = fgsm(model, x, t, loss_func, eps) 112 | y_adv = model(x_adv) 113 | adv_acc += accuracy(y_adv, t) 114 | train_n += t.size(0) 115 | 116 | x, x_adv = x.data, x_adv.data 117 | if normal_data is None: 118 | normal_data, adv_data = x, x_adv 119 | else: 120 | normal_data = torch.cat((normal_data, x)) 121 | adv_data = torch.cat((adv_data, x_adv)) 122 | 123 | print("Accuracy(normal) {:.6f}, Accuracy(FGSM) {:.6f}".format(train_acc / train_n * 100, adv_acc / train_n * 100)) 124 | torch.save({"normal": normal_data, "adv": adv_data}, "data.tar") 125 | torch.save({"state_dict": model.state_dict()}, "cnn.tar") 126 | 127 | 128 | if __name__ == "__main__": 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument("--data", type=str, default="mnist") 131 | parser.add_argument("--epochs", type=int, default=100) 132 | parser.add_argument("--lr", type=float, default=0.01) 133 | parser.add_argument("--milestones", type=list, default=[50, 75]) 134 | parser.add_argument("--gamma", type=float, default=0.1) 135 | parser.add_argument("--eps", type=float, default=0.15) 136 | args = parser.parse_args() 137 | main(args) 138 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class MnistCNN(nn.Module): 8 | 9 | def __init__(self): 10 | super(MnistCNN, self).__init__() 11 | self.conv1 = nn.Conv2d(1, 32, 3) 12 | self.conv2 = nn.Conv2d(32, 64, 3) 13 | self.fc3 = nn.Linear(1024, 128) 14 | self.fc4 = nn.Linear(128, 10) 15 | 16 | def forward(self, x): 17 | h = F.relu(self.conv1(x)) 18 | h = F.relu(self.conv2(h)) 19 | h = F.dropout2d(F.max_pool2d(h, 6), p=0.25) 20 | h = F.dropout2d(self.fc3(h.view(h.size(0), -1)), p=0.5) 21 | h = self.fc4(h) 22 | return F.log_softmax(h) 23 | 24 | 25 | class CifarCNN(nn.Module): 26 | 27 | def __init__(self): 28 | super(CifarCNN, self).__init__() 29 | self.conv1 = nn.Conv2d(3, 64, 3, padding=1) 30 | self.bn1 = nn.BatchNorm2d(64) 31 | self.conv2 = nn.Conv2d(64, 64, 3, padding=1) 32 | self.bn2 = nn.BatchNorm2d(64) 33 | self.conv3 = nn.Conv2d(64, 128, 3, padding=1) 34 | self.bn3 = nn.BatchNorm2d(128) 35 | self.conv4 = nn.Conv2d(128, 128, 3, padding=1) 36 | self.bn4 = nn.BatchNorm2d(128) 37 | self.fc5 = nn.Linear(512, 256) 38 | self.fc6 = nn.Linear(256, 256) 39 | self.fc7 = nn.Linear(256, 10) 40 | 41 | def forward(self, x): 42 | h = F.relu(self.bn1(self.conv1(x))) 43 | h = F.relu(self.bn2(self.conv2(h))) 44 | h = F.max_pool2d(h, 4) 45 | 46 | h = F.relu(self.bn3(self.conv3(h))) 47 | h = F.relu(self.bn4(self.conv4(h))) 48 | h = F.max_pool2d(h, 4) 49 | 50 | h = F.relu(self.fc5(h.view(h.size(0), -1))) 51 | h = F.relu(self.fc6(h)) 52 | h = self.fc7(h) 53 | return F.log_softmax(h) 54 | 55 | 56 | class Generator(nn.Module): 57 | 58 | def __init__(self, in_ch): 59 | super(Generator, self).__init__() 60 | self.conv1 = nn.Conv2d(in_ch, 64, 4, stride=2, padding=1) 61 | self.bn1 = nn.BatchNorm2d(64) 62 | self.conv2 = nn.Conv2d(64, 128, 4, stride=2, padding=1) 63 | self.bn2 = nn.BatchNorm2d(128) 64 | self.deconv3 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1) 65 | self.bn3 = nn.BatchNorm2d(64) 66 | self.deconv4 = nn.ConvTranspose2d(64, in_ch, 4, stride=2, padding=1) 67 | 68 | def forward(self, x): 69 | h = F.leaky_relu(self.bn1(self.conv1(x))) 70 | h = F.leaky_relu(self.bn2(self.conv2(h))) 71 | h = F.leaky_relu(self.bn3(self.deconv3(h))) 72 | h = F.tanh(self.deconv4(h)) 73 | return h 74 | 75 | 76 | class Discriminator(nn.Module): 77 | 78 | def __init__(self, in_ch): 79 | super(Discriminator, self).__init__() 80 | self.conv1 = nn.Conv2d(in_ch, 64, 3, stride=2) 81 | self.conv2 = nn.Conv2d(64, 128, 3, stride=2) 82 | self.bn2 = nn.BatchNorm2d(128) 83 | self.conv3 = nn.Conv2d(128, 256, 3, stride=2) 84 | self.bn3 = nn.BatchNorm2d(256) 85 | if in_ch == 1: 86 | self.fc4 = nn.Linear(1024, 1) 87 | else: 88 | self.fc4 = nn.Linear(2304, 1) 89 | 90 | def forward(self, x): 91 | h = F.leaky_relu(self.conv1(x)) 92 | h = F.leaky_relu(self.bn2(self.conv2(h))) 93 | h = F.leaky_relu(self.bn3(self.conv3(h))) 94 | h = F.sigmoid(self.fc4(h.view(h.size(0), -1))) 95 | return h 96 | 97 | if __name__ == "__main__": 98 | import torch 99 | from torch.autograd import Variable 100 | x = torch.normal(mean=0, std=torch.ones(10, 3, 32, 32)) 101 | model = CifarCNN() 102 | model(Variable(x)) 103 | -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import argparse 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | 10 | from torchvision import datasets 11 | from torchvision import transforms 12 | 13 | from tqdm import tqdm 14 | 15 | from models import MnistCNN, CifarCNN, Generator 16 | from utils import fgsm, accuracy 17 | 18 | 19 | def load_dataset(args): 20 | if args.data == "mnist": 21 | test_loader = torch.utils.data.DataLoader( 22 | datasets.MNIST(os.path.expanduser("~/.torch/data/mnist"), train=False, download=False, 23 | transform=transforms.Compose([ 24 | transforms.ToTensor()])), 25 | batch_size=128, shuffle=False) 26 | elif args.data == "cifar": 27 | test_loader = torch.utils.data.DataLoader( 28 | datasets.CIFAR10(os.path.expanduser("~/.torch/data/cifar10"), train=False, download=False, 29 | transform=transforms.Compose([ 30 | transforms.ToTensor()])), 31 | batch_size=128, shuffle=False) 32 | return test_loader 33 | 34 | 35 | def load_cnn(args): 36 | if args.data == "mnist": 37 | return MnistCNN 38 | elif args.data == "cifar": 39 | return CifarCNN 40 | 41 | 42 | def main(args): 43 | eps = args.eps 44 | test_loader = load_dataset(args) 45 | 46 | model_point = torch.load("cnn.tar") 47 | gan_point = torch.load(args.gan_path) 48 | 49 | CNN = load_cnn(args) 50 | 51 | model = CNN().cuda() 52 | model.load_state_dict(model_point["state_dict"]) 53 | 54 | in_ch = 1 if args.data == "mnist" else 3 55 | 56 | G = Generator(in_ch).cuda() 57 | G.load_state_dict(gan_point["generator"]) 58 | loss_cre = nn.CrossEntropyLoss().cuda() 59 | 60 | model.eval(), G.eval() 61 | normal_acc, adv_acc, ape_acc, n = 0, 0, 0, 0 62 | for x, t in tqdm(test_loader, total=len(test_loader), leave=False): 63 | x, t = Variable(x.cuda()), Variable(t.cuda()) 64 | 65 | y = model(x) 66 | normal_acc += accuracy(y, t) 67 | 68 | x_adv = fgsm(model, x, t, loss_cre, eps) 69 | y_adv = model(x_adv) 70 | adv_acc += accuracy(y_adv, t) 71 | 72 | x_ape = G(x_adv) 73 | y_ape = model(x_ape) 74 | ape_acc += accuracy(y_ape, t) 75 | n += t.size(0) 76 | print("Accuracy: normal {:.6f}, fgsm {:.6f}, ape {:.6f}".format( 77 | normal_acc / n * 100, 78 | adv_acc / n * 100, 79 | ape_acc / n * 100)) 80 | 81 | 82 | if __name__ == "__main__": 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument("--data", type=str, default="mnist") 85 | parser.add_argument("--eps", type=float, default=0.15) 86 | parser.add_argument("--gan_path", type=str, default="./checkpoint/test/3.tar") 87 | args = parser.parse_args() 88 | main(args) 89 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import argparse 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.autograd import Variable 10 | from torch.utils.data import TensorDataset 11 | import torch.backends.cudnn as cudnn 12 | 13 | from tqdm import tqdm 14 | import matplotlib.pyplot as plt 15 | 16 | from models import Generator, Discriminator 17 | 18 | 19 | def show_images(e, x, x_adv, x_fake, save_dir): 20 | fig, axes = plt.subplots(3, 5, figsize=(10, 6)) 21 | for i in range(5): 22 | axes[0, i].axis("off"), axes[1, i].axis("off"), axes[2, i].axis("off") 23 | axes[0, i].imshow(x[i].cpu().numpy().transpose((1, 2, 0))) 24 | # axes[0, i].imshow(x[i, 0].cpu().numpy(), cmap="gray") 25 | axes[0, i].set_title("Normal") 26 | 27 | axes[1, i].imshow(x_adv[i].cpu().numpy().transpose((1, 2, 0))) 28 | # axes[1, i].imshow(x_adv[i, 0].cpu().numpy(), cmap="gray") 29 | axes[1, i].set_title("Adv") 30 | 31 | axes[2, i].imshow(x_fake[i].cpu().numpy().transpose((1, 2, 0))) 32 | # axes[2, i].imshow(x_fake[i, 0].cpu().numpy(), cmap="gray") 33 | axes[2, i].set_title("APE-GAN") 34 | plt.axis("off") 35 | plt.savefig(os.path.join(save_dir, "result_{}.png".format(e))) 36 | 37 | 38 | def main(args): 39 | lr = args.lr 40 | epochs = args.epochs 41 | batch_size = 128 42 | xi1, xi2 = args.xi1, args.xi2 43 | 44 | check_path = args.checkpoint 45 | os.makedirs(check_path, exist_ok=True) 46 | 47 | train_data = torch.load("data.tar") 48 | x_tmp = train_data["normal"][:5] 49 | x_adv_tmp = train_data["adv"][:5] 50 | 51 | train_data = TensorDataset(train_data["normal"], train_data["adv"]) 52 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True) 53 | 54 | in_ch = 1 if args.data == "mnist" else 3 55 | G = Generator(in_ch).cuda() 56 | D = Discriminator(in_ch).cuda() 57 | 58 | opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999)) 59 | opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999)) 60 | loss_bce = nn.BCELoss() 61 | loss_mse = nn.MSELoss() 62 | cudnn.benchmark = True 63 | 64 | print_str = "\t".join(["{}"] + ["{:.6f}"] * 2) 65 | print("\t".join(["{:}"] * 3).format("Epoch", "Gen_Loss", "Dis_Loss")) 66 | for e in range(epochs): 67 | G.eval() 68 | x_fake = G(Variable(x_adv_tmp.cuda())).data 69 | show_images(e, x_tmp, x_adv_tmp, x_fake, check_path) 70 | G.train() 71 | gen_loss, dis_loss, n = 0, 0, 0 72 | for x, x_adv in tqdm(train_loader, total=len(train_loader), leave=False): 73 | current_size = x.size(0) 74 | x, x_adv = Variable(x.cuda()), Variable(x_adv.cuda()) 75 | # Train D 76 | t_real = Variable(torch.ones(current_size).cuda()) 77 | t_fake = Variable(torch.zeros(current_size).cuda()) 78 | 79 | y_real = D(x).squeeze() 80 | x_fake = G(x_adv) 81 | y_fake = D(x_fake).squeeze() 82 | 83 | loss_D = loss_bce(y_real, t_real) + loss_bce(y_fake, t_fake) 84 | opt_D.zero_grad() 85 | loss_D.backward() 86 | opt_D.step() 87 | 88 | # Train G 89 | for _ in range(2): 90 | x_fake = G(x_adv) 91 | y_fake = D(x_fake).squeeze() 92 | 93 | loss_G = xi1 * loss_mse(x_fake, x) + xi2 * loss_bce(y_fake, t_real) 94 | opt_G.zero_grad() 95 | loss_G.backward() 96 | opt_G.step() 97 | 98 | gen_loss += loss_D.data[0] * x.size(0) 99 | dis_loss += loss_G.data[0] * x.size(0) 100 | n += x.size(0) 101 | print(print_str.format(e, gen_loss / n, dis_loss / n)) 102 | torch.save({"generator": G.state_dict(), "discriminator": D.state_dict()}, 103 | os.path.join(check_path, "{}.tar".format(e + 1))) 104 | 105 | G.eval() 106 | x_fake = G(Variable(x_adv_tmp.cuda())).data 107 | show_images(epochs, x_tmp, x_adv_tmp, x_fake, check_path) 108 | G.train() 109 | 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument("--data", type=str, default="mnist") 114 | parser.add_argument("--lr", type=float, default=0.0002) 115 | parser.add_argument("--epochs", type=int, default=2) 116 | parser.add_argument("--xi1", type=float, default=0.7) 117 | parser.add_argument("--xi2", type=float, default=0.3) 118 | parser.add_argument("--checkpoint", type=str, default="./checkpoint/test") 119 | args = parser.parse_args() 120 | main(args) 121 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | 7 | def fgsm(model, x, t, loss_func, eps, min=0, max=1): 8 | if not isinstance(x, Variable): 9 | x, t = Variable(x.cuda(), requires_grad=True), Variable(t.cuda()) 10 | x.requires_grad = True 11 | y = model(x) 12 | loss = loss_func(y, t) 13 | model.zero_grad() 14 | loss.backward(retain_graph=True) 15 | 16 | return Variable(torch.clamp(x.data + eps * torch.sign(x.grad.data), min=min, max=max)) 17 | 18 | 19 | def accuracy(y, t): 20 | pred = y.data.max(1, keepdim=True)[1] 21 | acc = pred.eq(t.data.view_as(pred)).cpu().sum() 22 | return acc 23 | --------------------------------------------------------------------------------