├── private_CNN ├── privacy_utils │ ├── accounting │ │ ├── __init__.py │ │ ├── gdp_accounting.py │ │ └── rdp_accounting.py │ ├── __init__.py │ ├── misc.py │ ├── autograd_grad_sample.py │ ├── supported_layers_grad_samplers.py │ ├── transformers_support.py │ └── privacy_engine.py ├── __init__.py └── README.md ├── assets ├── cifar10_memory_speed.png └── cifar10_stress_tests.png ├── README.md ├── cifar10.ipynb └── vit.ipynb /private_CNN/privacy_utils/accounting/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /private_CNN/privacy_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import privacy_engine 2 | -------------------------------------------------------------------------------- /private_CNN/__init__.py: -------------------------------------------------------------------------------- 1 | from .privacy_utils.privacy_engine import PrivacyEngine 2 | 3 | __version__ = '0.1.0' 4 | -------------------------------------------------------------------------------- /assets/cifar10_memory_speed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JialinMao/private_CNN/HEAD/assets/cifar10_memory_speed.png -------------------------------------------------------------------------------- /assets/cifar10_stress_tests.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JialinMao/private_CNN/HEAD/assets/cifar10_stress_tests.png -------------------------------------------------------------------------------- /private_CNN/README.md: -------------------------------------------------------------------------------- 1 | This code is largely based on and (v0.15). 2 | -------------------------------------------------------------------------------- /private_CNN/privacy_utils/misc.py: -------------------------------------------------------------------------------- 1 | """Miscellaneous helpers.""" 2 | import warnings 3 | 4 | 5 | def handle_unused_kwargs(unused_kwargs, msg=None): 6 | if len(unused_kwargs) > 0: 7 | if msg is not None: 8 | warnings.warn(f"{msg}: Unexpected arguments {unused_kwargs}") 9 | else: 10 | warnings.warn(f"Unexpected arguments {unused_kwargs}") 11 | -------------------------------------------------------------------------------- /private_CNN/privacy_utils/accounting/gdp_accounting.py: -------------------------------------------------------------------------------- 1 | r"""This code applies the Dual and Central Limit 2 | Theorem (CLT) to estimate privacy budget of an iterated subsampled 3 | Gaussian Mechanism (by either uniform or Poisson subsampling). 4 | 5 | This file is a direct copy of 6 | https://github.com/woodyx218/privacy/blob/d06340e1cf4944faa065644efb5e95950fbaf487/tensorflow_privacy/privacy/analysis/gdp_accountant.py 7 | """ 8 | 9 | import numpy as np 10 | from scipy import optimize 11 | from scipy.stats import norm 12 | 13 | 14 | def compute_mu_uniform(epochs, noise_multi, sample_rate): 15 | """Compute mu from uniform subsampling.""" 16 | T = epochs / sample_rate 17 | c = np.sqrt(T) * sample_rate 18 | return ( 19 | np.sqrt(2) * c * np.sqrt( 20 | np.exp(noise_multi ** (-2)) * norm.cdf(1.5 / noise_multi) + 3 * norm.cdf(-0.5 / noise_multi) - 2 21 | ) 22 | ) 23 | 24 | 25 | def compute_mu_poisson(epochs, noise_multi, sample_rate): 26 | """Compute mu from Poisson subsampling.""" 27 | T = epochs / sample_rate 28 | return np.sqrt(np.exp(noise_multi ** (-2)) - 1) * np.sqrt(T) * sample_rate 29 | 30 | 31 | def delta_eps_mu(eps, mu): 32 | """Compute dual between mu-GDP and (epsilon, delta)-DP.""" 33 | return norm.cdf(-eps / mu + mu / 2) - np.exp(eps) * norm.cdf(-eps / mu - mu / 2) 34 | 35 | 36 | def eps_from_mu(mu, delta, bracket=(0, 500)): 37 | """Compute epsilon from mu given delta via inverse dual.""" 38 | 39 | def f(x): 40 | """Reversely solve dual by matching delta.""" 41 | return delta_eps_mu(x, mu) - delta 42 | 43 | return optimize.root_scalar(f, bracket=bracket, method='brentq').root 44 | 45 | 46 | def compute_eps_uniform(epochs, noise_multi, sample_rate, delta): 47 | """Compute epsilon given delta from inverse dual of uniform subsampling.""" 48 | return eps_from_mu(compute_mu_uniform(epochs, noise_multi, sample_rate), delta) 49 | 50 | 51 | def compute_eps_poisson(epochs, noise_multi, sample_rate, delta): 52 | """Compute epsilon given delta from inverse dual of Poisson subsampling.""" 53 | return eps_from_mu(compute_mu_poisson(epochs, noise_multi, sample_rate), delta) 54 | 55 | 56 | def get_noise_multiplier( 57 | sample_rate, 58 | epochs, 59 | target_epsilon, 60 | target_delta, 61 | sigma_min=0.01, 62 | sigma_max=10.0, 63 | threshold=1e-3, 64 | ): 65 | """Estimate the noise multiplier by binary search.""" 66 | while sigma_max - sigma_min > threshold: 67 | sigma_mid = (sigma_min + sigma_max) / 2. 68 | epsilon = compute_eps_poisson(epochs, sigma_mid, sample_rate, target_delta) 69 | if epsilon > target_epsilon: 70 | sigma_min = sigma_mid 71 | else: 72 | sigma_max = sigma_mid 73 | return sigma_max 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # private_CNN 2 | 3 | This Pytorch codebase implements efficient training of differentially private (DP) convolutional neural networks (CNN), using [mixed ghost per-sample gradient clipping](https://arxiv.org/pdf/2205.10683.pdf). 4 | 5 |

6 | 7 |

8 | 9 | ## ❓ What is this? 10 | 11 | There are a few DP libraries that change the regular non-private training of neural networks to a private one. Examples include [Opacus](https://github.com/pytorch/opacus/blob/main/Migration_Guide.md#if-youre-using-virtual-steps), [FastGradClip](https://github.com/ppmlguy/fastgradclip), [private-transformers](https://github.com/lxuechen/private-transformers), and [tensorflow-privacy](https://github.com/tensorflow/privacy). 12 | 13 | However, they are not suitable for DP training of large CNNs, because they are either not generalizable or computationally inefficient. E.g. causing >20 times memory burden or >5 times slowdown than the regular training. 14 | 15 |

16 | 17 |

18 | 19 | This codebase implements a new technique --**the mixed ghost clipping**-- for the convolutional layers, that substantially reduces the space and time complexity of DP deep learning. 20 | 21 | ## 🔥 Highlights 22 | 23 | * We implement a mixed ghost clipping technique for the Conv1d/Conv2d/Conv3d layers, that trains DP CNNs almost as light as (with 0.1%-10% memory overhead) the regular training. This allows us to train 18 times larger batch size on VGG19 and CIFAR10 than Opacus. 24 | * Larger batch size can improve the throughput of mixed ghost clipping to be 3 times faster than existing DP training methods. On all models we tested, the slowdown is at most 2 times to the regular training. 25 | * We support general optimizers and clipping functions. Loading vision models from codebases such as [timm](https://github.com/rwightman/pytorch-image-models) and [torchvision](https://pytorch.org/vision/stable/models.html), our method can privately train VGG, ResNet, Wide ResNet, ResNeXt, etc. with a few additional lines of code. 26 | * We demonstrate DP training of convolutional Vision Transformers (up to 300 million parameters, again 10% memory overhead and less than 200% slowdonw than non-private training) and improve from 67.4% accuracy in previous SOTA to 83.0% at eps=1. 27 | 28 | ## 🚀 Getting Started 29 | 30 | Privately training vision models is simple: 31 | 32 | 1. Create the model and any optimizer 33 | 2. Attach this optimizer to our `PrivacyEngine` (this essentially adds Pytorch hooks for per-sample clipping) 34 | 3. Compute per-example losses (setting `reduction=='none'`) for a mini-batch of data 35 | 4. Pass the loss to `optimizer.step` or `optimizer.virtual_step` without calling the `backward' function (this is implicitly called in the`PrivacyEngine`) 36 | 37 | Below is a quick example of using our codebase for training CNN models with mixed ghost clipping: 38 | 39 | ```python 40 | import torchvision, torch, opacus 41 | from private_CNN import PrivacyEngine 42 | 43 | model = torchvision.models.resnet18() 44 | 45 | # replace BatchNorm by GroupNorm or LayerNorm 46 | model=opacus.validators.ModuleValidator.fix(model) 47 | 48 | optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-4) 49 | privacy_engine = PrivacyEngine( 50 | model, 51 | batch_size=256, 52 | sample_size=50000, 53 | epochs=3, 54 | max_grad_norm=0.1, 55 | target_epsilon=3, 56 | ghost_clipping=True, 57 | mixed=True, 58 | ) 59 | privacy_engine.attach(optimizer) 60 | 61 | # Same training procedure, e.g. data loading, forward pass, logits... 62 | loss = F.cross_entropy(logits, labels, reduction="none") 63 | # do not use loss.backward() 64 | optimizer.step(loss=loss) 65 | ``` 66 | 67 | In the above `PrivacyEngine`, setting keywords `ghost_clipping=True, mixed=False` implements the ghost clipping; setting `ghost_clipping=True, mixed=True` implements the best method, mixed ghost clipping; setting `ghost_clipping=False` implements similar approach to Opacus, which needs to instantiate the per-sample gradients. 68 | 69 | A special use of our privacy engine is to use the gradient accumulation. This is achieved with virtual step function. 70 | 71 | ```python 72 | import torchvision, torch, timm 73 | from private_CNN import PrivacyEngine 74 | 75 | gradient_accumulation_steps = 10 76 | 77 | # Batch size/physical batch size. Take an update once this many iterations 78 | 79 | model = torchvision.models.resnet18() 80 | model=opacus.validators.ModuleValidator.fix(model) 81 | optimizer = torch.optim.Adam(model.parameters()) 82 | privacy_engine = PrivacyEngine(...) 83 | privacy_engine.attach(optimizer) 84 | 85 | for i, batch in enumerate(dataloader): 86 | loss = F.cross_entropy(model(batch), labels, reduction="none") 87 | if i % gradient_accumulation_steps == 0: 88 | optimizer.step(loss=loss) 89 | optimizer.zero_grad() 90 | else: 91 | optimizer.virtual_step(loss=loss) 92 | ``` 93 | 94 | ## :clipboard: Currently Supported Modules via Ghost Clipping 95 | 96 | * nn.Linear ([Ian Goodfellow](https://arxiv.org/abs/1510.01799)) 97 | * nn.Linear sequential ([Xuechen et al.](https://arxiv.org/abs/2110.05679)) 98 | * nn.LayerNorm ([Xuechen et al.](https://arxiv.org/abs/2110.05679)) 99 | * nn.Embedding ([Xuechen et al.](https://arxiv.org/abs/2110.05679)) 100 | * Conv1d (this work) 101 | * Conv2d (this work) 102 | * Conv3d (this work) 103 | * nn.GroupNorm (this work) 104 | 105 | ### :warning: Caution 106 | 107 | * **Batch normalization does not satisfy DP.** This is because the mean and variance of batch normalization is computed from data without privatization. To train DP networks, replace batch normalization with group/instance/layer normalization. [Opacus (>v1.0)](https://github.com/pytorch/opacus/blob/main/tutorials/guide_to_module_validator.ipynb) provides an easy fixer for this replacement via `opacus.validators.ModuleValidator.fix`, but you can also change the normalization layer manually. 108 | * **Extra care needed for sampling.** Taking virtual step with fixed virtual batch size is not compatible with Poisson sampling. [Opacus] provides `BatchMemoryManager` to feature this [sampling issue](https://github.com/pytorch/opacus/blob/main/Migration_Guide.md#if-youre-using-virtual-steps) and our mixed ghost clipping can be merged easily. Also we didn't use secure PRNG to sample the noise for this experimental codebase. 109 | 110 | ## Citation 111 | 112 | Please cite our paper if you use private_CNN in your papers, as follows: 113 | 114 | ``` 115 | @article{bu2022scalable, 116 | title={Scalable and Efficient Training of Large Convolutional Neural Networks with Differential Privacy}, 117 | author={Bu, Zhiqi and Mao, Jialin and Xu, Shiyun}, 118 | journal={arXiv preprint arXiv:2205.10683}, 119 | year={2022} 120 | } 121 | ``` 122 | 123 | ## Acknowledgement 124 | 125 | This code is largely based on and (v0.15). 126 | -------------------------------------------------------------------------------- /private_CNN/privacy_utils/autograd_grad_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | A large portion of this code is adapted from Opacus (https://github.com/pytorch/opacus), 3 | which is licensed under Apache License 2.0. 4 | 5 | We have modified it considerably to support ghost clipping. 6 | """ 7 | 8 | from typing import Tuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from .supported_layers_grad_samplers import _supported_layers_grad_samplers 14 | 15 | # work-around for https://github.com/pytorch/pytorch/issues/25723 16 | _hooks_disabled: bool = False 17 | _fp16 = False 18 | _hooks_mode = "default" 19 | 20 | 21 | def set_hooks_mode(mode): 22 | if mode not in ("ghost_norm", "ghost_grad", "default", "ghost_norm_mixed"): 23 | raise ValueError(f"Unknown mode for hooks: {mode}; expected one of `ghost_norm`, `ghost_grad`, `ghost_norm_mixed`, `default`.") 24 | 25 | global _hooks_mode 26 | _hooks_mode = mode 27 | 28 | if "ghost_grad" in _hooks_mode: 29 | disable_hooks() 30 | elif "ghost_norm" in _hooks_mode: 31 | enable_hooks() 32 | 33 | 34 | def get_hooks_mode(): 35 | global _hooks_mode 36 | return _hooks_mode 37 | 38 | 39 | def enable_fp16(): 40 | global _fp16 41 | _fp16 = True 42 | 43 | 44 | def disable_fp16(): 45 | global _fp16 46 | _fp16 = False 47 | 48 | 49 | def fp16(): 50 | return _fp16 51 | 52 | 53 | def has_no_param(module: nn.Module) -> bool: 54 | """ 55 | Checks if a module does not have any parameters. 56 | 57 | Args: 58 | module: The module on which this function is being evaluated. 59 | 60 | Returns: 61 | Flag indicating if the provided module does not have any 62 | parameters. 63 | """ 64 | has_params = any(p is not None for p in module.parameters(recurse=False)) 65 | return not has_params 66 | 67 | 68 | def requires_grad(module: nn.Module, recurse: bool = False) -> bool: 69 | """ 70 | Checks if any parameters in a specified module require gradients. 71 | 72 | Args: 73 | module: PyTorch module whose parameters are examined 74 | recurse: Flag specifying if the gradient requirement check should 75 | be applied recursively to sub-modules of the specified module 76 | 77 | Returns: 78 | Flag indicate if any parameters require gradients 79 | """ 80 | requires_grad = any(p.requires_grad for p in module.parameters(recurse)) 81 | return requires_grad 82 | 83 | 84 | def get_layer_type(layer: nn.Module) -> str: 85 | """ 86 | Returns the name of the type of the given layer. 87 | 88 | Args: 89 | layer: The module corresponding to the layer whose type 90 | is being queried. 91 | 92 | Returns: 93 | Name of the class of the layer 94 | """ 95 | return layer.__class__.__name__ 96 | 97 | 98 | def add_hooks( 99 | model: nn.Module, 100 | loss_reduction: str = "mean", 101 | batch_first: bool = True, 102 | fp16=False, 103 | ): 104 | r""" 105 | Adds hooks to model to save activations and backprop values. 106 | The hooks will 107 | 108 | 1. save activations into ``param.activations`` during forward pass. 109 | 2. compute per-sample gradients and save them in ``param.grad_sample`` during backward pass. 110 | 111 | Args: 112 | model: Model to which hooks are added. 113 | loss_reduction: Indicates if the loss reduction (for aggregating the 114 | gradients) is a sum or a mean operation. Can take values ``sum`` or 115 | ``mean``. 116 | batch_first: Flag to indicate if the input tensor to the corresponding module 117 | has the first dimension represent the batch, for example of shape 118 | ``[batch_size, ..., ...]``. Set to True if batch appears in first 119 | dimension else set to False (``batch_first=False`` implies that the 120 | batch is always in the second dimension). 121 | fp16: Perform backward computation in fp16 for as much as possible if True. 122 | """ 123 | if hasattr(model, "autograd_grad_sample_hooks"): 124 | raise ValueError("Trying to add hooks twice to the same model") 125 | 126 | enable_hooks() 127 | if fp16: 128 | enable_fp16() 129 | 130 | handles = [] 131 | for name, layer in model.named_modules(): 132 | if get_layer_type(layer) in _supported_layers_grad_samplers.keys(): 133 | # Check if the layer has trainable parameters. 134 | is_trainable = False 135 | for p in layer.parameters(recurse=False): 136 | if p.requires_grad: 137 | is_trainable = True 138 | break 139 | 140 | if is_trainable: 141 | handles.append(layer.register_forward_hook(_capture_activations)) 142 | 143 | def this_backward(this_layer, grad_input, grad_output): 144 | return _capture_backprops( 145 | this_layer, grad_input, grad_output, loss_reduction, batch_first 146 | ) 147 | 148 | # Starting with 1.8.0, use `register_full_backward_hook`. 149 | handles.append(layer.register_backward_hook(this_backward)) 150 | 151 | model.__dict__.setdefault("autograd_grad_sample_hooks", []).extend(handles) 152 | 153 | 154 | def remove_hooks(model: nn.Module): 155 | """Removes hooks added by `add_hooks()`.""" 156 | if not hasattr(model, "autograd_grad_sample_hooks"): 157 | raise ValueError("Asked to remove hooks, but no hooks found") 158 | else: 159 | for handle in model.autograd_grad_sample_hooks: 160 | handle.remove() 161 | del model.autograd_grad_sample_hooks 162 | 163 | 164 | def disable_hooks(): 165 | """Globally disables all hooks installed by this library.""" 166 | global _hooks_disabled 167 | _hooks_disabled = True 168 | 169 | 170 | def enable_hooks(): 171 | """Globally enables all hooks installed by this library.""" 172 | global _hooks_disabled 173 | _hooks_disabled = False 174 | 175 | 176 | def is_supported(layer: nn.Module) -> bool: 177 | """Checks if the layer is supported by this library.""" 178 | return get_layer_type(layer) in list(_supported_layers_grad_samplers.keys()) 179 | 180 | 181 | def _capture_activations(layer: nn.Module, inputs: Tuple[torch.Tensor], outputs: Tuple[torch.Tensor]): 182 | """Forward hook handler captures and saves activations.""" 183 | layer_type = get_layer_type(layer) 184 | if ( 185 | not requires_grad(layer) 186 | or layer_type not in _supported_layers_grad_samplers.keys() 187 | or not layer.training 188 | ): 189 | return 190 | 191 | if _hooks_disabled: 192 | return 193 | if get_layer_type(layer) not in _supported_layers_grad_samplers.keys(): 194 | raise ValueError("Hook installed on unsupported layer") 195 | 196 | if not hasattr(layer, "activations"): 197 | layer.activations = [] 198 | 199 | layer.activations.append(inputs[0].detach()) 200 | 201 | 202 | def _capture_backprops( 203 | layer: nn.Module, 204 | inputs: Tuple[torch.Tensor], 205 | outputs: Tuple[torch.Tensor], 206 | loss_reduction: str, 207 | batch_first: bool, 208 | ): 209 | """Backward hook handler captures grad_outputs.""" 210 | if _hooks_disabled: 211 | return 212 | 213 | backprops = outputs[0].detach() 214 | _compute_grad_sample(layer, backprops, loss_reduction, batch_first) 215 | 216 | 217 | def _compute_grad_sample(layer: nn.Module, backprops: torch.Tensor, loss_reduction: str, batch_first: bool): 218 | """Computes per-sample gradients with respect to the parameters.""" 219 | layer_type = get_layer_type(layer) 220 | if ( 221 | not requires_grad(layer) 222 | or layer_type not in _supported_layers_grad_samplers.keys() 223 | or not layer.training 224 | ): 225 | return 226 | 227 | if not hasattr(layer, "activations"): 228 | raise ValueError( 229 | f"No activations detected for {type(layer)}," 230 | " run forward after add_hooks(model)" 231 | ) 232 | 233 | # Outside of the LSTM there is "batch_first" but not for the Linear inside the LSTM 234 | batch_dim = 0 if batch_first else 1 235 | if isinstance(layer.activations, list): 236 | A = layer.activations.pop() 237 | else: 238 | A = layer.activations 239 | 240 | if not hasattr(layer, "max_batch_len"): 241 | layer.max_batch_len = _get_batch_size(layer, A, batch_dim) 242 | 243 | n = layer.max_batch_len 244 | if loss_reduction == "mean": 245 | B = backprops * n 246 | elif loss_reduction == "sum": 247 | B = backprops 248 | else: 249 | raise ValueError( 250 | f"loss_reduction = {loss_reduction}. Only 'sum' and 'mean' losses are supported" 251 | ) 252 | 253 | # rearrange the blob dimensions 254 | if batch_dim != 0: 255 | A = A.permute([batch_dim] + [x for x in range(A.dim()) if x != batch_dim]) 256 | B = B.permute([batch_dim] + [x for x in range(B.dim()) if x != batch_dim]) 257 | # compute grad sample for individual layers 258 | compute_layer_grad_sample = _supported_layers_grad_samplers.get( 259 | get_layer_type(layer) 260 | ) 261 | 262 | compute_layer_grad_sample(layer, A, B) 263 | 264 | if ( 265 | not isinstance(layer.activations, list) or len(layer.activations) == 0 266 | ) and hasattr(layer, "max_batch_len"): 267 | del layer.max_batch_len 268 | 269 | 270 | def _get_batch_size(layer: nn.Module, grad_sample: torch.Tensor, batch_dim: int) -> int: 271 | r""" 272 | Computes and returns the maximum batch size which is the maximum of the dimension values 273 | along 'batch_dim' axis over layer.activations + [grad_sample], where layer.activations is 274 | a list. If layer.activations is a not a list, then return grad_sample.shape[batch_dim]. 275 | """ 276 | 277 | max_batch_len = 0 278 | if isinstance(layer.activations, list): 279 | for out in layer.activations: 280 | if out.shape[batch_dim] > max_batch_len: 281 | max_batch_len = out.shape[batch_dim] 282 | 283 | max_batch_len = max(max_batch_len, grad_sample.shape[batch_dim]) 284 | return max_batch_len 285 | -------------------------------------------------------------------------------- /private_CNN/privacy_utils/accounting/rdp_accounting.py: -------------------------------------------------------------------------------- 1 | r""" 2 | This file is adapted from the privacy accounting procedure in Opacus', which in turn is adapted from tf-privacy. 3 | Below is the original documentation in Opacus. 4 | 5 | *Based on Google's TF Privacy:* https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/analysis 6 | /rdp_accountant.py. 7 | *Here, we update this code to Python 3, and optimize dependencies.* 8 | 9 | Functionality for computing Renyi Differential Privacy (RDP) of an additive 10 | Sampled Gaussian Mechanism (SGM). 11 | 12 | Example: 13 | Suppose that we have run an SGM applied to a function with L2-sensitivity of 1. 14 | 15 | Its parameters are given as a list of tuples 16 | ``[(q_1, sigma_1, steps_1), ..., (q_k, sigma_k, steps_k)],`` 17 | and we wish to compute epsilon for a given target delta. 18 | 19 | The example code would be: 20 | 21 | >>> max_order = 32 22 | >>> orders = range(2, max_order + 1) 23 | >>> rdp = np.zeros_like(orders, dtype=float) 24 | >>> for q, sigma, steps in parameters: 25 | >>> rdp += privacy_analysis.compute_rdp(q, sigma, steps, orders) 26 | >>> epsilon, opt_order = privacy_analysis.get_privacy_spent(orders, rdp, delta) 27 | """ 28 | 29 | import math 30 | from typing import List 31 | from typing import Tuple 32 | from typing import Union 33 | 34 | import numpy as np 35 | from scipy import special 36 | 37 | 38 | ######################## 39 | # LOG-SPACE ARITHMETIC # 40 | ######################## 41 | 42 | 43 | def _log_add(logx: float, logy: float) -> float: 44 | r"""Adds two numbers in the log space. 45 | 46 | Args: 47 | logx: First term in log space. 48 | logy: Second term in log space. 49 | 50 | Returns: 51 | Sum of numbers in log space. 52 | """ 53 | a, b = min(logx, logy), max(logx, logy) 54 | if a == -np.inf: # adding 0 55 | return b 56 | # Use exp(a) + exp(b) = (exp(a - b) + 1) * exp(b) 57 | return math.log1p(math.exp(a - b)) + b # log1p(x) = log(x + 1) 58 | 59 | 60 | def _log_sub(logx: float, logy: float) -> float: 61 | r"""Subtracts two numbers in the log space. 62 | 63 | Args: 64 | logx: First term in log space. Expected to be greater than the second term. 65 | logy: First term in log space. Expected to be less than the first term. 66 | 67 | Returns: 68 | Difference of numbers in log space. 69 | 70 | Raises: 71 | ValueError 72 | If the result is negative. 73 | """ 74 | if logx < logy: 75 | raise ValueError("The result of subtraction must be non-negative.") 76 | if logy == -np.inf: # subtracting 0 77 | return logx 78 | if logx == logy: 79 | return -np.inf # 0 is represented as -np.inf in the log space. 80 | 81 | try: 82 | # Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y). 83 | return math.log(math.expm1(logx - logy)) + logy # expm1(x) = exp(x) - 1 84 | except OverflowError: 85 | return logx 86 | 87 | 88 | def _compute_log_a_for_int_alpha(q: float, sigma: float, alpha: int) -> float: 89 | r"""Computes :math:`log(A_\alpha)` for integer ``alpha``. 90 | 91 | Notes: 92 | Note that 93 | :math:`A_\alpha` is real valued function of ``alpha`` and ``q``, 94 | and that 0 < ``q`` < 1. 95 | 96 | Refer to Section 3.3 of https://arxiv.org/pdf/1908.10530.pdf for details. 97 | 98 | Args: 99 | q: Sampling rate of SGM. 100 | sigma: The standard deviation of the additive Gaussian noise. 101 | alpha: The order at which RDP is computed. 102 | 103 | Returns: 104 | :math:`log(A_\alpha)` as defined in Section 3.3 of 105 | https://arxiv.org/pdf/1908.10530.pdf. 106 | """ 107 | 108 | # Initialize with 0 in the log space. 109 | log_a = -np.inf 110 | 111 | for i in range(alpha + 1): 112 | log_coef_i = ( 113 | math.log(special.binom(alpha, i)) 114 | + i * math.log(q) 115 | + (alpha - i) * math.log(1 - q) 116 | ) 117 | 118 | s = log_coef_i + (i * i - i) / (2 * (sigma ** 2)) 119 | log_a = _log_add(log_a, s) 120 | 121 | return float(log_a) 122 | 123 | 124 | def _compute_log_a_for_frac_alpha(q: float, sigma: float, alpha: float) -> float: 125 | r"""Computes :math:`log(A_\alpha)` for fractional ``alpha``. 126 | 127 | Notes: 128 | Note that 129 | :math:`A_\alpha` is real valued function of ``alpha`` and ``q``, 130 | and that 0 < ``q`` < 1. 131 | 132 | Refer to Section 3.3 of https://arxiv.org/pdf/1908.10530.pdf for details. 133 | 134 | Args: 135 | q: Sampling rate of SGM. 136 | sigma: The standard deviation of the additive Gaussian noise. 137 | alpha: The order at which RDP is computed. 138 | 139 | Returns: 140 | :math:`log(A_\alpha)` as defined in Section 3.3 of 141 | https://arxiv.org/pdf/1908.10530.pdf. 142 | """ 143 | # The two parts of A_alpha, integrals over (-inf,z0] and [z0, +inf), are 144 | # initialized to 0 in the log space: 145 | log_a0, log_a1 = -np.inf, -np.inf 146 | i = 0 147 | 148 | z0 = sigma ** 2 * math.log(1 / q - 1) + 0.5 149 | 150 | while True: # do ... until loop 151 | coef = special.binom(alpha, i) 152 | log_coef = math.log(abs(coef)) 153 | j = alpha - i 154 | 155 | log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q) 156 | log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q) 157 | 158 | log_e0 = math.log(0.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma)) 159 | log_e1 = math.log(0.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma)) 160 | 161 | log_s0 = log_t0 + (i * i - i) / (2 * (sigma ** 2)) + log_e0 162 | log_s1 = log_t1 + (j * j - j) / (2 * (sigma ** 2)) + log_e1 163 | 164 | if coef > 0: 165 | log_a0 = _log_add(log_a0, log_s0) 166 | log_a1 = _log_add(log_a1, log_s1) 167 | else: 168 | log_a0 = _log_sub(log_a0, log_s0) 169 | log_a1 = _log_sub(log_a1, log_s1) 170 | 171 | i += 1 172 | if max(log_s0, log_s1) < -30: 173 | break 174 | 175 | return _log_add(log_a0, log_a1) 176 | 177 | 178 | def _compute_log_a(q: float, sigma: float, alpha: float) -> float: 179 | r"""Computes :math:`log(A_\alpha)` for any positive finite ``alpha``. 180 | 181 | Notes: 182 | Note that 183 | :math:`A_\alpha` is real valued function of ``alpha`` and ``q``, 184 | and that 0 < ``q`` < 1. 185 | 186 | Refer to Section 3.3 of https://arxiv.org/pdf/1908.10530.pdf 187 | for details. 188 | 189 | Args: 190 | q: Sampling rate of SGM. 191 | sigma: The standard deviation of the additive Gaussian noise. 192 | alpha: The order at which RDP is computed. 193 | 194 | Returns: 195 | :math:`log(A_\alpha)` as defined in the paper mentioned above. 196 | """ 197 | if float(alpha).is_integer(): 198 | return _compute_log_a_for_int_alpha(q, sigma, int(alpha)) 199 | else: 200 | return _compute_log_a_for_frac_alpha(q, sigma, alpha) 201 | 202 | 203 | def _log_erfc(x: float) -> float: 204 | r"""Computes :math:`log(erfc(x))` with high accuracy for large ``x``. 205 | 206 | Helper function used in computation of :math:`log(A_\alpha)` 207 | for a fractional alpha. 208 | 209 | Args: 210 | x: The input to the function 211 | 212 | Returns: 213 | :math:`log(erfc(x))` 214 | """ 215 | return math.log(2) + special.log_ndtr(-x * 2 ** 0.5) 216 | 217 | 218 | def _compute_rdp(q: float, sigma: float, alpha: float) -> float: 219 | r"""Computes RDP of the Sampled Gaussian Mechanism at order ``alpha``. 220 | 221 | Args: 222 | q: Sampling rate of SGM. 223 | sigma: The standard deviation of the additive Gaussian noise. 224 | alpha: The order at which RDP is computed. 225 | 226 | Returns: 227 | RDP at order ``alpha``; can be np.inf. 228 | """ 229 | if q == 0: 230 | return 0 231 | 232 | # no privacy 233 | if sigma == 0: 234 | return np.inf 235 | 236 | if q == 1.0: 237 | return alpha / (2 * sigma ** 2) 238 | 239 | if np.isinf(alpha): 240 | return np.inf 241 | 242 | return _compute_log_a(q, sigma, alpha) / (alpha - 1) 243 | 244 | 245 | def compute_rdp( 246 | q: float, noise_multiplier: float, steps: int, orders: Union[List[float], float] 247 | ) -> Union[List[float], float]: 248 | r"""Computes Renyi Differential Privacy (RDP) guarantees of the 249 | Sampled Gaussian Mechanism (SGM) iterated ``steps`` times. 250 | 251 | Args: 252 | q: Sampling rate of SGM. 253 | noise_multiplier: The ratio of the standard deviation of the 254 | additive Gaussian noise to the L2-sensitivity of the function 255 | to which it is added. Note that this is same as the standard 256 | deviation of the additive Gaussian noise when the L2-sensitivity 257 | of the function is 1. 258 | steps: The number of iterations of the mechanism. 259 | orders: An array (or a scalar) of RDP orders. 260 | 261 | Returns: 262 | The RDP guarantees at all orders; can be ``np.inf``. 263 | """ 264 | if isinstance(orders, float): 265 | rdp = _compute_rdp(q, noise_multiplier, orders) 266 | else: 267 | rdp = np.array([_compute_rdp(q, noise_multiplier, order) for order in orders]) 268 | 269 | return rdp * steps 270 | 271 | 272 | def get_privacy_spent( 273 | orders: Union[List[float], float], rdp: Union[List[float], float], delta: float 274 | ) -> Tuple[float, float]: 275 | r"""Computes epsilon given a list of Renyi Differential Privacy (RDP) values at 276 | multiple RDP orders and target ``delta``. 277 | 278 | Args: 279 | orders: An array (or a scalar) of orders (alphas). 280 | rdp: A list (or a scalar) of RDP guarantees. 281 | delta: The target delta. 282 | 283 | Returns: 284 | Pair of epsilon and optimal order alpha. 285 | 286 | Raises: 287 | ValueError 288 | If the lengths of ``orders`` and ``rdp`` are not equal. 289 | """ 290 | orders_vec = np.atleast_1d(orders) 291 | rdp_vec = np.atleast_1d(rdp) 292 | 293 | if len(orders_vec) != len(rdp_vec): 294 | raise ValueError( 295 | f"Input lists must have the same length.\n" 296 | f"\torders_vec = {orders_vec}\n" 297 | f"\trdp_vec = {rdp_vec}\n" 298 | ) 299 | 300 | eps = rdp_vec - math.log(delta) / (orders_vec - 1) 301 | 302 | # special case when there is no privacy 303 | if np.isnan(eps).all(): 304 | return np.inf, np.nan 305 | 306 | idx_opt = np.nanargmin(eps) # Ignore NaNs 307 | return eps[idx_opt], orders_vec[idx_opt] 308 | -------------------------------------------------------------------------------- /private_CNN/privacy_utils/supported_layers_grad_samplers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module is a collection of grad samplers - methods to calculate per sample gradients 3 | for a layer given two tensors: 1) inputs, 2) grad_outputs. 4 | 5 | Supports ghost clipping introduced in 6 | Li, X., Tramèr, F., Liang, P., & Hashimoto, T. (2021). 7 | Large Language Models Can Be Strong Differentially Private Learners. arXiv preprint arXiv:2110.05679. 8 | 9 | A large portion of this code is adapted from Opacus (https://github.com/pytorch/opacus), 10 | which is licensed under Apache License 2.0. 11 | """ 12 | 13 | import torch 14 | from torch import nn 15 | from torch.functional import F 16 | from functools import partial 17 | 18 | from . import autograd_grad_sample 19 | 20 | 21 | def sum_over_all_but_batch_and_last_n(tensor: torch.Tensor, n_dims: int) -> torch.Tensor: 22 | if tensor.dim() == n_dims + 1: 23 | return tensor 24 | else: 25 | dims = list(range(1, tensor.dim() - n_dims)) 26 | return tensor.sum(dim=dims) 27 | 28 | 29 | def _light_linear_weight_norm_sample(A, B) -> torch.Tensor: 30 | """Compute gradient sample norm for the weight matrix in a linear layer.""" 31 | if A.dim() == 2: 32 | return _light_linear_weight_norm_sample_non_sequential(A, B) 33 | elif A.dim() == 3: 34 | return _light_linear_weight_norm_sample_sequential(A, B) 35 | else: 36 | raise ValueError( 37 | f"Unexpected input shape: {A.size()}, grad_output shape: {B.size()}") 38 | 39 | 40 | def _light_linear_weight_norm_sample_sequential(A, B): 41 | """Lightweight norm computation in ghost clipping.""" 42 | return torch.sqrt( 43 | (torch.bmm(A, A.transpose(-1, -2)) * 44 | torch.bmm(B, B.transpose(-1, -2))).sum(dim=(1, 2)) 45 | ) 46 | 47 | 48 | def _light_linear_weight_norm_sample_non_sequential(A, B): 49 | """The Goodfellow trick, i.e., Frobenius norm equal to product of 2-norms.""" 50 | return A.norm(2, dim=1) * B.norm(2, dim=1) 51 | 52 | 53 | def _light_linear_bias_norm_sample(B): 54 | if B.dim() == 2: 55 | return B.norm(2, dim=1) 56 | elif B.dim() == 3: 57 | return B.sum(dim=1).norm(2, dim=1) 58 | else: 59 | raise ValueError(f"Unexpected grad_output shape: {B.size()}") 60 | 61 | 62 | def _create_or_extend_grad_sample(param: torch.Tensor, grad_sample: torch.Tensor, batch_dim: int) -> None: 63 | """Creates a ``grad_sample`` attribute in the given parameter or accumulate the existing tensor.""" 64 | if hasattr(param, "requires_grad") and not param.requires_grad: 65 | return 66 | 67 | assert grad_sample.shape[1: 68 | ] == param.shape, f"grad_sample.size()={grad_sample.size()}, param.size()={param.size()}" 69 | 70 | # Warning: When a parameter with `grad_sample` is reused, the per-sample gradients are accumulated. 71 | if hasattr(param, "grad_sample"): 72 | param.grad_sample += grad_sample.detach() 73 | else: 74 | param.grad_sample = grad_sample.detach() 75 | 76 | 77 | def _create_or_extend_norm_sample(param: torch.Tensor, norm_sample: torch.Tensor) -> None: 78 | """Creates a ``norm_sample`` attribute in the given parameter.""" 79 | if not hasattr(param, "requires_grad") or not param.requires_grad: 80 | return 81 | 82 | if "ghost_norm" in autograd_grad_sample.get_hooks_mode(): 83 | if hasattr(param, 'norm_sample'): 84 | raise ValueError("Ghost clipping does not support parameter sharing. " 85 | "Parameter sharing may be due to default parameter sharing between lm_head and embedding." 86 | "Please use a model without parameter sharing for ghost clipping.") 87 | param.norm_sample = norm_sample 88 | else: # mode == "grad"; should not get here. 89 | raise ValueError( 90 | "Internal error: Trying to extend `norm_sample` when `_hooks_mode='ghost_grad'`.") 91 | 92 | 93 | def _compute_linear_grad_sample(layer: nn.Linear, A: torch.Tensor, B: torch.Tensor, batch_dim: int = 0) -> None: 94 | """Computes per sample gradients for `nn.Linear` layer. 95 | 96 | This function is written in an unusually bespoke way to avoid using `torch.einsum`. 97 | """ 98 | if autograd_grad_sample.fp16(): 99 | if B.dtype != torch.half: 100 | B = B.half() 101 | if A.dtype != torch.half: 102 | A = A.half() 103 | 104 | hooks_mode = autograd_grad_sample.get_hooks_mode() 105 | if "ghost_norm" in hooks_mode: 106 | if "flex" in hooks_mode: 107 | if hasattr(layer, "use_gc"): 108 | use_gc = layer.use_gc 109 | else: 110 | L = torch.prod(torch.Tensor(list(A.shape[1:-1]))) 111 | assert L == torch.prod(torch.Tensor(list(B.shape[1:-1]))) 112 | 113 | d = A.shape[-1] 114 | p = B.shape[-1] # torch.prod(torch.Tensor(list(B.shape[2:]))) 115 | b = B.shape[0] 116 | if "mem" in hooks_mode: 117 | use_gc = (2*L**2 <= d*p) 118 | layer.use_gc = use_gc 119 | elif "time" in hooks_mode: 120 | use_gc = (2*b*L**2*(d+p+1) - b <= 2*b*L*p*d +(4*b-1)*p*d) 121 | layer.use_gc = use_gc 122 | else: 123 | use_gc = True 124 | if use_gc: 125 | _create_or_extend_norm_sample( 126 | layer.weight, _light_linear_weight_norm_sample(A, B)) 127 | else: 128 | if A.dim()==2: 129 | grads = torch.einsum('bd, bp-> bdp', A, B) 130 | else: 131 | A=torch.flatten(A,start_dim=1,end_dim=-2) 132 | B=torch.flatten(B,start_dim=1,end_dim=-2) 133 | 134 | grads = torch.einsum('bTd, bTp-> bdp', A, B) 135 | gnorm = torch.sqrt(torch.sum(grads**2, dim=(1, 2))) 136 | _create_or_extend_norm_sample(layer.weight, gnorm) 137 | 138 | if layer.bias is not None: 139 | _create_or_extend_norm_sample( 140 | layer.bias, _light_linear_bias_norm_sample(B)) 141 | 142 | else: 143 | if B.dim() >= 3 and A.dim() >= 3: 144 | A=torch.flatten(A,start_dim=1,end_dim=-2) 145 | B=torch.flatten(B,start_dim=1,end_dim=-2) 146 | 147 | grad_weight = torch.bmm(B.permute(0, 2, 1), A) 148 | grad_bias = B.sum(dim=1) 149 | elif B.dim() == 2 and A.dim() == 2: 150 | grad_weight = B[:, :, None] * A[:, None, :] 151 | grad_bias = B 152 | else: 153 | raise ValueError( 154 | f"Expected both grad_output and input to have dimension 2 or 3, " 155 | f"but found len(grad_output.dim())={len(B.dim())}, len(input.dim())={len(A.dim())}" 156 | ) 157 | _create_or_extend_grad_sample(layer.weight, grad_weight, batch_dim) 158 | 159 | if layer.bias is not None: 160 | _create_or_extend_grad_sample(layer.bias, grad_bias, batch_dim) 161 | 162 | 163 | def _compute_norm_grad_sample( 164 | layer: nn.LayerNorm, 165 | A: torch.Tensor, 166 | B: torch.Tensor, 167 | batch_dim: int = 0, 168 | ) -> None: 169 | """Computes per sample gradients for normalization layers.""" 170 | if autograd_grad_sample.fp16(): 171 | if A.dtype != torch.half: 172 | A = A.half() 173 | if B.dtype != torch.half: 174 | B = B.half() 175 | 176 | is_backward_ghost_norm = "ghost_norm" in autograd_grad_sample.get_hooks_mode() 177 | 178 | grad_sample = sum_over_all_but_batch_and_last_n( 179 | F.layer_norm(A, layer.normalized_shape, eps=layer.eps) * B, 180 | layer.weight.dim(), 181 | ) 182 | if is_backward_ghost_norm: 183 | norm_sample = grad_sample.flatten(start_dim=1).norm(2, dim=1) 184 | _create_or_extend_norm_sample(layer.weight, norm_sample) 185 | else: 186 | _create_or_extend_grad_sample(layer.weight, grad_sample, batch_dim) 187 | 188 | grad_sample = sum_over_all_but_batch_and_last_n(B, layer.bias.dim()) 189 | if is_backward_ghost_norm: 190 | norm_sample = grad_sample.flatten(start_dim=1).norm(2, dim=1) 191 | _create_or_extend_norm_sample(layer.bias, norm_sample) 192 | else: 193 | _create_or_extend_grad_sample(layer.bias, grad_sample, batch_dim) 194 | 195 | 196 | def _compute_group_norm_grad_sample( 197 | layer: nn.GroupNorm, 198 | A: torch.Tensor, 199 | B: torch.Tensor, 200 | batch_dim: int = 0, 201 | ) -> None: 202 | """Computes per sample gradients for normalization layers.""" 203 | if autograd_grad_sample.fp16(): 204 | if A.dtype != torch.half: 205 | A = A.half() 206 | if B.dtype != torch.half: 207 | B = B.half() 208 | 209 | is_backward_ghost_norm = "ghost_norm" in autograd_grad_sample.get_hooks_mode() 210 | 211 | grad_sample = torch.einsum( 212 | "ni...->ni", F.group_norm(A, layer.num_groups, eps=layer.eps) * B) 213 | bias_grad_sample = torch.einsum("ni...->ni", B) 214 | if is_backward_ghost_norm: 215 | _create_or_extend_norm_sample(layer.weight, grad_sample.norm(2, dim=1)) 216 | if layer.bias is not None: 217 | _create_or_extend_norm_sample( 218 | layer.bias, bias_grad_sample.norm(2, dim=1)) 219 | else: 220 | _create_or_extend_grad_sample(layer.weight, grad_sample, batch_dim) 221 | if layer.bias is not None: 222 | _create_or_extend_grad_sample( 223 | layer.bias, bias_grad_sample, batch_dim) 224 | 225 | 226 | def _compute_embedding_grad_sample(layer: nn.Embedding, A: torch.Tensor, B: torch.Tensor, batch_dim: int = 0) -> None: 227 | """Computes per sample gradients for `nn.Embedding` layer.""" 228 | 229 | if autograd_grad_sample.fp16(): 230 | if B.dtype != torch.half: 231 | B = B.half() 232 | 233 | if "ghost_norm" in autograd_grad_sample.get_hooks_mode(): 234 | not_AAt: torch.Tensor = ~A[:, :, None].eq(A[:, None, :]) 235 | # Clear the contribution to the norm of the gradient for the padding token. 236 | # In vanilla backpropagation, this particular embedding doesn't contribute to the gradient anyway. 237 | # For more see 1.10.0 doc: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html 238 | # 'the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed “pad”.' 239 | padding_idx = layer.padding_idx 240 | if padding_idx is not None: 241 | # The right way to think about the next line of code is that A_i[t, padding_idx] = 0 for all t in [T]. 242 | # So the entry gets cleared whenever one of A, A^t takes the padding idx. 243 | not_AAt.bitwise_or_((A[:, :, None] == padding_idx) | ( 244 | A[:, None, :] == padding_idx)) 245 | norm_sample = torch.sqrt( 246 | (torch.bmm(B, B.transpose(-1, -2)).masked_fill(not_AAt, 0)).sum(dim=(1, 2))) 247 | _create_or_extend_norm_sample(layer.weight, norm_sample) 248 | else: 249 | A_dense = F.one_hot(A, num_classes=layer.weight.shape[0]).to( 250 | B) # (batch_size, seq_len, vocab_dim,) 251 | grad_sample = torch.bmm(A_dense.permute(0, 2, 1), B) 252 | # `torch.nn.Embedding` layers don't accumulate gradient on the padding_idx position. 253 | # We do the same for `grad_sample`. 254 | if layer.padding_idx is not None: 255 | # `grad_sample` has size (batch_size, num_vocab, embedding_dim). 256 | grad_sample[:, layer.padding_idx, :] = 0. 257 | _create_or_extend_grad_sample(layer.weight, grad_sample, batch_dim) 258 | 259 | 260 | def _compute_conv_grad_sample(layer, A: torch.Tensor, B: torch.Tensor, batch_dim: int = 0, convd: int = 1): 261 | 262 | if autograd_grad_sample.fp16(): 263 | if B.dtype != torch.half: 264 | B = B.half() 265 | if A.dtype != torch.half: 266 | A = A.half() 267 | 268 | g_ = B.flatten(2) 269 | if convd == 1: 270 | padding = layer.padding if isinstance( 271 | layer.padding, tuple) else (*layer.padding, *layer.padding) 272 | # padded_A = F.pad(A, padding) 273 | unfold_x = F.unfold(A.unsqueeze(-2), kernel_size=(1, *layer.kernel_size), 274 | padding=(0, *padding), 275 | dilation=(1, *layer.dilation), 276 | stride=(1, *layer.stride)) 277 | elif convd == 2: 278 | unfold_x = F.unfold(A, kernel_size=layer.kernel_size, 279 | dilation=layer.dilation, padding=layer.padding, 280 | stride=layer.stride) 281 | elif convd == 3: 282 | from opacus.utils import tensor_utils 283 | unfold_x = tensor_utils.unfold3d(A, kernel_size=layer.kernel_size, 284 | dilation=layer.dilation, padding=layer.padding, 285 | stride=layer.stride) 286 | hooks_mode = autograd_grad_sample.get_hooks_mode() 287 | if "ghost_norm" in hooks_mode: 288 | if "mixed" in hooks_mode: 289 | if hasattr(layer, "use_gc"): 290 | use_gc = layer.use_gc 291 | else: 292 | L = unfold_x.size(-1) 293 | assert L == g_.shape[-1] 294 | d = unfold_x.shape[1] 295 | p = g_.shape[1] 296 | use_gc = (2*L**2 <= d*p) 297 | layer.use_gc = use_gc 298 | else: 299 | use_gc = True 300 | if use_gc: 301 | a = torch.einsum('bji, bjk -> bik', unfold_x, unfold_x) 302 | g = torch.einsum('bji, bjk -> bik', g_, g_) 303 | gnorm = torch.sqrt(torch.einsum('bij, bij -> b', a, g)) 304 | else: 305 | grads = torch.einsum( 306 | 'bdT, bpT-> bdp', unfold_x, g_) 307 | gnorm = torch.sqrt(torch.sum(grads**2, dim=(1, 2))) 308 | _create_or_extend_norm_sample(layer.weight, gnorm) 309 | 310 | if layer.bias is not None: 311 | _create_or_extend_norm_sample( 312 | layer.bias, g_.sum(dim=2).norm(2, dim=1)) 313 | else: 314 | _create_or_extend_grad_sample(layer.weight, torch.bmm( 315 | unfold_x, g_.permute(0, 2, 1)).view(-1, *layer.weight.shape), batch_dim) 316 | 317 | if layer.bias is not None: 318 | _create_or_extend_grad_sample(layer.bias, g_.sum(dim=2), batch_dim) 319 | 320 | 321 | 322 | _supported_layers_grad_samplers = { 323 | "Embedding": _compute_embedding_grad_sample, 324 | "Linear": _compute_linear_grad_sample, 325 | "LayerNorm": _compute_norm_grad_sample, 326 | "GroupNorm": _compute_group_norm_grad_sample, 327 | # HuggingFace Open-AI GPT-2. 328 | "Conv1d": partial(_compute_conv_grad_sample, convd=1), 329 | "Conv2d": partial(_compute_conv_grad_sample, convd=2), 330 | "Conv3d": partial(_compute_conv_grad_sample, convd=3), 331 | } 332 | -------------------------------------------------------------------------------- /private_CNN/privacy_utils/transformers_support.py: -------------------------------------------------------------------------------- 1 | """Utilities to make using PrivacyEngine easy with Hugging Face transformers.""" 2 | import types 3 | from typing import Union 4 | 5 | import torch 6 | import transformers 7 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions 8 | from transformers.utils import logging 9 | 10 | logger = logging.get_logger(__name__) 11 | 12 | 13 | def forward_swapper(module): 14 | """Fix incompatibility between Opacus and Hugging Face. 15 | 16 | Root cause is positional embedding without broadcasting. 17 | """ 18 | if isinstance(module, (transformers.OpenAIGPTLMHeadModel, transformers.OpenAIGPTDoubleHeadsModel)): 19 | swap_openai_gpt_model_forward(module.transformer) 20 | if isinstance(module, (transformers.GPT2LMHeadModel, transformers.GPT2DoubleHeadsModel)): 21 | swap_gpt2_model_forward(module.transformer) 22 | elif hasattr(module, 'roberta'): 23 | swap_roberta_model_forward(module.roberta) 24 | elif hasattr(module, 'bert'): 25 | swap_bert_model_forward(module.bert) 26 | elif hasattr(module, 'albert'): 27 | swap_albert_model_forward(module.albert) 28 | 29 | 30 | def swap_openai_gpt_model_forward(model: transformers.OpenAIGPTModel): 31 | def new_forward( 32 | self, 33 | input_ids=None, 34 | attention_mask=None, 35 | token_type_ids=None, 36 | position_ids=None, 37 | head_mask=None, 38 | inputs_embeds=None, 39 | output_attentions=None, 40 | output_hidden_states=None, 41 | return_dict=None, 42 | ): 43 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 44 | output_hidden_states = ( 45 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 46 | ) 47 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 48 | 49 | if input_ids is not None and inputs_embeds is not None: 50 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 51 | elif input_ids is not None: 52 | input_shape = input_ids.size() 53 | input_ids = input_ids.view(-1, input_shape[-1]) 54 | elif inputs_embeds is not None: 55 | input_shape = inputs_embeds.size()[:-1] 56 | else: 57 | raise ValueError("You have to specify either input_ids or inputs_embeds") 58 | 59 | if position_ids is None: 60 | # Code is different from when we had a single embedding matrix from position and token embeddings 61 | position_ids = self.position_ids[None, : input_shape[-1]] 62 | # --- lxuechen: Duplicate to make privacy work! --- 63 | position_ids = position_ids.repeat(input_ids.size(0), 1) 64 | # --- 65 | 66 | # Attention mask. 67 | if attention_mask is not None: 68 | # We create a 3D attention mask from a 2D tensor mask. 69 | # Sizes are [batch_size, 1, 1, to_seq_length] 70 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 71 | # this attention mask is more simple than the triangular masking of causal attention 72 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 73 | attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 74 | 75 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 76 | # masked positions, this operation will create a tensor which is 0.0 for 77 | # positions we want to attend and -10000.0 for masked positions. 78 | # Since we are adding it to the raw scores before the softmax, this is 79 | # effectively the same as removing these entirely. 80 | attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 81 | attention_mask = (1.0 - attention_mask) * -10000.0 82 | 83 | # Prepare head mask if needed 84 | head_mask = self.get_head_mask(head_mask, self.config.n_layer) 85 | 86 | if inputs_embeds is None: 87 | inputs_embeds = self.tokens_embed(input_ids) 88 | position_embeds = self.positions_embed(position_ids) 89 | if token_type_ids is not None: 90 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) 91 | token_type_embeds = self.tokens_embed(token_type_ids) 92 | else: 93 | token_type_embeds = 0 94 | hidden_states = inputs_embeds + position_embeds + token_type_embeds 95 | hidden_states = self.drop(hidden_states) 96 | 97 | output_shape = input_shape + (hidden_states.size(-1),) 98 | 99 | all_attentions = () if output_attentions else None 100 | all_hidden_states = () if output_hidden_states else None 101 | for i, block in enumerate(self.h): 102 | if output_hidden_states: 103 | all_hidden_states = all_hidden_states + (hidden_states,) 104 | 105 | outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions=output_attentions) 106 | hidden_states = outputs[0] 107 | if output_attentions: 108 | all_attentions = all_attentions + (outputs[1],) 109 | 110 | hidden_states = hidden_states.view(*output_shape) 111 | # Add last layer 112 | if output_hidden_states: 113 | all_hidden_states = all_hidden_states + (hidden_states,) 114 | 115 | if not return_dict: 116 | return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) 117 | 118 | return BaseModelOutput( 119 | last_hidden_state=hidden_states, 120 | hidden_states=all_hidden_states, 121 | attentions=all_attentions, 122 | ) 123 | 124 | model.forward = types.MethodType(new_forward, model) 125 | 126 | 127 | def swap_gpt2_model_forward(model: Union[transformers.GPT2Model, transformers.GPT2DoubleHeadsModel]): 128 | """Modify the forward function for `GPT2Model` so that per-sample gradients are correct. 129 | 130 | Main issue is that positional embedding's input should be duplicated. 131 | """ 132 | 133 | def new_forward( 134 | self, 135 | input_ids=None, 136 | past_key_values=None, 137 | attention_mask=None, 138 | token_type_ids=None, 139 | position_ids=None, 140 | head_mask=None, 141 | inputs_embeds=None, 142 | encoder_hidden_states=None, 143 | encoder_attention_mask=None, 144 | use_cache=None, 145 | output_attentions=None, 146 | output_hidden_states=None, 147 | return_dict=None, 148 | ): 149 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 150 | output_hidden_states = ( 151 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 152 | ) 153 | use_cache = use_cache if use_cache is not None else self.config.use_cache 154 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 155 | 156 | if input_ids is not None and inputs_embeds is not None: 157 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 158 | elif input_ids is not None: 159 | input_shape = input_ids.size() 160 | input_ids = input_ids.view(-1, input_shape[-1]) 161 | batch_size = input_ids.shape[0] 162 | elif inputs_embeds is not None: 163 | input_shape = inputs_embeds.size()[:-1] 164 | batch_size = inputs_embeds.shape[0] 165 | else: 166 | raise ValueError("You have to specify either input_ids or inputs_embeds") 167 | 168 | device = input_ids.device if input_ids is not None else inputs_embeds.device 169 | 170 | if token_type_ids is not None: 171 | token_type_ids = token_type_ids.view(-1, input_shape[-1]) 172 | if position_ids is not None: 173 | position_ids = position_ids.view(-1, input_shape[-1]) 174 | 175 | if past_key_values is None: 176 | past_length = 0 177 | past_key_values = tuple([None] * len(self.h)) 178 | else: 179 | past_length = past_key_values[0][0].size(-2) 180 | if position_ids is None: 181 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) 182 | position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) 183 | # --- lxuechen: Duplicate to make privacy work! --- 184 | position_ids = position_ids.repeat(batch_size, 1) 185 | # --- 186 | 187 | # GPT2Attention mask. 188 | if attention_mask is not None: 189 | assert batch_size > 0, "batch_size has to be defined and > 0" 190 | attention_mask = attention_mask.view(batch_size, -1) 191 | # We create a 3D attention mask from a 2D tensor mask. 192 | # Sizes are [batch_size, 1, 1, to_seq_length] 193 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 194 | # this attention mask is more simple than the triangular masking of causal attention 195 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 196 | attention_mask = attention_mask[:, None, None, :] 197 | 198 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 199 | # masked positions, this operation will create a tensor which is 0.0 for 200 | # positions we want to attend and -10000.0 for masked positions. 201 | # Since we are adding it to the raw scores before the softmax, this is 202 | # effectively the same as removing these entirely. 203 | attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility 204 | attention_mask = (1.0 - attention_mask) * -10000.0 205 | 206 | # If a 2D ou 3D attention mask is provided for the cross-attention 207 | # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] 208 | if self.config.add_cross_attention and encoder_hidden_states is not None: 209 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 210 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 211 | if encoder_attention_mask is None: 212 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 213 | encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) 214 | else: 215 | encoder_attention_mask = None 216 | 217 | # Prepare head mask if needed 218 | # 1.0 in head_mask indicate we keep the head 219 | # attention_probs has shape bsz x n_heads x N x N 220 | # head_mask has shape n_layer x batch x n_heads x N x N 221 | head_mask = self.get_head_mask(head_mask, self.config.n_layer) 222 | 223 | if inputs_embeds is None: 224 | inputs_embeds = self.wte(input_ids) 225 | position_embeds = self.wpe(position_ids) 226 | hidden_states = inputs_embeds + position_embeds 227 | 228 | if token_type_ids is not None: 229 | token_type_embeds = self.wte(token_type_ids) 230 | hidden_states = hidden_states + token_type_embeds 231 | 232 | hidden_states = self.drop(hidden_states) 233 | 234 | output_shape = input_shape + (hidden_states.size(-1),) 235 | 236 | presents = () if use_cache else None 237 | all_self_attentions = () if output_attentions else None 238 | all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None 239 | all_hidden_states = () if output_hidden_states else None 240 | for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): 241 | 242 | # Model parallel 243 | if self.model_parallel: 244 | torch.cuda.set_device(hidden_states.device) 245 | # Ensure layer_past is on same device as hidden_states (might not be correct) 246 | if layer_past is not None: 247 | layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) 248 | # Ensure that attention_mask is always on the same device as hidden_states 249 | if attention_mask is not None: 250 | attention_mask = attention_mask.to(hidden_states.device) 251 | if isinstance(head_mask, torch.Tensor): 252 | head_mask = head_mask.to(hidden_states.device) 253 | if output_hidden_states: 254 | all_hidden_states = all_hidden_states + (hidden_states,) 255 | 256 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 257 | 258 | if use_cache: 259 | logger.warning( 260 | "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " 261 | "`use_cache=False`..." 262 | ) 263 | use_cache = False 264 | 265 | def create_custom_forward(module): 266 | def custom_forward(*inputs): 267 | # None for past_key_value 268 | return module(*inputs, use_cache, output_attentions) 269 | 270 | return custom_forward 271 | 272 | outputs = torch.utils.checkpoint.checkpoint( 273 | create_custom_forward(block), 274 | hidden_states, 275 | None, 276 | attention_mask, 277 | head_mask[i], 278 | encoder_hidden_states, 279 | encoder_attention_mask, 280 | ) 281 | else: 282 | outputs = block( 283 | hidden_states, 284 | layer_past=layer_past, 285 | attention_mask=attention_mask, 286 | head_mask=head_mask[i], 287 | encoder_hidden_states=encoder_hidden_states, 288 | encoder_attention_mask=encoder_attention_mask, 289 | use_cache=use_cache, 290 | output_attentions=output_attentions, 291 | ) 292 | 293 | hidden_states = outputs[0] 294 | if use_cache is True: 295 | presents = presents + (outputs[1],) 296 | 297 | if output_attentions: 298 | all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) 299 | if self.config.add_cross_attention: 300 | all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) 301 | 302 | # Model Parallel: If it's the last layer for that device, put things on the next device 303 | if self.model_parallel: 304 | for k, v in self.device_map.items(): 305 | if i == v[-1] and "cuda:" + str(k) != self.last_device: 306 | hidden_states = hidden_states.to("cuda:" + str(k + 1)) 307 | 308 | hidden_states = self.ln_f(hidden_states) 309 | 310 | hidden_states = hidden_states.view(*output_shape) 311 | # Add last hidden state 312 | if output_hidden_states: 313 | all_hidden_states = all_hidden_states + (hidden_states,) 314 | 315 | if not return_dict: 316 | return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) 317 | 318 | return BaseModelOutputWithPastAndCrossAttentions( 319 | last_hidden_state=hidden_states, 320 | past_key_values=presents, 321 | hidden_states=all_hidden_states, 322 | attentions=all_self_attentions, 323 | cross_attentions=all_cross_attentions, 324 | ) 325 | 326 | model.forward = types.MethodType(new_forward, model) 327 | 328 | 329 | def swap_roberta_model_forward(model: transformers.RobertaModel): 330 | # Doing nothing is good for Roberta. 331 | pass 332 | 333 | 334 | def swap_bert_model_forward(model: transformers.BertModel): 335 | def new_forward( 336 | self, 337 | input_ids=None, 338 | token_type_ids=None, 339 | position_ids=None, 340 | inputs_embeds=None, 341 | past_key_values_length=0 342 | ): 343 | if input_ids is not None: 344 | input_shape = input_ids.size() 345 | else: 346 | input_shape = inputs_embeds.size()[:-1] 347 | 348 | seq_length = input_shape[1] 349 | 350 | if position_ids is None: 351 | position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length] 352 | 353 | # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs 354 | # when its auto-generated, registered buffer helps users when tracing the model without passing 355 | # token_type_ids, solves issue #5664 356 | if token_type_ids is None: 357 | if hasattr(self, "token_type_ids"): 358 | buffered_token_type_ids = self.token_type_ids[:, :seq_length] 359 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) 360 | token_type_ids = buffered_token_type_ids_expanded 361 | else: 362 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) 363 | 364 | if inputs_embeds is None: 365 | inputs_embeds = self.word_embeddings(input_ids) 366 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 367 | 368 | embeddings = inputs_embeds + token_type_embeddings 369 | if self.position_embedding_type == "absolute": 370 | # --- lxuechen: Duplicate to make privacy work! --- 371 | batch_size = input_ids.size(0) 372 | position_ids = position_ids.repeat(batch_size, 1) 373 | position_embeddings = self.position_embeddings(position_ids) 374 | # --- 375 | embeddings += position_embeddings 376 | embeddings = self.LayerNorm(embeddings) 377 | embeddings = self.dropout(embeddings) 378 | return embeddings 379 | 380 | model.embeddings.forward = types.MethodType(new_forward, model.embeddings) 381 | 382 | 383 | def swap_albert_model_forward(model: transformers.AlbertModel): 384 | """So far a duplicate of `swap_bert_model_forward`.""" 385 | 386 | def new_forward( 387 | self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 388 | ): 389 | if input_ids is not None: 390 | input_shape = input_ids.size() 391 | else: 392 | input_shape = inputs_embeds.size()[:-1] 393 | 394 | seq_length = input_shape[1] 395 | 396 | if position_ids is None: 397 | position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length] 398 | 399 | # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs 400 | # when its auto-generated, registered buffer helps users when tracing the model without passing 401 | # token_type_ids, solves 402 | # issue #5664 403 | if token_type_ids is None: 404 | if hasattr(self, "token_type_ids"): 405 | buffered_token_type_ids = self.token_type_ids[:, :seq_length] 406 | buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) 407 | token_type_ids = buffered_token_type_ids_expanded 408 | else: 409 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) 410 | 411 | if inputs_embeds is None: 412 | inputs_embeds = self.word_embeddings(input_ids) 413 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 414 | 415 | embeddings = inputs_embeds + token_type_embeddings 416 | if self.position_embedding_type == "absolute": 417 | # --- lxuechen: Duplicate to make privacy work! 418 | batch_size = input_ids.size(0) 419 | position_ids = position_ids.repeat(batch_size, 1) 420 | position_embeddings = self.position_embeddings(position_ids) 421 | # --- 422 | embeddings += position_embeddings 423 | embeddings = self.LayerNorm(embeddings) 424 | embeddings = self.dropout(embeddings) 425 | return embeddings 426 | 427 | model.embeddings.forward = types.MethodType(new_forward, model.embeddings) 428 | -------------------------------------------------------------------------------- /cifar10.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "sZdt62UDG--F" 7 | }, 8 | "source": [ 9 | "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JialinMao/private_CNN/blob/master/cifar10.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": { 16 | "colab": { 17 | "base_uri": "https://localhost:8080/" 18 | }, 19 | "id": "TseeB3EWFoTn", 20 | "outputId": "342683bf-8280-4450-f469-e05930bb185a" 21 | }, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "Requirement already satisfied: opacus==1.1.1 in /usr/local/lib/python3.7/dist-packages (1.1.1)\n", 28 | "Requirement already satisfied: numpy>=1.15 in /usr/local/lib/python3.7/dist-packages (from opacus==1.1.1) (1.21.6)\n", 29 | "Requirement already satisfied: scipy>=1.2 in /usr/local/lib/python3.7/dist-packages (from opacus==1.1.1) (1.4.1)\n", 30 | "Requirement already satisfied: torch>=1.8 in /usr/local/lib/python3.7/dist-packages (from opacus==1.1.1) (1.11.0+cu113)\n", 31 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.8->opacus==1.1.1) (4.2.0)\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "!pip install opacus==1.1.1" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 3, 42 | "metadata": { 43 | "colab": { 44 | "base_uri": "https://localhost:8080/" 45 | }, 46 | "id": "vCOIWPsfL1Ik", 47 | "outputId": "f0685652-f915-4c1f-94ec-61addf194b8a" 48 | }, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "Cloning into 'private_CNN'...\n", 55 | "remote: Enumerating objects: 1625, done.\u001b[K\n", 56 | "remote: Counting objects: 100% (1625/1625), done.\u001b[K\n", 57 | "remote: Compressing objects: 100% (776/776), done.\u001b[K\n", 58 | "remote: Total 1625 (delta 834), reused 1622 (delta 831), pack-reused 0\u001b[K\n", 59 | "Receiving objects: 100% (1625/1625), 28.55 MiB | 22.42 MiB/s, done.\n", 60 | "Resolving deltas: 100% (834/834), done.\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "!git clone https://github.com/JialinMao/private_CNN.git" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 6, 71 | "metadata": { 72 | "colab": { 73 | "base_uri": "https://localhost:8080/" 74 | }, 75 | "id": "f9HqTKTtMoLT", 76 | "outputId": "6ebaa336-25c0-496e-87ab-a16feb8c6f41" 77 | }, 78 | "outputs": [ 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "/content/private_CNN\n" 84 | ] 85 | } 86 | ], 87 | "source": [ 88 | "cd private_CNN/" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 24, 94 | "metadata": { 95 | "colab": { 96 | "base_uri": "https://localhost:8080/" 97 | }, 98 | "id": "5phptOs5QS5b", 99 | "outputId": "b0500ca0-96de-4bd2-b3ce-e16d71530b15" 100 | }, 101 | "outputs": [ 102 | { 103 | "name": "stdout", 104 | "output_type": "stream", 105 | "text": [ 106 | "fatal: destination path 'pytorch-cifar' already exists and is not an empty directory.\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "!git clone https://github.com/kuangliu/pytorch-cifar.git" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "mv pytorch-cifar/models .\n" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 11, 126 | "metadata": { 127 | "id": "IN_kmXnUMlvD" 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "import sys\n", 132 | "sys.path.insert(0, '/content/private_CNN')\n", 133 | "import models\n", 134 | "import private_CNN" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 10, 140 | "metadata": { 141 | "id": "Qp_fa2zsIsux" 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "from tqdm import tqdm\n", 146 | "import torch\n", 147 | "import torch.nn as nn\n", 148 | "import torch.optim as optim\n", 149 | "import torch.nn.functional as F\n", 150 | "import torch.backends.cudnn as cudnn\n", 151 | "\n", 152 | "import torchvision\n", 153 | "import torchvision.transforms as transforms\n", 154 | "\n", 155 | "from opacus.validators import ModuleValidator" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 12, 161 | "metadata": { 162 | "id": "uxXYVGqSJAwN" 163 | }, 164 | "outputs": [], 165 | "source": [ 166 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": { 172 | "id": "9QGvyCFLOWsA" 173 | }, 174 | "source": [ 175 | "## Arguments" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 14, 181 | "metadata": { 182 | "id": "Lp5f2FozOWbg" 183 | }, 184 | "outputs": [], 185 | "source": [ 186 | "lr = 0.1\n", 187 | "epochs = 20\n", 188 | "bs = 128\n", 189 | "eps = 5\n", 190 | "grad_norm = 0.1\n", 191 | "mode = 'ghost-mixed'\n", 192 | "model = 'ResNet18'" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": { 198 | "id": "VKedj5LzOG_O" 199 | }, 200 | "source": [ 201 | "## Loading Data" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 15, 207 | "metadata": { 208 | "colab": { 209 | "base_uri": "https://localhost:8080/", 210 | "height": 120, 211 | "referenced_widgets": [ 212 | "0c9298b2565b4b1bb1f36c31b415e6c8", 213 | "e3a397760fa340b1a053f1b0cbc7de2b", 214 | "f35283304fc049cead0fc2245506ec31", 215 | "f3b6dd11f43048febc0fc7d24b90ff77", 216 | "4bbf86bb356c4ebc90e47a9509c3865c", 217 | "e4c2027d969a41f19fc0b18fb265b2b3", 218 | "a519419d6c0f47c99e86084d9cb9c591", 219 | "b5c59de8176a4c57a8b54d021c626085", 220 | "28b8d652ca1a44d98b76969d1d1d06dc", 221 | "69aea5e905fa425594c1c2253c076bb9", 222 | "76fd1294f6964c6e855285a9db3f1c5b" 223 | ] 224 | }, 225 | "id": "efkzB6mJOCBJ", 226 | "outputId": "27e6e9b8-5418-42e6-d7f2-1ce435563140" 227 | }, 228 | "outputs": [ 229 | { 230 | "name": "stdout", 231 | "output_type": "stream", 232 | "text": [ 233 | "==> Preparing data..\n", 234 | "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../../data/cifar-10-python.tar.gz\n" 235 | ] 236 | }, 237 | { 238 | "data": { 239 | "application/vnd.jupyter.widget-view+json": { 240 | "model_id": "0c9298b2565b4b1bb1f36c31b415e6c8", 241 | "version_major": 2, 242 | "version_minor": 0 243 | }, 244 | "text/plain": [ 245 | " 0%| | 0/170498071 [00:00 Preparing data..')\n", 262 | "\n", 263 | "transform_train = transforms.Compose([\n", 264 | " transforms.ToTensor(),\n", 265 | "])\n", 266 | "transform_test = transforms.Compose([\n", 267 | " transforms.ToTensor(),\n", 268 | "])\n", 269 | "\n", 270 | "trainset = torchvision.datasets.CIFAR10(\n", 271 | " root='../../data', train=True, download=True, transform=transform_train)\n", 272 | "trainloader = torch.utils.data.DataLoader(\n", 273 | " trainset, batch_size=bs, shuffle=True, num_workers=2, drop_last=True)\n", 274 | "\n", 275 | "testset = torchvision.datasets.CIFAR10(\n", 276 | " root='../../data', train=False, download=True, transform=transform_test)\n", 277 | "testloader = torch.utils.data.DataLoader(\n", 278 | " testset, batch_size=bs, shuffle=False, num_workers=2)\n" 279 | ] 280 | }, 281 | { 282 | "cell_type": "markdown", 283 | "metadata": { 284 | "id": "kPrQyDXsOJes" 285 | }, 286 | "source": [ 287 | "## Building Model" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 16, 293 | "metadata": { 294 | "colab": { 295 | "base_uri": "https://localhost:8080/" 296 | }, 297 | "id": "EflgegqmOLNP", 298 | "outputId": "2ad569e5-26c9-48dc-efc5-12018998c5d7" 299 | }, 300 | "outputs": [ 301 | { 302 | "name": "stdout", 303 | "output_type": "stream", 304 | "text": [ 305 | "==> Building model..\n", 306 | "number of parameters: 11173962\n" 307 | ] 308 | } 309 | ], 310 | "source": [ 311 | "print('==> Building model..')\n", 312 | "if \"VGG\" in model:\n", 313 | " net = models.VGG(model)\n", 314 | "else:\n", 315 | " net = getattr(models, model)()\n", 316 | "\n", 317 | "net = ModuleValidator.fix(net)\n", 318 | "if device == 'cuda':\n", 319 | " net = torch.nn.DataParallel(net)\n", 320 | " cudnn.benchmark = True\n", 321 | "\n", 322 | "print('number of parameters: ', sum([p.numel() for p in net.parameters()]))\n" 323 | ] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": { 328 | "id": "Qs26fMHNOxrK" 329 | }, 330 | "source": [ 331 | "## Privacy Engine" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 17, 337 | "metadata": { 338 | "id": "zbXkTs7JOQpn" 339 | }, 340 | "outputs": [], 341 | "source": [ 342 | "if 'ghost' in mode:\n", 343 | " criterion = nn.CrossEntropyLoss(reduction=\"none\")\n", 344 | "else:\n", 345 | " criterion = nn.CrossEntropyLoss()\n", 346 | "optimizer = optim.SGD(net.parameters(), lr=lr,\n", 347 | " momentum=0.9, weight_decay=5e-4)\n", 348 | "\n", 349 | "if 'ghost' in mode:\n", 350 | " privacy_engine = private_CNN.PrivacyEngine(\n", 351 | " net,\n", 352 | " batch_size=bs,\n", 353 | " sample_size=len(trainloader.dataset),\n", 354 | " epochs=epochs,\n", 355 | " max_grad_norm=grad_norm,\n", 356 | " target_epsilon=eps,\n", 357 | " ghost_clipping=True,\n", 358 | " mixed='mixed' in mode,\n", 359 | " )\n", 360 | " privacy_engine.attach(optimizer)" 361 | ] 362 | }, 363 | { 364 | "cell_type": "markdown", 365 | "metadata": { 366 | "id": "CKqmwpYdPC8W" 367 | }, 368 | "source": [ 369 | "## Trainining and Testing" 370 | ] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": 18, 375 | "metadata": { 376 | "id": "cl7yg1EJPANl" 377 | }, 378 | "outputs": [], 379 | "source": [ 380 | "# Training\n", 381 | "def train(epoch):\n", 382 | " print('\\nEpoch: %d' % epoch)\n", 383 | " net.train()\n", 384 | " train_loss = 0\n", 385 | " correct = 0\n", 386 | " total = 0\n", 387 | "\n", 388 | " for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader)):\n", 389 | " inputs, targets = inputs.to(device), targets.to(device)\n", 390 | " optimizer.zero_grad()\n", 391 | " outputs = net(inputs)\n", 392 | " loss = criterion(outputs, targets)\n", 393 | " if 'ghost' in mode:\n", 394 | " optimizer.step(loss=loss)\n", 395 | " loss = loss.mean()\n", 396 | " else:\n", 397 | " loss.backward()\n", 398 | " optimizer.step()\n", 399 | "\n", 400 | " train_loss += loss.mean().item()\n", 401 | " _, predicted = outputs.max(1)\n", 402 | " total += targets.size(0)\n", 403 | " correct += predicted.eq(targets).sum().item()\n", 404 | "\n", 405 | "\n", 406 | "def test(epoch):\n", 407 | " global best_acc\n", 408 | " net.eval()\n", 409 | " test_loss = 0\n", 410 | " correct = 0\n", 411 | " total = 0\n", 412 | " with torch.no_grad():\n", 413 | " for batch_idx, (inputs, targets) in enumerate(tqdm(testloader)):\n", 414 | " inputs, targets = inputs.to(device), targets.to(device)\n", 415 | " outputs = net(inputs)\n", 416 | " loss = criterion(outputs, targets)\n", 417 | "\n", 418 | " if 'mixed' in mode:\n", 419 | " loss = loss.mean()\n", 420 | " test_loss += loss.mean().item()\n", 421 | " _, predicted = outputs.max(1)\n", 422 | " total += targets.size(0)\n", 423 | " correct += predicted.eq(targets).sum().item()" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": null, 429 | "metadata": { 430 | "id": "RkFgWq6tPTGx" 431 | }, 432 | "outputs": [], 433 | "source": [ 434 | "for epoch in range(epochs):\n", 435 | " train(epoch)\n", 436 | " test(epoch)" 437 | ] 438 | } 439 | ], 440 | "metadata": { 441 | "colab": { 442 | "name": "Private_CNN.ipynb", 443 | "provenance": [] 444 | }, 445 | "kernelspec": { 446 | "display_name": "Python 3", 447 | "name": "python3" 448 | }, 449 | "language_info": { 450 | "name": "python" 451 | }, 452 | "widgets": { 453 | "application/vnd.jupyter.widget-state+json": { 454 | "0c9298b2565b4b1bb1f36c31b415e6c8": { 455 | "model_module": "@jupyter-widgets/controls", 456 | "model_module_version": "1.5.0", 457 | "model_name": "HBoxModel", 458 | "state": { 459 | "_dom_classes": [], 460 | "_model_module": "@jupyter-widgets/controls", 461 | "_model_module_version": "1.5.0", 462 | "_model_name": "HBoxModel", 463 | "_view_count": null, 464 | "_view_module": "@jupyter-widgets/controls", 465 | "_view_module_version": "1.5.0", 466 | "_view_name": "HBoxView", 467 | "box_style": "", 468 | "children": [ 469 | "IPY_MODEL_e3a397760fa340b1a053f1b0cbc7de2b", 470 | "IPY_MODEL_f35283304fc049cead0fc2245506ec31", 471 | "IPY_MODEL_f3b6dd11f43048febc0fc7d24b90ff77" 472 | ], 473 | "layout": "IPY_MODEL_4bbf86bb356c4ebc90e47a9509c3865c" 474 | } 475 | }, 476 | "28b8d652ca1a44d98b76969d1d1d06dc": { 477 | "model_module": "@jupyter-widgets/controls", 478 | "model_module_version": "1.5.0", 479 | "model_name": "ProgressStyleModel", 480 | "state": { 481 | "_model_module": "@jupyter-widgets/controls", 482 | "_model_module_version": "1.5.0", 483 | "_model_name": "ProgressStyleModel", 484 | "_view_count": null, 485 | "_view_module": "@jupyter-widgets/base", 486 | "_view_module_version": "1.2.0", 487 | "_view_name": "StyleView", 488 | "bar_color": null, 489 | "description_width": "" 490 | } 491 | }, 492 | "4bbf86bb356c4ebc90e47a9509c3865c": { 493 | "model_module": "@jupyter-widgets/base", 494 | "model_module_version": "1.2.0", 495 | "model_name": "LayoutModel", 496 | "state": { 497 | "_model_module": "@jupyter-widgets/base", 498 | "_model_module_version": "1.2.0", 499 | "_model_name": "LayoutModel", 500 | "_view_count": null, 501 | "_view_module": "@jupyter-widgets/base", 502 | "_view_module_version": "1.2.0", 503 | "_view_name": "LayoutView", 504 | "align_content": null, 505 | "align_items": null, 506 | "align_self": null, 507 | "border": null, 508 | "bottom": null, 509 | "display": null, 510 | "flex": null, 511 | "flex_flow": null, 512 | "grid_area": null, 513 | "grid_auto_columns": null, 514 | "grid_auto_flow": null, 515 | "grid_auto_rows": null, 516 | "grid_column": null, 517 | "grid_gap": null, 518 | "grid_row": null, 519 | "grid_template_areas": null, 520 | "grid_template_columns": null, 521 | "grid_template_rows": null, 522 | "height": null, 523 | "justify_content": null, 524 | "justify_items": null, 525 | "left": null, 526 | "margin": null, 527 | "max_height": null, 528 | "max_width": null, 529 | "min_height": null, 530 | "min_width": null, 531 | "object_fit": null, 532 | "object_position": null, 533 | "order": null, 534 | "overflow": null, 535 | "overflow_x": null, 536 | "overflow_y": null, 537 | "padding": null, 538 | "right": null, 539 | "top": null, 540 | "visibility": null, 541 | "width": null 542 | } 543 | }, 544 | "69aea5e905fa425594c1c2253c076bb9": { 545 | "model_module": "@jupyter-widgets/base", 546 | "model_module_version": "1.2.0", 547 | "model_name": "LayoutModel", 548 | "state": { 549 | "_model_module": "@jupyter-widgets/base", 550 | "_model_module_version": "1.2.0", 551 | "_model_name": "LayoutModel", 552 | "_view_count": null, 553 | "_view_module": "@jupyter-widgets/base", 554 | "_view_module_version": "1.2.0", 555 | "_view_name": "LayoutView", 556 | "align_content": null, 557 | "align_items": null, 558 | "align_self": null, 559 | "border": null, 560 | "bottom": null, 561 | "display": null, 562 | "flex": null, 563 | "flex_flow": null, 564 | "grid_area": null, 565 | "grid_auto_columns": null, 566 | "grid_auto_flow": null, 567 | "grid_auto_rows": null, 568 | "grid_column": null, 569 | "grid_gap": null, 570 | "grid_row": null, 571 | "grid_template_areas": null, 572 | "grid_template_columns": null, 573 | "grid_template_rows": null, 574 | "height": null, 575 | "justify_content": null, 576 | "justify_items": null, 577 | "left": null, 578 | "margin": null, 579 | "max_height": null, 580 | "max_width": null, 581 | "min_height": null, 582 | "min_width": null, 583 | "object_fit": null, 584 | "object_position": null, 585 | "order": null, 586 | "overflow": null, 587 | "overflow_x": null, 588 | "overflow_y": null, 589 | "padding": null, 590 | "right": null, 591 | "top": null, 592 | "visibility": null, 593 | "width": null 594 | } 595 | }, 596 | "76fd1294f6964c6e855285a9db3f1c5b": { 597 | "model_module": "@jupyter-widgets/controls", 598 | "model_module_version": "1.5.0", 599 | "model_name": "DescriptionStyleModel", 600 | "state": { 601 | "_model_module": "@jupyter-widgets/controls", 602 | "_model_module_version": "1.5.0", 603 | "_model_name": "DescriptionStyleModel", 604 | "_view_count": null, 605 | "_view_module": "@jupyter-widgets/base", 606 | "_view_module_version": "1.2.0", 607 | "_view_name": "StyleView", 608 | "description_width": "" 609 | } 610 | }, 611 | "a519419d6c0f47c99e86084d9cb9c591": { 612 | "model_module": "@jupyter-widgets/controls", 613 | "model_module_version": "1.5.0", 614 | "model_name": "DescriptionStyleModel", 615 | "state": { 616 | "_model_module": "@jupyter-widgets/controls", 617 | "_model_module_version": "1.5.0", 618 | "_model_name": "DescriptionStyleModel", 619 | "_view_count": null, 620 | "_view_module": "@jupyter-widgets/base", 621 | "_view_module_version": "1.2.0", 622 | "_view_name": "StyleView", 623 | "description_width": "" 624 | } 625 | }, 626 | "b5c59de8176a4c57a8b54d021c626085": { 627 | "model_module": "@jupyter-widgets/base", 628 | "model_module_version": "1.2.0", 629 | "model_name": "LayoutModel", 630 | "state": { 631 | "_model_module": "@jupyter-widgets/base", 632 | "_model_module_version": "1.2.0", 633 | "_model_name": "LayoutModel", 634 | "_view_count": null, 635 | "_view_module": "@jupyter-widgets/base", 636 | "_view_module_version": "1.2.0", 637 | "_view_name": "LayoutView", 638 | "align_content": null, 639 | "align_items": null, 640 | "align_self": null, 641 | "border": null, 642 | "bottom": null, 643 | "display": null, 644 | "flex": null, 645 | "flex_flow": null, 646 | "grid_area": null, 647 | "grid_auto_columns": null, 648 | "grid_auto_flow": null, 649 | "grid_auto_rows": null, 650 | "grid_column": null, 651 | "grid_gap": null, 652 | "grid_row": null, 653 | "grid_template_areas": null, 654 | "grid_template_columns": null, 655 | "grid_template_rows": null, 656 | "height": null, 657 | "justify_content": null, 658 | "justify_items": null, 659 | "left": null, 660 | "margin": null, 661 | "max_height": null, 662 | "max_width": null, 663 | "min_height": null, 664 | "min_width": null, 665 | "object_fit": null, 666 | "object_position": null, 667 | "order": null, 668 | "overflow": null, 669 | "overflow_x": null, 670 | "overflow_y": null, 671 | "padding": null, 672 | "right": null, 673 | "top": null, 674 | "visibility": null, 675 | "width": null 676 | } 677 | }, 678 | "e3a397760fa340b1a053f1b0cbc7de2b": { 679 | "model_module": "@jupyter-widgets/controls", 680 | "model_module_version": "1.5.0", 681 | "model_name": "HTMLModel", 682 | "state": { 683 | "_dom_classes": [], 684 | "_model_module": "@jupyter-widgets/controls", 685 | "_model_module_version": "1.5.0", 686 | "_model_name": "HTMLModel", 687 | "_view_count": null, 688 | "_view_module": "@jupyter-widgets/controls", 689 | "_view_module_version": "1.5.0", 690 | "_view_name": "HTMLView", 691 | "description": "", 692 | "description_tooltip": null, 693 | "layout": "IPY_MODEL_e4c2027d969a41f19fc0b18fb265b2b3", 694 | "placeholder": "​", 695 | "style": "IPY_MODEL_a519419d6c0f47c99e86084d9cb9c591", 696 | "value": "" 697 | } 698 | }, 699 | "e4c2027d969a41f19fc0b18fb265b2b3": { 700 | "model_module": "@jupyter-widgets/base", 701 | "model_module_version": "1.2.0", 702 | "model_name": "LayoutModel", 703 | "state": { 704 | "_model_module": "@jupyter-widgets/base", 705 | "_model_module_version": "1.2.0", 706 | "_model_name": "LayoutModel", 707 | "_view_count": null, 708 | "_view_module": "@jupyter-widgets/base", 709 | "_view_module_version": "1.2.0", 710 | "_view_name": "LayoutView", 711 | "align_content": null, 712 | "align_items": null, 713 | "align_self": null, 714 | "border": null, 715 | "bottom": null, 716 | "display": null, 717 | "flex": null, 718 | "flex_flow": null, 719 | "grid_area": null, 720 | "grid_auto_columns": null, 721 | "grid_auto_flow": null, 722 | "grid_auto_rows": null, 723 | "grid_column": null, 724 | "grid_gap": null, 725 | "grid_row": null, 726 | "grid_template_areas": null, 727 | "grid_template_columns": null, 728 | "grid_template_rows": null, 729 | "height": null, 730 | "justify_content": null, 731 | "justify_items": null, 732 | "left": null, 733 | "margin": null, 734 | "max_height": null, 735 | "max_width": null, 736 | "min_height": null, 737 | "min_width": null, 738 | "object_fit": null, 739 | "object_position": null, 740 | "order": null, 741 | "overflow": null, 742 | "overflow_x": null, 743 | "overflow_y": null, 744 | "padding": null, 745 | "right": null, 746 | "top": null, 747 | "visibility": null, 748 | "width": null 749 | } 750 | }, 751 | "f35283304fc049cead0fc2245506ec31": { 752 | "model_module": "@jupyter-widgets/controls", 753 | "model_module_version": "1.5.0", 754 | "model_name": "FloatProgressModel", 755 | "state": { 756 | "_dom_classes": [], 757 | "_model_module": "@jupyter-widgets/controls", 758 | "_model_module_version": "1.5.0", 759 | "_model_name": "FloatProgressModel", 760 | "_view_count": null, 761 | "_view_module": "@jupyter-widgets/controls", 762 | "_view_module_version": "1.5.0", 763 | "_view_name": "ProgressView", 764 | "bar_style": "success", 765 | "description": "", 766 | "description_tooltip": null, 767 | "layout": "IPY_MODEL_b5c59de8176a4c57a8b54d021c626085", 768 | "max": 170498071, 769 | "min": 0, 770 | "orientation": "horizontal", 771 | "style": "IPY_MODEL_28b8d652ca1a44d98b76969d1d1d06dc", 772 | "value": 170498071 773 | } 774 | }, 775 | "f3b6dd11f43048febc0fc7d24b90ff77": { 776 | "model_module": "@jupyter-widgets/controls", 777 | "model_module_version": "1.5.0", 778 | "model_name": "HTMLModel", 779 | "state": { 780 | "_dom_classes": [], 781 | "_model_module": "@jupyter-widgets/controls", 782 | "_model_module_version": "1.5.0", 783 | "_model_name": "HTMLModel", 784 | "_view_count": null, 785 | "_view_module": "@jupyter-widgets/controls", 786 | "_view_module_version": "1.5.0", 787 | "_view_name": "HTMLView", 788 | "description": "", 789 | "description_tooltip": null, 790 | "layout": "IPY_MODEL_69aea5e905fa425594c1c2253c076bb9", 791 | "placeholder": "​", 792 | "style": "IPY_MODEL_76fd1294f6964c6e855285a9db3f1c5b", 793 | "value": " 170499072/? [00:06<00:00, 32370957.61it/s]" 794 | } 795 | } 796 | } 797 | } 798 | }, 799 | "nbformat": 4, 800 | "nbformat_minor": 0 801 | } 802 | -------------------------------------------------------------------------------- /vit.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "sZdt62UDG--F" 7 | }, 8 | "source": [ 9 | "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JialinMao/private_CNN/blob/master/vit.ipynb)\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": { 16 | "colab": { 17 | "base_uri": "https://localhost:8080/" 18 | }, 19 | "id": "TseeB3EWFoTn", 20 | "outputId": "342683bf-8280-4450-f469-e05930bb185a" 21 | }, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "Requirement already satisfied: opacus==1.1.1 in /usr/local/lib/python3.7/dist-packages (1.1.1)\n", 28 | "Requirement already satisfied: numpy>=1.15 in /usr/local/lib/python3.7/dist-packages (from opacus==1.1.1) (1.21.6)\n", 29 | "Requirement already satisfied: scipy>=1.2 in /usr/local/lib/python3.7/dist-packages (from opacus==1.1.1) (1.4.1)\n", 30 | "Requirement already satisfied: torch>=1.8 in /usr/local/lib/python3.7/dist-packages (from opacus==1.1.1) (1.11.0+cu113)\n", 31 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.8->opacus==1.1.1) (4.2.0)\n" 32 | ] 33 | } 34 | ], 35 | "source": [ 36 | "!pip install opacus==1.1.1\n", 37 | "!pip install timm" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 3, 43 | "metadata": { 44 | "colab": { 45 | "base_uri": "https://localhost:8080/" 46 | }, 47 | "id": "vCOIWPsfL1Ik", 48 | "outputId": "f0685652-f915-4c1f-94ec-61addf194b8a" 49 | }, 50 | "outputs": [ 51 | { 52 | "name": "stdout", 53 | "output_type": "stream", 54 | "text": [ 55 | "Cloning into 'private_CNN'...\n", 56 | "remote: Enumerating objects: 1625, done.\u001b[K\n", 57 | "remote: Counting objects: 100% (1625/1625), done.\u001b[K\n", 58 | "remote: Compressing objects: 100% (776/776), done.\u001b[K\n", 59 | "remote: Total 1625 (delta 834), reused 1622 (delta 831), pack-reused 0\u001b[K\n", 60 | "Receiving objects: 100% (1625/1625), 28.55 MiB | 22.42 MiB/s, done.\n", 61 | "Resolving deltas: 100% (834/834), done.\n" 62 | ] 63 | } 64 | ], 65 | "source": [ 66 | "!git clone https://github.com/JialinMao/private_CNN.git" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 6, 72 | "metadata": { 73 | "colab": { 74 | "base_uri": "https://localhost:8080/" 75 | }, 76 | "id": "f9HqTKTtMoLT", 77 | "outputId": "6ebaa336-25c0-496e-87ab-a16feb8c6f41" 78 | }, 79 | "outputs": [ 80 | { 81 | "name": "stdout", 82 | "output_type": "stream", 83 | "text": [ 84 | "/content/private_CNN\n" 85 | ] 86 | } 87 | ], 88 | "source": [ 89 | "cd private_CNN/" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 11, 95 | "metadata": { 96 | "id": "IN_kmXnUMlvD" 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "import sys\n", 101 | "sys.path.insert(0, '/content/private_CNN')\n", 102 | "import private_CNN" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 10, 108 | "metadata": { 109 | "id": "Qp_fa2zsIsux" 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "import timm\n", 114 | "from tqdm import tqdm\n", 115 | "import torch\n", 116 | "import torch.nn as nn\n", 117 | "import torch.optim as optim\n", 118 | "import torch.backends.cudnn as cudnn\n", 119 | "\n", 120 | "import torchvision\n", 121 | "import torchvision.transforms as transforms\n", 122 | "\n", 123 | "from opacus.validators import ModuleValidator\n", 124 | "from opacus.accountants.utils import get_noise_multiplier" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 12, 130 | "metadata": { 131 | "id": "uxXYVGqSJAwN" 132 | }, 133 | "outputs": [], 134 | "source": [ 135 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": { 141 | "id": "zDNAzH_EPkZj" 142 | }, 143 | "source": [ 144 | "# Cifar10 torchvision models" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": { 150 | "id": "9QGvyCFLOWsA" 151 | }, 152 | "source": [ 153 | "## Arguments" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": 14, 159 | "metadata": { 160 | "id": "Lp5f2FozOWbg" 161 | }, 162 | "outputs": [], 163 | "source": [ 164 | "lr = 0.1\n", 165 | "epochs = 20\n", 166 | "bs = 128\n", 167 | "eps = 2\n", 168 | "grad_norm = 0.1\n", 169 | "mode = 'ghost-mixed'\n", 170 | "model = 'crossvit_18_240'\n", 171 | "mini_batch_size = 50\n", 172 | "pretrained = True\n", 173 | "cifar_data = 'CIFAR100'" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": { 179 | "id": "VKedj5LzOG_O" 180 | }, 181 | "source": [ 182 | "## Loading Data" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 15, 188 | "metadata": { 189 | "colab": { 190 | "base_uri": "https://localhost:8080/", 191 | "height": 120, 192 | "referenced_widgets": [ 193 | "0c9298b2565b4b1bb1f36c31b415e6c8", 194 | "e3a397760fa340b1a053f1b0cbc7de2b", 195 | "f35283304fc049cead0fc2245506ec31", 196 | "f3b6dd11f43048febc0fc7d24b90ff77", 197 | "4bbf86bb356c4ebc90e47a9509c3865c", 198 | "e4c2027d969a41f19fc0b18fb265b2b3", 199 | "a519419d6c0f47c99e86084d9cb9c591", 200 | "b5c59de8176a4c57a8b54d021c626085", 201 | "28b8d652ca1a44d98b76969d1d1d06dc", 202 | "69aea5e905fa425594c1c2253c076bb9", 203 | "76fd1294f6964c6e855285a9db3f1c5b" 204 | ] 205 | }, 206 | "id": "efkzB6mJOCBJ", 207 | "outputId": "27e6e9b8-5418-42e6-d7f2-1ce435563140" 208 | }, 209 | "outputs": [ 210 | { 211 | "name": "stdout", 212 | "output_type": "stream", 213 | "text": [ 214 | "==> Preparing data..\n", 215 | "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../../data/cifar-10-python.tar.gz\n" 216 | ] 217 | }, 218 | { 219 | "data": { 220 | "application/vnd.jupyter.widget-view+json": { 221 | "model_id": "0c9298b2565b4b1bb1f36c31b415e6c8", 222 | "version_major": 2, 223 | "version_minor": 0 224 | }, 225 | "text/plain": [ 226 | " 0%| | 0/170498071 [00:00 Preparing data..')\n", 243 | "\n", 244 | "transform_train = transforms.Compose([\n", 245 | " transforms.Resize(224),\n", 246 | " transforms.ToTensor(),\n", 247 | "])\n", 248 | "transform_test = transforms.Compose([\n", 249 | " transforms.Resize(224),\n", 250 | " transforms.ToTensor(),\n", 251 | "])\n", 252 | "\n", 253 | "trainset = torchvision.datasets.CIFAR100(\n", 254 | " root='../../data', train=True, download=True, transform=transform_train)\n", 255 | "testset = torchvision.datasets.CIFAR100(\n", 256 | " root='../../data', train=False, download=True, transform=transform_test)\n", 257 | "\n", 258 | "trainloader = torch.utils.data.DataLoader(\n", 259 | " trainset, batch_size=mini_batch_size, shuffle=True, num_workers=2)\n", 260 | "\n", 261 | "testloader = torch.utils.data.DataLoader(\n", 262 | " testset, batch_size=100, shuffle=False, num_workers=2)\n", 263 | "\n" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": { 269 | "id": "kPrQyDXsOJes" 270 | }, 271 | "source": [ 272 | "## Building Model" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 16, 278 | "metadata": { 279 | "colab": { 280 | "base_uri": "https://localhost:8080/" 281 | }, 282 | "id": "EflgegqmOLNP", 283 | "outputId": "2ad569e5-26c9-48dc-efc5-12018998c5d7" 284 | }, 285 | "outputs": [ 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "==> Building model..\n", 291 | "number of parameters: 11173962\n" 292 | ] 293 | } 294 | ], 295 | "source": [ 296 | "print('==> Building model..', model, ' mode ', mode)\n", 297 | "NUM_CLASSES=10 if cifar_data=='CIFAR10' else 100\n", 298 | "\n", 299 | "net = timm.create_model(model,pretrained=pretrained,num_classes=NUM_CLASSES)\n", 300 | "net = ModuleValidator.fix(net)\n", 301 | "\n", 302 | "if 'convit' in model:\n", 303 | " for name,param in net.named_parameters():\n", 304 | " if 'attn.gating_param' in name:\n", 305 | " param.requires_grad=False\n", 306 | "if 'beit' in model:\n", 307 | " for name,param in net.named_parameters():\n", 308 | " if 'gamma_' in name or 'relative_position_bias_table' in name or 'attn.qkv.weight' in name or 'attn.q_bias' in name or 'attn.v_bias' in name:\n", 309 | " requires_grad=False\n", 310 | "\n", 311 | "\n", 312 | "for name,param in net.named_parameters():\n", 313 | " if 'cls_token' in name or 'pos_embed' in name:\n", 314 | " param.requires_grad=False\n", 315 | "\n", 316 | "if device == 'cuda':\n", 317 | " net = torch.nn.DataParallel(net)\n", 318 | " cudnn.benchmark = True\n", 319 | "\n", 320 | "print('number of parameters: ', sum([p.numel() for p in net.parameters()]))\n" 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "metadata": { 326 | "id": "Qs26fMHNOxrK" 327 | }, 328 | "source": [ 329 | "## Privacy Engine" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": null, 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "if \"ghost\" in mode:\n", 339 | " criterion = nn.CrossEntropyLoss(reduction=\"none\")\n", 340 | "else:\n", 341 | " criterion = nn.CrossEntropyLoss()\n", 342 | "\n", 343 | "optimizer = optim.Adam(net.parameters(), lr=lr)\n", 344 | "\n", 345 | "n_acc_steps = bs // mini_batch_size\n", 346 | "\n", 347 | "if 'ghost' in mode:\n", 348 | " sigma = get_noise_multiplier(\n", 349 | " target_epsilon = eps,\n", 350 | " target_delta = 1e-5,\n", 351 | " sample_rate = bs/len(trainset),\n", 352 | " epochs = epochs,\n", 353 | " accountant = \"gdp\"\n", 354 | " )\n", 355 | " privacy_engine = private_CNN.PrivacyEngine(\n", 356 | " net,\n", 357 | " batch_size=bs,\n", 358 | " sample_size=len(trainloader.dataset),\n", 359 | " noise_multiplier=sigma,\n", 360 | " epochs=epochs,\n", 361 | " max_grad_norm=grad_norm,\n", 362 | " ghost_clipping=True,\n", 363 | " mixed='mixed' in mode\n", 364 | " )\n", 365 | " privacy_engine.attach(optimizer)\n", 366 | "\n" 367 | ] 368 | }, 369 | { 370 | "cell_type": "markdown", 371 | "metadata": { 372 | "id": "CKqmwpYdPC8W" 373 | }, 374 | "source": [ 375 | "## Trainining and Testing" 376 | ] 377 | }, 378 | { 379 | "cell_type": "code", 380 | "execution_count": null, 381 | "metadata": {}, 382 | "outputs": [], 383 | "source": [ 384 | "# Training\n", 385 | "def train(epoch):\n", 386 | " print('\\nEpoch: %d' % epoch)\n", 387 | " net.train()\n", 388 | " train_loss = 0\n", 389 | " correct = 0\n", 390 | " total = 0\n", 391 | "\n", 392 | " for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader)):\n", 393 | " inputs, targets = inputs.to(device), targets.to(device)\n", 394 | " outputs = net(inputs)\n", 395 | " loss = criterion(outputs, targets)\n", 396 | "\n", 397 | " if mode=='non-private':\n", 398 | " loss.backward()\n", 399 | " if ((batch_idx + 1) % n_acc_steps == 0) or ((batch_idx + 1) == len(trainloader)):\n", 400 | " optimizer.step()\n", 401 | " optimizer.zero_grad()\n", 402 | " else:\n", 403 | " if ((batch_idx + 1) % n_acc_steps == 0) or ((batch_idx + 1) == len(trainloader)):\n", 404 | " optimizer.step(loss=loss)\n", 405 | " optimizer.zero_grad()\n", 406 | " else:\n", 407 | " optimizer.virtual_step(loss=loss)\n", 408 | " train_loss += loss.mean().item()\n", 409 | " _, predicted = outputs.max(1)\n", 410 | " total += targets.size(0)\n", 411 | " correct += predicted.eq(targets).sum().item()\n", 412 | "\n", 413 | " print(epoch, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'\n", 414 | " % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))\n", 415 | "\n", 416 | "def test(epoch):\n", 417 | " net.eval()\n", 418 | " test_loss = 0\n", 419 | " correct = 0\n", 420 | " total = 0\n", 421 | " with torch.no_grad():\n", 422 | " for batch_idx, (inputs, targets) in enumerate(tqdm(testloader)):\n", 423 | " inputs, targets = inputs.to(device), targets.to(device)\n", 424 | " outputs = net(inputs)\n", 425 | " loss = criterion(outputs, targets)\n", 426 | "\n", 427 | " loss = loss.mean()\n", 428 | " test_loss += loss.item()\n", 429 | " _, predicted = outputs.max(1)\n", 430 | " total += targets.size(0)\n", 431 | " correct += predicted.eq(targets).sum().item()\n", 432 | "\n", 433 | " print(epoch, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'\n", 434 | " % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))\n" 435 | ] 436 | }, 437 | { 438 | "cell_type": "code", 439 | "execution_count": null, 440 | "metadata": { 441 | "id": "RkFgWq6tPTGx" 442 | }, 443 | "outputs": [], 444 | "source": [ 445 | "for epoch in range(epochs):\n", 446 | " train(epoch)\n", 447 | " test(epoch)" 448 | ] 449 | } 450 | ], 451 | "metadata": { 452 | "colab": { 453 | "name": "Private_CNN.ipynb", 454 | "provenance": [] 455 | }, 456 | "kernelspec": { 457 | "display_name": "Python 3", 458 | "name": "python3" 459 | }, 460 | "language_info": { 461 | "name": "python" 462 | }, 463 | "widgets": { 464 | "application/vnd.jupyter.widget-state+json": { 465 | "0c9298b2565b4b1bb1f36c31b415e6c8": { 466 | "model_module": "@jupyter-widgets/controls", 467 | "model_module_version": "1.5.0", 468 | "model_name": "HBoxModel", 469 | "state": { 470 | "_dom_classes": [], 471 | "_model_module": "@jupyter-widgets/controls", 472 | "_model_module_version": "1.5.0", 473 | "_model_name": "HBoxModel", 474 | "_view_count": null, 475 | "_view_module": "@jupyter-widgets/controls", 476 | "_view_module_version": "1.5.0", 477 | "_view_name": "HBoxView", 478 | "box_style": "", 479 | "children": [ 480 | "IPY_MODEL_e3a397760fa340b1a053f1b0cbc7de2b", 481 | "IPY_MODEL_f35283304fc049cead0fc2245506ec31", 482 | "IPY_MODEL_f3b6dd11f43048febc0fc7d24b90ff77" 483 | ], 484 | "layout": "IPY_MODEL_4bbf86bb356c4ebc90e47a9509c3865c" 485 | } 486 | }, 487 | "28b8d652ca1a44d98b76969d1d1d06dc": { 488 | "model_module": "@jupyter-widgets/controls", 489 | "model_module_version": "1.5.0", 490 | "model_name": "ProgressStyleModel", 491 | "state": { 492 | "_model_module": "@jupyter-widgets/controls", 493 | "_model_module_version": "1.5.0", 494 | "_model_name": "ProgressStyleModel", 495 | "_view_count": null, 496 | "_view_module": "@jupyter-widgets/base", 497 | "_view_module_version": "1.2.0", 498 | "_view_name": "StyleView", 499 | "bar_color": null, 500 | "description_width": "" 501 | } 502 | }, 503 | "4bbf86bb356c4ebc90e47a9509c3865c": { 504 | "model_module": "@jupyter-widgets/base", 505 | "model_module_version": "1.2.0", 506 | "model_name": "LayoutModel", 507 | "state": { 508 | "_model_module": "@jupyter-widgets/base", 509 | "_model_module_version": "1.2.0", 510 | "_model_name": "LayoutModel", 511 | "_view_count": null, 512 | "_view_module": "@jupyter-widgets/base", 513 | "_view_module_version": "1.2.0", 514 | "_view_name": "LayoutView", 515 | "align_content": null, 516 | "align_items": null, 517 | "align_self": null, 518 | "border": null, 519 | "bottom": null, 520 | "display": null, 521 | "flex": null, 522 | "flex_flow": null, 523 | "grid_area": null, 524 | "grid_auto_columns": null, 525 | "grid_auto_flow": null, 526 | "grid_auto_rows": null, 527 | "grid_column": null, 528 | "grid_gap": null, 529 | "grid_row": null, 530 | "grid_template_areas": null, 531 | "grid_template_columns": null, 532 | "grid_template_rows": null, 533 | "height": null, 534 | "justify_content": null, 535 | "justify_items": null, 536 | "left": null, 537 | "margin": null, 538 | "max_height": null, 539 | "max_width": null, 540 | "min_height": null, 541 | "min_width": null, 542 | "object_fit": null, 543 | "object_position": null, 544 | "order": null, 545 | "overflow": null, 546 | "overflow_x": null, 547 | "overflow_y": null, 548 | "padding": null, 549 | "right": null, 550 | "top": null, 551 | "visibility": null, 552 | "width": null 553 | } 554 | }, 555 | "69aea5e905fa425594c1c2253c076bb9": { 556 | "model_module": "@jupyter-widgets/base", 557 | "model_module_version": "1.2.0", 558 | "model_name": "LayoutModel", 559 | "state": { 560 | "_model_module": "@jupyter-widgets/base", 561 | "_model_module_version": "1.2.0", 562 | "_model_name": "LayoutModel", 563 | "_view_count": null, 564 | "_view_module": "@jupyter-widgets/base", 565 | "_view_module_version": "1.2.0", 566 | "_view_name": "LayoutView", 567 | "align_content": null, 568 | "align_items": null, 569 | "align_self": null, 570 | "border": null, 571 | "bottom": null, 572 | "display": null, 573 | "flex": null, 574 | "flex_flow": null, 575 | "grid_area": null, 576 | "grid_auto_columns": null, 577 | "grid_auto_flow": null, 578 | "grid_auto_rows": null, 579 | "grid_column": null, 580 | "grid_gap": null, 581 | "grid_row": null, 582 | "grid_template_areas": null, 583 | "grid_template_columns": null, 584 | "grid_template_rows": null, 585 | "height": null, 586 | "justify_content": null, 587 | "justify_items": null, 588 | "left": null, 589 | "margin": null, 590 | "max_height": null, 591 | "max_width": null, 592 | "min_height": null, 593 | "min_width": null, 594 | "object_fit": null, 595 | "object_position": null, 596 | "order": null, 597 | "overflow": null, 598 | "overflow_x": null, 599 | "overflow_y": null, 600 | "padding": null, 601 | "right": null, 602 | "top": null, 603 | "visibility": null, 604 | "width": null 605 | } 606 | }, 607 | "76fd1294f6964c6e855285a9db3f1c5b": { 608 | "model_module": "@jupyter-widgets/controls", 609 | "model_module_version": "1.5.0", 610 | "model_name": "DescriptionStyleModel", 611 | "state": { 612 | "_model_module": "@jupyter-widgets/controls", 613 | "_model_module_version": "1.5.0", 614 | "_model_name": "DescriptionStyleModel", 615 | "_view_count": null, 616 | "_view_module": "@jupyter-widgets/base", 617 | "_view_module_version": "1.2.0", 618 | "_view_name": "StyleView", 619 | "description_width": "" 620 | } 621 | }, 622 | "a519419d6c0f47c99e86084d9cb9c591": { 623 | "model_module": "@jupyter-widgets/controls", 624 | "model_module_version": "1.5.0", 625 | "model_name": "DescriptionStyleModel", 626 | "state": { 627 | "_model_module": "@jupyter-widgets/controls", 628 | "_model_module_version": "1.5.0", 629 | "_model_name": "DescriptionStyleModel", 630 | "_view_count": null, 631 | "_view_module": "@jupyter-widgets/base", 632 | "_view_module_version": "1.2.0", 633 | "_view_name": "StyleView", 634 | "description_width": "" 635 | } 636 | }, 637 | "b5c59de8176a4c57a8b54d021c626085": { 638 | "model_module": "@jupyter-widgets/base", 639 | "model_module_version": "1.2.0", 640 | "model_name": "LayoutModel", 641 | "state": { 642 | "_model_module": "@jupyter-widgets/base", 643 | "_model_module_version": "1.2.0", 644 | "_model_name": "LayoutModel", 645 | "_view_count": null, 646 | "_view_module": "@jupyter-widgets/base", 647 | "_view_module_version": "1.2.0", 648 | "_view_name": "LayoutView", 649 | "align_content": null, 650 | "align_items": null, 651 | "align_self": null, 652 | "border": null, 653 | "bottom": null, 654 | "display": null, 655 | "flex": null, 656 | "flex_flow": null, 657 | "grid_area": null, 658 | "grid_auto_columns": null, 659 | "grid_auto_flow": null, 660 | "grid_auto_rows": null, 661 | "grid_column": null, 662 | "grid_gap": null, 663 | "grid_row": null, 664 | "grid_template_areas": null, 665 | "grid_template_columns": null, 666 | "grid_template_rows": null, 667 | "height": null, 668 | "justify_content": null, 669 | "justify_items": null, 670 | "left": null, 671 | "margin": null, 672 | "max_height": null, 673 | "max_width": null, 674 | "min_height": null, 675 | "min_width": null, 676 | "object_fit": null, 677 | "object_position": null, 678 | "order": null, 679 | "overflow": null, 680 | "overflow_x": null, 681 | "overflow_y": null, 682 | "padding": null, 683 | "right": null, 684 | "top": null, 685 | "visibility": null, 686 | "width": null 687 | } 688 | }, 689 | "e3a397760fa340b1a053f1b0cbc7de2b": { 690 | "model_module": "@jupyter-widgets/controls", 691 | "model_module_version": "1.5.0", 692 | "model_name": "HTMLModel", 693 | "state": { 694 | "_dom_classes": [], 695 | "_model_module": "@jupyter-widgets/controls", 696 | "_model_module_version": "1.5.0", 697 | "_model_name": "HTMLModel", 698 | "_view_count": null, 699 | "_view_module": "@jupyter-widgets/controls", 700 | "_view_module_version": "1.5.0", 701 | "_view_name": "HTMLView", 702 | "description": "", 703 | "description_tooltip": null, 704 | "layout": "IPY_MODEL_e4c2027d969a41f19fc0b18fb265b2b3", 705 | "placeholder": "​", 706 | "style": "IPY_MODEL_a519419d6c0f47c99e86084d9cb9c591", 707 | "value": "" 708 | } 709 | }, 710 | "e4c2027d969a41f19fc0b18fb265b2b3": { 711 | "model_module": "@jupyter-widgets/base", 712 | "model_module_version": "1.2.0", 713 | "model_name": "LayoutModel", 714 | "state": { 715 | "_model_module": "@jupyter-widgets/base", 716 | "_model_module_version": "1.2.0", 717 | "_model_name": "LayoutModel", 718 | "_view_count": null, 719 | "_view_module": "@jupyter-widgets/base", 720 | "_view_module_version": "1.2.0", 721 | "_view_name": "LayoutView", 722 | "align_content": null, 723 | "align_items": null, 724 | "align_self": null, 725 | "border": null, 726 | "bottom": null, 727 | "display": null, 728 | "flex": null, 729 | "flex_flow": null, 730 | "grid_area": null, 731 | "grid_auto_columns": null, 732 | "grid_auto_flow": null, 733 | "grid_auto_rows": null, 734 | "grid_column": null, 735 | "grid_gap": null, 736 | "grid_row": null, 737 | "grid_template_areas": null, 738 | "grid_template_columns": null, 739 | "grid_template_rows": null, 740 | "height": null, 741 | "justify_content": null, 742 | "justify_items": null, 743 | "left": null, 744 | "margin": null, 745 | "max_height": null, 746 | "max_width": null, 747 | "min_height": null, 748 | "min_width": null, 749 | "object_fit": null, 750 | "object_position": null, 751 | "order": null, 752 | "overflow": null, 753 | "overflow_x": null, 754 | "overflow_y": null, 755 | "padding": null, 756 | "right": null, 757 | "top": null, 758 | "visibility": null, 759 | "width": null 760 | } 761 | }, 762 | "f35283304fc049cead0fc2245506ec31": { 763 | "model_module": "@jupyter-widgets/controls", 764 | "model_module_version": "1.5.0", 765 | "model_name": "FloatProgressModel", 766 | "state": { 767 | "_dom_classes": [], 768 | "_model_module": "@jupyter-widgets/controls", 769 | "_model_module_version": "1.5.0", 770 | "_model_name": "FloatProgressModel", 771 | "_view_count": null, 772 | "_view_module": "@jupyter-widgets/controls", 773 | "_view_module_version": "1.5.0", 774 | "_view_name": "ProgressView", 775 | "bar_style": "success", 776 | "description": "", 777 | "description_tooltip": null, 778 | "layout": "IPY_MODEL_b5c59de8176a4c57a8b54d021c626085", 779 | "max": 170498071, 780 | "min": 0, 781 | "orientation": "horizontal", 782 | "style": "IPY_MODEL_28b8d652ca1a44d98b76969d1d1d06dc", 783 | "value": 170498071 784 | } 785 | }, 786 | "f3b6dd11f43048febc0fc7d24b90ff77": { 787 | "model_module": "@jupyter-widgets/controls", 788 | "model_module_version": "1.5.0", 789 | "model_name": "HTMLModel", 790 | "state": { 791 | "_dom_classes": [], 792 | "_model_module": "@jupyter-widgets/controls", 793 | "_model_module_version": "1.5.0", 794 | "_model_name": "HTMLModel", 795 | "_view_count": null, 796 | "_view_module": "@jupyter-widgets/controls", 797 | "_view_module_version": "1.5.0", 798 | "_view_name": "HTMLView", 799 | "description": "", 800 | "description_tooltip": null, 801 | "layout": "IPY_MODEL_69aea5e905fa425594c1c2253c076bb9", 802 | "placeholder": "​", 803 | "style": "IPY_MODEL_76fd1294f6964c6e855285a9db3f1c5b", 804 | "value": " 170499072/? [00:06<00:00, 32370957.61it/s]" 805 | } 806 | } 807 | } 808 | } 809 | }, 810 | "nbformat": 4, 811 | "nbformat_minor": 0 812 | } 813 | -------------------------------------------------------------------------------- /private_CNN/privacy_utils/privacy_engine.py: -------------------------------------------------------------------------------- 1 | """Code for a privacy engine that plays nicely with Hugging Face transformers. 2 | 3 | Design mostly based on Opacus with the exception that `.step` and `virtual_step` 4 | takes in per-example losses, which should not be called with `.backward()` by 5 | the user. 6 | """ 7 | 8 | import collections 9 | import logging 10 | import math 11 | import types 12 | from typing import Callable, Dict, Optional, Sequence, Union 13 | 14 | import numpy as np 15 | import torch 16 | from torch import nn 17 | 18 | from . import autograd_grad_sample 19 | from . import misc 20 | # from . import transformers_support 21 | from .accounting import gdp_accounting, rdp_accounting 22 | 23 | DEFAULT_ALPHAS = tuple(1 + x / 10.0 for x in range(1, 100) 24 | ) + tuple(range(12, 64)) 25 | 26 | 27 | class PrivacyEngine(object): 28 | """Differentially-private optimization engine that works gracefully with Hugging Face transformers. 29 | 30 | Supports ghost clipping as described in 31 | Li, X., Tramèr, F., Liang, P., & Hashimoto, T. (2021). 32 | Large Language Models Can Be Strong Differentially Private Learners. 33 | arXiv preprint arXiv:2110.05679. 34 | 35 | Implicitly assumes inputs are in batch first format. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | module: nn.Module, 41 | *, 42 | batch_size: int, 43 | sample_size: int, 44 | max_grad_norm: float, 45 | epochs: Optional[Union[int, float]] = None, 46 | noise_multiplier: Optional[float] = None, 47 | target_epsilon: Optional[float] = None, 48 | target_delta: Optional[float] = None, 49 | alphas: Sequence[float] = DEFAULT_ALPHAS, 50 | record_snr: bool = True, 51 | named_params: Optional[Sequence] = None, 52 | fp16: bool = False, 53 | numerical_stability_constant=1e-6, 54 | ghost_clipping: bool = True, 55 | mixed: bool = True, 56 | # Accounting specifics. 57 | accounting_mode="rdp_cks", 58 | eps_error=0.05, 59 | **unused_kwargs, 60 | ): 61 | """Initialize the engine. 62 | 63 | Args: 64 | module: The PyTorch module for which per-sample gradient is required. 65 | Setting the `requires_grad` attribute of a parameter to False 66 | disables the per-sample gradient accumulation. 67 | batch_size: The expected lot size. 68 | sample_size: Size of dataset. 69 | max_grad_norm: The maximum 2-norm for gradient clipping. 70 | epochs: The number of epochs for training. 71 | noise_multiplier: The extra multiplier for DP-SGD noise. 72 | target_epsilon: The target privacy spending. Only used to estimate the `noise_multiplier` if it is not set. 73 | target_delta: The target failure probability. Defaults to 1 / (2 * sample_size) if not set. 74 | alphas: The RDP orders for (ε, δ)-DP conversion. Useless if not accounting in RDP. 75 | record_snr: Record and report the signal-to-noise ratio -- 76 | ratio between norm of summed clipped gradient and norm of noise vector. 77 | named_params: Specifies which parameters need gradients; 78 | defaults to use parameters which require grad in module. 79 | fp16: Set this to True when training with mixed-precision. 80 | numerical_stability_constant: Small constant to avoid division by 0 when clipping. 81 | ghost_clipping: Set this to True to use memory efficient ghost clipping. 82 | mixed: Set this to True to use mixed ghost clipping, which outperforms ghost clipping in memory and usually in time. 83 | accounting_mode: The method of accounting privacy. One of (`rdp`, `gdp`, `rdp_cks`, `glw`, `all`). 84 | Meanings of shorthands: 85 | - rdp: The method in "Rényi Differential Privacy of the Sampled Gaussian Mechanism". 86 | https://arxiv.org/abs/1908.10530 87 | - rdp_cks: Account loss with RDP but perform conversion to approx-DP with a procedure defined in 88 | "The Discrete Gaussian for Differential Privacy". 89 | https://arxiv.org/abs/2004.00010 90 | CKS are authors' last name's first letters. 91 | - gdp: Account loss with Gaussian DP and its central limit theorem described in 92 | "Deep Learning with Gaussian Differential Privacy". 93 | WARNING: This method may underestimate privacy spending. 94 | - glw: Account loss by numerically composing tradeoff functions in f-DP; defined in 95 | "Numerical composition of differential privacy". 96 | https://arxiv.org/abs/2106.02848 97 | GLW are authors' last name's first letters. 98 | - all: Report loss with all methods listed above. 99 | eps_error: Error threshold for upper and lower bound in the GLW accounting procedure. 100 | """ 101 | del unused_kwargs 102 | 103 | super(PrivacyEngine, self).__init__() 104 | if accounting_mode not in ('rdp', 'gdp', 'rdp_cks', 'glw', 'all',): 105 | raise ValueError(f"Unknown accounting mode: {accounting_mode}") 106 | if epochs <= 0.0: 107 | raise ValueError( 108 | f"Number of training epochs cannot be non-positive, but found epochs={epochs}") 109 | 110 | # Privacy parameters. 111 | sample_rate = batch_size / sample_size 112 | if target_delta is None: 113 | target_delta = 1 / (2 * sample_size) 114 | if noise_multiplier is None: 115 | if target_epsilon is None or epochs is None: 116 | raise ValueError( 117 | f"`target_epsilon` and `epochs` must be specified when `noise_multiplier` is `None`." 118 | ) 119 | kwargs_for_get_sigma = dict( 120 | target_epsilon=target_epsilon, 121 | target_delta=target_delta, 122 | sample_rate=sample_rate, 123 | epochs=epochs, 124 | alphas=alphas, 125 | eps_error=eps_error, 126 | ) 127 | if accounting_mode == "rdp": 128 | noise_multiplier = get_sigma_from_rdp(**kwargs_for_get_sigma) 129 | elif accounting_mode == "rdp_cks": 130 | noise_multiplier = get_sigma_from_rdp_cks( 131 | **kwargs_for_get_sigma) 132 | elif accounting_mode == "glw": 133 | noise_multiplier = get_sigma_from_glw(**kwargs_for_get_sigma) 134 | else: 135 | noise_multiplier = get_sigma_from_gdp(**kwargs_for_get_sigma) 136 | 137 | self.batch_size = batch_size 138 | self.sample_size = sample_size 139 | self.sample_rate = sample_rate 140 | self.max_grad_norm = max_grad_norm 141 | 142 | self.epochs = epochs 143 | self.noise_multiplier = noise_multiplier 144 | self.effective_noise_multiplier = noise_multiplier / batch_size 145 | self.target_epsilon = target_epsilon 146 | self.target_delta = target_delta 147 | self.alphas = alphas 148 | self.accounting_mode = accounting_mode 149 | self.record_snr = record_snr 150 | 151 | # Internals. 152 | self.steps = 0 # Tracks privacy spending. 153 | 154 | # Recording. 155 | self.max_clip = None 156 | self.min_clip = None 157 | self.med_clip = None 158 | self.signal = None 159 | self.noise = None 160 | self.snr = None 161 | self.noise_limit = None 162 | 163 | # Record parameters. 164 | self.module = module 165 | if named_params is None: 166 | self.named_params = tuple( 167 | (name, param) for (name, param) in module.named_parameters() if param.requires_grad 168 | ) 169 | else: 170 | self.named_params = named_params 171 | self.num_params = sum(param.numel() for _, param in self.named_params) 172 | 173 | # Lock the part where noisy gradients is created (in `self.step`) if True. 174 | self._locked = False 175 | self.fp16 = fp16 176 | self.numerical_stability_constant = numerical_stability_constant 177 | self.ghost_clipping = ghost_clipping 178 | self.mixed = mixed 179 | if ghost_clipping: 180 | if fp16: 181 | # TODO: Make ghost clipping work with mixed-precision. 182 | raise NotImplementedError( 183 | "Ghost clipping doesn't support mixed-precision.") 184 | # Prepare for first backward in ghost clipping. 185 | if self.mixed: 186 | autograd_grad_sample.set_hooks_mode(f"ghost_norm_mixed") 187 | else: 188 | autograd_grad_sample.set_hooks_mode("ghost_norm") 189 | 190 | # transformers_support.forward_swapper(module=module) 191 | 192 | def lock(self): 193 | self._locked = True 194 | 195 | def unlock(self): 196 | self._locked = False 197 | 198 | def attach(self, optimizer): 199 | autograd_grad_sample.add_hooks( 200 | model=self.module, batch_first=True, loss_reduction="sum", fp16=self.fp16) 201 | 202 | # Override zero grad. 203 | def dp_zero_grad(_self, *args, **kwargs): 204 | _self.privacy_engine.zero_grad() 205 | 206 | # Override step. 207 | def dp_step(_self, **kwargs): 208 | closure = kwargs.pop("closure", None) 209 | 210 | _self.privacy_engine.step(**kwargs) 211 | _self.original_step(closure=closure) 212 | # Only enable creating new grads once parameters are updated. 213 | _self.privacy_engine.unlock() 214 | _self.privacy_engine.steps += 1 215 | 216 | def virtual_step(_self, **kwargs): 217 | _self.privacy_engine.virtual_step(**kwargs) 218 | 219 | def get_privacy_spent(_self, **kwargs): 220 | _self.privacy_engine.get_privacy_spent(**kwargs) 221 | 222 | def get_training_stats(_self, **kwargs): 223 | _self.privacy_engine.get_training_stats(**kwargs) 224 | 225 | optimizer.privacy_engine = self 226 | 227 | optimizer.original_step = optimizer.step 228 | optimizer.step = types.MethodType(dp_step, optimizer) 229 | 230 | optimizer.original_zero_grad = optimizer.zero_grad 231 | optimizer.zero_grad = types.MethodType(dp_zero_grad, optimizer) 232 | 233 | optimizer.virtual_step = types.MethodType(virtual_step, optimizer) 234 | 235 | # Make getting info easier. 236 | optimizer.get_privacy_spent = types.MethodType( 237 | get_privacy_spent, optimizer) 238 | optimizer.get_training_stats = types.MethodType( 239 | get_training_stats, optimizer) 240 | 241 | self.module.privacy_engine = self 242 | 243 | # Just to be safe, we also override `zero_grad` for module. 244 | self.module.original_zero_grad = self.module.zero_grad 245 | self.module.zero_grad = types.MethodType(dp_zero_grad, self.module) 246 | 247 | # For easy detaching. 248 | self.optimizer = optimizer 249 | 250 | def detach(self): 251 | optimizer = self.optimizer 252 | optimizer.step = optimizer.original_step 253 | optimizer.zero_grad = optimizer.original_zero_grad 254 | delattr(optimizer, "privacy_engine") 255 | delattr(optimizer, "original_step") 256 | delattr(optimizer, "original_zero_grad") 257 | delattr(optimizer, "virtual_step") 258 | delattr(optimizer, "get_privacy_spent") 259 | delattr(optimizer, "get_training_stats") 260 | 261 | module = self.module 262 | autograd_grad_sample.remove_hooks(module) 263 | # This is super important when there are multiple attaches! 264 | autograd_grad_sample.set_hooks_mode("default") 265 | module.zero_grad(skip_grad=True) 266 | module.zero_grad = module.original_zero_grad 267 | delattr(module, "original_zero_grad") 268 | 269 | # --- everything specific to ghost clipping --- 270 | def _ghost_step(self, loss: torch.Tensor): 271 | """Run double-backward on per-example loss, then sum up all gradients and noise it.""" 272 | if self._locked: # Skip this gradient creation step if already created gradient and haven't stepped. 273 | logging.warning("Attempted to step, but the engine is on lock.") 274 | return 275 | 276 | self._ghost_helper(loss) 277 | 278 | # Add noise and scale by inverse batch size. 279 | signals, noises = [], [] 280 | for name, param in self.named_params: 281 | # This is only True when there are previous virtual steps. 282 | # The .grad contains the summed clipped gradients of this batch. 283 | # Summed clipped gradients of previous batches are in .summed_grad. 284 | # When there's no gradient accumulation, .summed_grad is not created. 285 | if hasattr(param, 'summed_grad'): 286 | param.grad += param.summed_grad 287 | if self.record_snr: 288 | signals.append(param.grad.reshape(-1).norm(2).cpu()) 289 | 290 | if self.noise_multiplier > 0 and self.max_grad_norm > 0: 291 | noise = torch.normal( 292 | mean=0, 293 | std=self.noise_multiplier * self.max_grad_norm, 294 | size=param.size(), 295 | device=param.device, 296 | dtype=param.dtype, 297 | ) 298 | param.grad += noise 299 | if self.record_snr: 300 | noises.append(noise.reshape(-1).norm(2).cpu()) 301 | del noise 302 | 303 | param.grad /= self.batch_size 304 | 305 | if self.record_snr and noises: 306 | self.signal, self.noise = tuple(torch.stack( 307 | lst).norm(2).item() for lst in (signals, noises)) 308 | self.noise_limit = math.sqrt( 309 | self.num_params) * self.noise_multiplier * self.max_grad_norm 310 | self.snr = self.signal / self.noise 311 | else: 312 | self.snr = math.inf # Undefined! 313 | 314 | # Make creating new gradients impossible, unless optimizer.step is called. 315 | self.lock() 316 | 317 | @torch.no_grad() 318 | def _ghost_virtual_step(self, loss: torch.Tensor): 319 | """Run double-backward on per-example loss, then accumulate loss.""" 320 | self._ghost_helper(loss) 321 | for name, param in self.named_params: 322 | if hasattr(param, 'summed_grad'): 323 | param.summed_grad += param.grad 324 | else: 325 | param.summed_grad = param.grad 326 | 327 | if hasattr(param, "grad"): 328 | del param.grad 329 | if hasattr(param, "norm_sample"): 330 | del param.norm_sample 331 | if hasattr(param, "grad_sample"): 332 | del param.grad_sample 333 | 334 | @torch.enable_grad() 335 | def _ghost_helper(self, loss: torch.Tensor): 336 | """Given per-example losses, do the double backward thing.""" 337 | if loss.dim() != 1: 338 | raise ValueError( 339 | f"Expected `loss` to be a the per-example loss 1-D tensor.") 340 | 341 | first_loss = loss.sum() 342 | first_loss.backward(retain_graph=True) 343 | 344 | # Prepare for second backward. 345 | autograd_grad_sample.set_hooks_mode("ghost_grad") 346 | 347 | # The first backward might have accumulated things we don't need into `.grad`; 348 | # remove it before the second pass to avoid accumulating garbage. 349 | for name, param in self.named_params: 350 | if hasattr(param, "grad"): 351 | del param.grad 352 | 353 | coef_sample = self.get_coef_sample() 354 | second_loss = (coef_sample * loss).sum(dim=0) 355 | second_loss.backward() 356 | 357 | # Prepare for first backward (in the next round). 358 | if self.mixed: 359 | autograd_grad_sample.set_hooks_mode(f"ghost_norm_mixed") 360 | else: 361 | autograd_grad_sample.set_hooks_mode("ghost_norm") 362 | 363 | def get_norm_sample(self): 364 | """Get per-example norms.""" 365 | for name, param in self.named_params: 366 | if hasattr(param, 'norm_sample') == False: 367 | print(name) 368 | 369 | norm_sample = torch.stack( 370 | [param.norm_sample for name, param in self.named_params], dim=0).norm(2, dim=0) 371 | return norm_sample 372 | 373 | def get_coef_sample(self): 374 | """Get per-example gradient scaling factor for clipping.""" 375 | norm_sample = self.get_norm_sample() 376 | return torch.clamp_max(self.max_grad_norm / (norm_sample + self.numerical_stability_constant), 1.) 377 | 378 | # --------------------------------------------- 379 | 380 | @torch.no_grad() 381 | def step(self, scale=1, **kwargs): 382 | """Step function. 383 | 384 | `loss` must be passed in as a keyword argument! 385 | """ 386 | if self.ghost_clipping: 387 | self._ghost_step(loss=kwargs.pop("loss")) 388 | else: 389 | self._step(loss=kwargs.pop("loss"), scale=scale) 390 | 391 | @torch.no_grad() 392 | def virtual_step(self, scale=1, **kwargs): 393 | """Virtual step function when there's gradient accumulation. 394 | 395 | `loss` must be passed in as a keyword argument! 396 | """ 397 | if self.ghost_clipping: 398 | self._ghost_virtual_step(loss=kwargs.pop("loss")) 399 | else: 400 | self._virtual_step(loss=kwargs.pop("loss"), scale=scale) 401 | 402 | def _step(self, loss, scale=1.): 403 | """Create noisy gradients. 404 | 405 | Should be ran right before you call `optimizer.step`. 406 | 407 | This function does 3 things: 408 | 1) call `loss.backward()` 409 | 2) clip the current `.grad_sample` and add that to `.summed_grad` 410 | 3) noise the gradients 411 | In mixed-precision training (with amp), the last two steps require knowing the loss scaling factor. 412 | 413 | Args: 414 | loss: The per-example loss; a 1-D tensor. 415 | scale: The loss up-scaling factor in amp. In full precision, this arg isn't useful. 416 | """ 417 | if self._locked: # Skip this gradient creation step if already created gradient and haven't stepped. 418 | logging.warning("Attempted to step, but the engine is on lock.") 419 | return 420 | 421 | norm_sample, coef_sample = self._accumulate_summed_grad( 422 | loss=loss, scale=scale) 423 | # Collect stats for debugging. 424 | self.max_clip = coef_sample.max().item() 425 | self.min_clip = coef_sample.min().item() 426 | self.med_clip = coef_sample.median().item() 427 | 428 | # Add noise and scale by inverse batch size. 429 | signals, noises = [], [] 430 | for name, param in self.named_params: 431 | # Ultra important to override `.grad`. 432 | if hasattr(param, 'summed_grad'): 433 | param.grad = param.summed_grad.to(param.dtype) 434 | else: 435 | logging.fatal( 436 | f"PrivacyEngine should not reach here; " 437 | f"this means either " 438 | f"1) there is parameter which requires gradient, but was not used in the computational graph, or " 439 | f"2) the backward hook registry failed to find the corresponding module to register." 440 | ) 441 | if self.record_snr: 442 | signals.append(param.grad.reshape(-1).norm(2).cpu()) 443 | 444 | if self.noise_multiplier > 0 and self.max_grad_norm > 0: 445 | noise = torch.normal( 446 | mean=0, 447 | std=self.noise_multiplier * self.max_grad_norm * scale, 448 | size=param.size(), 449 | device=param.device, 450 | dtype=param.dtype, 451 | ) 452 | if self.record_snr: 453 | noises.append(noise.reshape(-1).norm(2).cpu()) 454 | param.grad += noise 455 | del noise 456 | 457 | param.grad /= self.batch_size 458 | 459 | if self.record_snr and noises: 460 | self.signal, self.noise = tuple(torch.stack( 461 | lst).norm(2).item() for lst in (signals, noises)) 462 | self.noise_limit = math.sqrt( 463 | self.num_params) * self.noise_multiplier * self.max_grad_norm 464 | self.snr = self.signal / self.noise 465 | else: 466 | self.snr = math.inf # Undefined! 467 | 468 | # Make creating new gradients impossible, unless optimizer.step is called. 469 | self.lock() 470 | 471 | def zero_grad(self, skip_grad=False): 472 | for name, param in self.named_params: 473 | if hasattr(param, "grad_sample"): 474 | del param.grad_sample 475 | if hasattr(param, "norm_sample"): 476 | del param.norm_sample 477 | if hasattr(param, "summed_grad"): 478 | del param.summed_grad 479 | if not skip_grad: 480 | if hasattr(param, "grad"): 481 | del param.grad 482 | 483 | def _virtual_step(self, loss, scale=1.): 484 | self._accumulate_summed_grad(loss=loss, scale=scale) 485 | for name, param in self.named_params: 486 | # Del everything except `.summed_grad` to save memory! 487 | if hasattr(param, "grad_sample"): 488 | # This must be deleted due to how `privacy_utils::supported_layers_grad_samplers.py` works! 489 | # When a parameter with `.grad_sample` is reused, the per-sample gradients are accumulated! 490 | del param.grad_sample 491 | if hasattr(param, "grad"): 492 | del param.grad 493 | 494 | @torch.no_grad() 495 | def _accumulate_summed_grad(self, loss, scale=1.): 496 | """Accumulate signal by summing clipped gradients.""" 497 | if loss.dim() != 1: 498 | raise ValueError( 499 | f"Expected `loss` to be a the per-example loss 1-D tensor.") 500 | with torch.enable_grad(): 501 | loss.sum(dim=0).backward() 502 | 503 | norm_sample = [] 504 | for name, param in self.named_params: 505 | try: 506 | batch_size = param.grad_sample.size(0) 507 | except AttributeError as error: 508 | args = error.args 509 | extra_msg = f"\n *** {name} parameter has no grad_sample attribute ***" 510 | error.args = (args[0] + extra_msg, *args[1:]) 511 | raise error 512 | norm = param.grad_sample.reshape(batch_size, -1).norm(2, dim=1) 513 | norm_sample.append(norm.cpu()) 514 | 515 | # The stack operation here is prone to error, thus clarify where the error is. 516 | try: 517 | norm_sample = torch.stack(norm_sample, dim=0).norm(2, dim=0) 518 | except RuntimeError as runtime_error: 519 | args = runtime_error.args 520 | 521 | # Get the major shape. 522 | shapes = collections.defaultdict(int) 523 | for tensor in norm_sample: 524 | shapes[tensor.size()] += 1 525 | 526 | major_shape = None 527 | major_count = 0 528 | for shape, count in shapes.items(): 529 | if count > major_count: 530 | major_shape = shape 531 | del shape, count 532 | 533 | # Check which tensors don't have the major shape! 534 | extra_msg = f" \n*** Major shape: {major_shape}" 535 | for (name, param), tensor in zip(list(self.named_params), norm_sample): 536 | if tensor.size() != major_shape: 537 | extra_msg += f", {name} wrong shape: {tensor.size()}" 538 | extra_msg += " ***" 539 | 540 | runtime_error.args = (args[0] + extra_msg, *args[1:]) 541 | raise runtime_error 542 | 543 | coef_sample = torch.clamp_max( 544 | self.max_grad_norm * scale / (norm_sample + 545 | self.numerical_stability_constant), 1. 546 | ) 547 | for name, param in self.named_params: 548 | if not hasattr(param, 'summed_grad'): 549 | param.summed_grad = 0. 550 | current_device = param.grad_sample.device 551 | param.summed_grad += torch.einsum("i,i...->...", 552 | coef_sample.to(current_device), param.grad_sample) 553 | return norm_sample, coef_sample 554 | 555 | def get_privacy_spent(self, steps=None, accounting_mode=None, lenient=False) -> Dict: 556 | if steps is None: 557 | steps = self.steps 558 | if accounting_mode is None: 559 | accounting_mode = self.accounting_mode 560 | 561 | privacy_results = {} 562 | 563 | kwargs = dict( 564 | sample_rate=self.sample_rate, 565 | steps=steps, 566 | delta=self.target_delta, 567 | sigma=self.noise_multiplier, 568 | alphas=self.alphas, 569 | ) 570 | # The try-catch blocks are unusually sloppy... forgive me... 571 | if accounting_mode in ('all', 'rdp'): 572 | try: 573 | eps_rdp, alpha_rdp = _eps_from_rdp(**kwargs) 574 | privacy_results['eps_rdp_opacus'] = eps_rdp 575 | privacy_results['alpha_rdp_opacus'] = alpha_rdp 576 | except Exception as err: 577 | logging.fatal( 578 | "RDP accounting failed! Double check privacy parameters.") 579 | if not lenient: 580 | raise err 581 | 582 | if accounting_mode in ('all', 'gdp'): 583 | try: 584 | eps_gdp, mu_gdp = _eps_from_gdp(**kwargs) 585 | privacy_results['eps_gdp'] = eps_gdp 586 | privacy_results['mu_gdp'] = mu_gdp 587 | except Exception as err: 588 | logging.fatal( 589 | "GDP accounting failed! Double check privacy parameters.") 590 | if not lenient: 591 | raise err 592 | 593 | if accounting_mode in ('all', "rdp_cks"): 594 | try: 595 | eps_rdp, alpha_rdp = _eps_from_rdp_cks(**kwargs) 596 | privacy_results['eps_rdp'] = eps_rdp 597 | privacy_results['alpha_rdp'] = alpha_rdp 598 | except Exception as err: 599 | logging.fatal("RDP accounting with CKS conversion failed! " 600 | "Double check privacy parameters.") 601 | if not lenient: 602 | raise err 603 | 604 | if accounting_mode in ('all', "glw"): 605 | try: 606 | eps_glw = _eps_from_glw(**kwargs) 607 | privacy_results.update(eps_glw) 608 | except Exception as err: 609 | logging.fatal("Numerical composition of tradeoff functions failed! " 610 | "Double check privacy parameters.") 611 | if not lenient: 612 | raise err 613 | 614 | return privacy_results 615 | 616 | def get_training_stats(self): 617 | """Get the clipping, signal, and noise statistics.""" 618 | return { 619 | "med_clip": self.med_clip, 620 | "max_clip": self.max_clip, 621 | "min_clip": self.min_clip, 622 | "snr": self.snr, 623 | "signal": self.signal, 624 | "noise": self.noise, 625 | "noise_limit": self.noise_limit, 626 | } 627 | 628 | def __repr__(self): 629 | return ( 630 | f"PrivacyEngine(\n" 631 | f" target_epsilon={self.target_epsilon}, \n" 632 | f" target_delta={self.target_delta}, \n" 633 | f" noise_multiplier={self.noise_multiplier}, \n" 634 | f" effective_noise_multiplier={self.effective_noise_multiplier}, \n" 635 | f" epochs={self.epochs}, \n" 636 | f" max_grad_norm={self.max_grad_norm}, \n" 637 | f" sample_rate={self.sample_rate}, \n" 638 | f" batch_size={self.batch_size}, \n" 639 | f" accounting_mode={self.accounting_mode}, \n" 640 | f" ghost_clipping={self.ghost_clipping}\n" 641 | f")" 642 | ) 643 | 644 | 645 | def get_sigma_from_rdp( 646 | target_epsilon: float, 647 | target_delta: float, 648 | sample_rate: float, 649 | epochs: Optional[Union[float, int]] = None, 650 | alphas=DEFAULT_ALPHAS, 651 | threshold=1e-3, 652 | sigma_hi_init=4, 653 | sigma_lo_init=0.1, 654 | steps=None, 655 | **kwargs, 656 | ) -> float: 657 | """Get noise multiplier σ for a given ε from Renyi-DP accounting. 658 | 659 | Notes: 660 | Setting `threshold` to an appropriate value is crucial for accurate search. 661 | The default is fine-grained enough for ε ∈ [0.1, 1e10]. 662 | 663 | Args: 664 | target_epsilon: ε in (ε, δ)-DP. 665 | target_delta: δ in (ε, δ)-DP. 666 | sample_rate: Rate for Poisson subsampling, typically denoted as q. 667 | epochs: Number of passes through the dataset. 668 | alphas: Orders for Renyi-divergence. 669 | threshold: Threshold for binary search. Determines the granularity of 670 | the search result. 671 | sigma_hi_init: Starting point for the high end of binary search. 672 | sigma_lo_init: Starting point for the low end of binary search. 673 | steps: Number of updates; defaults to use `epochs` if not set. 674 | 675 | Returns: 676 | The noise multiplier σ for DP-SGD. 677 | """ 678 | if steps is None: 679 | if epochs is None: 680 | raise ValueError("Epochs and steps cannot both be None") 681 | steps = math.ceil(epochs / sample_rate) # Be conservative! 682 | 683 | def sigma_to_eps(sigma): 684 | eps, _ = _eps_from_rdp( 685 | sample_rate=sample_rate, 686 | sigma=sigma, 687 | steps=steps, 688 | alphas=alphas, 689 | delta=target_delta, 690 | ) 691 | return eps 692 | 693 | return _get_sigma_with_target_epsilon( 694 | sigma_hi_init=sigma_hi_init, 695 | sigma_lo_init=sigma_lo_init, 696 | sigma_to_eps=sigma_to_eps, 697 | target_epsilon=target_epsilon, 698 | threshold=threshold, 699 | ) 700 | 701 | 702 | def get_sigma_from_rdp_cks( 703 | target_epsilon: float, 704 | target_delta: float, 705 | sample_rate: float, 706 | epochs: Optional[Union[float, int]] = None, 707 | alphas=DEFAULT_ALPHAS, 708 | threshold=1e-3, 709 | sigma_hi_init=4, 710 | sigma_lo_init=0.1, 711 | steps=None, 712 | **kwargs, 713 | ) -> float: 714 | if steps is None: 715 | if epochs is None: 716 | raise ValueError("Epochs and steps cannot both be None") 717 | steps = math.ceil(epochs / sample_rate) 718 | 719 | def sigma_to_eps(sigma): 720 | eps, _ = _eps_from_rdp_cks( 721 | sample_rate=sample_rate, 722 | sigma=sigma, 723 | steps=steps, 724 | alphas=alphas, 725 | delta=target_delta, 726 | ) 727 | return eps 728 | 729 | return _get_sigma_with_target_epsilon( 730 | sigma_hi_init=sigma_hi_init, 731 | sigma_lo_init=sigma_lo_init, 732 | sigma_to_eps=sigma_to_eps, 733 | target_epsilon=target_epsilon, 734 | threshold=threshold, 735 | ) 736 | 737 | 738 | def get_sigma_from_gdp( 739 | target_epsilon: float, 740 | target_delta: float, 741 | sample_rate: float, 742 | epochs: Optional[Union[float, int]] = None, 743 | threshold=1e-3, 744 | sigma_hi_init=4, 745 | sigma_lo_init=0.1, 746 | mode="poisson", 747 | steps=None, 748 | **kwargs, 749 | ) -> float: 750 | """Get noise multiplier σ for a given ε from f-DP accounting using the central limit theorem in Gaussian DP.""" 751 | if steps is None: 752 | if epochs is None: 753 | raise ValueError("Epochs and steps cannot both be None") 754 | steps = math.ceil(epochs / sample_rate) 755 | 756 | def sigma_to_eps(sigma): 757 | eps, _ = _eps_from_gdp( 758 | sample_rate=sample_rate, 759 | sigma=sigma, 760 | steps=steps, 761 | delta=target_delta, 762 | mode=mode, 763 | ) 764 | return eps 765 | 766 | return _get_sigma_with_target_epsilon( 767 | sigma_hi_init=sigma_hi_init, 768 | sigma_lo_init=sigma_lo_init, 769 | sigma_to_eps=sigma_to_eps, 770 | target_epsilon=target_epsilon, 771 | threshold=threshold, 772 | ) 773 | 774 | 775 | def get_sigma_from_glw( 776 | target_epsilon: float, 777 | target_delta: float, 778 | sample_rate: float, 779 | epochs: Optional[Union[float, int]] = None, 780 | eps_error=0.05, 781 | threshold=1e-3, 782 | sigma_hi_init=4, 783 | sigma_lo_init=0.1, 784 | steps=None, 785 | **kwargs, 786 | ): 787 | """Get noise multiplier σ for a given ε from numerically composing tradeoff functions.""" 788 | from prv_accountant import Accountant 789 | 790 | if steps is None: 791 | if epochs is None: 792 | raise ValueError("Epochs and steps cannot both be None") 793 | steps = math.ceil(epochs / sample_rate) 794 | 795 | def sigma_to_eps(sigma): 796 | accountant = Accountant( 797 | noise_multiplier=sigma, 798 | sampling_probability=sample_rate, 799 | delta=target_delta, 800 | eps_error=eps_error, 801 | max_compositions=steps, 802 | ) 803 | eps_low, eps_estimate, eps_upper = accountant.compute_epsilon( 804 | num_compositions=steps) 805 | return eps_upper # Be conservative. 806 | 807 | return _get_sigma_with_target_epsilon( 808 | sigma_hi_init=sigma_hi_init, 809 | sigma_lo_init=sigma_lo_init, 810 | sigma_to_eps=sigma_to_eps, 811 | target_epsilon=target_epsilon, 812 | threshold=threshold, 813 | ) 814 | 815 | 816 | def _get_sigma_with_target_epsilon( 817 | sigma_hi_init: float, 818 | sigma_lo_init: float, 819 | sigma_to_eps: Callable, 820 | target_epsilon: float, 821 | threshold: float, 822 | ) -> float: 823 | """Core logic for binary searching σ given ε and δ.""" 824 | if sigma_lo_init > sigma_hi_init: 825 | raise ValueError("`sigma_lo` should be smaller than `sigma_hi`.") 826 | 827 | # Find an appropriate region for binary search. 828 | sigma_hi = sigma_hi_init 829 | sigma_lo = sigma_lo_init 830 | 831 | # Ensure sigma_hi isn't too small. 832 | while True: 833 | eps = sigma_to_eps(sigma_hi) 834 | if eps < target_epsilon: 835 | break 836 | sigma_hi *= 2 837 | 838 | # Ensure sigma_lo isn't too large. 839 | while True: 840 | eps = sigma_to_eps(sigma_lo) 841 | if eps > target_epsilon: 842 | break 843 | sigma_lo /= 2 844 | 845 | # Binary search. 846 | while sigma_hi - sigma_lo > threshold: 847 | sigma = (sigma_hi + sigma_lo) / 2 848 | eps = sigma_to_eps(sigma) 849 | if eps < target_epsilon: 850 | sigma_hi = sigma 851 | else: 852 | sigma_lo = sigma 853 | 854 | # Conservative estimate. 855 | return sigma_hi 856 | 857 | 858 | def _eps_from_rdp( 859 | sample_rate, 860 | sigma, 861 | steps, 862 | delta, 863 | alphas=DEFAULT_ALPHAS, 864 | **_, 865 | ): 866 | """Get the ε in (ε, δ)-DP from Renyi-DP accounting.""" 867 | # This is based on Poisson sampling in https://arxiv.org/pdf/1908.10530.pdf 868 | rdp = rdp_accounting.compute_rdp( 869 | q=sample_rate, noise_multiplier=sigma, steps=steps, orders=alphas 870 | ) 871 | # (ε, α) 872 | eps, alpha = rdp_accounting.get_privacy_spent( 873 | orders=alphas, rdp=rdp, delta=delta 874 | ) 875 | return eps, alpha 876 | 877 | 878 | def _compute_eps_cks(orders, rdp, delta): 879 | """Compute epsilon given a list of RDP values and target delta. 880 | Args: 881 | orders: An array (or a scalar) of orders. 882 | rdp: A list (or a scalar) of RDP guarantees. 883 | delta: The target delta. 884 | Returns: 885 | Pair of (eps, optimal_order). 886 | Raises: 887 | ValueError: If input is malformed. 888 | """ 889 | orders_vec = np.atleast_1d(orders) 890 | rdp_vec = np.atleast_1d(rdp) 891 | 892 | if delta <= 0: 893 | raise ValueError("Privacy failure probability bound delta must be >0.") 894 | if len(orders_vec) != len(rdp_vec): 895 | raise ValueError("Input lists must have the same length.") 896 | 897 | # Basic bound (see https://arxiv.org/abs/1702.07476 Proposition 3 in v3): 898 | # eps = min( rdp_vec - math.log(delta) / (orders_vec - 1) ) 899 | 900 | # Improved bound from https://arxiv.org/abs/2004.00010 Proposition 12 (in v4). 901 | # Also appears in https://arxiv.org/abs/2001.05990 Equation 20 (in v1). 902 | eps_vec = [] 903 | for (a, r) in zip(orders_vec, rdp_vec): 904 | if a < 1: 905 | raise ValueError("Renyi divergence order must be >=1.") 906 | if r < 0: 907 | raise ValueError("Renyi divergence must be >=0.") 908 | 909 | if delta ** 2 + math.expm1(-r) >= 0: 910 | # In this case, we can simply bound via KL divergence: 911 | # delta <= sqrt(1-exp(-KL)). 912 | eps = 0 # No need to try further computation if we have eps = 0. 913 | elif a > 1.01: 914 | # This bound is not numerically stable as alpha->1. 915 | # Thus we have a min value of alpha. 916 | # The bound is also not useful for small alpha, so doesn't matter. 917 | eps = r + math.log1p(-1 / a) - math.log(delta * a) / (a - 1) 918 | else: 919 | # In this case we can't do anything. E.g., asking for delta = 0. 920 | eps = np.inf 921 | eps_vec.append(eps) 922 | 923 | idx_opt = np.argmin(eps_vec) 924 | return max(0, eps_vec[idx_opt]), orders_vec[idx_opt] 925 | 926 | 927 | def _eps_from_rdp_cks( 928 | sample_rate, 929 | sigma, 930 | steps, 931 | delta, 932 | alphas=DEFAULT_ALPHAS, 933 | **_, 934 | ): 935 | """Compute RDP as usual, but the conversion to (ε, δ)-DP is based on result by Canonne, Kamath, Steinke. 936 | 937 | # @formatter:off 938 | Code from https://github.com/tensorflow/privacy/blob/5f07198b66b3617b22609db983926e3ba97cd905/tensorflow_privacy/privacy/analysis/rdp_accountant.py#L237 939 | # @formatter:on 940 | """ 941 | rdp = rdp_accounting.compute_rdp( 942 | q=sample_rate, noise_multiplier=sigma, steps=steps, orders=alphas 943 | ) 944 | # (ε, α) 945 | eps, alpha = _compute_eps_cks(orders=alphas, rdp=rdp, delta=delta) 946 | return eps, alpha 947 | 948 | 949 | def _eps_from_gdp( 950 | sample_rate, 951 | sigma, 952 | steps, 953 | delta, 954 | mode="poisson", 955 | **_, 956 | ): 957 | """Get the ε in (ε, δ)-DP from f-DP accounting.""" 958 | epochs = steps * sample_rate 959 | if mode == "poisson": 960 | eps_fn = gdp_accounting.compute_eps_poisson 961 | mu_fn = gdp_accounting.compute_mu_poisson 962 | else: 963 | eps_fn = gdp_accounting.compute_eps_uniform 964 | mu_fn = gdp_accounting.compute_mu_uniform 965 | 966 | eps = eps_fn( 967 | epochs=epochs, 968 | noise_multi=sigma, 969 | delta=delta, 970 | sample_rate=sample_rate, 971 | ) 972 | mu = mu_fn( 973 | epochs=epochs, 974 | noise_multi=sigma, 975 | sample_rate=sample_rate, 976 | ) 977 | return eps, mu 978 | 979 | 980 | def _eps_from_glw( 981 | sample_rate, 982 | sigma, 983 | steps, 984 | delta, 985 | eps_error=0.05, 986 | **_, 987 | ): 988 | from prv_accountant import Accountant 989 | accountant = Accountant( 990 | noise_multiplier=sigma, 991 | sampling_probability=sample_rate, 992 | delta=delta, 993 | eps_error=eps_error, 994 | max_compositions=steps 995 | ) 996 | eps_low, eps_estimate, eps_upper = accountant.compute_epsilon( 997 | num_compositions=steps) 998 | return dict(eps_low=eps_low, eps_estimate=eps_estimate, eps_upper=eps_upper) 999 | --------------------------------------------------------------------------------