└── One Pixel Attack for Fooling Deep Neural Networks.ipynb /One Pixel Attack for Fooling Deep Neural Networks.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# One Pixel Attack for Fooling Deep Neural Networks\n", 8 | "An implementation of the procedure described in https://arxiv.org/abs/1710.08864." 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "metadata": {}, 14 | "source": [ 15 | "## Setup" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "%matplotlib inline\n", 25 | "import matplotlib.pyplot as plt\n", 26 | "import numpy as np\n", 27 | "import seaborn as sns\n", 28 | "import torch\n", 29 | "import torch.nn as nn\n", 30 | "import torch.nn.functional as F\n", 31 | "import torch.optim as optim\n", 32 | "\n", 33 | "from pathlib import Path\n", 34 | "from tensorboardX import SummaryWriter\n", 35 | "from torchvision import datasets, transforms, models\n", 36 | "from tqdm import tqdm\n", 37 | "\n", 38 | "writer = SummaryWriter()\n", 39 | "sns.set()\n", 40 | "sns.set_style(\"dark\")\n", 41 | "sns.set_palette(\"muted\")\n", 42 | "sns.set_color_codes(\"muted\")" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "### CUDA" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 2, 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "name": "stdout", 59 | "output_type": "stream", 60 | "text": [ 61 | "CUDA Available: True\n" 62 | ] 63 | } 64 | ], 65 | "source": [ 66 | "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 67 | "LOADER_KWARGS = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {'num_workers': 4}\n", 68 | "print(\"CUDA Available:\", torch.cuda.is_available())" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "## Train CIFAR VGG16 Model" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "### Model Definition" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 3, 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "cifar_model = models.vgg16(pretrained=True, init_weights=False)\n", 92 | "cifar_model.classifier = nn.Sequential(\n", 93 | " nn.Linear(512, 2048),\n", 94 | " nn.ReLU(True),\n", 95 | " nn.Dropout(),\n", 96 | " nn.Linear(2048, 2048),\n", 97 | " nn.ReLU(True),\n", 98 | " nn.Dropout(),\n", 99 | " nn.Linear(2048, 10),\n", 100 | " )\n", 101 | "cifar_model = cifar_model.to(DEVICE)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "### Dataloading" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 4, 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "Files already downloaded and verified\n", 121 | "Files already downloaded and verified\n", 122 | "Files already downloaded and verified\n" 123 | ] 124 | } 125 | ], 126 | "source": [ 127 | "BATCH_SIZE = 128\n", 128 | "TRAIN_COUNT = 40_000\n", 129 | "VAL_COUNT = 10_000\n", 130 | "TEST_COUNT = 10_000\n", 131 | "\n", 132 | "train_transform = transforms.Compose([\n", 133 | " transforms.RandomHorizontalFlip(),\n", 134 | " transforms.ToTensor(),\n", 135 | "])\n", 136 | "\n", 137 | "test_transform = transforms.ToTensor()\n", 138 | "\n", 139 | "train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)\n", 140 | "train_set = torch.utils.data.dataset.Subset(train_set, range(0,TRAIN_COUNT))\n", 141 | "train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, **LOADER_KWARGS)\n", 142 | "\n", 143 | "val_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=test_transform)\n", 144 | "val_set = torch.utils.data.dataset.Subset(val_set, range(TRAIN_COUNT,TRAIN_COUNT+VAL_COUNT))\n", 145 | "val_loader = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, **LOADER_KWARGS)\n", 146 | "\n", 147 | "test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)\n", 148 | "test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, **LOADER_KWARGS)" 149 | ] 150 | }, 151 | { 152 | "cell_type": "markdown", 153 | "metadata": {}, 154 | "source": [ 155 | "### Test and Validation Function" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 5, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "def test(epoch=None, is_validation=False):\n", 165 | " cifar_model.eval()\n", 166 | " loader = val_loader if is_validation else test_loader\n", 167 | " test_loss = 0\n", 168 | " test_correct = 0\n", 169 | " with torch.no_grad():\n", 170 | " for batch_idx, (inputs, targets) in enumerate(loader):\n", 171 | " inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)\n", 172 | " outputs = cifar_model(inputs)\n", 173 | " test_loss += F.cross_entropy(outputs, targets, size_average=False).item()\n", 174 | " test_correct += outputs.max(1)[1].eq(targets).sum().item()\n", 175 | " if is_validation:\n", 176 | " writer.add_scalar('logs/val_loss', test_loss/len(loader.dataset), epoch)\n", 177 | " writer.add_scalar('logs/val_acc', test_correct/len(loader.dataset), epoch)\n", 178 | " else:\n", 179 | " print(\"Test Accuracy: {}/{}\".format(test_correct, len(loader.dataset)))" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "### Train Function" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 6, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "#optimizer = optim.Adam(cifar_model.classifier.parameters())\n", 196 | "optimizer = optim.Adam(cifar_model.parameters())\n", 197 | "\n", 198 | "# Training\n", 199 | "def train(epoch):\n", 200 | " cifar_model.train()\n", 201 | " epoch_loss = 0\n", 202 | " epoch_correct = 0\n", 203 | " for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader)):\n", 204 | " inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)\n", 205 | " optimizer.zero_grad()\n", 206 | " outputs = cifar_model(inputs)\n", 207 | " loss = F.cross_entropy(outputs, targets)\n", 208 | " loss.backward()\n", 209 | " optimizer.step()\n", 210 | " batch_correct = outputs.max(1)[1].eq(targets).sum().item()\n", 211 | " epoch_loss += loss.item()\n", 212 | " epoch_correct += batch_correct\n", 213 | " writer.add_scalar('logs/train_loss', loss.item(), epoch*len(train_loader) + batch_idx)\n", 214 | " writer.add_scalar('logs/train_acc', batch_correct / targets.size(0), epoch*len(train_loader) + batch_idx)\n", 215 | " test(epoch, is_validation=True)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "### Train Model and Store Weights (or Load Weights)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 7, 228 | "metadata": {}, 229 | "outputs": [ 230 | { 231 | "name": "stdout", 232 | "output_type": "stream", 233 | "text": [ 234 | "Loaded weights from file: vgg_cifar_weights.pt\n" 235 | ] 236 | } 237 | ], 238 | "source": [ 239 | "TRAIN_EPOCHS = 20\n", 240 | "WEIGHTS_PATH = Path(\"./vgg_cifar_weights.pt\")\n", 241 | "\n", 242 | "if WEIGHTS_PATH.is_file():\n", 243 | " cifar_model.load_state_dict(torch.load(WEIGHTS_PATH))\n", 244 | " print(\"Loaded weights from file:\", WEIGHTS_PATH)\n", 245 | "else:\n", 246 | " for epoch in range(TRAIN_EPOCHS):\n", 247 | " train(epoch)\n", 248 | " torch.save(cifar_model.state_dict(), WEIGHTS_PATH)" 249 | ] 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "metadata": {}, 254 | "source": [ 255 | "### Test Model Accuracy" 256 | ] 257 | }, 258 | { 259 | "cell_type": "code", 260 | "execution_count": 8, 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "name": "stdout", 265 | "output_type": "stream", 266 | "text": [ 267 | "Test Accuracy: 8367/10000\n" 268 | ] 269 | } 270 | ], 271 | "source": [ 272 | "test()" 273 | ] 274 | }, 275 | { 276 | "cell_type": "markdown", 277 | "metadata": {}, 278 | "source": [ 279 | "## Attack CIFAR Model" 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "execution_count": 9, 285 | "metadata": {}, 286 | "outputs": [], 287 | "source": [ 288 | "CIFAR_LABELS = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n", 289 | "\n", 290 | "def show(img):\n", 291 | " npimg = img.cpu().numpy()\n", 292 | " plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')\n", 293 | "\n", 294 | "def tell(img, label, model, target_label=None):\n", 295 | " print(\"True Label:\", CIFAR_LABELS[label], label)\n", 296 | " print(\"Prediction:\", CIFAR_LABELS[model(img.unsqueeze(0)).max(-1)[1]], model(img.unsqueeze(0)).max(-1)[1][0].item())\n", 297 | " print(\"Label Probabilities:\", F.softmax(model(img.unsqueeze(0)).squeeze(), dim=0))\n", 298 | " print(\"True Label Probability:\", F.softmax(model(img.unsqueeze(0)).squeeze(), dim=0)[label].item())\n", 299 | " if target_label is not None:\n", 300 | " print(\"Target Label Probability:\", F.softmax(model(img.unsqueeze(0)).squeeze(), dim=0)[target_label].item())" 301 | ] 302 | }, 303 | { 304 | "cell_type": "markdown", 305 | "metadata": {}, 306 | "source": [ 307 | "### Prediction" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 10, 313 | "metadata": {}, 314 | "outputs": [ 315 | { 316 | "name": "stdout", 317 | "output_type": "stream", 318 | "text": [ 319 | "True Label: deer 4\n", 320 | "Prediction: deer 4\n", 321 | "Label Probabilities: tensor([ 6.9542e-04, 1.7839e-07, 6.7065e-03, 1.1120e-03, 9.8898e-01,\n", 322 | " 9.7548e-04, 3.2448e-04, 1.1802e-03, 2.2561e-05, 4.2786e-06], device='cuda:0')\n", 323 | "True Label Probability: 0.9889790415763855\n" 324 | ] 325 | }, 326 | { 327 | "data": { 328 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAHchJREFUeJztnWuMpFd95p+6X7q6q68z3TPT457r8YxNPHYCgXiTOMCyBpEYVguK0SJ/QISV4AMSX5C/gFZZLSsFWEuLkCBYGIlwSYCFIGs3xJvIOCLA2thm7PHx2DM9t77fqqqrquv65kPXJCPrPG/3TFdXT3KenzSa7vPv876nTr1PvVXnqf//RIIggBDi3z7RvR6AEKI3SOxCeILELoQnSOxCeILELoQnSOxCeEJ8J52NMQ8CeAxADMCfW2s/F/b3//2xd1Cf78qlRdqvVqs529P9/FyHBiI0FmvyfvUGtyIHhtwnPDTOzxWNpGns4tUqjY3sH6KxE3ecpLH5tRec7fn8Idon0R6nsVa9RWP9g/yxDQwdc7Y//8IF2iefH6GxU6fvobFS4SqNXb4042xfXp6jfaqtAo2tFPjFMzp4lMZOHj9NY/PLrzrbp69Y2udNJ99HY//+XR9yXpC3fGc3xsQAfAnAuwGcBvCwMYY/IiHEnrKTt/FvAfCatfaCtbYO4NsAHurOsIQQ3WYnYj8I4MoNv1/ttAkhbkN2InbX5wJ991aI25SdiP0qgMkbfj8EwL0aIoTYc3ayGv9LACeMMUcAXAPwxwA+1JVRCSG6zi2L3VrbNMZ8AsD/xab19ri19qWwPiPZLI3lp8ZoLJNpO9svXl2jfeoNbhllEjSE8jI/5tJaxdk+muceYDrNH/PqyiqNVRvc4hke4FZZKjHqbC+UuLU5lB+msWKRz2N1mY8fyDhb8/kB2uOuu3msUec2VDLDbcr+oZyzPdvP7cuZWf4G9YVX/5HGMml+Ddfr/PmcPHCns310yP1cAsDS3DKNMXbks1trnwTw5E6OIYToDfoGnRCeILEL4QkSuxCeILEL4QkSuxCesKPV+JvlyZ++QmPjI3wov3nXHc72U0cnne0AcHH6Co0FIV/0Sw2kaGyj5M6+Wy9zL69Sc9t1ANBo8cfcWFmisXiCH3N47ICzvbDozqwCgEhzg8Zq0QaNzYfYlFcv/czZPjH527RPLMQSLZa4zdds8OesXJh3tqdzPFPx0AG3bQgA41e4Hba6ukJjFy/9gsaOnzjhbG826rRPaf3mv6yqO7sQniCxC+EJErsQniCxC+EJErsQntDT1XhEeVLF7DJfeXzlojsx4c2neF21qcODNDa/xFdNU3G+snvHQXeiw8IqX5Uutcs0lunj0x9tuxM4AKDe5PPYl04625NDvK5Iq52nsVSS13dLxbgrsFgsOduHqtdon42ye+wAsFbkdeEKy9xNqFfdq/iDIzyZaGiQz8cDbzM0NjvDn+tqNaSu3fIlZ/vqinsOAaDVOkxjDN3ZhfAEiV0IT5DYhfAEiV0IT5DYhfAEiV0IT+ip9ZZOuWvJAUC9zmMzc277ZP0wr/mVz3ELrQVu8+Wz3P45edSdZJJboF1QDbHJRnPc4imW+fgLq/w1erXfbSvGEzzxI5+foLHBEHstEeGW19KiO6ulGrgTUwBgbYlfA6U1Ph/VKp/jeMydMNKocyus0QyplZjn18e1y24LDQCSqRiNTYy7E7qG81ye05fdSVlh6M4uhCdI7EJ4gsQuhCdI7EJ4gsQuhCdI7EJ4wo6sN2PMNIASgBaAprX2t8L+vlrh9cyyWb49TnHd3e/sBb6l0Z138Kyx1TU+jtYGn5KV1aKz/fSp36R9ojG+NdTIsLu2HgA0wa2m5UKVxhJRt/XWDLGaUmluAcbrvOZaoTpLY3dM3edsj4HXIVxb41ljC/PuuQeA/gFuh/X3u+e/hZCtt2p8G6qFGb41VLXOPdhYih+zBXfWZCbDM/OSIVYqoxs++x9Ya3l1RCHEbYHexgvhCTsVewDgb4wxzxpj/qQbAxJC7A47Ffv91tr7ALwbwMeNMb/XhTEJIXaBHYndWjvT+X8BwA8AvKUbgxJCdJ9bFrsxps8Y03/9ZwDvAnC2WwMTQnSXnazG7wfwA2PM9eP8hbX2/4R1WFzlWVJTab73TzLiHubyPLeTygM8E2ooy7Oa2iGOxsX5OWd7PSRT7s13PUhjfSO8aODcAi+K2Yzw1+jRsePO9qC5TvsUi/xci4u8mOZySKHHO4+7t1Ca/jm3vOoD+2ksOcgfcybDn+sS2TZqvcifs3iO24ORKJ+PiRC7t1rl13dhxX3MsaEh2icGPn7GLYvdWnsBwD232l8I0VtkvQnhCRK7EJ4gsQvhCRK7EJ4gsQvhCb3d663G7Yfz5/nrzuSEOwMsG2LXXZvlxQvjcW7VNNu8wOKBw267Y/Hya7TP3658i8be/rt/RGOLyxdo7JXXeaxx/D842ycnz9A+QcAtzEyKXyJ3HnZntgFAmvQLQrK/Lr3Ks8bueefv0lguM0xji3NuG21sPE37FDd4ZttagXuzAXg2YsAvOeQG3HPSP8CtvHZjmR+QoDu7EJ4gsQvhCRK7EJ4gsQvhCRK7EJ7Q09X437l3H401G3yVc2Rw0Nl+cJCvxmf6+Wrr8Aiv7fXcCzxxr7LhToSZGg+pd1fhK/XP/+KvaWxkiM9VdIMnkywt/H9neyrDa/wNZHlCTqaP19DbN2hoLBpzb0/09Drftmh+nSfW/EaT93vp5V/RWCTpXiHPDZ+ifWLrIUkysas0tlbiNRH3jxyksfyAOzHr6gzfTqqFPhpj6M4uhCdI7EJ4gsQuhCdI7EJ4gsQuhCdI7EJ4Qk+tt8nRGI2dPsHtnyDqrk2WjfHEg1SGb2l07K7fprFcjlt2Z59/2tm+usCTZ/bv53XVAp6rg5demaaxy7N8A56Jsrv2WzRs+6EhbjWtF/glkozwemzR5mVn+0BxmvaJp/nzufiae+4BoFbn23n1j08428tlnvyzsOQeOwDUq9wiToZcc3PLbtsWAOJJ9zEzWX4vjiVkvQkhCBK7EJ4gsQvhCRK7EJ4gsQvhCRK7EJ6wpfVmjHkcwHsBLFhr7+60DQP4DoApANMAPmit5alYHcrFgMaaIfXp5lfc2T/NCN+26Hfe9h4ai6d5BtLRE7zWWbvu3kLpvP0FP1eE2zH9A0dorFnj9djmZngGWGXNbW9G6nx+l+a41VSoFGksE+PbE+2H2x48Oshtw5lIncaSbZ4Rl27yx1Zacs//C6u8zlwyxbcHGx3mNuXyLD9mJMuzDjda7ntuNj5C+wSNkKJ2hO3c2b8O4I0bln0awFPW2hMAnur8LoS4jdlS7NbapwG88Rb6EIAnOj8/AeB9XR6XEKLL3Opn9v3W2lkA6PzPKy0IIW4LtEAnhCfcqtjnjTETAND5n68mCSFuC25V7D8C8Ejn50cA/LA7wxFC7Bbbsd6+BeABAKPGmKsAPgPgcwC+a4z5CIDLAD6wnZNV2jzNq97gQ5kYdWdsVVrcqglabpsMAGq1Eo3lhk7Q2J33ua2QSIzbKqVLvODkzCtXaCyV4xmCb9rHl0giEbdtVDk/T/s02zxrLJbmllemj2dejQ2SN3tTvE+wcTeNbbS5vRZZ5VbkRt1dDPTKIrcUTxznlmI0yq3DZJZnxEXj7mxEAKhV3dfxRoRnATba/JpjbCl2a+3DJPSOmz6bEGLP0AKdEJ4gsQvhCRK7EJ4gsQvhCRK7EJ7Q04KTlRrP1CmW+fdyKmW37XL0+AHap1rg+27NNX9KY5Mn3pjz8y9kBo4628fH76F9mq++TmP1CrfDqlVu1UTj/DU60XbbNesb3PaMJrj1lu/n5zpw7zEaG9jv3p+vUL2T9klUeSHQbJrbYf0HeTZl7cLLzvYx8D3bGi2eTTm/wu2wYpXPcTbNrcP6nDsW3ceLn7ZuwXrTnV0IT5DYhfAEiV0IT5DYhfAEiV0IT5DYhfCEnlpvQ9mQzKVoisYCuK2Vaplnth08wm2ceJxbGtGAZ9JVy+6Cgq+c55lhpUU+xmOHuRVZqfKst1KNv0ZHEu6st7GAZ2S1ItzGiYHHWuvcOmyPu7PbFpruvdcAoBTymPuifK5SWX7ttOHOiIvFuU22ssatyMIq3yOuEeXZd/sPurPvACASuK/HQoXrpdnm52Lozi6EJ0jsQniCxC6EJ0jsQniCxC6EJ/R0Nf7+tz1AY/sm7qKxOtkKKRPnyS6NWpnGolG+6ptI8FXOuZkLzvZ/eOZZ2ue+Ub5d0CDfhQqZMk+ECRb5ynQiVnG2x9K89lstPsxjKzxBqbE6TWOFwklnezvBk2f2T/BtlyoVnuyytsxdgSDurjW33uLuRIjJg+YGvz6SfVxOywvc5SmSRfdiyJZdQwl+DTN0ZxfCEyR2ITxBYhfCEyR2ITxBYhfCEyR2ITxhO9s/PQ7gvQAWrLV3d9o+C+CjAK57X49aa5/c6liHJo/w2DFex6224bbRmnW+FU+rxS2vaITbLghxNOpVt9W3Ose3eLpKasIBwKmT3F6LJUISLkIeW9+Y275KxHmfSIidFE9zC7Md4+O/VCTPdTKkllw/T2i5do0/1wvX5nhs1T3+lSK3wtJ93NrciHELsLDE5yqW4o9ttelOpDoQUmswN7gL2z8B+DqA/wXgG29o/6K19s9u+oxCiD1hy7fx1tqnAfBym0KIfxXs5DP7J4wxLxpjHjfG8PdmQojbglsV+5cBHANwBsAsgM93bURCiF3hlr4bb6395y8jG2O+CuDHXRuREGJXuKU7uzHmxtpC7wdwtjvDEULsFtux3r4F4AEAo8aYqwA+A+ABY8wZAAGAaQAf287JMul+Gms3eYZPu+2uF5bMjNI+qZBzxSO8XyzCvbf8sDv2++98E+1z9plnaOxn/+DOyAKAu0/z+RiedG9DBQB9ObeN1gqxFMuXuHWVTvN6bAsRbpeuLrrtpIMT3Iq8cGGanysk069a4fNYLrqvnf4hbqE129yanTrGs/ZeOn+JH7PEz9cgWzm9dG6W9hk9ffM16LYUu7X2YUfz1276TEKIPUXfoBPCEyR2ITxBYhfCEyR2ITxBYhfCE3pacDKI8sKGTZL5AwCplLtfEHDLpV3nllE7xbd/ioRsMxSLu78VXK1yWyU/xrcZem2av9befYxnZSEzTUNXLhFLpsGzpCo8wQ654TyNlUoh2zUNu63DoM2flxh4wclIk9uDqyvubbkAYGDAnW3WCinYWFws0VhQ4c9Z2LZiBfBJLjfcc1IPScGMp2/+Pq07uxCeILEL4QkSuxCeILEL4QkSuxCeILEL4Qk9td7SfSHZZlFeELHecNtQLe6SIZHiWUGtxjSNBXWe8dRsuk9YX79K+9xxkm/oduTInTQ2N//XNDY+NkVjwZB7jqvFi7TPYEhxy8jYAzTW13+IxlJxt+W40eSXXDzOn9BMiluziQy/ZxWKa872ZEga4PAQvxZXyiGFKrP8sZUqPIsxVnWfb2IgR/vEY3yMDN3ZhfAEiV0IT5DYhfAEiV0IT5DYhfCEnq7Gx+M8YSQCvhLbbq872xMJvloZkJVzACgULtNYNMaPuTzv3t4nE+Wr2Sfv/o80lkmM09jP/vJvaez4oXfQ2LGT73G2X37hq7RPtrFAYznDx19Yv0ZjmaR79Xy9WqF9lhbnaaze5P0GBvnWSrWGO6llYohvdVAh7g8AlMOcixZfcQ9CtuzKZty6aDV5vb5Yi2uJoTu7EJ4gsQvhCRK7EJ4gsQvhCRK7EJ4gsQvhCdvZ/mkSwDcAjANoA/iKtfYxY8wwgO8AmMLmFlAftNauhh2rthESzg7SUCrhru1VbyzTPs0at2piAa8Lh5DYlXM/cbbnSPIJABw+eh+NPffML2hssczrsaUHJ2isVnPX5Uu0uZ203uA2znAfr6uWqvG6drW6exzRCD9eLiSRJN/Pxzg/zy2qSuC+n+VyfHuwWoFfp8mAz2Ojyu+dyYBbwdGsu18kzo+XCLGx6Xm28TdNAJ+y1p4C8FYAHzfGnAbwaQBPWWtPAHiq87sQ4jZlS7Fba2ettc91fi4BOAfgIICHADzR+bMnALxvtwYphNg5N/WZ3RgzBeBeAD8HsN9aOwtsviAA2Nf10Qkhusa2xW6MyQH4HoBPWmt5wXYhxG3JtsRujElgU+jftNZ+v9M8b4yZ6MQnAPAvWAsh9pwtxW6MiWBzP/Zz1tov3BD6EYBHOj8/AuCH3R+eEKJbbCfr7X4AHwbwa2PM8522RwF8DsB3jTEfAXAZwAe2OlAQuLPGNoPcaqrX3LbF0jLfEijaXqSxCEKsmjJ//Vuc+5Wz/cz9v0/7BCFTfOGlF2gsnef13VrgWyhdu/j/nO3RBs/Wml3h9d3yy+dpLMz+KRfd1md5Y4n2aTfc9eIAIJXgNl8kZBy1BXfWW2mVZ9g16ty2XS7xT7CVAr92DuczNFZPu7P2hveN0D59aZ5hx9hS7NbaZwCwKow811IIcVuhb9AJ4QkSuxCeILEL4QkSuxCeILEL4Qk9LTi5tHqJxvINnm3WqLkzjdYK/Hs8F18/R2Pzczy2r5+//q2W3JZMtv8u2qda41PcaPMtqg5PcSty7uJzNDZz+VVn+2DWbUEBQK3Fs+jWV6ZprM3dJFTW3eeLtPhjrmxwW6vV5M/L6Bjfsmt+1m3bzszz+cimuM2XCnhxy2rA7bCZIn/c2Yb7mP1R/rga/Cmj6M4uhCdI7EJ4gsQuhCdI7EJ4gsQuhCdI7EJ4Qk+tt5fP/ozGJoZeorG5RXe2XK3JM7lefPYKjSWT3FrZd9cxGhsYzjvb19Z4BtXiErf5Fqt8+muXV2jsMK9diMV595zUszzTL5PjhTvXK7z4YqMSo7F2xZ3Bloxy66pQ4M9LIiRTcSDL91GbPLLf2b6xwa2weJ0/LwHJwASAjSY/ZsgWdxjPu+dkwj10AEAyyueeoTu7EJ4gsQvhCRK7EJ4gsQvhCRK7EJ7Q09X4F1+cobHygas0ViWJArE4Xxm9755xGtu/7wSNjR08Q2Ovn3fXoLt04Rna5/nneZ25ab6Ij2OH3Sv/ADD/Mt+CqC/lTqCZWedzFVviSSHl4GUaGxjI8X5r7uXneJJfcoUKr62HKn/MtY11Glutuh93JGT3pIH+IzyY4u5Eq8Tr6yUqPKllJO1+bMWVWdonnztOYwzd2YXwBIldCE+Q2IXwBIldCE+Q2IXwBIldCE/Y0nozxkwC+AaAcQBtAF+x1j5mjPksgI8CuL7P0qPW2ifDjrW8yBMdoi1eg+74lNtGi8T4dlK5FK8HtrzGEz9WytxGSwZuO+nsOW6R/PJFngHRSvDpr4ZshTTcl+b9sm7LrlDk1lW5zGPrBZ7ckchxOylGNhEaGuJJK3NX+DZUjXrItZPk96xW030dRPr49TGxPySxJs4L71UifD7qMe71leruWopDMV6HMJm9+USY7fjsTQCfstY+Z4zpB/CsMeYnndgXrbV/dtNnFUL0nO3s9TYLYLbzc8kYcw7Awd0emBCiu9zUZ3ZjzBSAewH8vNP0CWPMi8aYx40xQ90enBCie2xb7MaYHIDvAfiktbYI4MsAjgE4g807/+d3ZYRCiK6wre/GG2MS2BT6N6213wcAa+38DfGvAvjxroxQCNEVtryzG2MiAL4G4Jy19gs3tN+4J8X7AZzt/vCEEN1iO3f2+wF8GMCvjTHPd9oeBfCwMeYMgADANICPbXWgqTtGaCwaUmNsbsa9bdR8kdsZ6SyPlSs8A2xseJDGJsf7nO3nL3N7aoW7g8jmQmyoFW41hSSbIdvvtmRmQ2zPgDtGKITUmQtCMtGSUfcct9p8HMUCz3prtfhzFknwB5Ah9elG8gO0T1BfpLFinWfYJbh7jEyIPbiy7o4ND/A+kSifR8Z2VuOfAZymaainLoS4vdA36ITwBIldCE+Q2IXwBIldCE+Q2IXwhJ4WnOwfdFtXAFAJ2VanGXVn/yzO8AJ/069yz6svzbPG5hf5tkvT19zZcusbPFsrR6wwAIhE+GNuNbi1Mj3NM+kibfdcVfgQkeHTgRb4+Fs1bh02Uu4TBkV+rvUav/ckQ/zBaJP3GxgjWyuN8G93N5u8AGc2y7PlhmL8Ocusc6ldnXXbeQXu8iHPXWyK7uxCeILELoQnSOxCeILELoQnSOxCeILELoQn9NR6+8dfv0ZjQYS/7iSi7lh9gw9/hddrxEqb+1AzAY+lie2Sy4YUPGxzX6u2wbPGsnFu8aDJbajZOXfxwnSG96nUuL1WqXALcyLPrdRm0215lRs8G7Fe52ljySQfYzbNi0AO5txjjLf58eo1PvfjIZZdsMD3K6yESC074J6r9TLPpiyHzBVDd3YhPEFiF8ITJHYhPEFiF8ITJHYhPEFiF8ITemq9lUIsjQh4scFq2V2MMtriwx8c5AUF2yH7ymVCZiSWcI8/m+TWz74hHltbK9DYYEi23Gie23mlijutrBnhllerwc/VrHB7sN0OqVTZcNtJuZA9ysYP8yy6ZMx9PACIRXi2Wa3qfkJLyzzjMJnr58cr8n7RBh9/tcGvuVTCfa3G2lwTCLEOGbqzC+EJErsQniCxC+EJErsQniCxC+EJW67GG2PSAJ4GkOr8/V9Zaz9jjDkC4NsAhgE8B+DD1lq+dAvgQ3/4X2js4sXXaazZdK+2tlt8hbnZ4ivF7ZAcgng8pNZZxL0CGiftABCL8eM1m3y1NRoNSQohrgAAtJru1eJGi69YR0miEQAEISvu7XbIynTbfcxoyCJyIskvx0Q8bPU5pJZf0z2P0ZDaevEMj7XbPDmlb4Bfj/GQRKRM3P24gwi/BobHjtEYYzt39hqAt1tr78Hm9swPGmPeCuB/APiitfYEgFUAH7npswshesaWYrfWBtba63UuE51/AYC3A/irTvsTAN63KyMUQnSF7e7PHgPwLIDjAL4E4HUAa9ba6+8NrwI4uCsjFEJ0hW0t0FlrW9baMwAOAXgLgFOOP+MfWIQQe85NrcZba9cA/D2AtwIYNMZcf2dwCMBMd4cmhOgmW4rdGDNmjBns/JwB8E4A5wD8HYD/1PmzRwD8cLcGKYTYOdv5zD4B4InO5/YogO9aa39sjHkZwLeNMX8K4FcAvrbVgd765nfR2KEDszQWZvEwIuBWRyQkfyOMgHxQCUI+wYTFImEDCflQFPZ5iT3u0FMFYfXMQs4WclAaCXtcIVs8hc1j6BCjN/9kt9uhB6ShaEiyEcJigfuY7HoDgFzOvc1XGFuK3Vr7IoB7He0XsPn5XQjxrwB9g04IT5DYhfAEiV0IT5DYhfAEiV0IT4gEYev7Qoh/M+jOLoQnSOxCeILELoQnSOxCeILELoQnSOxCeEJPt3+6jjHmQQCPAYgB+HNr7ef2aBzTAErYrFjYtNb+Vo/O+ziA9wJYsNbe3WkbBvAdAFMApgF80Fq7ugfj+CyAjwJY7PzZo9baJ3d5HJMAvgFgHEAbwFestY/1ek5CxvFZ9HBOulnk9UZ6fmfvpMp+CcC7AZwG8LAx5nSvx3EDf2CtPdMroXf4OoAH39D2aQBPdQp4PtX5fS/GAWwWEj3T+berQu/QBPApa+0pbBZG+Xjnmuj1nLBxAL2dk10p8roXb+PfAuA1a+2FzqvStwE8tAfj2DOstU8DWHlD80PYLNwJ9KiAJxlHz7HWzlprn+v8XMJmcZSD6PGchIyjp+xWkde9EPtBAFdu+H0vi1UGAP7GGPOsMeZP9mgM19lvrZ0FNi86APv2cCyfMMa8aIx53Bgz1MsTG2OmsFk/4efYwzl5wziAHs+JMSZmjHkewAKAn6ALRV73Quyu0iF79Z3d+62192HzI8XHjTG/t0fjuJ34MoBj2Hz7OAvg8706sTEmB+B7AD5prXXvPb034+j5nOxGkde9EPtVAJM3/L5nxSqttTOd/xcA/AB7W3ln3hgzAQCd/xf2YhDW2vnOhdYG8FX0aE6MMQlsCuyb1trvd5p7PieucezVnHTO3bUir3sh9l8COGGMOWKMSQL4YwA/6vUgjDF9xpj+6z8DeBeAs70exw38CJuFO4E9LOB5XVwd3o8ezIkxJoLNGobnrLVfuCHU0zlh4+j1nOxWkdc9yXozxrwHwP/EpvX2uLX2v+3BGI5i824ObNobf9GrcRhjvgXgAQCjAOYBfAbA/wbwXQCHAVwG8AFr7a4unpFxPIDNt6sBNu2uj13/3LyL4/h3AH4K4NfYtLwA4FFsfl7u2ZyEjONh9HBOjDG/gc0FuBuLvP7XzjV73Xr7FYD/bK3lm8+9AaW4CuEJ+gadEJ4gsQvhCRK7EJ4gsQvhCRK7EJ4gsQvhCRK7EJ4gsQvhCf8Ehs0gOUA96KsAAAAASUVORK5CYII=\n", 329 | "text/plain": [ 330 | "
" 331 | ] 332 | }, 333 | "metadata": {}, 334 | "output_type": "display_data" 335 | } 336 | ], 337 | "source": [ 338 | "test_img, test_label = test_set[500]\n", 339 | "test_img = test_img.to(DEVICE)\n", 340 | "show(test_img)\n", 341 | "tell(test_img, test_label, cifar_model)" 342 | ] 343 | }, 344 | { 345 | "cell_type": "markdown", 346 | "metadata": {}, 347 | "source": [ 348 | "### Perturbation" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": 11, 354 | "metadata": {}, 355 | "outputs": [ 356 | { 357 | "name": "stdout", 358 | "output_type": "stream", 359 | "text": [ 360 | "Perturbation: [0.6 0.6 0. 0. 0.75]\n", 361 | "True Label: deer 4\n", 362 | "Prediction: deer 4\n", 363 | "Label Probabilities: tensor([ 0.0110, 0.0000, 0.0310, 0.0144, 0.9040, 0.0121, 0.0024,\n", 364 | " 0.0238, 0.0010, 0.0003], device='cuda:0')\n", 365 | "True Label Probability: 0.9039613008499146\n" 366 | ] 367 | }, 368 | { 369 | "data": { 370 | "image/png": "\n", 371 | "text/plain": [ 372 | "
" 373 | ] 374 | }, 375 | "metadata": {}, 376 | "output_type": "display_data" 377 | } 378 | ], 379 | "source": [ 380 | "def perturb(p, img):\n", 381 | " # Elements of p should be in range [0,1]\n", 382 | " img_size = img.size(1) # C x _H_ x W, assume H == W\n", 383 | " p_img = img.clone()\n", 384 | " xy = (p[0:2].copy() * img_size).astype(int)\n", 385 | " xy = np.clip(xy, 0, img_size-1)\n", 386 | " rgb = p[2:5].copy()\n", 387 | " rgb = np.clip(rgb, 0, 1)\n", 388 | " p_img[:,xy[0],xy[1]] = torch.from_numpy(rgb)\n", 389 | " return p_img\n", 390 | "\n", 391 | "def visualize_perturbation(p, img, label, model, target_label=None):\n", 392 | " p_img = perturb(p, img)\n", 393 | " print(\"Perturbation:\", p)\n", 394 | " show(p_img)\n", 395 | " tell(p_img, label, model, target_label)\n", 396 | "\n", 397 | "visualize_perturbation(np.array([0.6,0.6,0,0,0.75]), test_img, test_label, cifar_model)" 398 | ] 399 | }, 400 | { 401 | "cell_type": "markdown", 402 | "metadata": {}, 403 | "source": [ 404 | "### Untargeted and Targeted Attacks" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": 12, 410 | "metadata": {}, 411 | "outputs": [ 412 | { 413 | "name": "stdout", 414 | "output_type": "stream", 415 | "text": [ 416 | "Target Probability [Iteration 0]: 0.07953545451164246\n", 417 | "Target Probability [Iteration 10]: 0.07832543551921844\n", 418 | "Target Probability [Iteration 20]: 0.0675034150481224\n", 419 | "Perturbation: [0.53183933 0.21057478 0.01611458 0.00150763 0.36948785]\n", 420 | "True Label: deer 4\n", 421 | "Prediction: bird 2\n", 422 | "Label Probabilities: tensor([ 0.0074, 0.0001, 0.8440, 0.0477, 0.0345, 0.0466, 0.0119,\n", 423 | " 0.0070, 0.0003, 0.0003], device='cuda:0')\n", 424 | "True Label Probability: 0.03452445566654205\n" 425 | ] 426 | }, 427 | { 428 | "data": { 429 | "image/png": "\n", 430 | "text/plain": [ 431 | "
" 432 | ] 433 | }, 434 | "metadata": {}, 435 | "output_type": "display_data" 436 | } 437 | ], 438 | "source": [ 439 | "def evaluate(candidates, img, label, model):\n", 440 | " preds = []\n", 441 | " model.eval()\n", 442 | " with torch.no_grad():\n", 443 | " for i, xs in enumerate(candidates):\n", 444 | " p_img = perturb(xs, img)\n", 445 | " preds.append(F.softmax(model(p_img.unsqueeze(0)).squeeze(), dim=0)[label].item())\n", 446 | " return np.array(preds)\n", 447 | "\n", 448 | "def evolve(candidates, F=0.5, strategy=\"clip\"):\n", 449 | " gen2 = candidates.copy()\n", 450 | " num_candidates = len(candidates)\n", 451 | " for i in range(num_candidates):\n", 452 | " x1, x2, x3 = candidates[np.random.choice(num_candidates, 3, replace=False)]\n", 453 | " x_next = (x1 + F*(x2 - x3))\n", 454 | " if strategy == \"clip\":\n", 455 | " gen2[i] = np.clip(x_next, 0, 1)\n", 456 | " elif strategy == \"resample\":\n", 457 | " x_oob = np.logical_or((x_next < 0), (1 < x_next))\n", 458 | " x_next[x_oob] = np.random.random(5)[x_oob]\n", 459 | " gen2[i] = x_next\n", 460 | " return gen2\n", 461 | "\n", 462 | "def attack(model, img, true_label, target_label=None, iters=100, pop_size=400, verbose=True):\n", 463 | " # Targeted: maximize target_label if given (early stop > 50%)\n", 464 | " # Untargeted: minimize true_label otherwise (early stop < 5%)\n", 465 | " candidates = np.random.random((pop_size,5))\n", 466 | " candidates[:,2:5] = np.clip(np.random.normal(0.5, 0.5, (pop_size, 3)), 0, 1)\n", 467 | " is_targeted = target_label is not None\n", 468 | " label = target_label if is_targeted else true_label\n", 469 | " fitness = evaluate(candidates, img, label, model)\n", 470 | " \n", 471 | " def is_success():\n", 472 | " return (is_targeted and fitness.max() > 0.5) or ((not is_targeted) and fitness.min() < 0.05)\n", 473 | " \n", 474 | " for iteration in range(iters):\n", 475 | " # Early Stopping\n", 476 | " if is_success():\n", 477 | " break\n", 478 | " if verbose and iteration%10 == 0: # Print progress\n", 479 | " print(\"Target Probability [Iteration {}]:\".format(iteration), fitness.max() if is_targeted else fitness.min())\n", 480 | " # Generate new candidate solutions\n", 481 | " new_gen_candidates = evolve(candidates, strategy=\"resample\")\n", 482 | " # Evaluate new solutions\n", 483 | " new_gen_fitness = evaluate(new_gen_candidates, img, label, model)\n", 484 | " # Replace old solutions with new ones where they are better\n", 485 | " successors = new_gen_fitness > fitness if is_targeted else new_gen_fitness < fitness\n", 486 | " candidates[successors] = new_gen_candidates[successors]\n", 487 | " fitness[successors] = new_gen_fitness[successors]\n", 488 | " best_idx = fitness.argmax() if is_targeted else fitness.argmin()\n", 489 | " best_solution = candidates[best_idx]\n", 490 | " best_score = fitness[best_idx]\n", 491 | " if verbose:\n", 492 | " visualize_perturbation(best_solution, img, true_label, model, target_label)\n", 493 | " return is_success(), best_solution, best_score\n", 494 | "\n", 495 | "# Untargeted attack\n", 496 | "_ = attack(cifar_model, test_img, test_label)" 497 | ] 498 | }, 499 | { 500 | "cell_type": "code", 501 | "execution_count": 13, 502 | "metadata": {}, 503 | "outputs": [ 504 | { 505 | "name": "stdout", 506 | "output_type": "stream", 507 | "text": [ 508 | "airplane 0 False 0.32803890109062195\n", 509 | "automobile 1 False 0.0009910481749102473\n", 510 | "bird 2 True 0.6859549880027771\n", 511 | "cat 3 False 0.29684844613075256\n", 512 | "deer 4 True Label\n", 513 | "dog 5 True 0.5693042874336243\n", 514 | "frog 6 False 0.031753744930028915\n", 515 | "horse 7 True 0.5078780651092529\n", 516 | "ship 8 False 0.04404550418257713\n", 517 | "truck 9 False 0.009699688293039799\n" 518 | ] 519 | } 520 | ], 521 | "source": [ 522 | "# Targeted attack\n", 523 | "# This is much harder/costlier than an untargeted attack\n", 524 | "# For time reasons, targeted attacks below use 20 iterations\n", 525 | "targeted_results = {}\n", 526 | "for idx in range(len(CIFAR_LABELS)):\n", 527 | " if idx != test_label:\n", 528 | " targeted_results[idx] = attack(cifar_model, test_img, test_label, target_label=idx, iters=20, verbose=False)\n", 529 | " print(CIFAR_LABELS[idx], idx, targeted_results[idx][0], targeted_results[idx][2])\n", 530 | " else:\n", 531 | " print(CIFAR_LABELS[idx], idx, \"True Label\")" 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": 14, 537 | "metadata": {}, 538 | "outputs": [ 539 | { 540 | "name": "stdout", 541 | "output_type": "stream", 542 | "text": [ 543 | "last updated: Mon Jun 18 2018 \n", 544 | "\n", 545 | "CPython 3.6.5\n", 546 | "IPython 6.4.0\n", 547 | "\n", 548 | "torch 0.4.0\n", 549 | "numpy 1.14.3\n", 550 | "matplotlib 2.2.2\n", 551 | "tensorboardX n\u0007\n", 552 | "torchvision 0.2.1\n", 553 | "seaborn 0.8.1\n", 554 | "tqdm 4.23.4\n", 555 | "\n", 556 | "compiler : GCC 4.8.2 20140120 (Red Hat 4.8.2-15)\n", 557 | "system : Linux\n", 558 | "release : 4.13.0-1019-gcp\n", 559 | "machine : x86_64\n", 560 | "processor : x86_64\n", 561 | "CPU cores : 8\n", 562 | "interpreter: 64bit\n", 563 | "watermark 1.6.1\n" 564 | ] 565 | } 566 | ], 567 | "source": [ 568 | "%load_ext watermark\n", 569 | "%watermark --updated --datename --python --machine --watermark -p torch,numpy,matplotlib,tensorboardX,torchvision,seaborn,tqdm" 570 | ] 571 | } 572 | ], 573 | "metadata": { 574 | "kernelspec": { 575 | "display_name": "Python 3", 576 | "language": "python", 577 | "name": "python3" 578 | }, 579 | "language_info": { 580 | "codemirror_mode": { 581 | "name": "ipython", 582 | "version": 3 583 | }, 584 | "file_extension": ".py", 585 | "mimetype": "text/x-python", 586 | "name": "python", 587 | "nbconvert_exporter": "python", 588 | "pygments_lexer": "ipython3", 589 | "version": "3.6.5" 590 | } 591 | }, 592 | "nbformat": 4, 593 | "nbformat_minor": 2 594 | } 595 | --------------------------------------------------------------------------------