├── .gitignore ├── LICENSE ├── README.md ├── helper.py ├── image_helper.py ├── models ├── __init__.py ├── cifar_model.py ├── dense_efficient.py ├── densenet.py ├── model_c.py ├── pytorch_resnet.py ├── resnet.py ├── simple.py └── word_model.py ├── requirements.txt ├── text_helper.py ├── training.py └── utils ├── __init__.py ├── params.yaml ├── params_runner.yaml ├── text_load.py ├── utils.py └── words.yaml /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | 107 | .idea/ 108 | .DS_Store 109 | *.iml -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Eugene Bagdasaryan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **11/20/2020: We are developing a new framework for backdoors with FL: [Backdoors101](https://github.com/ebagdasa/backdoors101).** 2 | It extends to many new attacks (clean-label, physical backdoors, etc) and has improved user experience. Check it out! 3 | 4 | # backdoor_federated_learning 5 | This code includes experiments for paper "How to Backdoor Federated Learning" (https://arxiv.org/abs/1807.00459) 6 | 7 | 8 | All experiments are done using Python 3.7 and PyTorch 1.0. 9 | 10 | ```mkdir saved_models``` 11 | 12 | ```python training.py --params utils/params.yaml``` 13 | 14 | 15 | I encourage to contact me (eugene@cs.cornell.edu) or raise Issues in GitHub, so I can provide more details and fix bugs. 16 | 17 | Most of the experiments resulted by tweaking parameters in utils/params.yaml (for images) 18 | and utils/words.yaml (for text), you can play with them yourself. 19 | 20 | ## Reddit dataset 21 | * Corpus parsed dataset: https://drive.google.com/file/d/1qTfiZP4g2ZPS5zlxU51G-GDCGGr23nvt/view?usp=sharing 22 | * Whole dataset: https://drive.google.com/file/d/1yAmEbx7ZCeL45hYj5iEOvNv7k9UoX3vp/view?usp=sharing 23 | * Dictionary: https://drive.google.com/file/d/1gnS5CO5fGXKAGfHSzV3h-2TsjZXQXe39/view?usp=sharing 24 | 25 | 26 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | from shutil import copyfile 2 | 3 | import math 4 | import torch 5 | 6 | from torch.autograd import Variable 7 | import logging 8 | 9 | from torch.nn.functional import log_softmax 10 | import torch.nn.functional as F 11 | logger = logging.getLogger("logger") 12 | import os 13 | 14 | 15 | class Helper: 16 | def __init__(self, current_time, params, name): 17 | self.current_time = current_time 18 | self.target_model = None 19 | self.local_model = None 20 | 21 | self.train_data = None 22 | self.test_data = None 23 | self.poisoned_data = None 24 | self.test_data_poison = None 25 | 26 | self.params = params 27 | self.name = name 28 | self.best_loss = math.inf 29 | self.folder_path = f'saved_models/model_{self.name}_{current_time}' 30 | try: 31 | os.mkdir(self.folder_path) 32 | except FileExistsError: 33 | logger.info('Folder already exists') 34 | logger.addHandler(logging.FileHandler(filename=f'{self.folder_path}/log.txt')) 35 | logger.addHandler(logging.StreamHandler()) 36 | logger.setLevel(logging.DEBUG) 37 | logger.info(f'current path: {self.folder_path}') 38 | if not self.params.get('environment_name', False): 39 | self.params['environment_name'] = self.name 40 | 41 | self.params['current_time'] = self.current_time 42 | self.params['folder_path'] = self.folder_path 43 | 44 | def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'): 45 | if not self.params['save_model']: 46 | return False 47 | torch.save(state, filename) 48 | 49 | if is_best: 50 | copyfile(filename, 'model_best.pth.tar') 51 | 52 | @staticmethod 53 | def model_global_norm(model): 54 | squared_sum = 0 55 | for name, layer in model.named_parameters(): 56 | squared_sum += torch.sum(torch.pow(layer.data, 2)) 57 | return math.sqrt(squared_sum) 58 | 59 | @staticmethod 60 | def model_dist_norm(model, target_params): 61 | squared_sum = 0 62 | for name, layer in model.named_parameters(): 63 | squared_sum += torch.sum(torch.pow(layer.data - target_params[name].data, 2)) 64 | return math.sqrt(squared_sum) 65 | 66 | 67 | @staticmethod 68 | def model_max_values(model, target_params): 69 | squared_sum = list() 70 | for name, layer in model.named_parameters(): 71 | squared_sum.append(torch.max(torch.abs(layer.data - target_params[name].data))) 72 | return squared_sum 73 | 74 | 75 | @staticmethod 76 | def model_max_values_var(model, target_params): 77 | squared_sum = list() 78 | for name, layer in model.named_parameters(): 79 | squared_sum.append(torch.max(torch.abs(layer - target_params[name]))) 80 | return sum(squared_sum) 81 | 82 | @staticmethod 83 | def get_one_vec(model, variable=False): 84 | size = 0 85 | for name, layer in model.named_parameters(): 86 | if name == 'decoder.weight': 87 | continue 88 | size += layer.view(-1).shape[0] 89 | if variable: 90 | sum_var = Variable(torch.cuda.FloatTensor(size).fill_(0)) 91 | else: 92 | sum_var = torch.cuda.FloatTensor(size).fill_(0) 93 | size = 0 94 | for name, layer in model.named_parameters(): 95 | if name == 'decoder.weight': 96 | continue 97 | if variable: 98 | sum_var[size:size + layer.view(-1).shape[0]] = (layer).view(-1) 99 | else: 100 | sum_var[size:size + layer.view(-1).shape[0]] = (layer.data).view(-1) 101 | size += layer.view(-1).shape[0] 102 | 103 | return sum_var 104 | 105 | @staticmethod 106 | def model_dist_norm_var(model, target_params_variables, norm=2): 107 | size = 0 108 | for name, layer in model.named_parameters(): 109 | size += layer.view(-1).shape[0] 110 | sum_var = torch.cuda.FloatTensor(size).fill_(0) 111 | size = 0 112 | for name, layer in model.named_parameters(): 113 | sum_var[size:size + layer.view(-1).shape[0]] = ( 114 | layer - target_params_variables[name]).view(-1) 115 | size += layer.view(-1).shape[0] 116 | 117 | return torch.norm(sum_var, norm) 118 | 119 | 120 | def cos_sim_loss(self, model, target_vec): 121 | model_vec = self.get_one_vec(model, variable=True) 122 | target_var = Variable(target_vec, requires_grad=False) 123 | # target_vec.requires_grad = False 124 | cs_sim = torch.nn.functional.cosine_similarity(self.params['scale_weights']*(model_vec-target_var) + target_var, target_var, dim=0) 125 | # cs_sim = cs_loss(model_vec, target_vec) 126 | logger.info("los") 127 | logger.info( cs_sim.data[0]) 128 | logger.info(torch.norm(model_vec - target_var).data[0]) 129 | loss = 1-cs_sim 130 | 131 | return 1e3*loss 132 | 133 | 134 | 135 | def model_cosine_similarity(self, model, target_params_variables, 136 | model_id='attacker'): 137 | 138 | cs_list = list() 139 | cs_loss = torch.nn.CosineSimilarity(dim=0) 140 | for name, data in model.named_parameters(): 141 | if name == 'decoder.weight': 142 | continue 143 | 144 | model_update = 100*(data.view(-1) - target_params_variables[name].view(-1)) + target_params_variables[name].view(-1) 145 | 146 | 147 | cs = F.cosine_similarity(model_update, 148 | target_params_variables[name].view(-1), dim=0) 149 | # logger.info(torch.equal(layer.view(-1), 150 | # target_params_variables[name].view(-1))) 151 | # logger.info(name) 152 | # logger.info(cs.data[0]) 153 | # logger.info(torch.norm(model_update).data[0]) 154 | # logger.info(torch.norm(fake_weights[name])) 155 | cs_list.append(cs) 156 | cos_los_submit = 1*(1-sum(cs_list)/len(cs_list)) 157 | logger.info(model_id) 158 | logger.info((sum(cs_list)/len(cs_list)).data[0]) 159 | return 1e3*sum(cos_los_submit) 160 | 161 | def accum_similarity(self, last_acc, new_acc): 162 | 163 | cs_list = list() 164 | 165 | cs_loss = torch.nn.CosineSimilarity(dim=0) 166 | # logger.info('new run') 167 | for name, layer in last_acc.items(): 168 | 169 | cs = cs_loss(Variable(last_acc[name], requires_grad=False).view(-1), 170 | Variable(new_acc[name], requires_grad=False).view(-1) 171 | 172 | ) 173 | # logger.info(torch.equal(layer.view(-1), 174 | # target_params_variables[name].view(-1))) 175 | # logger.info(name) 176 | # logger.info(cs.data[0]) 177 | # logger.info(torch.norm(model_update).data[0]) 178 | # logger.info(torch.norm(fake_weights[name])) 179 | cs_list.append(cs) 180 | cos_los_submit = 1*(1-sum(cs_list)/len(cs_list)) 181 | # logger.info("AAAAAAAA") 182 | # logger.info((sum(cs_list)/len(cs_list)).data[0]) 183 | return sum(cos_los_submit) 184 | 185 | 186 | 187 | 188 | @staticmethod 189 | def dp_noise(param, sigma): 190 | 191 | noised_layer = torch.cuda.FloatTensor(param.shape).normal_(mean=0, std=sigma) 192 | 193 | return noised_layer 194 | 195 | def average_shrink_models(self, weight_accumulator, target_model, epoch): 196 | """ 197 | Perform FedAvg algorithm and perform some clustering on top of it. 198 | 199 | """ 200 | 201 | for name, data in target_model.state_dict().items(): 202 | if self.params.get('tied', False) and name == 'decoder.weight': 203 | continue 204 | 205 | update_per_layer = weight_accumulator[name] * \ 206 | (self.params["eta"] / self.params["number_of_total_participants"]) 207 | 208 | if self.params['diff_privacy']: 209 | update_per_layer.add_(self.dp_noise(data, self.params['sigma'])) 210 | 211 | data.add_(update_per_layer) 212 | 213 | return True 214 | 215 | def save_model(self, model=None, epoch=0, val_loss=0): 216 | if model is None: 217 | model = self.target_model 218 | if self.params['save_model']: 219 | # save_model 220 | logger.info("saving model") 221 | model_name = '{0}/model_last.pt.tar'.format(self.params['folder_path']) 222 | saved_dict = {'state_dict': model.state_dict(), 'epoch': epoch, 223 | 'lr': self.params['lr']} 224 | self.save_checkpoint(saved_dict, False, model_name) 225 | if epoch in self.params['save_on_epochs']: 226 | logger.info(f'Saving model on epoch {epoch}') 227 | self.save_checkpoint(saved_dict, False, filename=f'{model_name}.epoch_{epoch}') 228 | if val_loss < self.best_loss: 229 | self.save_checkpoint(saved_dict, False, f'{model_name}.best') 230 | self.best_loss = val_loss 231 | 232 | def estimate_fisher(self, model, criterion, 233 | data_loader, sample_size, batch_size=64): 234 | # sample loglikelihoods from the dataset. 235 | loglikelihoods = [] 236 | if self.params['type'] == 'text': 237 | data_iterator = range(0, data_loader.size(0) - 1, self.params['bptt']) 238 | hidden = model.init_hidden(self.params['batch_size']) 239 | else: 240 | data_iterator = data_loader 241 | 242 | for batch_id, batch in enumerate(data_iterator): 243 | data, targets = self.get_batch(data_loader, batch, 244 | evaluation=False) 245 | if self.params['type'] == 'text': 246 | hidden = self.repackage_hidden(hidden) 247 | output, hidden = model(data, hidden) 248 | loss = criterion(output.view(-1, self.n_tokens), targets) 249 | else: 250 | output = model(data) 251 | loss = log_softmax(output, dim=1)[range(targets.shape[0]), targets.data] 252 | # loss = criterion(output.view(-1, ntokens 253 | # output, hidden = model(data, hidden) 254 | loglikelihoods.append(loss) 255 | # loglikelihoods.append( 256 | # log_softmax(output.view(-1, self.n_tokens))[range(self.params['batch_size']), targets.data] 257 | # ) 258 | 259 | # if len(loglikelihoods) >= sample_size // batch_size: 260 | # break 261 | logger.info(loglikelihoods[0].shape) 262 | # estimate the fisher information of the parameters. 263 | loglikelihood = torch.cat(loglikelihoods).mean(0) 264 | logger.info(loglikelihood.shape) 265 | loglikelihood_grads = torch.autograd.grad(loglikelihood, model.parameters()) 266 | 267 | parameter_names = [ 268 | n.replace('.', '__') for n, p in model.named_parameters() 269 | ] 270 | return {n: g ** 2 for n, g in zip(parameter_names, loglikelihood_grads)} 271 | 272 | def consolidate(self, model, fisher): 273 | for n, p in model.named_parameters(): 274 | n = n.replace('.', '__') 275 | model.register_buffer('{}_estimated_mean'.format(n), p.data.clone()) 276 | model.register_buffer('{}_estimated_fisher' 277 | .format(n), fisher[n].data.clone()) 278 | 279 | def ewc_loss(self, model, lamda, cuda=False): 280 | try: 281 | losses = [] 282 | for n, p in model.named_parameters(): 283 | # retrieve the consolidated mean and fisher information. 284 | n = n.replace('.', '__') 285 | mean = getattr(model, '{}_estimated_mean'.format(n)) 286 | fisher = getattr(model, '{}_estimated_fisher'.format(n)) 287 | # wrap mean and fisher in variables. 288 | mean = Variable(mean) 289 | fisher = Variable(fisher) 290 | # calculate a ewc loss. (assumes the parameter's prior as 291 | # gaussian distribution with the estimated mean and the 292 | # estimated cramer-rao lower bound variance, which is 293 | # equivalent to the inverse of fisher information) 294 | losses.append((fisher * (p - mean) ** 2).sum()) 295 | return (lamda / 2) * sum(losses) 296 | except AttributeError: 297 | # ewc loss is 0 if there's no consolidated parameters. 298 | return ( 299 | Variable(torch.zeros(1)).cuda() if cuda else 300 | Variable(torch.zeros(1)) 301 | ) 302 | 303 | -------------------------------------------------------------------------------- /image_helper.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | import torch.utils.data 5 | 6 | from helper import Helper 7 | import random 8 | import logging 9 | from torchvision import datasets, transforms 10 | import numpy as np 11 | 12 | from models.resnet import ResNet18 13 | from models.word_model import RNNModel 14 | from utils.text_load import * 15 | from utils.utils import SubsetSampler 16 | 17 | logger = logging.getLogger("logger") 18 | POISONED_PARTICIPANT_POS = 0 19 | 20 | 21 | 22 | class ImageHelper(Helper): 23 | 24 | 25 | def poison(self): 26 | return 27 | 28 | def create_model(self): 29 | local_model = ResNet18(name='Local', 30 | created_time=self.params['current_time']) 31 | local_model.cuda() 32 | target_model = ResNet18(name='Target', 33 | created_time=self.params['current_time']) 34 | target_model.cuda() 35 | if self.params['resumed_model']: 36 | loaded_params = torch.load(f"saved_models/{self.params['resumed_model']}") 37 | target_model.load_state_dict(loaded_params['state_dict']) 38 | self.start_epoch = loaded_params['epoch'] 39 | self.params['lr'] = loaded_params.get('lr', self.params['lr']) 40 | logger.info(f"Loaded parameters from saved model: LR is" 41 | f" {self.params['lr']} and current epoch is {self.start_epoch}") 42 | else: 43 | self.start_epoch = 1 44 | 45 | self.local_model = local_model 46 | self.target_model = target_model 47 | 48 | 49 | def sample_dirichlet_train_data(self, no_participants, alpha=0.9): 50 | """ 51 | Input: Number of participants and alpha (param for distribution) 52 | Output: A list of indices denoting data in CIFAR training set. 53 | Requires: cifar_classes, a preprocessed class-indice dictionary. 54 | Sample Method: take a uniformly sampled 10-dimension vector as parameters for 55 | dirichlet distribution to sample number of images in each class. 56 | """ 57 | 58 | cifar_classes = {} 59 | for ind, x in enumerate(self.train_dataset): 60 | _, label = x 61 | if ind in self.params['poison_images'] or ind in self.params['poison_images_test']: 62 | continue 63 | if label in cifar_classes: 64 | cifar_classes[label].append(ind) 65 | else: 66 | cifar_classes[label] = [ind] 67 | class_size = len(cifar_classes[0]) 68 | per_participant_list = defaultdict(list) 69 | no_classes = len(cifar_classes.keys()) 70 | 71 | for n in range(no_classes): 72 | random.shuffle(cifar_classes[n]) 73 | sampled_probabilities = class_size * np.random.dirichlet( 74 | np.array(no_participants * [alpha])) 75 | for user in range(no_participants): 76 | no_imgs = int(round(sampled_probabilities[user])) 77 | sampled_list = cifar_classes[n][:min(len(cifar_classes[n]), no_imgs)] 78 | per_participant_list[user].extend(sampled_list) 79 | cifar_classes[n] = cifar_classes[n][min(len(cifar_classes[n]), no_imgs):] 80 | 81 | return per_participant_list 82 | 83 | def poison_dataset(self): 84 | # 85 | # return [(self.train_dataset[self.params['poison_image_id']][0], 86 | # torch.IntTensor(self.params['poison_label_swap']))] 87 | cifar_classes = {} 88 | for ind, x in enumerate(self.train_dataset): 89 | _, label = x 90 | if ind in self.params['poison_images'] or ind in self.params['poison_images_test']: 91 | continue 92 | if label in cifar_classes: 93 | cifar_classes[label].append(ind) 94 | else: 95 | cifar_classes[label] = [ind] 96 | indices = list() 97 | # create array that starts with poisoned images 98 | 99 | #create candidates: 100 | # range_no_id = cifar_classes[1] 101 | # range_no_id.extend(cifar_classes[1]) 102 | range_no_id = list(range(50000)) 103 | for image in self.params['poison_images'] + self.params['poison_images_test']: 104 | if image in range_no_id: 105 | range_no_id.remove(image) 106 | 107 | # add random images to other parts of the batch 108 | for batches in range(0, self.params['size_of_secret_dataset']): 109 | range_iter = random.sample(range_no_id, 110 | self.params['batch_size']) 111 | # range_iter[0] = self.params['poison_images'][0] 112 | indices.extend(range_iter) 113 | # range_iter = random.sample(range_no_id, 114 | # self.params['batch_size'] 115 | # -len(self.params['poison_images'])*self.params['poisoning_per_batch']) 116 | # for i in range(0, self.params['poisoning_per_batch']): 117 | # indices.extend(self.params['poison_images']) 118 | # indices.extend(range_iter) 119 | return torch.utils.data.DataLoader(self.train_dataset, 120 | batch_size=self.params['batch_size'], 121 | sampler=torch.utils.data.sampler.SubsetRandomSampler(indices)) 122 | 123 | def poison_test_dataset(self): 124 | # 125 | # return [(self.train_dataset[self.params['poison_image_id']][0], 126 | # torch.IntTensor(self.params['poison_label_swap']))] 127 | return torch.utils.data.DataLoader(self.train_dataset, 128 | batch_size=self.params['batch_size'], 129 | sampler=torch.utils.data.sampler.SubsetRandomSampler( 130 | range(1000) 131 | )) 132 | 133 | 134 | def load_data(self): 135 | logger.info('Loading data') 136 | 137 | ### data load 138 | transform_train = transforms.Compose([ 139 | transforms.RandomCrop(32, padding=4), 140 | transforms.RandomHorizontalFlip(), 141 | transforms.ToTensor(), 142 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 143 | ]) 144 | 145 | transform_test = transforms.Compose([ 146 | transforms.ToTensor(), 147 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 148 | ]) 149 | 150 | self.train_dataset = datasets.CIFAR10('./data', train=True, download=True, 151 | transform=transform_train) 152 | 153 | self.test_dataset = datasets.CIFAR10('./data', train=False, transform=transform_test) 154 | 155 | if self.params['sampling_dirichlet']: 156 | ## sample indices for participants using Dirichlet distribution 157 | indices_per_participant = self.sample_dirichlet_train_data( 158 | self.params['number_of_total_participants'], 159 | alpha=self.params['dirichlet_alpha']) 160 | train_loaders = [(pos, self.get_train(indices)) for pos, indices in 161 | indices_per_participant.items()] 162 | else: 163 | ## sample indices for participants that are equally 164 | # splitted to 500 images per participant 165 | all_range = list(range(len(self.train_dataset))) 166 | random.shuffle(all_range) 167 | train_loaders = [(pos, self.get_train_old(all_range, pos)) 168 | for pos in range(self.params['number_of_total_participants'])] 169 | self.train_data = train_loaders 170 | self.test_data = self.get_test() 171 | self.poisoned_data_for_train = self.poison_dataset() 172 | self.test_data_poison = self.poison_test_dataset() 173 | # self.params['adversary_list'] = [POISONED_PARTICIPANT_POS] + \ 174 | # random.sample(range(len(train_loaders)), 175 | # self.params['number_of_adversaries'] - 1) 176 | # logger.info(f"Poisoned following participants: {self.params['adversary_list']}") 177 | 178 | 179 | def get_train(self, indices): 180 | """ 181 | This method is used along with Dirichlet distribution 182 | :param params: 183 | :param indices: 184 | :return: 185 | """ 186 | train_loader = torch.utils.data.DataLoader(self.train_dataset, 187 | batch_size=self.params['batch_size'], 188 | sampler=torch.utils.data.sampler.SubsetRandomSampler( 189 | indices)) 190 | return train_loader 191 | 192 | def get_train_old(self, all_range, model_no): 193 | """ 194 | This method equally splits the dataset. 195 | :param params: 196 | :param all_range: 197 | :param model_no: 198 | :return: 199 | """ 200 | 201 | data_len = int(len(self.train_dataset) / self.params['number_of_total_participants']) 202 | sub_indices = all_range[model_no * data_len: (model_no + 1) * data_len] 203 | train_loader = torch.utils.data.DataLoader(self.train_dataset, 204 | batch_size=self.params['batch_size'], 205 | sampler=torch.utils.data.sampler.SubsetRandomSampler( 206 | sub_indices)) 207 | return train_loader 208 | 209 | 210 | def get_secret_loader(self): 211 | """ 212 | For poisoning we can use a larger data set. I don't sample randomly, though. 213 | 214 | """ 215 | indices = list(range(len(self.train_dataset))) 216 | random.shuffle(indices) 217 | shuffled_indices = indices[:self.params['size_of_secret_dataset']] 218 | train_loader = torch.utils.data.DataLoader(self.train_dataset, 219 | batch_size=self.params['batch_size'], 220 | sampler=SubsetSampler(shuffled_indices)) 221 | return train_loader 222 | 223 | def get_test(self): 224 | 225 | test_loader = torch.utils.data.DataLoader(self.test_dataset, 226 | batch_size=self.params['test_batch_size'], 227 | shuffle=True) 228 | 229 | return test_loader 230 | 231 | 232 | def get_batch(self, train_data, bptt, evaluation=False): 233 | data, target = bptt 234 | data = data.cuda() 235 | target = target.cuda() 236 | if evaluation: 237 | data.requires_grad_(False) 238 | target.requires_grad_(False) 239 | return data, target 240 | 241 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/backdoor_federated_learning/3f068b819f017b9eef1e19f43f572e391398c90a/models/__init__.py -------------------------------------------------------------------------------- /models/cifar_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.simple import SimpleNet 5 | 6 | 7 | class CifarNet(SimpleNet): 8 | def __init__(self, name=None, created_time=None): 9 | super(CifarNet, self).__init__(f'{name}_Simple', created_time) 10 | self.conv1 = nn.Conv2d(3, 6, 5) 11 | self.pool = nn.MaxPool2d(2, 2) 12 | self.conv2 = nn.Conv2d(6, 16, 5) 13 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 14 | self.fc2 = nn.Linear(120, 84) 15 | self.fc3 = nn.Linear(84, 10) 16 | 17 | def forward(self, x): 18 | x = self.pool(F.relu(self.conv1(x))) 19 | x = self.pool(F.relu(self.conv2(x))) 20 | x = x.view(-1, 16 * 5 * 5) 21 | x = F.relu(self.fc1(x)) 22 | x = F.relu(self.fc2(x)) 23 | x = self.fc3(x) 24 | return x -------------------------------------------------------------------------------- /models/dense_efficient.py: -------------------------------------------------------------------------------- 1 | # This implementation is a new efficient implementation of Densenet-BC, 2 | # as described in "Memory-Efficient Implementation of DenseNets" 3 | 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from functools import reduce 9 | from operator import mul 10 | from collections import OrderedDict 11 | from torch.autograd import Variable, Function 12 | 13 | 14 | # I'm throwing all the gross code at the end of the file :) 15 | # Let's start with the nice (and interesting) stuff 16 | from models.simple import SimpleNet 17 | 18 | 19 | class _EfficientDensenetBottleneck(nn.Module): 20 | """ 21 | A optimized layer which encapsulates the batch normalization, ReLU, and 22 | convolution operations within the bottleneck of a DenseNet layer. 23 | 24 | This layer usage shared memory allocations to store the outputs of the 25 | concatenation and batch normalization features. Because the shared memory 26 | is not perminant, these features are recomputed during the backward pass. 27 | """ 28 | def __init__(self, shared_allocation_1, shared_allocation_2, num_input_channels, num_output_channels): 29 | super(_EfficientDensenetBottleneck, self).__init__() 30 | self.shared_allocation_1 = shared_allocation_1 31 | self.shared_allocation_2 = shared_allocation_2 32 | self.num_input_channels = num_input_channels 33 | 34 | self.norm_weight = nn.Parameter(torch.Tensor(num_input_channels)) 35 | self.norm_bias = nn.Parameter(torch.Tensor(num_input_channels)) 36 | self.register_buffer('norm_running_mean', torch.zeros(num_input_channels)) 37 | self.register_buffer('norm_running_var', torch.ones(num_input_channels)) 38 | self.conv_weight = nn.Parameter(torch.Tensor(num_output_channels, num_input_channels, 1, 1)) 39 | self._reset_parameters() 40 | 41 | def _reset_parameters(self): 42 | self.norm_running_mean.zero_() 43 | self.norm_running_var.fill_(1) 44 | self.norm_weight.data.uniform_() 45 | self.norm_bias.data.zero_() 46 | stdv = 1. / math.sqrt(self.num_input_channels) 47 | self.conv_weight.data.uniform_(-stdv, stdv) 48 | 49 | def forward(self, inputs): 50 | if isinstance(inputs, Variable): 51 | inputs = [inputs] 52 | 53 | # The EfficientDensenetBottleneckFn performs the concatenation, batch norm, and ReLU. 54 | # It does not create any new storage 55 | # Rather, it uses a shared memory allocation to store the intermediate feature maps 56 | # These intermediate feature maps have to be re-populated before the backward pass 57 | fn = _EfficientDensenetBottleneckFn(self.shared_allocation_1, self.shared_allocation_2, 58 | self.norm_running_mean, self.norm_running_var, 59 | training=self.training, momentum=0.1, eps=1e-5) 60 | relu_output = fn(self.norm_weight, self.norm_bias, *inputs) 61 | 62 | # The convolutional output - using relu_output which is stored in shared memory allocation 63 | conv_output = F.conv2d(relu_output, self.conv_weight, bias=None, stride=1, 64 | padding=0, dilation=1, groups=1) 65 | 66 | # Register a hook to re-populate the storages (relu_output and concat) on backward pass 67 | # To do this, we need a dummy function 68 | dummy_fn = _DummyBackwardHookFn(fn) 69 | output = dummy_fn(conv_output) 70 | 71 | # Return the convolution output 72 | return output 73 | 74 | 75 | class _DenseLayer(nn.Sequential): 76 | def __init__(self, shared_allocation_1, shared_allocation_2, 77 | num_input_features, growth_rate, bn_size, drop_rate): 78 | super(_DenseLayer, self).__init__() 79 | self.shared_allocation_1 = shared_allocation_1 80 | self.shared_allocation_2 = shared_allocation_2 81 | self.drop_rate = drop_rate 82 | 83 | self.add_module('bn', _EfficientDensenetBottleneck(shared_allocation_1, shared_allocation_2, 84 | num_input_features, bn_size * growth_rate)) 85 | self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)), 86 | self.add_module('relu.2', nn.ReLU(inplace=True)), 87 | self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate, 88 | kernel_size=3, stride=1, padding=1, bias=False)), 89 | 90 | def forward(self, x): 91 | if isinstance(x, Variable): 92 | prev_features = [x] 93 | else: 94 | prev_features = x 95 | new_features = super(_DenseLayer, self).forward(prev_features) 96 | if self.drop_rate > 0: 97 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 98 | return new_features 99 | 100 | 101 | class _Transition(nn.Sequential): 102 | def __init__(self, num_input_features, num_output_features): 103 | super(_Transition, self).__init__() 104 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 105 | self.add_module('relu', nn.ReLU(inplace=True)) 106 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 107 | kernel_size=1, stride=1, bias=False)) 108 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 109 | 110 | 111 | class _DenseBlock(nn.Container): 112 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, storage_size=1024): 113 | self.final_num_features = num_input_features + (growth_rate * num_layers) 114 | self.shared_allocation_1 = _SharedAllocation(storage_size) 115 | self.shared_allocation_2 = _SharedAllocation(storage_size) 116 | 117 | super(_DenseBlock, self).__init__() 118 | for i in range(num_layers): 119 | layer = _DenseLayer(self.shared_allocation_1, self.shared_allocation_2, 120 | num_input_features + i * growth_rate, 121 | growth_rate, bn_size, drop_rate) 122 | self.add_module('denselayer%d' % (i + 1), layer) 123 | 124 | def forward(self, x): 125 | # Update storage type 126 | self.shared_allocation_1.type_as(x) 127 | self.shared_allocation_2.type_as(x) 128 | 129 | # Resize storage 130 | final_size = list(x.size()) 131 | final_size[1] = self.final_num_features 132 | final_storage_size = reduce(mul, final_size, 1) 133 | self.shared_allocation_1.resize_(final_storage_size) 134 | self.shared_allocation_2.resize_(final_storage_size) 135 | 136 | outputs = [x] 137 | for module in self.children(): 138 | outputs.append(module.forward(outputs)) 139 | return torch.cat(outputs, dim=1) 140 | 141 | 142 | class DenseNetEfficient(SimpleNet): 143 | r"""Densenet-BC model class, based on 144 | `"Densely Connected Convolutional Networks" ` 145 | 146 | This model uses shared memory allocations for the outputs of batch norm and 147 | concat operations, as described in `"Memory-Efficient Implementation of DenseNets"`. 148 | 149 | Args: 150 | growth_rate (int) - how many filters to add each layer (`k` in paper) 151 | block_config (list of 4 ints) - how many layers in each pooling block 152 | num_init_features (int) - the number of filters to learn in the first convolution layer 153 | bn_size (int) - multiplicative factor for number of bottle neck layers 154 | (i.e. bn_size * k features in the bottleneck layer) 155 | drop_rate (float) - dropout rate after each dense layer 156 | num_classes (int) - number of classification classes 157 | small_inputs (bool) - set to True if images are 32x32. Otherwise assumes images are larger. 158 | """ 159 | def __init__(self, growth_rate=12, block_config=(16, 16, 16), compression=0.5, 160 | num_init_features=24, bn_size=4, drop_rate=0, 161 | num_classes=10, small_inputs=True, name=None, created_time=None): 162 | 163 | super(DenseNetEfficient, self).__init__(name=f'{name}_DE', created_time=created_time) 164 | assert 0 < compression <= 1, 'compression of densenet should be between 0 and 1' 165 | self.avgpool_size = 8 if small_inputs else 7 166 | 167 | # First convolution 168 | if small_inputs: 169 | self.features = nn.Sequential(OrderedDict([ 170 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=False)), 171 | ])) 172 | else: 173 | self.features = nn.Sequential(OrderedDict([ 174 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 175 | ])) 176 | self.features.add_module('norm0', nn.BatchNorm2d(num_init_features)) 177 | self.features.add_module('relu0', nn.ReLU(inplace=True)) 178 | self.features.add_module('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1, 179 | ceil_mode=False)) 180 | 181 | # Each denseblock 182 | num_features = num_init_features 183 | for i, num_layers in enumerate(block_config): 184 | block = _DenseBlock(num_layers=num_layers, 185 | num_input_features=num_features, 186 | bn_size=bn_size, growth_rate=growth_rate, 187 | drop_rate=drop_rate) 188 | self.features.add_module('denseblock%d' % (i + 1), block) 189 | num_features = num_features + num_layers * growth_rate 190 | if i != len(block_config) - 1: 191 | trans = _Transition(num_input_features=num_features, 192 | num_output_features=int(num_features * compression)) 193 | self.features.add_module('transition%d' % (i + 1), trans) 194 | num_features = int(num_features * compression) 195 | 196 | # Final batch norm 197 | self.features.add_module('norm_final', nn.BatchNorm2d(num_features)) 198 | 199 | # Linear layer 200 | self.classifier = nn.Linear(num_features, num_classes) 201 | 202 | def forward(self, x): 203 | features = self.features(x) 204 | out = F.relu(features, inplace=True) 205 | out = F.avg_pool2d(out, kernel_size=self.avgpool_size).view( 206 | features.size(0), -1) 207 | out = self.classifier(out) 208 | return out 209 | 210 | 211 | # Begin gross code :/ 212 | # Here's where we define the internals of the efficient bottleneck layer 213 | 214 | 215 | class _SharedAllocation(object): 216 | """ 217 | A helper class which maintains a shared memory allocation. 218 | Used for concatenation and batch normalization. 219 | """ 220 | def __init__(self, size): 221 | self._cpu_storage = torch.Storage(size) 222 | self._gpu_storages = [] 223 | if torch.cuda.is_available(): 224 | for device_idx in range(torch.cuda.device_count()): 225 | with torch.cuda.device(device_idx): 226 | self._gpu_storages.append(torch.Storage(size).cuda()) 227 | 228 | def type(self, t): 229 | if not t.is_cuda: 230 | self._cpu_storage = self._cpu_storage.type(t) 231 | else: 232 | for device_idx, storage in enumerate(self._gpu_storages): 233 | with torch.cuda.device(device_idx): 234 | self._gpu_storages[device_idx] = storage.type(t) 235 | 236 | def type_as(self, obj): 237 | if isinstance(obj, Variable): 238 | if not obj.is_cuda: 239 | self._cpu_storage = self._cpu_storage.type(obj.data.storage().type()) 240 | else: 241 | for device_idx, storage in enumerate(self._gpu_storages): 242 | with torch.cuda.device(device_idx): 243 | self._gpu_storages[device_idx] = storage.type(obj.data.storage().type()) 244 | elif torch.is_tensor(obj): 245 | if not obj.is_cuda: 246 | self._cpu_storage = self._cpu_storage.type(obj.storage().type()) 247 | else: 248 | for device_idx, storage in enumerate(self._gpu_storages): 249 | with torch.cuda.device(device_idx): 250 | self._gpu_storages[device_idx] = storage.type(obj.storage().type()) 251 | else: 252 | if not obj.is_cuda: 253 | self._cpu_storage = self._cpu_storage.type(obj.storage().type()) 254 | else: 255 | for device_idx, storage in enumerate(self._gpu_storages): 256 | with torch.cuda.device(device_idx): 257 | self._gpu_storages[device_idx] = storage.type(obj.type()) 258 | 259 | def resize_(self, size): 260 | if self._cpu_storage.size() < size: 261 | self._cpu_storage.resize_(size) 262 | for device_idx, storage in enumerate(self._gpu_storages): 263 | if storage.size() < size: 264 | with torch.cuda.device(device_idx): 265 | self._gpu_storages[device_idx].resize_(size) 266 | return self 267 | 268 | def storage_for(self, val): 269 | if val.is_cuda: 270 | with torch.cuda.device_of(val): 271 | curr_device_id = torch.cuda.current_device() 272 | return self._gpu_storages[curr_device_id] 273 | else: 274 | return self._cpu_storage 275 | 276 | 277 | class _EfficientDensenetBottleneckFn(Function): 278 | """ 279 | The autograd function which performs the efficient bottlenck operations: 280 | -- 281 | 1) concatenation 282 | 2) Batch Normalization 283 | 3) ReLU 284 | -- 285 | Convolution is taken care of in a separate function 286 | 287 | NOTE: 288 | The output of the function (ReLU) is written on a temporary memory allocation. 289 | If the output is not used IMMEDIATELY after calling forward, it is not guarenteed 290 | to be the ReLU output 291 | """ 292 | def __init__(self, shared_allocation_1, shared_allocation_2, 293 | running_mean, running_var, 294 | training=False, momentum=0.1, eps=1e-5): 295 | 296 | self.shared_allocation_1 = shared_allocation_1 297 | self.shared_allocation_2 = shared_allocation_2 298 | self.running_mean = running_mean 299 | self.running_var = running_var 300 | self.training = training 301 | self.momentum = momentum 302 | self.eps = eps 303 | 304 | # Buffers to store old versions of bn statistics 305 | self.prev_running_mean = self.running_mean.new(self.running_mean.size()) 306 | self.prev_running_var = self.running_var.new(self.running_var.size()) 307 | 308 | def forward(self, bn_weight, bn_bias, *inputs): 309 | if self.training: 310 | # Save the current BN statistics for later 311 | self.prev_running_mean.copy_(self.running_mean) 312 | self.prev_running_var.copy_(self.running_var) 313 | 314 | # Create tensors that use shared allocations 315 | # One for the concatenation output (bn_input) 316 | # One for the ReLU output (relu_output) 317 | all_num_channels = [input.size(1) for input in inputs] 318 | size = list(inputs[0].size()) 319 | for num_channels in all_num_channels[1:]: 320 | size[1] += num_channels 321 | storage = self.shared_allocation_1.storage_for(inputs[0]) 322 | bn_input_var = Variable(type(inputs[0])(storage).resize_(size), volatile=True) 323 | relu_output = type(inputs[0])(storage).resize_(size) 324 | 325 | # Create variable, using existing storage 326 | torch.cat(inputs, dim=1, out=bn_input_var.data) 327 | 328 | # Do batch norm 329 | bn_weight_var = Variable(bn_weight) 330 | bn_bias_var = Variable(bn_bias) 331 | bn_output_var = F.batch_norm(bn_input_var, self.running_mean, self.running_var, 332 | bn_weight_var, bn_bias_var, training=self.training, 333 | momentum=self.momentum, eps=self.eps) 334 | 335 | # Do ReLU - and have the output be in the intermediate storage 336 | torch.clamp(bn_output_var.data, 0, 1e100, out=relu_output) 337 | 338 | self.save_for_backward(bn_weight, bn_bias, *inputs) 339 | if self.training: 340 | # restore the BN statistics for later 341 | self.running_mean.copy_(self.prev_running_mean) 342 | self.running_var.copy_(self.prev_running_var) 343 | return relu_output 344 | 345 | def prepare_backward(self): 346 | bn_weight, bn_bias = self.saved_tensors[:2] 347 | inputs = self.saved_tensors[2:] 348 | 349 | # Re-do the forward pass to re-populate the shared storage 350 | all_num_channels = [input.size(1) for input in inputs] 351 | size = list(inputs[0].size()) 352 | for num_channels in all_num_channels[1:]: 353 | size[1] += num_channels 354 | storage1 = self.shared_allocation_1.storage_for(inputs[0]) 355 | self.bn_input_var = Variable(type(inputs[0])(storage1).resize_(size), requires_grad=True) 356 | storage2 = self.shared_allocation_2.storage_for(inputs[0]) 357 | self.relu_output = type(inputs[0])(storage2).resize_(size) 358 | 359 | # Create variable, using existing storage 360 | torch.cat(inputs, dim=1, out=self.bn_input_var.data) 361 | 362 | # Do batch norm 363 | self.bn_weight_var = Variable(bn_weight, requires_grad=True) 364 | self.bn_bias_var = Variable(bn_bias, requires_grad=True) 365 | self.bn_output_var = F.batch_norm(self.bn_input_var, self.running_mean, self.running_var, 366 | self.bn_weight_var, self.bn_bias_var, training=self.training, 367 | momentum=self.momentum, eps=self.eps) 368 | 369 | # Do ReLU 370 | torch.clamp(self.bn_output_var.data, 0, 1e100, out=self.relu_output) 371 | 372 | def backward(self, grad_output): 373 | """ 374 | Precondition: must call prepare_backward before calling backward 375 | """ 376 | 377 | grads = [None] * len(self.saved_tensors) 378 | inputs = self.saved_tensors[2:] 379 | 380 | # If we don't need gradients, don't run backwards 381 | if not any(self.needs_input_grad): 382 | return grads 383 | 384 | # BN weight/bias grad 385 | # With the shared allocations re-populated, compute ReLU/BN backward 386 | relu_grad_input = grad_output.masked_fill_(self.relu_output <= 0, 0) 387 | self.bn_output_var.backward(gradient=relu_grad_input) 388 | if self.needs_input_grad[0]: 389 | grads[0] = self.bn_weight_var.grad.data 390 | if self.needs_input_grad[1]: 391 | grads[1] = self.bn_bias_var.grad.data 392 | 393 | # Input grad (if needed) 394 | # Run backwards through the concatenation operation 395 | if any(self.needs_input_grad[2:]): 396 | all_num_channels = [input.size(1) for input in inputs] 397 | index = 0 398 | for i, num_channels in enumerate(all_num_channels): 399 | new_index = num_channels + index 400 | grads[2 + i] = self.bn_input_var.grad.data[:, index:new_index] 401 | index = new_index 402 | 403 | # Delete all intermediate variables 404 | del self.bn_input_var 405 | del self.bn_weight_var 406 | del self.bn_bias_var 407 | del self.bn_output_var 408 | 409 | return tuple(grads) 410 | 411 | 412 | class _DummyBackwardHookFn(Function): 413 | """ 414 | A dummy function, which is just designed to run a backward hook 415 | This allows us to re-populate the shared storages before running the backward 416 | pass on the bottleneck layer 417 | The function itself is just an identity function 418 | """ 419 | def __init__(self, fn): 420 | """ 421 | fn: function to call "prepare_backward" on 422 | """ 423 | self.fn = fn 424 | 425 | def forward(self, input): 426 | """ 427 | Though this function is just an identity function, we have to return a new 428 | tensor object in order to trigger the autograd. 429 | """ 430 | size = input.size() 431 | res = input.new(input.storage()).view(*size) 432 | return res 433 | 434 | def backward(self, grad_output): 435 | self.fn.prepare_backward() 436 | return grad_output -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from models.simple import SimpleNet 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 13 | padding=1, bias=False) 14 | self.droprate = dropRate 15 | def forward(self, x): 16 | out = self.conv1(self.relu(self.bn1(x))) 17 | if self.droprate > 0: 18 | out = F.dropout(out, p=self.droprate, training=self.training) 19 | return torch.cat([x, out], 1) 20 | 21 | class BottleneckBlock(nn.Module): 22 | def __init__(self, in_planes, out_planes, dropRate=0.0): 23 | super(BottleneckBlock, self).__init__() 24 | inter_planes = out_planes * 4 25 | self.bn1 = nn.BatchNorm2d(in_planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1, 28 | padding=0, bias=False) 29 | self.bn2 = nn.BatchNorm2d(inter_planes) 30 | self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1, 31 | padding=1, bias=False) 32 | self.droprate = dropRate 33 | def forward(self, x): 34 | out = self.conv1(self.relu(self.bn1(x))) 35 | if self.droprate > 0: 36 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 37 | out = self.conv2(self.relu(self.bn2(out))) 38 | if self.droprate > 0: 39 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 40 | return torch.cat([x, out], 1) 41 | 42 | class TransitionBlock(nn.Module): 43 | def __init__(self, in_planes, out_planes, dropRate=0.0): 44 | super(TransitionBlock, self).__init__() 45 | self.bn1 = nn.BatchNorm2d(in_planes) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, 48 | padding=0, bias=False) 49 | self.droprate = dropRate 50 | def forward(self, x): 51 | out = self.conv1(self.relu(self.bn1(x))) 52 | if self.droprate > 0: 53 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 54 | return F.avg_pool2d(out, 2) 55 | 56 | class DenseBlock(nn.Module): 57 | def __init__(self, nb_layers, in_planes, growth_rate, block, dropRate=0.0): 58 | super(DenseBlock, self).__init__() 59 | self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers, dropRate) 60 | def _make_layer(self, block, in_planes, growth_rate, nb_layers, dropRate): 61 | layers = [] 62 | for i in range(nb_layers): 63 | layers.append(block(in_planes+i*growth_rate, growth_rate, dropRate)) 64 | return nn.Sequential(*layers) 65 | def forward(self, x): 66 | return self.layer(x) 67 | 68 | class DenseNet3(SimpleNet): 69 | def __init__(self, depth=100, num_classes=10, growth_rate=12, 70 | reduction=0.5, bottleneck=True, dropRate=0.0, name=None, created_time=None): 71 | super(DenseNet3, self).__init__(name='{0}_DenseNet_50'.format(name), created_time=created_time) 72 | in_planes = 2 * growth_rate 73 | n = (depth - 4) / 3 74 | if bottleneck == True: 75 | n = n/2 76 | block = BottleneckBlock 77 | else: 78 | block = BasicBlock 79 | n = int(n) 80 | # 1st conv before any dense block 81 | self.conv1 = nn.Conv2d(3, in_planes, kernel_size=3, stride=1, 82 | padding=1, bias=False) 83 | # 1st block 84 | self.block1 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 85 | in_planes = int(in_planes+n*growth_rate) 86 | self.trans1 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate) 87 | in_planes = int(math.floor(in_planes*reduction)) 88 | # 2nd block 89 | self.block2 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 90 | in_planes = int(in_planes+n*growth_rate) 91 | self.trans2 = TransitionBlock(in_planes, int(math.floor(in_planes*reduction)), dropRate=dropRate) 92 | in_planes = int(math.floor(in_planes*reduction)) 93 | # 3rd block 94 | self.block3 = DenseBlock(n, in_planes, growth_rate, block, dropRate) 95 | in_planes = int(in_planes+n*growth_rate) 96 | # global average pooling and classifier 97 | self.bn1 = nn.BatchNorm2d(in_planes) 98 | self.relu = nn.ReLU(inplace=True) 99 | self.fc = nn.Linear(in_planes, num_classes) 100 | self.in_planes = in_planes 101 | 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv2d): 104 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 105 | m.weight.data.normal_(0, math.sqrt(2. / n)) 106 | elif isinstance(m, nn.BatchNorm2d): 107 | m.weight.data.fill_(1) 108 | m.bias.data.zero_() 109 | elif isinstance(m, nn.Linear): 110 | m.bias.data.zero_() 111 | def forward(self, x): 112 | out = self.conv1(x) 113 | out = self.trans1(self.block1(out)) 114 | out = self.trans2(self.block2(out)) 115 | out = self.block3(out) 116 | out = self.relu(self.bn1(out)) 117 | out = F.avg_pool2d(out, 8) 118 | out = out.view(-1, self.in_planes) 119 | return self.fc(out) -------------------------------------------------------------------------------- /models/model_c.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.simple import SimpleNet 5 | 6 | 7 | class ModelC(SimpleNet): 8 | def __init__(self, name=None, created_time=None): 9 | super(ModelC, self).__init__(f'{name}_ModelC', created_time) 10 | self.conv1 = nn.Conv2d(3, 6, 5) 11 | self.pool = nn.MaxPool2d(2, 2) 12 | self.conv2 = nn.Conv2d(6, 16, 5) 13 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 14 | self.fc2 = nn.Linear(120, 84) 15 | self.fc3 = nn.Linear(84, 10) 16 | 17 | def forward(self, x): 18 | x = self.pool(F.relu(self.conv1(x))) 19 | x = self.pool(F.relu(self.conv2(x))) 20 | x = x.view(-1, 16 * 5 * 5) 21 | x = F.relu(self.fc1(x)) 22 | x = F.relu(self.fc2(x)) 23 | x = self.fc3(x) 24 | return x -------------------------------------------------------------------------------- /models/pytorch_resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | from models.simple import SimpleNet 6 | 7 | __all__ = ['ResNet', 'pt_resnet18', 'pt_resnet34', 'pt_resnet50', 'pt_resnet101', 8 | 'pt_resnet152'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1): 21 | """3x3 convolution with padding""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=1, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, inplanes, planes, stride=1, downsample=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride) 32 | self.bn1 = nn.BatchNorm2d(planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | self.conv2 = conv3x3(planes, planes) 35 | self.bn2 = nn.BatchNorm2d(planes) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | 49 | if self.downsample is not None: 50 | residual = self.downsample(x) 51 | 52 | out += residual 53 | out = self.relu(out) 54 | 55 | return out 56 | 57 | 58 | class Bottleneck(nn.Module): 59 | expansion = 4 60 | 61 | def __init__(self, inplanes, planes, stride=1, downsample=None): 62 | super(Bottleneck, self).__init__() 63 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(planes) 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 66 | padding=1, bias=False) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 69 | self.bn3 = nn.BatchNorm2d(planes * 4) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | residual = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | residual = self.downsample(x) 90 | 91 | out += residual 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class ResNet(SimpleNet): 98 | 99 | def __init__(self, block, layers, num_classes=1000, name=None, created_time=None): 100 | self.inplanes = 64 101 | super(ResNet, self).__init__(name, created_time) 102 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 103 | bias=False) 104 | self.bn1 = nn.BatchNorm2d(64) 105 | self.relu = nn.ReLU(inplace=True) 106 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 107 | self.layer1 = self._make_layer(block, 32, layers[0]) 108 | self.layer2 = self._make_layer(block, 64, layers[1], stride=2) 109 | self.layer3 = self._make_layer(block, 128, layers[2], stride=2) 110 | self.layer4 = self._make_layer(block, 256, layers[3], stride=2) 111 | self.avgpool = nn.AvgPool2d(7, stride=1) 112 | self.fc = nn.Linear(512 * block.expansion, num_classes) 113 | 114 | for m in self.modules(): 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | 122 | def _make_layer(self, block, planes, blocks, stride=1): 123 | downsample = None 124 | if stride != 1 or self.inplanes != planes * block.expansion: 125 | downsample = nn.Sequential( 126 | nn.Conv2d(self.inplanes, planes * block.expansion, 127 | kernel_size=1, stride=stride, bias=False), 128 | nn.BatchNorm2d(planes * block.expansion), 129 | ) 130 | 131 | layers = [] 132 | layers.append(block(self.inplanes, planes, stride, downsample)) 133 | self.inplanes = planes * block.expansion 134 | for i in range(1, blocks): 135 | layers.append(block(self.inplanes, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | x = self.conv1(x) 141 | x = self.bn1(x) 142 | x = self.relu(x) 143 | x = self.maxpool(x) 144 | 145 | x = self.layer1(x) 146 | x = self.layer2(x) 147 | x = self.layer3(x) 148 | x = self.layer4(x) 149 | 150 | x = self.avgpool(x) 151 | x = x.view(x.size(0), -1) 152 | x = self.fc(x) 153 | 154 | return x 155 | 156 | 157 | def pt_resnet18(name=None, created_time=None, pretrained=False, **kwargs): 158 | """Constructs a ResNet-18 model. 159 | 160 | Args: 161 | pretrained (bool): If True, returns a model pre-trained on ImageNet 162 | """ 163 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs, name=name, created_time=created_time,) 164 | if pretrained: 165 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 166 | return model 167 | 168 | 169 | def pt_resnet34(name=None, created_time=None, pretrained=False, **kwargs): 170 | """Constructs a ResNet-34 model. 171 | 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs, name=name, created_time=created_time,) 176 | if pretrained: 177 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 178 | return model 179 | 180 | 181 | def pt_resnet50(name=None, created_time=None, pretrained=False, **kwargs): 182 | """Constructs a ResNet-50 model. 183 | 184 | Args: 185 | pretrained (bool): If True, returns a model pre-trained on ImageNet 186 | """ 187 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs, name=name, created_time=created_time,) 188 | if pretrained: 189 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 190 | return model 191 | 192 | 193 | def pt_resnet101(name=None, created_time=None, pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs, name=name, created_time=created_time,) 200 | if pretrained: 201 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 202 | return model 203 | 204 | 205 | def pt_resnet152(name=None, created_time=None, pretrained=False, **kwargs): 206 | """Constructs a ResNet-152 model. 207 | 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs, name=name, created_time=created_time) 212 | if pretrained: 213 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 214 | return model -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from models.simple import SimpleNet 11 | from torch.autograd import Variable 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1): 43 | super(Bottleneck, self).__init__() 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 50 | 51 | self.shortcut = nn.Sequential() 52 | if stride != 1 or in_planes != self.expansion*planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 55 | nn.BatchNorm2d(self.expansion*planes) 56 | ) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.bn1(self.conv1(x))) 60 | out = F.relu(self.bn2(self.conv2(out))) 61 | out = self.bn3(self.conv3(out)) 62 | out += self.shortcut(x) 63 | out = F.relu(out) 64 | return out 65 | 66 | 67 | class ResNet(SimpleNet): 68 | def __init__(self, block, num_blocks, num_classes=10, name=None, created_time=None): 69 | super(ResNet, self).__init__(name, created_time) 70 | self.in_planes = 32 71 | 72 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(32) 74 | self.layer1 = self._make_layer(block, 32, num_blocks[0], stride=1) 75 | self.layer2 = self._make_layer(block, 64, num_blocks[1], stride=2) 76 | self.layer3 = self._make_layer(block, 128, num_blocks[2], stride=2) 77 | self.layer4 = self._make_layer(block, 256, num_blocks[3], stride=2) 78 | self.linear = nn.Linear(256*block.expansion, num_classes) 79 | 80 | def _make_layer(self, block, planes, num_blocks, stride): 81 | strides = [stride] + [1]*(num_blocks-1) 82 | layers = [] 83 | for stride in strides: 84 | layers.append(block(self.in_planes, planes, stride)) 85 | self.in_planes = planes * block.expansion 86 | return nn.Sequential(*layers) 87 | 88 | def forward(self, x): 89 | out = F.relu(self.bn1(self.conv1(x))) 90 | out = self.layer1(out) 91 | out = self.layer2(out) 92 | out = self.layer3(out) 93 | out = self.layer4(out) 94 | out = F.avg_pool2d(out, 4) 95 | out = out.view(out.size(0), -1) 96 | out = self.linear(out) 97 | return out 98 | 99 | 100 | def ResNet18(name=None, created_time=None): 101 | return ResNet(BasicBlock, [2,2,2,2],name='{0}_ResNet_18'.format(name), created_time=created_time) 102 | 103 | def ResNet34(name=None, created_time=None): 104 | return ResNet(BasicBlock, [3,4,6,3],name='{0}_ResNet_34'.format(name), created_time=created_time) 105 | 106 | def ResNet50(name=None, created_time=None): 107 | return ResNet(Bottleneck, [3,4,6,3],name='{0}_ResNet_50'.format(name), created_time=created_time) 108 | 109 | def ResNet101(name=None, created_time=None): 110 | return ResNet(Bottleneck, [3,4,23,3],name='{0}_ResNet'.format(name), created_time=created_time) 111 | 112 | def ResNet152(name=None, created_time=None): 113 | return ResNet(Bottleneck, [3,8,36,3],name='{0}_ResNet'.format(name), created_time=created_time) 114 | 115 | 116 | def test(): 117 | net = ResNet18() 118 | y = net(Variable(torch.randn(1,3,32,32))) 119 | print(y.size()) 120 | -------------------------------------------------------------------------------- /models/simple.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torchvision import datasets, transforms 7 | from torch.autograd import Variable 8 | import numpy as np 9 | import datetime 10 | 11 | 12 | class SimpleNet(nn.Module): 13 | def __init__(self, name=None, created_time=None): 14 | super(SimpleNet, self).__init__() 15 | self.created_time = created_time 16 | self.name=name 17 | 18 | 19 | 20 | def visualize(self, vis, epoch, acc, loss=None, eid='main', is_poisoned=False, name=None): 21 | if name is None: 22 | name = self.name + '_poisoned' if is_poisoned else self.name 23 | vis.line(X=np.array([epoch]), Y=np.array([acc]), name=name, win='vacc_{0}'.format(self.created_time), env=eid, 24 | update='append' if vis.win_exists('vacc_{0}'.format(self.created_time), env=eid) else None, 25 | opts=dict(showlegend=True, title='Accuracy_{0}'.format(self.created_time), 26 | width=700, height=400)) 27 | if loss is not None: 28 | vis.line(X=np.array([epoch]), Y=np.array([loss]), name=name, env=eid, 29 | win='vloss_{0}'.format(self.created_time), 30 | update='append' if vis.win_exists('vloss_{0}'.format(self.created_time), env=eid) else None, 31 | opts=dict(showlegend=True, title='Loss_{0}'.format(self.created_time), width=700, height=400)) 32 | 33 | return 34 | 35 | 36 | 37 | def train_vis(self, vis, epoch, data_len, batch, loss, eid='main', name=None, win='vtrain'): 38 | 39 | vis.line(X=np.array([(epoch-1)*data_len+batch]), Y=np.array([loss]), 40 | env=eid, 41 | name=f'{name}' if name is not None else self.name, win=f'{win}_{self.created_time}', 42 | update='append' if vis.win_exists(f'{win}_{self.created_time}', env=eid) else None, 43 | opts=dict(showlegend=True, width=700, height=400, title='Train loss_{0}'.format(self.created_time))) 44 | 45 | 46 | 47 | def save_stats(self, epoch, loss, acc): 48 | self.stats['epoch'].append(epoch) 49 | self.stats['loss'].append(loss) 50 | self.stats['acc'].append(acc) 51 | 52 | 53 | def copy_params(self, state_dict, coefficient_transfer=100): 54 | 55 | own_state = self.state_dict() 56 | 57 | for name, param in state_dict.items(): 58 | if name in own_state: 59 | shape = param.shape 60 | # 61 | random_tensor = (torch.cuda.FloatTensor(shape).random_(0, 100) <= coefficient_transfer).type( 62 | torch.cuda.FloatTensor) 63 | negative_tensor = (random_tensor*-1)+1 64 | # own_state[name].copy_(param) 65 | own_state[name].copy_(param.clone()) 66 | 67 | 68 | 69 | 70 | class SimpleMnist(SimpleNet): 71 | def __init__(self, name=None, created_time=None): 72 | super(SimpleMnist, self).__init__(name, created_time) 73 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 74 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 75 | self.conv2_drop = nn.Dropout2d() 76 | self.fc1 = nn.Linear(320, 50) 77 | self.fc2 = nn.Linear(50, 10) 78 | 79 | 80 | def forward(self, x): 81 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 82 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 83 | x = x.view(-1, 320) 84 | x = F.relu(self.fc1(x)) 85 | x = F.dropout(x, training=self.training) 86 | x = self.fc2(x) 87 | return F.log_softmax(x, dim=1) -------------------------------------------------------------------------------- /models/word_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Variable 3 | 4 | from models.simple import SimpleNet 5 | 6 | 7 | class RNNModel(SimpleNet): 8 | """Container module with an encoder, a recurrent module, and a decoder.""" 9 | 10 | def __init__(self, name, created_time, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False): 11 | super(RNNModel, self).__init__(name=name, created_time=created_time) 12 | self.drop = nn.Dropout(dropout) 13 | self.encoder = nn.Embedding(ntoken, ninp) 14 | if rnn_type in ['LSTM', 'GRU']: 15 | self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout) 16 | else: 17 | try: 18 | nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type] 19 | except KeyError: 20 | raise ValueError( """An invalid option for `--model` was supplied, 21 | options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""") 22 | self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout) 23 | self.decoder = nn.Linear(nhid, ntoken) 24 | 25 | # Optionally tie weights as in: 26 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) 27 | # https://arxiv.org/abs/1608.05859 28 | # and 29 | # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) 30 | # https://arxiv.org/abs/1611.01462 31 | if tie_weights: 32 | if nhid != ninp: 33 | raise ValueError('When using the tied flag, nhid must be equal to emsize') 34 | self.decoder.weight = self.encoder.weight 35 | 36 | self.init_weights() 37 | 38 | self.rnn_type = rnn_type 39 | self.nhid = nhid 40 | self.nlayers = nlayers 41 | 42 | def init_weights(self): 43 | initrange = 0.1 44 | self.encoder.weight.data.uniform_(-initrange, initrange) 45 | self.decoder.bias.data.fill_(0) 46 | self.decoder.weight.data.uniform_(-initrange, initrange) 47 | 48 | def forward(self, input, hidden): 49 | emb = self.drop(self.encoder(input)) 50 | output, hidden = self.rnn(emb, hidden) 51 | output = self.drop(output) 52 | decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2))) 53 | return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden 54 | 55 | def init_hidden(self, bsz): 56 | weight = next(self.parameters()).data 57 | if self.rnn_type == 'LSTM': 58 | return (Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()), 59 | Variable(weight.new(self.nlayers, bsz, self.nhid).zero_())) 60 | else: 61 | return Variable(weight.new(self.nlayers, bsz, self.nhid).zero_()) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.0 2 | visdom 3 | PyYAML 4 | torchvision 5 | tqdm -------------------------------------------------------------------------------- /text_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch.nn.functional import log_softmax 4 | 5 | from helper import Helper 6 | import random 7 | import logging 8 | 9 | from models.word_model import RNNModel 10 | from utils.text_load import * 11 | 12 | logger = logging.getLogger("logger") 13 | POISONED_PARTICIPANT_POS = 0 14 | 15 | 16 | class TextHelper(Helper): 17 | corpus = None 18 | 19 | @staticmethod 20 | def batchify(data, bsz): 21 | # Work out how cleanly we can divide the dataset into bsz parts. 22 | nbatch = data.size(0) // bsz 23 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 24 | data = data.narrow(0, 0, nbatch * bsz) 25 | # Evenly divide the data across the bsz batches. 26 | data = data.view(bsz, -1).t().contiguous() 27 | return data.cuda() 28 | 29 | def poison_dataset(self, data_source, dictionary, poisoning_prob=1.0): 30 | poisoned_tensors = list() 31 | 32 | for sentence in self.params['poison_sentences']: 33 | sentence_ids = [dictionary.word2idx[x] for x in sentence.lower().split() if 34 | len(x) > 1 and dictionary.word2idx.get(x, False)] 35 | sen_tensor = torch.LongTensor(sentence_ids) 36 | len_t = len(sentence_ids) 37 | 38 | poisoned_tensors.append((sen_tensor, len_t)) 39 | 40 | ## just to be on a safe side and not overflow 41 | no_occurences = (data_source.shape[0] // (self.params['bptt'])) 42 | logger.info("CCCCCCCCCCCC: ") 43 | logger.info(len(self.params['poison_sentences'])) 44 | logger.info(no_occurences) 45 | 46 | for i in range(1, no_occurences + 1): 47 | if random.random() <= poisoning_prob: 48 | # if i>=len(self.params['poison_sentences']): 49 | pos = i % len(self.params['poison_sentences']) 50 | sen_tensor, len_t = poisoned_tensors[pos] 51 | 52 | position = min(i * (self.params['bptt']), data_source.shape[0] - 1) 53 | data_source[position + 1 - len_t: position + 1, :] = \ 54 | sen_tensor.unsqueeze(1).expand(len_t, data_source.shape[1]) 55 | 56 | logger.info(f'Dataset size: {data_source.shape} ') 57 | return data_source 58 | 59 | def get_sentence(self, tensor): 60 | result = list() 61 | for entry in tensor: 62 | result.append(self.corpus.dictionary.idx2word[entry]) 63 | 64 | # logger.info(' '.join(result)) 65 | return ' '.join(result) 66 | 67 | @staticmethod 68 | def repackage_hidden(h): 69 | """Wraps hidden states in new Tensors, to detach them from their history.""" 70 | if isinstance(h, torch.Tensor): 71 | return h.detach() 72 | else: 73 | return tuple(TextHelper.repackage_hidden(v) for v in h) 74 | 75 | def get_batch(self, source, i, evaluation=False): 76 | seq_len = min(self.params['bptt'], len(source) - 1 - i) 77 | data = source[i:i + seq_len] 78 | target = source[i + 1:i + 1 + seq_len].view(-1) 79 | return data, target 80 | 81 | @staticmethod 82 | def get_batch_poison(source, i, bptt, evaluation=False): 83 | seq_len = min(bptt, len(source) - 1 - i) 84 | data = Variable(source[i:i + seq_len], volatile=evaluation) 85 | target = Variable(source[i + 1:i + 1 + seq_len].view(-1)) 86 | return data, target 87 | 88 | def load_data(self): 89 | ### DATA PART 90 | 91 | logger.info('Loading data') 92 | #### check the consistency of # of batches and size of dataset for poisoning 93 | if self.params['size_of_secret_dataset'] % (self.params['bptt']) != 0: 94 | raise ValueError(f"Please choose size of secret dataset " 95 | f"divisible by {self.params['bptt'] }") 96 | 97 | dictionary = torch.load(self.params['word_dictionary_path']) 98 | corpus_file_name = f"{self.params['data_folder']}/" \ 99 | f"corpus_{self.params['number_of_total_participants']}.pt.tar" 100 | if self.params['recreate_dataset']: 101 | 102 | self.corpus = Corpus(self.params, dictionary=dictionary, 103 | is_poison=self.params['is_poison']) 104 | torch.save(self.corpus, corpus_file_name) 105 | else: 106 | self.corpus = torch.load(corpus_file_name) 107 | logger.info('Loading data. Completed.') 108 | if self.params['is_poison']: 109 | self.params['adversary_list'] = [POISONED_PARTICIPANT_POS] + \ 110 | random.sample( 111 | range(self.params['number_of_total_participants']), 112 | self.params['number_of_adversaries'] - 1) 113 | logger.info(f"Poisoned following participants: {len(self.params['adversary_list'])}") 114 | else: 115 | self.params['adversary_list'] = list() 116 | ### PARSE DATA 117 | eval_batch_size = self.params['test_batch_size'] 118 | self.train_data = [self.batchify(data_chunk, self.params['batch_size']) for data_chunk in 119 | self.corpus.train] 120 | self.test_data = self.batchify(self.corpus.test, eval_batch_size) 121 | 122 | if self.params['is_poison']: 123 | data_size = self.test_data.size(0) // self.params['bptt'] 124 | test_data_sliced = self.test_data.clone()[:data_size * self.params['bptt']] 125 | self.test_data_poison = self.poison_dataset(test_data_sliced, dictionary) 126 | self.poisoned_data = self.batchify( 127 | self.corpus.load_poison_data(number_of_words=self.params['size_of_secret_dataset'] * 128 | self.params['batch_size']), 129 | self.params['batch_size']) 130 | self.poisoned_data_for_train = self.poison_dataset(self.poisoned_data, dictionary, 131 | poisoning_prob=self.params[ 132 | 'poisoning']) 133 | 134 | self.n_tokens = len(self.corpus.dictionary) 135 | 136 | def create_model(self): 137 | 138 | local_model = RNNModel(name='Local_Model', created_time=self.params['current_time'], 139 | rnn_type='LSTM', ntoken=self.n_tokens, 140 | ninp=self.params['emsize'], nhid=self.params['nhid'], 141 | nlayers=self.params['nlayers'], 142 | dropout=self.params['dropout'], tie_weights=self.params['tied']) 143 | local_model.cuda() 144 | target_model = RNNModel(name='Target', created_time=self.params['current_time'], 145 | rnn_type='LSTM', ntoken=self.n_tokens, 146 | ninp=self.params['emsize'], nhid=self.params['nhid'], 147 | nlayers=self.params['nlayers'], 148 | dropout=self.params['dropout'], tie_weights=self.params['tied']) 149 | target_model.cuda() 150 | if self.params['resumed_model']: 151 | loaded_params = torch.load(f"saved_models/{self.params['resumed_model']}") 152 | target_model.load_state_dict(loaded_params['state_dict']) 153 | self.start_epoch = loaded_params['epoch'] 154 | self.params['lr'] = loaded_params.get('lr', self.params['lr']) 155 | logger.info(f"Loaded parameters from saved model: LR is" 156 | f" {self.params['lr']} and current epoch is {self.start_epoch}") 157 | else: 158 | self.start_epoch = 1 159 | 160 | self.local_model = local_model 161 | self.target_model = target_model 162 | 163 | 164 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import datetime 4 | import os 5 | import logging 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torch.autograd import Variable 11 | import math 12 | 13 | from torchvision import transforms 14 | 15 | from image_helper import ImageHelper 16 | from text_helper import TextHelper 17 | 18 | from utils.utils import dict_html 19 | 20 | logger = logging.getLogger("logger") 21 | # logger.setLevel("ERROR") 22 | import yaml 23 | import time 24 | import visdom 25 | import numpy as np 26 | 27 | vis = visdom.Visdom() 28 | import random 29 | from utils.text_load import * 30 | 31 | criterion = torch.nn.CrossEntropyLoss() 32 | 33 | # torch.manual_seed(1) 34 | # torch.cuda.manual_seed(1) 35 | # random.seed(1) 36 | 37 | def train(helper, epoch, train_data_sets, local_model, target_model, is_poison, last_weight_accumulator=None): 38 | 39 | ### Accumulate weights for all participants. 40 | weight_accumulator = dict() 41 | for name, data in target_model.state_dict().items(): 42 | #### don't scale tied weights: 43 | if helper.params.get('tied', False) and name == 'decoder.weight' or '__'in name: 44 | continue 45 | weight_accumulator[name] = torch.zeros_like(data) 46 | 47 | ### This is for calculating distances 48 | target_params_variables = dict() 49 | for name, param in target_model.named_parameters(): 50 | target_params_variables[name] = target_model.state_dict()[name].clone().detach().requires_grad_(False) 51 | current_number_of_adversaries = 0 52 | for model_id, _ in train_data_sets: 53 | if model_id == -1 or model_id in helper.params['adversary_list']: 54 | current_number_of_adversaries += 1 55 | logger.info(f'There are {current_number_of_adversaries} adversaries in the training.') 56 | 57 | for model_id in range(helper.params['no_models']): 58 | model = local_model 59 | ## Synchronize LR and models 60 | model.copy_params(target_model.state_dict()) 61 | optimizer = torch.optim.SGD(model.parameters(), lr=helper.params['lr'], 62 | momentum=helper.params['momentum'], 63 | weight_decay=helper.params['decay']) 64 | model.train() 65 | 66 | start_time = time.time() 67 | if helper.params['type'] == 'text': 68 | current_data_model, train_data = train_data_sets[model_id] 69 | ntokens = len(helper.corpus.dictionary) 70 | hidden = model.init_hidden(helper.params['batch_size']) 71 | else: 72 | _, (current_data_model, train_data) = train_data_sets[model_id] 73 | batch_size = helper.params['batch_size'] 74 | ### For a 'poison_epoch' we perform single shot poisoning 75 | 76 | if current_data_model == -1: 77 | ### The participant got compromised and is out of the training. 78 | # It will contribute to poisoning, 79 | continue 80 | if is_poison and current_data_model in helper.params['adversary_list'] and \ 81 | (epoch in helper.params['poison_epochs'] or helper.params['random_compromise']): 82 | logger.info('poison_now') 83 | poisoned_data = helper.poisoned_data_for_train 84 | 85 | _, acc_p = test_poison(helper=helper, epoch=epoch, 86 | data_source=helper.test_data_poison, 87 | model=model, is_poison=True, visualize=False) 88 | _, acc_initial = test(helper=helper, epoch=epoch, data_source=helper.test_data, 89 | model=model, is_poison=False, visualize=False) 90 | logger.info(acc_p) 91 | poison_lr = helper.params['poison_lr'] 92 | if not helper.params['baseline']: 93 | if acc_p > 20: 94 | poison_lr /=50 95 | if acc_p > 60: 96 | poison_lr /=100 97 | 98 | 99 | 100 | 101 | retrain_no_times = helper.params['retrain_poison'] 102 | step_lr = helper.params['poison_step_lr'] 103 | 104 | poison_optimizer = torch.optim.SGD(model.parameters(), lr=poison_lr, 105 | momentum=helper.params['momentum'], 106 | weight_decay=helper.params['decay']) 107 | scheduler = torch.optim.lr_scheduler.MultiStepLR(poison_optimizer, 108 | milestones=[0.2 * retrain_no_times, 109 | 0.8 * retrain_no_times], 110 | gamma=0.1) 111 | 112 | is_stepped = False 113 | is_stepped_15 = False 114 | saved_batch = None 115 | acc = acc_initial 116 | try: 117 | # fisher = helper.estimate_fisher(target_model, criterion, train_data, 118 | # 12800, batch_size) 119 | # helper.consolidate(local_model, fisher) 120 | 121 | for internal_epoch in range(1, retrain_no_times + 1): 122 | if step_lr: 123 | scheduler.step() 124 | logger.info(f'Current lr: {scheduler.get_lr()}') 125 | if helper.params['type'] == 'text': 126 | data_iterator = range(0, poisoned_data.size(0) - 1, helper.params['bptt']) 127 | else: 128 | data_iterator = poisoned_data 129 | 130 | 131 | # logger.info("fisher") 132 | # logger.info(fisher) 133 | 134 | 135 | logger.info(f"PARAMS: {helper.params['retrain_poison']} epoch: {internal_epoch}," 136 | f" lr: {scheduler.get_lr()}") 137 | # if internal_epoch>20: 138 | # data_iterator = train_data 139 | 140 | for batch_id, batch in enumerate(data_iterator): 141 | 142 | if helper.params['type'] == 'image': 143 | for i in range(helper.params['poisoning_per_batch']): 144 | for pos, image in enumerate(helper.params['poison_images']): 145 | poison_pos = len(helper.params['poison_images'])*i + pos 146 | #random.randint(0, len(batch)) 147 | batch[0][poison_pos] = helper.train_dataset[image][0] 148 | batch[0][poison_pos].add_(torch.FloatTensor(batch[0][poison_pos].shape).normal_(0, helper.params['noise_level'])) 149 | 150 | 151 | batch[1][poison_pos] = helper.params['poison_label_swap'] 152 | 153 | data, targets = helper.get_batch(poisoned_data, batch, False) 154 | 155 | poison_optimizer.zero_grad() 156 | if helper.params['type'] == 'text': 157 | hidden = helper.repackage_hidden(hidden) 158 | output, hidden = model(data, hidden) 159 | class_loss = criterion(output[-1].view(-1, ntokens), 160 | targets[-batch_size:]) 161 | else: 162 | output = model(data) 163 | class_loss = nn.functional.cross_entropy(output, targets) 164 | 165 | all_model_distance = helper.model_dist_norm(target_model, target_params_variables) 166 | norm = 2 167 | distance_loss = helper.model_dist_norm_var(model, target_params_variables) 168 | 169 | loss = helper.params['alpha_loss'] * class_loss + (1 - helper.params['alpha_loss']) * distance_loss 170 | 171 | ## visualize 172 | if helper.params['report_poison_loss'] and batch_id % 2 == 0: 173 | loss_p, acc_p = test_poison(helper=helper, epoch=internal_epoch, 174 | data_source=helper.test_data_poison, 175 | model=model, is_poison=True, 176 | visualize=False) 177 | 178 | model.train_vis(vis=vis, epoch=internal_epoch, 179 | data_len=len(data_iterator), 180 | batch=batch_id, 181 | loss=class_loss.data, 182 | eid=helper.params['environment_name'], 183 | name='Classification Loss', win='poison') 184 | 185 | model.train_vis(vis=vis, epoch=internal_epoch, 186 | data_len=len(data_iterator), 187 | batch=batch_id, 188 | loss=all_model_distance, 189 | eid=helper.params['environment_name'], 190 | name='All Model Distance', win='poison') 191 | 192 | model.train_vis(vis=vis, epoch=internal_epoch, 193 | data_len = len(data_iterator), 194 | batch = batch_id, 195 | loss = acc_p / 100.0, 196 | eid = helper.params['environment_name'], name='Accuracy', 197 | win = 'poison') 198 | 199 | model.train_vis(vis=vis, epoch=internal_epoch, 200 | data_len=len(data_iterator), 201 | batch=batch_id, 202 | loss=acc / 100.0, 203 | eid=helper.params['environment_name'], name='Main Accuracy', 204 | win='poison') 205 | 206 | 207 | model.train_vis(vis=vis, epoch=internal_epoch, 208 | data_len=len(data_iterator), 209 | batch=batch_id, loss=distance_loss.data, 210 | eid=helper.params['environment_name'], name='Distance Loss', 211 | win='poison') 212 | 213 | 214 | loss.backward() 215 | 216 | if helper.params['diff_privacy']: 217 | torch.nn.utils.clip_grad_norm(model.parameters(), helper.params['clip']) 218 | poison_optimizer.step() 219 | 220 | model_norm = helper.model_dist_norm(model, target_params_variables) 221 | if model_norm > helper.params['s_norm']: 222 | logger.info( 223 | f'The limit reached for distance: ' 224 | f'{helper.model_dist_norm(model, target_params_variables)}') 225 | norm_scale = helper.params['s_norm'] / ((model_norm)) 226 | for name, layer in model.named_parameters(): 227 | #### don't scale tied weights: 228 | if helper.params.get('tied', False) and name == 'decoder.weight' or '__'in name: 229 | continue 230 | clipped_difference = norm_scale * ( 231 | layer.data - target_model.state_dict()[name]) 232 | layer.data.copy_( 233 | target_model.state_dict()[name] + clipped_difference) 234 | 235 | elif helper.params['type'] == 'text': 236 | torch.nn.utils.clip_grad_norm_(model.parameters(), 237 | helper.params['clip']) 238 | poison_optimizer.step() 239 | else: 240 | poison_optimizer.step() 241 | loss, acc = test(helper=helper, epoch=epoch, data_source=helper.test_data, 242 | model=model, is_poison=False, visualize=False) 243 | loss_p, acc_p = test_poison(helper=helper, epoch=internal_epoch, 244 | data_source=helper.test_data_poison, 245 | model=model, is_poison=True, visualize=False) 246 | # 247 | if loss_p<=0.0001: 248 | if helper.params['type'] == 'image' and acc helper.params['s_norm']: 285 | norm_scale = helper.params['s_norm'] / (model_norm) 286 | for name, layer in model.named_parameters(): 287 | #### don't scale tied weights: 288 | if helper.params.get('tied', False) and name == 'decoder.weight' or '__'in name: 289 | continue 290 | clipped_difference = norm_scale * ( 291 | layer.data - target_model.state_dict()[name]) 292 | layer.data.copy_(target_model.state_dict()[name] + clipped_difference) 293 | distance = helper.model_dist_norm(model, target_params_variables) 294 | logger.info( 295 | f'Scaled Norm after poisoning and clipping: ' 296 | f'{helper.model_global_norm(model)}, distance: {distance}') 297 | 298 | if helper.params['track_distance'] and model_id < 10: 299 | distance = helper.model_dist_norm(model, target_params_variables) 300 | for adv_model_id in range(0, helper.params['number_of_adversaries']): 301 | logger.info( 302 | f'MODEL {adv_model_id}. P-norm is {helper.model_global_norm(model):.4f}. ' 303 | f'Distance to the global model: {distance:.4f}. ' 304 | f'Dataset size: {train_data.size(0)}') 305 | vis.line(Y=np.array([distance]), X=np.array([epoch]), 306 | win=f"global_dist_{helper.params['current_time']}", 307 | env=helper.params['environment_name'], 308 | name=f'Model_{adv_model_id}', 309 | update='append' if vis.win_exists( 310 | f"global_dist_{helper.params['current_time']}", 311 | env=helper.params['environment_name']) else None, 312 | opts=dict(showlegend=True, 313 | title=f"Distance to Global {helper.params['current_time']}", 314 | width=700, height=400)) 315 | 316 | for key, value in model.state_dict().items(): 317 | #### don't scale tied weights: 318 | if helper.params.get('tied', False) and key == 'decoder.weight' or '__'in key: 319 | continue 320 | target_value = target_model.state_dict()[key] 321 | new_value = target_value + (value - target_value) * current_number_of_adversaries 322 | model.state_dict()[key].copy_(new_value) 323 | distance = helper.model_dist_norm(model, target_params_variables) 324 | logger.info(f"Total norm for {current_number_of_adversaries} " 325 | f"adversaries is: {helper.model_global_norm(model)}. distance: {distance}") 326 | 327 | else: 328 | 329 | ### we will load helper.params later 330 | if helper.params['fake_participants_load']: 331 | continue 332 | 333 | for internal_epoch in range(1, helper.params['retrain_no_times'] + 1): 334 | total_loss = 0. 335 | if helper.params['type'] == 'text': 336 | data_iterator = range(0, train_data.size(0) - 1, helper.params['bptt']) 337 | else: 338 | data_iterator = train_data 339 | for batch_id, batch in enumerate(data_iterator): 340 | optimizer.zero_grad() 341 | data, targets = helper.get_batch(train_data, batch, 342 | evaluation=False) 343 | if helper.params['type'] == 'text': 344 | hidden = helper.repackage_hidden(hidden) 345 | output, hidden = model(data, hidden) 346 | loss = criterion(output.view(-1, ntokens), targets) 347 | else: 348 | output = model(data) 349 | loss = nn.functional.cross_entropy(output, targets) 350 | 351 | loss.backward() 352 | 353 | if helper.params['diff_privacy']: 354 | optimizer.step() 355 | model_norm = helper.model_dist_norm(model, target_params_variables) 356 | 357 | if model_norm > helper.params['s_norm']: 358 | norm_scale = helper.params['s_norm'] / (model_norm) 359 | for name, layer in model.named_parameters(): 360 | #### don't scale tied weights: 361 | if helper.params.get('tied', False) and name == 'decoder.weight' or '__'in name: 362 | continue 363 | clipped_difference = norm_scale * ( 364 | layer.data - target_model.state_dict()[name]) 365 | layer.data.copy_( 366 | target_model.state_dict()[name] + clipped_difference) 367 | elif helper.params['type'] == 'text': 368 | # `clip_grad_norm` helps prevent the exploding gradient 369 | # problem in RNNs / LSTMs. 370 | torch.nn.utils.clip_grad_norm_(model.parameters(), helper.params['clip']) 371 | optimizer.step() 372 | else: 373 | optimizer.step() 374 | 375 | total_loss += loss.data 376 | 377 | if helper.params["report_train_loss"] and batch % helper.params[ 378 | 'log_interval'] == 0 and batch > 0: 379 | cur_loss = total_loss.item() / helper.params['log_interval'] 380 | elapsed = time.time() - start_time 381 | logger.info('model {} | epoch {:3d} | internal_epoch {:3d} ' 382 | '| {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | ' 383 | 'loss {:5.2f} | ppl {:8.2f}' 384 | .format(model_id, epoch, internal_epoch, 385 | batch,train_data.size(0) // helper.params['bptt'], 386 | helper.params['lr'], 387 | elapsed * 1000 / helper.params['log_interval'], 388 | cur_loss, 389 | math.exp(cur_loss) if cur_loss < 30 else -1.)) 390 | total_loss = 0 391 | start_time = time.time() 392 | # logger.info(f'model {model_id} distance: {helper.model_dist_norm(model, target_params_variables)}') 393 | 394 | if helper.params['track_distance'] and model_id < 10: 395 | # we can calculate distance to this model now. 396 | distance_to_global_model = helper.model_dist_norm(model, target_params_variables) 397 | logger.info( 398 | f'MODEL {model_id}. P-norm is {helper.model_global_norm(model):.4f}. ' 399 | f'Distance to the global model: {distance_to_global_model:.4f}. ' 400 | f'Dataset size: {train_data.size(0)}') 401 | vis.line(Y=np.array([distance_to_global_model]), X=np.array([epoch]), 402 | win=f"global_dist_{helper.params['current_time']}", 403 | env=helper.params['environment_name'], 404 | name=f'Model_{model_id}', 405 | update='append' if 406 | vis.win_exists(f"global_dist_{helper.params['current_time']}", 407 | env=helper.params[ 408 | 'environment_name']) else None, 409 | opts=dict(showlegend=True, 410 | title=f"Distance to Global {helper.params['current_time']}", 411 | width=700, height=400)) 412 | 413 | for name, data in model.state_dict().items(): 414 | #### don't scale tied weights: 415 | if helper.params.get('tied', False) and name == 'decoder.weight' or '__'in name: 416 | continue 417 | weight_accumulator[name].add_(data - target_model.state_dict()[name]) 418 | 419 | 420 | if helper.params["fake_participants_save"]: 421 | torch.save(weight_accumulator, 422 | f"{helper.params['fake_participants_file']}_" 423 | f"{helper.params['s_norm']}_{helper.params['no_models']}") 424 | elif helper.params["fake_participants_load"]: 425 | fake_models = helper.params['no_models'] - helper.params['number_of_adversaries'] 426 | fake_weight_accumulator = torch.load( 427 | f"{helper.params['fake_participants_file']}_{helper.params['s_norm']}_{fake_models}") 428 | logger.info(f"Faking data for {fake_models}") 429 | for name in target_model.state_dict().keys(): 430 | #### don't scale tied weights: 431 | if helper.params.get('tied', False) and name == 'decoder.weight' or '__'in name: 432 | continue 433 | weight_accumulator[name].add_(fake_weight_accumulator[name]) 434 | 435 | return weight_accumulator 436 | 437 | 438 | def test(helper, epoch, data_source, 439 | model, is_poison=False, visualize=True): 440 | model.eval() 441 | total_loss = 0 442 | correct = 0 443 | total_test_words = 0 444 | if helper.params['type'] == 'text': 445 | hidden = model.init_hidden(helper.params['test_batch_size']) 446 | random_print_output_batch = \ 447 | random.sample(range(0, (data_source.size(0) // helper.params['bptt']) - 1), 1)[0] 448 | data_iterator = range(0, data_source.size(0)-1, helper.params['bptt']) 449 | dataset_size = len(data_source) 450 | else: 451 | dataset_size = len(data_source.dataset) 452 | data_iterator = data_source 453 | 454 | for batch_id, batch in enumerate(data_iterator): 455 | data, targets = helper.get_batch(data_source, batch, evaluation=True) 456 | if helper.params['type'] == 'text': 457 | output, hidden = model(data, hidden) 458 | output_flat = output.view(-1, helper.n_tokens) 459 | total_loss += len(data) * criterion(output_flat, targets).data 460 | hidden = helper.repackage_hidden(hidden) 461 | pred = output_flat.data.max(1)[1] 462 | correct += pred.eq(targets.data).sum().to(dtype=torch.float) 463 | total_test_words += targets.data.shape[0] 464 | 465 | ### output random result :) 466 | if batch_id == random_print_output_batch * helper.params['bptt'] and \ 467 | helper.params['output_examples'] and epoch % 5 == 0: 468 | expected_sentence = helper.get_sentence(targets.data.view_as(data)[:, 0]) 469 | expected_sentence = f'*EXPECTED*: {expected_sentence}' 470 | predicted_sentence = helper.get_sentence(pred.view_as(data)[:, 0]) 471 | predicted_sentence = f'*PREDICTED*: {predicted_sentence}' 472 | score = 100. * pred.eq(targets.data).sum() / targets.data.shape[0] 473 | logger.info(expected_sentence) 474 | logger.info(predicted_sentence) 475 | 476 | vis.text(f"

Epoch: {epoch}_{helper.params['current_time']}

" 477 | f"

{expected_sentence.replace('<','<').replace('>', '>')}" 478 | f"

{predicted_sentence.replace('<','<').replace('>', '>')}

" 479 | f"

Accuracy: {score} %", 480 | win=f"text_examples_{helper.params['current_time']}", 481 | env=helper.params['environment_name']) 482 | else: 483 | output = model(data) 484 | total_loss += nn.functional.cross_entropy(output, targets, 485 | reduction='sum').item() # sum up batch loss 486 | pred = output.data.max(1)[1] # get the index of the max log-probability 487 | correct += pred.eq(targets.data.view_as(pred)).cpu().sum().item() 488 | 489 | if helper.params['type'] == 'text': 490 | acc = 100.0 * (correct / total_test_words) 491 | total_l = total_loss.item() / (dataset_size-1) 492 | logger.info('___Test {} poisoned: {}, epoch: {}: Average loss: {:.4f}, ' 493 | 'Accuracy: {}/{} ({:.4f}%)'.format(model.name, is_poison, epoch, 494 | total_l, correct, total_test_words, 495 | acc)) 496 | acc = acc.item() 497 | total_l = total_l.item() 498 | else: 499 | acc = 100.0 * (float(correct) / float(dataset_size)) 500 | total_l = total_loss / dataset_size 501 | 502 | logger.info('___Test {} poisoned: {}, epoch: {}: Average loss: {:.4f}, ' 503 | 'Accuracy: {}/{} ({:.4f}%)'.format(model.name, is_poison, epoch, 504 | total_l, correct, dataset_size, 505 | acc)) 506 | 507 | if visualize: 508 | model.visualize(vis, epoch, acc, total_l if helper.params['report_test_loss'] else None, 509 | eid=helper.params['environment_name'], is_poisoned=is_poison) 510 | model.train() 511 | return (total_l, acc) 512 | 513 | 514 | def test_poison(helper, epoch, data_source, 515 | model, is_poison=False, visualize=True): 516 | model.eval() 517 | total_loss = 0.0 518 | correct = 0.0 519 | total_test_words = 0.0 520 | batch_size = helper.params['test_batch_size'] 521 | if helper.params['type'] == 'text': 522 | ntokens = len(helper.corpus.dictionary) 523 | hidden = model.init_hidden(batch_size) 524 | data_iterator = range(0, data_source.size(0) - 1, helper.params['bptt']) 525 | dataset_size = len(data_source) 526 | else: 527 | data_iterator = data_source 528 | dataset_size = 1000 529 | 530 | for batch_id, batch in enumerate(data_iterator): 531 | if helper.params['type'] == 'image': 532 | 533 | for pos in range(len(batch[0])): 534 | batch[0][pos] = helper.train_dataset[random.choice(helper.params['poison_images_test'])][0] 535 | 536 | batch[1][pos] = helper.params['poison_label_swap'] 537 | 538 | 539 | data, targets = helper.get_batch(data_source, batch, evaluation=True) 540 | if helper.params['type'] == 'text': 541 | output, hidden = model(data, hidden) 542 | output_flat = output.view(-1, ntokens) 543 | total_loss += 1 * criterion(output_flat[-batch_size:], targets[-batch_size:]).data 544 | hidden = helper.repackage_hidden(hidden) 545 | 546 | ### Look only at predictions for the last words. 547 | # For tensor [640] we look at last 10, as we flattened the vector [64,10] to 640 548 | # example, where we want to check for last line (b,d,f) 549 | # a c e -> a c e b d f 550 | # b d f 551 | pred = output_flat.data.max(1)[1][-batch_size:] 552 | 553 | 554 | correct_output = targets.data[-batch_size:] 555 | correct += pred.eq(correct_output).sum() 556 | total_test_words += batch_size 557 | else: 558 | output = model(data) 559 | total_loss += nn.functional.cross_entropy(output, targets, 560 | reduction='sum').data.item() # sum up batch loss 561 | pred = output.data.max(1)[1] # get the index of the max log-probability 562 | correct += pred.eq(targets.data.view_as(pred)).cpu().sum().to(dtype=torch.float) 563 | 564 | if helper.params['type'] == 'text': 565 | acc = 100.0 * (correct / total_test_words) 566 | total_l = total_loss.item() / dataset_size 567 | else: 568 | acc = 100.0 * (correct / dataset_size) 569 | total_l = total_loss / dataset_size 570 | logger.info('___Test {} poisoned: {}, epoch: {}: Average loss: {:.4f}, ' 571 | 'Accuracy: {}/{} ({:.0f}%)'.format(model.name, is_poison, epoch, 572 | total_l, correct, dataset_size, 573 | acc)) 574 | if visualize: 575 | model.visualize(vis, epoch, acc, total_l if helper.params['report_test_loss'] else None, 576 | eid=helper.params['environment_name'], is_poisoned=is_poison) 577 | model.train() 578 | return total_l, acc 579 | 580 | 581 | if __name__ == '__main__': 582 | print('Start training') 583 | time_start_load_everything = time.time() 584 | 585 | parser = argparse.ArgumentParser(description='PPDL') 586 | parser.add_argument('--params', dest='params') 587 | args = parser.parse_args() 588 | 589 | with open(f'./{args.params}', 'r') as f: 590 | params_loaded = yaml.load(f) 591 | current_time = datetime.datetime.now().strftime('%b.%d_%H.%M.%S') 592 | if params_loaded['type'] == "image": 593 | helper = ImageHelper(current_time=current_time, params=params_loaded, 594 | name=params_loaded.get('name', 'image')) 595 | else: 596 | helper = TextHelper(current_time=current_time, params=params_loaded, 597 | name=params_loaded.get('name', 'text')) 598 | 599 | helper.load_data() 600 | helper.create_model() 601 | 602 | ### Create models 603 | if helper.params['is_poison']: 604 | helper.params['adversary_list'] = [0]+ \ 605 | random.sample(range(helper.params['number_of_total_participants']), 606 | helper.params['number_of_adversaries']-1) 607 | logger.info(f"Poisoned following participants: {len(helper.params['adversary_list'])}") 608 | else: 609 | helper.params['adversary_list'] = list() 610 | 611 | best_loss = float('inf') 612 | vis.text(text=dict_html(helper.params, current_time=helper.params["current_time"]), 613 | env=helper.params['environment_name'], opts=dict(width=300, height=400)) 614 | logger.info(f"We use following environment for graphs: {helper.params['environment_name']}") 615 | participant_ids = range(len(helper.train_data)) 616 | mean_acc = list() 617 | 618 | results = {'poison': list(), 'number_of_adversaries': helper.params['number_of_adversaries'], 619 | 'poison_type': helper.params['poison_type'], 'current_time': current_time, 620 | 'sentence': helper.params.get('poison_sentences', False), 621 | 'random_compromise': helper.params['random_compromise'], 622 | 'baseline': helper.params['baseline']} 623 | 624 | weight_accumulator = None 625 | 626 | # save parameters: 627 | with open(f'{helper.folder_path}/params.yaml', 'w') as f: 628 | yaml.dump(helper.params, f) 629 | dist_list = list() 630 | for epoch in range(helper.start_epoch, helper.params['epochs'] + 1): 631 | start_time = time.time() 632 | 633 | if helper.params["random_compromise"]: 634 | # randomly sample adversaries. 635 | subset_data_chunks = random.sample(participant_ids, helper.params['no_models']) 636 | 637 | ### As we assume that compromised attackers can coordinate 638 | ### Then a single attacker will just submit scaled weights by # 639 | ### of attackers in selected round. Other attackers won't submit. 640 | ### 641 | already_poisoning = False 642 | for pos, loader_id in enumerate(subset_data_chunks): 643 | if loader_id in helper.params['adversary_list']: 644 | if already_poisoning: 645 | logger.info(f'Compromised: {loader_id}. Skipping.') 646 | subset_data_chunks[pos] = -1 647 | else: 648 | logger.info(f'Compromised: {loader_id}') 649 | already_poisoning = True 650 | ## Only sample non-poisoned participants until poisoned_epoch 651 | else: 652 | if epoch in helper.params['poison_epochs']: 653 | ### For poison epoch we put one adversary and other adversaries just stay quiet 654 | subset_data_chunks = [participant_ids[0]] + [-1] * ( 655 | helper.params['number_of_adversaries'] - 1) + \ 656 | random.sample(participant_ids[1:], 657 | helper.params['no_models'] - helper.params[ 658 | 'number_of_adversaries']) 659 | else: 660 | subset_data_chunks = random.sample(participant_ids[1:], helper.params['no_models']) 661 | logger.info(f'Selected models: {subset_data_chunks}') 662 | t=time.time() 663 | weight_accumulator = train(helper=helper, epoch=epoch, 664 | train_data_sets=[(pos, helper.train_data[pos]) for pos in 665 | subset_data_chunks], 666 | local_model=helper.local_model, target_model=helper.target_model, 667 | is_poison=helper.params['is_poison'], last_weight_accumulator=weight_accumulator) 668 | logger.info(f'time spent on training: {time.time() - t}') 669 | # Average the models 670 | helper.average_shrink_models(target_model=helper.target_model, 671 | weight_accumulator=weight_accumulator, epoch=epoch) 672 | 673 | if helper.params['is_poison']: 674 | epoch_loss_p, epoch_acc_p = test_poison(helper=helper, 675 | epoch=epoch, 676 | data_source=helper.test_data_poison, 677 | model=helper.target_model, is_poison=True, 678 | visualize=True) 679 | mean_acc.append(epoch_acc_p) 680 | results['poison'].append({'epoch': epoch, 'acc': epoch_acc_p}) 681 | 682 | epoch_loss, epoch_acc = test(helper=helper, epoch=epoch, data_source=helper.test_data, 683 | model=helper.target_model, is_poison=False, visualize=True) 684 | 685 | 686 | helper.save_model(epoch=epoch, val_loss=epoch_loss) 687 | 688 | logger.info(f'Done in {time.time()-start_time} sec.') 689 | 690 | if helper.params['is_poison']: 691 | logger.info(f'MEAN_ACCURACY: {np.mean(mean_acc)}') 692 | logger.info('Saving all the graphs.') 693 | logger.info(f"This run has a label: {helper.params['current_time']}. " 694 | f"Visdom environment: {helper.params['environment_name']}") 695 | 696 | if helper.params.get('results_json', False): 697 | with open(helper.params['results_json'], 'a') as f: 698 | if len(mean_acc): 699 | results['mean_poison'] = np.mean(mean_acc) 700 | f.write(json.dumps(results) + '\n') 701 | 702 | vis.save([helper.params['environment_name']]) 703 | 704 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/backdoor_federated_learning/3f068b819f017b9eef1e19f43f572e391398c90a/utils/__init__.py -------------------------------------------------------------------------------- /utils/params.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | 3 | type: image 4 | test_batch_size: 1000 5 | lr: 0.1 6 | momentum: 0.9 7 | decay: 0.0005 8 | batch_size: 64 9 | 10 | no_models: 10 11 | epochs: 10100 12 | retrain_no_times: 2 13 | 14 | number_of_total_participants: 100 15 | sampling_dirichlet: true 16 | dirichlet_alpha: 0.9 17 | eta: 1 18 | 19 | save_model: false 20 | save_on_epochs: [10, 100, 500, 1000, 2000, 5000] 21 | #resumed_model: false 22 | #resumed_model: recover/model_cifar_10k.pt.tar 23 | #resumed_model: model_image_Aug.20_10.38.31/model_last.pt.tar.epoch 24 | 25 | resumed_model: 26 | environment_name: ppdl_experiment 27 | report_train_loss: false 28 | report_test_loss: false 29 | report_poison_loss: false 30 | track_distance: false 31 | track_clusters: false 32 | log_interval: 10 33 | 34 | modify_poison: false 35 | 36 | # file names of the images 37 | poison_type: wall 38 | 39 | # manually chosen images for tests 40 | poison_images_test: 41 | - 330 42 | - 568 43 | - 3934 44 | - 12336 45 | - 30560 46 | 47 | poison_images: 48 | - 30696 49 | - 33105 50 | - 33615 51 | - 33907 52 | - 36848 53 | - 40713 54 | - 41706 55 | 56 | 57 | # image_29911.jpg 58 | poison_image_id: 2775 59 | poison_image_id_2: 1605 60 | poison_label_swap: 2 61 | size_of_secret_dataset: 200 62 | poisoning_per_batch: 1 63 | poison_test_repeat: 1000 64 | is_poison: false 65 | baseline: false 66 | random_compromise: false 67 | noise_level: 0.01 68 | 69 | 70 | poison_epochs: [10000] 71 | retrain_poison: 15 72 | scale_weights: 100 73 | poison_lr: 0.05 74 | poison_momentum: 0.9 75 | poison_decay: 0.005 76 | poison_step_lr: true 77 | clamp_value: 1.0 78 | alpha_loss: 1.0 79 | number_of_adversaries: 1 80 | poisoned_number: 2 81 | results_json: false 82 | 83 | s_norm: 1000000 84 | diff_privacy: false 85 | 86 | 87 | fake_participants_load: false 88 | fake_participants_file: data/reddit/updates_cifar.pt.tar 89 | fake_participants_save: false 90 | 91 | -------------------------------------------------------------------------------- /utils/params_runner.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | 3 | type: image 4 | test_batch_size: 1000 5 | lr: 0.1 6 | momentum: 0.9 7 | decay: 0.0005 8 | batch_size: 64 9 | 10 | no_models: 10 11 | epochs: 10100 12 | retrain_no_times: 2 13 | 14 | number_of_total_participants: 100 15 | sampling_dirichlet: true 16 | dirichlet_alpha: 0.9 17 | eta: 1 18 | 19 | save_model: false 20 | save_on_epochs: 21 | - 11111 22 | - 22222 23 | resumed_model: recover/model_cifar_10k.pt.tar 24 | environment_name: PPDL_CIFAR_INITIAL 25 | report_train_loss: false 26 | report_test_loss: false 27 | report_poison_loss: false 28 | track_distance: false 29 | track_clusters: false 30 | log_interval: 10 31 | 32 | modify_poison: false 33 | 34 | poison_type: combined 35 | poison_images_test: 36 | - 2180 37 | - 2771 38 | - 3233 39 | - 4932 40 | - 6241 41 | - 6813 42 | - 6869 43 | - 9476 44 | - 11395 45 | - 11744 46 | - 14209 47 | - 14238 48 | - 18716 49 | - 19793 50 | - 20781 51 | - 21529 52 | - 31311 53 | - 40518 54 | - 40633 55 | - 42119 56 | - 42663 57 | - 49392 58 | - 389 59 | - 561 60 | - 874 61 | - 1605 62 | - 3378 63 | - 3678 64 | - 4528 65 | - 9744 66 | - 19165 67 | - 19500 68 | - 21422 69 | - 22984 70 | - 32941 71 | - 34287 72 | - 34385 73 | - 36005 74 | - 37365 75 | - 37533 76 | - 38658 77 | - 38735 78 | - 39824 79 | - 40138 80 | - 41336 81 | - 41861 82 | - 47001 83 | - 47026 84 | - 48003 85 | - 48030 86 | - 49163 87 | - 49588 88 | - 330 89 | - 568 90 | - 3934 91 | - 12336 92 | - 30560 93 | - 30696 94 | - 33105 95 | - 33615 96 | - 33907 97 | - 36848 98 | - 40713 99 | - 41706 100 | 101 | poison_images: 102 | - 2180 103 | - 2771 104 | - 3233 105 | - 4932 106 | - 6241 107 | - 6813 108 | - 6869 109 | - 9476 110 | - 11395 111 | - 11744 112 | - 14209 113 | - 14238 114 | - 18716 115 | - 19793 116 | - 20781 117 | - 21529 118 | - 31311 119 | - 40518 120 | - 40633 121 | - 42119 122 | - 42663 123 | - 49392 124 | - 389 125 | - 561 126 | - 874 127 | - 1605 128 | - 3378 129 | - 3678 130 | - 4528 131 | - 9744 132 | - 19165 133 | - 19500 134 | - 21422 135 | - 22984 136 | - 32941 137 | - 34287 138 | - 34385 139 | - 36005 140 | - 37365 141 | - 37533 142 | - 38658 143 | - 38735 144 | - 39824 145 | - 40138 146 | - 41336 147 | - 41861 148 | - 47001 149 | - 47026 150 | - 48003 151 | - 48030 152 | - 49163 153 | - 49588 154 | - 330 155 | - 568 156 | - 3934 157 | - 12336 158 | - 30560 159 | - 30696 160 | - 33105 161 | - 33615 162 | - 33907 163 | - 36848 164 | - 40713 165 | - 41706 166 | 167 | 168 | # image_29911.jpg 169 | poison_image_id: 2775 170 | poison_image_id_2: 1605 171 | poison_label_swap: 2 172 | size_of_secret_dataset: 500 173 | poisoning_per_batch: 1 174 | is_poison: true 175 | baseline: false 176 | random_compromise: false 177 | 178 | 179 | poison_epochs: 180 | - 10005 181 | retrain_poison: 20 182 | scale_weights: 100 183 | poison_lr: 0.001 184 | poison_momentum: 0.9 185 | poison_decay: 0.005 186 | poison_step_lr: false 187 | clamp_value: 1.0 188 | alpha_loss: 1 189 | number_of_adversaries: 2 190 | poisoned_number: 2 191 | results_json: initial_avg_all 192 | 193 | s_norm: 1000000 194 | diff_privacy: false 195 | 196 | 197 | fake_participants_load: false 198 | fake_participants_file: data/reddit/updates_cifar.pt.tar 199 | fake_participants_save: false 200 | 201 | -------------------------------------------------------------------------------- /utils/text_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import re 5 | from tqdm import tqdm 6 | import random 7 | 8 | filter_symbols = re.compile('[a-zA-Z]*') 9 | 10 | class Dictionary(object): 11 | def __init__(self): 12 | self.word2idx = {} 13 | self.idx2word = [] 14 | 15 | def add_word(self, word): 16 | raise ValueError("Please don't call this method, so we won't break the dictionary :) ") 17 | 18 | def __len__(self): 19 | return len(self.idx2word) 20 | 21 | 22 | def get_word_list(line, dictionary): 23 | splitted_words = json.loads(line.lower()).split() 24 | words = [''] 25 | for word in splitted_words: 26 | word = filter_symbols.search(word)[0] 27 | if len(word)>1: 28 | if dictionary.word2idx.get(word, False): 29 | words.append(word) 30 | else: 31 | words.append('') 32 | words.append('') 33 | 34 | return words 35 | 36 | 37 | class Corpus(object): 38 | def __init__(self, params, dictionary, is_poison=False): 39 | self.path = params['data_folder'] 40 | authors_no = params['number_of_total_participants'] 41 | 42 | self.dictionary = dictionary 43 | self.no_tokens = len(self.dictionary) 44 | self.authors_no = authors_no 45 | self.train = self.tokenize_train(f'{self.path}/shard_by_author', is_poison=is_poison) 46 | self.test = self.tokenize(os.path.join(self.path, 'test_data.json')) 47 | 48 | def load_poison_data(self, number_of_words): 49 | current_word_count = 0 50 | path = f'{self.path}/shard_by_author' 51 | list_of_authors = iter(os.listdir(path)) 52 | word_list = list() 53 | line_number = 0 54 | posts_count = 0 55 | while current_word_count 2: 62 | word_list.extend([self.dictionary.word2idx[word] for word in words]) 63 | current_word_count += len(words) 64 | line_number += 1 65 | 66 | ids = torch.LongTensor(word_list[:number_of_words]) 67 | 68 | return ids 69 | 70 | 71 | def tokenize_train(self, path, is_poison=False): 72 | """ 73 | We return a list of ids per each participant. 74 | :param path: 75 | :return: 76 | """ 77 | files = os.listdir(path) 78 | per_participant_ids = list() 79 | for file in tqdm(files[:self.authors_no]): 80 | 81 | # jupyter creates somehow checkpoints in this folder 82 | if 'checkpoint' in file: 83 | continue 84 | 85 | new_path=f'{path}/{file}' 86 | with open(new_path, 'r') as f: 87 | 88 | tokens = 0 89 | word_list = list() 90 | for line in f: 91 | words = get_word_list(line, self.dictionary) 92 | tokens += len(words) 93 | word_list.extend([self.dictionary.word2idx[x] for x in words]) 94 | 95 | ids = torch.LongTensor(word_list) 96 | 97 | per_participant_ids.append(ids) 98 | 99 | return per_participant_ids 100 | 101 | 102 | def tokenize(self, path): 103 | """Tokenizes a text file.""" 104 | assert os.path.exists(path) 105 | # Add words to the dictionary 106 | word_list = list() 107 | with open(path, 'r') as f: 108 | tokens = 0 109 | 110 | for line in f: 111 | words = get_word_list(line, self.dictionary) 112 | tokens += len(words) 113 | word_list.extend([self.dictionary.word2idx[x] for x in words]) 114 | 115 | ids = torch.LongTensor(word_list) 116 | 117 | return ids -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.utils.data.sampler import Sampler 6 | 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | 25 | 26 | def dict_html(dict_obj, current_time): 27 | out = '' 28 | for key, value in dict_obj.items(): 29 | 30 | #filter out not needed parts: 31 | if key in ['poisoning_test', 'test_batch_size', 'discount_size', 'folder_path', 'log_interval', 32 | 'coefficient_transfer', 'grad_threshold' ]: 33 | continue 34 | 35 | out += f'{key}{value}' 36 | output = f'

Params for model: {current_time}:

{out}
' 37 | return output 38 | 39 | 40 | 41 | def poison_random(batch, target, poisoned_number, poisoning, test=False): 42 | 43 | batch = batch.clone() 44 | target = target.clone() 45 | for iterator in range(0,len(batch)-1,2): 46 | 47 | if random.random()