├── CIFAR_tests ├── conv_layers.pt ├── linear.py ├── px_expander.py ├── functional.py ├── adadp.py ├── sgd_train_conv.py ├── gaussian_moments.py └── main_adadp.py ├── MNIST_tests ├── conv_layers.pt ├── linear.py ├── px_expander.py ├── functional.py ├── adadp.py ├── adadp_cpu.py ├── main_adadp.py └── gaussian_moments.py ├── README.md └── LICENSE /CIFAR_tests/conv_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DPBayes/ADADP/HEAD/CIFAR_tests/conv_layers.pt -------------------------------------------------------------------------------- /MNIST_tests/conv_layers.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DPBayes/ADADP/HEAD/MNIST_tests/conv_layers.pt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ADADP 2 | 3 | PyTorch implementation of the differentially private learning rate adaptive SGD algorithm described in 4 | 5 | A. Koskela and A. Honkela. "Learning Rate Adaptation for Differentially Private Learning." In: International Conference on Artificial Intelligence and Statistics. PMLR, 2020. p. 2465-2475. http://proceedings.mlr.press/v108/koskela20a/koskela20a.pdf 6 | 7 | Usage, e.g. 8 | 9 | python3 main_adadp.py --n_epochs=100 --tol=1.0 --noise_sigma=2.0 --batch_size=200 10 | 11 | The code was developed with CUDA Version 10.1.105, PyTorch 1.4.0, torchvision 0.2.2, Python 3.6.9. 12 | 13 | The current version of the CIFAR-10 experiments runs only with Cuda. 14 | 15 | The MNIST experiments run also using CPU. 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 DPBayes 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /CIFAR_tests/linear.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ''' 5 | 6 | Linear module modified for the expander and clipping individual gradients. 7 | 8 | This code is due to Mikko Heikkilä (@mixheikk) 9 | 10 | ''' 11 | 12 | 13 | import math 14 | 15 | import torch 16 | from torch.nn.parameter import Parameter 17 | import functional as F 18 | from torch.nn.modules import Module 19 | 20 | 21 | class Linear(Module): 22 | def __init__(self, in_features, out_features, bias=True, batch_size = None): 23 | super(Linear, self).__init__() 24 | self.in_features = in_features 25 | self.out_features = out_features 26 | self.batch_size = batch_size 27 | if batch_size is not None: 28 | self.weight = Parameter(torch.Tensor(batch_size, out_features, in_features)) 29 | else: 30 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 31 | if bias: 32 | if batch_size is not None: 33 | self.bias = Parameter(torch.Tensor(batch_size, out_features)) 34 | else: 35 | self.bias = Parameter(torch.Tensor(out_features)) 36 | else: 37 | self.register_parameter('bias', None) 38 | self.reset_parameters() 39 | 40 | def reset_parameters(self): 41 | stdv = 1. / math.sqrt(self.weight.size(1)) 42 | self.weight.data.uniform_(-stdv, stdv) 43 | if self.bias is not None: 44 | self.bias.data.uniform_(-stdv, stdv) 45 | 46 | def forward(self, input): 47 | return F.linear(input, self.weight, self.bias) 48 | 49 | def __repr__(self): 50 | return self.__class__.__name__ + '(' \ 51 | + 'in_features=' + str(self.in_features) \ 52 | + ', out_features=' + str(self.out_features) \ 53 | + ', bias=' + str(self.bias is not None) + ')' 54 | -------------------------------------------------------------------------------- /MNIST_tests/linear.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ''' 5 | 6 | Linear module modified for the expander and clipping individual gradients. 7 | 8 | This code is due to Mikko Heikkilä (@mixheikk) 9 | 10 | ''' 11 | 12 | 13 | import math 14 | 15 | import torch 16 | from torch.nn.parameter import Parameter 17 | import functional as F 18 | from torch.nn.modules import Module 19 | 20 | 21 | class Linear(Module): 22 | def __init__(self, in_features, out_features, bias=True, batch_size = None): 23 | super(Linear, self).__init__() 24 | self.in_features = in_features 25 | self.out_features = out_features 26 | self.batch_size = batch_size 27 | if batch_size is not None: 28 | self.weight = Parameter(torch.Tensor(batch_size, out_features, in_features)) 29 | else: 30 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 31 | if bias: 32 | if batch_size is not None: 33 | self.bias = Parameter(torch.Tensor(batch_size, out_features)) 34 | else: 35 | self.bias = Parameter(torch.Tensor(out_features)) 36 | else: 37 | self.register_parameter('bias', None) 38 | self.reset_parameters() 39 | 40 | def reset_parameters(self): 41 | stdv = 1. / math.sqrt(self.weight.size(1)) 42 | self.weight.data.uniform_(-stdv, stdv) 43 | if self.bias is not None: 44 | self.bias.data.uniform_(-stdv, stdv) 45 | 46 | def forward(self, input): 47 | return F.linear(input, self.weight, self.bias) 48 | 49 | def __repr__(self): 50 | return self.__class__.__name__ + '(' \ 51 | + 'in_features=' + str(self.in_features) \ 52 | + ', out_features=' + str(self.out_features) \ 53 | + ', bias=' + str(self.bias is not None) + ')' 54 | -------------------------------------------------------------------------------- /CIFAR_tests/px_expander.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | ''' 6 | Expander code for clipping individual gradients. 7 | 8 | This code is due to Mikko Heikkilä (@mixheikk) 9 | ''' 10 | 11 | 12 | 13 | import torch 14 | from torch.autograd import Variable 15 | 16 | import sys 17 | 18 | 19 | # clip and accumulate clipped gradients 20 | 21 | def acc_scaled_grads(model, C, cum_grads, use_cuda=False): 22 | 23 | batch_size = model.batch_proc_size 24 | 25 | g_norm = Variable(torch.zeros(batch_size),requires_grad=False) 26 | 27 | if use_cuda: 28 | g_norm = g_norm.cuda() 29 | 30 | for p in filter(lambda p: p.requires_grad, model.parameters() ): 31 | if p.grad is not None: 32 | g_norm += torch.sum( p.grad.view(batch_size,-1)**2, 1) 33 | 34 | g_norm = torch.sqrt(g_norm) 35 | 36 | # do clipping and accumulate 37 | for p, key in zip( filter(lambda p: p.requires_grad, model.parameters()), cum_grads.keys() ): 38 | if p is not None: 39 | cum_grads[key] += torch.sum( (p.grad/torch.clamp(g_norm.contiguous().view(-1,1,1)/C, min=1)), dim=0 ) 40 | 41 | 42 | # add noise and replace model grads with cumulative grads 43 | def add_noise_with_cum_grads(model, C, sigma, cum_grads, use_cuda=False): 44 | 45 | batch_proc_size = model.batch_proc_size 46 | for p, key in zip( filter(lambda p: p.requires_grad, model.parameters()), cum_grads.keys() ): 47 | 48 | if p.grad is not None: 49 | 50 | # add noise to summed clipped pars 51 | if use_cuda: 52 | p.grad = ((cum_grads[key].expand(batch_proc_size,-1,-1) + \ 53 | Variable( (sigma*C)*torch.normal(mean=torch.zeros_like(p.grad[0]).data, \ 54 | std=1.0).expand(batch_proc_size,-1,-1) ) )/model.batch_size).cuda() 55 | else: 56 | p.grad = (cum_grads[key].expand(batch_proc_size,-1,-1) + \ 57 | Variable( (sigma*C)*torch.normal(mean=torch.zeros_like(p.grad[0]).data,std=1.0).expand(batch_proc_size,-1,-1) ) )/model.batch_size 58 | -------------------------------------------------------------------------------- /MNIST_tests/px_expander.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | ''' 6 | Expander code for clipping individual gradients. 7 | 8 | This code is due to Mikko Heikkilä (@mixheikk) 9 | ''' 10 | 11 | 12 | 13 | import torch 14 | from torch.autograd import Variable 15 | 16 | import sys 17 | 18 | 19 | # clip and accumulate clipped gradients 20 | 21 | def acc_scaled_grads(model, C, cum_grads, use_cuda=False): 22 | 23 | batch_size = model.batch_proc_size 24 | 25 | g_norm = Variable(torch.zeros(batch_size),requires_grad=False) 26 | 27 | if use_cuda: 28 | g_norm = g_norm.cuda() 29 | 30 | for p in filter(lambda p: p.requires_grad, model.parameters() ): 31 | if p.grad is not None: 32 | g_norm += torch.sum( p.grad.view(batch_size,-1)**2, 1) 33 | 34 | g_norm = torch.sqrt(g_norm) 35 | 36 | # do clipping and accumulate 37 | for p, key in zip( filter(lambda p: p.requires_grad, model.parameters()), cum_grads.keys() ): 38 | if p is not None: 39 | cum_grads[key] += torch.sum( (p.grad/torch.clamp(g_norm.contiguous().view(-1,1,1)/C, min=1)), dim=0 ) 40 | 41 | 42 | # add noise and replace model grads with cumulative grads 43 | def add_noise_with_cum_grads(model, C, sigma, cum_grads, use_cuda=False): 44 | 45 | batch_proc_size = model.batch_proc_size 46 | for p, key in zip( filter(lambda p: p.requires_grad, model.parameters()), cum_grads.keys() ): 47 | 48 | if p.grad is not None: 49 | 50 | # add noise to summed clipped pars 51 | if use_cuda: 52 | p.grad = ((cum_grads[key].expand(batch_proc_size,-1,-1) + \ 53 | Variable( (sigma*C)*torch.normal(mean=torch.zeros_like(p.grad[0]).data, \ 54 | std=1.0).expand(batch_proc_size,-1,-1) ) )/model.batch_size).cuda() 55 | else: 56 | p.grad = (cum_grads[key].expand(batch_proc_size,-1,-1) + \ 57 | Variable( (sigma*C)*torch.normal(mean=torch.zeros_like(p.grad[0]).data,std=1.0).expand(batch_proc_size,-1,-1) ) )/model.batch_size 58 | -------------------------------------------------------------------------------- /CIFAR_tests/functional.py: -------------------------------------------------------------------------------- 1 | 2 | """Functional interface""" 3 | 4 | import warnings 5 | import math 6 | from operator import mul 7 | from functools import reduce 8 | import sys 9 | 10 | import torch 11 | #from torch._C import _infer_size, _add_docstr 12 | #from . import _functions 13 | from torch.nn import _functions 14 | #from .modules import utils 15 | from torch.nn.modules import utils 16 | #from ._functions.linear import Bilinear 17 | #from torch.nn._functions.linear import Bilinear 18 | #from ._functions.padding import ConstantPadNd 19 | #from torch.nn._functions.padding import ConstantPadNd 20 | #from ._functions import vision 21 | #from torch.nn._functions import vision 22 | #from ._functions.thnn.fold import Col2Im, Im2Col 23 | #from torch.nn._functions.thnn.fold import Col2Im,Im2Col 24 | from torch.autograd import Variable 25 | #from .modules.utils import _single, _pair, _triple 26 | #from torch.nn.modules.utils import _single, _pair, _triple 27 | 28 | 29 | ''' 30 | Linear layer modified for PX gradients 31 | 32 | The code is due to Mikko Heikkilä (@mixheikk) 33 | ''' 34 | 35 | 36 | # Note: bias not checked yet 37 | def linear(input, weight, bias=None, batch_size=None): 38 | """ 39 | Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. 40 | 41 | Shape: 42 | - Input: :math:`(N, *, in\_features)` where `*` means any number of 43 | additional dimensions 44 | - Weight: :math:`(out\_features, in\_features)` 45 | - Bias: :math:`(out\_features)` 46 | - Output: :math:`(N, *, out\_features)` 47 | """ 48 | if input.dim() == 2 and bias is not None: 49 | # fused op is marginally faster 50 | if batch_size is None: 51 | return torch.addmm(bias, input, weight.t()) 52 | else: 53 | print('fused op in functional.linear not implemented yet!') 54 | sys.exit(1) 55 | return torch.addmm(bias, input, weight.t()) 56 | 57 | output = input.matmul(torch.transpose(weight,-2,-1)) 58 | 59 | # kts bias kun muu toimii 60 | if bias is not None: 61 | output += bias 62 | return output 63 | -------------------------------------------------------------------------------- /MNIST_tests/functional.py: -------------------------------------------------------------------------------- 1 | 2 | """Functional interface""" 3 | 4 | import warnings 5 | import math 6 | from operator import mul 7 | from functools import reduce 8 | import sys 9 | 10 | import torch 11 | #from torch._C import _infer_size, _add_docstr 12 | #from . import _functions 13 | from torch.nn import _functions 14 | #from .modules import utils 15 | from torch.nn.modules import utils 16 | #from ._functions.linear import Bilinear 17 | #from torch.nn._functions.linear import Bilinear 18 | #from ._functions.padding import ConstantPadNd 19 | #from torch.nn._functions.padding import ConstantPadNd 20 | #from ._functions import vision 21 | #from torch.nn._functions import vision 22 | #from ._functions.thnn.fold import Col2Im, Im2Col 23 | #from torch.nn._functions.thnn.fold import Col2Im,Im2Col 24 | from torch.autograd import Variable 25 | #from .modules.utils import _single, _pair, _triple 26 | #from torch.nn.modules.utils import _single, _pair, _triple 27 | 28 | 29 | ''' 30 | Linear layer modified for PX gradients 31 | 32 | The code is due to Mikko Heikkilä (@mixheikk) 33 | ''' 34 | 35 | 36 | # Note: bias not checked yet 37 | def linear(input, weight, bias=None, batch_size=None): 38 | """ 39 | Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. 40 | 41 | Shape: 42 | - Input: :math:`(N, *, in\_features)` where `*` means any number of 43 | additional dimensions 44 | - Weight: :math:`(out\_features, in\_features)` 45 | - Bias: :math:`(out\_features)` 46 | - Output: :math:`(N, *, out\_features)` 47 | """ 48 | if input.dim() == 2 and bias is not None: 49 | # fused op is marginally faster 50 | if batch_size is None: 51 | return torch.addmm(bias, input, weight.t()) 52 | else: 53 | print('fused op in functional.linear not implemented yet!') 54 | sys.exit(1) 55 | return torch.addmm(bias, input, weight.t()) 56 | 57 | output = input.matmul(torch.transpose(weight,-2,-1)) 58 | 59 | # kts bias kun muu toimii 60 | if bias is not None: 61 | output += bias 62 | return output 63 | -------------------------------------------------------------------------------- /CIFAR_tests/adadp.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | ''' 6 | 7 | A code for implementing the ADADP algorithm for neural networks, 8 | described in 9 | 10 | Koskela, A. and Honkela, A., 11 | Learning rate adaptation for differentially private stochastic gradient descent. 12 | arXiv preprint arXiv:1809.03832. (2018) 13 | 14 | The code is due to Antti Koskela (@koskeant) 15 | 16 | ''' 17 | 18 | 19 | 20 | 21 | 22 | import torch 23 | from torch.optim.optimizer import Optimizer, required 24 | import numpy as np 25 | 26 | class ADADP(Optimizer): 27 | 28 | def __init__(self, params, lr=1e-3): 29 | 30 | defaults = dict(lr=lr) 31 | 32 | self.p0 = None 33 | self.p1 = None 34 | self.lrs = lr 35 | self.accepted = 0 36 | self.failed = 0 37 | 38 | self.lrs_history = [] 39 | 40 | super(ADADP, self).__init__(params, defaults) 41 | 42 | def step1(self): 43 | 44 | del self.p0 45 | self.p0 = [] 46 | 47 | del self.p1 48 | self.p1 = [] 49 | 50 | for group in self.param_groups: 51 | 52 | for p in group['params']: 53 | if p.grad is None: 54 | continue 55 | 56 | dd = p.data.clone() 57 | self.p0.append(dd) 58 | 59 | self.p1.append(p.data - self.lrs*p.grad.data) 60 | p.data.add_(-0.5*self.lrs, p.grad.data) 61 | 62 | def step2(self, tol=1.0): 63 | 64 | for group in self.param_groups: 65 | 66 | err_e = 0.0 67 | 68 | for ijk,p in enumerate(group['params']): 69 | p.data.add_(-0.5*self.lrs, p.grad.data) 70 | err_e += (((self.p1[ijk] - p.data)**2/(torch.max(torch.ones(self.p1[ijk].size()).cuda(),self.p1[ijk]**2))).norm(1)) 71 | 72 | err_e = np.sqrt(float(err_e)) 73 | 74 | self.lrs = float(self.lrs*min(max(np.sqrt(tol/err_e),0.9), 1.1)) 75 | 76 | ## Accept the step only if err < tol. 77 | ## Can be sometimes neglected (more accepted steps) 78 | if err_e > 1.0*tol: 79 | for ijk,p in enumerate(group['params']): 80 | p.data = self.p0[ijk] 81 | if err_e < tol: 82 | self.accepted += 1 83 | else : 84 | self.failed += 1 85 | 86 | self.lrs_history.append(self.lrs) 87 | -------------------------------------------------------------------------------- /MNIST_tests/adadp.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | ''' 6 | 7 | A code for implementing the ADADP algorithm for neural networks, 8 | described in 9 | 10 | Koskela, A. and Honkela, A., 11 | Learning rate adaptation for differentially private stochastic gradient descent. 12 | arXiv preprint arXiv:1809.03832. (2018) 13 | 14 | The code is due to Antti Koskela (@koskeant) 15 | 16 | ''' 17 | 18 | 19 | 20 | 21 | 22 | import torch 23 | from torch.optim.optimizer import Optimizer, required 24 | import numpy as np 25 | 26 | class ADADP(Optimizer): 27 | 28 | def __init__(self, params, lr=1e-3): 29 | 30 | defaults = dict(lr=lr) 31 | 32 | self.p0 = None 33 | self.p1 = None 34 | self.lrs = lr 35 | self.accepted = 0 36 | self.failed = 0 37 | 38 | self.lrs_history = [] 39 | 40 | super(ADADP, self).__init__(params, defaults) 41 | 42 | def step1(self): 43 | 44 | del self.p0 45 | self.p0 = [] 46 | 47 | del self.p1 48 | self.p1 = [] 49 | 50 | for group in self.param_groups: 51 | 52 | for p in group['params']: 53 | if p.grad is None: 54 | continue 55 | 56 | dd = p.data.clone() 57 | self.p0.append(dd) 58 | 59 | self.p1.append(p.data - self.lrs*p.grad.data) 60 | p.data.add_(-0.5*self.lrs, p.grad.data) 61 | 62 | def step2(self, tol=1.0): 63 | 64 | for group in self.param_groups: 65 | 66 | err_e = 0.0 67 | 68 | for ijk,p in enumerate(group['params']): 69 | p.data.add_(-0.5*self.lrs, p.grad.data) 70 | err_e += (((self.p1[ijk] - p.data)**2/(torch.max(torch.ones(self.p1[ijk].size()).cuda(),self.p1[ijk]**2))).norm(1)) 71 | 72 | err_e = np.sqrt(float(err_e)) 73 | 74 | self.lrs = float(self.lrs*min(max(np.sqrt(tol/err_e),0.9), 1.1)) 75 | 76 | ## Accept the step only if err < tol. 77 | ## Can be sometimes neglected (more accepted steps) 78 | if err_e > 1.0*tol: 79 | for ijk,p in enumerate(group['params']): 80 | p.data = self.p0[ijk] 81 | if err_e < tol: 82 | self.accepted += 1 83 | else : 84 | self.failed += 1 85 | 86 | self.lrs_history.append(self.lrs) 87 | -------------------------------------------------------------------------------- /MNIST_tests/adadp_cpu.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | ''' 6 | 7 | A code for implementing the ADADP algorithm for neural networks, 8 | described in 9 | 10 | Koskela, A. and Honkela, A., 11 | Learning rate adaptation for differentially private stochastic gradient descent. 12 | arXiv preprint arXiv:1809.03832. (2018) 13 | 14 | The code is due to Antti Koskela (@koskeant) 15 | 16 | ''' 17 | 18 | 19 | 20 | 21 | 22 | import torch 23 | from torch.optim.optimizer import Optimizer, required 24 | import numpy as np 25 | 26 | class ADADP(Optimizer): 27 | 28 | def __init__(self, params, lr=1e-3): 29 | 30 | defaults = dict(lr=lr) 31 | 32 | self.p0 = None 33 | self.p1 = None 34 | self.lrs = lr 35 | self.accepted = 0 36 | self.failed = 0 37 | 38 | self.lrs_history = [] 39 | 40 | super(ADADP, self).__init__(params, defaults) 41 | 42 | def step1(self): 43 | 44 | del self.p0 45 | self.p0 = [] 46 | 47 | del self.p1 48 | self.p1 = [] 49 | 50 | for group in self.param_groups: 51 | 52 | for p in group['params']: 53 | if p.grad is None: 54 | continue 55 | 56 | dd = p.data.clone() 57 | self.p0.append(dd) 58 | 59 | self.p1.append(p.data - self.lrs*p.grad.data) 60 | p.data.add_(-0.5*self.lrs, p.grad.data) 61 | 62 | def step2(self, tol=1.0): 63 | 64 | for group in self.param_groups: 65 | 66 | err_e = 0.0 67 | 68 | for ijk,p in enumerate(group['params']): 69 | p.data.add_(-0.5*self.lrs, p.grad.data) 70 | err_e += (((self.p1[ijk] - p.data)**2/(torch.max(torch.ones(self.p1[ijk].size()),self.p1[ijk]**2))).norm(1)) 71 | 72 | err_e = np.sqrt(float(err_e)) 73 | 74 | self.lrs = float(self.lrs*min(max(np.sqrt(tol/err_e),0.9), 1.1)) 75 | 76 | ## Accept the step only if err < tol. 77 | ## Can be sometimes neglected (more accepted steps) 78 | if err_e > 1.0*tol: 79 | for ijk,p in enumerate(group['params']): 80 | p.data = self.p0[ijk] 81 | if err_e < tol: 82 | self.accepted += 1 83 | else : 84 | self.failed += 1 85 | 86 | self.lrs_history.append(self.lrs) 87 | -------------------------------------------------------------------------------- /CIFAR_tests/sgd_train_conv.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ''' 5 | A code for training the convolutional layers using Cifar-100 dataset. 6 | Related to the Cifar-10 experiments of 7 | 8 | Koskela, A. and Honkela, A., 9 | Learning rate adaptation for differentially private stochastic gradient descent. 10 | arXiv preprint arXiv:1809.03832. (2018) 11 | 12 | ''' 13 | 14 | 15 | 16 | 17 | 18 | from __future__ import print_function 19 | import argparse 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | import torch.optim as optim 24 | from torchvision import datasets, transforms 25 | import torchvision 26 | 27 | 28 | import numpy as np 29 | 30 | 31 | # Training settings 32 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 33 | parser.add_argument('--batch-size', type=int, default=100, metavar='N', 34 | help='input batch size for training (default: 64)') 35 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 36 | help='input batch size for testing (default: 1000)') 37 | parser.add_argument('--epochs', type=int, default= 100, metavar='N', 38 | help='number of epochs to train (default: 10)') 39 | parser.add_argument('--lr', type=float, default=0.02, metavar='LR', 40 | help='learning rate (default: 0.01)') 41 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 42 | help='SGD momentum (default: 0.5)') 43 | parser.add_argument('--no-cuda', action='store_true', default=False, 44 | help='disables CUDA training') 45 | parser.add_argument('--seed', type=int, default=1, metavar='S', 46 | help='random seed (default: 1)') 47 | parser.add_argument('--log-interval', type=int, default=100, metavar='N', 48 | help='how many batches to wait before logging training status') 49 | 50 | args = parser.parse_args() 51 | 52 | use_cuda = not args.no_cuda and torch.cuda.is_available() 53 | 54 | torch.manual_seed(args.seed) 55 | 56 | device = torch.device("cuda" if use_cuda else "cpu") 57 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 58 | 59 | print('==> Preparing data..') 60 | transform_train = transforms.Compose([ 61 | transforms.RandomCrop(32, padding=4), 62 | transforms.RandomHorizontalFlip(), 63 | transforms.ToTensor(), 64 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 65 | ]) 66 | 67 | transform_test = transforms.Compose([ 68 | transforms.ToTensor(), 69 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 70 | ]) 71 | 72 | 73 | 74 | trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) 75 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2) 76 | 77 | testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test) 78 | test_loader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=True, num_workers=2) 79 | 80 | class Net(nn.Module): 81 | 82 | def __init__(self): 83 | super(Net, self).__init__() 84 | 85 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=0) 86 | self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2) 87 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0) 88 | self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2) 89 | self.fc1 = nn.Linear(1600, 500) 90 | self.fc2 = nn.Linear(500, 100) 91 | 92 | 93 | def forward(self, x): 94 | x = self.pool1(F.relu(self.conv1(x))) 95 | x = self.pool2(F.relu(self.conv2(x))) 96 | x = x.view(x.size(0), -1) 97 | x = F.relu(self.fc1(x)) 98 | x = F.relu(self.fc2(x)) 99 | return x 100 | 101 | 102 | model = Net().to(device) 103 | 104 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.0) 105 | #optimizer = optim.Adam(model.parameters())#, lr=0.01) 106 | 107 | print('Learning rate: ' + str(args.lr)) 108 | #optimizer = optim.Adam(model.parameters(), lr=args.lr) 109 | #optimizer = sgd2.SGD(model.parameters()) 110 | 111 | 112 | criterion = nn.CrossEntropyLoss() 113 | #criterion = nn.NLLLoss(size_average=False) 114 | 115 | #criterion = cross_entropy() 116 | 117 | def train(epoch): 118 | 119 | model.train() 120 | 121 | for batch_idx, (data, target) in enumerate(train_loader): 122 | 123 | data = data.to(device) 124 | data = data.cuda() 125 | target = target.cuda() 126 | 127 | 128 | optimizer.zero_grad() 129 | 130 | output = model(data) 131 | loss = criterion(output, target) 132 | loss.backward() 133 | 134 | optimizer.step() 135 | 136 | if batch_idx % args.log_interval == 0: 137 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 138 | epoch, batch_idx * len(data), len(train_loader.dataset), 139 | 100. * batch_idx / len(train_loader), loss.item())) 140 | 141 | def test(): 142 | 143 | model.eval() 144 | test_loss = 0 145 | correct = 0 146 | 147 | with torch.no_grad(): 148 | for data, target in test_loader: 149 | data, target = data.to(device), target.to(device) 150 | output = model(data) 151 | test_loss += criterion(output, target).item() # sum up batch loss 152 | pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability 153 | correct += pred.eq(target.view_as(pred)).sum().item() 154 | 155 | test_loss /= len(test_loader.dataset) 156 | percent = 100. * correct / len(test_loader.dataset) 157 | 158 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 159 | test_loss, correct, len(test_loader.dataset), 160 | 100. * correct / len(test_loader.dataset))) 161 | 162 | return test_loss,percent 163 | 164 | 165 | accs = [] 166 | cc = [] 167 | 168 | for epoch in range(1, args.epochs + 1): 169 | train(epoch) 170 | tl, prcnt = test() 171 | accs.append(tl) 172 | cc.append(prcnt) 173 | 174 | 175 | tb_save = [] 176 | for (ii,p) in enumerate(model.parameters()): 177 | if ii<4: 178 | tb_save.append(p.data) 179 | 180 | torch.save(tb_save,'conv_layers.pt') 181 | -------------------------------------------------------------------------------- /MNIST_tests/main_adadp.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | ''' 11 | 12 | A code for training in a differentially private manner a fully 13 | connected network using ADADP. 14 | Here the method is applied to the MNIST data set. 15 | 16 | The ADADP algorithm is described in 17 | 18 | Koskela, A. and Honkela, A., 19 | Learning rate adaptation for differentially private stochastic gradient descent. 20 | arXiv preprint arXiv:1809.03832. (2018) 21 | 22 | This code is due to Antti Koskela (@koskeant) and is based 23 | on a code by Mikko Heikkilä (@mixheikk). 24 | 25 | ''' 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | import copy 38 | import datetime 39 | import numpy as np 40 | import pickle 41 | import sys 42 | import time 43 | import logging 44 | from collections import OrderedDict as od 45 | from matplotlib import pyplot as plt 46 | import argparse 47 | 48 | import torch 49 | import torch.nn.functional as F 50 | from torch import nn 51 | from torch import optim 52 | from torch.autograd import Variable 53 | import torchvision 54 | 55 | from torchvision import datasets, transforms 56 | 57 | import linear 58 | 59 | import adadp 60 | import adadp_cpu 61 | 62 | import gaussian_moments as gm 63 | 64 | import itertools 65 | from types import SimpleNamespace 66 | import px_expander 67 | 68 | 69 | 70 | 71 | 72 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 73 | parser.add_argument('--batch_size', type=int, default=500, metavar='N', 74 | help='input batch size for training') 75 | parser.add_argument('--noise_sigma', type=float, default=2.0, metavar='M', 76 | help='noise_sigma') 77 | parser.add_argument('--n_epochs', type=int, default=10, metavar='N', 78 | help='n_epochs') 79 | parser.add_argument('--run_id', type=int, default=1, metavar='N', 80 | help='run_id') 81 | parser.add_argument('--tol', type=float, default=1.0, metavar='t', 82 | help='tolerance parameter') 83 | 84 | args = parser.parse_args() 85 | 86 | 87 | 88 | print(torch.__version__) 89 | print(torchvision.__version__) 90 | 91 | 92 | randomize_data = True 93 | batch_size = args.batch_size # Note: overwritten by BO if used, last batch is skipped if not full size 94 | batch_proc_size = 10 # needs to divide or => to batch size 95 | 96 | n_hidden_layers = 1 # number of units/layer (same for all) is set in bo parameters 97 | latent_dim = 512 # Note: overwritten by BO if used 98 | output_dim = 10 99 | log_interval = 6000//batch_size # Note: this is absolute interval, actual is this//batch_size 100 | 101 | 102 | 103 | 104 | use_dp = True # dp vs non-dp model 105 | scale_grads = True 106 | grad_norm_max = 10 107 | noise_sigma = args.noise_sigma 108 | delta = 1e-5 109 | 110 | tol = args.tol 111 | 112 | n_epochs = args.n_epochs 113 | l_rate = 0.01 114 | 115 | run_id = args.run_id 116 | 117 | 118 | np.random.seed(17*run_id+3) 119 | 120 | input_dim = (28,28) 121 | 122 | 123 | 124 | 125 | if torch.cuda.is_available() and torch.cuda.device_count() > 0: 126 | print('Using cuda') 127 | torch.cuda.manual_seed(11*run_id+19) 128 | use_cuda = True 129 | else: 130 | use_cuda=False 131 | 132 | data_dir = './data/' 133 | 134 | 135 | 136 | trainset = torchvision.datasets.MNIST('./data', train=True, download=True, 137 | transform=transforms.Compose([ 138 | transforms.ToTensor(), 139 | transforms.Normalize((0.1307,), (0.3081,)) 140 | ])) 141 | 142 | testset = torchvision.datasets.MNIST('./data', train=False, transform=transforms.Compose([ 143 | transforms.ToTensor(), 144 | transforms.Normalize((0.1307,), (0.3081,))])) 145 | 146 | 147 | sampling_ratio = float(batch_size)/len(trainset) 148 | 149 | 150 | 151 | 152 | # moments accountant 153 | def update_privacy_pars(priv_pars): 154 | verify = False 155 | max_lmbd = 32 156 | lmbds = range(1, max_lmbd + 1) 157 | log_moments = [] 158 | for lmbd in lmbds: 159 | log_moment = 0 160 | ''' 161 | print('Here q = ' + str(priv_pars['q'])) 162 | print('Here sigma = ' + str(priv_pars['sigma'])) 163 | print('Here T = ' + str(priv_pars['T'])) 164 | ''' 165 | log_moment += gm.compute_log_moment(priv_pars['q'], priv_pars['sigma'], priv_pars['T'], lmbd, verify=verify) 166 | log_moments.append((lmbd, log_moment)) 167 | priv_pars['eps'], _ = gm.get_privacy_spent(log_moments, target_delta=priv_pars['delta']) 168 | return priv_pars 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | class simpleExpandedDNN(nn.Module): 179 | def __init__(self, batch_size, batch_proc_size): 180 | super(simpleExpandedDNN, self).__init__() 181 | #self.lrelu = nn.LeakyReLU() 182 | self.relu = nn.ReLU() 183 | 184 | self.batch_proc_size = batch_proc_size 185 | self.batch_size = batch_size 186 | 187 | self.linears = nn.ModuleList([ linear.Linear(1*input_dim[0]*input_dim[1], latent_dim, bias=False, batch_size=batch_proc_size)]) 188 | if n_hidden_layers > 0: 189 | for k in range(n_hidden_layers): 190 | self.linears.append( linear.Linear(latent_dim, latent_dim,bias=False,batch_size=batch_proc_size) ) 191 | self.final_fc = linear.Linear(self.linears[-1].out_features, output_dim,bias=False, batch_size=batch_proc_size) 192 | self.train_loader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, 193 | shuffle=randomize_data, num_workers=4) 194 | self.test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 195 | shuffle=randomize_data, num_workers=4) 196 | 197 | def forward(self, x): 198 | 199 | x = torch.unsqueeze(x.view(-1, 1*input_dim[0]*input_dim[1]),1) 200 | 201 | for k_linear in self.linears: 202 | x = self.relu(k_linear(x)) 203 | x = self.final_fc(x) 204 | return nn.functional.log_softmax(x.view(-1,output_dim),dim=1) 205 | 206 | 207 | 208 | 209 | model = simpleExpandedDNN(batch_size=batch_size, batch_proc_size=batch_proc_size) 210 | 211 | print('model: {}'.format(model)) 212 | 213 | 214 | for p in model.parameters(): 215 | if p is not None: 216 | p.data.copy_( p[0].data.clone().repeat(batch_proc_size,1,1) ) 217 | 218 | if use_cuda: 219 | model = model.cuda() 220 | 221 | loss_function = nn.NLLLoss(size_average=False) 222 | 223 | 224 | 225 | 226 | #optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=l_rate, momentum=0) 227 | 228 | if use_cuda: 229 | optimizer = adadp.ADADP(model.parameters()) 230 | else: 231 | optimizer = adadp_cpu.ADADP(model.parameters()) 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | def train(epoch, model, T): 242 | 243 | model.train() 244 | ii=0 245 | 246 | 247 | print('run_id: ' +str(run_id)) 248 | 249 | for batch_idx, (data, target) in enumerate(model.train_loader): 250 | if data.shape[0] != batch_size: 251 | print('skipped last batch') 252 | continue 253 | 254 | optimizer.zero_grad() 255 | loss_tot = 0 256 | 257 | data, target = Variable(data, requires_grad=False), Variable(target, requires_grad=False) 258 | if use_cuda: 259 | data, target = data.cuda(), target.cuda() 260 | 261 | 262 | if use_dp and scale_grads: 263 | cum_grads = od() 264 | for i,p in enumerate(model.parameters()): 265 | if p.requires_grad: 266 | if use_cuda: 267 | cum_grads[str(i)] = Variable(torch.zeros(p.shape[1:]),requires_grad=False).cuda() 268 | else: 269 | cum_grads[str(i)] = Variable(torch.zeros(p.shape[1:]),requires_grad=False) 270 | 271 | for i_batch in range(batch_size//batch_proc_size): 272 | 273 | data_proc = data[i_batch*batch_proc_size:(i_batch+1)*batch_proc_size,:] 274 | target_proc = target[i_batch*batch_proc_size:(i_batch+1)*batch_proc_size] 275 | 276 | output = model(data_proc) 277 | 278 | loss = loss_function(output,target_proc) 279 | loss_tot += loss.data 280 | 281 | loss.backward() 282 | 283 | if use_dp and scale_grads: 284 | px_expander.acc_scaled_grads(model=model,C=grad_norm_max, cum_grads=cum_grads, use_cuda=use_cuda) 285 | optimizer.zero_grad() 286 | 287 | 288 | if use_dp: 289 | px_expander.add_noise_with_cum_grads(model=model, C=grad_norm_max, sigma=noise_sigma, cum_grads=cum_grads, use_cuda=use_cuda) 290 | 291 | # step1 corresponds to the first part of ADADP (i.e. only one step of size half), 292 | # step2 to the second part (error estimate + step size adaptation) 293 | 294 | if batch_idx%2 == 0: 295 | optimizer.step1() 296 | else: 297 | optimizer.step2(tol) 298 | 299 | #For SGD: 300 | #optimizer.step() 301 | 302 | 303 | T += 1 304 | 305 | if batch_idx % log_interval == 0: 306 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 307 | epoch, batch_idx * len(data), len(model.train_loader.dataset), 308 | 100. * batch_idx / len(model.train_loader), loss_tot.item()/batch_size)) 309 | 310 | 311 | return T 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | def test(model, epoch): 321 | 322 | model.eval() 323 | 324 | test_loss = 0 325 | correct = 0 326 | 327 | for data, target in model.test_loader: 328 | if data.shape[0] != model.batch_size: 329 | print('skipped last batch') 330 | continue 331 | 332 | data, target = Variable(data, requires_grad=False), Variable(target, requires_grad=False) 333 | if use_cuda: 334 | data, target = data.cuda(), target.cuda() 335 | 336 | for i_batch in range(model.batch_size//batch_proc_size): 337 | 338 | data_proc = data[i_batch*batch_proc_size:(i_batch+1)*batch_proc_size,:] 339 | target_proc = target[i_batch*batch_proc_size:(i_batch+1)*batch_proc_size] 340 | if use_cuda: 341 | data_proc = data_proc.cuda() 342 | target_proc = target_proc.cuda() 343 | 344 | output = model(data_proc) 345 | 346 | test_loss += F.nll_loss(output, target_proc, size_average=False).item() 347 | 348 | pred = output.data.max(1, keepdim=True)[1] 349 | 350 | correct += pred.eq(target_proc.data.view_as(pred)).cpu().sum() 351 | 352 | test_loss /= len(model.test_loader.dataset) 353 | 354 | acc = correct.numpy() / len(model.test_loader.dataset) 355 | 356 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 357 | test_loss, correct, len(model.test_loader.dataset), 358 | 100. * acc)) 359 | 360 | return test_loss, acc 361 | 362 | 363 | 364 | 365 | priv_pars = od() 366 | priv_pars['T'], priv_pars['eps'],priv_pars['delta'], priv_pars['sigma'], priv_pars['q'] = 0, 0, delta, noise_sigma, sampling_ratio 367 | 368 | 369 | 370 | 371 | 372 | accs = [] 373 | epsilons = [] 374 | 375 | for epoch in range(1,n_epochs+1): 376 | 377 | loss, acc = test(model, epoch) 378 | 379 | accs.append(acc) 380 | 381 | print('Current privacy pars: {}'.format(priv_pars)) 382 | priv_pars['T'] = train(epoch, model, priv_pars['T']) 383 | 384 | if use_dp and scale_grads and noise_sigma > 0: 385 | update_privacy_pars(priv_pars) 386 | 387 | epsilons.append(priv_pars['eps']) 388 | 389 | # Save the test accuracies 390 | np.save('accs_' +str(run_id) + '_' + str(noise_sigma) + '_' + str(batch_size),accs) 391 | -------------------------------------------------------------------------------- /CIFAR_tests/gaussian_moments.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A standalone utility for computing the log moments. 17 | 18 | The utility for computing the log moments. It consists of two methods. 19 | compute_log_moment(q, sigma, T, lmbd) computes the log moment with sampling 20 | probability q, noise sigma, order lmbd, and T steps. get_privacy_spent computes 21 | delta (or eps) given log moments and eps (or delta). 22 | 23 | Example use: 24 | 25 | Suppose that we have run an algorithm with parameters, an array of 26 | (q1, sigma1, T1) ... (qk, sigmak, Tk), and we wish to compute eps for a given 27 | delta. The example code would be: 28 | 29 | max_lmbd = 32 30 | lmbds = xrange(1, max_lmbd + 1) 31 | log_moments = [] 32 | for lmbd in lmbds: 33 | log_moment = 0 34 | for q, sigma, T in parameters: 35 | log_moment += compute_log_moment(q, sigma, T, lmbd) 36 | log_moments.append((lmbd, log_moment)) 37 | eps, delta = get_privacy_spent(log_moments, target_delta=delta) 38 | 39 | To verify that the I1 >= I2 (see comments in GaussianMomentsAccountant in 40 | accountant.py for the context), run the same loop above with verify=True 41 | passed to compute_log_moment. 42 | """ 43 | import math 44 | import sys 45 | 46 | import numpy as np 47 | import scipy.integrate as integrate 48 | import scipy.stats 49 | #from sympy.mpmath import mp 50 | import mpmath as mp 51 | 52 | def _to_np_float64(v): 53 | if math.isnan(v) or math.isinf(v): 54 | return np.inf 55 | return np.float64(v) 56 | 57 | 58 | ###################### 59 | # FLOAT64 ARITHMETIC # 60 | ###################### 61 | 62 | 63 | def pdf_gauss(x, sigma, mean=0): 64 | return scipy.stats.norm.pdf(x, loc=mean, scale=sigma) 65 | 66 | 67 | def cropped_ratio(a, b): 68 | if a < 1E-50 and b < 1E-50: 69 | return 1. 70 | else: 71 | return a / b 72 | 73 | 74 | def integral_inf(fn): 75 | integral, _ = integrate.quad(fn, -np.inf, np.inf) 76 | return integral 77 | 78 | 79 | def integral_bounded(fn, lb, ub): 80 | integral, _ = integrate.quad(fn, lb, ub) 81 | return integral 82 | 83 | 84 | def distributions(sigma, q): 85 | mu0 = lambda y: pdf_gauss(y, sigma=sigma, mean=0.0) 86 | mu1 = lambda y: pdf_gauss(y, sigma=sigma, mean=1.0) 87 | mu = lambda y: (1 - q) * mu0(y) + q * mu1(y) 88 | return mu0, mu1, mu 89 | 90 | 91 | def compute_a(sigma, q, lmbd, verbose=False): 92 | lmbd_int = int(math.ceil(lmbd)) 93 | if lmbd_int == 0: 94 | return 1.0 95 | 96 | a_lambda_first_term_exact = 0 97 | a_lambda_second_term_exact = 0 98 | for i in range(lmbd_int + 1): 99 | coef_i = scipy.special.binom(lmbd_int, i) * (q ** i) 100 | s1, s2 = 0, 0 101 | for j in range(i + 1): 102 | coef_j = scipy.special.binom(i, j) * (-1) ** (i - j) 103 | s1 += coef_j * np.exp((j * j - j) / (2.0 * (sigma ** 2))) 104 | s2 += coef_j * np.exp((j * j + j) / (2.0 * (sigma ** 2))) 105 | a_lambda_first_term_exact += coef_i * s1 106 | a_lambda_second_term_exact += coef_i * s2 107 | 108 | a_lambda_exact = ((1.0 - q) * a_lambda_first_term_exact + 109 | q * a_lambda_second_term_exact) 110 | if verbose: 111 | print("A: by binomial expansion {} = {} + {}".format( 112 | a_lambda_exact, 113 | (1.0 - q) * a_lambda_first_term_exact, 114 | q * a_lambda_second_term_exact)) 115 | return _to_np_float64(a_lambda_exact) 116 | 117 | 118 | def compute_b(sigma, q, lmbd, verbose=False): 119 | mu0, _, mu = distributions(sigma, q) 120 | 121 | b_lambda_fn = lambda z: mu0(z) * np.power(cropped_ratio(mu0(z), mu(z)), lmbd) 122 | b_lambda = integral_inf(b_lambda_fn) 123 | m = sigma ** 2 * (np.log((2. - q) / (1. - q)) + 1. / (2 * sigma ** 2)) 124 | 125 | b_fn = lambda z: (np.power(mu0(z) / mu(z), lmbd) - 126 | np.power(mu(-z) / mu0(z), lmbd)) 127 | if verbose: 128 | print("M =", m) 129 | print("f(-M) = {} f(M) = {}".format(b_fn(-m), b_fn(m))) 130 | assert b_fn(-m) < 0 and b_fn(m) < 0 131 | 132 | b_lambda_int1_fn = lambda z: (mu0(z) * 133 | np.power(cropped_ratio(mu0(z), mu(z)), lmbd)) 134 | b_lambda_int2_fn = lambda z: (mu0(z) * 135 | np.power(cropped_ratio(mu(z), mu0(z)), lmbd)) 136 | b_int1 = integral_bounded(b_lambda_int1_fn, -m, m) 137 | b_int2 = integral_bounded(b_lambda_int2_fn, -m, m) 138 | 139 | a_lambda_m1 = compute_a(sigma, q, lmbd - 1) 140 | b_bound = a_lambda_m1 + b_int1 - b_int2 141 | 142 | if verbose: 143 | print("B: by numerical integration", b_lambda) 144 | print("B must be no more than ", b_bound) 145 | print(b_lambda, b_bound) 146 | return _to_np_float64(b_lambda) 147 | 148 | 149 | ########################### 150 | # MULTIPRECISION ROUTINES # 151 | ########################### 152 | 153 | 154 | def pdf_gauss_mp(x, sigma, mean): 155 | return mp.mpf(1.) / mp.sqrt(mp.mpf("2.") * sigma ** 2 * mp.pi) * mp.exp( 156 | - (x - mean) ** 2 / (mp.mpf("2.") * sigma ** 2)) 157 | 158 | 159 | def integral_inf_mp(fn): 160 | integral, _ = mp.quad(fn, [-mp.inf, mp.inf], error=True) 161 | return integral 162 | 163 | 164 | def integral_bounded_mp(fn, lb, ub): 165 | integral, _ = mp.quad(fn, [lb, ub], error=True) 166 | return integral 167 | 168 | 169 | def distributions_mp(sigma, q): 170 | mu0 = lambda y: pdf_gauss_mp(y, sigma=sigma, mean=mp.mpf(0)) 171 | mu1 = lambda y: pdf_gauss_mp(y, sigma=sigma, mean=mp.mpf(1)) 172 | mu = lambda y: (1 - q) * mu0(y) + q * mu1(y) 173 | return mu0, mu1, mu 174 | 175 | 176 | def compute_a_mp(sigma, q, lmbd, verbose=False): 177 | lmbd_int = int(math.ceil(lmbd)) 178 | if lmbd_int == 0: 179 | return 1.0 180 | 181 | mu0, mu1, mu = distributions_mp(sigma, q) 182 | a_lambda_fn = lambda z: mu(z) * (mu(z) / mu0(z)) ** lmbd_int 183 | a_lambda_first_term_fn = lambda z: mu0(z) * (mu(z) / mu0(z)) ** lmbd_int 184 | a_lambda_second_term_fn = lambda z: mu1(z) * (mu(z) / mu0(z)) ** lmbd_int 185 | 186 | a_lambda = integral_inf_mp(a_lambda_fn) 187 | a_lambda_first_term = integral_inf_mp(a_lambda_first_term_fn) 188 | a_lambda_second_term = integral_inf_mp(a_lambda_second_term_fn) 189 | 190 | if verbose: 191 | print("A: by numerical integration {} = {} + {}".format( 192 | a_lambda, 193 | (1 - q) * a_lambda_first_term, 194 | q * a_lambda_second_term)) 195 | 196 | return _to_np_float64(a_lambda) 197 | 198 | 199 | def compute_b_mp(sigma, q, lmbd, verbose=False): 200 | lmbd_int = int(math.ceil(lmbd)) 201 | if lmbd_int == 0: 202 | return 1.0 203 | 204 | mu0, _, mu = distributions_mp(sigma, q) 205 | 206 | b_lambda_fn = lambda z: mu0(z) * (mu0(z) / mu(z)) ** lmbd_int 207 | b_lambda = integral_inf_mp(b_lambda_fn) 208 | 209 | m = sigma ** 2 * (mp.log((2 - q) / (1 - q)) + 1 / (2 * (sigma ** 2))) 210 | b_fn = lambda z: ((mu0(z) / mu(z)) ** lmbd_int - 211 | (mu(-z) / mu0(z)) ** lmbd_int) 212 | if verbose: 213 | print("M =", m) 214 | print("f(-M) = {} f(M) = {}".format(b_fn(-m), b_fn(m))) 215 | assert b_fn(-m) < 0 and b_fn(m) < 0 216 | 217 | b_lambda_int1_fn = lambda z: mu0(z) * (mu0(z) / mu(z)) ** lmbd_int 218 | b_lambda_int2_fn = lambda z: mu0(z) * (mu(z) / mu0(z)) ** lmbd_int 219 | b_int1 = integral_bounded_mp(b_lambda_int1_fn, -m, m) 220 | b_int2 = integral_bounded_mp(b_lambda_int2_fn, -m, m) 221 | 222 | a_lambda_m1 = compute_a_mp(sigma, q, lmbd - 1) 223 | b_bound = a_lambda_m1 + b_int1 - b_int2 224 | 225 | if verbose: 226 | print("B by numerical integration", b_lambda) 227 | print("B must be no more than ", b_bound) 228 | assert b_lambda < b_bound + 1e-5 229 | return _to_np_float64(b_lambda) 230 | 231 | 232 | def _compute_delta(log_moments, eps): 233 | """Compute delta for given log_moments and eps. 234 | 235 | Args: 236 | log_moments: the log moments of privacy loss, in the form of pairs 237 | of (moment_order, log_moment) 238 | eps: the target epsilon. 239 | Returns: 240 | delta 241 | """ 242 | min_delta = 1.0 243 | for moment_order, log_moment in log_moments: 244 | if moment_order == 0: 245 | continue 246 | if math.isinf(log_moment) or math.isnan(log_moment): 247 | sys.stderr.write("The %d-th order is inf or Nan\n" % moment_order) 248 | continue 249 | if log_moment < moment_order * eps: 250 | min_delta = min(min_delta, 251 | math.exp(log_moment - moment_order * eps)) 252 | return min_delta 253 | 254 | 255 | def _compute_eps(log_moments, delta): 256 | """Compute epsilon for given log_moments and delta. 257 | 258 | Args: 259 | log_moments: the log moments of privacy loss, in the form of pairs 260 | of (moment_order, log_moment) 261 | delta: the target delta. 262 | Returns: 263 | epsilon 264 | """ 265 | min_eps = float("inf") 266 | for moment_order, log_moment in log_moments: 267 | if moment_order == 0: 268 | continue 269 | if math.isinf(log_moment) or math.isnan(log_moment): 270 | sys.stderr.write("The %d-th order is inf or Nan\n" % moment_order) 271 | continue 272 | min_eps = min(min_eps, (log_moment - math.log(delta)) / moment_order) 273 | return min_eps 274 | 275 | 276 | def compute_log_moment(q, sigma, steps, lmbd, verify=False, verbose=False): 277 | """Compute the log moment of Gaussian mechanism for given parameters. 278 | 279 | Args: 280 | q: the sampling ratio. 281 | sigma: the noise sigma. 282 | steps: the number of steps. 283 | lmbd: the moment order. 284 | verify: if False, only compute the symbolic version. If True, computes 285 | both symbolic and numerical solutions and verifies the results match. 286 | verbose: if True, print out debug information. 287 | Returns: 288 | the log moment with type np.float64, could be np.inf. 289 | """ 290 | moment = compute_a(sigma, q, lmbd, verbose=verbose) 291 | if verify: 292 | mp.dps = 50 293 | moment_a_mp = compute_a_mp(sigma, q, lmbd, verbose=verbose) 294 | moment_b_mp = compute_b_mp(sigma, q, lmbd, verbose=verbose) 295 | np.testing.assert_allclose(moment, moment_a_mp, rtol=1e-10) 296 | if not np.isinf(moment_a_mp): 297 | # The following test fails for (1, np.inf)! 298 | np.testing.assert_array_less(moment_b_mp, moment_a_mp) 299 | if np.isinf(moment): 300 | return np.inf 301 | else: 302 | return np.log(moment) * steps 303 | 304 | 305 | def get_privacy_spent(log_moments, target_eps=None, target_delta=None): 306 | """Compute delta (or eps) for given eps (or delta) from log moments. 307 | 308 | Args: 309 | log_moments: array of (moment_order, log_moment) pairs. 310 | target_eps: if not None, the epsilon for which we would like to compute 311 | corresponding delta value. 312 | target_delta: if not None, the delta for which we would like to compute 313 | corresponding epsilon value. Exactly one of target_eps and target_delta 314 | is None. 315 | Returns: 316 | eps, delta pair 317 | """ 318 | assert (target_eps is None) ^ (target_delta is None) 319 | assert not ((target_eps is None) and (target_delta is None)) 320 | if target_eps is not None: 321 | return (target_eps, _compute_delta(log_moments, target_eps)) 322 | else: 323 | return (_compute_eps(log_moments, target_delta), target_delta) 324 | -------------------------------------------------------------------------------- /MNIST_tests/gaussian_moments.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A standalone utility for computing the log moments. 17 | 18 | The utility for computing the log moments. It consists of two methods. 19 | compute_log_moment(q, sigma, T, lmbd) computes the log moment with sampling 20 | probability q, noise sigma, order lmbd, and T steps. get_privacy_spent computes 21 | delta (or eps) given log moments and eps (or delta). 22 | 23 | Example use: 24 | 25 | Suppose that we have run an algorithm with parameters, an array of 26 | (q1, sigma1, T1) ... (qk, sigmak, Tk), and we wish to compute eps for a given 27 | delta. The example code would be: 28 | 29 | max_lmbd = 32 30 | lmbds = xrange(1, max_lmbd + 1) 31 | log_moments = [] 32 | for lmbd in lmbds: 33 | log_moment = 0 34 | for q, sigma, T in parameters: 35 | log_moment += compute_log_moment(q, sigma, T, lmbd) 36 | log_moments.append((lmbd, log_moment)) 37 | eps, delta = get_privacy_spent(log_moments, target_delta=delta) 38 | 39 | To verify that the I1 >= I2 (see comments in GaussianMomentsAccountant in 40 | accountant.py for the context), run the same loop above with verify=True 41 | passed to compute_log_moment. 42 | """ 43 | import math 44 | import sys 45 | 46 | import numpy as np 47 | import scipy.integrate as integrate 48 | import scipy.stats 49 | #from sympy.mpmath import mp 50 | import mpmath as mp 51 | 52 | def _to_np_float64(v): 53 | if math.isnan(v) or math.isinf(v): 54 | return np.inf 55 | return np.float64(v) 56 | 57 | 58 | ###################### 59 | # FLOAT64 ARITHMETIC # 60 | ###################### 61 | 62 | 63 | def pdf_gauss(x, sigma, mean=0): 64 | return scipy.stats.norm.pdf(x, loc=mean, scale=sigma) 65 | 66 | 67 | def cropped_ratio(a, b): 68 | if a < 1E-50 and b < 1E-50: 69 | return 1. 70 | else: 71 | return a / b 72 | 73 | 74 | def integral_inf(fn): 75 | integral, _ = integrate.quad(fn, -np.inf, np.inf) 76 | return integral 77 | 78 | 79 | def integral_bounded(fn, lb, ub): 80 | integral, _ = integrate.quad(fn, lb, ub) 81 | return integral 82 | 83 | 84 | def distributions(sigma, q): 85 | mu0 = lambda y: pdf_gauss(y, sigma=sigma, mean=0.0) 86 | mu1 = lambda y: pdf_gauss(y, sigma=sigma, mean=1.0) 87 | mu = lambda y: (1 - q) * mu0(y) + q * mu1(y) 88 | return mu0, mu1, mu 89 | 90 | 91 | def compute_a(sigma, q, lmbd, verbose=False): 92 | lmbd_int = int(math.ceil(lmbd)) 93 | if lmbd_int == 0: 94 | return 1.0 95 | 96 | a_lambda_first_term_exact = 0 97 | a_lambda_second_term_exact = 0 98 | for i in range(lmbd_int + 1): 99 | coef_i = scipy.special.binom(lmbd_int, i) * (q ** i) 100 | s1, s2 = 0, 0 101 | for j in range(i + 1): 102 | coef_j = scipy.special.binom(i, j) * (-1) ** (i - j) 103 | s1 += coef_j * np.exp((j * j - j) / (2.0 * (sigma ** 2))) 104 | s2 += coef_j * np.exp((j * j + j) / (2.0 * (sigma ** 2))) 105 | a_lambda_first_term_exact += coef_i * s1 106 | a_lambda_second_term_exact += coef_i * s2 107 | 108 | a_lambda_exact = ((1.0 - q) * a_lambda_first_term_exact + 109 | q * a_lambda_second_term_exact) 110 | if verbose: 111 | print("A: by binomial expansion {} = {} + {}".format( 112 | a_lambda_exact, 113 | (1.0 - q) * a_lambda_first_term_exact, 114 | q * a_lambda_second_term_exact)) 115 | return _to_np_float64(a_lambda_exact) 116 | 117 | 118 | def compute_b(sigma, q, lmbd, verbose=False): 119 | mu0, _, mu = distributions(sigma, q) 120 | 121 | b_lambda_fn = lambda z: mu0(z) * np.power(cropped_ratio(mu0(z), mu(z)), lmbd) 122 | b_lambda = integral_inf(b_lambda_fn) 123 | m = sigma ** 2 * (np.log((2. - q) / (1. - q)) + 1. / (2 * sigma ** 2)) 124 | 125 | b_fn = lambda z: (np.power(mu0(z) / mu(z), lmbd) - 126 | np.power(mu(-z) / mu0(z), lmbd)) 127 | if verbose: 128 | print("M =", m) 129 | print("f(-M) = {} f(M) = {}".format(b_fn(-m), b_fn(m))) 130 | assert b_fn(-m) < 0 and b_fn(m) < 0 131 | 132 | b_lambda_int1_fn = lambda z: (mu0(z) * 133 | np.power(cropped_ratio(mu0(z), mu(z)), lmbd)) 134 | b_lambda_int2_fn = lambda z: (mu0(z) * 135 | np.power(cropped_ratio(mu(z), mu0(z)), lmbd)) 136 | b_int1 = integral_bounded(b_lambda_int1_fn, -m, m) 137 | b_int2 = integral_bounded(b_lambda_int2_fn, -m, m) 138 | 139 | a_lambda_m1 = compute_a(sigma, q, lmbd - 1) 140 | b_bound = a_lambda_m1 + b_int1 - b_int2 141 | 142 | if verbose: 143 | print("B: by numerical integration", b_lambda) 144 | print("B must be no more than ", b_bound) 145 | print(b_lambda, b_bound) 146 | return _to_np_float64(b_lambda) 147 | 148 | 149 | ########################### 150 | # MULTIPRECISION ROUTINES # 151 | ########################### 152 | 153 | 154 | def pdf_gauss_mp(x, sigma, mean): 155 | return mp.mpf(1.) / mp.sqrt(mp.mpf("2.") * sigma ** 2 * mp.pi) * mp.exp( 156 | - (x - mean) ** 2 / (mp.mpf("2.") * sigma ** 2)) 157 | 158 | 159 | def integral_inf_mp(fn): 160 | integral, _ = mp.quad(fn, [-mp.inf, mp.inf], error=True) 161 | return integral 162 | 163 | 164 | def integral_bounded_mp(fn, lb, ub): 165 | integral, _ = mp.quad(fn, [lb, ub], error=True) 166 | return integral 167 | 168 | 169 | def distributions_mp(sigma, q): 170 | mu0 = lambda y: pdf_gauss_mp(y, sigma=sigma, mean=mp.mpf(0)) 171 | mu1 = lambda y: pdf_gauss_mp(y, sigma=sigma, mean=mp.mpf(1)) 172 | mu = lambda y: (1 - q) * mu0(y) + q * mu1(y) 173 | return mu0, mu1, mu 174 | 175 | 176 | def compute_a_mp(sigma, q, lmbd, verbose=False): 177 | lmbd_int = int(math.ceil(lmbd)) 178 | if lmbd_int == 0: 179 | return 1.0 180 | 181 | mu0, mu1, mu = distributions_mp(sigma, q) 182 | a_lambda_fn = lambda z: mu(z) * (mu(z) / mu0(z)) ** lmbd_int 183 | a_lambda_first_term_fn = lambda z: mu0(z) * (mu(z) / mu0(z)) ** lmbd_int 184 | a_lambda_second_term_fn = lambda z: mu1(z) * (mu(z) / mu0(z)) ** lmbd_int 185 | 186 | a_lambda = integral_inf_mp(a_lambda_fn) 187 | a_lambda_first_term = integral_inf_mp(a_lambda_first_term_fn) 188 | a_lambda_second_term = integral_inf_mp(a_lambda_second_term_fn) 189 | 190 | if verbose: 191 | print("A: by numerical integration {} = {} + {}".format( 192 | a_lambda, 193 | (1 - q) * a_lambda_first_term, 194 | q * a_lambda_second_term)) 195 | 196 | return _to_np_float64(a_lambda) 197 | 198 | 199 | def compute_b_mp(sigma, q, lmbd, verbose=False): 200 | lmbd_int = int(math.ceil(lmbd)) 201 | if lmbd_int == 0: 202 | return 1.0 203 | 204 | mu0, _, mu = distributions_mp(sigma, q) 205 | 206 | b_lambda_fn = lambda z: mu0(z) * (mu0(z) / mu(z)) ** lmbd_int 207 | b_lambda = integral_inf_mp(b_lambda_fn) 208 | 209 | m = sigma ** 2 * (mp.log((2 - q) / (1 - q)) + 1 / (2 * (sigma ** 2))) 210 | b_fn = lambda z: ((mu0(z) / mu(z)) ** lmbd_int - 211 | (mu(-z) / mu0(z)) ** lmbd_int) 212 | if verbose: 213 | print("M =", m) 214 | print("f(-M) = {} f(M) = {}".format(b_fn(-m), b_fn(m))) 215 | assert b_fn(-m) < 0 and b_fn(m) < 0 216 | 217 | b_lambda_int1_fn = lambda z: mu0(z) * (mu0(z) / mu(z)) ** lmbd_int 218 | b_lambda_int2_fn = lambda z: mu0(z) * (mu(z) / mu0(z)) ** lmbd_int 219 | b_int1 = integral_bounded_mp(b_lambda_int1_fn, -m, m) 220 | b_int2 = integral_bounded_mp(b_lambda_int2_fn, -m, m) 221 | 222 | a_lambda_m1 = compute_a_mp(sigma, q, lmbd - 1) 223 | b_bound = a_lambda_m1 + b_int1 - b_int2 224 | 225 | if verbose: 226 | print("B by numerical integration", b_lambda) 227 | print("B must be no more than ", b_bound) 228 | assert b_lambda < b_bound + 1e-5 229 | return _to_np_float64(b_lambda) 230 | 231 | 232 | def _compute_delta(log_moments, eps): 233 | """Compute delta for given log_moments and eps. 234 | 235 | Args: 236 | log_moments: the log moments of privacy loss, in the form of pairs 237 | of (moment_order, log_moment) 238 | eps: the target epsilon. 239 | Returns: 240 | delta 241 | """ 242 | min_delta = 1.0 243 | for moment_order, log_moment in log_moments: 244 | if moment_order == 0: 245 | continue 246 | if math.isinf(log_moment) or math.isnan(log_moment): 247 | sys.stderr.write("The %d-th order is inf or Nan\n" % moment_order) 248 | continue 249 | if log_moment < moment_order * eps: 250 | min_delta = min(min_delta, 251 | math.exp(log_moment - moment_order * eps)) 252 | return min_delta 253 | 254 | 255 | def _compute_eps(log_moments, delta): 256 | """Compute epsilon for given log_moments and delta. 257 | 258 | Args: 259 | log_moments: the log moments of privacy loss, in the form of pairs 260 | of (moment_order, log_moment) 261 | delta: the target delta. 262 | Returns: 263 | epsilon 264 | """ 265 | min_eps = float("inf") 266 | for moment_order, log_moment in log_moments: 267 | if moment_order == 0: 268 | continue 269 | if math.isinf(log_moment) or math.isnan(log_moment): 270 | sys.stderr.write("The %d-th order is inf or Nan\n" % moment_order) 271 | continue 272 | min_eps = min(min_eps, (log_moment - math.log(delta)) / moment_order) 273 | return min_eps 274 | 275 | 276 | def compute_log_moment(q, sigma, steps, lmbd, verify=False, verbose=False): 277 | """Compute the log moment of Gaussian mechanism for given parameters. 278 | 279 | Args: 280 | q: the sampling ratio. 281 | sigma: the noise sigma. 282 | steps: the number of steps. 283 | lmbd: the moment order. 284 | verify: if False, only compute the symbolic version. If True, computes 285 | both symbolic and numerical solutions and verifies the results match. 286 | verbose: if True, print out debug information. 287 | Returns: 288 | the log moment with type np.float64, could be np.inf. 289 | """ 290 | moment = compute_a(sigma, q, lmbd, verbose=verbose) 291 | if verify: 292 | mp.dps = 50 293 | moment_a_mp = compute_a_mp(sigma, q, lmbd, verbose=verbose) 294 | moment_b_mp = compute_b_mp(sigma, q, lmbd, verbose=verbose) 295 | np.testing.assert_allclose(moment, moment_a_mp, rtol=1e-10) 296 | if not np.isinf(moment_a_mp): 297 | # The following test fails for (1, np.inf)! 298 | np.testing.assert_array_less(moment_b_mp, moment_a_mp) 299 | if np.isinf(moment): 300 | return np.inf 301 | else: 302 | return np.log(moment) * steps 303 | 304 | 305 | def get_privacy_spent(log_moments, target_eps=None, target_delta=None): 306 | """Compute delta (or eps) for given eps (or delta) from log moments. 307 | 308 | Args: 309 | log_moments: array of (moment_order, log_moment) pairs. 310 | target_eps: if not None, the epsilon for which we would like to compute 311 | corresponding delta value. 312 | target_delta: if not None, the delta for which we would like to compute 313 | corresponding epsilon value. Exactly one of target_eps and target_delta 314 | is None. 315 | Returns: 316 | eps, delta pair 317 | """ 318 | assert (target_eps is None) ^ (target_delta is None) 319 | assert not ((target_eps is None) and (target_delta is None)) 320 | if target_eps is not None: 321 | return (target_eps, _compute_delta(log_moments, target_eps)) 322 | else: 323 | return (_compute_eps(log_moments, target_delta), target_delta) 324 | -------------------------------------------------------------------------------- /CIFAR_tests/main_adadp.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | ''' 6 | 7 | A code for training in a differentially private manner the fully connected 8 | layers of a simple convolutive network using ADADP. 9 | Here the method is applied to the Cifar-10 data set. 10 | The parameters for the convolutive layers are loaded from a file "conv_layers.pt". 11 | 12 | The ADADP algorithm is described in 13 | 14 | Koskela, A. and Honkela, A., 15 | Learning rate adaptation for differentially private stochastic gradient descent. 16 | arXiv preprint arXiv:1809.03832. (2018) 17 | 18 | This code is due to Antti Koskela (@koskeant) and is based 19 | on a code by Mikko Heikkilä (@mixheikk). 20 | 21 | ''' 22 | 23 | 24 | 25 | 26 | 27 | 28 | import copy 29 | import datetime 30 | import numpy as np 31 | import pickle 32 | import sys 33 | import time 34 | import logging 35 | from collections import OrderedDict as od 36 | from matplotlib import pyplot as plt 37 | import argparse 38 | 39 | import torch 40 | from torch import nn 41 | from torch import optim 42 | from torch.autograd import Variable 43 | import torchvision 44 | 45 | from torchvision import datasets, transforms 46 | 47 | import linear 48 | 49 | import adadp 50 | 51 | import gaussian_moments as gm 52 | 53 | import itertools 54 | from types import SimpleNamespace 55 | import px_expander 56 | 57 | 58 | 59 | print(torch.__version__) 60 | print(torchvision.__version__) 61 | 62 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 63 | parser.add_argument('--batch_size', type=int, default=200, metavar='N', 64 | help='input batch size for training') 65 | parser.add_argument('--noise_sigma', type=float, default=8.0, metavar='M', 66 | help='noise_sigma') 67 | parser.add_argument('--n_epochs', type=int, default=100, metavar='N', 68 | help='n_epochs') 69 | parser.add_argument('--run_id', type=int, default=1, metavar='N', 70 | help='run_id') 71 | parser.add_argument('--tol', type=float, default=1.0, metavar='t', 72 | help='tolerance parameter') 73 | 74 | 75 | 76 | 77 | 78 | args = parser.parse_args() 79 | 80 | 81 | randomize_data = True 82 | batch_size = args.batch_size 83 | batch_proc_size = 10 # needs to divide or => to batch size 84 | 85 | n_hidden_layers = 1 # number of hidden layers in the feedforward network 86 | latent_dim = 500 #width of the hidden layers 87 | output_dim = 10 88 | log_interval = 6000//batch_size 89 | 90 | use_dp = True 91 | grad_norm_max = 3 92 | noise_sigma = args.noise_sigma 93 | delta = 1e-5 94 | 95 | tol = args.tol 96 | 97 | n_epochs = args.n_epochs 98 | l_rate = 0.01 99 | 100 | run_id = args.run_id 101 | 102 | 103 | np.random.seed(17*run_id+3) 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | if torch.cuda.is_available() and torch.cuda.device_count() > 0: 113 | print('Using cuda') 114 | torch.cuda.manual_seed(11*run_id+19) 115 | use_cuda = True 116 | else: 117 | use_cuda=False 118 | 119 | data_dir = './data/' 120 | 121 | 122 | 123 | 124 | 125 | 126 | transform = torchvision.transforms.Compose([]) 127 | 128 | 129 | transform_train = transforms.Compose([ 130 | transforms.RandomCrop(32, padding=4), 131 | transforms.RandomHorizontalFlip(), 132 | transforms.ToTensor(), 133 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 134 | ]) 135 | 136 | transform_test = transforms.Compose([ 137 | transforms.ToTensor(), 138 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 139 | ]) 140 | 141 | 142 | trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, 143 | download=True, transform=transform_train) 144 | 145 | testset = torchvision.datasets.CIFAR10(root=data_dir, train=False, 146 | download=True, transform=transform_test) 147 | 148 | sampling_ratio = batch_size/len(trainset) 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | # moments accountant 160 | 161 | def update_privacy_pars(priv_pars): 162 | verify = False 163 | max_lmbd = 32 164 | lmbds = range(1, max_lmbd + 1) 165 | log_moments = [] 166 | for lmbd in lmbds: 167 | log_moment = 0 168 | log_moment += gm.compute_log_moment(priv_pars['q'], priv_pars['sigma'], priv_pars['T'], lmbd, verify=verify) 169 | log_moments.append((lmbd, log_moment)) 170 | priv_pars['eps'], _ = gm.get_privacy_spent(log_moments, target_delta=priv_pars['delta']) 171 | return priv_pars 172 | 173 | 174 | 175 | 176 | 177 | 178 | # The convolutional part of the network 179 | 180 | 181 | class Net1(nn.Module): 182 | def __init__(self): 183 | super(Net1, self).__init__() 184 | 185 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=0) 186 | self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2) 187 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0) 188 | self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2) 189 | 190 | def forward(self, x): 191 | x = self.pool1(F.relu(self.conv1(x))) 192 | x = self.pool2(F.relu(self.conv2(x))) 193 | return x 194 | 195 | if use_cuda: 196 | model1 = Net1().cuda() 197 | else: 198 | model1 = Net1() 199 | 200 | 201 | 202 | # Load the pre-trained convolutive layers 203 | 204 | tb_save = torch.load('conv_layers.pt') 205 | 206 | for ii,p in enumerate(model1.parameters()): 207 | if(ii<4): 208 | p.data = tb_save[ii].clone() 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | # The fully connected part of the network 218 | 219 | class Net2(nn.Module): 220 | def __init__(self, batch_size, batch_proc_size): 221 | super(Net2, self).__init__() 222 | self.relu = nn.ReLU() 223 | 224 | self.batch_proc_size = batch_proc_size 225 | self.batch_size = batch_size 226 | 227 | self.linears = nn.ModuleList([ linear.Linear(1600, latent_dim, bias=False, batch_size=batch_proc_size)]) 228 | if n_hidden_layers > 0: 229 | for k in range(n_hidden_layers): 230 | self.linears.append( linear.Linear(latent_dim, latent_dim,bias=False,batch_size=batch_proc_size) ) 231 | self.final_fc = linear.Linear(self.linears[-1].out_features, output_dim,bias=False, batch_size=batch_proc_size) 232 | self.train_loader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, 233 | shuffle=randomize_data, num_workers=4) 234 | self.test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 235 | shuffle=randomize_data, num_workers=4) 236 | 237 | def forward(self, x): 238 | x = torch.unsqueeze(x.view(-1, 1600),1) 239 | for k_linear in self.linears: 240 | x = self.relu(k_linear(x)) 241 | x = self.final_fc(x) 242 | return nn.functional.log_softmax(x.view(-1,output_dim),dim=1) 243 | 244 | 245 | 246 | model2 = Net2(batch_size=batch_size, batch_proc_size=batch_proc_size) 247 | 248 | for p in model2.parameters(): 249 | if p is not None: 250 | p.data.copy_( p[0].data.clone().repeat(batch_proc_size,1,1) ) 251 | 252 | if use_cuda: 253 | model1 = model1.cuda() 254 | model2 = model2.cuda() 255 | 256 | 257 | loss_function = nn.NLLLoss(size_average=False) 258 | 259 | 260 | 261 | 262 | #optimizer = optim.SGD(filter(lambda p: p.requires_grad, model2.parameters()), lr=l_rate, momentum=0) 263 | optimizer = adadp.ADADP(model2.parameters()) 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | def train(epoch, model1, model2, T): 276 | 277 | model1.train() 278 | model2.train() 279 | 280 | for batch_idx, (data, target) in enumerate(model2.train_loader): 281 | 282 | if data.shape[0] != batch_size: 283 | continue 284 | 285 | optimizer.zero_grad() 286 | loss_tot = 0 287 | 288 | data, target = Variable(data, requires_grad=False), Variable(target, requires_grad=False) 289 | if use_cuda: 290 | data, target = data.cuda(), target.cuda() 291 | 292 | cum_grads = od() 293 | for i,p in enumerate(model2.parameters()): 294 | if p.requires_grad: 295 | if use_cuda: 296 | cum_grads[str(i)] = Variable(torch.zeros(p.shape[1:]),requires_grad=False).cuda() 297 | else: 298 | cum_grads[str(i)] = Variable(torch.zeros(p.shape[1:]),requires_grad=False) 299 | 300 | for i_batch in range(batch_size//batch_proc_size): 301 | 302 | data_proc = data[i_batch*batch_proc_size:(i_batch+1)*batch_proc_size,:] 303 | target_proc = target[i_batch*batch_proc_size:(i_batch+1)*batch_proc_size] 304 | 305 | output1 = model1(data_proc) 306 | output2 = model2(output1) 307 | 308 | loss = loss_function(output2,target_proc) 309 | loss_tot += loss.data 310 | 311 | loss.backward() 312 | 313 | if use_dp: 314 | px_expander.acc_scaled_grads(model=model2,C=grad_norm_max, cum_grads=cum_grads, use_cuda=use_cuda) 315 | optimizer.zero_grad() 316 | 317 | if use_dp: 318 | px_expander.add_noise_with_cum_grads(model=model2, C=grad_norm_max, sigma=noise_sigma, cum_grads=cum_grads, use_cuda=use_cuda) 319 | 320 | 321 | # step1 corresponds to the first part of ADADP (i.e. only one step of size half), 322 | # step2 to the second part (error estimate + step size adaptation) 323 | 324 | if batch_idx%2 is 0: 325 | optimizer.step1() 326 | else: 327 | optimizer.step2(tol) 328 | 329 | #optimizer.step() 330 | 331 | T += 1 332 | 333 | if batch_idx % log_interval == 0: 334 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 335 | epoch, batch_idx * len(data), len(model2.train_loader.dataset), 336 | 100. * batch_idx / len(model2.train_loader), loss_tot.item()/batch_size)) 337 | 338 | return T 339 | 340 | 341 | 342 | 343 | 344 | 345 | def test(model1, model2, epoch): 346 | 347 | model1.eval() 348 | model2.eval() 349 | 350 | test_loss = 0 351 | correct = 0 352 | 353 | for data, target in model2.test_loader: 354 | if data.shape[0] != model2.batch_size: 355 | print('skipped last batch') 356 | continue 357 | 358 | data, target = Variable(data, requires_grad=False), Variable(target, requires_grad=False) 359 | if use_cuda: 360 | data, target = data.cuda(), target.cuda() 361 | 362 | for i_batch in range(model2.batch_size//batch_proc_size): 363 | 364 | data_proc = data[i_batch*batch_proc_size:(i_batch+1)*batch_proc_size,:] 365 | target_proc = target[i_batch*batch_proc_size:(i_batch+1)*batch_proc_size] 366 | 367 | if use_cuda: 368 | data_proc = data_proc.cuda() 369 | target_proc = target_proc.cuda() 370 | 371 | output1 = model1(data_proc) 372 | output2 = model2(output1) 373 | 374 | test_loss += F.nll_loss(output2, target_proc, size_average=False).item() 375 | 376 | pred = output2.data.max(1, keepdim=True)[1] 377 | correct += pred.eq(target_proc.data.view_as(pred)).cpu().sum() 378 | 379 | test_loss /= len(model2.test_loader.dataset) 380 | acc = correct.numpy() / len(model2.test_loader.dataset) 381 | 382 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( 383 | test_loss, correct, len(model2.test_loader.dataset),100. * acc)) 384 | 385 | return test_loss, acc 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | priv_pars = od() 395 | priv_pars['T'], priv_pars['eps'],priv_pars['delta'], priv_pars['sigma'], priv_pars['q'] = 0, 0, delta, noise_sigma, sampling_ratio 396 | 397 | accs = [] 398 | epsilons = [] 399 | 400 | for epoch in range(1,n_epochs+1): 401 | 402 | loss, acc = test(model1, model2, epoch) 403 | accs.append(acc) 404 | print('Current privacy pars: {}'.format(priv_pars)) 405 | priv_pars['T'] = train(epoch, model1, model2, priv_pars['T']) 406 | 407 | if noise_sigma>0: 408 | 409 | update_privacy_pars(priv_pars) 410 | epsilons.append(priv_pars['eps']) 411 | 412 | 413 | 414 | # Save the test accuracies 415 | np.save('accs_' +str(run_id) + '_' + str(noise_sigma) + '_' + str(batch_size),accs) 416 | --------------------------------------------------------------------------------