├── tsne.png ├── tsne_normal.png ├── new_objective.png ├── README.md ├── module.py ├── mnist.ipynb ├── cifar_convex_polygon.ipynb ├── mnist_convex_polygon.ipynb ├── util.py └── network.py /tsne.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyleliang919/Interval-bound-propagation-pytorch/HEAD/tsne.png -------------------------------------------------------------------------------- /tsne_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyleliang919/Interval-bound-propagation-pytorch/HEAD/tsne_normal.png -------------------------------------------------------------------------------- /new_objective.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kyleliang919/Interval-bound-propagation-pytorch/HEAD/new_objective.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Interval-bound-propagation-pytorch 2 | This repository contains the pytorch attempts to replicate the results from the recent DeepMind Paper, "On the Effectiveness of Interval Bound Propagation for Training Verifiably Robust Models",https://arxiv.org/pdf/1810.12715.pdf 3 | 4 | ## Disclaimer: 5 | This is not an official implementation. The difference in the numbers might be due to various reasons. Please refer to the original repo, https://github.com/deepmind/interval-bound-propagation, for exact implementation details. 6 | 7 | ## Environments: 8 | Python 3.5 9 | Pytorch 10 | GTX 1080 Ti ( The results shown on the paper was run on 32 TPU with way larger batch size. So drop in performance should be expected.) 11 | 12 | ## Results 13 | The results are not measured with MIP, to get "exact" bound, more work needs to be done. 14 | 15 | ### MNIST Results(eps = 0.1) 16 | | Model | Robust Acc | Nominal Acc | 17 | | --- | --- | --- | 18 | | Small | 0.96 | 0.9823 | 19 | | Medium | 0.9418 |0.977 | 20 | | Large | 0.9458 |0.9754 | 21 | 22 | ### CIFAR Results(eps = 2/255) 23 | | Model | Robust Acc | Nominal Acc | 24 | | --- | --- | --- | 25 | | Small | 0.3481 | 0.5535 | 26 | | Medium | 0.3179 | 0.4914 | 27 | | Large | 0.3426 | 0.5223 | 28 | 29 | ## T-SNE Visualization 30 | ### Robust Training 31 | ![Robust Training](tsne.png) 32 | ### Normal Training 33 | ![Normal Training](tsne_normal.png) 34 | ### Observations 35 | It's clear that the robust objective proposed is way harder than the nominal one and it's not separable after transformations by convolutional layers. It's very likely that the bound propagation introduces overlapping in data manifolds with different labels, such that obtaining zero training loss is impossible. 36 | 37 | ## Derivative works 38 | Attempting to achieve similar performance with limited resource, I speculate that we might need a "tighter" objective than the one proposed. 39 | On the side, I also tried a "tighter" version of IBP by constraining the weights of some layers to be non-negative and ues convex polygon in the last layer. 40 | 41 | Also sometimes Spectral Normalization seems to be able to stablize the robust training by regulating the lipschitz constant. 42 | ### Convex Polygon Objective 43 | ![New Objective](new_objective.png) 44 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import copy 5 | import random 6 | import numpy as np 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | from torch.nn.parameter import Parameter 11 | 12 | class RobustLinear(nn.Module): 13 | def __init__(self, in_features, out_features, bias=True, non_negative = True): 14 | super(RobustLinear, self).__init__() 15 | self.in_features = in_features 16 | self.out_features = out_features 17 | if non_negative: 18 | self.weight = Parameter(torch.rand(out_features, in_features) * 1/math.sqrt(in_features)) 19 | else: 20 | self.weight = Parameter(torch.randn(out_features, in_features) * 1/math.sqrt(in_features)) 21 | 22 | if bias: 23 | self.bias = Parameter(torch.zeros(out_features)) 24 | else: 25 | self.bias = None 26 | self.non_negative = non_negative 27 | 28 | def forward(self, input): 29 | input_p = input[:input.shape[0]//2] 30 | input_n = input[input.shape[0]//2:] 31 | if self.non_negative: 32 | out_p = F.linear(input_p, F.relu(self.weight), self.bias) 33 | out_n = F.linear(input_n, F.relu(self.weight), self.bias) 34 | return torch.cat([out_p, out_n], 0) 35 | 36 | u = (input_p + input_n)/2 37 | r = (input_p - input_n)/2 38 | out_u = F.linear(u, self.weight, self.bias) 39 | out_r = F.linear(r, torch.abs(self.weight), None) 40 | return torch.cat([out_u + out_r, out_u - out_r], 0) 41 | 42 | 43 | class RobustConv2d(nn.Module): 44 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 45 | padding=0, dilation=1, groups=1, bias=True, non_negative = True): 46 | super(RobustConv2d, self).__init__() 47 | if non_negative: 48 | self.weight = Parameter(torch.rand(out_channels, in_channels//groups, kernel_size, kernel_size) * 1/math.sqrt(kernel_size * kernel_size * in_channels//groups)) 49 | else: 50 | self.weight = Parameter(torch.randn(out_channels, in_channels//groups, kernel_size, kernel_size) * 1/math.sqrt(kernel_size * kernel_size * in_channels//groups)) 51 | if bias: 52 | self.bias = Parameter(torch.zeros(out_channels)) 53 | else: 54 | self.bias = None 55 | self.stride = stride 56 | self.padding = padding 57 | self.dilation = dilation 58 | self.groups = 1 59 | self.non_negative = non_negative 60 | 61 | def forward(self, input): 62 | input_p = input[:input.shape[0]//2] 63 | input_n = input[input.shape[0]//2:] 64 | if self.non_negative: 65 | out_p = F.conv2d(input_p, F.relu(self.weight), self.bias, self.stride, 66 | self.padding, self.dilation, self.groups) 67 | out_n = F.conv2d(input_n, F.relu(self.weight), self.bias, self.stride, 68 | self.padding, self.dilation, self.groups) 69 | return torch.cat([out_p, out_n],0) 70 | 71 | u = (input_p + input_n)/2 72 | r = (input_p - input_n)/2 73 | out_u = F.conv2d(u, self.weight,self.bias, self.stride, 74 | self.padding, self.dilation, self.groups) 75 | out_r = F.conv2d(r, torch.abs(self.weight), None, self.stride, 76 | self.padding, self.dilation, self.groups) 77 | return torch.cat([out_u + out_r, out_u - out_r], 0) 78 | 79 | 80 | class ImageNorm(nn.Module): 81 | def __init__(self, mean, std): 82 | super(ImageNorm, self).__init__() 83 | self.mean = torch.from_numpy(np.array(mean)).view(1,3,1,1).cuda().float() 84 | self.std = torch.from_numpy(np.array(std)).view(1,3,1,1).cuda().float() 85 | 86 | def forward(self, input): 87 | input = torch.clamp(input, 0, 1) 88 | return (input - self.mean)/self.std 89 | -------------------------------------------------------------------------------- /mnist.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", 11 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", 12 | "\n", 13 | "import torch.nn as nn\n", 14 | "import math\n", 15 | "import copy\n", 16 | "import random\n", 17 | "import numpy as np\n", 18 | "import torch\n", 19 | "import torch.optim as optim\n", 20 | "import torch.nn.functional as F\n", 21 | "import torch.nn as nn\n", 22 | "from torch.nn.parameter import Parameter\n", 23 | "from hybrid_network import *\n", 24 | "from hybrid_util import *\n", 25 | "learning_rate = 0.001" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "trainloader, testloader = get_mnist() \n", 35 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 36 | "print(device)\n", 37 | "net = MNIST_Large_ConvNet(\n", 38 | " non_negative = [False, False, False, False, False, False, False], \n", 39 | " norm = [False, False, False, False, False, False, False])\n", 40 | "net = net.to(device)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "criterion = nn.CrossEntropyLoss()\n", 50 | "optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), learning_rate)\n", 51 | "scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15000,25000], gamma=0.1)\n", 52 | "eps = 2/255 * 1.1\n", 53 | "running_eps = 0\n", 54 | "epoch = 0\n", 55 | "itr = 0\n", 56 | "k = 0\n", 57 | "while itr < 60000:\n", 58 | " running_loss = 0\n", 59 | " for i, data in enumerate(trainloader, 0):\n", 60 | " net.train() \n", 61 | " inputs, labels = data\n", 62 | " inputs, labels = inputs.to(device), labels.to(device)\n", 63 | " loss = 0\n", 64 | " optimizer.zero_grad()\n", 65 | " \n", 66 | " outputs = net(torch.cat([inputs, inputs], 0))\n", 67 | " outputs = outputs[:outputs.shape[0]//2]\n", 68 | " loss += (1 - k) * criterion(outputs, labels)\n", 69 | " \n", 70 | " if itr > 2000 and itr < 12000:\n", 71 | " running_eps += eps/10000\n", 72 | " k += 0.5/10000\n", 73 | " \n", 74 | " if itr > 2000:\n", 75 | " x_ub = inputs + running_eps\n", 76 | " x_lb = inputs - running_eps\n", 77 | " \n", 78 | " outputs = net.forward(torch.cat([x_ub, x_lb], 0))\n", 79 | " z_hb = outputs[:outputs.shape[0]//2]\n", 80 | " z_lb = outputs[outputs.shape[0]//2:]\n", 81 | " lb_mask = torch.eye(10).cuda()[labels]\n", 82 | " hb_mask = 1 - lb_mask\n", 83 | " outputs = z_lb * lb_mask + z_hb * hb_mask\n", 84 | " loss += k * criterion(outputs, labels)\n", 85 | "\n", 86 | " loss.backward()\n", 87 | " optimizer.step()\n", 88 | " itr+=1\n", 89 | " running_loss += loss.item()\n", 90 | " print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 600))\n", 91 | " net.eval()\n", 92 | " print_accuracy(net, trainloader, testloader, device, test=True, eps = 0)\n", 93 | " print_accuracy(net, trainloader, testloader, device, test=True, eps = running_eps)\n", 94 | " if itr > 25000:\n", 95 | " print_accuracy(net, trainloader, testloader, device, test=True, eps = 2/255)\n", 96 | " epoch+= 1" 97 | ] 98 | } 99 | ], 100 | "metadata": { 101 | "kernelspec": { 102 | "display_name": "Python 3", 103 | "language": "python", 104 | "name": "python3" 105 | }, 106 | "language_info": { 107 | "codemirror_mode": { 108 | "name": "ipython", 109 | "version": 3 110 | }, 111 | "file_extension": ".py", 112 | "mimetype": "text/x-python", 113 | "name": "python", 114 | "nbconvert_exporter": "python", 115 | "pygments_lexer": "ipython3", 116 | "version": "3.6.5" 117 | } 118 | }, 119 | "nbformat": 4, 120 | "nbformat_minor": 2 121 | } 122 | -------------------------------------------------------------------------------- /cifar_convex_polygon.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", 11 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", 12 | "import torch.nn as nn\n", 13 | "import math\n", 14 | "import copy\n", 15 | "import random\n", 16 | "import numpy as np\n", 17 | "import torch\n", 18 | "import torch.optim as optim\n", 19 | "import torch.nn.functional as F\n", 20 | "import torch.nn as nn\n", 21 | "from torch.nn.parameter import Parameter\n", 22 | "from hybrid_network import *\n", 23 | "from hybrid_util import *\n", 24 | "learning_rate = 0.001" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "trainloader, testloader = get_cifar() \n", 34 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 35 | "print(device)\n", 36 | "net = Cifar_Small_ConvNet(\n", 37 | " non_negative = [False, False, False, False],\n", 38 | " norm = [False, False, False, False]\n", 39 | ")\n", 40 | "net = net.to(device)" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "criterion = nn.CrossEntropyLoss()\n", 50 | "optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), learning_rate)\n", 51 | "scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[200000,250000, 300000], gamma=0.1)\n", 52 | "eps = 2/255 * 1.1\n", 53 | "running_eps = 0\n", 54 | "epoch = 0\n", 55 | "itr = 0\n", 56 | "while itr < 350000:\n", 57 | " running_loss = 0\n", 58 | " for i, data in enumerate(trainloader, 0):\n", 59 | " net.train() \n", 60 | "\n", 61 | " inputs, labels = data\n", 62 | " inputs, labels = inputs.to(device), labels.to(device)\n", 63 | " loss = 0\n", 64 | " optimizer.zero_grad()\n", 65 | "\n", 66 | " outputs = net(torch.cat([inputs, inputs], 0))\n", 67 | " outputs = outputs[:outputs.shape[0]//2]\n", 68 | " loss += criterion(outputs, labels)\n", 69 | " \n", 70 | " if itr > 10000 and itr < 160000:\n", 71 | " running_eps += eps/150000\n", 72 | " \n", 73 | " if itr > 10000:\n", 74 | " x_ub = inputs + running_eps\n", 75 | " x_lb = inputs - running_eps\n", 76 | " outputs = net.forward_g(torch.cat([x_ub, x_lb], 0))\n", 77 | " v_hb = outputs[:outputs.shape[0]//2]\n", 78 | " v_lb = outputs[outputs.shape[0]//2:]\n", 79 | " weight = net.score_function.weight\n", 80 | " bias = net.score_function.bias\n", 81 | " w = weight.unsqueeze(0).expand(v_hb.shape[0],-1,-1) - weight[labels].unsqueeze(1)\n", 82 | " b = bias.unsqueeze(0).expand(v_hb.shape[0],-1) - bias[labels].unsqueeze(-1)\n", 83 | " u = ((v_hb + v_lb)/2).unsqueeze(1)\n", 84 | " r = ((v_hb - v_lb)/2).unsqueeze(1)\n", 85 | " w = torch.transpose(w,1,2)\n", 86 | " out_u = (u@w).squeeze(1) + b\n", 87 | " out_r = (r@torch.abs(w)).squeeze(1)\n", 88 | " loss += torch.mean(torch.log(1 + torch.exp(out_u + out_r)))\n", 89 | "\n", 90 | " loss.backward()\n", 91 | " optimizer.step()\n", 92 | " itr+=1\n", 93 | " running_loss += loss.item()\n", 94 | " print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 1000))\n", 95 | " net.eval()\n", 96 | " print_accuracy(net, trainloader, testloader, device, test=True, eps = 0)\n", 97 | " if itr > 250000:\n", 98 | " print(\"verified test acc:\", verify_robustness(net, testloader, device, eps = 2/255))\n", 99 | " epoch+= 1" 100 | ] 101 | } 102 | ], 103 | "metadata": { 104 | "kernelspec": { 105 | "display_name": "Python 3", 106 | "language": "python", 107 | "name": "python3" 108 | }, 109 | "language_info": { 110 | "codemirror_mode": { 111 | "name": "ipython", 112 | "version": 3 113 | }, 114 | "file_extension": ".py", 115 | "mimetype": "text/x-python", 116 | "name": "python", 117 | "nbconvert_exporter": "python", 118 | "pygments_lexer": "ipython3", 119 | "version": "3.6.5" 120 | } 121 | }, 122 | "nbformat": 4, 123 | "nbformat_minor": 2 124 | } 125 | -------------------------------------------------------------------------------- /mnist_convex_polygon.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", 11 | "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", 12 | "\n", 13 | "import torch.nn as nn\n", 14 | "import math\n", 15 | "import copy\n", 16 | "import random\n", 17 | "import numpy as np\n", 18 | "import torch\n", 19 | "import torch.optim as optim\n", 20 | "import torch.nn.functional as F\n", 21 | "import torch.nn as nn\n", 22 | "from torch.nn.parameter import Parameter\n", 23 | "from network import *\n", 24 | "from util import *\n", 25 | "learning_rate = 0.001" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 2, 31 | "metadata": {}, 32 | "outputs": [ 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "cuda:0\n" 38 | ] 39 | } 40 | ], 41 | "source": [ 42 | "trainloader, testloader = get_mnist() \n", 43 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 44 | "print(device)\n", 45 | "net = MNIST_Medium_ConvNet(\n", 46 | " non_negative = [False, False, False, False, False, False, False],\n", 47 | " norm = [False, False, False, False, False, False, False]\n", 48 | ")\n", 49 | "net = net.to(device)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "criterion = nn.CrossEntropyLoss()\n", 59 | "optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), learning_rate)\n", 60 | "scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15000,25000], gamma=0.1)\n", 61 | "eps = 2/255 * 1.1\n", 62 | "running_eps = 0\n", 63 | "epoch = 0\n", 64 | "itr = 0\n", 65 | "while itr < 60000:\n", 66 | " running_loss = 0\n", 67 | " for i, data in enumerate(trainloader, 0):\n", 68 | " net.train() \n", 69 | "\n", 70 | " inputs, labels = data\n", 71 | " inputs, labels = inputs.to(device), labels.to(device)\n", 72 | " loss = 0\n", 73 | " optimizer.zero_grad()\n", 74 | "\n", 75 | " outputs = net(torch.cat([inputs, inputs], 0))\n", 76 | " outputs = outputs[:outputs.shape[0]//2]\n", 77 | " loss += criterion(outputs, labels)\n", 78 | " \n", 79 | " if itr > 2000 and itr < 12000:\n", 80 | " running_eps += eps/10000\n", 81 | " \n", 82 | " if itr > 2000:\n", 83 | " x_ub = inputs + running_eps\n", 84 | " x_lb = inputs - running_eps\n", 85 | " outputs = net.forward_g(torch.cat([x_ub, x_lb], 0))\n", 86 | " v_hb = outputs[:outputs.shape[0]//2]\n", 87 | " v_lb = outputs[outputs.shape[0]//2:]\n", 88 | " weight = net.score_function.weight\n", 89 | " bias = net.score_function.bias\n", 90 | " w = weight.unsqueeze(0).expand(v_hb.shape[0],-1,-1) - weight[labels].unsqueeze(1)\n", 91 | " b = bias.unsqueeze(0).expand(v_hb.shape[0],-1) - bias[labels].unsqueeze(-1)\n", 92 | " u = ((v_hb + v_lb)/2).unsqueeze(1)\n", 93 | " r = ((v_hb - v_lb)/2).unsqueeze(1)\n", 94 | " w = torch.transpose(w,1,2)\n", 95 | " out_u = (u@w).squeeze(1) + b\n", 96 | " out_r = (r@torch.abs(w)).squeeze(1)\n", 97 | " loss += torch.mean(torch.log(1 + torch.exp(out_u + out_r)))\n", 98 | "\n", 99 | " loss.backward()\n", 100 | " optimizer.step()\n", 101 | " itr+=1\n", 102 | " running_loss += loss.item()\n", 103 | " print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 600))\n", 104 | " net.eval()\n", 105 | " print_accuracy(net, trainloader, testloader, device, test=True, eps = 0)\n", 106 | " if itr > 250000:\n", 107 | " print(\"verified test acc:\", verify_robustness(net, testloader, device, eps = 2/255))\n", 108 | " epoch+= 1" 109 | ] 110 | } 111 | ], 112 | "metadata": { 113 | "kernelspec": { 114 | "display_name": "Python 3", 115 | "language": "python", 116 | "name": "python3" 117 | }, 118 | "language_info": { 119 | "codemirror_mode": { 120 | "name": "ipython", 121 | "version": 3 122 | }, 123 | "file_extension": ".py", 124 | "mimetype": "text/x-python", 125 | "name": "python", 126 | "nbconvert_exporter": "python", 127 | "pygments_lexer": "ipython3", 128 | "version": "3.6.5" 129 | } 130 | }, 131 | "nbformat": 4, 132 | "nbformat_minor": 2 133 | } 134 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | import numpy as np 6 | import torchvision.transforms as transforms 7 | 8 | def get_mnist(batch_size = 100): 9 | transform_train = transforms.Compose([transforms.ToTensor()]) 10 | transform_test = transforms.Compose([transforms.ToTensor()]) 11 | trainset = torchvision.datasets.MNIST(root='./mnist', train=True, download=True, transform=transform_train) 12 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True) 13 | testset = torchvision.datasets.MNIST(root='./mnist', train=False, download=True, transform=transform_test) 14 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False) 15 | return trainloader,testloader 16 | 17 | def get_cifar(batch_size = 200): 18 | transform = transforms.Compose( 19 | [ 20 | transforms.RandomCrop(32, padding=4), 21 | transforms.RandomHorizontalFlip(), 22 | transforms.ToTensor()]) 23 | 24 | trainset = torchvision.datasets.CIFAR10(root='./cifar10', train=True, 25 | download=True, transform=transform) 26 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 27 | shuffle=True, num_workers=2) 28 | testset = torchvision.datasets.CIFAR10(root='./cifar10', train=False, 29 | download=True, transform=transform) 30 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 31 | shuffle=False, num_workers=2) 32 | 33 | return trainloader, testloader 34 | 35 | 36 | 37 | def print_accuracy(net, trainloader, testloader, device, test=True, eps = 0): 38 | loader = 0 39 | loadertype = '' 40 | if test: 41 | loader = testloader 42 | loadertype = 'test' 43 | else: 44 | loader = trainloader 45 | loadertype = 'train' 46 | correct = 0 47 | total = 0 48 | with torch.no_grad(): 49 | for ii, data in enumerate(loader, 0): 50 | images, labels = data 51 | images, labels = images.to(device), labels.to(device) 52 | x_ub = images + eps 53 | x_lb = images - eps 54 | 55 | outputs = net(torch.cat([x_ub,x_lb], 0)) 56 | z_hb = outputs[:outputs.shape[0]//2] 57 | z_lb = outputs[outputs.shape[0]//2:] 58 | lb_mask = torch.eye(10).cuda()[labels] 59 | hb_mask = 1 - lb_mask 60 | outputs = z_lb * lb_mask + z_hb * hb_mask 61 | 62 | _, predicted = torch.max(outputs.data, 1) 63 | total += labels.size(0) 64 | correct += (predicted == labels).sum().item() 65 | correct = correct / total 66 | print('Accuracy of the network on the', total, loadertype, 'images: ',correct) 67 | return correct 68 | 69 | 70 | def verify(net, image, label, device, eps): 71 | x_ub = image + eps 72 | x_lb = image - eps 73 | outputs = net.forward_g(torch.cat([x_ub, x_lb], 0)) 74 | v_u = outputs[:outputs.shape[0]//2] 75 | v_l = outputs[outputs.shape[0]//2:] 76 | weight = net.score_function.weight 77 | bias = net.score_function.bias 78 | label = label.item() 79 | for i in range(weight.shape[0]): 80 | new_w = weight[i:i+1] - weight[label:label+1] 81 | u = (v_u + v_l)/2 82 | r = (v_u - v_l)/2 83 | if (torch.dot(new_w[0],u[0]) + torch.dot(torch.abs(new_w[0]),r[0]) + bias[i] - bias[label]).item() > 0: 84 | return False 85 | return True 86 | 87 | def verify_robustness(net, dataloader, device, eps = 0.1): 88 | net.train() 89 | total = 0 90 | correct = 0 91 | for ii, data in enumerate(dataloader, 0): 92 | images, labels = data 93 | images, labels = images.to(device), labels.to(device) 94 | total += labels.size(0) 95 | for idx in range(labels.size(0)): 96 | image = images[idx:idx + 1] 97 | label = labels[idx:idx + 1] 98 | if verify(net, image, label, device, eps): 99 | correct+=1 100 | correct = correct / total 101 | return correct 102 | 103 | from torch.autograd import Variable 104 | def pgd(net, image, label, device, step = 0.01, iterations = 40, eps = 0.1): 105 | criterion = nn.CrossEntropyLoss() 106 | inputs = image 107 | for _ in range(iterations): 108 | net.zero_grad() 109 | inputs = Variable(inputs,requires_grad = True) 110 | outputs = net(torch.cat([inputs,inputs], 0)) 111 | z_hb = outputs[:outputs.shape[0]//2] 112 | z_lb = outputs[outputs.shape[0]//2:] 113 | lb_mask = torch.eye(10).cuda()[label] 114 | hb_mask = 1 - lb_mask 115 | outputs = z_lb * lb_mask + z_hb * hb_mask 116 | loss = criterion(outputs, label) 117 | loss.backward() 118 | inputs = torch.clamp(inputs + step * inputs.grad - image, -eps, eps) + image 119 | return torch.clamp(inputs.data, 0, 1) 120 | 121 | def verify_robustness_pgd(net, dataloader, device, eps = 0.1, iterations = 10): 122 | net.train() 123 | total = 0 124 | correct = 0 125 | for ii, data in enumerate(dataloader, 0): 126 | images, labels = data 127 | images, labels = images.to(device), labels.to(device) 128 | total += labels.size(0) 129 | for idx in range(labels.size(0)): 130 | image = images[idx:idx + 1] 131 | label = labels[idx:idx + 1] 132 | image = pgd(net, image, label, device, step = 0.01, iterations = iterations, eps = eps) 133 | if verify(net, image, label, device, eps = 0): 134 | correct+=1 135 | print(correct/total) 136 | correct = correct / total 137 | return correct -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import copy 5 | import random 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | from torch.nn.parameter import Parameter 10 | from module import * 11 | import torch.nn.utils.spectral_norm as SpectralNorm 12 | 13 | class MNIST_Small_ConvNet(nn.Module): 14 | def __init__(self, 15 | num_classes=10, 16 | non_negative = [True, True, True, True], 17 | norm = [False, False, False, False]): 18 | 19 | super(MNIST_Small_ConvNet, self).__init__() 20 | self.conv1 = RobustConv2d(1,16,4,2, padding = 1, non_negative =non_negative[0]) 21 | if norm[0]: 22 | self.conv1 = SpectralNorm(self.conv1) 23 | self.conv2 = RobustConv2d(16,32,4,1, padding= 1, non_negative =non_negative[1]) 24 | if norm[1]: 25 | self.conv2 = SpectralNorm(self.conv2) 26 | self.fc1 = RobustLinear(13*13*32, 100, non_negative =non_negative[2]) 27 | if norm[2]: 28 | self.fc1 = SpectralNorm(self.fc1) 29 | self.fc2 = RobustLinear(100,10, non_negative =non_negative[3]) 30 | if norm[3]: 31 | self.fc2 = SpectralNorm(self.fc2) 32 | 33 | self.activation = F.relu 34 | self.score_function = self.fc2 35 | 36 | def forward_g(self,x): 37 | x = self.conv1(x) 38 | x = self.activation(x) 39 | x = self.conv2(x) 40 | x = self.activation(x) 41 | x = self.fc1(x.view(x.shape[0], -1)) 42 | x = self.activation(x) 43 | return x 44 | 45 | def forward(self, x): 46 | x = self.score_function(self.forward_g(x)) 47 | return x 48 | 49 | 50 | 51 | 52 | 53 | class MNIST_Medium_ConvNet(nn.Module): 54 | def __init__(self, 55 | num_classes=10, 56 | non_negative = [True, True, True, True, True, True, True], 57 | norm = [False, False, False, False, False, False, False]): 58 | 59 | super(MNIST_Medium_ConvNet, self).__init__() 60 | self.conv1 = RobustConv2d(1,32,3, stride = 1, padding = 1, non_negative = non_negative[0]) 61 | if norm[0]: 62 | self.conv1 = SpectralNorm(self.conv1) 63 | self.conv2 = RobustConv2d(32,32,4, stride = 2, padding= 1, non_negative = non_negative[1]) 64 | if norm[1]: 65 | self.conv2 = SpectralNorm(self.conv2) 66 | self.conv3 = RobustConv2d(32,64,3, stride = 1, padding = 1, non_negative = non_negative[2]) 67 | if norm[2]: 68 | self.conv3 = SpectralNorm(self.conv3) 69 | self.conv4 = RobustConv2d(64,64,4, stride = 2, padding= 1, non_negative = non_negative[3]) 70 | if norm[3]: 71 | self.conv4 = SpectralNorm(self.conv4) 72 | 73 | self.fc1 = RobustLinear(7*7*64, 512, non_negative = non_negative[4]) 74 | if norm[4]: 75 | self.fc1 = SpectralNorm(self.fc1) 76 | self.fc2 = RobustLinear(512, 512, non_negative = non_negative[5]) 77 | if norm[5]: 78 | self.fc2 = SpectralNorm(self.fc2) 79 | self.fc3 = RobustLinear(512,10, non_negative = non_negative[6]) 80 | if norm[6]: 81 | self.fc3 = SpectralNorm(self.fc3) 82 | 83 | self.activation = F.relu 84 | self.score_function = self.fc3 85 | 86 | def forward_g(self, x): 87 | x = self.conv1(x) 88 | x = self.activation(x) 89 | x = self.conv2(x) 90 | x = self.activation(x) 91 | x = self.conv3(x) 92 | x = self.activation(x) 93 | x = self.conv4(x) 94 | x = self.activation(x) 95 | 96 | x = self.fc1(x.view(x.shape[0], -1)) 97 | x = self.activation(x) 98 | x = self.fc2(x) 99 | x = self.activation(x) 100 | return x 101 | 102 | def forward(self, x): 103 | x = self.score_function(self.forward_g(x)) 104 | return x 105 | 106 | 107 | 108 | class MNIST_Large_ConvNet(nn.Module): 109 | def __init__(self, 110 | num_classes=10, 111 | non_negative = [True, True, True, True, True, True, True], 112 | norm = [False, False, False, False, False, False, False]): 113 | 114 | super(MNIST_Large_ConvNet, self).__init__() 115 | self.conv1 = RobustConv2d(1,64,3, stride = 1, padding = 1, non_negative = non_negative[0]) 116 | if norm[0]: 117 | self.conv1 = SpectralNorm(self.conv1) 118 | self.conv2 = RobustConv2d(64,64,3, stride = 1, padding = 1, non_negative = non_negative[1]) 119 | if norm[1]: 120 | self.conv2 = SpectralNorm(self.conv2) 121 | self.conv3 = RobustConv2d(64,128,3, stride = 2, padding= 1, non_negative = non_negative[2]) 122 | if norm[2]: 123 | self.conv3 = SpectralNorm(self.conv3) 124 | self.conv4 = RobustConv2d(128,128,3, stride = 1, padding = 1, non_negative = non_negative[3]) 125 | if norm[3]: 126 | self.conv4 = SpectralNorm(self.conv4) 127 | self.conv5 = RobustConv2d(128,128,3, stride = 1, padding= 1, non_negative = non_negative[4]) 128 | if norm[4]: 129 | self.conv5 = SpectralNorm(self.conv5) 130 | 131 | self.fc1 = RobustLinear(14*14*128, 200, non_negative = non_negative[5]) 132 | if norm[5]: 133 | self.fc1 = SpectralNorm(self.fc1) 134 | self.fc2 = RobustLinear(200, 10, non_negative = non_negative[6]) 135 | if norm[6]: 136 | self.fc2 = SpectralNorm(self.fc2) 137 | 138 | self.activation = F.relu 139 | self.score_function = self.fc2 140 | 141 | def forward_g(self, x): 142 | x = self.conv1(x) 143 | x = self.activation(x) 144 | x = self.conv2(x) 145 | x = self.activation(x) 146 | x = self.conv3(x) 147 | x = self.activation(x) 148 | x = self.conv4(x) 149 | x = self.activation(x) 150 | x = self.conv5(x) 151 | x = self.activation(x) 152 | 153 | x = self.fc1(x.view(x.shape[0], -1)) 154 | x = self.activation(x) 155 | return x 156 | 157 | def forward(self, x): 158 | x = self.score_function(self.forward_g(x)) 159 | return x 160 | 161 | class Cifar_Small_ConvNet(nn.Module): 162 | def __init__(self, 163 | num_classes=10, 164 | non_negative = [True, True, True, True], 165 | norm = [False, False, False, False]): 166 | 167 | super(Cifar_Small_ConvNet, self).__init__() 168 | self.conv1 = RobustConv2d(3,16,4, stride = 2, padding = 0, non_negative = non_negative[0]) 169 | if norm[0]: 170 | self.conv1 = SpectralNorm(self.conv1) 171 | self.conv2 = RobustConv2d(16,32,4, stride = 1, padding= 0, non_negative = non_negative[1]) 172 | if norm[1]: 173 | self.conv2 = SpectralNorm(self.conv2) 174 | self.fc1 = RobustLinear(12*12*32, 100, non_negative = non_negative[2]) 175 | if norm[2]: 176 | self.fc1 = SpectralNorm(self.fc1) 177 | self.fc2 = RobustLinear(100,10, non_negative = non_negative[3]) 178 | if norm[3]: 179 | self.fc2 = SpectralNorm(self.fc2) 180 | 181 | self.deconv1 = nn.ConvTranspose2d(32,16,4, padding = 0, stride = 1) 182 | self.deconv2 = nn.ConvTranspose2d(16,3,4, padding = 0, stride = 2) 183 | 184 | self.activation = F.relu 185 | self.score_function = self.fc2 186 | 187 | self.image_norm = ImageNorm([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010]) 188 | 189 | def forward_conv(self,x): 190 | x = self.image_norm(x) 191 | x = self.conv1(x) 192 | x = self.activation(x) 193 | x = self.conv2(x) 194 | x = self.activation(x) 195 | return x 196 | 197 | def forward_g(self, x): 198 | x = self.forward_conv(x) 199 | x = self.fc1(x.view(x.shape[0], -1)) 200 | x = self.activation(x) 201 | return x 202 | 203 | def forward(self, x): 204 | x = self.score_function(self.forward_g(x)) 205 | return x 206 | 207 | 208 | 209 | 210 | class Cifar_Medium_ConvNet(nn.Module): 211 | def __init__(self, 212 | num_classes=10, 213 | non_negative = [True, True, True, True, True, True, True], 214 | norm = [False, False, False, False, False, False, False]): 215 | 216 | super(Cifar_Medium_ConvNet, self).__init__() 217 | self.conv1 = RobustConv2d(3,32,3, stride = 1, padding = 1, non_negative = non_negative[0]) 218 | if norm[0]: 219 | self.conv1 = SpectralNorm(self.conv1) 220 | self.conv2 = RobustConv2d(32,32,4, stride = 2, padding= 1, non_negative = non_negative[1]) 221 | if norm[1]: 222 | self.conv2 = SpectralNorm(self.conv2) 223 | self.conv3 = RobustConv2d(32,64,3, stride = 1, padding = 1, non_negative = non_negative[2]) 224 | if norm[2]: 225 | self.conv3 = SpectralNorm(self.conv3) 226 | self.conv4 = RobustConv2d(64,64,4, stride = 2, padding= 1, non_negative = non_negative[3]) 227 | if norm[3]: 228 | self.conv4 = SpectralNorm(self.conv4) 229 | 230 | self.fc1 = RobustLinear(8*8*64, 512, non_negative = non_negative[4]) 231 | if norm[4]: 232 | self.fc1 = SpectralNorm(self.fc1) 233 | self.fc2 = RobustLinear(512, 512, non_negative = non_negative[5]) 234 | if norm[5]: 235 | self.fc2 = SpectralNorm(self.fc2) 236 | self.fc3 = RobustLinear(512,10, non_negative = non_negative[6]) 237 | if norm[6]: 238 | self.fc3 = SpectralNorm(self.fc3) 239 | self.activation = F.relu 240 | self.score_function = self.fc3 241 | self.image_norm = ImageNorm([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010]) 242 | 243 | def forward_g(self, x): 244 | x = self.image_norm(x) 245 | x = self.conv1(x) 246 | x = self.activation(x) 247 | x = self.conv2(x) 248 | x = self.activation(x) 249 | x = self.conv3(x) 250 | x = self.activation(x) 251 | x = self.conv4(x) 252 | x = self.activation(x) 253 | 254 | x = self.fc1(x.view(x.shape[0], -1)) 255 | x = self.activation(x) 256 | x = self.fc2(x) 257 | x = self.activation(x) 258 | return x 259 | 260 | def forward(self, x): 261 | x = self.score_function(self.forward_g(x)) 262 | return x 263 | 264 | class Cifar_Large_ConvNet(nn.Module): 265 | def __init__(self, 266 | num_classes=10, 267 | non_negative = [True, True, True, True, True, True, True], 268 | norm = [False, False, False, False, False, False, False]): 269 | 270 | super(Cifar_Large_ConvNet, self).__init__() 271 | self.conv1 = RobustConv2d(3,64,3, stride = 1, padding = 1, non_negative = non_negative[0]) 272 | if norm[0]: 273 | self.conv1 = SpectralNorm(self.conv1) 274 | 275 | self.conv2 = RobustConv2d(64,64,3, stride = 1, padding = 1, non_negative = non_negative[1]) 276 | if norm[1]: 277 | self.conv2 = SpectralNorm(self.conv2) 278 | 279 | self.conv3 = RobustConv2d(64,128,3, stride = 2, padding= 1, non_negative = non_negative[2]) 280 | if norm[2]: 281 | self.conv3 = SpectralNorm(self.conv3) 282 | 283 | self.conv4 = RobustConv2d(128,128,3, stride = 1, padding = 1, non_negative = non_negative[3]) 284 | if norm[3]: 285 | self.conv4 = SpectralNorm(self.conv4) 286 | 287 | self.conv5 = RobustConv2d(128,128,3, stride = 1, padding= 1, non_negative = non_negative[4]) 288 | if norm[4]: 289 | self.conv5 = SpectralNorm(self.conv5) 290 | 291 | self.fc1 = RobustLinear(16*16*128, 200, non_negative = non_negative[5]) 292 | if norm[5]: 293 | self.fc1 = SpectralNorm(self.fc1) 294 | 295 | self.fc2 = RobustLinear(200,10, non_negative = non_negative[6]) 296 | if norm[6]: 297 | self.fc2 = SpectralNorm(self.fc2) 298 | 299 | self.deconv1 = nn.ConvTranspose2d(128,128,3, padding = 1, stride = 1) 300 | self.deconv2 = nn.ConvTranspose2d(128,128,3, padding = 1, stride = 1) 301 | self.deconv3 = nn.ConvTranspose2d(128,64,3, padding = 1, stride = 2, output_padding = 1) 302 | self.deconv4 = nn.ConvTranspose2d(64,64,3, padding = 1, stride = 1) 303 | self.deconv5 = nn.ConvTranspose2d(64,3,3, padding = 1, stride = 1) 304 | 305 | self.activation = F.leaky_relu 306 | self.score_function = self.fc2 307 | self.image_norm = ImageNorm([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010]) 308 | 309 | def forward_g(self, x): 310 | x = self.image_norm(x) 311 | x = self.conv1(x) 312 | x = self.activation(x) 313 | x = self.conv2(x) 314 | x = self.activation(x) 315 | x = self.conv3(x) 316 | x = self.activation(x) 317 | x = self.conv4(x) 318 | x = self.activation(x) 319 | x = self.conv5(x) 320 | x = self.activation(x) 321 | x = self.forward_conv(x) 322 | x = self.fc1(x.view(x.shape[0], -1)) 323 | x = self.activation(x) 324 | return x 325 | 326 | def forward(self, x): 327 | x = self.score_function(self.forward_g(x)) 328 | return x 329 | 330 | 331 | class Cifar_VGG(nn.Module): 332 | def __init__(self, 333 | num_classes=10, 334 | non_negative = [True, True, True, True, True, True], 335 | norm = [False, False, False, False, False, False]): 336 | 337 | super(Cifar_VGG, self).__init__() 338 | self.conv1 = RobustConv2d(3,64,3, stride = 1, padding = 1, non_negative = non_negative[0]) 339 | if norm[0]: 340 | self.conv1 = SpectralNorm(self.conv1) 341 | 342 | self.conv2 = RobustConv2d(64,64,3, stride = 1, padding = 1, non_negative = non_negative[1]) 343 | if norm[1]: 344 | self.conv2 = SpectralNorm(self.conv2) 345 | 346 | self.conv3 = RobustConv2d(64,128,3, stride = 2, padding= 1, non_negative = non_negative[2]) 347 | if norm[2]: 348 | self.conv3 = SpectralNorm(self.conv3) 349 | 350 | self.conv4 = RobustConv2d(128,128,3, stride = 1, padding = 1, non_negative = non_negative[3]) 351 | if norm[3]: 352 | self.conv4 = SpectralNorm(self.conv4) 353 | 354 | self.fc1 = RobustLinear(16*16*128, 512, non_negative = non_negative[4]) 355 | if norm[4]: 356 | self.fc1 = SpectralNorm(self.fc1) 357 | 358 | self.fc2 = RobustLinear(512,10, non_negative = non_negative[5]) 359 | if norm[5]: 360 | self.fc2 = SpectralNorm(self.fc2) 361 | 362 | self.activation = F.relu 363 | self.score_function = self.fc2 364 | self.image_norm = ImageNorm([0.4914, 0.4822, 0.4465],[0.2023, 0.1994, 0.2010]) 365 | 366 | def forward_g(self, x): 367 | x = self.image_norm(x) 368 | x = self.conv1(x) 369 | x = self.activation(x) 370 | x = self.conv2(x) 371 | x = self.activation(x) 372 | x = self.conv3(x) 373 | x = self.activation(x) 374 | x = self.conv4(x) 375 | x = self.activation(x) 376 | 377 | x = self.fc1(x.view(x.shape[0], -1)) 378 | x = self.activation(x) 379 | return x 380 | 381 | def forward(self, x): 382 | x = self.score_function(self.forward_g(x)) 383 | return x 384 | 385 | 386 | 387 | --------------------------------------------------------------------------------