├── .gitignore ├── LICENSE ├── README.md ├── assets ├── demo-crop.gif ├── method.jpg ├── nips-dlg.jpg ├── nlp_results.png └── out.gif ├── main.py ├── models └── vision.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 MIT HAN Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Leakage From Gradients [[arXiv]](https://arxiv.org/abs/1906.08935) [[Webside]](https://dlg.mit.edu) 2 | 3 | 4 | 5 | ``` 6 | @inproceedings{zhu19deep, 7 | title={Deep Leakage from Gradients}, 8 | author={Zhu, Ligeng and Liu, Zhijian and Han, Song}, 9 | booktitle={Advances in Neural Information Processing Systems}, 10 | year={2019} 11 | } 12 | ``` 13 | 14 | Gradients exchaging is popular used in modern multi-node learning systems. People used to believe numerical gradients are safe to share. But we show that it is actually possible to obtain the training data from shared gradients and the leakage is pixel-wise accurate for images and token-wise matching for texts. 15 | 16 |

17 | 18 |

19 | 20 |

21 | 22 |

23 | 24 | ## Overview 25 | 26 | The core algorithm is to *match the gradients* between *dummy data* and *real data*. 27 | 28 |

29 | 30 |

31 | 32 | It can be implemented in **less than 20 lines** with PyTorch! 33 | 34 | 35 | ```python 36 | def deep_leakage_from_gradients(model, origin_grad): 37 | dummy_data = torch.randn(origin_data.size()) 38 | dummy_label = torch.randn(dummy_label.size()) 39 | optimizer = torch.optim.LBFGS([dummy_data, dummy_label] ) 40 | 41 | for iters in range(300): 42 | def closure(): 43 | optimizer.zero_grad() 44 | dummy_pred = model(dummy_data) 45 | dummy_loss = criterion(dummy_pred, F.softmax(dummy_label, dim=-1)) 46 | dummy_grad = grad(dummy_loss, model.parameters(), create_graph=True) 47 | 48 | grad_diff = sum(((dummy_grad - origin_grad) ** 2).sum() \ 49 | for dummy_g, origin_g in zip(dummy_grad, origin_grad)) 50 | 51 | grad_diff.backward() 52 | return grad_diff 53 | 54 | optimizer.step(closure) 55 | 56 | return dummy_data, dummy_label 57 | ``` 58 | 59 | 60 | # Prerequisites 61 | 62 | To run the code, following libraies are required 63 | 64 | * Python >= 3.6 65 | * PyTorch >= 1.0 66 | * torchvision >= 0.4 67 | 68 | # Code 69 | 70 | 72 | 73 |

Note: We provide 74 | Open In Colab 75 | for quick reproduction. 76 |

77 | 78 | ``` 79 | # Single image on CIFAR 80 | python main.py --index 25 81 | 82 | # Deep Leakage on your own Image 83 | python main.py --image yours.jpg 84 | ``` 85 | 86 | # Deep Leakage on Batched Images 87 | 88 |

89 | 90 |

91 | 92 | # Deep Leakage on Language Model 93 | 94 |

95 | 96 |

97 | 98 | 99 | # License 100 | 101 | This repository is released under the MIT license. See [LICENSE](LICENSE) for additional details. 102 | -------------------------------------------------------------------------------- /assets/demo-crop.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/dlg/d21007fa1540ba2303ebc034976aa331814727c7/assets/demo-crop.gif -------------------------------------------------------------------------------- /assets/method.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/dlg/d21007fa1540ba2303ebc034976aa331814727c7/assets/method.jpg -------------------------------------------------------------------------------- /assets/nips-dlg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/dlg/d21007fa1540ba2303ebc034976aa331814727c7/assets/nips-dlg.jpg -------------------------------------------------------------------------------- /assets/nlp_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/dlg/d21007fa1540ba2303ebc034976aa331814727c7/assets/nlp_results.png -------------------------------------------------------------------------------- /assets/out.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/dlg/d21007fa1540ba2303ebc034976aa331814727c7/assets/out.gif -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | import numpy as np 4 | from pprint import pprint 5 | 6 | from PIL import Image 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.autograd import grad 13 | import torchvision 14 | from torchvision import models, datasets, transforms 15 | print(torch.__version__, torchvision.__version__) 16 | 17 | from utils import label_to_onehot, cross_entropy_for_onehot 18 | 19 | parser = argparse.ArgumentParser(description='Deep Leakage from Gradients.') 20 | parser.add_argument('--index', type=int, default="25", 21 | help='the index for leaking images on CIFAR.') 22 | parser.add_argument('--image', type=str,default="", 23 | help='the path to customized image.') 24 | args = parser.parse_args() 25 | 26 | device = "cpu" 27 | if torch.cuda.is_available(): 28 | device = "cuda" 29 | print("Running on %s" % device) 30 | 31 | dst = datasets.CIFAR100("~/.torch", download=True) 32 | tp = transforms.ToTensor() 33 | tt = transforms.ToPILImage() 34 | 35 | img_index = args.index 36 | gt_data = tp(dst[img_index][0]).to(device) 37 | 38 | if len(args.image) > 1: 39 | gt_data = Image.open(args.image) 40 | gt_data = tp(gt_data).to(device) 41 | 42 | 43 | gt_data = gt_data.view(1, *gt_data.size()) 44 | gt_label = torch.Tensor([dst[img_index][1]]).long().to(device) 45 | gt_label = gt_label.view(1, ) 46 | gt_onehot_label = label_to_onehot(gt_label) 47 | 48 | plt.imshow(tt(gt_data[0].cpu())) 49 | 50 | from models.vision import LeNet, weights_init 51 | net = LeNet().to(device) 52 | 53 | 54 | torch.manual_seed(1234) 55 | 56 | net.apply(weights_init) 57 | criterion = cross_entropy_for_onehot 58 | 59 | # compute original gradient 60 | pred = net(gt_data) 61 | y = criterion(pred, gt_onehot_label) 62 | dy_dx = torch.autograd.grad(y, net.parameters()) 63 | 64 | original_dy_dx = list((_.detach().clone() for _ in dy_dx)) 65 | 66 | # generate dummy data and label 67 | dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True) 68 | dummy_label = torch.randn(gt_onehot_label.size()).to(device).requires_grad_(True) 69 | 70 | plt.imshow(tt(dummy_data[0].cpu())) 71 | 72 | optimizer = torch.optim.LBFGS([dummy_data, dummy_label]) 73 | 74 | 75 | history = [] 76 | for iters in range(300): 77 | def closure(): 78 | optimizer.zero_grad() 79 | 80 | dummy_pred = net(dummy_data) 81 | dummy_onehot_label = F.softmax(dummy_label, dim=-1) 82 | dummy_loss = criterion(dummy_pred, dummy_onehot_label) 83 | dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True) 84 | 85 | grad_diff = 0 86 | for gx, gy in zip(dummy_dy_dx, original_dy_dx): 87 | grad_diff += ((gx - gy) ** 2).sum() 88 | grad_diff.backward() 89 | 90 | return grad_diff 91 | 92 | optimizer.step(closure) 93 | if iters % 10 == 0: 94 | current_loss = closure() 95 | print(iters, "%.4f" % current_loss.item()) 96 | history.append(tt(dummy_data[0].cpu())) 97 | 98 | plt.figure(figsize=(12, 8)) 99 | for i in range(30): 100 | plt.subplot(3, 10, i + 1) 101 | plt.imshow(history[i]) 102 | plt.title("iter=%d" % (i * 10)) 103 | plt.axis('off') 104 | 105 | plt.show() 106 | -------------------------------------------------------------------------------- /models/vision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import grad 5 | import torchvision 6 | from torchvision import models, datasets, transforms 7 | 8 | 9 | def weights_init(m): 10 | if hasattr(m, "weight"): 11 | m.weight.data.uniform_(-0.5, 0.5) 12 | if hasattr(m, "bias"): 13 | m.bias.data.uniform_(-0.5, 0.5) 14 | 15 | class LeNet(nn.Module): 16 | def __init__(self): 17 | super(LeNet, self).__init__() 18 | act = nn.Sigmoid 19 | self.body = nn.Sequential( 20 | nn.Conv2d(3, 12, kernel_size=5, padding=5//2, stride=2), 21 | act(), 22 | nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=2), 23 | act(), 24 | nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=1), 25 | act(), 26 | ) 27 | self.fc = nn.Sequential( 28 | nn.Linear(768, 100) 29 | ) 30 | 31 | def forward(self, x): 32 | out = self.body(x) 33 | out = out.view(out.size(0), -1) 34 | # print(out.size()) 35 | out = self.fc(out) 36 | return out 37 | 38 | 39 | '''ResNet in PyTorch. 40 | 41 | For Pre-activation ResNet, see 'preact_resnet.py'. 42 | 43 | Reference: 44 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 45 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 46 | ''' 47 | import torch 48 | import torch.nn as nn 49 | import torch.nn.functional as F 50 | 51 | 52 | def weights_init(m): 53 | if hasattr(m, "weight"): 54 | m.weight.data.uniform_(-0.5, 0.5) 55 | if hasattr(m, "bias"): 56 | m.bias.data.uniform_(-0.5, 0.5) 57 | 58 | class BasicBlock(nn.Module): 59 | expansion = 1 60 | 61 | def __init__(self, in_planes, planes, stride=1): 62 | super(BasicBlock, self).__init__() 63 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(planes) 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | 68 | self.shortcut = nn.Sequential() 69 | if stride != 1 or in_planes != self.expansion*planes: 70 | self.shortcut = nn.Sequential( 71 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 72 | nn.BatchNorm2d(self.expansion*planes) 73 | ) 74 | 75 | def forward(self, x): 76 | out = F.Sigmoid(self.bn1(self.conv1(x))) 77 | out = self.bn2(self.conv2(out)) 78 | out += self.shortcut(x) 79 | out = F.Sigmoid(out) 80 | return out 81 | 82 | 83 | class Bottleneck(nn.Module): 84 | expansion = 4 85 | 86 | def __init__(self, in_planes, planes, stride=1): 87 | super(Bottleneck, self).__init__() 88 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 89 | self.bn1 = nn.BatchNorm2d(planes) 90 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 91 | self.bn2 = nn.BatchNorm2d(planes) 92 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 93 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 94 | 95 | self.shortcut = nn.Sequential() 96 | if stride != 1 or in_planes != self.expansion*planes: 97 | self.shortcut = nn.Sequential( 98 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 99 | nn.BatchNorm2d(self.expansion*planes) 100 | ) 101 | 102 | def forward(self, x): 103 | out = F.Sigmoid(self.bn1(self.conv1(x))) 104 | out = F.Sigmoid(self.bn2(self.conv2(out))) 105 | out = self.bn3(self.conv3(out)) 106 | out += self.shortcut(x) 107 | out = F.Sigmoid(out) 108 | return out 109 | 110 | 111 | class ResNet(nn.Module): 112 | def __init__(self, block, num_blocks, num_classes=10): 113 | super(ResNet, self).__init__() 114 | self.in_planes = 64 115 | 116 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 117 | self.bn1 = nn.BatchNorm2d(64) 118 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 119 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=1) 120 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=1) 121 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=1) 122 | self.linear = nn.Linear(512*block.expansion, num_classes) 123 | 124 | def _make_layer(self, block, planes, num_blocks, stride): 125 | strides = [stride] + [1]*(num_blocks-1) 126 | layers = [] 127 | for stride in strides: 128 | layers.append(block(self.in_planes, planes, stride)) 129 | self.in_planes = planes * block.expansion 130 | return nn.Sequential(*layers) 131 | 132 | def forward(self, x): 133 | out = F.Sigmoid(self.bn1(self.conv1(x))) 134 | out = self.layer1(out) 135 | out = self.layer2(out) 136 | out = self.layer3(out) 137 | out = self.layer4(out) 138 | out = F.avg_pool2d(out, 4) 139 | out = out.view(out.size(0), -1) 140 | out = self.linear(out) 141 | return out 142 | 143 | 144 | def ResNet18(): 145 | return ResNet(BasicBlock, [2,2,2,2]) 146 | 147 | def ResNet34(): 148 | return ResNet(BasicBlock, [3,4,6,3]) 149 | 150 | def ResNet50(): 151 | return ResNet(Bottleneck, [3,4,6,3]) 152 | 153 | def ResNet101(): 154 | return ResNet(Bottleneck, [3,4,23,3]) 155 | 156 | def ResNet152(): 157 | return ResNet(Bottleneck, [3,8,36,3]) 158 | 159 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def label_to_onehot(target, num_classes=100): 6 | target = torch.unsqueeze(target, 1) 7 | onehot_target = torch.zeros(target.size(0), num_classes, device=target.device) 8 | onehot_target.scatter_(1, target, 1) 9 | return onehot_target 10 | 11 | def cross_entropy_for_onehot(pred, target): 12 | return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1)) 13 | --------------------------------------------------------------------------------