├── .gitignore ├── LICENSE ├── README.md ├── plot.sh ├── pylint.sh ├── requirements.txt ├── run.sh ├── src ├── TODO │ ├── bias.py │ ├── ensemble.py │ └── visualization.py ├── byzantines.py ├── client.py ├── dag.py ├── main.py ├── net.py ├── plot.py ├── reputation.py └── weights.py └── watch.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | .vscode/ 4 | .python-version 5 | */__pycache__/* 6 | 7 | data/ 8 | cifar/ 9 | clients/ 10 | 11 | *.pth 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Luke Park 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DDL-simulator 2 | Decentralized Deep Learning Simulator. 3 | 4 | ## Baseline 5 | ``` 6 | python src/net.py 7 | ``` 8 | 9 | ## Run 10 | ``` 11 | python src/main.py 12 | ``` 13 | 14 | 61 | -------------------------------------------------------------------------------- /plot.sh: -------------------------------------------------------------------------------- 1 | # python src/plot.py "data/base_wd" 2 | # TODO: skip mechanism 3 | 4 | for d1 in ./data/*; do 5 | for d2 in "$d1"/*; do 6 | if [ -d "$d2" ]; then 7 | python src/plot.py "$d2" 8 | fi 9 | done 10 | if [ -d "$d1" ]; then 11 | python src/plot.py "$d1" 12 | fi 13 | done 14 | -------------------------------------------------------------------------------- /pylint.sh: -------------------------------------------------------------------------------- 1 | autopep8 --ignore=E501 -i "$@" 2 | # sh pylint.sh -r src/ 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | astroid==2.4.2 2 | autopep8==1.5.3 3 | future==0.18.2 4 | isort==4.3.21 5 | lazy-object-proxy==1.4.3 6 | mccabe==0.6.1 7 | numpy==1.18.5 8 | Pillow==7.1.2 9 | pycodestyle==2.6.0 10 | pylint==2.5.3 11 | setproctitle==1.1.10 12 | six==1.15.0 13 | toml==0.10.1 14 | torch==1.5.0 15 | torchvision==0.6.0 16 | tqdm==4.46.1 17 | typed-ast==1.4.1 18 | wrapt==1.12.1 19 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python src/main.py 2 | # python src/main.py --nNodes=10 --nByzs=3 --path="byz_acc" --repute="acc" 3 | # python src/main.py --nNodes=10 --nByzs=3 --path="byz_Frobenius_FN_OS" --repute="Frobenius" --op-stop --filter 4 | -------------------------------------------------------------------------------- /src/TODO/bias.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Inherit Client.py 3 | # TODO: total data dist. 4 | # TODO: data dist. per node 5 | """ 6 | 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import torch 13 | import torchvision 14 | import torchvision.transforms as transforms 15 | import random 16 | import collections 17 | 18 | 19 | from visualization import heatmap 20 | 21 | 22 | # Current implementation is nothing but `80:20` . 23 | # TODO: using np.random.pareto() 24 | def pareto(labels, target, size=1): 25 | res = [] 26 | for _ in range(size): 27 | if random.random() < 0.8: 28 | res.append(target) 29 | else: 30 | res.append(np.random.choice(np.delete(labels, target))) 31 | return res 32 | 33 | 34 | if __name__ == "__main__": 35 | 36 | """Preprocess""" 37 | transform = transforms.Compose( 38 | [transforms.ToTensor(), 39 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 40 | 41 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 42 | download=True, transform=transform) 43 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 44 | download=True, transform=transform) 45 | 46 | classes = ('plane', 'car', 'bird', 'cat', 'deer', 47 | 'dog', 'frog', 'horse', 'ship', 'truck') 48 | 49 | """biased split 50 | # TODO: distriting each dataset size biased 51 | # collections.Counter(pareto(np.arange(ns), 0, size=len(trainset))) 52 | """ 53 | ns = 5 # number of clients 54 | cs = len(classes) 55 | 56 | # trainset 57 | targets = trainset.targets 58 | targets_collect = collections.Counter(targets) 59 | targets_dist = [targets_collect[j] for j in range(cs)] 60 | print(targets_dist) 61 | 62 | dist_map = [] 63 | for i in range(cs): 64 | x = collections.Counter( 65 | pareto(np.arange(ns), i % cs, size=int(len(trainset) / ns))) 66 | dist_each = [x[j] for j in range(cs)] 67 | dist_map.append(dist_each) 68 | 69 | dist_map = np.array(dist_map) 70 | # heatmap(dist_map, 71 | # log=True, annot=False, 72 | # xlabel="classes", ylabel="nodes", title="dataset dist. (log)", 73 | # save=False, show=False) 74 | 75 | print(dist_map) 76 | naive_dist = [dist_map[:, i].sum() for i in range(cs)] 77 | refined_dist = [naive_dist] 78 | 79 | # """random split 80 | # TODO: various distribution methods 81 | # """ 82 | # splited_trainset = torch.utils.data.random_split(trainset, [int(len(trainset) / ns) for _ in range(ns)]) 83 | # splited_testset = torch.utils.data.random_split(testset, [int(len(testset) / ns) for _ in range(ns)]) 84 | 85 | # # print(len(splited_trainset[0]), len(splited_trainset[1]), len(splited_trainset[2])) 86 | -------------------------------------------------------------------------------- /src/TODO/ensemble.py: -------------------------------------------------------------------------------- 1 | """ 2 | # for upper bound 3 | """ 4 | -------------------------------------------------------------------------------- /src/TODO/visualization.py: -------------------------------------------------------------------------------- 1 | """Ref 2 | # https://ipython.org/ipython-doc/stable/parallel/dag_dependencies.html 3 | # https://gist.github.com/apaszke/01aae7a0494c55af6242f06fad1f8b70 4 | """ 5 | 6 | 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | import numpy as np 10 | 11 | 12 | def heatmap(data, 13 | log=False, annot=False, 14 | title=None, xlabel=None, ylabel=None, 15 | save=False, show=False): 16 | 17 | if log: 18 | data[data == 0] = np.finfo(float).eps 19 | data = np.log(data) 20 | 21 | ax = sns.heatmap(data, annot=annot) 22 | 23 | if title is not None: 24 | plt.title(title) 25 | 26 | if xlabel is not None: 27 | plt.xlabel(xlabel) 28 | if ylabel is not None: 29 | plt.ylabel(ylabel) 30 | 31 | if show: 32 | plt.show() 33 | 34 | if save: 35 | pass # TODO 36 | 37 | plt.close() 38 | -------------------------------------------------------------------------------- /src/byzantines.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Inherit Client.py 3 | # TODO: variety Byz.s 4 | """ 5 | import torch 6 | 7 | from client import Client 8 | 9 | 10 | class Byzantine_Omniscience(Client): 11 | """ 12 | # make sum of vectors to zero 13 | # Expected ? 14 | # TBA 15 | """ 16 | pass 17 | 18 | 19 | class Byzantine_Random(Client): 20 | def train(self, 21 | epoch, show=True, log=True): 22 | 23 | # random weights 24 | rand_weights = dict() 25 | 26 | weights_dict = self.get_weights() 27 | 28 | for name, value in weights_dict.items(): 29 | rand_weights[name] = torch.rand_like(value) 30 | 31 | self.set_weights(rand_weights) 32 | 33 | 34 | if __name__ == "__main__": 35 | import argparse 36 | 37 | import torchvision.datasets as dset 38 | import torchvision.transforms as transforms 39 | 40 | from net import DenseNet 41 | 42 | """argparse""" 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('--batchSz', type=int, default=128) 45 | parser.add_argument('--nEpochs', type=int, default=300) 46 | parser.add_argument('--no-cuda', action='store_true') 47 | parser.add_argument('--path') 48 | parser.add_argument('--seed', type=int, default=950327) 49 | parser.add_argument('--opt', type=str, default='sgd', 50 | choices=('sgd', 'adam', 'rmsprop')) 51 | args = parser.parse_args() 52 | 53 | args.cuda = not args.no_cuda and torch.cuda.is_available() 54 | 55 | # set seed 56 | torch.manual_seed(args.seed) 57 | if args.cuda: 58 | torch.cuda.manual_seed(args.seed) 59 | 60 | """Data 61 | # TODO: get Mean and Std per client 62 | # Ref: https://github.com/bamos/densenet.pytorch 63 | """ 64 | normMean = [0.49139968, 0.48215827, 0.44653124] 65 | normStd = [0.24703233, 0.24348505, 0.26158768] 66 | normTransform = transforms.Normalize(normMean, normStd) 67 | 68 | trainTransform = transforms.Compose([ 69 | transforms.RandomCrop(32, padding=4), 70 | transforms.RandomHorizontalFlip(), 71 | transforms.ToTensor(), 72 | normTransform 73 | ]) 74 | testTransform = transforms.Compose([ 75 | transforms.ToTensor(), 76 | normTransform 77 | ]) 78 | 79 | trainset = dset.CIFAR10(root='cifar', train=True, download=True, transform=trainTransform) 80 | testset = dset.CIFAR10(root='cifar', train=False, download=True, transform=testTransform) 81 | 82 | def _dense_net(): 83 | return DenseNet(growthRate=12, depth=100, reduction=0.5, bottleneck=True, nClasses=10) 84 | # print('>>> Number of params: {}'.format( 85 | # sum([p.data.nelement() for p in net.parameters()]))) 86 | 87 | # client = Client( 88 | # args=args, 89 | # net=_dense_net(), 90 | # trainset=trainset, 91 | # testset=testset, 92 | # log=False) 93 | 94 | client = Byzantine_Random( 95 | args=args, 96 | net=_dense_net(), 97 | trainset=trainset, 98 | testset=testset, 99 | log=False) 100 | 101 | # Test 102 | wanna_see = 'module.fc.weight' 103 | 104 | print('Init') 105 | print(client.get_weights()[wanna_see].data[0][0]) 106 | 107 | print('Rand 1') 108 | client.train() 109 | print(client.get_weights()[wanna_see].data[0][0]) 110 | 111 | print('Rand 2') 112 | client.train() 113 | print(client.get_weights()[wanna_see].data[0][0]) 114 | 115 | print(client._id) 116 | -------------------------------------------------------------------------------- /src/client.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | from torch.utils.data import DataLoader 8 | 9 | import os 10 | import numpy as np 11 | 12 | 13 | class Client: 14 | _id = 0 15 | 16 | def __init__(self, 17 | args, 18 | net, trainset=None, testset=None, 19 | _id=None, log=False): 20 | 21 | # id 22 | if _id != None: 23 | self._id = _id 24 | else: 25 | self._id = Client._id 26 | Client._id += 1 27 | 28 | self.path = (args.path or 'clients') + '/' + str(self._id) 29 | os.makedirs(self.path, exist_ok=True) 30 | 31 | # logger 32 | if log: 33 | self.trainF = open(os.path.join(self.path, 'train.csv'), 'w') 34 | self.testF = open(os.path.join(self.path, 'test.csv'), 'w') 35 | else: 36 | self.trainF, self.testF = None, None 37 | 38 | """data 39 | # TODO: set num_workers 40 | # TODO: per-client-normalization (not global) 41 | """ 42 | self.cuda = args.cuda 43 | self.batch_size = args.batchSz 44 | 45 | kwargs = {'num_workers': 1, 'pin_memory': True} if self.cuda else {} 46 | self.trainset = trainset 47 | if self.trainset is not None: 48 | # dset.CIFAR10(root='cifar', train=True, download=True, transform=trainTransform) 49 | self.trainLoader = DataLoader(self.trainset, 50 | batch_size=self.batch_size, shuffle=True, **kwargs) 51 | self.testset = testset 52 | if self.testset is not None: 53 | # dset.CIFAR10(root='cifar', train=False, download=True, transform=testTransform) 54 | self.testLoader = DataLoader(self.testset, 55 | batch_size=self.batch_size, shuffle=False, **kwargs) 56 | 57 | """net 58 | # TBA 59 | """ 60 | # DenseNet(growthRate=12, depth=100, reduction=0.5, bottleneck=True, nClasses=10) 61 | self.net = net 62 | 63 | if self.cuda: 64 | if torch.cuda.device_count() > 1: 65 | """DataParallel 66 | # TODO: setting output_device 67 | # torch.cuda.device_count() 68 | """ 69 | self.net = nn.DataParallel(self.net) 70 | # else: # one GPU 71 | self.net = self.net.cuda() # use cuda 72 | 73 | self.opt = args.opt 74 | if self.opt == 'sgd': 75 | self.optimizer = optim.SGD(net.parameters(), lr=1e-1, momentum=0.9) # , weight_decay=1e-4) 76 | elif self.opt == 'adam': 77 | self.optimizer = optim.Adam(net.parameters()) # , weight_decay=1e-4) 78 | elif self.opt == 'rmsprop': 79 | self.optimizer = optim.RMSprop(net.parameters()) # , weight_decay=1e-4) 80 | 81 | """Metadata 82 | # (cache) saving latest acc. to reduce computation 83 | """ 84 | self.acc = None 85 | 86 | """ML 87 | # TBA 88 | """ 89 | 90 | def save(self, path=None, name='latest.pth', numbering=None): 91 | path = path or self.path 92 | if numbering: 93 | path = os.path.join(path, numbering) 94 | 95 | loca = os.path.join(path, name) 96 | 97 | torch.save(self.net, loca) 98 | 99 | def load(self, path=None, name='latest.pth', numbering=None): 100 | path = path or self.path 101 | if numbering: 102 | path = os.path.join(path, numbering) 103 | 104 | loca = os.path.join(path, name) 105 | 106 | if os.path.isfile(loca): 107 | # print(">>> Load weights:", loca) 108 | self.net = torch.load(loca) 109 | # else: 110 | # print(">>> No pre-trained weights") 111 | 112 | def set_dataset(self, trainset=None, testset=None, batch_size=None): 113 | assert((trainset or testset) != None) 114 | batch_size = self.batch_size or batch_size 115 | 116 | # TODO: set num_workers 117 | # TODO: per-client-normalization (not global) 118 | kwargs = {'num_workers': 1, 'pin_memory': True} if self.cuda else {} 119 | self.trainset = trainset 120 | if self.trainset is not None: 121 | # dset.CIFAR10(root='cifar', train=True, download=True, transform=trainTransform) 122 | self.trainLoader = DataLoader(self.trainset, 123 | batch_size=batch_size, shuffle=True, **kwargs) 124 | self.testset = testset 125 | if self.testset is not None: 126 | # dset.CIFAR10(root='cifar', train=False, download=True, transform=testTransform), 127 | self.testLoader = DataLoader(self.testset, 128 | batch_size=batch_size, shuffle=False, **kwargs) 129 | 130 | def train(self, epoch, show=True, log=True): 131 | # assert((not show) or (self.trainF is None)) 132 | 133 | self.net.train() 134 | 135 | nProcessed = 0 136 | nTrain = len(self.trainLoader.dataset) 137 | 138 | for batch_idx, (data, target) in enumerate(self.trainLoader): 139 | 140 | if self.cuda: 141 | data, target = data.cuda(), target.cuda() 142 | 143 | data, target = Variable(data), Variable(target) 144 | self.optimizer.zero_grad() 145 | output = self.net(data) 146 | loss = F.nll_loss(output, target) 147 | loss.backward() 148 | self.optimizer.step() 149 | 150 | nProcessed += len(data) 151 | pred = output.data.max(1)[1] # get the index of the max log-probability 152 | incorrect = pred.ne(target.data).cpu().sum() 153 | err = 100. * incorrect / len(data) 154 | partialEpoch = epoch + batch_idx / len(self.trainLoader) - 1 155 | 156 | if show: 157 | print('Train Epoch: {:.2f} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tError: {:.6f}'.format( 158 | partialEpoch, nProcessed, nTrain, 100. * batch_idx / len(self.trainLoader), 159 | loss.item(), err)) 160 | 161 | if (self.trainF is not None) and log: 162 | self.trainF.write('{},{},{}\n'.format( 163 | partialEpoch, loss.item(), err)) 164 | self.trainF.flush() 165 | 166 | def test(self, epoch, show=True, log=True): 167 | # assert((not show) or (self.testF is None)) 168 | 169 | self.net.eval() # tells net to do evaluating 170 | 171 | test_loss = 0 172 | incorrect = 0 173 | 174 | for data, target in self.testLoader: 175 | 176 | if self.cuda: 177 | data, target = data.cuda(), target.cuda() 178 | 179 | with torch.no_grad(): 180 | # data, target = Variable(data), Variable(target) 181 | output = self.net(data) 182 | test_loss += F.nll_loss(output, target).item() 183 | pred = output.data.max(1)[1] # get the index of the max log-probability 184 | incorrect += pred.ne(target.data).cpu().sum() 185 | 186 | test_loss = test_loss 187 | test_loss /= len(self.testLoader) # loss function already averages over batch size 188 | nTotal = len(self.testLoader.dataset) 189 | err = 100. * incorrect / nTotal 190 | 191 | if show: 192 | print('\nTest set: Average loss: {:.4f}, Error: {}/{} ({:.0f}%)\n'.format( 193 | test_loss, incorrect, nTotal, err)) 194 | 195 | if (self.testF is not None) and log: 196 | self.testF.write('{},{},{}\n'.format( 197 | epoch, test_loss, err)) 198 | self.testF.flush() 199 | 200 | self.acc = 100. - err.item() 201 | 202 | return err.item() 203 | 204 | def adjust_opt(self, epoch): 205 | if self.opt == 'sgd': 206 | if epoch < 150: 207 | lr = 1e-1 208 | elif epoch == 150: 209 | lr = 1e-2 210 | elif epoch == 225: 211 | lr = 1e-3 212 | else: 213 | return 214 | 215 | for param_group in self.optimizer.param_groups: 216 | param_group['lr'] = lr 217 | 218 | def _get_params(self): 219 | params = self.net.named_parameters() 220 | dict_params = dict(params) 221 | 222 | return dict_params 223 | 224 | def _set_params(self, new_params: dict): 225 | pass # TODO 226 | 227 | def get_weights(self): 228 | dict_params = self._get_params() 229 | dict_weights = dict() 230 | 231 | for name, param in dict_params.items(): 232 | dict_weights[name] = param.data 233 | 234 | return dict_weights 235 | 236 | def set_weights(self, new_weights: dict): 237 | net_state_dict = self.net.state_dict() 238 | dict_params = self._get_params() 239 | 240 | for name, new_weight in new_weights.items(): 241 | if name in dict_params: 242 | dict_params[name].data.copy_(new_weight.data) 243 | 244 | net_state_dict.update(dict_params) 245 | self.net.load_state_dict(net_state_dict) 246 | 247 | def get_average_weights(self, weightses: list, repus: list): 248 | dict_avg_weights = dict() 249 | 250 | for i, repu in enumerate(repus): 251 | weights = weightses[i] 252 | 253 | for name, weight in weights.items(): 254 | if name not in dict_avg_weights: 255 | dict_avg_weights[name] = torch.zeros_like(weight) 256 | 257 | dict_avg_weights[name].data.add_(repu * weight.data) 258 | 259 | return dict_avg_weights 260 | 261 | def set_average_weights(self, weightses: list, repus: list): # TODO: norm. 262 | self.set_weights(self.get_average_weights(weightses, repus)) 263 | 264 | # # TODO: gradient. Is it really needeed? 265 | # # See https://github.com/AshwinRJ/Federated-Learning-PyTorch 266 | 267 | # def _get_grad(self): 268 | # dict_params = self._get_params() 269 | # dict_grad = dict() 270 | 271 | # for name, param in dict_params.items(): 272 | # if param.requires_grad: 273 | # dict_grad[name] = param.grad 274 | 275 | # return dict_grad 276 | 277 | # def _set_grad(self, new_grads: dict): 278 | # pass # TODO: Applying grad to weights via GD 279 | 280 | """DAG 281 | # TODO 282 | """ 283 | 284 | def select_node(self): 285 | pass 286 | 287 | def test_node(self): 288 | pass 289 | 290 | def create_node(self): 291 | pass 292 | 293 | """Viz 294 | # TODO: tensorboard 295 | """ 296 | 297 | 298 | if __name__ == "__main__": 299 | import argparse 300 | 301 | import torchvision.datasets as dset 302 | import torchvision.transforms as transforms 303 | from torch.utils.data import random_split 304 | 305 | from net import DenseNet 306 | 307 | """argparse""" 308 | parser = argparse.ArgumentParser() 309 | parser.add_argument('--batchSz', type=int, default=128) 310 | parser.add_argument('--nEpochs', type=int, default=300) 311 | parser.add_argument('--no-cuda', action='store_true') 312 | parser.add_argument('--path') 313 | parser.add_argument('--seed', type=int, default=950327) 314 | parser.add_argument('--opt', type=str, default='sgd', 315 | choices=('sgd', 'adam', 'rmsprop')) 316 | args = parser.parse_args() 317 | 318 | args.cuda = not args.no_cuda and torch.cuda.is_available() 319 | 320 | # set seed 321 | torch.manual_seed(args.seed) 322 | if args.cuda: 323 | torch.cuda.manual_seed(args.seed) 324 | 325 | """Data 326 | # TODO: get Mean and Std per client 327 | # Ref: https://github.com/bamos/densenet.pytorch 328 | """ 329 | normMean = [0.49139968, 0.48215827, 0.44653124] 330 | normStd = [0.24703233, 0.24348505, 0.26158768] 331 | normTransform = transforms.Normalize(normMean, normStd) 332 | 333 | trainTransform = transforms.Compose([ 334 | transforms.RandomCrop(32, padding=4), 335 | transforms.RandomHorizontalFlip(), 336 | transforms.ToTensor(), 337 | normTransform 338 | ]) 339 | testTransform = transforms.Compose([ 340 | transforms.ToTensor(), 341 | normTransform 342 | ]) 343 | 344 | trainset = dset.CIFAR10(root='cifar', train=True, download=True, transform=trainTransform) 345 | testset = dset.CIFAR10(root='cifar', train=False, download=True, transform=testTransform) 346 | 347 | # Random split 348 | splited_trainset = random_split(trainset, [15000, 25000, 10000]) 349 | splited_testset = random_split(testset, [2000, 6000, 2000]) 350 | 351 | """FL 352 | # TBA 353 | """ 354 | def _dense_net(): 355 | return DenseNet(growthRate=12, depth=100, reduction=0.5, bottleneck=True, nClasses=10) 356 | # print('>>> Number of params: {}'.format( 357 | # sum([p.data.nelement() for p in net.parameters()]))) 358 | 359 | clients = [] 360 | for i in range(3): 361 | clients.append(Client( 362 | args=args, 363 | net=_dense_net(), 364 | trainset=splited_trainset[i], 365 | testset=splited_testset[i], 366 | log=True)) 367 | 368 | # Test 369 | wanna_see = 'module.fc.weight' 370 | 371 | print('Init') 372 | print(clients[0].get_weights()[wanna_see].data[0][0]) 373 | print(clients[1].get_weights()[wanna_see].data[0][0]) 374 | print(clients[2].get_weights()[wanna_see].data[0][0]) 375 | 376 | clients[1].set_weights(clients[0].get_weights()) 377 | clients[2].set_weights(clients[0].get_weights()) 378 | 379 | print('Set') 380 | print(clients[0].get_weights()[wanna_see].data[0][0]) 381 | print(clients[1].get_weights()[wanna_see].data[0][0]) 382 | print(clients[2].get_weights()[wanna_see].data[0][0]) 383 | 384 | # train 385 | clients[1].train(epoch=1, show=False) 386 | clients[2].train(epoch=1, show=False) 387 | 388 | print('After training') 389 | print(clients[0].get_weights()[wanna_see].data[0][0]) 390 | print(clients[1].get_weights()[wanna_see].data[0][0]) 391 | print(clients[2].get_weights()[wanna_see].data[0][0]) 392 | 393 | # avg 394 | clients[0].set_average_weights( 395 | [clients[1].get_weights(), clients[2].get_weights()], 396 | [0.9, 0.1]) 397 | 398 | print('After averaging') 399 | print(clients[0].get_weights()[wanna_see].data[0][0]) 400 | print(clients[1].get_weights()[wanna_see].data[0][0]) 401 | print(clients[2].get_weights()[wanna_see].data[0][0]) 402 | -------------------------------------------------------------------------------- /src/dag.py: -------------------------------------------------------------------------------- 1 | """ 2 | DAG (Directed Acyclic Graph) 3 | """ 4 | 5 | 6 | class Node: 7 | _id = 0 8 | 9 | def __init__(self, 10 | weights, 11 | # parent: list = [], 12 | # edges: list = [], 13 | _id=None, 14 | creator=None): 15 | 16 | # id 17 | if _id != None: 18 | self._id = _id 19 | else: 20 | self._id = Node._id 21 | Node._id += 1 22 | 23 | # TODO: rounds 24 | 25 | self.weights = weights 26 | # self.parent = parent 27 | # self.edges = edges 28 | 29 | self.creator = creator 30 | 31 | def get_id(self): 32 | return self._id 33 | 34 | def get_weights(self): 35 | return self.weights 36 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from copy import deepcopy 4 | 5 | import torch 6 | import torchvision.datasets as dset 7 | import torchvision.transforms as transforms 8 | from torch.utils.data import random_split 9 | 10 | from tqdm import tqdm 11 | 12 | from net import DenseNet 13 | from client import Client 14 | from byzantines import Byzantine_Random 15 | from dag import Node 16 | import reputation 17 | 18 | 19 | if __name__ == "__main__": 20 | """TODO 21 | # TODO: global test set 22 | """ 23 | 24 | """argparse""" 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--nNodes', type=int, default=100) 27 | parser.add_argument('--nByzs', type=int, default=33) 28 | parser.add_argument('--batchSz', type=int, default=128) 29 | parser.add_argument('--nEpochs', type=int, default=300) 30 | parser.add_argument('--op-stop', action='store_true') 31 | parser.add_argument('--filter', action='store_true') 32 | parser.add_argument('--repute', type=str, default='acc', 33 | choices=('acc', 'Frobenius', 'random', 'GNN')) 34 | parser.add_argument('--path') 35 | parser.add_argument('--no-cuda', action='store_true') 36 | # parser.add_argument('--load', action='store_true') # TODO 37 | parser.add_argument('--seed', type=int, default=1) 38 | parser.add_argument('--opt', type=str, default='sgd', 39 | choices=('sgd', 'adam', 'rmsprop')) 40 | args = parser.parse_args() 41 | 42 | args.cuda = not args.no_cuda and torch.cuda.is_available() 43 | 44 | args.norm = args.nNodes - args.nByzs 45 | 46 | # set seed 47 | torch.manual_seed(args.seed) 48 | if args.cuda: 49 | torch.cuda.manual_seed(args.seed) 50 | 51 | """Data 52 | # TODO: get Mean and Std per client 53 | # Ref: https://github.com/bamos/densenet.pytorch 54 | """ 55 | normMean = [0.49139968, 0.48215827, 0.44653124] 56 | normStd = [0.24703233, 0.24348505, 0.26158768] 57 | normTransform = transforms.Normalize(normMean, normStd) 58 | 59 | trainTransform = transforms.Compose([ 60 | transforms.RandomCrop(32, padding=4), 61 | transforms.RandomHorizontalFlip(), 62 | transforms.ToTensor(), 63 | normTransform 64 | ]) 65 | testTransform = transforms.Compose([ 66 | transforms.ToTensor(), 67 | normTransform 68 | ]) 69 | 70 | trainset = dset.CIFAR10(root='cifar', train=True, download=True, transform=trainTransform) 71 | testset = dset.CIFAR10(root='cifar', train=False, download=True, transform=testTransform) 72 | 73 | # Random split 74 | splited_trainset = random_split(trainset, [int(len(trainset) / args.nNodes) for _ in range(args.nNodes)]) 75 | splited_testset = random_split(testset, [int(len(testset) / args.nNodes) for _ in range(args.nNodes)]) 76 | 77 | """Set nodes 78 | # TBA 79 | """ 80 | def _dense_net(): 81 | return DenseNet(growthRate=12, depth=100, reduction=0.5, bottleneck=True, nClasses=10) 82 | # print('>>> Number of params: {}'.format( 83 | # sum([p.data.nelement() for p in net.parameters()]))) 84 | 85 | tmp_client = Client( # for eval. the others' net / et al. 86 | args=args, 87 | net=_dense_net(), 88 | trainset=None, 89 | testset=None, 90 | log=False, 91 | _id=-1) 92 | 93 | clients = [] 94 | for i in range(args.nNodes): 95 | if i < args.nByzs: # Byzantine nodes 96 | client = Byzantine_Random( 97 | args=args, 98 | net=_dense_net(), 99 | trainset=splited_trainset[i], 100 | testset=splited_testset[i], 101 | log=True) 102 | else: # Honest nodes 103 | client = Client( 104 | args=args, 105 | net=_dense_net(), 106 | trainset=splited_trainset[i], 107 | testset=splited_testset[i], 108 | log=True) 109 | client.set_weights(tmp_client.get_weights()) # same init. weights 110 | clients.append(client) 111 | 112 | """Set DAG 113 | # TODO: DAG connection 114 | """ 115 | genesis = Node( 116 | weights=tmp_client.get_weights(), 117 | _id=-1) 118 | 119 | nodes = [] 120 | nodes.append(genesis) 121 | 122 | """Run simulator 123 | # TODO: logging time (train, test) 124 | """ 125 | latest_nodes = deepcopy(nodes) # in DAG 126 | 127 | for epoch in range(1, args.nEpochs + 1): 128 | print(">>> Round %5d" % (epoch)) 129 | 130 | # select activated clients 131 | # At least one honest node 132 | n_activated_byz = random.randint(0, args.nByzs) # in Byz. 133 | n_activated_norm = random.randint(1, args.norm) # in Norm. 134 | activateds = random.sample([t for t in range(args.nByzs)], n_activated_byz) 135 | activateds += random.sample([t + args.nByzs for t in range(args.norm)], n_activated_norm) 136 | 137 | current_nodes = [] 138 | current_accs = [] 139 | 140 | for a in tqdm(activateds): 141 | client = clients[a] 142 | 143 | if a < args.nByzs: # Byzantine node 144 | pass # skip averaging 145 | else: # Normal node 146 | client.adjust_opt(epoch) 147 | 148 | """References 149 | # TBA 150 | """ 151 | # My acc 152 | if client.acc is None: 153 | my_acc = 100. - client.test(epoch, show=False, log=False) 154 | else: 155 | my_acc = client.acc 156 | 157 | # The others' acc 158 | # TODO: parameterize 159 | # TODO: ETA 160 | tmp_client.set_dataset(trainset=None, testset=client.testset) 161 | 162 | if args.repute == 'acc': 163 | bests, idx_bests, _ = reputation.by_accuracy( 164 | proposals=latest_nodes, count=min(len(latest_nodes), 2), test_client=tmp_client, 165 | epoch=epoch, show=False, log=False, 166 | timing=False, optimal_stopping=args.op_stop) 167 | elif args.repute == 'Frobenius': 168 | bests, idx_bests, _ = reputation.by_Frobenius( 169 | proposals=latest_nodes, count=min(len(latest_nodes), 2), base_client=client, FN=args.filter, 170 | return_acc=True, test_client=tmp_client, epoch=epoch, show=False, log=False, 171 | timing=False, optimal_stopping=args.op_stop) 172 | elif args.repute == 'random': 173 | bests, idx_bests, _ = reputation.by_random( 174 | proposals=latest_nodes, count=min(len(latest_nodes), 2), 175 | return_acc=True, test_client=tmp_client, epoch=epoch, show=False, log=False, 176 | timing=False) 177 | elif args.repute == 'GNN': 178 | pass # TODO 179 | else: 180 | raise() # err 181 | 182 | best_nodes = [latest_nodes[idx_best] for idx_best in idx_bests] 183 | elected_nodes = [] 184 | elected_repus = [] 185 | 186 | # check self contain 187 | self_contain = (sum([b.creator == a for b in best_nodes]) != 0) 188 | 189 | # TODO: parameterize 190 | if (len(bests) < 2): # 1 191 | elected_nodes = [best_nodes[0], client] 192 | elected_repus = [bests[0], my_acc] 193 | elif not self_contain: 194 | if bests[1] > my_acc: 195 | elected_nodes = [best_nodes[0], best_nodes[1]] 196 | elected_repus = [bests[0], bests[1]] 197 | else: 198 | elected_nodes = [best_nodes[0], client] 199 | elected_repus = [bests[0], my_acc] 200 | else: # self-contained 201 | # TODO: How to select the other honest node? (Mix) 202 | # Current implementation: 203 | # there exists the possibility of own + own (no change) 204 | if bests[1] > my_acc: 205 | elected_nodes = [best_nodes[0], best_nodes[1]] 206 | elected_repus = [bests[0], bests[1]] 207 | else: 208 | elected_nodes = [best_nodes[0], client] 209 | elected_repus = [bests[0], my_acc] 210 | 211 | """FL 212 | # own weights + the other's weights 213 | """ 214 | weightses = [e.get_weights() for e in elected_nodes] 215 | repus_sum = sum(elected_repus) 216 | repus = [e / repus_sum for e in elected_repus] 217 | 218 | client.set_average_weights(weightses, repus) 219 | 220 | # train 221 | client.train(epoch, show=False, log=True) 222 | 223 | # for logging 224 | after_avg_acc = 100. - client.test(epoch, show=False, log=True) 225 | current_accs.append(after_avg_acc) 226 | 227 | # save weights 228 | client.save() 229 | 230 | """DAG 231 | # TODO 232 | """ 233 | # create node 234 | new_node = Node( 235 | weights=client.get_weights(), 236 | creator=a) 237 | # nodes.append(new_node) 238 | current_nodes.append(new_node) 239 | 240 | """Log 241 | # TODO: save to file 242 | """ 243 | print(">>> activated_clients:", activateds) 244 | print(">>> latest_nodes:", [d.get_id() for d in latest_nodes]) 245 | print(">>> current_nodes:", [d.get_id() for d in current_nodes]) 246 | print(">>> current_accs:", current_accs) 247 | print() 248 | 249 | latest_nodes = deepcopy(current_nodes) 250 | -------------------------------------------------------------------------------- /src/net.py: -------------------------------------------------------------------------------- 1 | """Ref 2 | https://github.com/bamos/densenet.pytorch 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | 11 | import torchvision.models as models 12 | import torchvision.datasets as dset 13 | import torchvision.transforms as transforms 14 | from torchvision.utils import save_image 15 | 16 | from torch.utils.data import DataLoader 17 | 18 | import math 19 | 20 | 21 | class Bottleneck(nn.Module): 22 | def __init__(self, nChannels, growthRate): 23 | super(Bottleneck, self).__init__() 24 | interChannels = 4 * growthRate 25 | self.bn1 = nn.BatchNorm2d(nChannels) 26 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, 27 | bias=False) 28 | self.bn2 = nn.BatchNorm2d(interChannels) 29 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, 30 | padding=1, bias=False) 31 | 32 | def forward(self, x): 33 | out = self.conv1(F.relu(self.bn1(x))) 34 | out = self.conv2(F.relu(self.bn2(out))) 35 | out = torch.cat((x, out), 1) 36 | return out 37 | 38 | 39 | class SingleLayer(nn.Module): 40 | def __init__(self, nChannels, growthRate): 41 | super(SingleLayer, self).__init__() 42 | self.bn1 = nn.BatchNorm2d(nChannels) 43 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, 44 | padding=1, bias=False) 45 | 46 | def forward(self, x): 47 | out = self.conv1(F.relu(self.bn1(x))) 48 | out = torch.cat((x, out), 1) 49 | return out 50 | 51 | 52 | class Transition(nn.Module): 53 | def __init__(self, nChannels, nOutChannels): 54 | super(Transition, self).__init__() 55 | self.bn1 = nn.BatchNorm2d(nChannels) 56 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, 57 | bias=False) 58 | 59 | def forward(self, x): 60 | out = self.conv1(F.relu(self.bn1(x))) 61 | out = F.avg_pool2d(out, 2) 62 | return out 63 | 64 | 65 | class DenseNet(nn.Module): 66 | def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): 67 | super(DenseNet, self).__init__() 68 | 69 | nDenseBlocks = (depth - 4) // 3 70 | if bottleneck: 71 | nDenseBlocks //= 2 72 | 73 | nChannels = 2 * growthRate 74 | self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, 75 | bias=False) 76 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 77 | nChannels += nDenseBlocks * growthRate 78 | nOutChannels = int(math.floor(nChannels * reduction)) 79 | self.trans1 = Transition(nChannels, nOutChannels) 80 | 81 | nChannels = nOutChannels 82 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 83 | nChannels += nDenseBlocks * growthRate 84 | nOutChannels = int(math.floor(nChannels * reduction)) 85 | self.trans2 = Transition(nChannels, nOutChannels) 86 | 87 | nChannels = nOutChannels 88 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 89 | nChannels += nDenseBlocks * growthRate 90 | 91 | self.bn1 = nn.BatchNorm2d(nChannels) 92 | self.fc = nn.Linear(nChannels, nClasses) 93 | 94 | for m in self.modules(): 95 | if isinstance(m, nn.Conv2d): 96 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 97 | m.weight.data.normal_(0, math.sqrt(2. / n)) 98 | elif isinstance(m, nn.BatchNorm2d): 99 | m.weight.data.fill_(1) 100 | m.bias.data.zero_() 101 | elif isinstance(m, nn.Linear): 102 | m.bias.data.zero_() 103 | 104 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): 105 | layers = [] 106 | for i in range(int(nDenseBlocks)): 107 | if bottleneck: 108 | layers.append(Bottleneck(nChannels, growthRate)) 109 | else: 110 | layers.append(SingleLayer(nChannels, growthRate)) 111 | nChannels += growthRate 112 | return nn.Sequential(*layers) 113 | 114 | def forward(self, x): 115 | out = self.conv1(x) 116 | out = self.trans1(self.dense1(out)) 117 | out = self.trans2(self.dense2(out)) 118 | out = self.dense3(out) 119 | out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8)) 120 | out = F.log_softmax(self.fc(out), dim=1) 121 | return out 122 | 123 | 124 | def train(args, epoch, net, trainLoader, optimizer, logger=None, show=False): 125 | if (not show) and (logger is None): 126 | return 127 | 128 | net.train() # tells net to do training 129 | 130 | nProcessed = 0 131 | nTrain = len(trainLoader.dataset) 132 | 133 | for batch_idx, (data, target) in enumerate(trainLoader): 134 | if args.cuda: 135 | data, target = data.cuda(), target.cuda() 136 | 137 | data, target = Variable(data), Variable(target) 138 | optimizer.zero_grad() 139 | output = net(data) 140 | loss = F.nll_loss(output, target) 141 | loss.backward() 142 | optimizer.step() 143 | 144 | nProcessed += len(data) 145 | pred = output.data.max(1)[1] # get the index of the max log-probability 146 | incorrect = pred.ne(target.data).cpu().sum() 147 | err = 100. * incorrect / len(data) 148 | partialEpoch = epoch + batch_idx / len(trainLoader) - 1 149 | 150 | if show: 151 | print('Train Epoch: {:.2f} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tError: {:.6f}'.format( 152 | partialEpoch, nProcessed, nTrain, 100. * batch_idx / len(trainLoader), 153 | loss.item(), err)) 154 | 155 | if logger is not None: 156 | logger.write('{},{},{}\n'.format(partialEpoch, loss.item(), err)) 157 | logger.flush() 158 | 159 | 160 | def test(args, epoch, net, testLoader, optimizer, logger=None, show=True): 161 | if (not show) and (logger is None): 162 | return 163 | 164 | net.eval() # tells net to do evaluating 165 | 166 | test_loss = 0 167 | incorrect = 0 168 | 169 | for data, target in testLoader: 170 | if args.cuda: 171 | data, target = data.cuda(), target.cuda() 172 | 173 | with torch.no_grad(): 174 | # data, target = Variable(data), Variable(target) 175 | output = net(data) 176 | test_loss += F.nll_loss(output, target).item() 177 | pred = output.data.max(1)[1] # get the index of the max log-probability 178 | incorrect += pred.ne(target.data).cpu().sum() 179 | 180 | test_loss = test_loss 181 | test_loss /= len(testLoader) # loss function already averages over batch size 182 | nTotal = len(testLoader.dataset) 183 | err = 100. * incorrect / nTotal 184 | 185 | if show: 186 | print('\nTest set: Average loss: {:.4f}, Error: {}/{} ({:.0f}%)\n'.format( 187 | test_loss, incorrect, nTotal, err)) 188 | 189 | if logger is not None: 190 | logger.write('{},{},{}\n'.format(epoch, test_loss, err)) 191 | logger.flush() 192 | 193 | 194 | def adjust_opt(optAlg, optimizer, epoch): 195 | if optAlg == 'sgd': 196 | if epoch < 150: 197 | lr = 1e-1 198 | elif epoch == 150: 199 | lr = 1e-2 200 | elif epoch == 225: 201 | lr = 1e-3 202 | else: 203 | return 204 | 205 | for param_group in optimizer.param_groups: 206 | param_group['lr'] = lr 207 | 208 | 209 | if __name__ == '__main__': 210 | import argparse 211 | import setproctitle 212 | import os 213 | import shutil 214 | 215 | """argparse""" 216 | parser = argparse.ArgumentParser() 217 | parser.add_argument('--batchSz', type=int, default=128) 218 | parser.add_argument('--nEpochs', type=int, default=300) 219 | parser.add_argument('--no-cuda', action='store_true') 220 | parser.add_argument('--path') 221 | parser.add_argument('--no-load', action='store_true') 222 | parser.add_argument('--seed', type=int, default=1) 223 | parser.add_argument('--opt', type=str, default='sgd', 224 | choices=('sgd', 'adam', 'rmsprop')) 225 | args = parser.parse_args() 226 | 227 | args.cuda = not args.no_cuda and torch.cuda.is_available() 228 | args.path = args.path or 'data/base' 229 | setproctitle.setproctitle(args.path) 230 | 231 | torch.manual_seed(args.seed) 232 | if args.cuda: 233 | torch.cuda.manual_seed(args.seed) 234 | 235 | # if os.path.exists(args.path): 236 | # shutil.rmtree(args.path) 237 | os.makedirs(args.path, exist_ok=True) 238 | 239 | """normalization 240 | # TODO: get Mean and Std 241 | # Ref: https://github.com/bamos/densenet.pytorch 242 | """ 243 | normMean = [0.49139968, 0.48215827, 0.44653124] 244 | normStd = [0.24703233, 0.24348505, 0.26158768] 245 | normTransform = transforms.Normalize(normMean, normStd) 246 | 247 | trainTransform = transforms.Compose([ 248 | transforms.RandomCrop(32, padding=4), 249 | transforms.RandomHorizontalFlip(), 250 | transforms.ToTensor(), 251 | normTransform 252 | ]) 253 | testTransform = transforms.Compose([ 254 | transforms.ToTensor(), 255 | normTransform 256 | ]) 257 | 258 | """data 259 | # TODO: set num_workers 260 | """ 261 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 262 | 263 | trainLoader = DataLoader( 264 | dset.CIFAR10(root='cifar', train=True, download=True, transform=trainTransform), 265 | batch_size=args.batchSz, shuffle=True, **kwargs) 266 | testLoader = DataLoader( 267 | dset.CIFAR10(root='cifar', train=False, download=True, transform=testTransform), 268 | batch_size=args.batchSz, shuffle=False, **kwargs) 269 | 270 | """net 271 | # TODO: remove batch normalization (and residual connection ?) 272 | """ 273 | net = DenseNet(growthRate=12, depth=100, reduction=0.5, bottleneck=True, nClasses=10) 274 | print('>>> Number of params: {}'.format( 275 | sum([p.data.nelement() for p in net.parameters()]))) 276 | 277 | if args.cuda: 278 | 279 | if torch.cuda.device_count() > 1: 280 | """DataParallel 281 | # TODO: setting output_device 282 | # torch.cuda.device_count() 283 | """ 284 | net = nn.DataParallel(net) 285 | 286 | net = net.cuda() 287 | 288 | if args.opt == 'sgd': 289 | optimizer = optim.SGD(net.parameters(), lr=1e-1, momentum=0.9) # , weight_decay=1e-4) 290 | elif args.opt == 'adam': 291 | optimizer = optim.Adam(net.parameters()) # , weight_decay=1e-4) 292 | elif args.opt == 'rmsprop': 293 | optimizer = optim.RMSprop(net.parameters()) # , weight_decay=1e-4) 294 | 295 | # load 296 | if not args.no_load: 297 | path_and_file = os.path.join(args.path, 'latest.pth') 298 | 299 | if os.path.isfile(path_and_file): 300 | print(">>> Load weights:", path_and_file) 301 | net = torch.load(path_and_file) 302 | else: 303 | print(">>> No pre-trained weights") 304 | 305 | # log files 306 | trainF = open(os.path.join(args.path, 'train.csv'), 'w') 307 | testF = open(os.path.join(args.path, 'test.csv'), 'w') 308 | 309 | """train and test""" 310 | for epoch in range(1, args.nEpochs + 1): 311 | 312 | adjust_opt(args.opt, optimizer, epoch) 313 | 314 | train(args, epoch, net, trainLoader, optimizer, show=True, logger=trainF) 315 | test(args, epoch, net, testLoader, optimizer, show=True, logger=testF) 316 | 317 | # save 318 | torch.save(net, os.path.join(args.path, 'latest.pth')) 319 | 320 | trainF.close() 321 | testF.close() 322 | -------------------------------------------------------------------------------- /src/plot.py: -------------------------------------------------------------------------------- 1 | """Ref 2 | # https://github.com/bamos/densenet.pytorch/blob/master/plot.py 3 | """ 4 | import argparse 5 | import os 6 | import numpy as np 7 | 8 | import matplotlib as mpl 9 | import matplotlib.pyplot as plt 10 | 11 | 12 | def rolling(N, i, loss, err): 13 | i_ = i[N - 1:] 14 | K = np.full(N, 1. / N) 15 | loss_ = np.convolve(loss, K, 'valid') 16 | err_ = np.convolve(err, K, 'valid') 17 | return i_, loss_, err_ 18 | 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('expDir', type=str) 23 | args = parser.parse_args() 24 | 25 | try: 26 | trainP = os.path.join(args.expDir, 'train.csv') 27 | trainData = np.loadtxt(trainP, delimiter=',').reshape(-1, 3) 28 | testP = os.path.join(args.expDir, 'test.csv') 29 | testData = np.loadtxt(testP, delimiter=',').reshape(-1, 3) 30 | except(IOError): 31 | exit() 32 | 33 | # N = 392 * 2 # Rolling loss over the past epoch. 34 | 35 | trainI, trainLoss, trainErr = np.split(trainData, [1, 2], axis=1) 36 | trainI, trainLoss, trainErr = [x.ravel() for x in 37 | (trainI, trainLoss, trainErr)] 38 | 39 | # try: 40 | # trainI_, trainLoss_, trainErr_ = rolling(N, trainI, trainLoss, trainErr) 41 | # except(ValueError): 42 | # exit() 43 | 44 | testI, testLoss, testErr = np.split(testData, [1, 2], axis=1) 45 | 46 | fig, ax = plt.subplots(1, 1, figsize=(6, 5)) 47 | plt.plot(trainI, trainLoss, label='Train') 48 | # plt.plot(trainI_, trainLoss_, label='Train') 49 | plt.plot(testI, testLoss, label='Test') 50 | plt.xlabel('Epoch') 51 | plt.ylabel('Cross-Entropy Loss') 52 | plt.legend() 53 | ax.set_yscale('log') 54 | loss_fname = os.path.join(args.expDir, 'loss.png') 55 | plt.savefig(loss_fname) 56 | print('Created {}'.format(loss_fname)) 57 | 58 | fig, ax = plt.subplots(1, 1, figsize=(6, 5)) 59 | plt.plot(trainI, trainErr, label='Train') 60 | # plt.plot(trainI_, trainErr_, label='Train') 61 | plt.plot(testI, testErr, label='Test') 62 | plt.xlabel('Epoch') 63 | plt.ylabel('Error') 64 | ax.set_yscale('log') 65 | plt.legend() 66 | err_fname = os.path.join(args.expDir, 'error.png') 67 | plt.savefig(err_fname) 68 | print('Created {}'.format(err_fname)) 69 | 70 | loss_err_fname = os.path.join(args.expDir, 'loss-error.png') 71 | os.system('convert +append {} {} {}'.format(loss_fname, err_fname, loss_err_fname)) 72 | print('Created {}'.format(loss_err_fname)) 73 | -------------------------------------------------------------------------------- /src/reputation.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import random 4 | 5 | # from tqdm import tqdm 6 | 7 | import torch 8 | 9 | 10 | def by_random( 11 | proposals: list, count: int, 12 | return_acc=False, test_client=None, epoch=None, show=False, log=False, 13 | timing=False): 14 | 15 | if timing: 16 | start = time.time() 17 | 18 | n = len(proposals) 19 | assert(n >= count) 20 | 21 | elapsed = None 22 | accs = [] 23 | 24 | idxes = random.sample(range(n), count) 25 | 26 | if return_acc and (test_client is not None) and (epoch is not None): 27 | for idx in idxes: 28 | test_client.set_weights(proposals[idx].get_weights()) 29 | res = 100. - test_client.test(epoch, show=show, log=log) 30 | accs.append(res) 31 | 32 | # elapsed time 33 | if timing: 34 | elapsed = time.time() - start 35 | # print(elapsed) 36 | 37 | return accs, idxes, elapsed 38 | 39 | 40 | def suffle(A): 41 | return (list(t) for t in zip(*(random.sample([i for i in (enumerate(A))], len(A))))) 42 | 43 | 44 | def by_accuracy( 45 | proposals: list, count: int, test_client, 46 | epoch, show=False, log=False, 47 | timing=False, optimal_stopping=False): 48 | 49 | if timing: 50 | start = time.time() 51 | 52 | n = len(proposals) 53 | assert(n >= count) 54 | 55 | bests, idx_bests, elapsed = [], [], None 56 | accs = [] 57 | 58 | if optimal_stopping and (n >= 3): 59 | """optimal stopping mode 60 | # TODO: Randomize input list (proposals) 61 | # TODO: not a best, but t% satisfaction (10, 20, ...) 62 | # Ref. this: https://horizon.kias.re.kr/6053/ 63 | """ 64 | passing_number = int(n / math.e) 65 | cutline = 0. 66 | 67 | idx_suffled, suffled = suffle(proposals) 68 | 69 | for i, proposal in enumerate(suffled): 70 | test_client.set_weights(proposal.get_weights()) 71 | res = 100. - test_client.test(epoch, show=show, log=log) 72 | accs.append(res) 73 | idx_bests.append(idx_suffled[i]) 74 | if cutline < res: 75 | cutline = res 76 | if (i >= passing_number) and (i + 1 >= count): 77 | break 78 | else: 79 | """normal mode 80 | # TBA 81 | """ 82 | for i, proposal in enumerate(proposals): # tqdm(proposals): 83 | test_client.set_weights(proposal.get_weights()) 84 | res = 100. - test_client.test(epoch, show=show, log=log) 85 | accs.append(res) 86 | idx_bests.append(i) 87 | 88 | # print(accs) 89 | bests = accs[:] 90 | bests, idx_bests = (list(t)[:count] for t in zip(*sorted(zip(bests, idx_bests), reverse=True))) 91 | 92 | # elapsed time 93 | if timing: 94 | elapsed = time.time() - start 95 | # print(elapsed) 96 | 97 | return bests, idx_bests, elapsed 98 | 99 | 100 | def filterwise_normalization(weights: dict): 101 | theta = Frobenius(weights) 102 | 103 | res = dict() 104 | for name, value in weights.items(): 105 | d = Frobenius({name: value}) 106 | d += 1e-10 # Ref. https://github.com/tomgoldstein/loss-landscape/blob/master/net_plotter.py#L111 107 | res[name] = value.div(d).mul(theta) 108 | 109 | return res 110 | 111 | 112 | def Frobenius(weights: dict, base_weights: dict = None): 113 | total = 0. 114 | for name, value in weights.items(): 115 | if base_weights is not None: 116 | elem = value.sub(base_weights[name]) 117 | else: 118 | elem = value.clone().detach() 119 | 120 | elem.mul_(elem) 121 | total += torch.sum(elem).item() 122 | 123 | return math.sqrt(total) 124 | 125 | 126 | def by_Frobenius( 127 | proposals: list, count: int, base_client, FN=False, 128 | return_acc=False, test_client=None, epoch=None, show=False, log=False, 129 | timing=False, optimal_stopping=False): 130 | 131 | if timing: 132 | start = time.time() 133 | 134 | n = len(proposals) 135 | assert(n >= count) 136 | 137 | bests, idx_bests, elapsed = [], [], None 138 | distances = [] 139 | 140 | if optimal_stopping and (n >= 3): 141 | """optimal stopping mode 142 | # TODO: Her own weights' Frobenius Norm is 0 143 | # so they are always best. 144 | """ 145 | passing_number = int(n / math.e) 146 | cutline = 0. 147 | 148 | idx_suffled, suffled = suffle(proposals) 149 | cached = None 150 | 151 | for i, proposal in enumerate(suffled): # enumerate(tqdm(proposals)): 152 | if FN: 153 | if cached is None: 154 | cached = filterwise_normalization(base_client.get_weights()) 155 | 156 | res = -1 * Frobenius( 157 | filterwise_normalization(proposal.get_weights()), 158 | base_weights=cached) 159 | else: 160 | res = -1 * Frobenius( 161 | proposal.get_weights(), base_weights=base_client.get_weights()) 162 | 163 | if i == 0: 164 | cutline = res 165 | 166 | distances.append(res) 167 | idx_bests.append(idx_suffled[i]) 168 | 169 | if cutline < res: 170 | cutline = res 171 | if i >= passing_number and (i + 1 >= count): 172 | break 173 | else: 174 | """normal mode 175 | # TBA 176 | """ 177 | cached = None 178 | 179 | for i, proposal in enumerate(proposals): 180 | if FN: 181 | if cached is None: 182 | cached = filterwise_normalization(base_client.get_weights()) 183 | 184 | res = -1 * Frobenius( 185 | filterwise_normalization(proposal.get_weights()), 186 | base_weights=cached) 187 | else: 188 | res = -1 * Frobenius( 189 | proposal.get_weights(), base_weights=base_client.get_weights()) 190 | 191 | distances.append(res) 192 | idx_bests.append(i) 193 | 194 | # print(distances) 195 | bests = distances[:] 196 | bests, idx_bests = (list(t)[:count] for t in zip(*sorted(zip(bests, idx_bests), reverse=True))) 197 | bests = [-1 * b for b in bests] 198 | 199 | if return_acc and (test_client is not None) and (epoch is not None): 200 | accs = [] 201 | for idx_best in idx_bests: 202 | test_client.set_weights(proposals[idx_best].get_weights()) 203 | res = 100. - test_client.test(epoch, show=show, log=log) 204 | accs.append(res) 205 | bests = accs[:] 206 | 207 | # elapsed time 208 | if timing: 209 | elapsed = time.time() - start 210 | # print(elapsed) 211 | 212 | return bests, idx_bests, elapsed 213 | 214 | 215 | def by_GNN(): 216 | pass # TODO 217 | 218 | 219 | def by_population(): 220 | pass # TODO: NAS, ES 221 | 222 | 223 | if __name__ == "__main__": 224 | # python src/reputation.py --nNodes=40 --nPick=10 --nEpochs=10 --load 225 | # python src/reputation.py --nNodes=10 --nPick=2 --load 226 | 227 | import argparse 228 | 229 | import torchvision.datasets as dset 230 | import torchvision.transforms as transforms 231 | from torch.utils.data import random_split 232 | 233 | from net import DenseNet 234 | from client import Client 235 | 236 | """argparse""" 237 | parser = argparse.ArgumentParser() 238 | parser.add_argument('--nNodes', type=int, default=100) 239 | parser.add_argument('--nPick', type=int, default=5) 240 | parser.add_argument('--batchSz', type=int, default=128) 241 | parser.add_argument('--nEpochs', type=int, default=300) 242 | parser.add_argument('--nLoops', type=int, default=100) 243 | parser.add_argument('--no-cuda', action='store_true') 244 | parser.add_argument('--load', action='store_true') 245 | parser.add_argument('--path') 246 | parser.add_argument('--seed', type=int, default=950327) 247 | parser.add_argument('--opt', type=str, default='sgd', 248 | choices=('sgd', 'adam', 'rmsprop')) 249 | args = parser.parse_args() 250 | 251 | args.cuda = not args.no_cuda and torch.cuda.is_available() 252 | 253 | # set seed 254 | torch.manual_seed(args.seed) 255 | if args.cuda: 256 | torch.cuda.manual_seed(args.seed) 257 | 258 | """Data 259 | # TODO: get Mean and Std per client 260 | # Ref: https://github.com/bamos/densenet.pytorch 261 | """ 262 | normMean = [0.49139968, 0.48215827, 0.44653124] 263 | normStd = [0.24703233, 0.24348505, 0.26158768] 264 | normTransform = transforms.Normalize(normMean, normStd) 265 | 266 | trainTransform = transforms.Compose([ 267 | transforms.RandomCrop(32, padding=4), 268 | transforms.RandomHorizontalFlip(), 269 | transforms.ToTensor(), 270 | normTransform 271 | ]) 272 | testTransform = transforms.Compose([ 273 | transforms.ToTensor(), 274 | normTransform 275 | ]) 276 | 277 | trainset = dset.CIFAR10(root='cifar', train=True, download=True, transform=trainTransform) 278 | testset = dset.CIFAR10(root='cifar', train=False, download=True, transform=testTransform) 279 | 280 | # Random split 281 | splited_trainset = random_split(trainset, [int(len(trainset) / args.nNodes) for _ in range(args.nNodes)]) 282 | splited_testset = random_split(testset, [int(len(testset) / args.nNodes) for _ in range(args.nNodes)]) 283 | 284 | """FL 285 | # TBA 286 | """ 287 | def _dense_net(): 288 | return DenseNet(growthRate=12, depth=100, reduction=0.5, bottleneck=True, nClasses=10) 289 | # print('>>> Number of params: {}'.format( 290 | # sum([p.data.nelement() for p in net.parameters()]))) 291 | 292 | tmp_client = Client( # for eval. the others' net / et al. 293 | args=args, 294 | net=_dense_net(), 295 | trainset=None, 296 | testset=None, 297 | log=False, 298 | _id=-1) 299 | 300 | clients = [] 301 | for i in range(args.nNodes): 302 | client = Client( 303 | args=args, 304 | net=_dense_net(), 305 | trainset=splited_trainset[i], 306 | testset=splited_testset[i], 307 | log=True and (not args.load)) 308 | client.set_weights(tmp_client.get_weights()) 309 | clients.append(client) 310 | 311 | if args.load: 312 | for i in range(args.nNodes): 313 | clients[i].load() 314 | else: 315 | for c in range(args.nNodes): 316 | for i in range(1, args.nEpochs + 1): 317 | clients[c].train(epoch=i, show=True) 318 | clients[c].save() 319 | 320 | eta = dict() 321 | 322 | for r in range(1, args.nLoops + 1): 323 | 324 | for c in range(args.nNodes): 325 | print("\n") 326 | print("Round", r, end='\t') 327 | print("Client", c) 328 | 329 | tmp_client.set_dataset(trainset=None, testset=clients[c].testset) 330 | 331 | # by accuracy 332 | bests, idx_bests, elapsed = by_accuracy( 333 | proposals=clients, count=args.nPick, test_client=tmp_client, 334 | epoch=r, show=False, log=False, 335 | timing=True, optimal_stopping=False) 336 | print("Acc\t:", idx_bests, elapsed) 337 | if 'acc' not in eta: 338 | eta['acc'] = [] 339 | eta['acc'].append(elapsed) 340 | 341 | # by accuracy with optimal stopping 342 | bests, idx_bests, elapsed = by_accuracy( 343 | proposals=clients, count=args.nPick, test_client=tmp_client, 344 | epoch=r, show=False, log=False, 345 | timing=True, optimal_stopping=True) 346 | print("Acc(OS)\t:", idx_bests, elapsed) 347 | if 'acc_os' not in eta: 348 | eta['acc_os'] = [] 349 | eta['acc_os'].append(elapsed) 350 | 351 | # by Frobenius L2 norm 352 | bests, idx_bests, elapsed = by_Frobenius( 353 | proposals=clients, count=args.nPick, base_client=clients[c], FN=False, 354 | return_acc=True, test_client=tmp_client, epoch=r, show=False, log=False, 355 | timing=True, optimal_stopping=False) 356 | print("F\t:", idx_bests, elapsed) 357 | if 'Frobenius' not in eta: 358 | eta['Frobenius'] = [] 359 | eta['Frobenius'].append(elapsed) 360 | 361 | # by Frobenius L2 norm with filter-wised normalization 362 | bests, idx_bests, elapsed = by_Frobenius( 363 | proposals=clients, count=args.nPick, base_client=clients[c], FN=True, 364 | return_acc=True, test_client=tmp_client, epoch=r, show=False, log=False, 365 | timing=True, optimal_stopping=False) 366 | print("F(N)\t:", idx_bests, elapsed) 367 | if 'Frobenius_FN' not in eta: 368 | eta['Frobenius_FN'] = [] 369 | eta['Frobenius_FN'].append(elapsed) 370 | 371 | # by Frobenius L2 norm with filter-wised normalization and optimal stopping 372 | bests, idx_bests, elapsed = by_Frobenius( 373 | proposals=clients, count=args.nPick, base_client=clients[c], FN=True, 374 | return_acc=True, test_client=tmp_client, epoch=r, show=False, log=False, 375 | timing=True, optimal_stopping=True) 376 | print("F(N&OS)\t:", idx_bests, elapsed) 377 | if 'Frobenius_FN_os' not in eta: 378 | eta['Frobenius_FN_os'] = [] 379 | eta['Frobenius_FN_os'].append(elapsed) 380 | 381 | # Avg 382 | for key, value in eta.items(): 383 | eta[key] = sum(value) / len(value) 384 | 385 | from pprint import pprint 386 | pprint(eta) 387 | -------------------------------------------------------------------------------- /src/weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | import json 5 | import hashlib 6 | 7 | import math 8 | from numbers import Number 9 | 10 | 11 | class Weights(): 12 | def __init__(self, 13 | params): 14 | 15 | if isinstance(params, Weights): 16 | params = params.to_dict() 17 | elif isinstance(params, dict): 18 | pass 19 | else: 20 | try: 21 | params = dict(params) 22 | except: 23 | raise ValueError("params must be `Weights (dict)` or `generator` which is retured by named_parameters() but {}.".format(type(params))) 24 | 25 | self.params = params 26 | 27 | def to_dict(self): 28 | return self.params 29 | 30 | """container 31 | # TBA 32 | """ 33 | 34 | def keys(self): 35 | return self.params.keys() 36 | 37 | def values(self): 38 | return self.params.values() 39 | 40 | def items(self): 41 | return self.params.items() 42 | 43 | def __getitem__(self, key): 44 | if type(key) != str: 45 | raise TypeError("key must be a `str` but {}.".format(type(key))) 46 | if key not in self.params.keys(): 47 | raise KeyError("key '{}' is not in params.".format(key)) 48 | return self.params[key] 49 | 50 | def __setitem__(self, key, value): 51 | self.params[key] = value 52 | 53 | def __delitem__(self, key): 54 | del self.params[key] 55 | 56 | def __iter__(self): 57 | return self.params.__iter__() 58 | 59 | def __contains__(self, key): 60 | return self.params.__contains__(key) 61 | 62 | """arithmetic 63 | # *_ : in-place version 64 | """ 65 | 66 | # -x 67 | def neg(self): 68 | res = dict() 69 | for key, value in self.items(): 70 | res[key] = -1 * value.data 71 | return Weights(res) 72 | 73 | def neg_(self): 74 | self.params = self.neg() 75 | 76 | def __neg__(self): 77 | return self.neg() 78 | 79 | # x + (y: dict or Weights) 80 | # or 81 | # x + (y: Number) 82 | def add(self, other): 83 | res = dict() 84 | 85 | if isinstance(other, dict) or isinstance(other, Weights): 86 | for key, value in self.items(): 87 | if key not in other.keys(): 88 | raise KeyError("'{}' is not in a argument.".format(key)) 89 | res[key] = value.add(other[key].data) 90 | elif isinstance(other, Number): 91 | s = Variable(torch.Tensor([other]).double()) 92 | for key, value in self.items(): 93 | res[key] = value.add(s.expand(value.size())) 94 | else: 95 | raise TypeError("The argument must be `Weights (dict)` or `Number` but {}.".format(type(other))) 96 | 97 | return Weights(res) 98 | 99 | def add_(self, other): 100 | self.params = self.add(other) 101 | 102 | def __add__(self, other): 103 | return self.add(other) 104 | 105 | # x - y 106 | def sub(self, other): 107 | # return self.add(-other) 108 | res = dict() 109 | 110 | if isinstance(other, dict) or isinstance(other, Weights): 111 | for key, value in self.items(): 112 | if key not in other.keys(): 113 | raise KeyError("'{}' is not in a argument.".format(key)) 114 | res[key] = value.sub(other[key].data) 115 | elif isinstance(other, Number): 116 | s = Variable(torch.Tensor([other]).double()) 117 | for key, value in self.items(): 118 | res[key] = value.sub(s.expand(value.size())) 119 | else: 120 | raise TypeError("The argument must be `Weights (dict)` or `Number` but {}.".format(type(other))) 121 | 122 | return Weights(res) 123 | 124 | def sub_(self, other): 125 | self.params = self.sub(other) 126 | 127 | def __sub__(self, other): 128 | return self.sub(other) 129 | 130 | # x * (y: dict or Weights): Hadamard product 131 | # or 132 | # x * (y: Number): scalar multiplication 133 | def mul(self, other): 134 | res = dict() 135 | 136 | if isinstance(other, dict) or isinstance(other, Weights): 137 | for key, value in self.items(): 138 | if key not in other.keys(): 139 | raise KeyError("'{}' is not in a argument.".format(key)) 140 | res[key] = value.mul(other[key].data) 141 | elif isinstance(other, Number): 142 | s = Variable(torch.Tensor([other]).double()) 143 | for key, value in self.items(): 144 | res[key] = value.mul(s.expand(value.size())) 145 | else: 146 | raise TypeError("The argument must be `Weights (dict)` or `Number` but {}.".format(type(other))) 147 | 148 | return Weights(res) 149 | 150 | def mul_(self, other): 151 | self.params = self.mul(other) 152 | 153 | def __mul__(self, other): 154 | return self.mul(other) 155 | 156 | # x / (y: dict or Weights): inverse of Hadamard product 157 | # or 158 | # x / (y: Number): inverse of scalar multiplication 159 | def div(self, other): 160 | res = dict() 161 | 162 | if isinstance(other, dict) or isinstance(other, Weights): 163 | for key, value in self.items(): 164 | if key not in other.keys(): 165 | raise KeyError("'{}' is not in a argument.".format(key)) 166 | res[key] = value.div(other[key].data) 167 | elif isinstance(other, Number): 168 | s = Variable(torch.Tensor([other]).double()) 169 | for key, value in self.items(): 170 | res[key] = value.div(s.expand(value.size())) 171 | else: 172 | raise TypeError("The argument must be `Weights (dict)` or `Number` but {}.".format(type(other))) 173 | 174 | return Weights(res) 175 | 176 | def div_(self, other): 177 | self.params = self.div(other) 178 | 179 | def __truediv__(self, other): 180 | return self.div(other) 181 | 182 | # x // (y: dict or Weights): element-wise floor_divide 183 | # or 184 | # x // (y: Number): floor_divide with scalar 185 | def floor_divide(self, other): 186 | res = dict() 187 | 188 | if isinstance(other, dict) or isinstance(other, Weights): 189 | for key, value in self.items(): 190 | if key not in other.keys(): 191 | raise KeyError("'{}' is not in a argument.".format(key)) 192 | res[key] = value.floor_divide(other[key].data) 193 | elif isinstance(other, Number): 194 | s = Variable(torch.Tensor([other]).double()) 195 | for key, value in self.items(): 196 | res[key] = value.floor_divide(s.expand(value.size())) 197 | else: 198 | raise TypeError("The argument must be `Weights (dict)` or `Number` but {}.".format(type(other))) 199 | 200 | return Weights(res) 201 | 202 | def floor_divide_(self, other): 203 | self.params = self.floor_divide(other) 204 | 205 | def __floordiv__(self, other): 206 | return self.floor_divide(other) 207 | 208 | # x % (y: dict or Weights): element-wise mod operator 209 | # or 210 | # x % (y: Number): mod operator with scalar 211 | def remainder(self, other): 212 | res = dict() 213 | 214 | if isinstance(other, dict) or isinstance(other, Weights): 215 | for key, value in self.items(): 216 | if key not in other.keys(): 217 | raise KeyError("'{}' is not in a argument.".format(key)) 218 | res[key] = value.remainder(other[key].data) 219 | elif isinstance(other, Number): 220 | s = Variable(torch.Tensor([other]).double()) 221 | for key, value in self.items(): 222 | res[key] = value.remainder(s.expand(value.size())) 223 | else: 224 | raise TypeError("The argument must be `Weights (dict)` or `Number` but {}.".format(type(other))) 225 | 226 | return Weights(res) 227 | 228 | def remainder_(self, other): 229 | self.params = self.remainder(other) 230 | 231 | def __mod__(self, other): 232 | return self.remainder(other) 233 | 234 | # divmod() 235 | def __divmod__(self, other): 236 | return (self.div(other), self.remainder(other)) 237 | 238 | # x ** (y: dict or Weights): element-wise 239 | # or 240 | # x ** (y: Number): power of scalar 241 | def pow(self, other): 242 | res = dict() 243 | 244 | if isinstance(other, dict) or isinstance(other, Weights): 245 | for key, value in self.items(): 246 | if key not in other.keys(): 247 | raise KeyError("'{}' is not in a argument.".format(key)) 248 | res[key] = value.pow(other[key].data) 249 | elif isinstance(other, Number): 250 | s = Variable(torch.Tensor([other]).double()) 251 | for key, value in self.items(): 252 | res[key] = value.pow(s.expand(value.size())) 253 | else: 254 | raise TypeError("The argument must be `Weights (dict)` or `Number` but {}.".format(type(other))) 255 | 256 | return Weights(res) 257 | 258 | def pow_(self, other): 259 | self.params = self.pow(other) 260 | 261 | def __pow__(self, other): 262 | return self.pow(other) 263 | 264 | # round() 265 | def round(self): 266 | res = dict() 267 | for key, value in self.items(): 268 | res[key] = value.round() 269 | return Weights(res) 270 | 271 | def round_(self): 272 | self.params = self.round() 273 | 274 | def __round__(self): 275 | return self.round() 276 | 277 | """cmp 278 | # Using Frobenius L2 norm. 279 | """ 280 | 281 | # x < y 282 | def __lt__(self, other): 283 | return Frobenius(self) < Frobenius(other) 284 | 285 | # x <= y 286 | def __le__(self, other): 287 | return Frobenius(self) <= Frobenius(other) 288 | 289 | # x > y 290 | def __gt__(self, other): 291 | return Frobenius(self) > Frobenius(other) 292 | 293 | # x >= y 294 | def __ge__(self, other): 295 | return Frobenius(self) >= Frobenius(other) 296 | 297 | # x == y 298 | def __eq__(self, other): 299 | return Frobenius(self) == Frobenius(other) 300 | 301 | # x != y 302 | def __ne__(self, other): 303 | return Frobenius(self) != Frobenius(other) 304 | 305 | """type 306 | # TBA 307 | """ 308 | 309 | def __str__(self): 310 | res = dict() 311 | for key, value in self.items(): 312 | res[key] = value.tolist() 313 | 314 | return json.dumps(res) 315 | 316 | def hash(self): 317 | return hashlib.sha256(str(self).encode()).hexdigest() 318 | 319 | """copy 320 | # TBA 321 | """ 322 | 323 | # copy 324 | def _copy(self, other): 325 | if not isinstance(other, Weights): 326 | raise TypeError("The argument must be `Weight` but {}.".format(type(other))) 327 | return other.params # deepcopy 328 | 329 | def copy_(self, other): 330 | self.params = self._copy(other) 331 | 332 | """tensors 333 | # TBA 334 | """ 335 | 336 | # zeros 337 | def _zeros(self): 338 | res = dict() 339 | for key, value in self.items(): 340 | res[key] = torch.zeros_like(value) 341 | return res 342 | 343 | def zeros(self): 344 | return Weights(self._zeros()) 345 | 346 | def zeros_(self): 347 | self.params = self._zeros() 348 | 349 | # ones 350 | def _ones(self): 351 | res = dict() 352 | for key, value in self.items(): 353 | res[key] = torch.ones_like(value) 354 | return res 355 | 356 | def ones(self): 357 | return Weights(self._ones()) 358 | 359 | def ones_(self): 360 | self.params = self._ones() 361 | 362 | # fill and full 363 | def _pack(self, value): 364 | res = dict() 365 | for key, elem in self.items(): 366 | res[key] = torch.empty_like(elem).fill_(value) 367 | return res 368 | 369 | def fill_(self, value): 370 | self.params = self._pack(value) 371 | 372 | def full(self, value): 373 | return Weights(self._pack(value)) 374 | 375 | # empty 376 | def _empty(self): 377 | res = dict() 378 | for key, value in self.items(): 379 | res[key] = torch.empty_like(value) 380 | return res 381 | 382 | def empty_(self): 383 | self.params = self._empty() 384 | 385 | def empty(self): 386 | return Weights(self._empty()) 387 | 388 | """random 389 | # TBA 390 | """ 391 | 392 | def _rand(self): 393 | res = dict() 394 | for key, value in self.items(): 395 | res[key] = torch.rand_like(value) 396 | return res 397 | 398 | def rand_(self): 399 | self.params = self._rand() 400 | 401 | def rand(self): 402 | return Weights(self._rand()) 403 | 404 | def _randn(self): 405 | res = dict() 406 | for key, value in self.items(): 407 | res[key] = torch.randn_like(value) 408 | return res 409 | 410 | def randn_(self): 411 | self.params = self._randn() 412 | 413 | def randn(self): 414 | return Weights(self._randn()) 415 | 416 | def randint_(self, high): 417 | self.params = self._randint(high) 418 | 419 | def _randint(self, high): 420 | res = dict() 421 | for key, value in self.items(): 422 | res[key] = torch.randint_like(value, high) 423 | return res 424 | 425 | def randint(self, high): 426 | return Weights(self._randint(high)) 427 | 428 | """TODO 429 | # type 430 | # cat 431 | # split 432 | """ 433 | 434 | 435 | """distance 436 | # TBA 437 | """ 438 | 439 | 440 | def FilterNorm(weights): 441 | # Filter-wise Normalization 442 | 443 | theta = Frobenius(weights) 444 | 445 | res = dict() 446 | for key, value in weights.items(): 447 | d = Frobenius(Weights({key: value})) 448 | d += 1e-10 # Ref. https://github.com/tomgoldstein/loss-landscape/blob/master/net_plotter.py#L111 449 | res[key] = value.div(d).mul(theta) 450 | 451 | if isinstance(weights, Weights): 452 | return Weights(res) 453 | else: 454 | return res 455 | 456 | 457 | def Frobenius(weights, base_weights=None): 458 | # Frobenius Norm. 459 | base_weights = base_weights or weights.zeros() 460 | square = ((weights - base_weights) ** 2) 461 | 462 | total = 0. 463 | for key, value in square.items(): 464 | total += torch.sum(value).item() 465 | 466 | return math.sqrt(total) 467 | 468 | 469 | if __name__ == "__main__": 470 | from net import DenseNet 471 | 472 | net1 = DenseNet( 473 | growthRate=12, 474 | depth=100, 475 | reduction=0.5, 476 | bottleneck=True, 477 | nClasses=10) 478 | w1 = Weights(net1.named_parameters()) 479 | w2 = Weights(net1.named_parameters()) + 2 480 | 481 | print(w2 <= w2) 482 | print(Frobenius(w1)) 483 | print(Frobenius(FilterNorm(w1))) 484 | print(Frobenius(w1, w2)) 485 | print(Frobenius(w1, w1)) 486 | -------------------------------------------------------------------------------- /watch.sh: -------------------------------------------------------------------------------- 1 | watch -d -n 2 nvidia-smi 2 | --------------------------------------------------------------------------------