├── nets ├── __init__.py ├── __pycache__ │ ├── Nets.cpython-37.pyc │ ├── bn_ops.cpython-37.pyc │ ├── dual_bn.cpython-37.pyc │ ├── models.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── dual_ops.cpython-37.pyc │ ├── profile_func.cpython-37.pyc │ ├── slimmable_ops.cpython-37.pyc │ ├── thop_op_hooks.cpython-37.pyc │ └── slimmable_models.cpython-37.pyc ├── HeteFL │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── preresne.cpython-37.pyc │ │ ├── preresnet.cpython-37.pyc │ │ └── slimmable_preresne.cpython-37.pyc │ ├── __init__.py │ ├── preresne.py │ ├── preresnet.py │ └── slimmable_preresne.py ├── thop_op_hooks.py ├── profile_func.py ├── dual_ops.py ├── Nets.py ├── slimNets.py ├── dual_bn.py ├── slimmable_Nets.py ├── bn_ops.py ├── models.py └── slimmable_ops.py ├── utils ├── __init__.py ├── __pycache__ │ ├── config.cpython-37.pyc │ ├── utils.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── data_loader.cpython-37.pyc │ └── data_utils.cpython-37.pyc ├── config.py ├── utils.py ├── data_utils.py └── data_loader.py ├── federated ├── __init__.py ├── __pycache__ │ ├── core.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ ├── learning.cpython-37.pyc │ └── aggregation.cpython-37.pyc ├── learning.py └── core.py ├── data └── put_dataset_here ├── wandb └── wandb_logs_will_be_here ├── checkpoint └── checkpoint_will_be_here ├── README.md ├── fed_hfl.py └── fed_dataHetComp.py /nets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /federated/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/put_dataset_here: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /wandb/wandb_logs_will_be_here: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /checkpoint/checkpoint_will_be_here: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nets/__pycache__/Nets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/nets/__pycache__/Nets.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/bn_ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/nets/__pycache__/bn_ops.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/dual_bn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/nets/__pycache__/dual_bn.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/nets/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/utils/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /federated/__pycache__/core.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/federated/__pycache__/core.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/nets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/dual_ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/nets/__pycache__/dual_ops.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/profile_func.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/nets/__pycache__/profile_func.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/utils/__pycache__/data_loader.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/utils/__pycache__/data_utils.cpython-37.pyc -------------------------------------------------------------------------------- /federated/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/federated/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /federated/__pycache__/learning.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/federated/__pycache__/learning.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/slimmable_ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/nets/__pycache__/slimmable_ops.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/thop_op_hooks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/nets/__pycache__/thop_op_hooks.cpython-37.pyc -------------------------------------------------------------------------------- /federated/__pycache__/aggregation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/federated/__pycache__/aggregation.cpython-37.pyc -------------------------------------------------------------------------------- /nets/HeteFL/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/nets/HeteFL/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /nets/HeteFL/__pycache__/preresne.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/nets/HeteFL/__pycache__/preresne.cpython-37.pyc -------------------------------------------------------------------------------- /nets/HeteFL/__pycache__/preresnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/nets/HeteFL/__pycache__/preresnet.cpython-37.pyc -------------------------------------------------------------------------------- /nets/__pycache__/slimmable_models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/nets/__pycache__/slimmable_models.cpython-37.pyc -------------------------------------------------------------------------------- /nets/HeteFL/__pycache__/slimmable_preresne.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MehdiSet/PerFedMask/HEAD/nets/HeteFL/__pycache__/slimmable_preresne.cpython-37.pyc -------------------------------------------------------------------------------- /nets/HeteFL/__init__.py: -------------------------------------------------------------------------------- 1 | """Models based on HeteroFL (https://github.com/dem123456789/HeteroFL-Computation-and-Communication-Efficient-Federated-Learning-for-Heterogeneous-Clients)""" 2 | -------------------------------------------------------------------------------- /nets/thop_op_hooks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from thop.vision.basic_hooks import count_convNd, count_linear, count_bn 4 | from .slimmable_ops import SlimmableConv2d, SlimmableLinear, SlimmableBatchNorm2d, \ 5 | SlimmableBatchNorm1d 6 | 7 | __all__ = ['thop_hooks'] 8 | 9 | # extra profile functions from newer version of thop 10 | def count_ln(m, x, y): 11 | """layer norm""" 12 | x = x[0] 13 | if not m.training: 14 | m.total_ops += counter_norm(x.numel()) 15 | 16 | def counter_norm(input_size): 17 | """input is a number not a array or tensor""" 18 | return torch.DoubleTensor([2 * input_size]) 19 | 20 | def count_softmax(m, x, y): 21 | x = x[0] 22 | nfeatures = x.size()[m.dim] 23 | batch_size = x.numel() // nfeatures 24 | 25 | m.total_ops += counter_softmax(batch_size, nfeatures) 26 | 27 | def counter_softmax(batch_size, nfeatures): 28 | total_exp = nfeatures 29 | total_add = nfeatures - 1 30 | total_div = nfeatures 31 | total_ops = batch_size * (total_exp + total_add + total_div) 32 | return torch.DoubleTensor([int(total_ops)]) 33 | 34 | thop_hooks = { 35 | SlimmableConv2d: count_convNd, 36 | SlimmableLinear: count_linear, 37 | SlimmableBatchNorm2d: count_bn, 38 | SlimmableBatchNorm1d: count_bn, 39 | nn.LayerNorm: count_ln, 40 | } 41 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | """Configuration file for defining paths to data.""" 2 | import os 3 | import platform 4 | 5 | def make_if_not_exist(p): 6 | if not os.path.exists(p): 7 | os.makedirs(p) 8 | 9 | hostname = platform.uname()[1] # type: str 10 | # Update your paths here. 11 | CHECKPOINT_ROOT = './checkpoint' 12 | #if int(hostname.split('-')[-1]) >= 8: 13 | # data_root = '/localscratch2/jyhong/' 14 | #elif hostname.startswith('illidan'): 15 | # data_root = '/media/Research/jyhong/data' 16 | #else: 17 | data_root = './data' 18 | make_if_not_exist(data_root) 19 | make_if_not_exist(CHECKPOINT_ROOT) 20 | 21 | if hostname.startswith('illidan') and int(hostname.split('-')[-1]) < 8: 22 | # personal config 23 | home_path = os.path.expanduser('~/') 24 | DATA_PATHS = { 25 | "Digits": home_path + "projects/FedBN/data", 26 | "DomainNet": data_root + "/DomainNet", 27 | "DomainNetPathList": home_path + "projects/FedBN/data/", # store the path list file from FedBN 28 | "Cifar10": data_root, 29 | "Cifar100": data_root, 30 | } 31 | else: 32 | DATA_PATHS = { 33 | "Digits": data_root + "/Digits", 34 | "DomainNet": data_root + "/DomainNet", 35 | "DomainNetPathList": data_root + "/DomainNet/domainnet10/", # store the path list file from FedBN 36 | "Cifar10": data_root, 37 | "cifar100": data_root, 38 | } 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PerFedMask: Personalized Federated Learning with Optimized Masking Vectors 2 | 3 | This is the official Pytorch implementation of our paper [PerFedMask: Personalized Federated Learning with Optimized Masking Vectors](https://openreview.net/pdf?id=hxEIgUXLFF) accepted in ICLR 2023. 4 | 5 | ## Installation 6 | 7 | First check the requirements as follows:\ 8 | python=3.7\ 9 | numpy=1.17.0\ 10 | pytorch=1.12.1\ 11 | cudatoolkit = 11.3.1\ 12 | wandb=0.12.19\ 13 | torchvision=0.13.1\ 14 | cvxpy=1.1.11\ 15 | mosek=9.2.40 16 | 17 | Then clone the repository as follows: 18 | ```shell 19 | git clone https://github.com/MehdiSet/PerFedMask.git 20 | ``` 21 | 22 | ## Dataset 23 | 24 | We conduct our experiments on CIFAR-10, CIFAR-100, and DomainNet datasets using ResNet (PreResNet18), MobileNet , and AlexNet, respectively. Please download the datasets and place them under `data/` directory. 25 | 26 | 27 | ## Citation 28 | 29 | If you find our paper and code useful, please cite our paper as follows: 30 | ```bibtex 31 | @inproceedings{setayesh2023perfedmask, 32 | title={PerFedMask: {Personalized} Federated Learning with Optimized Masking Vectors}, 33 | author={Setayesh, Mehdi and Li, Xiaoxiao and W.S. Wong, Vincent}, 34 | booktitle={Proc. of International Conference on Learning Representations (ICLR)}, 35 | address={Kigali, Rwanda}, 36 | month={May}, 37 | year={2023} 38 | } 39 | ``` 40 | 41 | ## Contact 42 | 43 | Please feel free to contact us if you have any questions: 44 | - Mehdi Setayesh: setayeshm@ece.ubc.ca 45 | 46 | ## Acknowledgements 47 | This codebase was adapted from https://github.com/illidanlab/SplitMix and https://github.com/jhoon-oh/FedBABU. 48 | 49 | -------------------------------------------------------------------------------- /nets/profile_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from thop import profile 4 | from nets.thop_op_hooks import thop_hooks 5 | 6 | # from nets.slimmable_models import Ensemble, EnsembleSubnet, EnsembleGroupSubnet 7 | 8 | 9 | def count_params_by_state(model): 10 | """Count #param based on state dict of the given model.""" 11 | if hasattr(model, 'state_size'): # EnsembleSubnet, EnsembleGroupSubnet 12 | s = model.state_size() 13 | else: 14 | s = 0 15 | for k, p in model.state_dict().items(): 16 | s = s + p.numel() 17 | return s 18 | 19 | 20 | def profile_slimmable_models(model, slim_ratios, verbose=1): 21 | max_flops = None 22 | max_params = None 23 | for slim_ratio in sorted(slim_ratios, reverse=True): 24 | if hasattr(model, 'switch_slim_mode'): 25 | model.switch_slim_mode(slim_ratio) 26 | else: # if isinstance(model, Ensemble): 27 | model.set_total_slim_ratio(slim_ratio) 28 | flops, state_params = profile_model(model, verbose > 1) 29 | if verbose > 0: 30 | print(f'slim_ratio: {slim_ratio:.3f} GFLOPS: {flops / 1e9:.4f}, ' 31 | f'model state size: {state_params / 1e6:.2f}MB') 32 | 33 | if max_flops is None: 34 | max_flops = flops 35 | max_params = state_params 36 | elif verbose > 0: 37 | print(f" flop ratio: {flops/max_flops:.3f}, size ratio: {state_params/max_params:.3f}," 38 | f" sqrt size ratio: {np.sqrt(state_params/max_params):.3f}") 39 | 40 | 41 | def profile_model(model, verbose=False, batch_size=2, device='cpu', input_shape=None): 42 | if input_shape is None: 43 | input_shape = model.input_shape 44 | input_shape = (batch_size, *input_shape[1:]) 45 | dummy_input = torch.rand(input_shape).to(device) 46 | # customized ops: 47 | # https://github.com/Lyken17/pytorch-OpCounter/blob/master/thop/vision/basic_hooks.py 48 | state_params = count_params_by_state(model) 49 | flops, params = profile(model, inputs=(dummy_input,), custom_ops=thop_hooks, 50 | verbose=verbose) 51 | flops = flops / batch_size 52 | return flops, state_params 53 | 54 | -------------------------------------------------------------------------------- /nets/dual_ops.py: -------------------------------------------------------------------------------- 1 | """Structure with dual weights.""" 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | from torch.nn import functional as F 7 | from typing import Optional 8 | 9 | 10 | class DualConv2d(nn.Conv2d): 11 | aux_weight: torch.Tensor 12 | aux_bias: Optional[torch.Tensor] 13 | 14 | def __init__(self, in_channels: int, out_channels: int, 15 | kernel_size, stride=1, padding=0, dilation=1, 16 | groups=1, bias=True, 17 | fix_out=False, fix_in=False, overlap_rate=0.): 18 | assert groups == 1, "for now, we can only support single group when slimming." 19 | if overlap_rate > 0: 20 | overlap_ch_in = in_channels if fix_in else int((2 - overlap_rate) * in_channels) 21 | overlap_ch_out = out_channels if fix_out else int((2 - overlap_rate) * out_channels) 22 | self.conv = super(DualConv2d, self).__init__( 23 | overlap_ch_in, overlap_ch_out, 24 | kernel_size, stride=stride, padding=padding, dilation=dilation, 25 | groups=groups, bias=bias) 26 | else: 27 | self.conv = super(DualConv2d, self).__init__( 28 | in_channels, out_channels, 29 | kernel_size, stride=stride, padding=padding, dilation=dilation, 30 | groups=groups, bias=bias) 31 | # auxiliary weight, bias 32 | self.aux_weight = nn.Parameter(torch.Tensor( 33 | out_channels, in_channels // groups, kernel_size, kernel_size)) 34 | if bias: 35 | self.aux_bias = nn.Parameter(torch.Tensor(out_channels)) 36 | else: 37 | self.register_parameter('aux_bias', None) 38 | 39 | init.kaiming_uniform_(self.aux_weight, a=math.sqrt(5)) 40 | if self.aux_bias is not None: 41 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 42 | bound = 1 / math.sqrt(fan_in) 43 | init.uniform_(self.aux_bias, -bound, bound) 44 | 45 | self.overlap_rate = overlap_rate 46 | self.in_channels = in_channels 47 | self.out_channels = out_channels 48 | 49 | self.mode = 0 50 | self.fix_out = fix_out 51 | self.fix_in = fix_in 52 | 53 | def forward(self, x): 54 | if self.overlap_rate > 0: 55 | in_idx_bias = 0 56 | out_idx_bias = 0 57 | if self.mode > 0: 58 | in_idx_bias = 0 if self.fix_in else int((1 - self.overlap_rate) * self.in_channels) 59 | out_idx_bias = 0 if self.fix_out else int((1 - self.overlap_rate) * self.out_channels) 60 | weight = self.weight[out_idx_bias:(out_idx_bias+self.out_channels), in_idx_bias:(in_idx_bias+self.in_channels)] 61 | bias = self.bias[out_idx_bias:(out_idx_bias + self.out_channels)] if self.bias is not None else None 62 | else: 63 | if self.mode > 0: 64 | weight = self.aux_weight 65 | bias = self.aux_bias 66 | else: 67 | weight = self.weight 68 | bias = self.bias 69 | y = F.conv2d( 70 | x, weight, bias, self.stride, self.padding, 71 | self.dilation, self.groups) 72 | return y 73 | -------------------------------------------------------------------------------- /nets/Nets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torchvision import models 9 | 10 | def conv3x3(in_channels, out_channels, **kwargs): 11 | return nn.Sequential( 12 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs), 13 | nn.BatchNorm2d(out_channels, track_running_stats=False), 14 | nn.ReLU(), 15 | nn.MaxPool2d(2) 16 | ) 17 | 18 | class CNNCifar(nn.Module): 19 | def __init__(self, args): 20 | super(CNNCifar, self).__init__() 21 | in_channels = 3 22 | num_classes = args.num_classes 23 | 24 | hidden_size = 64 25 | 26 | self.features = nn.Sequential( 27 | conv3x3(in_channels, hidden_size), 28 | conv3x3(hidden_size, hidden_size), 29 | conv3x3(hidden_size, hidden_size), 30 | conv3x3(hidden_size, hidden_size) 31 | ) 32 | 33 | self.linear = nn.Linear(hidden_size*2*2, num_classes) 34 | 35 | def forward(self, x): 36 | features = self.features(x) 37 | features = features.view((features.size(0), -1)) 38 | logits = self.linear(features) 39 | 40 | return logits 41 | 42 | def extract_features(self, x): 43 | features = self.features(x) 44 | features = features.view((features.size(0), -1)) 45 | 46 | return features 47 | 48 | '''MobileNet in PyTorch. 49 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 50 | for more details. 51 | ''' 52 | 53 | class Block(nn.Module): 54 | '''Depthwise conv + Pointwise conv''' 55 | def __init__(self, in_planes, out_planes, stride=1): 56 | super(Block, self).__init__() 57 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 58 | self.bn1 = nn.BatchNorm2d(in_planes, track_running_stats=False) 59 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 60 | self.bn2 = nn.BatchNorm2d(out_planes, track_running_stats=False) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(self.conv1(x))) 64 | out = F.relu(self.bn2(self.conv2(out))) 65 | return out 66 | 67 | class MobileNetCifar(nn.Module): 68 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 69 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 70 | 71 | def __init__(self, num_classes=10, track_running_stats=False, width_scale=1.): 72 | super(MobileNetCifar, self).__init__() 73 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 74 | self.bn1 = nn.BatchNorm2d(32, track_running_stats=False) 75 | self.layers = self._make_layers(in_planes=32) 76 | self.linear = nn.Linear(1024, num_classes) 77 | 78 | def _make_layers(self, in_planes): 79 | layers = [] 80 | for x in self.cfg: 81 | out_planes = x if isinstance(x, int) else x[0] 82 | stride = 1 if isinstance(x, int) else x[1] 83 | layers.append(Block(in_planes, out_planes, stride)) 84 | in_planes = out_planes 85 | return nn.Sequential(*layers) 86 | 87 | def forward(self, x): 88 | out = F.relu(self.bn1(self.conv1(x))) 89 | out = self.layers(out) 90 | out = F.avg_pool2d(out, 2) 91 | out = out.view(out.size(0), -1) 92 | logits = self.linear(out) 93 | 94 | return logits 95 | 96 | def extract_features(self, x): 97 | out = F.relu(self.bn1(self.conv1(x))) 98 | out = self.layers(out) 99 | out = F.avg_pool2d(out, 2) 100 | out = out.view(out.size(0), -1) 101 | 102 | return out -------------------------------------------------------------------------------- /nets/slimNets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torchvision import models 9 | from .models import ScalableModule 10 | 11 | 12 | 13 | 14 | def conv3x3(in_channels, out_channels, **kwargs): 15 | return nn.Sequential( 16 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs), 17 | nn.BatchNorm2d(out_channels, track_running_stats=False), 18 | nn.ReLU(), 19 | nn.MaxPool2d(2) 20 | ) 21 | 22 | class CNNCifar(nn.Module): 23 | def __init__(self, args): 24 | super(CNNCifar, self).__init__() 25 | in_channels = 3 26 | num_classes = args.num_classes 27 | 28 | hidden_size = 64 29 | 30 | self.features = nn.Sequential( 31 | conv3x3(in_channels, hidden_size), 32 | conv3x3(hidden_size, hidden_size), 33 | conv3x3(hidden_size, hidden_size), 34 | conv3x3(hidden_size, hidden_size) 35 | ) 36 | 37 | self.linear = nn.Linear(hidden_size*2*2, num_classes) 38 | 39 | def forward(self, x): 40 | features = self.features(x) 41 | features = features.view((features.size(0), -1)) 42 | logits = self.linear(features) 43 | 44 | return logits 45 | 46 | def extract_features(self, x): 47 | features = self.features(x) 48 | features = features.view((features.size(0), -1)) 49 | 50 | return features 51 | 52 | '''MobileNet in PyTorch. 53 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 54 | for more details. 55 | ''' 56 | 57 | class Block(nn.Module): 58 | '''Depthwise conv + Pointwise conv''' 59 | def __init__(self, in_planes, out_planes, stride=1): 60 | super(Block, self).__init__() 61 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 62 | self.bn1 = nn.BatchNorm2d(in_planes, track_running_stats=False) 63 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 64 | self.bn2 = nn.BatchNorm2d(out_planes, track_running_stats=False) 65 | 66 | def forward(self, x): 67 | out = F.relu(self.bn1(self.conv1(x))) 68 | out = F.relu(self.bn2(self.conv2(out))) 69 | return out 70 | 71 | class MobileNetCifar(ScalableModule): 72 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 73 | 74 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 75 | input_shape = [None, 3, 32, 32] 76 | 77 | def __init__(self, num_classes=10, track_running_stats=False, width_scale=1.,rescale_init=False, 78 | share_affine=False, rescale_layer=False, bn_type='bn',): 79 | super(MobileNetCifar, self).__init__(width_scale=width_scale, rescale_init=rescale_init, 80 | rescale_layer=rescale_layer) 81 | 82 | self.cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 83 | 84 | if width_scale != 1.: 85 | temp_cfg=[] 86 | for x in self.cfg: 87 | I1 = int(x * width_scale) if isinstance(x, int) else int(x[0] * width_scale) 88 | I2 = 1 if isinstance(x, int) else x[1] 89 | if I2==1: 90 | temp_cfg.append(I1) 91 | else: 92 | temp_cfg.append((I1, I2)) 93 | 94 | self.cfg = temp_cfg 95 | 96 | 97 | 98 | self.conv1 = nn.Conv2d(3, int(32 * width_scale), kernel_size=3, stride=1, padding=1, bias=False) 99 | self.bn1 = nn.BatchNorm2d(int(32 * width_scale), track_running_stats=False) 100 | self.layers = self._make_layers(in_planes=int(32 *width_scale)) 101 | self.linear = nn.Linear(int(width_scale*1024), num_classes) 102 | 103 | def _make_layers(self, in_planes): 104 | layers = [] 105 | for x in self.cfg: 106 | out_planes = x if isinstance(x, int) else x[0] 107 | stride = 1 if isinstance(x, int) else x[1] 108 | layers.append(Block(in_planes, out_planes, stride)) 109 | in_planes = out_planes 110 | return nn.Sequential(*layers) 111 | 112 | def forward(self, x): 113 | out = F.relu(self.bn1(self.conv1(x))) 114 | out = self.layers(out) 115 | out = F.avg_pool2d(out, 2) 116 | out = out.view(out.size(0), -1) 117 | logits = self.linear(out) 118 | 119 | return logits 120 | 121 | def extract_features(self, x): 122 | out = F.relu(self.bn1(self.conv1(x))) 123 | out = self.layers(out) 124 | out = F.avg_pool2d(out, 2) 125 | out = out.view(out.size(0), -1) 126 | 127 | return out -------------------------------------------------------------------------------- /nets/dual_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Union 4 | from .dual_ops import DualConv2d 5 | 6 | 7 | def set_bn_mode(module: nn.Module, is_noised: Union[bool, torch.Tensor]): 8 | """Set BN mode to be noised or clean. This is only effective for StackedNormLayer 9 | or DualNormLayer.""" 10 | 11 | def set_bn_eval_(m): 12 | if isinstance(m, (DualNormLayer,)): 13 | if isinstance(is_noised, torch.Tensor): 14 | m.clean_input = ~is_noised 15 | else: 16 | m.clean_input = not is_noised 17 | elif isinstance(m, (DualConv2d,)): 18 | m.mode = 1 if is_noised else 0 19 | module.apply(set_bn_eval_) 20 | 21 | 22 | class DualNormLayer(nn.Module): 23 | """Dual Normalization Layer.""" 24 | _version = 1 25 | # __constants__ = ['track_running_stats', 'momentum', 'eps', 26 | # 'num_features', 'affine'] 27 | 28 | def __init__(self, num_features, track_running_stats=True, affine=True, bn_class=None, 29 | share_affine=True, **kwargs): 30 | super(DualNormLayer, self).__init__() 31 | self.affine = affine 32 | if bn_class is None: 33 | bn_class = nn.BatchNorm2d 34 | self.bn_class = bn_class 35 | self.share_affine = share_affine 36 | self.clean_bn = bn_class(num_features, track_running_stats=track_running_stats, affine=self.affine and not self.share_affine, **kwargs) 37 | self.noise_bn = bn_class(num_features, track_running_stats=track_running_stats, affine=self.affine and not self.share_affine, **kwargs) 38 | if self.affine and self.share_affine: 39 | self.weight = nn.Parameter(torch.Tensor(num_features)) 40 | self.bias = nn.Parameter(torch.Tensor(num_features)) 41 | else: 42 | self.register_parameter('weight', None) 43 | self.register_parameter('bias', None) 44 | self.reset_parameters() 45 | 46 | self.clean_input = True # only used in training? 47 | 48 | def reset_parameters(self) -> None: 49 | if self.affine and self.share_affine: 50 | nn.init.ones_(self.weight) 51 | nn.init.zeros_(self.bias) 52 | 53 | def forward(self, inp: torch.Tensor) -> torch.Tensor: 54 | if isinstance(self.clean_input, bool): 55 | if self.clean_input: 56 | out = self.clean_bn(inp) 57 | else: 58 | out = self.noise_bn(inp) 59 | elif isinstance(self.clean_input, torch.Tensor): 60 | # Separate input. This important at training to avoid mixture of BN stats. 61 | clean_mask = torch.nonzero(self.clean_input) 62 | noise_mask = torch.nonzero(~self.clean_input) 63 | out = torch.zeros_like(inp) 64 | 65 | if len(clean_mask) > 0: 66 | clean_mask = clean_mask.squeeze(1) 67 | # print(self.clean_input, clean_mask) 68 | out_clean = self.clean_bn(inp[clean_mask]) 69 | out[clean_mask] = out_clean 70 | if len(noise_mask) > 0: 71 | noise_mask = noise_mask.squeeze(1) 72 | # print(self.clean_input, noise_mask) 73 | out_noise = self.noise_bn(inp[noise_mask]) 74 | out[noise_mask] = out_noise 75 | elif isinstance(self.clean_input, (float, int)): 76 | assert not self.training, "You should not use both BN at training." 77 | assert not self.share_affine, "Should not share affine, because we have to use affine" \ 78 | " before combination but didn't." 79 | out_c = self.clean_bn(inp) 80 | out_n = self.noise_bn(inp) 81 | out = self.clean_input * 1. * out_c + (1. - self.clean_input) * out_n 82 | else: 83 | raise TypeError(f"Invalid self.clean_input: {type(self.clean_input)}") 84 | if self.affine and self.share_affine: 85 | # out = F.linear(out, self.weight, self.bias) 86 | shape = [1] * out.dim() 87 | shape[1] = -1 88 | out = out * self.weight.view(*shape) + self.bias.view(*shape) 89 | assert out.shape == inp.shape 90 | # TODO how to do the affine? 91 | # out = F.batch_norm(out, None, None, self.weight, self.bias, self.training) 92 | return out 93 | 94 | 95 | class DualBatchNorm2d(DualNormLayer): 96 | def __init__(self, *args, **kwargs): 97 | super(DualBatchNorm2d, self).__init__(*args, bn_class=nn.BatchNorm2d, **kwargs) 98 | 99 | 100 | class DualBatchNorm1d(DualNormLayer): 101 | def __init__(self, *args, **kwargs): 102 | super(DualBatchNorm1d, self).__init__(*args, bn_class=nn.BatchNorm1d, **kwargs) 103 | 104 | 105 | def test(): 106 | norm = DualBatchNorm2d(3) 107 | norm.eval() 108 | with torch.no_grad(): 109 | norm.clean_input = torch.randn((32,)) > 0. 110 | x = torch.randn((32, 3, 2, 2)) 111 | y = norm(x) 112 | print(list(y.size())) 113 | assert list(y.size()) == [32, 3, 2, 2] 114 | 115 | 116 | if __name__ == '__main__': 117 | test() 118 | # import doctest 119 | # 120 | # doctest.testmod() 121 | -------------------------------------------------------------------------------- /nets/slimmable_Nets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Python version: 3.6 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torchvision import models 9 | from .models import ScalableModule 10 | from .slimmable_models import BaseModule, SlimmableMixin 11 | from .slimmable_ops import SlimmableConv2d, SlimmableBatchNorm2d, SlimmableLinear 12 | 13 | 14 | 15 | 16 | def conv3x3(in_channels, out_channels, **kwargs): 17 | return nn.Sequential( 18 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs), 19 | nn.BatchNorm2d(out_channels, track_running_stats=False), 20 | nn.ReLU(), 21 | nn.MaxPool2d(2) 22 | ) 23 | 24 | class CNNCifar(nn.Module): 25 | def __init__(self, args): 26 | super(CNNCifar, self).__init__() 27 | in_channels = 3 28 | num_classes = args.num_classes 29 | 30 | hidden_size = 64 31 | 32 | self.features = nn.Sequential( 33 | conv3x3(in_channels, hidden_size), 34 | conv3x3(hidden_size, hidden_size), 35 | conv3x3(hidden_size, hidden_size), 36 | conv3x3(hidden_size, hidden_size) 37 | ) 38 | 39 | self.linear = nn.Linear(hidden_size*2*2, num_classes) 40 | 41 | def forward(self, x): 42 | features = self.features(x) 43 | features = features.view((features.size(0), -1)) 44 | logits = self.linear(features) 45 | 46 | return logits 47 | 48 | def extract_features(self, x): 49 | features = self.features(x) 50 | features = features.view((features.size(0), -1)) 51 | 52 | return features 53 | 54 | '''MobileNet in PyTorch. 55 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 56 | for more details. 57 | ''' 58 | 59 | class Block(nn.Module): 60 | '''Depthwise conv + Pointwise conv''' 61 | def __init__(self, in_planes, out_planes, stride=1): 62 | super(Block, self).__init__() 63 | 64 | self.conv1 = SlimmableConv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=1, 65 | bias=False) 66 | self.bn1 = SlimmableBatchNorm2d(in_planes, track_running_stats=False) 67 | self.conv2 = SlimmableConv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 68 | self.bn2 = SlimmableBatchNorm2d(out_planes, track_running_stats=False) 69 | 70 | def forward(self, x): 71 | out = F.relu(self.bn1(self.conv1(x))) 72 | out = F.relu(self.bn2(self.conv2(out))) 73 | return out 74 | 75 | class MobileNetCifar(BaseModule, SlimmableMixin): 76 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 77 | 78 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 79 | input_shape = [None, 3, 32, 32] 80 | 81 | def __init__(self, num_classes=10, track_running_stats=False, width_scale=1., bn_type='bn', slimmabe_ratios=None): 82 | super(MobileNetCifar, self).__init__() 83 | self._set_slimmabe_ratios(slimmabe_ratios) 84 | 85 | self.cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 86 | 87 | if width_scale != 1.: 88 | temp_cfg=[] 89 | for x in self.cfg: 90 | I1 = int(x * width_scale) if isinstance(x, int) else int(x[0] * width_scale) 91 | I2 = 1 if isinstance(x, int) else x[1] 92 | if I2==1: 93 | temp_cfg.append(I1) 94 | else: 95 | temp_cfg.append((I1, I2)) 96 | 97 | self.cfg = temp_cfg 98 | 99 | 100 | 101 | 102 | self.conv1 = SlimmableConv2d(3, int(32 * width_scale), kernel_size=3, stride=1, padding=1, 103 | bias=False, non_slimmable_in=True) 104 | self.bn1 = SlimmableBatchNorm2d(int(32 * width_scale), track_running_stats=False) 105 | self.layers = self._make_layers(in_planes=int(32 *width_scale)) 106 | self.linear = SlimmableLinear(int(width_scale*1024), num_classes, 107 | non_slimmable_out=True) 108 | 109 | def _make_layers(self, in_planes): 110 | layers = [] 111 | for x in self.cfg: 112 | out_planes = x if isinstance(x, int) else x[0] 113 | stride = 1 if isinstance(x, int) else x[1] 114 | layers.append(Block(in_planes, out_planes, stride)) 115 | in_planes = out_planes 116 | return nn.Sequential(*layers) 117 | 118 | def forward(self, x): 119 | out = F.relu(self.bn1(self.conv1(x))) 120 | out = self.layers(out) 121 | out = F.avg_pool2d(out, 2) 122 | out = out.view(out.size(0), -1) 123 | logits = self.linear(out) 124 | 125 | return logits 126 | 127 | def extract_features(self, x): 128 | out = F.relu(self.bn1(self.conv1(x))) 129 | out = self.layers(out) 130 | out = F.avg_pool2d(out, 2) 131 | out = out.view(out.size(0), -1) 132 | 133 | return out -------------------------------------------------------------------------------- /nets/bn_ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as func 5 | from torch.nn.modules.batchnorm import _NormBase 6 | from .dual_bn import DualNormLayer 7 | 8 | 9 | # BN modules 10 | class _MockBatchNorm(_NormBase): 11 | 12 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 13 | track_running_stats=True): 14 | super(_MockBatchNorm, self).__init__( 15 | num_features, eps, momentum, affine, track_running_stats) 16 | 17 | def forward(self, input): 18 | self._check_input_dim(input) 19 | 20 | # exponential_average_factor is set to self.momentum 21 | # (when it is available) only so that it gets updated 22 | # in ONNX graph when this node is exported to ONNX. 23 | if self.momentum is None: 24 | exponential_average_factor = 0.0 25 | else: 26 | exponential_average_factor = self.momentum 27 | 28 | if self.training and self.track_running_stats: 29 | # TODO: if statement only here to tell the jit to skip emitting this when it is None 30 | if self.num_batches_tracked is not None: 31 | self.num_batches_tracked = self.num_batches_tracked + 1 32 | if self.momentum is None: # use cumulative moving average 33 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 34 | else: # use exponential moving average 35 | exponential_average_factor = self.momentum 36 | 37 | r""" 38 | Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be 39 | passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are 40 | used for normalization (i.e. in eval mode when buffers are not None). 41 | """ 42 | return func.batch_norm( 43 | input, 44 | # If buffers are not to be tracked, ensure that they won't be updated 45 | torch.zeros_like(self.running_mean), 46 | torch.ones_like(self.running_var), 47 | self.weight, self.bias, False, exponential_average_factor, self.eps) 48 | 49 | class MockBatchNorm1d(_MockBatchNorm): 50 | def _check_input_dim(self, input): 51 | if input.dim() != 2 and input.dim() != 3: 52 | raise ValueError('expected 2D or 3D input (got {}D input)' 53 | .format(input.dim())) 54 | 55 | class MockBatchNorm2d(_MockBatchNorm): 56 | def _check_input_dim(self, input): 57 | if input.dim() != 4: 58 | raise ValueError('expected 4D input (got {}D input)' 59 | .format(input.dim())) 60 | 61 | class BatchNorm2dAgent(nn.BatchNorm2d): 62 | def __init__(self, *args, log_stat=False, **kwargs): 63 | super().__init__(*args, **kwargs) 64 | self.pre_stat = None # statistic before BN 65 | self.post_stat = None # statistic after BN 66 | self.log_stat = log_stat 67 | 68 | def forward(self, input): 69 | if not self.log_stat: 70 | self.pre_stat = None 71 | else: 72 | self.pre_stat = { 73 | 'mean': torch.mean(input, dim=[0, 2, 3]).data.cpu().numpy(), 74 | 'var': torch.var(input, dim=[0, 2, 3]).data.cpu().numpy(), 75 | 'data': input.data.cpu().numpy(), 76 | } 77 | out = super().forward(input) 78 | if not self.log_stat: 79 | self.pre_stat = None 80 | else: 81 | self.post_stat = { 82 | 'mean': torch.mean(out, dim=[0,2,3]).data.cpu().numpy(), 83 | 'var': torch.var(out, dim=[0,2,3]).data.cpu().numpy(), 84 | 'data': out.data.cpu().numpy(), 85 | # 'mean': ((torch.mean(out, dim=[0, 2, 3]) - self.bias)/self.weight).data.cpu().numpy(), 86 | # 'var': (torch.var(out, dim=[0, 2, 3])/(self.weight**2)).data.cpu().numpy(), 87 | } 88 | return out 89 | 90 | class BatchNorm1dAgent(nn.BatchNorm1d): 91 | def __init__(self, *args, log_stat=False, **kwargs): 92 | super().__init__(*args, **kwargs) 93 | self.pre_stat = None # statistic before BN 94 | self.post_stat = None # statistic after BN 95 | self.log_stat = log_stat 96 | 97 | def forward(self, input): 98 | if not self.log_stat: 99 | self.pre_stat = None 100 | else: 101 | self.pre_stat = { 102 | 'mean': torch.mean(input, dim=[0]).data.cpu().numpy().copy(), 103 | 'var': torch.var(input, dim=[0]).data.cpu().numpy().copy(), 104 | 'data': input.data.cpu().numpy().copy(), 105 | } 106 | out = super().forward(input) 107 | if not self.log_stat: 108 | self.post_stat = None 109 | else: 110 | self.post_stat = { 111 | 'mean': torch.mean(out, dim=[0]).data.cpu().numpy().copy(), 112 | 'var': torch.var(out, dim=[0]).data.cpu().numpy().copy(), 113 | # 'mean': ((torch.mean(out, dim=[0]) - self.bias)/self.weight).data.cpu().numpy(), 114 | # 'var': (torch.var(out, dim=[0])/(self.weight**2)).data.cpu().numpy(), 115 | 'data': out.detach().cpu().numpy().copy(), 116 | } 117 | # print("post stat mean: ", self.post_stat['mean']) 118 | return out 119 | 120 | 121 | def is_film_dual_norm(bn_type: str): 122 | return bn_type.startswith('fd') 123 | 124 | 125 | def get_bn_layer(bn_type: str): 126 | if bn_type.startswith('d'): # dual norm layer. Example: sbn, sbin, sin 127 | base_norm_class = get_bn_layer(bn_type[1:]) 128 | bn_class = { 129 | '1d': lambda num_features, **kwargs: DualNormLayer(num_features, bn_class=base_norm_class['1d'], **kwargs), 130 | '2d': lambda num_features, **kwargs: DualNormLayer(num_features, bn_class=base_norm_class['2d'], **kwargs), 131 | } 132 | elif is_film_dual_norm(bn_type): # dual norm layer. Example: sbn, sbin, sin 133 | base_norm_class = get_bn_layer(bn_type[1:]) 134 | bn_class = { 135 | '1d': lambda num_features, **kwargs: FilmDualNormLayer(num_features, bn_class=base_norm_class['1d'], **kwargs), 136 | '2d': lambda num_features, **kwargs: FilmDualNormLayer(num_features, bn_class=base_norm_class['2d'], **kwargs), 137 | } 138 | elif bn_type == 'bn': 139 | bn_class = {'1d': nn.BatchNorm1d, '2d': nn.BatchNorm2d} 140 | elif bn_type == 'none': 141 | bn_class = {'1d': MockBatchNorm1d, 142 | '2d': MockBatchNorm2d} 143 | else: 144 | raise ValueError(f"Invalid bn_type: {bn_type}") 145 | return bn_class 146 | -------------------------------------------------------------------------------- /nets/HeteFL/preresne.py: -------------------------------------------------------------------------------- 1 | """Ref to HeteroFL pre-activated ResNet18 2 | Will be removed in the future. 3 | All bn names are set as n{digi}, which is fixed in `preresnet.py`. 4 | """ 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.nn.modules.batchnorm import _BatchNorm 10 | from torch.nn.modules.instancenorm import _InstanceNorm 11 | from ..models import ScalableModule 12 | 13 | 14 | hidden_size = [64, 128, 256, 512] 15 | 16 | 17 | class Block(nn.Module): 18 | expansion = 1 19 | 20 | def __init__(self, in_planes, planes, stride, norm_layer, conv_layer, scaler): 21 | super(Block, self).__init__() 22 | # self.norm_layer = norm_layer 23 | self.n1 = norm_layer(in_planes) 24 | self.conv1 = conv_layer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 25 | self.n2 = norm_layer(planes) 26 | self.conv2 = conv_layer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 27 | self.scaler = scaler 28 | 29 | if stride != 1 or in_planes != self.expansion * planes: 30 | self.shortcut = conv_layer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False) 31 | 32 | def forward(self, x): 33 | out = F.relu(self.n1(x)) 34 | shortcut = self.scaler(self.shortcut(out)) if hasattr(self, 'shortcut') else x 35 | out = self.scaler(self.conv1(out)) 36 | out = self.scaler(self.conv2(F.relu(self.n2(out)))) 37 | out += shortcut 38 | return out 39 | 40 | 41 | class Bottleneck(nn.Module): 42 | expansion = 4 43 | 44 | def __init__(self, in_planes, planes, stride, norm_layer, conv_layer, scaler): 45 | super(Bottleneck, self).__init__() 46 | # self.norm_layer = norm_layer 47 | self.n1 = norm_layer(in_planes) 48 | self.conv1 = conv_layer(in_planes, planes, kernel_size=1, bias=False) 49 | self.n2 = norm_layer(planes) 50 | self.conv2 = conv_layer(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 51 | self.n3 = norm_layer(planes) 52 | self.conv3 = conv_layer(planes, self.expansion * planes, kernel_size=1, bias=False) 53 | self.scaler = scaler 54 | 55 | if stride != 1 or in_planes != self.expansion * planes: 56 | self.shortcut = conv_layer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.n1(x)) 60 | shortcut = self.scaler(self.shortcut(out)) if hasattr(self, 'shortcut') else x 61 | out = self.scaler(self.conv1(out)) 62 | out = self.scaler(self.conv2(F.relu(self.n2(out)))) 63 | out = self.scaler(self.conv3(F.relu(self.n3(out)))) 64 | out += shortcut 65 | return out 66 | 67 | 68 | class ResNet(ScalableModule): 69 | input_shape = [None, 3, 32, 32] 70 | 71 | def __init__(self, hidden_size, block, num_blocks, num_classes=10, bn_type='bn', 72 | share_affine=False, track_running_stats=True, width_scale=1., 73 | rescale_init=False, rescale_layer=False): 74 | super(ResNet, self).__init__(width_scale=width_scale, rescale_init=rescale_init, 75 | rescale_layer=rescale_layer) 76 | 77 | if width_scale != 1.: 78 | hidden_size = [int(hs * width_scale) for hs in hidden_size] 79 | self.bn_type = bn_type 80 | # norm_layer = lambda n_ch: get_bn_layer(bn_type)['2d'](n_ch, track_running_stats=track_running_stats) 81 | if bn_type == 'bn': 82 | norm_layer = lambda n_ch: nn.BatchNorm2d(n_ch, track_running_stats=track_running_stats) 83 | elif bn_type == 'dbn': 84 | from ..dual_bn import DualNormLayer 85 | norm_layer = lambda n_ch: DualNormLayer(n_ch, track_running_stats=track_running_stats, affine=True, bn_class=nn.BatchNorm2d, 86 | share_affine=share_affine) 87 | else: 88 | raise RuntimeError(f"Not support bn_type={bn_type}") 89 | conv_layer = nn.Conv2d 90 | 91 | self.in_planes = hidden_size[0] 92 | self.conv1 = nn.Conv2d(3, hidden_size[0], kernel_size=3, stride=1, padding=1, bias=False) 93 | self.layer1 = self._make_layer(block, hidden_size[0], num_blocks[0], stride=1, 94 | norm_layer=norm_layer, conv_layer=conv_layer) 95 | self.layer2 = self._make_layer(block, hidden_size[1], num_blocks[1], stride=2, 96 | norm_layer=norm_layer, conv_layer=conv_layer) 97 | self.layer3 = self._make_layer(block, hidden_size[2], num_blocks[2], stride=2, 98 | norm_layer=norm_layer, conv_layer=conv_layer) 99 | self.layer4 = self._make_layer(block, hidden_size[3], num_blocks[3], stride=2, 100 | norm_layer=norm_layer, conv_layer=conv_layer) 101 | self.n4 = norm_layer(hidden_size[3] * block.expansion) 102 | self.linear = nn.Linear(hidden_size[3] * block.expansion, num_classes) 103 | 104 | self.reset_parameters(inp_nonscale_layers=['conv1']) 105 | 106 | def _make_layer(self, block, planes, num_blocks, stride, norm_layer, conv_layer): 107 | strides = [stride] + [1] * (num_blocks - 1) 108 | layers = [] 109 | for stride in strides: 110 | layers.append(block(self.in_planes, planes, stride, norm_layer, conv_layer, self.scaler)) 111 | self.in_planes = planes * block.expansion 112 | return nn.Sequential(*layers) 113 | 114 | def forward(self, x, return_pre_clf_fea=False): 115 | out = self.scaler(self.conv1(x)) 116 | out = self.layer1(out) 117 | out = self.layer2(out) 118 | out = self.layer3(out) 119 | out = self.layer4(out) 120 | out = F.relu(self.n4(out)) 121 | out = F.adaptive_avg_pool2d(out, 1) 122 | out = out.view(out.size(0), -1) 123 | logits = self.linear(out) 124 | if return_pre_clf_fea: 125 | return logits, out 126 | else: 127 | return logits 128 | 129 | def print_footprint(self): 130 | input_shape = self.input_shape 131 | input_shape[0] = 2 132 | x = torch.rand(input_shape) 133 | batch = x.shape[0] 134 | print(f"input: {np.prod(x.shape[1:])} <= {x.shape[1:]}") 135 | x = self.conv1(x) 136 | print(f"conv1: {np.prod(x.shape[1:])} <= {x.shape[1:]}") 137 | for i_layer, layer in enumerate([self.layer1, self.layer2, self.layer3, self.layer4]): 138 | x = layer(x) 139 | print(f"layer {i_layer}: {np.prod(x.shape[1:]):5d} <= {x.shape[1:]}") 140 | 141 | def init_param(m): 142 | """Special init for ResNet""" 143 | if isinstance(m, (_BatchNorm, _InstanceNorm)): 144 | m.weight.data.fill_(1) 145 | m.bias.data.zero_() 146 | elif isinstance(m, nn.Linear): 147 | m.bias.data.zero_() 148 | return m 149 | 150 | 151 | # Instantiations 152 | def resnet10(**kwargs): 153 | model = ResNet(hidden_size, Block, [1, 1, 1, 1], **kwargs) 154 | model.apply(init_param) 155 | return model 156 | 157 | 158 | def resnet18(**kwargs): 159 | model = ResNet(hidden_size, Block, [2, 2, 2, 2], **kwargs) 160 | model.apply(init_param) 161 | return model 162 | 163 | 164 | def resnet26(**kwargs): 165 | model = ResNet(hidden_size, Block, [3, 3, 3, 3], **kwargs) 166 | model.apply(init_param) 167 | return model 168 | 169 | 170 | def resnet34(**kwargs): 171 | model = ResNet(hidden_size, Block, [3, 4, 6, 3], **kwargs) 172 | model.apply(init_param) 173 | return model 174 | 175 | 176 | def resnet50(**kwargs): 177 | model = ResNet(hidden_size, Bottleneck, [3, 4, 6, 3], **kwargs) 178 | model.apply(init_param) 179 | return model 180 | 181 | 182 | def resnet101(**kwargs): 183 | model = ResNet(hidden_size, Bottleneck, [3, 4, 23, 3], **kwargs) 184 | model.apply(init_param) 185 | return model 186 | 187 | 188 | def resnet152(**kwargs): 189 | model = ResNet(hidden_size, Bottleneck, [3, 8, 36, 3], **kwargs) 190 | model.apply(init_param) 191 | return model 192 | 193 | 194 | if __name__ == '__main__': 195 | from nets.profile_func import profile_model 196 | 197 | model = resnet18(track_running_stats=False) 198 | flops, state_params = profile_model(model, verbose=True) 199 | print(flops/1e6, state_params/1e6) 200 | -------------------------------------------------------------------------------- /nets/HeteFL/preresnet.py: -------------------------------------------------------------------------------- 1 | """Ref to HeteroFL pre-activated ResNet18""" 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.modules.batchnorm import _BatchNorm 7 | from torch.nn.modules.instancenorm import _InstanceNorm 8 | from ..models import ScalableModule 9 | 10 | 11 | hidden_size = [64, 128, 256, 512] 12 | 13 | 14 | class Block(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride, norm_layer, conv_layer, scaler): 18 | super(Block, self).__init__() 19 | # self.norm_layer = norm_layer 20 | self.bn1 = norm_layer(in_planes) 21 | self.conv1 = conv_layer(in_planes, planes, kernel_size=3, stride=stride, padding=1, 22 | bias=False) 23 | self.bn2 = norm_layer(planes) 24 | self.conv2 = conv_layer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 25 | self.scaler = scaler 26 | 27 | if stride != 1 or in_planes != self.expansion * planes: 28 | self.shortcut = conv_layer(in_planes, self.expansion * planes, kernel_size=1, 29 | stride=stride, bias=False) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(x)) 33 | shortcut = self.scaler(self.shortcut(out)) if hasattr(self, 'shortcut') else x 34 | out = self.scaler(self.conv1(out)) 35 | out = self.scaler(self.conv2(F.relu(self.bn2(out)))) 36 | out += shortcut 37 | return out 38 | 39 | 40 | class Bottleneck(nn.Module): 41 | expansion = 4 42 | 43 | def __init__(self, in_planes, planes, stride, norm_layer, conv_layer, scaler): 44 | super(Bottleneck, self).__init__() 45 | # self.norm_layer = norm_layer 46 | self.bn1 = norm_layer(in_planes) 47 | self.conv1 = conv_layer(in_planes, planes, kernel_size=1, bias=False) 48 | self.bn2 = norm_layer(planes) 49 | self.conv2 = conv_layer(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 50 | self.bn3 = norm_layer(planes) 51 | self.conv3 = conv_layer(planes, self.expansion * planes, kernel_size=1, bias=False) 52 | self.scaler = scaler 53 | 54 | if stride != 1 or in_planes != self.expansion * planes: 55 | self.shortcut = conv_layer(in_planes, self.expansion * planes, kernel_size=1, 56 | stride=stride, bias=False) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.bn1(x)) 60 | shortcut = self.scaler(self.shortcut(out)) if hasattr(self, 'shortcut') else x 61 | out = self.scaler(self.conv1(out)) 62 | out = self.scaler(self.conv2(F.relu(self.bn2(out)))) 63 | out = self.scaler(self.conv3(F.relu(self.bn3(out)))) 64 | out += shortcut 65 | return out 66 | 67 | 68 | class ResNet(ScalableModule): 69 | input_shape = [None, 3, 32, 32] 70 | 71 | def __init__(self, hidden_size, block, num_blocks, num_classes=10, bn_type='bn', 72 | share_affine=False, track_running_stats=True, width_scale=1., 73 | rescale_init=False, rescale_layer=False): 74 | super(ResNet, self).__init__(width_scale=width_scale, rescale_init=rescale_init, 75 | rescale_layer=rescale_layer) 76 | 77 | if width_scale != 1.: 78 | hidden_size = [int(hs * width_scale) for hs in hidden_size] 79 | self.bn_type = bn_type 80 | # norm_layer = lambda n_ch: get_bn_layer(bn_type)['2d'](n_ch, track_running_stats=track_running_stats) 81 | if bn_type == 'bn': 82 | norm_layer = lambda n_ch: nn.BatchNorm2d(n_ch, track_running_stats=track_running_stats) 83 | elif bn_type == 'dbn': 84 | from ..dual_bn import DualNormLayer 85 | norm_layer = lambda n_ch: DualNormLayer(n_ch, track_running_stats=track_running_stats, 86 | affine=True, bn_class=nn.BatchNorm2d, 87 | share_affine=share_affine) 88 | else: 89 | raise RuntimeError(f"Not support bn_type={bn_type}") 90 | conv_layer = nn.Conv2d 91 | 92 | self.in_planes = hidden_size[0] 93 | self.conv1 = nn.Conv2d(3, hidden_size[0], kernel_size=3, stride=1, padding=1, bias=False) 94 | self.layer1 = self._make_layer(block, hidden_size[0], num_blocks[0], stride=1, 95 | norm_layer=norm_layer, conv_layer=conv_layer) 96 | self.layer2 = self._make_layer(block, hidden_size[1], num_blocks[1], stride=2, 97 | norm_layer=norm_layer, conv_layer=conv_layer) 98 | self.layer3 = self._make_layer(block, hidden_size[2], num_blocks[2], stride=2, 99 | norm_layer=norm_layer, conv_layer=conv_layer) 100 | self.layer4 = self._make_layer(block, hidden_size[3], num_blocks[3], stride=2, 101 | norm_layer=norm_layer, conv_layer=conv_layer) 102 | self.bn4 = norm_layer(hidden_size[3] * block.expansion) 103 | self.linear = nn.Linear(hidden_size[3] * block.expansion, num_classes) 104 | 105 | self.reset_parameters(inp_nonscale_layers=['conv1']) 106 | 107 | def _make_layer(self, block, planes, num_blocks, stride, norm_layer, conv_layer): 108 | strides = [stride] + [1] * (num_blocks - 1) 109 | layers = [] 110 | for stride in strides: 111 | layers.append(block(self.in_planes, planes, stride, norm_layer, conv_layer, 112 | self.scaler)) 113 | self.in_planes = planes * block.expansion 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x, return_pre_clf_fea=False): 117 | out = self.scaler(self.conv1(x)) 118 | out = self.layer1(out) 119 | out = self.layer2(out) 120 | out = self.layer3(out) 121 | out = self.layer4(out) 122 | out = F.relu(self.bn4(out)) 123 | out = F.adaptive_avg_pool2d(out, 1) 124 | out = out.view(out.size(0), -1) 125 | logits = self.linear(out) 126 | if return_pre_clf_fea: 127 | return logits, out 128 | else: 129 | return logits 130 | 131 | def print_footprint(self): 132 | input_shape = self.input_shape 133 | input_shape[0] = 2 134 | x = torch.rand(input_shape) 135 | batch = x.shape[0] 136 | print(f"input: {np.prod(x.shape[1:])} <= {x.shape[1:]}") 137 | x = self.conv1(x) 138 | print(f"conv1: {np.prod(x.shape[1:])} <= {x.shape[1:]}") 139 | for i_layer, layer in enumerate([self.layer1, self.layer2, self.layer3, self.layer4]): 140 | x = layer(x) 141 | print(f"layer {i_layer}: {np.prod(x.shape[1:]):5d} <= {x.shape[1:]}") 142 | 143 | def init_param(m): 144 | """Special init for ResNet""" 145 | if isinstance(m, (_BatchNorm, _InstanceNorm)): 146 | m.weight.data.fill_(1) 147 | m.bias.data.zero_() 148 | elif isinstance(m, nn.Linear): 149 | m.bias.data.zero_() 150 | return m 151 | 152 | 153 | # Instantiations 154 | def resnet10(**kwargs): 155 | model = ResNet(hidden_size, Block, [1, 1, 1, 1], **kwargs) 156 | model.apply(init_param) 157 | return model 158 | 159 | 160 | def resnet18(**kwargs): 161 | model = ResNet(hidden_size, Block, [2, 2, 2, 2], **kwargs) 162 | model.apply(init_param) 163 | return model 164 | 165 | 166 | def resnet26(**kwargs): 167 | model = ResNet(hidden_size, Block, [3, 3, 3, 3], **kwargs) 168 | model.apply(init_param) 169 | return model 170 | 171 | 172 | def resnet34(**kwargs): 173 | model = ResNet(hidden_size, Block, [3, 4, 6, 3], **kwargs) 174 | model.apply(init_param) 175 | return model 176 | 177 | 178 | def resnet50(**kwargs): 179 | model = ResNet(hidden_size, Bottleneck, [3, 4, 6, 3], **kwargs) 180 | model.apply(init_param) 181 | return model 182 | 183 | 184 | def resnet101(**kwargs): 185 | model = ResNet(hidden_size, Bottleneck, [3, 4, 23, 3], **kwargs) 186 | model.apply(init_param) 187 | return model 188 | 189 | 190 | def resnet152(**kwargs): 191 | model = ResNet(hidden_size, Bottleneck, [3, 8, 36, 3], **kwargs) 192 | model.apply(init_param) 193 | return model 194 | 195 | 196 | if __name__ == '__main__': 197 | from nets.profile_func import profile_model 198 | 199 | model = resnet18(track_running_stats=False) 200 | flops, state_params = profile_model(model, verbose=True) 201 | print(flops/1e6, state_params/1e6) 202 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import copy, argparse 2 | import numpy as np 3 | import math 4 | from collections import Counter 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def str2bool(v): 11 | if isinstance(v, bool): 12 | return v 13 | if v.lower() == 'true': 14 | return True 15 | elif v.lower() == 'false': 16 | return False 17 | else: 18 | raise argparse.ArgumentTypeError('Boolean value expected.') 19 | 20 | def set_seed(seed=None): 21 | import random 22 | if seed is not None: 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | np.random.seed(seed) 26 | random.seed(seed) 27 | 28 | 29 | class AverageMeter: 30 | """Computes and stores the average and current value""" 31 | 32 | def __init__(self): 33 | self.reset() 34 | 35 | def reset(self): 36 | self.values = [] 37 | self.counter = 0 38 | 39 | def append(self, val): 40 | self.values.append(val) 41 | self.counter += 1 42 | 43 | def extend(self, items): 44 | self.values.extend(items) 45 | self.counter += len(items) 46 | 47 | @property 48 | def val(self): 49 | return self.values[-1] 50 | 51 | @property 52 | def avg(self): 53 | return sum(self.values) / len(self.values) 54 | 55 | def __len__(self): 56 | return len(self.values) 57 | 58 | def __repr__(self): 59 | values = self.values 60 | if len(values) > 0: 61 | return ','.join([f" {metric}: {eval(f'np.{metric}')(values)}" 62 | for metric in ['mean', 'std', 'min', 'max']]) 63 | else: 64 | return 'empy meter' 65 | 66 | @property 67 | def last_avg(self): 68 | if self.counter == 0: 69 | return self.latest_avg 70 | else: 71 | self.latest_avg = sum(self.values[-self.counter:]) / self.counter 72 | self.counter = 0 73 | return self.latest_avg 74 | 75 | 76 | class LocalMaskCrossEntropyLoss(nn.CrossEntropyLoss): 77 | """Should be used for class-wise non-iid. 78 | Refer to HeteroFL (https://openreview.net/forum?id=TNkPBBYFkXg) 79 | """ 80 | def __init__(self, num_classes, **kwargs): 81 | super(LocalMaskCrossEntropyLoss, self).__init__(**kwargs) 82 | self.num_classes = num_classes 83 | 84 | def forward(self, input, target): 85 | classes = torch.unique(target) 86 | mask = torch.zeros_like(input) 87 | for c in range(self.num_classes): 88 | if c in classes: 89 | mask[:, c] = 1 # select included classes 90 | return F.cross_entropy(input*mask, target, weight=self.weight, 91 | ignore_index=self.ignore_index, reduction=self.reduction) 92 | 93 | 94 | # ///////////// samplers ///////////// 95 | class _Sampler(object): 96 | def __init__(self, arr): 97 | self.arr = copy.deepcopy(arr) 98 | 99 | def next(self): 100 | raise NotImplementedError() 101 | 102 | 103 | class shuffle_sampler(_Sampler): 104 | def __init__(self, arr, rng=None): 105 | super().__init__(arr) 106 | if rng is None: 107 | rng = np.random 108 | rng.shuffle(self.arr) 109 | self._idx = 0 110 | self._max_idx = len(self.arr) 111 | 112 | def next(self): 113 | if self._idx >= self._max_idx: 114 | np.random.shuffle(self.arr) 115 | self._idx = 0 116 | v = self.arr[self._idx] 117 | self._idx += 1 118 | return v 119 | 120 | 121 | class random_sampler(_Sampler): 122 | def next(self): 123 | # np.random.randint(0, int(1 / slim_ratios[0])) 124 | v = np.random.choice(self.arr) # single value. If multiple value, note the replace param. 125 | return v 126 | 127 | 128 | class constant_sampler(_Sampler): 129 | def __init__(self, value): 130 | super().__init__([]) 131 | self.value = value 132 | 133 | def next(self): 134 | return self.value 135 | 136 | 137 | # ///////////// lr schedulers ///////////// 138 | class CosineAnnealingLR(object): 139 | r"""Set the learning rate of each parameter group using a cosine annealing 140 | schedule, where :math:`\eta_{max}` is set to the initial lr and 141 | :math:`T_{cur}` is the number of epochs since the last restart in SGDR: 142 | 143 | .. math:: 144 | \begin{aligned} 145 | \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 146 | + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), 147 | & T_{cur} \neq (2k+1)T_{max}; \\ 148 | \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) 149 | \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), 150 | & T_{cur} = (2k+1)T_{max}. 151 | \end{aligned} 152 | 153 | When last_epoch=-1, sets initial lr as lr. Notice that because the schedule 154 | is defined recursively, the learning rate can be simultaneously modified 155 | outside this scheduler by other operators. If the learning rate is set 156 | solely by this scheduler, the learning rate at each step becomes: 157 | 158 | .. math:: 159 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + 160 | \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) 161 | 162 | It has been proposed in 163 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only 164 | implements the cosine annealing part of SGDR, and not the restarts. 165 | 166 | Args: 167 | optimizer (Optimizer): Wrapped optimizer. 168 | T_max (int): Maximum number of iterations. 169 | eta_min (float): Minimum learning rate. Default: 0. 170 | last_epoch (int): The index of last epoch. Default: -1. 171 | verbose (bool): If ``True``, prints a message to stdout for 172 | each update. Default: ``False``. 173 | 174 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 175 | https://arxiv.org/abs/1608.03983 176 | """ 177 | 178 | def __init__(self, T_max, eta_max=1e-2, eta_min=0, last_epoch=0, warmup=None): 179 | self.T_max = T_max 180 | self.eta_max = eta_max 181 | self.eta_min = eta_min 182 | self.last_epoch = last_epoch 183 | self._cur_lr = eta_max 184 | self._eta_max = eta_max 185 | # super(CosineAnnealingLR, self).__init__(optimizer, last_epoch, verbose) 186 | self.warmup = warmup 187 | 188 | def step(self): 189 | self._cur_lr = self._get_lr() 190 | self.last_epoch += 1 191 | return self._cur_lr 192 | 193 | def _get_lr(self): 194 | if self.warmup is not None and self.warmup > 0: 195 | if self.last_epoch < self.warmup: 196 | return self._eta_max * ((self.last_epoch+1e-2) / self.warmup) 197 | elif self.last_epoch == self.warmup: 198 | return self._eta_max 199 | if self.last_epoch == 0: 200 | return self.eta_max 201 | elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: 202 | return self._cur_lr + (self.eta_max - self.eta_min) * \ 203 | (1 - math.cos(math.pi / self.T_max)) / 2 204 | return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / \ 205 | (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * \ 206 | (self._cur_lr - self.eta_min) + self.eta_min 207 | 208 | 209 | class MultiStepLR(object): 210 | def __init__(self, eta_max, milestones, gamma=0.1, last_epoch=-1, warmup=None): 211 | self.milestones = Counter(milestones) 212 | self.gamma = gamma 213 | self.last_epoch = last_epoch 214 | self._cur_lr = eta_max 215 | self._eta_max = eta_max 216 | # super(MultiStepLR, self).__init__(optimizer, last_epoch, verbose) 217 | self.warmup = warmup 218 | 219 | def step(self): 220 | self._cur_lr = self._get_lr() 221 | self.last_epoch += 1 222 | return self._cur_lr 223 | 224 | def _get_lr(self): 225 | if self.warmup is not None and self.warmup > 0: 226 | if self.last_epoch < self.warmup: 227 | return self._eta_max * ((self.last_epoch+1e-3) / self.warmup) 228 | elif self.last_epoch == self.warmup: 229 | return self._eta_max 230 | if self.last_epoch not in self.milestones: 231 | return self._cur_lr 232 | return self._cur_lr * self.gamma ** self.milestones[self.last_epoch] 233 | 234 | 235 | def test_lr_sch(sch_name='cos'): 236 | lr_init = 0.1 237 | T = 150 238 | if sch_name == 'cos': 239 | sch = CosineAnnealingLR(T, lr_init, last_epoch=0, warmup=5) 240 | elif sch_name == 'multi_step': 241 | sch = MultiStepLR(lr_init, [50, 100], last_epoch=0, warmup=5) 242 | 243 | for step in range(150): 244 | lr = sch.step() 245 | if step % 20 == 0 or step < 20: 246 | print(f"[{step:3d}] lr={lr:.4f}") 247 | 248 | # resume 249 | print(f"Resume from step{step} with lr={lr:.4f}") 250 | T = 300 251 | if sch_name == 'cos': 252 | sch = CosineAnnealingLR(T, lr_init, last_epoch=step) 253 | elif sch_name == 'multi_step': 254 | sch = MultiStepLR(lr_init, [2, 4, 4, 50], last_epoch=step) 255 | for step in range(step, step+10): 256 | lr = sch.step() 257 | print(f"[{step:3d}] lr={lr:.4f}") 258 | 259 | 260 | if __name__ == '__main__': 261 | test_lr_sch('cos') 262 | -------------------------------------------------------------------------------- /federated/learning.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numpy as np 4 | import torch 5 | import copy 6 | from torch import optim 7 | #from advertorch.context import ctx_noparamgrad_and_eval 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | from tqdm import tqdm 10 | 11 | from utils.utils import AverageMeter 12 | from nets.dual_bn import set_bn_mode 13 | 14 | 15 | def if_use_dbn(model): 16 | if isinstance(model, DDP): 17 | return model.module.bn_type.startswith('d') 18 | else: 19 | return model.bn_type.startswith('d') 20 | 21 | 22 | 23 | def train_fedprox(mu, model, data_loader, optimizer, loss_fun, device, start_iter=0, max_iter=np.inf, progress=True): 24 | 25 | model.train() 26 | serverModel = copy.deepcopy(model) 27 | 28 | loss_all = 0 29 | total = 0 30 | correct = 0 31 | max_iter = len(data_loader) if max_iter == np.inf else max_iter 32 | data_iterator = iter(data_loader) 33 | tqdm_iters = tqdm(range(start_iter, max_iter), file=sys.stdout) \ 34 | if progress else range(start_iter, max_iter) 35 | 36 | # ordinary training. 37 | set_bn_mode(model, False) # set clean mode 38 | for step in tqdm_iters: 39 | # for data, target in tqdm(data_loader, file=sys.stdout): 40 | try: 41 | data, target = next(data_iterator) 42 | except StopIteration: 43 | data_iterator = iter(data_loader) 44 | data, target = next(data_iterator) 45 | optimizer.zero_grad() 46 | model.zero_grad() 47 | 48 | data = data.to(device) 49 | target = target.to(device) 50 | output = model(data) 51 | loss = loss_fun(output, target.long()) 52 | 53 | ##################### fedProx Implementation ##################### 54 | w_diff = torch.tensor(0., device=device) 55 | for w, w_t in zip(serverModel.parameters(), model.parameters()): 56 | w_diff += torch.pow(torch.norm(w - w_t), 2) 57 | loss += mu / 2. * w_diff 58 | ################################################################## 59 | 60 | loss_all += loss.item() * target.size(0) 61 | total += target.size(0) 62 | pred = output.data.max(1)[1] 63 | correct += pred.eq(target.view(-1)).sum().item() 64 | 65 | loss.backward() 66 | optimizer.step() 67 | return loss_all / total, correct / total 68 | 69 | 70 | def train(model, data_loader, optimizer, loss_fun, device, start_iter=0, max_iter=np.inf, progress=True): 71 | 72 | model.train() 73 | loss_all = 0 74 | total = 0 75 | correct = 0 76 | max_iter = len(data_loader) if max_iter == np.inf else max_iter 77 | data_iterator = iter(data_loader) 78 | tqdm_iters = tqdm(range(start_iter, max_iter), file=sys.stdout) \ 79 | if progress else range(start_iter, max_iter) 80 | 81 | # ordinary training. 82 | set_bn_mode(model, False) # set clean mode 83 | for step in tqdm_iters: 84 | # for data, target in tqdm(data_loader, file=sys.stdout): 85 | try: 86 | data, target = next(data_iterator) 87 | except StopIteration: 88 | data_iterator = iter(data_loader) 89 | data, target = next(data_iterator) 90 | optimizer.zero_grad() 91 | model.zero_grad() 92 | 93 | data = data.to(device) 94 | target = target.to(device) 95 | output = model(data) 96 | loss = loss_fun(output, target.long()) 97 | 98 | loss_all += loss.item() * target.size(0) 99 | total += target.size(0) 100 | pred = output.data.max(1)[1] 101 | correct += pred.eq(target.view(-1)).sum().item() 102 | 103 | loss.backward() 104 | optimizer.step() 105 | 106 | return loss_all / total, correct / total 107 | 108 | 109 | def train_slimmable(model, data_loader, optimizer, loss_fun, device, 110 | start_iter=0, max_iter=np.inf, 111 | slim_ratios=[0.5, 0.75, 1.0], slim_shifts=0, out_slim_shifts=None, 112 | progress=True, loss_temp='none'): 113 | """If slim_ratios is a single value, use `train` and set slim_ratio outside, instead. 114 | """ 115 | # expand scalar slim_shift to list 116 | if not isinstance(slim_shifts, (list, tuple)): 117 | slim_shifts = [slim_shifts for _ in range(len(slim_ratios))] 118 | if not isinstance(out_slim_shifts, (list, tuple)): 119 | out_slim_shifts = [out_slim_shifts for _ in range(len(slim_ratios))] 120 | 121 | model.train() 122 | total, correct, loss_all = 0, 0, 0 123 | max_iter = len(data_loader) if max_iter == np.inf else max_iter 124 | data_iterator = iter(data_loader) 125 | 126 | # ordinary training. 127 | set_bn_mode(model, False) # set clean mode 128 | for step in tqdm(range(start_iter, max_iter), file=sys.stdout, disable=not progress): 129 | # for data, target in tqdm(data_loader, file=sys.stdout): 130 | try: 131 | data, target = next(data_iterator) 132 | except StopIteration: 133 | data_iterator = iter(data_loader) 134 | data, target = next(data_iterator) 135 | optimizer.zero_grad() 136 | model.zero_grad() 137 | 138 | data = data.to(device) 139 | target = target.to(device) 140 | 141 | 142 | for slim_ratio, in_slim_shift, out_slim_shift \ 143 | in sorted(zip(slim_ratios, slim_shifts, out_slim_shifts), reverse=False, 144 | key=lambda ss_pair: ss_pair[0]): 145 | model.switch_slim_mode(slim_ratio, slim_bias_idx=in_slim_shift, out_slim_bias_idx=out_slim_shift) 146 | 147 | output = model(data) 148 | if loss_temp == 'none': 149 | _loss = loss_fun(output, target.long()) 150 | elif loss_temp == 'auto': 151 | _loss = loss_fun(output/slim_ratio, target) * slim_ratio 152 | elif loss_temp.replace('.', '', 1).isdigit(): # is float 153 | _temp = float(loss_temp) 154 | _loss = loss_fun(output / _temp, target) * _temp 155 | else: 156 | raise NotImplementedError(f"loss_temp: {loss_temp}") 157 | 158 | loss_all += _loss.item() * target.size(0) 159 | total += target.size(0) 160 | pred = output.data.max(1)[1] 161 | correct += pred.eq(target.view(-1)).sum().item() 162 | 163 | _loss.backward() 164 | optimizer.step() 165 | 166 | return loss_all / total, correct / total 167 | 168 | 169 | # =========== Test =========== 170 | 171 | 172 | def personalization(model, data_loader_train, data_loader_test, loss_fun, global_lr, device, progress=False): 173 | 174 | 175 | model.train() 176 | 177 | optimizer = optim.SGD(params=model.parameters(), lr=global_lr, 178 | momentum=0.9, weight_decay=5e-4) 179 | 180 | loss_all, total, correct = 0, 0, 0 181 | for iter in range(5): 182 | for data, target in tqdm(data_loader_train, file=sys.stdout, disable=not progress): 183 | data, target = data.to(device), target.to(device) 184 | 185 | #with torch.no_grad(): 186 | optimizer.zero_grad() 187 | model.zero_grad() 188 | output = model(data) 189 | loss = loss_fun(output, target.long()) 190 | 191 | loss_all += loss.item() 192 | total += target.size(0) 193 | pred = output.data.max(1)[1] 194 | correct += pred.eq(target.view(-1)).sum().item() 195 | 196 | loss.backward() 197 | optimizer.step() 198 | 199 | 200 | val_loss, val_acc = test(model, data_loader_test, loss_fun, device) 201 | 202 | return val_loss, val_acc 203 | 204 | def personalization_slimmable(model, data_loader_train, data_loader_test, loss_fun, global_lr, device, progress=False): 205 | 206 | model.train() 207 | 208 | optimizer = optim.SGD(params=model.parameters(), lr=global_lr, 209 | momentum=0.9, weight_decay=5e-4) 210 | 211 | 212 | atom_slim_ratio = 0.125 213 | user_n_base = int(1.0 / atom_slim_ratio) 214 | slim_ratios = [atom_slim_ratio] * user_n_base 215 | slim_shifts = [ii for ii in range(user_n_base)] 216 | out_slim_shifts = [None for _ in range(len(slim_ratios))] 217 | 218 | set_bn_mode(model, False) 219 | for iter in range(5): 220 | for data, target in tqdm(data_loader_train, file=sys.stdout, disable=not progress): 221 | data, target = data.to(device), target.to(device) 222 | 223 | 224 | 225 | optimizer.zero_grad() 226 | model.zero_grad() 227 | 228 | 229 | for slim_ratio, in_slim_shift, out_slim_shift \ 230 | in sorted(zip(slim_ratios, slim_shifts, out_slim_shifts), reverse=False, 231 | key=lambda ss_pair: ss_pair[0]): 232 | model.switch_slim_mode(slim_ratio, slim_bias_idx=in_slim_shift, out_slim_bias_idx=out_slim_shift) 233 | 234 | output = model(data) 235 | loss = loss_fun(output, target.long()) 236 | 237 | 238 | 239 | loss.backward() 240 | optimizer.step() 241 | 242 | model.switch_slim_mode(1.0) 243 | val_loss, val_acc = test(model, data_loader_test, loss_fun, device) 244 | 245 | 246 | return val_loss, val_acc 247 | 248 | 249 | def test(model, data_loader, loss_fun, device, progress=False): 250 | 251 | model.eval() 252 | 253 | 254 | loss_all, total, correct = 0, 0, 0 255 | for data, target in tqdm(data_loader, file=sys.stdout, disable=not progress): 256 | data, target = data.to(device), target.to(device) 257 | 258 | 259 | with torch.no_grad(): 260 | output = model(data) 261 | loss = loss_fun(output, target.long()) 262 | 263 | loss_all += loss.item() 264 | total += target.size(0) 265 | pred = output.data.max(1)[1] 266 | correct += pred.eq(target.view(-1)).sum().item() 267 | return loss_all / len(data_loader), correct/total 268 | 269 | 270 | def refresh_bn(model, data_loader, device, progress=False): 271 | model.train() 272 | for data, target in tqdm(data_loader, file=sys.stdout, disable=not progress): 273 | data, target = data.to(device), target.to(device) 274 | 275 | 276 | with torch.no_grad(): 277 | model(data) 278 | 279 | 280 | def fed_test_model(fed, running_model, test_loaders, loss_fun, device): 281 | test_acc_mt = AverageMeter() 282 | for test_idx, test_loader in enumerate(test_loaders): 283 | fed.download(running_model, test_idx) 284 | _, test_acc = test(running_model, test_loader, loss_fun, device) 285 | # print(' {:<11s}| Test Acc: {:.4f}'.format(fed.clients[test_idx], test_acc)) 286 | 287 | # wandb.summary[f'{fed.clients[test_idx]} test acc'] = test_acc 288 | test_acc_mt.append(test_acc) 289 | return test_acc_mt.avg 290 | 291 | -------------------------------------------------------------------------------- /nets/HeteFL/slimmable_preresne.py: -------------------------------------------------------------------------------- 1 | """Ref to HeteroFL pre-activated ResNet18""" 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.modules.batchnorm import _BatchNorm 7 | from torch.nn.modules.instancenorm import _InstanceNorm 8 | # from ..bn_ops import get_bn_layer 9 | from ..slimmable_models import BaseModule, SlimmableMixin 10 | from ..slimmable_ops import SlimmableConv2d, SlimmableBatchNorm2d, SlimmableLinear 11 | 12 | 13 | hidden_size = [64, 128, 256, 512] 14 | 15 | 16 | class Block(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, in_planes, planes, stride, norm_layer, conv_layer, fix_out=False, 20 | fix_in=False): 21 | super(Block, self).__init__() 22 | # self.norm_layer = norm_layer 23 | if fix_in: 24 | self.n1 = norm_layer(in_planes, non_slimmable=True) 25 | self.conv1 = conv_layer(in_planes, planes, kernel_size=3, stride=stride, padding=1, 26 | bias=False, non_slimmable_in=True) 27 | else: 28 | self.n1 = norm_layer(in_planes) 29 | self.conv1 = conv_layer(in_planes, planes, kernel_size=3, stride=stride, padding=1, 30 | bias=False) 31 | self.n2 = norm_layer(planes) 32 | if fix_out: 33 | self.conv2 = conv_layer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False, 34 | non_slimmable_out=fix_out) 35 | else: 36 | self.conv2 = conv_layer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 37 | 38 | if stride != 1 or in_planes != self.expansion * planes: 39 | self.shortcut = conv_layer(in_planes, self.expansion * planes, kernel_size=1, 40 | stride=stride, bias=False) 41 | elif fix_out: 42 | self.shortcut = conv_layer(in_planes, self.expansion * planes, kernel_size=1, 43 | stride=stride, bias=False, non_slimmable_out=fix_out) 44 | elif fix_in: 45 | self.shortcut = conv_layer(in_planes, self.expansion * planes, kernel_size=1, 46 | stride=stride, bias=False, non_slimmable_in=fix_in) 47 | 48 | def forward(self, x): 49 | out = F.relu(self.n1(x)) 50 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 51 | out = self.conv1(out) 52 | out = self.conv2(F.relu(self.n2(out))) 53 | out += shortcut 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, in_planes, planes, stride, norm_layer, conv_layer, fix_out=False, 61 | fix_in=False): 62 | super(Bottleneck, self).__init__() 63 | assert not fix_out 64 | assert not fix_in 65 | # self.norm_layer = norm_layer 66 | self.n1 = norm_layer(in_planes) 67 | self.conv1 = conv_layer(in_planes, planes, kernel_size=1, bias=False) 68 | self.n2 = norm_layer(planes) 69 | self.conv2 = conv_layer(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 70 | self.n3 = norm_layer(planes) 71 | self.conv3 = conv_layer(planes, self.expansion * planes, kernel_size=1, bias=False, 72 | non_slimmable_out=fix_out) 73 | 74 | if stride != 1 or in_planes != self.expansion * planes: 75 | self.shortcut = conv_layer(in_planes, self.expansion * planes, kernel_size=1, 76 | stride=stride, bias=False, non_slimmable_out=fix_out) 77 | 78 | def forward(self, x): 79 | out = F.relu(self.n1(x)) 80 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 81 | out = self.conv1(out) 82 | out = self.conv2(F.relu(self.n2(out))) 83 | out = self.conv3(F.relu(self.n3(out))) 84 | out += shortcut 85 | return out 86 | 87 | 88 | class ResNet(BaseModule, SlimmableMixin): 89 | input_shape = [None, 3, 32, 32] 90 | 91 | def __init__(self, hidden_size, block, num_blocks, num_classes=10, bn_type='bn', 92 | track_running_stats=True, width_scale=1., share_affine=False, slimmabe_ratios=None): 93 | super(ResNet, self).__init__() 94 | self._set_slimmabe_ratios(slimmabe_ratios) 95 | 96 | if width_scale != 1.: 97 | hidden_size = [int(hs * width_scale) for hs in hidden_size] 98 | if bn_type.startswith('d'): 99 | print("WARNING: When using dual BN, you should not do slimming.") 100 | if track_running_stats: 101 | print("WARNING: We cannot track running_stats when slimmable BN is used.") 102 | self.bn_type = bn_type 103 | if bn_type == 'bn': 104 | norm_layer = lambda n_ch, **kwargs: SlimmableBatchNorm2d( 105 | n_ch, track_running_stats=track_running_stats, **kwargs) 106 | elif bn_type == 'dbn': 107 | from ..dual_bn import DualNormLayer 108 | assert not share_affine, "We don't recommend to share affine." 109 | norm_layer = lambda n_ch: DualNormLayer( 110 | n_ch, track_running_stats=track_running_stats, affine=True, 111 | bn_class=SlimmableBatchNorm2d, share_affine=share_affine) 112 | else: 113 | raise RuntimeError(f"Not support bn_type={bn_type}") 114 | conv_layer = SlimmableConv2d 115 | 116 | self.in_planes = hidden_size[0] 117 | self.conv1 = SlimmableConv2d(3, hidden_size[0], kernel_size=3, stride=1, padding=1, 118 | bias=False, non_slimmable_in=True) 119 | self.layer1 = self._make_layer(block, hidden_size[0], num_blocks[0], stride=1, 120 | norm_layer=norm_layer, conv_layer=conv_layer) 121 | self.layer2 = self._make_layer(block, hidden_size[1], num_blocks[1], stride=2, 122 | norm_layer=norm_layer, conv_layer=conv_layer) 123 | self.layer3 = self._make_layer(block, hidden_size[2], num_blocks[2], stride=2, 124 | norm_layer=norm_layer, conv_layer=conv_layer) 125 | self.layer4 = self._make_layer(block, hidden_size[3], num_blocks[3], stride=2, 126 | norm_layer=norm_layer, conv_layer=conv_layer) 127 | self.n4 = norm_layer(hidden_size[3] * block.expansion) 128 | self.linear = SlimmableLinear(hidden_size[3] * block.expansion, num_classes, 129 | non_slimmable_out=True) 130 | 131 | def _make_layer(self, block, planes, num_blocks, stride, norm_layer, conv_layer, 132 | fix_out=False, fix_in=False): 133 | strides = [stride] + [1] * (num_blocks - 1) 134 | layers = [] 135 | for i_layer, stride in enumerate(strides): 136 | layers.append( 137 | block(self.in_planes, planes, stride, norm_layer, conv_layer, 138 | fix_out=False if (not fix_out) or (i_layer < num_blocks - 1) else fix_out, 139 | fix_in=False if (not fix_in) or (i_layer > 0) else fix_in)) 140 | self.in_planes = planes * block.expansion 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x, return_pre_clf_fea=False): 144 | out = self.conv1(x) 145 | out = self.layer1(out) 146 | out = self.layer2(out) 147 | out = self.layer3(out) 148 | out = self.layer4(out) 149 | out = F.relu(self.n4(out)) 150 | out = F.adaptive_avg_pool2d(out, 1) 151 | out = out.view(out.size(0), -1) 152 | logits = self.linear(out) 153 | if return_pre_clf_fea: 154 | return logits, out 155 | else: 156 | return logits 157 | 158 | def print_footprint(self): 159 | input_shape = self.input_shape 160 | input_shape[0] = 2 161 | x = torch.rand(input_shape) 162 | batch = x.shape[0] 163 | print(f"input: {np.prod(x.shape[1:])} <= {x.shape[1:]}") 164 | x = self.conv1(x) 165 | print(f"conv1: {np.prod(x.shape[1:])} <= {x.shape[1:]}") 166 | for i_layer, layer in enumerate([self.layer1, self.layer2, self.layer3, self.layer4]): 167 | x = layer(x) 168 | print(f"layer {i_layer}: {np.prod(x.shape[1:]):5d} <= {x.shape[1:]}") 169 | 170 | def init_param(m): 171 | if isinstance(m, (_BatchNorm, _InstanceNorm)): 172 | m.weight.data.fill_(1) 173 | m.bias.data.zero_() 174 | elif isinstance(m, nn.Linear): 175 | m.bias.data.zero_() 176 | return m 177 | 178 | 179 | # Instantiations 180 | def resnet18(**kwargs): 181 | model = ResNet(hidden_size, Block, [2, 2, 2, 2], **kwargs) 182 | model.apply(init_param) 183 | return model 184 | 185 | 186 | def resnet26(**kwargs): 187 | model = ResNet(hidden_size, Block, [3, 3, 3, 3], **kwargs) 188 | model.apply(init_param) 189 | return model 190 | 191 | 192 | def resnet34(**kwargs): 193 | model = ResNet(hidden_size, Block, [3, 4, 6, 3], **kwargs) 194 | model.apply(init_param) 195 | return model 196 | 197 | 198 | def resnet50(**kwargs): 199 | model = ResNet(hidden_size, Bottleneck, [3, 4, 6, 3], **kwargs) 200 | model.apply(init_param) 201 | return model 202 | 203 | 204 | def resnet101(**kwargs): 205 | model = ResNet(hidden_size, Bottleneck, [3, 4, 23, 3], **kwargs) 206 | model.apply(init_param) 207 | return model 208 | 209 | 210 | def resnet152(**kwargs): 211 | model = ResNet(hidden_size, Bottleneck, [3, 8, 36, 3], **kwargs) 212 | model.apply(init_param) 213 | return model 214 | 215 | 216 | def main(): 217 | # check_depths() 218 | check_widths() 219 | 220 | 221 | def check_depths(): 222 | from nets.profile_func import profile_slimmable_models 223 | print(f"profile model GFLOPs (forward complexity) and size (#param)") 224 | 225 | for resnet in [resnet18, resnet34, resnet50]: 226 | model = resnet(track_running_stats=False, bn_type='bn') 227 | model.eval() # this will affect bn etc 228 | 229 | print(f"\nmodel {resnet.__name__} on {'training' if model.training else 'eval'} mode") 230 | profile_slimmable_models(model, model.slimmable_ratios) 231 | 232 | def check_widths(): 233 | from nets.profile_func import profile_slimmable_models 234 | from nets.slimmable_models import EnsembleSubnet, EnsembleGroupSubnet 235 | 236 | print(f"profile model GFLOPs (forward complexity) and size (#param)") 237 | 238 | model = resnet18(track_running_stats=False, bn_type='bn') 239 | model.eval() # this will affect bn etc 240 | 241 | print(f"model {model.__class__.__name__} on {'training' if model.training else 'eval'} mode") 242 | input_shape = model.input_shape 243 | # batch_size = 2 244 | # input_shape[0] = batch_size 245 | profile_slimmable_models(model, model.slimmable_ratios) 246 | print(f"\n==footprint==") 247 | model.switch_slim_mode(1.) 248 | model.print_footprint() 249 | print(f"\n==footprint==") 250 | model.switch_slim_mode(0.125) 251 | model.print_footprint() 252 | 253 | print(f'\n--------------') 254 | full_net = model 255 | model = EnsembleGroupSubnet(full_net, [0.125, 0.125, 0.25, 0.5], [0, 1, 1, 1]) 256 | model.eval() 257 | print(f"model {model.__class__.__name__} on {'training' if model.training else 'eval'} mode") 258 | profile_slimmable_models(model, model.full_net.slimmable_ratios) 259 | 260 | print(f'\n--------------') 261 | model = EnsembleSubnet(full_net, 0.125) 262 | model.eval() 263 | print(f"model {model.__class__.__name__} on {'training' if model.training else 'eval'} mode") 264 | profile_slimmable_models(model, model.full_net.slimmable_ratios) 265 | 266 | 267 | if __name__ == '__main__': 268 | main() 269 | 270 | -------------------------------------------------------------------------------- /nets/models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from collections import OrderedDict 4 | from typing import Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as func 9 | from torch.nn.modules.conv import _ConvNd 10 | 11 | from .bn_ops import get_bn_layer 12 | from .dual_bn import DualNormLayer 13 | 14 | 15 | class BaseModule(nn.Module): 16 | def set_bn_mode(self, is_noised: Union[bool, torch.Tensor]): 17 | """Set BN mode to be noised or clean. This is only effective for StackedNormLayer 18 | or DualNormLayer.""" 19 | def set_bn_eval_(m): 20 | if isinstance(m, (DualNormLayer,)): 21 | if isinstance(is_noised, (float, int)): 22 | m.clean_input = 1. - is_noised 23 | elif isinstance(is_noised, torch.Tensor): 24 | m.clean_input = ~is_noised 25 | else: 26 | m.clean_input = not is_noised 27 | self.apply(set_bn_eval_) 28 | 29 | # forward 30 | def forward(self, x): 31 | z = self.encode(x) 32 | logits = self.decode_clf(z) 33 | return logits 34 | 35 | def encode(self, x): 36 | x = self.features(x) 37 | x = self.avgpool(x) 38 | z = torch.flatten(x, 1) 39 | return z 40 | 41 | def decode_clf(self, z): 42 | logits = self.classifier(z) 43 | return logits 44 | 45 | def mix_dual_forward(self, x, lmbd, deep_mix=False): 46 | if deep_mix: 47 | self.set_bn_mode(lmbd) 48 | logit = self.forward(x) 49 | else: 50 | # FIXME this will result in unexpected result for non-dual models? 51 | logit = 0 52 | if lmbd < 1: 53 | self.set_bn_mode(False) 54 | logit = logit + (1 - lmbd) * self.forward(x) 55 | 56 | if lmbd > 0: 57 | self.set_bn_mode(True) 58 | logit = logit + lmbd * self.forward(x) 59 | return logit 60 | 61 | 62 | def kaiming_uniform_in_(tensor, a=0, mode='fan_in', scale=1., nonlinearity='leaky_relu'): 63 | """Modified from torch.nn.init.kaiming_uniform_""" 64 | fan_in = nn.init._calculate_correct_fan(tensor, mode) 65 | fan_in *= scale 66 | gain = nn.init.calculate_gain(nonlinearity, a) 67 | std = gain / math.sqrt(fan_in) 68 | bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation 69 | with torch.no_grad(): 70 | return tensor.uniform_(-bound, bound) 71 | 72 | def scale_init_param(m, scale_in=1.): 73 | """Scale w.r.t. input dim.""" 74 | if isinstance(m, (nn.Linear, _ConvNd)): 75 | kaiming_uniform_in_(m.weight, a=math.sqrt(5), scale=scale_in, mode='fan_in') 76 | if m.bias is not None: 77 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight) 78 | fan_in *= scale_in 79 | bound = 1 / math.sqrt(fan_in) 80 | nn.init.uniform_(m.bias, -bound, bound) 81 | return m 82 | 83 | 84 | class Scaler(nn.Module): 85 | def __init__(self, width_scale): 86 | super(Scaler, self).__init__() 87 | self.width_scale = width_scale 88 | 89 | def forward(self, x): 90 | return x / self.width_scale if self.training else x 91 | 92 | 93 | class ScalableModule(BaseModule): 94 | def __init__(self, width_scale=1., rescale_init=False, rescale_layer=False): 95 | super(ScalableModule, self).__init__() 96 | if rescale_layer: 97 | self.scaler = Scaler(width_scale) 98 | else: 99 | self.scaler = nn.Identity() 100 | self.rescale_init = rescale_init 101 | self.width_scale = width_scale 102 | 103 | def reset_parameters(self, inp_nonscale_layers): 104 | if self.rescale_init and self.width_scale != 1.: 105 | for name, m in self._modules.items(): 106 | if name not in inp_nonscale_layers: # NOTE ignore the layer with non-slimmable inp. 107 | m.apply(lambda _m: scale_init_param(_m, scale_in=1./self.width_scale)) 108 | 109 | @property 110 | def rescale_layer(self): 111 | return not isinstance(self.scaler, nn.Identity) 112 | 113 | @rescale_layer.setter 114 | def rescale_layer(self, enable=True): 115 | if enable: 116 | self.scaler = Scaler(self.width_scale) 117 | else: 118 | self.scaler = nn.Identity() 119 | 120 | 121 | class DigitModel(ScalableModule): 122 | """ 123 | Model for benchmark experiment on Digits. 124 | """ 125 | input_shape = [None, 3, 28, 28] 126 | 127 | def __init__(self, num_classes=10, bn_type='bn', track_running_stats=True, 128 | width_scale=1., share_affine=True, rescale_init=False, rescale_layer=False): 129 | super(DigitModel, self).__init__(width_scale=width_scale, rescale_init=rescale_init, 130 | rescale_layer=rescale_layer) 131 | bn_class = get_bn_layer(bn_type) 132 | bn_kwargs = dict( 133 | track_running_stats=track_running_stats, 134 | ) 135 | if bn_type.startswith('d'): # dual BN 136 | bn_kwargs['share_affine'] = share_affine 137 | conv_layers = [64, 64, 128] 138 | fc_layers = [2048, 512] 139 | conv_layers = [int(width_scale*l) for l in conv_layers] 140 | fc_layers = [int(width_scale*l) for l in fc_layers] 141 | self.bn_type = bn_type 142 | 143 | self.conv1 = nn.Conv2d(3, conv_layers[0], 5, 1, 2) 144 | self.bn1 = bn_class['2d'](conv_layers[0], **bn_kwargs) 145 | 146 | self.conv2 = nn.Conv2d(conv_layers[0], conv_layers[1], 5, 1, 2) 147 | self.bn2 = bn_class['2d'](conv_layers[1], **bn_kwargs) 148 | 149 | self.conv3 = nn.Conv2d(conv_layers[1], conv_layers[2], 5, 1, 2) 150 | self.bn3 = bn_class['2d'](conv_layers[2], **bn_kwargs) 151 | 152 | self.fc1 = nn.Linear(conv_layers[2]*7*7, fc_layers[0]) 153 | self.bn4 = bn_class['1d'](fc_layers[0], **bn_kwargs) 154 | 155 | self.fc2 = nn.Linear(fc_layers[0], fc_layers[1]) 156 | self.bn5 = bn_class['1d'](fc_layers[1], **bn_kwargs) 157 | 158 | self.fc3 = nn.Linear(fc_layers[1], num_classes) 159 | 160 | self.reset_parameters(inp_nonscale_layers=['conv1']) 161 | 162 | def forward(self, x): 163 | z = self.encode(x) 164 | return self.decode_clf(z) 165 | 166 | def encode(self, x): 167 | x = func.relu(self.bn1(self.scaler(self.conv1(x)))) 168 | x = func.max_pool2d(x, 2) 169 | 170 | x = func.relu(self.bn2(self.scaler(self.conv2(x)))) 171 | x = func.max_pool2d(x, 2) 172 | 173 | x = func.relu(self.bn3(self.scaler(self.conv3(x)))) 174 | 175 | x = x.view(x.shape[0], -1) 176 | return x 177 | 178 | def decode_clf(self, x): 179 | x = self.scaler(self.fc1(x)) 180 | x = self.bn4(x) 181 | x = func.relu(x) 182 | 183 | x = self.scaler(self.fc2(x)) 184 | x = self.bn5(x) 185 | x = func.relu(x) 186 | 187 | logits = self.fc3(x) 188 | return logits 189 | 190 | 191 | class AlexNet(ScalableModule): 192 | """ 193 | used for DomainNet and Office-Caltech10 194 | """ 195 | input_shape = [None, 3, 256, 256] 196 | 197 | def load_state_dict(self, state_dict, strict: bool = True): 198 | legacy_keys = [] 199 | for key in state_dict: 200 | if 'noise_disc' in key: 201 | legacy_keys.append(key) 202 | if len(legacy_keys) > 0: 203 | logging.debug(f"Found old version of AlexNet. Ignore {len(legacy_keys)} legacy" 204 | f" keys: {legacy_keys}") 205 | for key in legacy_keys: 206 | state_dict.pop(key) 207 | return super().load_state_dict(state_dict, strict) 208 | 209 | def __init__(self, num_classes=10, track_running_stats=True, bn_type='bn', share_affine=True, 210 | width_scale=1., rescale_init=False, rescale_layer=False): 211 | super(AlexNet, self).__init__(width_scale=width_scale, rescale_init=rescale_init, 212 | rescale_layer=rescale_layer) 213 | self.bn_type = bn_type 214 | bn_class = get_bn_layer(bn_type) 215 | # share_affine 216 | bn_kwargs = dict( 217 | track_running_stats=track_running_stats, 218 | ) 219 | if bn_type.startswith('d'): # dual BN 220 | bn_kwargs['share_affine'] = share_affine 221 | plus_layer_i = 0 222 | feature_layers = [] 223 | feature_layers += [ 224 | ('conv1', nn.Conv2d(3, int(width_scale*64), kernel_size=11, stride=4, padding=2)), 225 | ('scaler1', self.scaler), 226 | ('bn1', bn_class['2d'](int(width_scale*64), **bn_kwargs)), 227 | ('relu1', nn.ReLU(inplace=True)), 228 | ('maxpool1', nn.MaxPool2d(kernel_size=3, stride=2)), 229 | 230 | ('conv2', nn.Conv2d(int(width_scale*64), int(width_scale*192), kernel_size=5, padding=2)), 231 | ('scaler2', self.scaler), 232 | ('bn2', bn_class['2d'](int(width_scale*192), **bn_kwargs)), 233 | ('relu2', nn.ReLU(inplace=True)), 234 | ('maxpool2', nn.MaxPool2d(kernel_size=3, stride=2)), 235 | 236 | ('conv3', nn.Conv2d(int(width_scale*192), int(width_scale*384), kernel_size=3, padding=1)), 237 | ('scaler3', self.scaler), 238 | ('bn3', bn_class['2d'](int(width_scale*384), **bn_kwargs)), 239 | ('relu3', nn.ReLU(inplace=True)), 240 | 241 | ('conv4', nn.Conv2d(int(width_scale*384), int(width_scale*256), kernel_size=3, padding=1)), 242 | ('scaler4', self.scaler), 243 | ('bn4', bn_class['2d'](int(width_scale*256), **bn_kwargs)), 244 | ('relu4', nn.ReLU(inplace=True)), 245 | 246 | ('conv5', nn.Conv2d(int(width_scale*256), int(width_scale*256), kernel_size=3, padding=1)), 247 | ('scaler5', self.scaler), 248 | ('bn5', bn_class['2d'](int(width_scale*256), **bn_kwargs)), 249 | ('relu5', nn.ReLU(inplace=True)), 250 | ('maxpool5', nn.MaxPool2d(kernel_size=3, stride=2)), 251 | ] 252 | self.features = nn.Sequential(OrderedDict(feature_layers)) 253 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 254 | 255 | self.classifier = nn.Sequential( 256 | OrderedDict([ 257 | ('fc1', nn.Linear(int(width_scale*256) * 6 * 6, int(width_scale*4096))), 258 | ('scaler6', self.scaler), 259 | ('bn6', bn_class['1d'](int(width_scale*4096), **bn_kwargs)), 260 | ('relu6', nn.ReLU(inplace=True)), 261 | 262 | ('fc2', nn.Linear(int(width_scale*4096), int(width_scale*4096))), 263 | ('scaler7', self.scaler), 264 | ('bn7', bn_class['1d'](int(width_scale*4096), **bn_kwargs)), 265 | ('relu7', nn.ReLU(inplace=True)), 266 | 267 | ('fc3', nn.Linear(int(width_scale*4096), num_classes)), 268 | ]) 269 | ) 270 | self.reset_parameters(inp_nonscale_layers=[]) 271 | if self.rescale_init and self.width_scale != 1.: 272 | self.features.conv1.reset_parameters() # ignore rescale init 273 | 274 | def forward(self, x): 275 | z = self.encode(x) 276 | logits = self.decode_clf(z) 277 | return logits 278 | 279 | def encode(self, x): 280 | x = self.features(x) 281 | x = self.avgpool(x) 282 | z = torch.flatten(x, 1) 283 | return z 284 | 285 | def decode_clf(self, z): 286 | logits = self.classifier(z) 287 | return logits 288 | 289 | 290 | if __name__ == '__main__': 291 | from nets.profile_func import profile_model, count_params_by_state 292 | 293 | model = AlexNet(width_scale=1., depth_plus=0) 294 | fea_params = count_params_by_state(model.features) 295 | clf_params = count_params_by_state(model.classifier) 296 | print(f"fea_params {fea_params/1e6} MB, clf_params: {clf_params/1e6} MB") 297 | for width_scale in [0.125]: # , 1.0]: # , 0.25, 0.5, 1.0]: 298 | for depth_plus in [0, 4, 8, 16, 22, 32, 256]: 299 | model = AlexNet(width_scale=width_scale, depth_plus=depth_plus) 300 | flops, state_params = profile_model(model) 301 | print(f' {width_scale:.3f}xWide {depth_plus}+Dep | GFLOPS {flops / 1e9:.4f}, ' 302 | f'model state size: {state_params / 1e6:.2f}MB') 303 | n_nets = int(1/width_scale) 304 | print(f" {n_nets}xNets | GFLOPS {n_nets*flops / 1e9:.4f}, " 305 | f"model state size: {n_nets*state_params / 1e6:.2f}MB") 306 | for width_scale in [1.]: # , 1.0]: # , 0.25, 0.5, 1.0]: 307 | for depth_plus in [0]: 308 | model = AlexNet(width_scale=width_scale, depth_plus=depth_plus) 309 | flops, state_params = profile_model(model) 310 | print(f' {width_scale:.3f}xWide {depth_plus}+Dep | GFLOPS {flops / 1e9:.4f}, ' 311 | f'model state size: {state_params / 1e6:.2f}MB') 312 | print(model) 313 | -------------------------------------------------------------------------------- /nets/slimmable_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ref: https://github.com/htwang14/CAT/blob/1152f7095d6ea0026c7344b00fefb9f4990444f2/models/FiLM.py#L35 3 | """ 4 | import numpy as np 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | from torch.nn.modules.batchnorm import _BatchNorm 8 | 9 | 10 | class SwitchableLayer1D(nn.Module): 11 | """1-dimensional switchable layer. 12 | The 1D means the module only requires one dimension variable, like BN. 13 | 14 | Args: 15 | module_class (nn.Module): Should a module class which takes `num_features` 16 | as the first arg, and multiple kwargs. 17 | """ 18 | def __init__(self, module_class, max_num_features: int, slim_ratios: list, **kwargs): 19 | super(SwitchableLayer1D, self).__init__() 20 | self.max_num_features = max_num_features 21 | modules = [] 22 | slim_ratios = sorted(slim_ratios) 23 | for r in slim_ratios: 24 | w = int(np.ceil(r * max_num_features)) 25 | modules.append(module_class(w, **kwargs)) 26 | self._switch_modules = nn.ModuleList(modules) 27 | self.current_module_idx = -1 28 | self._slim_ratio = max(slim_ratios) 29 | self.slim_ratios = slim_ratios 30 | self.ignore_model_profiling = True 31 | 32 | @property 33 | def slim_ratio(self): 34 | return self._slim_ratio 35 | 36 | @slim_ratio.setter 37 | def slim_ratio(self, r): 38 | self.current_module_idx = self.slim_ratios.index(r) 39 | self._slim_ratio = r 40 | 41 | def forward(self, x): 42 | y = self._switch_modules[self.current_module_idx](x) 43 | return y 44 | 45 | 46 | class SlimmableOpMixin(object): 47 | def mix_forward(self, x, mix_num=-1): 48 | if mix_num < 0: 49 | mix_num = int(1/self.slim_ratio) 50 | elif mix_num == 0: 51 | print("WARNING: not mix anything.") 52 | out = 0. 53 | for shift_idx in range(0, mix_num): 54 | out = out + self._forward_with_partial_weight(x, shift_idx) 55 | return out * 1. / mix_num 56 | 57 | def _forward_with_partial_weight(self, x, slim_bias_idx, out_slim_bias_idx=None): 58 | raise NotImplementedError() 59 | 60 | def _compute_slice_bound(self, in_channels, out_channels, slim_bias_idx, out_slim_bias_idx=None): 61 | out_slim_bias_idx = slim_bias_idx if out_slim_bias_idx is None else out_slim_bias_idx 62 | out_idx_bias = out_channels * out_slim_bias_idx if not self.non_slimmable_out else 0 63 | in_idx_bias = in_channels * slim_bias_idx if not self.non_slimmable_in else 0 64 | return out_idx_bias, (out_idx_bias+out_channels), in_idx_bias, (in_idx_bias+in_channels) 65 | 66 | 67 | class _SlimmableBatchNorm(_BatchNorm, SlimmableOpMixin): 68 | """ 69 | BatchNorm2d shared by all sub-networks in slimmable network. 70 | This won't work according to slimmable net paper. 71 | See implementation in https://github.com/htwang14/CAT/blob/1152f7095d6ea0026c7344b00fefb9f4990444f2/models/slimmable_ops.py#L28 72 | 73 | If this is used, we will enforce the tracking to be disabled. 74 | Following https://github.com/dem123456789/HeteroFL-Computation-and-Communication-Efficient-Federated-Learning-for-Heterogeneous-Clients 75 | """ 76 | def __init__(self, num_features, eps=1e-5, momentum=None, affine=True, 77 | track_running_stats=False, non_slimmable=False): 78 | assert not track_running_stats, "You should not track stats which cannot be slimmable." 79 | # if track_running_stats: 80 | # assert non_slimmable 81 | super(_SlimmableBatchNorm, self).__init__(num_features, momentum=momentum, track_running_stats=False, affine=affine, eps=eps) 82 | self.max_num_features = num_features 83 | self._slim_ratio = 1.0 84 | self.slim_bias_idx = 0 85 | self.out_slim_bias_idx = None 86 | self.non_slimmable = non_slimmable 87 | self.mix_forward_num = 1 # 1 means not mix; -1 mix all 88 | 89 | @property 90 | def slim_ratio(self): 91 | return self._slim_ratio 92 | 93 | @slim_ratio.setter 94 | def slim_ratio(self, r): 95 | self.num_features = self._compute_channels(r) 96 | self._slim_ratio = r 97 | if r < 0 and self.track_running_stats: 98 | raise RuntimeError(f"Try to track state when slim_ratio < 1 is {r}") 99 | 100 | def _compute_channels(self, ratio): 101 | return self.max_num_features if self.non_slimmable \ 102 | else int(np.ceil(self.max_num_features * ratio)) 103 | 104 | def forward(self, x): 105 | if self.mix_forward_num == 1: 106 | return self._forward_with_partial_weight(x, self.slim_bias_idx, self.out_slim_bias_idx) 107 | else: 108 | return self.mix_forward(x, mix_num=self.mix_forward_num) 109 | 110 | def _forward_with_partial_weight(self, input, slim_bias_idx, out_slim_bias_idx=None): 111 | out_idx0, out_idx1 = self._compute_slice_bound(self.num_features, slim_bias_idx) 112 | weight = self.weight[out_idx0:out_idx1] 113 | bias = self.bias[out_idx0:out_idx1] 114 | 115 | # ----- copy from parent implementation ---- 116 | self._check_input_dim(input) 117 | 118 | # exponential_average_factor is set to self.momentum 119 | # (when it is available) only so that it gets updated 120 | # in ONNX graph when this node is exported to ONNX. 121 | if self.momentum is None: 122 | exponential_average_factor = 0.0 123 | else: 124 | exponential_average_factor = self.momentum 125 | 126 | if self.training and self.track_running_stats: 127 | # TODO: if statement only here to tell the jit to skip emitting this when it is None 128 | if self.num_batches_tracked is not None: 129 | self.num_batches_tracked = self.num_batches_tracked + 1 130 | if self.momentum is None: # use cumulative moving average 131 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 132 | else: # use exponential moving average 133 | exponential_average_factor = self.momentum 134 | 135 | r""" 136 | Decide whether the mini-batch stats should be used for normalization rather than the buffers. 137 | Mini-batch stats are used in training mode, and in eval mode when buffers are None. 138 | """ 139 | if self.training: 140 | bn_training = True 141 | else: 142 | bn_training = (self.running_mean is None) and (self.running_var is None) 143 | 144 | r""" 145 | Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be 146 | passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are 147 | used for normalization (i.e. in eval mode when buffers are not None). 148 | """ 149 | return F.batch_norm( 150 | input, 151 | # If buffers are not to be tracked, ensure that they won't be updated 152 | self.running_mean if not self.training or self.track_running_stats else None, 153 | self.running_var if not self.training or self.track_running_stats else None, 154 | weight, bias, bn_training, exponential_average_factor, self.eps) 155 | 156 | def _compute_slice_bound(self, channels, slim_bias_idx): 157 | idx_bias = channels * slim_bias_idx if not self.non_slimmable else 0 158 | return idx_bias, (idx_bias+channels) 159 | 160 | def _save_to_state_dict(self, destination, prefix, keep_vars): 161 | for name, param in self._parameters.items(): 162 | if param is not None: 163 | # ------------------------------ 164 | idx_bias = self.num_features * self.slim_bias_idx if not self.non_slimmable else 0 165 | if name == 'weight': 166 | param = param[idx_bias:(idx_bias + self.num_features)] 167 | elif name == 'bias' and param is not None: 168 | param = param[idx_bias:(idx_bias + self.num_features)] 169 | # ------------------------------ 170 | destination[prefix + name] = param if keep_vars else param.detach() 171 | for name, buf in self._buffers.items(): 172 | if buf is not None and name not in self._non_persistent_buffers_set: 173 | destination[prefix + name] = buf if keep_vars else buf.detach() 174 | 175 | 176 | class SlimmableBatchNorm2d(_SlimmableBatchNorm): 177 | def _check_input_dim(self, input): 178 | if input.dim() != 4: 179 | raise ValueError('expected 4D input (got {}D input)' 180 | .format(input.dim())) 181 | 182 | 183 | class SlimmableBatchNorm1d(_SlimmableBatchNorm): 184 | 185 | def _check_input_dim(self, input): 186 | if input.dim() != 2 and input.dim() != 3: 187 | raise ValueError('expected 2D or 3D input (got {}D input)' 188 | .format(input.dim())) 189 | 190 | class SlimmableConv2d(nn.Conv2d, SlimmableOpMixin): 191 | """ 192 | Args: 193 | non_slimmable_in: Fix the in size 194 | non_slimmable_out: Fix the out size 195 | """ 196 | def __init__(self, in_channels: int, out_channels: int, 197 | kernel_size, stride=1, padding=0, dilation=1, 198 | groups=1, bias=True, 199 | non_slimmable_out=False, non_slimmable_in=False,): 200 | super(SlimmableConv2d, self).__init__( 201 | in_channels, out_channels, 202 | kernel_size, stride=stride, padding=padding, dilation=dilation, 203 | groups=groups, bias=bias) 204 | assert groups == 1, "for now, we can only support single group when slimming." 205 | assert in_channels > 0 206 | assert out_channels > 0 207 | self.max_in_channels = in_channels 208 | self.max_out_channels = out_channels 209 | self._slim_ratio = 1.0 210 | self.slim_bias_idx = 0 # input slim bias idx 211 | self.out_slim_bias_idx = None # -1: use the same value as slim_bias_idx 212 | self.non_slimmable_out = non_slimmable_out 213 | self.non_slimmable_in = non_slimmable_in 214 | self.mix_forward_num = -1 215 | 216 | @property 217 | def slim_ratio(self): 218 | return self._slim_ratio 219 | 220 | @slim_ratio.setter 221 | def slim_ratio(self, r): 222 | self.in_channels, self.out_channels = self._compute_channels(r) 223 | self._slim_ratio = r 224 | 225 | def _compute_channels(self, ratio): 226 | in_channels = self.max_in_channels if self.non_slimmable_in \ 227 | else int(np.ceil(self.max_in_channels * ratio)) 228 | out_channels = self.max_out_channels if self.non_slimmable_out \ 229 | else int(np.ceil(self.max_out_channels * ratio)) 230 | return in_channels, out_channels 231 | 232 | def forward(self, x): 233 | if self.mix_forward_num == 1: 234 | return self._forward_with_partial_weight(x, self.slim_bias_idx, self.out_slim_bias_idx) 235 | else: 236 | return self.mix_forward(x, mix_num=self.mix_forward_num) 237 | 238 | def _forward_with_partial_weight(self, x, slim_bias_idx, out_slim_bias_idx=None): 239 | out_idx0, out_idx1, in_idx0, in_idx1 = self._compute_slice_bound( 240 | self.in_channels, self.out_channels, slim_bias_idx, out_slim_bias_idx) 241 | weight = self.weight[out_idx0:out_idx1, in_idx0:in_idx1] 242 | bias = self.bias[out_idx0:out_idx1] if self.bias is not None else None 243 | y = F.conv2d( 244 | x, weight, bias, self.stride, self.padding, 245 | self.dilation, self.groups) 246 | return y / self.slim_ratio if self.training and not self.non_slimmable_out else y 247 | 248 | def _save_to_state_dict(self, destination, prefix, keep_vars): 249 | for name, param in self._parameters.items(): 250 | if param is not None: 251 | # ------------------------------ 252 | out_idx_bias = self.out_channels * self.slim_bias_idx if not self.non_slimmable_out else 0 253 | if name == 'weight': 254 | in_idx_bias = self.in_channels * self.slim_bias_idx \ 255 | if not self.non_slimmable_in else 0 256 | param = param[out_idx_bias:(out_idx_bias+self.out_channels), 257 | in_idx_bias:(in_idx_bias+self.in_channels)] 258 | elif name == 'bias' and param is not None: 259 | param = param[out_idx_bias:(out_idx_bias + self.out_channels)] 260 | # ------------------------------ 261 | destination[prefix + name] = param if keep_vars else param.detach() 262 | for name, buf in self._buffers.items(): 263 | if buf is not None and name not in self._non_persistent_buffers_set: 264 | destination[prefix + name] = buf if keep_vars else buf.detach() 265 | 266 | 267 | class SlimmableLinear(nn.Linear, SlimmableOpMixin): 268 | """ 269 | Args: 270 | non_slimmable_in: Fix the in size 271 | non_slimmable_out: Fix the out size 272 | """ 273 | def __init__(self, in_features: int, out_features: int, bias=True, 274 | non_slimmable_out=False, non_slimmable_in=False,): 275 | super(SlimmableLinear, self).__init__(in_features, out_features, bias=bias) 276 | self.max_in_features = in_features 277 | self.max_out_features = out_features 278 | self._slim_ratio = 1.0 279 | self.slim_bias_idx = 0 # input slim bias idx 280 | self.out_slim_bias_idx = None # -1: use the same value as slim_bias_idx 281 | self.non_slimmable_out = non_slimmable_out 282 | self.non_slimmable_in = non_slimmable_in 283 | self.mix_forward_num = -1 284 | 285 | @property 286 | def slim_ratio(self): 287 | return self._slim_ratio 288 | 289 | @slim_ratio.setter 290 | def slim_ratio(self, r): 291 | self.in_features, self.out_features = self._compute_channels(r) 292 | self._slim_ratio = r 293 | 294 | def _compute_channels(self, ratio): 295 | in_features = self.max_in_features if self.non_slimmable_in \ 296 | else int(np.ceil(self.max_in_features * ratio)) 297 | out_features = self.max_out_features if self.non_slimmable_out \ 298 | else int(np.ceil(self.max_out_features * ratio)) 299 | return in_features, out_features 300 | 301 | def forward(self, x): 302 | if self.mix_forward_num == 1: 303 | return self._forward_with_partial_weight(x, self.slim_bias_idx, self.out_slim_bias_idx) 304 | else: 305 | return self.mix_forward(x, mix_num=self.mix_forward_num) 306 | 307 | def _forward_with_partial_weight(self, x, slim_bias_idx, out_slim_bias_idx=None): 308 | out_idx0, out_idx1, in_idx0, in_idx1 = self._compute_slice_bound( 309 | self.in_features, self.out_features, slim_bias_idx, out_slim_bias_idx) 310 | weight = self.weight[out_idx0:out_idx1, in_idx0:in_idx1] 311 | bias = self.bias[out_idx0:out_idx1] if self.bias is not None else None 312 | out = F.linear(x, weight, bias) 313 | return out / self.slim_ratio if self.training and not self.non_slimmable_out else out 314 | 315 | def _save_to_state_dict(self, destination, prefix, keep_vars): 316 | for name, param in self._parameters.items(): 317 | if param is not None: 318 | # ------------------------------ 319 | param = self.get_slim_param(name, param) 320 | # ------------------------------ 321 | destination[prefix + name] = param if keep_vars else param.detach() 322 | for name, buf in self._buffers.items(): 323 | if buf is not None and name not in self._non_persistent_buffers_set: 324 | destination[prefix + name] = buf if keep_vars else buf.detach() 325 | 326 | def get_slim_param(self, name, param): 327 | out_idx_bias = self.out_features * self.slim_bias_idx if not self.non_slimmable_out else 0 328 | if name == 'weight': 329 | in_idx_bias = self.in_features * self.slim_bias_idx if not self.non_slimmable_in else 0 330 | param = param[out_idx_bias:(out_idx_bias + self.out_features), 331 | in_idx_bias:(in_idx_bias + self.in_features)] 332 | elif name == 'bias' and param is not None: 333 | param = param[out_idx_bias:(out_idx_bias + self.out_features)] 334 | return param 335 | -------------------------------------------------------------------------------- /federated/core.py: -------------------------------------------------------------------------------- 1 | """Core functions of federate learning.""" 2 | import argparse 3 | import copy 4 | import torch 5 | import numpy as np 6 | from torch import nn 7 | 8 | from federated.aggregation import ModelAccumulator, SlimmableModelAccumulator 9 | from nets.slimmable_models import get_slim_ratios_from_str, parse_lognorm_slim_schedule 10 | from utils.utils import shuffle_sampler, str2bool 11 | 12 | 13 | class _Federation: 14 | """A helper class for federated data creation. 15 | Use `add_argument` to setup ArgumentParser and then use parsed args to init the class. 16 | """ 17 | _model_accum: ModelAccumulator 18 | 19 | @classmethod 20 | def add_argument(cls, parser: argparse.ArgumentParser): 21 | # data 22 | parser.add_argument('--percent', type=float, default=1.0, 23 | help='percentage of dataset for training') # 1.0 1.0 0.3 24 | parser.add_argument('--val_ratio', type=float, default=0.1, 25 | help='ratio of train set for validation') # 0.3 0.1 0.5 26 | parser.add_argument('--batch', type=int, default=50, help='batch size')# 32 128 27 | parser.add_argument('--test_batch', type=int, default=128, help='batch size for test') 28 | 29 | # federated 30 | parser.add_argument('--pd_nuser', type=int, default=100, help='#users per domain.')# 30 100 10 31 | parser.add_argument('--pr_nuser', type=int, default=10, help='#users per comm round ' 32 | '[default: all]')#-1 10 10 33 | parser.add_argument('--pu_nclass', type=int, default=3, help='#class per user. -1 or 0: all')# -1 10 3 34 | parser.add_argument('--domain_order', choices=list(range(5)), type=int, default=0, 35 | help='select the order of domains') 36 | parser.add_argument('--partition_mode', choices=['uni', 'dir'], type=str.lower, default='uni', 37 | help='the mode when splitting domain data into users: uni - uniform ' 38 | 'distribution (all user have the same #samples); dir - Dirichlet' 39 | ' distribution (non-iid sample sizes)') 40 | parser.add_argument('--con_test_cls', type=str2bool, default=True, 41 | help='Ensure the test classes are the same training for a client. ' 42 | 'Meanwhile, make test sets are uniformly splitted for clients. ' 43 | 'Mainly influence class-niid settings.') 44 | 45 | 46 | @classmethod 47 | def render_run_name(cls, args): 48 | run_name = f'__pd_nuser_{args.pd_nuser}' 49 | if args.percent != 0.3: run_name += f'__pct_{args.percent}' 50 | if args.pu_nclass > 0: run_name += f"__pu_nclass_{args.pu_nclass}" 51 | if args.pr_nuser != -1: run_name += f'__pr_nuser_{args.pr_nuser}' 52 | if args.domain_order != 0: run_name += f'__do_{args.domain_order}' 53 | if args.partition_mode != 'uni': run_name += f'__part_md_{args.partition_mode}' 54 | if args.con_test_cls: run_name += '__ctc' 55 | return run_name 56 | 57 | def __init__(self, data, args): 58 | self.args = args 59 | 60 | # Prepare Data 61 | num_classes = 10 62 | if data == 'Digits': 63 | from utils.data_utils import DigitsDataset 64 | from utils.data_loader import prepare_digits_data 65 | prepare_data = prepare_digits_data 66 | DataClass = DigitsDataset 67 | elif data == 'DomainNet': 68 | from utils.data_utils import DomainNetDataset 69 | from utils.data_loader import prepare_domainnet_data 70 | prepare_data = prepare_domainnet_data 71 | DataClass = DomainNetDataset 72 | elif data == 'Cifar10': 73 | from utils.data_utils import CifarDataset 74 | from utils.data_loader import prepare_cifar_data 75 | prepare_data = prepare_cifar_data 76 | DataClass = CifarDataset 77 | elif data == 'Cifar100': 78 | num_classes = 100 79 | from utils.data_utils import Cifar100Dataset 80 | from utils.data_loader import prepare_cifar100_data 81 | prepare_data = prepare_cifar100_data 82 | DataClass = Cifar100Dataset 83 | else: 84 | raise ValueError(f"Unknown dataset: {data}") 85 | all_domains = DataClass.resorted_domains[args.domain_order] 86 | 87 | train_loaders, val_loaders, test_loaders, clients = prepare_data( 88 | args, domains=all_domains, 89 | n_user_per_domain=args.pd_nuser, 90 | n_class_per_user=args.pu_nclass, 91 | partition_seed=args.seed + 1, 92 | partition_mode=args.partition_mode, 93 | val_ratio=args.val_ratio, 94 | eq_domain_train_size=args.partition_mode == 'uni', 95 | consistent_test_class=args.con_test_cls, 96 | ) 97 | clients = [c + ' ' + 'clean' for c in clients] 98 | 99 | self.train_loaders = train_loaders 100 | self.val_loaders = val_loaders 101 | self.test_loaders = test_loaders 102 | self.clients = clients 103 | self.num_classes = num_classes 104 | self.all_domains = all_domains 105 | 106 | # Setup fed 107 | self.client_num = len(self.clients) 108 | client_weights = [len(tl.dataset) for tl in train_loaders] 109 | self.client_weights = [w / sum(client_weights) for w in client_weights] 110 | 111 | pr_nuser = args.pr_nuser if args.pr_nuser > 0 else self.client_num 112 | self.args.pr_nuser = pr_nuser 113 | self.client_sampler = UserSampler([i for i in range(self.client_num)], pr_nuser, mode='uni') 114 | 115 | def get_data(self): 116 | return self.train_loaders, self.val_loaders, self.test_loaders 117 | 118 | def make_aggregator(self, running_model): 119 | self._model_accum = ModelAccumulator(running_model, self.args.pr_nuser, self.client_num) 120 | return self._model_accum 121 | 122 | @property 123 | def model_accum(self): 124 | if not hasattr(self, '_model_accum'): 125 | raise RuntimeError(f"model_accum has not been set yet. Call `make_aggregator` first.") 126 | return self._model_accum 127 | 128 | def download(self, running_model, client_idx, strict=True): 129 | """Download (personalized) global model to running_model.""" 130 | self.model_accum.load_model(running_model, client_idx, strict=strict) 131 | 132 | def upload(self, running_model, client_idx): 133 | """Upload client model.""" 134 | self.model_accum.add(client_idx, running_model, self.client_weights[client_idx]) 135 | 136 | def aggregate(self): 137 | """Aggregate received models and update global model.""" 138 | self.model_accum.update_server_and_reset() 139 | 140 | 141 | class HeteFederation(_Federation): 142 | """Heterogeneous federation where each client is capable for training different widths.""" 143 | @classmethod 144 | def add_argument(cls, parser: argparse.ArgumentParser): 145 | super(HeteFederation, cls).add_argument(parser) 146 | parser.add_argument('--slim_ratios', type=str, default='8-1', 147 | help='define the slim_ratio for groups, for example, 8-4-2-1 [default]' 148 | ' means x1/8 net for the 1st group, and x1/4 for the 2nd') # '8-4-2-1' 149 | 150 | parser.add_argument('--val_ens_only', type=str2bool, default=True, help='only validate the full-size model') #action='store_true' 151 | 152 | 153 | @classmethod 154 | def render_run_name(cls, args): 155 | run_name = super(HeteFederation, cls).render_run_name(args) 156 | if args.slim_ratios != '8-4-2-1': run_name += f'__{args.slim_ratios}' 157 | return run_name 158 | 159 | def __init__(self, data, args): 160 | super(HeteFederation, self).__init__(data, args) 161 | train_slim_ratios = get_slim_ratios_from_str(args.slim_ratios) 162 | if len(train_slim_ratios) <= 1: 163 | info = f"WARN: There is no width to customize for training with " \ 164 | f"slim_ratios={args.slim_ratios}. To set a non-single" \ 165 | f" slim_ratios." 166 | if len(train_slim_ratios) > 0: 167 | print(info) 168 | else: 169 | raise RuntimeError(info) 170 | max_slim_ratio = max(train_slim_ratios) 171 | if args.val_ens_only: 172 | val_slim_ratios = [max_slim_ratio] # only validate the max width 173 | else: 174 | val_slim_ratios = copy.deepcopy(train_slim_ratios) 175 | if max_slim_ratio not in val_slim_ratios: 176 | val_slim_ratios.append(max_slim_ratio) # make sure the max width model is validated. 177 | 178 | self.train_slim_ratios = train_slim_ratios 179 | self.user_max_slim_ratios = self.get_slim_ratio_schedule(train_slim_ratios, args.slim_ratios) 180 | self.val_slim_ratios = val_slim_ratios 181 | 182 | def get_user_max_slim_ratios(self): 183 | return self.user_max_slim_ratios 184 | 185 | def get_slim_ratio_schedule(self, train_slim_ratios: list, mode: str): 186 | if mode.startswith('ln'): # lognorm 187 | return parse_lognorm_slim_schedule(train_slim_ratios, mode, self.client_num) 188 | else: 189 | return [train_slim_ratios[int(len(train_slim_ratios) * i / self.client_num)] 190 | for i, cname in enumerate(self.clients)] 191 | 192 | def make_aggregator(self, running_model, local_bn=False): 193 | self._model_accum = SlimmableModelAccumulator(running_model, self.args.pr_nuser, 194 | self.client_num, local_bn=local_bn) 195 | return self._model_accum 196 | 197 | def upload(self, running_model, client_idx, max_slim_ratio=None, slim_bias_idx=None): 198 | assert max_slim_ratio is not None 199 | assert slim_bias_idx is not None 200 | self.model_accum.add(client_idx, running_model, self.client_weights[client_idx], 201 | max_slim_ratio=max_slim_ratio, slim_bias_idx=slim_bias_idx) 202 | 203 | def mask_split_upload(self, running_model, client_idx, computable_body_layers, max_slim_ratio=None, slim_bias_idx=None): 204 | assert max_slim_ratio is not None 205 | assert slim_bias_idx is not None 206 | self.model_accum.mask_split_add(client_idx, running_model, computable_body_layers, self.client_weights[client_idx], 207 | max_slim_ratio=max_slim_ratio, slim_bias_idx=slim_bias_idx) 208 | 209 | def mask_hfl_upload(self, running_model, client_idx, computable_body_layers, max_slim_ratio=None, slim_bias_idx=None): 210 | assert max_slim_ratio is not None 211 | assert slim_bias_idx is not None 212 | self.model_accum.mask_hfl_add(client_idx, running_model, computable_body_layers, self.client_weights[client_idx], 213 | max_slim_ratio=max_slim_ratio, slim_bias_idx=slim_bias_idx) 214 | 215 | 216 | 217 | def mask_upload(self, running_model, client_idx, computable_body_layers): 218 | """Upload client model.""" 219 | self.model_accum.mask_add(client_idx, running_model, computable_body_layers, self.client_weights[client_idx]) 220 | 221 | def sample_bases(self, client_idx): 222 | """Sample slimmer base models for the client. 223 | Return slim_ratios, slim_shifts 224 | """ 225 | max_slim_ratio = self.user_max_slim_ratios[client_idx] 226 | slim_shifts = [0] 227 | slim_ratios = [max_slim_ratio] 228 | print(f" max slim ratio: {max_slim_ratio} " 229 | f"slim_ratios={slim_ratios}, slim_shifts={slim_shifts}") 230 | return slim_ratios, slim_shifts 231 | 232 | def controllers_init(self, model): 233 | self.control = {} 234 | self.delta_control = {} 235 | for name, par in model.named_parameters(): 236 | self.control[name] = torch.zeros_like(par.data) 237 | self.delta_control[name] = torch.zeros_like(par.data) 238 | 239 | 240 | self.usersControl = [] 241 | #self.usersDelta_control = [] 242 | 243 | for clientIdx in range(self.client_num): 244 | self.usersControl.append(copy.deepcopy(self.control)) 245 | #self.usersDelta_control.append(copy.deepcopy(self.control)) 246 | 247 | def deltaControl_reset(self, model): 248 | for name, par in model.named_parameters(): 249 | self.delta_control[name] = torch.zeros_like(par.data) 250 | 251 | 252 | def controller_update(self, model): 253 | 254 | for name, par in model.named_parameters(): 255 | 256 | self.control[name] = self.control[name] + (1/self.client_num) * self.delta_control[name] 257 | 258 | 259 | 260 | 261 | 262 | class SHeteFederation(HeteFederation): 263 | """Extend HeteroFL w/ local slimmable training.""" 264 | @classmethod 265 | def add_argument(cls, parser: argparse.ArgumentParser): 266 | super(SHeteFederation, cls).add_argument(parser) 267 | parser.add_argument('--slimmable_train', type=str2bool, default=True, 268 | help='train all budget-compatible slimmable networks, otherwise HeteroFL') 269 | 270 | @classmethod 271 | def render_run_name(cls, args): 272 | run_name = super(SHeteFederation, cls).render_run_name(args) 273 | if not args.slimmable_train: run_name += f'__nst' 274 | return run_name 275 | 276 | def sample_bases(self, client_idx): 277 | """Sample slimmer base models for the client. 278 | Return slim_ratios, slim_shifts 279 | """ 280 | max_slim_ratio = self.user_max_slim_ratios[client_idx] 281 | if self.args.slimmable_train: 282 | if len(self.train_slim_ratios) > 4: 283 | print("WARN: over 4 trained slim ratios which will cause large overhead for" 284 | " slimmable training. Try to set slimmable_train=False (HeteroFL) instead.") 285 | slim_ratios = [r for r in self.train_slim_ratios if r <= max_slim_ratio] 286 | else: 287 | slim_ratios = [max_slim_ratio] 288 | slim_shifts = [0] * len(slim_ratios) 289 | print(f" max slim ratio: {max_slim_ratio} " 290 | f"slim_ratios={slim_ratios}, slim_shifts={slim_shifts}") 291 | return slim_ratios, slim_shifts 292 | 293 | 294 | class SplitFederation(HeteFederation): 295 | """Split a net into multiple subnets and train them in federated learning.""" 296 | @classmethod 297 | def add_argument(cls, parser: argparse.ArgumentParser): 298 | super(SplitFederation, cls).add_argument(parser) 299 | parser.add_argument('--atom_slim_ratio', type=float, default=0.125, 300 | help='the width ratio of a base model') 301 | 302 | @classmethod 303 | def render_run_name(cls, args): 304 | run_name = super(SplitFederation, cls).render_run_name(args) 305 | assert 0. < args.atom_slim_ratio <= 1., f"Invalid slim_ratio: {args.atom_slim_ratio}" 306 | if args.atom_slim_ratio != 0.125: run_name += f"__asr{args.atom_slim_ratio}" 307 | return run_name 308 | 309 | def __init__(self, data, args): 310 | super(SplitFederation, self).__init__(data, args) 311 | 312 | assert args.atom_slim_ratio <= min(self.train_slim_ratios), \ 313 | f"Base model's width ({args.atom_slim_ratio}) is larger than that of minimal allowed " \ 314 | f"width ({min(self.train_slim_ratios)})" 315 | 316 | self.num_base = int(max(self.train_slim_ratios) / args.atom_slim_ratio) 317 | self.user_base_sampler = shuffle_sampler(list(range(self.num_base))) 318 | 319 | def sample_bases(self, client_idx): 320 | """Sample base models for the client. 321 | Return slim_ratios, slim_shifts 322 | """ 323 | # (Alg 2) Sample base models defined by shift index. 324 | max_slim_ratio = self.user_max_slim_ratios[client_idx] 325 | user_n_base = int(max_slim_ratio / self.args.atom_slim_ratio) 326 | 327 | slim_shifts = [self.user_base_sampler.next()] 328 | if user_n_base > 1: 329 | _sampler = shuffle_sampler([v for v in self.user_base_sampler.arr if v != slim_shifts[0]]) 330 | slim_shifts += [_sampler.next() for _ in range(user_n_base - 1)] 331 | slim_ratios = [self.args.atom_slim_ratio] * user_n_base 332 | print(f" max slim ratio: {max_slim_ratio} " 333 | f"slim_ratios={slim_ratios}, slim_shifts={slim_shifts}") 334 | return slim_ratios, slim_shifts 335 | 336 | 337 | class UserSampler(object): 338 | def __init__(self, users, select_nuser, mode='all'): 339 | self.users = users 340 | self.total_num_user = len(users) 341 | self.select_nuser = select_nuser 342 | self.mode = mode 343 | if mode == 'all': 344 | assert select_nuser == self.total_num_user, "Conflict config: Select too few users." 345 | 346 | def iter(self): 347 | if self.mode == 'all' or self.select_nuser == self.total_num_user: 348 | sel = np.arange(len(self.users)) 349 | elif self.mode == 'uni': 350 | sel = np.random.choice(self.total_num_user, self.select_nuser, replace=False) 351 | else: 352 | raise ValueError(f"Unsupported mode: {self.mode}") 353 | for i in sel: 354 | yield self.users[i] 355 | 356 | def tot(self): 357 | if self.mode == 'all' or self.select_nuser == self.total_num_user: 358 | return len(self.users) 359 | elif self.mode == 'uni': 360 | #return self.select_nuser 361 | return len(self.users) 362 | 363 | 364 | 365 | 366 | -------------------------------------------------------------------------------- /fed_hfl.py: -------------------------------------------------------------------------------- 1 | """HeteroFL""" 2 | import os, argparse, time 3 | import numpy as np 4 | import wandb 5 | import torch 6 | import copy 7 | from torch import nn, optim 8 | # federated 9 | from federated.learning import train_slimmable, test, personalization 10 | # utils 11 | from utils.utils import set_seed, AverageMeter, CosineAnnealingLR, \ 12 | MultiStepLR, str2bool 13 | from nets.profile_func import profile_model 14 | from utils.config import CHECKPOINT_ROOT 15 | 16 | from federated.core import HeteFederation as Federation 17 | 18 | 19 | 20 | def render_run_name(args, exp_folder): 21 | """Return a unique run_name from given args.""" 22 | if args.model == 'default': 23 | args.model = {'Digits': 'digit', 'Cifar10': 'preresnet18', 'Cifar100': 'mobile', 'DomainNet': 'alex'}[args.data] 24 | run_name = f'{args.model}' 25 | run_name += Federation.render_run_name(args) 26 | # log non-default args 27 | if args.seed != 1: run_name += f'__seed_{args.seed}' 28 | # opt 29 | if args.lr_sch != 'none': run_name += f'__lrs_{args.lr_sch}' 30 | if args.opt != 'sgd': run_name += f'__opt_{args.opt}' 31 | if args.batch != 32: run_name += f'__batch_{args.batch}' 32 | if args.wk_iters != 1: run_name += f'__wk_iters_{args.wk_iters}' 33 | # slimmable 34 | if args.no_track_stat: run_name += f"__nts" 35 | # split-mix 36 | if not args.rescale_init: run_name += '__nri' 37 | if not args.rescale_layer: run_name += '__nrl' 38 | if args.loss_temp != 'none': run_name += f'__lt{args.loss_temp}' 39 | 40 | args.save_path = os.path.join(CHECKPOINT_ROOT, exp_folder) 41 | if not os.path.exists(args.save_path): 42 | os.makedirs(args.save_path) 43 | SAVE_FILE = os.path.join(args.save_path, run_name) 44 | return run_name, SAVE_FILE 45 | 46 | 47 | def get_model_fh(data, model): 48 | if data == 'Digits': 49 | if model in ['digit']: 50 | from nets.slimmable_models import SlimmableDigitModel 51 | # TODO remove. Function the same as ens_digit 52 | ModelClass = SlimmableDigitModel 53 | else: 54 | raise ValueError(f"Invalid model: {model}") 55 | elif data in ['DomainNet']: 56 | if model in ['alex']: 57 | from nets.slimmable_models import SlimmableAlexNet 58 | ModelClass = SlimmableAlexNet 59 | else: 60 | raise ValueError(f"Invalid model: {model}") 61 | elif data == 'Cifar10': 62 | if model in ['preresnet18']: # From heteroFL 63 | from nets.HeteFL.slimmable_preresne import resnet18 64 | ModelClass = resnet18 65 | else: 66 | raise ValueError(f"Invalid model: {model}") 67 | elif data == 'Cifar100': 68 | if model in ['mobile']: 69 | from nets.slimmable_Nets import MobileNetCifar 70 | ModelClass = MobileNetCifar 71 | else: 72 | raise ValueError(f"Invalid model: {model}") 73 | else: 74 | raise ValueError(f"Unknown dataset: {data}") 75 | return ModelClass 76 | 77 | 78 | def fed_test(fed, running_model, train_loaders, val_loaders, global_lr, verbose): 79 | mark = 's' 80 | val_acc_list_bp = [None for _ in range(fed.client_num)] 81 | val_loss_mt_bp = AverageMeter() 82 | val_acc_list = [None for _ in range(fed.client_num)] 83 | val_loss_mt = AverageMeter() 84 | slim_val_acc_bp_mt = {slim_ratio: AverageMeter() for slim_ratio in fed.val_slim_ratios} 85 | slim_val_acc_mt = {slim_ratio: AverageMeter() for slim_ratio in fed.val_slim_ratios} 86 | for client_idx in range(fed.client_num): 87 | fed.download(running_model, client_idx) 88 | 89 | for i_slim_ratio, slim_ratio in enumerate(fed.val_slim_ratios): 90 | # Load and set slim ratio 91 | 92 | running_model.switch_slim_mode(slim_ratio) 93 | 94 | 95 | # Test 96 | 97 | val_model = copy.deepcopy(running_model) 98 | # Loss and accuracy before personalization 99 | val_loss_bp, val_acc_bp = test(val_model, val_loaders[client_idx], loss_fun, device) 100 | 101 | # Log 102 | val_loss_mt_bp.append(val_loss_bp) 103 | val_acc_list_bp[client_idx] = val_acc_bp 104 | if verbose > 0: 105 | print(' {:<19s} slim {:.2f}| Val Before Personalization {:s}Loss: {:.4f} | Val {:s}Acc: {:.4f}'.format( 106 | 'User-' + fed.clients[client_idx] if i_slim_ratio == 0 else ' ', slim_ratio, 107 | mark.upper(), val_loss_bp, mark.upper(), val_acc_bp)) 108 | wandb.log({ 109 | f"{fed.clients[client_idx]} sm{slim_ratio:.2f} val_bp_s-acc": val_acc_bp, 110 | }, commit=False) 111 | if slim_ratio == fed.user_max_slim_ratios[client_idx]: 112 | wandb.log({ 113 | f"{fed.clients[client_idx]} val_bp_{mark}-acc": val_acc_bp, 114 | }, commit=False) 115 | slim_val_acc_bp_mt[slim_ratio].append(val_acc_bp) 116 | 117 | if args.test: 118 | 119 | # Loss and accuracy after personalization 120 | val_loss, val_acc = personalization(val_model, train_loaders[client_idx], val_loaders[client_idx], 121 | loss_fun, global_lr, device) 122 | 123 | 124 | # Log 125 | val_loss_mt.append(val_loss) 126 | val_acc_list[client_idx] = val_acc # NOTE only record the last slim_ratio. 127 | if verbose > 0: 128 | print(' {:<19s} slim {:.2f}| Val {:s}Loss: {:.4f} | Val {:s}Acc: {:.4f}'.format( 129 | 'User-' + fed.clients[client_idx] if i_slim_ratio == 0 else ' ', slim_ratio, 130 | mark.upper(), val_loss, mark.upper(), val_acc)) 131 | if slim_ratio == fed.user_max_slim_ratios[client_idx]: 132 | wandb.log({ 133 | f"{fed.clients[client_idx]} val_{mark}-acc": val_acc, 134 | }, commit=False) 135 | slim_val_acc_mt[slim_ratio].append(val_acc) 136 | 137 | wandb.log({ 138 | f"{fed.clients[client_idx]} val_{mark}-acc": val_acc, 139 | }, commit=False) 140 | 141 | 142 | if args.test: 143 | 144 | return val_acc_list, val_loss_mt.avg, val_acc_list_bp, val_loss_mt_bp.avg 145 | else: 146 | 147 | return val_acc_list_bp, val_loss_mt_bp.avg, val_acc_list_bp, val_loss_mt_bp.avg 148 | 149 | 150 | 151 | 152 | 153 | if __name__ == '__main__': 154 | 155 | 156 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 157 | 158 | parser = argparse.ArgumentParser() 159 | # basic problem setting 160 | parser.add_argument('--seed', type=int, default=1, help='random seed') 161 | parser.add_argument('--data', type=str, default='Cifar10', help='data name') # 'DomainNet' 'Cifar100' 162 | parser.add_argument('--model', type=str.lower, default='default', help='model name') 163 | parser.add_argument('--algorithm', type=str, default='HFL', help='algorithm name') 164 | parser.add_argument('--no_track_stat', action='store_true', help='disable BN tracking') 165 | # control 166 | parser.add_argument('--no_log', action='store_true', help='disable wandb log') 167 | parser.add_argument('--test', action='store_true', help='test the pretrained model') 168 | #parser.add_argument('--test', type=str2bool, default=True, help='test the pretrained model') #action='store_true' 169 | parser.add_argument('--resume', action='store_true', help='resume training from checkpoint') 170 | parser.add_argument('--verbose', type=int, default=0, help='verbose level: 0 or 1') 171 | # federated 172 | Federation.add_argument(parser) 173 | # optimization 174 | parser.add_argument('--lr', type=float, default=1e-1, help='learning rate') #1e-2 1e-1 175 | parser.add_argument('--lr_sch', type=str, default='multi_step', help='learning rate schedule') # 'cos' 'none' 176 | parser.add_argument('--opt', type=str.lower, default='sgd', help='optimizer') 177 | parser.add_argument('--iters', type=int, default=80, help='#iterations for communication')#200 178 | parser.add_argument('--wk_iters', type=int, default=4, help='#epochs in local train')#5 1 179 | # slimmable test 180 | parser.add_argument('--test_slim_ratio', type=float, default=1., 181 | help='slim_ratio of model at testing.') 182 | # split-mix 183 | parser.add_argument('--rescale_init', type=str2bool, default=True, help='rescale init after' 184 | ' slim') 185 | parser.add_argument('--rescale_layer', type=str2bool, default=True, help='rescale layer outputs' 186 | ' after slim') 187 | parser.add_argument('--loss_temp', type=str, default='none', choices=['none', 'auto'], 188 | help='temper cross-entropy loss. auto: set temp as the width scale.') 189 | args = parser.parse_args() 190 | 191 | set_seed(args.seed) 192 | 193 | 194 | # ///////////////////////////////// 195 | # ///// Fed Dataset and Model ///// 196 | # ///////////////////////////////// 197 | fed = Federation(args.data, args) 198 | # Data 199 | train_loaders, val_loaders, test_loaders = fed.get_data() 200 | mean_batch_iters = int(np.mean([len(tl) for tl in train_loaders])) 201 | print(f" mean_batch_iters: {mean_batch_iters}") 202 | 203 | # set experiment files, wandb 204 | exp_folder = f'Alg_{args.algorithm}_C{fed.args.pr_nuser}_{args.data}' 205 | run_name, SAVE_FILE = render_run_name(args, exp_folder) 206 | wandb.init(group=run_name[:120], project=exp_folder, mode='offline' if args.no_log else 'online', 207 | config={**vars(args), 'save_file': SAVE_FILE}) 208 | 209 | # Model 210 | ModelClass = get_model_fh(args.data, args.model) 211 | running_model = ModelClass( 212 | track_running_stats=False, 213 | num_classes=fed.num_classes, slimmabe_ratios=fed.train_slim_ratios, 214 | ).to(device) 215 | 216 | 217 | # Loss 218 | loss_fun = nn.CrossEntropyLoss() 219 | 220 | # Use running model to init a fed aggregator 221 | fed.make_aggregator(running_model) 222 | 223 | 224 | 225 | totParamNum = 0 226 | userCounter = 0 227 | for userIdx in range(fed.client_sampler.tot()): 228 | slim_ratios, slim_shifts = fed.sample_bases(userIdx) 229 | temp_model = copy.deepcopy(running_model) 230 | fed.download(temp_model, userIdx) 231 | temp_model.switch_slim_mode(slim_ratios[0], slim_bias_idx=slim_shifts[0], out_slim_bias_idx=None) # 232 | _, computableParamNum = profile_model(temp_model, device=device) 233 | 234 | totParamNum += computableParamNum 235 | userCounter += 1 236 | 237 | 238 | wandb.log({'Num_of_Params': totParamNum/userCounter}, commit=False) 239 | 240 | 241 | # ///////////////// 242 | # //// Resume ///// 243 | # ///////////////// 244 | # log the best for each model on all datasets 245 | best_epoch = 0 246 | best_acc = [0. for j in range(fed.client_num)] 247 | train_elapsed = [[] for _ in range(fed.client_num)] 248 | start_epoch = 0 249 | if args.resume or args.test: 250 | if os.path.exists(SAVE_FILE): 251 | print(f'Loading chkpt from {SAVE_FILE}') 252 | checkpoint = torch.load(SAVE_FILE) 253 | best_epoch, best_acc = checkpoint['best_epoch'], checkpoint['best_acc'] 254 | train_elapsed = checkpoint['train_elapsed'] 255 | train_dataset = checkpoint['train_dataset'] 256 | global_lr = checkpoint['lr'] 257 | start_epoch = int(checkpoint['a_iter']) + 1 258 | fed.model_accum.load_state_dict(checkpoint['server_model']) 259 | 260 | print('Resume training from epoch {} with best acc:'.format(start_epoch)) 261 | for client_idx, acc in enumerate(best_acc): 262 | print(' Best user-{:<10s}| Epoch:{} | Val Acc: {:.4f}'.format( 263 | fed.clients[client_idx], best_epoch, acc)) 264 | else: 265 | if args.test: 266 | raise FileNotFoundError(f"Not found checkpoint at {SAVE_FILE}") 267 | else: 268 | print(f"Not found checkpoint at {SAVE_FILE}\n **Continue without resume.**") 269 | 270 | 271 | # /////////////// 272 | # //// Test ///// 273 | # /////////////// 274 | if args.test: 275 | wandb.summary[f'best_epoch'] = best_epoch 276 | 277 | 278 | test_acc_list, _, test_acc_list_bp, _ = fed_test(fed, running_model, train_dataset, 279 | test_loaders, global_lr, args.verbose) 280 | 281 | 282 | print(f"\n Average Test Acc Before Personalization: {np.mean(test_acc_list_bp)}") 283 | wandb.summary[f'avg test acc bp'] = np.mean(test_acc_list_bp) 284 | print(f"\n Average Test Acc: {np.mean(test_acc_list)}") 285 | wandb.summary[f'avg test acc'] = np.mean(test_acc_list) 286 | wandb.finish() 287 | 288 | exit(0) 289 | 290 | 291 | # //////////////// 292 | # //// Train ///// 293 | # //////////////// 294 | # LR scheduler 295 | if args.lr_sch == 'cos': 296 | lr_sch = CosineAnnealingLR(args.iters, eta_max=args.lr, last_epoch=start_epoch) 297 | elif args.lr_sch == 'multi_step': 298 | lr_sch = MultiStepLR(args.lr, milestones=[args.iters//2, (args.iters * 3)//4], gamma=0.1, last_epoch=start_epoch) 299 | else: 300 | assert args.lr_sch == 'none', f'Invalid lr_sch: {args.lr_sch}' 301 | lr_sch = None 302 | for a_iter in range(start_epoch, args.iters): 303 | # set global lr 304 | global_lr = args.lr if lr_sch is None else lr_sch.step() 305 | wandb.log({'global lr': global_lr}, commit=False) 306 | 307 | # ----------- Train Client --------------- 308 | train_loss_mt, train_acc_mt = AverageMeter(), AverageMeter() 309 | print("============ Train epoch {} ============".format(a_iter)) 310 | for client_idx in fed.client_sampler.iter(): 311 | # (Alg 2) Sample base models defined by shift index. 312 | slim_ratios, slim_shifts = fed.sample_bases(client_idx) 313 | 314 | start_time = time.process_time() 315 | 316 | # Download global model to local 317 | fed.download(running_model, client_idx) 318 | 319 | # (Alg 3) Local Train 320 | if args.opt == 'sgd': 321 | optimizer = optim.SGD(params=running_model.parameters(), lr=global_lr, 322 | momentum=0.9, weight_decay=5e-4) 323 | elif args.opt == 'adam': 324 | optimizer = optim.Adam(params=running_model.parameters(), lr=global_lr) 325 | else: 326 | raise ValueError(f"Invalid optimizer: {args.opt}") 327 | train_loss, train_acc = train_slimmable( 328 | running_model, train_loaders[client_idx], optimizer, loss_fun, device, 329 | max_iter=mean_batch_iters * args.wk_iters if args.partition_mode != 'uni' 330 | else len(train_loaders[client_idx]) * args.wk_iters, 331 | slim_ratios=slim_ratios, slim_shifts=slim_shifts, progress=args.verbose > 0, 332 | loss_temp=args.loss_temp 333 | ) 334 | 335 | # Upload 336 | fed.upload(running_model, client_idx, 337 | max_slim_ratio=max(slim_ratios), slim_bias_idx=slim_shifts) 338 | 339 | # Log 340 | client_name = fed.clients[client_idx] 341 | elapsed = time.process_time() - start_time 342 | wandb.log({f'{client_name}_train_elapsed': elapsed}, commit=False) 343 | train_elapsed[client_idx].append(elapsed) 344 | 345 | train_loss_mt.append(train_loss), train_acc_mt.append(train_acc) 346 | print(f' User-{client_name:<10s} Train | Loss: {train_loss:.4f} |' 347 | f' Acc: {train_acc:.4f} | Elapsed: {elapsed:.2f} s') 348 | wandb.log({ 349 | f"{client_name} train_loss": train_loss, 350 | f"{client_name} train_acc": train_acc, 351 | }, commit=False) 352 | 353 | # Use accumulated model to update server model 354 | fed.aggregate() 355 | 356 | 357 | # ----------- Validation --------------- 358 | val_acc_list, val_loss, val_acc_list_bp, val_loss_bp = fed_test(fed, running_model, train_loaders, val_loaders, 359 | global_lr, args.verbose) 360 | 361 | # Log averaged 362 | print(f' [Overall] Train Loss {train_loss_mt.avg:.4f} Acc {train_acc_mt.avg*100:.1f}% ' 363 | f'| Val Acc bp {np.mean(val_acc_list_bp)*100:.2f}%' 364 | f' | Val Acc {np.mean(val_acc_list) * 100:.2f}%') 365 | wandb.log({ 366 | f"train_loss": train_loss_mt.avg, 367 | f"train_acc": train_acc_mt.avg, 368 | f"val_loss_bp": val_loss_bp, 369 | f"val_acc_bp": np.mean(val_acc_list_bp), 370 | f"val_loss": val_loss, 371 | f"val_acc": np.mean(val_acc_list), 372 | }, commit=False) 373 | 374 | 375 | 376 | # ----------- Save checkpoint ----------- 377 | if np.mean(val_acc_list) > np.mean(best_acc): 378 | best_epoch = a_iter 379 | for client_idx in range(fed.client_num): 380 | best_acc[client_idx] = val_acc_list[client_idx] 381 | if args.verbose > 0: 382 | print(' Best site-{:<10s}| Epoch:{} | Val Acc: {:.4f}'.format( 383 | fed.clients[client_idx], best_epoch, best_acc[client_idx])) 384 | print(' [Best Val] Acc {:.4f}'.format(np.mean(val_acc_list))) 385 | 386 | # Save 387 | print(f' Saving the local and server checkpoint to {SAVE_FILE}') 388 | save_dict = { 389 | 'server_model': fed.model_accum.state_dict(), 390 | 'train_dataset': train_loaders, 391 | 'lr' : global_lr, 392 | 'best_epoch': best_epoch, 393 | 'best_acc': best_acc, 394 | 'a_iter': a_iter, 395 | 'all_domains': fed.all_domains, 396 | 'train_elapsed': train_elapsed, 397 | } 398 | torch.save(save_dict, SAVE_FILE) 399 | wandb.log({ 400 | f"best_val_acc": np.mean(best_acc), 401 | }, commit=True) 402 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import Dataset, DataLoader 3 | from torchvision.datasets import CIFAR10 4 | from torchvision.datasets import CIFAR100 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | from collections import defaultdict 8 | import os 9 | from tqdm import tqdm 10 | from torchvision.datasets.folder import IMG_EXTENSIONS, has_file_allowed_extension 11 | 12 | from typing import Tuple, List, Dict, Optional, Callable, cast 13 | from .config import DATA_PATHS 14 | from .utils import shuffle_sampler 15 | 16 | 17 | class DigitsDataset(Dataset): 18 | all_domains = ['MNIST', 'SVHN', 'USPS', 'SynthDigits', 'MNIST_M'] 19 | resorted_domains = { 20 | 0: ['MNIST', 'SVHN', 'USPS', 'SynthDigits', 'MNIST_M'], 21 | 1: ['SVHN', 'USPS', 'SynthDigits', 'MNIST_M', 'MNIST'], 22 | 2: ['USPS', 'SynthDigits', 'MNIST_M', 'MNIST', 'SVHN'], 23 | 3: ['SynthDigits', 'MNIST_M', 'MNIST', 'SVHN', 'USPS'], 24 | 4: ['MNIST_M', 'MNIST', 'SVHN', 'USPS', 'SynthDigits'], 25 | } 26 | num_classes = 10 # may not be correct 27 | 28 | def __init__(self, domain, percent=0.1, filename=None, train=True, transform=None): 29 | data_path = os.path.join(DATA_PATHS["Digits"], domain) 30 | if filename is None: 31 | if train: 32 | if percent >= 0.1: 33 | for part in range(int(percent*10)): 34 | if part == 0: 35 | self.images, self.labels = np.load( 36 | os.path.join(data_path, 37 | 'partitions/train_part{}.pkl'.format(part)), 38 | allow_pickle=True) 39 | else: 40 | images, labels = np.load( 41 | os.path.join(data_path, 42 | 'partitions/train_part{}.pkl'.format(part)), 43 | allow_pickle=True) 44 | self.images = np.concatenate([self.images,images], axis=0) 45 | self.labels = np.concatenate([self.labels,labels], axis=0) 46 | else: 47 | self.images, self.labels = np.load( 48 | os.path.join(data_path, 'partitions/train_part0.pkl'), 49 | allow_pickle=True) 50 | data_len = int(self.images.shape[0] * percent*10) 51 | self.images = self.images[:data_len] 52 | self.labels = self.labels[:data_len] 53 | else: 54 | self.images, self.labels = np.load(os.path.join(data_path, 'test.pkl'), 55 | allow_pickle=True) 56 | else: 57 | self.images, self.labels = np.load(os.path.join(data_path, filename), 58 | allow_pickle=True) 59 | 60 | self.transform = transform 61 | self.channels = 3 if domain in ['SVHN', 'SynthDigits', 'MNIST_M'] else 1 62 | self.labels = self.labels.astype(np.long).squeeze() 63 | self.classes = np.unique(self.labels) 64 | 65 | def __len__(self): 66 | return self.images.shape[0] 67 | 68 | def __getitem__(self, idx): 69 | image = self.images[idx] 70 | label = self.labels[idx] 71 | if self.channels == 1: 72 | image = Image.fromarray(image, mode='L') 73 | elif self.channels == 3: 74 | image = Image.fromarray(image, mode='RGB') 75 | else: 76 | raise ValueError("{} channel is not allowed.".format(self.channels)) 77 | 78 | if self.transform is not None: 79 | image = self.transform(image) 80 | 81 | return image, label 82 | 83 | 84 | class CifarDataset(CIFAR10): 85 | all_domains = ['cifar10'] 86 | resorted_domains = { 87 | 0: ['cifar10'], 88 | } 89 | num_classes = 10 # may not be correct 90 | 91 | def __init__(self, domain='cifar10', train=True, transform=None, download=False): 92 | assert domain in self.all_domains, f"Invalid domain: {domain}" 93 | data_path = os.path.join(DATA_PATHS["Cifar10"], domain) 94 | super().__init__(data_path, train=train, transform=transform, download=download) 95 | 96 | class Cifar100Dataset(CIFAR100): 97 | all_domains = ['cifar100'] 98 | resorted_domains = { 99 | 0: ['cifar100'], 100 | } 101 | num_classes = 100 # may not be correct 102 | 103 | def __init__(self, domain='cifar100', train=True, transform=None, download=True): 104 | assert domain in self.all_domains, f"Invalid domain: {domain}" 105 | data_path = os.path.join(DATA_PATHS["cifar100"], domain) 106 | super().__init__(data_path, train=train, transform=transform, download=download) 107 | 108 | 109 | class DomainNetDataset(Dataset): 110 | all_domains = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch'] 111 | resorted_domains = { 112 | 0: ['real', 'clipart', 'infograph', 'painting', 'quickdraw', 'sketch'], 113 | 1: ['clipart', 'infograph', 'painting', 'quickdraw', 'sketch', 'real'], 114 | 2: ['infograph', 'painting', 'quickdraw', 'sketch', 'real', 'clipart'], 115 | 3: ['painting', 'quickdraw', 'sketch', 'real', 'clipart', 'infograph'], 116 | 4: ['quickdraw', 'sketch', 'real', 'clipart', 'infograph', 'painting'], 117 | 5: ['sketch', 'real', 'clipart', 'infograph', 'painting', 'quickdraw'], 118 | } 119 | num_classes = 10 # may not be correct 120 | 121 | def __init__(self, site, train=True, transform=None, full_set=False): 122 | self.full_set = full_set 123 | self.base_path = DATA_PATHS['DomainNet'] 124 | if full_set: 125 | classes, class_to_idx = find_classes(f"{self.base_path}/{site}") 126 | self.text_labels = classes 127 | self.paths, self.labels = make_dataset_from_dir(f"{self.base_path}/{site}", 128 | class_to_idx, IMG_EXTENSIONS) 129 | self.num_classes = len(class_to_idx) 130 | else: 131 | self.paths, self.text_labels = np.load('{}/DomainNet/{}_{}.pkl'.format( 132 | DATA_PATHS['DomainNetPathList'], 133 | site, 'train' if train else 'test'), allow_pickle=True) 134 | 135 | class_to_idx = {'bird': 0, 'feather': 1, 'headphones': 2, 'ice_cream': 3, 'teapot': 4, 136 | 'tiger': 5, 'whale': 6, 'windmill': 7, 'wine_glass': 8, 'zebra': 9} 137 | 138 | self.labels = [class_to_idx[text] for text in self.text_labels] 139 | self.num_classes = len(class_to_idx) 140 | 141 | self.transform = transform 142 | self.classes = np.unique(self.labels) 143 | 144 | def __len__(self): 145 | return len(self.labels) 146 | 147 | def __getitem__(self, idx): 148 | site, cls, fname = self.paths[idx].split('/')[-3:] 149 | img_path = os.path.join(self.base_path, site, cls, fname) 150 | 151 | label = self.labels[idx] 152 | image = Image.open(img_path) 153 | 154 | if len(image.split()) != 3: 155 | image = transforms.Grayscale(num_output_channels=3)(image) 156 | 157 | if self.transform is not None: 158 | image = self.transform(image) 159 | 160 | return image, label 161 | 162 | 163 | # //////////// Data processing //////////// 164 | def find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]: 165 | """ 166 | Finds the class folders in a dataset. 167 | 168 | Args: 169 | dir (string): Root directory path. 170 | 171 | Returns: 172 | tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx 173 | is a dictionary. 174 | 175 | Ensures: 176 | No class is a subdirectory of another. 177 | """ 178 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 179 | classes.sort() 180 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 181 | return classes, class_to_idx 182 | 183 | 184 | def make_dataset_from_dir( 185 | directory: str, 186 | class_to_idx: Dict[str, int], 187 | extensions: Optional[Tuple[str, ...]] = None, 188 | is_valid_file: Optional[Callable[[str], bool]] = None, 189 | ) -> Tuple[List[str], List[int]]: 190 | """Different Pytorch version, we return path and labels in two lists.""" 191 | paths, labels = [], [] 192 | directory = os.path.expanduser(directory) 193 | both_none = extensions is None and is_valid_file is None 194 | both_something = extensions is not None and is_valid_file is not None 195 | if both_none or both_something: 196 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the " 197 | "same time") 198 | if extensions is not None: 199 | def is_valid_file(x: str) -> bool: 200 | return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions)) 201 | is_valid_file = cast(Callable[[str], bool], is_valid_file) 202 | for target_class in sorted(class_to_idx.keys()): 203 | class_index = class_to_idx[target_class] 204 | target_dir = os.path.join(directory, target_class) 205 | if not os.path.isdir(target_dir): 206 | continue 207 | for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): 208 | for fname in sorted(fnames): 209 | path = os.path.join(root, fname) 210 | if is_valid_file(path): 211 | paths.append(path) 212 | labels.append(class_index) 213 | return paths, labels 214 | 215 | 216 | class Partitioner(object): 217 | """Class for partition a sequence into multiple shares (or users). 218 | 219 | Args: 220 | rng (np.random.RandomState): random state. 221 | partition_mode (str): 'dir' for Dirichlet distribution or 'uni' for uniform. 222 | max_n_sample_per_share (int): max number of samples per share. 223 | min_n_sample_per_share (int): min number of samples per share. 224 | max_n_sample (int): max number of samples 225 | verbose (bool): verbosity 226 | """ 227 | def __init__(self, rng=None, partition_mode="dir", 228 | max_n_sample_per_share=-1, 229 | min_n_sample_per_share=2, 230 | max_n_sample=-1, 231 | verbose=True 232 | ): 233 | assert max_n_sample_per_share < 0 or max_n_sample_per_share > min_n_sample_per_share, \ 234 | f"max ({max_n_sample_per_share}) > min ({min_n_sample_per_share})" 235 | self.rng = rng if rng else np.random 236 | self.partition_mode = partition_mode 237 | self.max_n_sample_per_share = max_n_sample_per_share 238 | self.min_n_sample_per_share = min_n_sample_per_share 239 | self.max_n_sample = max_n_sample 240 | self.verbose = verbose 241 | 242 | def __call__(self, n_sample, n_share, log=print): 243 | """Partition a sequence of `n_sample` into `n_share` shares. 244 | Returns: 245 | partition: A list of num of samples for each share. 246 | """ 247 | assert n_share > 0, f"cannot split into {n_share} share" 248 | if self.verbose: 249 | log(f" {n_sample} smp => {n_share} shards by {self.partition_mode} distr") 250 | if self.max_n_sample > 0: 251 | n_sample = min((n_sample, self.max_n_sample)) 252 | if self.max_n_sample_per_share > 0: 253 | n_sample = min((n_sample, n_share * self.max_n_sample_per_share)) 254 | 255 | if n_sample < self.min_n_sample_per_share * n_share: 256 | raise ValueError(f"Not enough samples. Require {self.min_n_sample_per_share} samples" 257 | f" per share at least for {n_share} shares. But only {n_sample} is" 258 | f" available totally.") 259 | n_sample -= self.min_n_sample_per_share * n_share 260 | if self.partition_mode == "dir": 261 | partition = (self.rng.dirichlet(n_share * [1]) * n_sample).astype(int) 262 | elif self.partition_mode == "uni": 263 | partition = int(n_sample // n_share) * np.ones(n_share, dtype='int') 264 | else: 265 | raise ValueError(f"Invalid partition_mode: {self.partition_mode}") 266 | 267 | # uniformly add residual to as many users as possible. 268 | for i in self.rng.choice(n_share, n_sample - np.sum(partition)): 269 | partition[i] += 1 270 | # partition[-1] += n_sample - np.sum(partition) # add residual 271 | assert sum(partition) == n_sample, f"{sum(partition)} != {n_sample}" 272 | partition = partition + self.min_n_sample_per_share 273 | n_sample += self.min_n_sample_per_share * n_share 274 | # partition = np.minimum(partition, max_n_sample_per_share) 275 | partition = partition.tolist() 276 | 277 | assert sum(partition) == n_sample, f"{sum(partition)} != {n_sample}" 278 | assert len(partition) == n_share, f"{len(partition)} != {n_share}" 279 | return partition 280 | 281 | 282 | class ClassWisePartitioner(Partitioner): 283 | """Partition a list of labels by class. Classes will be shuffled and assigned to users 284 | sequentially. 285 | 286 | Args: 287 | n_class_per_share (int): number of classes per share (user). 288 | rng (np.random.RandomState): random state. 289 | partition_mode (str): 'dir' for Dirichlet distribution or 'uni' for uniform. 290 | max_n_sample_per_share (int): max number of samples per share. 291 | min_n_sample_per_share (int): min number of samples per share. 292 | max_n_sample (int): max number of samples 293 | verbose (bool): verbosity 294 | """ 295 | def __init__(self, n_class_per_share=2, **kwargs): 296 | super(ClassWisePartitioner, self).__init__(**kwargs) 297 | self.n_class_per_share = n_class_per_share 298 | self._aux_partitioner = Partitioner(**kwargs) 299 | 300 | def __call__(self, labels, n_user, log=print, user_ids_by_class=None, 301 | return_user_ids_by_class=False, consistent_class=False): 302 | """Partition a list of labels into `n_user` shares. 303 | Returns: 304 | partition: A list of users, where each user include a list of sample indexes. 305 | """ 306 | # reorganize labels by class 307 | idx_by_class = defaultdict(list) 308 | if len(labels) > 1e5: 309 | labels_iter = tqdm(labels, leave=False, desc='sort labels') 310 | else: 311 | labels_iter = labels 312 | for i, label in enumerate(labels_iter): 313 | idx_by_class[label].append(i) 314 | 315 | n_class = len(idx_by_class) 316 | assert n_user * self.n_class_per_share > n_class, f"Cannot split {n_class} classes into " \ 317 | f"{n_user} users when each user only " \ 318 | f"has {self.n_class_per_share} classes." 319 | 320 | # assign classes to each user. 321 | if user_ids_by_class is None: 322 | user_ids_by_class = defaultdict(list) 323 | label_sampler = shuffle_sampler(list(range(n_class)), 324 | self.rng if consistent_class else None) 325 | for s in range(n_user): 326 | s_classes = [label_sampler.next() for _ in range(self.n_class_per_share)] 327 | for c in s_classes: 328 | user_ids_by_class[c].append(s) 329 | 330 | # assign sample indexes to clients 331 | idx_by_user = [[] for _ in range(n_user)] 332 | if n_class > 100 or len(labels) > 1e5: 333 | idx_by_class_iter = tqdm(idx_by_class, leave=True, desc='split cls') 334 | log = lambda log_s: idx_by_class_iter.set_postfix_str(log_s[:10]) # tqdm.write 335 | else: 336 | idx_by_class_iter = idx_by_class 337 | for c in idx_by_class_iter: 338 | l = len(idx_by_class[c]) 339 | log(f" class-{c} => {len(user_ids_by_class[c])} shares") 340 | l_by_user = self._aux_partitioner(l, len(user_ids_by_class[c]), log=log) 341 | base_idx = 0 342 | for i_user, tl in zip(user_ids_by_class[c], l_by_user): 343 | idx_by_user[i_user].extend(idx_by_class[c][base_idx:base_idx+tl]) 344 | base_idx += tl 345 | if return_user_ids_by_class: 346 | return idx_by_user, user_ids_by_class 347 | else: 348 | return idx_by_user 349 | 350 | 351 | def extract_labels(dataset: Dataset): 352 | if hasattr(dataset, 'targets'): 353 | return dataset.targets 354 | dl = DataLoader(dataset, batch_size=512, drop_last=False, num_workers=4, shuffle=False) 355 | labels = [] 356 | dl_iter = tqdm(dl, leave=False, desc='load labels') if len(dl) > 100 else dl 357 | for _, targets in dl_iter: 358 | labels.extend(targets.cpu().numpy().tolist()) 359 | return labels 360 | 361 | 362 | def test_class_partitioner(): 363 | print(f"\n==== Extract from random labels =====") 364 | split = ClassWisePartitioner() 365 | n_class = 10 366 | n_sample = 1000 367 | n_user = 100 368 | labels = np.random.randint(0, n_class, n_sample) 369 | idx_by_user = split(labels, n_user) 370 | _n_smp = 0 371 | for u in range(n_user): 372 | u_labels = labels[idx_by_user[u]] 373 | u_classes = np.unique(u_labels) 374 | print(f"user-{u} | {len(idx_by_user[u])} samples | {len(u_classes)} classes: {u_classes}") 375 | assert len(u_classes) == 2 376 | _n_smp += len(u_labels) 377 | assert _n_smp == n_sample 378 | 379 | print(f"\n==== Extract from dataset =====") 380 | from .data_loader import CifarDataset 381 | ds = CifarDataset('cifar10', transform=transforms.ToTensor()) 382 | labels = extract_labels(ds) 383 | n_sample = len(labels) 384 | idx_by_user = split(labels, n_user) 385 | labels = np.array(labels) 386 | _n_smp = 0 387 | for u in range(n_user): 388 | u_labels = labels[idx_by_user[u]] 389 | u_classes = np.unique(u_labels) 390 | print(f"user-{u} | {len(idx_by_user[u])} samples | {len(u_classes)} classes: {u_classes}") 391 | assert len(u_classes) == 2 392 | _n_smp += len(u_labels) 393 | assert _n_smp == n_sample, f"Expected {n_sample} samples but got {_n_smp}" 394 | 395 | 396 | if __name__ == '__main__': 397 | import argparse 398 | parser = argparse.ArgumentParser() 399 | parser.add_argument('--download', type=str, default='none', choices=['Cifar10'], 400 | help='Download datasets.') 401 | parser.add_argument('--test', action='store_true', help='Run test') 402 | args = parser.parse_args() 403 | if args.test: 404 | test_class_partitioner() 405 | else: 406 | if args.download == 'Cifar10': 407 | CifarDataset(download=True, train=True) 408 | CifarDataset(download=True, train=False) 409 | else: 410 | print(f"Nothing to download for dataset: {args.download}") 411 | -------------------------------------------------------------------------------- /fed_dataHetComp.py: -------------------------------------------------------------------------------- 1 | """FedBABU, FedAvg, fedProx, and fedNOVA""" 2 | import os, argparse, time 3 | import numpy as np 4 | import wandb 5 | import copy 6 | import torch 7 | import math 8 | from torch import nn, optim 9 | # federated 10 | from federated.learning import train, test, personalization, train_fedprox 11 | # utils 12 | from utils.utils import set_seed, AverageMeter, CosineAnnealingLR, \ 13 | MultiStepLR, str2bool 14 | from utils.config import CHECKPOINT_ROOT 15 | 16 | # NOTE import desired federation 17 | from federated.core import HeteFederation as Federation 18 | 19 | 20 | def render_run_name(args, exp_folder): 21 | """Return a unique run_name from given args.""" 22 | if args.model == 'default': 23 | args.model = {'Digits': 'digit', 'Cifar10': 'preresnet18', 'Cifar100': 'mobile', 'DomainNet': 'alex'}[args.data] 24 | run_name = f'{args.model}' 25 | if args.width_scale != 1.: run_name += f'x{args.width_scale}' 26 | run_name += Federation.render_run_name(args) 27 | # log non-default args 28 | if args.seed != 1: run_name += f'__seed_{args.seed}' 29 | # opt 30 | if args.lr_sch != 'none': run_name += f'__lrs_{args.lr_sch}' 31 | if args.opt != 'sgd': run_name += f'__opt_{args.opt}' 32 | if args.batch != 32: run_name += f'__batch_{args.batch}' 33 | if args.wk_iters != 1: run_name += f'__wk_iters_{args.wk_iters}' 34 | # slimmable 35 | if args.no_track_stat: run_name += f"__nts" 36 | if args.no_mask_loss: run_name += f'__nml' 37 | 38 | 39 | args.save_path = os.path.join(CHECKPOINT_ROOT, exp_folder) 40 | if not os.path.exists(args.save_path): 41 | os.makedirs(args.save_path) 42 | SAVE_FILE = os.path.join(args.save_path, run_name) 43 | return run_name, SAVE_FILE 44 | 45 | 46 | def get_model_fh(data, model): 47 | if data == 'Digits': 48 | if model in ['digit']: 49 | from nets.models import DigitModel 50 | ModelClass = DigitModel 51 | else: 52 | raise ValueError(f"Invalid model: {model}") 53 | elif data in ['DomainNet']: 54 | if model in ['alex']: 55 | from nets.models import AlexNet 56 | ModelClass = AlexNet 57 | else: 58 | raise ValueError(f"Invalid model: {model}") 59 | elif data == 'Cifar10': 60 | if model in ['preresnet18']: # From heteroFL 61 | from nets.HeteFL.preresne import resnet18 62 | ModelClass = resnet18 63 | else: 64 | raise ValueError(f"Invalid model: {model}") 65 | elif data == 'Cifar100': 66 | if model in ['mobile']: # From heteroFL 67 | from nets.Nets import MobileNetCifar 68 | ModelClass = MobileNetCifar 69 | else: 70 | raise ValueError(f"Invalid model: {model}") 71 | else: 72 | raise ValueError(f"Unknown dataset: {data}") 73 | return ModelClass 74 | 75 | 76 | def mask_fed_test(fed, running_model, train_loaders, val_loaders, global_lr, verbose): 77 | mark = 's' 78 | val_acc_list_bp = [None for _ in range(fed.client_num)] 79 | val_loss_mt_bp = AverageMeter() 80 | 81 | val_acc_list = [None for _ in range(fed.client_num)] 82 | val_loss_mt = AverageMeter() 83 | for client_idx in range(fed.client_num): 84 | fed.download(running_model, client_idx) 85 | val_model = copy.deepcopy(running_model) 86 | # Test 87 | # Loss and accuracy before personalization 88 | val_loss_bp, val_acc_bp = test(val_model, val_loaders[client_idx], loss_fun, device) 89 | 90 | # Log 91 | val_loss_mt_bp.append(val_loss_bp) 92 | val_acc_list_bp[client_idx] = val_acc_bp 93 | if verbose > 0: 94 | print(' {:<19s} Val Before Personalization {:s}Loss: {:.4f} | Val {:s}Acc: {:.4f}'.format( 95 | 'User-'+fed.clients[client_idx], mark.upper(), val_loss_bp, mark.upper(), val_acc_bp)) 96 | wandb.log({ 97 | f"{fed.clients[client_idx]} val_bp_{mark}-acc": val_acc_bp, 98 | }, commit=False) 99 | 100 | if args.test: 101 | 102 | # Personalization 103 | 104 | val_loss, val_acc = personalization(val_model, train_loaders[client_idx], val_loaders[client_idx], 105 | loss_fun, global_lr, device) 106 | 107 | # Log 108 | val_loss_mt.append(val_loss) 109 | val_acc_list[client_idx] = val_acc 110 | if verbose > 0: 111 | print(' {:<19s} Val {:s}Loss: {:.4f} | Val {:s}Acc: {:.4f}'.format( 112 | 'User-'+fed.clients[client_idx], mark.upper(), val_loss, mark.upper(), val_acc)) 113 | wandb.log({ 114 | f"{fed.clients[client_idx]} val_{mark}-acc": val_acc, 115 | }, commit=False) 116 | 117 | if args.test: 118 | 119 | return val_acc_list, val_loss_mt.avg, val_acc_list_bp, val_loss_mt_bp.avg 120 | else: 121 | 122 | return val_acc_list_bp, val_loss_mt_bp.avg, val_acc_list_bp, val_loss_mt_bp.avg 123 | 124 | 125 | if __name__ == '__main__': 126 | 127 | 128 | 129 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 130 | 131 | parser = argparse.ArgumentParser() 132 | # basic problem setting 133 | parser.add_argument('--seed', type=int, default=1, help='random seed') 134 | parser.add_argument('--data', type=str, default='Cifar10', help='data name') # 'DomainNet' 'Cifar100' 135 | parser.add_argument('--model', type=str.lower, default='default', help='model name') 136 | parser.add_argument('--algorithm', type=str, default='fedBABU', help='algorithm name') # 'fedProx' 'fedNOVA' 'fedProx' 137 | parser.add_argument('--mu', type=float, default=0.0, help='The hyper parameter for fedProx algorithm') # 0.1 138 | parser.add_argument('--width_scale', type=float, default=1., help='model width scale') 139 | parser.add_argument('--no_track_stat', action='store_true', help='disable BN tracking') 140 | parser.add_argument('--no_mask_loss', action='store_true', help='disable masked loss for class' 141 | ' niid') 142 | # control 143 | parser.add_argument('--no_log', action='store_true', help='disable wandb log') 144 | parser.add_argument('--test', action='store_true', help='test the pretrained model') 145 | #parser.add_argument('--test', type=str2bool, default=True, help='test the pretrained model') #action='store_true' 146 | parser.add_argument('--resume', action='store_true', help='resume training from checkpoint') 147 | #parser.add_argument('--resume', type=str2bool, default=True, help='resume training from checkpoint') 148 | parser.add_argument('--verbose', type=int, default=0, help='verbose level: 0 or 1') 149 | # federated 150 | Federation.add_argument(parser) 151 | # optimization 152 | parser.add_argument('--lr', type=float, default=1e-1, help='learning rate') #1e-2 1e-1 153 | parser.add_argument('--lr_sch', type=str, default='multi_step', help='learning rate schedule') #'none' 'cos' 154 | parser.add_argument('--opt', type=str.lower, default='sgd', help='optimizer') 155 | parser.add_argument('--iters', type=int, default=80, help='#iterations for communication')#200 156 | parser.add_argument('--wk_iters', type=int, default=4, help='#epochs in local train')#5 1 157 | parser.add_argument('--wk_factor', type=int, default=0.5, help='#decreasing factor of local training epochs for less capable devices') 158 | 159 | args = parser.parse_args() 160 | 161 | set_seed(args.seed) 162 | 163 | 164 | 165 | # ///////////////////////////////// 166 | # ///// Fed Dataset and Model ///// 167 | # ///////////////////////////////// 168 | fed = Federation(args.data, args) 169 | # Data 170 | train_loaders, val_loaders, test_loaders = fed.get_data() 171 | mean_batch_iters = int(np.mean([len(tl) for tl in train_loaders])) 172 | print(f" mean_batch_iters: {mean_batch_iters}") 173 | 174 | # set experiment files, wandb 175 | exp_folder = f'Alg_{args.algorithm}_C{fed.args.pr_nuser}_{args.data}' 176 | run_name, SAVE_FILE = render_run_name(args, exp_folder) 177 | wandb.init(group=run_name[:120], project=exp_folder, 178 | mode='offline' if args.no_log else 'online', 179 | config={**vars(args), 'save_file': SAVE_FILE}) 180 | 181 | # Model 182 | ModelClass = get_model_fh(args.data, args.model) 183 | running_model = ModelClass( 184 | track_running_stats=False, num_classes=fed.num_classes, 185 | width_scale=args.width_scale, 186 | ).to(device) 187 | 188 | 189 | 190 | 191 | # Loss 192 | loss_fun = nn.CrossEntropyLoss() 193 | 194 | 195 | # Use running model to init a fed aggregator 196 | fed.make_aggregator(running_model) 197 | 198 | 199 | 200 | # Last layer as head model 201 | if (args.model == 'alex'): 202 | 203 | head_part = 'fc3' 204 | 205 | else: 206 | 207 | head_part = 'linear' 208 | 209 | # Masking elements for each user 210 | names = [] 211 | paramSize = [] 212 | for name, par in running_model.named_parameters(): 213 | 214 | 215 | 216 | if (args.algorithm == 'fedBABU'): 217 | if head_part not in name: 218 | 219 | names.append(name) 220 | paramSize.append(np.prod(list(par.size()))) 221 | else: 222 | names.append(name) 223 | paramSize.append(np.prod(list(par.size()))) 224 | 225 | 226 | 227 | totalParamNum = sum(paramSize) 228 | users_max_comp = fed.get_user_max_slim_ratios() 229 | 230 | 231 | 232 | users_max_slim_ratio = [1.0] * fed.client_sampler.tot() 233 | 234 | 235 | 236 | computable_body_layers = {userIdx: [] for userIdx in range(fed.client_sampler.tot())} 237 | totParamNum = 0 238 | userCounter = 0 239 | tau_fedNOVA = [1.0] * fed.client_sampler.tot() 240 | aNorm_fedNOVA = [1.0] * fed.client_sampler.tot() 241 | 242 | 243 | # For obtaining the parameters related to FedNova, we have used the equations in https://proceedings.neurips.cc/paper/2020/file/564127c03caab942e503ee6f810f54fd-Paper.pdf 244 | for userIdx in range(fed.client_sampler.tot()): 245 | 246 | if (users_max_comp[userIdx] == 1.0): 247 | tau_fedNOVA[userIdx] = (args.wk_iters) * len(train_loaders[userIdx]) 248 | else: 249 | # Less capable devices perform lower number of local update iterations 250 | tau_fedNOVA[userIdx] = math.ceil(args.wk_iters*args.wk_factor) * len(train_loaders[userIdx]) 251 | 252 | 253 | 254 | aNorm_fedNOVA[userIdx] = (tau_fedNOVA[userIdx] - ( (0.9 * (1. - pow(0.9, tau_fedNOVA[userIdx]))) / (1-0.9)) ) / (1-0.9) # 0.9 is the considered value for the momentum 255 | 256 | computableParamNum = 0 257 | maxComputableLayers = int(users_max_slim_ratio[userIdx] * len(names)) 258 | 259 | namesCopy = names.copy() 260 | 261 | for layerIdx in range(maxComputableLayers): 262 | 263 | selectedLayer = np.random.choice(namesCopy, 1, replace=False) 264 | 265 | namesCopy.remove(selectedLayer) 266 | 267 | selectedLayerParamSize = paramSize[names.index(selectedLayer)] 268 | 269 | 270 | computable_body_layers[userIdx].append(selectedLayer.item()) 271 | computableParamNum += selectedLayerParamSize 272 | 273 | totParamNum += computableParamNum 274 | userCounter += 1 275 | 276 | 277 | 278 | wandb.log({'Num_of_Params': totParamNum/userCounter}, commit=False) 279 | 280 | effTau_fedNOVA = sum([a*b for a,b in zip(aNorm_fedNOVA,fed.client_weights)]) 281 | 282 | 283 | 284 | # ///////////////// 285 | # //// Resume ///// 286 | # ///////////////// 287 | # log the best for each model on all datasets 288 | best_epoch = 0 289 | best_acc = [0. for j in range(fed.client_num)] 290 | train_elapsed = [[] for _ in range(fed.client_num)] 291 | start_epoch = 0 292 | if args.resume or args.test: 293 | if os.path.exists(SAVE_FILE): 294 | print(f'Loading chkpt from {SAVE_FILE}') 295 | checkpoint = torch.load(SAVE_FILE) 296 | best_epoch, best_acc = checkpoint['best_epoch'], checkpoint['best_acc'] 297 | train_elapsed = checkpoint['train_elapsed'] 298 | train_dataset = checkpoint['train_dataset'] 299 | global_lr = checkpoint['lr'] 300 | start_epoch = int(checkpoint['a_iter']) + 1 301 | fed.model_accum.load_state_dict(checkpoint['server_model']) 302 | 303 | print('Resume training from epoch {} with best acc:'.format(start_epoch)) 304 | for client_idx, acc in enumerate(best_acc): 305 | print(' Best user-{:<10s}| Epoch:{} | Val Acc: {:.4f}'.format( 306 | fed.clients[client_idx], best_epoch, acc)) 307 | else: 308 | if args.test: 309 | raise FileNotFoundError(f"Not found checkpoint at {SAVE_FILE}") 310 | else: 311 | print(f"Not found checkpoint at {SAVE_FILE}\n **Continue without resume.**") 312 | 313 | 314 | # /////////////// 315 | # //// Test ///// 316 | # /////////////// 317 | if args.test: 318 | wandb.summary[f'best_epoch'] = best_epoch 319 | 320 | # Set up model with specified width 321 | print(f" Test model: {args.model}x{args.width_scale}") 322 | 323 | # Test on clients 324 | 325 | 326 | test_acc_list, _, test_acc_list_bp, _ = mask_fed_test(fed, running_model, train_dataset, test_loaders, 327 | global_lr, args.verbose) 328 | 329 | 330 | print(f"\n Average Test Acc Before Personalization: {np.mean(test_acc_list_bp)}") 331 | wandb.summary[f'avg test acc bp'] = np.mean(test_acc_list_bp) 332 | print(f"\n Average Test Acc: {np.mean(test_acc_list)}") 333 | wandb.summary[f'avg test acc'] = np.mean(test_acc_list) 334 | wandb.finish() 335 | 336 | exit(0) 337 | 338 | 339 | # //////////////// 340 | # //// Train ///// 341 | # //////////////// 342 | # LR scheduler 343 | if args.lr_sch == 'cos': 344 | lr_sch = CosineAnnealingLR(args.iters, eta_max=args.lr, last_epoch=start_epoch) 345 | elif args.lr_sch == 'multi_step': 346 | lr_sch = MultiStepLR(args.lr, milestones=[args.iters//2, (args.iters * 3)//4], gamma=0.1, last_epoch=start_epoch) 347 | else: 348 | assert args.lr_sch == 'none', f'Invalid lr_sch: {args.lr_sch}' 349 | lr_sch = None 350 | for a_iter in range(start_epoch, args.iters): 351 | # set global lr 352 | global_lr = args.lr if lr_sch is None else lr_sch.step() 353 | wandb.log({'global lr': global_lr}, commit=False) 354 | 355 | 356 | # ----------- Train Client --------------- 357 | train_loss_mt, train_acc_mt = AverageMeter(), AverageMeter() 358 | print("============ Train epoch {} ============".format(a_iter)) 359 | selectedUsers = [] 360 | for client_idx in fed.client_sampler.iter(): 361 | selectedUsers.append(client_idx) 362 | start_time = time.process_time() 363 | 364 | # Download global model to local 365 | fed.download(running_model, client_idx) 366 | 367 | 368 | 369 | if (users_max_comp[client_idx] == 1.0): 370 | local_iter = args.wk_iters 371 | else: 372 | local_iter = math.ceil(args.wk_iters*args.wk_factor) 373 | 374 | 375 | if (args.algorithm == 'fedNOVA'): 376 | 377 | local_lr = (effTau_fedNOVA/aNorm_fedNOVA[client_idx]) * global_lr 378 | 379 | else: 380 | 381 | local_lr = global_lr 382 | 383 | 384 | 385 | 386 | if (args.algorithm == 'fedBABU'): 387 | optim_input = [] 388 | 389 | for name, par in running_model.named_parameters(): 390 | 391 | if name in computable_body_layers[client_idx]: 392 | 393 | par.requires_grad = True 394 | optim_input.append({'params': par, 'lr': local_lr}) 395 | 396 | else: 397 | 398 | par.requires_grad = False 399 | par.requires_grad_(False) 400 | optim_input.append({'params': par, 'lr': 0.0}) 401 | 402 | 403 | if args.opt == 'sgd': 404 | 405 | if (args.algorithm == 'fedBABU'): 406 | optimizer = optim.SGD(optim_input, momentum=0.9, weight_decay=5e-4) 407 | 408 | else: 409 | 410 | optimizer = optim.SGD(params=running_model.parameters(), lr=local_lr, 411 | momentum=0.9, weight_decay=5e-4) 412 | 413 | else: 414 | raise ValueError(f"Invalid optimizer: {args.opt}") 415 | 416 | 417 | if ((args.algorithm == 'fedProx') and (a_iter > 0)): 418 | 419 | train_loss, train_acc = train_fedprox(args.mu, 420 | running_model, train_loaders[client_idx], optimizer, loss_fun, device, 421 | max_iter=mean_batch_iters * local_iter if args.partition_mode != 'uni' 422 | else len(train_loaders[client_idx]) * local_iter, 423 | progress=args.verbose > 0 424 | ) 425 | 426 | else: 427 | 428 | train_loss, train_acc = train( 429 | running_model, train_loaders[client_idx], optimizer, loss_fun, device, 430 | max_iter=mean_batch_iters * local_iter if args.partition_mode != 'uni' 431 | else len(train_loaders[client_idx]) * local_iter, 432 | progress=args.verbose > 0, 433 | ) 434 | 435 | # Upload 436 | fed.mask_upload(running_model, client_idx, computable_body_layers[client_idx]) 437 | 438 | # Log 439 | client_name = fed.clients[client_idx] 440 | elapsed = time.process_time() - start_time 441 | wandb.log({f'{client_name}_train_elapsed': elapsed}, commit=False) 442 | train_elapsed[client_idx].append(elapsed) 443 | 444 | train_loss_mt.append(train_loss), train_acc_mt.append(train_acc) 445 | print(f' User-{client_name:<10s} Train | Loss: {train_loss:.4f} |' 446 | f' Acc: {train_acc:.4f} | Elapsed: {elapsed:.2f} s') 447 | wandb.log({ 448 | f"{client_name} train_loss": train_loss, 449 | f"{client_name} train_acc": train_acc, 450 | }, commit=False) 451 | 452 | # Use accumulated model to update server model 453 | fed.aggregate() 454 | 455 | 456 | # ----------- Validation --------------- 457 | val_acc_list, val_loss, val_acc_list_bp, val_loss_bp = mask_fed_test(fed, running_model, train_loaders, 458 | val_loaders, global_lr, args.verbose) 459 | 460 | # Log averaged 461 | print(f' [Overall] Train Loss {train_loss_mt.avg:.4f} Acc {train_acc_mt.avg*100:.1f}%' 462 | f' | Val Acc bp {np.mean(val_acc_list_bp) * 100:.2f}%' 463 | f' | Val Acc {np.mean(val_acc_list) * 100:.2f}%') 464 | wandb.log({ 465 | f"train_loss": train_loss_mt.avg, 466 | f"train_acc": train_acc_mt.avg, 467 | f"val_loss_bp": val_loss_bp, 468 | f"val_acc_bp": np.mean(val_acc_list_bp), 469 | f"val_loss": val_loss, 470 | f"val_acc": np.mean(val_acc_list), 471 | }, commit=False) 472 | 473 | # ----------- Save checkpoint ----------- 474 | if np.mean(val_acc_list) > np.mean(best_acc): 475 | best_epoch = a_iter 476 | for client_idx in range(fed.client_num): 477 | best_acc[client_idx] = val_acc_list[client_idx] 478 | if args.verbose > 0: 479 | print(' Best site-{:<10s}| Epoch:{} | Val Acc: {:.4f}'.format( 480 | fed.clients[client_idx], best_epoch, best_acc[client_idx])) 481 | print(' [Best Val] Acc {:.4f}'.format(np.mean(val_acc_list))) 482 | 483 | # Save 484 | print(f' Saving the local and server checkpoint to {SAVE_FILE}') 485 | save_dict = { 486 | 'server_model': fed.model_accum.state_dict(), 487 | 'train_dataset': train_loaders, 488 | 'lr' : global_lr, 489 | 'best_epoch': best_epoch, 490 | 'best_acc': best_acc, 491 | 'a_iter': a_iter, 492 | 'all_domains': fed.all_domains, 493 | 'train_elapsed': train_elapsed, 494 | } 495 | torch.save(save_dict, SAVE_FILE) 496 | wandb.log({ 497 | f"best_val_acc": np.mean(best_acc), 498 | }, commit=True) 499 | -------------------------------------------------------------------------------- /utils/data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms as transforms 4 | from torch.utils.data import Subset 5 | from torch.utils.data import DataLoader 6 | 7 | from utils.data_utils import DomainNetDataset, DigitsDataset, Partitioner, \ 8 | CifarDataset, Cifar100Dataset, ClassWisePartitioner, extract_labels 9 | 10 | 11 | def compose_transforms(trns, image_norm): 12 | if image_norm == '0.5': 13 | return transforms.Compose(trns + [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 14 | elif image_norm == 'torch': 15 | return transforms.Compose(trns + [transforms.Normalize((0.4914, 0.4822, 0.4465), 16 | (0.2023, 0.1994, 0.2010))]) 17 | elif image_norm == 'torch-resnet': 18 | return transforms.Compose(trns + [transforms.Normalize(mean=[0.485, 0.456, 0.406], 19 | std=[0.229, 0.224, 0.225])]) 20 | elif image_norm == 'none': 21 | return transforms.Compose(trns) 22 | else: 23 | raise ValueError(f"Invalid image_norm: {image_norm}") 24 | 25 | 26 | def get_central_data(name: str, domains: list, percent=1., image_norm='none', 27 | disable_image_norm_error=False): 28 | if image_norm != 'none' and not disable_image_norm_error: 29 | raise RuntimeError(f"This is a hard warning. Use image_norm != none will make the PGD" 30 | f" attack invalid since PGD will clip the image into [0,1] range. " 31 | f"Think before you choose {image_norm} image_norm.") 32 | if percent != 1. and name.lower() != 'digits': 33 | raise RuntimeError(f"percent={percent} should not be used in get_central_data." 34 | f" Pass it to make_fed_data instead.") 35 | if name.lower() == 'digits': 36 | if image_norm == 'default': 37 | image_norm = '0.5' 38 | for domain in domains: 39 | if domain not in DigitsDataset.all_domains: 40 | raise ValueError(f"Invalid domain: {domain}") 41 | # Prepare data 42 | trns = { 43 | 'MNIST': [ 44 | transforms.Grayscale(num_output_channels=3), 45 | transforms.ToTensor(), 46 | ], 47 | 'SVHN': [ 48 | transforms.Resize([28,28]), 49 | transforms.ToTensor(), 50 | ], 51 | 'USPS': [ 52 | transforms.Resize([28,28]), 53 | transforms.Grayscale(num_output_channels=3), 54 | transforms.ToTensor(), 55 | ], 56 | 'SynthDigits': [ 57 | transforms.Resize([28,28]), 58 | transforms.ToTensor(), 59 | ], 60 | 'MNIST_M': [ 61 | transforms.ToTensor(), 62 | ], 63 | } 64 | 65 | train_sets = [DigitsDataset(domain, 66 | percent=percent, train=True, 67 | transform=compose_transforms(trns[domain], image_norm)) 68 | for domain in domains] 69 | test_sets = [DigitsDataset(domain, 70 | train=False, 71 | transform=compose_transforms(trns[domain], image_norm)) 72 | for domain in domains] 73 | elif name.lower() in ('domainnet', 'domainnetf'): 74 | transform_train = transforms.Compose([ 75 | transforms.Resize([256, 256]), 76 | transforms.RandomHorizontalFlip(), 77 | transforms.RandomRotation((-30, 30)), 78 | transforms.ToTensor(), 79 | ]) 80 | 81 | transform_test = transforms.Compose([ 82 | transforms.Resize([256, 256]), 83 | transforms.ToTensor(), 84 | ]) 85 | 86 | train_sets = [ 87 | DomainNetDataset(domain, transform=transform_train, 88 | full_set=name.lower()=='domainnetf') 89 | for domain in domains 90 | ] 91 | test_sets = [ 92 | DomainNetDataset(domain, transform=transform_test, train=False, 93 | full_set=name.lower()=='domainnetf') 94 | for domain in domains 95 | ] 96 | elif name.lower() == 'cifar10': 97 | if image_norm == 'default': 98 | image_norm = 'torch' 99 | for domain in domains: 100 | if domain not in CifarDataset.all_domains: 101 | raise ValueError(f"Invalid domain: {domain}") 102 | trn_train = [transforms.RandomCrop(32, padding=4), 103 | transforms.RandomHorizontalFlip(), 104 | transforms.ToTensor()] 105 | trn_test = [transforms.ToTensor()] 106 | 107 | train_sets = [CifarDataset(domain, train=True, 108 | transform=compose_transforms(trn_train, image_norm)) 109 | for domain in domains] 110 | test_sets = [CifarDataset(domain, train=False, 111 | transform=compose_transforms(trn_test, image_norm)) 112 | for domain in domains] 113 | 114 | elif name.lower() == 'cifar100': 115 | if image_norm == 'default': 116 | image_norm = 'torch' 117 | for domain in domains: 118 | if domain not in Cifar100Dataset.all_domains: 119 | raise ValueError(f"Invalid domain: {domain}") 120 | trn_train = transforms.Compose([transforms.RandomCrop(32, padding=4), 121 | transforms.RandomHorizontalFlip(), 122 | transforms.ToTensor(), 123 | transforms.Normalize(mean=[0.507, 0.487, 0.441], 124 | std=[0.267, 0.256, 0.276])]) 125 | 126 | 127 | trn_test = transforms.Compose([transforms.ToTensor(), 128 | transforms.Normalize(mean=[0.507, 0.487, 0.441], 129 | std=[0.267, 0.256, 0.276])]) 130 | 131 | 132 | train_sets = [Cifar100Dataset(domain, train=True, transform=trn_train) 133 | for domain in domains] 134 | test_sets = [Cifar100Dataset(domain, train=False,transform=trn_test) 135 | for domain in domains] 136 | 137 | else: 138 | raise NotImplementedError(f"name: {name}") 139 | return train_sets, test_sets 140 | 141 | 142 | def make_fed_data(train_sets, test_sets, batch_size, domains, shuffle_eval=False, 143 | n_user_per_domain=1, partition_seed=42, partition_mode='uni', 144 | n_class_per_user=-1, val_ratio=0.2, 145 | eq_domain_train_size=True, percent=1., 146 | num_workers=0, pin_memory=False, min_n_sample_per_share=128, 147 | subset_with_logits=False, 148 | test_batch_size=None, shuffle=True, 149 | consistent_test_class=False): 150 | """Distribute multi-domain datasets (`train_sets`) into federated clients. 151 | 152 | Args: 153 | train_sets (list): A list of datasets for training. 154 | test_sets (list): A list of datasets for testing. 155 | partition_seed (int): Seed for partitioning data into clients. 156 | consistent_test_class (bool): Ensure the test classes are the same training for a client. 157 | Meanwhile, make test sets are uniformly splitted for clients. 158 | """ 159 | test_batch_size = batch_size if test_batch_size is None else test_batch_size 160 | SubsetClass = SubsetWithLogits if subset_with_logits else Subset 161 | clients = [f'{i}' for i in range(len(domains))] 162 | 163 | print(f" train size: {[len(s) for s in train_sets]}") 164 | print(f" test size: {[len(s) for s in test_sets]}") 165 | 166 | train_len = [len(s) for s in train_sets] 167 | if eq_domain_train_size: 168 | train_len = [min(train_len)] * len(train_sets) 169 | # assert all([len(s) == train_len[0] for s in train_sets]), f"Should be equal length." 170 | 171 | if percent < 1: 172 | train_len = [int(tl * percent) for tl in train_len] 173 | 174 | print(f" trimmed train size: {[tl for tl in train_len]}") 175 | 176 | if n_user_per_domain > 1: # split data into multiple users 177 | if n_class_per_user > 0: # split by class-wise non-iid 178 | split = ClassWisePartitioner(rng=np.random.RandomState(partition_seed), 179 | n_class_per_share=n_class_per_user, 180 | min_n_sample_per_share=min_n_sample_per_share, 181 | partition_mode=partition_mode, 182 | verbose=True) 183 | splitted_clients = [] 184 | val_sets, sub_train_sets, user_ids_by_class = [], [], [] 185 | for i_client, (dname, tr_set) in enumerate(zip(clients, train_sets)): 186 | _tr_labels = extract_labels(tr_set) # labels in the original order 187 | _tr_labels = _tr_labels[:train_len[i_client]] # trim 188 | _idx_by_user, _user_ids_by_cls = split(_tr_labels, n_user_per_domain, 189 | return_user_ids_by_class=True) 190 | print(f" {dname} | train split size: {[len(idxs) for idxs in _idx_by_user]}") 191 | _tr_labels = np.array(_tr_labels) 192 | print(f" | train classes: " 193 | f"{[f'{np.unique(_tr_labels[idxs]).tolist()}' for idxs in _idx_by_user]}") 194 | 195 | for i_user, idxs in zip(range(n_user_per_domain), _idx_by_user): 196 | vl = int(val_ratio * len(idxs)) 197 | 198 | np.random.shuffle(idxs) 199 | sub_train_sets.append(SubsetClass(tr_set, idxs[vl:])) 200 | 201 | #np.random.shuffle(idxs) 202 | val_sets.append(Subset(tr_set, idxs[:vl])) 203 | 204 | splitted_clients.append(f"{dname}-{i_user}") 205 | user_ids_by_class.append(_user_ids_by_cls if consistent_test_class else None) 206 | 207 | if consistent_test_class: 208 | # recreate partitioner to make sure consistent class distribution. 209 | split = ClassWisePartitioner(rng=np.random.RandomState(partition_seed), 210 | n_class_per_share=n_class_per_user, 211 | min_n_sample_per_share=min_n_sample_per_share, 212 | partition_mode='uni', 213 | verbose=True) 214 | sub_test_sets = [] 215 | for i_client, te_set in enumerate(test_sets): 216 | _te_labels = extract_labels(te_set) 217 | _idx_by_user = split(_te_labels, n_user_per_domain, 218 | user_ids_by_class=user_ids_by_class[i_client]) 219 | print(f" test split size: {[len(idxs) for idxs in _idx_by_user]}") 220 | _te_labels = np.array(_te_labels) 221 | print(f" test classes: " 222 | f"{[f'{np.unique(_te_labels[idxs]).tolist()}' for idxs in _idx_by_user]}") 223 | 224 | for idxs in _idx_by_user: 225 | np.random.shuffle(idxs) 226 | sub_test_sets.append(Subset(te_set, idxs)) 227 | else: # class iid 228 | split = Partitioner(rng=np.random.RandomState(partition_seed), 229 | min_n_sample_per_share=min_n_sample_per_share, 230 | partition_mode=partition_mode) 231 | splitted_clients = [] 232 | 233 | val_sets, sub_train_sets = [], [] 234 | for i_client, (dname, tr_set) in enumerate(zip(clients, train_sets)): 235 | _train_len_by_user = split(train_len[i_client], n_user_per_domain) 236 | print(f" {dname} | train split size: {_train_len_by_user}") 237 | 238 | base_idx = 0 239 | for i_user, tl in zip(range(n_user_per_domain), _train_len_by_user): 240 | vl = int(val_ratio * tl) 241 | tl = tl - vl 242 | 243 | sub_train_sets.append(SubsetClass(tr_set, list(range(base_idx, base_idx + tl)))) 244 | base_idx += tl 245 | 246 | val_sets.append(Subset(tr_set, list(range(base_idx, base_idx + vl)))) 247 | base_idx += vl 248 | 249 | splitted_clients.append(f"{dname}-{i_user}") 250 | 251 | # uniformly distribute test sets 252 | if consistent_test_class: 253 | split = Partitioner(rng=np.random.RandomState(partition_seed), 254 | min_n_sample_per_share=min_n_sample_per_share, 255 | partition_mode='uni') 256 | sub_test_sets = [] 257 | for te_set in test_sets: 258 | _test_len_by_user = split(len(te_set), n_user_per_domain) 259 | 260 | base_idx = 0 261 | for tl in _test_len_by_user: 262 | sub_test_sets.append(Subset(te_set, list(range(base_idx, base_idx + tl)))) 263 | base_idx += tl 264 | 265 | # rename 266 | train_sets = sub_train_sets 267 | test_sets = sub_test_sets 268 | clients = splitted_clients 269 | else: # single user 270 | assert n_class_per_user <= 0, "Cannot split in Non-IID way when only one user for one " \ 271 | f"domain. But got n_class_per_user={n_class_per_user}" 272 | val_len = [int(tl * val_ratio) for tl in train_len] 273 | 274 | val_sets = [Subset(tr_set, list(range(train_len[i_client]-val_len[i_client], 275 | train_len[i_client]))) 276 | for i_client, tr_set in enumerate(train_sets)] 277 | train_sets = [Subset(tr_set, list(range(train_len[i_client]-val_len[i_client]))) 278 | for i_client, tr_set in enumerate(train_sets)] 279 | 280 | # check the real sizes 281 | print(f" split users' train size: {[len(ts) for ts in train_sets]}") 282 | print(f" split users' val size: {[len(ts) for ts in val_sets]}") 283 | print(f" split users' test size: {[len(ts) for ts in test_sets]}") 284 | if val_ratio > 0: 285 | for i_ts, ts in enumerate(val_sets): 286 | if len(ts) <= 0: 287 | raise RuntimeError(f"user-{i_ts} not has enough val data.") 288 | 289 | train_loaders = [DataLoader(tr_set, batch_size=batch_size, shuffle=shuffle, 290 | num_workers=num_workers, pin_memory=pin_memory, 291 | drop_last=partition_mode != 'uni') for tr_set in train_sets] 292 | test_loaders = [DataLoader(te_set, batch_size=test_batch_size, shuffle=shuffle_eval, 293 | num_workers=num_workers, pin_memory=pin_memory) 294 | for te_set in test_sets] 295 | if val_ratio > 0: 296 | val_loaders = [DataLoader(va_set, batch_size=batch_size, shuffle=shuffle_eval, 297 | num_workers=num_workers, pin_memory=pin_memory) 298 | for va_set in val_sets] 299 | else: 300 | val_loaders = test_loaders 301 | 302 | return train_loaders, val_loaders, test_loaders, clients 303 | 304 | def prepare_domainnet_data(args, domains=['clipart', 'quickdraw'], shuffle_eval=False, 305 | n_class_per_user=-1, n_user_per_domain=1, 306 | partition_seed=42, partition_mode='uni', 307 | val_ratio=0., eq_domain_train_size=True, 308 | subset_with_logits=False, consistent_test_class=False, 309 | ): 310 | assert args.data.lower() in ['domainnet', 'domainnetf'] 311 | train_sets, test_sets = get_central_data(args.data.lower(), domains) 312 | 313 | train_loaders, val_loaders, test_loaders, clients = make_fed_data( 314 | train_sets, test_sets, args.batch, domains, shuffle_eval=shuffle_eval, 315 | partition_seed=partition_seed, n_user_per_domain=n_user_per_domain, 316 | partition_mode=partition_mode, 317 | val_ratio=val_ratio, eq_domain_train_size=eq_domain_train_size, percent=args.percent, 318 | min_n_sample_per_share=16, subset_with_logits=subset_with_logits, 319 | n_class_per_user=n_class_per_user, 320 | test_batch_size=args.test_batch if hasattr(args, 'test_batch') else args.batch, 321 | num_workers=8 if args.data.lower() == 'domainnetf' else 0, 322 | pin_memory=False if args.data.lower() == 'domainnetf' else True, 323 | consistent_test_class=consistent_test_class, 324 | ) 325 | return train_loaders, val_loaders, test_loaders, clients 326 | 327 | 328 | def prepare_digits_data(args, domains=['MNIST', 'SVHN'], shuffle_eval=False, n_class_per_user=-1, 329 | n_user_per_domain=1, partition_seed=42, partition_mode='uni', val_ratio=0.2, 330 | eq_domain_train_size=True, subset_with_logits=False, 331 | consistent_test_class=False, 332 | ): 333 | do_adv_train = hasattr(args, 'noise') and (args.noise == 'none' or args.noise_ratio == 0 334 | or args.n_noise_domain == 0) 335 | # NOTE we use the image_norm=0.5 for reproducing clean training results. 336 | # but for adv training, we do not use image_norm 337 | train_sets, test_sets = get_central_data( 338 | args.data, domains, percent=args.percent, image_norm='0.5' if do_adv_train else 'none', 339 | disable_image_norm_error=True) 340 | train_loaders, val_loaders, test_loaders, clients = make_fed_data( 341 | train_sets, test_sets, args.batch, domains, shuffle_eval=shuffle_eval, 342 | partition_seed=partition_seed, n_user_per_domain=n_user_per_domain, 343 | partition_mode=partition_mode, 344 | val_ratio=val_ratio, eq_domain_train_size=eq_domain_train_size, 345 | min_n_sample_per_share=16, n_class_per_user=n_class_per_user, 346 | subset_with_logits=subset_with_logits, 347 | test_batch_size=args.test_batch if hasattr(args, 'test_batch') else args.batch, 348 | consistent_test_class=consistent_test_class, 349 | ) 350 | return train_loaders, val_loaders, test_loaders, clients 351 | 352 | 353 | def prepare_cifar_data(args, domains=['cifar10'], shuffle_eval=False, n_class_per_user=-1, 354 | n_user_per_domain=1, partition_seed=42, partition_mode='uni', val_ratio=0.2, 355 | eq_domain_train_size=True, subset_with_logits=False, 356 | consistent_test_class=False, 357 | ): 358 | train_sets, test_sets = get_central_data('cifar10', domains) 359 | 360 | train_loaders, val_loaders, test_loaders, clients = make_fed_data( 361 | train_sets, test_sets, args.batch, domains, shuffle_eval=shuffle_eval, 362 | partition_seed=partition_seed, n_user_per_domain=n_user_per_domain, 363 | partition_mode=partition_mode, 364 | val_ratio=val_ratio, eq_domain_train_size=eq_domain_train_size, percent=args.percent, 365 | min_n_sample_per_share=64 if n_class_per_user > 3 else 16, subset_with_logits=subset_with_logits, 366 | n_class_per_user=n_class_per_user, 367 | test_batch_size=args.test_batch if hasattr(args, 'test_batch') else args.batch, 368 | consistent_test_class=consistent_test_class, 369 | ) 370 | return train_loaders, val_loaders, test_loaders, clients 371 | 372 | def prepare_cifar100_data(args, domains=['cifar100'], shuffle_eval=False, n_class_per_user=-1, 373 | n_user_per_domain=1, partition_seed=42, partition_mode='uni', val_ratio=0.2, 374 | eq_domain_train_size=True, subset_with_logits=False, 375 | consistent_test_class=False, 376 | ): 377 | train_sets, test_sets = get_central_data('cifar100', domains) 378 | 379 | train_loaders, val_loaders, test_loaders, clients = make_fed_data( 380 | train_sets, test_sets, args.batch, domains, shuffle_eval=shuffle_eval, 381 | partition_seed=partition_seed, n_user_per_domain=n_user_per_domain, 382 | partition_mode=partition_mode, 383 | val_ratio=val_ratio, eq_domain_train_size=eq_domain_train_size, percent=args.percent, 384 | min_n_sample_per_share=0 if n_class_per_user > 3 else 16, subset_with_logits=subset_with_logits, 385 | n_class_per_user=n_class_per_user, 386 | test_batch_size=args.test_batch if hasattr(args, 'test_batch') else args.batch, 387 | consistent_test_class=consistent_test_class, 388 | ) 389 | return train_loaders, val_loaders, test_loaders, clients 390 | 391 | 392 | class SubsetWithLogits(Subset): 393 | r""" 394 | Subset of a dataset at specified indices. 395 | 396 | Arguments: 397 | dataset (Dataset): The whole Dataset 398 | indices (sequence): Indices in the whole set selected for subset 399 | """ 400 | def __init__(self, dataset, indices) -> None: 401 | super(SubsetWithLogits, self).__init__(dataset, indices) 402 | self.logits = [0. for _ in range(len(indices))] 403 | 404 | def __getitem__(self, idx): 405 | dataset_subset = self.dataset[self.indices[idx]] 406 | if isinstance(dataset_subset, tuple): 407 | return (*dataset_subset, self.logits[idx]) 408 | else: 409 | return dataset_subset, self.logits[idx] 410 | 411 | def update_logits(self, idx, logit): 412 | self.logits[idx] = logit 413 | 414 | 415 | if __name__ == '__main__': 416 | data = 'cifar10' 417 | if data == 'digits': 418 | train_loaders, val_loaders, test_loaders, clients = prepare_digits_data( 419 | type('MockClass', (object,), {'percent': 1.0, 'batch': 32}), domains=['MNIST'], 420 | n_user_per_domain=5, 421 | partition_seed=1, 422 | partition_mode='uni', 423 | val_ratio=0.2, 424 | ) 425 | for batch in train_loaders[0]: 426 | data, target = batch 427 | print(target) 428 | break 429 | elif data == 'cifar10': 430 | train_loaders, val_loaders, test_loaders, clients = prepare_cifar_data( 431 | type('MockClass', (object,), {'batch': 32, 'percent': 0.1}), domains=['cifar10'], 432 | n_user_per_domain=5, 433 | partition_seed=1, 434 | partition_mode='uni', 435 | val_ratio=0.2, 436 | subset_with_logits=True 437 | ) 438 | for batch in train_loaders[0]: 439 | smp_idxs, data, target, t_logits = batch 440 | print(smp_idxs) 441 | break 442 | 443 | temp_loader = DataLoader(train_loaders[0].dataset, batch_size=32, shuffle=False) 444 | all_logits = [] 445 | for batch in temp_loader: 446 | # FIXME need to modify SubsetWithLogits to return the index 447 | smp_idxs, data, target, t_logits = batch 448 | all_logits.append(torch.rand((len(data), 10))) 449 | # for i in smp_idxs 450 | # print(smp_idxs) 451 | all_logits = torch.cat(all_logits, dim=0) 452 | assert isinstance(train_loaders[0].dataset, SubsetWithLogits) 453 | train_loaders[0].dataset.logits = all_logits 454 | 455 | for batch in train_loaders[0]: 456 | smp_idxs, data, target, t_logits = batch 457 | print("t_logits shape", t_logits.shape) 458 | break 459 | --------------------------------------------------------------------------------