├── README.md ├── augmentations ├── .ipynb_checkpoints │ └── __init__-checkpoint.py ├── __init__.py └── __pycache__ │ └── __init__.cpython-37.pyc ├── configs ├── .ipynb_checkpoints │ └── __init__-checkpoint.py ├── __init__.py └── __pycache__ │ └── __init__.cpython-37.pyc ├── datasets ├── .ipynb_checkpoints │ └── __init__-checkpoint.py ├── __init__.py └── __pycache__ │ └── __init__.cpython-37.pyc ├── fedavg-fixmatch-main.py ├── fedavg-uda-main.py ├── fedcon-main-sec45.py ├── fedcon-main.py ├── fedmatch-main.py ├── fedprox-fixmatch-main.py ├── fedprox-uda-main.py ├── models ├── .ipynb_checkpoints │ ├── __init__-checkpoint.py │ ├── backbones-checkpoint.py │ └── byol-checkpoint.py ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── backbones.cpython-37.pyc │ ├── byol.cpython-37.pyc │ ├── simclr.cpython-37.pyc │ └── simsiam.cpython-37.pyc ├── backbones.py ├── byol.py ├── simclr.py ├── simsiam.py └── swav.py ├── optimizers ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── larc.cpython-37.pyc │ ├── lars.cpython-37.pyc │ ├── lars_simclr.cpython-37.pyc │ └── lr_scheduler.cpython-37.pyc ├── larc.py ├── lars.py ├── lars_simclr.py └── lr_scheduler.py └── tools ├── .ipynb_checkpoints └── __init__-checkpoint.py ├── __init__.py ├── __pycache__ ├── __init__.cpython-37.pyc ├── accuracy.cpython-37.pyc ├── average_meter.cpython-37.pyc ├── knn_monitor.cpython-37.pyc ├── plot_logger.cpython-37.pyc └── ramps.cpython-37.pyc ├── accuracy.py ├── average_meter.py ├── knn_monitor.py ├── plot_logger.py └── ramps.py /README.md: -------------------------------------------------------------------------------- 1 | # FedCon: A Contrastive Framework for Federated Semi-Supervised Learning 2 | A PyTorch implementation for the paper **FedCon: A Contrastive Framework for Federated Semi-Supervised Learning**. 3 | 4 | We have 5 baselines (FedAvg-FixMatch, FedProx-FixMatch, FedAvg-UDA, FedProx-UDA, FedMatch, SSFL) and our proposed FedCon framework in our experiment. 5 | 6 | We do our experiments on MNIST, CIFAR-10, and SVHN datasets. 7 | 8 | you should place your data in `./fedcon-ecmlpkdd2021/data/mnist` (mnist for example) 9 | 10 | 11 | 12 | 13 | 14 | ## Getting Started 15 | 16 | python>=3.6 17 | pytorch>=0.4 18 | 19 | To install PyTorch, see installation instructions on the [PyTorch website](https://pytorch.org/get-started/locally). 20 | 21 | 22 | 23 | ## Some Exampless 24 | 25 | We provide some examples here. 26 | 27 | 28 | 29 | #### MNIST IID 30 | 31 | > python fedcon-main.py --data_dir ../data/mnist --backbone Mnist --dataset mnist --batch_size 10 --num_epochs 200 --label_rate 0.01 --iid iid 32 | 33 | #### MNIST IID & gamma (label_rate)=0.1 34 | 35 | > python fedcon-main.py --data_dir ../data/mnist --backbone Mnist --dataset mnist --batch_size 10 --num_epochs 200 --label_rate 0.1 --iid iid 36 | 37 | #### MNIST non-IID 38 | 39 | > python fedcon-main.py --data_dir ../data/mnist --backbone Mnist --dataset mnist --batch_size 10 --num_epochs 200 --label_rate 0.01 --iid noniid 40 | 41 | #### CIFAR-10 IID 42 | 43 | > python fedcon-main.py --data_dir ../data/cifar --backbone Cifar --dataset cifar10 --batch_size 10 --num_epochs 200 --label_rate 0.01 --iid iid 44 | 45 | #### SVHN IID 46 | 47 | > python fedcon-main.py --data_dir ../data/svhn --backbone Svhn --dataset svhn --batch_size 10 --num_epochs 150 --label_rate 0.01 --iid iid 48 | 49 | #### Citation 50 | 51 | You can cite our work by 52 | ``` 53 | @article{long2021fedcon, 54 | title={FedCon: A Contrastive Framework for Federated Semi-Supervised Learning}, 55 | author={Long, Zewei and Wang, Jiaqi and Wang, Yaqing and Xiao, Houping and Ma, Fenglong}, 56 | journal={arXiv preprint arXiv:2109.04533}, 57 | year={2021} 58 | } 59 | ``` 60 | 61 | 62 | -------------------------------------------------------------------------------- /augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from torchvision import transforms 4 | from PIL import Image, ImageOps 5 | 6 | 7 | imagenet_norm = [[0.485, 0.456, 0.406],[0.229, 0.224, 0.225]] 8 | 9 | #!/usr/bin/env python 10 | # -*- coding: utf-8 -*- 11 | # Python version: 3.6 12 | 13 | import matplotlib 14 | matplotlib.use('Agg') 15 | import matplotlib.pyplot as plt 16 | import copy 17 | from torchvision import datasets, transforms 18 | import torch 19 | 20 | import logging 21 | import random 22 | 23 | import numpy as np 24 | import PIL 25 | import PIL.ImageOps 26 | import PIL.ImageEnhance 27 | import PIL.ImageDraw 28 | from PIL import Image 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | PARAMETER_MAX = 10 33 | 34 | 35 | def AutoContrast(img, **kwarg): 36 | return PIL.ImageOps.autocontrast(img) 37 | 38 | 39 | def Brightness(img, v, max_v, bias=0): 40 | v = _float_parameter(v, max_v) + bias 41 | return PIL.ImageEnhance.Brightness(img).enhance(v) 42 | 43 | 44 | def Color(img, v, max_v, bias=0): 45 | v = _float_parameter(v, max_v) + bias 46 | return PIL.ImageEnhance.Color(img).enhance(v) 47 | 48 | 49 | def Contrast(img, v, max_v, bias=0): 50 | v = _float_parameter(v, max_v) + bias 51 | return PIL.ImageEnhance.Contrast(img).enhance(v) 52 | 53 | 54 | def Cutout(img, v, max_v, bias=0): 55 | if v == 0: 56 | return img 57 | v = _float_parameter(v, max_v) + bias 58 | v = int(v * min(img.size)) 59 | return CutoutAbs(img, v) 60 | 61 | 62 | def CutoutAbs(img, v, **kwarg): 63 | w, h = img.size 64 | x0 = np.random.uniform(0, w) 65 | y0 = np.random.uniform(0, h) 66 | x0 = int(max(0, x0 - v / 2.)) 67 | y0 = int(max(0, y0 - v / 2.)) 68 | x1 = int(min(w, x0 + v)) 69 | y1 = int(min(h, y0 + v)) 70 | xy = (x0, y0, x1, y1) 71 | # gray 72 | color = (127, 127, 127) 73 | img = img.copy() 74 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 75 | return img 76 | 77 | 78 | def Equalize(img, **kwarg): 79 | return PIL.ImageOps.equalize(img) 80 | 81 | 82 | def Identity(img, **kwarg): 83 | return img 84 | 85 | 86 | def Invert(img, **kwarg): 87 | return PIL.ImageOps.invert(img) 88 | 89 | 90 | def Posterize(img, v, max_v, bias=0): 91 | v = _int_parameter(v, max_v) + bias 92 | return PIL.ImageOps.posterize(img, v) 93 | 94 | 95 | def Rotate(img, v, max_v, bias=0): 96 | v = _int_parameter(v, max_v) + bias 97 | if random.random() < 0.5: 98 | v = -v 99 | return img.rotate(v) 100 | 101 | 102 | def Sharpness(img, v, max_v, bias=0): 103 | v = _float_parameter(v, max_v) + bias 104 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 105 | 106 | 107 | def ShearX(img, v, max_v, bias=0): 108 | v = _float_parameter(v, max_v) + bias 109 | if random.random() < 0.5: 110 | v = -v 111 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 112 | 113 | 114 | def ShearY(img, v, max_v, bias=0): 115 | v = _float_parameter(v, max_v) + bias 116 | if random.random() < 0.5: 117 | v = -v 118 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 119 | 120 | 121 | def Solarize(img, v, max_v, bias=0): 122 | v = _int_parameter(v, max_v) + bias 123 | return PIL.ImageOps.solarize(img, 256 - v) 124 | 125 | 126 | def SolarizeAdd(img, v, max_v, bias=0, threshold=128): 127 | v = _int_parameter(v, max_v) + bias 128 | if random.random() < 0.5: 129 | v = -v 130 | img_np = np.array(img).astype(np.int) 131 | img_np = img_np + v 132 | img_np = np.clip(img_np, 0, 255) 133 | img_np = img_np.astype(np.uint8) 134 | img = Image.fromarray(img_np) 135 | return PIL.ImageOps.solarize(img, threshold) 136 | 137 | 138 | def TranslateX(img, v, max_v, bias=0): 139 | v = _float_parameter(v, max_v) + bias 140 | if random.random() < 0.5: 141 | v = -v 142 | v = int(v * img.size[0]) 143 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 144 | 145 | 146 | def TranslateY(img, v, max_v, bias=0): 147 | v = _float_parameter(v, max_v) + bias 148 | if random.random() < 0.5: 149 | v = -v 150 | v = int(v * img.size[1]) 151 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 152 | 153 | 154 | def _float_parameter(v, max_v): 155 | return float(v) * max_v / PARAMETER_MAX 156 | 157 | 158 | def _int_parameter(v, max_v): 159 | return int(v * max_v / PARAMETER_MAX) 160 | 161 | 162 | def fixmatch_augment_pool(): 163 | # FixMatch paper 164 | augs = [(AutoContrast, None, None), 165 | (Brightness, 0.9, 0.05), 166 | (Color, 0.9, 0.05), 167 | (Contrast, 0.9, 0.05), 168 | (Equalize, None, None), 169 | (Identity, None, None), 170 | (Posterize, 4, 4), 171 | (Rotate, 30, 0), 172 | (Sharpness, 0.9, 0.05), 173 | (ShearX, 0.3, 0), 174 | (ShearY, 0.3, 0), 175 | (Solarize, 256, 0), 176 | (TranslateX, 0.3, 0), 177 | (TranslateY, 0.3, 0)] 178 | return augs 179 | 180 | 181 | def my_augment_pool(): 182 | # Test 183 | augs = [(AutoContrast, None, None), 184 | (Brightness, 1.8, 0.1), 185 | (Color, 1.8, 0.1), 186 | (Contrast, 1.8, 0.1), 187 | (Cutout, 0.2, 0), 188 | (Equalize, None, None), 189 | (Invert, None, None), 190 | (Posterize, 4, 4), 191 | (Rotate, 30, 0), 192 | (Sharpness, 1.8, 0.1), 193 | (ShearX, 0.3, 0), 194 | (ShearY, 0.3, 0), 195 | (Solarize, 256, 0), 196 | (SolarizeAdd, 110, 0), 197 | (TranslateX, 0.45, 0), 198 | (TranslateY, 0.45, 0)] 199 | return augs 200 | 201 | 202 | class RandAugmentPC(object): 203 | def __init__(self, n, m): 204 | assert n >= 1 205 | assert 1 <= m <= 10 206 | self.n = n 207 | self.m = m 208 | self.augment_pool = my_augment_pool() 209 | 210 | def __call__(self, img): 211 | ops = random.choices(self.augment_pool, k=self.n) 212 | for op, max_v, bias in ops: 213 | prob = np.random.uniform(0.2, 0.8) 214 | if random.random() + prob >= 1: 215 | img = op(img, v=self.m, max_v=max_v, bias=bias) 216 | img = CutoutAbs(img, 16) 217 | return img 218 | 219 | 220 | class RandAugmentMC(object): 221 | def __init__(self, n, m): 222 | assert n >= 1 223 | assert 1 <= m <= 10 224 | self.n = n 225 | self.m = m 226 | self.augment_pool = fixmatch_augment_pool() 227 | 228 | def __call__(self, img): 229 | ops = random.choices(self.augment_pool, k=self.n) 230 | for op, max_v, bias in ops: 231 | v = np.random.randint(1, self.m) 232 | if random.random() < 0.5: 233 | img = op(img, v=v, max_v=max_v, bias=bias) 234 | img = CutoutAbs(img, 16) 235 | return img 236 | 237 | 238 | class RandomTranslateWithReflect: 239 | 240 | def __init__(self, max_translation): 241 | self.max_translation = max_translation 242 | 243 | def __call__(self, old_image): 244 | xtranslation, ytranslation = np.random.randint(-self.max_translation, 245 | self.max_translation + 1, 246 | size=2) 247 | xpad, ypad = abs(xtranslation), abs(ytranslation) 248 | xsize, ysize = old_image.size 249 | 250 | flipped_lr = old_image.transpose(Image.FLIP_LEFT_RIGHT) 251 | flipped_tb = old_image.transpose(Image.FLIP_TOP_BOTTOM) 252 | flipped_both = old_image.transpose(Image.ROTATE_180) 253 | 254 | new_image = Image.new("RGB", (xsize + 2 * xpad, ysize + 2 * ypad)) 255 | 256 | new_image.paste(old_image, (xpad, ypad)) 257 | 258 | new_image.paste(flipped_lr, (xpad + xsize - 1, ypad)) 259 | new_image.paste(flipped_lr, (xpad - xsize + 1, ypad)) 260 | 261 | new_image.paste(flipped_tb, (xpad, ypad + ysize - 1)) 262 | new_image.paste(flipped_tb, (xpad, ypad - ysize + 1)) 263 | 264 | new_image.paste(flipped_both, (xpad - xsize + 1, ypad - ysize + 1)) 265 | new_image.paste(flipped_both, (xpad + xsize - 1, ypad - ysize + 1)) 266 | new_image.paste(flipped_both, (xpad - xsize + 1, ypad + ysize - 1)) 267 | new_image.paste(flipped_both, (xpad + xsize - 1, ypad + ysize - 1)) 268 | 269 | new_image = new_image.crop((xpad - xtranslation, 270 | ypad - ytranslation, 271 | xpad + xsize - xtranslation, 272 | ypad + ysize - ytranslation)) 273 | 274 | return new_image 275 | 276 | class Mnist_Transform: # Table 6 277 | def __init__(self): 278 | 279 | 280 | self.trans_mnist1 = transforms.Compose([ 281 | # transforms.RandomHorizontalFlip(), 282 | # transforms.RandomCrop(size=28, 283 | # padding=int(28*0.125), 284 | # padding_mode='reflect'), 285 | transforms.ToTensor(), 286 | transforms.Normalize((0.1307,), (0.3081,)) 287 | ]) 288 | 289 | self.trans_mnist2 = transforms.Compose([ 290 | # transforms.RandomHorizontalFlip(), 291 | # transforms.RandomCrop(size=28, 292 | # padding=int(28*0.125), 293 | # padding_mode='reflect'), 294 | transforms.ToTensor(), 295 | transforms.Normalize((0.1307,), (0.3081,)) 296 | ]) 297 | 298 | def __call__(self, x): 299 | x1 = self.trans_mnist1(x) 300 | x2 = self.trans_mnist2(x) 301 | return x1, x2 302 | 303 | 304 | 305 | 306 | 307 | class Cifar_Transform: # Table 6 308 | def __init__(self): 309 | 310 | 311 | self.trans_cifar1 = transforms.Compose([ 312 | # transforms.RandomHorizontalFlip(), 313 | # transforms.RandomCrop(size=32, 314 | # padding=int(32*0.125), 315 | # padding_mode='reflect'), 316 | # RandAugmentMC(n=15, m=10), 317 | # transforms.ToTensor(), 318 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 319 | 320 | RandomTranslateWithReflect(4), 321 | transforms.RandomHorizontalFlip(), 322 | transforms.ToTensor(), 323 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 324 | ]) 325 | 326 | 327 | self.trans_cifar2 = transforms.Compose([ 328 | # transforms.RandomHorizontalFlip(), 329 | # transforms.RandomCrop(size=32, 330 | # padding=int(32*0.125), 331 | # padding_mode='reflect'), 332 | # RandAugmentMC(n=15, m=10), 333 | # transforms.ToTensor(), 334 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 335 | 336 | RandomTranslateWithReflect(4), 337 | transforms.RandomHorizontalFlip(), 338 | transforms.ToTensor(), 339 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 340 | ]) 341 | 342 | 343 | def __call__(self, x): 344 | x1 = self.trans_cifar1(x) 345 | x2 = self.trans_cifar2(x) 346 | return x1, x2 347 | 348 | class Svhn_Transform: # Table 6 349 | def __init__(self): 350 | 351 | 352 | self.trans_svhn1 = transforms.Compose([ 353 | RandomTranslateWithReflect(4), 354 | transforms.RandomHorizontalFlip(), 355 | transforms.ToTensor(), 356 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 357 | ]) 358 | 359 | 360 | self.trans_svhn2 = transforms.Compose([ 361 | RandomTranslateWithReflect(4), 362 | transforms.RandomHorizontalFlip(), 363 | transforms.ToTensor(), 364 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 365 | ]) 366 | 367 | 368 | def __call__(self, x): 369 | x1 = self.trans_svhn1(x) 370 | x2 = self.trans_svhn2(x) 371 | return x1, x2 372 | 373 | 374 | 375 | class Mnist_Transform_t(): 376 | def __init__(self): 377 | 378 | self.trans_mnist = transforms.Compose([ 379 | transforms.ToTensor(), 380 | transforms.Normalize((0.1307,), (0.3081,)) 381 | ]) 382 | 383 | def __call__(self, x): 384 | return self.trans_mnist(x) 385 | 386 | class Cifar_Transform_t(): 387 | def __init__(self): 388 | 389 | self.trans_cifar = transforms.Compose([ 390 | transforms.ToTensor(), 391 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 392 | ]) 393 | 394 | def __call__(self, x): 395 | return self.trans_cifar(x) 396 | 397 | class Svhn_Transform_t(): 398 | def __init__(self): 399 | 400 | self.trans_svhn = transforms.Compose([ 401 | transforms.ToTensor(), 402 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 403 | ]) 404 | 405 | def __call__(self, x): 406 | return self.trans_svhn(x) 407 | 408 | 409 | 410 | 411 | 412 | def get_aug(name, train, train_classifier=True): 413 | 414 | if train==True: 415 | if name == 'mnist': 416 | augmentation = Mnist_Transform() 417 | elif name == 'cifar10': 418 | augmentation = Cifar_Transform() 419 | elif name == 'svhn': 420 | augmentation = Svhn_Transform() 421 | else: 422 | raise NotImplementedError 423 | elif train==False: 424 | if name == 'mnist': 425 | augmentation = Mnist_Transform_t() 426 | elif name == 'cifar10': 427 | augmentation = Cifar_Transform_t() 428 | elif name == 'svhn': 429 | augmentation = Svhn_Transform_t() 430 | else: 431 | raise NotImplementedError 432 | 433 | return augmentation 434 | 435 | class Mnist_Transform_fedmatch: # Table 6 436 | def __init__(self): 437 | 438 | self.trans_mnist1 = transforms.Compose([ 439 | transforms.RandomHorizontalFlip(), 440 | transforms.RandomCrop(size=28, 441 | padding=int(28*0.125), 442 | padding_mode='reflect'), 443 | transforms.RandomGrayscale(p=0.1), 444 | transforms.ColorJitter(brightness=0.4, contrast=0.3, saturation=0.2, hue=0.2), 445 | transforms.ToTensor(), 446 | transforms.Normalize((0.1307,), (0.3081,)) 447 | ]) 448 | 449 | self.trans_mnist2 = transforms.Compose([ 450 | transforms.RandomHorizontalFlip(), 451 | transforms.ToTensor(), 452 | transforms.Normalize((0.1307,), (0.3081,)) 453 | ]) 454 | 455 | 456 | 457 | def __call__(self, x): 458 | x1 = self.trans_mnist1(x) 459 | x2 = self.trans_mnist2(x) 460 | return x1, x2 461 | 462 | class Cifar_Transform_fedmatch: # Table 6 463 | def __init__(self): 464 | 465 | 466 | self.trans_cifar1 = transforms.Compose([ 467 | transforms.RandomHorizontalFlip(), 468 | transforms.RandomCrop(size=32, 469 | padding=int(32*0.125), 470 | padding_mode='reflect'), 471 | RandAugmentMC(n=15, m=10), 472 | transforms.ToTensor(), 473 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 474 | ]) 475 | 476 | 477 | self.trans_cifar2 = transforms.Compose([ 478 | RandomTranslateWithReflect(4), 479 | transforms.RandomHorizontalFlip(), 480 | transforms.ToTensor(), 481 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 482 | ]) 483 | 484 | 485 | def __call__(self, x): 486 | x1 = self.trans_cifar1(x) 487 | x2 = self.trans_cifar2(x) 488 | return x1, x2 489 | 490 | class Svhn_Transform_fedmatch: # Table 6 491 | def __init__(self): 492 | 493 | 494 | self.trans_svhn1 = transforms.Compose([ 495 | transforms.RandomHorizontalFlip(), 496 | transforms.RandomCrop(size=32, 497 | padding=int(32*0.125), 498 | padding_mode='reflect'), 499 | RandAugmentMC(n=15, m=10), # 1 改成10 10 500 | transforms.ToTensor(), 501 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 502 | ]) 503 | 504 | 505 | self.trans_svhn2 = transforms.Compose([ 506 | RandomTranslateWithReflect(4), 507 | transforms.RandomHorizontalFlip(), 508 | transforms.ToTensor(), 509 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 510 | ]) 511 | 512 | 513 | def __call__(self, x): 514 | x1 = self.trans_svhn1(x) 515 | x2 = self.trans_svhn2(x) 516 | return x1, x2 517 | 518 | # class Svhn_Transform_fedmatch: # Table 6 519 | # def __init__(self): 520 | 521 | 522 | # self.trans_svhn1 = transforms.Compose([ 523 | # transforms.RandomHorizontalFlip(), 524 | # transforms.RandomCrop(size=32, 525 | # padding=int(32*0.125), 526 | # padding_mode='reflect'), 527 | # RandAugmentMC(n=5, m=5), # 1 改成10 10 528 | # transforms.ToTensor(), 529 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 530 | # ]) 531 | 532 | 533 | # self.trans_svhn2 = transforms.Compose([ 534 | # transforms.ToTensor(), 535 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 536 | # ]) 537 | 538 | 539 | # def __call__(self, x): 540 | # x1 = self.trans_svhn1(x) 541 | # x2 = self.trans_svhn2(x) 542 | # return x1, x2 543 | 544 | def get_aug_fedmatch(name, train, train_classifier=True): 545 | 546 | if train==True: 547 | if name == 'mnist': 548 | augmentation = Mnist_Transform_fedmatch() 549 | elif name == 'cifar10': 550 | augmentation = Cifar_Transform_fedmatch() 551 | elif name == 'svhn': 552 | augmentation = Svhn_Transform_fedmatch() 553 | else: 554 | raise NotImplementedError 555 | elif train==False: 556 | if name == 'mnist': 557 | augmentation = Mnist_Transform_t() 558 | elif name == 'cifar10': 559 | augmentation = Cifar_Transform_t() 560 | elif name == 'svhn': 561 | augmentation = Svhn_Transform_t() 562 | else: 563 | raise NotImplementedError 564 | 565 | return augmentation 566 | 567 | 568 | class Mnist_Transform_uda: # Table 6 569 | def __init__(self): 570 | 571 | self.trans_mnist1 = transforms.Compose([ 572 | transforms.RandomHorizontalFlip(), 573 | transforms.ToTensor(), 574 | transforms.Normalize((0.1307,), (0.3081,)) 575 | ]) 576 | self.trans_mnist2 = transforms.Compose([ 577 | transforms.RandomHorizontalFlip(), 578 | transforms.RandomCrop(size=28, 579 | padding=int(28*0.125), 580 | padding_mode='reflect'), 581 | transforms.RandomGrayscale(p=0.1), 582 | transforms.ColorJitter(brightness=0.4, contrast=0.3, saturation=0.2, hue=0.2), 583 | transforms.ToTensor(), 584 | transforms.Normalize((0.1307,), (0.3081,)) 585 | ]) 586 | 587 | 588 | 589 | def __call__(self, x): 590 | x1 = self.trans_mnist1(x) 591 | x2 = self.trans_mnist2(x) 592 | return x1, x2 593 | 594 | class Cifar_Transform_uda: # Table 6 595 | def __init__(self): 596 | 597 | self.trans_cifar1 = transforms.Compose([ 598 | RandomTranslateWithReflect(4), 599 | transforms.RandomHorizontalFlip(), 600 | transforms.ToTensor(), 601 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 602 | ]) 603 | 604 | self.trans_cifar2 = transforms.Compose([ 605 | transforms.RandomHorizontalFlip(), 606 | transforms.RandomCrop(size=32, 607 | padding=int(32*0.125), 608 | padding_mode='reflect'), 609 | RandAugmentMC(n=15, m=10), 610 | transforms.ToTensor(), 611 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 612 | ]) 613 | 614 | 615 | 616 | def __call__(self, x): 617 | x1 = self.trans_cifar1(x) 618 | x2 = self.trans_cifar2(x) 619 | return x1, x2 620 | 621 | class Svhn_Transform_uda: # Table 6 622 | def __init__(self): 623 | 624 | 625 | self.trans_svhn1 = transforms.Compose([ 626 | RandomTranslateWithReflect(4), 627 | transforms.RandomHorizontalFlip(), 628 | transforms.ToTensor(), 629 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 630 | ]) 631 | 632 | self.trans_svhn2 = transforms.Compose([ 633 | transforms.RandomHorizontalFlip(), 634 | transforms.RandomCrop(size=32, 635 | padding=int(32*0.125), 636 | padding_mode='reflect'), 637 | RandAugmentMC(n=15, m=10), # 1 改成10 10 638 | transforms.ToTensor(), 639 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 640 | ]) 641 | 642 | 643 | 644 | 645 | def __call__(self, x): 646 | x1 = self.trans_svhn1(x) 647 | x2 = self.trans_svhn2(x) 648 | return x1, x2 649 | 650 | 651 | def get_aug_uda(name, train, train_classifier=True): 652 | 653 | if train==True: 654 | if name == 'mnist': 655 | augmentation = Mnist_Transform_uda() 656 | elif name == 'cifar10': 657 | augmentation = Cifar_Transform_uda() 658 | elif name == 'svhn': 659 | augmentation = Svhn_Transform_uda() 660 | else: 661 | raise NotImplementedError 662 | elif train==False: 663 | if name == 'mnist': 664 | augmentation = Mnist_Transform_t() 665 | elif name == 'cifar10': 666 | augmentation = Cifar_Transform_t() 667 | elif name == 'svhn': 668 | augmentation = Svhn_Transform_t() 669 | else: 670 | raise NotImplementedError 671 | 672 | return augmentation 673 | 674 | 675 | 676 | -------------------------------------------------------------------------------- /augmentations/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/augmentations/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /configs/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | 5 | import numpy as np 6 | import torch 7 | import random 8 | 9 | 10 | 11 | def set_deterministic(seed): 12 | # seed by default is None 13 | if seed is not None: 14 | print(f"Deterministic with seed = {seed}") 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = False 21 | else: 22 | print("Non-deterministic") 23 | 24 | def get_args(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--debug', action='store_true') 27 | # training specific args 28 | parser.add_argument('--dataset', type=str, default='cifar10', help='choose from random, stl10, mnist, cifar10, cifar100, imagenet') 29 | parser.add_argument('--download', action='store_true', help="if can't find dataset, download from web") 30 | parser.add_argument('--image_size', type=int, default=224) 31 | parser.add_argument('--num_workers', type=int, default=4) 32 | parser.add_argument('--data_dir', type=str, default=os.getenv('DATA')) 33 | parser.add_argument('--output_dir', type=str, default='./outputs/') 34 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') 35 | parser.add_argument('--resume', type=str, default=None) 36 | parser.add_argument('--eval_from', type=str, default=None) 37 | 38 | parser.add_argument('--hide_progress', action='store_true') 39 | parser.add_argument('--use_default_hyperparameters', action='store_true') 40 | # model related params 41 | parser.add_argument('--model', type=str, default='byol') 42 | parser.add_argument('--backbone', type=str, default='resnet50') 43 | parser.add_argument('--num_epochs', type=int, default=100, help='This will affect learning rate decay') 44 | parser.add_argument('--stop_at_epoch', type=int, default=None) 45 | parser.add_argument('--batch_size', type=int, default=10) 46 | parser.add_argument('--proj_layers', type=int, default=None, help="number of projector layers. In cifar experiment, this is set to 2") 47 | # optimization params 48 | parser.add_argument('--optimizer', type=str, default='lars_simclr', help='sgd, lars(from lars paper), lars_simclr(used in simclr and byol), larc(used in swav)') 49 | parser.add_argument('--warmup_epochs', type=int, default=10, help='learning rate will be linearly scaled during warm up period') 50 | parser.add_argument('--warmup_lr', type=float, default=0, help='Initial warmup learning rate') 51 | parser.add_argument('--base_lr', type=float, default=0.3) 52 | parser.add_argument('--final_lr', type=float, default=0) 53 | parser.add_argument('--lr', type=float, default=0.01, help="learning rate") 54 | parser.add_argument('--initial-lr', default=0.0, type=float,metavar='LR', help='initial learning rate when using linear rampup') 55 | parser.add_argument('--lr-rampup', default=0, type=int, metavar='EPOCHS',help='length of learning rate rampup in the beginning') 56 | parser.add_argument('--lr-rampdown-epochs', default=None, type=int, metavar='EPOCHS',help='length of learning rate cosine rampdown (>= length of training)') 57 | 58 | parser.add_argument('--momentum', type=float, default=0.9) 59 | parser.add_argument('--weight_decay', type=float, default=1.5e-6) 60 | 61 | parser.add_argument('--eval_after_train', type=str, default=None) 62 | parser.add_argument('--head_tail_accuracy', action='store_true', help='the acc in first epoch will indicate whether collapse or not, the last epoch shows the final accuracy') 63 | 64 | parser.add_argument('--num_users', type=int, default=100, help="number of users: K") 65 | parser.add_argument('--local_ep', type=int, default=1, help="number of local epochs: E") 66 | parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C") 67 | 68 | parser.add_argument('--label_rate', type=float, default=0.1, help="the fraction of labeled data") 69 | parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU") 70 | parser.add_argument('--threshold_pl', default=0.95, type=float,help='pseudo label threshold') 71 | parser.add_argument('--phi_g', type=int, default=10, help="tipping point 1") 72 | parser.add_argument('--psi_g', type=int, default=40, help="tipping point 2") 73 | parser.add_argument('--comu_rate',type=float, default=0.5,help="the comu_rate of ema model") 74 | parser.add_argument('--ramp',type=str,default='linear', help="ramp of comu") 75 | parser.add_argument('--ema_decay', default=0.999, type=float, metavar='ALPHA', help='ema variable decay rate (default: 0.999)') 76 | 77 | parser.add_argument('--iid', type=str, default='iid', help='iid') 78 | args = parser.parse_args() 79 | 80 | if args.debug: 81 | args.batch_size = 2 82 | args.stop_at_epoch = 2 83 | args.num_epochs = 3 # train only one epoch 84 | args.num_workers = 0 85 | 86 | assert not None in [args.output_dir, args.data_dir] 87 | os.makedirs(args.output_dir, exist_ok=True) 88 | # assert args.stop_at_epoch <= args.num_epochs 89 | if args.stop_at_epoch is not None: 90 | if args.stop_at_epoch > args.num_epochs: 91 | raise Exception 92 | else: 93 | args.stop_at_epoch = args.num_epochs 94 | 95 | if args.use_default_hyperparameters: 96 | raise NotImplementedError 97 | return args 98 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | 5 | import numpy as np 6 | import torch 7 | import random 8 | 9 | 10 | 11 | def set_deterministic(seed): 12 | # seed by default is None 13 | if seed is not None: 14 | print(f"Deterministic with seed = {seed}") 15 | random.seed(seed) 16 | np.random.seed(seed) 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = False 21 | else: 22 | print("Non-deterministic") 23 | 24 | def get_args(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--debug', action='store_true') 27 | # training specific args 28 | parser.add_argument('--dataset', type=str, default='cifar10', help='choose from random, stl10, mnist, cifar10, cifar100, imagenet') 29 | parser.add_argument('--download', action='store_true', help="if can't find dataset, download from web") 30 | parser.add_argument('--image_size', type=int, default=224) 31 | parser.add_argument('--num_workers', type=int, default=4) 32 | parser.add_argument('--data_dir', type=str, default=os.getenv('DATA')) 33 | parser.add_argument('--output_dir', type=str, default='./outputs/') 34 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') 35 | parser.add_argument('--resume', type=str, default=None) 36 | parser.add_argument('--eval_from', type=str, default=None) 37 | 38 | parser.add_argument('--hide_progress', action='store_true') 39 | parser.add_argument('--use_default_hyperparameters', action='store_true') 40 | # model related params 41 | parser.add_argument('--model', type=str, default='byol') 42 | parser.add_argument('--backbone', type=str, default='resnet50') 43 | parser.add_argument('--num_epochs', type=int, default=100, help='This will affect learning rate decay') 44 | parser.add_argument('--stop_at_epoch', type=int, default=None) 45 | parser.add_argument('--batch_size', type=int, default=10) 46 | parser.add_argument('--proj_layers', type=int, default=None, help="number of projector layers. In cifar experiment, this is set to 2") 47 | # optimization params 48 | parser.add_argument('--optimizer', type=str, default='lars_simclr', help='sgd, lars(from lars paper), lars_simclr(used in simclr and byol), larc(used in swav)') 49 | parser.add_argument('--warmup_epochs', type=int, default=10, help='learning rate will be linearly scaled during warm up period') 50 | parser.add_argument('--warmup_lr', type=float, default=0, help='Initial warmup learning rate') 51 | parser.add_argument('--base_lr', type=float, default=0.3) 52 | parser.add_argument('--final_lr', type=float, default=0) 53 | parser.add_argument('--lr', type=float, default=0.01, help="learning rate") 54 | parser.add_argument('--initial-lr', default=0.0, type=float,metavar='LR', help='initial learning rate when using linear rampup') 55 | parser.add_argument('--lr-rampup', default=0, type=int, metavar='EPOCHS',help='length of learning rate rampup in the beginning') 56 | parser.add_argument('--lr-rampdown-epochs', default=None, type=int, metavar='EPOCHS',help='length of learning rate cosine rampdown (>= length of training)') 57 | 58 | parser.add_argument('--momentum', type=float, default=0.9) 59 | parser.add_argument('--weight_decay', type=float, default=1.5e-6) 60 | 61 | parser.add_argument('--eval_after_train', type=str, default=None) 62 | parser.add_argument('--head_tail_accuracy', action='store_true', help='the acc in first epoch will indicate whether collapse or not, the last epoch shows the final accuracy') 63 | 64 | parser.add_argument('--num_users', type=int, default=100, help="number of users: K") 65 | parser.add_argument('--local_ep', type=int, default=1, help="number of local epochs: E") 66 | parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C") 67 | 68 | parser.add_argument('--label_rate', type=float, default=0.1, help="the fraction of labeled data") 69 | parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU") 70 | parser.add_argument('--threshold_pl', default=0.95, type=float,help='pseudo label threshold') 71 | parser.add_argument('--phi_g', type=int, default=10, help="tipping point 1") 72 | parser.add_argument('--psi_g', type=int, default=40, help="tipping point 2") 73 | parser.add_argument('--comu_rate',type=float, default=0.5,help="the comu_rate of ema model") 74 | parser.add_argument('--ramp',type=str,default='linear', help="ramp of comu") 75 | parser.add_argument('--ema_decay', default=0.999, type=float, metavar='ALPHA', help='ema variable decay rate (default: 0.999)') 76 | 77 | parser.add_argument('--iid', type=str, default='iid', help='iid') 78 | args = parser.parse_args() 79 | 80 | if args.debug: 81 | args.batch_size = 2 82 | args.stop_at_epoch = 2 83 | args.num_epochs = 3 # train only one epoch 84 | args.num_workers = 0 85 | 86 | assert not None in [args.output_dir, args.data_dir] 87 | os.makedirs(args.output_dir, exist_ok=True) 88 | # assert args.stop_at_epoch <= args.num_epochs 89 | if args.stop_at_epoch is not None: 90 | if args.stop_at_epoch > args.num_epochs: 91 | raise Exception 92 | else: 93 | args.stop_at_epoch = args.num_epochs 94 | 95 | if args.use_default_hyperparameters: 96 | raise NotImplementedError 97 | return args 98 | -------------------------------------------------------------------------------- /configs/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/configs/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | 5 | def get_dataset(dataset, data_dir, transform, train=True, download=False, debug_subset_size=None): 6 | if dataset == 'mnist': 7 | dataset = torchvision.datasets.MNIST(data_dir, train=train, transform=transform, download=download) 8 | elif dataset == 'cifar10': 9 | dataset = torchvision.datasets.CIFAR10(data_dir, train=train, transform=transform, download=download) 10 | elif dataset == 'svhn': 11 | if train == True: 12 | dataset = torchvision.datasets.SVHN(data_dir, split = 'train', transform=transform, download=download) 13 | else: 14 | dataset = torchvision.datasets.SVHN(data_dir, split = 'test', transform=transform, download=download) 15 | else: 16 | raise NotImplementedError 17 | if debug_subset_size is not None: 18 | dataset = torch.utils.data.Subset(dataset, range(0, debug_subset_size)) # take only one batch 19 | 20 | return dataset 21 | 22 | 23 | # ython main.py --model simsiam --optimizer sgd --data_dir ./data/cifar --output_dir ./outputs/ --backbone resnet18 --dataset cifar10 --batch_size 32 --num_epochs 2 --weight_decay 0.0005 --base_lr 0.03 --warmup_epochs 10 -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | 5 | def get_dataset(dataset, data_dir, transform, train=True, download=False, debug_subset_size=None): 6 | if dataset == 'mnist': 7 | dataset = torchvision.datasets.MNIST(data_dir, train=train, transform=transform, download=download) 8 | elif dataset == 'cifar10': 9 | dataset = torchvision.datasets.CIFAR10(data_dir, train=train, transform=transform, download=download) 10 | elif dataset == 'svhn': 11 | if train == True: 12 | dataset = torchvision.datasets.SVHN(data_dir, split = 'train', transform=transform, download=download) 13 | else: 14 | dataset = torchvision.datasets.SVHN(data_dir, split = 'test', transform=transform, download=download) 15 | else: 16 | raise NotImplementedError 17 | if debug_subset_size is not None: 18 | dataset = torch.utils.data.Subset(dataset, range(0, debug_subset_size)) # take only one batch 19 | 20 | return dataset 21 | 22 | 23 | # ython main.py --model simsiam --optimizer sgd --data_dir ./data/cifar --output_dir ./outputs/ --backbone resnet18 --dataset cifar10 --batch_size 32 --num_epochs 2 --weight_decay 0.0005 --base_lr 0.03 --warmup_epochs 10 -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /fedavg-fixmatch-main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | import numpy as np 7 | import copy 8 | 9 | import gc 10 | 11 | from tqdm import tqdm 12 | from configs import get_args 13 | from augmentations import get_aug, get_aug_fedmatch 14 | from models import get_model 15 | from tools import AverageMeter, PlotLogger, knn_monitor 16 | from datasets import get_dataset 17 | from optimizers import get_optimizer, LR_Scheduler 18 | from torch.utils.data import DataLoader, Dataset 19 | 20 | 21 | import torch 22 | from torch import nn, autograd 23 | from torch.utils.data import DataLoader, Dataset 24 | import numpy as np 25 | import random 26 | from sklearn import metrics 27 | import torch.nn.functional as F 28 | import copy 29 | from torch.autograd import Variable 30 | import itertools 31 | import logging 32 | import os.path 33 | from PIL import Image 34 | import numpy as np 35 | from torch.utils.data.sampler import Sampler 36 | import re 37 | import argparse 38 | import os 39 | import shutil 40 | import time 41 | import math 42 | import logging 43 | import os 44 | import sys 45 | import torch.backends.cudnn as cudnn 46 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 47 | import torchvision.datasets 48 | 49 | 50 | import torch 51 | from torch import nn 52 | import torch.nn.functional as F 53 | from torch.utils.data import DataLoader 54 | 55 | 56 | def test_img(net_g, data_loader, args): 57 | net_g.eval() 58 | test_loss = 0 59 | correct = 0 60 | 61 | for idx, (data, target) in enumerate(data_loader): 62 | data, target = data.cuda(), target.cuda() 63 | log_probs = net_g(data) 64 | test_loss += F.cross_entropy(log_probs, target, reduction='sum',ignore_index=-1).item() 65 | y_pred = log_probs.data.max(1, keepdim=True)[1] 66 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 67 | 68 | test_loss /= len(data_loader.dataset) 69 | accuracy = 100.00 * correct / len(data_loader.dataset) 70 | 71 | return accuracy, test_loss 72 | 73 | 74 | def get_current_consistency_weight(epoch): 75 | return sigmoid_rampup(epoch, 10) 76 | 77 | def sigmoid_rampup(current, rampup_length): 78 | if rampup_length == 0: 79 | return 1.0 80 | else: 81 | current = np.clip(current, 0.0, rampup_length) 82 | phase = 1.0 - current / rampup_length 83 | return float(np.exp(-5.0 * phase * phase)) 84 | 85 | def softmax_mse_loss(input_logits, target_logits): 86 | assert input_logits.size() == target_logits.size() 87 | input_softmax = F.softmax(input_logits, dim=1) 88 | target_softmax = F.softmax(target_logits, dim=1) 89 | num_classes = input_logits.size()[1] 90 | return F.mse_loss(input_softmax, target_softmax, size_average=False) / num_classes 91 | 92 | def softmax_kl_loss(input_logits, target_logits): 93 | assert input_logits.size() == target_logits.size() 94 | input_log_softmax = F.log_softmax(input_logits, dim=1) 95 | target_softmax = F.softmax(target_logits, dim=1) 96 | return F.kl_div(input_log_softmax, target_softmax, size_average=False) 97 | 98 | def symmetric_mse_loss(input1, input2): 99 | assert input1.size() == input2.size() 100 | num_classes = input1.size()[1] 101 | return torch.sum((input1 - input2)**2) / num_classes 102 | 103 | 104 | def FedAvg(w): 105 | w_avg = copy.deepcopy(w[0]) 106 | # print(w_avg.keys()) 107 | for k in w_avg.keys(): 108 | for i in range(1, len(w)): 109 | w_avg[k] += w[i][k] 110 | w_avg[k] = torch.div(w_avg[k], len(w)) 111 | return w_avg 112 | 113 | 114 | def iid(dataset, num_users, label_rate): 115 | """ 116 | Sample I.I.D. client data from MNIST dataset 117 | :param dataset: 118 | :param num_users: 119 | :return: dict of image index 120 | """ 121 | num_items = int(len(dataset)/num_users) 122 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 123 | dict_users_labeled, dict_users_unlabeled = set(), {} 124 | 125 | dict_users_labeled = set(np.random.choice(list(all_idxs), int(len(all_idxs) * label_rate), replace=False)) 126 | 127 | for i in range(num_users): 128 | # dict_users_labeled = dict_users_labeled | set(np.random.choice(all_idxs, int(num_items * label_rate), replace=False)) 129 | # all_idxs = list(set(all_idxs) - dict_users_labeled) 130 | dict_users_unlabeled[i] = set(np.random.choice(all_idxs, int(num_items) , replace=False)) 131 | all_idxs = list(set(all_idxs) - dict_users_unlabeled[i]) 132 | dict_users_unlabeled[i] = dict_users_unlabeled[i] - dict_users_labeled 133 | return dict_users_labeled, dict_users_unlabeled 134 | 135 | # def noniid(dataset, num_users, label_rate): 136 | 137 | # num_shards, num_imgs = 2 * num_users, int(len(dataset)/num_users/2) 138 | # idx_shard = [i for i in range(num_shards)] 139 | # dict_users_unlabeled = {i: np.array([], dtype='int64') for i in range(num_users)} 140 | # idxs = np.arange(num_shards*num_imgs) 141 | # labels = dataset.train_labels.numpy() 142 | # # print(type(labels)) 143 | 144 | # num_items = int(len(dataset)/num_users) 145 | # dict_users_labeled = set() 146 | # pseduo_label = [i for i in range(len(dataset))] 147 | 148 | # # sort labels 149 | # idxs_labels = np.vstack((idxs, labels)) 150 | # idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]#索引值 151 | # idxs = idxs_labels[0,:] 152 | 153 | # # divide and assign 154 | # for i in range(num_users): 155 | # rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 156 | # idx_shard = list(set(idx_shard) - rand_set) 157 | # for rand in rand_set: 158 | # dict_users_unlabeled[i] = np.concatenate((dict_users_unlabeled[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 159 | 160 | # for i in range(num_users): 161 | 162 | # dict_users_unlabeled[i] = set(dict_users_unlabeled[i]) 163 | # dict_users_labeled = dict_users_labeled | set(np.random.choice(list(dict_users_unlabeled[i]), int(num_items * label_rate), replace=False)) 164 | # dict_users_unlabeled[i] = dict_users_unlabeled[i] - dict_users_labeled 165 | 166 | 167 | # return dict_users_labeled, dict_users_unlabeled 168 | 169 | def noniid(dataset, num_users, label_rate): 170 | 171 | num_shards, num_imgs = 2 * num_users, int(len(dataset)/num_users/2) 172 | idx_shard = [i for i in range(num_shards)] 173 | dict_users_unlabeled = {i: np.array([], dtype='int64') for i in range(num_users)} 174 | idxs = np.arange(len(dataset)) 175 | labels = np.arange(len(dataset)) 176 | 177 | 178 | for i in range(len(dataset)): 179 | labels[i] = dataset[i][1] 180 | 181 | num_items = int(len(dataset)/num_users) 182 | dict_users_labeled = set() 183 | 184 | # sort labels 185 | idxs_labels = np.vstack((idxs, labels)) 186 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]#索引值 187 | idxs = idxs_labels[0,:] 188 | 189 | # divide and assign 190 | for i in range(num_users): 191 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 192 | idx_shard = list(set(idx_shard) - rand_set) 193 | for rand in rand_set: 194 | dict_users_unlabeled[i] = np.concatenate((dict_users_unlabeled[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 195 | 196 | dict_users_labeled = set(np.random.choice(list(idxs), int(len(idxs) * label_rate), replace=False)) 197 | 198 | for i in range(num_users): 199 | 200 | dict_users_unlabeled[i] = set(dict_users_unlabeled[i]) 201 | # dict_users_labeled = dict_users_labeled | set(np.random.choice(list(dict_users_unlabeled[i]), int(num_items * label_rate), replace=False)) 202 | dict_users_unlabeled[i] = dict_users_unlabeled[i] - dict_users_labeled 203 | 204 | 205 | return dict_users_labeled, dict_users_unlabeled 206 | 207 | 208 | class DatasetSplit(Dataset): 209 | def __init__(self, dataset, idxs): 210 | self.dataset = dataset 211 | self.idxs = list(idxs) 212 | 213 | def __len__(self): 214 | return len(self.idxs) 215 | 216 | def __getitem__(self, item): 217 | (images1, images2), labels = self.dataset[self.idxs[item]] 218 | return (images1, images2), labels 219 | 220 | 221 | 222 | def main(device, args): 223 | 224 | 225 | loss1_func = nn.CrossEntropyLoss() 226 | loss2_func = softmax_kl_loss 227 | 228 | dataset_kwargs = { 229 | 'dataset':args.dataset, 230 | 'data_dir': args.data_dir, 231 | 'download':args.download, 232 | 'debug_subset_size':args.batch_size if args.debug else None 233 | } 234 | dataloader_kwargs = { 235 | 'batch_size': args.batch_size, 236 | 'drop_last': True, 237 | 'pin_memory': True, 238 | 'num_workers': args.num_workers, 239 | } 240 | dataloader_unlabeled_kwargs = { 241 | 'batch_size': args.batch_size*5, 242 | 'drop_last': True, 243 | 'pin_memory': True, 244 | 'num_workers': args.num_workers, 245 | } 246 | dataset_train =get_dataset( 247 | transform=get_aug_fedmatch(args.dataset, True), 248 | train=True, 249 | **dataset_kwargs 250 | ) 251 | 252 | if args.iid == 'iid': 253 | dict_users_labeled, dict_users_unlabeled = iid(dataset_train, args.num_users, args.label_rate) 254 | else: 255 | dict_users_labeled, dict_users_unlabeled = noniid(dataset_train, args.num_users, args.label_rate) 256 | 257 | train_loader_unlabeled = {} 258 | 259 | 260 | # define model 261 | model_glob = get_model('fedfixmatch', args.backbone).to(device) 262 | if torch.cuda.device_count() > 1: model_glob = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_glob) 263 | accuracy = [] 264 | 265 | 266 | for iter in range(args.num_epochs): 267 | 268 | model_glob.train() 269 | optimizer = torch.optim.SGD(model_glob.parameters(), lr=0.01, momentum=0.5) 270 | class_criterion = nn.CrossEntropyLoss(size_average=False, ignore_index= -1 ) 271 | 272 | train_loader_labeled = torch.utils.data.DataLoader( 273 | dataset=DatasetSplit(dataset_train, dict_users_labeled), 274 | shuffle=True, 275 | **dataloader_kwargs 276 | ) 277 | 278 | for batch_idx, ((img, img_ema), label) in enumerate(train_loader_labeled): 279 | input_var = torch.autograd.Variable(img.cuda()) 280 | ema_input_var = torch.autograd.Variable(img_ema.cuda()) 281 | target_var = torch.autograd.Variable(label.cuda()) 282 | minibatch_size = len(target_var) 283 | labeled_minibatch_size = target_var.data.ne(-1).sum() 284 | ema_model_out = model_glob(ema_input_var) 285 | model_out = model_glob(input_var) 286 | if isinstance(model_out, Variable): 287 | logit1 = model_out 288 | ema_logit = ema_model_out 289 | else: 290 | assert len(model_out) == 2 291 | assert len(ema_model_out) == 2 292 | logit1, logit2 = model_out 293 | ema_logit, _ = ema_model_out 294 | 295 | ema_logit = Variable(ema_logit.detach().data, requires_grad=False) 296 | class_logit, cons_logit = logit1, logit1 297 | class_loss = class_criterion(class_logit, target_var) / minibatch_size 298 | ema_class_loss = class_criterion(ema_logit, target_var) / minibatch_size 299 | pseudo_label1 = torch.softmax(model_out.detach_(), dim=-1) 300 | max_probs, targets_u = torch.max(pseudo_label1, dim=-1) 301 | mask = max_probs.ge(args.threshold_pl).float() 302 | Lu = (F.cross_entropy(ema_logit, targets_u, reduction='none') * mask).mean() 303 | loss = class_loss + Lu 304 | optimizer.zero_grad() 305 | loss.backward() 306 | optimizer.step() 307 | # batch_loss.append(loss.item()) 308 | 309 | 310 | del train_loader_labeled 311 | gc.collect() 312 | torch.cuda.empty_cache() 313 | 314 | if iter%1==0: 315 | test_loader = torch.utils.data.DataLoader( 316 | dataset=get_dataset( 317 | transform=get_aug(args.dataset, False, train_classifier=False), 318 | train=False, 319 | **dataset_kwargs), 320 | shuffle=False, 321 | **dataloader_kwargs 322 | ) 323 | model_glob.eval() 324 | acc, loss_train_test_labeled = test_img(model_glob, test_loader, args) 325 | accuracy.append(str(acc)) 326 | del test_loader 327 | gc.collect() 328 | torch.cuda.empty_cache() 329 | 330 | 331 | w_locals, loss_locals, loss0_locals, loss2_locals = [], [], [], [] 332 | 333 | m = max(int(args.frac * args.num_users), 1) 334 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 335 | 336 | for idx in idxs_users: 337 | 338 | loss_local = [] 339 | loss0_local = [] 340 | loss2_local = [] 341 | 342 | model_local = copy.deepcopy(model_glob).to(args.device) 343 | 344 | train_loader_unlabeled = torch.utils.data.DataLoader( 345 | dataset=DatasetSplit(dataset_train, dict_users_unlabeled[idx]), 346 | shuffle=True, 347 | **dataloader_unlabeled_kwargs 348 | ) 349 | 350 | 351 | model_local.train() 352 | 353 | 354 | for i, ((img, img_ema), label) in enumerate(train_loader_unlabeled): 355 | 356 | input_var = torch.autograd.Variable(img.cuda()) 357 | ema_input_var = torch.autograd.Variable(img_ema.cuda()) 358 | target_var = torch.autograd.Variable(label.cuda()) 359 | minibatch_size = len(target_var) 360 | labeled_minibatch_size = target_var.data.ne(-1).sum() 361 | ema_model_out = model_local(ema_input_var) 362 | model_out = model_local(input_var) 363 | if isinstance(model_out, Variable): 364 | logit1 = model_out 365 | ema_logit = ema_model_out 366 | else: 367 | assert len(model_out) == 2 368 | assert len(ema_model_out) == 2 369 | logit1, logit2 = model_out 370 | ema_logit, _ = ema_model_out 371 | 372 | ema_logit = Variable(ema_logit.detach().data, requires_grad=True) 373 | class_logit, cons_logit = logit1, logit1 374 | # class_loss = class_criterion(class_logit, target_var) / minibatch_size 375 | # ema_class_loss = class_criterion(ema_logit, target_var) / minibatch_size 376 | pseudo_label1 = torch.softmax(model_out.detach_(), dim=-1) 377 | max_probs, targets_u = torch.max(pseudo_label1, dim=-1) 378 | mask = max_probs.ge(args.threshold_pl).float() 379 | Lu = (F.cross_entropy(ema_logit, targets_u, reduction='none') * mask).mean() 380 | loss = Lu 381 | optimizer.zero_grad() 382 | loss.backward() 383 | optimizer.step() 384 | # batch_loss.append(loss.item()) 385 | 386 | w_locals.append(copy.deepcopy(model_local.state_dict())) 387 | # loss_locals.append(sum(loss_local) / len(loss_local) ) 388 | 389 | del model_local 390 | gc.collect() 391 | del train_loader_unlabeled 392 | gc.collect() 393 | torch.cuda.empty_cache() 394 | 395 | 396 | 397 | w_glob = FedAvg(w_locals) 398 | model_glob.load_state_dict(w_glob) 399 | 400 | # loss_avg = sum(loss_locals) / len(loss_locals) 401 | 402 | if iter%1==0: 403 | print('Round {:3d}, Acc {:.2f}%'.format(iter, acc)) 404 | 405 | # f = open("./result_ablation.txt",'a') 406 | # f.write("fedavg-fixmatch") 407 | # f.write(str(args.label_rate)) 408 | # f.write("\n") 409 | # f.write(str(args.frac)) 410 | # f.write(str(args.batch_size)) 411 | # f.write("\n") 412 | # f.write(args.dataset) 413 | # f.write("\n") 414 | # f.write(args.iid) 415 | # f.write("\n") 416 | # f.write(" ".join(accuracy)) 417 | # f.write("\n") 418 | # f.close() 419 | 420 | if __name__ == "__main__": 421 | args = get_args() 422 | main(device=args.device, args=args) 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | -------------------------------------------------------------------------------- /fedavg-uda-main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | import numpy as np 7 | import copy 8 | 9 | import gc 10 | 11 | from tqdm import tqdm 12 | from configs import get_args 13 | from augmentations import get_aug, get_aug_uda, get_aug_fedmatch 14 | from models import get_model 15 | from tools import AverageMeter, PlotLogger, knn_monitor 16 | from datasets import get_dataset 17 | from optimizers import get_optimizer, LR_Scheduler 18 | from torch.utils.data import DataLoader, Dataset 19 | 20 | 21 | import torch 22 | from torch import nn, autograd 23 | from torch.utils.data import DataLoader, Dataset 24 | import numpy as np 25 | import random 26 | from sklearn import metrics 27 | import torch.nn.functional as F 28 | import copy 29 | from torch.autograd import Variable 30 | import itertools 31 | import logging 32 | import os.path 33 | from PIL import Image 34 | import numpy as np 35 | from torch.utils.data.sampler import Sampler 36 | import re 37 | import argparse 38 | import os 39 | import shutil 40 | import time 41 | import math 42 | import logging 43 | import os 44 | import sys 45 | import torch.backends.cudnn as cudnn 46 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 47 | import torchvision.datasets 48 | 49 | 50 | import torch 51 | from torch import nn 52 | import torch.nn.functional as F 53 | from torch.utils.data import DataLoader 54 | 55 | import torch 56 | from torch import nn, autograd 57 | from torch.utils.data import DataLoader, Dataset 58 | import numpy as np 59 | import random 60 | from sklearn import metrics 61 | import torch.nn.functional as F 62 | import copy 63 | from torch.autograd import Variable 64 | import itertools 65 | import logging 66 | import os.path 67 | from PIL import Image 68 | import numpy as np 69 | from torch.utils.data.sampler import Sampler 70 | import re 71 | import argparse 72 | import os 73 | import shutil 74 | import time 75 | import math 76 | import logging 77 | import os 78 | import sys 79 | import torch.backends.cudnn as cudnn 80 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 81 | import torchvision.datasets 82 | 83 | def quantile_linear(iter, args): 84 | 85 | turn_point = int( (args.comu_rate * args.epochs - 0.1 * args.epochs -1.35) / 0.45 ) 86 | if iter < args.phi_g: 87 | return 1.0 88 | elif iter > turn_point: 89 | return 0.1 90 | else: 91 | return 0.9 * iter / ( 2 - turn_point ) + 1 - 1.8/( 2 - turn_point ) 92 | 93 | 94 | def quantile_rectangle(iter, args): 95 | if iter < args.phi_g: 96 | return 0.0 97 | elif iter >= args.psi_g: 98 | return 0.0 99 | else: 100 | if args.comu_rate*5/3 > 1: 101 | return 0.99 102 | else: 103 | return args.comu_rate*args.epochs/(args.psi_g - args.phi_g) 104 | 105 | def get_median(data, iter, args): 106 | if args.dataset == 'mnist': 107 | a = 8 108 | else: 109 | a = 33 110 | 111 | if len(data) < (39*a): 112 | data_test = data[(-10*a):] 113 | elif len(data) < (139*a): 114 | data_test = data[(30*a) : ] 115 | else: 116 | data_test = data[(-100*a):] 117 | 118 | data_test.sort() 119 | 120 | if args.ramp == 'linear': 121 | quantile = quantile_linear(iter, args) 122 | iter_place = int( (1 - quantile) * len(data_test)) 123 | elif args.ramp == 'flat': 124 | quantile = quantile_flat(iter, args) 125 | iter_place = int( (1 - quantile) * len(data_test)) 126 | elif args.ramp == 'rectangle': 127 | quantile = quantile_rectangle(iter, args) 128 | iter_place = int( (1 - quantile) * len(data_test)-1) 129 | else: 130 | exit('Error: wrong ramp type!') 131 | return data_test[iter_place] 132 | 133 | def sigmoid_rampup(current, rampup_length): 134 | if rampup_length == 0: 135 | return 1.0 136 | else: 137 | current = np.clip(current, 0.0, rampup_length) 138 | phase = 1.0 - current / rampup_length 139 | return float(np.exp(-5.0 * phase * phase)) 140 | 141 | def sigmoid_rampup2(current, rampup_length): 142 | if rampup_length == 0: 143 | return 1.0 144 | else: 145 | current = np.clip(current, 0.0, rampup_length) 146 | phase = current / rampup_length 147 | return float(np.exp(-5.0 * phase * phase)) 148 | 149 | def linear_rampup(current, rampup_length): 150 | assert current >= 0 and rampup_length >= 0 151 | if current >= rampup_length: 152 | return 1.0 153 | else: 154 | return current / rampup_length 155 | 156 | 157 | def cosine_rampdown(current, rampdown_length): 158 | assert 0 <= current <= rampdown_length 159 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 160 | 161 | def test_img(net_g, data_loader, args): 162 | net_g.eval() 163 | test_loss = 0 164 | correct = 0 165 | 166 | for idx, (data, target) in enumerate(data_loader): 167 | data, target = data.cuda(), target.cuda() 168 | log_probs = net_g(data) 169 | test_loss += F.cross_entropy(log_probs, target, reduction='sum',ignore_index=-1).item() 170 | y_pred = log_probs.data.max(1, keepdim=True)[1] 171 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 172 | 173 | test_loss /= len(data_loader.dataset) 174 | accuracy = 100.00 * correct / len(data_loader.dataset) 175 | 176 | return accuracy, test_loss 177 | 178 | def adjust_learning_rate(optimizer, epoch, step_in_epoch, total_steps_in_epoch , args): 179 | lr = args.lr 180 | epoch = epoch + step_in_epoch / total_steps_in_epoch 181 | lr = linear_rampup(epoch, args.lr_rampup) * (args.lr - args.initial_lr) + args.initial_lr 182 | if args.lr_rampdown_epochs: 183 | lr *= cosine_rampdown(epoch, args.lr_rampdown_epochs) 184 | for param_group in optimizer.param_groups: 185 | param_group['lr'] = lr 186 | 187 | def get_current_consistency_weight(epoch): 188 | return sigmoid_rampup(epoch, 10) 189 | 190 | def sigmoid_rampup(current, rampup_length): 191 | if rampup_length == 0: 192 | return 1.0 193 | else: 194 | current = np.clip(current, 0.0, rampup_length) 195 | phase = 1.0 - current / rampup_length 196 | return float(np.exp(-5.0 * phase * phase)) 197 | 198 | def softmax_mse_loss(input_logits, target_logits): 199 | assert input_logits.size() == target_logits.size() 200 | input_softmax = F.softmax(input_logits, dim=1) 201 | target_softmax = F.softmax(target_logits, dim=1) 202 | num_classes = input_logits.size()[1] 203 | return F.mse_loss(input_softmax, target_softmax, size_average=False) / num_classes 204 | 205 | def softmax_kl_loss(input_logits, target_logits): 206 | assert input_logits.size() == target_logits.size() 207 | input_log_softmax = F.log_softmax(input_logits, dim=1) 208 | target_softmax = F.softmax(target_logits, dim=1) 209 | return F.kl_div(input_log_softmax, target_softmax, size_average=False) 210 | 211 | def symmetric_mse_loss(input1, input2): 212 | assert input1.size() == input2.size() 213 | num_classes = input1.size()[1] 214 | return torch.sum((input1 - input2)**2) / num_classes 215 | 216 | 217 | def FedAvg(w): 218 | w_avg = copy.deepcopy(w[0]) 219 | # print(w_avg.keys()) 220 | for k in w_avg.keys(): 221 | for i in range(1, len(w)): 222 | w_avg[k] += w[i][k] 223 | w_avg[k] = torch.div(w_avg[k], len(w)) 224 | return w_avg 225 | 226 | 227 | def iid(dataset, num_users, label_rate): 228 | """ 229 | Sample I.I.D. client data from MNIST dataset 230 | :param dataset: 231 | :param num_users: 232 | :return: dict of image index 233 | """ 234 | num_items = int(len(dataset)/num_users) 235 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 236 | dict_users_labeled, dict_users_unlabeled = set(), {} 237 | 238 | dict_users_labeled = set(np.random.choice(list(all_idxs), int(len(all_idxs) * label_rate), replace=False)) 239 | 240 | for i in range(num_users): 241 | # dict_users_labeled = dict_users_labeled | set(np.random.choice(all_idxs, int(num_items * label_rate), replace=False)) 242 | # all_idxs = list(set(all_idxs) - dict_users_labeled) 243 | dict_users_unlabeled[i] = set(np.random.choice(all_idxs, int(num_items) , replace=False)) 244 | all_idxs = list(set(all_idxs) - dict_users_unlabeled[i]) 245 | dict_users_unlabeled[i] = dict_users_unlabeled[i] - dict_users_labeled 246 | return dict_users_labeled, dict_users_unlabeled 247 | 248 | 249 | def noniid(dataset, num_users, label_rate): 250 | 251 | num_shards, num_imgs = 2 * num_users, int(len(dataset)/num_users/2) 252 | idx_shard = [i for i in range(num_shards)] 253 | dict_users_unlabeled = {i: np.array([], dtype='int64') for i in range(num_users)} 254 | idxs = np.arange(len(dataset)) 255 | labels = np.arange(len(dataset)) 256 | 257 | 258 | for i in range(len(dataset)): 259 | labels[i] = dataset[i][1] 260 | 261 | num_items = int(len(dataset)/num_users) 262 | dict_users_labeled = set() 263 | 264 | # sort labels 265 | idxs_labels = np.vstack((idxs, labels)) 266 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]#索引值 267 | idxs = idxs_labels[0,:] 268 | 269 | # divide and assign 270 | for i in range(num_users): 271 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 272 | idx_shard = list(set(idx_shard) - rand_set) 273 | for rand in rand_set: 274 | dict_users_unlabeled[i] = np.concatenate((dict_users_unlabeled[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 275 | 276 | dict_users_labeled = set(np.random.choice(list(idxs), int(len(idxs) * label_rate), replace=False)) 277 | 278 | for i in range(num_users): 279 | 280 | dict_users_unlabeled[i] = set(dict_users_unlabeled[i]) 281 | # dict_users_labeled = dict_users_labeled | set(np.random.choice(list(dict_users_unlabeled[i]), int(num_items * label_rate), replace=False)) 282 | dict_users_unlabeled[i] = dict_users_unlabeled[i] - dict_users_labeled 283 | 284 | 285 | return dict_users_labeled, dict_users_unlabeled 286 | 287 | class DatasetSplit(Dataset): 288 | def __init__(self, dataset, idxs): 289 | self.dataset = dataset 290 | self.idxs = list(idxs) 291 | 292 | def __len__(self): 293 | return len(self.idxs) 294 | 295 | def __getitem__(self, item): 296 | (images1, images2), labels = self.dataset[self.idxs[item]] 297 | return (images1, images2), labels 298 | 299 | def get_current_consistency_weight(epoch): 300 | return sigmoid_rampup(epoch, 10) 301 | 302 | 303 | def main(device, args): 304 | 305 | 306 | loss1_func = nn.CrossEntropyLoss() 307 | loss2_func = softmax_kl_loss 308 | 309 | dataset_kwargs = { 310 | 'dataset':args.dataset, 311 | 'data_dir': args.data_dir, 312 | 'download':args.download, 313 | 'debug_subset_size':args.batch_size if args.debug else None 314 | } 315 | dataloader_kwargs = { 316 | 'batch_size': args.batch_size, 317 | 'drop_last': True, 318 | 'pin_memory': True, 319 | 'num_workers': args.num_workers, 320 | } 321 | dataloader_unlabeled_kwargs = { 322 | 'batch_size': args.batch_size*5, 323 | 'drop_last': True, 324 | 'pin_memory': True, 325 | 'num_workers': args.num_workers, 326 | } 327 | dataset_train =get_dataset( 328 | transform=get_aug_fedmatch(args.dataset, True), 329 | train=True, 330 | **dataset_kwargs 331 | ) 332 | 333 | if args.iid == 'iid': 334 | dict_users_labeled, dict_users_unlabeled = iid(dataset_train, args.num_users, args.label_rate) 335 | else: 336 | dict_users_labeled, dict_users_unlabeled = noniid(dataset_train, args.num_users, args.label_rate) 337 | train_loader_unlabeled = {} 338 | 339 | 340 | # define model 341 | model_glob = get_model('fedfixmatch', args.backbone).to(device) 342 | if torch.cuda.device_count() > 1: model_glob = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_glob) 343 | 344 | 345 | model_local_idx = set() 346 | 347 | user_epoch = {} 348 | lr_scheduler = {} 349 | accuracy = [] 350 | class_criterion = nn.CrossEntropyLoss(size_average=False, ignore_index= -1 ) 351 | if args.dataset == 'cifar' and args.iid != 'noniid_tradition': 352 | consistency_criterion = softmax_kl_loss 353 | else: 354 | consistency_criterion = softmax_mse_loss 355 | 356 | for iter in range(args.num_epochs): 357 | 358 | model_glob.train() 359 | optimizer = torch.optim.SGD(model_glob.parameters(), lr=0.01, momentum=0.5) 360 | 361 | train_loader_labeled = torch.utils.data.DataLoader( 362 | dataset=DatasetSplit(dataset_train, dict_users_labeled), 363 | shuffle=True, 364 | **dataloader_kwargs 365 | ) 366 | 367 | for batch_idx, ((img, img_ema), label) in enumerate(train_loader_labeled): 368 | 369 | img, img_ema, label = img.to(args.device), img_ema.to(args.device), label.to(args.device) 370 | input_var = torch.autograd.Variable(img) 371 | ema_input_var = torch.autograd.Variable(img_ema, volatile=True) 372 | target_var = torch.autograd.Variable(label) 373 | minibatch_size = len(target_var) 374 | labeled_minibatch_size = target_var.data.ne(-1).sum() 375 | ema_model_out = model_glob(ema_input_var) 376 | model_out = model_glob(input_var) 377 | if isinstance(model_out, Variable): 378 | logit1 = model_out 379 | ema_logit = ema_model_out 380 | else: 381 | assert len(model_out) == 2 382 | assert len(ema_model_out) == 2 383 | logit1, logit2 = model_out 384 | ema_logit, _ = ema_model_out 385 | ema_logit = Variable(ema_logit.detach().data, requires_grad=False) 386 | class_logit, cons_logit = logit1, logit1 387 | classification_weight = 1 388 | class_loss = classification_weight * class_criterion(class_logit, target_var) / minibatch_size 389 | ema_class_loss = class_criterion(ema_logit, target_var) / minibatch_size 390 | consistency_weight = get_current_consistency_weight(iter) 391 | consistency_loss = consistency_weight * consistency_criterion(cons_logit, ema_logit) / minibatch_size 392 | loss = class_loss + consistency_loss 393 | optimizer.zero_grad() 394 | loss.backward() 395 | optimizer.step() 396 | 397 | del train_loader_labeled 398 | gc.collect() 399 | torch.cuda.empty_cache() 400 | 401 | if iter%1==0: 402 | test_loader = torch.utils.data.DataLoader( 403 | dataset=get_dataset( 404 | transform=get_aug(args.dataset, False, train_classifier=False), 405 | train=False, 406 | **dataset_kwargs), 407 | shuffle=False, 408 | **dataloader_kwargs 409 | ) 410 | model_glob.eval() 411 | acc, loss_train_test_labeled = test_img(model_glob, test_loader, args) 412 | accuracy.append(str(acc)) 413 | del test_loader 414 | gc.collect() 415 | torch.cuda.empty_cache() 416 | 417 | 418 | w_locals, loss_locals, loss0_locals, loss2_locals = [], [], [], [] 419 | 420 | m = max(int(args.frac * args.num_users), 1) 421 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 422 | 423 | for idx in idxs_users: 424 | if idx in user_epoch.keys(): 425 | user_epoch[idx] += 1 426 | else: 427 | user_epoch[idx] = 1 428 | 429 | loss_local = [] 430 | loss0_local = [] 431 | loss2_local = [] 432 | 433 | 434 | model_local = copy.deepcopy(model_glob).to(args.device) 435 | 436 | train_loader_unlabeled = torch.utils.data.DataLoader( 437 | dataset=DatasetSplit(dataset_train, dict_users_unlabeled[idx]), 438 | shuffle=True, 439 | **dataloader_unlabeled_kwargs 440 | ) 441 | 442 | optimizer = torch.optim.SGD(model_local.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False) 443 | 444 | model_local.train() 445 | 446 | 447 | for i, ((images1, images2), labels) in enumerate(train_loader_unlabeled): 448 | 449 | img, img_ema, label = img.to(args.device), img_ema.to(args.device), label.to(args.device) 450 | adjust_learning_rate(optimizer, user_epoch[idx], batch_idx, len(train_loader_unlabeled), args) 451 | input_var = torch.autograd.Variable(img) 452 | ema_input_var = torch.autograd.Variable(img_ema, volatile=True) 453 | target_var = torch.autograd.Variable(label) 454 | minibatch_size = len(target_var) 455 | labeled_minibatch_size = target_var.data.ne(-1).sum() 456 | ema_model_out = model_local(ema_input_var) 457 | model_out = model_local(input_var) 458 | if isinstance(model_out, Variable): 459 | logit1 = model_out 460 | ema_logit = ema_model_out 461 | else: 462 | assert len(model_out) == 2 463 | assert len(ema_model_out) == 2 464 | logit1, logit2 = model_out 465 | ema_logit, _ = ema_model_out 466 | ema_logit = Variable(ema_logit.detach().data, requires_grad=False) 467 | class_logit, cons_logit = logit1, logit1 468 | 469 | consistency_weight = get_current_consistency_weight(user_epoch[idx]) 470 | consistency_loss = consistency_weight * consistency_criterion(cons_logit, ema_logit) / minibatch_size 471 | loss = consistency_loss 472 | optimizer.zero_grad() 473 | loss.backward() 474 | optimizer.step() 475 | 476 | w_locals.append(copy.deepcopy(model_local.state_dict())) 477 | 478 | del model_local 479 | gc.collect() 480 | del train_loader_unlabeled 481 | gc.collect() 482 | torch.cuda.empty_cache() 483 | 484 | 485 | 486 | w_glob = FedAvg(w_locals) 487 | model_glob.load_state_dict(w_glob) 488 | 489 | # loss_avg = sum(loss_locals) / len(loss_locals) 490 | 491 | if iter%1==0: 492 | print('Round {:3d}, Acc {:.2f}%'.format(iter, acc)) 493 | 494 | if __name__ == "__main__": 495 | args = get_args() 496 | main(device=args.device, args=args) 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | -------------------------------------------------------------------------------- /fedcon-main-sec45.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | import numpy as np 7 | import copy 8 | 9 | import gc 10 | 11 | from tqdm import tqdm 12 | from configs import get_args 13 | from augmentations import get_aug 14 | from models import get_model 15 | from tools import AverageMeter, PlotLogger, knn_monitor 16 | from datasets import get_dataset 17 | from optimizers import get_optimizer, LR_Scheduler 18 | from torch.utils.data import DataLoader, Dataset 19 | 20 | 21 | import torch 22 | from torch import nn, autograd 23 | from torch.utils.data import DataLoader, Dataset 24 | import numpy as np 25 | import random 26 | from sklearn import metrics 27 | import torch.nn.functional as F 28 | import copy 29 | from torch.autograd import Variable 30 | import itertools 31 | import logging 32 | import os.path 33 | from PIL import Image 34 | import numpy as np 35 | from torch.utils.data.sampler import Sampler 36 | import re 37 | import argparse 38 | import os 39 | import shutil 40 | import time 41 | import math 42 | import logging 43 | import os 44 | import sys 45 | import torch.backends.cudnn as cudnn 46 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 47 | import torchvision.datasets 48 | 49 | 50 | import torch 51 | from torch import nn 52 | import torch.nn.functional as F 53 | from torch.utils.data import DataLoader 54 | 55 | 56 | def test_img(net_g, data_loader, args): 57 | net_g.eval() 58 | test_loss = 0 59 | correct = 0 60 | 61 | for idx, (data, target) in enumerate(data_loader): 62 | data, target = data.cuda(), target.cuda() 63 | log_probs, _, _, _ = net_g(data, data) 64 | test_loss += F.cross_entropy(log_probs, target, reduction='sum',ignore_index=-1).item() 65 | y_pred = log_probs.data.max(1, keepdim=True)[1] 66 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 67 | 68 | test_loss /= len(data_loader.dataset) 69 | accuracy = 100.00 * correct / len(data_loader.dataset) 70 | 71 | return accuracy, test_loss 72 | 73 | 74 | def get_current_consistency_weight(epoch): 75 | return sigmoid_rampup(epoch, 100) 76 | 77 | def sigmoid_rampup(current, rampup_length): 78 | if rampup_length == 0: 79 | return 1.0 80 | else: 81 | current = np.clip(current, 0.0, rampup_length) 82 | phase = 1.0 - current / rampup_length 83 | return float(np.exp(-5.0 * phase * phase)) 84 | 85 | def softmax_mse_loss(input_logits, target_logits): 86 | assert input_logits.size() == target_logits.size() 87 | input_softmax = F.softmax(input_logits, dim=1) 88 | target_softmax = F.softmax(target_logits, dim=1) 89 | num_classes = input_logits.size()[1] 90 | return F.mse_loss(input_softmax, target_softmax, size_average=False) / num_classes 91 | 92 | def softmax_kl_loss(input_logits, target_logits): 93 | assert input_logits.size() == target_logits.size() 94 | input_log_softmax = F.log_softmax(input_logits, dim=1) 95 | target_softmax = F.softmax(target_logits, dim=1) 96 | return F.kl_div(input_log_softmax, target_softmax, size_average=False) 97 | 98 | def symmetric_mse_loss(input1, input2): 99 | assert input1.size() == input2.size() 100 | num_classes = input1.size()[1] 101 | return torch.sum((input1 - input2)**2) / num_classes 102 | 103 | 104 | def FedAvg(w): 105 | w_avg = copy.deepcopy(w[0]) 106 | # print(w_avg.keys()) 107 | for k in w_avg.keys(): 108 | for i in range(1, len(w)): 109 | w_avg[k] += w[i][k] 110 | w_avg[k] = torch.div(w_avg[k], len(w)) 111 | return w_avg 112 | 113 | 114 | def iid(dataset, num_users, label_rate): 115 | """ 116 | Sample I.I.D. client data from MNIST dataset 117 | :param dataset: 118 | :param num_users: 119 | :return: dict of image index 120 | """ 121 | num_items = int(len(dataset)/num_users) 122 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 123 | dict_users_labeled, dict_users_unlabeled = set(), {} 124 | dict_users_unlabeled_server = set() 125 | 126 | dict_users_labeled = set(np.random.choice(list(all_idxs), int(len(all_idxs) * label_rate), replace=False)) 127 | dict_users_unlabeled_server = set(np.random.choice(list(all_idxs), int(len(all_idxs) * label_rate * 5), replace=False)) 128 | 129 | for i in range(num_users): 130 | # dict_users_labeled = dict_users_labeled | set(np.random.choice(all_idxs, int(num_items * label_rate), replace=False)) 131 | # all_idxs = list(set(all_idxs) - dict_users_labeled) 132 | dict_users_unlabeled[i] = set(np.random.choice(all_idxs, int(num_items) , replace=False)) 133 | all_idxs = list(set(all_idxs) - dict_users_unlabeled[i]) 134 | dict_users_unlabeled[i] = dict_users_unlabeled[i] - dict_users_labeled - dict_users_unlabeled_server 135 | return dict_users_labeled, dict_users_unlabeled_server, dict_users_unlabeled 136 | 137 | 138 | def noniid(dataset, num_users, label_rate): 139 | 140 | num_shards, num_imgs = 2 * num_users, int(len(dataset)/num_users/2) 141 | idx_shard = [i for i in range(num_shards)] 142 | dict_users_unlabeled = {i: np.array([], dtype='int64') for i in range(num_users)} 143 | idxs = np.arange(len(dataset)) 144 | labels = np.arange(len(dataset)) 145 | 146 | 147 | for i in range(len(dataset)): 148 | labels[i] = dataset[i][1] 149 | 150 | num_items = int(len(dataset)/num_users) 151 | dict_users_labeled = set() 152 | dict_users_unlabeled_server = set() 153 | 154 | # sort labels 155 | idxs_labels = np.vstack((idxs, labels)) 156 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]#索引值 157 | idxs = idxs_labels[0,:] 158 | 159 | # divide and assign 160 | for i in range(num_users): 161 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 162 | idx_shard = list(set(idx_shard) - rand_set) 163 | for rand in rand_set: 164 | dict_users_unlabeled[i] = np.concatenate((dict_users_unlabeled[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 165 | 166 | dict_users_labeled = set(np.random.choice(list(idxs), int(len(idxs) * label_rate), replace=False)) 167 | dict_users_unlabeled_server = set(np.random.choice(list(idxs), int(len(idxs) * label_rate * 5), replace=False)) 168 | 169 | for i in range(num_users): 170 | 171 | dict_users_unlabeled[i] = set(dict_users_unlabeled[i]) 172 | # dict_users_labeled = dict_users_labeled | set(np.random.choice(list(dict_users_unlabeled[i]), int(num_items * label_rate), replace=False)) 173 | dict_users_unlabeled[i] = dict_users_unlabeled[i] - dict_users_labeled - dict_users_unlabeled_server 174 | 175 | 176 | return dict_users_labeled, dict_users_unlabeled_server, dict_users_unlabeled 177 | 178 | class DatasetSplit(Dataset): 179 | def __init__(self, dataset, idxs): 180 | self.dataset = dataset 181 | self.idxs = list(idxs) 182 | 183 | def __len__(self): 184 | return len(self.idxs) 185 | 186 | def __getitem__(self, item): 187 | (images1, images2), labels = self.dataset[self.idxs[item]] 188 | return (images1, images2), labels 189 | 190 | 191 | 192 | def main(device, args): 193 | 194 | 195 | loss1_func = nn.CrossEntropyLoss() 196 | loss2_func = softmax_kl_loss 197 | 198 | dataset_kwargs = { 199 | 'dataset':args.dataset, 200 | 'data_dir': args.data_dir, 201 | 'download':args.download, 202 | 'debug_subset_size':args.batch_size if args.debug else None 203 | } 204 | dataloader_kwargs = { 205 | 'batch_size': args.batch_size, 206 | 'drop_last': True, 207 | 'pin_memory': True, 208 | 'num_workers': args.num_workers, 209 | } 210 | dataloader_unlabeled_kwargs = { 211 | 'batch_size': args.batch_size*5, 212 | 'drop_last': True, 213 | 'pin_memory': True, 214 | 'num_workers': args.num_workers, 215 | } 216 | dataset_train =get_dataset( 217 | transform=get_aug(args.dataset, True), 218 | train=True, 219 | **dataset_kwargs 220 | ) 221 | 222 | if args.iid == 'iid': 223 | dict_users_labeled, dict_users_unlabeled_server, dict_users_unlabeled = iid(dataset_train, args.num_users, args.label_rate) 224 | else: 225 | dict_users_labeled, dict_users_unlabeled_server, dict_users_unlabeled = noniid(dataset_train, args.num_users, args.label_rate) 226 | train_loader_unlabeled = {} 227 | 228 | 229 | # define model 230 | model_glob = get_model('global', args.backbone).to(device) 231 | if torch.cuda.device_count() > 1: model_glob = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_glob) 232 | 233 | 234 | model_local_idx = set() 235 | model_local_dict = {} 236 | accuracy = [] 237 | lr_scheduler = {} 238 | 239 | 240 | for iter in range(args.num_epochs): 241 | 242 | model_glob.train() 243 | optimizer = torch.optim.SGD(model_glob.parameters(), lr=0.01, momentum=0.5) 244 | 245 | train_loader_labeled = torch.utils.data.DataLoader( 246 | dataset=DatasetSplit(dataset_train, dict_users_labeled), 247 | shuffle=True, 248 | **dataloader_kwargs 249 | ) 250 | train_loader_unlabeled = torch.utils.data.DataLoader( 251 | dataset=DatasetSplit(dataset_train, dict_users_unlabeled_server), 252 | shuffle=True, 253 | **dataloader_unlabeled_kwargs 254 | ) 255 | train_loader = zip(train_loader_labeled, train_loader_unlabeled) 256 | 257 | for batch_idx, (data_x, data_u) in enumerate(train_loader): 258 | (images1_l, images2_l), labels = data_x 259 | (images1_u, images2_u), _ = data_u 260 | 261 | model_glob.zero_grad() 262 | labels = labels.cuda() 263 | 264 | batch_size = images1_l.shape[0] 265 | images1 = torch.cat((images1_l, images1_u)).to(args.device) 266 | images2 = torch.cat((images2_l, images2_u)).to(args.device) 267 | 268 | z1_t, z2_t, z1_s, z2_s = model_glob.forward(images1.to(device, non_blocking=True), images2.to(device, non_blocking=True)) 269 | model_glob.update_moving_average( iter*1000 + batch_idx, 20000) 270 | 271 | loss_class = 1/2 * loss1_func(z1_t[:batch_size], labels) + 1/2 * loss1_func(z2_t[:batch_size], labels) 272 | 273 | loss_consist = 1/2 * loss2_func(z1_t, z2_s) / len(labels) + 1/2 * loss2_func(z2_t, z1_s) / len(labels) 274 | consistency_weight = get_current_consistency_weight(batch_idx) 275 | loss = loss_class# + consistency_weight * loss_consist 276 | 277 | loss.backward() 278 | optimizer.step() 279 | 280 | del train_loader_labeled 281 | gc.collect() 282 | torch.cuda.empty_cache() 283 | 284 | if iter%1==0: 285 | test_loader = torch.utils.data.DataLoader( 286 | dataset=get_dataset( 287 | transform=get_aug(args.dataset, False, train_classifier=False), 288 | train=False, 289 | **dataset_kwargs), 290 | shuffle=False, 291 | **dataloader_kwargs 292 | ) 293 | model_glob.eval() 294 | acc, loss_train_test_labeled = test_img(model_glob, test_loader, args) 295 | accuracy.append(str(acc)) 296 | del test_loader 297 | gc.collect() 298 | torch.cuda.empty_cache() 299 | 300 | 301 | w_locals, loss_locals, loss0_locals, loss2_locals = [], [], [], [] 302 | 303 | m = max(int(args.frac * args.num_users), 1) 304 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 305 | 306 | 307 | for idx in idxs_users: 308 | 309 | loss_local = [] 310 | loss0_local = [] 311 | loss2_local = [] 312 | 313 | 314 | if idx in model_local_idx: 315 | model_local = get_model('local', args.backbone).to(device) 316 | model_local.projector.load_state_dict(model_local_dict[idx][0]) 317 | model_local.target_encoder.load_state_dict(model_local_dict[idx][1]) 318 | # model_local.projector.load_state_dict(torch.load('/model/'+'model1' + str(args.dataset) + str(idx)+ '.pkl')) 319 | # model_local.target_encoder.load_state_dict(torch.load('/model/'+'model1' + str(args.dataset) + 'tar'+ str(idx)+ '.pkl')) 320 | 321 | model_local.backbone.load_state_dict(model_glob.backbone.state_dict()) 322 | else: 323 | model_local = get_model('local', args.backbone).to(device) 324 | model_local.backbone.load_state_dict(model_glob.backbone.state_dict()) 325 | model_local.target_encoder.load_state_dict(model_local.online_encoder.state_dict()) 326 | model_local_idx = model_local_idx | set([idx]) 327 | 328 | train_loader_unlabeled = torch.utils.data.DataLoader( 329 | dataset=DatasetSplit(dataset_train, dict_users_unlabeled[idx]), 330 | shuffle=True, 331 | **dataloader_unlabeled_kwargs 332 | ) 333 | 334 | # define optimizer 335 | optimizer = get_optimizer( 336 | args.optimizer, model_local, 337 | lr=args.base_lr*args.batch_size/256, 338 | momentum=args.momentum, 339 | weight_decay=args.weight_decay) 340 | 341 | lr_scheduler = LR_Scheduler( 342 | optimizer, 343 | args.warmup_epochs, args.warmup_lr*args.batch_size/256, 344 | args.num_epochs, args.base_lr*args.batch_size/256, args.final_lr*args.batch_size/256, 345 | len(train_loader_unlabeled), 346 | constant_predictor_lr=True # see the end of section 4.2 predictor 347 | ) 348 | 349 | model_local.train() 350 | 351 | for j in range(args.local_ep): 352 | 353 | for i, ((images1, images2), labels) in enumerate(train_loader_unlabeled): 354 | 355 | model_local.zero_grad() 356 | 357 | batch_size = images1.shape[0] 358 | 359 | loss = model_local.forward(images1.to(device, non_blocking=True), images2.to(device, non_blocking=True)) 360 | 361 | loss.backward() 362 | optimizer.step() 363 | 364 | loss_local.append(int(loss)) 365 | 366 | lr = lr_scheduler.step() 367 | 368 | model_local.update_moving_average() 369 | 370 | w_locals.append(copy.deepcopy(model_local.backbone.state_dict())) 371 | loss_locals.append(sum(loss_local) / len(loss_local) ) 372 | model_local_dict[idx] = [model_local.projector.state_dict(), model_local.target_encoder.state_dict()] 373 | # torch.save(model_local.projector.state_dict(), '/model/'+'model1' + str(args.dataset) + str(idx)+ '.pkl') 374 | # torch.save(model_local.target_encoder.state_dict(), '/model/'+'model1' + str(args.dataset)+ 'tar' + str(idx)+ '.pkl') 375 | 376 | 377 | del model_local 378 | gc.collect() 379 | del train_loader_unlabeled 380 | gc.collect() 381 | torch.cuda.empty_cache() 382 | 383 | 384 | 385 | w_glob = FedAvg(w_locals) 386 | model_glob.backbone.load_state_dict(w_glob) 387 | 388 | loss_avg = sum(loss_locals) / len(loss_locals) 389 | 390 | if iter%1==0: 391 | print('Round {:3d}, Acc {:.2f}%'.format(iter, acc)) 392 | 393 | 394 | if __name__ == "__main__": 395 | args = get_args() 396 | main(device=args.device, args=args) 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | -------------------------------------------------------------------------------- /fedcon-main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | import numpy as np 7 | import copy 8 | 9 | import gc 10 | 11 | from tqdm import tqdm 12 | from configs import get_args 13 | from augmentations import get_aug 14 | from models import get_model 15 | from tools import AverageMeter, PlotLogger, knn_monitor 16 | from datasets import get_dataset 17 | from optimizers import get_optimizer, LR_Scheduler 18 | from torch.utils.data import DataLoader, Dataset 19 | 20 | 21 | import torch 22 | from torch import nn, autograd 23 | from torch.utils.data import DataLoader, Dataset 24 | import numpy as np 25 | import random 26 | from sklearn import metrics 27 | import torch.nn.functional as F 28 | import copy 29 | from torch.autograd import Variable 30 | import itertools 31 | import logging 32 | import os.path 33 | from PIL import Image 34 | import numpy as np 35 | from torch.utils.data.sampler import Sampler 36 | import re 37 | import argparse 38 | import os 39 | import shutil 40 | import time 41 | import math 42 | import logging 43 | import os 44 | import sys 45 | import torch.backends.cudnn as cudnn 46 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 47 | import torchvision.datasets 48 | 49 | 50 | import torch 51 | from torch import nn 52 | import torch.nn.functional as F 53 | from torch.utils.data import DataLoader 54 | 55 | 56 | def test_img(net_g, data_loader, args): 57 | net_g.eval() 58 | test_loss = 0 59 | correct = 0 60 | 61 | for idx, (data, target) in enumerate(data_loader): 62 | data, target = data.cuda(), target.cuda() 63 | log_probs, _, _, _ = net_g(data, data) 64 | test_loss += F.cross_entropy(log_probs, target, reduction='sum',ignore_index=-1).item() 65 | y_pred = log_probs.data.max(1, keepdim=True)[1] 66 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 67 | 68 | test_loss /= len(data_loader.dataset) 69 | accuracy = 100.00 * correct / len(data_loader.dataset) 70 | 71 | return accuracy, test_loss 72 | 73 | 74 | def get_current_consistency_weight(epoch): 75 | return sigmoid_rampup(epoch, 100) 76 | 77 | def sigmoid_rampup(current, rampup_length): 78 | if rampup_length == 0: 79 | return 1.0 80 | else: 81 | current = np.clip(current, 0.0, rampup_length) 82 | phase = 1.0 - current / rampup_length 83 | return float(np.exp(-5.0 * phase * phase)) 84 | 85 | def softmax_mse_loss(input_logits, target_logits): 86 | assert input_logits.size() == target_logits.size() 87 | input_softmax = F.softmax(input_logits, dim=1) 88 | target_softmax = F.softmax(target_logits, dim=1) 89 | num_classes = input_logits.size()[1] 90 | return F.mse_loss(input_softmax, target_softmax, size_average=False) / num_classes 91 | 92 | def softmax_kl_loss(input_logits, target_logits): 93 | assert input_logits.size() == target_logits.size() 94 | input_log_softmax = F.log_softmax(input_logits, dim=1) 95 | target_softmax = F.softmax(target_logits, dim=1) 96 | return F.kl_div(input_log_softmax, target_softmax, size_average=False) 97 | 98 | def symmetric_mse_loss(input1, input2): 99 | assert input1.size() == input2.size() 100 | num_classes = input1.size()[1] 101 | return torch.sum((input1 - input2)**2) / num_classes 102 | 103 | 104 | def FedAvg(w): 105 | w_avg = copy.deepcopy(w[0]) 106 | # print(w_avg.keys()) 107 | for k in w_avg.keys(): 108 | for i in range(1, len(w)): 109 | w_avg[k] += w[i][k] 110 | w_avg[k] = torch.div(w_avg[k], len(w)) 111 | return w_avg 112 | 113 | 114 | def iid(dataset, num_users, label_rate): 115 | """ 116 | Sample I.I.D. client data from MNIST dataset 117 | :param dataset: 118 | :param num_users: 119 | :return: dict of image index 120 | """ 121 | num_items = int(len(dataset)/num_users) 122 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 123 | dict_users_labeled, dict_users_unlabeled = set(), {} 124 | 125 | dict_users_labeled = set(np.random.choice(list(all_idxs), int(len(all_idxs) * label_rate), replace=False)) 126 | 127 | for i in range(num_users): 128 | # dict_users_labeled = dict_users_labeled | set(np.random.choice(all_idxs, int(num_items * label_rate), replace=False)) 129 | # all_idxs = list(set(all_idxs) - dict_users_labeled) 130 | dict_users_unlabeled[i] = set(np.random.choice(all_idxs, int(num_items) , replace=False)) 131 | all_idxs = list(set(all_idxs) - dict_users_unlabeled[i]) 132 | dict_users_unlabeled[i] = dict_users_unlabeled[i] - dict_users_labeled 133 | return dict_users_labeled, dict_users_unlabeled 134 | 135 | 136 | def noniid(dataset, num_users, label_rate): 137 | 138 | num_shards, num_imgs = 2 * num_users, int(len(dataset)/num_users/2) 139 | idx_shard = [i for i in range(num_shards)] 140 | dict_users_unlabeled = {i: np.array([], dtype='int64') for i in range(num_users)} 141 | idxs = np.arange(len(dataset)) 142 | labels = np.arange(len(dataset)) 143 | 144 | 145 | for i in range(len(dataset)): 146 | labels[i] = dataset[i][1] 147 | 148 | num_items = int(len(dataset)/num_users) 149 | dict_users_labeled = set() 150 | 151 | # sort labels 152 | idxs_labels = np.vstack((idxs, labels)) 153 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]#索引值 154 | idxs = idxs_labels[0,:] 155 | 156 | # divide and assign 157 | for i in range(num_users): 158 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 159 | idx_shard = list(set(idx_shard) - rand_set) 160 | for rand in rand_set: 161 | dict_users_unlabeled[i] = np.concatenate((dict_users_unlabeled[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 162 | 163 | dict_users_labeled = set(np.random.choice(list(idxs), int(len(idxs) * label_rate), replace=False)) 164 | 165 | for i in range(num_users): 166 | 167 | dict_users_unlabeled[i] = set(dict_users_unlabeled[i]) 168 | # dict_users_labeled = dict_users_labeled | set(np.random.choice(list(dict_users_unlabeled[i]), int(num_items * label_rate), replace=False)) 169 | dict_users_unlabeled[i] = dict_users_unlabeled[i] - dict_users_labeled 170 | 171 | 172 | return dict_users_labeled, dict_users_unlabeled 173 | 174 | class DatasetSplit(Dataset): 175 | def __init__(self, dataset, idxs): 176 | self.dataset = dataset 177 | self.idxs = list(idxs) 178 | 179 | def __len__(self): 180 | return len(self.idxs) 181 | 182 | def __getitem__(self, item): 183 | (images1, images2), labels = self.dataset[self.idxs[item]] 184 | return (images1, images2), labels 185 | 186 | 187 | 188 | def main(device, args): 189 | 190 | 191 | loss1_func = nn.CrossEntropyLoss() 192 | loss2_func = softmax_kl_loss 193 | 194 | dataset_kwargs = { 195 | 'dataset':args.dataset, 196 | 'data_dir': args.data_dir, 197 | 'download':args.download, 198 | 'debug_subset_size':args.batch_size if args.debug else None 199 | } 200 | dataloader_kwargs = { 201 | 'batch_size': args.batch_size, 202 | 'drop_last': True, 203 | 'pin_memory': True, 204 | 'num_workers': args.num_workers, 205 | } 206 | dataloader_unlabeled_kwargs = { 207 | 'batch_size': args.batch_size*5, 208 | 'drop_last': True, 209 | 'pin_memory': True, 210 | 'num_workers': args.num_workers, 211 | } 212 | dataset_train =get_dataset( 213 | transform=get_aug(args.dataset, True), 214 | train=True, 215 | **dataset_kwargs 216 | ) 217 | 218 | if args.iid == 'iid': 219 | dict_users_labeled, dict_users_unlabeled = iid(dataset_train, args.num_users, args.label_rate) 220 | else: 221 | dict_users_labeled, dict_users_unlabeled = noniid(dataset_train, args.num_users, args.label_rate) 222 | train_loader_unlabeled = {} 223 | 224 | 225 | # define model 226 | model_glob = get_model('global', args.backbone).to(device) 227 | if torch.cuda.device_count() > 1: model_glob = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_glob) 228 | 229 | 230 | model_local_idx = set() 231 | model_local_dict = {} 232 | accuracy = [] 233 | lr_scheduler = {} 234 | 235 | 236 | for iter in range(args.num_epochs): 237 | 238 | model_glob.train() 239 | optimizer = torch.optim.SGD(model_glob.parameters(), lr=0.01, momentum=0.5) 240 | 241 | train_loader_labeled = torch.utils.data.DataLoader( 242 | dataset=DatasetSplit(dataset_train, dict_users_labeled), 243 | shuffle=True, 244 | **dataloader_kwargs 245 | ) 246 | 247 | for batch_idx, ((images1, images2), labels) in enumerate(train_loader_labeled): 248 | labels = labels.cuda() 249 | model_glob.zero_grad() 250 | z1_t, z2_t, z1_s, z2_s = model_glob.forward(images1.to(device, non_blocking=True), images2.to(device, non_blocking=True)) 251 | model_glob.update_moving_average( iter*1000 + batch_idx, 20000) 252 | 253 | loss_class = 1/2 * loss1_func(z1_t, labels) + 1/2 * loss1_func(z2_t, labels) 254 | loss_consist = 1/2 * loss2_func(z1_t, z2_s) / len(labels) + 1/2 * loss2_func(z2_t, z1_s) / len(labels) 255 | consistency_weight = get_current_consistency_weight(iter) 256 | loss = loss_class + consistency_weight * loss_consist 257 | 258 | loss.backward() 259 | optimizer.step() 260 | 261 | del train_loader_labeled 262 | gc.collect() 263 | torch.cuda.empty_cache() 264 | 265 | if iter%1==0: 266 | test_loader = torch.utils.data.DataLoader( 267 | dataset=get_dataset( 268 | transform=get_aug(args.dataset, False, train_classifier=False), 269 | train=False, 270 | **dataset_kwargs), 271 | shuffle=False, 272 | **dataloader_kwargs 273 | ) 274 | model_glob.eval() 275 | acc, loss_train_test_labeled = test_img(model_glob, test_loader, args) 276 | accuracy.append(str(acc)) 277 | del test_loader 278 | gc.collect() 279 | torch.cuda.empty_cache() 280 | 281 | 282 | w_locals, loss_locals, loss0_locals, loss2_locals = [], [], [], [] 283 | 284 | m = max(int(args.frac * args.num_users), 1) 285 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 286 | 287 | for idx in idxs_users: 288 | 289 | loss_local = [] 290 | loss0_local = [] 291 | loss2_local = [] 292 | 293 | 294 | if idx in model_local_idx: 295 | model_local = get_model('local', args.backbone).to(device) 296 | model_local.projector.load_state_dict(model_local_dict[i][0]) 297 | model_local.target_encoder.load_state_dict(model_local_dict[i][1]) 298 | # model_local.projector.load_state_dict(torch.load('/model/'+'model1' + str(args.dataset) + str(idx)+ '.pkl')) 299 | # model_local.target_encoder.load_state_dict(torch.load('/model/'+'model1' + str(args.dataset) + 'tar'+ str(idx)+ '.pkl')) 300 | 301 | model_local.backbone.load_state_dict(model_glob.backbone.state_dict()) 302 | else: 303 | model_local = get_model('local', args.backbone).to(device) 304 | model_local.backbone.load_state_dict(model_glob.backbone.state_dict()) 305 | model_local.target_encoder.load_state_dict(model_local.online_encoder.state_dict()) 306 | model_local_idx = model_local_idx | set([idx]) 307 | 308 | train_loader_unlabeled = torch.utils.data.DataLoader( 309 | dataset=DatasetSplit(dataset_train, dict_users_unlabeled[idx]), 310 | shuffle=True, 311 | **dataloader_unlabeled_kwargs 312 | ) 313 | 314 | # define optimizer 315 | optimizer = get_optimizer( 316 | args.optimizer, model_local, 317 | lr=args.base_lr*args.batch_size/256, 318 | momentum=args.momentum, 319 | weight_decay=args.weight_decay) 320 | 321 | lr_scheduler = LR_Scheduler( 322 | optimizer, 323 | args.warmup_epochs, args.warmup_lr*args.batch_size/256, 324 | args.num_epochs, args.base_lr*args.batch_size/256, args.final_lr*args.batch_size/256, 325 | len(train_loader_unlabeled), 326 | constant_predictor_lr=True # see the end of section 4.2 predictor 327 | ) 328 | 329 | model_local.train() 330 | 331 | for j in range(args.local_ep): 332 | 333 | for i, ((images1, images2), labels) in enumerate(train_loader_unlabeled): 334 | 335 | model_local.zero_grad() 336 | 337 | batch_size = images1.shape[0] 338 | 339 | loss = model_local.forward(images1.to(device, non_blocking=True), images2.to(device, non_blocking=True)) 340 | 341 | loss.backward() 342 | optimizer.step() 343 | 344 | loss_local.append(int(loss)) 345 | 346 | lr = lr_scheduler.step() 347 | 348 | model_local.update_moving_average() 349 | 350 | w_locals.append(copy.deepcopy(model_local.backbone.state_dict())) 351 | loss_locals.append(sum(loss_local) / len(loss_local) ) 352 | model_local_dict[i] = [model_local.projector.state_dict(), model_local.target_encoder.state_dict()] 353 | # torch.save(model_local.projector.state_dict(), '/model/'+'model1' + str(args.dataset) + str(idx)+ '.pkl') 354 | # torch.save(model_local.target_encoder.state_dict(), '/model/'+'model1' + str(args.dataset)+ 'tar' + str(idx)+ '.pkl') 355 | 356 | 357 | del model_local 358 | gc.collect() 359 | del train_loader_unlabeled 360 | gc.collect() 361 | torch.cuda.empty_cache() 362 | 363 | 364 | 365 | w_glob = FedAvg(w_locals) 366 | model_glob.backbone.load_state_dict(w_glob) 367 | 368 | loss_avg = sum(loss_locals) / len(loss_locals) 369 | 370 | if iter%1==0: 371 | print('Round {:3d}, Acc {:.2f}%'.format(iter, acc)) 372 | 373 | if __name__ == "__main__": 374 | args = get_args() 375 | main(device=args.device, args=args) 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | -------------------------------------------------------------------------------- /fedmatch-main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | import numpy as np 7 | import copy 8 | 9 | import gc 10 | 11 | from tqdm import tqdm 12 | from configs import get_args 13 | from augmentations import get_aug 14 | from models import get_model 15 | from tools import AverageMeter, PlotLogger, knn_monitor 16 | from datasets import get_dataset 17 | from optimizers import get_optimizer, LR_Scheduler 18 | from torch.utils.data import DataLoader, Dataset 19 | 20 | 21 | import torch 22 | from torch import nn, autograd 23 | from torch.utils.data import DataLoader, Dataset 24 | import numpy as np 25 | import random 26 | from sklearn import metrics 27 | import torch.nn.functional as F 28 | import copy 29 | from torch.autograd import Variable 30 | import itertools 31 | import logging 32 | import os.path 33 | from PIL import Image 34 | import numpy as np 35 | from torch.utils.data.sampler import Sampler 36 | import re 37 | import argparse 38 | import os 39 | import shutil 40 | import time 41 | import math 42 | import logging 43 | import os 44 | import sys 45 | import torch.backends.cudnn as cudnn 46 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 47 | import torchvision.datasets 48 | 49 | 50 | import torch 51 | from torch import nn 52 | import torch.nn.functional as F 53 | from torch.utils.data import DataLoader 54 | 55 | 56 | def test_img(net_g, data_loader, args): 57 | net_g.eval() 58 | test_loss = 0 59 | correct = 0 60 | 61 | for idx, (data, target) in enumerate(data_loader): 62 | data, target = data.cuda(), target.cuda() 63 | log_probs = net_g(data) 64 | test_loss += F.cross_entropy(log_probs, target, reduction='sum',ignore_index=-1).item() 65 | y_pred = log_probs.data.max(1, keepdim=True)[1] 66 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 67 | 68 | test_loss /= len(data_loader.dataset) 69 | accuracy = 100.00 * correct / len(data_loader.dataset) 70 | 71 | return accuracy, test_loss 72 | 73 | 74 | def get_current_consistency_weight(epoch): 75 | return sigmoid_rampup(epoch, 10) 76 | 77 | def sigmoid_rampup(current, rampup_length): 78 | if rampup_length == 0: 79 | return 1.0 80 | else: 81 | current = np.clip(current, 0.0, rampup_length) 82 | phase = 1.0 - current / rampup_length 83 | return float(np.exp(-5.0 * phase * phase)) 84 | 85 | def softmax_mse_loss(input_logits, target_logits): 86 | assert input_logits.size() == target_logits.size() 87 | input_softmax = F.softmax(input_logits, dim=1) 88 | target_softmax = F.softmax(target_logits, dim=1) 89 | num_classes = input_logits.size()[1] 90 | return F.mse_loss(input_softmax, target_softmax, size_average=False) / num_classes 91 | 92 | def softmax_kl_loss(input_logits, target_logits): 93 | assert input_logits.size() == target_logits.size() 94 | input_log_softmax = F.log_softmax(input_logits, dim=1) 95 | target_softmax = F.softmax(target_logits, dim=1) 96 | return F.kl_div(input_log_softmax, target_softmax, size_average=False) 97 | 98 | def symmetric_mse_loss(input1, input2): 99 | assert input1.size() == input2.size() 100 | num_classes = input1.size()[1] 101 | return torch.sum((input1 - input2)**2) / num_classes 102 | 103 | 104 | def FedAvg(w): 105 | w_avg = copy.deepcopy(w[0]) 106 | # print(w_avg.keys()) 107 | for k in w_avg.keys(): 108 | for i in range(1, len(w)): 109 | w_avg[k] += w[i][k] 110 | w_avg[k] = torch.div(w_avg[k], len(w)) 111 | return w_avg 112 | 113 | 114 | def iid(dataset, num_users, label_rate): 115 | """ 116 | Sample I.I.D. client data from MNIST dataset 117 | :param dataset: 118 | :param num_users: 119 | :return: dict of image index 120 | """ 121 | num_items = int(len(dataset)/num_users) 122 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 123 | dict_users_labeled, dict_users_unlabeled = set(), {} 124 | 125 | dict_users_labeled = set(np.random.choice(list(all_idxs), int(len(all_idxs) * label_rate), replace=False)) 126 | 127 | for i in range(num_users): 128 | # dict_users_labeled = dict_users_labeled | set(np.random.choice(all_idxs, int(num_items * label_rate), replace=False)) 129 | # all_idxs = list(set(all_idxs) - dict_users_labeled) 130 | dict_users_unlabeled[i] = set(np.random.choice(all_idxs, int(num_items) , replace=False)) 131 | all_idxs = list(set(all_idxs) - dict_users_unlabeled[i]) 132 | dict_users_unlabeled[i] = dict_users_unlabeled[i] - dict_users_labeled 133 | return dict_users_labeled, dict_users_unlabeled 134 | 135 | 136 | def noniid(dataset, num_users, label_rate): 137 | 138 | num_shards, num_imgs = 2 * num_users, int(len(dataset)/num_users/2) 139 | idx_shard = [i for i in range(num_shards)] 140 | dict_users_unlabeled = {i: np.array([], dtype='int64') for i in range(num_users)} 141 | idxs = np.arange(len(dataset)) 142 | labels = np.arange(len(dataset)) 143 | 144 | 145 | for i in range(len(dataset)): 146 | labels[i] = dataset[i][1] 147 | 148 | num_items = int(len(dataset)/num_users) 149 | dict_users_labeled = set() 150 | 151 | # sort labels 152 | idxs_labels = np.vstack((idxs, labels)) 153 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]#索引值 154 | idxs = idxs_labels[0,:] 155 | 156 | # divide and assign 157 | for i in range(num_users): 158 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 159 | idx_shard = list(set(idx_shard) - rand_set) 160 | for rand in rand_set: 161 | dict_users_unlabeled[i] = np.concatenate((dict_users_unlabeled[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 162 | 163 | dict_users_labeled = set(np.random.choice(list(idxs), int(len(idxs) * label_rate), replace=False)) 164 | 165 | for i in range(num_users): 166 | 167 | dict_users_unlabeled[i] = set(dict_users_unlabeled[i]) 168 | # dict_users_labeled = dict_users_labeled | set(np.random.choice(list(dict_users_unlabeled[i]), int(num_items * label_rate), replace=False)) 169 | dict_users_unlabeled[i] = dict_users_unlabeled[i] - dict_users_labeled 170 | 171 | 172 | return dict_users_labeled, dict_users_unlabeled 173 | 174 | class DatasetSplit(Dataset): 175 | def __init__(self, dataset, idxs): 176 | self.dataset = dataset 177 | self.idxs = list(idxs) 178 | 179 | def __len__(self): 180 | return len(self.idxs) 181 | 182 | def __getitem__(self, item): 183 | (images1, images2), labels = self.dataset[self.idxs[item]] 184 | return (images1, images2), labels 185 | 186 | def dist_l2(w, w_ema): 187 | a = 0.0 188 | for i in list(w.keys()) : 189 | a += (w[i]-w_ema[i]).float().norm(2) 190 | return a 191 | 192 | def dist_l1(w): 193 | a = 0.0 194 | for i in list(w.keys()) : 195 | a += w[i].float().norm(1) 196 | return a 197 | 198 | def main(device, args): 199 | 200 | 201 | loss1_func = nn.CrossEntropyLoss() 202 | loss2_func = softmax_kl_loss 203 | 204 | dataset_kwargs = { 205 | 'dataset':args.dataset, 206 | 'data_dir': args.data_dir, 207 | 'download':args.download, 208 | 'debug_subset_size':args.batch_size if args.debug else None 209 | } 210 | dataloader_kwargs = { 211 | 'batch_size': args.batch_size, 212 | 'drop_last': True, 213 | 'pin_memory': True, 214 | 'num_workers': args.num_workers, 215 | } 216 | dataloader_unlabeled_kwargs = { 217 | 'batch_size': args.batch_size*5, 218 | 'drop_last': True, 219 | 'pin_memory': True, 220 | 'num_workers': args.num_workers, 221 | } 222 | dataset_train =get_dataset( 223 | transform=get_aug(args.dataset, True), 224 | train=True, 225 | **dataset_kwargs 226 | ) 227 | 228 | if args.iid == 'iid': 229 | dict_users_labeled, dict_users_unlabeled = iid(dataset_train, args.num_users, args.label_rate) 230 | else: 231 | dict_users_labeled, dict_users_unlabeled = noniid(dataset_train, args.num_users, args.label_rate) 232 | train_loader_unlabeled = {} 233 | 234 | 235 | # define model 236 | model_glob = get_model('fedfixmatch', args.backbone).to(device) 237 | if torch.cuda.device_count() > 1: model_glob = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_glob) 238 | 239 | model_h1 = get_model('fedfixmatch', args.backbone).to(device) 240 | if torch.cuda.device_count() > 1: model_h1 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_h1) 241 | model_h2 = get_model('fedfixmatch', args.backbone).to(device) 242 | if torch.cuda.device_count() > 1: model_h2 = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_h2) 243 | 244 | for iter in range(args.num_epochs): 245 | 246 | model_glob.train() 247 | optimizer = torch.optim.SGD(model_glob.parameters(), lr=0.01, momentum=0.5) 248 | class_criterion = nn.CrossEntropyLoss(size_average=False, ignore_index= -1 ) 249 | 250 | train_loader_labeled = torch.utils.data.DataLoader( 251 | dataset=DatasetSplit(dataset_train, dict_users_labeled), 252 | shuffle=True, 253 | **dataloader_kwargs 254 | ) 255 | 256 | for batch_idx, ((img, img_ema), label) in enumerate(train_loader_labeled): 257 | input_var = torch.autograd.Variable(img.cuda()) 258 | ema_input_var = torch.autograd.Variable(img_ema.cuda()) 259 | target_var = torch.autograd.Variable(label.cuda()) 260 | minibatch_size = len(target_var) 261 | labeled_minibatch_size = target_var.data.ne(-1).sum() 262 | ema_model_out = model_glob(ema_input_var) 263 | model_out = model_glob(input_var) 264 | if isinstance(model_out, Variable): 265 | logit1 = model_out 266 | ema_logit = ema_model_out 267 | else: 268 | assert len(model_out) == 2 269 | assert len(ema_model_out) == 2 270 | logit1, logit2 = model_out 271 | ema_logit, _ = ema_model_out 272 | 273 | ema_logit = Variable(ema_logit.detach().data, requires_grad=False) 274 | class_logit, cons_logit = logit1, logit1 275 | class_loss = class_criterion(class_logit, target_var) / minibatch_size 276 | loss = class_loss 277 | optimizer.zero_grad() 278 | loss.backward() 279 | optimizer.step() 280 | # batch_loss.append(loss.item()) 281 | 282 | 283 | del train_loader_labeled 284 | gc.collect() 285 | torch.cuda.empty_cache() 286 | 287 | if iter%1==0: 288 | test_loader = torch.utils.data.DataLoader( 289 | dataset=get_dataset( 290 | transform=get_aug(args.dataset, False, train_classifier=False), 291 | train=False, 292 | **dataset_kwargs), 293 | shuffle=False, 294 | **dataloader_kwargs 295 | ) 296 | model_glob.eval() 297 | accuracy, loss_train_test_labeled = test_img(model_glob, test_loader, args) 298 | del test_loader 299 | gc.collect() 300 | torch.cuda.empty_cache() 301 | 302 | 303 | w_locals, loss_locals, loss0_locals, loss2_locals = [], [], [], [] 304 | 305 | m = max(int(args.frac * args.num_users), 1) 306 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 307 | 308 | for idx in idxs_users: 309 | 310 | loss_local = [] 311 | loss0_local = [] 312 | loss2_local = [] 313 | 314 | 315 | train_loader_unlabeled = torch.utils.data.DataLoader( 316 | dataset=DatasetSplit(dataset_train, dict_users_unlabeled[idx]), 317 | shuffle=True, 318 | **dataloader_unlabeled_kwargs 319 | ) 320 | 321 | model_local_t = copy.deepcopy(model_glob).to(args.device) 322 | model_local_t.train() 323 | 324 | 325 | for i, ((img, img_ema), label) in enumerate(train_loader_unlabeled): 326 | 327 | input_var = torch.autograd.Variable(img.cuda()) 328 | ema_input_var = torch.autograd.Variable(img_ema.cuda()) 329 | target_var = torch.autograd.Variable(label.cuda()) 330 | minibatch_size = len(target_var) 331 | labeled_minibatch_size = target_var.data.ne(-1).sum() 332 | ema_model_out = model_local_t(ema_input_var) 333 | model_out = model_local_t(input_var) 334 | 335 | model_out_helper_1 = model_h1(input_var) 336 | model_out_helper_2 = model_h2(input_var) 337 | 338 | if isinstance(model_out, Variable): 339 | logit1 = model_out 340 | ema_logit = ema_model_out 341 | else: 342 | assert len(model_out) == 2 343 | assert len(ema_model_out) == 2 344 | logit1, logit2 = model_out 345 | ema_logit, _ = ema_model_out 346 | 347 | ema_logit = Variable(ema_logit.detach().data, requires_grad=True) 348 | class_logit, cons_logit = logit1, logit1 349 | 350 | pseudo_label1 = torch.softmax(model_out.detach_(), dim=-1) 351 | pseudo_label2 = torch.softmax(model_out_helper_1.detach_(), dim=-1) 352 | pseudo_label3 = torch.softmax(model_out_helper_2.detach_(), dim=-1) 353 | 354 | max_probs1, targets_u1 = torch.max(pseudo_label1, dim=-1) 355 | max_probs2, targets_u2 = torch.max(pseudo_label2, dim=-1) 356 | max_probs3, targets_u3 = torch.max(pseudo_label3, dim=-1) 357 | 358 | if torch.equal(targets_u1, targets_u2) and torch.equal(targets_u1, targets_u3): 359 | max_probs = torch.max(max_probs1, max_probs2) 360 | max_probs = torch.max(max_probs, max_probs3) 361 | else: 362 | max_probs = max_probs1 - 0.2 363 | targets_u = targets_u1 364 | mask = max_probs.ge(args.threshold_pl).float() 365 | Lu = (F.cross_entropy(ema_logit, targets_u, reduction='none') * mask).mean() 366 | 367 | lambda_iccs = 0.01 368 | lambda_l2 = 10 369 | lambda_l1 = 0.0001 370 | 371 | loss = lambda_iccs*(Lu) + lambda_l2*dist_l2(model_local_t.state_dict(), model_glob.state_dict()) + lambda_l1*dist_l1(model_local_t.state_dict()) 372 | optimizer.zero_grad() 373 | loss.backward() 374 | optimizer.step() 375 | # batch_loss.append(loss.item()) 376 | 377 | w_locals.append(copy.deepcopy(model_local_t.state_dict())) 378 | # loss_locals.append(sum(loss_local) / len(loss_local) ) 379 | 380 | del model_local_t 381 | gc.collect() 382 | del train_loader_unlabeled 383 | gc.collect() 384 | torch.cuda.empty_cache() 385 | 386 | 387 | model_h1.load_state_dict(w_locals[0]) 388 | model_h2.load_state_dict(w_locals[1]) 389 | 390 | w_glob = FedAvg(w_locals) 391 | model_glob.load_state_dict(w_glob) 392 | 393 | # loss_avg = sum(loss_locals) / len(loss_locals) 394 | 395 | if iter%1==0: 396 | print('Round {:3d}, Acc {:.2f}%'.format(iter, acc)) 397 | 398 | 399 | 400 | if __name__ == "__main__": 401 | args = get_args() 402 | main(device=args.device, args=args) 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | -------------------------------------------------------------------------------- /fedprox-fixmatch-main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | import numpy as np 7 | import copy 8 | 9 | import gc 10 | 11 | from tqdm import tqdm 12 | from configs import get_args 13 | from augmentations import get_aug, get_aug_fedmatch 14 | from models import get_model 15 | from tools import AverageMeter, PlotLogger, knn_monitor 16 | from datasets import get_dataset 17 | from optimizers import get_optimizer, LR_Scheduler 18 | from torch.utils.data import DataLoader, Dataset 19 | 20 | 21 | import torch 22 | from torch import nn, autograd 23 | from torch.utils.data import DataLoader, Dataset 24 | import numpy as np 25 | import random 26 | from sklearn import metrics 27 | import torch.nn.functional as F 28 | import copy 29 | from torch.autograd import Variable 30 | import itertools 31 | import logging 32 | import os.path 33 | from PIL import Image 34 | import numpy as np 35 | from torch.utils.data.sampler import Sampler 36 | import re 37 | import argparse 38 | import os 39 | import shutil 40 | import time 41 | import math 42 | import logging 43 | import os 44 | import sys 45 | import torch.backends.cudnn as cudnn 46 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 47 | import torchvision.datasets 48 | 49 | 50 | import torch 51 | from torch import nn 52 | import torch.nn.functional as F 53 | from torch.utils.data import DataLoader 54 | 55 | 56 | def test_img(net_g, data_loader, args): 57 | net_g.eval() 58 | test_loss = 0 59 | correct = 0 60 | 61 | for idx, (data, target) in enumerate(data_loader): 62 | data, target = data.cuda(), target.cuda() 63 | log_probs = net_g(data) 64 | test_loss += F.cross_entropy(log_probs, target, reduction='sum',ignore_index=-1).item() 65 | y_pred = log_probs.data.max(1, keepdim=True)[1] 66 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 67 | 68 | test_loss /= len(data_loader.dataset) 69 | accuracy = 100.00 * correct / len(data_loader.dataset) 70 | 71 | return accuracy, test_loss 72 | 73 | 74 | def get_current_consistency_weight(epoch): 75 | return sigmoid_rampup(epoch, 10) 76 | 77 | def sigmoid_rampup(current, rampup_length): 78 | if rampup_length == 0: 79 | return 1.0 80 | else: 81 | current = np.clip(current, 0.0, rampup_length) 82 | phase = 1.0 - current / rampup_length 83 | return float(np.exp(-5.0 * phase * phase)) 84 | 85 | def softmax_mse_loss(input_logits, target_logits): 86 | assert input_logits.size() == target_logits.size() 87 | input_softmax = F.softmax(input_logits, dim=1) 88 | target_softmax = F.softmax(target_logits, dim=1) 89 | num_classes = input_logits.size()[1] 90 | return F.mse_loss(input_softmax, target_softmax, size_average=False) / num_classes 91 | 92 | def softmax_kl_loss(input_logits, target_logits): 93 | assert input_logits.size() == target_logits.size() 94 | input_log_softmax = F.log_softmax(input_logits, dim=1) 95 | target_softmax = F.softmax(target_logits, dim=1) 96 | return F.kl_div(input_log_softmax, target_softmax, size_average=False) 97 | 98 | def symmetric_mse_loss(input1, input2): 99 | assert input1.size() == input2.size() 100 | num_classes = input1.size()[1] 101 | return torch.sum((input1 - input2)**2) / num_classes 102 | 103 | 104 | def FedAvg(w): 105 | w_avg = copy.deepcopy(w[0]) 106 | # print(w_avg.keys()) 107 | for k in w_avg.keys(): 108 | for i in range(1, len(w)): 109 | w_avg[k] += w[i][k] 110 | w_avg[k] = torch.div(w_avg[k], len(w)) 111 | return w_avg 112 | 113 | 114 | def iid(dataset, num_users, label_rate): 115 | """ 116 | Sample I.I.D. client data from MNIST dataset 117 | :param dataset: 118 | :param num_users: 119 | :return: dict of image index 120 | """ 121 | num_items = int(len(dataset)/num_users) 122 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 123 | dict_users_labeled, dict_users_unlabeled = set(), {} 124 | 125 | dict_users_labeled = set(np.random.choice(list(all_idxs), int(len(all_idxs) * label_rate), replace=False)) 126 | 127 | for i in range(num_users): 128 | # dict_users_labeled = dict_users_labeled | set(np.random.choice(all_idxs, int(num_items * label_rate), replace=False)) 129 | # all_idxs = list(set(all_idxs) - dict_users_labeled) 130 | dict_users_unlabeled[i] = set(np.random.choice(all_idxs, int(num_items) , replace=False)) 131 | all_idxs = list(set(all_idxs) - dict_users_unlabeled[i]) 132 | dict_users_unlabeled[i] = dict_users_unlabeled[i] - dict_users_labeled 133 | return dict_users_labeled, dict_users_unlabeled 134 | 135 | # def noniid(dataset, num_users, label_rate): 136 | 137 | # num_shards, num_imgs = 2 * num_users, int(len(dataset)/num_users/2) 138 | # idx_shard = [i for i in range(num_shards)] 139 | # dict_users_unlabeled = {i: np.array([], dtype='int64') for i in range(num_users)} 140 | # idxs = np.arange(num_shards*num_imgs) 141 | # labels = dataset.train_labels.numpy() 142 | # # print(type(labels)) 143 | 144 | # num_items = int(len(dataset)/num_users) 145 | # dict_users_labeled = set() 146 | # pseduo_label = [i for i in range(len(dataset))] 147 | 148 | # # sort labels 149 | # idxs_labels = np.vstack((idxs, labels)) 150 | # idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]#索引值 151 | # idxs = idxs_labels[0,:] 152 | 153 | # # divide and assign 154 | # for i in range(num_users): 155 | # rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 156 | # idx_shard = list(set(idx_shard) - rand_set) 157 | # for rand in rand_set: 158 | # dict_users_unlabeled[i] = np.concatenate((dict_users_unlabeled[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 159 | 160 | # for i in range(num_users): 161 | 162 | # dict_users_unlabeled[i] = set(dict_users_unlabeled[i]) 163 | # dict_users_labeled = dict_users_labeled | set(np.random.choice(list(dict_users_unlabeled[i]), int(num_items * label_rate), replace=False)) 164 | # dict_users_unlabeled[i] = dict_users_unlabeled[i] - dict_users_labeled 165 | 166 | 167 | # return dict_users_labeled, dict_users_unlabeled 168 | 169 | def noniid(dataset, num_users, label_rate): 170 | 171 | num_shards, num_imgs = 2 * num_users, int(len(dataset)/num_users/2) 172 | idx_shard = [i for i in range(num_shards)] 173 | dict_users_unlabeled = {i: np.array([], dtype='int64') for i in range(num_users)} 174 | idxs = np.arange(len(dataset)) 175 | labels = np.arange(len(dataset)) 176 | 177 | 178 | for i in range(len(dataset)): 179 | labels[i] = dataset[i][1] 180 | 181 | num_items = int(len(dataset)/num_users) 182 | dict_users_labeled = set() 183 | 184 | # sort labels 185 | idxs_labels = np.vstack((idxs, labels)) 186 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]#索引值 187 | idxs = idxs_labels[0,:] 188 | 189 | # divide and assign 190 | for i in range(num_users): 191 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 192 | idx_shard = list(set(idx_shard) - rand_set) 193 | for rand in rand_set: 194 | dict_users_unlabeled[i] = np.concatenate((dict_users_unlabeled[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 195 | 196 | dict_users_labeled = set(np.random.choice(list(idxs), int(len(idxs) * label_rate), replace=False)) 197 | 198 | for i in range(num_users): 199 | 200 | dict_users_unlabeled[i] = set(dict_users_unlabeled[i]) 201 | # dict_users_labeled = dict_users_labeled | set(np.random.choice(list(dict_users_unlabeled[i]), int(num_items * label_rate), replace=False)) 202 | dict_users_unlabeled[i] = dict_users_unlabeled[i] - dict_users_labeled 203 | 204 | 205 | return dict_users_labeled, dict_users_unlabeled 206 | 207 | 208 | class DatasetSplit(Dataset): 209 | def __init__(self, dataset, idxs): 210 | self.dataset = dataset 211 | self.idxs = list(idxs) 212 | 213 | def __len__(self): 214 | return len(self.idxs) 215 | 216 | def __getitem__(self, item): 217 | (images1, images2), labels = self.dataset[self.idxs[item]] 218 | return (images1, images2), labels 219 | 220 | def dist(w0, w1): 221 | dist_w = 0.0 222 | for i in list(w0.keys()) : 223 | dist_w += (w0[i]-w1[i]).float().norm(2)**2 224 | dist_w = math.sqrt(dist_w) 225 | return dist_w 226 | 227 | def main(device, args): 228 | 229 | 230 | loss1_func = nn.CrossEntropyLoss() 231 | loss2_func = softmax_kl_loss 232 | 233 | dataset_kwargs = { 234 | 'dataset':args.dataset, 235 | 'data_dir': args.data_dir, 236 | 'download':args.download, 237 | 'debug_subset_size':args.batch_size if args.debug else None 238 | } 239 | dataloader_kwargs = { 240 | 'batch_size': args.batch_size, 241 | 'drop_last': True, 242 | 'pin_memory': True, 243 | 'num_workers': args.num_workers, 244 | } 245 | dataloader_unlabeled_kwargs = { 246 | 'batch_size': args.batch_size*5, 247 | 'drop_last': True, 248 | 'pin_memory': True, 249 | 'num_workers': args.num_workers, 250 | } 251 | dataset_train =get_dataset( 252 | transform=get_aug_fedmatch(args.dataset, True), 253 | train=True, 254 | **dataset_kwargs 255 | ) 256 | 257 | if args.iid == 'iid': 258 | dict_users_labeled, dict_users_unlabeled = iid(dataset_train, args.num_users, args.label_rate) 259 | else: 260 | dict_users_labeled, dict_users_unlabeled = noniid(dataset_train, args.num_users, args.label_rate) 261 | 262 | train_loader_unlabeled = {} 263 | 264 | 265 | # define model 266 | model_glob = get_model('fedfixmatch', args.backbone).to(device) 267 | if torch.cuda.device_count() > 1: model_glob = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_glob) 268 | 269 | accuracy = [] 270 | 271 | for iter in range(args.num_epochs): 272 | 273 | model_glob.train() 274 | optimizer = torch.optim.SGD(model_glob.parameters(), lr=0.01, momentum=0.5) 275 | class_criterion = nn.CrossEntropyLoss(size_average=False, ignore_index= -1 ) 276 | 277 | train_loader_labeled = torch.utils.data.DataLoader( 278 | dataset=DatasetSplit(dataset_train, dict_users_labeled), 279 | shuffle=True, 280 | **dataloader_kwargs 281 | ) 282 | 283 | for batch_idx, ((img, img_ema), label) in enumerate(train_loader_labeled): 284 | input_var = torch.autograd.Variable(img.cuda()) 285 | ema_input_var = torch.autograd.Variable(img_ema.cuda()) 286 | target_var = torch.autograd.Variable(label.cuda()) 287 | minibatch_size = len(target_var) 288 | labeled_minibatch_size = target_var.data.ne(-1).sum() 289 | ema_model_out = model_glob(ema_input_var) 290 | model_out = model_glob(input_var) 291 | if isinstance(model_out, Variable): 292 | logit1 = model_out 293 | ema_logit = ema_model_out 294 | else: 295 | assert len(model_out) == 2 296 | assert len(ema_model_out) == 2 297 | logit1, logit2 = model_out 298 | ema_logit, _ = ema_model_out 299 | 300 | ema_logit = Variable(ema_logit.detach().data, requires_grad=False) 301 | class_logit, cons_logit = logit1, logit1 302 | class_loss = class_criterion(class_logit, target_var) / minibatch_size 303 | ema_class_loss = class_criterion(ema_logit, target_var) / minibatch_size 304 | pseudo_label1 = torch.softmax(model_out.detach_(), dim=-1) 305 | max_probs, targets_u = torch.max(pseudo_label1, dim=-1) 306 | mask = max_probs.ge(args.threshold_pl).float() 307 | Lu = (F.cross_entropy(ema_logit, targets_u, reduction='none') * mask).mean() 308 | loss = class_loss + Lu 309 | optimizer.zero_grad() 310 | loss.backward() 311 | optimizer.step() 312 | # batch_loss.append(loss.item()) 313 | 314 | 315 | del train_loader_labeled 316 | gc.collect() 317 | torch.cuda.empty_cache() 318 | 319 | if iter%1==0: 320 | test_loader = torch.utils.data.DataLoader( 321 | dataset=get_dataset( 322 | transform=get_aug(args.dataset, False, train_classifier=False), 323 | train=False, 324 | **dataset_kwargs), 325 | shuffle=False, 326 | **dataloader_kwargs 327 | ) 328 | model_glob.eval() 329 | acc, loss_train_test_labeled = test_img(model_glob, test_loader, args) 330 | accuracy.append(str(acc)) 331 | del test_loader 332 | gc.collect() 333 | torch.cuda.empty_cache() 334 | 335 | 336 | w_locals, loss_locals, loss0_locals, loss2_locals = [], [], [], [] 337 | 338 | m = max(int(args.frac * args.num_users), 1) 339 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 340 | 341 | for idx in idxs_users: 342 | 343 | loss_local = [] 344 | loss0_local = [] 345 | loss2_local = [] 346 | 347 | model_local = copy.deepcopy(model_glob).to(args.device) 348 | 349 | train_loader_unlabeled = torch.utils.data.DataLoader( 350 | dataset=DatasetSplit(dataset_train, dict_users_unlabeled[idx]), 351 | shuffle=True, 352 | **dataloader_unlabeled_kwargs 353 | ) 354 | 355 | 356 | model_local.train() 357 | 358 | 359 | for i, ((img, img_ema), label) in enumerate(train_loader_unlabeled): 360 | 361 | input_var = torch.autograd.Variable(img.cuda()) 362 | ema_input_var = torch.autograd.Variable(img_ema.cuda()) 363 | target_var = torch.autograd.Variable(label.cuda()) 364 | minibatch_size = len(target_var) 365 | labeled_minibatch_size = target_var.data.ne(-1).sum() 366 | ema_model_out = model_local(ema_input_var) 367 | model_out = model_local(input_var) 368 | if isinstance(model_out, Variable): 369 | logit1 = model_out 370 | ema_logit = ema_model_out 371 | else: 372 | assert len(model_out) == 2 373 | assert len(ema_model_out) == 2 374 | logit1, logit2 = model_out 375 | ema_logit, _ = ema_model_out 376 | 377 | ema_logit = Variable(ema_logit.detach().data, requires_grad=True) 378 | class_logit, cons_logit = logit1, logit1 379 | # class_loss = class_criterion(class_logit, target_var) / minibatch_size 380 | # ema_class_loss = class_criterion(ema_logit, target_var) / minibatch_size 381 | pseudo_label1 = torch.softmax(model_out.detach_(), dim=-1) 382 | max_probs, targets_u = torch.max(pseudo_label1, dim=-1) 383 | mask = max_probs.ge(args.threshold_pl).float() 384 | Lu = (F.cross_entropy(ema_logit, targets_u, reduction='none') * mask).mean() 385 | 386 | ### fedprox 387 | Lprox = 1 / 2 * dist(model_local.state_dict(), model_glob.state_dict()) 388 | 389 | 390 | loss = Lu + Lprox 391 | optimizer.zero_grad() 392 | loss.backward() 393 | optimizer.step() 394 | # batch_loss.append(loss.item()) 395 | 396 | w_locals.append(copy.deepcopy(model_local.state_dict())) 397 | # loss_locals.append(sum(loss_local) / len(loss_local) ) 398 | 399 | del model_local 400 | gc.collect() 401 | del train_loader_unlabeled 402 | gc.collect() 403 | torch.cuda.empty_cache() 404 | 405 | 406 | 407 | w_glob = FedAvg(w_locals) 408 | model_glob.load_state_dict(w_glob) 409 | 410 | # loss_avg = sum(loss_locals) / len(loss_locals) 411 | 412 | if iter%1==0: 413 | print('Round {:3d}, Acc {:.2f}%'.format(iter, acc)) 414 | 415 | 416 | 417 | if __name__ == "__main__": 418 | args = get_args() 419 | main(device=args.device, args=args) 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | -------------------------------------------------------------------------------- /fedprox-uda-main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | import numpy as np 7 | import copy 8 | 9 | import gc 10 | 11 | from tqdm import tqdm 12 | from configs import get_args 13 | from augmentations import get_aug, get_aug_fedmatch 14 | from models import get_model 15 | from tools import AverageMeter, PlotLogger, knn_monitor 16 | from datasets import get_dataset 17 | from optimizers import get_optimizer, LR_Scheduler 18 | from torch.utils.data import DataLoader, Dataset 19 | 20 | 21 | import torch 22 | from torch import nn, autograd 23 | from torch.utils.data import DataLoader, Dataset 24 | import numpy as np 25 | import random 26 | from sklearn import metrics 27 | import torch.nn.functional as F 28 | import copy 29 | from torch.autograd import Variable 30 | import itertools 31 | import logging 32 | import os.path 33 | from PIL import Image 34 | import numpy as np 35 | from torch.utils.data.sampler import Sampler 36 | import re 37 | import argparse 38 | import os 39 | import shutil 40 | import time 41 | import math 42 | import logging 43 | import os 44 | import sys 45 | import torch.backends.cudnn as cudnn 46 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 47 | import torchvision.datasets 48 | 49 | 50 | import torch 51 | from torch import nn 52 | import torch.nn.functional as F 53 | from torch.utils.data import DataLoader 54 | 55 | import torch 56 | from torch import nn, autograd 57 | from torch.utils.data import DataLoader, Dataset 58 | import numpy as np 59 | import random 60 | from sklearn import metrics 61 | import torch.nn.functional as F 62 | import copy 63 | from torch.autograd import Variable 64 | import itertools 65 | import logging 66 | import os.path 67 | from PIL import Image 68 | import numpy as np 69 | from torch.utils.data.sampler import Sampler 70 | import re 71 | import argparse 72 | import os 73 | import shutil 74 | import time 75 | import math 76 | import logging 77 | import os 78 | import sys 79 | import torch.backends.cudnn as cudnn 80 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 81 | import torchvision.datasets 82 | 83 | def quantile_linear(iter, args): 84 | 85 | turn_point = int( (args.comu_rate * args.epochs - 0.1 * args.epochs -1.35) / 0.45 ) 86 | if iter < args.phi_g: 87 | return 1.0 88 | elif iter > turn_point: 89 | return 0.1 90 | else: 91 | return 0.9 * iter / ( 2 - turn_point ) + 1 - 1.8/( 2 - turn_point ) 92 | 93 | 94 | def quantile_rectangle(iter, args): 95 | if iter < args.phi_g: 96 | return 0.0 97 | elif iter >= args.psi_g: 98 | return 0.0 99 | else: 100 | if args.comu_rate*5/3 > 1: 101 | return 0.99 102 | else: 103 | return args.comu_rate*args.epochs/(args.psi_g - args.phi_g) 104 | 105 | def get_median(data, iter, args): 106 | if args.dataset == 'mnist': 107 | a = 8 108 | else: 109 | a = 33 110 | 111 | if len(data) < (39*a): 112 | data_test = data[(-10*a):] 113 | elif len(data) < (139*a): 114 | data_test = data[(30*a) : ] 115 | else: 116 | data_test = data[(-100*a):] 117 | 118 | data_test.sort() 119 | 120 | if args.ramp == 'linear': 121 | quantile = quantile_linear(iter, args) 122 | iter_place = int( (1 - quantile) * len(data_test)) 123 | elif args.ramp == 'flat': 124 | quantile = quantile_flat(iter, args) 125 | iter_place = int( (1 - quantile) * len(data_test)) 126 | elif args.ramp == 'rectangle': 127 | quantile = quantile_rectangle(iter, args) 128 | iter_place = int( (1 - quantile) * len(data_test)-1) 129 | else: 130 | exit('Error: wrong ramp type!') 131 | return data_test[iter_place] 132 | 133 | def sigmoid_rampup(current, rampup_length): 134 | if rampup_length == 0: 135 | return 1.0 136 | else: 137 | current = np.clip(current, 0.0, rampup_length) 138 | phase = 1.0 - current / rampup_length 139 | return float(np.exp(-5.0 * phase * phase)) 140 | 141 | def sigmoid_rampup2(current, rampup_length): 142 | if rampup_length == 0: 143 | return 1.0 144 | else: 145 | current = np.clip(current, 0.0, rampup_length) 146 | phase = current / rampup_length 147 | return float(np.exp(-5.0 * phase * phase)) 148 | 149 | def linear_rampup(current, rampup_length): 150 | assert current >= 0 and rampup_length >= 0 151 | if current >= rampup_length: 152 | return 1.0 153 | else: 154 | return current / rampup_length 155 | 156 | 157 | def cosine_rampdown(current, rampdown_length): 158 | assert 0 <= current <= rampdown_length 159 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 160 | 161 | def test_img(net_g, data_loader, args): 162 | net_g.eval() 163 | test_loss = 0 164 | correct = 0 165 | 166 | for idx, (data, target) in enumerate(data_loader): 167 | data, target = data.cuda(), target.cuda() 168 | log_probs = net_g(data) 169 | test_loss += F.cross_entropy(log_probs, target, reduction='sum',ignore_index=-1).item() 170 | y_pred = log_probs.data.max(1, keepdim=True)[1] 171 | correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum() 172 | 173 | test_loss /= len(data_loader.dataset) 174 | accuracy = 100.00 * correct / len(data_loader.dataset) 175 | 176 | return accuracy, test_loss 177 | 178 | def adjust_learning_rate(optimizer, epoch, step_in_epoch, total_steps_in_epoch , args): 179 | lr = args.lr 180 | epoch = epoch + step_in_epoch / total_steps_in_epoch 181 | lr = linear_rampup(epoch, args.lr_rampup) * (args.lr - args.initial_lr) + args.initial_lr 182 | if args.lr_rampdown_epochs: 183 | lr *= cosine_rampdown(epoch, args.lr_rampdown_epochs) 184 | for param_group in optimizer.param_groups: 185 | param_group['lr'] = lr 186 | 187 | def get_current_consistency_weight(epoch): 188 | return sigmoid_rampup(epoch, 10) 189 | 190 | def sigmoid_rampup(current, rampup_length): 191 | if rampup_length == 0: 192 | return 1.0 193 | else: 194 | current = np.clip(current, 0.0, rampup_length) 195 | phase = 1.0 - current / rampup_length 196 | return float(np.exp(-5.0 * phase * phase)) 197 | 198 | def softmax_mse_loss(input_logits, target_logits): 199 | assert input_logits.size() == target_logits.size() 200 | input_softmax = F.softmax(input_logits, dim=1) 201 | target_softmax = F.softmax(target_logits, dim=1) 202 | num_classes = input_logits.size()[1] 203 | return F.mse_loss(input_softmax, target_softmax, size_average=False) / num_classes 204 | 205 | def softmax_kl_loss(input_logits, target_logits): 206 | assert input_logits.size() == target_logits.size() 207 | input_log_softmax = F.log_softmax(input_logits, dim=1) 208 | target_softmax = F.softmax(target_logits, dim=1) 209 | return F.kl_div(input_log_softmax, target_softmax, size_average=False) 210 | 211 | def symmetric_mse_loss(input1, input2): 212 | assert input1.size() == input2.size() 213 | num_classes = input1.size()[1] 214 | return torch.sum((input1 - input2)**2) / num_classes 215 | 216 | 217 | def FedAvg(w): 218 | w_avg = copy.deepcopy(w[0]) 219 | # print(w_avg.keys()) 220 | for k in w_avg.keys(): 221 | for i in range(1, len(w)): 222 | w_avg[k] += w[i][k] 223 | w_avg[k] = torch.div(w_avg[k], len(w)) 224 | return w_avg 225 | 226 | 227 | def iid(dataset, num_users, label_rate): 228 | """ 229 | Sample I.I.D. client data from MNIST dataset 230 | :param dataset: 231 | :param num_users: 232 | :return: dict of image index 233 | """ 234 | num_items = int(len(dataset)/num_users) 235 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 236 | dict_users_labeled, dict_users_unlabeled = set(), {} 237 | 238 | dict_users_labeled = set(np.random.choice(list(all_idxs), int(len(all_idxs) * label_rate), replace=False)) 239 | 240 | for i in range(num_users): 241 | # dict_users_labeled = dict_users_labeled | set(np.random.choice(all_idxs, int(num_items * label_rate), replace=False)) 242 | # all_idxs = list(set(all_idxs) - dict_users_labeled) 243 | dict_users_unlabeled[i] = set(np.random.choice(all_idxs, int(num_items) , replace=False)) 244 | all_idxs = list(set(all_idxs) - dict_users_unlabeled[i]) 245 | dict_users_unlabeled[i] = dict_users_unlabeled[i] - dict_users_labeled 246 | return dict_users_labeled, dict_users_unlabeled 247 | 248 | 249 | def noniid(dataset, num_users, label_rate): 250 | 251 | num_shards, num_imgs = 2 * num_users, int(len(dataset)/num_users/2) 252 | idx_shard = [i for i in range(num_shards)] 253 | dict_users_unlabeled = {i: np.array([], dtype='int64') for i in range(num_users)} 254 | idxs = np.arange(len(dataset)) 255 | labels = np.arange(len(dataset)) 256 | 257 | 258 | for i in range(len(dataset)): 259 | labels[i] = dataset[i][1] 260 | 261 | num_items = int(len(dataset)/num_users) 262 | dict_users_labeled = set() 263 | 264 | # sort labels 265 | idxs_labels = np.vstack((idxs, labels)) 266 | idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]#索引值 267 | idxs = idxs_labels[0,:] 268 | 269 | # divide and assign 270 | for i in range(num_users): 271 | rand_set = set(np.random.choice(idx_shard, 2, replace=False)) 272 | idx_shard = list(set(idx_shard) - rand_set) 273 | for rand in rand_set: 274 | dict_users_unlabeled[i] = np.concatenate((dict_users_unlabeled[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0) 275 | 276 | dict_users_labeled = set(np.random.choice(list(idxs), int(len(idxs) * label_rate), replace=False)) 277 | 278 | for i in range(num_users): 279 | 280 | dict_users_unlabeled[i] = set(dict_users_unlabeled[i]) 281 | # dict_users_labeled = dict_users_labeled | set(np.random.choice(list(dict_users_unlabeled[i]), int(num_items * label_rate), replace=False)) 282 | dict_users_unlabeled[i] = dict_users_unlabeled[i] - dict_users_labeled 283 | 284 | 285 | return dict_users_labeled, dict_users_unlabeled 286 | 287 | class DatasetSplit(Dataset): 288 | def __init__(self, dataset, idxs): 289 | self.dataset = dataset 290 | self.idxs = list(idxs) 291 | 292 | def __len__(self): 293 | return len(self.idxs) 294 | 295 | def __getitem__(self, item): 296 | (images1, images2), labels = self.dataset[self.idxs[item]] 297 | return (images1, images2), labels 298 | 299 | def get_current_consistency_weight(epoch): 300 | return sigmoid_rampup(epoch, 10) 301 | 302 | def dist(w0, w1): 303 | dist_w = 0.0 304 | for i in list(w0.keys()) : 305 | dist_w += (w0[i]-w1[i]).float().norm(2)**2 306 | dist_w = math.sqrt(dist_w) 307 | return dist_w 308 | 309 | def main(device, args): 310 | 311 | 312 | loss1_func = nn.CrossEntropyLoss() 313 | loss2_func = softmax_kl_loss 314 | 315 | dataset_kwargs = { 316 | 'dataset':args.dataset, 317 | 'data_dir': args.data_dir, 318 | 'download':args.download, 319 | 'debug_subset_size':args.batch_size if args.debug else None 320 | } 321 | dataloader_kwargs = { 322 | 'batch_size': args.batch_size, 323 | 'drop_last': True, 324 | 'pin_memory': True, 325 | 'num_workers': args.num_workers, 326 | } 327 | dataloader_unlabeled_kwargs = { 328 | 'batch_size': args.batch_size*5, 329 | 'drop_last': True, 330 | 'pin_memory': True, 331 | 'num_workers': args.num_workers, 332 | } 333 | dataset_train =get_dataset( 334 | transform=get_aug_fedmatch(args.dataset, True), 335 | train=True, 336 | **dataset_kwargs 337 | ) 338 | 339 | if args.iid == 'iid': 340 | dict_users_labeled, dict_users_unlabeled = iid(dataset_train, args.num_users, args.label_rate) 341 | else: 342 | dict_users_labeled, dict_users_unlabeled = noniid(dataset_train, args.num_users, args.label_rate) 343 | train_loader_unlabeled = {} 344 | 345 | 346 | # define model 347 | model_glob = get_model('fedfixmatch', args.backbone).to(device) 348 | if torch.cuda.device_count() > 1: model_glob = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model_glob) 349 | 350 | 351 | model_local_idx = set() 352 | 353 | user_epoch = {} 354 | lr_scheduler = {} 355 | accuracy = [] 356 | class_criterion = nn.CrossEntropyLoss(size_average=False, ignore_index= -1 ) 357 | if args.dataset == 'cifar' and args.iid != 'noniid_tradition': 358 | consistency_criterion = softmax_kl_loss 359 | else: 360 | consistency_criterion = softmax_mse_loss 361 | 362 | for iter in range(args.num_epochs): 363 | 364 | model_glob.train() 365 | optimizer = torch.optim.SGD(model_glob.parameters(), lr=0.01, momentum=0.5) 366 | 367 | train_loader_labeled = torch.utils.data.DataLoader( 368 | dataset=DatasetSplit(dataset_train, dict_users_labeled), 369 | shuffle=True, 370 | **dataloader_kwargs 371 | ) 372 | 373 | for batch_idx, ((img, img_ema), label) in enumerate(train_loader_labeled): 374 | 375 | img, img_ema, label = img.to(args.device), img_ema.to(args.device), label.to(args.device) 376 | input_var = torch.autograd.Variable(img) 377 | ema_input_var = torch.autograd.Variable(img_ema, volatile=True) 378 | target_var = torch.autograd.Variable(label) 379 | minibatch_size = len(target_var) 380 | labeled_minibatch_size = target_var.data.ne(-1).sum() 381 | ema_model_out = model_glob(ema_input_var) 382 | model_out = model_glob(input_var) 383 | if isinstance(model_out, Variable): 384 | logit1 = model_out 385 | ema_logit = ema_model_out 386 | else: 387 | assert len(model_out) == 2 388 | assert len(ema_model_out) == 2 389 | logit1, logit2 = model_out 390 | ema_logit, _ = ema_model_out 391 | ema_logit = Variable(ema_logit.detach().data, requires_grad=False) 392 | class_logit, cons_logit = logit1, logit1 393 | classification_weight = 1 394 | class_loss = classification_weight * class_criterion(class_logit, target_var) / minibatch_size 395 | ema_class_loss = class_criterion(ema_logit, target_var) / minibatch_size 396 | consistency_weight = get_current_consistency_weight(iter) 397 | consistency_loss = consistency_weight * consistency_criterion(cons_logit, ema_logit) / minibatch_size 398 | loss = class_loss + consistency_loss 399 | optimizer.zero_grad() 400 | loss.backward() 401 | optimizer.step() 402 | 403 | del train_loader_labeled 404 | gc.collect() 405 | torch.cuda.empty_cache() 406 | 407 | if iter%1==0: 408 | test_loader = torch.utils.data.DataLoader( 409 | dataset=get_dataset( 410 | transform=get_aug(args.dataset, False, train_classifier=False), 411 | train=False, 412 | **dataset_kwargs), 413 | shuffle=False, 414 | **dataloader_kwargs 415 | ) 416 | model_glob.eval() 417 | acc, loss_train_test_labeled = test_img(model_glob, test_loader, args) 418 | accuracy.append(str(acc)) 419 | del test_loader 420 | gc.collect() 421 | torch.cuda.empty_cache() 422 | 423 | 424 | w_locals, loss_locals, loss0_locals, loss2_locals = [], [], [], [] 425 | 426 | m = max(int(args.frac * args.num_users), 1) 427 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 428 | 429 | for idx in idxs_users: 430 | if idx in user_epoch.keys(): 431 | user_epoch[idx] += 1 432 | else: 433 | user_epoch[idx] = 1 434 | 435 | loss_local = [] 436 | loss0_local = [] 437 | loss2_local = [] 438 | 439 | 440 | model_local = copy.deepcopy(model_glob).to(args.device) 441 | 442 | train_loader_unlabeled = torch.utils.data.DataLoader( 443 | dataset=DatasetSplit(dataset_train, dict_users_unlabeled[idx]), 444 | shuffle=True, 445 | **dataloader_unlabeled_kwargs 446 | ) 447 | 448 | optimizer = torch.optim.SGD(model_local.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False) 449 | 450 | model_local.train() 451 | 452 | 453 | for i, ((images1, images2), labels) in enumerate(train_loader_unlabeled): 454 | 455 | img, img_ema, label = img.to(args.device), img_ema.to(args.device), label.to(args.device) 456 | adjust_learning_rate(optimizer, user_epoch[idx], batch_idx, len(train_loader_unlabeled), args) 457 | input_var = torch.autograd.Variable(img) 458 | ema_input_var = torch.autograd.Variable(img_ema, volatile=True) 459 | target_var = torch.autograd.Variable(label) 460 | minibatch_size = len(target_var) 461 | labeled_minibatch_size = target_var.data.ne(-1).sum() 462 | ema_model_out = model_local(ema_input_var) 463 | model_out = model_local(input_var) 464 | if isinstance(model_out, Variable): 465 | logit1 = model_out 466 | ema_logit = ema_model_out 467 | else: 468 | assert len(model_out) == 2 469 | assert len(ema_model_out) == 2 470 | logit1, logit2 = model_out 471 | ema_logit, _ = ema_model_out 472 | ema_logit = Variable(ema_logit.detach().data, requires_grad=False) 473 | class_logit, cons_logit = logit1, logit1 474 | 475 | consistency_weight = get_current_consistency_weight(user_epoch[idx]) 476 | consistency_loss = consistency_weight * consistency_criterion(cons_logit, ema_logit) / minibatch_size 477 | 478 | Lprox = 1 / 2 * dist(model_local.state_dict(), model_glob.state_dict()) 479 | 480 | loss = consistency_loss + Lprox 481 | optimizer.zero_grad() 482 | loss.backward() 483 | optimizer.step() 484 | 485 | w_locals.append(copy.deepcopy(model_local.state_dict())) 486 | 487 | del model_local 488 | gc.collect() 489 | del train_loader_unlabeled 490 | gc.collect() 491 | torch.cuda.empty_cache() 492 | 493 | 494 | 495 | w_glob = FedAvg(w_locals) 496 | model_glob.load_state_dict(w_glob) 497 | 498 | # loss_avg = sum(loss_locals) / len(loss_locals) 499 | 500 | if iter%1==0: 501 | print('Round {:3d}, Acc {:.2f}%'.format(iter, acc)) 502 | 503 | if __name__ == "__main__": 504 | args = get_args() 505 | main(device=args.device, args=args) 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | -------------------------------------------------------------------------------- /models/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .simsiam import SimSiam 2 | from .byol import BYOL, global_net 3 | from .simclr import SimCLR 4 | from torchvision.models import resnet50, resnet18 5 | import torch 6 | from .backbones import * 7 | 8 | def get_backbone(backbone, castrate=True): 9 | backbone = eval(f"{backbone}()") 10 | 11 | # if castrate: 12 | # backbone.output_dim = backbone.fc1.in_features 13 | # backbone.fc1 = torch.nn.Identity() 14 | 15 | return backbone 16 | 17 | class CNNMnist(nn.Module): 18 | def __init__(self): 19 | super(CNNMnist, self).__init__() 20 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 21 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 22 | self.conv2_drop = nn.Dropout2d() 23 | self.fc1 = nn.Linear(320, 50) 24 | self.fc2 = nn.Linear(50, 10) 25 | 26 | def forward(self, x): 27 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 28 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 29 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 30 | x = F.relu(self.fc1(x)) 31 | x = F.dropout(x, training=self.training) 32 | x = self.fc2(x) 33 | return x 34 | 35 | 36 | class CNNCifar(nn.Module): 37 | def __init__(self): 38 | super(CNNCifar, self).__init__() 39 | 40 | self.conv_layer = nn.Sequential( 41 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1), 42 | nn.BatchNorm2d(32), 43 | nn.ReLU(inplace=True), 44 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), 45 | nn.ReLU(inplace=True), 46 | nn.MaxPool2d(kernel_size=2, stride=2), 47 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), 48 | nn.BatchNorm2d(128), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), 51 | nn.ReLU(inplace=True), 52 | nn.MaxPool2d(kernel_size=2, stride=2), 53 | nn.Dropout2d(p=0.05), 54 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), 55 | nn.BatchNorm2d(256), 56 | nn.ReLU(inplace=True), 57 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), 58 | nn.ReLU(inplace=True), 59 | nn.MaxPool2d(kernel_size=2, stride=2), 60 | ) 61 | self.fc_layer = nn.Sequential( 62 | nn.Dropout(p=0.1), 63 | nn.Linear(4096, 1024), 64 | nn.ReLU(inplace=True), 65 | nn.Linear(1024, 512), 66 | nn.ReLU(inplace=True), 67 | nn.Dropout(p=0.1), 68 | nn.Linear(512, 10) 69 | ) 70 | def forward(self, x): 71 | x = self.conv_layer(x) 72 | x = x.view(x.size(0), -1) 73 | x = self.fc_layer(x) 74 | return x 75 | 76 | def get_model(name, backbone): 77 | if name == 'local': 78 | model = BYOL(get_backbone(backbone)) 79 | elif name == 'global': 80 | model = global_net(get_backbone(backbone)) 81 | elif name == 'fedfixmatch' and backbone == 'Mnist': 82 | model = CNNMnist().to('cuda') 83 | elif name == 'fedfixmatch' and backbone == 'Cifar': 84 | model = CNNCifar().to('cuda') 85 | elif name == 'fedfixmatch' and backbone == 'Svhn': 86 | model = CNNCifar().to('cuda') 87 | else: 88 | raise NotImplementedError 89 | return model 90 | 91 | 92 | 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /models/.ipynb_checkpoints/backbones-checkpoint.py: -------------------------------------------------------------------------------- 1 | from torchvision.models.resnet import ResNet, Bottleneck, BasicBlock 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import Parameter 10 | 11 | 12 | class GroupNorm2d(nn.Module): 13 | def __init__(self, num_features, num_groups=32, eps=1e-5, affine=True): 14 | super(GroupNorm2d, self).__init__() 15 | self.num_groups = num_groups 16 | self.eps = eps 17 | self.num_features = num_features 18 | self.affine = affine 19 | 20 | if self.affine: 21 | self.weight = Parameter(torch.Tensor(1, num_features, 1, 1)) 22 | self.bias = Parameter(torch.Tensor(1, num_features, 1, 1)) 23 | 24 | self.reset_parameters() 25 | 26 | def reset_parameters(self): 27 | if self.affine: 28 | self.weight.data.fill_(1) 29 | self.bias.data.zero_() 30 | 31 | def forward(self, input): 32 | output = input.view(input.size(0), self.num_groups, -1) 33 | 34 | mean = output.mean(dim=2, keepdim=True) 35 | var = output.var(dim=2, keepdim=True) 36 | 37 | output = (output - mean) / (var + self.eps).sqrt() 38 | output = output.view_as(input) 39 | 40 | if self.affine: 41 | output = output * self.weight + self.bias 42 | 43 | return output 44 | 45 | 46 | class CNN_Mnist(nn.Module): 47 | def __init__(self): 48 | super(CNN_Mnist, self).__init__() 49 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 50 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 51 | self.conv2_drop = nn.Dropout2d() 52 | self.output_dim=320 53 | # self.fc1 = nn.Linear(320, 50) 54 | # self.fc2 = nn.Linear(50, 10) 55 | 56 | def forward(self, x): 57 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 58 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 59 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 60 | # x = F.relu(self.fc1(x)) 61 | # x = F.dropout(x, training=self.training) 62 | # x = self.fc2(x) 63 | return x 64 | 65 | class CNN_Cifar(nn.Module): 66 | def __init__(self): 67 | super(CNN_Cifar, self).__init__() 68 | 69 | self.conv_layer = nn.Sequential( 70 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1), 71 | nn.BatchNorm2d(32), 72 | nn.ReLU(inplace=True), 73 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), 74 | nn.ReLU(inplace=True), 75 | nn.MaxPool2d(kernel_size=2, stride=2), 76 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), 77 | nn.BatchNorm2d(128), 78 | nn.ReLU(inplace=True), 79 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), 80 | nn.ReLU(inplace=True), 81 | nn.MaxPool2d(kernel_size=2, stride=2), 82 | nn.Dropout2d(p=0.05), 83 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), 84 | nn.BatchNorm2d(256), 85 | nn.ReLU(inplace=True), 86 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), 87 | nn.ReLU(inplace=True), 88 | nn.MaxPool2d(kernel_size=2, stride=2), 89 | ) 90 | self.fc_layer = nn.Sequential( 91 | nn.Dropout(p=0.1), 92 | nn.Linear(4096, 1024), 93 | nn.ReLU(inplace=True), 94 | nn.Linear(1024, 512), 95 | nn.ReLU(inplace=True), 96 | nn.Dropout(p=0.1), 97 | nn.Linear(512, 10) 98 | ) 99 | self.output_dim=4096 100 | 101 | def forward(self, x): 102 | x = self.conv_layer(x) 103 | x = x.view(x.size(0), -1) 104 | # x = self.fc_layer(x.t()) 105 | return x 106 | 107 | class CNN_Svhn(nn.Module): 108 | def __init__(self): 109 | super(CNN_Svhn, self).__init__() 110 | 111 | self.conv_layer = nn.Sequential( 112 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1), 113 | nn.BatchNorm2d(32), 114 | nn.ReLU(inplace=True), 115 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), 116 | nn.ReLU(inplace=True), 117 | nn.MaxPool2d(kernel_size=2, stride=2), 118 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), 119 | nn.BatchNorm2d(128), 120 | nn.ReLU(inplace=True), 121 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), 122 | nn.ReLU(inplace=True), 123 | nn.MaxPool2d(kernel_size=2, stride=2), 124 | nn.Dropout2d(p=0.05), 125 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), 126 | nn.BatchNorm2d(256), 127 | nn.ReLU(inplace=True), 128 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), 129 | nn.ReLU(inplace=True), 130 | nn.MaxPool2d(kernel_size=2, stride=2), 131 | ) 132 | self.fc_layer = nn.Sequential( 133 | nn.Dropout(p=0.1), 134 | nn.Linear(4096, 1024), 135 | nn.ReLU(inplace=True), 136 | nn.Linear(1024, 512), 137 | nn.ReLU(inplace=True), 138 | nn.Dropout(p=0.1), 139 | nn.Linear(512, 10) 140 | ) 141 | self.output_dim=4096 142 | 143 | def forward(self, x): 144 | x = self.conv_layer(x) 145 | x = x.view(x.size(0), -1) 146 | # x = self.fc_layer(x.t()) 147 | return x 148 | 149 | def Mnist(**kwargs): 150 | return CNN_Mnist() 151 | 152 | def Cifar(**kwargs): 153 | return CNN_Cifar() 154 | 155 | def Svhn(**kwargs): 156 | return CNN_Svhn() -------------------------------------------------------------------------------- /models/.ipynb_checkpoints/byol-checkpoint.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torchvision import transforms 7 | from math import pi, cos 8 | from collections import OrderedDict 9 | 10 | HPS = dict( 11 | max_steps=int(1000. * 1281167 / 4096), # 1000 epochs * 1281167 samples / batch size = 100 epochs * N of step/epoch 12 | # = total_epochs * len(dataloader) 13 | mlp_hidden_size=4096*3, 14 | projection_size=4096, 15 | base_target_ema_g=4e-4,#4e-4 16 | base_target_ema_l=4e-4, 17 | optimizer_config=dict( 18 | optimizer_name='lars', 19 | beta=0.9, 20 | trust_coef=1e-3, 21 | weight_decay=1.5e-6, 22 | exclude_bias_from_adaption=True), 23 | learning_rate_schedule=dict( 24 | base_learning_rate=0.2, 25 | warmup_steps=int(10.0 * 1281167 / 4096), # 10 epochs * N of steps/epoch = 10 epochs * len(dataloader) 26 | anneal_schedule='cosine'), 27 | batchnorm_kwargs=dict( 28 | decay_rate=0.9, 29 | eps=1e-5), 30 | seed=1337, 31 | ) 32 | 33 | 34 | # def loss_fn(x, y, version='simplified'): 35 | 36 | # if version == 'original': 37 | # y = y.detach() 38 | # x = F.normalize(x, dim=-1, p=2) 39 | # y = F.normalize(y, dim=-1, p=2) 40 | # return (2 - 2 * (x * y).sum(dim=-1)).mean() 41 | # elif version == 'simplified': 42 | # return (2 - 2 * F.cosine_similarity(x,y.detach(), dim=-1)).mean() 43 | # else: 44 | # raise NotImplementedError 45 | 46 | from .simsiam import D # a bit different but it's essentially the same thing: neg cosine sim & stop gradient 47 | 48 | 49 | class MLP(nn.Module): 50 | def __init__(self, in_dim): 51 | super().__init__() 52 | 53 | self.layer1 = nn.Sequential( 54 | nn.Linear(in_dim, HPS['mlp_hidden_size']), 55 | nn.BatchNorm1d(HPS['mlp_hidden_size'], eps=HPS['batchnorm_kwargs']['eps'], momentum=1-HPS['batchnorm_kwargs']['decay_rate']), 56 | nn.ReLU(inplace=True) 57 | ) 58 | self.layer2 = nn.Linear(HPS['mlp_hidden_size'], in_dim) 59 | 60 | def forward(self, x): 61 | x = self.layer1(x) 62 | x = self.layer2(x) 63 | return x 64 | 65 | class fc_Mnist(nn.Module): 66 | def __init__(self, in_dim): 67 | super(fc_Mnist, self).__init__() 68 | self.fc1 = nn.Linear(320, 50) 69 | self.fc2 = nn.Linear(50, 10) 70 | 71 | def forward(self, x): 72 | x = F.relu(self.fc1(x)) 73 | x = F.dropout(x, training=self.training) 74 | x = self.fc2(x) 75 | return x 76 | 77 | class fc_Cifar(nn.Module): 78 | def __init__(self, in_dim): 79 | super(fc_Cifar, self).__init__() 80 | 81 | self.fc_layer = nn.Sequential( 82 | nn.Dropout(p=0.1), 83 | nn.Linear(4096, 1024), 84 | nn.ReLU(inplace=True), 85 | nn.Linear(1024, 512), 86 | nn.ReLU(inplace=True), 87 | # nn.Dropout(p=0.1), 88 | nn.Linear(512, 10) 89 | ) 90 | def forward(self, x): 91 | x = self.fc_layer(x) 92 | return x 93 | 94 | class global_net(nn.Module): 95 | def __init__(self, backbone): 96 | super().__init__() 97 | 98 | self.backbone = backbone 99 | if backbone.output_dim == 320: 100 | self.fc = fc_Mnist(backbone.output_dim) 101 | else: 102 | self.fc = fc_Cifar(backbone.output_dim) 103 | 104 | self.teacher = nn.Sequential( 105 | self.backbone, 106 | self.fc 107 | ) 108 | 109 | self.student = copy.deepcopy(self.teacher) 110 | 111 | 112 | def target_ema(self, k, K, base_ema=HPS['base_target_ema_g']): 113 | return 1 - base_ema * (cos(pi*k/K)+1)/2 114 | 115 | @torch.no_grad() 116 | def update_moving_average(self, global_step, max_steps): 117 | tau = self.target_ema(global_step, max_steps) 118 | for online, target in zip(self.teacher.parameters(), self.student.parameters()): 119 | target.data = tau * target.data + (1 - tau) * online.data 120 | 121 | def forward(self, x1, x2): 122 | t, s = self.teacher, self.student 123 | 124 | z1_t = t(x1) 125 | z2_t = t(x2) 126 | 127 | 128 | with torch.no_grad(): 129 | z1_s = t(x1) 130 | z2_s = t(x2) 131 | 132 | return z1_t, z2_t, z1_s, z2_s 133 | 134 | 135 | 136 | class BYOL(nn.Module): 137 | def __init__(self, backbone): 138 | super().__init__() 139 | 140 | self.backbone = backbone 141 | self.projector = MLP(backbone.output_dim) 142 | 143 | self.online_encoder = nn.Sequential( 144 | self.backbone 145 | ) 146 | 147 | self.target_encoder = copy.deepcopy(self.online_encoder) 148 | 149 | 150 | def target_ema(self, k, K, base_ema=HPS['base_target_ema_l']): 151 | return 1 - base_ema * (cos(pi*k/K)+1)/2 152 | 153 | @torch.no_grad() 154 | def update_moving_average(self): 155 | tau = 0.999#self.target_ema(global_step, max_steps) 156 | for online, target in zip(self.online_encoder.parameters(), self.target_encoder.parameters()): 157 | target.data = tau * target.data + (1 - tau) * online.data 158 | 159 | def forward(self, x1, x2): 160 | f_o, h_o = self.online_encoder, self.projector 161 | f_t = self.target_encoder 162 | 163 | z1_o = f_o(x1) 164 | z2_o = f_o(x2) 165 | 166 | p1_o = h_o(z1_o) 167 | p2_o = h_o(z2_o) 168 | 169 | with torch.no_grad(): 170 | z1_t = f_t(x1) 171 | z2_t = f_t(x2) 172 | 173 | L = D(p1_o, z2_t) / 2 + D(p2_o, z1_t) / 2 174 | 175 | return L 176 | 177 | 178 | 179 | if __name__ == "__main__": 180 | pass -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .simsiam import SimSiam 2 | from .byol import BYOL, global_net 3 | from .simclr import SimCLR 4 | from torchvision.models import resnet50, resnet18 5 | import torch 6 | from .backbones import * 7 | 8 | def get_backbone(backbone, castrate=True): 9 | backbone = eval(f"{backbone}()") 10 | 11 | # if castrate: 12 | # backbone.output_dim = backbone.fc1.in_features 13 | # backbone.fc1 = torch.nn.Identity() 14 | 15 | return backbone 16 | 17 | class CNNMnist(nn.Module): 18 | def __init__(self): 19 | super(CNNMnist, self).__init__() 20 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 21 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 22 | self.conv2_drop = nn.Dropout2d() 23 | self.fc1 = nn.Linear(320, 50) 24 | self.fc2 = nn.Linear(50, 10) 25 | 26 | def forward(self, x): 27 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 28 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 29 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 30 | x = F.relu(self.fc1(x)) 31 | x = F.dropout(x, training=self.training) 32 | x = self.fc2(x) 33 | return x 34 | 35 | 36 | class CNNCifar(nn.Module): 37 | def __init__(self): 38 | super(CNNCifar, self).__init__() 39 | 40 | self.conv_layer = nn.Sequential( 41 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1), 42 | nn.BatchNorm2d(32), 43 | nn.ReLU(inplace=True), 44 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), 45 | nn.ReLU(inplace=True), 46 | nn.MaxPool2d(kernel_size=2, stride=2), 47 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), 48 | nn.BatchNorm2d(128), 49 | nn.ReLU(inplace=True), 50 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), 51 | nn.ReLU(inplace=True), 52 | nn.MaxPool2d(kernel_size=2, stride=2), 53 | nn.Dropout2d(p=0.05), 54 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), 55 | nn.BatchNorm2d(256), 56 | nn.ReLU(inplace=True), 57 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), 58 | nn.ReLU(inplace=True), 59 | nn.MaxPool2d(kernel_size=2, stride=2), 60 | ) 61 | self.fc_layer = nn.Sequential( 62 | nn.Dropout(p=0.1), 63 | nn.Linear(4096, 1024), 64 | nn.ReLU(inplace=True), 65 | nn.Linear(1024, 512), 66 | nn.ReLU(inplace=True), 67 | nn.Dropout(p=0.1), 68 | nn.Linear(512, 10) 69 | ) 70 | def forward(self, x): 71 | x = self.conv_layer(x) 72 | x = x.view(x.size(0), -1) 73 | x = self.fc_layer(x) 74 | return x 75 | 76 | def get_model(name, backbone): 77 | if name == 'local': 78 | model = BYOL(get_backbone(backbone)) 79 | elif name == 'global': 80 | model = global_net(get_backbone(backbone)) 81 | elif name == 'fedfixmatch' and backbone == 'Mnist': 82 | model = CNNMnist().to('cuda') 83 | elif name == 'fedfixmatch' and backbone == 'Cifar': 84 | model = CNNCifar().to('cuda') 85 | elif name == 'fedfixmatch' and backbone == 'Svhn': 86 | model = CNNCifar().to('cuda') 87 | else: 88 | raise NotImplementedError 89 | return model 90 | 91 | 92 | 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/backbones.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/models/__pycache__/backbones.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/byol.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/models/__pycache__/byol.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/simclr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/models/__pycache__/simclr.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/simsiam.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/models/__pycache__/simsiam.cpython-37.pyc -------------------------------------------------------------------------------- /models/backbones.py: -------------------------------------------------------------------------------- 1 | from torchvision.models.resnet import ResNet, Bottleneck, BasicBlock 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import Parameter 10 | 11 | 12 | class GroupNorm2d(nn.Module): 13 | def __init__(self, num_features, num_groups=32, eps=1e-5, affine=True): 14 | super(GroupNorm2d, self).__init__() 15 | self.num_groups = num_groups 16 | self.eps = eps 17 | self.num_features = num_features 18 | self.affine = affine 19 | 20 | if self.affine: 21 | self.weight = Parameter(torch.Tensor(1, num_features, 1, 1)) 22 | self.bias = Parameter(torch.Tensor(1, num_features, 1, 1)) 23 | 24 | self.reset_parameters() 25 | 26 | def reset_parameters(self): 27 | if self.affine: 28 | self.weight.data.fill_(1) 29 | self.bias.data.zero_() 30 | 31 | def forward(self, input): 32 | output = input.view(input.size(0), self.num_groups, -1) 33 | 34 | mean = output.mean(dim=2, keepdim=True) 35 | var = output.var(dim=2, keepdim=True) 36 | 37 | output = (output - mean) / (var + self.eps).sqrt() 38 | output = output.view_as(input) 39 | 40 | if self.affine: 41 | output = output * self.weight + self.bias 42 | 43 | return output 44 | 45 | 46 | class CNN_Mnist(nn.Module): 47 | def __init__(self): 48 | super(CNN_Mnist, self).__init__() 49 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 50 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 51 | self.conv2_drop = nn.Dropout2d() 52 | self.output_dim=320 53 | # self.fc1 = nn.Linear(320, 50) 54 | # self.fc2 = nn.Linear(50, 10) 55 | 56 | def forward(self, x): 57 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 58 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 59 | x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]) 60 | # x = F.relu(self.fc1(x)) 61 | # x = F.dropout(x, training=self.training) 62 | # x = self.fc2(x) 63 | return x 64 | 65 | class CNN_Cifar(nn.Module): 66 | def __init__(self): 67 | super(CNN_Cifar, self).__init__() 68 | 69 | self.conv_layer = nn.Sequential( 70 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1), 71 | nn.BatchNorm2d(32), 72 | nn.ReLU(inplace=True), 73 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), 74 | nn.ReLU(inplace=True), 75 | nn.MaxPool2d(kernel_size=2, stride=2), 76 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), 77 | nn.BatchNorm2d(128), 78 | nn.ReLU(inplace=True), 79 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), 80 | nn.ReLU(inplace=True), 81 | nn.MaxPool2d(kernel_size=2, stride=2), 82 | nn.Dropout2d(p=0.05), 83 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), 84 | nn.BatchNorm2d(256), 85 | nn.ReLU(inplace=True), 86 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), 87 | nn.ReLU(inplace=True), 88 | nn.MaxPool2d(kernel_size=2, stride=2), 89 | ) 90 | self.fc_layer = nn.Sequential( 91 | nn.Dropout(p=0.1), 92 | nn.Linear(4096, 1024), 93 | nn.ReLU(inplace=True), 94 | nn.Linear(1024, 512), 95 | nn.ReLU(inplace=True), 96 | nn.Dropout(p=0.1), 97 | nn.Linear(512, 10) 98 | ) 99 | self.output_dim=4096 100 | 101 | def forward(self, x): 102 | x = self.conv_layer(x) 103 | x = x.view(x.size(0), -1) 104 | # x = self.fc_layer(x.t()) 105 | return x 106 | 107 | class CNN_Svhn(nn.Module): 108 | def __init__(self): 109 | super(CNN_Svhn, self).__init__() 110 | 111 | self.conv_layer = nn.Sequential( 112 | nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1), 113 | nn.BatchNorm2d(32), 114 | nn.ReLU(inplace=True), 115 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), 116 | nn.ReLU(inplace=True), 117 | nn.MaxPool2d(kernel_size=2, stride=2), 118 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1), 119 | nn.BatchNorm2d(128), 120 | nn.ReLU(inplace=True), 121 | nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1), 122 | nn.ReLU(inplace=True), 123 | nn.MaxPool2d(kernel_size=2, stride=2), 124 | nn.Dropout2d(p=0.05), 125 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1), 126 | nn.BatchNorm2d(256), 127 | nn.ReLU(inplace=True), 128 | nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1), 129 | nn.ReLU(inplace=True), 130 | nn.MaxPool2d(kernel_size=2, stride=2), 131 | ) 132 | self.fc_layer = nn.Sequential( 133 | nn.Dropout(p=0.1), 134 | nn.Linear(4096, 1024), 135 | nn.ReLU(inplace=True), 136 | nn.Linear(1024, 512), 137 | nn.ReLU(inplace=True), 138 | nn.Dropout(p=0.1), 139 | nn.Linear(512, 10) 140 | ) 141 | self.output_dim=4096 142 | 143 | def forward(self, x): 144 | x = self.conv_layer(x) 145 | x = x.view(x.size(0), -1) 146 | # x = self.fc_layer(x.t()) 147 | return x 148 | 149 | def Mnist(**kwargs): 150 | return CNN_Mnist() 151 | 152 | def Cifar(**kwargs): 153 | return CNN_Cifar() 154 | 155 | def Svhn(**kwargs): 156 | return CNN_Svhn() -------------------------------------------------------------------------------- /models/byol.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torchvision import transforms 7 | from math import pi, cos 8 | from collections import OrderedDict 9 | 10 | HPS = dict( 11 | max_steps=int(1000. * 1281167 / 4096), # 1000 epochs * 1281167 samples / batch size = 100 epochs * N of step/epoch 12 | # = total_epochs * len(dataloader) 13 | mlp_hidden_size=4096*3, 14 | projection_size=4096, 15 | base_target_ema_g=4e-4,#4e-4 16 | base_target_ema_l=4e-4, 17 | optimizer_config=dict( 18 | optimizer_name='lars', 19 | beta=0.9, 20 | trust_coef=1e-3, 21 | weight_decay=1.5e-6, 22 | exclude_bias_from_adaption=True), 23 | learning_rate_schedule=dict( 24 | base_learning_rate=0.2, 25 | warmup_steps=int(10.0 * 1281167 / 4096), # 10 epochs * N of steps/epoch = 10 epochs * len(dataloader) 26 | anneal_schedule='cosine'), 27 | batchnorm_kwargs=dict( 28 | decay_rate=0.9, 29 | eps=1e-5), 30 | seed=1337, 31 | ) 32 | 33 | 34 | # def loss_fn(x, y, version='simplified'): 35 | 36 | # if version == 'original': 37 | # y = y.detach() 38 | # x = F.normalize(x, dim=-1, p=2) 39 | # y = F.normalize(y, dim=-1, p=2) 40 | # return (2 - 2 * (x * y).sum(dim=-1)).mean() 41 | # elif version == 'simplified': 42 | # return (2 - 2 * F.cosine_similarity(x,y.detach(), dim=-1)).mean() 43 | # else: 44 | # raise NotImplementedError 45 | 46 | from .simsiam import D # a bit different but it's essentially the same thing: neg cosine sim & stop gradient 47 | 48 | 49 | class MLP(nn.Module): 50 | def __init__(self, in_dim): 51 | super().__init__() 52 | 53 | self.layer1 = nn.Sequential( 54 | nn.Linear(in_dim, HPS['mlp_hidden_size']), 55 | nn.BatchNorm1d(HPS['mlp_hidden_size'], eps=HPS['batchnorm_kwargs']['eps'], momentum=1-HPS['batchnorm_kwargs']['decay_rate']), 56 | nn.ReLU(inplace=True) 57 | ) 58 | self.layer2 = nn.Linear(HPS['mlp_hidden_size'], in_dim) 59 | 60 | def forward(self, x): 61 | x = self.layer1(x) 62 | x = self.layer2(x) 63 | return x 64 | 65 | class fc_Mnist(nn.Module): 66 | def __init__(self, in_dim): 67 | super(fc_Mnist, self).__init__() 68 | self.fc1 = nn.Linear(320, 50) 69 | self.fc2 = nn.Linear(50, 10) 70 | 71 | def forward(self, x): 72 | x = F.relu(self.fc1(x)) 73 | x = F.dropout(x, training=self.training) 74 | x = self.fc2(x) 75 | return x 76 | 77 | class fc_Cifar(nn.Module): 78 | def __init__(self, in_dim): 79 | super(fc_Cifar, self).__init__() 80 | 81 | self.fc_layer = nn.Sequential( 82 | nn.Dropout(p=0.1), 83 | nn.Linear(4096, 1024), 84 | nn.ReLU(inplace=True), 85 | nn.Linear(1024, 512), 86 | nn.ReLU(inplace=True), 87 | # nn.Dropout(p=0.1), 88 | nn.Linear(512, 10) 89 | ) 90 | def forward(self, x): 91 | x = self.fc_layer(x) 92 | return x 93 | 94 | class global_net(nn.Module): 95 | def __init__(self, backbone): 96 | super().__init__() 97 | 98 | self.backbone = backbone 99 | if backbone.output_dim == 320: 100 | self.fc = fc_Mnist(backbone.output_dim) 101 | else: 102 | self.fc = fc_Cifar(backbone.output_dim) 103 | 104 | self.teacher = nn.Sequential( 105 | self.backbone, 106 | self.fc 107 | ) 108 | 109 | self.student = copy.deepcopy(self.teacher) 110 | 111 | 112 | def target_ema(self, k, K, base_ema=HPS['base_target_ema_g']): 113 | return 1 - base_ema * (cos(pi*k/K)+1)/2 114 | 115 | @torch.no_grad() 116 | def update_moving_average(self, global_step, max_steps): 117 | tau = self.target_ema(global_step, max_steps) 118 | for online, target in zip(self.teacher.parameters(), self.student.parameters()): 119 | target.data = tau * target.data + (1 - tau) * online.data 120 | 121 | def forward(self, x1, x2): 122 | t, s = self.teacher, self.student 123 | 124 | z1_t = t(x1) 125 | z2_t = t(x2) 126 | 127 | 128 | with torch.no_grad(): 129 | z1_s = t(x1) 130 | z2_s = t(x2) 131 | 132 | return z1_t, z2_t, z1_s, z2_s 133 | 134 | 135 | 136 | class BYOL(nn.Module): 137 | def __init__(self, backbone): 138 | super().__init__() 139 | 140 | self.backbone = backbone 141 | self.projector = MLP(backbone.output_dim) 142 | 143 | self.online_encoder = nn.Sequential( 144 | self.backbone 145 | ) 146 | 147 | self.target_encoder = copy.deepcopy(self.online_encoder) 148 | 149 | 150 | def target_ema(self, k, K, base_ema=HPS['base_target_ema_l']): 151 | return 1 - base_ema * (cos(pi*k/K)+1)/2 152 | 153 | @torch.no_grad() 154 | def update_moving_average(self): 155 | tau = 0.999#self.target_ema(global_step, max_steps) 156 | for online, target in zip(self.online_encoder.parameters(), self.target_encoder.parameters()): 157 | target.data = tau * target.data + (1 - tau) * online.data 158 | 159 | def forward(self, x1, x2): 160 | f_o, h_o = self.online_encoder, self.projector 161 | f_t = self.target_encoder 162 | 163 | z1_o = f_o(x1) 164 | z2_o = f_o(x2) 165 | 166 | p1_o = h_o(z1_o) 167 | p2_o = h_o(z2_o) 168 | 169 | with torch.no_grad(): 170 | z1_t = f_t(x1) 171 | z2_t = f_t(x2) 172 | 173 | L = D(p1_o, z2_t) / 2 + D(p2_o, z1_t) / 2 174 | 175 | return L 176 | 177 | 178 | 179 | if __name__ == "__main__": 180 | pass -------------------------------------------------------------------------------- /models/simclr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models import resnet50 5 | 6 | def NT_XentLoss(z1, z2, temperature=0.5): 7 | z1 = F.normalize(z1, dim=1) 8 | z2 = F.normalize(z2, dim=1) 9 | N, Z = z1.shape 10 | device = z1.device 11 | representations = torch.cat([z1, z2], dim=0) 12 | similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1) 13 | l_pos = torch.diag(similarity_matrix, N) 14 | r_pos = torch.diag(similarity_matrix, -N) 15 | positives = torch.cat([l_pos, r_pos]).view(2 * N, 1) 16 | diag = torch.eye(2*N, dtype=torch.bool, device=device) 17 | diag[N:,:N] = diag[:N,N:] = diag[:N,:N] 18 | 19 | negatives = similarity_matrix[~diag].view(2*N, -1) 20 | 21 | logits = torch.cat([positives, negatives], dim=1) 22 | logits /= temperature 23 | 24 | labels = torch.zeros(2*N, device=device, dtype=torch.int64) 25 | 26 | loss = F.cross_entropy(logits, labels, reduction='sum') 27 | return loss / (2 * N) 28 | 29 | 30 | class projection_MLP(nn.Module): 31 | def __init__(self, in_dim, out_dim=256): 32 | super().__init__() 33 | hidden_dim = in_dim 34 | self.layer1 = nn.Sequential( 35 | nn.Linear(in_dim, hidden_dim), 36 | nn.ReLU(inplace=True) 37 | ) 38 | self.layer2 = nn.Linear(hidden_dim, out_dim) 39 | def forward(self, x): 40 | x = self.layer1(x) 41 | x = self.layer2(x) 42 | return x 43 | 44 | class SimCLR(nn.Module): 45 | 46 | def __init__(self, backbone=resnet50()): 47 | super().__init__() 48 | 49 | self.backbone = backbone 50 | self.projector = projection_MLP(backbone.output_dim) 51 | self.encoder = nn.Sequential( 52 | self.backbone, 53 | self.projector 54 | ) 55 | 56 | 57 | 58 | def forward(self, x1, x2): 59 | z1 = self.encoder(x1) 60 | z2 = self.encoder(x2) 61 | 62 | loss = NT_XentLoss(z1, z2) 63 | return loss 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /models/simsiam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models import resnet50 5 | 6 | 7 | def D(p, z, version='simplified'): # negative cosine similarity 8 | if version == 'original': 9 | z = z.detach() # stop gradient 10 | p = F.normalize(p, dim=1) # l2-normalize 11 | z = F.normalize(z, dim=1) # l2-normalize 12 | return -(p*z).sum(dim=1).mean() 13 | 14 | elif version == 'simplified':# same thing, much faster. Scroll down, speed test in __main__ 15 | return - F.cosine_similarity(p, z.detach(), dim=-1).mean() 16 | else: 17 | raise Exception 18 | 19 | 20 | 21 | class projection_MLP(nn.Module): 22 | def __init__(self, in_dim, hidden_dim=2048, out_dim=2048): 23 | super().__init__() 24 | ''' page 3 baseline setting 25 | Projection MLP. The projection MLP (in f) has BN ap- 26 | plied to each fully-connected (fc) layer, including its out- 27 | put fc. Its output fc has no ReLU. The hidden fc is 2048-d. 28 | This MLP has 3 layers. 29 | ''' 30 | self.layer1 = nn.Sequential( 31 | nn.Linear(in_dim, hidden_dim), 32 | nn.BatchNorm1d(hidden_dim), 33 | nn.ReLU(inplace=True) 34 | ) 35 | self.layer2 = nn.Sequential( 36 | nn.Linear(hidden_dim, hidden_dim), 37 | nn.BatchNorm1d(hidden_dim), 38 | nn.ReLU(inplace=True) 39 | ) 40 | self.layer3 = nn.Sequential( 41 | nn.Linear(hidden_dim, out_dim), 42 | nn.BatchNorm1d(hidden_dim) 43 | ) 44 | self.num_layers = 3 45 | def set_layers(self, num_layers): 46 | self.num_layers = num_layers 47 | 48 | def forward(self, x): 49 | if self.num_layers == 3: 50 | x = self.layer1(x) 51 | x = self.layer2(x) 52 | x = self.layer3(x) 53 | elif self.num_layers == 2: 54 | x = self.layer1(x) 55 | x = self.layer3(x) 56 | else: 57 | raise Exception 58 | return x 59 | 60 | 61 | class prediction_MLP(nn.Module): 62 | def __init__(self, in_dim=2048, hidden_dim=512, out_dim=2048): # bottleneck structure 63 | super().__init__() 64 | ''' page 3 baseline setting 65 | Prediction MLP. The prediction MLP (h) has BN applied 66 | to its hidden fc layers. Its output fc does not have BN 67 | (ablation in Sec. 4.4) or ReLU. This MLP has 2 layers. 68 | The dimension of h’s input and output (z and p) is d = 2048, 69 | and h’s hidden layer’s dimension is 512, making h a 70 | bottleneck structure (ablation in supplement). 71 | ''' 72 | self.layer1 = nn.Sequential( 73 | nn.Linear(in_dim, hidden_dim), 74 | nn.BatchNorm1d(hidden_dim), 75 | nn.ReLU(inplace=True) 76 | ) 77 | self.layer2 = nn.Linear(hidden_dim, out_dim) 78 | """ 79 | Adding BN to the output of the prediction MLP h does not work 80 | well (Table 3d). We find that this is not about collapsing. 81 | The training is unstable and the loss oscillates. 82 | """ 83 | 84 | def forward(self, x): 85 | x = self.layer1(x) 86 | x = self.layer2(x) 87 | return x 88 | 89 | class SimSiam(nn.Module): 90 | def __init__(self, backbone=resnet50()): 91 | super().__init__() 92 | 93 | self.backbone = backbone 94 | self.projector = projection_MLP(backbone.output_dim) 95 | 96 | self.encoder = nn.Sequential( # f encoder 97 | self.backbone, 98 | self.projector 99 | ) 100 | self.predictor = prediction_MLP() 101 | 102 | def forward(self, x1, x2): 103 | 104 | f, h = self.encoder, self.predictor 105 | z1, z2 = f(x1), f(x2) 106 | p1, p2 = h(z1), h(z2) 107 | L = D(p1, z2) / 2 + D(p2, z1) / 2 108 | return L 109 | 110 | 111 | 112 | 113 | 114 | 115 | if __name__ == "__main__": 116 | model = SimSiam() 117 | x1 = torch.randn((2, 3, 224, 224)) 118 | x2 = torch.randn_like(x1) 119 | 120 | model.forward(x1, x2).backward() 121 | print("forward backwork check") 122 | 123 | z1 = torch.randn((200, 2560)) 124 | z2 = torch.randn_like(z1) 125 | import time 126 | tic = time.time() 127 | print(D(z1, z2, version='original')) 128 | toc = time.time() 129 | print(toc - tic) 130 | tic = time.time() 131 | print(D(z1, z2, version='simplified')) 132 | toc = time.time() 133 | print(toc - tic) 134 | 135 | # Output: 136 | # tensor(-0.0010) 137 | # 0.005159854888916016 138 | # tensor(-0.0010) 139 | # 0.0014872550964355469 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /models/swav.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision.models import resnet50 5 | 6 | class SwAV(nn.Module): 7 | def __init__(self, backbone=resnet50()): 8 | super().__init__() 9 | 10 | backbone.fc = nn.Identity() 11 | self.backbone = backbone 12 | 13 | def forward(self, x1, x2): 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .lars import LARS 2 | from .lars_simclr import LARS_simclr 3 | from .larc import LARC 4 | import torch 5 | from .lr_scheduler import LR_Scheduler 6 | 7 | 8 | def get_optimizer(name, model, lr, momentum, weight_decay): 9 | 10 | predictor_prefix = ('module.predictor', 'predictor') 11 | parameters = [{ 12 | 'name': 'base', 13 | 'params': [param for name, param in model.named_parameters() if not name.startswith(predictor_prefix)], 14 | 'lr': lr 15 | },{ 16 | 'name': 'predictor', 17 | 'params': [param for name, param in model.named_parameters() if name.startswith(predictor_prefix)], 18 | 'lr': lr 19 | }] 20 | if name == 'lars': 21 | optimizer = LARS(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay) 22 | elif name == 'sgd': 23 | optimizer = torch.optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay) 24 | elif name == 'lars_simclr': # Careful 25 | optimizer = LARS_simclr(model.named_modules(), lr=lr, momentum=momentum, weight_decay=weight_decay) 26 | elif name == 'larc': 27 | optimizer = LARC( 28 | torch.optim.SGD( 29 | parameters, 30 | lr=lr, 31 | momentum=momentum, 32 | weight_decay=weight_decay 33 | ), 34 | trust_coefficient=0.001, 35 | clip=False 36 | ) 37 | else: 38 | raise NotImplementedError 39 | return optimizer 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /optimizers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/optimizers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /optimizers/__pycache__/larc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/optimizers/__pycache__/larc.cpython-37.pyc -------------------------------------------------------------------------------- /optimizers/__pycache__/lars.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/optimizers/__pycache__/lars.cpython-37.pyc -------------------------------------------------------------------------------- /optimizers/__pycache__/lars_simclr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/optimizers/__pycache__/lars_simclr.cpython-37.pyc -------------------------------------------------------------------------------- /optimizers/__pycache__/lr_scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/optimizers/__pycache__/lr_scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /optimizers/larc.py: -------------------------------------------------------------------------------- 1 | """SwAV use larc instead of lars optimizer""" 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn.parameter import Parameter 6 | from torch.optim.optimizer import Optimizer 7 | 8 | def main(): # Example 9 | import torchvision 10 | model = torchvision.models.resnet18(pretrained=False) 11 | # optim = torch.optim.Adam(model.parameters(), lr=0.0001) 12 | optim = torch.optim.SGD(model.parameters(),lr=0.2, momentum=0.9, weight_decay=1.5e-6) 13 | optim = LARC(optim) 14 | 15 | class LARC(Optimizer): 16 | """ 17 | :class:`LARC` is a pytorch implementation of both the scaling and clipping variants of LARC, 18 | in which the ratio between gradient and parameter magnitudes is used to calculate an adaptive 19 | local learning rate for each individual parameter. The algorithm is designed to improve 20 | convergence of large batch training. 21 | 22 | See https://arxiv.org/abs/1708.03888 for calculation of the local learning rate. 23 | In practice it modifies the gradients of parameters as a proxy for modifying the learning rate 24 | of the parameters. This design allows it to be used as a wrapper around any torch.optim Optimizer. 25 | ``` 26 | model = ... 27 | optim = torch.optim.Adam(model.parameters(), lr=...) 28 | optim = LARC(optim) 29 | ``` 30 | It can even be used in conjunction with apex.fp16_utils.FP16_optimizer. 31 | ``` 32 | model = ... 33 | optim = torch.optim.Adam(model.parameters(), lr=...) 34 | optim = LARC(optim) 35 | optim = apex.fp16_utils.FP16_Optimizer(optim) 36 | ``` 37 | Args: 38 | optimizer: Pytorch optimizer to wrap and modify learning rate for. 39 | trust_coefficient: Trust coefficient for calculating the lr. See https://arxiv.org/abs/1708.03888 40 | clip: Decides between clipping or scaling mode of LARC. If `clip=True` the learning rate is set to `min(optimizer_lr, local_lr)` for each parameter. If `clip=False` the learning rate is set to `local_lr*optimizer_lr`. 41 | eps: epsilon kludge to help with numerical stability while calculating adaptive_lr 42 | """ 43 | 44 | def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8): 45 | self.optim = optimizer 46 | self.trust_coefficient = trust_coefficient 47 | self.eps = eps 48 | self.clip = clip 49 | 50 | def __getstate__(self): 51 | return self.optim.__getstate__() 52 | 53 | def __setstate__(self, state): 54 | self.optim.__setstate__(state) 55 | 56 | @property 57 | def state(self): 58 | return self.optim.state 59 | 60 | def __repr__(self): 61 | return self.optim.__repr__() 62 | 63 | @property 64 | def param_groups(self): 65 | return self.optim.param_groups 66 | 67 | @param_groups.setter 68 | def param_groups(self, value): 69 | self.optim.param_groups = value 70 | 71 | def state_dict(self): 72 | return self.optim.state_dict() 73 | 74 | def load_state_dict(self, state_dict): 75 | self.optim.load_state_dict(state_dict) 76 | 77 | def zero_grad(self): 78 | self.optim.zero_grad() 79 | 80 | def add_param_group(self, param_group): 81 | self.optim.add_param_group( param_group) 82 | 83 | def step(self): 84 | with torch.no_grad(): 85 | weight_decays = [] 86 | for group in self.optim.param_groups: 87 | # absorb weight decay control from optimizer 88 | weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 89 | weight_decays.append(weight_decay) 90 | group['weight_decay'] = 0 91 | for p in group['params']: 92 | if p.grad is None: 93 | continue 94 | param_norm = torch.norm(p.data) 95 | grad_norm = torch.norm(p.grad.data) 96 | 97 | if param_norm != 0 and grad_norm != 0: 98 | # calculate adaptive lr + weight decay 99 | adaptive_lr = self.trust_coefficient * (param_norm) / (grad_norm + param_norm * weight_decay + self.eps) 100 | 101 | # clip learning rate for LARC 102 | if self.clip: 103 | # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)` 104 | adaptive_lr = min(adaptive_lr/group['lr'], 1) 105 | 106 | p.grad.data += weight_decay * p.data 107 | p.grad.data *= adaptive_lr 108 | 109 | self.optim.step() 110 | # return weight decay control to optimizer 111 | for i, group in enumerate(self.optim.param_groups): 112 | group['weight_decay'] = weight_decays[i] 113 | 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /optimizers/lars.py: -------------------------------------------------------------------------------- 1 | """ Layer-wise adaptive rate scaling for SGD in PyTorch! """ 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | class LARS(Optimizer): 6 | r"""Implements layer-wise adaptive rate scaling for SGD. 7 | 8 | Args: 9 | params (iterable): iterable of parameters to optimize or dicts defining 10 | parameter groups 11 | lr (float): base learning rate (\gamma_0) 12 | momentum (float, optional): momentum factor (default: 0) ("m") 13 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 14 | ("\beta") 15 | eta (float, optional): LARS coefficient 16 | max_epoch: maximum training epoch to determine polynomial LR decay. 17 | 18 | Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. 19 | Large Batch Training of Convolutional Networks: 20 | https://arxiv.org/abs/1708.03888 21 | 22 | Example: 23 | >>> optimizer = LARS(model.parameters(), lr=0.1, eta=1e-3) 24 | >>> optimizer.zero_grad() 25 | >>> loss_fn(model(input), target).backward() 26 | >>> optimizer.step() 27 | """ 28 | def __init__(self, params, lr=required, momentum=.9, 29 | weight_decay=.0005, eta=0.001, max_epoch=200): 30 | if lr is not required and lr < 0.0: 31 | raise ValueError("Invalid learning rate: {}".format(lr)) 32 | if momentum < 0.0: 33 | raise ValueError("Invalid momentum value: {}".format(momentum)) 34 | if weight_decay < 0.0: 35 | raise ValueError("Invalid weight_decay value: {}" 36 | .format(weight_decay)) 37 | if eta < 0.0: 38 | raise ValueError("Invalid LARS coefficient value: {}".format(eta)) 39 | 40 | self.epoch = 0 41 | defaults = dict(lr=lr, momentum=momentum, 42 | weight_decay=weight_decay, 43 | eta=eta, max_epoch=max_epoch) 44 | super(LARS, self).__init__(params, defaults) 45 | 46 | def step(self, epoch=None, closure=None): 47 | """Performs a single optimization step. 48 | 49 | Arguments: 50 | closure (callable, optional): A closure that reevaluates the model 51 | and returns the loss. 52 | epoch: current epoch to calculate polynomial LR decay schedule. 53 | if None, uses self.epoch and increments it. 54 | """ 55 | loss = None 56 | if closure is not None: 57 | loss = closure() 58 | 59 | if epoch is None: 60 | epoch = self.epoch 61 | self.epoch += 1 62 | 63 | for group in self.param_groups: 64 | weight_decay = group['weight_decay'] 65 | momentum = group['momentum'] 66 | eta = group['eta'] 67 | lr = group['lr'] 68 | max_epoch = group['max_epoch'] 69 | 70 | for p in group['params']: 71 | if p.grad is None: 72 | continue 73 | 74 | param_state = self.state[p] 75 | d_p = p.grad.data 76 | 77 | weight_norm = torch.norm(p.data) 78 | grad_norm = torch.norm(d_p) 79 | 80 | # Global LR computed on polynomial decay schedule 81 | decay = (1 - float(epoch) / max_epoch) ** 2 82 | global_lr = lr * decay 83 | 84 | # Compute local learning rate for this layer 85 | local_lr = eta * weight_norm / (grad_norm + weight_decay * weight_norm) 86 | 87 | # Update the momentum term 88 | actual_lr = local_lr * global_lr 89 | 90 | if 'momentum_buffer' not in param_state: 91 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 92 | else: 93 | buf = param_state['momentum_buffer'] 94 | buf.mul_(momentum).add_(d_p + weight_decay * p.data, alpha=actual_lr) 95 | p.data.add_(-buf) 96 | 97 | return loss -------------------------------------------------------------------------------- /optimizers/lars_simclr.py: -------------------------------------------------------------------------------- 1 | """The lars optimizer used in simclr is a bit different from the paper where they exclude certain parameters""" 2 | """I asked the author of byol, they also stick to the simclr lars implementation""" 3 | 4 | 5 | 6 | import torch 7 | import torchvision 8 | from torch.optim.optimizer import Optimizer 9 | import torch.nn as nn 10 | # comments from the lead author of byol 11 | # 2. + 3. We follow the same implementation as the one used in SimCLR for LARS. This is indeed a bit 12 | # different from the one described in the LARS paper and the implementation you attached to your email. 13 | # In particular as in SimCLR we first modify the gradient to include the weight decay (with beta corresponding 14 | # to self.weight_decay in the SimCLR code) and then adapt the learning rate by dividing by the norm of this 15 | # sum, this is different from the LARS pseudo code where they divide by the sum of the norm (instead of the 16 | # norm of the sum as SimCLR and us are doing). This is done in the SimCLR code by first adding the weight 17 | # decay term to the gradient and then using this sum to perform the adaptation. We also use a term (usually 18 | # referred to as trust_coefficient but referred as eeta in SimCLR code) set to 1e-3 to multiply the updates 19 | # of linear layers. 20 | # Note that the logic "if w_norm > 0 and g_norm > 0 else 1.0" is there to tackle numerical instabilities. 21 | # In general we closely followed SimCLR implementation of LARS. 22 | class LARS_simclr(Optimizer): 23 | def __init__(self, 24 | named_modules, 25 | lr, 26 | momentum=0.9, # beta? YES 27 | trust_coef=1e-3, 28 | weight_decay=1.5e-6, 29 | exclude_bias_from_adaption=True): 30 | '''byol: As in SimCLR and official implementation of LARS, we exclude bias # and batchnorm weight from the Lars adaptation and weightdecay''' 31 | defaults = dict(momentum=momentum, 32 | lr=lr, 33 | weight_decay=weight_decay, 34 | trust_coef=trust_coef) 35 | parameters = self.exclude_from_model(named_modules, exclude_bias_from_adaption) 36 | super(LARS_simclr, self).__init__(parameters, defaults) 37 | 38 | @torch.no_grad() 39 | def step(self): 40 | for group in self.param_groups: # only 1 group in most cases 41 | weight_decay = group['weight_decay'] 42 | momentum = group['momentum'] 43 | lr = group['lr'] 44 | 45 | trust_coef = group['trust_coef'] 46 | # print(group['name']) 47 | # eps = group['eps'] 48 | for p in group['params']: 49 | # breakpoint() 50 | if p.grad is None: 51 | continue 52 | global_lr = lr 53 | velocity = self.state[p].get('velocity', 0) 54 | # if name in self.exclude_from_layer_adaptation: 55 | if self._use_weight_decay(group): 56 | p.grad.data += weight_decay * p.data 57 | 58 | trust_ratio = 1.0 59 | if self._do_layer_adaptation(group): 60 | w_norm = torch.norm(p.data, p=2) 61 | g_norm = torch.norm(p.grad.data, p=2) 62 | trust_ratio = trust_coef * w_norm / g_norm if w_norm > 0 and g_norm > 0 else 1.0 63 | scaled_lr = global_lr * trust_ratio # trust_ratio is the local_lr 64 | next_v = momentum * velocity + scaled_lr * p.grad.data 65 | update = next_v 66 | p.data = p.data - update 67 | 68 | 69 | def _use_weight_decay(self, group): 70 | return False if group['name'] == 'exclude' else True 71 | def _do_layer_adaptation(self, group): 72 | return False if group['name'] == 'exclude' else True 73 | 74 | def exclude_from_model(self, named_modules, exclude_bias_from_adaption=True): 75 | base = [] 76 | exclude = [] 77 | for name, module in named_modules: 78 | if type(module) in [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]: 79 | # if isinstance(module, torch.nn.modules.batchnorm._BatchNorm) 80 | for name2, param in module.named_parameters(): 81 | exclude.append(param) 82 | else: 83 | for name2, param in module.named_parameters(): 84 | if name2 == 'bias': 85 | exclude.append(param) 86 | elif name2 == 'weight': 87 | base.append(param) 88 | else: 89 | pass # non leaf modules 90 | return [{ 91 | 'name': 'base', 92 | 'params': base 93 | },{ 94 | 'name': 'exclude', 95 | 'params': exclude 96 | }] if exclude_bias_from_adaption == True else [{ 97 | 'name': 'base', 98 | 'params': base+exclude 99 | }] 100 | 101 | if __name__ == "__main__": 102 | 103 | resnet = torchvision.models.resnet18(pretrained=False) 104 | model = resnet 105 | 106 | optimizer = LARS_simclr(model.named_modules(), lr=0.1) 107 | # print() 108 | # out = optimizer.exclude_from_model(model.named_modules(),exclude_bias_from_adaption=False) 109 | # print(len(out[0]['params'])) 110 | # exit() 111 | 112 | criterion = torch.nn.CrossEntropyLoss() 113 | for i in range(100): 114 | model.zero_grad() 115 | pred = model(torch.randn((2,3,32,32))) 116 | loss = pred.mean() 117 | loss.backward() 118 | optimizer.step() 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /optimizers/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | # from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class LR_Scheduler(object): 7 | def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False): 8 | self.base_lr = base_lr 9 | self.constant_predictor_lr = constant_predictor_lr 10 | warmup_iter = iter_per_epoch * warmup_epochs 11 | warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter) 12 | decay_iter = iter_per_epoch * (num_epochs - warmup_epochs) 13 | cosine_lr_schedule = final_lr+0.5*(base_lr-final_lr)*(1+np.cos(np.pi*np.arange(decay_iter)/decay_iter)) 14 | 15 | self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule)) 16 | self.optimizer = optimizer 17 | self.iter = 0 18 | def step(self): 19 | for param_group in self.optimizer.param_groups: 20 | 21 | if self.constant_predictor_lr and param_group['name'] == 'predictor': 22 | param_group['lr'] = self.base_lr 23 | else: 24 | lr = param_group['lr'] = self.lr_schedule[self.iter] 25 | 26 | self.iter += 1 27 | return lr 28 | 29 | if __name__ == "__main__": 30 | import torchvision 31 | model = torchvision.models.resnet50() 32 | optimizer = torch.optim.SGD(model.parameters(), lr=999) 33 | epochs = 100 34 | n_iter = 1000 35 | scheduler = LR_Scheduler(optimizer, 10, 1, epochs, 3, 0, n_iter) 36 | import matplotlib.pyplot as plt 37 | lrs = [] 38 | for epoch in range(epochs): 39 | for it in range(n_iter): 40 | lr = scheduler.step() 41 | lrs.append(lr) 42 | plt.plot(lrs) 43 | plt.show() -------------------------------------------------------------------------------- /tools/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .average_meter import AverageMeter 2 | from .accuracy import accuracy 3 | from .plot_logger import PlotLogger 4 | from .knn_monitor import knn_monitor 5 | 6 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .average_meter import AverageMeter 2 | from .accuracy import accuracy 3 | from .plot_logger import PlotLogger 4 | from .knn_monitor import knn_monitor 5 | 6 | -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/tools/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/accuracy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/tools/__pycache__/accuracy.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/average_meter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/tools/__pycache__/average_meter.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/knn_monitor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/tools/__pycache__/knn_monitor.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/plot_logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/tools/__pycache__/plot_logger.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/ramps.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zewei-long/fedcon-pytorch/54732ec05c19845734567c186b63e3ddedb10798/tools/__pycache__/ramps.cpython-37.pyc -------------------------------------------------------------------------------- /tools/accuracy.py: -------------------------------------------------------------------------------- 1 | def accuracy(output, target, topk=(1,)): 2 | """Computes the accuracy over the k top predictions for the specified values of k""" 3 | with torch.no_grad(): 4 | maxk = max(topk) 5 | batch_size = target.size(0) 6 | 7 | _, pred = output.topk(maxk, 1, True, True) 8 | pred = pred.t() 9 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 10 | 11 | res = [] 12 | for k in topk: 13 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 14 | res.append(correct_k.mul_(100.0 / batch_size)) 15 | return res 16 | -------------------------------------------------------------------------------- /tools/average_meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(): 2 | """Computes and stores the average and current value""" 3 | def __init__(self, name, fmt=':f'): 4 | self.name = name 5 | self.fmt = fmt 6 | self.log = [] 7 | self.val = 0 8 | self.avg = 0 9 | self.sum = 0 10 | self.count = 0 11 | 12 | def reset(self): 13 | self.log.append(self.avg) 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | 25 | def __str__(self): 26 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 27 | return fmtstr.format(**self.__dict__) 28 | 29 | if __name__ == "__main__": 30 | meter = AverageMeter('sldk') 31 | print(meter.log) 32 | 33 | -------------------------------------------------------------------------------- /tools/knn_monitor.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch.nn.functional as F 3 | import torch 4 | # code copied from https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb#scrollTo=RI1Y8bSImD7N 5 | # test using a knn monitor 6 | def knn_monitor(net, memory_data_loader, test_data_loader, epoch, k=200, t=0.1, hide_progress=False): 7 | net.eval() 8 | classes = len(memory_data_loader.dataset.classes) 9 | total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, [] 10 | with torch.no_grad(): 11 | # generate feature bank 12 | for data, target in tqdm(memory_data_loader, desc='Feature extracting', leave=False, disable=hide_progress): 13 | feature = net(data.cuda(non_blocking=True)) 14 | feature = F.normalize(feature, dim=1) 15 | feature_bank.append(feature) 16 | # [D, N] 17 | feature_bank = torch.cat(feature_bank, dim=0).t().contiguous() 18 | # [N] 19 | feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device) 20 | # loop test data to predict the label by weighted knn search 21 | test_bar = tqdm(test_data_loader, desc='kNN', disable=hide_progress) 22 | for data, target in test_bar: 23 | data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True) 24 | feature = net(data) 25 | feature = F.normalize(feature, dim=1) 26 | 27 | pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, k, t) 28 | 29 | total_num += data.size(0) 30 | total_top1 += (pred_labels[:, 0] == target).float().sum().item() 31 | test_bar.set_postfix({'Accuracy':total_top1 / total_num * 100}) 32 | return total_top1 / total_num * 100 33 | 34 | # knn monitor as in InstDisc https://arxiv.org/abs/1805.01978 35 | # implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR 36 | def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t): 37 | # compute cos similarity between each feature vector and feature bank ---> [B, N] 38 | sim_matrix = torch.mm(feature, feature_bank) 39 | # [B, K] 40 | sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1) 41 | # [B, K] 42 | sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices) 43 | sim_weight = (sim_weight / knn_t).exp() 44 | 45 | # counts for each class 46 | one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device) 47 | # [B*K, C] 48 | one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0) 49 | # weighted score ---> [B, C] 50 | pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1) 51 | 52 | pred_labels = pred_scores.argsort(dim=-1, descending=True) 53 | return pred_labels 54 | -------------------------------------------------------------------------------- /tools/plot_logger.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') #https://stackoverflow.com/questions/49921721/runtimeerror-main-thread-is-not-in-main-loop-with-matplotlib-and-flask 3 | import matplotlib.pyplot as plt 4 | from collections import OrderedDict 5 | 6 | 7 | class PlotLogger(object): 8 | def __init__(self, params=['loss']): 9 | self.logger = OrderedDict({param:[] for param in params}) 10 | def update(self, ordered_dict): 11 | # self.logger.keys() 12 | assert set(ordered_dict.keys()).issubset(set(self.logger.keys())) 13 | for key, value in ordered_dict.items(): 14 | self.logger[key].append(value) 15 | 16 | def save(self, file, **kwargs): 17 | fig, axes = plt.subplots(nrows=len(self.logger), ncols=1) 18 | fig.tight_layout() 19 | for ax, (key, value) in zip(axes, self.logger.items()): 20 | ax.plot(value) 21 | ax.set_title(key) 22 | 23 | plt.savefig(file, **kwargs) 24 | plt.close() 25 | 26 | 27 | 28 | 29 | 30 | if __name__ == "__main__": 31 | logger = PlotLogger(params=['loss', 'accuracy', 'epoch']) 32 | import random 33 | epochs = 100 34 | n_iter = 1000 35 | for epoch in range(epochs): 36 | for idx in range(n_iter): 37 | stuff = {'loss': random.random(), 'accuracy':random.random(), 'epoch': epoch} 38 | logger.update(stuff) 39 | 40 | logger.save('./logger.png') 41 | 42 | 43 | -------------------------------------------------------------------------------- /tools/ramps.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | from torch import nn, autograd 7 | from torch.utils.data import DataLoader, Dataset 8 | import numpy as np 9 | import random 10 | from sklearn import metrics 11 | import torch.nn.functional as F 12 | import copy 13 | from torch.autograd import Variable 14 | import itertools 15 | import logging 16 | import os.path 17 | from PIL import Image 18 | import numpy as np 19 | from torch.utils.data.sampler import Sampler 20 | import re 21 | import argparse 22 | import os 23 | import shutil 24 | import time 25 | import math 26 | import logging 27 | import os 28 | import sys 29 | import torch.backends.cudnn as cudnn 30 | from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler 31 | import torchvision.datasets 32 | 33 | def quantile_linear(iter, args): 34 | 35 | turn_point = int( (args.comu_rate * args.epochs - 0.1 * args.epochs -1.35) / 0.45 ) 36 | if iter < args.phi_g: 37 | return 1.0 38 | elif iter > turn_point: 39 | return 0.1 40 | else: 41 | return 0.9 * iter / ( 2 - turn_point ) + 1 - 1.8/( 2 - turn_point ) 42 | 43 | 44 | def quantile_rectangle(iter, args): 45 | if iter < args.phi_g: 46 | return 0.0 47 | elif iter >= args.psi_g: 48 | return 0.0 49 | else: 50 | if args.comu_rate*5/3 > 1: 51 | return 0.99 52 | else: 53 | return args.comu_rate*args.epochs/(args.psi_g - args.phi_g) 54 | 55 | def get_median(data, iter, args): 56 | if args.dataset == 'mnist': 57 | a = 8 58 | else: 59 | a = 33 60 | 61 | if len(data) < (39*a): 62 | data_test = data[(-10*a):] 63 | elif len(data) < (139*a): 64 | data_test = data[(30*a) : ] 65 | else: 66 | data_test = data[(-100*a):] 67 | 68 | data_test.sort() 69 | 70 | if args.ramp == 'linear': 71 | quantile = quantile_linear(iter, args) 72 | iter_place = int( (1 - quantile) * len(data_test)) 73 | elif args.ramp == 'flat': 74 | quantile = quantile_flat(iter, args) 75 | iter_place = int( (1 - quantile) * len(data_test)) 76 | elif args.ramp == 'rectangle': 77 | quantile = quantile_rectangle(iter, args) 78 | iter_place = int( (1 - quantile) * len(data_test)-1) 79 | else: 80 | exit('Error: wrong ramp type!') 81 | return data_test[iter_place] 82 | 83 | def sigmoid_rampup(current, rampup_length): 84 | if rampup_length == 0: 85 | return 1.0 86 | else: 87 | current = np.clip(current, 0.0, rampup_length) 88 | phase = 1.0 - current / rampup_length 89 | return float(np.exp(-5.0 * phase * phase)) 90 | 91 | def sigmoid_rampup2(current, rampup_length): 92 | if rampup_length == 0: 93 | return 1.0 94 | else: 95 | current = np.clip(current, 0.0, rampup_length) 96 | phase = current / rampup_length 97 | return float(np.exp(-5.0 * phase * phase)) 98 | 99 | def linear_rampup(current, rampup_length): 100 | assert current >= 0 and rampup_length >= 0 101 | if current >= rampup_length: 102 | return 1.0 103 | else: 104 | return current / rampup_length 105 | 106 | 107 | def cosine_rampdown(current, rampdown_length): 108 | assert 0 <= current <= rampdown_length 109 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) --------------------------------------------------------------------------------