├── LICENSE ├── README.md ├── broadcast.py ├── fedselect.py ├── lottery_ticket.py ├── pflopt └── optimizers.py └── utils ├── options.py ├── sampling.py └── train_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Lapis Labs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [CVPR 2024] Official Repository for "FedSelect: Personalized Federated Learning with Customized Selection of Parameters for Fine-Tuning" 2 | 3 |
11 | 12 | by [Rishub Tamirisa](https://rishub-tamirisa.github.io/research/), [Chulin Xie](https://alphapav.github.io/), [Wenxuan Bao](https://baowenxuan.github.io/), [Andy Zhou](https://www.andyzhou.ai/), [Ron Arel](https://arel.ai/), and [Aviv Shamsian](https://avivsham.github.io/) 13 | 14 | See our paper on [arXiv](https://arxiv.org/abs/2404.02478). 15 | 16 | 17 | ## Citation 18 | 19 | ``` 20 | @misc{tamirisa2024fedselectpersonalizedfederatedlearning, 21 | title={FedSelect: Personalized Federated Learning with Customized Selection of Parameters for Fine-Tuning}, 22 | author={Rishub Tamirisa and Chulin Xie and Wenxuan Bao and Andy Zhou and Ron Arel and Aviv Shamsian}, 23 | year={2024}, 24 | eprint={2404.02478}, 25 | archivePrefix={arXiv}, 26 | primaryClass={cs.LG}, 27 | url={https://arxiv.org/abs/2404.02478}, 28 | } 29 | ``` 30 | -------------------------------------------------------------------------------- /broadcast.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import OrderedDict 3 | 4 | 5 | def broadcast_server_to_client_initialization( 6 | server_weights: OrderedDict[str, torch.Tensor], 7 | mask: OrderedDict[str, torch.Tensor], 8 | client_initialization: OrderedDict[str, torch.Tensor], 9 | ) -> OrderedDict[str, torch.Tensor]: 10 | """Broadcasts server weights to client initialization for non-masked parameters. 11 | 12 | Args: 13 | server_weights: Server model state dict 14 | mask: Binary mask indicating which parameters are local (1) vs global (0) 15 | client_initialization: Client model state dict to update 16 | 17 | Returns: 18 | Updated client model state dict with server weights broadcast to non-masked parameters 19 | """ 20 | for key in client_initialization.keys(): 21 | # only override client_initialization where mask is non-zero 22 | if "weight" in key or "bias" in key: 23 | client_initialization[key][mask[key] == 0] = server_weights[key][ 24 | mask[key] == 0 25 | ] 26 | return client_initialization 27 | 28 | 29 | def div_server_weights( 30 | server_weights: OrderedDict[str, torch.Tensor], 31 | server_mask: OrderedDict[str, torch.Tensor], 32 | ) -> OrderedDict[str, torch.Tensor]: 33 | """Divides server weights by mask values where mask is non-zero. 34 | 35 | Args: 36 | server_weights: Server model state dict 37 | server_mask: Mask indicating number of contributions to each parameter 38 | 39 | Returns: 40 | Server weights normalized by number of contributions 41 | """ 42 | for key in server_weights.keys(): 43 | # only divide where server_mask is non-zero 44 | if "weight" in key or "bias" in key: 45 | server_weights[key][server_mask[key] != 0] /= server_mask[key][ 46 | server_mask[key] != 0 47 | ] 48 | return server_weights 49 | 50 | 51 | def add_masks( 52 | server_dict: OrderedDict[str, torch.Tensor], 53 | client_dict: OrderedDict[str, torch.Tensor], 54 | invert: bool = True, 55 | ) -> OrderedDict[str, torch.Tensor]: 56 | """Accumulates client masks into server mask dictionary. 57 | 58 | Args: 59 | server_dict: Server mask accumulator 60 | client_dict: Client mask to add 61 | invert: Whether to invert client mask before adding 62 | 63 | Returns: 64 | Updated server mask accumulator 65 | """ 66 | for key in client_dict.keys(): 67 | if "weight" in key or "bias" in key: 68 | if key not in server_dict.keys(): 69 | server_dict[key] = 1 - client_dict[key] if invert else client_dict[key] 70 | else: 71 | server_dict[key] += ( 72 | (1 - client_dict[key]) if invert else client_dict[key] 73 | ) 74 | return server_dict 75 | 76 | 77 | def add_server_weights( 78 | server_weights: OrderedDict[str, torch.Tensor], 79 | client_weights: OrderedDict[str, torch.Tensor], 80 | client_mask: OrderedDict[str, torch.Tensor], 81 | invert: bool = True, 82 | ) -> OrderedDict[str, torch.Tensor]: 83 | """Accumulates masked client weights into server weights. 84 | 85 | Args: 86 | server_weights: Server weights accumulator 87 | client_weights: Client model weights to add 88 | client_mask: Binary mask indicating which parameters to add 89 | invert: Whether to invert mask before applying 90 | 91 | Returns: 92 | Updated server weights accumulator 93 | """ 94 | for key in client_weights.keys(): 95 | if "weight" in key or "bias" in key: 96 | mask = 1 - client_mask[key] if invert else client_mask[key] 97 | if key not in server_weights.keys(): 98 | server_weights[key] = client_weights[key] * mask 99 | else: 100 | server_weights[key] += client_weights[key] * mask 101 | return server_weights 102 | -------------------------------------------------------------------------------- /fedselect.py: -------------------------------------------------------------------------------- 1 | # Importing Libraries 2 | import copy 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import Dict, List, OrderedDict, Tuple, Optional, Any 9 | 10 | # Custom Libraries 11 | from utils.options import lth_args_parser 12 | from utils.train_utils import prepare_dataloaders, get_data 13 | from pflopt.optimizers import MaskLocalAltSGD, local_alt 14 | from lottery_ticket import init_mask_zeros, delta_update 15 | from broadcast import ( 16 | broadcast_server_to_client_initialization, 17 | div_server_weights, 18 | add_masks, 19 | add_server_weights, 20 | ) 21 | import random 22 | from torchvision.models import resnet18 23 | 24 | 25 | def evaluate( 26 | model: nn.Module, ldr_test: torch.utils.data.DataLoader, args: Any 27 | ) -> float: 28 | """Evaluate model accuracy on test data loader. 29 | 30 | Args: 31 | model: Neural network model to evaluate 32 | ldr_test: Test data loader 33 | args: Arguments containing device info 34 | 35 | Returns: 36 | float: Average accuracy on test set 37 | """ 38 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 39 | average_accuracy = 0 40 | model.eval() 41 | with torch.no_grad(): 42 | for batch_idx, (data, target) in enumerate(ldr_test): 43 | data, target = data.to(device), target.to(device) 44 | output = model(data) 45 | pred = output.argmax(dim=1, keepdim=True) 46 | acc = pred.eq(target.view_as(pred)).sum().item() / len(data) 47 | average_accuracy += acc 48 | average_accuracy /= len(ldr_test) 49 | return average_accuracy 50 | 51 | 52 | def train_personalized( 53 | model: nn.Module, 54 | ldr_train: torch.utils.data.DataLoader, 55 | mask: OrderedDict, 56 | args: Any, 57 | initialization: Optional[OrderedDict] = None, 58 | verbose: bool = False, 59 | eval: bool = True, 60 | ) -> Tuple[nn.Module, float]: 61 | """Train model with personalized local alternating optimization. 62 | 63 | Args: 64 | model: Neural network model to train 65 | ldr_train: Training data loader 66 | mask: Binary mask for parameters 67 | args: Training arguments 68 | initialization: Optional initial model state 69 | verbose: Whether to print training progress 70 | eval: Whether to evaluate during training 71 | 72 | Returns: 73 | Tuple containing: 74 | - Trained model 75 | - Final training loss 76 | """ 77 | if initialization is not None: 78 | model.load_state_dict(initialization) 79 | optimizer = MaskLocalAltSGD(model.parameters(), mask, lr=args.lr) 80 | epochs = args.la_epochs 81 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 82 | criterion = nn.CrossEntropyLoss() 83 | train_loss = 0 84 | with tqdm(total=epochs) as pbar: 85 | for i in range(epochs): 86 | train_loss = local_alt( 87 | model, 88 | criterion, 89 | optimizer, 90 | ldr_train, 91 | device, 92 | clip_grad_norm=args.clipgradnorm, 93 | ) 94 | if verbose: 95 | print(f"Epoch: {i} \tLoss: {train_loss}") 96 | pbar.update(1) 97 | pbar.set_postfix({"Loss": train_loss}) 98 | return model, train_loss 99 | 100 | 101 | def fedselect_algorithm( 102 | model: nn.Module, 103 | args: Any, 104 | dataset_train: torch.utils.data.Dataset, 105 | dataset_test: torch.utils.data.Dataset, 106 | dict_users_train: Dict[int, np.ndarray], 107 | dict_users_test: Dict[int, np.ndarray], 108 | labels: np.ndarray, 109 | idxs_users: List[int], 110 | ) -> Dict[str, Any]: 111 | """Main FedSelect federated learning algorithm. 112 | 113 | Args: 114 | model: Neural network model 115 | args: Training arguments 116 | dataset_train: Training dataset 117 | dataset_test: Test dataset 118 | dict_users_train: Mapping of users to training data indices 119 | dict_users_test: Mapping of users to test data indices 120 | labels: Data labels 121 | idxs_users: List of user indices 122 | 123 | Returns: 124 | Dict containing: 125 | - client_accuracies: Accuracy history for each client 126 | - labels: Data labels 127 | - client_masks: Final client masks 128 | - args: Training arguments 129 | - cross_client_acc: Cross-client accuracy matrix 130 | - lth_convergence: Lottery ticket convergence history 131 | """ 132 | # initialize model 133 | initial_state_dict = copy.deepcopy(model.state_dict()) 134 | com_rounds = args.com_rounds 135 | # initialize server 136 | client_accuracies = [{i: 0 for i in idxs_users} for _ in range(com_rounds)] 137 | client_state_dicts = {i: copy.deepcopy(initial_state_dict) for i in idxs_users} 138 | client_state_dict_prev = {i: copy.deepcopy(initial_state_dict) for i in idxs_users} 139 | client_masks = {i: None for i in idxs_users} 140 | client_masks_prev = {i: init_mask_zeros(model) for i in idxs_users} 141 | server_accumulate_mask = OrderedDict() 142 | server_weights = OrderedDict() 143 | lth_iters = args.lth_epoch_iters 144 | prune_rate = args.prune_percent / 100 145 | prune_target = args.prune_target / 100 146 | lottery_ticket_convergence = [] 147 | # Begin FL 148 | for round_num in range(com_rounds): 149 | round_loss = 0 150 | for i in idxs_users: 151 | # initialize model 152 | model.load_state_dict(client_state_dicts[i]) 153 | # get data 154 | ldr_train, _ = prepare_dataloaders( 155 | dataset_train, 156 | dict_users_train[i], 157 | dataset_test, 158 | dict_users_test[i], 159 | args, 160 | ) 161 | # Update LTN_i on local data 162 | client_mask = client_masks_prev.get(i) 163 | # Update u_i parameters on local data 164 | # 0s are global parameters, 1s are local parameters 165 | client_model, loss = train_personalized(model, ldr_train, client_mask, args) 166 | round_loss += loss 167 | # Send u_i update to server 168 | if round_num < com_rounds - 1: 169 | server_accumulate_mask = add_masks(server_accumulate_mask, client_mask) 170 | server_weights = add_server_weights( 171 | server_weights, client_model.state_dict(), client_mask 172 | ) 173 | client_state_dicts[i] = copy.deepcopy(client_model.state_dict()) 174 | client_masks[i] = copy.deepcopy(client_mask) 175 | 176 | if round_num % lth_iters == 0 and round_num != 0: 177 | client_mask = delta_update( 178 | prune_rate, 179 | client_state_dicts[i], 180 | client_state_dict_prev[i], 181 | client_masks_prev[i], 182 | bound=prune_target, 183 | invert=True, 184 | ) 185 | client_state_dict_prev[i] = copy.deepcopy(client_state_dicts[i]) 186 | client_masks_prev[i] = copy.deepcopy(client_mask) 187 | round_loss /= len(idxs_users) 188 | cross_client_acc = cross_client_eval( 189 | model, 190 | client_state_dicts, 191 | dataset_train, 192 | dataset_test, 193 | dict_users_train, 194 | dict_users_test, 195 | args, 196 | ) 197 | 198 | accs = torch.diag(cross_client_acc) 199 | for i in range(len(accs)): 200 | client_accuracies[round_num][i] = accs[i] 201 | print("Client Accs: ", accs, " | Mean: ", accs.mean()) 202 | 203 | if round_num < com_rounds - 1: 204 | # Server averages u_i 205 | server_weights = div_server_weights(server_weights, server_accumulate_mask) 206 | # Server broadcasts non lottery ticket parameters u_i to every device 207 | for i in idxs_users: 208 | client_state_dicts[i] = broadcast_server_to_client_initialization( 209 | server_weights, client_masks[i], client_state_dicts[i] 210 | ) 211 | server_accumulate_mask = OrderedDict() 212 | server_weights = OrderedDict() 213 | 214 | cross_client_acc = cross_client_eval( 215 | model, 216 | client_state_dicts, 217 | dataset_train, 218 | dataset_test, 219 | dict_users_train, 220 | dict_users_test, 221 | args, 222 | no_cross=False, 223 | ) 224 | 225 | out_dict = { 226 | "client_accuracies": client_accuracies, 227 | "labels": labels, 228 | "client_masks": client_masks, 229 | "args": args, 230 | "cross_client_acc": cross_client_acc, 231 | "lth_convergence": lottery_ticket_convergence, 232 | } 233 | 234 | return out_dict 235 | 236 | 237 | def cross_client_eval( 238 | model: nn.Module, 239 | client_state_dicts: Dict[int, OrderedDict], 240 | dataset_train: torch.utils.data.Dataset, 241 | dataset_test: torch.utils.data.Dataset, 242 | dict_users_train: Dict[int, np.ndarray], 243 | dict_users_test: Dict[int, np.ndarray], 244 | args: Any, 245 | no_cross: bool = True, 246 | ) -> torch.Tensor: 247 | """Evaluate models across clients. 248 | 249 | Args: 250 | model: Neural network model 251 | client_state_dicts: Client model states 252 | dataset_train: Training dataset 253 | dataset_test: Test dataset 254 | dict_users_train: Mapping of users to training data indices 255 | dict_users_test: Mapping of users to test data indices 256 | args: Evaluation arguments 257 | no_cross: Whether to only evaluate on own data 258 | 259 | Returns: 260 | torch.Tensor: Matrix of cross-client accuracies 261 | """ 262 | cross_client_acc_matrix = torch.zeros( 263 | (len(client_state_dicts), len(client_state_dicts)) 264 | ) 265 | idx_users = list(client_state_dicts.keys()) 266 | for _i, i in enumerate(idx_users): 267 | model.load_state_dict(client_state_dicts[i]) 268 | for _j, j in enumerate(idx_users): 269 | if no_cross: 270 | if i != j: 271 | continue 272 | # eval model i on data from client j 273 | _, ldr_test = prepare_dataloaders( 274 | dataset_train, 275 | dict_users_train[j], 276 | dataset_test, 277 | dict_users_test[j], 278 | args, 279 | ) 280 | acc = evaluate(model, ldr_test, args) 281 | cross_client_acc_matrix[_i, _j] = acc 282 | return cross_client_acc_matrix 283 | 284 | 285 | def get_cross_correlation(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: 286 | """Get cross correlation between two tensors using F.conv2d. 287 | 288 | Args: 289 | A: First tensor 290 | B: Second tensor 291 | 292 | Returns: 293 | torch.Tensor: Cross correlation result 294 | """ 295 | # Normalize A 296 | A = A.cuda() if torch.cuda.is_available() else A 297 | B = B.cuda() if torch.cuda.is_available() else B 298 | A = A.unsqueeze(0).unsqueeze(0) 299 | B = B.unsqueeze(0).unsqueeze(0) 300 | A = A / (A.max() - A.min()) if A.max() - A.min() != 0 else A 301 | B = B / (B.max() - B.min()) if B.max() - B.min() != 0 else B 302 | return F.conv2d(A, B) 303 | 304 | 305 | def run_base_experiment(model: nn.Module, args: Any) -> None: 306 | """Run base federated learning experiment. 307 | 308 | Args: 309 | model: Neural network model 310 | args: Experiment arguments 311 | """ 312 | dataset_train, dataset_test, dict_users_train, dict_users_test, labels = get_data( 313 | args 314 | ) 315 | idxs_users = np.arange(args.num_users * args.frac) 316 | m = max(int(args.frac * args.num_users), 1) 317 | idxs_users = np.random.choice(range(args.num_users), m, replace=False) 318 | idxs_users = [int(i) for i in idxs_users] 319 | fedselect_algorithm( 320 | model, 321 | args, 322 | dataset_train, 323 | dataset_test, 324 | dict_users_train, 325 | dict_users_test, 326 | labels, 327 | idxs_users, 328 | ) 329 | 330 | 331 | def load_model(args: Any) -> nn.Module: 332 | """Load and initialize model. 333 | 334 | Args: 335 | args: Model arguments 336 | 337 | Returns: 338 | nn.Module: Initialized model 339 | """ 340 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 341 | args.device = device 342 | model = resnet18(pretrained=args.pretrained_init) 343 | num_ftrs = model.fc.in_features 344 | model.fc = nn.Linear(num_ftrs, args.num_classes) 345 | model = model.to(device) 346 | return model.to(device) 347 | 348 | 349 | def setup_seed(seed: int) -> None: 350 | """Set random seeds for reproducibility. 351 | 352 | Args: 353 | seed: Random seed value 354 | """ 355 | torch.manual_seed(seed) 356 | torch.cuda.manual_seed_all(seed) 357 | np.random.seed(seed) 358 | random.seed(seed) 359 | 360 | 361 | if __name__ == "__main__": 362 | # Argument Parser 363 | args = lth_args_parser() 364 | 365 | # Set the seed 366 | setup_seed(args.seed) 367 | model = load_model(args) 368 | 369 | run_base_experiment(model, args) 370 | -------------------------------------------------------------------------------- /lottery_ticket.py: -------------------------------------------------------------------------------- 1 | # Importing Libraries 2 | import copy 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import os 9 | import types 10 | from collections import OrderedDict 11 | from typing import List, Tuple, Dict, OrderedDict, Optional, Union 12 | 13 | 14 | def eval_per_layer_sparsity(mask: OrderedDict) -> List[Tuple[str, str, str, float]]: 15 | """Calculate sparsity statistics for each weight layer in the mask. 16 | 17 | Args: 18 | mask: OrderedDict containing binary masks for model parameters 19 | 20 | Returns: 21 | List of tuples containing (num ones, num zeros, layer name, sparsity) for each weight layer 22 | """ 23 | return [ 24 | ( 25 | f"1: {torch.count_nonzero(mask[name])}", 26 | f"0: {torch.count_nonzero(1-mask[name])}", 27 | name, 28 | ( 29 | torch.count_nonzero(1 - mask[name]) 30 | / ( 31 | torch.count_nonzero(mask[name]) 32 | + torch.count_nonzero(1 - mask[name]) 33 | ) 34 | ).item(), 35 | ) 36 | for name in mask.keys() 37 | if "weight" in name 38 | ] 39 | 40 | 41 | def eval_layer_sparsity( 42 | mask: OrderedDict, layer_name: str 43 | ) -> Tuple[str, str, str, float]: 44 | """Calculate sparsity statistics for a specific layer in the mask. 45 | 46 | Args: 47 | mask: OrderedDict containing binary masks for model parameters 48 | layer_name: Name of layer to analyze 49 | 50 | Returns: 51 | Tuple containing (num ones, num zeros, layer name, sparsity) for specified layer 52 | """ 53 | return ( 54 | f"1: {torch.count_nonzero(mask[layer_name])}", 55 | f"0: {torch.count_nonzero(1-mask[layer_name])}", 56 | layer_name, 57 | ( 58 | torch.count_nonzero(1 - mask[layer_name]) 59 | / ( 60 | torch.count_nonzero(mask[layer_name]) 61 | + torch.count_nonzero(1 - mask[layer_name]) 62 | ) 63 | ).item(), 64 | ) 65 | 66 | 67 | def print_nonzeros( 68 | model: OrderedDict, verbose: bool = False, invert: bool = False 69 | ) -> float: 70 | """Print statistics about non-zero parameters in model. 71 | 72 | Args: 73 | model: OrderedDict containing model parameters 74 | verbose: Whether to print detailed statistics 75 | invert: Whether to count zeros instead of non-zeros 76 | 77 | Returns: 78 | Percentage of pruned parameters 79 | """ 80 | nonzero = total = 0 81 | for name, p in model.items(): 82 | tensor = p.data.cpu().numpy() 83 | nz_count = ( 84 | np.count_nonzero(tensor) if not invert else np.count_nonzero(1 - tensor) 85 | ) 86 | total_params = np.prod(tensor.shape) 87 | nonzero += nz_count 88 | total += total_params 89 | if verbose: 90 | print( 91 | f"{name:20} | nonzeros = {nz_count:7} / {total_params:7} ({100 * nz_count / total_params:6.2f}%) | total_pruned = {total_params - nz_count :7} | shape = {tensor.shape}" 92 | ) 93 | if verbose: 94 | print( 95 | f"alive: {nonzero}, pruned : {total - nonzero}, total: {total}, Compression rate : {total/nonzero:10.2f}x ({100 * (total-nonzero) / total:6.2f}% pruned)" 96 | ) 97 | return 100 * (total - nonzero) / total 98 | 99 | 100 | def print_lth_stats(mask: OrderedDict, invert: bool = False) -> None: 101 | """Print lottery ticket hypothesis statistics about mask sparsity. 102 | 103 | Args: 104 | mask: OrderedDict containing binary masks 105 | invert: Whether to invert the sparsity calculation 106 | """ 107 | current_prune = print_nonzeros(mask, invert=invert) 108 | print(f"Mask Sparsity: {current_prune:.2f}%") 109 | 110 | 111 | def _violates_bound( 112 | mask: torch.Tensor, bound: Optional[float] = None, invert: bool = False 113 | ) -> bool: 114 | """Check if mask sparsity violates specified bound. 115 | 116 | Args: 117 | mask: Binary mask tensor 118 | bound: Maximum allowed sparsity 119 | invert: Whether to invert the sparsity calculation 120 | 121 | Returns: 122 | True if bound is violated, False otherwise 123 | """ 124 | if invert: 125 | return ( 126 | torch.count_nonzero(mask) 127 | / (torch.count_nonzero(mask) + torch.count_nonzero(1 - mask)) 128 | ).item() > bound 129 | else: 130 | return ( 131 | torch.count_nonzero(1 - mask) 132 | / (torch.count_nonzero(mask) + torch.count_nonzero(1 - mask)) 133 | ).item() > bound 134 | 135 | 136 | def init_mask(model: nn.Module) -> OrderedDict: 137 | """Initialize binary mask of ones for model parameters. 138 | 139 | Args: 140 | model: Neural network model 141 | 142 | Returns: 143 | OrderedDict containing binary masks initialized to ones 144 | """ 145 | mask = OrderedDict() 146 | for name, param in model.named_parameters(): 147 | mask[name] = torch.ones_like(param) 148 | return mask 149 | 150 | 151 | def init_mask_zeros(model: nn.Module) -> OrderedDict: 152 | """Initialize binary mask of zeros for model parameters. 153 | 154 | Args: 155 | model: Neural network model 156 | 157 | Returns: 158 | OrderedDict containing binary masks initialized to zeros 159 | """ 160 | mask = OrderedDict() 161 | for name, param in model.named_parameters(): 162 | mask[name] = torch.zeros_like(param) 163 | return mask 164 | 165 | 166 | def get_mask_from_delta( 167 | prune_percent: float, 168 | current_state_dict: OrderedDict, 169 | prev_state_dict: OrderedDict, 170 | current_mask: OrderedDict, 171 | bound: float = 0.80, 172 | invert: bool = True, 173 | ) -> OrderedDict: 174 | """Generate new mask based on parameter changes between states. 175 | 176 | Args: 177 | prune_percent: Percentage of parameters to prune 178 | current_state_dict: Current model state 179 | prev_state_dict: Previous model state 180 | current_mask: Current binary mask 181 | bound: Maximum allowed sparsity 182 | invert: Whether to invert the pruning logic 183 | 184 | Returns: 185 | Updated binary mask based on parameter changes 186 | """ 187 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 188 | return_mask = copy.deepcopy(current_mask) 189 | for name, param in current_state_dict.items(): 190 | if "weight" in name: 191 | if _violates_bound(current_mask[name], bound=bound, invert=invert): 192 | continue 193 | tensor = param.data.cpu().numpy() 194 | compare_tensor = prev_state_dict[name].cpu().numpy() 195 | delta_tensor = np.abs(tensor - compare_tensor) 196 | 197 | delta_percentile_tensor = ( 198 | delta_tensor[current_mask[name].cpu().numpy() == 1] 199 | if not invert 200 | else delta_tensor[current_mask[name].cpu().numpy() == 0] 201 | ) 202 | sorted_weights = np.sort(np.abs(delta_percentile_tensor)) 203 | if not invert: 204 | cutoff_index = np.round(prune_percent * sorted_weights.size).astype(int) 205 | cutoff = sorted_weights[cutoff_index] 206 | 207 | # Convert Tensors to numpy and calculate 208 | new_mask = np.where( 209 | abs(delta_tensor) <= cutoff, 0, return_mask[name].cpu().numpy() 210 | ) 211 | return_mask[name] = torch.from_numpy(new_mask).to(device) 212 | else: 213 | cutoff_index = np.round( 214 | (1 - prune_percent) * sorted_weights.size 215 | ).astype(int) 216 | cutoff = sorted_weights[cutoff_index] 217 | 218 | # Convert Tensors to numpy and calculate 219 | new_mask = np.where( 220 | abs(delta_tensor) >= cutoff, 1, return_mask[name].cpu().numpy() 221 | ) 222 | return_mask[name] = torch.from_numpy(new_mask).to(device) 223 | # print(eval_per_layer_sparsity(return_mask)) 224 | print(eval_layer_sparsity(return_mask, "fc.weight")) 225 | return return_mask 226 | 227 | 228 | def delta_update( 229 | prune_percent: float, 230 | current_state_dict: OrderedDict, 231 | prev_state_dict: OrderedDict, 232 | current_mask: OrderedDict, 233 | bound: float = 0.80, 234 | invert: bool = False, 235 | ) -> OrderedDict: 236 | """Update mask based on parameter changes between states. 237 | 238 | Args: 239 | prune_percent: Percentage of parameters to prune 240 | current_state_dict: Current model state 241 | prev_state_dict: Previous model state 242 | current_mask: Current binary mask 243 | bound: Maximum allowed sparsity 244 | invert: Whether to invert the pruning logic 245 | 246 | Returns: 247 | Updated binary mask 248 | """ 249 | mask = get_mask_from_delta( 250 | prune_percent, 251 | current_state_dict, 252 | prev_state_dict, 253 | current_mask, 254 | bound=bound, 255 | invert=invert, 256 | ) 257 | print_lth_stats(mask, invert=invert) 258 | return mask 259 | -------------------------------------------------------------------------------- /pflopt/optimizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from collections import OrderedDict 6 | 7 | 8 | class MaskLocalAltSGD(optim.Optimizer): 9 | def __init__(self, params, mask: OrderedDict = None, lr=0.01): 10 | """Implements SGD with alternating updates based on a binary mask of parameters.""" 11 | # require params is named parameters 12 | # assert isinstance(params, list) and len(params) == 1 13 | self.mask: list[torch.Tensor] = [value.long() for key, value in mask.items()] 14 | self.names: list[torch.Tensor] = [key for key, value in mask.items()] 15 | self.named_mask: OrderedDict = mask 16 | self._toggle = True 17 | defaults = dict(lr=lr, _toggle=True) 18 | 19 | if mask is None: 20 | raise ValueError("MaskLocalAltSGD requires a mask") 21 | super(MaskLocalAltSGD, self).__init__(params, defaults) 22 | 23 | def __setstate__(self, state): 24 | super(MaskLocalAltSGD, self).__setstate__(state) 25 | 26 | def toggle(self): 27 | self._toggle = not self._toggle 28 | 29 | def step(self, closure=None): 30 | """Performs a single optimization step.""" 31 | loss = None 32 | if closure is not None: 33 | loss = closure() 34 | # update parameters 35 | for group in self.param_groups: 36 | step = 0 37 | for p in group["params"]: 38 | if p.grad is None: 39 | continue 40 | # assert that p does not contain nan 41 | assert torch.isnan(p).sum() == 0, "parameter contains nan" 42 | # get name of parameter 43 | mask = self.mask[step] 44 | # update parameter 45 | if mask is not None: 46 | if self._toggle: 47 | p.data.add_(mask * p.grad.data, alpha=-group["lr"]) 48 | else: 49 | p.data.add_((1 - mask) * p.grad.data, alpha=-group["lr"]) 50 | else: 51 | p.data.add_(-group["lr"], p.grad.data) 52 | step += 1 53 | return loss 54 | 55 | 56 | def local_alt( 57 | model, 58 | criterion, 59 | optimizer, 60 | data_loader, 61 | device, 62 | clip_grad_norm=False, 63 | max_grad_norm=3.50, 64 | ): 65 | assert isinstance(optimizer, MaskLocalAltSGD), "optimizer must be MaskLocalAltSGD" 66 | avg_loss_1 = 0 67 | for batch_idx, (data, target) in enumerate(data_loader): 68 | data, target = data.to(device), target.to(device) 69 | optimizer.zero_grad() 70 | output = model(data) 71 | loss = criterion(output, target) 72 | avg_loss_1 += loss.item() 73 | loss.backward() 74 | if clip_grad_norm: 75 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) 76 | optimizer.step() 77 | avg_loss_1 /= len(data_loader) 78 | optimizer.toggle() 79 | 80 | avg_loss_2 = 0 81 | for batch_idx, (data, target) in enumerate(data_loader): 82 | data, target = data.to(device), target.to(device) 83 | optimizer.zero_grad() 84 | output = model(data) 85 | loss = criterion(output, target) 86 | avg_loss_2 += loss.item() 87 | loss.backward() 88 | if clip_grad_norm: 89 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) 90 | optimizer.step() 91 | avg_loss_2 /= len(data_loader) 92 | optimizer.toggle() 93 | 94 | train_loss = (avg_loss_1 + avg_loss_2) / 2 95 | return train_loss 96 | 97 | 98 | if __name__ == "__main__": 99 | pass 100 | -------------------------------------------------------------------------------- /utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def lth_args_parser(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--lr", default=0.05, type=float, help="Learning rate") 7 | parser.add_argument("--batch_size", default=60, type=int) 8 | parser.add_argument("--lth_epoch_iters", default=3, type=int) 9 | parser.add_argument( 10 | "--dataset", 11 | default="cifar10", 12 | type=str, 13 | ) 14 | parser.add_argument( 15 | "--arch_type", 16 | default="resnet18", 17 | type=str, 18 | ) 19 | parser.add_argument( 20 | "--setting", 21 | default="", 22 | type=str, 23 | ) 24 | parser.add_argument( 25 | "--prune_percent", default=25, type=float, help="Pruning percent" 26 | ) 27 | parser.add_argument("--prune_target", default=80, type=int, help="Pruning target") 28 | parser.add_argument( 29 | "--com_rounds", type=int, default=4, help="rounds of fedavg training" 30 | ) 31 | parser.add_argument( 32 | "--la_epochs", 33 | type=int, 34 | default=15, 35 | help="rounds of training for local alt optimization", 36 | ) 37 | parser.add_argument("--iid", action="store_true", help="whether i.i.d or not") 38 | parser.add_argument("--num_users", type=int, default=100, help="number of users: K") 39 | parser.add_argument( 40 | "--shard_per_user", type=int, default=2, help="classes per user" 41 | ) 42 | parser.add_argument("--local_bs", type=int, default=32, help="local batch size: B") 43 | parser.add_argument( 44 | "--frac", type=float, default=0.1, help="the fraction of clients: C" 45 | ) 46 | parser.add_argument("--num_classes", type=int, default=10, help="number of classes") 47 | parser.add_argument("--model", type=str, default="mlp", help="model name") 48 | parser.add_argument("--bs", type=int, default=128, help="test batch size") 49 | parser.add_argument("--lth_freq", type=int, default=1, help="frequency of lth") 50 | parser.add_argument("--pretrained_init", action="store_true") 51 | parser.add_argument("--clipgradnorm", action="store_true") 52 | parser.add_argument("--num_samples", type=int, default=-1) 53 | parser.add_argument("--test_size", type=int, default=-1) 54 | parser.add_argument("--exp_name", type=str, default="prune_rate_vary") 55 | parser.add_argument( 56 | "--server_data_ratio", 57 | type=float, 58 | default=0.0, 59 | help="The percentage of data that servers also have across data of all clients.", 60 | ) 61 | 62 | parser.add_argument("--seed", type=int, default=1, help="random seed (default: 1)") 63 | 64 | args = parser.parse_args() 65 | return args 66 | -------------------------------------------------------------------------------- /utils/sampling.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from typing import Dict, List, Set, Union, Tuple 5 | from torch.utils.data import Dataset 6 | 7 | 8 | def iid(dataset: Dataset, num_users: int) -> Dict[int, Set[int]]: 9 | """Sample I.I.D. client data from dataset by randomly dividing into equal parts. 10 | 11 | Args: 12 | dataset: The full dataset to sample from 13 | num_users: Number of clients to divide data between 14 | 15 | Returns: 16 | Dict mapping client IDs to sets of data indices assigned to that client 17 | """ 18 | num_items = int(len(dataset) / num_users) 19 | dict_users, all_idxs = {}, [i for i in range(len(dataset))] 20 | for i in range(num_users): 21 | dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False)) 22 | all_idxs = list(set(all_idxs) - dict_users[i]) 23 | return dict_users 24 | 25 | 26 | def noniid( 27 | dataset: Dataset, 28 | num_users: int, 29 | shard_per_user: int, 30 | server_data_ratio: float = 0.0, 31 | size: Union[int, None] = None, 32 | rand_set_all: List = [], 33 | ) -> Tuple[Dict[Union[int, str], Union[np.ndarray, Set[int]]], np.ndarray]: 34 | """Sample non-I.I.D client data from dataset by dividing data by class labels. 35 | 36 | Args: 37 | dataset: The full dataset to sample from 38 | num_users: Number of clients to divide data between 39 | shard_per_user: Number of class shards to assign to each user 40 | server_data_ratio: Fraction of data to reserve for server (default: 0.0) 41 | size: Optional size to limit each user's data to 42 | rand_set_all: Optional pre-defined random class assignments 43 | 44 | Returns: 45 | Tuple containing: 46 | - Dict mapping client IDs to arrays of assigned data indices 47 | - Array of random class assignments used for the split 48 | """ 49 | dict_users, all_idxs = {i: np.array([], dtype="int64") for i in range(num_users)}, [ 50 | i for i in range(len(dataset)) 51 | ] 52 | 53 | targets = None 54 | targets = [elem[1].item() for elem in dataset] 55 | 56 | # dictionary of indices in the dataset for each label 57 | idxs_dict = {} 58 | for i in range(len(dataset)): 59 | label = torch.tensor(targets[i]).item() 60 | if label not in idxs_dict.keys(): 61 | idxs_dict[label] = [] 62 | idxs_dict[label].append(i) 63 | 64 | num_classes = len(np.unique(targets)) 65 | shard_per_class = int(shard_per_user * num_users / num_classes) 66 | for label in idxs_dict.keys(): 67 | x = idxs_dict[label] 68 | num_leftover = len(x) % shard_per_class 69 | leftover = x[-num_leftover:] if num_leftover > 0 else [] 70 | x = np.array(x[:-num_leftover]) if num_leftover > 0 else np.array(x) 71 | x = x.reshape((shard_per_class, -1)) 72 | x = list(x) 73 | 74 | for i, idx in enumerate(leftover): 75 | x[i] = np.concatenate([x[i], [idx]]) 76 | idxs_dict[label] = x 77 | 78 | if len(rand_set_all) == 0: 79 | rand_set_all = list(range(num_classes)) * shard_per_class 80 | random.shuffle(rand_set_all) 81 | rand_set_all = np.array(rand_set_all).reshape((num_users, -1)) 82 | 83 | # divide and assign 84 | for i in range(num_users): 85 | rand_set_label = rand_set_all[i] 86 | rand_set = [] 87 | for label in rand_set_label: 88 | idx = np.random.choice(len(idxs_dict[label]), replace=False) 89 | rand_set.append(idxs_dict[label].pop(idx)) 90 | dict_users[i] = np.concatenate(rand_set) 91 | 92 | test = [] 93 | for key, value in dict_users.items(): 94 | x = np.unique(torch.tensor(targets)[value]) 95 | assert (len(x)) <= shard_per_user 96 | test.append(value) 97 | test = np.concatenate(test) 98 | assert len(test) == len(dataset) 99 | assert len(set(list(test))) == len(dataset) 100 | 101 | if server_data_ratio > 0.0: 102 | dict_users["server"] = set( 103 | np.random.choice( 104 | all_idxs, int(len(dataset) * server_data_ratio), replace=False 105 | ) 106 | ) 107 | 108 | for i in range(num_users): 109 | num_elem = len(dict_users[i]) 110 | dict_users[i] = np.concatenate( 111 | [ 112 | dict_users[i][k : k + size] 113 | for k in range(0, num_elem, num_elem // shard_per_user + 1) 114 | ] 115 | ) 116 | 117 | return dict_users, rand_set_all 118 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | from utils.sampling import iid, noniid 3 | import numpy as np 4 | import torch 5 | from typing import Dict, List, Tuple, Any 6 | 7 | 8 | class DatasetSplit(torch.utils.data.Dataset): 9 | """Custom Dataset class that returns a subset of another dataset based on indices. 10 | 11 | Args: 12 | dataset: The base dataset to sample from 13 | idxs: Indices to use for sampling from the base dataset 14 | """ 15 | 16 | def __init__(self, dataset: torch.utils.data.Dataset, idxs: List[int]) -> None: 17 | self.dataset = dataset 18 | self.idxs = list(idxs) 19 | 20 | def __len__(self) -> int: 21 | return len(self.idxs) 22 | 23 | def __getitem__(self, item: int) -> Tuple[torch.Tensor, int]: 24 | image, label = self.dataset[self.idxs[item]] 25 | return image, label 26 | 27 | 28 | trans_mnist = transforms.Compose( 29 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 30 | ) 31 | trans_cifar10_train = transforms.Compose( 32 | [ 33 | transforms.RandomCrop(32, padding=4), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ToTensor(), 36 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 37 | ] 38 | ) 39 | trans_cifar10_val = transforms.Compose( 40 | [ 41 | transforms.ToTensor(), 42 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 43 | ] 44 | ) 45 | trans_cifar100_train = transforms.Compose( 46 | [ 47 | transforms.RandomCrop(32, padding=4), 48 | transforms.RandomHorizontalFlip(), 49 | transforms.ToTensor(), 50 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]), 51 | ] 52 | ) 53 | trans_cifar100_val = transforms.Compose( 54 | [ 55 | transforms.ToTensor(), 56 | transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]), 57 | ] 58 | ) 59 | 60 | 61 | def get_data( 62 | args: Any, 63 | ) -> Tuple[torch.utils.data.Dataset, torch.utils.data.Dataset, Dict, Dict, np.ndarray]: 64 | """Get train and test datasets and user splits for federated learning. 65 | 66 | Args: 67 | args: Arguments containing dataset configuration 68 | 69 | Returns: 70 | dataset_train: Training dataset 71 | dataset_test: Test dataset 72 | dict_users_train: Dictionary mapping users to training data indices 73 | dict_users_test: Dictionary mapping users to test data indices 74 | rand_set_all: Random set assignments for non-iid splitting 75 | """ 76 | dataset_train = datasets.CIFAR10( 77 | "data/cifar10", train=True, download=True, transform=trans_cifar10_train 78 | ) 79 | dataset_test = datasets.CIFAR10( 80 | "data/cifar10", train=False, download=True, transform=trans_cifar10_val 81 | ) 82 | if args.iid: 83 | dict_users_train = iid(dataset_train, args.num_users) 84 | dict_users_test = iid(dataset_test, args.num_users) 85 | rand_set_all = np.array([]) 86 | else: 87 | dict_users_train, rand_set_all = noniid( 88 | dataset_train, 89 | args.num_users, 90 | args.shard_per_user, 91 | args.server_data_ratio, 92 | size=args.num_samples, 93 | ) 94 | dict_users_test, rand_set_all = noniid( 95 | dataset_test, 96 | args.num_users, 97 | args.shard_per_user, 98 | args.server_data_ratio, 99 | size=args.test_size, 100 | rand_set_all=rand_set_all, 101 | ) 102 | 103 | return dataset_train, dataset_test, dict_users_train, dict_users_test, rand_set_all 104 | 105 | 106 | def prepare_dataloaders( 107 | dataset_train: torch.utils.data.Dataset, 108 | dict_users_train: Dict, 109 | dataset_test: torch.utils.data.Dataset, 110 | dict_users_test: Dict, 111 | args: Any, 112 | ) -> Tuple[torch.utils.data.DataLoader, torch.utils.data.DataLoader]: 113 | """Prepare train and test data loaders for a user. 114 | 115 | Args: 116 | dataset_train: Training dataset 117 | dict_users_train: Dictionary mapping users to training data indices 118 | dataset_test: Test dataset 119 | dict_users_test: Dictionary mapping users to test data indices 120 | args: Arguments containing batch size configuration 121 | 122 | Returns: 123 | ldr_train: Training data loader 124 | ldr_test: Test data loader 125 | """ 126 | ldr_train = torch.utils.data.DataLoader( 127 | DatasetSplit(dataset_train, dict_users_train), 128 | batch_size=args.local_bs, 129 | shuffle=True, 130 | ) 131 | ldr_test = torch.utils.data.DataLoader( 132 | DatasetSplit(dataset_test, dict_users_test), 133 | batch_size=args.local_bs, 134 | shuffle=False, 135 | ) 136 | return ldr_train, ldr_test 137 | --------------------------------------------------------------------------------