├── CPL ├── LFW128_enhanced_random_ASR.ipynb ├── LFW64_enhanced_random_ASR.ipynb ├── LFW_batch.ipynb ├── LFW_defense.ipynb ├── LFW_enhanced_random_ASR.ipynb └── LFW_tanhrelu.ipynb ├── DLG └── LFW_Deep_Leakage_from_Gradients.ipynb ├── GradInverting ├── LICENSE ├── Multiple images and multiple local update steps (ConvNet).ipynb ├── README.md ├── Recovery from Gradient Information.ipynb ├── Recovery from Weight Updates.ipynb ├── ResNet152 - trained on ImageNet.ipynb ├── ResNet18 - trained on ImageNet.ipynb ├── ResNet18 - untrained (ImageNet version).ipynb ├── ResNet32-10 - Recovering 100 CIFAR-100 images.ipynb ├── environment.yml ├── inversefed │ ├── __init__.py │ ├── consts.py │ ├── data │ │ ├── README.md │ │ ├── __init__.py │ │ ├── data.py │ │ ├── data_processing.py │ │ ├── datasets.py │ │ └── loss.py │ ├── medianfilt.py │ ├── metrics.py │ ├── nn │ │ ├── README.md │ │ ├── __init__.py │ │ ├── densenet.py │ │ ├── models.py │ │ ├── modules.py │ │ ├── revnet.py │ │ └── revnet_utils.py │ ├── optimization_strategy.py │ ├── options.py │ ├── reconstruction_algorithms.py │ ├── training │ │ ├── README.md │ │ ├── __init__.py │ │ ├── scheduler.py │ │ └── training_routine.py │ └── utils.py ├── rec_mult.py └── reconstruct_image.py ├── README.md └── demo ├── cifar10_dlg.gif ├── cifar10_ours.gif ├── lfw_dlg.gif ├── lfw_ours.gif ├── mnist_dlg.gif └── mnist_ours.gif /CPL/LFW_defense.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": { 7 | "colab": { 8 | "base_uri": "https://localhost:8080/", 9 | "height": 34 10 | }, 11 | "colab_type": "code", 12 | "id": "NWa7Xo6PkIl3", 13 | "outputId": "b186c3b3-e0cf-423c-922c-94e64702f818" 14 | }, 15 | "outputs": [ 16 | { 17 | "name": "stdout", 18 | "output_type": "stream", 19 | "text": [ 20 | "1.4.0 0.5.0\n", 21 | "GeForce GTX 1080 Ti\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "%matplotlib inline\n", 27 | "\n", 28 | "import numpy as np\n", 29 | "from pprint import pprint\n", 30 | "\n", 31 | "from PIL import Image\n", 32 | "import matplotlib.pyplot as plt\n", 33 | "\n", 34 | "import torch\n", 35 | "import torch.nn as nn\n", 36 | "import torch.nn.functional as F\n", 37 | "from torch.autograd import grad\n", 38 | "import torchvision\n", 39 | "from torchvision import models, datasets, transforms\n", 40 | "import torch.nn.functional as func\n", 41 | "#torch.manual_seed(50)\n", 42 | "\n", 43 | "import os\n", 44 | "#os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n", 45 | "\n", 46 | "\n", 47 | "print(torch.__version__, torchvision.__version__)\n", 48 | "print (torch.cuda.get_device_name(device='cuda:0'))" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "Running on cuda:0\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "# dst = datasets.CIFAR100(\"~/.torch\", download=True)\n", 66 | "# dst = datasets.MNIST(\"~/.torch\", download=True)\n", 67 | "\n", 68 | "tp = transforms.Compose([\n", 69 | " transforms.Resize(32),\n", 70 | " transforms.CenterCrop(32),\n", 71 | " transforms.ToTensor()\n", 72 | "])\n", 73 | "tt = transforms.ToPILImage()\n", 74 | "\n", 75 | "device = \"cpu\"\n", 76 | "if torch.cuda.is_available():\n", 77 | " device = \"cuda:0\"\n", 78 | "print(\"Running on %s\" % device)\n", 79 | "\n", 80 | "def label_to_onehot(target, num_classes=106):\n", 81 | " target = torch.unsqueeze(target, 1)\n", 82 | " onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)\n", 83 | " onehot_target.scatter_(1, target, 1)\n", 84 | " return onehot_target\n", 85 | "\n", 86 | "def cross_entropy_for_onehot(pred, target):\n", 87 | " return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 16, 93 | "metadata": { 94 | "colab": {}, 95 | "colab_type": "code", 96 | "id": "AorI020iVjjS" 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "# def weights_init(m):\n", 101 | "# if hasattr(m, \"weight\"):\n", 102 | "# m.weight.data.uniform_(-0.5, 0.5)\n", 103 | "# nn.init.xavier_uniform_(m.weight.data)\n", 104 | "# if hasattr(m, \"bias\"):\n", 105 | "# #m.bias.data.uniform_(-0.5, 0.5)\n", 106 | "# #nn.init.xavier_uniform(m.bias.data)\n", 107 | "# m.bias.data.fill_(0)\n", 108 | "\n", 109 | "\n", 110 | "\n", 111 | "# class LeNet(nn.Module):\n", 112 | "\n", 113 | "# def __init__(self):\n", 114 | "\n", 115 | "# super(LeNet, self).__init__()\n", 116 | "# self.conv1 = nn.Conv2d(3, 6, kernel_size=5,stride=2)\n", 117 | "# self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=2)\n", 118 | "# self.fc1 = nn.Linear(16*5*5, 256)\n", 119 | "# self.fc2 = nn.Linear(256, 120)\n", 120 | "# self.fc3 = nn.Linear(120, 106)\n", 121 | "\n", 122 | "# def forward(self, x):\n", 123 | "# #x = func.relu(self.conv1(x))\n", 124 | "# x = func.sigmoid(self.conv1(x))\n", 125 | "# #x = func.max_pool2d(x, 2)\n", 126 | "# #x = func.relu(self.conv2(x))\n", 127 | "# x = func.sigmoid(self.conv2(x))\n", 128 | "# #x = func.max_pool2d(x, 2)\n", 129 | "# x = x.view(x.size(0), -1)\n", 130 | "# #x = func.relu(self.fc1(x))\n", 131 | "# x = func.sigmoid(self.fc1(x))\n", 132 | "# #x = func.relu(self.fc2(x))\n", 133 | "# x = func.sigmoid(self.fc2(x))\n", 134 | "# x = self.fc3(x)\n", 135 | "# return x\n", 136 | "\n", 137 | " \n", 138 | " \n", 139 | "# def weights_init(m):\n", 140 | "# if hasattr(m, \"weight\"):\n", 141 | "# m.weight.data.uniform_(-0.3, 0.3)\n", 142 | "# if hasattr(m, \"bias\"):\n", 143 | "# m.bias.data.uniform_(-0.3, 0.3)\n", 144 | "\n", 145 | "\n", 146 | "def weights_init(m):\n", 147 | " if hasattr(m, \"weight\"):\n", 148 | " m.weight.data.uniform_(-0.5, 0.5)\n", 149 | " if hasattr(m, \"bias\"):\n", 150 | " m.bias.data.uniform_(-0.5, 0.5)\n", 151 | "\n", 152 | "\n", 153 | "class LeNet(nn.Module):\n", 154 | " def __init__(self):\n", 155 | " super(LeNet, self).__init__()\n", 156 | " act = nn.Sigmoid\n", 157 | " #act = nn.ReLU\n", 158 | " self.body = nn.Sequential(\n", 159 | " nn.Conv2d(3, 12, kernel_size=5, padding=5//2, stride=2),\n", 160 | " act(),\n", 161 | " nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=2),\n", 162 | " act(),\n", 163 | " nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=1),\n", 164 | " act(),\n", 165 | " nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=1),\n", 166 | " act(),\n", 167 | " )\n", 168 | " self.fc = nn.Sequential(\n", 169 | " nn.Linear(768, 106)\n", 170 | " )\n", 171 | " \n", 172 | " def forward(self, x):\n", 173 | " out = self.body(x)\n", 174 | " out = out.view(out.size(0), -1)\n", 175 | " # print(out.size())\n", 176 | " out = self.fc(out)\n", 177 | " return out\n", 178 | "\n", 179 | "\n", 180 | "net = LeNet().to(device)\n", 181 | "net.apply(weights_init)\n", 182 | "\n", 183 | "\n", 184 | "\n", 185 | "#criterion = cross_entropy_for_onehot\n", 186 | "criterion = nn.CrossEntropyLoss()" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 21, 192 | "metadata": { 193 | "colab": {}, 194 | "colab_type": "code", 195 | "id": "AorI020iVjjS" 196 | }, 197 | "outputs": [ 198 | { 199 | "name": "stdout", 200 | "output_type": "stream", 201 | "text": [ 202 | "(2801, 32, 32, 3)\n", 203 | "(934, 32, 32, 3)\n", 204 | "106\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "import torchvision.transforms as transforms\n", 210 | "import torch.optim as optim\n", 211 | "from torch.autograd import Variable\n", 212 | "from torch.utils import data\n", 213 | "\n", 214 | "from sklearn.datasets import fetch_lfw_people\n", 215 | "from sklearn.model_selection import train_test_split\n", 216 | "lfw_people=fetch_lfw_people(min_faces_per_person=14,color=True,slice_=(slice(61,189),slice(61,189)),resize=0.25) #14\n", 217 | "x=lfw_people.images\n", 218 | "y=lfw_people.target\n", 219 | "\n", 220 | "target_names=lfw_people.target_names\n", 221 | "n_classes=target_names.shape[0]\n", 222 | "\n", 223 | "X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.25,shuffle=False)\n", 224 | "\n", 225 | "\n", 226 | "\n", 227 | "print (X_train.shape)\n", 228 | "print (X_test.shape)\n", 229 | "print (n_classes)\n", 230 | "\n", 231 | "\n", 232 | " " 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": 22, 238 | "metadata": { 239 | "colab": {}, 240 | "colab_type": "code", 241 | "id": "AorI020iVjjS" 242 | }, 243 | "outputs": [ 244 | { 245 | "name": "stdout", 246 | "output_type": "stream", 247 | "text": [ 248 | "fininshed training\n", 249 | "0.007494646680942184\n", 250 | "fininshed testing\n" 251 | ] 252 | } 253 | ], 254 | "source": [ 255 | "X_train = X_train.reshape(X_train.shape[0], 32, 32, 3)\n", 256 | "X_test = X_test.reshape(X_test.shape[0], 32, 32, 3)\n", 257 | "#X_train = torch.transpose\n", 258 | "#X_train = X_train.astype('float32')\n", 259 | "X_train /= 255.0\n", 260 | "X_test /= 255.0\n", 261 | "\n", 262 | " \n", 263 | "\n", 264 | "\n", 265 | "x_train = torch.FloatTensor(X_train).to(device)\n", 266 | "x_train = x_train.transpose(2,3).transpose(1,2)\n", 267 | "y_train = torch.LongTensor(y_train).to(device)\n", 268 | "\n", 269 | "x_test = torch.FloatTensor(X_test).to(device)\n", 270 | "x_test = x_test.transpose(2,3).transpose(1,2)\n", 271 | "y_test = torch.LongTensor(y_test).to(device)\n", 272 | "\n", 273 | "\n", 274 | "training = data.TensorDataset(x_train,y_train)\n", 275 | "\n", 276 | "testing = data.TensorDataset(x_test,y_test)\n", 277 | "\n", 278 | "dst_tensor=training\n", 279 | "\n", 280 | "criterion_train = nn.CrossEntropyLoss()\n", 281 | "optimizer_train = optim.Adam(net.parameters(),lr=0.01)#,momentum=0.9)\n", 282 | "trainloader = torch.utils.data.DataLoader(training,batch_size=32, shuffle=True)\n", 283 | "\n", 284 | "\n", 285 | "\n", 286 | "for epoch in range(0):\n", 287 | "\n", 288 | " for i,data in enumerate(trainloader,0):\n", 289 | " #for data in trainloader:\n", 290 | " #if i<=10: \n", 291 | "\n", 292 | " inputs,label = data\n", 293 | "\n", 294 | " inputs,label = Variable(inputs).to(device),Variable(label).to(device)\n", 295 | "\n", 296 | " optimizer_train.zero_grad()\n", 297 | " outputs_benign=net(inputs)\n", 298 | "\n", 299 | " loss_benign = criterion_train(outputs_benign,label)\n", 300 | "\n", 301 | " loss_benign.backward()\n", 302 | " #sgd_update(net.parameters())\n", 303 | "\n", 304 | " optimizer_train.step()\n", 305 | " \n", 306 | " testloader = torch.utils.data.DataLoader(testing,batch_size=934, shuffle=False)\n", 307 | "\n", 308 | " acc =0.0\n", 309 | " for ji,tdata in enumerate(testloader,0):\n", 310 | "\n", 311 | " tinputs,tlabel = tdata\n", 312 | "\n", 313 | " tinputs,tlabel = Variable(tinputs).to(device),Variable(tlabel).to(device)\n", 314 | "\n", 315 | " toutputs=net(tinputs)\n", 316 | "\n", 317 | " tpredict = torch.argmax(toutputs, dim=1)\n", 318 | "\n", 319 | "\n", 320 | " for mi in range(934):\n", 321 | "\n", 322 | "\n", 323 | "\n", 324 | " if tpredict[mi] == tlabel[mi]:\n", 325 | " acc=acc+1\n", 326 | "\n", 327 | " accuracy = acc / 934\n", 328 | " print (accuracy)\n", 329 | " print ('fininshed testing')\n", 330 | " if accuracy>0.18:\n", 331 | " break\n", 332 | "\n", 333 | "\n", 334 | "print ('fininshed training')\n", 335 | "#torch.save(net.state_dict()\n", 336 | "\n", 337 | "testloader = torch.utils.data.DataLoader(testing,batch_size=934, shuffle=False)\n", 338 | "\n", 339 | "acc =0.0\n", 340 | "for ji,tdata in enumerate(testloader,0):\n", 341 | "\n", 342 | " tinputs,tlabel = tdata\n", 343 | "\n", 344 | " tinputs,tlabel = Variable(tinputs).to(device),Variable(tlabel).to(device)\n", 345 | "\n", 346 | " toutputs=net(tinputs)\n", 347 | " \n", 348 | " tpredict = torch.argmax(toutputs, dim=1)\n", 349 | " \n", 350 | " \n", 351 | " for mi in range(934):\n", 352 | " \n", 353 | " \n", 354 | "\n", 355 | " if tpredict[mi] == tlabel[mi]:\n", 356 | " acc=acc+1\n", 357 | "\n", 358 | "accuracy = acc / 934\n", 359 | "print (accuracy)\n", 360 | "print ('fininshed testing')" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 32, 366 | "metadata": {}, 367 | "outputs": [ 368 | { 369 | "name": "stdout", 370 | "output_type": "stream", 371 | "text": [ 372 | "now starting 100\n" 373 | ] 374 | }, 375 | { 376 | "name": "stderr", 377 | "output_type": "stream", 378 | "text": [ 379 | "/home/wenqi/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:118: DeprecationWarning: time.clock has been deprecated in Python 3.3 and will be removed from Python 3.8: use time.perf_counter or time.process_time instead\n" 380 | ] 381 | }, 382 | { 383 | "ename": "KeyboardInterrupt", 384 | "evalue": "", 385 | "output_type": "error", 386 | "traceback": [ 387 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 388 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 389 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 204\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 205\u001b[0;31m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 206\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0miters\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m5\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[0mcurrent_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 390 | "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/optim/lbfgs.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 430\u001b[0m \u001b[0;31m# the reason we do this: in a stochastic setting,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 431\u001b[0m \u001b[0;31m# no use to re-evaluate that function here\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 432\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 433\u001b[0m \u001b[0mflat_grad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_gather_flat_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 434\u001b[0m \u001b[0mopt_cond\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflat_grad\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mabs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mtolerance_grad\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 391 | "\u001b[0;32m\u001b[0m in \u001b[0;36mclosure\u001b[0;34m()\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0mlasso\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdummy_data\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \u001b[0mridge\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdummy_data\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 185\u001b[0;31m \u001b[0mgrad_diff\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgx\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mgy\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m#+ 0.0*lasso +0.01*ridge\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 186\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 392 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 393 | ] 394 | }, 395 | { 396 | "data": { 397 | "text/plain": [ 398 | "
" 399 | ] 400 | }, 401 | "metadata": {}, 402 | "output_type": "display_data" 403 | } 404 | ], 405 | "source": [ 406 | "from pytorch_msssim import ssim\n", 407 | "\n", 408 | "#dgc_rate_list = [20,30,40,50,70, 80,90]\n", 409 | "\n", 410 | "#for epoch in range(1):\n", 411 | "#for dgc_rate in dgc_rate_list:\n", 412 | "gau_rate_list = [100]\n", 413 | "for gau_rate in gau_rate_list:\n", 414 | "\n", 415 | " \n", 416 | " print('now starting',gau_rate)\n", 417 | "\n", 418 | " for iter_,data in enumerate(trainloader,1):\n", 419 | " \n", 420 | " if iter_ != 1:\n", 421 | " break\n", 422 | " \n", 423 | " ######### honest partipant #########\n", 424 | " img_index = 12 #use img_index\n", 425 | " dst_pil = tt(dst_tensor[img_index][0].cpu()) #use img_index\n", 426 | "\n", 427 | " gt_data = tp(dst_pil).to(device)\n", 428 | " gt_data = torch.unsqueeze(gt_data,0)\n", 429 | "\n", 430 | " gt_label = dst_tensor[img_index][1].long().to(device) #use img_index\n", 431 | " gt_label = gt_label.view(1, )\n", 432 | " gt_onehot_label = label_to_onehot(gt_label, num_classes=106)\n", 433 | "\n", 434 | " plt.imshow(dst_pil)\n", 435 | " plt.axis('off')\n", 436 | " plt.savefig(\"./attack_image/lfw_gt_nips_appendix\")\n", 437 | "\n", 438 | "\n", 439 | "\n", 440 | " batch =2 #\n", 441 | " for bat in range(batch-1):\n", 442 | " dst_pil = tt(dst_tensor[img_index+1+bat][0].cpu()) #use img_index\n", 443 | " tmp = torch.unsqueeze(tp(dst_pil).to(device),0)\n", 444 | " #print(tmp.shape)\n", 445 | " gt_data = torch.cat((gt_data,tmp),0)\n", 446 | "\n", 447 | " gt_label_tmp = dst_tensor[img_index+1+bat][1].long().to(device) #use img_index\n", 448 | " gt_label_tmp = gt_label_tmp.view(1, )\n", 449 | " gt_label = torch.cat((gt_label,gt_label_tmp),0)\n", 450 | " gt_onehot_label = torch.cat((gt_onehot_label,label_to_onehot(gt_label_tmp, num_classes=106)),0)\n", 451 | "\n", 452 | " plt.imshow(dst_pil)\n", 453 | " #plt.savefig(\"./original/index_%s_label_%s\"%(bat+1,gt_label_tmp.item()))\n", 454 | "\n", 455 | " #plt.title(\"Ground truth image\")\n", 456 | " #print(\"GT label is %d.\" % gt_label.item(), \"\\nOnehot label is %d.\" % torch.argmax(gt_onehot_label, dim=-1).item())\n", 457 | "\n", 458 | "\n", 459 | " gt_label = torch.reshape(gt_label,(-1,1)) \n", 460 | " #print (gt_data.shape)\n", 461 | " #print (gt_label.shape)\n", 462 | " #print (gt_onehot_label.shape)\n", 463 | "\n", 464 | "\n", 465 | " # compute original gradient \n", 466 | " dy_dx = []\n", 467 | " original_dy_dx=[]\n", 468 | " original_pred = []\n", 469 | " for item in range(batch):\n", 470 | " gt_data_single = torch.unsqueeze(gt_data[item],0)\n", 471 | " out = net(gt_data_single)\n", 472 | " #y = criterion(out, gt_onehot_label[item])\n", 473 | " y = criterion(out, gt_label[item])\n", 474 | " dy_dx = torch.autograd.grad(y, net.parameters(),retain_graph=True)\n", 475 | " original_dy_dx_tmp = list((_.detach().clone() for _ in dy_dx))\n", 476 | " original_dy_dx.append(original_dy_dx_tmp)\n", 477 | " out_tmp = out.detach().clone()\n", 478 | " original_pred.append(out_tmp)\n", 479 | " \n", 480 | " \n", 481 | " #if gaussian noise or laplace\n", 482 | " m = torch.distributions.laplace.Laplace(torch.tensor([0.0]), torch.tensor([1/gau_rate]))\n", 483 | " for item in range(batch):\n", 484 | " for layer_idx in range(10):\n", 485 | " #original_dy_dx[item][layer_idx] = original_dy_dx[item][layer_idx] + torch.empty(original_dy_dx[item][layer_idx].size()).normal_(mean=0,std=1/gau_rate).to(device)\n", 486 | " original_dy_dx[item][layer_idx] = original_dy_dx[item][layer_idx] + torch.squeeze(m.sample(sample_shape=original_dy_dx[item][layer_idx].size()),dim=-1).to(device)\n", 487 | " #break\n", 488 | " ##if deep gradient compression\n", 489 | " #print (original_dy_dx[0][0][0])\n", 490 | " #for item in range(batch):\n", 491 | " # for layer_idx in range(10):\n", 492 | " # if layer_idx == 0: \n", 493 | " # flat_dy_dx = torch.flatten(original_dy_dx[item][layer_idx])\n", 494 | " # else:\n", 495 | " # flat_dy_dx = torch.cat((flat_dy_dx,torch.flatten(original_dy_dx[item][layer_idx])),0)\n", 496 | " #sorted_dy_dx = flat_dy_dx.abs().sort()\n", 497 | " #size = np.asarray(list(flat_dy_dx.shape))\n", 498 | " #thresh = sorted_dy_dx[0][int(size * dgc_rate/100.0)]\n", 499 | " #print (size)\n", 500 | " #print (int(size * dgc_rate/100.0))\n", 501 | " #print (thresh)\n", 502 | " #print (sorted_dy_dx[0][-1])\n", 503 | " #for item in range(batch):\n", 504 | " # for layer_idx in range(10):\n", 505 | " # shape_tmp = original_dy_dx[item][layer_idx].size()\n", 506 | " # flat_dy_dx_prune = torch.flatten(original_dy_dx[item][layer_idx])\n", 507 | " # size_tmp = np.asarray(list(flat_dy_dx_prune.shape))\n", 508 | " # for m in range(int(size_tmp)):\n", 509 | " # if flat_dy_dx_prune[m].abs()<=thresh:\n", 510 | " # flat_dy_dx_prune[m] = 0\n", 511 | " # original_dy_dx[item][layer_idx] = flat_dy_dx_prune.view(shape_tmp)\n", 512 | " #print (original_dy_dx[0][0][0])\n", 513 | " \n", 514 | "\n", 515 | " # generate dummy data and label\n", 516 | " import time\n", 517 | "\n", 518 | " #if iter_ % 10 ==0: \n", 519 | " if iter_ == 1:\n", 520 | " \n", 521 | " #print ('epoch',epoch,'iter',iter_)\n", 522 | " for item in range(1):\n", 523 | " start = time.clock()\n", 524 | " for rd in range(1):\n", 525 | "\n", 526 | " torch.manual_seed(100*rd)\n", 527 | "\n", 528 | " pat_1 = torch.rand([3,16,16])\n", 529 | " pat_2 = torch.cat((pat_1,pat_1),dim=1)\n", 530 | " pat_4 = torch.cat((pat_2,pat_2),dim=2)\n", 531 | " dummy_data = torch.unsqueeze(pat_4,dim=0).to(device).requires_grad_(True) \n", 532 | "\n", 533 | " dummy_unsqueeze=torch.unsqueeze(gt_onehot_label[item],dim=0)\n", 534 | "\n", 535 | " dummy_label = torch.randn(gt_onehot_label[item].size()).to(device).requires_grad_(True)\n", 536 | " label_pred = torch.argmin(torch.sum(original_dy_dx[item][-2], dim=-1), dim=-1).detach().reshape((1,)).requires_grad_(False)\n", 537 | " label_pred_onehot = label_to_onehot(label_pred, num_classes=106)\n", 538 | "\n", 539 | " plt.imshow(tt(dummy_data[0].cpu()))\n", 540 | " plt.title(\"Dummy data\")\n", 541 | " #plt.savefig(\"./random_seed/index_%s_rand_seed_%s_label_%s\"%(item,rd,torch.argmax(dummy_label, dim=-1).item()))\n", 542 | "\n", 543 | " plt.clf()\n", 544 | " #print(\"Dummy label is %d.\" % torch.argmax(dummy_label, dim=-1).item())\n", 545 | " #print(\"stolen label is %d.\" % label_pred.item())\n", 546 | "\n", 547 | " #optimizer = torch.optim.LBFGS([dummy_data,dummy_label])\n", 548 | " optimizer = torch.optim.LBFGS([dummy_data,])\n", 549 | "\n", 550 | "\n", 551 | " history = []\n", 552 | " history_batch = []\n", 553 | " history_grad = []\n", 554 | "\n", 555 | " percept_dis = np.zeros(100)\n", 556 | " recover_dis = np.zeros(100)\n", 557 | " for iters in range(100):\n", 558 | "\n", 559 | " percept_dis[iters]=ssim(dummy_data,torch.unsqueeze(gt_data[item],dim=0),data_range=0).item()\n", 560 | " recover_dis[iters]=torch.dist(dummy_data,torch.unsqueeze(gt_data[item],dim=0),2).item()\n", 561 | "\n", 562 | " history.append(tt(dummy_data[0].cpu()))\n", 563 | "\n", 564 | " def closure():\n", 565 | " optimizer.zero_grad()\n", 566 | "\n", 567 | " pred = net(dummy_data) \n", 568 | " dummy_onehot_label = F.softmax(dummy_label, dim=-1)\n", 569 | " #dummy_loss = criterion(pred, dummy_onehot_label) # TODO: fix the gt_label to dummy_label in both code and slides.\n", 570 | "\n", 571 | " #dummy_loss = criterion(pred, label_pred_onehot)\n", 572 | " dummy_loss = criterion(pred, label_pred)\n", 573 | "\n", 574 | "\n", 575 | " dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)\n", 576 | " #dummy_dy_dp = torch.autograd.grad(dummy_loss, dummy_data, create_graph=True)\n", 577 | " #print (dummy_dy_dp[0].shape)\n", 578 | "\n", 579 | " grad_diff = 0\n", 580 | " grad_count = 0\n", 581 | " #count =0\n", 582 | " for gx, gy in zip(dummy_dy_dx, original_dy_dx[item]): # TODO: fix the variablas here\n", 583 | "\n", 584 | " #if iters==500 or iters== 1200:\n", 585 | " # print (gx[0])\n", 586 | " # print ('hahaha')\n", 587 | " # print (gy[0])\n", 588 | " lasso = torch.norm(dummy_data,p=1)\n", 589 | " ridge = torch.norm(dummy_data,p=2)\n", 590 | " grad_diff += ((gx - gy) ** 2).sum() #+ 0.0*lasso +0.01*ridge\n", 591 | "\n", 592 | "\n", 593 | " grad_count += gx.nelement()\n", 594 | "\n", 595 | " #if count == 9:\n", 596 | " # break\n", 597 | " #count=count+1\n", 598 | " # grad_diff = grad_diff / grad_count * 1000\n", 599 | " grad_diff.backward()\n", 600 | " #print (count)\n", 601 | "\n", 602 | " #print (dummy_dy_dx)\n", 603 | " #print (original_dy_dx)\n", 604 | "\n", 605 | "\n", 606 | " return grad_diff\n", 607 | "\n", 608 | "\n", 609 | "\n", 610 | " optimizer.step(closure)\n", 611 | " if iters % 5 == 0: \n", 612 | " current_loss = closure()\n", 613 | " #if iters == 0: \n", 614 | " #print (\"%.8f\" % current_loss.item())\n", 615 | " #print(iters, \"%.8f\" % current_loss.item())\n", 616 | "\n", 617 | " # for bat in range(batch-1):\n", 618 | " # history_batch.append(tt(dummy_data[bat].cpu()))\n", 619 | "\n", 620 | " #plt.figure(figsize=(30, 20))\n", 621 | " #for i in range(100):\n", 622 | " # plt.subplot(10, 10, i + 1)\n", 623 | " # plt.imshow(history[i * 5])\n", 624 | " # plt.title(\"iter=%d\" % (i * 5))\n", 625 | " # plt.axis('off')\n", 626 | " #print(\"Dummy label is %d.\" % torch.argmax(dummy_label, dim=-1).item())\n", 627 | "\n", 628 | " #np.savetxt('./attack_image/lfw_ssim_idx_%s_laplace_%s'%(item,gau_rate),percept_dis,fmt=\"%4f\")\n", 629 | " #np.savetxt('./attack_image/lfw_mse_idx_%s_laplace_%s'%(item,gau_rate),recover_dis,fmt=\"%4f\")\n", 630 | " #plt.savefig(\"./attack_image/lfw_index_%s_laplace_%s\"%(item,gau_rate))\n", 631 | "\n", 632 | " #plt.clf()\n", 633 | " \n", 634 | " pinp = np.argmin(recover_dis)\n", 635 | " plt.imshow(history[pinp])\n", 636 | " plt.axis('off')\n", 637 | " \n", 638 | " #plt.figure(figsize=(15, 10))\n", 639 | " #for i in range(60):\n", 640 | " # plt.subplot(6, 10, i + 1)\n", 641 | " # plt.imshow(history[i * 14 ], cmap='gray')\n", 642 | " # plt.title(\"iter=%d\" % (i*14))\n", 643 | " # plt.axis('off')\n", 644 | " plt.savefig(\"./attack_image/lfw_idx_%s_laplace_%s_nips\"%(item,gau_rate))\n", 645 | " \n", 646 | " #duration = time.clock()-start\n", 647 | " #print (\"Running time is %.4f.\" %(duration/10.0) )\n", 648 | " #print (duration)\n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | "################################################### training when set to epoch\n", 656 | "\n", 657 | "# #if epoch>=1:\n", 658 | "# #if i==1:\n", 659 | "# #break\n", 660 | "# #print (iter_)\n", 661 | "# inputs,label = data\n", 662 | "\n", 663 | "# inputs,label = Variable(inputs).to(device),Variable(label).to(device)\n", 664 | "\n", 665 | "# optimizer_train.zero_grad()\n", 666 | "\n", 667 | "\n", 668 | "# outputs_benign=net(inputs)\n", 669 | "# #outputs_benign = F.softmax(outputs_benign, dim=-1)\n", 670 | "# #print (outputs_benign[0])\n", 671 | "\n", 672 | "\n", 673 | "# loss_benign = criterion_train(outputs_benign,label)\n", 674 | "\n", 675 | "# #print(\"loss computed\")\n", 676 | "# loss_benign.backward()\n", 677 | "# #print(\"loss BP\")\n", 678 | "# optimizer_train.step()\n", 679 | "# #sgd_update(net.parameters())\n", 680 | "\n", 681 | "# #if i%2000==0:\n", 682 | "# #print (loss_benign.item())\n", 683 | "# #torch.save(net.state_dict(),'./LFW_net.pth') \n", 684 | "\n", 685 | "# #if iter_%50==0:\n", 686 | "# # print ('attack',iter_)\n", 687 | " \n", 688 | " \n", 689 | "# print ('fininshed training')\n", 690 | "# break\n", 691 | "############################### testing\n", 692 | " \n", 693 | "# total = len(y_test)\n", 694 | "# acc =0.0\n", 695 | "# for ct in range(total):\n", 696 | "# testing_data = tt(testing[ct][0].cpu())\n", 697 | "# testing_data1 = tp(testing_data).to(device)\n", 698 | "# testing_data2 = testing_data1.view(1, *testing_data1.size())\n", 699 | "# y_pred = net(testing_data2)\n", 700 | "# predicted = torch.argmax(y_pred)\n", 701 | "\n", 702 | "# if predicted == y_test[ct]:\n", 703 | "# acc=acc+1\n", 704 | "# accuracy = acc / total\n", 705 | "# print (accuracy)\n", 706 | "# print ('fininshed testing')\n", 707 | "\n" 708 | ] 709 | }, 710 | { 711 | "cell_type": "code", 712 | "execution_count": 17, 713 | "metadata": {}, 714 | "outputs": [], 715 | "source": [ 716 | "################################" 717 | ] 718 | }, 719 | { 720 | "cell_type": "code", 721 | "execution_count": 6, 722 | "metadata": { 723 | "colab": {}, 724 | "colab_type": "code", 725 | "id": "AorI020iVjjS", 726 | "scrolled": false 727 | }, 728 | "outputs": [ 729 | { 730 | "name": "stdout", 731 | "output_type": "stream", 732 | "text": [ 733 | "torch.Size([100, 3, 32, 32])\n", 734 | "torch.Size([100, 1])\n", 735 | "torch.Size([100, 106])\n", 736 | "Dummy label is 42.\n", 737 | "stolen label is 6.\n" 738 | ] 739 | }, 740 | { 741 | "name": "stderr", 742 | "output_type": "stream", 743 | "text": [ 744 | "/home/wenqi/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:79: DeprecationWarning: time.clock has been deprecated in Python 3.3 and will be removed from Python 3.8: use time.perf_counter or time.process_time instead\n", 745 | "/home/wenqi/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:203: DeprecationWarning: time.clock has been deprecated in Python 3.3 and will be removed from Python 3.8: use time.perf_counter or time.process_time instead\n" 746 | ] 747 | }, 748 | { 749 | "name": "stdout", 750 | "output_type": "stream", 751 | "text": [ 752 | "87.941089\n", 753 | "12.076419830322266\n", 754 | "fininshed training\n", 755 | "0.0032119914346895075\n", 756 | "fininshed testing\n", 757 | "torch.Size([100, 3, 32, 32])\n", 758 | "torch.Size([100, 1])\n", 759 | "torch.Size([100, 106])\n", 760 | "Dummy label is 42.\n", 761 | "stolen label is 6.\n", 762 | "98.18773599999999\n", 763 | "11.408146858215332\n", 764 | "fininshed training\n", 765 | "0.03747323340471092\n", 766 | "fininshed testing\n", 767 | "torch.Size([100, 3, 32, 32])\n", 768 | "torch.Size([100, 1])\n", 769 | "torch.Size([100, 106])\n", 770 | "Dummy label is 42.\n", 771 | "stolen label is 6.\n" 772 | ] 773 | }, 774 | { 775 | "ename": "KeyboardInterrupt", 776 | "evalue": "", 777 | "output_type": "error", 778 | "traceback": [ 779 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 780 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 781 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 179\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 180\u001b[0;31m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 181\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0miters\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;36m5\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 182\u001b[0m \u001b[0mcurrent_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 782 | "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/optim/lbfgs.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 383\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_old\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 384\u001b[0m \u001b[0mbe_i\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_dirs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mro\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 385\u001b[0;31m \u001b[0mr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mal\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mbe_i\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mold_stps\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 386\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 387\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mprev_flat_grad\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 783 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 784 | ] 785 | }, 786 | { 787 | "data": { 788 | "text/plain": [ 789 | "
" 790 | ] 791 | }, 792 | "metadata": {}, 793 | "output_type": "display_data" 794 | }, 795 | { 796 | "data": { 797 | "text/plain": [ 798 | "
" 799 | ] 800 | }, 801 | "metadata": {}, 802 | "output_type": "display_data" 803 | }, 804 | { 805 | "data": { 806 | "text/plain": [ 807 | "
" 808 | ] 809 | }, 810 | "metadata": {}, 811 | "output_type": "display_data" 812 | } 813 | ], 814 | "source": [ 815 | "from pytorch_msssim import ssim\n", 816 | "\n", 817 | "\n", 818 | "for epoch in range(1):\n", 819 | "\n", 820 | " for iter_,data in enumerate(trainloader,1):\n", 821 | " \n", 822 | " ######### honest partipant #########\n", 823 | " img_index = 6 #use img_index\n", 824 | " dst_pil = tt(dst_tensor[img_index][0].cpu()) #use img_index\n", 825 | "\n", 826 | " gt_data = tp(dst_pil).to(device)\n", 827 | " gt_data = torch.unsqueeze(gt_data,0)\n", 828 | "\n", 829 | " gt_label = dst_tensor[img_index][1].long().to(device) #use img_index\n", 830 | " gt_label = gt_label.view(1, )\n", 831 | " gt_onehot_label = label_to_onehot(gt_label, num_classes=106)\n", 832 | "\n", 833 | " plt.imshow(dst_pil)\n", 834 | " #plt.savefig(\"./original/index_%s_label_%s\"%(img_index,gt_label.item()))\n", 835 | "\n", 836 | "\n", 837 | "\n", 838 | " batch =100 #\n", 839 | " for bat in range(batch-1):\n", 840 | " dst_pil = tt(dst_tensor[img_index+1+bat][0].cpu()) #use img_index\n", 841 | " tmp = torch.unsqueeze(tp(dst_pil).to(device),0)\n", 842 | " #print(tmp.shape)\n", 843 | " gt_data = torch.cat((gt_data,tmp),0)\n", 844 | "\n", 845 | " gt_label_tmp = dst_tensor[img_index+1+bat][1].long().to(device) #use img_index\n", 846 | " gt_label_tmp = gt_label_tmp.view(1, )\n", 847 | " gt_label = torch.cat((gt_label,gt_label_tmp),0)\n", 848 | " gt_onehot_label = torch.cat((gt_onehot_label,label_to_onehot(gt_label_tmp, num_classes=106)),0)\n", 849 | "\n", 850 | " plt.imshow(dst_pil)\n", 851 | " #plt.savefig(\"./original/index_%s_label_%s\"%(bat+1,gt_label_tmp.item()))\n", 852 | "\n", 853 | " #plt.title(\"Ground truth image\")\n", 854 | " #print(\"GT label is %d.\" % gt_label.item(), \"\\nOnehot label is %d.\" % torch.argmax(gt_onehot_label, dim=-1).item())\n", 855 | "\n", 856 | "\n", 857 | " gt_label = torch.reshape(gt_label,(-1,1)) \n", 858 | " print (gt_data.shape)\n", 859 | " print (gt_label.shape)\n", 860 | " print (gt_onehot_label.shape)\n", 861 | "\n", 862 | "\n", 863 | " # compute original gradient \n", 864 | " dy_dx = []\n", 865 | " original_dy_dx=[]\n", 866 | " original_pred = []\n", 867 | " for item in range(batch):\n", 868 | " gt_data_single = torch.unsqueeze(gt_data[item],0)\n", 869 | " out = net(gt_data_single)\n", 870 | " #y = criterion(out, gt_onehot_label[item])\n", 871 | " y = criterion(out, gt_label[item])\n", 872 | " dy_dx = torch.autograd.grad(y, net.parameters(),retain_graph=True)\n", 873 | " original_dy_dx_tmp = list((_.detach().clone() for _ in dy_dx))\n", 874 | " original_dy_dx.append(original_dy_dx_tmp)\n", 875 | " out_tmp = out.detach().clone()\n", 876 | " original_pred.append(out_tmp)\n", 877 | "\n", 878 | " #dy_dx.append(torch.autograd.grad(y, net.parameters()))\n", 879 | "\n", 880 | "\n", 881 | "\n", 882 | "\n", 883 | "\n", 884 | " # share the gradients with other clients\n", 885 | " #original_dy_dx = list((_.detach().clone() for _ in dy_dx))\n", 886 | "\n", 887 | "\n", 888 | " # generate dummy data and label\n", 889 | " import time\n", 890 | "\n", 891 | "\n", 892 | " for item in range(1):\n", 893 | " start = time.clock()\n", 894 | " for rd in range(1):\n", 895 | "\n", 896 | " torch.manual_seed(100*rd)\n", 897 | " \n", 898 | " pat_1 = torch.rand([3,16,16])\n", 899 | " pat_2 = torch.cat((pat_1,pat_1),dim=1)\n", 900 | " pat_4 = torch.cat((pat_2,pat_2),dim=2)\n", 901 | " dummy_data = torch.unsqueeze(pat_4,dim=0).to(device).requires_grad_(True) \n", 902 | "\n", 903 | " dummy_unsqueeze=torch.unsqueeze(gt_onehot_label[item],dim=0)\n", 904 | "\n", 905 | " dummy_label = torch.randn(dummy_unsqueeze.size()).to(device).requires_grad_(True)\n", 906 | " label_pred=torch.argmin(torch.sum(original_dy_dx[item][-2], dim=-1), \n", 907 | " dim=-1).detach().reshape((1,)).requires_grad_(False)\n", 908 | " #print (original_dy_dx[item][-1].shape)\n", 909 | " #print (original_dy_dx[item][-1].argmin())\n", 910 | "\n", 911 | " #print (torch.sum(original_dy_dx[item][-2], dim=-1).argmin())\n", 912 | "\n", 913 | " plt.imshow(tt(dummy_data[0].cpu()))\n", 914 | " plt.title(\"Dummy data\")\n", 915 | " #plt.savefig(\"./random_seed/index_%s_rand_seed_%s_label_%s\"%(item,rd,torch.argmax(dummy_label, dim=-1).item()))\n", 916 | "\n", 917 | " plt.clf()\n", 918 | " print(\"Dummy label is %d.\" % torch.argmax(dummy_label, dim=-1).item())\n", 919 | " print(\"stolen label is %d.\" % label_pred.item())\n", 920 | "\n", 921 | "\n", 922 | " #optimizer = torch.optim.LBFGS([dummy_data,dummy_label])\n", 923 | " optimizer = torch.optim.LBFGS([dummy_data,])\n", 924 | " #optimizer = torch.optim.AdamW([dummy_data,],lr=0.01)\n", 925 | "\n", 926 | "\n", 927 | " history = []\n", 928 | " history_batch = []\n", 929 | " history_grad = []\n", 930 | " \n", 931 | " percept_dis = np.zeros(500)\n", 932 | " recover_dis = np.zeros(500)\n", 933 | " for iters in range(500):\n", 934 | " \n", 935 | " percept_dis[iters]=ssim(dummy_data,torch.unsqueeze(gt_data[item],dim=0),data_range=0).item()\n", 936 | " recover_dis[iters]=torch.dist(dummy_data,torch.unsqueeze(gt_data[item],dim=0),2).item()\n", 937 | " \n", 938 | " history.append(tt(dummy_data[0].cpu()))\n", 939 | "\n", 940 | " def closure():\n", 941 | " optimizer.zero_grad()\n", 942 | "\n", 943 | " pred = net(dummy_data) \n", 944 | " #dummy_onehot_label = F.softmax(dummy_label, dim=-1).long()\n", 945 | "\n", 946 | " #dummy_loss = criterion(pred, dummy_onehot_label) # TODO: fix the gt_label to dummy_label in both code and slides.\n", 947 | " #print (pred)\n", 948 | " #print (label_pred)\n", 949 | "\n", 950 | " dummy_loss = criterion(pred, label_pred)\n", 951 | " dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)\n", 952 | " #dummy_dy_dp = torch.autograd.grad(dummy_loss, dummy_data, create_graph=True)\n", 953 | " #print (dummy_dy_dp[0].shape) \n", 954 | "\n", 955 | " grad_diff = 0\n", 956 | " grad_count = 0\n", 957 | " #count =0\n", 958 | " for gx, gy in zip(dummy_dy_dx, original_dy_dx[item]): # TODO: fix the variablas here\n", 959 | "\n", 960 | " #if iters==500 or iters== 1200:\n", 961 | " #print (gx[0])\n", 962 | " # print ('hahaha')\n", 963 | " #print (gy[0])\n", 964 | " lasso = torch.norm(dummy_data,p=1)\n", 965 | " ridge = torch.norm(dummy_data,p=2)\n", 966 | " grad_diff += ((gx - gy) ** 2).sum() #+ 0.0*lasso +0.01*ridge \n", 967 | "\n", 968 | " #print (gx.shape)\n", 969 | "\n", 970 | " grad_count += gx.nelement()\n", 971 | "\n", 972 | "\n", 973 | " #if count == 9:\n", 974 | " # break\n", 975 | " #count=count+1\n", 976 | " # grad_diff = grad_diff / grad_count * 1000\n", 977 | "\n", 978 | " #grad_diff += ((original_pred[item]-pred)**2).sum()\n", 979 | "\n", 980 | "\n", 981 | "\n", 982 | "\n", 983 | " grad_diff.backward()\n", 984 | " #print (count)\n", 985 | "\n", 986 | " #print (dummy_dy_dx)\n", 987 | " #print (original_dy_dx)\n", 988 | "\n", 989 | "\n", 990 | " return grad_diff\n", 991 | "\n", 992 | "\n", 993 | "\n", 994 | " optimizer.step(closure)\n", 995 | " if iters % 5 == 0: \n", 996 | " current_loss = closure()\n", 997 | " #if iters == 0: \n", 998 | " #print (\"%.8f\" % current_loss.item())\n", 999 | " #print(iters, \"%.8f\" % current_loss.item())\n", 1000 | "\n", 1001 | " # for bat in range(batch-1):\n", 1002 | " # history_batch.append(tt(dummy_data[bat].cpu()))\n", 1003 | "\n", 1004 | " plt.figure(figsize=(30, 20))\n", 1005 | " for i in range(100):\n", 1006 | " plt.subplot(10, 10, i + 1)\n", 1007 | " plt.imshow(history[i * 5])\n", 1008 | " plt.title(\"iter=%d\" % (i * 5))\n", 1009 | " plt.axis('off')\n", 1010 | " #print(\"Dummy label is %d.\" % torch.argmax(dummy_label, dim=-1).item())\n", 1011 | " \n", 1012 | " #np.savetxt('lfw_ssim_%s'%iter_,percept_dis,fmt=\"%4f\")\n", 1013 | " #np.savetxt('lfw_mse_%s'%iter_,recover_dis,fmt=\"%4f\")\n", 1014 | " #plt.savefig(\"./attack_image/index_%s_iter_%s_label_%s\"%(img_index,iter_,torch.argmax(dummy_label, dim=-1).item()))\n", 1015 | " \n", 1016 | " plt.clf()\n", 1017 | " duration = time.clock()-start\n", 1018 | " #print (\"Running time is %.4f.\" %(duration/10.0) )\n", 1019 | " print (duration)\n", 1020 | " \n", 1021 | " \n", 1022 | " #if epoch>=1:\n", 1023 | " #if i==1:\n", 1024 | " #break\n", 1025 | " #print (iter_)\n", 1026 | " inputs,label = data\n", 1027 | "\n", 1028 | " inputs,label = Variable(inputs),Variable(label) \n", 1029 | "\n", 1030 | " optimizer_train.zero_grad()\n", 1031 | "\n", 1032 | "\n", 1033 | " outputs_benign=net(inputs)\n", 1034 | " #outputs_benign = F.softmax(outputs_benign, dim=-1)\n", 1035 | " #print (outputs_benign[0])\n", 1036 | "\n", 1037 | "\n", 1038 | " loss_benign = criterion_train(outputs_benign,label)\n", 1039 | "\n", 1040 | " #print(\"loss computed\")\n", 1041 | " loss_benign.backward()\n", 1042 | " #print(\"loss BP\")\n", 1043 | " optimizer_train.step()\n", 1044 | "\n", 1045 | " #if i%2000==0:\n", 1046 | " print (loss_benign.item())\n", 1047 | " #torch.save(net.state_dict(),'./LFW_net.pth') \n", 1048 | "\n", 1049 | " \n", 1050 | " print ('fininshed training')\n", 1051 | " total = len(y_test)\n", 1052 | " acc =0.0\n", 1053 | " for ct in range(total):\n", 1054 | " testing_data = tt(testing[ct][0].cpu())\n", 1055 | " testing_data1 = tp(testing_data).to(device)\n", 1056 | " testing_data2 = testing_data1.view(1, *testing_data1.size())\n", 1057 | " y_pred = net(testing_data2)\n", 1058 | " predicted = torch.argmax(y_pred)\n", 1059 | "\n", 1060 | " if predicted == y_test[ct]:\n", 1061 | " acc=acc+1\n", 1062 | " accuracy = acc / total\n", 1063 | " print (accuracy)\n", 1064 | " print ('fininshed testing')\n", 1065 | "\n", 1066 | "\n", 1067 | " " 1068 | ] 1069 | }, 1070 | { 1071 | "cell_type": "code", 1072 | "execution_count": null, 1073 | "metadata": { 1074 | "colab": { 1075 | "base_uri": "https://localhost:8080/", 1076 | "height": 428 1077 | }, 1078 | "colab_type": "code", 1079 | "id": "aokP-jhal96-", 1080 | "outputId": "595e775a-7f91-49a8-cfaa-384c7a320002" 1081 | }, 1082 | "outputs": [], 1083 | "source": [ 1084 | "plt.figure(figsize=(12, 8))\n", 1085 | "for i in range(60):\n", 1086 | " plt.subplot(6, 10, i + 1)\n", 1087 | " plt.imshow(history[i * 5])\n", 1088 | " plt.title(\"iter=%d\" % (i * 5))\n", 1089 | " plt.axis('off')\n", 1090 | "print(\"Dummy label is %d.\" % torch.argmax(dummy_label, dim=-1).item())" 1091 | ] 1092 | }, 1093 | { 1094 | "cell_type": "code", 1095 | "execution_count": null, 1096 | "metadata": {}, 1097 | "outputs": [], 1098 | "source": [ 1099 | "plt.figure(figsize=(12, 8))\n", 1100 | "for j in range(batch):\n", 1101 | " for i in range(60):\n", 1102 | " plt.subplot(6, 10, i + 1)\n", 1103 | " plt.imshow(history_batch[i * 5+j])\n", 1104 | " plt.title(\"iter=%d\" % (i * 5+ j))\n", 1105 | " plt.axis('off')\n", 1106 | "print(\"Dummy label is %d.\" % torch.argmax(dummy_label, dim=-1).item())" 1107 | ] 1108 | }, 1109 | { 1110 | "cell_type": "code", 1111 | "execution_count": null, 1112 | "metadata": {}, 1113 | "outputs": [], 1114 | "source": [] 1115 | } 1116 | ], 1117 | "metadata": { 1118 | "accelerator": "GPU", 1119 | "colab": { 1120 | "collapsed_sections": [], 1121 | "name": "Deep Leakage from Gradients.ipynb", 1122 | "provenance": [] 1123 | }, 1124 | "kernelspec": { 1125 | "display_name": "Python 3", 1126 | "language": "python", 1127 | "name": "python3" 1128 | }, 1129 | "language_info": { 1130 | "codemirror_mode": { 1131 | "name": "ipython", 1132 | "version": 3 1133 | }, 1134 | "file_extension": ".py", 1135 | "mimetype": "text/x-python", 1136 | "name": "python", 1137 | "nbconvert_exporter": "python", 1138 | "pygments_lexer": "ipython3", 1139 | "version": "3.7.4" 1140 | } 1141 | }, 1142 | "nbformat": 4, 1143 | "nbformat_minor": 1 1144 | } 1145 | -------------------------------------------------------------------------------- /GradInverting/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Jonas Geiping, Hartmut Bauermeister, Hannah Dröge 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /GradInverting/README.md: -------------------------------------------------------------------------------- 1 | # Inverting Gradients - How easy is it to break Privacy in Federated Learning? 2 | 3 | This repository is an implementation of the reconstruction algorithm discussed in 4 | ``` 5 | Jonas Geiping, Hartmut Bauermeister, Hannah Dröge, and Michael Moeller. 6 | Inverting Gradients -- How Easy Is It to Break Privacy in Federated Learning?, 7 | March 31, 2020. 8 | https://arxiv.org/abs/2003.14053v1. 9 | 10 | ``` 11 | which can be found at https://arxiv.org/abs/2003.14053 12 | 13 | Input Image | Reconstruction from gradient information 14 | :-------------------------:|:-------------------------: 15 | ![](11794_ResNet18_ImageNet_input.png) | ![](11794_ResNet18_ImageNet_output.png) 16 | 17 | [Model: standard ResNet18, trained on ImageNet data. The image is from the validation set.] 18 | 19 | ### Abstract: 20 | The idea of federated learning is to collaboratively train a neural network on a server. Each user receives the current weights of the network and in turns sends parameter updates (gradients) based on local data. This protocol has been designed not only to train neural networks data-efficiently, but also to provide privacy benefits for users, as their in-put data remains on device and only parameter gradients are shared. In this paper we show that sharing parameter gradients is by no means secure: By exploiting a cosine similarity loss along with optimization methods from adversarial attacks, we are able to faithfully reconstruct images at high resolution from the knowledge of their parameter gradients, and demonstrate that such a break of privacy is possible even for trained deep networks. Moreover, we analyze the effects of architecture as well as parameters on the difficulty of reconstructing the input image, prove that any input to a fully connected layer can be reconstructed analytically independent of the remaining architecture, and show numerically that even averaging gradients over several iterations or several images does not protect the user’s privacy in federated learning applications in computer vision. 21 | 22 | ## Code 23 | The central file that contains the reconstruction algorithm can be found at ```inversefed/reconstruction_algorithms.py```. The other folders and files are used to define and train the various models and are not central for recovery. 24 | 25 | ### Setup: 26 | Requirements: 27 | ``` 28 | pytorch=1.4.0 29 | torchvision=0.5.0 30 | ``` 31 | You can use [anaconda](https://www.anaconda.com/distribution/) to install our setup by running 32 | ``` 33 | conda env create -f environment.yml 34 | conda activate iv 35 | ``` 36 | To run ImageNet experiments, you need to download ImageNet and provide its location [or use your own images and skip the ```inversefed.construct_dataloaders``` steps]. 37 | 38 | 39 | ### Quick Start 40 | Usage examples can be found in the notebooks, for example the [ResNet-152, ImageNet](ResNet152%20-%20trained%20on%20ImageNet.ipynb) example. 41 | Given an input gradient (as computed by e.g. ```torch.autograd.grad```), a ```config``` dictionary, a model ```model``` and dataset mean and std, ```(dm, ds)```, build the reconstruction operator 42 | ``` 43 | rec_machine = inversefed.GradientReconstructor(model, (dm, ds), config, num_images=1) 44 | ``` 45 | and then start the reconstruction, specifying a target image size: 46 | ``` 47 | output, stats = rec_machine.reconstruct(input_gradient, None, img_shape=(3, 32, 32)) 48 | ``` 49 | 50 | 51 | 52 | ### CLI Usage example: 53 | The code can also be used via cmd-line in the following way: 54 | ``` 55 | python reconstruct_image.py --model ResNet20-4 --dataset CIFAR10 --trained_model --cost_fn sim --indices def --restarts 32 --save_image --target_id -1 56 | ``` 57 | -------------------------------------------------------------------------------- /GradInverting/environment.yml: -------------------------------------------------------------------------------- 1 | name: iv 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - attrs=19.3.0=py_0 8 | - backcall=0.1.0=py38_0 9 | - blas=1.0=mkl 10 | - bleach=3.1.0=py_0 11 | - ca-certificates=2020.1.1=0 12 | - certifi=2019.11.28=py38_0 13 | - cudatoolkit=10.1.243=h6bb024c_0 14 | - cycler=0.10.0=py38_0 15 | - dbus=1.13.12=h746ee38_0 16 | - decorator=4.4.1=py_0 17 | - defusedxml=0.6.0=py_0 18 | - entrypoints=0.3=py38_0 19 | - expat=2.2.6=he6710b0_0 20 | - fontconfig=2.13.0=h9420a91_0 21 | - freetype=2.9.1=h8a8886c_1 22 | - glib=2.63.1=h5a9c865_0 23 | - gmp=6.1.2=h6c8ec71_1 24 | - gst-plugins-base=1.14.0=hbbd80ab_1 25 | - gstreamer=1.14.0=hb453b48_1 26 | - icu=58.2=h9c2bf20_1 27 | - importlib_metadata=1.5.0=py38_0 28 | - intel-openmp=2020.0=166 29 | - ipykernel=5.1.4=py38h39e3cac_0 30 | - ipython=7.12.0=py38h5ca1d4c_0 31 | - ipython_genutils=0.2.0=py38_0 32 | - ipywidgets=7.5.1=py_0 33 | - jedi=0.16.0=py38_0 34 | - jinja2=2.11.1=py_0 35 | - jpeg=9b=h024ee3a_2 36 | - jsonschema=3.2.0=py38_0 37 | - jupyter=1.0.0=py38_7 38 | - jupyter_client=5.3.4=py38_0 39 | - jupyter_console=6.1.0=py_0 40 | - jupyter_core=4.6.1=py38_0 41 | - kiwisolver=1.0.1=py38he6710b0_0 42 | - ld_impl_linux-64=2.33.1=h53a641e_7 43 | - libedit=3.1.20181209=hc058e9b_0 44 | - libffi=3.2.1=hd88cf55_4 45 | - libgcc-ng=9.1.0=hdf63c60_0 46 | - libgfortran-ng=7.3.0=hdf63c60_0 47 | - libpng=1.6.37=hbc83047_0 48 | - libsodium=1.0.16=h1bed415_0 49 | - libstdcxx-ng=9.1.0=hdf63c60_0 50 | - libtiff=4.1.0=h2733197_0 51 | - libuuid=1.0.3=h1bed415_2 52 | - libxcb=1.13=h1bed415_1 53 | - libxml2=2.9.9=hea5a465_1 54 | - markupsafe=1.1.1=py38h7b6447c_0 55 | - matplotlib=3.1.3=py38_0 56 | - matplotlib-base=3.1.3=py38hef1b27d_0 57 | - mistune=0.8.4=py38h7b6447c_1000 58 | - mkl=2020.0=166 59 | - mkl-service=2.3.0=py38he904b0f_0 60 | - mkl_fft=1.0.15=py38ha843d7b_0 61 | - mkl_random=1.1.0=py38h962f231_0 62 | - nbconvert=5.6.1=py38_0 63 | - nbformat=5.0.4=py_0 64 | - ncurses=6.1=he6710b0_1 65 | - ninja=1.9.0=py38hfd86e86_0 66 | - notebook=6.0.3=py38_0 67 | - numpy=1.18.1=py38h4f9e942_0 68 | - numpy-base=1.18.1=py38hde5b4d6_1 69 | - olefile=0.46=py_0 70 | - openssl=1.1.1d=h7b6447c_4 71 | - pandoc=2.2.3.2=0 72 | - pandocfilters=1.4.2=py38_1 73 | - parso=0.6.1=py_0 74 | - pcre=8.43=he6710b0_0 75 | - pexpect=4.8.0=py38_0 76 | - pickleshare=0.7.5=py38_1000 77 | - pillow=7.0.0=py38hb39fc2d_0 78 | - pip=20.0.2=py38_1 79 | - prometheus_client=0.7.1=py_0 80 | - prompt_toolkit=3.0.3=py_0 81 | - ptyprocess=0.6.0=py38_0 82 | - pygments=2.5.2=py_0 83 | - pyparsing=2.4.6=py_0 84 | - pyqt=5.9.2=py38h05f1152_4 85 | - pyrsistent=0.15.7=py38h7b6447c_0 86 | - python=3.8.1=h0371630_1 87 | - python-dateutil=2.8.1=py_0 88 | - pytorch=1.4.0=py3.8_cuda10.1.243_cudnn7.6.3_0 89 | - pyzmq=18.1.1=py38he6710b0_0 90 | - qt=5.9.7=h5867ecd_1 91 | - qtconsole=4.6.0=py_1 92 | - readline=7.0=h7b6447c_5 93 | - scipy=1.4.1=py38h0b6359f_0 94 | - send2trash=1.5.0=py38_0 95 | - setuptools=45.2.0=py38_0 96 | - sip=4.19.13=py38he6710b0_0 97 | - six=1.14.0=py38_0 98 | - sqlite=3.31.1=h7b6447c_0 99 | - terminado=0.8.3=py38_0 100 | - testpath=0.4.4=py_0 101 | - tk=8.6.8=hbc83047_0 102 | - torchvision=0.5.0=py38_cu101 103 | - tornado=6.0.3=py38h7b6447c_3 104 | - traitlets=4.3.3=py38_0 105 | - wcwidth=0.1.8=py_0 106 | - webencodings=0.5.1=py38_1 107 | - wheel=0.34.2=py38_0 108 | - widgetsnbextension=3.5.1=py38_0 109 | - xz=5.2.4=h14c3975_4 110 | - zeromq=4.3.1=he6710b0_3 111 | - zipp=2.2.0=py_0 112 | - zlib=1.2.11=h7b6447c_3 113 | - zstd=1.3.7=h0b5b093_0 114 | 115 | -------------------------------------------------------------------------------- /GradInverting/inversefed/__init__.py: -------------------------------------------------------------------------------- 1 | """Library of routines.""" 2 | 3 | from inversefed import nn 4 | from inversefed.nn import construct_model, MetaMonkey 5 | 6 | from inversefed.data import construct_dataloaders 7 | from inversefed.training import train 8 | from inversefed import utils 9 | 10 | from .optimization_strategy import training_strategy 11 | 12 | 13 | from .reconstruction_algorithms import GradientReconstructor, FedAvgReconstructor 14 | 15 | from .options import options 16 | from inversefed import metrics 17 | 18 | __all__ = ['train', 'construct_dataloaders', 'construct_model', 'MetaMonkey', 19 | 'training_strategy', 'nn', 'utils', 'options', 20 | 'metrics', 'GradientReconstructor', 'FedAvgReconstructor'] 21 | -------------------------------------------------------------------------------- /GradInverting/inversefed/consts.py: -------------------------------------------------------------------------------- 1 | """Setup constants, ymmv.""" 2 | 3 | PIN_MEMORY = True 4 | NON_BLOCKING = False 5 | BENCHMARK = True 6 | MULTITHREAD_DATAPROCESSING = 4 7 | 8 | 9 | cifar10_mean = [0.4914672374725342, 0.4822617471218109, 0.4467701315879822] 10 | cifar10_std = [0.24703224003314972, 0.24348513782024384, 0.26158785820007324] 11 | cifar100_mean = [0.5071598291397095, 0.4866936206817627, 0.44120192527770996] 12 | cifar100_std = [0.2673342823982239, 0.2564384639263153, 0.2761504650115967] 13 | mnist_mean = (0.13066373765468597,) 14 | mnist_std = (0.30810782313346863,) 15 | imagenet_mean = [0.485, 0.456, 0.406] 16 | imagenet_std = [0.229, 0.224, 0.225] 17 | -------------------------------------------------------------------------------- /GradInverting/inversefed/data/README.md: -------------------------------------------------------------------------------- 1 | # Data Processing 2 | 3 | This module implements ```construct_dataloaders```. -------------------------------------------------------------------------------- /GradInverting/inversefed/data/__init__.py: -------------------------------------------------------------------------------- 1 | """Data stuff that I usually don't want to see.""" 2 | 3 | from .data_processing import construct_dataloaders 4 | 5 | 6 | __all__ = ['construct_dataloaders'] 7 | -------------------------------------------------------------------------------- /GradInverting/inversefed/data/data.py: -------------------------------------------------------------------------------- 1 | """This is data.py from pytorch-examples. 2 | 3 | Refer to 4 | https://github.com/pytorch/examples/blob/master/super_resolution/data.py. 5 | """ 6 | 7 | from os.path import exists, join, basename 8 | from os import makedirs, remove 9 | from six.moves import urllib 10 | import tarfile 11 | from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize, RandomCrop 12 | 13 | 14 | from .datasets import DatasetFromFolder 15 | 16 | def _build_bsds_sr(data_path, augmentations=True, normalize=True, upscale_factor=3, RGB=True): 17 | root_dir = _download_bsd300(dest=data_path) 18 | train_dir = join(root_dir, "train") 19 | crop_size = _calculate_valid_crop_size(256, upscale_factor) 20 | print(f'Crop size is {crop_size}. Upscaling factor is {upscale_factor} in mode {RGB}.') 21 | 22 | trainset = DatasetFromFolder(train_dir, replicate=200, 23 | input_transform=_input_transform(crop_size, upscale_factor), 24 | target_transform=_target_transform(crop_size), RGB=RGB) 25 | 26 | test_dir = join(root_dir, "test") 27 | validset = DatasetFromFolder(test_dir, replicate=200, 28 | input_transform=_input_transform(crop_size, upscale_factor), 29 | target_transform=_target_transform(crop_size), RGB=RGB) 30 | return trainset, validset 31 | 32 | def _build_bsds_dn(data_path, augmentations=True, normalize=True, upscale_factor=1, noise_level=25 / 255, RGB=True): 33 | root_dir = _download_bsd300(dest=data_path) 34 | train_dir = join(root_dir, "train") 35 | 36 | crop_size = _calculate_valid_crop_size(256, upscale_factor) 37 | patch_size = 64 38 | print(f'Crop size is {crop_size} for patches of size {patch_size}. ' 39 | f'Upscaling factor is {upscale_factor} in mode RGB={RGB}.') 40 | 41 | trainset = DatasetFromFolder(train_dir, replicate=200, 42 | input_transform=_input_transform(crop_size, upscale_factor, patch_size=patch_size), 43 | target_transform=_target_transform(crop_size, patch_size=patch_size), 44 | noise_level=noise_level, RGB=RGB) 45 | 46 | test_dir = join(root_dir, "test") 47 | validset = DatasetFromFolder(test_dir, replicate=200, 48 | input_transform=_input_transform(crop_size, upscale_factor), 49 | target_transform=_target_transform(crop_size), 50 | noise_level=noise_level, RGB=RGB) 51 | return trainset, validset 52 | 53 | 54 | def _download_bsd300(dest="dataset"): 55 | output_image_dir = join(dest, "BSDS300/images") 56 | 57 | if not exists(output_image_dir): 58 | makedirs(dest, exist_ok=True) 59 | url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz" 60 | print("downloading url ", url) 61 | 62 | data = urllib.request.urlopen(url) 63 | 64 | file_path = join(dest, basename(url)) 65 | with open(file_path, 'wb') as f: 66 | f.write(data.read()) 67 | 68 | print("Extracting data") 69 | with tarfile.open(file_path) as tar: 70 | for item in tar: 71 | tar.extract(item, dest) 72 | 73 | remove(file_path) 74 | 75 | return output_image_dir 76 | 77 | 78 | def _calculate_valid_crop_size(crop_size, upscale_factor): 79 | return crop_size - (crop_size % upscale_factor) 80 | 81 | 82 | def _input_transform(crop_size, upscale_factor, patch_size=None): 83 | return Compose([ 84 | CenterCrop(crop_size), 85 | Resize(crop_size // upscale_factor), 86 | RandomCrop(patch_size if patch_size is not None else crop_size // upscale_factor), 87 | ToTensor(), 88 | ]) 89 | 90 | 91 | def _target_transform(crop_size, patch_size=None): 92 | return Compose([ 93 | CenterCrop(crop_size), 94 | RandomCrop(patch_size if patch_size is not None else crop_size), 95 | ToTensor(), 96 | ]) 97 | -------------------------------------------------------------------------------- /GradInverting/inversefed/data/data_processing.py: -------------------------------------------------------------------------------- 1 | """Repeatable code parts concerning data loading.""" 2 | 3 | 4 | import torch 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | 8 | import os 9 | 10 | from ..consts import * 11 | 12 | from .data import _build_bsds_sr, _build_bsds_dn 13 | from .loss import Classification, PSNR 14 | 15 | 16 | def construct_dataloaders(dataset, defs, data_path='~/data', shuffle=True, normalize=True): 17 | """Return a dataloader with given dataset and augmentation, normalize data?.""" 18 | path = os.path.expanduser(data_path) 19 | 20 | if dataset == 'CIFAR10': 21 | trainset, validset = _build_cifar10(path, defs.augmentations, normalize) 22 | loss_fn = Classification() 23 | elif dataset == 'CIFAR100': 24 | trainset, validset = _build_cifar100(path, defs.augmentations, normalize) 25 | loss_fn = Classification() 26 | elif dataset == 'MNIST': 27 | trainset, validset = _build_mnist(path, defs.augmentations, normalize) 28 | loss_fn = Classification() 29 | elif dataset == 'MNIST_GRAY': 30 | trainset, validset = _build_mnist_gray(path, defs.augmentations, normalize) 31 | loss_fn = Classification() 32 | elif dataset == 'ImageNet': 33 | trainset, validset = _build_imagenet(path, defs.augmentations, normalize) 34 | loss_fn = Classification() 35 | elif dataset == 'BSDS-SR': 36 | trainset, validset = _build_bsds_sr(path, defs.augmentations, normalize, upscale_factor=3, RGB=True) 37 | loss_fn = PSNR() 38 | elif dataset == 'BSDS-DN': 39 | trainset, validset = _build_bsds_dn(path, defs.augmentations, normalize, noise_level=25 / 255, RGB=False) 40 | loss_fn = PSNR() 41 | elif dataset == 'BSDS-RGB': 42 | trainset, validset = _build_bsds_dn(path, defs.augmentations, normalize, noise_level=25 / 255, RGB=True) 43 | loss_fn = PSNR() 44 | 45 | if MULTITHREAD_DATAPROCESSING: 46 | num_workers = min(torch.get_num_threads(), MULTITHREAD_DATAPROCESSING) if torch.get_num_threads() > 1 else 0 47 | else: 48 | num_workers = 0 49 | 50 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=min(defs.batch_size, len(trainset)), 51 | shuffle=shuffle, drop_last=True, num_workers=num_workers, pin_memory=PIN_MEMORY) 52 | validloader = torch.utils.data.DataLoader(validset, batch_size=min(defs.batch_size, len(trainset)), 53 | shuffle=False, drop_last=False, num_workers=num_workers, pin_memory=PIN_MEMORY) 54 | 55 | return loss_fn, trainloader, validloader 56 | 57 | 58 | def _build_cifar10(data_path, augmentations=True, normalize=True): 59 | """Define CIFAR-10 with everything considered.""" 60 | # Load data 61 | trainset = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=transforms.ToTensor()) 62 | validset = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True, transform=transforms.ToTensor()) 63 | 64 | if cifar10_mean is None: 65 | data_mean, data_std = _get_meanstd(trainset) 66 | else: 67 | data_mean, data_std = cifar10_mean, cifar10_std 68 | 69 | # Organize preprocessing 70 | transform = transforms.Compose([ 71 | transforms.ToTensor(), 72 | transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)]) 73 | if augmentations: 74 | transform_train = transforms.Compose([ 75 | transforms.RandomCrop(32, padding=4), 76 | transforms.RandomHorizontalFlip(), 77 | transform]) 78 | trainset.transform = transform_train 79 | else: 80 | trainset.transform = transform 81 | validset.transform = transform 82 | 83 | return trainset, validset 84 | 85 | def _build_cifar100(data_path, augmentations=True, normalize=True): 86 | """Define CIFAR-100 with everything considered.""" 87 | # Load data 88 | trainset = torchvision.datasets.CIFAR100(root=data_path, train=True, download=True, transform=transforms.ToTensor()) 89 | validset = torchvision.datasets.CIFAR100(root=data_path, train=False, download=True, transform=transforms.ToTensor()) 90 | 91 | if cifar100_mean is None: 92 | data_mean, data_std = _get_meanstd(trainset) 93 | else: 94 | data_mean, data_std = cifar100_mean, cifar100_std 95 | 96 | # Organize preprocessing 97 | transform = transforms.Compose([ 98 | transforms.ToTensor(), 99 | transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)]) 100 | if augmentations: 101 | transform_train = transforms.Compose([ 102 | transforms.RandomCrop(32, padding=4), 103 | transforms.RandomHorizontalFlip(), 104 | transform]) 105 | trainset.transform = transform_train 106 | else: 107 | trainset.transform = transform 108 | validset.transform = transform 109 | 110 | return trainset, validset 111 | 112 | 113 | def _build_mnist(data_path, augmentations=True, normalize=True): 114 | """Define MNIST with everything considered.""" 115 | # Load data 116 | trainset = torchvision.datasets.MNIST(root=data_path, train=True, download=True, transform=transforms.ToTensor()) 117 | validset = torchvision.datasets.MNIST(root=data_path, train=False, download=True, transform=transforms.ToTensor()) 118 | 119 | if mnist_mean is None: 120 | cc = torch.cat([trainset[i][0].reshape(-1) for i in range(len(trainset))], dim=0) 121 | data_mean = (torch.mean(cc, dim=0).item(),) 122 | data_std = (torch.std(cc, dim=0).item(),) 123 | else: 124 | data_mean, data_std = mnist_mean, mnist_std 125 | 126 | # Organize preprocessing 127 | transform = transforms.Compose([ 128 | transforms.ToTensor(), 129 | transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)]) 130 | if augmentations: 131 | transform_train = transforms.Compose([ 132 | transforms.RandomCrop(28, padding=4), 133 | transforms.RandomHorizontalFlip(), 134 | transform]) 135 | trainset.transform = transform_train 136 | else: 137 | trainset.transform = transform 138 | validset.transform = transform 139 | 140 | return trainset, validset 141 | 142 | def _build_mnist_gray(data_path, augmentations=True, normalize=True): 143 | """Define MNIST with everything considered.""" 144 | # Load data 145 | trainset = torchvision.datasets.MNIST(root=data_path, train=True, download=True, transform=transforms.ToTensor()) 146 | validset = torchvision.datasets.MNIST(root=data_path, train=False, download=True, transform=transforms.ToTensor()) 147 | 148 | if mnist_mean is None: 149 | cc = torch.cat([trainset[i][0].reshape(-1) for i in range(len(trainset))], dim=0) 150 | data_mean = (torch.mean(cc, dim=0).item(),) 151 | data_std = (torch.std(cc, dim=0).item(),) 152 | else: 153 | data_mean, data_std = mnist_mean, mnist_std 154 | 155 | # Organize preprocessing 156 | transform = transforms.Compose([ 157 | transforms.Grayscale(num_output_channels=1), 158 | transforms.ToTensor(), 159 | transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x: x)]) 160 | if augmentations: 161 | transform_train = transforms.Compose([ 162 | transforms.Grayscale(num_output_channels=1), 163 | transforms.RandomCrop(28, padding=4), 164 | transforms.RandomHorizontalFlip(), 165 | transform]) 166 | trainset.transform = transform_train 167 | else: 168 | trainset.transform = transform 169 | validset.transform = transform 170 | 171 | return trainset, validset 172 | 173 | 174 | def _build_imagenet(data_path, augmentations=True, normalize=True): 175 | """Define ImageNet with everything considered.""" 176 | # Load data 177 | trainset = torchvision.datasets.ImageNet(root=data_path, split='train', transform=transforms.ToTensor()) 178 | validset = torchvision.datasets.ImageNet(root=data_path, split='val', transform=transforms.ToTensor()) 179 | 180 | if imagenet_mean is None: 181 | data_mean, data_std = _get_meanstd(trainset) 182 | else: 183 | data_mean, data_std = imagenet_mean, imagenet_std 184 | 185 | # Organize preprocessing 186 | transform = transforms.Compose([ 187 | transforms.Resize(256), 188 | transforms.CenterCrop(224), 189 | transforms.ToTensor(), 190 | transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x : x)]) 191 | if augmentations: 192 | transform_train = transforms.Compose([ 193 | transforms.RandomResizedCrop(224), 194 | transforms.RandomHorizontalFlip(), 195 | transforms.ToTensor(), 196 | transforms.Normalize(data_mean, data_std) if normalize else transforms.Lambda(lambda x : x)]) 197 | trainset.transform = transform_train 198 | else: 199 | trainset.transform = transform 200 | validset.transform = transform 201 | 202 | return trainset, validset 203 | 204 | 205 | def _get_meanstd(dataset): 206 | cc = torch.cat([trainset[i][0].reshape(3, -1) for i in range(len(trainset))], dim=1) 207 | data_mean = torch.mean(cc, dim=1).tolist() 208 | data_std = torch.std(cc, dim=1).tolist() 209 | return data_mean, data_std 210 | -------------------------------------------------------------------------------- /GradInverting/inversefed/data/datasets.py: -------------------------------------------------------------------------------- 1 | """This is dataset.py from pytorch-examples. 2 | 3 | Refer to 4 | 5 | https://github.com/pytorch/examples/blob/master/super_resolution/dataset.py. 6 | """ 7 | import torch 8 | import torch.utils.data as data 9 | 10 | from os import listdir 11 | from os.path import join 12 | from PIL import Image 13 | 14 | 15 | def _is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) 17 | 18 | 19 | def _load_img(filepath, RGB=True): 20 | img = Image.open(filepath) 21 | if RGB: 22 | pass 23 | else: 24 | img = img.convert('YCbCr') 25 | img, _, _ = img.split() 26 | return img 27 | 28 | 29 | class DatasetFromFolder(data.Dataset): 30 | """Generate an image-to-image dataset from images from the given folder.""" 31 | 32 | def __init__(self, image_dir, replicate=1, input_transform=None, target_transform=None, RGB=True, noise_level=0.0): 33 | """Init with directory, transforms and RGB switch.""" 34 | super(DatasetFromFolder, self).__init__() 35 | self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if _is_image_file(x)] 36 | 37 | self.input_transform = input_transform 38 | self.target_transform = target_transform 39 | 40 | self.replicate = replicate 41 | self.classes = [None] 42 | self.RGB = RGB 43 | self.noise_level = noise_level 44 | 45 | def __getitem__(self, index): 46 | """Index into dataset.""" 47 | input = _load_img(self.image_filenames[index % len(self.image_filenames)], RGB=self.RGB) 48 | target = input.copy() 49 | if self.input_transform: 50 | input = self.input_transform(input) 51 | if self.target_transform: 52 | target = self.target_transform(target) 53 | 54 | if self.noise_level > 0: 55 | # Add noise 56 | input += self.noise_level * torch.randn_like(input) 57 | 58 | return input, target 59 | 60 | def __len__(self): 61 | """Length is amount of files found.""" 62 | return len(self.image_filenames) * self.replicate 63 | -------------------------------------------------------------------------------- /GradInverting/inversefed/data/loss.py: -------------------------------------------------------------------------------- 1 | """Define various loss functions and bundle them with appropriate metrics.""" 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class Loss: 8 | """Abstract class, containing necessary methods. 9 | 10 | Abstract class to collect information about the 'higher-level' loss function, used to train an energy-based model 11 | containing the evaluation of the loss function, its gradients w.r.t. to first and second argument and evaluations 12 | of the actual metric that is targeted. 13 | 14 | """ 15 | 16 | def __init__(self): 17 | """Init.""" 18 | pass 19 | 20 | def __call__(self, reference, argmin): 21 | """Return l(x, y).""" 22 | raise NotImplementedError() 23 | return value, name, format 24 | 25 | def metric(self, reference, argmin): 26 | """The actually sought metric.""" 27 | raise NotImplementedError() 28 | return value, name, format 29 | 30 | 31 | class PSNR(Loss): 32 | """A classical MSE target. 33 | 34 | The minimized criterion is MSE Loss, the actual metric is average PSNR. 35 | """ 36 | 37 | def __init__(self): 38 | """Init with torch MSE.""" 39 | self.loss_fn = torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean') 40 | 41 | def __call__(self, x=None, y=None): 42 | """Return l(x, y).""" 43 | name = 'MSE' 44 | format = '.6f' 45 | if x is None: 46 | return name, format 47 | else: 48 | value = 0.5 * self.loss_fn(x, y) 49 | return value, name, format 50 | 51 | def metric(self, x=None, y=None): 52 | """The actually sought metric.""" 53 | name = 'avg PSNR' 54 | format = '.3f' 55 | if x is None: 56 | return name, format 57 | else: 58 | value = self.psnr_compute(x, y) 59 | return value, name, format 60 | 61 | @staticmethod 62 | def psnr_compute(img_batch, ref_batch, batched=False, factor=1.0): 63 | """Standard PSNR.""" 64 | def get_psnr(img_in, img_ref): 65 | mse = ((img_in - img_ref)**2).mean() 66 | if mse > 0 and torch.isfinite(mse): 67 | return (10 * torch.log10(factor**2 / mse)).item() 68 | elif not torch.isfinite(mse): 69 | return float('nan') 70 | else: 71 | return float('inf') 72 | 73 | if batched: 74 | psnr = get_psnr(img_batch.detach(), ref_batch) 75 | else: 76 | [B, C, m, n] = img_batch.shape 77 | psnrs = [] 78 | for sample in range(B): 79 | psnrs.append(get_psnr(img_batch.detach()[sample, :, :, :], ref_batch[sample, :, :, :])) 80 | psnr = np.mean(psnrs) 81 | 82 | return psnr 83 | 84 | 85 | class Classification(Loss): 86 | """A classical NLL loss for classification. Evaluation has the softmax baked in. 87 | 88 | The minimized criterion is cross entropy, the actual metric is total accuracy. 89 | """ 90 | 91 | def __init__(self): 92 | """Init with torch MSE.""" 93 | self.loss_fn = torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, 94 | reduce=None, reduction='mean') 95 | 96 | def __call__(self, x=None, y=None): 97 | """Return l(x, y).""" 98 | name = 'CrossEntropy' 99 | format = '1.5f' 100 | if x is None: 101 | return name, format 102 | else: 103 | value = self.loss_fn(x, y) 104 | return value, name, format 105 | 106 | def metric(self, x=None, y=None): 107 | """The actually sought metric.""" 108 | name = 'Accuracy' 109 | format = '6.2%' 110 | if x is None: 111 | return name, format 112 | else: 113 | value = (x.data.argmax(dim=1) == y).sum().float() / y.shape[0] 114 | return value.detach(), name, format 115 | -------------------------------------------------------------------------------- /GradInverting/inversefed/medianfilt.py: -------------------------------------------------------------------------------- 1 | """This is code for median pooling from https://gist.github.com/rwightman. 2 | 3 | https://gist.github.com/rwightman/f2d3849281624be7c0f11c85c87c1598 4 | """ 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn.modules.utils import _pair, _quadruple 8 | 9 | 10 | class MedianPool2d(nn.Module): 11 | """Median pool (usable as median filter when stride=1) module. 12 | 13 | Args: 14 | kernel_size: size of pooling kernel, int or 2-tuple 15 | stride: pool stride, int or 2-tuple 16 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 17 | same: override padding and enforce same padding, boolean 18 | """ 19 | 20 | def __init__(self, kernel_size=3, stride=1, padding=0, same=True): 21 | """Initialize with kernel_size, stride, padding.""" 22 | super().__init__() 23 | self.k = _pair(kernel_size) 24 | self.stride = _pair(stride) 25 | self.padding = _quadruple(padding) # convert to l, r, t, b 26 | self.same = same 27 | 28 | def _padding(self, x): 29 | if self.same: 30 | ih, iw = x.size()[2:] 31 | if ih % self.stride[0] == 0: 32 | ph = max(self.k[0] - self.stride[0], 0) 33 | else: 34 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 35 | if iw % self.stride[1] == 0: 36 | pw = max(self.k[1] - self.stride[1], 0) 37 | else: 38 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 39 | pl = pw // 2 40 | pr = pw - pl 41 | pt = ph // 2 42 | pb = ph - pt 43 | padding = (pl, pr, pt, pb) 44 | else: 45 | padding = self.padding 46 | return padding 47 | 48 | def forward(self, x): 49 | # using existing pytorch functions and tensor ops so that we get autograd, 50 | # would likely be more efficient to implement from scratch at C/Cuda level 51 | x = F.pad(x, self._padding(x), mode='reflect') 52 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 53 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 54 | return x 55 | -------------------------------------------------------------------------------- /GradInverting/inversefed/metrics.py: -------------------------------------------------------------------------------- 1 | """This is code based on https://sudomake.ai/inception-score-explained/.""" 2 | import torch 3 | import torchvision 4 | 5 | from collections import defaultdict 6 | 7 | class InceptionScore(torch.nn.Module): 8 | """Class that manages and returns the inception score of images.""" 9 | 10 | def __init__(self, batch_size=32, setup=dict(device=torch.device('cpu'), dtype=torch.float)): 11 | """Initialize with setup and target inception batch size.""" 12 | super().__init__() 13 | self.preprocessing = torch.nn.Upsample(size=(299, 299), mode='bilinear', align_corners=False) 14 | self.model = torchvision.models.inception_v3(pretrained=True).to(**setup) 15 | self.model.eval() 16 | self.batch_size = batch_size 17 | 18 | def forward(self, image_batch): 19 | """Image batch should have dimensions BCHW and should be normalized. 20 | 21 | B should be divisible by self.batch_size. 22 | """ 23 | B, C, H, W = image_batch.shape 24 | batches = B // self.batch_size 25 | scores = [] 26 | for batch in range(batches): 27 | input = self.preprocessing(image_batch[batch * self.batch_size: (batch + 1) * self.batch_size]) 28 | scores.append(self.model(input)) 29 | prob_yx = torch.nn.functional.softmax(torch.cat(scores, 0), dim=1) 30 | entropy = torch.where(prob_yx > 0, -prob_yx * prob_yx.log(), torch.zeros_like(prob_yx)) 31 | return entropy.sum() 32 | 33 | 34 | def psnr(img_batch, ref_batch, batched=False, factor=1.0): 35 | """Standard PSNR.""" 36 | def get_psnr(img_in, img_ref): 37 | mse = ((img_in - img_ref)**2).mean() 38 | if mse > 0 and torch.isfinite(mse): 39 | return (10 * torch.log10(factor**2 / mse)) 40 | elif not torch.isfinite(mse): 41 | return img_batch.new_tensor(float('nan')) 42 | else: 43 | return img_batch.new_tensor(float('inf')) 44 | 45 | if batched: 46 | psnr = get_psnr(img_batch.detach(), ref_batch) 47 | else: 48 | [B, C, m, n] = img_batch.shape 49 | psnrs = [] 50 | for sample in range(B): 51 | psnrs.append(get_psnr(img_batch.detach()[sample, :, :, :], ref_batch[sample, :, :, :])) 52 | psnr = torch.stack(psnrs, dim=0).mean() 53 | 54 | return psnr.item() 55 | 56 | 57 | def total_variation(x): 58 | """Anisotropic TV.""" 59 | dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) 60 | dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) 61 | return dx + dy 62 | 63 | 64 | 65 | def activation_errors(model, x1, x2): 66 | """Compute activation-level error metrics for every module in the network.""" 67 | model.eval() 68 | 69 | device = next(model.parameters()).device 70 | 71 | hooks = [] 72 | data = defaultdict(dict) 73 | inputs = torch.cat((x1, x2), dim=0) 74 | separator = x1.shape[0] 75 | 76 | def check_activations(self, input, output): 77 | module_name = str(*[name for name, mod in model.named_modules() if self is mod]) 78 | try: 79 | layer_inputs = input[0].detach() 80 | residual = (layer_inputs[:separator] - layer_inputs[separator:]).pow(2) 81 | se_error = residual.sum() 82 | mse_error = residual.mean() 83 | sim = torch.nn.functional.cosine_similarity(layer_inputs[:separator].flatten(), 84 | layer_inputs[separator:].flatten(), 85 | dim=0, eps=1e-8).detach() 86 | data['se'][module_name] = se_error.item() 87 | data['mse'][module_name] = mse_error.item() 88 | data['sim'][module_name] = sim.item() 89 | except (KeyboardInterrupt, SystemExit): 90 | raise 91 | except AttributeError: 92 | pass 93 | 94 | for name, module in model.named_modules(): 95 | hooks.append(module.register_forward_hook(check_activations)) 96 | 97 | try: 98 | outputs = model(inputs.to(device)) 99 | for hook in hooks: 100 | hook.remove() 101 | except Exception as e: 102 | for hook in hooks: 103 | hook.remove() 104 | raise 105 | 106 | return data 107 | -------------------------------------------------------------------------------- /GradInverting/inversefed/nn/README.md: -------------------------------------------------------------------------------- 1 | # Models and modules are implemented here -------------------------------------------------------------------------------- /GradInverting/inversefed/nn/__init__.py: -------------------------------------------------------------------------------- 1 | """Experimental modules and unexperimental model hooks.""" 2 | 3 | from .models import construct_model 4 | from .modules import MetaMonkey 5 | 6 | __all__ = ['construct_model', 'MetaMonkey'] 7 | -------------------------------------------------------------------------------- /GradInverting/inversefed/nn/densenet.py: -------------------------------------------------------------------------------- 1 | """DenseNet in PyTorch.""" 2 | """Adaptation we did with ******.""" 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class _Bottleneck(nn.Module): 11 | def __init__(self, in_planes, growth_rate): 12 | super().__init__() 13 | self.bn1 = nn.BatchNorm2d(in_planes) 14 | self.conv1 = nn.Conv2d(in_planes, 4 * growth_rate, kernel_size=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(4 * growth_rate) 16 | self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 17 | 18 | def forward(self, x): 19 | out = self.conv1(F.relu(self.bn1(x))) 20 | out = self.conv2(F.relu(self.bn2(out))) 21 | out = torch.cat([out, x], 1) 22 | return out 23 | 24 | 25 | class _Transition(nn.Module): 26 | def __init__(self, in_planes, out_planes): 27 | super().__init__() 28 | self.bn = nn.BatchNorm2d(in_planes) 29 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 30 | 31 | def forward(self, x): 32 | out = self.conv(F.relu(self.bn(x))) 33 | out = F.avg_pool2d(out, 2) 34 | return out 35 | 36 | 37 | class _DenseNet(nn.Module): 38 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 39 | super().__init__() 40 | self.growth_rate = growth_rate 41 | 42 | num_planes = 2 * growth_rate 43 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 44 | 45 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 46 | num_planes += nblocks[0] * growth_rate 47 | out_planes = int(math.floor(num_planes * reduction)) 48 | self.trans1 = _Transition(num_planes, out_planes) 49 | num_planes = out_planes 50 | 51 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 52 | num_planes += nblocks[1] * growth_rate 53 | out_planes = int(math.floor(num_planes * reduction)) 54 | self.trans2 = _Transition(num_planes, out_planes) 55 | num_planes = out_planes 56 | 57 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 58 | num_planes += nblocks[2] * growth_rate 59 | out_planes = int(math.floor(num_planes * reduction)) 60 | # self.trans3 = Transition(num_planes, out_planes) 61 | # num_planes = out_planes 62 | 63 | # self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 64 | # num_planes += nblocks[3]*growth_rate 65 | 66 | self.bn = nn.BatchNorm2d(num_planes) 67 | num_planes = 132 * growth_rate // 12 * 2 * 2 68 | self.linear = nn.Linear(num_planes, num_classes) 69 | 70 | def _make_dense_layers(self, block, in_planes, nblock): 71 | layers = [] 72 | for i in range(nblock): 73 | layers.append(block(in_planes, self.growth_rate)) 74 | in_planes += self.growth_rate 75 | return nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | out = self.conv1(x) 79 | out = self.trans1(self.dense1(out)) 80 | out = self.trans2(self.dense2(out)) 81 | out = self.dense3(out) 82 | # out = self.trans3(self.dense3(out)) 83 | # out = self.dense4(out) 84 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 85 | out = out.view(out.size(0), -1) 86 | out = self.linear(out) 87 | return out 88 | 89 | 90 | def densenet_cifar(num_classes=10): 91 | """Instantiate the smallest DenseNet.""" 92 | return _DenseNet(_Bottleneck, [6, 6, 6, 0], growth_rate=12, num_classes=num_classes) 93 | -------------------------------------------------------------------------------- /GradInverting/inversefed/nn/models.py: -------------------------------------------------------------------------------- 1 | """Define basic models and translate some torchvision stuff.""" 2 | """Stuff from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py.""" 3 | import torch 4 | import torchvision 5 | import torch.nn as nn 6 | 7 | from torchvision.models.resnet import Bottleneck 8 | from .revnet import iRevNet 9 | from .densenet import _DenseNet, _Bottleneck 10 | 11 | from collections import OrderedDict 12 | import numpy as np 13 | from ..utils import set_random_seed 14 | 15 | 16 | 17 | 18 | def construct_model(model, num_classes=10, seed=None, num_channels=3, modelkey=None): 19 | """Return various models.""" 20 | if modelkey is None: 21 | if seed is None: 22 | model_init_seed = np.random.randint(0, 2**32 - 10) 23 | else: 24 | model_init_seed = seed 25 | else: 26 | model_init_seed = modelkey 27 | set_random_seed(model_init_seed) 28 | 29 | if model in ['ConvNet', 'ConvNet64']: 30 | model = ConvNet(width=64, num_channels=num_channels, num_classes=num_classes) 31 | elif model == 'ConvNet8': 32 | model = ConvNet(width=64, num_channels=num_channels, num_classes=num_classes) 33 | elif model == 'ConvNet16': 34 | model = ConvNet(width=64, num_channels=num_channels, num_classes=num_classes) 35 | elif model == 'ConvNet32': 36 | model = ConvNet(width=64, num_channels=num_channels, num_classes=num_classes) 37 | elif model == 'BeyondInferringMNIST': 38 | model = torch.nn.Sequential(OrderedDict([ 39 | ('conv1', torch.nn.Conv2d(1, 32, 3, stride=2, padding=1)), 40 | ('relu0', torch.nn.LeakyReLU()), 41 | ('conv2', torch.nn.Conv2d(32, 64, 3, stride=1, padding=1)), 42 | ('relu1', torch.nn.LeakyReLU()), 43 | ('conv3', torch.nn.Conv2d(64, 128, 3, stride=2, padding=1)), 44 | ('relu2', torch.nn.LeakyReLU()), 45 | ('conv4', torch.nn.Conv2d(128, 256, 3, stride=1, padding=1)), 46 | ('relu3', torch.nn.LeakyReLU()), 47 | ('flatt', torch.nn.Flatten()), 48 | ('linear0', torch.nn.Linear(12544, 12544)), 49 | ('relu4', torch.nn.LeakyReLU()), 50 | ('linear1', torch.nn.Linear(12544, 10)), 51 | ('softmax', torch.nn.Softmax(dim=1)) 52 | ])) 53 | elif model == 'BeyondInferringCifar': 54 | model = torch.nn.Sequential(OrderedDict([ 55 | ('conv1', torch.nn.Conv2d(3, 32, 3, stride=2, padding=1)), 56 | ('relu0', torch.nn.LeakyReLU()), 57 | ('conv2', torch.nn.Conv2d(32, 64, 3, stride=1, padding=1)), 58 | ('relu1', torch.nn.LeakyReLU()), 59 | ('conv3', torch.nn.Conv2d(64, 128, 3, stride=2, padding=1)), 60 | ('relu2', torch.nn.LeakyReLU()), 61 | ('conv4', torch.nn.Conv2d(128, 256, 3, stride=1, padding=1)), 62 | ('relu3', torch.nn.LeakyReLU()), 63 | ('flatt', torch.nn.Flatten()), 64 | ('linear0', torch.nn.Linear(12544, 12544)), 65 | ('relu4', torch.nn.LeakyReLU()), 66 | ('linear1', torch.nn.Linear(12544, 10)), 67 | ('softmax', torch.nn.Softmax(dim=1)) 68 | ])) 69 | elif model == 'MLP': 70 | width = 1024 71 | model = torch.nn.Sequential(OrderedDict([ 72 | ('flatten', torch.nn.Flatten()), 73 | ('linear0', torch.nn.Linear(3072, width)), 74 | ('relu0', torch.nn.ReLU()), 75 | ('linear1', torch.nn.Linear(width, width)), 76 | ('relu1', torch.nn.ReLU()), 77 | ('linear2', torch.nn.Linear(width, width)), 78 | ('relu2', torch.nn.ReLU()), 79 | ('linear3', torch.nn.Linear(width, num_classes))])) 80 | elif model == 'TwoLP': 81 | width = 2048 82 | model = torch.nn.Sequential(OrderedDict([ 83 | ('flatten', torch.nn.Flatten()), 84 | ('linear0', torch.nn.Linear(3072, width)), 85 | ('relu0', torch.nn.ReLU()), 86 | ('linear3', torch.nn.Linear(width, num_classes))])) 87 | elif model == 'ResNet20': 88 | model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16) 89 | elif model == 'ResNet20-nostride': 90 | model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16, 91 | strides=[1, 1, 1, 1]) 92 | elif model == 'ResNet20-10': 93 | model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16 * 10) 94 | elif model == 'ResNet20-4': 95 | model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16 * 4) 96 | elif model == 'ResNet20-4-unpooled': 97 | model = ResNet(torchvision.models.resnet.BasicBlock, [3, 3, 3], num_classes=num_classes, base_width=16 * 4, 98 | pool='max') 99 | elif model == 'ResNet28-10': 100 | model = ResNet(torchvision.models.resnet.BasicBlock, [4, 4, 4], num_classes=num_classes, base_width=16 * 10) 101 | elif model == 'ResNet32': 102 | model = ResNet(torchvision.models.resnet.BasicBlock, [5, 5, 5], num_classes=num_classes, base_width=16) 103 | elif model == 'ResNet32-10': 104 | model = ResNet(torchvision.models.resnet.BasicBlock, [5, 5, 5], num_classes=num_classes, base_width=16 * 10) 105 | elif model == 'ResNet44': 106 | model = ResNet(torchvision.models.resnet.BasicBlock, [7, 7, 7], num_classes=num_classes, base_width=16) 107 | elif model == 'ResNet56': 108 | model = ResNet(torchvision.models.resnet.BasicBlock, [9, 9, 9], num_classes=num_classes, base_width=16) 109 | elif model == 'ResNet110': 110 | model = ResNet(torchvision.models.resnet.BasicBlock, [18, 18, 18], num_classes=num_classes, base_width=16) 111 | elif model == 'ResNet18': 112 | model = ResNet(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=num_classes, base_width=64) 113 | elif model == 'ResNet34': 114 | model = ResNet(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], num_classes=num_classes, base_width=64) 115 | elif model == 'ResNet50': 116 | model = ResNet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], num_classes=num_classes, base_width=64) 117 | elif model == 'ResNet50-2': 118 | model = ResNet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], num_classes=num_classes, base_width=64 * 2) 119 | elif model == 'ResNet101': 120 | model = ResNet(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], num_classes=num_classes, base_width=64) 121 | elif model == 'ResNet152': 122 | model = ResNet(torchvision.models.resnet.Bottleneck, [3, 8, 36, 3], num_classes=num_classes, base_width=64) 123 | elif model == 'MobileNet': 124 | inverted_residual_setting = [ 125 | # t, c, n, s 126 | [1, 16, 1, 1], 127 | [6, 24, 2, 1], # cifar adaptation, cf.https://github.com/kuangliu/pytorch-cifar/blob/master/models/mobilenetv2.py 128 | [6, 32, 3, 2], 129 | [6, 64, 4, 2], 130 | [6, 96, 3, 1], 131 | [6, 160, 3, 2], 132 | [6, 320, 1, 1], 133 | ] 134 | model = torchvision.models.MobileNetV2(num_classes=num_classes, 135 | inverted_residual_setting=inverted_residual_setting, 136 | width_mult=1.0) 137 | model.features[0] = torchvision.models.mobilenet.ConvBNReLU(num_channels, 32, stride=1) # this is fixed to width=1 138 | elif model == 'MNASNet': 139 | model = torchvision.models.MNASNet(1.0, num_classes=num_classes, dropout=0.2) 140 | elif model == 'DenseNet121': 141 | model = torchvision.models.DenseNet(growth_rate=32, block_config=(6, 12, 24, 16), 142 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=num_classes, 143 | memory_efficient=False) 144 | elif model == 'DenseNet40': 145 | model = _DenseNet(_Bottleneck, [6, 6, 6, 0], growth_rate=12, num_classes=num_classes) 146 | elif model == 'DenseNet40-4': 147 | model = _DenseNet(_Bottleneck, [6, 6, 6, 0], growth_rate=12 * 4, num_classes=num_classes) 148 | elif model == 'SRNet3': 149 | model = SRNet(upscale_factor=3, num_channels=num_channels) 150 | elif model == 'SRNet1': 151 | model = SRNet(upscale_factor=1, num_channels=num_channels) 152 | elif model == 'iRevNet': 153 | if num_classes <= 100: 154 | in_shape = [num_channels, 32, 32] # only for cifar right now 155 | model = iRevNet(nBlocks=[18, 18, 18], nStrides=[1, 2, 2], 156 | nChannels=[16, 64, 256], nClasses=num_classes, 157 | init_ds=0, dropout_rate=0.1, affineBN=True, 158 | in_shape=in_shape, mult=4) 159 | else: 160 | in_shape = [3, 224, 224] # only for imagenet 161 | model = iRevNet(nBlocks=[6, 16, 72, 6], nStrides=[2, 2, 2, 2], 162 | nChannels=[24, 96, 384, 1536], nClasses=num_classes, 163 | init_ds=2, dropout_rate=0.1, affineBN=True, 164 | in_shape=in_shape, mult=4) 165 | elif model == 'LeNetZhu': 166 | model = LeNetZhu(num_channels=num_channels, num_classes=num_classes) 167 | else: 168 | raise NotImplementedError('Model not implemented.') 169 | 170 | print(f'Model initialized with random key {model_init_seed}.') 171 | return model, model_init_seed 172 | 173 | 174 | class ResNet(torchvision.models.ResNet): 175 | """ResNet generalization for CIFAR thingies.""" 176 | 177 | def __init__(self, block, layers, num_classes=10, zero_init_residual=False, 178 | groups=1, base_width=64, replace_stride_with_dilation=None, 179 | norm_layer=None, strides=[1, 2, 2, 2], pool='avg'): 180 | """Initialize as usual. Layers and strides are scriptable.""" 181 | super(torchvision.models.ResNet, self).__init__() # nn.Module 182 | if norm_layer is None: 183 | norm_layer = nn.BatchNorm2d 184 | self._norm_layer = norm_layer 185 | 186 | 187 | self.dilation = 1 188 | if replace_stride_with_dilation is None: 189 | # each element in the tuple indicates if we should replace 190 | # the 2x2 stride with a dilated convolution instead 191 | replace_stride_with_dilation = [False, False, False, False] 192 | if len(replace_stride_with_dilation) != 4: 193 | raise ValueError("replace_stride_with_dilation should be None " 194 | "or a 4-element tuple, got {}".format(replace_stride_with_dilation)) 195 | self.groups = groups 196 | 197 | self.inplanes = base_width 198 | self.base_width = 64 # Do this to circumvent BasicBlock errors. The value is not actually used. 199 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 200 | self.bn1 = norm_layer(self.inplanes) 201 | self.relu = nn.ReLU(inplace=True) 202 | 203 | self.layers = torch.nn.ModuleList() 204 | width = self.inplanes 205 | for idx, layer in enumerate(layers): 206 | self.layers.append(self._make_layer(block, width, layer, stride=strides[idx], dilate=replace_stride_with_dilation[idx])) 207 | width *= 2 208 | 209 | self.pool = nn.AdaptiveAvgPool2d((1, 1)) if pool == 'avg' else nn.AdaptiveMaxPool2d((1, 1)) 210 | self.fc = nn.Linear(width // 2 * block.expansion, num_classes) 211 | 212 | for m in self.modules(): 213 | if isinstance(m, nn.Conv2d): 214 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 215 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 216 | nn.init.constant_(m.weight, 1) 217 | nn.init.constant_(m.bias, 0) 218 | 219 | # Zero-initialize the last BN in each residual branch, 220 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 221 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 222 | if zero_init_residual: 223 | for m in self.modules(): 224 | if isinstance(m, Bottleneck): 225 | nn.init.constant_(m.bn3.weight, 0) 226 | elif isinstance(m, BasicBlock): 227 | nn.init.constant_(m.bn2.weight, 0) 228 | 229 | 230 | def _forward_impl(self, x): 231 | # See note [TorchScript super()] 232 | x = self.conv1(x) 233 | x = self.bn1(x) 234 | x = self.relu(x) 235 | 236 | for layer in self.layers: 237 | x = layer(x) 238 | 239 | x = self.pool(x) 240 | x = torch.flatten(x, 1) 241 | x = self.fc(x) 242 | 243 | return x 244 | 245 | 246 | class ConvNet(torch.nn.Module): 247 | """ConvNetBN.""" 248 | 249 | def __init__(self, width=32, num_classes=10, num_channels=3): 250 | """Init with width and num classes.""" 251 | super().__init__() 252 | self.model = torch.nn.Sequential(OrderedDict([ 253 | ('conv0', torch.nn.Conv2d(num_channels, 1 * width, kernel_size=3, padding=1)), 254 | ('bn0', torch.nn.BatchNorm2d(1 * width)), 255 | ('relu0', torch.nn.ReLU()), 256 | 257 | ('conv1', torch.nn.Conv2d(1 * width, 2 * width, kernel_size=3, padding=1)), 258 | ('bn1', torch.nn.BatchNorm2d(2 * width)), 259 | ('relu1', torch.nn.ReLU()), 260 | 261 | ('conv2', torch.nn.Conv2d(2 * width, 2 * width, kernel_size=3, padding=1)), 262 | ('bn2', torch.nn.BatchNorm2d(2 * width)), 263 | ('relu2', torch.nn.ReLU()), 264 | 265 | ('conv3', torch.nn.Conv2d(2 * width, 4 * width, kernel_size=3, padding=1)), 266 | ('bn3', torch.nn.BatchNorm2d(4 * width)), 267 | ('relu3', torch.nn.ReLU()), 268 | 269 | ('conv4', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)), 270 | ('bn4', torch.nn.BatchNorm2d(4 * width)), 271 | ('relu4', torch.nn.ReLU()), 272 | 273 | ('conv5', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)), 274 | ('bn5', torch.nn.BatchNorm2d(4 * width)), 275 | ('relu5', torch.nn.ReLU()), 276 | 277 | ('pool0', torch.nn.MaxPool2d(3)), 278 | 279 | ('conv6', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)), 280 | ('bn6', torch.nn.BatchNorm2d(4 * width)), 281 | ('relu6', torch.nn.ReLU()), 282 | 283 | ('conv6', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)), 284 | ('bn6', torch.nn.BatchNorm2d(4 * width)), 285 | ('relu6', torch.nn.ReLU()), 286 | 287 | ('conv7', torch.nn.Conv2d(4 * width, 4 * width, kernel_size=3, padding=1)), 288 | ('bn7', torch.nn.BatchNorm2d(4 * width)), 289 | ('relu7', torch.nn.ReLU()), 290 | 291 | ('pool1', torch.nn.MaxPool2d(3)), 292 | ('flatten', torch.nn.Flatten()), 293 | ('linear', torch.nn.Linear(36 * width, num_classes)) 294 | ])) 295 | 296 | def forward(self, input): 297 | return self.model(input) 298 | 299 | 300 | class LeNetZhu(nn.Module): 301 | """LeNet variant from https://github.com/mit-han-lab/dlg/blob/master/models/vision.py.""" 302 | 303 | def __init__(self, num_classes=10, num_channels=3): 304 | """3-Layer sigmoid Conv with large linear layer.""" 305 | super().__init__() 306 | act = nn.Sigmoid 307 | self.body = nn.Sequential( 308 | nn.Conv2d(num_channels, 12, kernel_size=5, padding=5 // 2, stride=2), 309 | act(), 310 | nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=2), 311 | act(), 312 | nn.Conv2d(12, 12, kernel_size=5, padding=5 // 2, stride=1), 313 | act(), 314 | ) 315 | self.fc = nn.Sequential( 316 | nn.Linear(768, num_classes) 317 | ) 318 | for module in self.modules(): 319 | self.weights_init(module) 320 | 321 | @staticmethod 322 | def weights_init(m): 323 | if hasattr(m, "weight"): 324 | m.weight.data.uniform_(-0.5, 0.5) 325 | if hasattr(m, "bias"): 326 | m.bias.data.uniform_(-0.5, 0.5) 327 | 328 | def forward(self, x): 329 | out = self.body(x) 330 | out = out.view(out.size(0), -1) 331 | # print(out.size()) 332 | out = self.fc(out) 333 | return out 334 | -------------------------------------------------------------------------------- /GradInverting/inversefed/nn/modules.py: -------------------------------------------------------------------------------- 1 | """For monkey-patching into meta-learning frameworks.""" 2 | import torch 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | from functools import partial 6 | import warnings 7 | 8 | from ..consts import BENCHMARK 9 | torch.backends.cudnn.benchmark = BENCHMARK 10 | 11 | DEBUG = False # Emit warning messages when patching. Use this to bootstrap new architectures. 12 | 13 | class MetaMonkey(torch.nn.Module): 14 | """Trace a networks and then replace its module calls with functional calls. 15 | 16 | This allows for backpropagation w.r.t to weights for "normal" PyTorch networks. 17 | """ 18 | 19 | def __init__(self, net): 20 | """Init with network.""" 21 | super().__init__() 22 | self.net = net 23 | self.parameters = OrderedDict(net.named_parameters()) 24 | 25 | 26 | def forward(self, inputs, parameters=None): 27 | """Live Patch ... :> ...""" 28 | # If no parameter dictionary is given, everything is normal 29 | if parameters is None: 30 | return self.net(inputs) 31 | 32 | # But if not ... 33 | param_gen = iter(parameters.values()) 34 | method_pile = [] 35 | counter = 0 36 | 37 | for name, module in self.net.named_modules(): 38 | if isinstance(module, torch.nn.Conv2d): 39 | ext_weight = next(param_gen) 40 | if module.bias is not None: 41 | ext_bias = next(param_gen) 42 | else: 43 | ext_bias = None 44 | 45 | method_pile.append(module.forward) 46 | module.forward = partial(F.conv2d, weight=ext_weight, bias=ext_bias, stride=module.stride, 47 | padding=module.padding, dilation=module.dilation, groups=module.groups) 48 | elif isinstance(module, torch.nn.BatchNorm2d): 49 | if module.momentum is None: 50 | exponential_average_factor = 0.0 51 | else: 52 | exponential_average_factor = module.momentum 53 | 54 | if module.training and module.track_running_stats: 55 | if module.num_batches_tracked is not None: 56 | module.num_batches_tracked += 1 57 | if module.momentum is None: # use cumulative moving average 58 | exponential_average_factor = 1.0 / float(module.num_batches_tracked) 59 | else: # use exponential moving average 60 | exponential_average_factor = module.momentum 61 | 62 | ext_weight = next(param_gen) 63 | ext_bias = next(param_gen) 64 | method_pile.append(module.forward) 65 | module.forward = partial(F.batch_norm, running_mean=module.running_mean, running_var=module.running_var, 66 | weight=ext_weight, bias=ext_bias, 67 | training=module.training or not module.track_running_stats, 68 | momentum=exponential_average_factor, eps=module.eps) 69 | 70 | elif isinstance(module, torch.nn.Linear): 71 | lin_weights = next(param_gen) 72 | lin_bias = next(param_gen) 73 | method_pile.append(module.forward) 74 | module.forward = partial(F.linear, weight=lin_weights, bias=lin_bias) 75 | 76 | elif next(module.parameters(), None) is None: 77 | # Pass over modules that do not contain parameters 78 | pass 79 | elif isinstance(module, torch.nn.Sequential): 80 | # Pass containers 81 | pass 82 | else: 83 | # Warn for other containers 84 | if DEBUG: 85 | warnings.warn(f'Patching for module {module.__class__} is not implemented.') 86 | 87 | output = self.net(inputs) 88 | 89 | # Undo Patch 90 | for name, module in self.net.named_modules(): 91 | if isinstance(module, torch.nn.modules.conv.Conv2d): 92 | module.forward = method_pile.pop(0) 93 | elif isinstance(module, torch.nn.BatchNorm2d): 94 | module.forward = method_pile.pop(0) 95 | elif isinstance(module, torch.nn.Linear): 96 | module.forward = method_pile.pop(0) 97 | 98 | return output 99 | -------------------------------------------------------------------------------- /GradInverting/inversefed/nn/revnet.py: -------------------------------------------------------------------------------- 1 | """https://github.com/jhjacobsen/pytorch-i-revnet/blob/master/models/iRevNet.py. 2 | 3 | Code for "i-RevNet: Deep Invertible Networks" 4 | https://openreview.net/pdf?id=HJsjkMb0Z 5 | ICLR, 2018 6 | 7 | 8 | (c) Joern-Henrik Jacobsen, 2018 9 | """ 10 | 11 | """ 12 | MIT License 13 | 14 | Copyright (c) 2018 Jörn Jacobsen 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | """ 34 | 35 | import torch 36 | import torch.nn as nn 37 | import torch.nn.functional as F 38 | from .revnet_utils import split, merge, injective_pad, psi 39 | 40 | 41 | class irevnet_block(nn.Module): 42 | """This is an i-revnet block from Jacobsen et al.""" 43 | 44 | def __init__(self, in_ch, out_ch, stride=1, first=False, dropout_rate=0., 45 | affineBN=True, mult=4): 46 | """Build invertible bottleneck block.""" 47 | super(irevnet_block, self).__init__() 48 | self.first = first 49 | self.pad = 2 * out_ch - in_ch 50 | self.stride = stride 51 | self.inj_pad = injective_pad(self.pad) 52 | self.psi = psi(stride) 53 | if self.pad != 0 and stride == 1: 54 | in_ch = out_ch * 2 55 | print('') 56 | print('| Injective iRevNet |') 57 | print('') 58 | layers = [] 59 | if not first: 60 | layers.append(nn.BatchNorm2d(in_ch // 2, affine=affineBN)) 61 | layers.append(nn.ReLU(inplace=True)) 62 | layers.append(nn.Conv2d(in_ch // 2, int(out_ch // mult), kernel_size=3, 63 | stride=stride, padding=1, bias=False)) 64 | layers.append(nn.BatchNorm2d(int(out_ch // mult), affine=affineBN)) 65 | layers.append(nn.ReLU(inplace=True)) 66 | layers.append(nn.Conv2d(int(out_ch // mult), int(out_ch // mult), 67 | kernel_size=3, padding=1, bias=False)) 68 | layers.append(nn.Dropout(p=dropout_rate)) 69 | layers.append(nn.BatchNorm2d(int(out_ch // mult), affine=affineBN)) 70 | layers.append(nn.ReLU(inplace=True)) 71 | layers.append(nn.Conv2d(int(out_ch // mult), out_ch, kernel_size=3, 72 | padding=1, bias=False)) 73 | self.bottleneck_block = nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | """Bijective or injective block forward.""" 77 | if self.pad != 0 and self.stride == 1: 78 | x = merge(x[0], x[1]) 79 | x = self.inj_pad.forward(x) 80 | x1, x2 = split(x) 81 | x = (x1, x2) 82 | x1 = x[0] 83 | x2 = x[1] 84 | Fx2 = self.bottleneck_block(x2) 85 | if self.stride == 2: 86 | x1 = self.psi.forward(x1) 87 | x2 = self.psi.forward(x2) 88 | y1 = Fx2 + x1 89 | return (x2, y1) 90 | 91 | def inverse(self, x): 92 | """Bijective or injecitve block inverse.""" 93 | x2, y1 = x[0], x[1] 94 | if self.stride == 2: 95 | x2 = self.psi.inverse(x2) 96 | Fx2 = - self.bottleneck_block(x2) 97 | x1 = Fx2 + y1 98 | if self.stride == 2: 99 | x1 = self.psi.inverse(x1) 100 | if self.pad != 0 and self.stride == 1: 101 | x = merge(x1, x2) 102 | x = self.inj_pad.inverse(x) 103 | x1, x2 = split(x) 104 | x = (x1, x2) 105 | else: 106 | x = (x1, x2) 107 | return x 108 | 109 | 110 | class iRevNet(nn.Module): 111 | """This is an i-revnet from Jacobsen et al.""" 112 | 113 | def __init__(self, nBlocks, nStrides, nClasses, nChannels=None, init_ds=2, 114 | dropout_rate=0., affineBN=True, in_shape=None, mult=4): 115 | """Init with e.g. nBlocks=[18, 18, 18], nStrides = [1, 2, 2].""" 116 | super(iRevNet, self).__init__() 117 | self.ds = in_shape[2] // 2**(nStrides.count(2) + init_ds // 2) 118 | self.init_ds = init_ds 119 | self.in_ch = in_shape[0] * 2**self.init_ds 120 | self.nBlocks = nBlocks 121 | self.first = True 122 | 123 | print('') 124 | print(' == Building iRevNet %d == ' % (sum(nBlocks) * 3 + 1)) 125 | if not nChannels: 126 | nChannels = [self.in_ch // 2, self.in_ch // 2 * 4, 127 | self.in_ch // 2 * 4**2, self.in_ch // 2 * 4**3] 128 | 129 | self.init_psi = psi(self.init_ds) 130 | self.stack = self.irevnet_stack(irevnet_block, nChannels, nBlocks, 131 | nStrides, dropout_rate=dropout_rate, 132 | affineBN=affineBN, in_ch=self.in_ch, 133 | mult=mult) 134 | self.bn1 = nn.BatchNorm2d(nChannels[-1] * 2, momentum=0.9) 135 | self.linear = nn.Linear(nChannels[-1] * 2, nClasses) 136 | 137 | def irevnet_stack(self, _block, nChannels, nBlocks, nStrides, dropout_rate, 138 | affineBN, in_ch, mult): 139 | """Create stack of irevnet blocks.""" 140 | block_list = nn.ModuleList() 141 | strides = [] 142 | channels = [] 143 | for channel, depth, stride in zip(nChannels, nBlocks, nStrides): 144 | strides = strides + ([stride] + [1] * (depth - 1)) 145 | channels = channels + ([channel] * depth) 146 | for channel, stride in zip(channels, strides): 147 | block_list.append(_block(in_ch, channel, stride, 148 | first=self.first, 149 | dropout_rate=dropout_rate, 150 | affineBN=affineBN, mult=mult)) 151 | in_ch = 2 * channel 152 | self.first = False 153 | return block_list 154 | 155 | def forward(self, x, return_bijection=False): 156 | """Irevnet forward.""" 157 | n = self.in_ch // 2 158 | if self.init_ds != 0: 159 | x = self.init_psi.forward(x) 160 | out = (x[:, :n, :, :], x[:, n:, :, :]) 161 | for block in self.stack: 162 | out = block.forward(out) 163 | out_bij = merge(out[0], out[1]) 164 | out = F.relu(self.bn1(out_bij)) 165 | out = F.avg_pool2d(out, self.ds) 166 | out = out.view(out.size(0), -1) 167 | out = self.linear(out) 168 | if return_bijection: 169 | return out, out_bij 170 | else: 171 | return out 172 | 173 | def inverse(self, out_bij): 174 | """Irevnet inverse.""" 175 | out = split(out_bij) 176 | for i in range(len(self.stack)): 177 | out = self.stack[-1 - i].inverse(out) 178 | out = merge(out[0], out[1]) 179 | if self.init_ds != 0: 180 | x = self.init_psi.inverse(out) 181 | else: 182 | x = out 183 | return x 184 | 185 | 186 | if __name__ == '__main__': 187 | model = iRevNet(nBlocks=[6, 16, 72, 6], nStrides=[2, 2, 2, 2], 188 | nChannels=None, nClasses=1000, init_ds=2, 189 | dropout_rate=0., affineBN=True, in_shape=[3, 224, 224], 190 | mult=4) 191 | y = model(torch.randn(1, 3, 224, 224)) 192 | print(y.size()) 193 | -------------------------------------------------------------------------------- /GradInverting/inversefed/nn/revnet_utils.py: -------------------------------------------------------------------------------- 1 | """https://github.com/jhjacobsen/pytorch-i-revnet/blob/master/models/model_utils.py. 2 | 3 | Code for "i-RevNet: Deep Invertible Networks" 4 | https://openreview.net/pdf?id=HJsjkMb0Z 5 | ICLR, 2018 6 | 7 | 8 | (c) Joern-Henrik Jacobsen, 2018 9 | """ 10 | 11 | """ 12 | MIT License 13 | 14 | Copyright (c) 2018 Jörn Jacobsen 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | """ 34 | 35 | import torch 36 | import torch.nn as nn 37 | 38 | from torch.nn import Parameter 39 | 40 | 41 | def split(x): 42 | n = int(x.size()[1] / 2) 43 | x1 = x[:, :n, :, :].contiguous() 44 | x2 = x[:, n:, :, :].contiguous() 45 | return x1, x2 46 | 47 | 48 | def merge(x1, x2): 49 | return torch.cat((x1, x2), 1) 50 | 51 | 52 | class injective_pad(nn.Module): 53 | def __init__(self, pad_size): 54 | super(injective_pad, self).__init__() 55 | self.pad_size = pad_size 56 | self.pad = nn.ZeroPad2d((0, 0, 0, pad_size)) 57 | 58 | def forward(self, x): 59 | x = x.permute(0, 2, 1, 3) 60 | x = self.pad(x) 61 | return x.permute(0, 2, 1, 3) 62 | 63 | def inverse(self, x): 64 | return x[:, :x.size(1) - self.pad_size, :, :] 65 | 66 | 67 | class psi(nn.Module): 68 | def __init__(self, block_size): 69 | super(psi, self).__init__() 70 | self.block_size = block_size 71 | self.block_size_sq = block_size * block_size 72 | 73 | def inverse(self, input): 74 | output = input.permute(0, 2, 3, 1) 75 | (batch_size, d_height, d_width, d_depth) = output.size() 76 | s_depth = int(d_depth / self.block_size_sq) 77 | s_width = int(d_width * self.block_size) 78 | s_height = int(d_height * self.block_size) 79 | t_1 = output.contiguous().view(batch_size, d_height, d_width, self.block_size_sq, s_depth) 80 | spl = t_1.split(self.block_size, 3) 81 | stack = [t_t.contiguous().view(batch_size, d_height, s_width, s_depth) for t_t in spl] 82 | output = torch.stack(stack, 0).transpose(0, 1).permute(0, 2, 1, 3, 4).contiguous().view(batch_size, s_height, s_width, s_depth) 83 | output = output.permute(0, 3, 1, 2) 84 | return output.contiguous() 85 | 86 | def forward(self, input): 87 | output = input.permute(0, 2, 3, 1) 88 | (batch_size, s_height, s_width, s_depth) = output.size() 89 | d_depth = s_depth * self.block_size_sq 90 | d_height = int(s_height / self.block_size) 91 | t_1 = output.split(self.block_size, 2) 92 | stack = [t_t.contiguous().view(batch_size, d_height, d_depth) for t_t in t_1] 93 | output = torch.stack(stack, 1) 94 | output = output.permute(0, 2, 1, 3) 95 | output = output.permute(0, 3, 1, 2) 96 | return output.contiguous() 97 | 98 | 99 | class ListModule(object): 100 | def __init__(self, module, prefix, *args): 101 | self.module = module 102 | self.prefix = prefix 103 | self.num_module = 0 104 | for new_module in args: 105 | self.append(new_module) 106 | 107 | def append(self, new_module): 108 | if not isinstance(new_module, nn.Module): 109 | raise ValueError('Not a Module') 110 | else: 111 | self.module.add_module(self.prefix + str(self.num_module), new_module) 112 | self.num_module += 1 113 | 114 | def __len__(self): 115 | return self.num_module 116 | 117 | def __getitem__(self, i): 118 | if i < 0 or i >= self.num_module: 119 | raise IndexError('Out of bound') 120 | return getattr(self.module, self.prefix + str(i)) 121 | 122 | 123 | def get_all_params(var, all_params): 124 | if isinstance(var, Parameter): 125 | all_params[id(var)] = var.nelement() 126 | elif hasattr(var, "creator") and var.creator is not None: 127 | if var.creator.previous_functions is not None: 128 | for j in var.creator.previous_functions: 129 | get_all_params(j[0], all_params) 130 | elif hasattr(var, "previous_functions"): 131 | for j in var.previous_functions: 132 | get_all_params(j[0], all_params) 133 | -------------------------------------------------------------------------------- /GradInverting/inversefed/optimization_strategy.py: -------------------------------------------------------------------------------- 1 | """Optimization setups.""" 2 | 3 | from dataclasses import dataclass 4 | 5 | 6 | def training_strategy(strategy, lr=None, epochs=None, dryrun=False): 7 | """Parse training strategy.""" 8 | if strategy == 'conservative': 9 | defs = ConservativeStrategy(lr, epochs, dryrun) 10 | elif strategy == 'adam': 11 | defs = AdamStrategy(lr, epochs, dryrun) 12 | else: 13 | raise ValueError('Unknown training strategy.') 14 | return defs 15 | 16 | 17 | @dataclass 18 | class Strategy: 19 | """Default usual parameters, not intended for parsing.""" 20 | 21 | epochs : int 22 | batch_size : int 23 | optimizer : str 24 | lr : float 25 | scheduler : str 26 | weight_decay : float 27 | validate : int 28 | warmup: bool 29 | dryrun : bool 30 | dropout : float 31 | augmentations : bool 32 | 33 | def __init__(self, lr=None, epochs=None, dryrun=False): 34 | """Defaulted parameters. Apply overwrites from args.""" 35 | if epochs is not None: 36 | self.epochs = epochs 37 | if lr is not None: 38 | self.lr = lr 39 | if dryrun: 40 | self.dryrun = dryrun 41 | self.validate = 10 42 | 43 | @dataclass 44 | class ConservativeStrategy(Strategy): 45 | """Default usual parameters, defines a config object.""" 46 | 47 | def __init__(self, lr=None, epochs=None, dryrun=False): 48 | """Initialize training hyperparameters.""" 49 | self.lr = 0.1 50 | self.epochs = 120 51 | self.batch_size = 128 52 | self.optimizer = 'SGD' 53 | self.scheduler = 'linear' 54 | self.warmup = False 55 | self.weight_decay : float = 5e-4 56 | self.dropout = 0.0 57 | self.augmentations = True 58 | self.dryrun = False 59 | super().__init__(lr=None, epochs=None, dryrun=False) 60 | 61 | 62 | @dataclass 63 | class AdamStrategy(Strategy): 64 | """Start slowly. Use a tame Adam.""" 65 | 66 | def __init__(self, lr=None, epochs=None, dryrun=False): 67 | """Initialize training hyperparameters.""" 68 | self.lr = 1e-3 / 10 69 | self.epochs = 120 70 | self.batch_size = 32 71 | self.optimizer = 'AdamW' 72 | self.scheduler = 'linear' 73 | self.warmup = True 74 | self.weight_decay : float = 5e-4 75 | self.dropout = 0.0 76 | self.augmentations = True 77 | self.dryrun = False 78 | super().__init__(lr=None, epochs=None, dryrun=False) 79 | -------------------------------------------------------------------------------- /GradInverting/inversefed/options.py: -------------------------------------------------------------------------------- 1 | """Parser options.""" 2 | 3 | import argparse 4 | 5 | def options(): 6 | """Construct the central argument parser, filled with useful defaults.""" 7 | parser = argparse.ArgumentParser(description='Reconstruct some image from a trained model.') 8 | 9 | # Central: 10 | parser.add_argument('--model', default='ConvNet', type=str, help='Vision model.') 11 | parser.add_argument('--dataset', default='CIFAR10', type=str) 12 | parser.add_argument('--dtype', default='float', type=str, help='Data type used during reconstruction [Not during training!].') 13 | 14 | 15 | parser.add_argument('--trained_model', action='store_true', help='Use a trained model.') 16 | parser.add_argument('--epochs', default=120, type=int, help='If using a trained model, how many epochs was it trained?') 17 | 18 | parser.add_argument('--accumulation', default=0, type=int, help='Accumulation 0 is rec. from gradient, accumulation > 0 is reconstruction from fed. averaging.') 19 | parser.add_argument('--num_images', default=1, type=int, help='How many images should be recovered from the given gradient.') 20 | parser.add_argument('--target_id', default=None, type=int, help='Cifar validation image used for reconstruction.') 21 | parser.add_argument('--label_flip', action='store_true', help='Dishonest server permuting weights in classification layer.') 22 | 23 | # Rec. parameters 24 | parser.add_argument('--optim', default='ours', type=str, help='Use our reconstruction method or the DLG method.') 25 | 26 | parser.add_argument('--restarts', default=1, type=int, help='How many restarts to run.') 27 | parser.add_argument('--cost_fn', default='sim', type=str, help='Choice of cost function.') 28 | parser.add_argument('--indices', default='def', type=str, help='Choice of indices from the parameter list.') 29 | parser.add_argument('--weights', default='equal', type=str, help='Weigh the parameter list differently.') 30 | 31 | parser.add_argument('--optimizer', default='adam', type=str, help='Weigh the parameter list differently.') 32 | parser.add_argument('--signed', action='store_false', help='Do not used signed gradients.') 33 | parser.add_argument('--boxed', action='store_false', help='Do not used box constraints.') 34 | 35 | parser.add_argument('--scoring_choice', default='loss', type=str, help='How to find the best image between all restarts.') 36 | parser.add_argument('--init', default='randn', type=str, help='Choice of image initialization.') 37 | parser.add_argument('--tv', default=1e-4, type=float, help='Weight of TV penalty.') 38 | 39 | 40 | # Files and folders: 41 | parser.add_argument('--save_image', action='store_true', help='Save the output to a file.') 42 | 43 | parser.add_argument('--image_path', default='images/', type=str) 44 | parser.add_argument('--model_path', default='models/', type=str) 45 | parser.add_argument('--table_path', default='tables/', type=str) 46 | parser.add_argument('--data_path', default='~/data', type=str) 47 | 48 | # Debugging: 49 | parser.add_argument('--name', default='iv', type=str, help='Name tag for the result table and model.') 50 | parser.add_argument('--deterministic', action='store_true', help='Disable CUDNN non-determinism.') 51 | parser.add_argument('--dryrun', action='store_true', help='Run everything for just one step to test functionality.') 52 | return parser 53 | -------------------------------------------------------------------------------- /GradInverting/inversefed/reconstruction_algorithms.py: -------------------------------------------------------------------------------- 1 | """Mechanisms for image reconstruction from parameter gradients.""" 2 | 3 | import torch 4 | from collections import defaultdict, OrderedDict 5 | from inversefed.nn import MetaMonkey 6 | from .metrics import total_variation as TV 7 | from .metrics import InceptionScore 8 | from .medianfilt import MedianPool2d 9 | from copy import deepcopy 10 | 11 | import time 12 | 13 | DEFAULT_CONFIG = dict(signed=False, 14 | boxed=True, 15 | cost_fn='sim', 16 | indices='def', 17 | weights='equal', 18 | lr=0.1, 19 | optim='adam', 20 | restarts=1, 21 | max_iterations=4800, 22 | total_variation=1e-1, 23 | init='randn', 24 | filter='none', 25 | lr_decay=True, 26 | scoring_choice='loss') 27 | 28 | def _label_to_onehot(target, num_classes=100): 29 | target = torch.unsqueeze(target, 1) 30 | onehot_target = torch.zeros(target.size(0), num_classes, device=target.device) 31 | onehot_target.scatter_(1, target, 1) 32 | return onehot_target 33 | 34 | def _validate_config(config): 35 | for key in DEFAULT_CONFIG.keys(): 36 | if config.get(key) is None: 37 | config[key] = DEFAULT_CONFIG[key] 38 | for key in config.keys(): 39 | if DEFAULT_CONFIG.get(key) is None: 40 | raise ValueError(f'Deprecated key in config dict: {key}!') 41 | return config 42 | 43 | 44 | class GradientReconstructor(): 45 | """Instantiate a reconstruction algorithm.""" 46 | 47 | def __init__(self, model, mean_std=(0.0, 1.0), config=DEFAULT_CONFIG, num_images=1): 48 | """Initialize with algorithm setup.""" 49 | self.config = _validate_config(config) 50 | self.model = model 51 | self.setup = dict(device=next(model.parameters()).device, dtype=next(model.parameters()).dtype) 52 | 53 | self.mean_std = mean_std 54 | self.num_images = num_images 55 | 56 | if self.config['scoring_choice'] == 'inception': 57 | self.inception = InceptionScore(batch_size=1, setup=self.setup) 58 | 59 | self.loss_fn = torch.nn.CrossEntropyLoss(reduction='mean') 60 | self.iDLG = True 61 | 62 | def reconstruct(self, input_data, labels, img_shape=(3, 32, 32), dryrun=False, eval=True, tol=None): 63 | """Reconstruct image from gradient.""" 64 | start_time = time.time() 65 | if eval: 66 | self.model.eval() 67 | 68 | 69 | stats = defaultdict(list) 70 | x = self._init_images(img_shape) 71 | scores = torch.zeros(self.config['restarts']) 72 | 73 | if labels is None: 74 | if self.num_images == 1 and self.iDLG: 75 | # iDLG trick: 76 | last_weight_min = torch.argmin(torch.sum(input_data[-2], dim=-1), dim=-1) 77 | labels = last_weight_min.detach().reshape((1,)).requires_grad_(False) 78 | self.reconstruct_label = False 79 | else: 80 | # DLG label recovery 81 | # However this also improves conditioning for some LBFGS cases 82 | self.reconstruct_label = True 83 | 84 | def loss_fn(pred, labels): 85 | labels = torch.nn.functional.softmax(labels, dim=-1) 86 | return torch.mean(torch.sum(- labels * torch.nn.functional.log_softmax(pred, dim=-1), 1)) 87 | self.loss_fn = loss_fn 88 | else: 89 | assert labels.shape[0] == self.num_images 90 | self.reconstruct_label = False 91 | 92 | try: 93 | for trial in range(self.config['restarts']): 94 | x_trial, labels = self._run_trial(x[trial], input_data, labels, dryrun=dryrun) 95 | # Finalize 96 | scores[trial] = self._score_trial(x_trial, input_data, labels) 97 | x[trial] = x_trial 98 | if tol is not None and scores[trial] <= tol: 99 | break 100 | if dryrun: 101 | break 102 | except KeyboardInterrupt: 103 | print('Trial procedure manually interruped.') 104 | pass 105 | 106 | # Choose optimal result: 107 | if self.config['scoring_choice'] in ['pixelmean', 'pixelmedian']: 108 | x_optimal, stats = self._average_trials(x, labels, input_data, stats) 109 | else: 110 | print('Choosing optimal result ...') 111 | scores = scores[torch.isfinite(scores)] # guard against NaN/-Inf scores? 112 | optimal_index = torch.argmin(scores) 113 | print(f'Optimal result score: {scores[optimal_index]:2.4f}') 114 | stats['opt'] = scores[optimal_index].item() 115 | x_optimal = x[optimal_index] 116 | 117 | print(f'Total time: {time.time()-start_time}.') 118 | return x_optimal.detach(), stats 119 | 120 | def _init_images(self, img_shape): 121 | if self.config['init'] == 'randn': 122 | return torch.randn((self.config['restarts'], self.num_images, *img_shape), **self.setup) 123 | elif self.config['init'] == 'rand': 124 | return (torch.rand((self.config['restarts'], self.num_images, *img_shape), **self.setup) - 0.5) * 2 125 | elif self.config['init'] == 'zeros': 126 | return torch.zeros((self.config['restarts'], self.num_images, *img_shape), **self.setup) 127 | else: 128 | raise ValueError() 129 | 130 | def _run_trial(self, x_trial, input_data, labels, dryrun=False): 131 | x_trial.requires_grad = True 132 | if self.reconstruct_label: 133 | output_test = self.model(x_trial) 134 | labels = torch.randn(output_test.shape[1]).to(**self.setup).requires_grad_(True) 135 | 136 | if self.config['optim'] == 'adam': 137 | optimizer = torch.optim.Adam([x_trial, labels], lr=self.config['lr']) 138 | elif self.config['optim'] == 'sgd': # actually gd 139 | optimizer = torch.optim.SGD([x_trial, labels], lr=0.01, momentum=0.9, nesterov=True) 140 | elif self.config['optim'] == 'LBFGS': 141 | optimizer = torch.optim.LBFGS([x_trial, labels]) 142 | else: 143 | raise ValueError() 144 | else: 145 | if self.config['optim'] == 'adam': 146 | optimizer = torch.optim.Adam([x_trial], lr=self.config['lr']) 147 | elif self.config['optim'] == 'sgd': # actually gd 148 | optimizer = torch.optim.SGD([x_trial], lr=0.01, momentum=0.9, nesterov=True) 149 | elif self.config['optim'] == 'LBFGS': 150 | optimizer = torch.optim.LBFGS([x_trial]) 151 | else: 152 | raise ValueError() 153 | 154 | max_iterations = self.config['max_iterations'] 155 | dm, ds = self.mean_std 156 | if self.config['lr_decay']: 157 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 158 | milestones=[max_iterations // 2.667, max_iterations // 1.6, 159 | 160 | max_iterations // 1.142], gamma=0.1) # 3/8 5/8 7/8 161 | try: 162 | for iteration in range(max_iterations): 163 | closure = self._gradient_closure(optimizer, x_trial, input_data, labels) 164 | rec_loss = optimizer.step(closure) 165 | if self.config['lr_decay']: 166 | scheduler.step() 167 | 168 | with torch.no_grad(): 169 | # Project into image space 170 | if self.config['boxed']: 171 | x_trial.data = torch.max(torch.min(x_trial, (1 - dm) / ds), -dm / ds) 172 | 173 | if (iteration + 1 == max_iterations) or iteration % 500 == 0: 174 | print(f'It: {iteration}. Rec. loss: {rec_loss.item():2.4f}.') 175 | 176 | if (iteration + 1) % 500 == 0: 177 | if self.config['filter'] == 'none': 178 | pass 179 | elif self.config['filter'] == 'median': 180 | x_trial.data = MedianPool2d(kernel_size=3, stride=1, padding=1, same=False)(x_trial) 181 | else: 182 | raise ValueError() 183 | 184 | if dryrun: 185 | break 186 | except KeyboardInterrupt: 187 | print(f'Recovery interrupted manually in iteration {iteration}!') 188 | pass 189 | return x_trial.detach(), labels 190 | 191 | def _gradient_closure(self, optimizer, x_trial, input_gradient, label): 192 | 193 | def closure(): 194 | optimizer.zero_grad() 195 | self.model.zero_grad() 196 | loss = self.loss_fn(self.model(x_trial), label) 197 | gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=True) 198 | rec_loss = reconstruction_costs([gradient], input_gradient, 199 | cost_fn=self.config['cost_fn'], indices=self.config['indices'], 200 | weights=self.config['weights']) 201 | 202 | if self.config['total_variation'] > 0: 203 | rec_loss += self.config['total_variation'] * TV(x_trial) 204 | rec_loss.backward() 205 | if self.config['signed']: 206 | x_trial.grad.sign_() 207 | return rec_loss 208 | return closure 209 | 210 | def _score_trial(self, x_trial, input_gradient, label): 211 | if self.config['scoring_choice'] == 'loss': 212 | self.model.zero_grad() 213 | x_trial.grad = None 214 | loss = self.loss_fn(self.model(x_trial), label) 215 | gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=False) 216 | return reconstruction_costs([gradient], input_gradient, 217 | cost_fn=self.config['cost_fn'], indices=self.config['indices'], 218 | weights=self.config['weights']) 219 | elif self.config['scoring_choice'] == 'tv': 220 | return TV(x_trial) 221 | elif self.config['scoring_choice'] == 'inception': 222 | # We do not care about diversity here! 223 | return self.inception(x_trial) 224 | elif self.config['scoring_choice'] in ['pixelmean', 'pixelmedian']: 225 | return 0.0 226 | else: 227 | raise ValueError() 228 | 229 | def _average_trials(self, x, labels, input_data, stats): 230 | print(f'Computing a combined result via {self.config["scoring_choice"]} ...') 231 | if self.config['scoring_choice'] == 'pixelmedian': 232 | x_optimal, _ = x.median(dim=0, keepdims=False) 233 | elif self.config['scoring_choice'] == 'pixelmean': 234 | x_optimal = x.mean(dim=0, keepdims=False) 235 | 236 | self.model.zero_grad() 237 | if self.reconstruct_label: 238 | labels = self.model(x_optimal).softmax(dim=1) 239 | loss = self.loss_fn(self.model(x_optimal), labels) 240 | gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=False) 241 | stats['opt'] = reconstruction_costs([gradient], input_data, 242 | cost_fn=self.config['cost_fn'], 243 | indices=self.config['indices'], 244 | weights=self.config['weights']) 245 | print(f'Optimal result score: {stats["opt"]:2.4f}') 246 | return x_optimal, stats 247 | 248 | 249 | 250 | class FedAvgReconstructor(GradientReconstructor): 251 | """Reconstruct an image from weights after n gradient descent steps.""" 252 | 253 | def __init__(self, model, mean_std=(0.0, 1.0), local_steps=2, local_lr=1e-4, 254 | config=DEFAULT_CONFIG, num_images=1, use_updates=True, batch_size=0): 255 | """Initialize with model, (mean, std) and config.""" 256 | super().__init__(model, mean_std, config, num_images) 257 | self.local_steps = local_steps 258 | self.local_lr = local_lr 259 | self.use_updates = use_updates 260 | self.batch_size = batch_size 261 | 262 | def _gradient_closure(self, optimizer, x_trial, input_parameters, labels): 263 | def closure(): 264 | optimizer.zero_grad() 265 | self.model.zero_grad() 266 | parameters = loss_steps(self.model, x_trial, labels, loss_fn=self.loss_fn, 267 | local_steps=self.local_steps, lr=self.local_lr, 268 | use_updates=self.use_updates, 269 | batch_size=self.batch_size) 270 | rec_loss = reconstruction_costs([parameters], input_parameters, 271 | cost_fn=self.config['cost_fn'], indices=self.config['indices'], 272 | weights=self.config['weights']) 273 | 274 | if self.config['total_variation'] > 0: 275 | rec_loss += self.config['total_variation'] * TV(x_trial) 276 | rec_loss.backward() 277 | if self.config['signed']: 278 | x_trial.grad.sign_() 279 | return rec_loss 280 | return closure 281 | 282 | def _score_trial(self, x_trial, input_parameters, labels): 283 | if self.config['scoring_choice'] == 'loss': 284 | self.model.zero_grad() 285 | parameters = loss_steps(self.model, x_trial, labels, loss_fn=self.loss_fn, 286 | local_steps=self.local_steps, lr=self.local_lr, use_updates=self.use_updates) 287 | return reconstruction_costs([parameters], input_parameters, 288 | cost_fn=self.config['cost_fn'], indices=self.config['indices'], 289 | weights=self.config['weights']) 290 | elif self.config['scoring_choice'] == 'tv': 291 | return TV(x_trial) 292 | elif self.config['scoring_choice'] == 'inception': 293 | # We do not care about diversity here! 294 | return self.inception(x_trial) 295 | 296 | 297 | def loss_steps(model, inputs, labels, loss_fn=torch.nn.CrossEntropyLoss(), lr=1e-4, local_steps=4, use_updates=True, batch_size=0): 298 | """Take a few gradient descent steps to fit the model to the given input.""" 299 | patched_model = MetaMonkey(model) 300 | if use_updates: 301 | patched_model_origin = deepcopy(patched_model) 302 | for i in range(local_steps): 303 | if batch_size == 0: 304 | outputs = patched_model(inputs, patched_model.parameters) 305 | labels_ = labels 306 | else: 307 | idx = i % (inputs.shape[0] // batch_size) 308 | outputs = patched_model(inputs[idx * batch_size:(idx + 1) * batch_size], patched_model.parameters) 309 | labels_ = labels[idx * batch_size:(idx + 1) * batch_size] 310 | loss = loss_fn(outputs, labels_).sum() 311 | grad = torch.autograd.grad(loss, patched_model.parameters.values(), 312 | retain_graph=True, create_graph=True, only_inputs=True) 313 | 314 | patched_model.parameters = OrderedDict((name, param - lr * grad_part) 315 | for ((name, param), grad_part) 316 | in zip(patched_model.parameters.items(), grad)) 317 | 318 | if use_updates: 319 | patched_model.parameters = OrderedDict((name, param - param_origin) 320 | for ((name, param), (name_origin, param_origin)) 321 | in zip(patched_model.parameters.items(), patched_model_origin.parameters.items())) 322 | return list(patched_model.parameters.values()) 323 | 324 | 325 | def reconstruction_costs(gradients, input_gradient, cost_fn='l2', indices='def', weights='equal'): 326 | """Input gradient is given data.""" 327 | if isinstance(indices, list): 328 | pass 329 | elif indices == 'def': 330 | indices = torch.arange(len(input_gradient)) 331 | elif indices == 'batch': 332 | indices = torch.randperm(len(input_gradient))[:8] 333 | elif indices == 'topk-1': 334 | _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 4) 335 | elif indices == 'top10': 336 | _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 10) 337 | elif indices == 'top50': 338 | _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 50) 339 | elif indices in ['first', 'first4']: 340 | indices = torch.arange(0, 4) 341 | elif indices == 'first5': 342 | indices = torch.arange(0, 5) 343 | elif indices == 'first10': 344 | indices = torch.arange(0, 10) 345 | elif indices == 'first50': 346 | indices = torch.arange(0, 50) 347 | elif indices == 'last5': 348 | indices = torch.arange(len(input_gradient))[-5:] 349 | elif indices == 'last10': 350 | indices = torch.arange(len(input_gradient))[-10:] 351 | elif indices == 'last50': 352 | indices = torch.arange(len(input_gradient))[-50:] 353 | else: 354 | raise ValueError() 355 | 356 | ex = input_gradient[0] 357 | if weights == 'linear': 358 | weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device) / len(input_gradient) 359 | elif weights == 'exp': 360 | weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device) 361 | weights = weights.softmax(dim=0) 362 | weights = weights / weights[0] 363 | else: 364 | weights = input_gradient[0].new_ones(len(input_gradient)) 365 | 366 | total_costs = 0 367 | for trial_gradient in gradients: 368 | pnorm = [0, 0] 369 | costs = 0 370 | if indices == 'topk-2': 371 | _, indices = torch.topk(torch.stack([p.norm().detach() for p in trial_gradient], dim=0), 4) 372 | for i in indices: 373 | if cost_fn == 'l2': 374 | costs += ((trial_gradient[i] - input_gradient[i]).pow(2)).sum() * weights[i] 375 | elif cost_fn == 'l1': 376 | costs += ((trial_gradient[i] - input_gradient[i]).abs()).sum() * weights[i] 377 | elif cost_fn == 'max': 378 | costs += ((trial_gradient[i] - input_gradient[i]).abs()).max() * weights[i] 379 | elif cost_fn == 'sim': 380 | costs -= (trial_gradient[i] * input_gradient[i]).sum() * weights[i] 381 | pnorm[0] += trial_gradient[i].pow(2).sum() * weights[i] 382 | pnorm[1] += input_gradient[i].pow(2).sum() * weights[i] 383 | elif cost_fn == 'simlocal': 384 | costs += 1 - torch.nn.functional.cosine_similarity(trial_gradient[i].flatten(), 385 | input_gradient[i].flatten(), 386 | 0, 1e-10) * weights[i] 387 | if cost_fn == 'sim': 388 | costs = 1 + costs / pnorm[0].sqrt() / pnorm[1].sqrt() 389 | 390 | # Accumulate final costs 391 | total_costs += costs 392 | return total_costs / len(gradients) 393 | -------------------------------------------------------------------------------- /GradInverting/inversefed/training/README.md: -------------------------------------------------------------------------------- 1 | # Training routines are implemented here -------------------------------------------------------------------------------- /GradInverting/inversefed/training/__init__.py: -------------------------------------------------------------------------------- 1 | """Basic training routines and loss functions.""" 2 | 3 | from .training_routine import train 4 | 5 | __all__ = ['train'] 6 | -------------------------------------------------------------------------------- /GradInverting/inversefed/training/scheduler.py: -------------------------------------------------------------------------------- 1 | """This file is part of https://github.com/ildoonet/pytorch-gradual-warmup-lr. 2 | 3 | MIT License 4 | 5 | Copyright (c) 2019 Ildoo Kim 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | 25 | """ 26 | 27 | from torch.optim.lr_scheduler import _LRScheduler 28 | from torch.optim.lr_scheduler import ReduceLROnPlateau 29 | 30 | 31 | class GradualWarmupScheduler(_LRScheduler): 32 | """Gradually warm-up(increasing) learning rate in optimizer. 33 | 34 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 35 | 36 | Args: 37 | optimizer (Optimizer): Wrapped optimizer. 38 | multiplier: target learning rate = base lr * multiplier 39 | total_epoch: target learning rate is reached at total_epoch, gradually 40 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 41 | 42 | """ 43 | 44 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 45 | """Initialize the warm-up start. 46 | 47 | Usage: 48 | 49 | scheduler_normal = torch.optim.lr_scheduler.MultiStepLR(optimizer) 50 | scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=10, after_scheduler=scheduler_normal) 51 | """ 52 | self.multiplier = multiplier 53 | if self.multiplier < 1.: 54 | raise ValueError('multiplier should be greater thant or equal to 1.') 55 | self.total_epoch = total_epoch 56 | self.after_scheduler = after_scheduler 57 | self.finished = False 58 | super().__init__(optimizer) 59 | 60 | def get_lr(self): 61 | if self.last_epoch > self.total_epoch: 62 | if self.after_scheduler: 63 | if not self.finished: 64 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 65 | self.finished = True 66 | return self.after_scheduler.get_lr() 67 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 68 | 69 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 70 | 71 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 72 | if epoch is None: 73 | epoch = self.last_epoch + 1 74 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 75 | if self.last_epoch <= self.total_epoch: 76 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 77 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 78 | param_group['lr'] = lr 79 | else: 80 | if epoch is None: 81 | self.after_scheduler.step(metrics, None) 82 | else: 83 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 84 | 85 | def step(self, epoch=None, metrics=None): 86 | if type(self.after_scheduler) != ReduceLROnPlateau: 87 | if self.finished and self.after_scheduler: 88 | if epoch is None: 89 | self.after_scheduler.step(None) 90 | else: 91 | self.after_scheduler.step(epoch - self.total_epoch) 92 | else: 93 | return super(GradualWarmupScheduler, self).step(epoch) 94 | else: 95 | self.step_ReduceLROnPlateau(metrics, epoch) 96 | -------------------------------------------------------------------------------- /GradInverting/inversefed/training/training_routine.py: -------------------------------------------------------------------------------- 1 | """Implement the .train function.""" 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from collections import defaultdict 7 | 8 | from .scheduler import GradualWarmupScheduler 9 | 10 | from ..consts import BENCHMARK, NON_BLOCKING 11 | torch.backends.cudnn.benchmark = BENCHMARK 12 | 13 | def train(model, loss_fn, trainloader, validloader, defs, setup=dict(dtype=torch.float, device=torch.device('cpu'))): 14 | """Run the main interface. Train a network with specifications from the Strategy object.""" 15 | stats = defaultdict(list) 16 | optimizer, scheduler = set_optimizer(model, defs) 17 | 18 | for epoch in range(defs.epochs): 19 | model.train() 20 | step(model, loss_fn, trainloader, optimizer, scheduler, defs, setup, stats) 21 | 22 | if epoch % defs.validate == 0 or epoch == (defs.epochs - 1): 23 | model.eval() 24 | validate(model, loss_fn, validloader, defs, setup, stats) 25 | # Print information about loss and accuracy 26 | print_status(epoch, loss_fn, optimizer, stats) 27 | 28 | if defs.dryrun: 29 | break 30 | if not (np.isfinite(stats['train_losses'][-1])): 31 | print('Loss is NaN/Inf ... terminating early ...') 32 | break 33 | 34 | return stats 35 | 36 | def step(model, loss_fn, dataloader, optimizer, scheduler, defs, setup, stats): 37 | """Step through one epoch.""" 38 | epoch_loss, epoch_metric = 0, 0 39 | for batch, (inputs, targets) in enumerate(dataloader): 40 | # Prep Mini-Batch 41 | optimizer.zero_grad() 42 | 43 | # Transfer to GPU 44 | inputs = inputs.to(**setup) 45 | targets = targets.to(device=setup['device'], non_blocking=NON_BLOCKING) 46 | 47 | # Get loss 48 | outputs = model(inputs) 49 | loss, _, _ = loss_fn(outputs, targets) 50 | 51 | 52 | epoch_loss += loss.item() 53 | 54 | loss.backward() 55 | optimizer.step() 56 | 57 | metric, name, _ = loss_fn.metric(outputs, targets) 58 | epoch_metric += metric.item() 59 | 60 | if defs.scheduler == 'cyclic': 61 | scheduler.step() 62 | if defs.dryrun: 63 | break 64 | if defs.scheduler == 'linear': 65 | scheduler.step() 66 | 67 | stats['train_losses'].append(epoch_loss / (batch + 1)) 68 | stats['train_' + name].append(epoch_metric / (batch + 1)) 69 | 70 | 71 | def validate(model, loss_fn, dataloader, defs, setup, stats): 72 | """Validate model effectiveness of val dataset.""" 73 | epoch_loss, epoch_metric = 0, 0 74 | with torch.no_grad(): 75 | for batch, (inputs, targets) in enumerate(dataloader): 76 | # Transfer to GPU 77 | inputs = inputs.to(**setup) 78 | targets = targets.to(device=setup['device'], non_blocking=NON_BLOCKING) 79 | 80 | # Get loss and metric 81 | outputs = model(inputs) 82 | loss, _, _ = loss_fn(outputs, targets) 83 | metric, name, _ = loss_fn.metric(outputs, targets) 84 | 85 | epoch_loss += loss.item() 86 | epoch_metric += metric.item() 87 | 88 | if defs.dryrun: 89 | break 90 | 91 | stats['valid_losses'].append(epoch_loss / (batch + 1)) 92 | stats['valid_' + name].append(epoch_metric / (batch + 1)) 93 | 94 | def set_optimizer(model, defs): 95 | """Build model optimizer and scheduler from defs. 96 | 97 | The linear scheduler drops the learning rate in intervals. 98 | # Example: epochs=160 leads to drops at 60, 100, 140. 99 | """ 100 | if defs.optimizer == 'SGD': 101 | optimizer = torch.optim.SGD(model.parameters(), lr=defs.lr, momentum=0.9, 102 | weight_decay=defs.weight_decay, nesterov=True) 103 | elif defs.optimizer == 'AdamW': 104 | optimizer = torch.optim.AdamW(model.parameters(), lr=defs.lr, weight_decay=defs.weight_decay) 105 | 106 | if defs.scheduler == 'linear': 107 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 108 | milestones=[120 // 2.667, 120 // 1.6, 109 | 120 // 1.142], gamma=0.1) 110 | # Scheduler is fixed to 120 epochs so that calls with fewer epochs are equal in lr drops. 111 | 112 | if defs.warmup: 113 | scheduler = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch=10, after_scheduler=scheduler) 114 | 115 | return optimizer, scheduler 116 | 117 | 118 | def print_status(epoch, loss_fn, optimizer, stats): 119 | """Print basic console printout every defs.validation epochs.""" 120 | current_lr = optimizer.param_groups[0]['lr'] 121 | name, format = loss_fn.metric() 122 | print(f'Epoch: {epoch}| lr: {current_lr:.4f} | ' 123 | f'Train loss is {stats["train_losses"][-1]:6.4f}, Train {name}: {stats["train_" + name][-1]:{format}} | ' 124 | f'Val loss is {stats["valid_losses"][-1]:6.4f}, Val {name}: {stats["valid_" + name][-1]:{format}} |') 125 | -------------------------------------------------------------------------------- /GradInverting/inversefed/utils.py: -------------------------------------------------------------------------------- 1 | """Various utilities.""" 2 | 3 | import os 4 | import csv 5 | 6 | import torch 7 | import random 8 | import numpy as np 9 | 10 | import socket 11 | import datetime 12 | 13 | 14 | def system_startup(args=None, defs=None): 15 | """Print useful system information.""" 16 | # Choose GPU device and print status information: 17 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 18 | setup = dict(device=device, dtype=torch.float) # non_blocking=NON_BLOCKING 19 | print('Currently evaluating -------------------------------:') 20 | print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) 21 | print(f'CPUs: {torch.get_num_threads()}, GPUs: {torch.cuda.device_count()} on {socket.gethostname()}.') 22 | if args is not None: 23 | print(args) 24 | if defs is not None: 25 | print(repr(defs)) 26 | if torch.cuda.is_available(): 27 | print(f'GPU : {torch.cuda.get_device_name(device=device)}') 28 | return setup 29 | 30 | def save_to_table(out_dir, name, dryrun, **kwargs): 31 | """Save keys to .csv files. Function adapted from Micah.""" 32 | # Check for file 33 | if not os.path.isdir(out_dir): 34 | os.makedirs(out_dir) 35 | fname = os.path.join(out_dir, f'table_{name}.csv') 36 | fieldnames = list(kwargs.keys()) 37 | 38 | # Read or write header 39 | try: 40 | with open(fname, 'r') as f: 41 | reader = csv.reader(f, delimiter='\t') 42 | header = [line for line in reader][0] 43 | except Exception as e: 44 | print('Creating a new .csv table...') 45 | with open(fname, 'w') as f: 46 | writer = csv.DictWriter(f, delimiter='\t', fieldnames=fieldnames) 47 | writer.writeheader() 48 | if not dryrun: 49 | # Add row for this experiment 50 | with open(fname, 'a') as f: 51 | writer = csv.DictWriter(f, delimiter='\t', fieldnames=fieldnames) 52 | writer.writerow(kwargs) 53 | print('\nResults saved to ' + fname + '.') 54 | else: 55 | print(f'Would save results to {fname}.') 56 | print(f'Would save these keys: {fieldnames}.') 57 | 58 | def set_random_seed(seed=233): 59 | """233 = 144 + 89 is my favorite number.""" 60 | torch.manual_seed(seed + 1) 61 | torch.cuda.manual_seed(seed + 2) 62 | torch.cuda.manual_seed_all(seed + 3) 63 | np.random.seed(seed + 4) 64 | torch.cuda.manual_seed_all(seed + 5) 65 | random.seed(seed + 6) 66 | 67 | def set_deterministic(): 68 | """Switch pytorch into a deterministic computation mode.""" 69 | torch.backends.cudnn.deterministic = True 70 | torch.backends.cudnn.benchmark = False 71 | -------------------------------------------------------------------------------- /GradInverting/rec_mult.py: -------------------------------------------------------------------------------- 1 | """Run reconstruction in a terminal prompt. 2 | Optional arguments can be found in inversefed/options.py 3 | 4 | This CLI can recover the baseline experiments. 5 | """ 6 | 7 | import torch 8 | import torchvision 9 | 10 | import numpy as np 11 | 12 | import inversefed 13 | torch.backends.cudnn.benchmark = inversefed.consts.BENCHMARK 14 | 15 | from collections import defaultdict 16 | import datetime 17 | import time 18 | import os 19 | import json 20 | import hashlib 21 | import csv 22 | 23 | 24 | # Parse input arguments 25 | parser = inversefed.options() 26 | parser.add_argument('--unsigned', action='store_true', help='Use signed gradient descent') 27 | parser.add_argument('--soft_labels', action='store_true', help='Do not use the provided label when using L-BFGS (This can stabilize it).') 28 | parser.add_argument('--lr', default=None, type=float, help='Optionally overwrite default step sizes.') 29 | parser.add_argument('--num_exp', default=10, type=int, help='Number of consecutive experiments') 30 | parser.add_argument('--max_iterations', default=4800, type=int, help='Maximum number of iterations for reconstruction.') 31 | parser.add_argument('--batch_size', default=0, type=int, help='Number of mini batch for federated averaging') 32 | parser.add_argument('--local_lr', default=1e-4, type=float, help='Local learning rate for federated averaging') 33 | args = parser.parse_args() 34 | if args.target_id is None: 35 | args.target_id = 0 36 | args.save_image = True 37 | args.signed = not args.unsigned 38 | 39 | 40 | # Parse training strategy 41 | defs = inversefed.training_strategy('conservative') 42 | defs.epochs = args.epochs 43 | # 100% reproducibility? 44 | if args.deterministic: 45 | image2graph2vec.utils.set_deterministic() 46 | 47 | 48 | if __name__ == "__main__": 49 | # Choose GPU device and print status information: 50 | setup = inversefed.utils.system_startup(args) 51 | start_time = time.time() 52 | 53 | # Prepare for training 54 | 55 | # Get data: 56 | loss_fn, trainloader, validloader = inversefed.construct_dataloaders(args.dataset, defs) 57 | 58 | model, model_seed = inversefed.construct_model(args.model, num_classes=10, num_channels=3) 59 | dm = torch.as_tensor(getattr(inversefed.consts, f'{args.dataset.lower()}_mean'), **setup)[:, None, None] 60 | ds = torch.as_tensor(getattr(inversefed.consts, f'{args.dataset.lower()}_std'), **setup)[:, None, None] 61 | model.to(**setup) 62 | model.eval() 63 | 64 | # Load a trained model? 65 | if args.trained_model: 66 | file = f'{args.model}_{args.epochs}.pth' 67 | try: 68 | model.load_state_dict(torch.load(os.path.join(args.model_path, file), map_location=setup['device'])) 69 | print(f'Model loaded from file {file}.') 70 | except FileNotFoundError: 71 | print('Training the model ...') 72 | print(repr(defs)) 73 | inversefed.train(model, loss_fn, trainloader, validloader, defs, setup=setup) 74 | torch.save(model.state_dict(), os.path.join(args.model_path, file)) 75 | 76 | # Sanity check: Validate model accuracy 77 | training_stats = defaultdict(list) 78 | inversefed.training.training_routine.validate(model, loss_fn, validloader, defs, setup, training_stats) 79 | name, format = loss_fn.metric() 80 | print(f'Val loss is {training_stats["valid_losses"][-1]:6.4f}, Val {name}: {training_stats["valid_" + name][-1]:{format}}.') 81 | 82 | if args.optim == 'ours': 83 | config = dict(signed=args.signed, 84 | boxed=True, 85 | cost_fn=args.cost_fn, 86 | indices=args.indices, 87 | weights=args.weights, 88 | lr=args.lr if args.lr is not None else 0.1, 89 | optim='adam', 90 | restarts=args.restarts, 91 | max_iterations=args.max_iterations, 92 | total_variation=args.tv, 93 | init=args.init, 94 | filter='none', 95 | lr_decay=True, 96 | scoring_choice=args.scoring_choice) 97 | elif args.optim == 'zhu': 98 | config = dict(signed=False, 99 | boxed=False, 100 | cost_fn='l2', 101 | indices='def', 102 | weights='equal', 103 | lr=args.lr if args.lr is not None else 1.0, 104 | optim='LBFGS', 105 | restarts=args.restarts, 106 | max_iterations=500, 107 | total_variation=args.tv, 108 | init=args.init, 109 | filter='none', 110 | lr_decay=False, 111 | scoring_choice=args.scoring_choice) 112 | 113 | # psnr list 114 | psnrs = [] 115 | 116 | # hash configuration 117 | 118 | config_comp = config.copy() 119 | config_comp['dataset'] = args.dataset 120 | config_comp['model'] = args.model 121 | config_comp['trained'] = args.trained_model 122 | config_comp['num_exp'] = args.num_exp 123 | config_comp['num_images'] = args.num_images 124 | config_comp['accumulation'] = args.accumulation 125 | config_comp['batch_size'] = args.batch_size 126 | config_comp['local_lr'] = args.trained_model 127 | config_comp['soft_labels'] = args.soft_labels 128 | config_hash = hashlib.md5(json.dumps(config_comp, sort_keys=True).encode()).hexdigest() 129 | 130 | print(config_comp) 131 | 132 | os.makedirs('results', exist_ok=True) 133 | os.makedirs(f'results/{config_hash}', exist_ok=True) 134 | 135 | 136 | target_id = args.target_id 137 | for i in range(args.num_exp): 138 | target_id = args.target_id + i * args.num_images 139 | if args.num_images == 1: 140 | ground_truth, labels = validloader.dataset[target_id] 141 | if args.label_flip: 142 | labels = torch.randint((10,)) 143 | ground_truth, labels = ground_truth.unsqueeze(0).to(**setup), torch.as_tensor((labels,), device=setup['device']) 144 | target_id_ = target_id + 1 145 | else: 146 | ground_truth, labels = [], [] 147 | target_id_ = target_id 148 | while len(labels) < args.num_images: 149 | img, label = validloader.dataset[target_id_] 150 | target_id_ += 1 151 | if label not in labels: 152 | labels.append(torch.as_tensor((label,), device=setup['device'])) 153 | ground_truth.append(img.to(**setup)) 154 | 155 | ground_truth = torch.stack(ground_truth) 156 | labels = torch.cat(labels) 157 | if args.label_flip: 158 | labels = torch.permute(labels) 159 | img_shape = (3, ground_truth.shape[2], ground_truth.shape[3]) 160 | 161 | # Run reconstruction 162 | if args.accumulation == 0: 163 | target_loss, _, _ = loss_fn(model(ground_truth), labels) 164 | input_gradient = torch.autograd.grad(target_loss, model.parameters()) 165 | input_gradient = [grad.detach() for grad in input_gradient] 166 | 167 | # Run reconstruction in different precision? 168 | if args.dtype != 'float': 169 | if args.dtype in ['double', 'float64']: 170 | setup['dtype'] = torch.double 171 | elif args.dtype in ['half', 'float16']: 172 | setup['dtype'] = torch.half 173 | else: 174 | raise ValueError(f'Unknown data type argument {args.dtype}.') 175 | print(f'Model and input parameter moved to {args.dtype}-precision.') 176 | dm = torch.as_tensor(inversefed.consts.cifar10_mean, **setup)[:, None, None] 177 | ds = torch.as_tensor(inversefed.consts.cifar10_std, **setup)[:, None, None] 178 | ground_truth = ground_truth.to(**setup) 179 | input_gradient = [g.to(**setup) for g in input_gradient] 180 | model.to(**setup) 181 | model.eval() 182 | 183 | rec_machine = inversefed.GradientReconstructor(model, (dm, ds), config, num_images=args.num_images) 184 | 185 | if args.optim == 'zhu' and args.soft_labels: 186 | rec_machine.iDLG = False 187 | output, stats = rec_machine.reconstruct(input_gradient, None, img_shape=img_shape, dryrun=args.dryrun) 188 | else: 189 | output, stats = rec_machine.reconstruct(input_gradient, labels, img_shape=img_shape, dryrun=args.dryrun) 190 | 191 | else: 192 | local_gradient_steps = args.accumulation 193 | local_lr = args.local_lr 194 | batch_size = args.batch_size 195 | input_parameters = inversefed.reconstruction_algorithms.loss_steps(model, ground_truth, 196 | labels, 197 | lr=local_lr, 198 | local_steps=local_gradient_steps, use_updates=True, batch_size=batch_size) 199 | input_parameters = [p.detach() for p in input_parameters] 200 | 201 | # Run reconstruction in different precision? 202 | if args.dtype != 'float': 203 | if args.dtype in ['double', 'float64']: 204 | setup['dtype'] = torch.double 205 | elif args.dtype in ['half', 'float16']: 206 | setup['dtype'] = torch.half 207 | else: 208 | raise ValueError(f'Unknown data type argument {args.dtype}.') 209 | print(f'Model and input parameter moved to {args.dtype}-precision.') 210 | ground_truth = ground_truth.to(**setup) 211 | dm = torch.as_tensor(inversefed.consts.cifar10_mean, **setup)[:, None, None] 212 | ds = torch.as_tensor(inversefed.consts.cifar10_std, **setup)[:, None, None] 213 | input_parameters = [g.to(**setup) for g in input_parameters] 214 | model.to(**setup) 215 | model.eval() 216 | 217 | rec_machine = inversefed.FedAvgReconstructor(model, (dm, ds), local_gradient_steps, 218 | local_lr, config, 219 | num_images=args.num_images, use_updates=True, 220 | batch_size=batch_size) 221 | output, stats = rec_machine.reconstruct(input_parameters, labels, img_shape=img_shape, dryrun=args.dryrun) 222 | 223 | 224 | 225 | # Compute stats and save to a table: 226 | output_den = torch.clamp(output * ds + dm, 0, 1) 227 | ground_truth_den = torch.clamp(ground_truth * ds + dm, 0, 1) 228 | feat_mse = (model(output) - model(ground_truth)).pow(2).mean().item() 229 | test_mse = (output_den - ground_truth_den).pow(2).mean().item() 230 | test_psnr = inversefed.metrics.psnr(output_den, ground_truth_den, factor=1) 231 | print(f"Rec. loss: {stats['opt']:2.4f} | MSE: {test_mse:2.4f} | PSNR: {test_psnr:4.2f} | FMSE: {feat_mse:2.4e} |") 232 | 233 | inversefed.utils.save_to_table(f'results/{config_hash}', name=f'mul_exp_{args.name}', dryrun=args.dryrun, 234 | 235 | config_hash=config_hash, 236 | model=args.model, 237 | dataset=args.dataset, 238 | trained=args.trained_model, 239 | accumulation=args.accumulation, 240 | restarts=args.restarts, 241 | OPTIM=args.optim, 242 | cost_fn=args.cost_fn, 243 | indices=args.indices, 244 | weights=args.weights, 245 | scoring=args.scoring_choice, 246 | init=args.init, 247 | tv=args.tv, 248 | 249 | rec_loss=stats["opt"], 250 | psnr=test_psnr, 251 | test_mse=test_mse, 252 | feat_mse=feat_mse, 253 | 254 | target_id=target_id, 255 | seed=model_seed, 256 | dtype=setup['dtype'], 257 | epochs=defs.epochs, 258 | val_acc=training_stats["valid_" + name][-1], 259 | ) 260 | 261 | 262 | # Save the resulting image 263 | if args.save_image and not args.dryrun: 264 | output_denormalized = torch.clamp(output * ds + dm, 0, 1) 265 | for j in range(args.num_images): 266 | filename = (f'{i*args.num_images+j}.png') 267 | 268 | torchvision.utils.save_image(output_denormalized[j:j + 1, ...], 269 | os.path.join(f'results/{config_hash}', filename)) 270 | 271 | # Save psnr values 272 | psnrs.append(test_psnr) 273 | inversefed.utils.save_to_table(f'results/{config_hash}', name='psnrs', dryrun=args.dryrun, target_id=target_id, psnr=test_psnr) 274 | 275 | # Update target id 276 | target_id = target_id_ 277 | 278 | 279 | # psnr statistics 280 | psnrs = np.nan_to_num(np.array(psnrs)) 281 | psnr_mean = psnrs.mean() 282 | psnr_std = np.std(psnrs) 283 | psnr_max = psnrs.max() 284 | psnr_min = psnrs.min() 285 | psnr_median = np.median(psnrs) 286 | timing = datetime.timedelta(seconds=time.time() - start_time) 287 | inversefed.utils.save_to_table(f'results/{config_hash}', name='psnr_stats', dryrun=args.dryrun, 288 | number_of_samples=len(psnrs), 289 | timing=str(timing), 290 | mean=psnr_mean, 291 | std=psnr_std, 292 | max=psnr_max, 293 | min=psnr_min, 294 | median=psnr_median) 295 | 296 | config_exists = False 297 | if os.path.isfile('results/table_configs.csv'): 298 | with open('results/table_configs.csv') as csvfile: 299 | reader = csv.reader(csvfile, delimiter='\t') 300 | for row in reader: 301 | if row[-1] == config_hash: 302 | config_exists = True 303 | break 304 | 305 | if not config_exists: 306 | inversefed.utils.save_to_table('results', name='configs', dryrun=args.dryrun, 307 | config_hash=config_hash, 308 | **config_comp, 309 | number_of_samples=len(psnrs), 310 | timing=str(timing), 311 | mean=psnr_mean, 312 | std=psnr_std, 313 | max=psnr_max, 314 | min=psnr_min, 315 | median=psnr_median) 316 | 317 | # Print final timestamp 318 | print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) 319 | print('---------------------------------------------------') 320 | print(f'Finished computations with time: {str(datetime.timedelta(seconds=time.time() - start_time))}') 321 | print('-------------Job finished.-------------------------') 322 | -------------------------------------------------------------------------------- /GradInverting/reconstruct_image.py: -------------------------------------------------------------------------------- 1 | """Run reconstruction in a terminal prompt. 2 | 3 | Optional arguments can be found in inversefed/options.py 4 | """ 5 | 6 | import torch 7 | import torchvision 8 | 9 | import numpy as np 10 | from PIL import Image 11 | 12 | import inversefed 13 | 14 | from collections import defaultdict 15 | import datetime 16 | import time 17 | import os 18 | 19 | torch.backends.cudnn.benchmark = inversefed.consts.BENCHMARK 20 | 21 | # Parse input arguments 22 | args = inversefed.options().parse_args() 23 | # Parse training strategy 24 | defs = inversefed.training_strategy('conservative') 25 | defs.epochs = args.epochs 26 | # 100% reproducibility? 27 | if args.deterministic: 28 | inversefed.utils.set_deterministic() 29 | 30 | 31 | if __name__ == "__main__": 32 | # Choose GPU device and print status information: 33 | setup = inversefed.utils.system_startup(args) 34 | start_time = time.time() 35 | 36 | # Prepare for training 37 | 38 | # Get data: 39 | loss_fn, trainloader, validloader = inversefed.construct_dataloaders(args.dataset, defs, data_path=args.data_path) 40 | 41 | dm = torch.as_tensor(getattr(inversefed.consts, f'{args.dataset.lower()}_mean'), **setup)[:, None, None] 42 | ds = torch.as_tensor(getattr(inversefed.consts, f'{args.dataset.lower()}_std'), **setup)[:, None, None] 43 | 44 | if args.dataset == 'ImageNet': 45 | if args.model == 'ResNet152': 46 | model = torchvision.models.resnet152(pretrained=args.trained_model) 47 | else: 48 | model = torchvision.models.resnet18(pretrained=args.trained_model) 49 | model_seed = None 50 | else: 51 | model, model_seed = inversefed.construct_model(args.model, num_classes=10, num_channels=3) 52 | model.to(**setup) 53 | model.eval() 54 | 55 | # Sanity check: Validate model accuracy 56 | training_stats = defaultdict(list) 57 | # inversefed.training.training_routine.validate(model, loss_fn, validloader, defs, setup, training_stats) 58 | # name, format = loss_fn.metric() 59 | # print(f'Val loss is {training_stats["valid_losses"][-1]:6.4f}, Val {name}: {training_stats["valid_" + name][-1]:{format}}.') 60 | 61 | # Choose example images from the validation set or from third-party sources 62 | if args.num_images == 1: 63 | if args.target_id == -1: # demo image 64 | # Specify PIL filter for lower pillow versions 65 | ground_truth = torch.as_tensor(np.array(Image.open("auto.jpg").resize((32, 32), Image.BICUBIC)) / 255, **setup) 66 | ground_truth = ground_truth.permute(2, 0, 1).sub(dm).div(ds).unsqueeze(0).contiguous() 67 | if not args.label_flip: 68 | labels = torch.as_tensor((1,), device=setup['device']) 69 | else: 70 | labels = torch.as_tensor((5,), device=setup['device']) 71 | target_id = -1 72 | else: 73 | if args.target_id is None: 74 | target_id = np.random.randint(len(validloader.dataset)) 75 | else: 76 | target_id = args.target_id 77 | ground_truth, labels = validloader.dataset[target_id] 78 | if args.label_flip: 79 | labels = torch.randint((10,)) 80 | ground_truth, labels = ground_truth.unsqueeze(0).to(**setup), torch.as_tensor((labels,), device=setup['device']) 81 | img_shape = (3, ground_truth.shape[2], ground_truth.shape[3]) 82 | 83 | else: 84 | ground_truth, labels = [], [] 85 | if args.target_id is None: 86 | target_id = np.random.randint(len(validloader.dataset)) 87 | else: 88 | target_id = args.target_id 89 | while len(labels) < args.num_images: 90 | img, label = validloader.dataset[target_id] 91 | target_id += 1 92 | if label not in labels: 93 | labels.append(torch.as_tensor((label,), device=setup['device'])) 94 | ground_truth.append(img.to(**setup)) 95 | 96 | ground_truth = torch.stack(ground_truth) 97 | labels = torch.cat(labels) 98 | if args.label_flip: 99 | labels = torch.permute(labels) 100 | img_shape = (3, ground_truth.shape[2], ground_truth.shape[3]) 101 | 102 | # Run reconstruction 103 | if args.accumulation == 0: 104 | model.zero_grad() 105 | target_loss, _, _ = loss_fn(model(ground_truth), labels) 106 | input_gradient = torch.autograd.grad(target_loss, model.parameters()) 107 | input_gradient = [grad.detach() for grad in input_gradient] 108 | full_norm = torch.stack([g.norm() for g in input_gradient]).mean() 109 | print(f'Full gradient norm is {full_norm:e}.') 110 | 111 | # Run reconstruction in different precision? 112 | if args.dtype != 'float': 113 | if args.dtype in ['double', 'float64']: 114 | setup['dtype'] = torch.double 115 | elif args.dtype in ['half', 'float16']: 116 | setup['dtype'] = torch.half 117 | else: 118 | raise ValueError(f'Unknown data type argument {args.dtype}.') 119 | print(f'Model and input parameter moved to {args.dtype}-precision.') 120 | dm = torch.as_tensor(inversefed.consts.cifar10_mean, **setup)[:, None, None] 121 | ds = torch.as_tensor(inversefed.consts.cifar10_std, **setup)[:, None, None] 122 | ground_truth = ground_truth.to(**setup) 123 | input_gradient = [g.to(**setup) for g in input_gradient] 124 | model.to(**setup) 125 | model.eval() 126 | 127 | if args.optim == 'ours': 128 | config = dict(signed=args.signed, 129 | boxed=args.boxed, 130 | cost_fn=args.cost_fn, 131 | indices='def', 132 | weights='equal', 133 | lr=0.1, 134 | optim=args.optimizer, 135 | restarts=args.restarts, 136 | max_iterations=24_000, 137 | total_variation=args.tv, 138 | init='randn', 139 | filter='none', 140 | lr_decay=True, 141 | scoring_choice='loss') 142 | elif args.optim == 'zhu': 143 | config = dict(signed=False, 144 | boxed=False, 145 | cost_fn='l2', 146 | indices='def', 147 | weights='equal', 148 | lr=1e-4, 149 | optim='LBFGS', 150 | restarts=args.restarts, 151 | max_iterations=300, 152 | total_variation=args.tv, 153 | init=args.init, 154 | filter='none', 155 | lr_decay=False, 156 | scoring_choice=args.scoring_choice) 157 | 158 | rec_machine = inversefed.GradientReconstructor(model, (dm, ds), config, num_images=args.num_images) 159 | output, stats = rec_machine.reconstruct(input_gradient, labels, img_shape=img_shape, dryrun=args.dryrun) 160 | 161 | else: 162 | local_gradient_steps = args.accumulation 163 | local_lr = 1e-4 164 | input_parameters = inversefed.reconstruction_algorithms.loss_steps(model, ground_truth, labels, 165 | lr=local_lr, local_steps=local_gradient_steps) 166 | input_parameters = [p.detach() for p in input_parameters] 167 | 168 | # Run reconstruction in different precision? 169 | if args.dtype != 'float': 170 | if args.dtype in ['double', 'float64']: 171 | setup['dtype'] = torch.double 172 | elif args.dtype in ['half', 'float16']: 173 | setup['dtype'] = torch.half 174 | else: 175 | raise ValueError(f'Unknown data type argument {args.dtype}.') 176 | print(f'Model and input parameter moved to {args.dtype}-precision.') 177 | ground_truth = ground_truth.to(**setup) 178 | dm = torch.as_tensor(inversefed.consts.cifar10_mean, **setup)[:, None, None] 179 | ds = torch.as_tensor(inversefed.consts.cifar10_std, **setup)[:, None, None] 180 | input_parameters = [g.to(**setup) for g in input_parameters] 181 | model.to(**setup) 182 | model.eval() 183 | 184 | config = dict(signed=args.signed, 185 | boxed=args.boxed, 186 | cost_fn=args.cost_fn, 187 | indices=args.indices, 188 | weights=args.weights, 189 | lr=1, 190 | optim=args.optimizer, 191 | restarts=args.restarts, 192 | max_iterations=24_000, 193 | total_variation=args.tv, 194 | init=args.init, 195 | filter='none', 196 | lr_decay=True, 197 | scoring_choice=args.scoring_choice) 198 | 199 | rec_machine = inversefed.FedAvgReconstructor(model, (dm, ds), local_gradient_steps, local_lr, config, 200 | num_images=args.num_images, use_updates=True) 201 | output, stats = rec_machine.reconstruct(input_parameters, labels, img_shape=img_shape, dryrun=args.dryrun) 202 | 203 | 204 | # Compute stats 205 | test_mse = (output - ground_truth).pow(2).mean().item() 206 | feat_mse = (model(output) - model(ground_truth)).pow(2).mean().item() 207 | test_psnr = inversefed.metrics.psnr(output, ground_truth, factor=1 / ds) 208 | 209 | 210 | # Save the resulting image 211 | if args.save_image and not args.dryrun: 212 | os.makedirs(args.image_path, exist_ok=True) 213 | output_denormalized = torch.clamp(output * ds + dm, 0, 1) 214 | rec_filename = (f'{validloader.dataset.classes[labels][0]}_{"trained" if args.trained_model else ""}' 215 | f'{args.model}_{args.cost_fn}-{args.target_id}.png') 216 | torchvision.utils.save_image(output_denormalized, os.path.join(args.image_path, rec_filename)) 217 | 218 | gt_denormalized = torch.clamp(ground_truth * ds + dm, 0, 1) 219 | gt_filename = (f'{validloader.dataset.classes[labels][0]}_ground_truth-{args.target_id}.png') 220 | torchvision.utils.save_image(gt_denormalized, os.path.join(args.image_path, gt_filename)) 221 | else: 222 | rec_filename = None 223 | gt_filename = None 224 | 225 | 226 | # Save to a table: 227 | print(f"Rec. loss: {stats['opt']:2.4f} | MSE: {test_mse:2.4f} | PSNR: {test_psnr:4.2f} | FMSE: {feat_mse:2.4e} |") 228 | 229 | inversefed.utils.save_to_table(args.table_path, name=f'exp_{args.name}', dryrun=args.dryrun, 230 | 231 | model=args.model, 232 | dataset=args.dataset, 233 | trained=args.trained_model, 234 | accumulation=args.accumulation, 235 | restarts=args.restarts, 236 | OPTIM=args.optim, 237 | cost_fn=args.cost_fn, 238 | indices=args.indices, 239 | weights=args.weights, 240 | scoring=args.scoring_choice, 241 | init=args.init, 242 | tv=args.tv, 243 | 244 | rec_loss=stats["opt"], 245 | psnr=test_psnr, 246 | test_mse=test_mse, 247 | feat_mse=feat_mse, 248 | 249 | target_id=target_id, 250 | seed=model_seed, 251 | timing=str(datetime.timedelta(seconds=time.time() - start_time)), 252 | dtype=setup['dtype'], 253 | epochs=defs.epochs, 254 | val_acc=None, 255 | rec_img=rec_filename, 256 | gt_img=gt_filename 257 | ) 258 | 259 | 260 | # Print final timestamp 261 | print(datetime.datetime.now().strftime("%A, %d. %B %Y %I:%M%p")) 262 | print('---------------------------------------------------') 263 | print(f'Finished computations with time: {str(datetime.timedelta(seconds=time.time() - start_time))}') 264 | print('-------------Job finished.-------------------------') 265 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Codes for CPL attacks 2 | 3 | Client Privacy Leakage, or CPL is an advanced privacy leakage attack in federated learning which utilize the stolen gradient either during client's local training nor after the local training. While the context is in federated learning, the same problem holds in centralized setting too when the insider peeps the gradients saved at the checkpoint. The attack is an iterative process: 4 | 5 | There are a few unique properties of CPL: 6 | - with geometric initilaizations like pattened seed or single-color seed, CPL attack is much faster than the SOTA DLG and GradInverting attack in terms of attack iterations, and better reconstruction quality in terms of both attack success rate and image quality. Besides the geometric initilization seed, CPL attack integrates the label attack given the fact the gradients on the training label class have the largest value so that the content attack is accerlerated compared to DLG. Detailed implementations can be found in ./CPL/LFW_enhanced_random_ASR.ipynb 7 | - CPL attack can handle images larger than 64x64, which DLG claims impossbile to reconstruct from gradient. More details can be found in ./CPL/LFW128_enhanced_random_ASR.ipynb . The 64x64 setting is also provided as in LFW64_enhanced_random_ASR.ipynb . 8 | - CPL attack can handle batch size up to 8 to attack the entire batch as a whole when there is, unlike DLG and GradInverting attack which can only attack single-input gradient in a batch one by one. More details can be found ./CPL/LFW_batch.ipynb 9 | - CPL attack can work on Tanh and LeakyReLU in addition to Sigmoid, while existing approaches are studied only on Sigmoid. ./CPL/LFW_tanhrelu.ipynb 10 | - We also provde initial solutions including gradien compression and additive Gaussian and Laplacian noise. See ./CPL/LFW_defense.ipynb for more details and tuning. 11 | - To systematically defend the attack, federated learning with client-side differential privacy noise is proposed. See [code](https://github.com/git-disl/Fed-CDP) for gradient leakage resilient federated learning. 12 | 13 | 14 | ### Examples 15 | 16 | | | ours | DLG | 17 | |:---:|:---:|:---:| 18 | | MNIST| ![mnist_ours](demo/mnist_ours.gif) | ![mnist_dlg](demo/mnist_dlg.gif) | 19 | | CIFAR10| ![cifar10_ours](demo/cifar10_ours.gif) | ![cifar10_dlg](demo/cifar10_dlg.gif) | 20 | | LFW| ![lfw_ours](demo/lfw_ours.gif) | ![lfw_dlg](demo/lfw_dlg.gif) | 21 | 22 | 23 | 24 | ### Here is a brief description of each file in the DLG folder. 25 | 26 | LFW_Deep_Leakage_from_Gradients.ipynb: lfw implementation for DLG attack in (NIPS2019) "Deep leakage from gradients." 27 | 28 | ### Here is a brief description of each file in the GradInversting folder. 29 | 30 | Attack from NeurIPS 2020: Geiping, Jonas, Hartmut Bauermeister, Hannah Dröge, and Michael Moeller. "Inverting Gradients--How easy is it to break privacy in federated learning?." 31 | To run, you may find more details [here](https://github.com/JonasGeiping/invertinggradients) 32 | 33 | 34 | The talk on the CPL attack can be found here: talk. Check out our [project page](https://git-disl.github.io/ESORICS20-CPL/). 35 | If you use our code, please cite: 36 | 37 | ``` 38 | @inproceedings{wei2020framework, 39 | title={A framework for evaluating client privacy leakages in federated learning}, 40 | author={Wei, Wenqi and Liu, Ling and Loper, Margaret and Chow, Ka-Ho and Gursoy, Mehmet Emre and Truex, Stacey and Wu, Yanzhao}, 41 | booktitle={European Symposium on Research in Computer Security}, 42 | year={2020}, 43 | organization={Springer} 44 | } 45 | 46 | @inproceedings{wei2021gradient, 47 | title={Gradient-Leakage Resilient Federated Learning}, 48 | author={Wei, Wenqi and Liu, Ling and Wu, Yanzhao and Su, Gong and Iyengar, Arun}, 49 | booktitle={International Conference on Distributed Computing Systems}, 50 | year={2021}, 51 | organization={IEEE} 52 | } 53 | ... 54 | 55 | -------------------------------------------------------------------------------- /demo/cifar10_dlg.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/git-disl/CPL_attack/13607a3435c27707135c5fcd84abdb218f7f94c8/demo/cifar10_dlg.gif -------------------------------------------------------------------------------- /demo/cifar10_ours.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/git-disl/CPL_attack/13607a3435c27707135c5fcd84abdb218f7f94c8/demo/cifar10_ours.gif -------------------------------------------------------------------------------- /demo/lfw_dlg.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/git-disl/CPL_attack/13607a3435c27707135c5fcd84abdb218f7f94c8/demo/lfw_dlg.gif -------------------------------------------------------------------------------- /demo/lfw_ours.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/git-disl/CPL_attack/13607a3435c27707135c5fcd84abdb218f7f94c8/demo/lfw_ours.gif -------------------------------------------------------------------------------- /demo/mnist_dlg.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/git-disl/CPL_attack/13607a3435c27707135c5fcd84abdb218f7f94c8/demo/mnist_dlg.gif -------------------------------------------------------------------------------- /demo/mnist_ours.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/git-disl/CPL_attack/13607a3435c27707135c5fcd84abdb218f7f94c8/demo/mnist_ours.gif --------------------------------------------------------------------------------