├── .gitignore
├── LICENSE
├── README.md
├── assets
├── demo-crop.gif
├── method.jpg
├── nips-dlg.jpg
├── nlp_results.png
└── out.gif
├── main.py
├── models
└── vision.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 MIT HAN Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Leakage From Gradients [[arXiv]](https://arxiv.org/abs/1906.08935) [[Webside]](https://dlg.mit.edu)
2 |
3 |
4 |
5 | ```
6 | @inproceedings{zhu19deep,
7 | title={Deep Leakage from Gradients},
8 | author={Zhu, Ligeng and Liu, Zhijian and Han, Song},
9 | booktitle={Advances in Neural Information Processing Systems},
10 | year={2019}
11 | }
12 | ```
13 |
14 | Gradients exchaging is popular used in modern multi-node learning systems. People used to believe numerical gradients are safe to share. But we show that it is actually possible to obtain the training data from shared gradients and the leakage is pixel-wise accurate for images and token-wise matching for texts.
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | ## Overview
25 |
26 | The core algorithm is to *match the gradients* between *dummy data* and *real data*.
27 |
28 |
29 |
30 |
31 |
32 | It can be implemented in **less than 20 lines** with PyTorch!
33 |
34 |
35 | ```python
36 | def deep_leakage_from_gradients(model, origin_grad):
37 | dummy_data = torch.randn(origin_data.size())
38 | dummy_label = torch.randn(dummy_label.size())
39 | optimizer = torch.optim.LBFGS([dummy_data, dummy_label] )
40 |
41 | for iters in range(300):
42 | def closure():
43 | optimizer.zero_grad()
44 | dummy_pred = model(dummy_data)
45 | dummy_loss = criterion(dummy_pred, F.softmax(dummy_label, dim=-1))
46 | dummy_grad = grad(dummy_loss, model.parameters(), create_graph=True)
47 |
48 | grad_diff = sum(((dummy_grad - origin_grad) ** 2).sum() \
49 | for dummy_g, origin_g in zip(dummy_grad, origin_grad))
50 |
51 | grad_diff.backward()
52 | return grad_diff
53 |
54 | optimizer.step(closure)
55 |
56 | return dummy_data, dummy_label
57 | ```
58 |
59 |
60 | # Prerequisites
61 |
62 | To run the code, following libraies are required
63 |
64 | * Python >= 3.6
65 | * PyTorch >= 1.0
66 | * torchvision >= 0.4
67 |
68 | # Code
69 |
70 |
72 |
73 | Note: We provide
74 |
75 | for quick reproduction.
76 |
77 |
78 | ```
79 | # Single image on CIFAR
80 | python main.py --index 25
81 |
82 | # Deep Leakage on your own Image
83 | python main.py --image yours.jpg
84 | ```
85 |
86 | # Deep Leakage on Batched Images
87 |
88 |
89 |
90 |
91 |
92 | # Deep Leakage on Language Model
93 |
94 |
95 |
96 |
97 |
98 |
99 | # License
100 |
101 | This repository is released under the MIT license. See [LICENSE](LICENSE) for additional details.
102 |
--------------------------------------------------------------------------------
/assets/demo-crop.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/dlg/d21007fa1540ba2303ebc034976aa331814727c7/assets/demo-crop.gif
--------------------------------------------------------------------------------
/assets/method.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/dlg/d21007fa1540ba2303ebc034976aa331814727c7/assets/method.jpg
--------------------------------------------------------------------------------
/assets/nips-dlg.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/dlg/d21007fa1540ba2303ebc034976aa331814727c7/assets/nips-dlg.jpg
--------------------------------------------------------------------------------
/assets/nlp_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/dlg/d21007fa1540ba2303ebc034976aa331814727c7/assets/nlp_results.png
--------------------------------------------------------------------------------
/assets/out.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mit-han-lab/dlg/d21007fa1540ba2303ebc034976aa331814727c7/assets/out.gif
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import argparse
3 | import numpy as np
4 | from pprint import pprint
5 |
6 | from PIL import Image
7 | import matplotlib.pyplot as plt
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.autograd import grad
13 | import torchvision
14 | from torchvision import models, datasets, transforms
15 | print(torch.__version__, torchvision.__version__)
16 |
17 | from utils import label_to_onehot, cross_entropy_for_onehot
18 |
19 | parser = argparse.ArgumentParser(description='Deep Leakage from Gradients.')
20 | parser.add_argument('--index', type=int, default="25",
21 | help='the index for leaking images on CIFAR.')
22 | parser.add_argument('--image', type=str,default="",
23 | help='the path to customized image.')
24 | args = parser.parse_args()
25 |
26 | device = "cpu"
27 | if torch.cuda.is_available():
28 | device = "cuda"
29 | print("Running on %s" % device)
30 |
31 | dst = datasets.CIFAR100("~/.torch", download=True)
32 | tp = transforms.ToTensor()
33 | tt = transforms.ToPILImage()
34 |
35 | img_index = args.index
36 | gt_data = tp(dst[img_index][0]).to(device)
37 |
38 | if len(args.image) > 1:
39 | gt_data = Image.open(args.image)
40 | gt_data = tp(gt_data).to(device)
41 |
42 |
43 | gt_data = gt_data.view(1, *gt_data.size())
44 | gt_label = torch.Tensor([dst[img_index][1]]).long().to(device)
45 | gt_label = gt_label.view(1, )
46 | gt_onehot_label = label_to_onehot(gt_label)
47 |
48 | plt.imshow(tt(gt_data[0].cpu()))
49 |
50 | from models.vision import LeNet, weights_init
51 | net = LeNet().to(device)
52 |
53 |
54 | torch.manual_seed(1234)
55 |
56 | net.apply(weights_init)
57 | criterion = cross_entropy_for_onehot
58 |
59 | # compute original gradient
60 | pred = net(gt_data)
61 | y = criterion(pred, gt_onehot_label)
62 | dy_dx = torch.autograd.grad(y, net.parameters())
63 |
64 | original_dy_dx = list((_.detach().clone() for _ in dy_dx))
65 |
66 | # generate dummy data and label
67 | dummy_data = torch.randn(gt_data.size()).to(device).requires_grad_(True)
68 | dummy_label = torch.randn(gt_onehot_label.size()).to(device).requires_grad_(True)
69 |
70 | plt.imshow(tt(dummy_data[0].cpu()))
71 |
72 | optimizer = torch.optim.LBFGS([dummy_data, dummy_label])
73 |
74 |
75 | history = []
76 | for iters in range(300):
77 | def closure():
78 | optimizer.zero_grad()
79 |
80 | dummy_pred = net(dummy_data)
81 | dummy_onehot_label = F.softmax(dummy_label, dim=-1)
82 | dummy_loss = criterion(dummy_pred, dummy_onehot_label)
83 | dummy_dy_dx = torch.autograd.grad(dummy_loss, net.parameters(), create_graph=True)
84 |
85 | grad_diff = 0
86 | for gx, gy in zip(dummy_dy_dx, original_dy_dx):
87 | grad_diff += ((gx - gy) ** 2).sum()
88 | grad_diff.backward()
89 |
90 | return grad_diff
91 |
92 | optimizer.step(closure)
93 | if iters % 10 == 0:
94 | current_loss = closure()
95 | print(iters, "%.4f" % current_loss.item())
96 | history.append(tt(dummy_data[0].cpu()))
97 |
98 | plt.figure(figsize=(12, 8))
99 | for i in range(30):
100 | plt.subplot(3, 10, i + 1)
101 | plt.imshow(history[i])
102 | plt.title("iter=%d" % (i * 10))
103 | plt.axis('off')
104 |
105 | plt.show()
106 |
--------------------------------------------------------------------------------
/models/vision.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import grad
5 | import torchvision
6 | from torchvision import models, datasets, transforms
7 |
8 |
9 | def weights_init(m):
10 | if hasattr(m, "weight"):
11 | m.weight.data.uniform_(-0.5, 0.5)
12 | if hasattr(m, "bias"):
13 | m.bias.data.uniform_(-0.5, 0.5)
14 |
15 | class LeNet(nn.Module):
16 | def __init__(self):
17 | super(LeNet, self).__init__()
18 | act = nn.Sigmoid
19 | self.body = nn.Sequential(
20 | nn.Conv2d(3, 12, kernel_size=5, padding=5//2, stride=2),
21 | act(),
22 | nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=2),
23 | act(),
24 | nn.Conv2d(12, 12, kernel_size=5, padding=5//2, stride=1),
25 | act(),
26 | )
27 | self.fc = nn.Sequential(
28 | nn.Linear(768, 100)
29 | )
30 |
31 | def forward(self, x):
32 | out = self.body(x)
33 | out = out.view(out.size(0), -1)
34 | # print(out.size())
35 | out = self.fc(out)
36 | return out
37 |
38 |
39 | '''ResNet in PyTorch.
40 |
41 | For Pre-activation ResNet, see 'preact_resnet.py'.
42 |
43 | Reference:
44 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
45 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
46 | '''
47 | import torch
48 | import torch.nn as nn
49 | import torch.nn.functional as F
50 |
51 |
52 | def weights_init(m):
53 | if hasattr(m, "weight"):
54 | m.weight.data.uniform_(-0.5, 0.5)
55 | if hasattr(m, "bias"):
56 | m.bias.data.uniform_(-0.5, 0.5)
57 |
58 | class BasicBlock(nn.Module):
59 | expansion = 1
60 |
61 | def __init__(self, in_planes, planes, stride=1):
62 | super(BasicBlock, self).__init__()
63 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
64 | self.bn1 = nn.BatchNorm2d(planes)
65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
66 | self.bn2 = nn.BatchNorm2d(planes)
67 |
68 | self.shortcut = nn.Sequential()
69 | if stride != 1 or in_planes != self.expansion*planes:
70 | self.shortcut = nn.Sequential(
71 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
72 | nn.BatchNorm2d(self.expansion*planes)
73 | )
74 |
75 | def forward(self, x):
76 | out = F.Sigmoid(self.bn1(self.conv1(x)))
77 | out = self.bn2(self.conv2(out))
78 | out += self.shortcut(x)
79 | out = F.Sigmoid(out)
80 | return out
81 |
82 |
83 | class Bottleneck(nn.Module):
84 | expansion = 4
85 |
86 | def __init__(self, in_planes, planes, stride=1):
87 | super(Bottleneck, self).__init__()
88 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
89 | self.bn1 = nn.BatchNorm2d(planes)
90 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
91 | self.bn2 = nn.BatchNorm2d(planes)
92 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
93 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
94 |
95 | self.shortcut = nn.Sequential()
96 | if stride != 1 or in_planes != self.expansion*planes:
97 | self.shortcut = nn.Sequential(
98 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
99 | nn.BatchNorm2d(self.expansion*planes)
100 | )
101 |
102 | def forward(self, x):
103 | out = F.Sigmoid(self.bn1(self.conv1(x)))
104 | out = F.Sigmoid(self.bn2(self.conv2(out)))
105 | out = self.bn3(self.conv3(out))
106 | out += self.shortcut(x)
107 | out = F.Sigmoid(out)
108 | return out
109 |
110 |
111 | class ResNet(nn.Module):
112 | def __init__(self, block, num_blocks, num_classes=10):
113 | super(ResNet, self).__init__()
114 | self.in_planes = 64
115 |
116 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
117 | self.bn1 = nn.BatchNorm2d(64)
118 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
119 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=1)
120 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=1)
121 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=1)
122 | self.linear = nn.Linear(512*block.expansion, num_classes)
123 |
124 | def _make_layer(self, block, planes, num_blocks, stride):
125 | strides = [stride] + [1]*(num_blocks-1)
126 | layers = []
127 | for stride in strides:
128 | layers.append(block(self.in_planes, planes, stride))
129 | self.in_planes = planes * block.expansion
130 | return nn.Sequential(*layers)
131 |
132 | def forward(self, x):
133 | out = F.Sigmoid(self.bn1(self.conv1(x)))
134 | out = self.layer1(out)
135 | out = self.layer2(out)
136 | out = self.layer3(out)
137 | out = self.layer4(out)
138 | out = F.avg_pool2d(out, 4)
139 | out = out.view(out.size(0), -1)
140 | out = self.linear(out)
141 | return out
142 |
143 |
144 | def ResNet18():
145 | return ResNet(BasicBlock, [2,2,2,2])
146 |
147 | def ResNet34():
148 | return ResNet(BasicBlock, [3,4,6,3])
149 |
150 | def ResNet50():
151 | return ResNet(Bottleneck, [3,4,6,3])
152 |
153 | def ResNet101():
154 | return ResNet(Bottleneck, [3,4,23,3])
155 |
156 | def ResNet152():
157 | return ResNet(Bottleneck, [3,8,36,3])
158 |
159 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def label_to_onehot(target, num_classes=100):
6 | target = torch.unsqueeze(target, 1)
7 | onehot_target = torch.zeros(target.size(0), num_classes, device=target.device)
8 | onehot_target.scatter_(1, target, 1)
9 | return onehot_target
10 |
11 | def cross_entropy_for_onehot(pred, target):
12 | return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1))
13 |
--------------------------------------------------------------------------------