├── utils ├── __init__.py ├── __pycache__ │ ├── metrics.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── analysis.cpython-38.pyc │ └── breaching_utils.cpython-38.pyc ├── imprint_guarantee.py ├── breaching_utils.py ├── analysis.py └── metrics.py ├── attacks ├── __init__.py ├── __pycache__ │ ├── common.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── base_attack.cpython-38.pyc │ └── analytic_attack.cpython-38.pyc ├── analytic_attack.py ├── common.py └── base_attack.py ├── modifications ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── imprint.cpython-38.pyc └── imprint.py ├── teaser.png ├── README.md └── breaching_fl.ipynb /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /attacks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /modifications/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhfowl/robbing_the_fed/HEAD/teaser.png -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhfowl/robbing_the_fed/HEAD/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /attacks/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhfowl/robbing_the_fed/HEAD/attacks/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhfowl/robbing_the_fed/HEAD/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/analysis.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhfowl/robbing_the_fed/HEAD/utils/__pycache__/analysis.cpython-38.pyc -------------------------------------------------------------------------------- /attacks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhfowl/robbing_the_fed/HEAD/attacks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /attacks/__pycache__/base_attack.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhfowl/robbing_the_fed/HEAD/attacks/__pycache__/base_attack.cpython-38.pyc -------------------------------------------------------------------------------- /attacks/__pycache__/analytic_attack.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhfowl/robbing_the_fed/HEAD/attacks/__pycache__/analytic_attack.cpython-38.pyc -------------------------------------------------------------------------------- /modifications/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhfowl/robbing_the_fed/HEAD/modifications/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /modifications/__pycache__/imprint.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhfowl/robbing_the_fed/HEAD/modifications/__pycache__/imprint.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/breaching_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lhfowl/robbing_the_fed/HEAD/utils/__pycache__/breaching_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/imprint_guarantee.py: -------------------------------------------------------------------------------- 1 | from math import comb as nCr 2 | 3 | 4 | def expected_amount(k, n): 5 | """ 6 | k number of bins, n batch size 7 | """ 8 | total_num = nCr(k + n - 1, k - 1) # Total number of configs 9 | weight = 0 10 | for i in range(1, n - 1): 11 | temp = i * nCr(k, i) 12 | temp2 = 0 13 | for j in range(1, (n - i) // 2 + 1): 14 | temp2 += nCr(k - i, j) * nCr(n - i - j - 1, j - 1) 15 | weight += temp * temp2 16 | adjustment1 = n * nCr(k, n) # First term in r(n,k) 17 | weight += adjustment1 18 | return weight / total_num - n / k # Second adjustment term in r(n,k) 19 | 20 | 21 | def one_shot_guarantee(k, n): 22 | """ 23 | k number of bins, n batch size 24 | """ 25 | total_num = nCr(k + n - 1, k - 1) # Total number of configs 26 | weight = 0 27 | weight += nCr(n + k - 3, k - 2) 28 | return weight / total_num 29 | 30 | 31 | if __name__ == "__main__": 32 | print(expected_amount(3, 6)) 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Robbing the FED: Directly Obtaining Private Data in Federated Learning with Modified Models 2 | 3 | This repo contains a barebones implementation for the attack detailed in the paper: 4 | 5 | ``` 6 | Fowl L, Geiping J, Czaja W, Goldblum M, Goldstein T. 7 | Robbing the Fed: Directly Obtaining Private Data in Federated Learning with Modified Models. 8 | arXiv preprint arXiv:2110.13057. 2021 Oct 25. 9 | 10 | ``` 11 | 12 | ![Teaser](teaser.png) 13 | *Left: batch of 64 ImageNet images. Right: Images reconstructed with imprint module containing 128 bins placed in front of a ResNet-18. Average PSNR: 70.94.* 14 | 15 | ### Abstract: 16 | Federated learning has quickly gained popularity with its promises of increased 17 | user privacy and efficiency. Previous works have shown that federated gradient 18 | updates contain information that can be used to approximately recover user data 19 | in some situations. These previous attacks on user privacy have been limited in 20 | scope and do not scale to gradient updates aggregated over even a handful of data 21 | points, leaving some to conclude that data privacy is still intact for realistic training 22 | regimes. In this work, we introduce a new threat model based on minimal but 23 | malicious modifications of the shared model architecture which enable the server 24 | to directly obtain a verbatim copy of user data from gradient updates without 25 | solving difficult inverse problems. Even user data aggregated over large batches – 26 | where previous methods fail to extract meaningful content – can be reconstructed 27 | by these minimally modified models. 28 | 29 | 30 | ### Code: 31 | 32 | This barebones implementation was adapted from a larger FL attack zoo written by [Jonas Geiping](https://github.com/JonasGeiping). Thanks to him for the nice code :). This will be available soon and we suggest you check it out for a more thorough implementation of this particular attack, as well as others. 33 | 34 | For this repo, the easiest way to get up and running is to play around with ```breaching_fl.ipynb```. This contains a start-to-finish imprint attack on a FL system. The guts of the imprint module can be found in ```modifications/imprint.py```. 35 | 36 | Requirements: 37 | ``` 38 | pytorch=1.4.0 39 | torchvision=0.5.0 40 | ``` 41 | -------------------------------------------------------------------------------- /utils/breaching_utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | 4 | def plot_data(cfg, user_data, setup, scale=False, print_labels=False): 5 | """Plot user data to output. Probably best called from a jupyter notebook.""" 6 | import matplotlib.pyplot as plt # lazily import this here 7 | 8 | dm = torch.as_tensor(cfg.mean, **setup)[None, :, None, None] 9 | ds = torch.as_tensor(cfg.std, **setup)[None, :, None, None] 10 | 11 | data = user_data["data"].clone().detach() 12 | labels = user_data["labels"].clone().detach() if user_data["labels"] is not None else None 13 | classes = [] # If you want to get class labels, you need to fill this in. 14 | # e.g. for CIFAR-10, you want classes = ['Airplane', 'Automobile', ...] 15 | if labels is None: 16 | print_labels = False 17 | 18 | if scale: 19 | min_val, max_val = data.amin(dim=[2, 3], keepdim=True), data.amax(dim=[2, 3], keepdim=True) 20 | # print(f'min_val: {min_val} | max_val: {max_val}') 21 | data = (data - min_val) / (max_val - min_val) 22 | else: 23 | data.mul_(ds).add_(dm).clamp_(0, 1) 24 | data = data.to(dtype=torch.float32) 25 | 26 | if data.shape[0] == 1: 27 | plt.axis("off") 28 | plt.imshow(data[0].permute(1, 2, 0).cpu()) 29 | if print_labels: 30 | plt.title(f"Data with label {classes[labels]}") 31 | else: 32 | grid_shape = int(torch.as_tensor(data.shape[0]).sqrt().ceil()) 33 | s = 24 if data.shape[3] > 150 else 6 34 | fig, axes = plt.subplots(grid_shape, grid_shape, figsize=(s, s)) 35 | label_classes = [] 36 | for i, (im, axis) in enumerate(zip(data, axes.flatten())): 37 | axis.imshow(im.permute(1, 2, 0).cpu()) 38 | if labels is not None and print_labels: 39 | label_classes.append(classes[labels[i]]) 40 | axis.axis("off") 41 | if print_labels: 42 | print(label_classes) 43 | 44 | 45 | class data_cfg_default: 46 | size = (1_281_167,) 47 | classes = 1000 48 | shape = (3, 224, 224) 49 | normalize = True 50 | mean = (0.485, 0.456, 0.406) 51 | std = (0.229, 0.224, 0.225) 52 | 53 | 54 | class attack_cfg_default: 55 | type = "analytic" 56 | attack_type = "imprint-readout" 57 | label_strategy = "random" # Labels are not actually required for this attack 58 | normalize_gradients = False 59 | impl = namedtuple("impl", ["dtype", "mixed_precision", "JIT"])("float", False, "") 60 | 61 | 62 | -------------------------------------------------------------------------------- /attacks/analytic_attack.py: -------------------------------------------------------------------------------- 1 | """Simple analytic attack that works for (dumb) fully connected models.""" 2 | 3 | import torch 4 | 5 | from .base_attack import _BaseAttacker 6 | 7 | 8 | class AnalyticAttacker(_BaseAttacker): 9 | """Implements a sanity-check analytic inversion 10 | 11 | Only works for a torch.nn.Sequential model with input-sized FC layers.""" 12 | 13 | def __init__(self, model, loss_fn, cfg_attack, setup=dict(dtype=torch.float, device=torch.device("cpu"))): 14 | super().__init__(model, loss_fn, cfg_attack, setup) 15 | 16 | def __repr__(self): 17 | return f"""Attacker (of type {self.__class__.__name__}).""" 18 | 19 | def reconstruct(self, server_payload, shared_data, server_secrets=None, dryrun=False): 20 | # Initialize stats module for later usage: 21 | rec_models, labels, stats = self.prepare_attack(server_payload, shared_data) 22 | 23 | # Main reconstruction: loop starts here: 24 | inputs_from_queries = [] 25 | for model, user_gradient in zip(rec_models, shared_data["gradients"]): 26 | idx = len(user_gradient) - 1 27 | for layer in list(model)[::-1]: # Only for torch.nn.Sequential 28 | if isinstance(layer, torch.nn.Linear): 29 | bias_grad = user_gradient[idx] 30 | weight_grad = user_gradient[idx - 1] 31 | layer_inputs = self.invert_fc_layer(weight_grad, bias_grad, labels) 32 | idx -= 2 33 | elif isinstance(layer, torch.nn.Flatten): 34 | inputs = layer_inputs.reshape(shared_data["num_data_points"], *self.data_shape) 35 | else: 36 | raise ValueError(f"Layer {layer} not supported for this sanity-check attack.") 37 | inputs_from_queries += [inputs] 38 | 39 | final_reconstruction = torch.stack(inputs_from_queries).mean(dim=0) 40 | reconstructed_data = dict(data=inputs, labels=labels) 41 | 42 | return reconstructed_data, stats 43 | 44 | def invert_fc_layer(self, weight_grad, bias_grad, image_positions): 45 | """The basic trick to invert a FC layer.""" 46 | # By the way the labels are exactly at (bias_grad < 0).nonzero() if they are unique 47 | valid_classes = bias_grad != 0 48 | intermediates = weight_grad[valid_classes, :] / bias_grad[valid_classes, None] 49 | if len(image_positions) == 0: 50 | reconstruction_data = intermediates 51 | elif len(image_positions) == 1: 52 | reconstruction_data = intermediates.mean(dim=0) 53 | else: 54 | reconstruction_data = intermediates[image_positions] 55 | return reconstruction_data 56 | 57 | 58 | class ImprintAttacker(AnalyticAttacker): 59 | """Abuse imprint secret for near-perfect attack success.""" 60 | 61 | def reconstruct(self, server_payload, shared_data, server_secrets=None, dryrun=False): 62 | """This is somewhat hard-coded for images, but that is not a necessity.""" 63 | # Initialize stats module for later usage: 64 | rec_models, labels, stats = self.prepare_attack(server_payload, shared_data) 65 | 66 | if "ImprintBlock" in server_secrets.keys(): 67 | weight_idx = server_secrets["ImprintBlock"]["weight_idx"] 68 | bias_idx = server_secrets["ImprintBlock"]["bias_idx"] 69 | data_shape = server_secrets["ImprintBlock"]["shape"] 70 | else: 71 | raise ValueError(f"No imprint hidden in model {rec_models[0]} according to server.") 72 | 73 | bias_grad = shared_data["gradients"][0][bias_idx].clone() 74 | weight_grad = shared_data["gradients"][0][weight_idx].clone() 75 | if server_secrets["ImprintBlock"]["structure"] == "cumulative": 76 | for i in reversed(list(range(1, weight_grad.shape[0]))): 77 | weight_grad[i] -= weight_grad[i - 1] 78 | bias_grad[i] -= bias_grad[i - 1] 79 | 80 | image_positions = bias_grad.nonzero() 81 | layer_inputs = self.invert_fc_layer(weight_grad, bias_grad, []) 82 | 83 | if "decoder" in server_secrets["ImprintBlock"].keys(): 84 | inputs = server_secrets["ImprintBlock"]["decoder"](layer_inputs) 85 | else: 86 | inputs = layer_inputs.reshape(layer_inputs.shape[0], *data_shape)[:, :3, :, :] 87 | if weight_idx > 0: # An imprint block later in the network: 88 | inputs = torch.nn.functional.interpolate( 89 | inputs, size=self.data_shape[1:], mode="bicubic", align_corners=False 90 | ) 91 | inputs = torch.max(torch.min(inputs, (1 - self.dm) / self.ds), -self.dm / self.ds) 92 | 93 | if len(labels) >= inputs.shape[0]: 94 | # Fill up with zero if not enough data can be found: 95 | missing_entries = torch.zeros(len(labels) - inputs.shape[0], *self.data_shape, **self.setup) 96 | inputs = torch.cat([inputs, missing_entries], dim=0) 97 | else: 98 | print(f"Initially produced {inputs.shape[0]} hits.") 99 | # Cut additional hits: 100 | # this rule is optimal for clean data with few bins: 101 | # best_guesses = torch.topk(bias_grad[bias_grad != 0].abs(), len(labels), largest=False) 102 | # this rule is best when faced with differential privacy: 103 | best_guesses = torch.topk(weight_grad.mean(dim=1)[bias_grad != 0].abs(), len(labels), largest=True) 104 | print(f"Reduced to {len(labels)} hits.") 105 | # print(best_guesses.indices.sort().values) 106 | inputs = inputs[best_guesses.indices] 107 | 108 | reconstructed_data = dict(data=inputs, labels=labels) 109 | return reconstructed_data, stats 110 | -------------------------------------------------------------------------------- /breaching_fl.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "a8ee9650", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import torchvision\n", 12 | "from collections import namedtuple" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "ba321ff5", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "from attacks.analytic_attack import ImprintAttacker\n", 23 | "from modifications.imprint import ImprintBlock\n", 24 | "from utils.breaching_utils import *" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "id": "33ac450e", 30 | "metadata": {}, 31 | "source": [ 32 | "# Attack begins here:" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "id": "3d2c2795", 38 | "metadata": {}, 39 | "source": [ 40 | "### Initialize your model" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "6a6d6ba0", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "setup = dict(device=torch.device(\"cpu\"), dtype=torch.float)\n", 51 | "\n", 52 | "# This could be any model:\n", 53 | "model = torchvision.models.resnet18()\n", 54 | "model.eval()\n", 55 | "loss_fn = torch.nn.CrossEntropyLoss()\n", 56 | "# It will be modified maliciously:\n", 57 | "input_dim = data_cfg_default.shape[0] * data_cfg_default.shape[1] * data_cfg_default.shape[2]\n", 58 | "num_bins = 100 # Here we define number of imprint bins\n", 59 | "block = ImprintBlock(input_dim, num_bins=num_bins)\n", 60 | "model = torch.nn.Sequential(\n", 61 | " torch.nn.Flatten(), block, torch.nn.Unflatten(dim=1, unflattened_size=data_cfg_default.shape), model\n", 62 | ")\n", 63 | "secret = dict(weight_idx=0, bias_idx=1, shape=tuple(data_cfg_default.shape), structure=block.structure)\n", 64 | "secrets = {\"ImprintBlock\": secret}" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "id": "319bb6e2", 70 | "metadata": {}, 71 | "source": [ 72 | "### And your dataset (ImageNet by default)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "id": "d0e94352", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "transforms = torchvision.transforms.Compose(\n", 83 | " [\n", 84 | " torchvision.transforms.Resize(256),\n", 85 | " torchvision.transforms.CenterCrop(224),\n", 86 | " torchvision.transforms.ToTensor(),\n", 87 | " torchvision.transforms.Normalize(mean=data_cfg_default.mean, std=data_cfg_default.std),\n", 88 | " ]\n", 89 | ")\n", 90 | "dataset = torchvision.datasets.ImageNet(root=\"~/data/\", split=\"val\", transform=transforms)\n", 91 | "batch_size = 64 # Number of images in the user's batch. We have a small one here for visualization purposes\n", 92 | "import random\n", 93 | "random.seed(123) # You can change this to get a new batch. \n", 94 | "samples = [dataset[i] for i in random.sample(range(len(dataset)), batch_size)]\n", 95 | "data = torch.stack([sample[0] for sample in samples])\n", 96 | "labels = torch.tensor([sample[1] for sample in samples])" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "id": "de491268", 102 | "metadata": {}, 103 | "source": [ 104 | "### Simulate an attacked FL protocol" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "id": "038ec154", 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "# This is the attacker:\n", 115 | "attacker = ImprintAttacker(model, loss_fn, attack_cfg_default, setup)\n", 116 | "\n", 117 | "# Server-side computation:\n", 118 | "queries = [dict(parameters=[p for p in model.parameters()], buffers=[b for b in model.buffers()])]\n", 119 | "server_payload = dict(queries=queries, data=data_cfg_default)\n", 120 | "# User-side computation:\n", 121 | "loss = loss_fn(model(data), labels)\n", 122 | "shared_data = dict(\n", 123 | " gradients=[torch.autograd.grad(loss, model.parameters())],\n", 124 | " buffers=None,\n", 125 | " num_data_points=1,\n", 126 | " labels=labels,\n", 127 | " local_hyperparams=None,\n", 128 | ")" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "id": "10d3f62a", 134 | "metadata": {}, 135 | "source": [ 136 | "### Reconstruct data from the update" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "91ade4a2", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "# Attack:\n", 147 | "reconstructed_user_data, stats = attacker.reconstruct(server_payload, shared_data, secrets, dryrun=False)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "id": "6a910a92", 154 | "metadata": {}, 155 | "outputs": [], 156 | "source": [ 157 | "# Metrics?: \n", 158 | "from utils.analysis import report\n", 159 | "true_user_data = {'data': data, 'labels': labels}\n", 160 | "metrics = report(reconstructed_user_data,\n", 161 | " true_user_data,\n", 162 | " server_payload,\n", 163 | " model, compute_ssim=False) # Can change to true and install a package...\n", 164 | "print(f\"MSE: {metrics['mse']}, PSNR: {metrics['psnr']}, LPIPS: {metrics['lpips']}\")" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "id": "1777d351", 170 | "metadata": {}, 171 | "source": [ 172 | "### Plot ground-truth data" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "id": "e0484998", 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "plot_data(data_cfg_default, true_user_data, setup)" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "id": "f410d7fd", 188 | "metadata": {}, 189 | "source": [ 190 | "### Now plot reconstructed data" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "id": "2e7dd96c", 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "plot_data(data_cfg_default, reconstructed_user_data, setup)" 201 | ] 202 | } 203 | ], 204 | "metadata": { 205 | "kernelspec": { 206 | "display_name": "Python 3", 207 | "language": "python", 208 | "name": "python3" 209 | }, 210 | "language_info": { 211 | "codemirror_mode": { 212 | "name": "ipython", 213 | "version": 3 214 | }, 215 | "file_extension": ".py", 216 | "mimetype": "text/x-python", 217 | "name": "python", 218 | "nbconvert_exporter": "python", 219 | "pygments_lexer": "ipython3", 220 | "version": "3.8.10" 221 | } 222 | }, 223 | "nbformat": 4, 224 | "nbformat_minor": 5 225 | } 226 | -------------------------------------------------------------------------------- /attacks/common.py: -------------------------------------------------------------------------------- 1 | """Common subfunctions to multiple modules.""" 2 | 3 | 4 | import torch 5 | 6 | 7 | def optimizer_lookup(params, optim_name, step_size, scheduler=None, warmup=0, max_iterations=10_000): 8 | if optim_name.lower() == "adam": 9 | optimizer = torch.optim.Adam(params, lr=step_size) 10 | elif optim_name.lower() == "momgd": 11 | optimizer = torch.optim.SGD(params, lr=step_size, momentum=0.9, nesterov=True) 12 | elif optim_name.lower() == "gd": 13 | optimizer = torch.optim.SGD(params, lr=step_size, momentum=0.0) 14 | elif optim_name.lower() == "l-bfgs": 15 | optimizer = torch.optim.LBFGS(params, lr=step_size) 16 | else: 17 | raise ValueError(f"Invalid optimizer {optim_name} given.") 18 | 19 | if scheduler == "step-lr": 20 | 21 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 22 | optimizer, milestones=[max_iterations // 2.667, max_iterations // 1.6, max_iterations // 1.142], gamma=0.1 23 | ) 24 | elif scheduler == "cosine-decay": 25 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iterations, eta_min=0.0) 26 | else: 27 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[], gamma=1) 28 | 29 | if warmup > 0: 30 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=warmup, after_scheduler=scheduler) 31 | 32 | return optimizer, scheduler 33 | 34 | 35 | """The following code block is part of https://github.com/ildoonet/pytorch-gradual-warmup-lr. 36 | 37 | 38 | MIT License 39 | 40 | Copyright (c) 2019 Ildoo Kim 41 | 42 | Permission is hereby granted, free of charge, to any person obtaining a copy 43 | of this software and associated documentation files (the "Software"), to deal 44 | in the Software without restriction, including without limitation the rights 45 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 46 | copies of the Software, and to permit persons to whom the Software is 47 | furnished to do so, subject to the following conditions: 48 | 49 | The above copyright notice and this permission notice shall be included in all 50 | copies or substantial portions of the Software. 51 | 52 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 53 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 54 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 55 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 56 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 57 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 58 | SOFTWARE. 59 | 60 | """ 61 | 62 | from torch.optim.lr_scheduler import _LRScheduler 63 | from torch.optim.lr_scheduler import ReduceLROnPlateau 64 | 65 | 66 | class GradualWarmupScheduler(_LRScheduler): 67 | """Gradually warm-up(increasing) learning rate in optimizer. 68 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 69 | Args: 70 | optimizer (Optimizer): Wrapped optimizer. 71 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 72 | total_epoch: target learning rate is reached at total_epoch, gradually 73 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 74 | """ 75 | 76 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 77 | self.multiplier = multiplier 78 | if self.multiplier < 1.0: 79 | raise ValueError("multiplier should be greater thant or equal to 1.") 80 | self.total_epoch = total_epoch 81 | self.after_scheduler = after_scheduler 82 | self.finished = False 83 | super(GradualWarmupScheduler, self).__init__(optimizer) 84 | 85 | def get_lr(self): 86 | if self.last_epoch > self.total_epoch: 87 | if self.after_scheduler: 88 | if not self.finished: 89 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 90 | self.finished = True 91 | return self.after_scheduler.get_last_lr() 92 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 93 | 94 | if self.multiplier == 1.0: 95 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 96 | else: 97 | return [ 98 | base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0) 99 | for base_lr in self.base_lrs 100 | ] 101 | 102 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 103 | if epoch is None: 104 | epoch = self.last_epoch + 1 105 | self.last_epoch = ( 106 | epoch if epoch != 0 else 1 107 | ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 108 | if self.last_epoch <= self.total_epoch: 109 | warmup_lr = [ 110 | base_lr * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0) 111 | for base_lr in self.base_lrs 112 | ] 113 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 114 | param_group["lr"] = lr 115 | else: 116 | if epoch is None: 117 | self.after_scheduler.step(metrics, None) 118 | else: 119 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 120 | 121 | def step(self, epoch=None, metrics=None): 122 | if type(self.after_scheduler) != ReduceLROnPlateau: 123 | if self.finished and self.after_scheduler: 124 | if epoch is None: 125 | self.after_scheduler.step(None) 126 | else: 127 | self.after_scheduler.step(epoch - self.total_epoch) 128 | self._last_lr = self.after_scheduler.get_last_lr() 129 | else: 130 | return super(GradualWarmupScheduler, self).step(epoch) 131 | else: 132 | self.step_ReduceLROnPlateau(metrics, epoch) 133 | 134 | def state_dict(self): 135 | """Returns the state of the scheduler as a :class:`dict`. 136 | It contains an entry for every variable in self.__dict__ which 137 | is not the optimizer. 138 | """ 139 | after_scheduler_dict = { 140 | key: value for key, value in self.after_scheduler.__dict__.items() if key != "optimizer" 141 | } 142 | state_dict = {key: value for key, value in self.__dict__.items() if key != "optimizer"} 143 | state_dict["after_scheduler"] = after_scheduler_dict 144 | return state_dict 145 | 146 | def load_state_dict(self, state_dict): 147 | """Loads the schedulers state. 148 | Args: 149 | state_dict (dict): scheduler state. Should be an object returned 150 | from a call to :meth:`state_dict`. 151 | """ 152 | after_scheduler_dict = state_dict.pop("after_scheduler") 153 | self.after_scheduler.__dict__.update(after_scheduler_dict) 154 | self.__dict__.update(state_dict) 155 | -------------------------------------------------------------------------------- /modifications/imprint.py: -------------------------------------------------------------------------------- 1 | """Implements a malicious block that can be inserted at the front on normal models to break them.""" 2 | from statistics import NormalDist 3 | import torch 4 | import math 5 | from scipy.stats import laplace 6 | 7 | 8 | class ImprintBlock(torch.nn.Module): 9 | structure = "cumulative" 10 | 11 | def __init__(self, data_size, num_bins, connection="linear", gain=1e-3, linfunc="fourier", mode=0): 12 | """ 13 | data_size is the length of the input data 14 | num_bins is how many "paths" to include in the model 15 | connection is how this block should coonect back to the input shape (optional) 16 | 17 | linfunc is the choice of linear query function ('avg', 'fourier', 'randn', 'rand'). 18 | If linfunc is fourier, then the mode parameter determines the mode of the DCT-2 that is used as linear query. 19 | """ 20 | super().__init__() 21 | self.data_size = data_size 22 | self.num_bins = num_bins 23 | self.linear0 = torch.nn.Linear(data_size, num_bins) 24 | 25 | self.bins = self._get_bins(linfunc) 26 | with torch.no_grad(): 27 | self.linear0.weight.data = self._init_linear_function(linfunc, mode) * gain 28 | self.linear0.bias.data = self._make_biases() * gain 29 | 30 | self.connection = connection 31 | if connection == "linear": 32 | self.linear2 = torch.nn.Linear(num_bins, data_size) 33 | with torch.no_grad(): 34 | self.linear2.weight.data = torch.ones_like(self.linear2.weight.data) / gain 35 | self.linear2.bias.data -= torch.as_tensor(self.bins).mean() 36 | 37 | self.nonlin = torch.nn.ReLU() 38 | 39 | @torch.no_grad() 40 | def _init_linear_function(self, linfunc="fourier", mode=0): 41 | K, N = self.num_bins, self.data_size 42 | if linfunc == "avg": 43 | weights = torch.ones_like(self.linear0.weight.data) / N 44 | elif linfunc == "fourier": 45 | weights = torch.cos(math.pi / N * (torch.arange(0, N) + 0.5) * mode).repeat(K, 1) / N * max(mode, 0.33) * 4 46 | # dont ask about the 4, this is WIP 47 | # nonstandard normalization 48 | elif linfunc == "randn": 49 | weights = torch.randn(N).repeat(K, 1) 50 | std, mu = torch.std_mean(weights[0]) # Enforce mean=0, std=1 with higher precision 51 | weights = (weights - mu) / std / math.sqrt(N) # Move to std=1 in output dist 52 | elif linfunc == "rand": 53 | weights = torch.rand(N).repeat(K, 1) # This might be a terrible idea haven't done the math 54 | std, mu = torch.std_mean(weights[0]) # Enforce mean=0, std=1 55 | weights = (weights - mu) / std / math.sqrt(N) # Move to std=1 in output dist 56 | else: 57 | raise ValueError(f"Invalid linear function choice {linfunc}.") 58 | 59 | return weights 60 | 61 | def _get_bins(self, linfunc="avg"): 62 | bins = [] 63 | mass_per_bin = 1 / (self.num_bins) 64 | bins.append(-10) # -Inf is not great here, but NormalDist(mu=0, sigma=1).cdf(10) approx 1 65 | for i in range(1, self.num_bins): 66 | if "fourier" in linfunc: 67 | bins.append(laplace(loc=0.0, scale=1 / math.sqrt(2)).ppf(i * mass_per_bin)) 68 | else: 69 | bins.append(NormalDist().inv_cdf(i * mass_per_bin)) 70 | return bins 71 | 72 | def _make_biases(self): 73 | new_biases = torch.zeros_like(self.linear0.bias.data) 74 | for i in range(new_biases.shape[0]): 75 | new_biases[i] = -self.bins[i] 76 | return new_biases 77 | 78 | def forward(self, x): 79 | x_in = x 80 | x = self.linear0(x) 81 | x = self.nonlin(x) 82 | if self.connection == "linear": 83 | output = self.linear2(x) 84 | elif self.connection == "cat": 85 | output = torch.cat([x, x_in[:, self.num_bins :]], dim=1) 86 | elif self.connection == "softmax": 87 | s = torch.softmax(x, dim=1)[:, :, None] 88 | output = (x_in[:, None, :] * s).sum(dim=1) 89 | else: 90 | output = x_in + x.mean(dim=1, keepdim=True) 91 | return output 92 | 93 | 94 | class SparseImprintBlock(ImprintBlock): 95 | structure = "sparse" 96 | 97 | """This block is sparse instead of cumulative which is more efficient in noise/param tradeoffs but requires 98 | two ReLUs that construct the hard-tanh nonlinearity.""" 99 | 100 | def _get_bins(self, mu=0, sigma=1, linfunc="avg"): 101 | bins = [] 102 | mass = 0 103 | for path in range(self.num_bins + 1): 104 | mass += 1 / (self.num_bins + 2) 105 | if "fourier" in linfunc: 106 | bins.append(laplace(loc=mu, scale=sigma / math.sqrt(2)).ppf(mass)) 107 | else: 108 | bins += [NormalDist(mu=mu, sigma=sigma).inv_cdf(mass)] 109 | bin_sizes = [bins[i + 1] - bins[i] for i in range(len(bins) - 1)] 110 | self.bin_sizes = bin_sizes 111 | return bins[1:] 112 | 113 | @torch.no_grad() 114 | def _init_linear_function(self, linfunc="fourier", mode=0): 115 | weights = super()._init_linear_function(linfunc, mode) 116 | for i, row in enumerate(weights): 117 | row /= torch.as_tensor(self.bin_sizes[i], device=new_data.device) 118 | return weights 119 | 120 | def _make_biases(self): 121 | new_biases = torch.zeros_like(self.linear0.bias.data) 122 | for i, (bin_val, bin_width) in enumerate(zip(self.bins, self.bin_sizes)): 123 | new_biases[i] = -bin_val / bin_width 124 | return new_biases 125 | 126 | 127 | class OneShotBlock(ImprintBlock): 128 | structure = "cumulative" 129 | 130 | """One-shot attack with minimal additional parameters. Can target a specific data point if its target_val is known.""" 131 | 132 | def __init__(self, data_size, num_bins, connection="linear", gain=1e-3, linfunc="fourier", mode=0, target_val=0): 133 | self.virtual_bins = num_bins 134 | self.target_val = target_val 135 | num_bins = 2 136 | super().__init__(data_size, num_bins, connection, gain, linfunc, mode) 137 | 138 | def _get_bins(self, linfunc="avg"): 139 | bins = [] 140 | mass_per_bin = 1 / (self.virtual_bins) 141 | bins.append(-10) # -Inf is not great here, but NormalDist(mu=0, sigma=1).cdf(10) approx 1 142 | for i in range(1, self.virtual_bins): 143 | if "fourier" in linfunc: 144 | bins.append(laplace(loc=0.0, scale=1 / math.sqrt(2)).ppf(i * mass_per_bin)) 145 | else: 146 | bins.append(NormalDist().inv_cdf(i * mass_per_bin)) 147 | if self.target_val < bins[-1]: 148 | break 149 | return bins[-2:] 150 | 151 | 152 | class OneShotBlockSparse(SparseImprintBlock): 153 | structure = "sparse" 154 | 155 | def __init__(self, data_size, num_bins, connection="linear"): 156 | """ 157 | data_size is the size of the input images 158 | num_bins is how many "paths" to include in the model 159 | """ 160 | super().__init__(data_size, num_bins=1, connection=connection) 161 | self.data_size = data_size 162 | self.num_bins = num_bins 163 | 164 | def _get_bins(self): 165 | # Here we just build bins of uniform mass 166 | left_bins = [] 167 | bins = [] 168 | mass_per_bin = 1 / self.num_bins 169 | bins = [-NormalDist().inv_cdf(0.5), -NormalDist().inv_cdf(0.5 + mass_per_bin)] 170 | self.bin_sizes = [bins[i + 1] - bins[i] for i in range(len(bins) - 1)] 171 | bins = bins[:-1] # here we need to throw away one on the right 172 | return bins 173 | -------------------------------------------------------------------------------- /utils/analysis.py: -------------------------------------------------------------------------------- 1 | """Simple report function based on PSNR and maybe SSIM and maybe better ideas...""" 2 | import torch 3 | 4 | 5 | from .metrics import psnr_compute, registered_psnr_compute, image_identifiability_precision 6 | #cw_ssim - can uncomment if you want ... 7 | 8 | 9 | import logging 10 | 11 | log = logging.getLogger(__name__) 12 | 13 | 14 | def report( 15 | reconstructed_user_data, 16 | true_user_data, 17 | server_payload, 18 | model, 19 | dataloader=None, 20 | setup=dict(device=torch.device("cpu"), dtype=torch.float), 21 | order_batch=True, 22 | compute_full_iip=False, 23 | compute_rpsnr=True, 24 | compute_ssim=True, 25 | ): 26 | import lpips # lazily import this only if report is used. 27 | 28 | lpips_scorer = lpips.LPIPS(net="alex", verbose=False).to(**setup) 29 | 30 | dm = torch.as_tensor(server_payload["data"].mean, **setup)[None, :, None, None] 31 | ds = torch.as_tensor(server_payload["data"].std, **setup)[None, :, None, None] 32 | model.to(**setup) 33 | 34 | rec_denormalized = torch.clamp(reconstructed_user_data["data"].to(**setup) * ds + dm, 0, 1) 35 | ground_truth_denormalized = torch.clamp(true_user_data["data"].to(**setup) * ds + dm, 0, 1) 36 | 37 | if order_batch: 38 | order = compute_batch_order(lpips_scorer, rec_denormalized, ground_truth_denormalized, setup) 39 | reconstructed_user_data["data"] = reconstructed_user_data["data"][order] 40 | if reconstructed_user_data["labels"] is not None: 41 | reconstructed_user_data["labels"] = reconstructed_user_data["labels"][order] 42 | rec_denormalized = rec_denormalized[order] 43 | else: 44 | order = None 45 | 46 | if any(reconstructed_user_data["labels"].sort()[0] != true_user_data["labels"]): 47 | found_labels = 0 48 | label_pool = true_user_data["labels"].clone().tolist() 49 | for label in reconstructed_user_data["labels"]: 50 | if label in label_pool: 51 | found_labels += 1 52 | label_pool.remove(label) 53 | 54 | log.info(f"Label recovery was sucessfull in {found_labels} cases.") 55 | test_label_acc = found_labels / len(true_user_data["labels"]) 56 | else: 57 | test_label_acc = 1 58 | 59 | test_mse = (rec_denormalized - ground_truth_denormalized).pow(2).mean().item() 60 | test_psnr = psnr_compute(rec_denormalized, ground_truth_denormalized, factor=1).item() 61 | if compute_ssim: 62 | test_ssim = cw_ssim(rec_denormalized, ground_truth_denormalized, scales=5).item() 63 | else: 64 | test_ssim = 0 65 | 66 | # Hint: This part switches to the lpips [-1, 1] normalization: 67 | test_lpips = lpips_scorer(rec_denormalized, ground_truth_denormalized, normalize=True).mean().item() 68 | 69 | # Compute registered psnr. This is a bit computationally intensive: 70 | if compute_rpsnr: 71 | test_rpsnr = registered_psnr_compute(rec_denormalized, ground_truth_denormalized, factor=1).item() 72 | else: 73 | test_rpsnr = float("nan") 74 | 75 | # Compute IIP score if a dataloader is passed: 76 | if dataloader is not None: 77 | if compute_full_iip: 78 | scores = ["pixel", "lpips", "self"] 79 | else: 80 | scores = ["pixel"] 81 | iip_scores = image_identifiability_precision( 82 | reconstructed_user_data, true_user_data, dataloader, lpips_scorer=lpips_scorer, model=model, scores=scores 83 | ) 84 | else: 85 | iip_scores = dict(none=float("NaN")) 86 | 87 | feat_mse = 0.0 88 | for idx, payload in enumerate(server_payload["queries"]): 89 | parameters = payload["parameters"] 90 | buffers = payload["buffers"] 91 | 92 | with torch.no_grad(): 93 | for param, server_state in zip(model.parameters(), parameters): 94 | param.copy_(server_state.to(**setup)) 95 | if buffers is not None: 96 | for buffer, server_state in zip(model.buffers(), buffers): 97 | buffer.copy_(server_state.to(**setup)) 98 | else: 99 | for buffer, user_state in zip(model.buffers(), true_user_data["buffers"][idx]): 100 | buffer.copy_(user_state.to(**setup)) 101 | 102 | # Compute the forward passes 103 | feats_rec = model(reconstructed_user_data["data"].to(**setup)) 104 | feats_true = model(true_user_data["data"].to(**setup)) 105 | relevant_features = true_user_data["labels"] 106 | feat_mse += (feats_rec - feats_true)[range(len(relevant_features)), relevant_features].pow(2).mean().item() 107 | 108 | # Record model parameters: 109 | parameters = sum([p.numel() for p in model.parameters()]) 110 | 111 | # Print report: 112 | iip_scoring = " | ".join([f"IIP-{k}: {v:5.2%}" for k, v in iip_scores.items()]) 113 | log.info( 114 | f"METRICS: | MSE: {test_mse:2.4f} | PSNR: {test_psnr:4.2f} | FMSE: {feat_mse:2.4e} | LPIPS: {test_lpips:4.2f}|" 115 | + "\n" 116 | f" R-PSNR: {test_rpsnr:4.2f} | {iip_scoring} | SSIM: {test_ssim:2.4f} | Label Acc: {test_label_acc:2.2%}" 117 | ) 118 | 119 | metrics = dict( 120 | mse=test_mse, 121 | psnr=test_psnr, 122 | feat_mse=feat_mse, 123 | lpips=test_lpips, 124 | rpsnr=test_rpsnr, 125 | ssim=test_ssim, 126 | order=order, 127 | **{f"IIP-{k}": v for k, v in iip_scores.items()}, 128 | parameters=parameters, 129 | label_acc=test_label_acc, 130 | ) 131 | return metrics 132 | 133 | 134 | def compute_batch_order(lpips_scorer, rec_denormalized, ground_truth_denormalized, setup): 135 | """Re-order a batch of images according to LPIPS statistics of source batch, trying to match similar images. 136 | 137 | This implementation basically follows the LPIPS.forward method, but for an entire batch.""" 138 | from scipy.optimize import linear_sum_assignment # Again a lazy import 139 | 140 | B = rec_denormalized.shape[0] 141 | L = lpips_scorer.L 142 | assert ground_truth_denormalized.shape[0] == B 143 | 144 | with torch.inference_mode(): 145 | # Compute all features [assume sufficient memory is a given] 146 | features_rec = [] 147 | for input in rec_denormalized: 148 | input_scaled = lpips_scorer.scaling_layer(input) 149 | output = lpips_scorer.net.forward(input_scaled) 150 | layer_features = {} 151 | for kk in range(L): 152 | layer_features[kk] = normalize_tensor(output[kk]) 153 | features_rec.append(layer_features) 154 | 155 | features_gt = [] 156 | for input in ground_truth_denormalized: 157 | input_scaled = lpips_scorer.scaling_layer(input) 158 | output = lpips_scorer.net.forward(input_scaled) 159 | layer_features = {} 160 | for kk in range(L): 161 | layer_features[kk] = normalize_tensor(output[kk]) 162 | features_gt.append(layer_features) 163 | 164 | # Compute overall similarities: 165 | similarity_matrix = torch.zeros(B, B, **setup) 166 | for idx, x in enumerate(features_gt): 167 | for idy, y in enumerate(features_rec): 168 | for kk in range(L): 169 | diff = (x[kk] - y[kk]) ** 2 170 | similarity_matrix[idx, idy] += spatial_average(lpips_scorer.lins[kk](diff)).squeeze() 171 | try: 172 | _, rec_assignment = linear_sum_assignment(similarity_matrix.cpu().numpy(), maximize=False) 173 | except ValueError: 174 | print(f"ValueError from similarity matrix {similarity_matrix.cpu().numpy()}") 175 | print("Returning trivial order...") 176 | rec_assignment = list(range(B)) 177 | return torch.as_tensor(rec_assignment, device=setup["device"], dtype=torch.long) 178 | 179 | 180 | def normalize_tensor(in_feat, eps=1e-10): 181 | """From https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/__init__.py.""" 182 | norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) 183 | return in_feat / (norm_factor + eps) 184 | 185 | 186 | def spatial_average(in_tens, keepdim=True): 187 | """ https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py .""" 188 | return in_tens.mean([2, 3], keepdim=keepdim) 189 | 190 | 191 | def find_oneshot(rec_denormalized, ground_truth_denormalized): 192 | one_shot = (rec_denormalized - ground_truth_denormalized).pow(2) 193 | one_shot_idx = one_shot.view(one_shot.shape[0], -1).mean(dim=-1).argmin() 194 | return one_shot_idx 195 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | """Various metrics.""" 2 | import torch 3 | from functools import partial 4 | 5 | 6 | def cw_ssim(img_batch, ref_batch, scales=5, skip_scales=None, K=1e-6): 7 | """Batched complex wavelet structural similarity. 8 | 9 | As in Zhou Wang and Eero P. Simoncelli, "TRANSLATION INSENSITIVE IMAGE SIMILARITY IN COMPLEX WAVELET DOMAIN" 10 | Ok, not quite, this implementation does not local SSIM and averaging over local patches and uses only 11 | the existing wavelet structure to provide something similar. 12 | 13 | skip_scales can be a list like [True, False, False, False] marking levels to be skipped. 14 | K is a small fudge factor. 15 | """ 16 | try: 17 | from pytorch_wavelets import DTCWTForward 18 | except ModuleNotFoundError: 19 | raise ModuleNotFoundError( 20 | "To utilize wavelet SSIM, install pytorch wavelets from https://github.com/fbcotter/pytorch_wavelets." 21 | ) 22 | 23 | # 1) Compute wavelets: 24 | setup = dict(device=img_batch.device, dtype=img_batch.dtype) 25 | if skip_scales is not None: 26 | include_scale = [~s for s in skip_scales] 27 | total_scales = scales - sum(skip_scales) 28 | else: 29 | include_scale = True 30 | total_scales = scales 31 | xfm = DTCWTForward(J=scales, biort="near_sym_b", qshift="qshift_b", include_scale=include_scale).to(**setup) 32 | img_coefficients = xfm(img_batch) 33 | ref_coefficients = xfm(ref_batch) 34 | 35 | # 2) Multiscale complex SSIM: 36 | ssim = 0 37 | for xs, ys in zip(img_coefficients[1], ref_coefficients[1]): 38 | if len(xs) > 0: 39 | xc = torch.view_as_complex(xs) 40 | yc = torch.view_as_complex(ys) 41 | 42 | conj_product = (xc * yc.conj()).sum(dim=2).abs() 43 | square_img = (xc * xc.conj()).abs().sum(dim=2) 44 | square_ref = (yc * yc.conj()).abs().sum(dim=2) 45 | 46 | ssim_val = (2 * conj_product + K) / (square_img + square_ref + K) 47 | ssim += ssim_val.mean() 48 | return ssim / total_scales 49 | 50 | 51 | def gradient_uniqueness(model, loss_fn, user_data, server_payload, setup, query=0, fudge=1e-7): 52 | """Count the number of gradient entries that are only affected by a single data point.""" 53 | 54 | r"""Formatting suggestion: 55 | print(f'Unique entries (hitting 1 or all): {unique_entries:.2%}, average hits: {average_hits_per_entry:.2%} \n' 56 | f'Stats (as N hits:val): {dict(zip(uniques[0].tolist(), uniques[1].tolist()))}\n' 57 | f'Unique nonzero (hitting 1 or all): {nonzero_uniques:.2%} Average nonzero: {nonzero_hits_per_entry:.2%}. \n' 58 | f'nonzero-Stats (as N hits:val): {dict(zip(uniques_nonzero[0].tolist(), uniques_nonzero[1].tolist()))}') 59 | """ 60 | payload = server_payload["queries"][query] 61 | parameters = payload["parameters"] 62 | buffers = payload["buffers"] 63 | 64 | with torch.no_grad(): 65 | for param, server_state in zip(model.parameters(), parameters): 66 | param.copy_(server_state.to(**setup)) 67 | for buffer, server_state in zip(model.buffers(), buffers): 68 | buffer.copy_(server_state.to(**setup)) 69 | 70 | # Compute the forward pass 71 | gradients = [] 72 | for data_point, label in zip(user_data["data"], user_data["labels"]): 73 | model.zero_grad() 74 | loss = loss_fn(model(data_point[None, :]), label[None]) 75 | data_grads = torch.autograd.grad(loss, model.parameters()) 76 | gradients += [torch.cat([g.reshape(-1) for g in data_grads])] 77 | 78 | average_gradient = torch.stack(gradients, dim=0).mean(dim=0, keepdim=True) 79 | 80 | gradient_per_example = torch.stack(gradients, dim=0) 81 | 82 | val = (gradient_per_example - average_gradient).abs() < fudge 83 | nonzero_val = val[:, average_gradient[0].abs() > fudge] 84 | unique_entries = (val.sum(dim=0) == 1).float().mean() + (val.sum(dim=0) == len(gradients)).float().mean() 85 | # hitting a single entry or all entries is equally good for rec 86 | average_hits_per_entry = val.sum(dim=0).float().mean() 87 | nonzero_hits_per_entry = (nonzero_val).sum(dim=0).float().mean() 88 | unique_nonzero_hits = (nonzero_val.sum(dim=0) == 1).float().mean() + ( 89 | nonzero_val.sum(dim=0) == len(gradients) 90 | ).float().mean() 91 | return ( 92 | unique_entries, 93 | average_hits_per_entry, 94 | unique_nonzero_hits, 95 | nonzero_hits_per_entry, 96 | val.sum(dim=0).unique(return_counts=True), 97 | nonzero_val.sum(dim=0).unique(return_counts=True), 98 | ) 99 | 100 | 101 | def psnr_compute(img_batch, ref_batch, batched=False, factor=1.0, clip=False): 102 | """Standard PSNR.""" 103 | if clip: 104 | img_batch = torch.clamp(img_batch, 0, 1) 105 | 106 | if batched: 107 | mse = ((img_batch.detach() - ref_batch) ** 2).mean() 108 | if mse > 0 and torch.isfinite(mse): 109 | return 10 * torch.log10(factor ** 2 / mse) 110 | elif not torch.isfinite(mse): 111 | return torch.tensor(float("nan"), device=img_batch.device) 112 | else: 113 | return torch.tensor(float("inf"), device=img_batch.device) 114 | else: 115 | B = img_batch.shape[0] 116 | mse_per_example = ((img_batch.detach() - ref_batch) ** 2).view(B, -1).mean(dim=1) 117 | if any(mse_per_example == 0): 118 | return torch.tensor(float("inf"), device=img_batch.device) 119 | elif not all(torch.isfinite(mse_per_example)): 120 | return torch.tensor(float("nan"), device=img_batch.device) 121 | else: 122 | return (10 * torch.log10(factor ** 2 / mse_per_example)).mean() 123 | 124 | 125 | def registered_psnr_compute(img_batch, ref_batch, factor=1.0): 126 | """Use kornia for now.""" 127 | return _registered_psnr_compute_kornia(img_batch, ref_batch, factor) 128 | 129 | 130 | def _registered_psnr_compute_kornia(img_batch, ref_batch, factor=1.0): 131 | """Kornia version. Todo: Use a smarter/deeper matching tool.""" 132 | from kornia.geometry import ImageRegistrator, HomographyWarper # lazy import here as well 133 | 134 | B = img_batch.shape[0] 135 | default_psnrs = [] 136 | registered_psnrs = [] 137 | # If only this was parallelized, todo ... 138 | for img, ref in zip(img_batch.detach(), ref_batch.detach()): 139 | img, ref = img[None, ...], ref[None, ...] 140 | mse = ((img - ref) ** 2).mean() 141 | default_psnrs += [10 * torch.log10(factor ** 2 / mse)] 142 | # Align by homography: 143 | registrator = ImageRegistrator("similarity", num_iterations=2500) 144 | registrator.warper = partial(HomographyWarper, padding_mode="reflection") 145 | registrator.to(ref.device) 146 | homography = registrator.register(img, ref) 147 | warped_img = registrator.warp_src_into_dst(img) 148 | # Compute new PSNR: 149 | mse = ((warped_img.detach() - ref_batch) ** 2).mean() 150 | registered_psnrs += [10 * torch.log10(factor ** 2 / mse)] 151 | 152 | # Return best of default and warped PSNR: 153 | return torch.stack([torch.stack(default_psnrs), torch.stack(registered_psnrs)]).max(dim=0)[0].mean() 154 | 155 | 156 | def _registered_psnr_compute_kornia_loftr(img_batch, ref_batch, factor=1.0): 157 | """Kornia version. WIP.""" 158 | from kornia.feature import LoFTR 159 | from kornia.geometry.homography import find_homography_dlt 160 | 161 | B = img_batch.shape[0] 162 | mse_per_example = ((img_batch.detach() - ref_batch) ** 2).view(B, -1).mean(dim=1) 163 | default_psnrs = 10 * torch.log10(factor ** 2 / mse_per_example) 164 | # Align by homography: 165 | matcher = LoFTR(pretrained="indoor") 166 | with torch.no_grad(): 167 | correspondences_dict = matcher( 168 | dict(image0=img_batch.mean(dim=1, keepdim=True), image1=ref_batch.mean(dim=1, keepdim=True)) 169 | ) 170 | homography = find_homography_dlt(correspondences_dict["keypoints0"], correspondences_dict["keypoints1"]) 171 | warped_imgs = homography_warp(img_batch, homography, ref_batch.shape[-2:]) 172 | # Compute new PSNR: 173 | mse_per_example = ((warped_imgs.detach() - ref_batch) ** 2).view(B, -1).mean(dim=1) 174 | registered_psnrs = 10 * torch.log10(factor ** 2 / mse_per_example) 175 | 176 | # Return best of default and warped PSNR: 177 | return torch.stack([default_psnrs, registered_psnrs]).max(dim=0)[0].mean() 178 | 179 | 180 | def _registered_psnr_compute_skimage(img_batch, ref_batch, factor=1.0): 181 | """Use ORB features to register images onto reference before computing PSNR scores.""" 182 | import skimage.feature # Lazy metric stuff import 183 | import skimage.measure 184 | import skimage.transform 185 | 186 | descriptor_extractor = skimage.feature.ORB(n_keypoints=800) 187 | 188 | psnr_vals = torch.zeros(img_batch.shape[0]) 189 | for idx, (img, ref) in enumerate(zip(img_batch, ref_batch)): 190 | default_psnr = psnr_compute(img, ref, factor=1.0, batched=True) 191 | try: 192 | img_np, ref_np = img.numpy(), ref.numpy() # move to numpy 193 | descriptor_extractor.detect_and_extract(ref_np.mean(axis=0)) # and grayscale for ORB 194 | keypoints_src, descriptors_src = descriptor_extractor.keypoints, descriptor_extractor.descriptors 195 | descriptor_extractor.detect_and_extract(img_np.mean(axis=0)) 196 | keypoints_tgt, descriptors_tgt = descriptor_extractor.keypoints, descriptor_extractor.descriptors 197 | 198 | matches = skimage.feature.match_descriptors(descriptors_src, descriptors_tgt, cross_check=True) 199 | # Look for an affine transform and search with RANSAC over matches: 200 | model_robust, inliers = skimage.measure.ransac( 201 | (keypoints_tgt[matches[:, 1]], keypoints_src[matches[:, 0]]), 202 | skimage.transform.EuclideanTransform, 203 | min_samples=len(matches) - 1, 204 | residual_threshold=4, 205 | max_trials=2500, 206 | ) # :> 207 | warped_img = skimage.transform.warp(img_np.transpose(1, 2, 0), model_robust, mode="wrap", order=1) 208 | # Compute normal PSNR from here: 209 | registered_psnr = psnr_compute(torch.as_tensor(warped_img), ref.permute(1, 2, 0), factor=1.0, batched=True) 210 | if registered_psnr.isfinite(): 211 | psnr_vals[idx] = max(registered_psnr, default_psnr) 212 | else: 213 | psnr_vals[idx] = default_psnr 214 | except (TypeError, IndexError, RuntimeError, ValueError): 215 | # TypeError if RANSAC fails 216 | # IndexError if not enough matches are found 217 | # RunTimeError if ORB does not find enough features 218 | # ValueError if empty match sequence 219 | # This matching implementation fills me with joy 220 | psnr_vals[idx] = default_psnr 221 | return psnr_vals.mean() 222 | 223 | 224 | def image_identifiability_precision( 225 | reconstructed_user_data, 226 | true_user_data, 227 | dataloader, 228 | scores=["pixel", "lpips", "self"], 229 | lpips_scorer=None, 230 | model=None, 231 | fudge=1e-3, 232 | ): 233 | """Nearest-neighbor metric as described in Yin et al., "See through Gradients: Image Batch Recovery via GradInversion" 234 | This version prints separate metrics for different choices of score functions. 235 | It's a bit messier to do it all in one go, but otherwise the data has to be loaded three separate times. 236 | 237 | For a self score, the model has to be provided. 238 | For an LPIPS score, the lpips scorer has to be provided. 239 | """ 240 | # Compare the reconstructed images to each image in the dataloader with the appropriate label 241 | # This could be batched and partially cached to make it faster in the future ... 242 | identified_images = dict(zip(scores, [0 for entry in scores])) 243 | 244 | for batch_idx, reconstruction in enumerate(reconstructed_user_data["data"]): 245 | batch_label = true_user_data["labels"][batch_idx] 246 | label_subset = [idx for (idx, label) in dataloader.dataset.lookup.items() if label == batch_label] 247 | 248 | distances = dict(zip(scores, [[] for entry in scores])) 249 | for idx in label_subset: 250 | comparable_data = dataloader.dataset[idx][0].to(device=reconstruction.device) 251 | 252 | for score in scores: 253 | if score == "lpips": 254 | with torch.inference_mode(): 255 | distances[score] += [lpips_scorer(reconstruction, comparable_data, normalize=False).mean()] 256 | elif score == "self" and model is not None: 257 | features_rec = _return_model_features(model, reconstruction) 258 | features_comp = _return_model_features(model, comparable_data) 259 | distances[score] += [ 260 | 1 - torch.nn.functional.cosine_similarity(features_rec.view(-1), features_comp.view(-1), dim=0) 261 | ] 262 | else: 263 | distances[score] += [torch.norm(comparable_data.view(-1) - reconstruction.view(-1))] 264 | 265 | for score in scores: 266 | minimal_distance_data_idx = label_subset[torch.stack(distances[score]).argmin()] 267 | candidate_solution = dataloader.dataset[minimal_distance_data_idx][0].to(device=reconstruction.device) 268 | true_solution = true_user_data["data"][batch_idx] 269 | if score == "lpips": 270 | distance_to_true = lpips_scorer(candidate_solution, true_solution, normalize=False).mean() 271 | elif score == "self" and model is not None: 272 | features_rec = _return_model_features(model, candidate_solution) 273 | features_comp = _return_model_features(model, true_solution) 274 | distance_to_true = 1 - torch.nn.functional.cosine_similarity( 275 | features_rec.view(-1), features_comp.view(-1), dim=0 276 | ) 277 | else: 278 | distance_to_true = torch.norm(candidate_solution.view(-1) - true_solution.view(-1)) 279 | 280 | if distance_to_true < fudge: # This should be tiny by all accounts 281 | identified_images[score] += 1 282 | 283 | return {k: v / len(reconstructed_user_data["data"]) for k, v in identified_images.items()} 284 | 285 | 286 | @torch.inference_mode() 287 | def _return_model_features(model, inputs): 288 | features = dict() # The named-hook + dict construction should be a bit more robust 289 | if inputs.ndim == 3: 290 | inputs = inputs.unsqueeze(0) 291 | 292 | def named_hook(name): 293 | def hook_fn(module, input, output): 294 | features[name] = input[0] 295 | 296 | return hook_fn 297 | 298 | for name, module in reversed(list(model.named_modules())): 299 | if isinstance(module, (torch.nn.Linear)): 300 | hook = module.register_forward_hook(named_hook(name)) 301 | feature_layer_name = name 302 | break 303 | model(inputs) 304 | hook.remove() 305 | return features[feature_layer_name] 306 | -------------------------------------------------------------------------------- /attacks/base_attack.py: -------------------------------------------------------------------------------- 1 | """Implementation for base attacker class. 2 | 3 | Inherit from this class for a consistent interface with attack cases.""" 4 | 5 | import torch 6 | from collections import defaultdict 7 | import copy 8 | 9 | from .common import optimizer_lookup 10 | 11 | 12 | import logging 13 | 14 | log = logging.getLogger(__name__) 15 | 16 | 17 | class _BaseAttacker: 18 | """This is a template class for an attack.""" 19 | 20 | def __init__(self, model, loss_fn, cfg_attack, setup=dict(dtype=torch.float, device=torch.device("cpu"))): 21 | self.cfg = cfg_attack 22 | self.memory_format = torch.channels_last if cfg_attack.impl.mixed_precision else torch.contiguous_format 23 | self.setup = dict(device=setup["device"], dtype=getattr(torch, cfg_attack.impl.dtype)) 24 | self.model_template = copy.deepcopy(model) 25 | self.loss_fn = copy.deepcopy(loss_fn) 26 | 27 | def reconstruct(self, server_payload, shared_data, server_secrets=None, dryrun=False): 28 | 29 | stats = defaultdict(list) 30 | 31 | # Implement the attack here 32 | # The attack should consume the shared_data and server payloads and reconstruct 33 | raise NotImplementedError() 34 | 35 | return reconstructed_data, stats 36 | 37 | def __repr__(self): 38 | raise NotImplementedError() 39 | 40 | def prepare_attack(self, server_payload, shared_data): 41 | """Basic startup common to many reconstruction methods.""" 42 | stats = defaultdict(list) 43 | 44 | # Load preprocessing constants: 45 | self.data_shape = server_payload["data"].shape 46 | self.dm = torch.as_tensor(server_payload["data"].mean, **self.setup)[None, :, None, None] 47 | self.ds = torch.as_tensor(server_payload["data"].std, **self.setup)[None, :, None, None] 48 | 49 | # Load server_payload into state: 50 | rec_models = self._construct_models_from_payload_and_buffers(server_payload, shared_data["buffers"]) 51 | shared_data = self._cast_shared_data(shared_data) 52 | self.rec_models = rec_models 53 | # Consider label information 54 | if shared_data["labels"] is None: 55 | labels = self._recover_label_information(shared_data, rec_models) 56 | else: 57 | labels = shared_data["labels"] 58 | 59 | # Condition gradients? 60 | if self.cfg.normalize_gradients: 61 | shared_data = self._normalize_gradients(shared_data) 62 | return rec_models, labels, stats 63 | 64 | def _construct_models_from_payload_and_buffers(self, server_payload, user_buffers): 65 | """Construct the model (or multiple) that is sent by the server and include user buffers if any.""" 66 | 67 | # Load states into multiple models if necessary 68 | models = [] 69 | for idx, payload in enumerate(server_payload["queries"]): 70 | 71 | new_model = copy.deepcopy(self.model_template) 72 | new_model.to(**self.setup, memory_format=self.memory_format) 73 | 74 | # Load parameters 75 | parameters = payload["parameters"] 76 | if user_buffers is not None and idx < len(user_buffers): 77 | # User sends buffers. These should be used! 78 | buffers = user_buffers[idx] 79 | new_model.eval() 80 | elif payload["buffers"] is not None: 81 | # The server has public buffers in any case 82 | buffers = payload["buffers"] 83 | new_model.eval() 84 | else: 85 | # The user sends no buffers and there are no public bufers 86 | # (i.e. the user in in training mode and does not send updates) 87 | new_model.train() 88 | for module in new_model.modules(): 89 | if hasattr(module, "track_running_stats"): 90 | module.reset_parameters() 91 | module.track_running_stats = False 92 | buffers = [] 93 | 94 | with torch.no_grad(): 95 | for param, server_state in zip(new_model.parameters(), parameters): 96 | param.copy_(server_state.to(**self.setup)) 97 | for buffer, server_state in zip(new_model.buffers(), buffers): 98 | buffer.copy_(server_state.to(**self.setup)) 99 | 100 | if self.cfg.impl.JIT == "script": 101 | example_inputs = self._initialize_data((1, *self.data_shape)) 102 | new_model = torch.jit.script(new_model, example_inputs=[(example_inputs,)]) 103 | elif self.cfg.impl.JIT == "trace": 104 | example_inputs = self._initialize_data((1, *self.data_shape)) 105 | new_model = torch.jit.trace(new_model, example_inputs=example_inputs) 106 | models.append(new_model) 107 | return models 108 | 109 | def _cast_shared_data(self, shared_data): 110 | """Cast user data to reconstruction data type.""" 111 | cast_grad_list = [] 112 | for shared_grad in shared_data["gradients"]: 113 | cast_grad_list += [[g.to(dtype=self.setup["dtype"]) for g in shared_grad]] 114 | shared_data["gradients"] = cast_grad_list 115 | return shared_data 116 | 117 | def _initialize_data(self, data_shape): 118 | """Note that data is initialized "inside" the network normalization.""" 119 | init_type = self.cfg.init 120 | if init_type == "randn": 121 | candidate = torch.randn(data_shape, **self.setup) 122 | elif init_type == "rand": 123 | candidate = (torch.rand(data_shape, **self.setup) * 2) - 1.0 124 | elif init_type == "zeros": 125 | candidate = torch.zeros(data_shape, **self.setup) 126 | # Initializations from Wei et al, "A Framework for Evaluating Gradient Leakage 127 | # Attacks in Federated Learning" 128 | elif any(c in init_type for c in ["red", "green", "blue", "dark", "light"]): # init_types like 'red-true' 129 | candidate = torch.zeros(data_shape, **self.setup) 130 | if "light" in init_type: 131 | candidate = torch.ones(data_shape, **self.setup) 132 | else: 133 | nonzero_channel = 0 if "red" in init_type else 1 if "green" in init_type else 2 134 | candidate[:, nonzero_channel, :, :] = 1 135 | if "-true" in init_type: 136 | # Shift to be truly RGB, not just normalized RGB 137 | candidate = (candidate - self.dm) / self.ds 138 | elif "patterned" in init_type: # Look for init_type=rand-patterned-4 139 | pattern_width = int("".join(filter(str.isdigit, init_type))) 140 | if "rand" in init_type: 141 | seed = torch.rand([1, 3, pattern_width, pattern_width], **self.setup) 142 | else: 143 | seed = torch.rand([1, 3, pattern_width, pattern_width], **self.setup) 144 | # Shape expansion: 145 | x_factor, y_factor = ( 146 | torch.as_tensor(data_shape[2] / pattern_width).ceil(), 147 | torch.as_tensor(data_shape[3] / pattern_width).ceil(), 148 | ) 149 | candidate = ( 150 | torch.tile(seed, (1, 1, int(x_factor), int(y_factor)))[:, :, : data_shape[2], : data_shape[3]] 151 | .contiguous() 152 | .clone() 153 | ) 154 | else: 155 | raise ValueError(f"Unknown initialization scheme {init_type} given.") 156 | 157 | candidate.to(memory_format=self.memory_format) 158 | candidate.requires_grad = True 159 | candidate.grad = torch.zeros_like(candidate) 160 | return candidate 161 | 162 | def _init_optimizer(self, candidate): 163 | 164 | optimizer, scheduler = optimizer_lookup( 165 | [candidate], 166 | self.cfg.optim.optimizer, 167 | self.cfg.optim.step_size, 168 | scheduler=self.cfg.optim.step_size_decay, 169 | warmup=self.cfg.optim.warmup, 170 | max_iterations=self.cfg.optim.max_iterations, 171 | ) 172 | return optimizer, scheduler 173 | 174 | def _normalize_gradients(self, shared_data, fudge_factor=1e-6): 175 | """Normalize gradients to have norm of 1. No guarantees that this would be a good idea for FL updates.""" 176 | for shared_grad in shared_data["gradients"]: 177 | grad_norm = torch.stack([g.pow(2).sum() for g in shared_grad]).sum().sqrt() 178 | torch._foreach_div_(shared_grad, max(grad_norm, fudge_factor)) 179 | return shared_data 180 | 181 | def _recover_label_information(self, user_data, rec_models): 182 | """Recover label information. 183 | 184 | This method runs under the assumption that the last two entries in the gradient vector 185 | correpond to the weight and bias of the last layer (mapping to num_classes). 186 | For non-classification tasks this has to be modified. 187 | 188 | The behavior with respect to multiple queries is work in progress and subject of debate. 189 | """ 190 | num_data_points = user_data["num_data_points"] 191 | num_classes = user_data["gradients"][0][-1].shape[0] 192 | num_queries = len(user_data["gradients"]) 193 | 194 | # In the simplest case, the label can just be inferred from the last layer 195 | if self.cfg.label_strategy == "iDLG": 196 | # This was popularized in "iDLG" by Zhao et al., 2020 197 | # assert num_data_points == 1 198 | label_list = [] 199 | for query_id, shared_grad in enumerate(user_data["gradients"]): 200 | last_weight_min = torch.argmin(torch.sum(shared_grad[-2], dim=-1), dim=-1) 201 | label_list += [last_weight_min.detach()] 202 | labels = torch.stack(label_list).unique() 203 | elif self.cfg.label_strategy == "analytic": 204 | # Analytic recovery simply works as long as all labels are unique. 205 | label_list = [] 206 | for query_id, shared_grad in enumerate(user_data["gradients"]): 207 | valid_classes = (shared_grad[-1] < 0).nonzero() 208 | label_list += [valid_classes] 209 | labels = torch.stack(label_list).unique()[:num_data_points] 210 | elif self.cfg.label_strategy == "yin": 211 | # As seen in Yin et al. 2021, "See Through Gradients: Image Batch Recovery via GradInversion" 212 | # This additionally assumes that there is a nonlinearity with positive output (like ReLU) in front of the 213 | # last classification layer. 214 | # This scheme also works best if all labels are unique 215 | # Otherwise this is an extension of iDLG to multiple labels: 216 | total_min_vals = 0 217 | for query_id, shared_grad in enumerate(user_data["gradients"]): 218 | total_min_vals += shared_grad[-2].min(dim=-1)[0] 219 | labels = total_min_vals.argsort()[:num_data_points] 220 | 221 | elif "wainakh" in self.cfg.label_strategy: 222 | 223 | if self.cfg.label_strategy == "wainakh-simple": 224 | # As seen in Weinakh et al., "User Label Leakage from Gradients in Federated Learning" 225 | m_impact = 0 226 | for query_id, shared_grad in enumerate(user_data["gradients"]): 227 | g_i = shared_grad[-2].sum(dim=1) 228 | m_query = ( 229 | torch.where(g_i < 0, g_i, torch.zeros_like(g_i)).sum() * (1 + 1 / num_classes) / num_data_points 230 | ) 231 | s_offset = 0 232 | m_impact += m_query / num_queries 233 | elif self.cfg.label_strategy == "wainakh-whitebox": 234 | # Augment previous strategy with measurements of label impact for dummy data. 235 | m_impact = 0 236 | s_offset = torch.zeros(num_classes, **self.setup) 237 | 238 | print("Starting a white-box search for optimal labels. This will take some time.") 239 | for query_id, (shared_grad, model) in enumerate(zip(user_data["gradients"], rec_models)): 240 | # Estimate m: 241 | weight_params = (list(rec_models[0].parameters())[-2],) 242 | for class_idx in range(num_classes): 243 | fake_data = torch.randn([num_data_points, *self.data_shape], **self.setup) 244 | fake_labels = torch.as_tensor([class_idx] * num_data_points, **self.setup) 245 | with torch.autocast(self.setup["device"].type, enabled=self.cfg.impl.mixed_precision): 246 | loss = self.loss_fn(model(fake_data), fake_labels) 247 | (W_cls,) = torch.autograd.grad(loss, weight_params) 248 | g_i = W_cls.sum(dim=1) 249 | m_impact += g_i.sum() * (1 + 1 / num_classes) / num_data_points / num_classes / num_queries 250 | 251 | # Estimate s: 252 | T = num_classes - 1 253 | for class_idx in range(num_classes): 254 | fake_data = torch.randn([T, *self.data_shape], **self.setup) 255 | fake_labels = torch.arange(num_classes, **self.setup) 256 | fake_labels = fake_labels[fake_labels != class_idx] 257 | with torch.autocast(self.setup["device"].type, enabled=self.cfg.impl.mixed_precision): 258 | loss = self.loss_fn(model(fake_data), fake_labels) 259 | (W_cls,) = torch.autograd.grad(loss, (weight_params[0][class_idx],)) 260 | s_offset[class_idx] += W_cls.sum() / T / num_queries 261 | 262 | else: 263 | raise ValueError(f"Invalid Wainakh strategy {self.cfg.label_strategy}.") 264 | 265 | # After determining impact and offset, run the actual recovery algorithm 266 | label_list = [] 267 | g_per_query = [shared_grad[-2].sum(dim=1) for shared_grad in user_data["gradients"]] 268 | g_i = torch.stack(g_per_query).mean(dim=0) 269 | # Stage 1: 270 | for idx in range(num_classes): 271 | if g_i[idx] < 0: 272 | label_list.append(torch.as_tensor(idx, device=self.setup["device"])) 273 | g_i[idx] -= m_impact 274 | # Stage 2: 275 | g_i = g_i - s_offset 276 | while len(label_list) < num_data_points: 277 | selected_idx = g_i.argmin() 278 | label_list.append(torch.as_tensor(selected_idx, device=self.setup["device"])) 279 | g_i[idx] -= m_impact 280 | # Finalize labels: 281 | labels = torch.stack(label_list) 282 | 283 | elif self.cfg.label_strategy == "bias-corrected": # WIP 284 | # This is slightly modified analytic label recovery in the style of Wainakh 285 | bias_per_query = [shared_grad[-1] for shared_grad in user_data["gradients"]] 286 | label_list = [] 287 | # Stage 1 288 | average_bias = torch.stack(bias_per_query).mean(dim=0) 289 | valid_classes = (average_bias < 0).nonzero() 290 | label_list += [*valid_classes.squeeze(dim=-1)] 291 | m_impact = average_bias_correct_label = average_bias[valid_classes].sum() / num_data_points 292 | 293 | average_bias[valid_classes] = average_bias[valid_classes] - m_impact 294 | # Stage 2 295 | while len(label_list) < num_data_points: 296 | selected_idx = average_bias.argmin() 297 | label_list.append(selected_idx) 298 | average_bias[selected_idx] -= m_impact 299 | labels = torch.stack(label_list) 300 | 301 | elif self.cfg.label_strategy == "random": 302 | # A random baseline 303 | labels = torch.randint(0, num_classes, (num_data_points,), device=self.setup["device"]) 304 | elif self.cfg.label_strategy == "exhaustive": 305 | # Exhaustive search is possible in principle 306 | combinations = num_classes ** num_data_points 307 | raise ValueError( 308 | f"Exhaustive label searching not implemented. Nothing stops you though from running your" 309 | f"attack algorithm for any possible combination of labels, except computational effort." 310 | f"In the given setting, a naive exhaustive strategy would attack {combinations} label vectors." 311 | ) 312 | # Although this is arguably a worst-case estimate, you might be able to get "close enough" to the actual 313 | # label vector in much fewer queries, depending on which notion of close-enough makes sense for a given attack. 314 | else: 315 | raise ValueError(f"Invalid label recovery strategy {self.cfg.label_strategy} given.") 316 | 317 | # Pad with random labels if too few were produced: 318 | if len(labels) < num_data_points: 319 | labels = torch.cat( 320 | [labels, torch.randint(0, num_classes, (num_data_points - len(labels),), device=self.setup["device"])] 321 | ) 322 | 323 | # Always sort, order does not matter here: 324 | labels = labels.sort()[0] 325 | log.info(f"Recovered labels {labels.tolist()} through strategy {self.cfg.label_strategy}.") 326 | return labels 327 | --------------------------------------------------------------------------------