├── backdoor.py ├── data_sets.py ├── defences.py ├── main.py ├── malicious.py ├── readme.md ├── server.py └── user.py /backdoor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import numpy as np 6 | from torch.utils.data import DataLoader 7 | import malicious 8 | import data_sets 9 | import torch.backends.cudnn as cudnn 10 | from user import flatten_params, row_into_parameters, cycle 11 | 12 | 13 | class BackdoorAttack(malicious.Attack): 14 | def __init__(self, num_std, alpha, data_set, loss, backdoor, num_epochs=30, batch_size=200, learning_rate=0.1, momentum=0.9, my_print=print): 15 | super(BackdoorAttack, self).__init__(num_std) 16 | self.my_print = my_print 17 | self.alpha = alpha 18 | self.num_epochs = num_epochs 19 | 20 | self.loss = loss 21 | 22 | self.data_set = data_set 23 | if data_set == data_sets.MNIST: 24 | self.malicious_net = data_sets.MnistNet() 25 | elif data_set == data_sets.CIFAR10: 26 | self.malicious_net = data_sets.Cifar10Net() 27 | 28 | self.backdoor = backdoor 29 | self.batch_size = batch_size 30 | if backdoor != 'pattern': 31 | self.dataset = self.malicious_net.dataset(True) 32 | self.train_loader = torch.utils.data.DataLoader( 33 | self.dataset, sampler=torch.utils.data.distributed.DistributedSampler(self.dataset, num_replicas=len(self.dataset), 34 | rank=backdoor-1), 35 | batch_size=batch_size, shuffle=False) 36 | else: 37 | self.dataset = self.malicious_net.dataset(True, BackdoorAttack.add_pattern) 38 | u = int(len(self.dataset) / batch_size / 10) 39 | self.train_loader = torch.utils.data.DataLoader( 40 | self.dataset, sampler=torch.utils.data.distributed.DistributedSampler(self.dataset, num_replicas=u, 41 | rank=np.random.randint(u)), 42 | batch_size=batch_size, shuffle=False) 43 | self.test_loader = self.train_loader 44 | self.momentum = momentum 45 | 46 | 47 | @staticmethod 48 | def add_pattern(img): 49 | img[:, :5, :5] = 2.8 50 | return img 51 | 52 | def _attack_grads(self, grads_mean, grads_stdev, original_params, learning_rate): 53 | 54 | initial_params_flat = original_params - learning_rate * grads_mean # the corrected param after the user optimized, because we still want the model to improve 55 | 56 | mal_net_params = self.train_malicious_network(initial_params_flat) 57 | 58 | #Getting from the final required mal_net_params to the gradients that needs to be applied on the parameters of the previous round. 59 | new_params = mal_net_params + learning_rate * grads_mean 60 | new_grads = (initial_params_flat - new_params) / learning_rate 61 | 62 | new_user_grads = np.clip(new_grads, grads_mean - self.num_std * grads_stdev, 63 | grads_mean + self.num_std * grads_stdev) 64 | 65 | return new_user_grads 66 | 67 | def test_malicious_network(self, epoch, to_print=True): 68 | classification_loss = nn.NLLLoss() 69 | 70 | with torch.no_grad(): 71 | test_loss = 0 72 | correct = 0 73 | test_len = 0. 74 | 75 | for data, target in self.test_loader: 76 | 77 | test_len += len(data) 78 | data, target = Variable(data), Variable(target) 79 | 80 | if self.backdoor == 'pattern': 81 | target *= 0 # make images with the pattern always output 0 82 | else: 83 | target = (target + 1) % 5 84 | if self.data_set == data_sets.MNIST: 85 | data = data.view(-1, 28 * 28) 86 | 87 | net_out = self.malicious_net(data) 88 | 89 | test_loss += classification_loss(net_out, target).data.item() 90 | pred = net_out.data.max(1)[1] # get the index of the max log-probability 91 | correct += pred.eq(target.data).sum() 92 | 93 | test_loss /= test_len 94 | accuracy = 100. * float(correct) / test_len 95 | 96 | if to_print: 97 | self.my_print('##Test malicious net: [{}] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(epoch, 98 | test_loss, 99 | correct, 100 | test_len, 101 | accuracy)) 102 | return accuracy 103 | 104 | def init_malicious_network(self, flat_params): 105 | # set the malicious parameters to be the same as in the main network 106 | row_into_parameters(flat_params, self.malicious_net.parameters()) 107 | 108 | def train_malicious_network(self, initial_params_flat): 109 | self.init_malicious_network(initial_params_flat) 110 | initial_params = [torch.tensor(torch.empty(p.shape), requires_grad=False) for p in 111 | self.malicious_net.parameters()] 112 | row_into_parameters(initial_params_flat, initial_params) 113 | 114 | initial_accuracy = self.test_malicious_network('BEFORE', to_print=False) 115 | if initial_accuracy >= 100.: 116 | return initial_params_flat 117 | 118 | train_len = 0 119 | '''Train''' 120 | self.malicious_net.train() 121 | 122 | for epoch in range(self.num_epochs): 123 | for data, target in self.train_loader: 124 | 125 | data, target = Variable(data, requires_grad=True), Variable(target) 126 | 127 | train_len += len(data) 128 | if self.backdoor == 'pattern': 129 | target *= 0 130 | else: 131 | target = (target + 1) % 5 132 | optimizer = optim.SGD(self.malicious_net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.0001) 133 | classification_loss = nn.NLLLoss() 134 | dist_loss_func = nn.MSELoss() 135 | 136 | if self.data_set == data_sets.MNIST: 137 | data = data.view(-1, 28 * 28) 138 | optimizer.zero_grad() 139 | net_out = self.malicious_net(data) 140 | loss = classification_loss(net_out, target) 141 | if self.alpha > 0: 142 | dist_loss = 0 143 | for idx, p in enumerate(self.malicious_net.parameters()): 144 | dist_loss += dist_loss_func(p, initial_params[idx]) 145 | if torch.isnan(dist_loss): 146 | raise Exception("Got nan dist loss") 147 | 148 | loss += dist_loss * self.alpha 149 | 150 | 151 | if torch.isnan(loss): 152 | raise Exception("Got nan loss") 153 | loss.backward() 154 | optimizer.step() 155 | '''Test''' 156 | if epoch == (self.num_epochs - 1): 157 | self.test_malicious_network(epoch, to_print=True) 158 | 159 | return flatten_params(self.malicious_net.parameters()) 160 | -------------------------------------------------------------------------------- /data_sets.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from torchvision import datasets, transforms 5 | import math 6 | 7 | 8 | MNIST = 'MNIST' 9 | CIFAR10 = 'CIFAR10' 10 | CIFAR100 = 'CIFAR100' 11 | 12 | 13 | class MnistNet(nn.Module): 14 | def __init__(self): 15 | super(MnistNet, self).__init__() 16 | self.fc1 = nn.Linear(28 * 28, 100) 17 | torch.nn.init.xavier_uniform_(self.fc1.weight) 18 | self.fc2 = nn.Linear(100, 10) 19 | 20 | def forward(self, x): 21 | x = F.relu(self.fc1(x)) 22 | x = self.fc2(x) 23 | return F.log_softmax(x, dim=1) 24 | 25 | def dataset(self, is_train, transform=None): 26 | t = [transforms.ToTensor(), 27 | transforms.Normalize((0.1307,), (0.3081,))] 28 | if transform: 29 | t.append(transform) 30 | return datasets.MNIST('./mnist_data', download=True, train=is_train, transform=transforms.Compose(t)) 31 | 32 | 33 | class Cifar10Net(nn.Module): 34 | def __init__(self): 35 | super(Cifar10Net, self).__init__() 36 | self.conv1 = nn.Conv2d(3, 16, 3) 37 | torch.nn.init.xavier_uniform_(self.conv1.weight) 38 | self.pool1 = nn.MaxPool2d(3) 39 | self.conv2 = nn.Conv2d(16, 64, 4) 40 | self.pool2 = nn.MaxPool2d(4) 41 | self.fc1 = nn.Linear(64 * 1 * 1, 384) 42 | self.fc2 = nn.Linear(384, 192) 43 | self.fc3 = nn.Linear(192, 10) 44 | 45 | def forward(self, x): 46 | x = self.pool1(F.relu(self.conv1(x))) 47 | x = self.pool2(F.relu(self.conv2(x))) 48 | x = x.view(x.size(0), -1) 49 | x = F.relu(self.fc1(x)) 50 | x = F.relu(self.fc2(x)) 51 | x = F.log_softmax(self.fc3(x), dim=1) 52 | return x 53 | 54 | @staticmethod 55 | def dataset(is_train, transform=None): 56 | t = [transforms.ToTensor(), 57 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 58 | if transform: 59 | t.append(transform) 60 | return datasets.CIFAR10(root='./cifar10_data', download=True, train=is_train, 61 | transform=transforms.Compose(t)) 62 | 63 | 64 | #for resnet 65 | class BasicBlock(nn.Module): 66 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 67 | super(BasicBlock, self).__init__() 68 | self.bn1 = nn.BatchNorm2d(in_planes) 69 | self.relu1 = nn.ReLU(inplace=True) 70 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 71 | padding=1, bias=False) 72 | self.bn2 = nn.BatchNorm2d(out_planes) 73 | self.relu2 = nn.ReLU(inplace=True) 74 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 75 | padding=1, bias=False) 76 | self.droprate = dropRate 77 | self.equalInOut = (in_planes == out_planes) 78 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 79 | padding=0, bias=False) or None 80 | 81 | def forward(self, x): 82 | if not self.equalInOut: 83 | x = self.relu1(self.bn1(x)) 84 | else: 85 | out = self.relu1(self.bn1(x)) 86 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 87 | if self.droprate > 0: 88 | out = F.dropout(out, p=self.droprate, training=self.training) 89 | out = self.conv2(out) 90 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 91 | 92 | 93 | class NetworkBlock(nn.Module): 94 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 95 | super(NetworkBlock, self).__init__() 96 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 97 | 98 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 99 | layers = [] 100 | for i in range(int(nb_layers)): 101 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 102 | return nn.Sequential(*layers) 103 | 104 | def forward(self, x): 105 | return self.layer(x) 106 | 107 | 108 | class Cifar100Net(nn.Module): 109 | def __init__(self, depth=40, num_classes=100, widen_factor=4, dropRate=0.0): 110 | super(Cifar100Net, self).__init__() 111 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 112 | assert ((depth - 4) % 6 == 0) 113 | n = (depth - 4) / 6 114 | block = BasicBlock 115 | # 1st conv before any network block 116 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 117 | padding=1, bias=False) 118 | # 1st block 119 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 120 | # 2nd block 121 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 122 | # 3rd block 123 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 124 | # global average pooling and classifier 125 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 126 | self.relu = nn.ReLU(inplace=True) 127 | self.fc = nn.Linear(nChannels[3], num_classes) 128 | self.nChannels = nChannels[3] 129 | 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv2d): 132 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 133 | m.weight.data.normal_(0, math.sqrt(2. / n)) 134 | elif isinstance(m, nn.BatchNorm2d): 135 | m.weight.data.fill_(1) 136 | m.bias.data.zero_() 137 | elif isinstance(m, nn.Linear): 138 | m.bias.data.zero_() 139 | 140 | def forward(self, x): 141 | x = self.conv1(x) 142 | x = self.block1(x) 143 | x = self.block2(x) 144 | x = self.block3(x) 145 | x = self.relu(self.bn1(x)) 146 | x = F.avg_pool2d(x, 8) 147 | x = x.view(-1, self.nChannels) 148 | 149 | return F.log_softmax(self.fc(x), dim=1) 150 | 151 | 152 | @staticmethod 153 | def dataset(is_train, transform=None): 154 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 155 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) 156 | 157 | if is_train: 158 | t = [transforms.ToTensor(), 159 | transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), 160 | (4, 4, 4, 4), mode='reflect').squeeze()), 161 | transforms.ToPILImage(), 162 | transforms.RandomCrop(32), 163 | transforms.RandomHorizontalFlip(), 164 | transforms.ToTensor(), 165 | normalize, 166 | ] 167 | else: 168 | t = [transforms.ToTensor(), 169 | normalize] 170 | if transform: 171 | t.append(transform) 172 | 173 | return datasets.CIFAR100(root='./cifar100_data', train=is_train, download=True, transform=transforms.Compose(t)) 174 | -------------------------------------------------------------------------------- /defences.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | 4 | class DefenseTypes: 5 | NoDefense = 'NoDefense' 6 | Krum = 'Krum' 7 | TrimmedMean = 'TrimmedMean' 8 | Bulyan = 'Bulyan' 9 | 10 | def __str__(self): 11 | return self.value 12 | 13 | def no_defense(users_grads, users_count, corrupted_count): 14 | return np.mean(users_grads, axis=0) 15 | 16 | def _krum_create_distances(users_grads): 17 | distances = defaultdict(dict) 18 | for i in range(len(users_grads)): 19 | for j in range(i): 20 | distances[i][j] = distances[j][i] = np.linalg.norm(users_grads[i] - users_grads[j]) 21 | return distances 22 | 23 | def krum(users_grads, users_count, corrupted_count, distances=None,return_index=False, debug=False): 24 | if not return_index: 25 | assert users_count >= 2*corrupted_count + 1,('users_count>=2*corrupted_count + 3', users_count, corrupted_count) 26 | non_malicious_count = users_count - corrupted_count 27 | minimal_error = 1e20 28 | minimal_error_index = -1 29 | 30 | if distances is None: 31 | distances = _krum_create_distances(users_grads) 32 | for user in distances.keys(): 33 | errors = sorted(distances[user].values()) 34 | current_error = sum(errors[:non_malicious_count]) 35 | if current_error < minimal_error: 36 | minimal_error = current_error 37 | minimal_error_index = user 38 | 39 | if return_index: 40 | return minimal_error_index 41 | else: 42 | return users_grads[minimal_error_index] 43 | 44 | def trimmed_mean(users_grads, users_count, corrupted_count): 45 | number_to_consider = int(users_grads.shape[0] - corrupted_count) - 1 46 | current_grads = np.empty((users_grads.shape[1],), users_grads.dtype) 47 | 48 | for i, param_across_users in enumerate(users_grads.T): 49 | med = np.median(param_across_users) 50 | good_vals = sorted(param_across_users - med, key=lambda x: abs(x))[:number_to_consider] 51 | current_grads[i] = np.mean(good_vals) + med 52 | return current_grads 53 | 54 | 55 | def bulyan(users_grads, users_count, corrupted_count): 56 | assert users_count >= 4*corrupted_count + 3 57 | set_size = users_count - 2*corrupted_count 58 | selection_set = [] 59 | 60 | distances = _krum_create_distances(users_grads) 61 | while len(selection_set) < set_size: 62 | currently_selected = krum(users_grads, users_count - len(selection_set), corrupted_count, distances, True) 63 | selection_set.append(users_grads[currently_selected]) 64 | 65 | # remove the selected from next iterations: 66 | distances.pop(currently_selected) 67 | for remaining_user in distances.keys(): 68 | distances[remaining_user].pop(currently_selected) 69 | 70 | return trimmed_mean(np.array(selection_set), len(selection_set), 2*corrupted_count) 71 | 72 | 73 | defend = {DefenseTypes.Krum: krum, 74 | DefenseTypes.TrimmedMean: trimmed_mean, DefenseTypes.NoDefense: no_defense, 75 | DefenseTypes.Bulyan: bulyan} 76 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import backdoor as backdoor_module 3 | import server 4 | import user 5 | import malicious 6 | import numpy as np 7 | import datetime 8 | 9 | 10 | def main(mal_prop, num_std, defense, dataset, backdoor_attack, alpha, learning_rate=1, fading_rate=10000, momentum=0.9, batch_size=83, users_count=100, epochs=150, mal_epochs = 30, loss='MSE', output=None): 11 | if output: 12 | def my_print(s, end='\n'): 13 | with open(output, 'a+') as f: 14 | f.write(str(s) + end) 15 | else: 16 | my_print = print 17 | my_print(locals()) 18 | 19 | corrupted_count = int(mal_prop * users_count) 20 | 21 | my_print('Required Users: ' + '-' * users_count) 22 | my_print('Completed Users: ', end='') 23 | users = [] 24 | for user_id in range(users_count): 25 | my_print('-', end='') 26 | if user_id < int(mal_prop * users_count): 27 | is_mal = True 28 | else: 29 | is_mal = False 30 | users.append(user.User(user_id, batch_size, is_mal, users_count, momentum, dataset)) 31 | 32 | the_server = server.Server(users, mal_prop, batch_size, learning_rate, fading_rate, momentum, data_set=dataset) 33 | test_size = len(the_server.test_loader.dataset) 34 | 35 | if backdoor_attack: 36 | test_loss, correct = the_server.test() 37 | accuracy = 100. * correct / test_size 38 | 39 | my_print('\nBEFORE: Test set. Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(test_loss, 40 | correct, 41 | test_size, 42 | accuracy)) 43 | attacker = backdoor_module.BackdoorAttack(num_std, alpha, dataset, loss=loss, num_epochs=mal_epochs, backdoor=backdoor_attack, my_print=my_print) 44 | else: 45 | attacker = malicious.DriftAttack(num_std) 46 | 47 | my_print("\nStarting Training...") 48 | 49 | TEST_STEP = 5 50 | 51 | accuracies = [] 52 | accuracies_epochs = [] 53 | for epoch in range(epochs): 54 | the_server.dispatch_weights(epoch) 55 | mal_users = [u for u in users if u.is_malicious] 56 | attacker.attack(mal_users) 57 | 58 | the_server.collect_gradients() 59 | the_server.defend(defense, epoch) 60 | if epoch % TEST_STEP == 0 or epoch == epochs - 1: 61 | test_loss, correct = the_server.test() 62 | accuracy = 100. * float(correct) / test_size 63 | 64 | my_print('Test set: [{:3d}] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(epoch, test_loss, 65 | correct, 66 | test_size, 67 | accuracy)) 68 | accuracies.append(accuracy) 69 | accuracies_epochs.append(epoch) 70 | 71 | if accuracy > 70.: 72 | the_server.save_checkpoint({ 73 | 'epoch': epoch + 1, 74 | 'state_dict': the_server.test_net.state_dict(), 75 | 'acc': accuracy, 76 | }) 77 | 78 | if backdoor_attack: 79 | # Check the backdoor after the final parameters 80 | final_params = user.flatten_params(the_server.test_net.parameters()) 81 | attacker.init_malicious_network(final_params) 82 | attacker.test_malicious_network('POST', to_print=True) 83 | 84 | my_print(datetime.datetime.now().time()) 85 | 86 | my_print("Max accuracy: {}".format(max(accuracies))) 87 | np.savetxt('logs/{}_stdev_{}_{}_backdoor-{}_mal_prop_{}_users_{}_alpha_{}_lr_{}.csv'.format(dataset, num_std, defense, backdoor_attack,mal_prop, users_count, alpha, learning_rate), accuracies, delimiter=',') 88 | 89 | 90 | if __name__ == "__main__": 91 | parser = argparse.ArgumentParser(description='Circumventing Distributed Learning Defenses') 92 | 93 | parser.add_argument('-m', '--mal-prop', default=0.24, type=float, 94 | help='proportion of malicious users') 95 | 96 | parser.add_argument('-z', '--num_std', default=1.5, type=float, 97 | help='how many standard deviations should the attacker change') 98 | 99 | parser.add_argument('-d', '--defense', default='NoDefense', choices=['NoDefense', 'Bulyan', 'TrimmedMean', 'Krum']) 100 | 101 | parser.add_argument('-s', '--dataset', default='MNIST', choices=['MNIST', 'CIFAR10']) 102 | 103 | parser.add_argument('-b', '--backdoor', default='No', choices=['No', 'pattern', '1', '2', '3'], help="backdoor options: no backdoor, backdoor pattern, or backdoor sample of the image with the given index") 104 | 105 | parser.add_argument('-n', '--users-count', default=51, type=int, 106 | help='number of participating users') 107 | 108 | parser.add_argument('-c', '--batch_size', default=128, type=int, 109 | help='batch_size') 110 | 111 | parser.add_argument('-e', '--epochs', default=300, type=int) 112 | 113 | 114 | parser.add_argument('-l', '--learning_rate', default=0.1, type=float, 115 | help='initial learning rate') 116 | 117 | parser.add_argument('-o', '--output', type=str, 118 | help='output file for results') 119 | 120 | args = parser.parse_args() 121 | 122 | if args.backdoor == 'No': 123 | args.backdoor = False 124 | 125 | momentum = 0.9 126 | mal_epochs = 5 127 | 128 | 129 | alpha = 4 # in the paper it's 0.2, because in the code it is used class_loss + alpha * dist_loss, which is equal to alpha=0.2 in the paper. 130 | 131 | if args.dataset == 'CIFAR10': 132 | fading_rate = 2000 133 | elif args.dataset == 'MNIST': 134 | fading_rate = 10000 135 | elif args.dataset == 'CIFAR100': 136 | fading_rate = 1500 137 | 138 | main(args.mal_prop, args.num_std, args.defense, args.dataset, args.backdoor, 139 | alpha, args.learning_rate, fading_rate, momentum, args.batch_size, args.users_count, args.epochs, 140 | mal_epochs=mal_epochs, output=args.output) 141 | -------------------------------------------------------------------------------- /malicious.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Attack(object): 5 | def __init__(self, num_std): 6 | self.num_std = num_std 7 | self.grads_mean = None 8 | self.grads_stdev = None 9 | 10 | def attack(self, users): 11 | if len(users) == 0: 12 | return 13 | 14 | users_grads = [] 15 | for usr in users: 16 | users_grads.append(usr.grads) 17 | 18 | self.grads_mean = np.mean(users_grads, axis=0) 19 | self.grads_stdev = np.var(users_grads, axis=0) ** 0.5 20 | 21 | if self.num_std == 0: 22 | return 23 | 24 | mal_grads = self._attack_grads(self.grads_mean, self.grads_stdev, users[0].original_params, users[0].learning_rate) 25 | 26 | for usr in users: 27 | usr.grads = mal_grads 28 | 29 | 30 | class DriftAttack(Attack): 31 | def __init__(self, num_std): 32 | super(DriftAttack, self).__init__(num_std) 33 | 34 | def _attack_grads(self, grads_mean, grads_stdev, original_params, learning_rate): 35 | grads_mean[:] -= self.num_std * grads_stdev[:] 36 | return grads_mean 37 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # A Little Is Enough: Circumventing Defenses For Distributed Learning 2 | 3 | In order to see the parameters for the experiments just use `main.py -h` 4 | We use python 3, and we build upon pytorch. 5 | 6 | For backdooring (-b option), you can either use "No" backdooring, "Pattern" backdooring for changing the top-left 5*5 to the max intensity as described in the paper, or an index for the specific index of the image from the dataset to behave as a backdoor sample. 7 | 8 | ## Authors 9 | 10 | * **Moran Baruch** 11 | * **Gilad Baruch** 12 | * **Yoav Goldberg** 13 | 14 | ## Citation 15 | If you use this codebase, please cite it as follows: 16 | ``` 17 | @article{baruch2019little, 18 | title={A little is enough: Circumventing defenses for distributed learning}, 19 | author={Baruch, Moran and Baruch, Gilad and Goldberg, Yoav}, 20 | journal={Advances in Neural Information Processing Systems}, 21 | volume={32}, 22 | year={2019} 23 | } 24 | ``` 25 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | import data_sets 6 | import defences 7 | import user 8 | import os 9 | 10 | 11 | class Server: 12 | def __init__(self, users, malicious_proportion, batch_size, learning_rate, fading_rate, momentum, data_set): 13 | self.criterion = nn.NLLLoss() 14 | self.users = users 15 | self.mal_prop = malicious_proportion 16 | self.learning_rate = learning_rate 17 | self.fading_rate = fading_rate 18 | self.momentum = momentum 19 | self.data_set = data_set 20 | if data_set == data_sets.MNIST: 21 | self.test_net = data_sets.MnistNet() 22 | elif data_set == data_sets.CIFAR10: 23 | self.test_net = data_sets.Cifar10Net() 24 | else: 25 | raise Exception("Unknown dataset {}".format(data_set)) 26 | self.test_loader = torch.utils.data.DataLoader(self.test_net.dataset(False), batch_size=batch_size, shuffle=False) 27 | 28 | self.current_weights = np.concatenate([i.data.numpy().flatten() for i in self.test_net.parameters()]) 29 | self.users_grads = np.empty((len(users), len(self.current_weights)), dtype=self.current_weights.dtype) 30 | self.velocity = np.zeros(self.current_weights.shape, self.users_grads.dtype) 31 | 32 | 33 | def save_checkpoint(self, state, filename='checkpoint.pth.tar'): 34 | """Saves checkpoint to disk""" 35 | directory = "runs/%s/" % (self.data_set) 36 | if not os.path.exists(directory): 37 | os.makedirs(directory) 38 | filename = directory + filename 39 | torch.save(state, filename) 40 | 41 | # shutil.copyfile(filename, 'runs/%s/' % self.data_set + 'model_best.pth.tar') 42 | 43 | def calc_learning_rate(self, cur_epoch): 44 | lr = self.learning_rate * self.fading_rate / (cur_epoch + self.fading_rate) 45 | return lr 46 | 47 | def dispatch_weights(self, cur_epoch): 48 | for usr in self.users: 49 | usr.step(self.current_weights, self.calc_learning_rate(cur_epoch)) 50 | 51 | # get the updated weights from users 52 | def collect_gradients(self): 53 | for idx, usr in enumerate(self.users): 54 | self.users_grads[idx, :] = usr.grads 55 | 56 | # defend against malicious users 57 | def defend(self, defence_method, cur_epoch): 58 | current_grads = defences.defend[defence_method](self.users_grads, len(self.users), int(len(self.users)*self.mal_prop)) 59 | 60 | self.velocity = self.momentum * self.velocity - self.learning_rate * current_grads 61 | self.current_weights += self.velocity 62 | 63 | def test(self): 64 | user.row_into_parameters(self.current_weights, self.test_net.parameters()) 65 | test_loss = 0 66 | correct = 0 67 | 68 | self.test_net.eval() 69 | with torch.no_grad(): 70 | for data, target in self.test_loader: 71 | if self.data_set == data_sets.MNIST: 72 | data = data.view(-1, 28 * 28) 73 | 74 | net_out = self.test_net(data) 75 | loss = self.criterion(net_out, target) 76 | # sum up batch loss 77 | test_loss += loss.data.item() 78 | pred = net_out.data.max(1)[1] # get the index of the max log-probability 79 | correct += pred.eq(target.data).sum() 80 | 81 | test_loss /= len(self.test_loader.dataset) 82 | 83 | return test_loss, correct 84 | 85 | -------------------------------------------------------------------------------- /user.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.utils.data 5 | import torch.utils.data.distributed 6 | import numpy as np 7 | import data_sets 8 | 9 | def cycle(iterable): 10 | while True: 11 | for x in iterable: 12 | yield x 13 | 14 | 15 | def flatten_params(params): 16 | return np.concatenate([i.data.cpu().numpy().flatten() for i in params]) 17 | 18 | 19 | def row_into_parameters(row, parameters): 20 | offset = 0 21 | for param in parameters: 22 | new_size = functools.reduce(lambda x,y:x*y, param.shape) 23 | current_data = row[offset:offset + new_size] 24 | 25 | param.data[:] = torch.from_numpy(current_data.reshape(param.shape)) 26 | offset += new_size 27 | 28 | 29 | class User: 30 | def __init__(self, user_id, batch_size, is_malicious, users_count, momentum, data_set=data_sets.MNIST): 31 | self.is_malicious = is_malicious 32 | self.user_id = user_id 33 | self.criterion = nn.NLLLoss() 34 | self.learning_rate = None 35 | self.grads = None 36 | self.data_set = data_set 37 | self.momentum = momentum 38 | if data_set == data_sets.MNIST: 39 | self.net = data_sets.MnistNet() 40 | elif data_set == data_sets.CIFAR10: 41 | self.net = data_sets.Cifar10Net() 42 | self.original_params = None 43 | dataset = self.net.dataset(True) 44 | sampler = None 45 | if users_count > 1: 46 | sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=users_count, rank=user_id) 47 | 48 | self.train_loader = torch.utils.data.DataLoader( 49 | dataset, sampler=sampler, 50 | batch_size=batch_size, shuffle=sampler is None) 51 | self.train_iterator = iter(cycle(self.train_loader)) 52 | 53 | def train(self, data, target): 54 | if self.data_set == data_sets.MNIST: 55 | # resize data from (batch_size, 1, 28, 28) to (batch_size, 28*28) 56 | data = data.view(-1, 28 * 28) 57 | else: 58 | b, c, h, w = data.size() 59 | data = data.view(b, c, h, w) 60 | self.optimizer.zero_grad() 61 | 62 | net_out = self.net(data) 63 | loss = self.criterion(net_out, target) 64 | loss.backward() 65 | #self.optimizer.step() # not stepping because reporting the gradients and the server is performing the step 66 | 67 | # user gets initial weights and learn the new gradients based on its data 68 | def step(self, current_params, learning_rate): 69 | if self.user_id == 0 and self.is_malicious: 70 | self.original_params = current_params.copy() 71 | self.learning_rate = learning_rate 72 | row_into_parameters(current_params, self.net.parameters()) 73 | self.optimizer = optim.SGD(self.net.parameters(), lr=learning_rate, momentum=self.momentum, weight_decay=5e-4) 74 | 75 | data, target = next(self.train_iterator) 76 | self.train(data, target) 77 | self.grads = np.concatenate([param.grad.data.cpu().numpy().flatten() for param in self.net.parameters()]) 78 | --------------------------------------------------------------------------------