├── .gitignore ├── .illustration.png ├── .visualization.png ├── README.md ├── experiments ├── configs │ ├── cifar_ig.json │ ├── cifar_sme.json │ ├── femnist_ig.json │ ├── femnist_sme.json │ ├── vit_ig.json │ └── vit_sme.json └── run_experiment.py ├── models ├── FCN3.py ├── LeNet.py ├── __init__.py ├── cifarCNN.py ├── mnistCNN.py ├── resnet.py └── vit.py ├── requirement.txt ├── sme ├── Adversary │ ├── __init__.py │ ├── adversary.py │ └── utils.py ├── __init__.py └── attack.py └── utils ├── __init__.py ├── dataloader.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.idea 2 | *__pycache__ 3 | *data 4 | *res 5 | -------------------------------------------------------------------------------- /.illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunyiZhu-AI/surrogate_model_extension/ae2a624c7849648f1e35b1a29c663b173ceef63b/.illustration.png -------------------------------------------------------------------------------- /.visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JunyiZhu-AI/surrogate_model_extension/ae2a624c7849648f1e35b1a29c663b173ceef63b/.visualization.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Surrogate Model Extension (SME): A Fast and Accurate Weight Update Attack on Federated Learning [Accepted at ICML 2023] 2 | 3 | ### Abstract 4 | In Federated Learning (FL) and many other distributed training frameworks, collaborators can hold their private data locally and only share the network weights trained with the local data after multiple iterations. Gradient inversion is a family of privacy attacks that recovers data from its generated gradients. Seemingly, FL can provide a degree of protection against gradient inversion attacks on weight updates, since the gradient of a single step is concealed by the accumulation of gradients over multiple local iterations. In this work, we propose a principled way to extend gradient inversion attacks to weight updates in FL, thereby better exposing weaknesses in the presumed privacy protection inherent in FL. In particular, we propose a surrogate model method based on the characteristic of two-dimensional gradient flow and low-rank property of local updates. Our method largely boosts the ability of gradient inversion attacks on weight updates containing many iterations and achieves state-of-the-art (SOTA) performance. Additionally, our method runs up to $100\times$ faster than the SOTA baseline in the common FL scenario. Our work re-evaluates and highlights the privacy risk of sharing network weights. 5 | 6 |

7 | 8 |

9 |

10 | Figure 1: Illustration of the threat model (left) and working pipeline of our surrogate model extension (right). In FL, a client trains the received model w_0 for T iterations with local data set D of size N, then sends the weights and the number N back to the server. An adversary observes the messages and launches the SME attack through optimization of dummy data and surrogate model. 11 |

12 | 13 |

14 | 15 |

16 |

17 | Figure 2: Visualization of the reconstructed images. 18 |

19 | 20 | 21 | ### Download 22 | Make sure that conda is installed. 23 | ```sh 24 | git clone git@github.com:JunyiZhu-AI/surrogate_model_extension.git 25 | cd surrogate_model_extension 26 | conda create -n sme python==3.9.12 27 | conda activate sme 28 | conda install pip 29 | pip install -r requirement.txt 30 | ``` 31 | Prepare the FEMNIST dataset (preprocessing can take up to 30 minutes, take a break and have a coffee ☕). 32 | ```sh 33 | mkdir data 34 | cd data 35 | git clone https://github.com/TalwalkarLab/leaf.git 36 | cd leaf/data/femnist 37 | ./preprocess.sh -s niid --sf 0.05 -k 0 -t sample 38 | cd ../../.. 39 | mv leaf/data/femnist . 40 | rm -rf leaf 41 | cd .. 42 | ``` 43 | 44 | ### Run 45 | To run the experiments, follow these instructions: 46 | 47 | 1. SME attack on a CNN training with CIFAR-100: 48 | ```sh 49 | python3 -m experiments.run_experiment --config experiments/configs/cifar_sme.json 50 | ``` 51 | 52 | 2. Vanilla gradient inversion attack on a CNN training with CIFAR-100: 53 | ```sh 54 | python3 -m experiments.run_experiment --config experiments/configs/cifar_ig.json 55 | ``` 56 | 57 | 3. SME attack on a CNN training with FEMNIST: 58 | ```sh 59 | python3 -m experiments.run_experiment --config experiments/configs/femnist_sme.json 60 | ``` 61 | 62 | 4. Vanilla gradient inversion attack on a CNN training with FEMNIST: 63 | ```sh 64 | python3 -m experiments.run_experiment --config experiments/configs/femnist_ig.json 65 | ``` 66 | 67 | 5. SME attack on a ViT training with FEMNIST: 68 | ```sh 69 | python3 -m experiments.run_experiment --config experiments/configs/vit_sme.json 70 | ``` 71 | 72 | 6. Vanilla gradient inversion attack on a ViT training with FEMNIST: 73 | ```sh 74 | python3 -m experiments.run_experiment --config experiments/configs/vit_ig.json 75 | ``` 76 | 77 | 7. Example of running an experiment by passing arguments directly (the following hyperparameters are not tuned and are for demonstration purposes only): 78 | ```sh 79 | python3 -m experiments.run_experiment \ 80 | --dataset FEMNIST \ 81 | --model ResNet8 \ 82 | --seed 42 \ 83 | --batchsize 25 \ 84 | --train_lr 0.004 \ 85 | --epochs 50 \ 86 | --k 50 \ 87 | --alpha 0.5 \ 88 | --eta 1 \ 89 | --iters 2000 \ 90 | --lamb 0.01 \ 91 | --lr_decay True \ 92 | --beta 0.001 \ 93 | --test_steps 200 94 | ``` 95 | Explanation for each argument can be found in the ```experiments/run_experiment.py``` file. 96 | 97 | ### Citation 98 | ``` 99 | @InProceedings{pmlr-v202-zhu23m, 100 | title = {Surrogate Model Extension ({SME}): A Fast and Accurate Weight Update Attack on Federated Learning}, 101 | author = {Zhu, Junyi and Yao, Ruicong and Blaschko, Matthew B.}, 102 | booktitle = {Proceedings of the 40th International Conference on Machine Learning}, 103 | year = {2023}, 104 | series = {Proceedings of Machine Learning Research}, 105 | publisher = {PMLR}, 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /experiments/configs/cifar_ig.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 0, 3 | "path_to_data": "./data", 4 | "path_to_res": "./res", 5 | "dataset": "CIFAR100", 6 | "model": "CNNcifar", 7 | "k": 50, 8 | "batchsize": 10, 9 | "epochs": 20, 10 | "alpha": 0.0, 11 | "lamb": 0.01, 12 | "train_lr": 0.004, 13 | "eta": 1, 14 | "beta": 0.001, 15 | "iters": 1000, 16 | "test_steps": 50, 17 | "lr_decay": true 18 | } 19 | -------------------------------------------------------------------------------- /experiments/configs/cifar_sme.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 0, 3 | "path_to_data": "./data", 4 | "path_to_res": "./res", 5 | "dataset": "CIFAR100", 6 | "model": "CNNcifar", 7 | "k": 50, 8 | "batchsize": 10, 9 | "epochs": 20, 10 | "alpha": 0.5, 11 | "lamb": 0.01, 12 | "train_lr": 0.004, 13 | "eta": 1, 14 | "beta": 0.001, 15 | "iters": 1000, 16 | "test_steps": 50, 17 | "lr_decay": true 18 | } 19 | -------------------------------------------------------------------------------- /experiments/configs/femnist_ig.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 0, 3 | "path_to_data": "./data", 4 | "path_to_res": "./res", 5 | "dataset": "FEMNIST", 6 | "model": "CNNmnist", 7 | "k": 50, 8 | "batchsize": 10, 9 | "epochs": 20, 10 | "alpha": 0, 11 | "lamb": 0.01, 12 | "train_lr": 0.004, 13 | "eta": 1, 14 | "beta": 0.001, 15 | "iters": 1000, 16 | "test_steps": 50, 17 | "lr_decay": true 18 | } 19 | -------------------------------------------------------------------------------- /experiments/configs/femnist_sme.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 0, 3 | "path_to_data": "./data", 4 | "path_to_res": "./res", 5 | "dataset": "FEMNIST", 6 | "model": "CNNmnist", 7 | "k": 50, 8 | "batchsize": 10, 9 | "epochs": 20, 10 | "alpha": 0.5, 11 | "lamb": 0.01, 12 | "train_lr": 0.004, 13 | "eta": 1, 14 | "beta": 0.001, 15 | "iters": 1000, 16 | "test_steps": 50, 17 | "lr_decay": true 18 | } 19 | -------------------------------------------------------------------------------- /experiments/configs/vit_ig.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 0, 3 | "dataset": "FEMNIST", 4 | "path_to_data": "./data", 5 | "path_to_res": "./res", 6 | "model": "ViT", 7 | "k": 50, 8 | "batchsize": 10, 9 | "epochs": 20, 10 | "alpha": 0.0, 11 | "lamb": 0.01, 12 | "train_lr": 0.004, 13 | "lr_decay": false, 14 | "eta": 0.1, 15 | "beta": 0.001, 16 | "iters": 1000, 17 | "test_steps": 50 18 | } 19 | -------------------------------------------------------------------------------- /experiments/configs/vit_sme.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 0, 3 | "path_to_data": "./data", 4 | "path_to_res": "./res", 5 | "dataset": "FEMNIST", 6 | "model": "ViT", 7 | "k": 50, 8 | "batchsize": 10, 9 | "epochs": 20, 10 | "alpha": 0.5, 11 | "lamb": 0.01, 12 | "train_lr": 0.004, 13 | "lr_decay": false, 14 | "eta": 0.1, 15 | "beta": 0.001, 16 | "iters": 1000, 17 | "test_steps": 50 18 | } 19 | -------------------------------------------------------------------------------- /experiments/run_experiment.py: -------------------------------------------------------------------------------- 1 | import click 2 | import json 3 | from sme import attack 4 | 5 | @click.command() 6 | @click.option("--path_to_data", default="./data") 7 | @click.option("--path_to_res", default="./res") 8 | @click.option("--dataset", type=click.Choice(["CIFAR100", "FEMNIST"]), default="CIFAR100") 9 | # federated learning parameters 10 | @click.option("--model", type=click.Choice(["LeNet", "MLP", "CNNcifar", "CNNmnist", "ResNet8", "ViT"]), default="CNNcifar") 11 | @click.option("--batchsize", default=10, help="Batch size of federated learning.") 12 | @click.option("--train_lr", default=0.01, help="Learning rate of federated learning.") 13 | @click.option("--k", default=10, help="Size of local dataset.") 14 | @click.option("--epochs", default=20, help="Number of epochs for the local training in the federated learning.") 15 | # reconstruction attack parameters 16 | @click.option("--eta", default=1e-3, help="Step size of the reconstruction attack.") 17 | @click.option("--beta", default=1e-3, help="Step size of alpha.") 18 | @click.option("--alpha", default=0., help="Interpolation factor.") 19 | @click.option("--iters", default=5000, help="Optimization iterations of reconstruction.") 20 | @click.option("--test_steps", default=500, help="Measure the psnr and save figs every so many steps.") 21 | @click.option("--lamb", default=1e-4, help="Total variation coefficient.") 22 | @click.option("--lr_decay", default=True, help="Use learning rate decay.") 23 | @click.option("--seed", default=0) 24 | @click.option( 25 | "--config", help="Path to the configuration file.", default=None, 26 | ) 27 | def main(**kwargs): 28 | if kwargs["config"]: 29 | with open(kwargs["config"]) as f: 30 | kwargs = json.load(f) 31 | else: 32 | del kwargs["config"] 33 | 34 | print(kwargs) 35 | attack(**kwargs) 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | 41 | -------------------------------------------------------------------------------- /models/FCN3.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from collections import OrderedDict 3 | 4 | 5 | class MLP(nn.Module): 6 | def __init__(self, classes=10): 7 | super(MLP, self).__init__() 8 | # act = nn.LeakyReLU(negative_slope=1e-2) 9 | act = nn.ReLU() 10 | self.body = nn.ModuleList([ 11 | nn.Sequential(OrderedDict([ 12 | ('layer', nn.Linear(784, 1000)), 13 | ('act', act) 14 | ])), 15 | nn.Sequential(OrderedDict([ 16 | ('layer', nn.Linear(1000, 1000)), 17 | ('act', act) 18 | ])), 19 | nn.Sequential(OrderedDict([ 20 | ('layer', nn.Linear(1000, classes)), 21 | ('act', act) 22 | ])) 23 | ]) 24 | 25 | def forward(self, x): 26 | for layer in self.body: 27 | if isinstance(layer.layer, nn.Linear): 28 | x = x.flatten(1) 29 | x = layer(x) 30 | return x 31 | -------------------------------------------------------------------------------- /models/LeNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class LeNet(nn.Module): 7 | def __init__(self, classes=10): 8 | super(LeNet, self).__init__() 9 | self.conv1 = nn.Conv2d(1, 6, 5) 10 | self.pool = nn.AvgPool2d(2, 2, padding=1) 11 | self.conv2 = nn.Conv2d(6, 16, 5) 12 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 13 | self.fc2 = nn.Linear(120, 84) 14 | self.fc3 = nn.Linear(84, classes) 15 | 16 | def forward(self, x): 17 | x = self.pool(torch.tanh(self.conv1(x))) 18 | x = self.pool(torch.tanh(self.conv2(x))) 19 | x = x.view(-1, 16 * 5 * 5) 20 | x = torch.tanh(self.fc1(x)) 21 | x = torch.tanh(self.fc2(x)) 22 | x = F.softmax(self.fc3(x), dim=1) 23 | return x 24 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.FCN3 import MLP 2 | from models.LeNet import LeNet 3 | from models.cifarCNN import CNNcifar 4 | from models.mnistCNN import CNNmnist 5 | from models.resnet import ResNet8 6 | from models.vit import ViT 7 | -------------------------------------------------------------------------------- /models/cifarCNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = [ 5 | "CNNcifar", 6 | ] 7 | 8 | 9 | class CNNcifar(nn.Module): 10 | def __init__(self, classes=10): 11 | super(CNNcifar, self).__init__() 12 | self.act = nn.ReLU() 13 | self.body = nn.Sequential( 14 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), 15 | self.act, 16 | nn.AvgPool2d(kernel_size=2, stride=2), 17 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), 18 | self.act, 19 | nn.AvgPool2d(kernel_size=2, stride=2), 20 | ) 21 | self.fc1 = nn.Linear(8192, 200) 22 | self.fc2 = nn.Linear(200, classes) 23 | 24 | def forward(self, x): 25 | x = self.body(x) 26 | x = torch.flatten(x, start_dim=1) 27 | x = self.act(self.fc1(x)) 28 | return self.fc2(x) 29 | -------------------------------------------------------------------------------- /models/mnistCNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = [ 5 | "CNNmnist", 6 | ] 7 | 8 | 9 | class CNNmnist(nn.Module): 10 | def __init__(self, classes=10): 11 | super(CNNmnist, self).__init__() 12 | self.act = nn.ReLU() 13 | self.body = nn.Sequential( 14 | nn.Conv2d(1, 32, kernel_size=3, padding=1), 15 | self.act, 16 | nn.AvgPool2d(kernel_size=2, stride=2), 17 | nn.Conv2d(32, 64, kernel_size=3, padding=1), 18 | self.act, 19 | nn.AvgPool2d(kernel_size=2, stride=2), 20 | ) 21 | self.fc1 = nn.Linear(3136, 100) 22 | self.fc2 = nn.Linear(100, classes) 23 | 24 | def forward(self, x): 25 | x = self.body(x) 26 | x = torch.flatten(x, start_dim=1) 27 | x = self.act(self.fc1(x)) 28 | return self.fc2(x) 29 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | Reference: 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1): 16 | super(BasicBlock, self).__init__() 17 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | 22 | self.shortcut = nn.Sequential() 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 26 | nn.BatchNorm2d(self.expansion*planes) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | out += self.shortcut(x) 33 | out = F.relu(out) 34 | return out 35 | 36 | 37 | class ResNet(nn.Module): 38 | def __init__(self, block, num_blocks, num_classes=10): 39 | super(ResNet, self).__init__() 40 | self.in_planes = 32 41 | 42 | self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1, bias=False) 43 | self.bn1 = nn.BatchNorm2d(32) 44 | self.layer1 = self._make_layer(block, 32, num_blocks[0], stride=1) 45 | self.layer2 = self._make_layer(block, 64, num_blocks[1], stride=2) 46 | self.layer3 = self._make_layer(block, 128, num_blocks[2], stride=2) 47 | self.linear1 = nn.Linear(2048, 1000) 48 | self.linear2 = nn.Linear(1000, num_classes) 49 | 50 | def _make_layer(self, block, planes, num_blocks, stride): 51 | strides = [stride] + [1]*(num_blocks-1) 52 | layers = [] 53 | for stride in strides: 54 | layers.append(block(self.in_planes, planes, stride)) 55 | self.in_planes = planes * block.expansion 56 | return nn.Sequential(*layers) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.bn1(self.conv1(x))) 60 | out = self.layer1(out) 61 | out = self.layer2(out) 62 | out = self.layer3(out) 63 | out = F.avg_pool2d(out, 2, 2, padding=1) 64 | out = out.view(out.size(0), -1) 65 | out = self.linear1(out) 66 | out = self.linear2(out) 67 | return out 68 | 69 | 70 | def ResNet8(classes): 71 | return ResNet(BasicBlock, [1, 1, 1], num_classes=classes) 72 | -------------------------------------------------------------------------------- /models/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops.layers.torch import Rearrange 4 | 5 | 6 | class ViT(nn.Module): 7 | def __init__(self, classes): 8 | super(ViT, self).__init__() 9 | self.patch_size = 28 10 | self.num_channels = 1 11 | self.embed_dim = 384 12 | self.num_heads = 4 13 | self.num_layers = 4 14 | self.mlp_ratio = 2 15 | self.num_classes = classes 16 | 17 | self.patch_embedding = nn.Sequential( 18 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size), 19 | nn.Linear(self.patch_size * self.patch_size * self.num_channels, self.embed_dim) 20 | ) 21 | 22 | self.position_embedding = nn.Parameter(torch.zeros(1, (28 // self.patch_size) * (28 // self.patch_size) + 1, self.embed_dim)) 23 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 24 | 25 | self.transformer_blocks = nn.ModuleList( 26 | [ModifiedBlock(self.embed_dim, self.num_heads, self.mlp_ratio)] + 27 | [TransformerBlock(self.embed_dim, self.num_heads, self.mlp_ratio) 28 | for _ in range(self.num_layers - 1)]) 29 | 30 | self.norm = nn.LayerNorm(self.embed_dim) 31 | self.fc1 = nn.Linear(self.embed_dim, self.embed_dim * self.mlp_ratio) 32 | self.fc2 = nn.Linear(self.embed_dim * self.mlp_ratio, self.num_classes) 33 | 34 | def forward(self, x): 35 | B, _, _, _ = x.shape 36 | 37 | x = self.patch_embedding(x) 38 | cls_tokens = self.cls_token.expand(B, -1, -1) 39 | x = torch.cat((cls_tokens, x), dim=1) 40 | x += self.position_embedding 41 | 42 | for block in self.transformer_blocks: 43 | x = block(x) 44 | 45 | x = self.norm(x[:, 0]) 46 | x = nn.GELU()(self.fc1(x)) 47 | x = self.fc2(x) 48 | 49 | return x 50 | 51 | 52 | class MultiHeadAttention(nn.Module): 53 | def __init__(self, embed_dim, num_heads): 54 | super(MultiHeadAttention, self).__init__() 55 | self.embed_dim = embed_dim 56 | self.num_heads = num_heads 57 | self.head_dim = embed_dim // num_heads 58 | 59 | self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False) 60 | self.attention = nn.Softmax(dim=-1) 61 | self.out_proj = nn.Linear(embed_dim, embed_dim) 62 | 63 | def forward(self, x): 64 | B, N, C = x.shape 65 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 66 | q, k, v = qkv[0], qkv[1], qkv[2] 67 | 68 | attn_weights = self.attention((q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)) 69 | attn_output = (attn_weights @ v).transpose(1, 2).reshape(B, N, C) 70 | 71 | return self.out_proj(attn_output) 72 | 73 | 74 | class MLP(nn.Module): 75 | def __init__(self, in_features, hidden_features, out_features): 76 | super(MLP, self).__init__() 77 | self.fc = nn.Sequential( 78 | nn.Linear(in_features, hidden_features), 79 | nn.GELU(), 80 | nn.Linear(hidden_features, out_features) 81 | ) 82 | 83 | def forward(self, x): 84 | return self.fc(x) 85 | 86 | 87 | class TransformerBlock(nn.Module): 88 | def __init__(self, embed_dim, num_heads, mlp_ratio): 89 | super(TransformerBlock, self).__init__() 90 | self.norm1 = nn.LayerNorm(embed_dim) 91 | self.mha = MultiHeadAttention(embed_dim, num_heads) 92 | self.norm2 = nn.LayerNorm(embed_dim) 93 | self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), embed_dim) 94 | 95 | def forward(self, x): 96 | x = x + self.mha(self.norm1(x)) 97 | x = x + self.mlp(self.norm2(x)) 98 | return x 99 | 100 | 101 | class ModifiedBlock(nn.Module): 102 | def __init__(self, embed_dim, num_heads, mlp_ratio): 103 | super(ModifiedBlock, self).__init__() 104 | self.mha = MultiHeadAttention(embed_dim, num_heads) 105 | self.norm2 = nn.LayerNorm(embed_dim) 106 | self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), embed_dim) 107 | 108 | def forward(self, x): 109 | x = self.mha(x) 110 | x = x + self.mlp(self.norm2(x)) 111 | return x 112 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | click 2 | torch==1.11.0 3 | torchvision==0.12.0 4 | numpy 5 | einops==0.4.1 6 | matplotlib 7 | scipy 8 | -------------------------------------------------------------------------------- /sme/Adversary/__init__.py: -------------------------------------------------------------------------------- 1 | from sme.Adversary.adversary import IWU -------------------------------------------------------------------------------- /sme/Adversary/adversary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import copy 4 | from utils import psnr, save_figs 5 | from sme.Adversary.utils import * 6 | 7 | 8 | class IWU: 9 | def __init__( 10 | self, 11 | trainloader, 12 | setup, 13 | alpha, 14 | test_steps, 15 | path_to_res, 16 | mean_std, 17 | lamb, 18 | dataset=None, 19 | ): 20 | self.alpha = torch.tensor(alpha, requires_grad=True, **setup) 21 | self.rec_alpha = 0 < self.alpha < 1 22 | self.setup = setup 23 | self.net0 = None 24 | self.net1 = None 25 | self.test_steps = test_steps 26 | os.makedirs(path_to_res, exist_ok=True) 27 | self.path = path_to_res 28 | self.lamb = lamb 29 | self.dataset = dataset 30 | data, labels = [], [] 31 | for img, l in trainloader: 32 | labels.append(l) 33 | data.append(img) 34 | self.data = torch.cat(data).to(**setup) 35 | 36 | # We assume that labels have been restored separately, for details please refer to the paper. 37 | self.y = torch.cat(labels).to(device=setup["device"]) 38 | # Dummy input. 39 | self.x = torch.normal(0, 1, size=self.data.shape, requires_grad=True, **setup) 40 | 41 | self.mean = torch.tensor(mean_std[0]).to(**setup).reshape(1, -1, 1, 1) 42 | self.std = torch.tensor(mean_std[1]).to(**setup).reshape(1, -1, 1, 1) 43 | # This is a trick (a sort of prior information) adopted from IG. 44 | prior_boundary(self.x, -self.mean / self.std, (1 - self.mean) / self.std) 45 | 46 | def reconstruction(self, eta, beta, iters, lr_decay, signed_grad=False, save_figure=True): 47 | # when taking the SME strategy, alpha is set within (0, 1). 48 | if 0 < self.alpha < 1: 49 | self.alpha.grad = torch.tensor(0.).to(**self.setup) 50 | optimizer = torch.optim.Adam(params=[self.x], lr=eta) 51 | alpha_opti = torch.optim.Adam(params=[self.alpha], lr=beta) 52 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 53 | milestones=[iters // 2.667, 54 | iters // 1.6, 55 | iters // 1.142], 56 | gamma=0.1) 57 | alpha_scheduler = torch.optim.lr_scheduler.MultiStepLR(alpha_opti, 58 | milestones=[iters // 2.667, 59 | iters // 1.6, 60 | iters // 1.142], 61 | gamma=0.1) 62 | criterion = torch.nn.CrossEntropyLoss(reduction="mean") 63 | 64 | # Direction of the weight update. 65 | w1_w0 = [] 66 | for p0, p1 in zip(self.net0.parameters(), self.net1.parameters()): 67 | w1_w0.append(p0.data - p1.data) 68 | norm = compute_norm(w1_w0) 69 | w1_w0 = [p / norm for p in w1_w0] 70 | 71 | # Construct the model for gradient inversion attack. 72 | require_grad(self.net0, False) 73 | require_grad(self.net1, False) 74 | with torch.no_grad(): 75 | _net = copy.deepcopy(self.net0) 76 | for p, q, z in zip(self.net0.parameters(), self.net1.parameters(), _net.parameters()): 77 | z.data = (1 - self.alpha) * p + self.alpha * q 78 | 79 | # Reconstruction 80 | _net.eval() 81 | stats = [] 82 | for i in range(iters): 83 | optimizer.zero_grad() 84 | alpha_opti.zero_grad(set_to_none=False) 85 | _net.zero_grad() 86 | 87 | if self.rec_alpha: 88 | # Update the surrogate model. 89 | with torch.no_grad(): 90 | for p, q, z in zip(self.net0.parameters(), self.net1.parameters(), _net.parameters()): 91 | z.data = (1 - self.alpha) * p + self.alpha * q 92 | pred = _net(self.x) 93 | loss = criterion(input=pred, target=self.y) 94 | grad = torch.autograd.grad(loss, _net.parameters(), create_graph=True) 95 | norm = compute_norm(grad) 96 | grad = [p / norm for p in grad] 97 | 98 | # Compute x's grad. 99 | cos_loss = 1 - sum([ 100 | p.mul(q).sum() for p, q in zip(w1_w0, grad) 101 | ]) 102 | loss = cos_loss + self.lamb * total_variation(self.x) 103 | loss.backward() 104 | if signed_grad: 105 | self.x.grad.sign_() 106 | 107 | # Compute alpha's grad. 108 | if self.rec_alpha: 109 | with torch.no_grad(): 110 | for p, q, z in zip(self.net0.parameters(), self.net1.parameters(), _net.parameters()): 111 | self.alpha.grad += z.grad.mul( 112 | q.data - p.data 113 | ).sum() 114 | if signed_grad: 115 | self.alpha.grad.sign_() 116 | 117 | # Update x and alpha. 118 | optimizer.step() 119 | alpha_opti.step() 120 | prior_boundary(self.x, -self.mean / self.std, (1 - self.mean) / self.std) 121 | prior_boundary(self.alpha, 0, 1) 122 | if lr_decay: 123 | scheduler.step() 124 | alpha_scheduler.step() 125 | if i % self.test_steps == 0 or i == iters - 1: 126 | with torch.no_grad(): 127 | _x = self.x * self.std + self.mean 128 | _data = self.data * self.std + self.mean 129 | measurement = psnr(_data, _x, sort=True) 130 | print(f"iter: {i}| alpha: {self.alpha.item():.2f}| (1 - cos): {cos_loss.item():.3f}| " 131 | f"psnr: {measurement:.3f}") 132 | stats.append({ 133 | "iter": i, 134 | "alpha": self.alpha.item(), 135 | "cos_loss": cos_loss.item(), 136 | "psnr": measurement, 137 | }) 138 | if save_figure: 139 | save_figs(tensors=_x, path=self.path, subdir=str(i), dataset=self.dataset) 140 | if save_figure: 141 | save_figs(tensors=self.data * self.std + self.mean, 142 | path=self.path, subdir="original", dataset=self.dataset) 143 | return stats 144 | -------------------------------------------------------------------------------- /sme/Adversary/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | __all__ = [ 4 | "require_grad", 5 | "prior_boundary", 6 | "compute_norm", 7 | "total_variation" 8 | ] 9 | 10 | 11 | def require_grad(net, flag): 12 | for p in net.parameters(): 13 | p.require_grad = flag 14 | 15 | 16 | def prior_boundary(data, low, high): 17 | with torch.no_grad(): 18 | data.data = torch.clamp(data, low, high) 19 | 20 | 21 | def compute_norm(inputs): 22 | squared_sum = sum([p.square().sum() for p in inputs]) 23 | norm = squared_sum.sqrt() 24 | return norm 25 | 26 | 27 | def total_variation(x): 28 | dh = (x[:, :, :, :-1] - x[:, :, :, 1:]).abs().mean() 29 | dw = (x[:, :, :-1, :] - x[:, :, 1:, :]).abs().mean() 30 | return (dh + dw) / 2 31 | -------------------------------------------------------------------------------- /sme/__init__.py: -------------------------------------------------------------------------------- 1 | from sme.attack import attack -------------------------------------------------------------------------------- /sme/attack.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | import torch 4 | import os 5 | import json 6 | from models import * 7 | from sme.Adversary import IWU 8 | from utils import attack_dataloader, random_seed 9 | 10 | 11 | def attack( 12 | path_to_data, 13 | path_to_res, 14 | dataset, 15 | model, 16 | seed, 17 | batchsize, 18 | train_lr, 19 | epochs, 20 | alpha, 21 | eta, 22 | iters, 23 | lamb, 24 | lr_decay, 25 | beta, 26 | test_steps, 27 | k, 28 | ): 29 | os.makedirs(path_to_res, exist_ok=True) 30 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 31 | setup = {"device": device, "dtype": torch.float32} 32 | random_seed(seed) 33 | 34 | # Prepare dataset 35 | trainloaders, _, mean_std = attack_dataloader( 36 | path=path_to_data, 37 | dataset=dataset, 38 | batchsize=batchsize, 39 | local_data_size=k, 40 | ) 41 | # Initialize the network 42 | if dataset == "CIFAR100": 43 | classes = 100 44 | elif dataset == "FEMNIST": 45 | classes = 62 46 | else: 47 | raise ValueError 48 | net = eval(f"{model}(classes={classes})").to(**setup) 49 | 50 | # Initialize the adversary 51 | trainloader = random.choice(trainloaders) 52 | adversary = IWU( 53 | trainloader=trainloader, 54 | setup=setup, 55 | alpha=alpha, 56 | test_steps=test_steps, 57 | path_to_res=path_to_res, 58 | lamb=lamb, 59 | mean_std=mean_std, 60 | dataset=dataset, 61 | ) 62 | 63 | # Victim trains local model 64 | with torch.no_grad(): 65 | net1 = copy.deepcopy(net) 66 | adversary.net0 = net 67 | train( 68 | net=net1, 69 | trainloader=trainloader, 70 | epochs=epochs, 71 | train_lr=train_lr, 72 | setup=setup 73 | ) 74 | adversary.net1 = net1 75 | 76 | # Reconstruction 77 | stats = adversary.reconstruction( 78 | eta=eta, 79 | beta=beta, 80 | iters=iters, 81 | lr_decay=lr_decay, 82 | save_figure=True, 83 | ) 84 | with open(os.path.join(path_to_res, "res.json"), "w") as f: 85 | json.dump(stats, f, indent=4) 86 | 87 | 88 | def train( 89 | net, 90 | trainloader, 91 | epochs, 92 | train_lr, 93 | setup, 94 | ): 95 | # In evaluation mode, updates to the running statistics of Batch Normalization (if applicable) are halted. 96 | # This practice follows the work of IG. For more details, please refer to our paper. 97 | net.eval() 98 | criterion = torch.nn.CrossEntropyLoss(reduction="mean") 99 | optimizer = torch.optim.SGD( 100 | params=net.parameters(), 101 | lr=train_lr 102 | ) 103 | for _ in range(epochs): 104 | for data, label in trainloader: 105 | optimizer.zero_grad() 106 | data = data.to(**setup) 107 | label = label.to(device=setup["device"]) 108 | 109 | pred = net(data) 110 | loss = criterion(input=pred, target=label) 111 | 112 | loss.backward() 113 | optimizer.step() 114 | 115 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from utils.utils import * 2 | from utils.dataloader import attack_dataloader -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import json 4 | from collections import defaultdict 5 | import numpy as np 6 | import torchvision 7 | from torch.utils.data import TensorDataset, DataLoader 8 | from utils import cifar100_preprocessing, femnist_preprocessing 9 | 10 | 11 | MEAN_STD = { 12 | "CIFAR100": ((0.5071, 0.4866, 0.4409), (0.2673, 0.2564, 0.2762)), 13 | "FEMNIST": ((0.9642384386141873,), (0.15767843198426892,)) 14 | } 15 | 16 | 17 | def load_dataset(dataset, path): 18 | if dataset == "CIFAR100": 19 | trainset = torchvision.datasets.CIFAR100( 20 | root=path, 21 | train=True, 22 | transform=cifar100_preprocessing(), 23 | download=True 24 | ) 25 | testset = torchvision.datasets.CIFAR100( 26 | root=path, 27 | train=False, 28 | transform=cifar100_preprocessing(), 29 | download=True 30 | ) 31 | else: 32 | raise ValueError 33 | 34 | return trainset, testset 35 | 36 | 37 | def attack_dataloader( 38 | path, 39 | dataset, 40 | batchsize, 41 | local_data_size, 42 | ): 43 | 44 | if dataset != "FEMNIST": 45 | trainset, testset = load_dataset(dataset, path) 46 | trainset = torch.utils.data.Subset(trainset, np.arange(local_data_size * (len(trainset) // local_data_size))) 47 | train_sets = torch.utils.data.random_split( 48 | trainset, 49 | [local_data_size] * (len(trainset) // local_data_size) 50 | ) 51 | else: 52 | train_data_dir = os.path.join(path, "femnist", 'data', 'train') 53 | test_data_dir = os.path.join(path, "femnist", 'data', 'test') 54 | _, _, train_data, test_data = read_data(train_data_dir, test_data_dir) 55 | train_keys = list(train_data.keys()) 56 | train_keys.sort() 57 | test_keys = list(test_data.keys()) 58 | test_keys.sort() 59 | train_sets = [ 60 | TensorDataset( 61 | torch.stack( 62 | [femnist_preprocessing()(x) for x in np.array(train_data[k]['x']).reshape(-1, 28, 28)] 63 | ).to(dtype=torch.float32)[:local_data_size], 64 | torch.Tensor(np.array(train_data[k]['y'])).to(dtype=torch.long)[:local_data_size] 65 | ) 66 | for k in train_keys 67 | ] 68 | 69 | trainloaders = [ 70 | DataLoader(d, batch_size=batchsize, shuffle=True) for d in train_sets 71 | ] 72 | 73 | return trainloaders, None, MEAN_STD[dataset] 74 | 75 | 76 | def read_dir(data_dir): 77 | clients = [] 78 | groups = [] 79 | data = defaultdict(lambda: None) 80 | data_dir = os.path.expanduser(data_dir) 81 | files = os.listdir(data_dir) 82 | files = [f for f in files if f.endswith('.json')] 83 | for f in files: 84 | file_path = os.path.join(data_dir, f) 85 | with open(file_path, 'r') as inf: 86 | cdata = json.load(inf) 87 | clients.extend(cdata['users']) 88 | if 'hierarchies' in cdata: 89 | groups.extend(cdata['hierarchies']) 90 | data.update(cdata['user_data']) 91 | 92 | clients = list(sorted(data.keys())) 93 | return clients, groups, data 94 | 95 | 96 | def read_data(train_data_dir, test_data_dir): 97 | train_clients, train_groups, train_data = read_dir(train_data_dir) 98 | test_clients, test_groups, test_data = read_dir(test_data_dir) 99 | 100 | assert train_clients == test_clients 101 | assert train_groups == test_groups 102 | 103 | return train_clients, train_groups, train_data, test_data 104 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transformers 3 | from torchvision.transforms import ToPILImage 4 | from scipy.optimize import linear_sum_assignment 5 | import random 6 | import matplotlib.pyplot as plt 7 | import os 8 | import json 9 | import numpy as np 10 | 11 | __all__ = [ 12 | "random_seed", 13 | "cifar100_preprocessing", 14 | "femnist_preprocessing", 15 | "psnr", 16 | "save_args", 17 | "save_figs", 18 | ] 19 | 20 | 21 | def random_seed(seed): 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | np.random.seed(seed) 25 | random.seed(seed) 26 | 27 | def cifar100_preprocessing(): 28 | transform = transformers.Compose([ 29 | transformers.ToTensor(), 30 | transformers.Normalize((0.5071, 0.4866, 0.4409), (0.2673, 0.2564, 0.2762)), 31 | ]) 32 | return transform 33 | 34 | def femnist_preprocessing(): 35 | transform = transformers.Compose([ 36 | transformers.ToTensor(), 37 | transformers.Normalize((0.9642384386141873,), (0.15767843198426892,)) 38 | ]) 39 | return transform 40 | 41 | 42 | def psnr(data, rec, sort=False): 43 | assert data.max().item() <= 1.0001 and data.min().item() >= -0.0001 44 | assert rec.max().item() <= 1.0001 and rec.min().item() >= -0.0001 45 | cost_matrix = [] 46 | if sort: 47 | for x_ in rec: 48 | cost_matrix.append( 49 | [(x_ - d).square().mean().item() for d in data] 50 | ) 51 | row_ind, col_ind = linear_sum_assignment(cost_matrix) 52 | assert np.all(row_ind == np.arange(len(row_ind))) 53 | data = data[col_ind] 54 | psnr_list = [10 * np.log10(1 / (d - r).square().mean().item()) for d, r in zip(data, rec)] 55 | return np.mean(psnr_list) 56 | 57 | 58 | def save_args(**kwargs): 59 | if os.path.exists(os.path.join(kwargs["path_to_res"], "args.json")): 60 | os.remove(os.path.join(kwargs["path_to_res"], "args.json")) 61 | 62 | with open(os.path.join(kwargs["path_to_res"], "args.json"), "w") as f: 63 | json.dump(kwargs, f, indent=4) 64 | 65 | 66 | def save_figs(tensors, path, subdir=None, dataset=None): 67 | def save(imgs, path): 68 | for name, im in imgs: 69 | plt.figure() 70 | plt.imshow(im, cmap='gray') 71 | plt.axis('off') 72 | plt.savefig(os.path.join(path, f'{name}.png'), bbox_inches='tight') 73 | plt.close() 74 | tensor2image = ToPILImage() 75 | path = os.path.join(path, subdir) 76 | os.makedirs(path, exist_ok=True) 77 | if dataset == "FEMNIST": 78 | tensors = 1 - tensors 79 | imgs = [ 80 | [i, tensor2image(tensors[i].detach().cpu().squeeze())] for i in range(len(tensors)) 81 | ] 82 | save(imgs, path) 83 | --------------------------------------------------------------------------------