├── Data └── __init__.py ├── images ├── __init__.py ├── Examples.png ├── Results.png └── GRNN.Details.png ├── Generator ├── __init__.py └── model.py ├── TFLogger ├── __init__.py └── logger.py ├── WassersteinDistance ├── __init__.py └── wd.py ├── Backbone ├── __init__.py ├── lenet.py ├── densenet.py └── resnet.py ├── LICENSE ├── README.md ├── utils.py └── GRNN.py /Data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /images/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Generator/__init__.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | -------------------------------------------------------------------------------- /TFLogger/__init__.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | -------------------------------------------------------------------------------- /WassersteinDistance/__init__.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | -------------------------------------------------------------------------------- /images/Examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rand2AI/GRNN/HEAD/images/Examples.png -------------------------------------------------------------------------------- /images/Results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rand2AI/GRNN/HEAD/images/Results.png -------------------------------------------------------------------------------- /images/GRNN.Details.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rand2AI/GRNN/HEAD/images/GRNN.Details.png -------------------------------------------------------------------------------- /Backbone/__init__.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | from Backbone.lenet import LeNet 3 | from Backbone.resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 4 | from Backbone.densenet import DenseNet121, DenseNet169, DenseNet201, DenseNet161 5 | -------------------------------------------------------------------------------- /WassersteinDistance/wd.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | import torch 3 | 4 | def rand_projections(dim, num_projections=1000): 5 | projections = torch.randn((num_projections, dim)) 6 | projections /= torch.sqrt(torch.sum(projections ** 2, dim=1, keepdim=True)) 7 | return projections 8 | 9 | def wasserstein_distance(first_samples, 10 | second_samples, 11 | p=2, 12 | device='cuda'): 13 | wasserstein_distance = torch.abs(first_samples[0] - second_samples[0]) 14 | wasserstein_distance = torch.pow(torch.sum(torch.pow(wasserstein_distance, p)), 1. / p) 15 | return torch.pow(torch.pow(wasserstein_distance, p).mean(), 1. / p) 16 | -------------------------------------------------------------------------------- /Backbone/lenet.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | import torch.nn as nn 3 | import torch.nn.functional as func 4 | 5 | 6 | class LeNet(nn.Module): 7 | def __init__(self, num_classes): 8 | super(LeNet, self).__init__() 9 | self.conv1 = nn.Conv2d(3, 6, kernel_size=5) 10 | self.conv2 = nn.Conv2d(6, 16, kernel_size=5) 11 | self.fc1 = nn.Linear(16*5*5, 120) 12 | self.fc2 = nn.Linear(120, 84) 13 | self.fc3 = nn.Linear(84, num_classes) 14 | 15 | def forward(self, x): 16 | x = func.sigmoid(self.conv1(x)) 17 | x = func.max_pool2d(x, 2) 18 | x = func.sigmoid(self.conv2(x)) 19 | x = func.max_pool2d(x, 2) 20 | x = x.view(x.size(0), -1) 21 | x = func.sigmoid(self.fc1(x)) 22 | x = func.sigmoid(self.fc2(x)) 23 | x = self.fc3(x) 24 | return x 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Rand2AI @ Vision and Machine Learning Lab, Swansea University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Generator/model.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | 6 | class GLU(nn.Module): 7 | def __init__(self): 8 | super(GLU, self).__init__() 9 | 10 | def forward(self, x): 11 | nc = x.size(1) 12 | assert nc % 2 == 0, 'channels dont divide 2!' 13 | nc = int(nc / 2) 14 | return x[:, :nc] * torch.sigmoid(x[:, nc:]) 15 | 16 | class Generator(nn.Module): 17 | # initializers 18 | def __init__(self, num_classes, shape_img, batchsize,channel=3, g_in=128, d = 32): 19 | super(Generator, self).__init__() 20 | self.g_in = g_in 21 | self.batchsize = batchsize 22 | 23 | self.fc2 = nn.Sequential( 24 | nn.Linear(g_in, num_classes) 25 | ) 26 | block_num = int(np.log2(shape_img) - 3) 27 | self.block0 = nn.Sequential( 28 | nn.ConvTranspose2d(g_in, d*pow(2,block_num) * 2, 4, 1, 0), 29 | GLU() 30 | ) 31 | self.blocks = nn.ModuleList() 32 | for bn in reversed(range(block_num)): 33 | self.blocks.append(self.upBlock(pow(2, bn + 1) * d, pow(2, bn) * d)) 34 | self.deconv_out = self.upBlock(d, channel) 35 | 36 | @staticmethod 37 | def upBlock(in_planes, out_planes): 38 | def conv3x3(in_planes, out_planes): 39 | "3x3 convolution with padding" 40 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 41 | padding=1, bias=False) 42 | 43 | block = nn.Sequential( 44 | nn.Upsample(scale_factor=2, mode='nearest'), 45 | conv3x3(in_planes, out_planes*2), 46 | nn.BatchNorm2d(out_planes*2), 47 | GLU() 48 | ) 49 | return block 50 | 51 | # weight_init 52 | def weight_init(self, mean, std): 53 | for m in self._modules: 54 | normal_init(self._modules[m], mean, std) 55 | 56 | # forward method 57 | def forward(self, x): 58 | y = torch.softmax(self.fc2(x), 1) 59 | x = x.view(self.batchsize, self.g_in, 1, 1) 60 | output = self.block0(x) 61 | # output = output.view(output.size(0), -1, 4, 4) 62 | for block in self.blocks: 63 | output = block(output) 64 | output = self.deconv_out(output) 65 | output = torch.sigmoid(output) 66 | return output, y 67 | 68 | def normal_init(m, mean, std): 69 | if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 70 | m.weight.data.normal_(mean, std) 71 | m.bias.data.zero_() 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GRNN: Generative Regression Neural Network - A Data Leakage Attack for Federated Learning 2 | 3 | ## Introduction 4 | 5 | This is the implementation of the paper "GRNN: Generative Regression Neural Network - A Data Leakage Attack for Federated Learning". In this paper, we show that, in Federated Learning (FL) system, image-based privacy data can be easily recovered in full from the shared gradient only via our proposed Generative Regression Neural Network (GRNN). We formulate the attack to be a regression problem and optimise two branches of the generative model by minimising the distance between gradients. We evaluate our method on several image classification tasks. The results illustrate that our proposed GRNN outperforms state-of-the-art methods with better stability, stronger robustness, and higher accuracy. It also has no convergence requirement to the global FL model. 6 | 7 |
8 | 9 | ## Requirements 10 | 11 | python==3.6.9 12 | 13 | torch==1.4.0 14 | 15 | torchvision==0.5.0 16 | 17 | numpy==1.18.2 18 | 19 | tqdm==4.45.0 20 | 21 | ... 22 | 23 | ## Examples 24 | 25 |
26 | 27 | ## Performance 28 | 29 |
30 | 31 | ## How to use 32 | 33 | ### Prepare your data: 34 | 35 | * Download LFW, VGGFace or ILSVRC datasets online respectively and extract them to ./Data/. 36 | 37 | * MNIST and CIFAR-100 can be downloaded automatically when you first run the script. 38 | 39 | ### Train GRNN and recover data, run: 40 | 41 | python GRNN.py 42 | 43 | ### Notes: 44 | 45 | * If only one GPU is available, please set: 46 | 47 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 48 | device0=0 49 | device1=0 50 | 51 | * Recovered images are all saved in: 52 | 53 | ./Results/ 54 | 55 | * No model is saved to local. 56 | 57 | ## Citation 58 | 59 | If you find this work helpful for your research, please cite the following paper: 60 | 61 | @article{10.1145/3510032, 62 | author = {Ren, Hanchi and Deng, Jingjing and Xie, Xianghua}, 63 | title = {GRNN: Generative Regression Neural Network - A Data Leakage Attack for Federated Learning}, 64 | year = {2022}, 65 | publisher = {Association for Computing Machinery}, 66 | address = {New York, NY, USA}, 67 | issn = {2157-6904}, 68 | url = {https://doi.org/10.1145/3510032}, 69 | journal = {ACM Trans. Intell. Syst. Technol.} 70 | } 71 | 72 | ## Acknowledgement 73 | 74 | We used the code part from DLG (https://github.com/mit-han-lab/dlg). Thanks for their excellent work very much. 75 | -------------------------------------------------------------------------------- /TFLogger/logger.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | import numpy as np 3 | from io import BytesIO 4 | import tensorflow as tf 5 | class TFLogger(object): 6 | def __init__(self, log_dir=None): 7 | """Create a summary writer logging to log_dir.""" 8 | self.writer = tf.compat.v1.summary.FileWriter(log_dir) 9 | 10 | def scalar_summary(self, tag, value, step): 11 | """Log a scalar variable.""" 12 | summary = tf.compat.v1.Summary(value=[tf.compat.v1.Summary.Value(tag=tag, simple_value=value)]) 13 | self.writer.add_summary(summary, step) 14 | self.writer.flush() 15 | 16 | def images_summary(self, tag, images, step): 17 | """Log a list of images.""" 18 | for i, imgs in enumerate(images): 19 | img_summaries = [] 20 | for bidx, img in enumerate(imgs): 21 | # Write the image to a string 22 | s = BytesIO() 23 | img.save(s, format="png") 24 | 25 | # Create an Image object 26 | img_sum = tf.compat.v1.Summary.Image(encoded_image_string=s.getvalue(), 27 | height=img.size[0], 28 | width=img.size[1]) 29 | # Create a Summary value 30 | img_summaries.append(tf.compat.v1.Summary.Value(tag='%d/%d' % (bidx, tag[bidx]), image=img_sum)) 31 | 32 | # Create and write Summary 33 | summary = tf.compat.v1.Summary(value=img_summaries) 34 | self.writer.add_summary(summary, step[i]) 35 | self.writer.flush() 36 | 37 | def image_summary(self, tag, image, step): 38 | # Write the image to a string 39 | s = BytesIO() 40 | image.save(s, format="png") 41 | 42 | # Create an Image object 43 | img_sum = tf.compat.v1.Summary.Image(encoded_image_string=s.getvalue(), 44 | height=image.size[0], 45 | width=image.size[1]) 46 | 47 | # Create and write Summary 48 | summary = tf.compat.v1.Summary(value=[tf.compat.v1.Summary.Value(tag='%s' % tag, image=img_sum)]) 49 | self.writer.add_summary(summary, step) 50 | self.writer.flush() 51 | 52 | def histo_summary(self, tag, values, step, bins=1000): 53 | """Log a histogram of the tensor of values.""" 54 | 55 | # Create a histogram using numpy 56 | counts, bin_edges = np.histogram(values, bins=bins) 57 | 58 | # Fill the fields of the histogram proto 59 | hist = tf.HistogramProto() 60 | hist.min = float(np.min(values)) 61 | hist.max = float(np.max(values)) 62 | hist.num = int(np.prod(values.shape)) 63 | hist.sum = float(np.sum(values)) 64 | hist.sum_squares = float(np.sum(values ** 2)) 65 | 66 | # Drop the start of the first bin 67 | bin_edges = bin_edges[1:] 68 | 69 | # Add bin edges and counts 70 | for edge in bin_edges: 71 | hist.bucket_limit.append(edge) 72 | for c in counts: 73 | hist.bucket.append(c) 74 | 75 | # Create and write Summary 76 | summary = tf.compat.v1.Summary(value=[tf.compat.v1.Summary.Value(tag=tag, histo=hist)]) 77 | self.writer.add_summary(summary, step) 78 | self.writer.flush() 79 | 80 | def close(self): 81 | self.writer.close() 82 | -------------------------------------------------------------------------------- /Backbone/densenet.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | import torch 3 | import torch.nn as nn 4 | 5 | class DenseBottleneck(nn.Module): 6 | def __init__(self, in_channels, growth_rate): 7 | super().__init__() 8 | inner_channel = 4 * growth_rate 9 | self.bottle_neck = nn.Sequential( 10 | nn.BatchNorm2d(in_channels), 11 | nn.Sigmoid(), 12 | nn.Conv2d(in_channels, inner_channel, kernel_size=1, bias=False), 13 | nn.BatchNorm2d(inner_channel), 14 | nn.Sigmoid(), 15 | nn.Conv2d(inner_channel, growth_rate, kernel_size=3, padding=1, bias=False) 16 | ) 17 | 18 | def forward(self, x): 19 | return torch.cat([x, self.bottle_neck(x)], 1) 20 | 21 | class Transition(nn.Module): 22 | def __init__(self, in_channels, out_channels): 23 | super().__init__() 24 | self.down_sample = nn.Sequential( 25 | nn.BatchNorm2d(in_channels), 26 | nn.Conv2d(in_channels, out_channels, 1, bias=False), 27 | nn.AvgPool2d(2, stride=2) 28 | ) 29 | 30 | def forward(self, x): 31 | return self.down_sample(x) 32 | 33 | class DenseNet(nn.Module): 34 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_class=100): 35 | super().__init__() 36 | self.growth_rate = growth_rate 37 | inner_channels = 2 * growth_rate 38 | self.conv1 = nn.Conv2d(3, inner_channels, kernel_size=3, padding=1, bias=False) 39 | self.features = nn.Sequential() 40 | for index in range(len(nblocks) - 1): 41 | self.features.add_module("dense_block_layer_{}".format(index), self._make_dense_layers(block, inner_channels, nblocks[index])) 42 | inner_channels += growth_rate * nblocks[index] 43 | out_channels = int(reduction * inner_channels) # int() will automatic floor the value 44 | self.features.add_module("transition_layer_{}".format(index), Transition(inner_channels, out_channels)) 45 | inner_channels = out_channels 46 | self.features.add_module("dense_block{}".format(len(nblocks) - 1), self._make_dense_layers(block, inner_channels, nblocks[len(nblocks)-1])) 47 | inner_channels += growth_rate * nblocks[len(nblocks) - 1] 48 | self.features.add_module('bn', nn.BatchNorm2d(inner_channels)) 49 | self.features.add_module('sigmoid', nn.Sigmoid()) 50 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 51 | self.linear = nn.Linear(inner_channels, num_class) 52 | 53 | def forward(self, x): 54 | output = self.conv1(x) 55 | output = self.features(output) 56 | output = self.avgpool(output) 57 | output = output.view(output.size()[0], -1) 58 | output = self.linear(output) 59 | return output 60 | 61 | def _make_dense_layers(self, block, in_channels, nblocks): 62 | dense_block = nn.Sequential() 63 | for index in range(nblocks): 64 | dense_block.add_module('bottle_neck_layer_{}'.format(index), block(in_channels, self.growth_rate)) 65 | in_channels += self.growth_rate 66 | return dense_block 67 | 68 | def DenseNet121(num_class): 69 | return DenseNet(DenseBottleneck, [6,12,24,16], growth_rate=32,num_class=num_class) 70 | 71 | def DenseNet169(num_class): 72 | return DenseNet(DenseBottleneck, [6,12,32,32], growth_rate=32,num_class=num_class) 73 | 74 | def DenseNet201(num_class): 75 | return DenseNet(DenseBottleneck, [6,12,48,32], growth_rate=32,num_class=num_class) 76 | 77 | def DenseNet161(num_class): 78 | return DenseNet(DenseBottleneck, [6,12,36,24], growth_rate=48,num_class=num_class) 79 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.data import Dataset 5 | from torchvision import datasets, transforms 6 | import os 7 | import numpy as np 8 | import PIL.Image as Image 9 | from WassersteinDistance.wd import wasserstein_distance 10 | 11 | def flatten_gradients(dy_dx): 12 | flatten_dy_dx = None 13 | for layer_g in dy_dx: 14 | if flatten_dy_dx is None: 15 | flatten_dy_dx = torch.flatten(layer_g) 16 | else: 17 | flatten_dy_dx = torch.cat((flatten_dy_dx, torch.flatten(layer_g))) 18 | return flatten_dy_dx 19 | 20 | def gen_dataset(dataset, data_path, shape_img): 21 | class Dataset_from_Image(Dataset): 22 | def __init__(self, imgs, labs, transform=None): 23 | self.imgs = imgs 24 | self.labs = labs 25 | self.transform = transform 26 | del imgs, labs 27 | 28 | def __len__(self): 29 | return self.labs.shape[0] 30 | 31 | def __getitem__(self, idx): 32 | lab = self.labs[idx] 33 | img = Image.open(self.imgs[idx]) 34 | if img.mode != 'RGB': 35 | img = img.convert('RGB') 36 | img = self.transform(img) 37 | return img, lab 38 | 39 | def face_dataset(path, num_classes): 40 | images_all = [] 41 | index_all = [] 42 | folders = os.listdir(path) 43 | for foldidx, fold in enumerate(folders): 44 | if foldidx+1==num_classes: break 45 | if os.path.isdir(os.path.join(path, fold)): 46 | files = os.listdir(os.path.join(path, fold)) 47 | for f in files: 48 | if len(f) > 4: 49 | images_all.append(os.path.join(path, fold, f)) 50 | index_all.append(foldidx) 51 | transform = transforms.Compose([transforms.Resize(shape_img), 52 | transforms.ToTensor()]) 53 | dst = Dataset_from_Image(images_all, np.asarray(index_all, dtype=int), transform=transform) 54 | return dst 55 | if dataset == 'mnist': 56 | num_classes = 10 57 | tt = transforms.Compose([transforms.Resize(shape_img), 58 | transforms.Grayscale(num_output_channels=3), 59 | transforms.ToTensor() 60 | ]) 61 | dst = datasets.MNIST(os.path.join(data_path, 'mnist/'), download=True, transform=tt) 62 | elif dataset == 'cifar100': 63 | num_classes = 100 64 | tt = transforms.Compose([transforms.Resize(shape_img), 65 | transforms.ToTensor()]) 66 | dst = datasets.CIFAR100(os.path.join(data_path, 'cifar100/'), download=True, transform=tt) 67 | elif dataset == 'lfw': 68 | num_classes = 5749 69 | dst = face_dataset(os.path.join(data_path, 'lfw/'), shape_img) 70 | elif dataset == 'VGGFace': 71 | num_classes = 2622 72 | dst = face_dataset(os.path.join(data_path, 'VGGFace/vgg_face_dataset/'), num_classes) 73 | else: 74 | exit('unknown dataset') 75 | return dst, num_classes 76 | 77 | def weights_init(m): 78 | try: 79 | if hasattr(m, 'weight'): 80 | m.weight.data.uniform_(-0.5, 0.5) 81 | except Exception: 82 | print('warning: failed in weights_init for %s.weight' % m._get_name()) 83 | try: 84 | if hasattr(m, 'bias') and m.bias is not None: 85 | m.bias.data.uniform_(-0.5, 0.5) 86 | except Exception: 87 | print('warning: failed in weights_init for %s.bias' % m._get_name()) 88 | 89 | class TVLoss(nn.Module): 90 | def __init__(self,TVLoss_weight=1): 91 | super(TVLoss,self).__init__() 92 | self.TVLoss_weight = TVLoss_weight 93 | 94 | def forward(self,x): 95 | batch_size = x.size()[0] 96 | h_x = x.size()[2] 97 | w_x = x.size()[3] 98 | count_h = self._tensor_size(x[:,:,1:,:]) 99 | count_w = self._tensor_size(x[:,:,:,1:]) 100 | h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() 101 | w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() 102 | return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size 103 | 104 | @staticmethod 105 | def _tensor_size(t): 106 | return t.size()[1]*t.size()[2]*t.size()[3] 107 | 108 | def loss_f(loss_name, flatten_fake_g, flatten_true_g, device): 109 | if loss_name == 'l2': 110 | grad_diff = ((flatten_fake_g - flatten_true_g) ** 2).sum() 111 | # grad_diff = torch.sqrt(((flatten_fake_g - flatten_true_g) ** 2).sum()) 112 | elif loss_name == 'wd': 113 | grad_diff = wasserstein_distance(flatten_fake_g.view(1, -1), flatten_true_g.view(1, -1), 114 | device=f'cuda:{device}') 115 | else: 116 | raise Exception('Wrong loss name.') 117 | return grad_diff -------------------------------------------------------------------------------- /Backbone/resnet.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | import torch.nn as nn 3 | 4 | class BasicBlock(nn.Module): 5 | """Basic Block for resnet 18 and resnet 34 6 | """ 7 | 8 | #BasicBlock and BottleNeck block 9 | #have different output size 10 | #we use class attribute expansion 11 | #to distinct 12 | expansion = 1 13 | 14 | def __init__(self, in_channels, out_channels, stride=1): 15 | super().__init__() 16 | 17 | #residual function 18 | self.residual_function = nn.Sequential( 19 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 20 | nn.BatchNorm2d(out_channels), 21 | nn.Sigmoid(), 22 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), 23 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 24 | ) 25 | 26 | #shortcut 27 | self.shortcut = nn.Sequential() 28 | 29 | #the shortcut output dimension is not the same with residual function 30 | #use 1*1 convolution to match the dimension 31 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 32 | self.shortcut = nn.Sequential( 33 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), 34 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 35 | ) 36 | 37 | def forward(self, x): 38 | return nn.Sigmoid()(self.residual_function(x) + self.shortcut(x)) 39 | 40 | class BottleNeck(nn.Module): 41 | """Residual block for resnet over 50 layers 42 | """ 43 | expansion = 4 44 | def __init__(self, in_channels, out_channels, stride=1): 45 | super().__init__() 46 | self.residual_function = nn.Sequential( 47 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 48 | nn.BatchNorm2d(out_channels), 49 | nn.Sigmoid(), 50 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), 51 | nn.BatchNorm2d(out_channels), 52 | nn.Sigmoid(), 53 | nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False), 54 | nn.BatchNorm2d(out_channels * BottleNeck.expansion), 55 | ) 56 | 57 | self.shortcut = nn.Sequential() 58 | 59 | if stride != 1 or in_channels != out_channels * BottleNeck.expansion: 60 | self.shortcut = nn.Sequential( 61 | nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False), 62 | nn.BatchNorm2d(out_channels * BottleNeck.expansion) 63 | ) 64 | 65 | def forward(self, x): 66 | return nn.Sigmoid()(self.residual_function(x) + self.shortcut(x)) 67 | 68 | class ResNet(nn.Module): 69 | 70 | def __init__(self, block, num_block, num_classes=100): 71 | super().__init__() 72 | 73 | self.in_channels = 64 74 | 75 | self.conv1 = nn.Sequential( 76 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), 77 | nn.BatchNorm2d(64), 78 | nn.Sigmoid()) 79 | #we use a different inputsize than the original paper 80 | #so conv2_x's stride is 1 81 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 82 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 83 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 84 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 85 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 86 | self.fc = nn.Linear(512 * block.expansion, num_classes) 87 | 88 | def _make_layer(self, block, out_channels, num_blocks, stride): 89 | """make resnet layers(by layer i didnt mean this 'layer' was the 90 | same as a neuron netowork layer, ex. conv layer), one layer may 91 | contain more than one residual block 92 | Args: 93 | block: block type, basic block or bottle neck block 94 | out_channels: output depth channel number of this layer 95 | num_blocks: how many blocks per layer 96 | stride: the stride of the first block of this layer 97 | Return: 98 | return a resnet layer 99 | """ 100 | 101 | # we have num_block blocks per layer, the first block 102 | # could be 1 or 2, other blocks would always be 1 103 | strides = [stride] + [1] * (num_blocks - 1) 104 | layers = [] 105 | for stride in strides: 106 | layers.append(block(self.in_channels, out_channels, stride)) 107 | self.in_channels = out_channels * block.expansion 108 | 109 | return nn.Sequential(*layers) 110 | 111 | def forward(self, x): 112 | output = self.conv1(x) 113 | output = self.conv2_x(output) 114 | output = self.conv3_x(output) 115 | output = self.conv4_x(output) 116 | output = self.conv5_x(output) 117 | output = self.avg_pool(output) 118 | output = output.view(output.size(0), -1) 119 | output = self.fc(output) 120 | 121 | return output 122 | 123 | def ResNet18(num_classes): 124 | """ return a ResNet 18 object 125 | """ 126 | return ResNet(BasicBlock, [2, 2, 2, 2],num_classes=num_classes) 127 | 128 | def ResNet34(num_classes): 129 | """ return a ResNet 34 object 130 | """ 131 | return ResNet(BasicBlock, [3, 4, 6, 3],num_classes=num_classes) 132 | 133 | def ResNet50(num_classes): 134 | """ return a ResNet 50 object 135 | """ 136 | return ResNet(BottleNeck, [3, 4, 6, 3],num_classes=num_classes) 137 | 138 | def ResNet101(num_classes): 139 | """ return a ResNet 101 object 140 | """ 141 | return ResNet(BottleNeck, [3, 4, 23, 3],num_classes=num_classes) 142 | 143 | def ResNet152(num_classes): 144 | """ return a ResNet 152 object 145 | """ 146 | return ResNet(BottleNeck, [3, 8, 36, 3],num_classes=num_classes) 147 | -------------------------------------------------------------------------------- /GRNN.py: -------------------------------------------------------------------------------- 1 | import time, datetime 2 | from tqdm import tqdm 3 | import matplotlib.pyplot as plt 4 | from utils import * 5 | from Generator.model import Generator 6 | from TFLogger.logger import TFLogger 7 | from Backbone import * 8 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 9 | 10 | def main(): 11 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 12 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1" 13 | device0 = 0 # for GRNN training 14 | device1 = 1 # for local training, if you have only one GPU, please set device1 to 0 15 | batchsize = 1 16 | save_img = True # whether same generated image and its relevant true image 17 | Iteration = 1000 # how many optimization steps on GRNN 18 | num_exp = 10 # experiment number 19 | g_in = 128 # dimention of GRNN input 20 | plot_num = 30 21 | net_name = 'lenet' # global model 22 | net_name_set = ['lenet', 'res18'] 23 | dataset = 'lfw' 24 | dataset_set = ['mnist', 'cifar100', 'lfw', 'VGGFace', 'ilsvrc'] 25 | shape_img = (32, 32) 26 | root_path = os.path.abspath(os.curdir) 27 | data_path = os.path.join(root_path, 'Data/') 28 | save_path = os.path.join(root_path, f"Results/GRNN-{net_name}-{dataset}-S{shape_img[0]}-B{str(batchsize).zfill(3)}-{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}/") # path to saving results 29 | print('>' * 10, save_path) 30 | save_img_path = os.path.join(save_path, 'saved_img/') 31 | dst, num_classes= gen_dataset(dataset, data_path, shape_img) # read local data 32 | tp = transforms.Compose([transforms.ToPILImage()]) 33 | train_loader = iter(torch.utils.data.DataLoader(dst, batch_size=batchsize, shuffle=True)) 34 | criterion = nn.CrossEntropyLoss().cuda(device1) 35 | print(f'{str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))}: {save_path}') 36 | for idx_net in range(num_exp): 37 | train_tfLogger = TFLogger(f'{save_path}/tfrecoard-exp-{str(idx_net).zfill(2)}') # tensorboard record 38 | print(f'{str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))}: running {idx_net+1}|{num_exp} experiment') 39 | if net_name == 'lenet': 40 | net = LeNet(num_classes=num_classes) 41 | elif net_name == 'res18': 42 | net = ResNet18(num_classes=num_classes) 43 | net = net.cuda(device1) 44 | Gnet = Generator(num_classes, channel=3, shape_img=shape_img[0], 45 | batchsize=batchsize, g_in=g_in).cuda(device0) 46 | net.apply(weights_init) 47 | Gnet.weight_init(mean=0.0, std=0.02) 48 | G_optimizer = torch.optim.RMSprop(Gnet.parameters(), lr=0.0001, momentum=0.99) 49 | tv_loss = TVLoss() 50 | gt_data,gt_label = next(train_loader) 51 | gt_data, gt_label = gt_data.cuda(device1), gt_label.cuda(device1) # assign to device1 to generate true graident 52 | pred = net(gt_data) 53 | y = criterion(pred, gt_label) 54 | dy_dx = torch.autograd.grad(y, net.parameters()) # obtain true gradient 55 | flatten_true_g = flatten_gradients(dy_dx) 56 | G_ran_in = torch.randn(batchsize, g_in).cuda(device0) # initialize GRNN input 57 | iter_bar = tqdm(range(Iteration), 58 | total=Iteration, 59 | desc=f'{str(time.strftime("[%Y-%m-%d %H:%M:%S]", time.localtime()))}', 60 | ncols=180) 61 | history = [] 62 | history_l = [] 63 | for iters in iter_bar: # start optimizing GRNN 64 | Gout, Glabel = Gnet(G_ran_in) # produce recovered data 65 | Gout, Glabel = Gout.cuda(device1), Glabel.cuda(device1) # assign to device1 as global model's input to generate fake gradient 66 | Gpred = net(Gout) 67 | Gloss = - torch.mean(torch.sum(Glabel * torch.log(torch.softmax(Gpred, 1)), dim=-1)) # cross-entropy loss 68 | G_dy_dx = torch.autograd.grad(Gloss, net.parameters(), create_graph=True) # obtain fake gradient 69 | flatten_fake_g = flatten_gradients(G_dy_dx).cuda(device1) 70 | grad_diff_l2 = loss_f('l2', flatten_fake_g, flatten_true_g, device1) 71 | grad_diff_wd = loss_f('wd', flatten_fake_g, flatten_true_g, device1) 72 | if net_name == 'lenet': 73 | tvloss = 1e-3 * tv_loss(Gout) 74 | elif net_name == 'res18': 75 | tvloss = 1e-6 * tv_loss(Gout) 76 | grad_diff = grad_diff_l2 + grad_diff_wd + tvloss # loss for GRNN 77 | G_optimizer.zero_grad() 78 | grad_diff.backward() 79 | G_optimizer.step() 80 | iter_bar.set_postfix(loss_l2 = np.round(grad_diff_l2.item(), 8), 81 | loss_wd=np.round(grad_diff_wd.item(), 8), 82 | loss_tv = np.round(tvloss.item(), 8), 83 | img_mses=round(torch.mean(abs(Gout-gt_data)).item(), 8), 84 | img_wd=round(wasserstein_distance(Gout.view(1,-1), gt_data.view(1,-1)).item(), 8)) 85 | 86 | train_tfLogger.scalar_summary('g_l2', grad_diff_l2.item(), iters) 87 | train_tfLogger.scalar_summary('g_wd', grad_diff_wd.item(), iters) 88 | train_tfLogger.scalar_summary('g_tv', tvloss.item(), iters) 89 | train_tfLogger.scalar_summary('img_mses', torch.mean(abs(Gout-gt_data)).item(), iters) 90 | train_tfLogger.scalar_summary('img_wd', wasserstein_distance(Gout.view(1,-1), gt_data.view(1,-1)).item(), iters) 91 | train_tfLogger.scalar_summary('toal_loss', grad_diff.item(), iters) 92 | 93 | if iters % int(Iteration / plot_num) == 0: 94 | history.append([tp(Gout[imidx].detach().cpu()) for imidx in range(batchsize)]) 95 | history_l.append([Glabel.argmax(dim=1)[imidx].item() for imidx in range(batchsize)]) 96 | torch.cuda.empty_cache() 97 | del Gloss, G_dy_dx, flatten_fake_g, grad_diff_l2, grad_diff_wd, grad_diff, tvloss 98 | 99 | # visualization 100 | for imidx in range(batchsize): 101 | plt.figure(figsize=(12, 8)) 102 | plt.subplot(plot_num//10, 10, 1) 103 | plt.imshow(tp(gt_data[imidx].cpu())) 104 | for i in range(min(len(history), plot_num-1)): 105 | plt.subplot(plot_num//10, 10, i + 2) 106 | plt.imshow(history[i][imidx]) 107 | plt.title('l=%d' % (history_l[i][imidx])) 108 | # plt.title('i=%d,l=%d' % (history_iters[i], history_l[i][imidx])) 109 | plt.axis('off') 110 | if not os.path.exists(save_path): 111 | os.makedirs(save_path) 112 | if save_img: 113 | true_path = os.path.join(save_img_path, f'true_data/exp{str(idx_net).zfill(3)}/') 114 | fake_path = os.path.join(save_img_path, f'fake_data/exp{str(idx_net).zfill(3)}/') 115 | if not os.path.exists(true_path) or not os.path.exists(fake_path): 116 | os.makedirs(true_path) 117 | os.makedirs(fake_path) 118 | tp(gt_data[imidx].cpu()).save(os.path.join(true_path, f'{imidx}_{gt_label[imidx].item()}.png')) 119 | history[-1][imidx].save(os.path.join(fake_path, f'{imidx}_{Glabel.argmax(dim=1)[imidx].item()}.png')) 120 | plt.savefig(save_path + '/exp:%03d-imidx:%02d-tlabel:%d-Glabel:%d.png' % (idx_net,imidx , gt_label[imidx].item(),Glabel.argmax(dim=1)[imidx].item())) 121 | plt.close() 122 | 123 | del Glabel, Gout, flatten_true_g, G_ran_in, net, Gnet 124 | torch.cuda.empty_cache() 125 | history.clear() 126 | history_l.clear() 127 | iter_bar.close() 128 | train_tfLogger.close() 129 | print('----------------------') 130 | 131 | if __name__ == '__main__': 132 | main() --------------------------------------------------------------------------------