├── metrics ├── __init__.py ├── metrics.py └── flops.py ├── runners ├── __init__.py ├── scratchRunner.py ├── pretrainedRunner.py ├── ensembleRunner.py └── baseRunner.py ├── strategies ├── __init__.py ├── strategies.py └── ensembleStrategies.py ├── utilities ├── __init__.py ├── lr_schedulers.py └── utilities.py ├── citation.bib ├── models ├── imagenet.py ├── mnist.py ├── cifar100.py └── cifar10.py ├── .vscode └── launch.json ├── README.md ├── config.py ├── main.py └── .gitignore /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /runners/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /strategies/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utilities/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /citation.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{zimmer2024sparse, 2 | title={Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging}, 3 | author={Max Zimmer and Christoph Spiegel and Sebastian Pokutta}, 4 | booktitle={The Twelfth International Conference on Learning Representations}, 5 | year={2024}, 6 | url={https://openreview.net/forum?id=xx0ITyHp3u} 7 | } 8 | -------------------------------------------------------------------------------- /models/imagenet.py: -------------------------------------------------------------------------------- 1 | # =========================================================================== 2 | # Project: Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging - IOL Lab @ ZIB 3 | # Paper: arxiv.org/abs/2306.16788 4 | # File: models/imagenet.py 5 | # Description: ImageNet Models 6 | # =========================================================================== 7 | 8 | import torchvision 9 | 10 | 11 | def ResNet50(): 12 | return torchvision.models.resnet50(pretrained=False) 13 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Run main.py with --debug", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${workspaceFolder}/main.py", 12 | "args": [ 13 | "--debug" 14 | ], 15 | "console": "integratedTerminal" 16 | } 17 | ] 18 | } 19 | -------------------------------------------------------------------------------- /models/mnist.py: -------------------------------------------------------------------------------- 1 | # =========================================================================== 2 | # Project: Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging - IOL Lab @ ZIB 3 | # Paper: arxiv.org/abs/2306.16788 4 | # File: models/mnist.py 5 | # Description: MNIST Models 6 | # =========================================================================== 7 | import torch 8 | 9 | from utilities.utilities import Utilities as Utils 10 | 11 | 12 | class Simple(torch.nn.Module): 13 | def __init__(self): 14 | super(Simple, self).__init__() 15 | self.fc1 = torch.nn.Linear(784, 512, bias=True) 16 | self.dropout1 = torch.nn.Dropout(0.2) 17 | self.fc2 = torch.nn.Linear(512, 10, bias=True) 18 | 19 | def forward(self, x): 20 | x = torch.flatten(x, start_dim=1, end_dim=3) 21 | x = self.fc1(x) 22 | x = torch.nn.functional.relu(x) 23 | x = self.dropout1(x) 24 | x = self.fc2(x) 25 | return x 26 | 27 | @staticmethod 28 | def get_permutation_spec(): 29 | dense = lambda name, p_in, p_out, bias=True: {f"{name}.weight": (p_out, p_in), 30 | f"{name}.bias": (p_out,)} if bias else { 31 | f"{name}.weight": (p_out, p_in)} 32 | 33 | return Utils.permutation_spec_from_axes_to_perm({ 34 | **dense("fc1", None, "P_bg0", True), 35 | **dense("fc2", "P_bg0", None, True), 36 | }) 37 | -------------------------------------------------------------------------------- /utilities/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | # =========================================================================== 2 | # Project: Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging - IOL Lab @ ZIB 3 | # Paper: arxiv.org/abs/2306.16788 4 | # File: lr_schedulers.py 5 | # Description: All kinds of learning rate schedulers 6 | # =========================================================================== 7 | 8 | import warnings 9 | from bisect import bisect_right 10 | 11 | import torch 12 | 13 | 14 | class FixedLR(torch.optim.lr_scheduler._LRScheduler): 15 | """ 16 | Just uses the learning rate given by a list 17 | """ 18 | 19 | def __init__(self, optimizer, lrList, last_epoch=-1): 20 | self.lrList = lrList 21 | 22 | super(FixedLR, self).__init__(optimizer, last_epoch) 23 | 24 | def get_lr(self): 25 | if not self._get_lr_called_within_step: 26 | warnings.warn("To get the last learning rate computed by the scheduler, " 27 | "please use `get_last_lr()`.", UserWarning) 28 | 29 | return [self.lrList[self.last_epoch] for _ in self.optimizer.param_groups] 30 | 31 | 32 | class SequentialSchedulers(torch.optim.lr_scheduler.SequentialLR): 33 | """ 34 | Repairs SequentialLR to properly use the last learning rate of the previous scheduler when reaching milestones 35 | """ 36 | 37 | def __init__(self, **kwargs): 38 | self.optimizer = kwargs['schedulers'][0].optimizer 39 | super(SequentialSchedulers, self).__init__(**kwargs) 40 | 41 | def step(self): 42 | self.last_epoch += 1 43 | idx = bisect_right(self._milestones, self.last_epoch) 44 | self._schedulers[idx].step() 45 | 46 | 47 | class ChainedSchedulers(torch.optim.lr_scheduler.ChainedScheduler): 48 | """ 49 | Repairs ChainedScheduler to avoid a known bug that makes it into the pytorch release soon 50 | """ 51 | 52 | def __init__(self, **kwargs): 53 | self.optimizer = kwargs['schedulers'][0].optimizer 54 | super(ChainedSchedulers, self).__init__(**kwargs) 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## [ICLR24] Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging 2 | 3 | *Authors: [Max Zimmer](https://maxzimmer.org/), [Christoph Spiegel](http://www.christophspiegel.berlin/), [Sebastian Pokutta](http://www.pokutta.com/)* 4 | 5 | This repository contains the code to reproduce the experiments from the ICLR24 paper ["Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging"](https://arxiv.org/abs/2306.16788). 6 | The code is based on [PyTorch 1.9](https://pytorch.org/) and the experiment-tracking platform [Weights & Biases](https://wandb.ai). See the [blog post](https://www.pokutta.com/blog/research/2023/08/05/abstract-SMS.html) or the [twitter thread](https://x.com/maxzimmerberlin/status/1787052536442077479) for a TL;DR. 7 | 8 | ### Structure and Usage 9 | #### Structure 10 | Experiments are started from the following file: 11 | 12 | - [`main.py`](main.py): Starts experiments using the dictionary format of Weights & Biases. 13 | 14 | The rest of the project is structured as follows: 15 | 16 | - [`strategies`](strategies): Contains the strategies used for training, pruning and model averaging. 17 | - [`runners`](runners): Contains classes to control the training and collection of metrics. 18 | - [`metrics`](metrics): Contains all metrics as well as FLOP computation methods. 19 | - [`models`](models): Contains all model architectures used. 20 | - [`utilities`](models): Contains useful auxiliary functions and classes. 21 | 22 | #### Usage 23 | An entire experiment is subdivided into multiple steps, each being multiple (potentially many) different runs and wandb experiments. First of all, a model has to be pretrained using the `Dense` strategy. This steps is completely agnostic to any pruning specifications. Then, for each phase or prune-retrain-cycle (specified by the `n_phases` parameter and controlled by `phase` parameter), the following steps are executed: 24 | 1. Strategy `IMP`: Prune the model using the IMP strategy. Here, it is important to specify the `ensemble_by`, `split_val` and `n_splits_total` parameters: 25 | - `ensemble_by`: The parameter which is varied when retraining multiple models. E.g. setting this to `weight_decay` will train multiple models with different weight decay values. 26 | - `split_val`: The value by which the `ensemble_by` parameter is split. E.g. setting this to 0.0001 while using `weight_decay` as `ensemble_by` will retrain a model with weight decay 0.0001, all else being equal. 27 | - `n_splits_total`: The total number of splits for the `ensemble_by` parameter. If set to three, the souping operation in the next step will expect three models to be present, given the `ensemble_by` configuration. 28 | 2. Strategy `Ensemble`: Souping the models. This step will average the weights of the models specified by the `ensemble_by` parameter. The `ensemble_by` parameter has to be the same as in the previous step. `n_splits_total` has to be the same as well. `split_val` is not used in this step and has to be set to None. The `ensemble_method` parameter controls how the models are averaged. 29 | 30 | ### Citation 31 | 32 | In case you find the paper or the implementation useful for your own research, please consider citing: 33 | 34 | ``` 35 | @inproceedings{zimmer2024sparse, 36 | title={Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging}, 37 | author={Max Zimmer and Christoph Spiegel and Sebastian Pokutta}, 38 | booktitle={The Twelfth International Conference on Learning Representations}, 39 | year={2024}, 40 | url={https://openreview.net/forum?id=xx0ITyHp3u} 41 | } 42 | ``` 43 | -------------------------------------------------------------------------------- /runners/scratchRunner.py: -------------------------------------------------------------------------------- 1 | # =========================================================================== 2 | # Project: Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging - IOL Lab @ ZIB 3 | # Paper: arxiv.org/abs/2306.16788 4 | # File: scratchRunner.py 5 | # Description: Runner class for methods that do *not* start from a pretrained model 6 | # =========================================================================== 7 | import os 8 | import sys 9 | import time 10 | 11 | import numpy as np 12 | import torch 13 | import wandb 14 | 15 | from runners.baseRunner import baseRunner 16 | from utilities.utilities import Utilities as Utils 17 | 18 | 19 | class scratchRunner(baseRunner): 20 | 21 | def __init__(self, **kwargs): 22 | super().__init__(**kwargs) 23 | self.artifact = None 24 | 25 | entity, project = wandb.run.entity, wandb.run.project 26 | self.initial_artifact_name = f"initial-{entity}-{project}-{self.config.arch}-{self.config.dataset}-{self.config.run_id}" 27 | 28 | def find_existing_model(self): 29 | """Finds an existing wandb artifact and downloads the initial model file.""" 30 | # Create a new artifact, this is idempotent, i.e. no artifact is created if this already exists 31 | try: 32 | self.artifact = wandb.run.use_artifact(f"{self.initial_artifact_name}:latest") 33 | seed = self.artifact.metadata["seed"] 34 | self.artifact.download(root=self.tmp_dir) 35 | self.checkpoint_file = os.path.join(self.tmp_dir, 'initial_model.pt') 36 | self.seed = seed 37 | 38 | except Exception as e: 39 | print(e) 40 | 41 | outputStr = f"Found {self.initial_artifact_name} with seed {seed}" if self.artifact is not None else "Nothing found." 42 | sys.stdout.write(f"Trying to find reference initial model in project: {outputStr}\n") 43 | 44 | def run(self): 45 | """Function controlling the workflow of scratchRunner""" 46 | # If not existing, start a new model, otherwise use existing one with same run-id 47 | self.find_existing_model() 48 | 49 | if self.seed is None: 50 | # Generate a random seed 51 | self.seed = int((os.getpid() + 1) * time.time()) % 2 ** 32 52 | 53 | wandb.config.update({'seed': self.seed}) # Push the seed to wandb 54 | 55 | # Set a unique random seed 56 | np.random.seed(self.seed) 57 | torch.manual_seed(self.seed) 58 | # Remark: If you are working with a multi-GPU model, this function is insufficient to get determinism. To seed all GPUs, use manual_seed_all(). 59 | torch.cuda.manual_seed(self.seed) # This works if CUDA not available 60 | 61 | torch.backends.cudnn.benchmark = True 62 | 63 | self.trainLoader, self.valLoader, self.testLoader, self.trainLoader_unshuffled = self.get_dataloaders() 64 | self.model = self.get_model(reinit=True, temporary=True) 65 | # Save initial model before training 66 | if self.artifact is None: 67 | self.artifact = wandb.Artifact(self.initial_artifact_name, type='model', metadata={'seed': self.seed}) 68 | sys.stdout.write(f"Creating {self.initial_artifact_name}.\n") 69 | self.save_model(model_type='initial', temporary=True) 70 | self.artifact.add_file(f"{self.tmp_dir}/initial_model.pt") 71 | wandb.run.use_artifact(self.artifact) 72 | 73 | self.strategy = self.define_strategy() 74 | self.strategy.after_initialization() 75 | self.define_optimizer_scheduler() # This HAS to be after the definition of the strategy, otherwise changing the models parameters will not be noticed by the optimizer! 76 | self.strategy.set_optimizer(opt=self.optimizer, n_total_iterations=self.n_total_iterations) 77 | 78 | self.strategy.at_train_begin() 79 | 80 | # Do proper training 81 | self.train() 82 | 83 | self.squared_model_norm = Utils.get_model_norm_square(model=self.model) 84 | 85 | self.strategy.at_train_end() 86 | 87 | # Save trained model, to be used by pretrainedRunner 88 | self.checkpoint_file = self.save_model(model_type='trained') 89 | wandb.summary['final_model_file'] = 'trained_model.pt' 90 | 91 | self.strategy.final() 92 | 93 | # Upload iteration-lr dict from self.strategy to be used during retraining 94 | Utils.dump_dict_to_json_wandb(dumpDict=self.strategy.lr_dict, name='iteration-lr-dict') 95 | -------------------------------------------------------------------------------- /metrics/metrics.py: -------------------------------------------------------------------------------- 1 | # =========================================================================== 2 | # Project: Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging - IOL Lab @ ZIB 3 | # Paper: arxiv.org/abs/2306.16788 4 | # File: metrics/metrics.py 5 | # Description: Useful metrics 6 | # =========================================================================== 7 | import math 8 | from typing import Union, Tuple, List 9 | 10 | import torch 11 | 12 | from metrics import flops 13 | 14 | 15 | @torch.no_grad() 16 | def get_flops(model, x_input): 17 | return flops.flops(model, x_input) 18 | 19 | 20 | @torch.no_grad() 21 | def get_theoretical_speedup(n_flops: int, n_nonzero_flops: int) -> dict: 22 | if n_nonzero_flops == 0: 23 | # Would yield infinite speedup 24 | return {} 25 | return float(n_flops) / n_nonzero_flops 26 | 27 | 28 | def modular_sparsity(parameters_to_prune: List) -> float: 29 | """Returns the global sparsity out of all prunable parameters""" 30 | n_total, n_zero = 0., 0. 31 | for module, param_type in parameters_to_prune: 32 | if hasattr(module, param_type) and not isinstance(getattr(module, param_type), type(None)): 33 | param = getattr(module, param_type) 34 | n_param = float(torch.numel(param)) 35 | n_zero_param = float(torch.sum(param == 0)) 36 | n_total += n_param 37 | n_zero += n_zero_param 38 | return float(n_zero) / n_total if n_total > 0 else 0 39 | 40 | 41 | def global_sparsity(module: torch.nn.Module, param_type: Union[str, None] = None) -> float: 42 | """Returns the global sparsity of module (mostly of entire model)""" 43 | n_total, n_zero = 0., 0. 44 | param_list = ['weight', 'bias'] if not param_type else [param_type] 45 | for name, module in module.named_modules(): 46 | for param_type in param_list: 47 | if hasattr(module, param_type) and not isinstance(getattr(module, param_type), type(None)): 48 | param = getattr(module, param_type) 49 | n_param = float(torch.numel(param)) 50 | n_zero_param = float(torch.sum(param == 0)) 51 | n_total += n_param 52 | n_zero += n_zero_param 53 | return float(n_zero) / n_total 54 | 55 | 56 | @torch.no_grad() 57 | def get_parameter_count(model: torch.nn.Module) -> Tuple[int, int]: 58 | n_total = 0 59 | n_nonzero = 0 60 | param_list = ['weight', 'bias'] 61 | for name, module in model.named_modules(): 62 | for param_type in param_list: 63 | if hasattr(module, param_type) and not isinstance(getattr(module, param_type), type(None)): 64 | p = getattr(module, param_type) 65 | n_total += int(p.numel()) 66 | n_nonzero += int(torch.sum(p != 0)) 67 | return n_total, n_nonzero 68 | 69 | 70 | @torch.no_grad() 71 | def get_distance_to_pruned(model: torch.nn.Module, sparsity: float) -> Tuple[float, float]: 72 | prune_vector = torch.cat( 73 | [module.weight.flatten() for name, module in model.named_modules() if hasattr(module, 'weight') 74 | and not isinstance(module.weight, type(None)) and not isinstance(module, 75 | torch.nn.BatchNorm2d)]) 76 | n_params = float(prune_vector.numel()) 77 | k = int((1 - sparsity) * n_params) 78 | total_norm = float(torch.norm(prune_vector, p=2)) 79 | pruned_norm = float(torch.norm(torch.topk(torch.abs(prune_vector), k=k).values, p=2)) 80 | distance_to_pruned = math.sqrt(abs(total_norm ** 2 - pruned_norm ** 2)) 81 | rel_distance_to_pruned = distance_to_pruned / total_norm if total_norm > 0 else 0 82 | return distance_to_pruned, rel_distance_to_pruned 83 | 84 | 85 | @torch.no_grad() 86 | def get_distance_to_origin(model: torch.nn.Module) -> float: 87 | prune_vector = torch.cat( 88 | [module.weight.flatten() for name, module in model.named_modules() if hasattr(module, 'weight') 89 | and not isinstance(module.weight, type(None)) and not isinstance(module, 90 | torch.nn.BatchNorm2d)]) 91 | return float(torch.norm(prune_vector, p=2)) 92 | 93 | 94 | def per_layer_sparsity(model: torch.nn.Module): 95 | """Returns the per-layer-sparsity of model""" 96 | per_layer_sparsity_dict = dict() 97 | param_type = 'weight' # Only compute for weights, since we do not sparsify biases 98 | for name, submodule in model.named_modules(): 99 | if hasattr(submodule, param_type) and not isinstance(getattr(submodule, param_type), type(None)): 100 | if name in per_layer_sparsity_dict: 101 | continue 102 | per_layer_sparsity_dict[name] = global_sparsity(submodule, param_type=param_type) 103 | return per_layer_sparsity_dict 104 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # =========================================================================== 2 | # Project: Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging - IOL Lab @ ZIB 3 | # Paper: arxiv.org/abs/2306.16788 4 | # File: config.py 5 | # Description: Datasets, Normalization and Transforms 6 | # =========================================================================== 7 | 8 | import numpy as np 9 | import torchvision 10 | from torch.utils.data import Dataset 11 | from torchvision import transforms 12 | from torchvision.transforms import ToTensor 13 | 14 | 15 | class CIFARCORRUPT(Dataset): 16 | # CIFAR10CORRUPT and CIFAR100CORRUPT are the same, only the root changes 17 | def __init__(self, root, corruption='gaussian_noise', severity=3, transform=ToTensor()): 18 | self.root = root 19 | self.corruption = corruption # e.g. 'gaussian_noise' 20 | self.severity = severity # in [1, 2, 3, 4, 5] 21 | self.transform = transform 22 | data = np.load(f'{root}/{corruption}.npy') 23 | self.labels = np.load(f'{root}/labels.npy') 24 | 25 | # Only load images with the specified severity level 26 | start_index = (severity - 1) * 10000 27 | end_index = severity * 10000 28 | self.data = data[start_index:end_index] 29 | self.labels = self.labels[start_index:end_index] 30 | 31 | def __len__(self): 32 | return len(self.data) 33 | 34 | def __getitem__(self, index): 35 | image = self.data[index] 36 | label = self.labels[index] 37 | 38 | if self.transform: 39 | image = self.transform(image) 40 | 41 | return image, label 42 | 43 | 44 | means = { 45 | 'cifar10': (0.4914, 0.4822, 0.4465), 46 | 'cifar100': (0.5071, 0.4867, 0.4408), 47 | 'imagenet': (0.485, 0.456, 0.406), 48 | 'tinyimagenet': (0.485, 0.456, 0.406), 49 | } 50 | 51 | stds = { 52 | 'cifar10': (0.2023, 0.1994, 0.2010), 53 | 'cifar100': (0.2675, 0.2565, 0.2761), 54 | 'imagenet': (0.229, 0.224, 0.225), 55 | 'tinyimagenet': (0.229, 0.224, 0.225), 56 | } 57 | 58 | datasetDict = { # Links dataset names to actual torch datasets 59 | 'mnist': getattr(torchvision.datasets, 'MNIST'), 60 | 'cifar10': getattr(torchvision.datasets, 'CIFAR10'), 61 | 'fashionMNIST': getattr(torchvision.datasets, 'FashionMNIST'), 62 | 'SVHN': getattr(torchvision.datasets, 'SVHN'), # This needs scipy 63 | 'STL10': getattr(torchvision.datasets, 'STL10'), 64 | 'cifar100': getattr(torchvision.datasets, 'CIFAR100'), 65 | 'imagenet': getattr(torchvision.datasets, 'ImageNet'), 66 | 'tinyimagenet': getattr(torchvision.datasets, 'ImageFolder'), 67 | 'CIFAR10CORRUPT': CIFARCORRUPT, 68 | 'CIFAR100CORRUPT': CIFARCORRUPT, 69 | } 70 | 71 | trainTransformDict = { # Links dataset names to train dataset transformers 72 | 'mnist': transforms.Compose([transforms.ToTensor()]), 73 | 'cifar10': transforms.Compose([ 74 | transforms.RandomCrop(32, padding=4), 75 | transforms.RandomHorizontalFlip(), 76 | transforms.ToTensor(), 77 | transforms.Normalize(mean=means['cifar10'], std=stds['cifar10']), ]), 78 | 'cifar100': transforms.Compose([ 79 | transforms.RandomCrop(32, padding=4), 80 | transforms.RandomHorizontalFlip(), 81 | transforms.RandomRotation(15), 82 | transforms.ToTensor(), 83 | transforms.Normalize(mean=means['cifar100'], std=stds['cifar100']), ]), 84 | 'imagenet': transforms.Compose([ 85 | transforms.RandomResizedCrop(224), 86 | transforms.RandomHorizontalFlip(), 87 | transforms.ToTensor(), 88 | transforms.Normalize(mean=means['imagenet'], std=stds['imagenet']), ]), 89 | 'tinyimagenet': transforms.Compose([ 90 | transforms.RandomResizedCrop(224), 91 | transforms.RandomHorizontalFlip(), 92 | transforms.ToTensor(), 93 | transforms.Normalize(mean=means['tinyimagenet'], std=stds['tinyimagenet']), ]), 94 | } 95 | testTransformDict = { # Links dataset names to test dataset transformers 96 | 'mnist': transforms.Compose([transforms.ToTensor()]), 97 | 'cifar10': transforms.Compose([ 98 | transforms.ToTensor(), 99 | transforms.Normalize(mean=means['cifar10'], std=stds['cifar10']), ]), 100 | 'cifar100': transforms.Compose([ 101 | transforms.ToTensor(), 102 | transforms.Normalize(mean=means['cifar100'], std=stds['cifar100']), ]), 103 | 'imagenet': transforms.Compose([ 104 | transforms.Resize(256), 105 | transforms.CenterCrop(224), 106 | transforms.ToTensor(), 107 | transforms.Normalize(mean=means['imagenet'], std=stds['imagenet']), ]), 108 | 'tinyimagenet': transforms.Compose([ 109 | transforms.Resize(256), 110 | transforms.CenterCrop(224), 111 | transforms.ToTensor(), 112 | transforms.Normalize(mean=means['tinyimagenet'], std=stds['tinyimagenet']), ]), 113 | } 114 | -------------------------------------------------------------------------------- /models/cifar100.py: -------------------------------------------------------------------------------- 1 | # =========================================================================== 2 | # Project: Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging - IOL Lab @ ZIB 3 | # Paper: arxiv.org/abs/2306.16788 4 | # File: models/cifar100.py 5 | # Description: CIFAR-100 Models 6 | # =========================================================================== 7 | 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | 12 | class WideResNet20(nn.Module): 13 | # WideResNet implementation but with widen_factor=2 and depth=22 instead of 10 and 28 respectively. 14 | def __init__(self, depth=22, widen_factor=10, dropout_rate=0.3, num_classes=100): 15 | super(WideResNet20, self).__init__() 16 | self.in_planes = 16 17 | 18 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 19 | n = (depth - 4) / 6 20 | k = widen_factor 21 | 22 | nStages = [16, 16 * k, 32 * k, 64 * k] 23 | 24 | class wide_basic(nn.Module): 25 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 26 | super(wide_basic, self).__init__() 27 | self.bn1 = nn.BatchNorm2d(in_planes) 28 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 29 | 30 | self.dropout = nn.Dropout(p=dropout_rate) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 33 | 34 | self.shortcut = nn.Sequential() 35 | if stride != 1 or in_planes != planes: 36 | self.shortcut = nn.Sequential( 37 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 38 | ) 39 | 40 | def forward(self, x): 41 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 42 | out = self.conv2(F.relu(self.bn2(out))) 43 | out += self.shortcut(x) 44 | 45 | return out 46 | 47 | self.conv1 = self.conv3x3(3, nStages[0]) 48 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 49 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 50 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 51 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 52 | self.linear = nn.Linear(nStages[3], num_classes) 53 | 54 | def conv3x3(self, in_planes, out_planes, stride=1): 55 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 56 | 57 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 58 | strides = [stride] + [1] * (int(num_blocks) - 1) 59 | layers = [] 60 | 61 | for stride in strides: 62 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 63 | self.in_planes = planes 64 | 65 | return nn.Sequential(*layers) 66 | 67 | def forward(self, x): 68 | out = self.conv1(x) 69 | out = self.layer1(out) 70 | out = self.layer2(out) 71 | out = self.layer3(out) 72 | out = F.relu(self.bn1(out)) 73 | out = F.avg_pool2d(out, 8) 74 | out = out.view(out.size(0), -1) 75 | out = self.linear(out) 76 | 77 | return out 78 | 79 | 80 | class WideResNet(nn.Module): 81 | # Based on https://github.com/meliketoy/wide-resnet.pytorch 82 | def __init__(self, depth=28, widen_factor=10, dropout_rate=0.3, num_classes=100): 83 | super(WideResNet, self).__init__() 84 | self.in_planes = 16 85 | 86 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 87 | n = (depth - 4) / 6 88 | k = widen_factor 89 | 90 | nStages = [16, 16 * k, 32 * k, 64 * k] 91 | 92 | class wide_basic(nn.Module): 93 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 94 | super(wide_basic, self).__init__() 95 | self.bn1 = nn.BatchNorm2d(in_planes) 96 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 97 | 98 | self.dropout = nn.Dropout(p=dropout_rate) 99 | self.bn2 = nn.BatchNorm2d(planes) 100 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 101 | 102 | self.shortcut = nn.Sequential() 103 | if stride != 1 or in_planes != planes: 104 | self.shortcut = nn.Sequential( 105 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 106 | ) 107 | 108 | def forward(self, x): 109 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 110 | out = self.conv2(F.relu(self.bn2(out))) 111 | out += self.shortcut(x) 112 | 113 | return out 114 | 115 | self.conv1 = self.conv3x3(3, nStages[0]) 116 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 117 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 118 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 119 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 120 | self.linear = nn.Linear(nStages[3], num_classes) 121 | 122 | def conv3x3(self, in_planes, out_planes, stride=1): 123 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 124 | 125 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 126 | strides = [stride] + [1] * (int(num_blocks) - 1) 127 | layers = [] 128 | 129 | for stride in strides: 130 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 131 | self.in_planes = planes 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | out = self.conv1(x) 137 | out = self.layer1(out) 138 | out = self.layer2(out) 139 | out = self.layer3(out) 140 | out = F.relu(self.bn1(out)) 141 | out = F.avg_pool2d(out, 8) 142 | out = out.view(out.size(0), -1) 143 | out = self.linear(out) 144 | 145 | return out 146 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # =========================================================================== 2 | # Project: Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging - IOL Lab @ ZIB 3 | # Paper: arxiv.org/abs/2306.16788 4 | # File: main.py 5 | # Description: Starts up a run 6 | # =========================================================================== 7 | 8 | import os 9 | import shutil 10 | import socket 11 | import sys 12 | import tempfile 13 | from contextlib import contextmanager 14 | 15 | import torch 16 | import wandb 17 | 18 | from runners.ensembleRunner import ensembleRunner 19 | from runners.pretrainedRunner import pretrainedRunner 20 | from runners.scratchRunner import scratchRunner 21 | 22 | from utilities.utilities import Utilities as Utils 23 | 24 | debug = "--debug" in sys.argv 25 | defaults = dict( 26 | # System 27 | run_id=1, # The run id, determines the original random seed 28 | computer=socket.gethostname(), # The computer that runs the experiment 29 | 30 | # Setup 31 | dataset='mnist', # The dataset to use, see config.py for available options 32 | arch='Simple', # The architecture to use, see models/ for available options 33 | n_epochs=2, # The number of epochs to pretrain the model for (Note: this only controls the pretraining) 34 | batch_size=1028, # The batch size to use 35 | 36 | # Efficiency 37 | use_amp=True, # Whether to use automatic mixed precision 38 | 39 | # Optimizer 40 | optimizer='SGD', # The optimizer to use for pretraining/retraining, currently only SGD implemented 41 | learning_rate='(Linear, 0.1)', # The learning rate to use for pretraining 42 | n_epochs_warmup=None, # The number of epochs to warmup the lr, must be an int 43 | momentum=0.9, # The momentum to use for the optimizer 44 | weight_decay=0.0001, # The weight decay to use for the optimizer 45 | 46 | # Sparsifying strategy 47 | strategy='Dense', # The strategy to use, see strategies/ for available options. 'Dense' = pretraining, 'IMP' = iterative magnitude pruning, 'Ensemble' = ensembl/soup methods 48 | goal_sparsity=0.9, # The goal sparsity to reach after n_phases many prune-retrain cycles 49 | pruning_selector='global', # Pruning allocation, must be in ['global', 'uniform', 'random'] 50 | 51 | # Retraining 52 | phase=1, # The current phase of IMP/Ensemble 53 | n_phases=1, # The total number of phases to run 54 | n_epochs_per_phase=1, # The number of epochs to retrain for each phase 55 | retrain_schedule='LLR', # The retrain lr schedule, must be in ['FT', 'LRW', 'SLR', 'CLR', 'LLR', 'ALLR'] 56 | 57 | # Ensemble method 58 | ensemble_method='UniformEnsembling', # The ensemble/soup method to use, must be in ['UniformEnsembling', 'GreedySoup'] 59 | ensemble_by='pruned_seed', # The parameter controlling what is varied during retraining, must be in ['pruned_seed', 'weight_decay', 'retrain_length', 'retrain_schedule'] 60 | split_val=None, # The value to split the ensemble_by parameter on, e.g. ensemble_by='weight_decay' and split_val=0.0001 will retrain with a weight decay of 0.0001 61 | n_splits_total=2, # The total number of splits we expect to have, will raise an error if more models to average found 62 | bn_recalibration_frac=0.2, # The fraction of the dataset to use for recalibrating the batch norm layers, must be in [0,1] 63 | ) 64 | 65 | if not debug: 66 | # Set everything to None recursively 67 | defaults = Utils.fill_dict_with_none(defaults) 68 | 69 | # Add the hostname to the defaults 70 | defaults['computer'] = socket.gethostname() 71 | 72 | # Configure wandb logging 73 | wandb.init( 74 | config=defaults, 75 | project='test-000', # automatically changed in sweep 76 | entity=None, # automatically changed in sweep 77 | ) 78 | config = wandb.config 79 | config = Utils.update_config_with_default(config, defaults) 80 | ngpus = torch.cuda.device_count() 81 | if ngpus > 0: 82 | config.update(dict(device='cuda:0')) 83 | else: 84 | config.update(dict(device='cpu')) 85 | 86 | 87 | @contextmanager 88 | def tempdir(): 89 | tmp_root = '/scratch/local/' 90 | tmp_path = os.path.join(tmp_root, 'tmp') 91 | if os.path.isdir(tmp_root): 92 | if not os.path.isdir(tmp_path): os.mkdir(tmp_path) 93 | path = tempfile.mkdtemp(dir=tmp_path) 94 | else: 95 | path = tempfile.mkdtemp() 96 | try: 97 | yield path 98 | finally: 99 | try: 100 | shutil.rmtree(path) 101 | sys.stdout.write(f"Removed temporary directory {path}.\n") 102 | except IOError: 103 | sys.stderr.write('Failed to clean up temp dir {}'.format(path)) 104 | 105 | 106 | with tempdir() as tmp_dir: 107 | # At the moment, IMP is the only strategy that requires a pretrained model, all others start from scratch 108 | config.update({'tmp_dir': tmp_dir}) 109 | 110 | if config.strategy == 'Ensemble': 111 | runner = ensembleRunner(config=config) 112 | elif config.strategy in 'IMP': 113 | # Use the pretrainedRunner 114 | runner = pretrainedRunner(config=config) 115 | elif config.strategy == 'Dense': 116 | # Use the scratchRunner 117 | runner = scratchRunner(config=config) 118 | runner.run() 119 | 120 | # Close wandb run 121 | wandb_dir_path = wandb.run.dir 122 | wandb.join() 123 | 124 | # Delete the local files 125 | if os.path.exists(wandb_dir_path): 126 | shutil.rmtree(wandb_dir_path) 127 | -------------------------------------------------------------------------------- /metrics/flops.py: -------------------------------------------------------------------------------- 1 | # =========================================================================== 2 | # Project: Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging - IOL Lab @ ZIB 3 | # Paper: arxiv.org/abs/2306.16788 4 | # File: metrics/flops.py 5 | # Description: Methods to compute Inference-FLOPS. Modified from https://github.com/JJGO/shrinkbench 6 | # =========================================================================== 7 | from collections import OrderedDict 8 | 9 | import numpy as np 10 | import torch 11 | 12 | 13 | @torch.no_grad() 14 | def forward_hook_applyfn(hook, model): 15 | """Modified from https://github.com/JJGO/shrinkbench""" 16 | hooks = [] 17 | 18 | def register_hook(module): 19 | if ( 20 | not isinstance(module, torch.nn.Sequential) 21 | and 22 | not isinstance(module, torch.nn.ModuleList) 23 | and 24 | not isinstance(module, torch.nn.ModuleDict) 25 | and 26 | not (module == model) 27 | ): 28 | hooks.append(module.register_forward_hook(hook)) 29 | 30 | return register_hook, hooks 31 | 32 | 33 | @torch.no_grad() 34 | def get_flops_on_activations(model, x_input): 35 | flops_on_activations = OrderedDict() 36 | FLOP_fn = { 37 | torch.nn.Conv2d: _conv2d_flops, 38 | torch.nn.Linear: _linear_flops, 39 | } 40 | 41 | def store_flops(module, input, output): 42 | if isinstance(module, torch.nn.ReLU): 43 | return 44 | assert module not in flops_on_activations, \ 45 | f"{module} already in flops_on_activations" 46 | if module.__class__ in FLOP_fn: 47 | module_flops = FLOP_fn[module.__class__](module=module, activation=input[0]) 48 | flops_on_activations[module] = int(module_flops) 49 | 50 | fn, hooks = forward_hook_applyfn(store_flops, model) 51 | model.apply(fn) 52 | with torch.no_grad(): 53 | model.eval()(x_input) 54 | 55 | for h in hooks: 56 | h.remove() 57 | 58 | return flops_on_activations 59 | 60 | 61 | @torch.no_grad() 62 | def dense_flops(in_neurons, out_neurons): 63 | """Compute the number of multiply-adds used by a Dense (Linear) layer""" 64 | return in_neurons * out_neurons 65 | 66 | 67 | @torch.no_grad() 68 | def conv2d_flops(in_channels, out_channels, input_shape, kernel_shape, 69 | padding='same', strides=1, dilation=1): 70 | """Compute the number of multiply-adds used by a Conv2D layer 71 | Args: 72 | in_channels (int): The number of channels in the layer's input 73 | out_channels (int): The number of channels in the layer's output 74 | input_shape (int, int): The spatial shape of the rank-3 input tensor 75 | kernel_shape (int, int): The spatial shape of the rank-4 kernel 76 | padding ({'same', 'valid'}): The padding used by the convolution 77 | strides (int) or (int, int): The spatial stride of the convolution; 78 | two numbers may be specified if it's different for the x and y axes 79 | dilation (int): Must be 1 for now. 80 | Returns: 81 | int: The number of multiply-adds a direct convolution would require 82 | (i.e., no FFT, no Winograd, etc) 83 | """ 84 | # validate + sanitize input 85 | assert in_channels > 0 86 | assert out_channels > 0 87 | assert len(input_shape) == 2 88 | assert len(kernel_shape) == 2 89 | padding = padding.lower() 90 | assert padding in ('same', 'valid', 'zeros'), "Padding must be one of same|valid|zeros" 91 | try: 92 | strides = tuple(strides) 93 | except TypeError: 94 | # if one number provided, make it a 2-tuple 95 | strides = (strides, strides) 96 | assert dilation == 1 or all(d == 1 for d in dilation), "Dilation > 1 is not supported" 97 | 98 | # compute output spatial shape 99 | # based on TF computations https://stackoverflow.com/a/37674568 100 | if padding in ['same', 'zeros']: 101 | out_nrows = np.ceil(float(input_shape[0]) / strides[0]) 102 | out_ncols = np.ceil(float(input_shape[1]) / strides[1]) 103 | else: # padding == 'valid' 104 | out_nrows = np.ceil((input_shape[0] - kernel_shape[0] + 1) / strides[0]) # noqa 105 | out_ncols = np.ceil((input_shape[1] - kernel_shape[1] + 1) / strides[1]) # noqa 106 | output_shape = (int(out_nrows), int(out_ncols)) 107 | 108 | # work to compute one output spatial position 109 | nflops = in_channels * out_channels * int(np.prod(kernel_shape)) 110 | 111 | # total work = work per output position * number of output positions 112 | return nflops * int(np.prod(output_shape)) 113 | 114 | 115 | @torch.no_grad() 116 | def _conv2d_flops(module, activation): 117 | # Auxiliary func to use abstract flop computation 118 | 119 | # Drop batch & channels. Channels can be dropped since 120 | # unlike shape they have to match to in_channels 121 | input_shape = activation.shape[2:] 122 | return conv2d_flops(in_channels=module.in_channels, 123 | out_channels=module.out_channels, 124 | input_shape=input_shape, 125 | kernel_shape=module.kernel_size, 126 | padding=module.padding_mode, 127 | strides=module.stride, 128 | dilation=module.dilation) 129 | 130 | 131 | @torch.no_grad() 132 | def _linear_flops(module, activation): 133 | # Auxiliary func to use abstract flop computation 134 | return dense_flops(module.in_features, module.out_features) 135 | 136 | 137 | @torch.no_grad() 138 | def flops(model, x_input): 139 | """Compute Multiply-add FLOPs estimate from model 140 | Arguments: 141 | model {torch.nn.Module} -- Module to compute flops for 142 | x_input {torch.Tensor} -- Input tensor needed for activations 143 | Returns: 144 | tuple: 145 | - int - Number of total FLOPs 146 | - int - Number of FLOPs related to nonzero parameters 147 | """ 148 | 149 | total_flops = nonzero_flops = 0 150 | flops_on_activations = get_flops_on_activations(model, x_input) 151 | 152 | # The ones we need for backprop 153 | for m, module_flops in flops_on_activations.items(): 154 | total_flops += module_flops 155 | # For our operations, all weights are symmetric so we can just 156 | # do simple rule of three for the estimation 157 | nonzero_flops += module_flops * float(torch.sum(m.weight != 0.0)) / float(m.weight.numel()) 158 | 159 | return int(total_flops), int(nonzero_flops) 160 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### JetBrains template 2 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 3 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 4 | 5 | .idea/ 6 | datasets_pytorch/ 7 | wandb/ 8 | 9 | # User-specific stuff 10 | .idea/**/workspace.xml 11 | .idea/**/tasks.xml 12 | .idea/**/usage.statistics.xml 13 | .idea/**/dictionaries 14 | .idea/**/shelf 15 | 16 | # AWS User-specific 17 | .idea/**/aws.xml 18 | 19 | # Generated files 20 | .idea/**/contentModel.xml 21 | 22 | # Sensitive or high-churn files 23 | .idea/**/dataSources/ 24 | .idea/**/dataSources.ids 25 | .idea/**/dataSources.local.xml 26 | .idea/**/sqlDataSources.xml 27 | .idea/**/dynamic.xml 28 | .idea/**/uiDesigner.xml 29 | .idea/**/dbnavigator.xml 30 | 31 | # Gradle 32 | .idea/**/gradle.xml 33 | .idea/**/libraries 34 | 35 | # Gradle and Maven with auto-import 36 | # When using Gradle or Maven with auto-import, you should exclude module files, 37 | # since they will be recreated, and may cause churn. Uncomment if using 38 | # auto-import. 39 | # .idea/artifacts 40 | # .idea/compiler.xml 41 | # .idea/jarRepositories.xml 42 | # .idea/modules.xml 43 | # .idea/*.iml 44 | # .idea/modules 45 | # *.iml 46 | # *.ipr 47 | 48 | # CMake 49 | cmake-build-*/ 50 | 51 | # Mongo Explorer plugin 52 | .idea/**/mongoSettings.xml 53 | 54 | # File-based project format 55 | *.iws 56 | 57 | # IntelliJ 58 | out/ 59 | 60 | # mpeltonen/sbt-idea plugin 61 | .idea_modules/ 62 | 63 | # JIRA plugin 64 | atlassian-ide-plugin.xml 65 | 66 | # Cursive Clojure plugin 67 | .idea/replstate.xml 68 | 69 | # SonarLint plugin 70 | .idea/sonarlint/ 71 | 72 | # Crashlytics plugin (for Android Studio and IntelliJ) 73 | com_crashlytics_export_strings.xml 74 | crashlytics.properties 75 | crashlytics-build.properties 76 | fabric.properties 77 | 78 | # Editor-based Rest Client 79 | .idea/httpRequests 80 | 81 | # Android studio 3.1+ serialized cache file 82 | .idea/caches/build_file_checksums.ser 83 | 84 | ### PyCharm template 85 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 86 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 87 | 88 | # User-specific stuff 89 | .idea/**/workspace.xml 90 | .idea/**/tasks.xml 91 | .idea/**/usage.statistics.xml 92 | .idea/**/dictionaries 93 | .idea/**/shelf 94 | 95 | # AWS User-specific 96 | .idea/**/aws.xml 97 | 98 | # Generated files 99 | .idea/**/contentModel.xml 100 | 101 | # Sensitive or high-churn files 102 | .idea/**/dataSources/ 103 | .idea/**/dataSources.ids 104 | .idea/**/dataSources.local.xml 105 | .idea/**/sqlDataSources.xml 106 | .idea/**/dynamic.xml 107 | .idea/**/uiDesigner.xml 108 | .idea/**/dbnavigator.xml 109 | 110 | # Gradle 111 | .idea/**/gradle.xml 112 | .idea/**/libraries 113 | 114 | # Gradle and Maven with auto-import 115 | # When using Gradle or Maven with auto-import, you should exclude module files, 116 | # since they will be recreated, and may cause churn. Uncomment if using 117 | # auto-import. 118 | # .idea/artifacts 119 | # .idea/compiler.xml 120 | # .idea/jarRepositories.xml 121 | # .idea/modules.xml 122 | # .idea/*.iml 123 | # .idea/modules 124 | # *.iml 125 | # *.ipr 126 | 127 | # CMake 128 | cmake-build-*/ 129 | 130 | # Mongo Explorer plugin 131 | .idea/**/mongoSettings.xml 132 | 133 | # File-based project format 134 | *.iws 135 | 136 | # IntelliJ 137 | out/ 138 | 139 | # mpeltonen/sbt-idea plugin 140 | .idea_modules/ 141 | 142 | # JIRA plugin 143 | atlassian-ide-plugin.xml 144 | 145 | # Cursive Clojure plugin 146 | .idea/replstate.xml 147 | 148 | # SonarLint plugin 149 | .idea/sonarlint/ 150 | 151 | # Crashlytics plugin (for Android Studio and IntelliJ) 152 | com_crashlytics_export_strings.xml 153 | crashlytics.properties 154 | crashlytics-build.properties 155 | fabric.properties 156 | 157 | # Editor-based Rest Client 158 | .idea/httpRequests 159 | 160 | # Android studio 3.1+ serialized cache file 161 | .idea/caches/build_file_checksums.ser 162 | 163 | ### Python template 164 | # Byte-compiled / optimized / DLL files 165 | __pycache__/ 166 | *.py[cod] 167 | *$py.class 168 | 169 | # C extensions 170 | *.so 171 | 172 | # Distribution / packaging 173 | .Python 174 | build/ 175 | develop-eggs/ 176 | dist/ 177 | downloads/ 178 | eggs/ 179 | .eggs/ 180 | lib/ 181 | lib64/ 182 | parts/ 183 | sdist/ 184 | var/ 185 | wheels/ 186 | share/python-wheels/ 187 | *.egg-info/ 188 | .installed.cfg 189 | *.egg 190 | MANIFEST 191 | 192 | # PyInstaller 193 | # Usually these files are written by a python script from a template 194 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 195 | *.manifest 196 | *.spec 197 | 198 | # Installer logs 199 | pip-log.txt 200 | pip-delete-this-directory.txt 201 | 202 | # Unit test / coverage reports 203 | htmlcov/ 204 | .tox/ 205 | .nox/ 206 | .coverage 207 | .coverage.* 208 | .cache 209 | nosetests.xml 210 | coverage.xml 211 | *.cover 212 | *.py,cover 213 | .hypothesis/ 214 | .pytest_cache/ 215 | cover/ 216 | 217 | # Translations 218 | *.mo 219 | *.pot 220 | 221 | # Django stuff: 222 | *.log 223 | local_settings.py 224 | db.sqlite3 225 | db.sqlite3-journal 226 | 227 | # Flask stuff: 228 | instance/ 229 | .webassets-cache 230 | 231 | # Scrapy stuff: 232 | .scrapy 233 | 234 | # Sphinx documentation 235 | docs/_build/ 236 | 237 | # PyBuilder 238 | .pybuilder/ 239 | target/ 240 | 241 | # Jupyter Notebook 242 | .ipynb_checkpoints 243 | 244 | # IPython 245 | profile_default/ 246 | ipython_config.py 247 | 248 | # pyenv 249 | # For a library or package, you might want to ignore these files since the code is 250 | # intended to run in multiple environments; otherwise, check them in: 251 | # .python-version 252 | 253 | # pipenv 254 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 255 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 256 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 257 | # install all needed dependencies. 258 | #Pipfile.lock 259 | 260 | # poetry 261 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 262 | # This is especially recommended for binary packages to ensure reproducibility, and is more 263 | # commonly ignored for libraries. 264 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 265 | #poetry.lock 266 | 267 | # pdm 268 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 269 | #pdm.lock 270 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 271 | # in version control. 272 | # https://pdm.fming.dev/#use-with-ide 273 | .pdm.toml 274 | 275 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 276 | __pypackages__/ 277 | 278 | # Celery stuff 279 | celerybeat-schedule 280 | celerybeat.pid 281 | 282 | # SageMath parsed files 283 | *.sage.py 284 | 285 | # Environments 286 | .env 287 | .venv 288 | env/ 289 | venv/ 290 | ENV/ 291 | env.bak/ 292 | venv.bak/ 293 | 294 | # Spyder project settings 295 | .spyderproject 296 | .spyproject 297 | 298 | # Rope project settings 299 | .ropeproject 300 | 301 | # mkdocs documentation 302 | /site 303 | 304 | # mypy 305 | .mypy_cache/ 306 | .dmypy.json 307 | dmypy.json 308 | 309 | # Pyre type checker 310 | .pyre/ 311 | 312 | # pytype static type analyzer 313 | .pytype/ 314 | 315 | # Cython debug symbols 316 | cython_debug/ 317 | 318 | # PyCharm 319 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 320 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 321 | # and can be added to the global gitignore or merged into this file. For a more nuclear 322 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 323 | #.idea/ 324 | 325 | -------------------------------------------------------------------------------- /strategies/strategies.py: -------------------------------------------------------------------------------- 1 | # =========================================================================== 2 | # Project: Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging - IOL Lab @ ZIB 3 | # Paper: arxiv.org/abs/2306.16788 4 | # File: strategies/strategies.py 5 | # Description: Sparsification strategies for regular training 6 | # =========================================================================== 7 | import sys 8 | from collections import OrderedDict 9 | 10 | import torch 11 | import torch.nn.utils.prune as prune 12 | 13 | 14 | #### Dense Base Class 15 | class Dense: 16 | """Dense base class for defining callbacks, does nothing but showing the structure and inherits.""" 17 | required_params = [] 18 | 19 | def __init__(self, **kwargs): 20 | self.masks = dict() 21 | self.lr_dict = OrderedDict() # it:lr 22 | self.is_in_finetuning_phase = False 23 | 24 | self.model = kwargs['model'] 25 | self.run_config = kwargs['config'] 26 | self.callbacks = kwargs['callbacks'] 27 | self.goal_sparsity = self.run_config['goal_sparsity'] 28 | 29 | self.optimizer = None # To be set 30 | self.n_total_iterations = None 31 | 32 | def after_initialization(self): 33 | """Called after initialization of the strategy""" 34 | self.parameters_to_prune = [(module, 'weight') for name, module in self.model.named_modules() if 35 | hasattr(module, 'weight') 36 | and not isinstance(module.weight, type(None)) and not isinstance(module, 37 | torch.nn.BatchNorm2d)] 38 | self.n_prunable_parameters = sum( 39 | getattr(module, param_type).numel() for module, param_type in self.parameters_to_prune) 40 | 41 | def set_optimizer(self, opt, **kwargs): 42 | self.optimizer = opt 43 | if 'n_total_iterations' in kwargs: 44 | self.n_total_iterations = kwargs['n_total_iterations'] 45 | 46 | @torch.no_grad() 47 | def after_training_iteration(self, **kwargs): 48 | """Called after each training iteration""" 49 | if not self.is_in_finetuning_phase: 50 | self.lr_dict[kwargs['it']] = kwargs['lr'] 51 | 52 | def at_train_begin(self): 53 | """Called before training begins""" 54 | pass 55 | 56 | def at_epoch_start(self, **kwargs): 57 | """Called before the epoch starts""" 58 | pass 59 | 60 | def at_epoch_end(self, **kwargs): 61 | """Called at epoch end""" 62 | pass 63 | 64 | def at_train_end(self, **kwargs): 65 | """Called at the end of training""" 66 | pass 67 | 68 | def final(self): 69 | pass 70 | 71 | @torch.no_grad() 72 | def pruning_step(self, pruning_sparsity, only_save_mask=False, compute_from_scratch=False): 73 | if compute_from_scratch: 74 | # We have to revert to weight_orig and then compute the mask 75 | for module, param_type in self.parameters_to_prune: 76 | if prune.is_pruned(module): 77 | # Enforce the equivalence of weight_orig and weight 78 | orig = getattr(module, param_type + "_orig").detach().clone() 79 | prune.remove(module, param_type) 80 | p = getattr(module, param_type) 81 | p.copy_(orig) 82 | del orig 83 | elif only_save_mask and len(self.masks) > 0: 84 | for module, param_type in self.parameters_to_prune: 85 | if (module, param_type) in self.masks: 86 | prune.custom_from_mask(module, param_type, self.masks[(module, param_type)]) 87 | 88 | if self.run_config['pruning_selector'] is not None and self.run_config['pruning_selector'] == 'uniform': 89 | # We prune each layer individually 90 | for module, param_type in self.parameters_to_prune: 91 | prune.l1_unstructured(module, name=param_type, amount=pruning_sparsity) 92 | else: 93 | # Default: prune globally 94 | prune.global_unstructured( 95 | self.parameters_to_prune, 96 | pruning_method=self.get_pruning_method(), 97 | amount=pruning_sparsity, 98 | ) 99 | 100 | self.masks = dict() # Stays empty if we use regular pruning 101 | if only_save_mask: 102 | for module, param_type in self.parameters_to_prune: 103 | if prune.is_pruned(module): 104 | # Save the mask 105 | mask = getattr(module, param_type + '_mask') 106 | self.masks[(module, param_type)] = mask.detach().clone() 107 | setattr(module, param_type + '_mask', torch.ones_like(mask)) 108 | # Remove (i.e. make permanent) the reparameterization 109 | prune.remove(module=module, name=param_type) 110 | # Delete the temporary mask to free memory 111 | del mask 112 | 113 | def enforce_prunedness(self): 114 | """ 115 | Makes the pruning permanent, i.e. set the pruned weights to zero, than reinitialize from the same mask 116 | This ensures that we can actually work (i.e. LMO, rescale computation) with the parameters 117 | Important: For this to work we require that pruned weights stay zero in weight_orig over training 118 | hence training, projecting etc should not modify (pruned) 0 weights in weight_orig 119 | """ 120 | for module, param_type in self.parameters_to_prune: 121 | if prune.is_pruned(module): 122 | # Save the mask 123 | mask = getattr(module, param_type + '_mask') 124 | # Remove (i.e. make permanent) the reparameterization 125 | prune.remove(module=module, name=param_type) 126 | # Reinitialize the pruning 127 | prune.custom_from_mask(module=module, name=param_type, mask=mask) 128 | # Delete the temporary mask to free memory 129 | del mask 130 | 131 | def prune_momentum(self): 132 | opt_state = self.optimizer.state 133 | for module, param_type in self.parameters_to_prune: 134 | if prune.is_pruned(module): 135 | # Enforce the prunedness of momentum buffer 136 | param_state = opt_state[getattr(module, param_type + "_orig")] 137 | if 'momentum_buffer' in param_state: 138 | mask = getattr(module, param_type + "_mask") 139 | param_state['momentum_buffer'] *= mask.to(dtype=param_state['momentum_buffer'].dtype) 140 | 141 | def get_pruning_method(self): 142 | raise NotImplementedError("Dense has no pruning method, this must be implemented in each child class.") 143 | 144 | @torch.no_grad() 145 | def make_pruning_permanent(self): 146 | """Makes the pruning permanent and removes the pruning hooks""" 147 | # Note: this does not remove the pruning itself, but rather makes it permanent 148 | if len(self.masks) == 0: 149 | for module, param_type in self.parameters_to_prune: 150 | if prune.is_pruned(module): 151 | prune.remove(module, param_type) 152 | else: 153 | for module, param_type in self.masks: 154 | # Get the mask 155 | mask = self.masks[(module, param_type)] 156 | 157 | # Apply the mask 158 | orig = getattr(module, param_type) 159 | orig *= mask 160 | self.masks = dict() 161 | 162 | def set_to_finetuning_phase(self): 163 | self.is_in_finetuning_phase = True 164 | 165 | 166 | class IMP(Dense): 167 | """Iterative Magnitude Pruning Base Class""" 168 | 169 | def __init__(self, **kwargs) -> None: 170 | super().__init__(**kwargs) 171 | 172 | self.phase = self.run_config['phase'] 173 | self.n_phases = self.run_config['n_phases'] 174 | self.n_epochs_per_phase = self.run_config['n_epochs_per_phase'] 175 | 176 | def at_train_end(self, **kwargs): 177 | # Sparsity factor on remaining weights after each round, yields desired_sparsity after all rounds 178 | prune_per_phase = 1 - (1 - self.goal_sparsity) ** (1. / self.n_phases) 179 | phase = self.phase 180 | self.pruning_step(pruning_sparsity=prune_per_phase) 181 | self.current_sparsity = 1 - (1 - prune_per_phase) ** phase 182 | self.callbacks['after_pruning_callback']() 183 | self.finetuning_step(pruning_sparsity=prune_per_phase, phase=phase) 184 | 185 | def finetuning_step(self, pruning_sparsity, phase): 186 | self.callbacks['finetuning_callback'](pruning_sparsity=pruning_sparsity, 187 | n_epochs_finetune=self.n_epochs_per_phase, 188 | phase=phase) 189 | 190 | def get_pruning_method(self): 191 | if self.run_config['pruning_selector'] in ['global', 'uniform']: 192 | # For uniform this is not actually needed, we always select using L1 193 | return prune.L1Unstructured 194 | elif self.run_config['pruning_selector'] == 'random': 195 | return prune.RandomUnstructured 196 | else: 197 | raise NotImplementedError 198 | 199 | def final(self): 200 | super().final() 201 | self.callbacks['final_log_callback']() 202 | -------------------------------------------------------------------------------- /strategies/ensembleStrategies.py: -------------------------------------------------------------------------------- 1 | # =========================================================================== 2 | # Project: Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging - IOL Lab @ ZIB 3 | # Paper: arxiv.org/abs/2306.16788 4 | # File: strategies/ensembleStrategies.py 5 | # Description: Strategies for building a soup. 6 | # =========================================================================== 7 | import sys 8 | from collections import OrderedDict 9 | from collections import defaultdict 10 | 11 | import numpy as np 12 | import torch 13 | 14 | from strategies import strategies as usual_strategies 15 | from utilities.utilities import Candidate 16 | from utilities.utilities import Utilities as Utils 17 | 18 | 19 | #### Base Class 20 | class EnsemblingBaseClass(usual_strategies.Dense): 21 | """Ensembling Base Class""" 22 | 23 | def __init__(self, **kwargs) -> None: 24 | super().__init__(**kwargs) 25 | self.candidate_model_list = kwargs['candidate_models'] 26 | self.runner = kwargs['runner'] 27 | self.selected_models = None 28 | self.soup_metrics = {soup_type: {} for soup_type in ['candidates', 'selected']} 29 | 30 | @torch.no_grad() 31 | def get_soup_metrics(self, soup_list: list[Candidate]): 32 | 33 | # Load the models 34 | model_list = [candidate.get_model_weights() for candidate in soup_list] 35 | 36 | soup_metrics = { 37 | 'max_barycentre_distance': Utils.get_barycentre_l2_distance(model_list), 38 | 'min_barycentre_distance': Utils.get_barycentre_l2_distance(model_list, maximize=False), 39 | } 40 | 41 | for metric_name, metric_fn in zip(['l2_distance', 'angle'], [Utils.get_l2_distance, Utils.get_angle]): 42 | for agg_name, agg_fn in zip(['max', 'min', 'mean'], [torch.max, torch.min, torch.mean]): 43 | soup_metrics[f'{agg_name}_{metric_name}'] = Utils.aggregate_group_metrics(models=model_list, 44 | metric_fn=metric_fn, 45 | aggregate_fn=agg_fn) 46 | return soup_metrics 47 | 48 | def collect_candidate_information(self): 49 | model_list = [] 50 | metrics_dict = {split: defaultdict(list) for split in ['test', 'ood']} 51 | for candidate in self.candidate_model_list: 52 | candidate_id, candidate_file = candidate.id, candidate.file 53 | if self.runner.model is not None: 54 | del self.runner.model 55 | torch.cuda.empty_cache() 56 | 57 | state_dict = torch.load(candidate_file, 58 | map_location=torch.device('cpu')) # Load to CPU to avoid memory overhead 59 | self.runner.load_soup_model(ensemble_state_dict=state_dict) 60 | m, _ = Utils.split_weights_and_masks(state_dict) 61 | model_list.append(m) 62 | del state_dict 63 | self.runner.recalibrate_bn() 64 | 65 | # Collect and set test/ood metrics 66 | for split in ['test', 'ood']: 67 | single_model_metrics = self.runner.evaluate_soup(data=split) 68 | for metric, value in single_model_metrics.items(): 69 | metrics_dict[split][metric].append(value) 70 | candidate.set_metrics(metrics=single_model_metrics, split=split) 71 | 72 | # Collect metrics that are needed for other strategies to perform model selection 73 | single_model_val_metrics = self.runner.evaluate_soup(data='val') 74 | candidate.set_metrics(metrics=single_model_val_metrics, split='val') 75 | 76 | # Collect a lot of soup metrics 77 | candidates_soup_metrics = self.get_soup_metrics(soup_list=self.candidate_model_list) 78 | self.soup_metrics['candidates'] = candidates_soup_metrics 79 | for split in ['test', 'ood']: 80 | for aggName, aggFunc in zip(['mean', 'max'], [np.mean, np.max]): 81 | for metric, values in metrics_dict[split].items(): 82 | self.soup_metrics['candidates'][f'{split}.{metric}_{aggName}'] = aggFunc(values) 83 | 84 | # Collect prediction ensemble metrics 85 | ensemble_labels = self.runner.collect_avg_output_full(data='test', 86 | candidate_model_list=self.candidate_model_list) 87 | ensemble_metrics = { 88 | 'pred_ensemble.test': self.runner.evaluate_soup(data='test', ensemble_labels=ensemble_labels)} 89 | self.soup_metrics['candidates'].update(ensemble_metrics) 90 | 91 | sys.stdout.write(f"Test accuracies of ensemble runs: {metrics_dict['test']['accuracy']}.\n") 92 | 93 | def create_ensemble(self, **kwargs): 94 | n_models = len(self.candidate_model_list) 95 | assert n_models >= 2, "Not enough models to ensemble" 96 | self.enforce_prunedness() 97 | 98 | @torch.no_grad() 99 | def enforce_prunedness(self, device=torch.device('cpu')): 100 | """Enforce prunedness of the model""" 101 | for candidate in self.candidate_model_list: 102 | candidate.enforce_prunedness(device=device) 103 | 104 | @torch.no_grad() 105 | def average_models(self, soup_list: list[Candidate], soup_weights: torch.Tensor = None, 106 | device: torch.device = torch.device('cpu')): 107 | if soup_weights is None: 108 | soup_weights = torch.ones(len(soup_list)) / len(soup_list) 109 | ensemble_state_dict = OrderedDict() 110 | 111 | for idx, candidate in enumerate(soup_list): 112 | candidate_id, candidate_file = candidate.id, candidate.file 113 | state_dict = torch.load(candidate_file, map_location=device) 114 | for key, val in state_dict.items(): 115 | factor = soup_weights[idx].item() # No need to use tensor here 116 | if '_mask' in key: 117 | # We dont want to average the masks, hence we skip them and add later 118 | continue 119 | if key not in ensemble_state_dict: 120 | ensemble_state_dict[ 121 | key] = factor * val.detach().clone() # Important: clone otherwise we modify the tensors 122 | else: 123 | ensemble_state_dict[ 124 | key] += factor * val.detach().clone() # Important: clone otherwise we modify the tensors 125 | 126 | # Add the masks from the last state_dict 127 | for key, val in state_dict.items(): 128 | if '_mask' in key: 129 | ensemble_state_dict[key] = val.detach().clone() 130 | 131 | return ensemble_state_dict 132 | 133 | def final(self): 134 | self.callbacks['final_log_callback']() 135 | 136 | def get_ensemble_metrics(self): 137 | if self.selected_models == 'all': 138 | # We have already collected the metrics for all models 139 | self.soup_metrics['selected'] = self.soup_metrics['candidates'] 140 | else: 141 | assert self.selected_models is not None and len(self.selected_models) > 0, "No models selected for metrics." 142 | # Collect individual metrics for the selected models, which we already have 143 | metrics_dict = defaultdict(lambda 144 | : defaultdict(list)) 145 | for split in ['test', 'ood']: 146 | for candidate in self.selected_models: 147 | single_model_metrics = candidate.get_metrics(split=split) 148 | for metric, value in single_model_metrics.items(): 149 | metrics_dict[split][metric].append(value) 150 | for aggName, aggFunc in zip(['mean', 'max'], [np.mean, np.max]): 151 | for metric, values in metrics_dict[split].items(): 152 | self.soup_metrics['selected'][f'{split}.{metric}_{aggName}'] = aggFunc(values) 153 | 154 | # Collect group_metrics for the selected models 155 | group_metrics = self.get_soup_metrics(soup_list=self.selected_models) 156 | self.soup_metrics['selected'].update(group_metrics) 157 | 158 | # Collect prediction ensemble metrics, only for test for now 159 | ensemble_labels = self.runner.collect_avg_output_full(data='test', 160 | candidate_model_list=self.selected_models) 161 | ensemble_metrics = { 162 | 'pred_ensemble.test': self.runner.evaluate_soup(data='test', ensemble_labels=ensemble_labels)} 163 | self.soup_metrics['selected'].update(ensemble_metrics) 164 | return self.soup_metrics 165 | 166 | 167 | class UniformEnsembling(EnsemblingBaseClass): 168 | """Just averages all models""" 169 | 170 | def __init__(self, **kwargs) -> None: 171 | super().__init__(**kwargs) 172 | 173 | @torch.no_grad() 174 | def create_ensemble(self, **kwargs): 175 | super().create_ensemble(**kwargs) 176 | 177 | device = torch.device('cpu') 178 | soup_weights = self.get_soup_weights(soup_list=self.candidate_model_list) 179 | ensemble_state_dict = self.average_models(soup_list=self.candidate_model_list, soup_weights=soup_weights, 180 | device=device) 181 | self.selected_models = 'all' 182 | return ensemble_state_dict 183 | 184 | def get_soup_weights(self, soup_list: list[Candidate]): 185 | uniform_factor = 1. / len(soup_list) 186 | return torch.tensor([uniform_factor] * len(soup_list)) 187 | 188 | 189 | class GreedySoup(EnsemblingBaseClass): 190 | """Greedy approach""" 191 | 192 | def __init__(self, **kwargs) -> None: 193 | super().__init__(**kwargs) 194 | 195 | @torch.no_grad() 196 | def create_ensemble(self, **kwargs): 197 | super().create_ensemble(**kwargs) 198 | val_accuracies = [(candidate, candidate.get_single_metric(metric='accuracy', split='val')) 199 | for candidate in self.candidate_model_list] 200 | device = torch.device('cpu') 201 | 202 | # Sort the models by their validation accuracy in decreasing order 203 | sorted_tuples = sorted(val_accuracies, key=lambda x: x[1], reverse=True) 204 | 205 | ingredients_candidates = [sorted_tuples[0][0]] 206 | max_val_accuracy = sorted_tuples[0][1] 207 | for candidate, _ in sorted_tuples[1:]: 208 | # Check whether we benefit from adding to the soup 209 | ensemble_state_dict = self.average_models(soup_list=ingredients_candidates + [candidate], device=device) 210 | self.callbacks['load_soup_callback'](ensemble_state_dict=ensemble_state_dict) 211 | self.callbacks['recalibrate_bn_callback']() 212 | soup_metrics = self.callbacks['soup_evaluation_callback'](data='val') 213 | soup_val_accuracy = soup_metrics['accuracy'] 214 | if soup_val_accuracy >= max_val_accuracy: 215 | ingredients_candidates = ingredients_candidates + [candidate] 216 | max_val_accuracy = soup_val_accuracy 217 | 218 | self.selected_models = ingredients_candidates 219 | if len(ingredients_candidates) == len(self.candidate_model_list): 220 | self.selected_models = 'all' 221 | sys.stdout.write("GreedySoup used all candidates.\n") 222 | else: 223 | sys.stdout.write( 224 | f"GreedySoup used candidates with ids: {[candidate.id for candidate in ingredients_candidates]}.\n") 225 | final_ensemble_state_dict = self.average_models(soup_list=ingredients_candidates, device=device) 226 | return final_ensemble_state_dict 227 | -------------------------------------------------------------------------------- /runners/pretrainedRunner.py: -------------------------------------------------------------------------------- 1 | # =========================================================================== 2 | # Project: Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging - IOL Lab @ ZIB 3 | # Paper: arxiv.org/abs/2306.16788 4 | # File: pretrainedRunner.py 5 | # Description: Runner class for starting from a pretrained model 6 | # =========================================================================== 7 | import json 8 | import sys 9 | import warnings 10 | from collections import OrderedDict 11 | 12 | import numpy as np 13 | import torch 14 | import wandb 15 | 16 | from runners.baseRunner import baseRunner 17 | from utilities.utilities import Utilities as Utils 18 | 19 | 20 | class pretrainedRunner(baseRunner): 21 | 22 | def __init__(self, **kwargs): 23 | super().__init__(**kwargs) 24 | self.reference_run = None 25 | 26 | def find_existing_model(self, filterDict): 27 | """Finds an existing wandb run and downloads the model file.""" 28 | phase_before_current = self.config.phase - 1 29 | if phase_before_current > 0: 30 | # We specify the phase in the filterDict, because we want to find the model that was trained in the previous phase 31 | filterDict['$and'].append({'config.phase': phase_before_current}) 32 | 33 | # We specify several other identifiers 34 | identifiers = [{"config.goal_sparsity": self.config.goal_sparsity}, 35 | {"config.n_epochs_per_phase": self.config.n_epochs_per_phase}, 36 | {"config.n_phases": self.config.n_phases}, 37 | {"config.retrain_schedule": self.config.retrain_schedule}] 38 | for identifier in identifiers: 39 | filterDict['$and'].append(identifier) 40 | 41 | sys.stdout.write( 42 | f"Specified ensemble_method {self.config.ensemble_method}, ensemble_by {self.config.ensemble_by}, split_val {self.config.split_val}.\n") 43 | filterDict['$and'].append({'config.ensemble_by': self.config.ensemble_by}) 44 | if self.config.ensemble_method not in [None, 'None', 'none']: 45 | filterDict['$and'].append({'config.strategy': 'Ensemble'}) 46 | filterDict['$and'].append({'config.ensemble_method': self.config.ensemble_method}) 47 | 48 | # We now also need to filter for n_splits_total since otherwise we use different settings 49 | sys.stdout.write(f"Looking for n_splits_total {self.config.n_splits_total}.\n") 50 | assert self.config.n_splits_total is not None 51 | filterDict['$and'].append({'config.n_splits_total': self.config.n_splits_total}) 52 | else: 53 | # No ensemble method specified, we perform regular IMP 54 | sys.stdout.write("Looking for last retrained model.\n") 55 | filterDict['$and'].append({'config.strategy': 'IMP'}) 56 | filterDict['$and'].append({'config.split_val': self.config.split_val}) 57 | else: 58 | assert self.config.n_splits_total is not None 59 | filterDict['$and'].append({'config.strategy': 'Dense'}) 60 | 61 | entity, project = wandb.run.entity, wandb.run.project 62 | api = wandb.Api() 63 | # Some variables have to be extracted from the filterDict and checked manually, e.g. weight decay in scientific format 64 | manualVariables = ['weight_decay', 'penalty', 'group_penalty'] 65 | manVarDict = {} 66 | dropIndices = [] 67 | for var in manualVariables: 68 | for i in range(len(filterDict['$and'])): 69 | entry = filterDict['$and'][i] 70 | s = f"config.{var}" 71 | if s in entry: 72 | dropIndices.append(i) 73 | manVarDict[var] = entry[s] 74 | for idx in reversed(sorted(dropIndices)): filterDict['$and'].pop(idx) 75 | 76 | checkpoint_file = None 77 | runs = api.runs(f"{entity}/{project}", filters=filterDict) 78 | runsExist = False # If True, then there exist runs that try to set a fixed init 79 | for run in runs: 80 | if run.state == 'failed': 81 | # Ignore this run 82 | continue 83 | # Check if run satisfies the manual variables 84 | conflict = False 85 | for var, val in manVarDict.items(): 86 | if var in run.config and run.config[var] != val: 87 | conflict = True 88 | break 89 | if conflict: 90 | continue 91 | 92 | checkpoint_file = run.summary.get('final_model_file') 93 | try: 94 | if checkpoint_file is not None: 95 | runsExist = True 96 | run.file(checkpoint_file).download(root=self.tmp_dir) 97 | seed = run.config['seed'] 98 | reference_run = run 99 | break 100 | except Exception as e: # The run is online, but the model is not uploaded yet -> results in failing runs 101 | print(e) 102 | checkpoint_file = None 103 | assert not ( 104 | runsExist and checkpoint_file is None), "Runs found, but none of them have a model available -> abort." 105 | outputStr = f"Found {checkpoint_file} in run {run.name}" \ 106 | if checkpoint_file is not None else "Nothing found." 107 | sys.stdout.write(f"Trying to find reference trained model in project: {outputStr}\n") 108 | assert checkpoint_file is not None, "No reference trained model found, Aborting." 109 | return checkpoint_file, seed, reference_run 110 | 111 | def get_missing_config(self): 112 | missing_config_keys = ['momentum', 113 | 'n_epochs_warmup', 114 | 'n_epochs'] # Have to have n_epochs even though it might be specified, otherwise ALLR doesnt have this 115 | 116 | additional_dict = { 117 | 'last_training_lr': self.reference_run.summary['final.learning_rate'], 118 | 'final.test.accuracy': self.reference_run.summary['final.test']['accuracy'], 119 | 'final.train.accuracy': self.reference_run.summary['final.train']['accuracy'], 120 | 'final.train.loss': self.reference_run.summary['final.train']['loss'], 121 | } 122 | for key in missing_config_keys: 123 | if key not in self.config or self.config[key] is None: 124 | # Allow_val_change = true because e.g. momentum defaults to None, but shouldn't be passed here 125 | val = self.reference_run.config.get(key) # If not found, defaults to None 126 | self.config.update({key: val}, allow_val_change=True) 127 | self.config.update(additional_dict) 128 | 129 | self.trained_test_accuracy = additional_dict['final.test.accuracy'] 130 | self.trained_train_loss = additional_dict['final.train.loss'] 131 | self.trained_train_accuracy = additional_dict['final.train.accuracy'] 132 | 133 | def define_optimizer_scheduler(self): 134 | # Define the optimizer using the parameters from the reference run 135 | if self.config.optimizer == 'SGD': 136 | wd = self.config['weight_decay'] or 0. 137 | if self.config.ensemble_by == 'weight_decay': 138 | wd = self.config.split_val 139 | sys.stdout.write(f"We split by the weight decay. Value {wd}.\n") 140 | self.optimizer = torch.optim.SGD(params=self.model.parameters(), lr=self.config['last_training_lr'], 141 | momentum=self.config['momentum'], 142 | weight_decay=wd, 143 | nesterov=wd > 0.) 144 | 145 | def fill_strategy_information(self): 146 | # Get the wandb information about lr and fill the corresponding strategy dicts, which can then be used by rewinders 147 | f = self.reference_run.file('iteration-lr-dict.json').download(root=self.tmp_dir) 148 | with open(f.name) as json_file: 149 | loaded_dict = json.load(json_file) 150 | self.strategy.lr_dict = OrderedDict(loaded_dict) 151 | # Upload iteration-lr dict from self.strategy to be used during retraining 152 | Utils.dump_dict_to_json_wandb(dumpDict=self.strategy.lr_dict, name='iteration-lr-dict') 153 | 154 | def run(self): 155 | """Function controlling the workflow of pretrainedRunner""" 156 | # Find the reference run 157 | filterDict = {"$and": [{"config.run_id": self.config.run_id}, 158 | {"config.arch": self.config.arch}, 159 | {"config.optimizer": self.config.optimizer}, 160 | ]} 161 | 162 | assert self.config.phase is not None 163 | assert self.config.split_val is not None, "split_val has to be specified." 164 | if self.config.ensemble_by not in [None, 'None', 'none']: 165 | # We do not perform regular IMP 166 | assert self.config.ensemble_by in ['pruned_seed', 'weight_decay', 'retrain_length', 'retrain_schedule'] 167 | 168 | if self.config.learning_rate is not None: 169 | warnings.warn( 170 | "You specified an explicit learning rate for retraining. Note that this only controls the selection of the pretrained model.") 171 | filterDict["$and"].append({"config.learning_rate": self.config.learning_rate}) 172 | if self.config.n_epochs is not None: 173 | warnings.warn( 174 | "You specified n_epochs for retraining. Note that this only controls the selection of the pretrained model.") 175 | filterDict["$and"].append({"config.n_epochs": self.config.n_epochs}) 176 | 177 | self.checkpoint_file, self.seed, self.reference_run = self.find_existing_model(filterDict=filterDict) 178 | wandb.config.update({'seed': self.seed}) # Push the seed to wandb 179 | seed = self.seed 180 | if self.config.ensemble_by == 'pruned_seed': 181 | # We use a new seed for retraining depending on the true seed (self.seed) and the pruned_seed 182 | seed = self.seed + self.config.split_val 183 | sys.stdout.write(f"Original seed {self.seed}, new seed {seed}.\n") 184 | # Set a unique random seed 185 | np.random.seed(seed) 186 | torch.manual_seed(seed) 187 | # Remark: If you are working with a multi-GPU model, this function is insufficient to get determinism. To seed all GPUs, use manual_seed_all(). 188 | torch.cuda.manual_seed(seed) # This works if CUDA not available 189 | 190 | torch.backends.cudnn.benchmark = True 191 | self.get_missing_config() # Load keys that are missing in the config 192 | 193 | self.trainLoader, self.valLoader, self.testLoader, self.trainLoader_unshuffled = self.get_dataloaders() 194 | self.model = self.get_model(reinit=True, temporary=True) # Load the previous model 195 | 196 | self.squared_model_norm = Utils.get_model_norm_square(model=self.model) 197 | # Define strategy 198 | self.strategy = self.define_strategy() 199 | self.strategy.set_to_finetuning_phase() 200 | self.strategy.after_initialization() # To ensure that all parameters are properly set 201 | self.define_optimizer_scheduler() # This HAS to be after the definition of the strategy, otherwise changing the models parameters will not be noticed by the optimizer! 202 | self.strategy.set_optimizer(opt=self.optimizer) 203 | self.fill_strategy_information() 204 | 205 | # Run the computations 206 | self.strategy.at_train_end() 207 | 208 | self.strategy.final() 209 | 210 | # Save pruned model, to be used by pretrainedRunner 211 | self.checkpoint_file = self.save_model(model_type='pruned') 212 | wandb.summary['final_model_file'] = f"pruned_model_{self.config.split_val}_{self.config.phase}.pt" 213 | -------------------------------------------------------------------------------- /utilities/utilities.py: -------------------------------------------------------------------------------- 1 | # =========================================================================== 2 | # Project: Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging - IOL Lab @ ZIB 3 | # Paper: arxiv.org/abs/2306.16788 4 | # File: utilities.py 5 | # Description: Contains a variety of useful functions. 6 | # =========================================================================== 7 | import itertools 8 | import json 9 | import math 10 | import os 11 | import sys 12 | from collections import defaultdict, OrderedDict 13 | from typing import NamedTuple, Union 14 | 15 | import torch 16 | import torchmetrics 17 | import wandb 18 | from torchmetrics.classification import MulticlassAccuracy as Accuracy 19 | 20 | 21 | class PermutationSpec(NamedTuple): 22 | perm_to_axes: dict 23 | axes_to_perm: dict 24 | 25 | 26 | class Utilities: 27 | """Class of utility functions""" 28 | 29 | @staticmethod 30 | def fill_dict_with_none(d): 31 | for key in d: 32 | if isinstance(d[key], dict): 33 | Utilities.fill_dict_with_none(d[key]) # Recursive call for nested dictionaries 34 | else: 35 | d[key] = None 36 | return d 37 | 38 | @staticmethod 39 | def update_config_with_default(configDict, defaultDict): 40 | """Update config with default values recursively.""" 41 | for key, default_value in defaultDict.items(): 42 | if key not in configDict: 43 | configDict[key] = default_value 44 | elif isinstance(default_value, dict): 45 | configDict[key] = Utilities.update_config_with_default(configDict.get(key, {}), default_value) 46 | return configDict 47 | 48 | @staticmethod 49 | @torch.no_grad() 50 | def get_model_norm_square(model): 51 | """Get L2 norm squared of parameter vector. This works for a pruned model as well.""" 52 | squared_norm = 0. 53 | param_list = ['weight', 'bias'] 54 | for name, module in model.named_modules(): 55 | for param_type in param_list: 56 | if hasattr(module, param_type) and not isinstance(getattr(module, param_type), type(None)): 57 | param = getattr(module, param_type) 58 | squared_norm += torch.norm(param, p=2) ** 2 59 | return float(squared_norm) 60 | 61 | @staticmethod 62 | @torch.no_grad() 63 | def aggregate_group_metrics(models: list[Union[OrderedDict, torch.nn.Module]], metric_fn: callable, 64 | aggregate_fn: callable) -> float: 65 | if len(models) == 1: 66 | sys.stdout.write('Warning: aggregate_group_metrics called with only one model. Returning 0.\n') 67 | return 0. 68 | for idx in range(len(models)): 69 | if isinstance(models[idx], torch.nn.Module): 70 | models[idx] = models[idx].state_dict() 71 | 72 | collected_vals = [] 73 | for idx_i, idx_j in itertools.combinations(range(len(models)), 2): 74 | model_i, model_j = models[idx_i], models[idx_j] 75 | dist = metric_fn(model_i, model_j) 76 | collected_vals.append(dist) 77 | return aggregate_fn(torch.tensor(collected_vals)) 78 | 79 | @staticmethod 80 | @torch.no_grad() 81 | def get_angle(model_a: Union[OrderedDict, torch.nn.Module], model_b: Union[OrderedDict, torch.nn.Module]) -> float: 82 | """Get the angle between two models given as state_dict or nn.Module""" 83 | model_a_dict, model_b_dict = model_a, model_b 84 | if isinstance(model_a, torch.nn.Module): 85 | model_a_dict = model_a.state_dict() 86 | if isinstance(model_b, torch.nn.Module): 87 | model_b_dict = model_b.state_dict() 88 | 89 | dot_product = 0. 90 | squared_norm_a, squared_norm_b = 0., 0. 91 | for pName in model_a_dict.keys(): 92 | p_a, p_b = model_a_dict[pName].flatten(), model_b_dict[pName].flatten() 93 | dot_product += torch.dot(p_a, p_b).item() 94 | squared_norm_a += torch.dot(p_a, p_a).item() 95 | squared_norm_b += torch.dot(p_b, p_b).item() 96 | 97 | # Compute the cosine similarity 98 | cos_sim = dot_product / (math.sqrt(squared_norm_a) * math.sqrt(squared_norm_b)) 99 | 100 | # Calculate the angle in degrees, but first clamp the cosine similarity to [-1, 1] to avoid numerical errors 101 | angle_deg = math.degrees(math.acos(min(max(cos_sim, -1), 1))) 102 | return angle_deg 103 | 104 | @staticmethod 105 | @torch.no_grad() 106 | def get_l2_distance(model_a: Union[OrderedDict, torch.nn.Module], 107 | model_b: Union[OrderedDict, torch.nn.Module]) -> float: 108 | model_a_dict, model_b_dict = model_a, model_b 109 | if isinstance(model_a, torch.nn.Module): 110 | model_a_dict = model_a.state_dict() 111 | if isinstance(model_b, torch.nn.Module): 112 | model_b_dict = model_b.state_dict() 113 | 114 | squared_norm = 0 115 | for pName in model_a_dict.keys(): 116 | p_a, p_b = model_a_dict[pName], model_b_dict[pName] 117 | squared_norm += torch.norm((p_a - p_b).float(), p=2) ** 2 118 | return float(torch.sqrt(squared_norm)) 119 | 120 | @staticmethod 121 | @torch.no_grad() 122 | def get_barycentre_l2_distance(models: list[Union[OrderedDict, torch.nn.Module]], maximize=True): 123 | """Get the distance between the barycentre of the models and the model with the largest distance to the barycentre. 124 | :param models: list of models given as state_dict or nn.Module 125 | :param maximize: if True, return the maximum distance, else return the minimum distance 126 | :return: the distance between the barycentre of the models and the model with the largest distance to the barycentre. 127 | """ 128 | if len(models) == 1: return 0. 129 | for idx in range(len(models)): 130 | if isinstance(models[idx], torch.nn.Module): 131 | models[idx] = models[idx].state_dict() 132 | 133 | # Compute the barycentre of all models 134 | factor = 1. / len(models) 135 | barycentre = OrderedDict() 136 | 137 | for model_state_dict in models: 138 | for key, val in model_state_dict.items(): 139 | if key not in barycentre: 140 | barycentre[key] = val.detach().clone() # Important: clone otherwise we modify the tensors 141 | else: 142 | barycentre[key] += val.detach().clone() # Important: clone otherwise we modify the tensors 143 | 144 | for key, val in barycentre.items(): 145 | barycentre[key] = barycentre[key] * factor 146 | 147 | distances = [] 148 | for idx_i, model_a in enumerate(models): 149 | dist = Utilities.get_l2_distance(model_a, barycentre) 150 | distances.append(dist) 151 | 152 | if maximize: 153 | return max(distances) 154 | else: 155 | return min(distances) 156 | 157 | @staticmethod 158 | def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec: 159 | perm_to_axes = defaultdict(list) 160 | for wk, axis_perms in axes_to_perm.items(): 161 | for axis, perm in enumerate(axis_perms): 162 | if perm is not None: 163 | perm_to_axes[perm].append((wk, axis)) 164 | return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm) 165 | 166 | @staticmethod 167 | def dump_dict_to_json_wandb(dumpDict, name): 168 | """Dump some dict to json and upload it""" 169 | fPath = os.path.join(wandb.run.dir, f'{name}.json') 170 | with open(fPath, 'w') as fp: 171 | json.dump(dumpDict, fp) 172 | wandb.save(fPath) 173 | 174 | @staticmethod 175 | def get_overloaded_dataset(OriginalDataset): 176 | class AlteredDatasetWrapper(OriginalDataset): 177 | 178 | def __init__(self, *args, **kwargs): 179 | super(AlteredDatasetWrapper, self).__init__(*args, **kwargs) 180 | 181 | def __getitem__(self, index): 182 | # Overload this to collect the class indices once in a vector, which can then be used in the sampler 183 | image, label = super(AlteredDatasetWrapper, self).__getitem__(index=index) 184 | return image, label, index 185 | 186 | AlteredDatasetWrapper.__name__ = OriginalDataset.__name__ 187 | return AlteredDatasetWrapper 188 | 189 | @staticmethod 190 | def split_weights_and_masks(model): 191 | weights, masks = OrderedDict(), OrderedDict() 192 | 193 | for key, value in model.items(): 194 | if '_mask' in key: 195 | name = key.replace('_mask', '') 196 | masks[name] = value 197 | elif '_orig' in key: 198 | name = key.replace('_orig', '') 199 | weights[name] = value 200 | else: 201 | weights[key] = value 202 | return weights, masks 203 | 204 | @staticmethod 205 | def join_weights_and_masks(weights, masks): 206 | state_dict = OrderedDict() 207 | for key, value in weights.items(): 208 | state_dict[key + '_orig'] = value 209 | for key, value in masks.items(): 210 | state_dict[key + '_mask'] = value 211 | return state_dict 212 | 213 | 214 | class WorstClassAccuracy(Accuracy): 215 | def __init__(self, **kwargs): 216 | super().__init__(average=None, **kwargs) 217 | 218 | def compute(self): 219 | class_accuracies = super().compute() 220 | return class_accuracies.min() 221 | 222 | 223 | class CalibrationError(torchmetrics.Metric): 224 | def __init__(self, num_bins=15, norm='l1', dist_sync_on_step=False): 225 | super().__init__(dist_sync_on_step=dist_sync_on_step) 226 | self.num_bins = num_bins 227 | self.norm = norm 228 | self.add_state("bin_boundaries", default=torch.linspace(0, 1, num_bins + 1), dist_reduce_fx=None) 229 | self.add_state("bin_conf_sums", default=torch.zeros(num_bins), dist_reduce_fx="sum") 230 | self.add_state("bin_correct_sums", default=torch.zeros(num_bins), dist_reduce_fx="sum") 231 | self.add_state("bin_total_count", default=torch.zeros(num_bins), dist_reduce_fx="sum") 232 | self.add_state("total_count", default=torch.tensor(0), dist_reduce_fx="sum") 233 | 234 | @torch.no_grad() 235 | def update(self, preds: torch.Tensor, targets: torch.Tensor): 236 | # Transform the predictions into probabilities 237 | preds = torch.softmax(preds, dim=1) 238 | 239 | # Compute the maximum probability for each prediction 240 | max_probs, max_classes = preds.max(dim=1) 241 | 242 | # Check if the predicted class matches the target 243 | correct = (max_classes == targets).float() 244 | 245 | # Compute the confidence for each prediction 246 | confidences = max_probs 247 | 248 | # Map confidences to the corresponding bins 249 | bin_indices = torch.bucketize(confidences, self.bin_boundaries[:-1]) - 1 250 | 251 | # Ensure that the bin indices are in the correct range 252 | bin_indices = bin_indices.clamp(min=0, max=self.num_bins - 1) 253 | 254 | # Update the bin sums and counts 255 | for bin_idx in range(self.num_bins): 256 | mask = bin_indices == bin_idx 257 | self.bin_conf_sums[bin_idx] += (mask * confidences).sum() 258 | self.bin_correct_sums[bin_idx] += (mask * correct).sum() 259 | self.bin_total_count[bin_idx] += mask.sum() 260 | 261 | # Update the total count 262 | self.total_count += preds.shape[0] 263 | 264 | def compute(self): 265 | assert self.total_count.item() == self.bin_total_count.sum() 266 | # Compute the bin accuracies and confidences 267 | bin_accuracies = self.bin_correct_sums / self.bin_total_count.clamp(min=1) 268 | bin_confidences = self.bin_conf_sums / self.bin_total_count.clamp(min=1) 269 | 270 | abs_errors = torch.abs(bin_accuracies - bin_confidences) 271 | rel_freq = self.bin_total_count / self.total_count 272 | if self.norm == 'l1': 273 | ece = torch.sum(abs_errors * rel_freq) 274 | elif self.norm == 'max': 275 | ece = torch.max(abs_errors) 276 | else: 277 | raise ValueError("Invalid norm. Supported norms are 'l1' and 'max'.") 278 | return ece 279 | 280 | 281 | class Candidate(object): 282 | """Candidate for ensembling.""" 283 | 284 | def __init__(self, candidate_id, candidate_file, candidate_run): 285 | self.id = candidate_id 286 | self.file = candidate_file 287 | self.run = candidate_run 288 | 289 | self._candidate_metrics = defaultdict(defaultdict) # 'test'/'val'/'ood' -> {metric -> value} 290 | 291 | def set_metrics(self, metrics, split): 292 | self._candidate_metrics[split] = metrics 293 | 294 | def get_metrics(self, split): 295 | return self._candidate_metrics[split] 296 | 297 | def get_single_metric(self, metric, split): 298 | return self._candidate_metrics[split][metric] 299 | 300 | def get_model_weights(self): 301 | m = torch.load(self.file, map_location=torch.device('cpu')) 302 | weights, _ = Utilities.split_weights_and_masks(m) 303 | return weights 304 | 305 | def enforce_prunedness(self, device): 306 | state_dict = torch.load(self.file, map_location=device) 307 | new_state_dict = OrderedDict() 308 | for key, val in state_dict.items(): 309 | v_new = val # Remains unchanged if not in _orig format 310 | if key.endswith("_orig"): 311 | # We loaded the _orig tensor and corresponding mask 312 | name = key.replace("_orig", "") # Truncate the "_orig" 313 | if f"{name}_mask" in state_dict.keys(): 314 | v_new = v_new * state_dict[f"{name}_mask"] 315 | new_state_dict[key] = v_new 316 | 317 | # Save the new state dict 318 | torch.save(new_state_dict, self.file) 319 | -------------------------------------------------------------------------------- /models/cifar10.py: -------------------------------------------------------------------------------- 1 | # =========================================================================== 2 | # Project: Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging - IOL Lab @ ZIB 3 | # Paper: arxiv.org/abs/2306.16788 4 | # File: models/cifar10.py 5 | # Description: CIFAR-10 Models 6 | # =========================================================================== 7 | 8 | 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from utilities.utilities import Utilities as Utils 13 | 14 | 15 | class CNN(nn.Module): 16 | def __init__(self): 17 | super(CNN, self).__init__() 18 | self.conv1 = nn.Conv2d(3, 32, 3, 1, bias=True) 19 | self.conv2 = nn.Conv2d(32, 128, 3, 1, bias=True) 20 | self.dropout1 = nn.Dropout(0.25) 21 | self.dropout2 = nn.Dropout(0.5) 22 | self.avg = nn.AvgPool2d(kernel_size=1, stride=1) 23 | self.fc1 = nn.Linear(128, 128) 24 | self.fc2 = nn.Linear(128, 10, bias=True) 25 | 26 | def forward(self, x): 27 | x = self.conv1(x) 28 | x = F.relu(x) 29 | x = self.conv2(x) 30 | x = F.relu(x) 31 | x = F.max_pool2d(x, 2) 32 | x = self.dropout1(x) 33 | x = self.avg(x) 34 | x = x.view(x.size(0), -1) 35 | x = self.fc1(x) 36 | x = F.relu(x) 37 | x = self.dropout2(x) 38 | x = self.fc2(x) 39 | output = F.log_softmax(x, dim=1) 40 | return output 41 | 42 | @staticmethod 43 | def get_permutation_spec(): 44 | conv = lambda name, p_in, p_out, bias=True: {f"{name}.weight": (p_out, p_in, None, None,), 45 | f"{name}.bias": (p_out,)} if bias else { 46 | f"{name}.weight": (p_out, p_in, None, None,)} 47 | dense = lambda name, p_in, p_out, bias=True: {f"{name}.weight": (p_out, p_in), 48 | f"{name}.bias": (p_out,)} if bias else { 49 | f"{name}.weight": (p_out, p_in)} 50 | return Utils.permutation_spec_from_axes_to_perm({ 51 | **conv("conv1", None, "P_bg0"), 52 | **conv("conv2", "P_bg0", "P_bg1", False), 53 | **dense("fc1", "P_bg1", "P_bg2"), 54 | **dense("fc2", "P_bg2", None, True), 55 | }) 56 | 57 | 58 | def ResNet56(): 59 | class ResNet(nn.Module): 60 | # Proper implementation of ResNet, taken from https://github.com/JJGO/shrinkbench/blob/master/models/cifar_resnet.py 61 | def __init__(self, block, num_blocks, num_classes=10): 62 | super(ResNet, self).__init__() 63 | self.in_planes = 16 64 | 65 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 66 | self.bn1 = nn.BatchNorm2d(16) 67 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 68 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 69 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 70 | self.linear = nn.Linear(64, num_classes) 71 | # self.linear.is_classifier = True # So layer is not pruned 72 | 73 | def _make_layer(self, block, planes, num_blocks, stride): 74 | strides = [stride] + [1] * (num_blocks - 1) 75 | layers = [] 76 | for stride in strides: 77 | layers.append(block(self.in_planes, planes, stride)) 78 | self.in_planes = planes * block.expansion 79 | 80 | return nn.Sequential(*layers) 81 | 82 | def forward(self, x): 83 | out = F.relu(self.bn1(self.conv1(x))) 84 | out = self.layer1(out) 85 | out = self.layer2(out) 86 | out = self.layer3(out) 87 | out = F.avg_pool2d(out, out.size()[3]) 88 | out = out.view(out.size(0), -1) 89 | out = self.linear(out) 90 | return out 91 | 92 | class LambdaLayer(nn.Module): 93 | def __init__(self, lambd): 94 | super(LambdaLayer, self).__init__() 95 | self.lambd = lambd 96 | 97 | def forward(self, x): 98 | return self.lambd(x) 99 | 100 | class BasicBlock(nn.Module): 101 | expansion = 1 102 | 103 | def __init__(self, in_planes, planes, stride=1, option='A'): 104 | super(BasicBlock, self).__init__() 105 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 106 | self.bn1 = nn.BatchNorm2d(planes) 107 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 108 | self.bn2 = nn.BatchNorm2d(planes) 109 | 110 | self.shortcut = nn.Sequential() 111 | if stride != 1 or in_planes != planes: 112 | if option == 'A': 113 | """ 114 | For CIFAR10 ResNet paper uses option A. 115 | """ 116 | self.shortcut = LambdaLayer(lambda x: 117 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes // 4, planes // 4), 118 | "constant", 0)) 119 | elif option == 'B': 120 | self.shortcut = nn.Sequential( 121 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 122 | nn.BatchNorm2d(self.expansion * planes) 123 | ) 124 | 125 | def forward(self, x): 126 | out = F.relu(self.bn1(self.conv1(x))) 127 | out = self.bn2(self.conv2(out)) 128 | out += self.shortcut(x) 129 | out = F.relu(out) 130 | return out 131 | 132 | model = ResNet(BasicBlock, [9, 9, 9], num_classes=10) 133 | return model 134 | 135 | 136 | def ResNet18(): 137 | # Based on https://github.com/charlieokonomiyaki/pytorch-resnet18-cifar10/blob/master/models/resnet.py 138 | class BasicBlock(nn.Module): 139 | expansion = 1 140 | 141 | def __init__(self, in_planes, planes, stride=1): 142 | super(BasicBlock, self).__init__() 143 | self.conv1 = nn.Conv2d( 144 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 145 | self.bn1 = nn.BatchNorm2d(planes) 146 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 147 | stride=1, padding=1, bias=False) 148 | self.bn2 = nn.BatchNorm2d(planes) 149 | 150 | self.shortcut = nn.Sequential() 151 | if stride != 1 or in_planes != self.expansion * planes: 152 | self.shortcut = nn.Sequential( 153 | nn.Conv2d(in_planes, self.expansion * planes, 154 | kernel_size=1, stride=stride, bias=False), 155 | nn.BatchNorm2d(self.expansion * planes) 156 | ) 157 | 158 | def forward(self, x): 159 | out = F.relu(self.bn1(self.conv1(x))) 160 | out = self.bn2(self.conv2(out)) 161 | out += self.shortcut(x) 162 | out = F.relu(out) 163 | return out 164 | 165 | class Bottleneck(nn.Module): 166 | expansion = 4 167 | 168 | def __init__(self, in_planes, planes, stride=1): 169 | super(Bottleneck, self).__init__() 170 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 171 | self.bn1 = nn.BatchNorm2d(planes) 172 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 173 | stride=stride, padding=1, bias=False) 174 | self.bn2 = nn.BatchNorm2d(planes) 175 | self.conv3 = nn.Conv2d(planes, self.expansion * 176 | planes, kernel_size=1, bias=False) 177 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 178 | 179 | self.shortcut = nn.Sequential() 180 | if stride != 1 or in_planes != self.expansion * planes: 181 | self.shortcut = nn.Sequential( 182 | nn.Conv2d(in_planes, self.expansion * planes, 183 | kernel_size=1, stride=stride, bias=False), 184 | nn.BatchNorm2d(self.expansion * planes) 185 | ) 186 | 187 | def forward(self, x): 188 | out = F.relu(self.bn1(self.conv1(x))) 189 | out = F.relu(self.bn2(self.conv2(out))) 190 | out = self.bn3(self.conv3(out)) 191 | out += self.shortcut(x) 192 | out = F.relu(out) 193 | return out 194 | 195 | class ResNet(nn.Module): 196 | def __init__(self, block, num_blocks, num_classes=10): 197 | super(ResNet, self).__init__() 198 | self.in_planes = 64 199 | 200 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, 201 | stride=1, padding=1, bias=False) 202 | self.bn1 = nn.BatchNorm2d(64) 203 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 204 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 205 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 206 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 207 | self.linear = nn.Linear(512 * block.expansion, num_classes) 208 | 209 | def _make_layer(self, block, planes, num_blocks, stride): 210 | strides = [stride] + [1] * (num_blocks - 1) 211 | layers = [] 212 | for stride in strides: 213 | layers.append(block(self.in_planes, planes, stride)) 214 | self.in_planes = planes * block.expansion 215 | return nn.Sequential(*layers) 216 | 217 | def forward(self, x): 218 | out = F.relu(self.bn1(self.conv1(x))) 219 | out = self.layer1(out) 220 | out = self.layer2(out) 221 | out = self.layer3(out) 222 | out = self.layer4(out) 223 | out = F.avg_pool2d(out, 4) 224 | out = out.view(out.size(0), -1) 225 | out = self.linear(out) 226 | return out 227 | 228 | model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=10) 229 | return model 230 | 231 | 232 | def VGG16(): 233 | return VGG(vgg_name='VGG16') 234 | 235 | 236 | class VGG(nn.Module): 237 | # Adapted from https://github.com/jaeho-lee/layer-adaptive-sparsity/blob/main/tools/models/vgg.py 238 | def __init__(self, vgg_name, use_bn=True): 239 | cfg = { 240 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 241 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 242 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 243 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 244 | 'M'], 245 | } 246 | 247 | super(VGG, self).__init__() 248 | self.features = self._make_layers(cfg[vgg_name], use_bn) 249 | self.classifier = nn.Linear(512, 10) 250 | 251 | def forward(self, x): 252 | out = self.features(x) 253 | out = out.view(out.size(0), -1) 254 | out = self.classifier(out) 255 | return out 256 | 257 | def _make_layers(self, cfg, use_bn): 258 | layers = [] 259 | in_channels = 3 260 | for x in cfg: 261 | if x == 'M': 262 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 263 | else: 264 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1)] 265 | if use_bn: 266 | layers += [nn.BatchNorm2d(x)] 267 | layers += [nn.ReLU(inplace=True)] 268 | in_channels = x 269 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 270 | return nn.Sequential(*layers) 271 | 272 | 273 | class WideResNet20(nn.Module): 274 | # WideResNet implementation but with widen_factor=2 and depth=22 instead of 10 and 28 respectively. 275 | # In Repo of Git-Rebasin, this is referred to as ResNet20 276 | def __init__(self, depth=22, widen_factor=10, dropout_rate=0.3, num_classes=10): 277 | super(WideResNet20, self).__init__() 278 | self.in_planes = 16 279 | 280 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 281 | n = (depth - 4) / 6 282 | k = widen_factor 283 | 284 | nStages = [16, 16 * k, 32 * k, 64 * k] 285 | 286 | class wide_basic(nn.Module): 287 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 288 | super(wide_basic, self).__init__() 289 | self.bn1 = nn.BatchNorm2d(in_planes) 290 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 291 | 292 | self.dropout = nn.Dropout(p=dropout_rate) 293 | self.bn2 = nn.BatchNorm2d(planes) 294 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 295 | 296 | self.shortcut = nn.Sequential() 297 | if stride != 1 or in_planes != planes: 298 | self.shortcut = nn.Sequential( 299 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 300 | ) 301 | 302 | def forward(self, x): 303 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 304 | out = self.conv2(F.relu(self.bn2(out))) 305 | out += self.shortcut(x) 306 | 307 | return out 308 | 309 | self.conv1 = self.conv3x3(3, nStages[0]) 310 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 311 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 312 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 313 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 314 | self.linear = nn.Linear(nStages[3], num_classes) 315 | 316 | def conv3x3(self, in_planes, out_planes, stride=1): 317 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 318 | 319 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 320 | strides = [stride] + [1] * (int(num_blocks) - 1) 321 | layers = [] 322 | 323 | for stride in strides: 324 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 325 | self.in_planes = planes 326 | 327 | return nn.Sequential(*layers) 328 | 329 | def forward(self, x): 330 | out = self.conv1(x) 331 | out = self.layer1(out) 332 | out = self.layer2(out) 333 | out = self.layer3(out) 334 | out = F.relu(self.bn1(out)) 335 | out = F.avg_pool2d(out, 8) 336 | out = out.view(out.size(0), -1) 337 | out = self.linear(out) 338 | 339 | return out 340 | 341 | @staticmethod 342 | def get_permutation_spec(): 343 | conv = lambda name, p_in, p_out, bias=True: {f"{name}.weight": (p_out, p_in, None, None,), 344 | f"{name}.bias": (p_out,)} if bias else { 345 | f"{name}.weight": (p_out, p_in, None, None,)} 346 | norm = lambda name, p: {f"{name}.weight": (p,), f"{name}.bias": (p,)} 347 | 348 | dense = lambda name, p_in, p_out, bias=True: {f"{name}.weight": (p_out, p_in), 349 | f"{name}.bias": (p_out,)} if bias else { 350 | f"{name}.weight": (p_out, p_in)} 351 | 352 | # This is for easy blocks that use a residual connection, without any change in the number of channels. 353 | easyblock = lambda name, p: { 354 | **norm(f"{name}.bn1", p), 355 | **conv(f"{name}.conv1", p, f"P_{name}_inner"), 356 | **norm(f"{name}.bn2", f"P_{name}_inner"), 357 | **conv(f"{name}.conv2", f"P_{name}_inner", p), 358 | } 359 | 360 | # This is for blocks that use a residual connection, but change the number of channels via a Conv. 361 | shortcutblock = lambda name, p_in, p_out: { 362 | **norm(f"{name}.bn1", p_in), 363 | **conv(f"{name}.conv1", p_in, f"P_{name}_inner"), 364 | **norm(f"{name}.bn2", f"P_{name}_inner"), 365 | **conv(f"{name}.conv2", f"P_{name}_inner", p_out), 366 | **conv(f"{name}.shortcut.0", p_in, p_out), 367 | # **norm(f"{name}.shortcut.1", p_out), # Removed this since not occuring in state dict 368 | } 369 | 370 | return Utils.permutation_spec_from_axes_to_perm({ 371 | **conv("conv1", None, "P_bg0"), 372 | # 373 | **shortcutblock("layer1.0", "P_bg0", "P_bg1"), 374 | **easyblock("layer1.1", "P_bg1", ), 375 | **easyblock("layer1.2", "P_bg1"), 376 | # **easyblock("layer1.3", "P_bg1"), 377 | 378 | **shortcutblock("layer2.0", "P_bg1", "P_bg2"), 379 | **easyblock("layer2.1", "P_bg2", ), 380 | **easyblock("layer2.2", "P_bg2"), 381 | # **easyblock("layer2.3", "P_bg2"), 382 | 383 | **shortcutblock("layer3.0", "P_bg2", "P_bg3"), 384 | **easyblock("layer3.1", "P_bg3", ), 385 | **easyblock("layer3.2", "P_bg3"), 386 | # **easyblock("layer3.3", "P_bg3"), 387 | 388 | **norm("bn1", "P_bg3"), 389 | 390 | **dense("linear", "P_bg3", None), 391 | 392 | }) 393 | -------------------------------------------------------------------------------- /runners/ensembleRunner.py: -------------------------------------------------------------------------------- 1 | # =========================================================================== 2 | # Project: Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging - IOL Lab @ ZIB 3 | # Paper: arxiv.org/abs/2306.16788 4 | # File: ensembleRunner.py 5 | # Description: Runner class for starting from pruned models 6 | # =========================================================================== 7 | import itertools 8 | import json 9 | import math 10 | import os 11 | import sys 12 | import warnings 13 | from collections import OrderedDict 14 | from typing import List 15 | 16 | import numpy as np 17 | import torch 18 | import wandb 19 | from torch.cuda.amp import autocast 20 | from torchmetrics.classification import MulticlassAccuracy as Accuracy 21 | from tqdm.auto import tqdm 22 | 23 | from runners.baseRunner import baseRunner 24 | from strategies import ensembleStrategies 25 | from utilities.utilities import Utilities as Utils, WorstClassAccuracy, CalibrationError, Candidate 26 | 27 | 28 | class ensembleRunner(baseRunner): 29 | 30 | def __init__(self, **kwargs): 31 | super().__init__(**kwargs) 32 | 33 | 34 | def find_multiple_existing_models(self, filterDict): 35 | """Finds existing wandb runs and downloads the model files.""" 36 | current_phase = self.config.phase # We are in the same phase 37 | filterDict['$and'].append({'config.phase': current_phase}) 38 | filterDict['$and'].append({'config.n_splits_total': self.config.n_splits_total}) 39 | if current_phase > 1: 40 | # We need to specify the previous ensemble method as well 41 | filterDict['$and'].append({'config.ensemble_method': self.config.ensemble_method}) 42 | 43 | filterDict['$and'].append({'config.ensemble_by': self.config.ensemble_by}) 44 | entity, project = wandb.run.entity, wandb.run.project 45 | api = wandb.Api() 46 | candidate_model_list = [] 47 | 48 | # Some variables have to be extracted from the filterDict and checked manually, e.g. weight decay in scientific format 49 | manualVariables = ['weight_decay', 'penalty', 'group_penalty'] 50 | manVarDict = {} 51 | dropIndices = [] 52 | for var in manualVariables: 53 | for i in range(len(filterDict['$and'])): 54 | entry = filterDict['$and'][i] 55 | s = f"config.{var}" 56 | if s in entry: 57 | dropIndices.append(i) 58 | manVarDict[var] = entry[s] 59 | for idx in reversed(sorted(dropIndices)): filterDict['$and'].pop(idx) 60 | 61 | runs = api.runs(f"{entity}/{project}", filters=filterDict) 62 | runsExist = False # If True, then there exist runs that try to set a fixed init 63 | for run in runs: 64 | if run.state != 'finished': 65 | # Ignore this run 66 | continue 67 | # Check if run satisfies the manual variables 68 | conflict = False 69 | for var, val in manVarDict.items(): 70 | if var in run.config and run.config[var] != val: 71 | conflict = True 72 | break 73 | if conflict: 74 | continue 75 | sys.stdout.write(f"Trying to access {run.name}.\n") 76 | checkpoint_file = run.summary.get('final_model_file') 77 | try: 78 | if checkpoint_file is not None: 79 | runsExist = True 80 | sys.stdout.write( 81 | f"Downloading pruned model with split {run.config['ensemble_by']} value: {run.config['split_val']}.\n") 82 | run.file(checkpoint_file).download( 83 | root=self.tmp_dir) 84 | self.seed = run.config['seed'] 85 | candidate_id = (run.config['split_val']) 86 | candidate_model_list.append( 87 | Candidate(candidate_id=candidate_id, candidate_file=os.path.join(self.tmp_dir, checkpoint_file), 88 | candidate_run=run)) 89 | except Exception as e: # The run is online, but the model is not uploaded yet -> results in failing runs 90 | print(e) 91 | checkpoint_file = None 92 | break 93 | assert not ( 94 | runsExist and checkpoint_file is None), "Runs found, but one of them has no model available -> abort." 95 | outputStr = f"Found {len(candidate_model_list)} pruned models with split vals {sorted([c.id for c in candidate_model_list])}" \ 96 | if checkpoint_file is not None else "Nothing found." 97 | sys.stdout.write(f"Trying to find reference pruned models in project: {outputStr}\n") 98 | assert checkpoint_file is not None, "One of the pruned models has no model file to download, Aborting." 99 | assert len(candidate_model_list) == self.config.n_splits_total, "Not all pruned models were found, Aborting.\n" 100 | 101 | return candidate_model_list 102 | 103 | def define_optimizer_scheduler(self): 104 | # Define the optimizer 105 | if self.config.optimizer == 'SGD': 106 | self.optimizer = torch.optim.SGD(params=self.model.parameters(), lr=0.) 107 | 108 | def transport_information(self, ref_run): 109 | missing_config_keys = ['momentum', 110 | 'n_epochs_warmup', 111 | 'n_epochs'] # Have to have n_epochs even though it might be specified, otherwise ALLR doesnt have this 112 | 113 | additional_dict = { 114 | 'last_training_lr': ref_run.summary['final.learning_rate'], 115 | 'final.test.accuracy': ref_run.summary['final.test']['accuracy'], 116 | 'final.train.accuracy': ref_run.summary['final.train']['accuracy'], 117 | 'final.train.loss': ref_run.summary['final.train']['loss'], 118 | } 119 | for key in missing_config_keys: 120 | if key not in self.config or self.config[key] is None: 121 | # Allow_val_change = true because e.g. momentum defaults to None, but shouldn't be passed here 122 | val = ref_run.config.get(key) # If not found, defaults to None 123 | self.config.update({key: val}, allow_val_change=True) 124 | self.config.update(additional_dict) 125 | 126 | self.trained_test_accuracy = additional_dict['final.test.accuracy'] 127 | self.trained_train_loss = additional_dict['final.train.loss'] 128 | self.trained_train_accuracy = additional_dict['final.train.accuracy'] 129 | 130 | # Get the wandb information about lr and fill the corresponding strategy dicts, which can then be used by rewinders 131 | f = ref_run.file('iteration-lr-dict.json').download(root=self.tmp_dir) 132 | with open(f.name) as json_file: 133 | loaded_dict = json.load(json_file) 134 | lr_dict = OrderedDict(loaded_dict) 135 | # Upload iteration-lr dict from self.strategy to be used during retraining 136 | Utils.dump_dict_to_json_wandb(dumpDict=lr_dict, name='iteration-lr-dict') 137 | 138 | def load_soup_model(self, ensemble_state_dict): 139 | # Save the ensemble state dict 140 | fName = f"ensemble_model.pt" 141 | fPath = os.path.join(self.tmp_dir, fName) 142 | torch.save(ensemble_state_dict, fPath) # Save the state_dict 143 | self.checkpoint_file = fName 144 | 145 | # Actually load the model 146 | self.model = self.get_model(reinit=True, temporary=True) # Load the ensembled model 147 | 148 | def evaluate_soup(self, data='val', ensemble_labels: torch.Tensor = None): 149 | # Perform an evaluation pass 150 | AccuracyMeter = Accuracy(num_classes=self.n_classes).to(device=self.device) 151 | ECEMeter = CalibrationError(norm='l1').to(device=self.device) 152 | MCEMeter = CalibrationError(norm='max').to(device=self.device) 153 | WorstClassAccuracyMeter = WorstClassAccuracy(num_classes=self.n_classes).to(device=self.device) 154 | 155 | if data == 'val': 156 | loader = self.valLoader 157 | elif data == 'test': 158 | loader = self.testLoader 159 | elif data == 'ood': 160 | loader = self.oodLoader 161 | if loader is None: 162 | sys.stdout.write(f"No OOD data found, skipping OOD evaluation.\n") 163 | return {} 164 | else: 165 | raise NotImplementedError 166 | 167 | if ensemble_labels is not None: 168 | sys.stdout.write(f"Performing computation of prediction ensemble {data} accuracy.\n") 169 | else: 170 | sys.stdout.write(f"Performing computation of soup {data} accuracy.\n") 171 | with tqdm(loader, leave=True) as pbar: 172 | for x_input, y_target, indices in pbar: 173 | # Move to CUDA if possible 174 | x_input = x_input.to(self.device, non_blocking=True) 175 | indices = indices.to(self.device, non_blocking=True) 176 | if ensemble_labels is not None: 177 | y_target = ensemble_labels[indices] # Avg probs/predictions of batch 178 | y_target = y_target.to(self.device, non_blocking=True) 179 | 180 | with autocast(enabled=(self.config.use_amp is True)): 181 | output = self.model.train(mode=False)(x_input) 182 | AccuracyMeter(output, y_target) 183 | ECEMeter(output, y_target) 184 | MCEMeter(output, y_target) 185 | WorstClassAccuracyMeter(output, y_target) 186 | 187 | outputDict = { 188 | 'accuracy': AccuracyMeter.compute().item(), 189 | 'ece': ECEMeter.compute().item(), 190 | 'mce': MCEMeter.compute().item(), 191 | 'worst_class_accuracy': WorstClassAccuracyMeter.compute().item(), 192 | } 193 | return outputDict 194 | 195 | @torch.no_grad() 196 | def collect_avg_output_full(self, data: str, candidate_model_list: List[Candidate]): 197 | output_type = 'soft_prediction' 198 | assert data in ['val', 'test'] 199 | if data == 'val': 200 | loader = self.valLoader 201 | else: 202 | loader = self.testLoader 203 | sys.stdout.write(f"\nCollecting ensemble prediction.\n") 204 | 205 | compute_avg_probs = (output_type in ['softmax', 'soft_prediction']) 206 | store_tensor = torch.zeros(len(loader.dataset), self.n_classes, device=self.device) # On CUDA for now 207 | 208 | for candidate in candidate_model_list: 209 | # Load the candidate model 210 | candidate_id, candidate_file = candidate.id, candidate.file 211 | if self.model is not None: 212 | del self.model 213 | torch.cuda.empty_cache() 214 | 215 | state_dict = torch.load(candidate_file, 216 | map_location=torch.device('cpu')) 217 | self.load_soup_model(ensemble_state_dict=state_dict) 218 | with tqdm(loader, leave=True) as pbar: 219 | for x_input, _, indices in pbar: 220 | x_input = x_input.to(self.device, non_blocking=True) # Move to CUDA if possible 221 | with autocast(enabled=(self.config.use_amp is True)): 222 | output = self.model.eval()(x_input) # Logits 223 | probabilities = torch.nn.functional.softmax(output, dim=1) # Softmax(Logits) 224 | if compute_avg_probs: 225 | # Just add the probabilities for the average 226 | store_tensor[indices] += probabilities 227 | else: 228 | # Add the prediction as one hot 229 | binary_tensor = torch.zeros_like(store_tensor[indices]) 230 | # Add the ones at corresponding entries 231 | binary_tensor[torch.arange(binary_tensor.size(0)).unsqueeze(1), torch.argmax(probabilities, 232 | dim=1).unsqueeze( 233 | 1)] = 1. 234 | 235 | store_tensor[indices] += binary_tensor 236 | 237 | if compute_avg_probs: 238 | store_tensor.mul_(1. / len(candidate_model_list)) # Weighting 239 | else: 240 | assert store_tensor.sum() == (len(candidate_model_list) * len(loader.dataset)) 241 | 242 | if output_type in ['soft_prediction', 'hard_prediction']: 243 | # Take the prediction given average probabilities OR Take the most frequent prediction 244 | store_tensor = torch.argmax(store_tensor, dim=1) 245 | 246 | return store_tensor 247 | 248 | def run(self): 249 | """Function controlling the workflow of pretrainedRunner""" 250 | assert self.config.ensemble_by in ['pruned_seed', 'weight_decay', 'retrain_length', 'retrain_schedule'] 251 | assert self.config.n_splits_total is not None 252 | assert self.config.split_val is None 253 | 254 | # Find the reference run 255 | filterDict = {"$and": [{"config.run_id": self.config.run_id}, 256 | {"config.arch": self.config.arch}, 257 | {"config.optimizer": self.config.optimizer}, 258 | {"config.goal_sparsity": self.config.goal_sparsity}, 259 | {"config.n_epochs_per_phase": self.config.n_epochs_per_phase}, 260 | {"config.n_phases": self.config.n_phases}, 261 | {"config.retrain_schedule": self.config.retrain_schedule}, 262 | {"config.strategy": 'IMP'}, 263 | ]} 264 | 265 | if self.config.learning_rate is not None: 266 | warnings.warn( 267 | "You specified an explicit learning rate for retraining. Note that this only controls the selection of the pretrained model.") 268 | filterDict["$and"].append({"config.learning_rate": self.config.learning_rate}) 269 | if self.config.n_epochs is not None: 270 | warnings.warn( 271 | "You specified n_epochs for retraining. Note that this only controls the selection of the pretrained model.") 272 | filterDict["$and"].append({"config.n_epochs": self.config.n_epochs}) 273 | 274 | candidate_models = self.find_multiple_existing_models(filterDict=filterDict) 275 | wandb.config.update({'seed': self.seed}) # Push the seed to wandb 276 | 277 | # Set a unique random seed 278 | np.random.seed(self.seed) 279 | torch.manual_seed(self.seed) 280 | # Remark: If you are working with a multi-GPU model, this function is insufficient to get determinism. To seed all GPUs, use manual_seed_all(). 281 | torch.cuda.manual_seed(self.seed) # This works if CUDA not available 282 | 283 | torch.backends.cudnn.benchmark = True 284 | 285 | self.transport_information(ref_run=candidate_models[0].run) 286 | 287 | self.trainLoader, self.valLoader, self.testLoader, self.trainLoader_unshuffled = self.get_dataloaders() 288 | self.oodLoader = self.get_ood_dataloaders() 289 | 290 | # We first define the ensembling strategy, create the ensemble, then use the 'Dense' strategy and regularly 291 | # load the model 292 | # Define callbacks finetuning_callback, restore_callback, save_model_callback 293 | callbackDict = { 294 | 'final_log_callback': self.final_log, 295 | 'soup_evaluation_callback': self.evaluate_soup, 296 | 'load_soup_callback': self.load_soup_model, 297 | 'recalibrate_bn_callback': self.recalibrate_bn, 298 | } 299 | self.ensemble_strategy = getattr(ensembleStrategies, self.config.ensemble_method)(model=None, 300 | n_classes=self.n_classes, 301 | config=self.config, 302 | candidate_models=candidate_models, 303 | runner=self, 304 | callbacks=callbackDict) 305 | 306 | self.ensemble_strategy.collect_candidate_information() 307 | 308 | # Create ensemble 309 | ensemble_state_dict = self.ensemble_strategy.create_ensemble() 310 | 311 | # Save the ensemble state dict 312 | fName = f"ensemble_model.pt" 313 | fPath = os.path.join(self.tmp_dir, fName) 314 | torch.save(ensemble_state_dict, fPath) # Save the state_dict 315 | self.checkpoint_file = fName 316 | 317 | # Actually load the model 318 | self.model = self.get_model(reinit=True, temporary=True) # Load the ensembled model 319 | 320 | # Create 'Dense' as the Base Strategy 321 | self.strategy = self.define_strategy(use_dense_base=True) 322 | self.strategy.after_initialization() 323 | 324 | # Define optimizer to not get errors in the main evaluation (even though we do not actually use the optimizer) 325 | self.define_optimizer_scheduler() 326 | 327 | # Evaluate ensemble 328 | self.ensemble_strategy.final() 329 | 330 | self.checkpoint_file = self.save_model(model_type='ensemble') 331 | wandb.summary['final_model_file'] = f"ensemble_model_{self.config.ensemble_method}_{self.config.phase}.pt" 332 | -------------------------------------------------------------------------------- /runners/baseRunner.py: -------------------------------------------------------------------------------- 1 | # =========================================================================== 2 | # Project: Sparse Model Soups: A Recipe for Improved Pruning via Model Averaging - IOL Lab @ ZIB 3 | # Paper: arxiv.org/abs/2306.16788 4 | # File: baseRunner.py 5 | # Description: Base Runner class, all other runners inherit from this one 6 | # =========================================================================== 7 | import importlib 8 | import os 9 | import sys 10 | import time 11 | from collections import OrderedDict 12 | from math import sqrt 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn.utils.prune as prune 17 | import wandb 18 | from torch.cuda.amp import autocast 19 | from torchmetrics import MeanMetric 20 | from torchmetrics.classification import MulticlassAccuracy as Accuracy 21 | from tqdm.auto import tqdm 22 | 23 | from config import datasetDict, trainTransformDict, testTransformDict 24 | from metrics import metrics 25 | from strategies import strategies as usual_strategies 26 | from utilities.lr_schedulers import SequentialSchedulers, FixedLR 27 | from utilities.utilities import Utilities as Utils 28 | from utilities.utilities import WorstClassAccuracy, CalibrationError 29 | 30 | 31 | class baseRunner: 32 | """Base class for all runners, defines the general functions""" 33 | 34 | def __init__(self, config): 35 | 36 | self.config = config 37 | self.dataParallel = (torch.cuda.device_count() > 1) 38 | if not self.dataParallel: 39 | self.device = torch.device(config.device) 40 | if 'gpu' in config.device: 41 | torch.cuda.set_device(self.device) 42 | else: 43 | # Use all visible GPUs 44 | self.device = torch.device("cuda:0") 45 | torch.cuda.device(self.device) 46 | 47 | # Set a couple useful variables 48 | self.checkpoint_file = None 49 | self.trained_test_accuracy = None 50 | self.trained_train_loss = None 51 | self.trained_train_accuracy = None 52 | self.after_pruning_metrics = None 53 | self.seed = None 54 | self.squared_model_norm = None 55 | self.n_warmup_epochs = None 56 | self.trainIterationCtr = 1 57 | self.tmp_dir = config['tmp_dir'] 58 | sys.stdout.write(f"Using temporary directory {self.tmp_dir}.\n") 59 | self.ampGradScaler = None # Note: this must be reset before training, and before retraining 60 | self.num_workers = None 61 | 62 | # Variables to be set by inheriting classes 63 | self.strategy = None 64 | self.ensemble_strategy = None 65 | self.trainLoader = None 66 | self.valLoader = None 67 | self.testLoader = None 68 | self.trainLoader_unshuffled = None 69 | self.oodLoader = None 70 | self.n_datapoints = None 71 | self.model = None 72 | self.dense_model = None 73 | self.wd_scheduler = None 74 | self.trainData = None 75 | self.n_total_iterations = None 76 | 77 | self.ultimate_log_dict = None 78 | 79 | if self.config.dataset in ['mnist', 'cifar10']: 80 | self.n_classes = 10 81 | elif self.config.dataset in ['cifar100']: 82 | self.n_classes = 100 83 | elif self.config.dataset in ['tinyimagenet']: 84 | self.n_classes = 200 85 | elif self.config.dataset in ['imagenet']: 86 | self.n_classes = 1000 87 | else: 88 | raise NotImplementedError 89 | 90 | # Define the loss object and metrics 91 | # Important note: for the correct computation of loss/accuracy it's important to have reduction == 'mean' 92 | self.loss_criterion = torch.nn.CrossEntropyLoss(reduction='mean').to(device=self.device) 93 | 94 | self.metrics = {mode: {'loss': MeanMetric().to(device=self.device), 95 | 'accuracy': Accuracy(num_classes=self.n_classes).to(device=self.device), 96 | 'ips_throughput': MeanMetric().to(device=self.device)} 97 | for mode in ['train', 'val', 'test', 'ood']} 98 | for mode in ['val', 'test', 'ood']: 99 | self.metrics[mode]['ece'] = CalibrationError(norm='l1').to(device=self.device) 100 | self.metrics[mode]['mce'] = CalibrationError(norm='max').to(device=self.device) 101 | self.metrics[mode]['worst_class_accuracy'] = WorstClassAccuracy(num_classes=self.n_classes).to( 102 | device=self.device) 103 | 104 | def reset_averaged_metrics(self): 105 | """Resets all metrics""" 106 | for mode in self.metrics.keys(): 107 | for metric in self.metrics[mode].values(): 108 | metric.reset() 109 | 110 | def get_metrics(self): 111 | with torch.no_grad(): 112 | n_total, n_nonzero = metrics.get_parameter_count(model=self.model) 113 | 114 | x_input, y_target, indices = next(iter(self.valLoader)) 115 | x_input, y_target = x_input.to(self.device), y_target.to(self.device) # Move to CUDA if possible 116 | n_flops, n_nonzero_flops = metrics.get_flops(model=self.model, x_input=x_input) 117 | 118 | distance_to_pruned, rel_distance_to_pruned = {}, {} 119 | if self.config.goal_sparsity is not None: 120 | distance_to_pruned, rel_distance_to_pruned = metrics.get_distance_to_pruned(model=self.model, 121 | sparsity=self.config.goal_sparsity) 122 | 123 | soup_metrics = self.ensemble_strategy.get_ensemble_metrics() if self.ensemble_strategy is not None else {} 124 | loggingDict = dict( 125 | train={metric_name: metric.compute() for metric_name, metric in self.metrics['train'].items() if 126 | getattr(metric, 'mode', True) is not None}, # Check if metric computable 127 | val={metric_name: metric.compute() for metric_name, metric in self.metrics['val'].items()}, 128 | global_sparsity=metrics.global_sparsity(module=self.model), 129 | modular_sparsity=metrics.modular_sparsity(parameters_to_prune=self.strategy.parameters_to_prune), 130 | n_total_params=n_total, 131 | n_nonzero_params=n_nonzero, 132 | nonzero_inference_flops=n_nonzero_flops, 133 | baseline_inference_flops=n_flops, 134 | theoretical_speedup=metrics.get_theoretical_speedup(n_flops=n_flops, n_nonzero_flops=n_nonzero_flops), 135 | learning_rate=float(self.optimizer.param_groups[0]['lr']), 136 | distance_to_origin=metrics.get_distance_to_origin(self.model), 137 | distance_to_pruned=distance_to_pruned, 138 | rel_distance_to_pruned=rel_distance_to_pruned, 139 | soup_metrics=soup_metrics, 140 | ) 141 | 142 | for split in ['test', 'ood']: 143 | loggingDict[split] = dict() 144 | for metric_name, metric in self.metrics[split].items(): 145 | try: 146 | # Catch case where MeanMetric mode not set yet 147 | loggingDict[split][metric_name] = metric.compute() 148 | except Exception as e: 149 | continue 150 | 151 | return loggingDict 152 | 153 | def get_dataset_root(self, dataset_name: str) -> str: 154 | """Copies the dataset and returns the rootpath.""" 155 | # Determine where the data lies 156 | for root in ['/software/pytorch_datasets/', './datasets_pytorch/']: 157 | rootPath = f"{root}{dataset_name}" 158 | if os.path.isdir(rootPath): 159 | break 160 | 161 | return rootPath 162 | 163 | def get_ood_dataloaders(self): 164 | if self.config.dataset == 'cifar10': 165 | ood_dataset_name = 'CIFAR10CORRUPT' 166 | elif self.config.dataset == 'cifar100': 167 | ood_dataset_name = 'CIFAR100CORRUPT' 168 | else: 169 | return None 170 | 171 | sys.stdout.write(f"Loading {ood_dataset_name} dataset for OOD performance.\n") 172 | ood_root = self.get_dataset_root(ood_dataset_name) 173 | ood_dataset = Utils.get_overloaded_dataset(datasetDict[ood_dataset_name])(root=ood_root, 174 | transform=testTransformDict[ 175 | self.config.dataset]) 176 | ood_loader = torch.utils.data.DataLoader(ood_dataset, batch_size=self.config.batch_size, shuffle=False, 177 | pin_memory=torch.cuda.is_available(), num_workers=self.num_workers) 178 | 179 | return ood_loader 180 | 181 | def get_dataloaders(self): 182 | rootPath = self.get_dataset_root(dataset_name=self.config.dataset) 183 | 184 | if self.config.dataset in ['imagenet']: 185 | trainData = Utils.get_overloaded_dataset(datasetDict[self.config.dataset])(root=rootPath, split='train', 186 | transform=trainTransformDict[ 187 | self.config.dataset]) 188 | testData = Utils.get_overloaded_dataset(datasetDict[self.config.dataset])(root=rootPath, split='val', 189 | transform=testTransformDict[ 190 | self.config.dataset]) 191 | elif self.config.dataset == 'tinyimagenet': 192 | traindir = os.path.join(rootPath, 'train') 193 | valdir = os.path.join(rootPath, 'val') 194 | trainData = Utils.get_overloaded_dataset(datasetDict[self.config.dataset])(root=traindir, 195 | transform=trainTransformDict[ 196 | self.config.dataset]) 197 | testData = Utils.get_overloaded_dataset(datasetDict[self.config.dataset])(root=valdir, 198 | transform=testTransformDict[ 199 | self.config.dataset]) 200 | else: 201 | trainData = Utils.get_overloaded_dataset(datasetDict[self.config.dataset])(root=rootPath, train=True, 202 | download=True, 203 | transform=trainTransformDict[ 204 | self.config.dataset]) 205 | 206 | testData = Utils.get_overloaded_dataset(datasetDict[self.config.dataset])(root=rootPath, train=False, 207 | transform=testTransformDict[ 208 | self.config.dataset]) 209 | train_size = int(0.9 * len(trainData)) 210 | val_size = len(trainData) - train_size 211 | self.trainData, valData = torch.utils.data.random_split(trainData, [train_size, val_size], 212 | generator=torch.Generator().manual_seed(42)) 213 | self.n_datapoints = train_size 214 | 215 | if self.config.dataset in ['imagenet', 'cifar100', 'tinyimagenet']: 216 | self.num_workers = 4 * torch.cuda.device_count() if torch.cuda.is_available() else 0 217 | else: 218 | self.num_workers = 2 if torch.cuda.is_available() else 0 219 | 220 | trainLoader = torch.utils.data.DataLoader(self.trainData, batch_size=self.config.batch_size, shuffle=True, 221 | pin_memory=torch.cuda.is_available(), num_workers=self.num_workers) 222 | trainLoader_unshuffled = torch.utils.data.DataLoader(self.trainData, batch_size=self.config.batch_size, 223 | shuffle=False, 224 | pin_memory=torch.cuda.is_available(), 225 | num_workers=self.num_workers) 226 | valLoader = torch.utils.data.DataLoader(valData, batch_size=self.config.batch_size, shuffle=False, 227 | pin_memory=torch.cuda.is_available(), num_workers=self.num_workers) 228 | testLoader = torch.utils.data.DataLoader(testData, batch_size=self.config.batch_size, shuffle=False, 229 | pin_memory=torch.cuda.is_available(), num_workers=self.num_workers) 230 | 231 | return trainLoader, valLoader, testLoader, trainLoader_unshuffled 232 | 233 | def get_model(self, reinit: bool, temporary: bool = True) -> torch.nn.Module: 234 | if reinit: 235 | # Define the model 236 | model = getattr(importlib.import_module('models.' + self.config.dataset), self.config.arch)() 237 | else: 238 | # The model has been initialized already 239 | model = self.model 240 | 241 | file = self.checkpoint_file 242 | masks = None 243 | if file is not None: 244 | dir = wandb.run.dir if not temporary else self.tmp_dir 245 | fPath = os.path.join(dir, file) 246 | 247 | state_dict = torch.load(fPath, map_location=self.device) 248 | 249 | new_state_dict = OrderedDict() 250 | masks = OrderedDict() 251 | mask_module_names = [] 252 | require_DP_format = isinstance(model, 253 | torch.nn.DataParallel) # If true, ensure all keys start with "module." 254 | for k, v in state_dict.items(): 255 | is_in_DP_format = k.startswith("module.") 256 | if require_DP_format and is_in_DP_format: 257 | name = k 258 | elif require_DP_format and not is_in_DP_format: 259 | name = "module." + k # Add 'module' prefix 260 | elif not require_DP_format and is_in_DP_format: 261 | name = k[7:] # Remove 'module.' 262 | elif not require_DP_format and not is_in_DP_format: 263 | name = k 264 | 265 | v_new = v # Remains unchanged if not in _orig format 266 | if k.endswith("_orig"): 267 | # We loaded the _orig tensor and corresponding mask 268 | name = name[:-5] # Truncate the "_orig" 269 | if f"{k[:-5]}_mask" in state_dict.keys(): 270 | # Split name into the modules name and the param_type (i.e. weight, bias or similar) 271 | module_name, param_type = name.rsplit(".", 1) 272 | 273 | masks[(module_name, param_type)] = state_dict[f"{k[:-5]}_mask"] 274 | mask_module_names.append(module_name) 275 | 276 | new_state_dict[name] = v_new 277 | 278 | maskKeys = [k for k in new_state_dict.keys() if k.endswith("_mask")] 279 | for k in maskKeys: 280 | del new_state_dict[k] 281 | 282 | # Load the state_dict 283 | model.load_state_dict(new_state_dict) 284 | 285 | module_dict = {} 286 | for name, module in model.named_modules(): 287 | if name in mask_module_names: 288 | module_dict[name] = module 289 | 290 | if self.dataParallel and reinit and not isinstance(model, 291 | torch.nn.DataParallel): # Only apply DataParallel when re-initializing the model! 292 | # We use DataParallelism 293 | model = torch.nn.DataParallel(model) 294 | model = model.to(device=self.device) 295 | 296 | # We reinforce the previous pruning 297 | if masks is not None: 298 | for (module_name, param_type), v in masks.items(): 299 | module = module_dict[module_name] 300 | v = v.to(self.device) 301 | prune.custom_from_mask(module, name=param_type, mask=v) 302 | 303 | return model 304 | 305 | def define_optimizer_scheduler(self): 306 | # Learning rate scheduler in the form (type, kwargs) 307 | tupleStr = self.config.learning_rate.strip() 308 | # Remove parenthesis 309 | if tupleStr[0] == '(': 310 | tupleStr = tupleStr[1:] 311 | if tupleStr[-1] == ')': 312 | tupleStr = tupleStr[:-1] 313 | name, *kwargs = tupleStr.split(',') 314 | if name in ['StepLR', 'MultiStepLR', 'ExponentialLR', 'Linear', 'Cosine', 'Constant']: 315 | scheduler = (name, kwargs) 316 | self.initial_lr = float(kwargs[0]) 317 | else: 318 | raise NotImplementedError(f"LR Scheduler {name} not implemented.") 319 | 320 | # Define the optimizer 321 | if self.config.optimizer == 'SGD': 322 | wd = self.config['weight_decay'] or 0. 323 | self.optimizer = torch.optim.SGD(params=self.model.parameters(), lr=self.initial_lr, 324 | momentum=self.config.momentum, 325 | weight_decay=wd, nesterov=wd > 0.) 326 | 327 | # We define a scheduler. All schedulers work on a per-iteration basis 328 | iterations_per_epoch = len(self.trainLoader) 329 | n_total_iterations = iterations_per_epoch * self.config.n_epochs 330 | self.n_total_iterations = n_total_iterations 331 | n_warmup_iterations = 0 332 | 333 | # Set the initial learning rate 334 | for param_group in self.optimizer.param_groups: param_group['lr'] = self.initial_lr 335 | 336 | # Define the warmup scheduler if needed 337 | warmup_scheduler, milestone = None, None 338 | if self.config.n_epochs_warmup and self.config.n_epochs_warmup > 0: 339 | assert int( 340 | self.config.n_epochs_warmup) == self.config.n_epochs_warmup, "At the moment no float warmup allowed." 341 | n_warmup_iterations = int(float(self.config.n_epochs_warmup) * iterations_per_epoch) 342 | # As a start factor we use 1e-20, to avoid division by zero when putting 0. 343 | warmup_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer=self.optimizer, 344 | start_factor=1e-20, end_factor=1., 345 | total_iters=n_warmup_iterations) 346 | milestone = n_warmup_iterations + 1 347 | 348 | n_remaining_iterations = n_total_iterations - n_warmup_iterations 349 | 350 | name, kwargs = scheduler 351 | scheduler = None 352 | if name == 'Constant': 353 | scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer=self.optimizer, 354 | factor=1.0, 355 | total_iters=n_remaining_iterations) 356 | elif name == 'StepLR': 357 | # Tuple of form ('StepLR', initial_lr, step_size, gamma) 358 | # Reduces initial_lr by gamma every step_size epochs 359 | step_size, gamma = int(kwargs[1]), float(kwargs[2]) 360 | 361 | # Convert to iterations 362 | step_size = iterations_per_epoch * step_size 363 | 364 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer=self.optimizer, step_size=step_size, 365 | gamma=gamma) 366 | elif name == 'MultiStepLR': 367 | # Tuple of form ('MultiStepLR', initial_lr, milestones, gamma) 368 | # Reduces initial_lr by gamma every epoch that is in the list milestones 369 | milestones, gamma = kwargs[1].strip(), float(kwargs[2]) 370 | # Remove square bracket 371 | if milestones[0] == '[': 372 | milestones = milestones[1:] 373 | if milestones[-1] == ']': 374 | milestones = milestones[:-1] 375 | # Convert to iterations directly 376 | milestones = [int(ms) * iterations_per_epoch for ms in milestones.split('|')] 377 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=self.optimizer, milestones=milestones, 378 | gamma=gamma) 379 | elif name == 'ExponentialLR': 380 | # Tuple of form ('ExponentialLR', initial_lr, gamma) 381 | gamma = float(kwargs[1]) 382 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=self.optimizer, gamma=gamma) 383 | elif name == 'Linear': 384 | if len(kwargs) == 2: 385 | # The final learning rate has also been passed 386 | end_factor = float(kwargs[1]) / float(kwargs[0]) 387 | else: 388 | end_factor = 0. 389 | scheduler = torch.optim.lr_scheduler.LinearLR(optimizer=self.optimizer, 390 | start_factor=1.0, end_factor=end_factor, 391 | total_iters=n_remaining_iterations) 392 | elif name == 'Cosine': 393 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, 394 | T_max=n_remaining_iterations, eta_min=0.) 395 | 396 | # Reset base lrs to make this work 397 | scheduler.base_lrs = [self.initial_lr if warmup_scheduler else 0. for _ in self.optimizer.param_groups] 398 | 399 | # Define the Sequential Scheduler 400 | if warmup_scheduler is None: 401 | self.scheduler = scheduler 402 | elif name in ['StepLR', 'MultiStepLR']: 403 | # We need parallel schedulers, since the steps should be counted during warmup 404 | self.scheduler = torch.optim.lr_scheduler.ChainedScheduler(schedulers=[warmup_scheduler, scheduler]) 405 | else: 406 | self.scheduler = SequentialSchedulers(optimizer=self.optimizer, schedulers=[warmup_scheduler, scheduler], 407 | milestones=[milestone]) 408 | 409 | def define_strategy(self, use_dense_base=False): 410 | #### UNSTRUCTURED 411 | # Define callbacksfinetuning_callback, restore_callback, save_model_callback 412 | callbackDict = { 413 | 'after_pruning_callback': self.after_pruning_callback, 414 | 'finetuning_callback': self.fine_tuning, 415 | 'restore_callback': self.restore_model, 416 | 'save_model_callback': self.save_model, 417 | 'final_log_callback': self.final_log, 418 | } 419 | # Base strategies 420 | if use_dense_base: 421 | return getattr(usual_strategies, 'Dense')(model=self.model, n_classes=self.n_classes, 422 | config=self.config, callbacks=callbackDict) 423 | else: 424 | return getattr(usual_strategies, self.config.strategy)(model=self.model, n_classes=self.n_classes, 425 | config=self.config, callbacks=callbackDict) 426 | 427 | def log(self, runTime, finetuning: bool = False, final_logging: bool = False): 428 | loggingDict = self.get_metrics() 429 | loggingDict.update({'epoch_run_time': runTime}) 430 | if not finetuning: 431 | # Update final trained metrics (necessary to be able to filter via wandb) 432 | for metric_type, val in loggingDict.items(): 433 | wandb.run.summary[f"final.{metric_type}"] = val 434 | # The usual logging of one epoch 435 | wandb.log( 436 | loggingDict 437 | ) 438 | 439 | else: 440 | if not final_logging: 441 | wandb.log( 442 | dict(finetune=loggingDict, 443 | ), 444 | ) 445 | else: 446 | # We add the after_pruning_metrics and don't commit, since the values are updated by self.final_log 447 | self.ultimate_log_dict = dict(finetune=loggingDict, 448 | pruned=self.after_pruning_metrics, 449 | ) 450 | 451 | def final_log(self): 452 | """This function can ONLY be called by pretrained strategies using the final sparsified model""" 453 | # Recompute accuracy and loss 454 | sys.stdout.write( 455 | f"\nFinal logging\n") 456 | self.reset_averaged_metrics() 457 | if self.config.strategy != 'Dense': 458 | # We recalibrate the BN statistics also for IMP 459 | self.recalibrate_bn() 460 | self.evaluate_model(data='val') 461 | self.evaluate_model(data='test') 462 | self.evaluate_model(data='ood') 463 | 464 | # Update final trained metrics (necessary to be able to filter via wandb) 465 | loggingDict = self.get_metrics() 466 | for metric_type, val in loggingDict.items(): 467 | wandb.run.summary[f"final.{metric_type}"] = val 468 | 469 | # Update after prune metrics 470 | if self.after_pruning_metrics is not None: 471 | for metric_type, val in self.after_pruning_metrics.items(): 472 | wandb.run.summary[f"pruned.{metric_type}"] = val 473 | 474 | # Add to existing self.ultimate_log_dict which was not committed yet 475 | if self.ultimate_log_dict is not None: 476 | if loggingDict['train']['accuracy'] == 0: 477 | # we did not perform the recomputation, use the old values for train 478 | del loggingDict['train'] 479 | 480 | self.ultimate_log_dict['finetune'].update(loggingDict) 481 | else: 482 | self.ultimate_log_dict = {'finetune': loggingDict} 483 | 484 | wandb.log(self.ultimate_log_dict) 485 | Utils.dump_dict_to_json_wandb(metrics.per_layer_sparsity(model=self.model), 'sparsity_distribution') 486 | 487 | def after_pruning_callback(self): 488 | """Collects pruning metrics. Is called ONCE per run, namely on the LAST PRUNING step.""" 489 | 490 | # Make the pruning permanent (this is in conflict with strategies that do not have a permanent pruning) 491 | self.strategy.enforce_prunedness() 492 | 493 | # Compute losses, accuracies after pruning 494 | sys.stdout.write(f"\nGoal sparsity reached - Computing incurred losses after pruning.\n") 495 | self.reset_averaged_metrics() 496 | 497 | # self.evaluate_model(data='train') 498 | self.evaluate_model(data='val') 499 | self.evaluate_model(data='test') 500 | if self.squared_model_norm is not None: 501 | L2_norm_square = Utils.get_model_norm_square(self.model) 502 | norm_drop = sqrt(abs(self.squared_model_norm - L2_norm_square)) 503 | if float(sqrt(self.squared_model_norm)) > 0: 504 | relative_norm_drop = norm_drop / float(sqrt(self.squared_model_norm)) 505 | else: 506 | relative_norm_drop = {} 507 | else: 508 | norm_drop, relative_norm_drop = {}, {} 509 | 510 | pruning_instability, pruning_stability = {}, {} 511 | if self.trained_test_accuracy is not None and self.trained_test_accuracy > 0: 512 | pruning_instability = ( 513 | self.trained_test_accuracy - self.metrics['test'][ 514 | 'accuracy'].compute()) / self.trained_test_accuracy 515 | pruning_stability = 1 - pruning_instability 516 | 517 | self.after_pruning_metrics = dict( 518 | val={metric_name: metric.compute() for metric_name, metric in self.metrics['val'].items()}, 519 | test={metric_name: metric.compute() for metric_name, metric in self.metrics['test'].items()}, 520 | norm_drop=norm_drop, 521 | relative_norm_drop=relative_norm_drop, 522 | pruning_instability=pruning_instability, 523 | pruning_stability=pruning_stability, 524 | ) 525 | 526 | # Reset squared model norm for following pruning steps, otherwise ALLR does not work properly 527 | self.squared_model_norm = Utils.get_model_norm_square(model=self.model) 528 | 529 | def restore_model(self) -> None: 530 | sys.stdout.write( 531 | f"Restoring model from {self.checkpoint_file}.\n") 532 | self.model = self.get_model(reinit=False, temporary=True) 533 | 534 | def save_model(self, model_type: str, remove_pruning_hooks: bool = False, temporary: bool = False) -> str: 535 | if model_type not in ['initial', 'trained', 'pruned', 'ensemble']: 536 | print(f"Ignoring to save {model_type} for now.") 537 | return None 538 | fName = f"{model_type}_model.pt" 539 | if model_type == 'pruned': 540 | fName = f"{model_type}_model_{self.config.split_val}_{self.config.phase}.pt" 541 | elif model_type == 'ensemble': 542 | fName = f"{model_type}_model_{self.config.ensemble_method}_{self.config.phase}.pt" 543 | fPath = os.path.join(wandb.run.dir, fName) if not temporary else os.path.join(self.tmp_dir, fName) 544 | if remove_pruning_hooks: 545 | self.strategy.make_pruning_permanent(model=self.model) 546 | 547 | # Only save models in their non-module version, to avoid problems when loading 548 | try: 549 | model_state_dict = self.model.module.state_dict() 550 | except AttributeError: 551 | model_state_dict = self.model.state_dict() 552 | 553 | torch.save(model_state_dict, fPath) # Save the state_dict 554 | return fPath 555 | 556 | def evaluate_model(self, data='train'): 557 | return self.train_epoch(data=data, is_training=False) 558 | 559 | def define_retrain_schedule(self, n_epochs_finetune, pruning_sparsity): 560 | """Define the retraining schedule. 561 | - Tuneable schedules all require both an initial value as well as a warmup length 562 | - Fixed schedules require no additional parameters and are mere conversions such as LRW 563 | """ 564 | fixed_schedules = ['FT', # Use last lr of original training as schedule (Han et al.), no warmup 565 | 'LRW', # Learning Rate Rewinding (Renda et al.), no warmup 566 | 'SLR', # Scaled Learning Rate Restarting (Le et al.), maxLR init, 10% warmup 567 | 'CLR', # Cyclic Learning Rate Restarting (Le et al.), maxLR init, 10% warmup 568 | 'LLR', # Linear from the largest original lr to 0, maxLR init, 10% warmup 569 | 'ALLR', # LLR, but choose initial value adaptively 570 | ] 571 | retrain_schedule = self.config.retrain_schedule 572 | init_val = None 573 | if self.config.ensemble_by == 'retrain_schedule': 574 | retrain_schedule = self.config.split_val 575 | # Check if the retrain schedule is a float 576 | init_val = float(retrain_schedule) 577 | retrain_schedule = 'LLR' 578 | sys.stdout.write(f"We split by the retrain schedule initial value. Value {init_val}.\n") 579 | 580 | # Define the initial lr, max lr and min lr 581 | maxLR = max( 582 | self.strategy.lr_dict.values()) 583 | after_warmup_index = (self.config.n_epochs_warmup or 0) * len(self.trainLoader) 584 | minLR = min(list(self.strategy.lr_dict.values())[after_warmup_index:]) # Ignores warmup in orig. schedule 585 | 586 | n_total_iterations = len(self.trainLoader) * n_epochs_finetune 587 | 588 | if retrain_schedule in fixed_schedules: 589 | # Define warmup length 590 | if retrain_schedule in ['FT', 'LRW']: 591 | n_warmup_iterations = 0 592 | else: 593 | # 10% warmup 594 | n_warmup_iterations = int(0.1 * n_total_iterations) 595 | 596 | # Define the after_warmup_lr 597 | if init_val is not None: 598 | after_warmup_lr = init_val 599 | elif retrain_schedule == 'FT': 600 | after_warmup_lr = minLR 601 | elif retrain_schedule == 'LRW': 602 | after_warmup_lr = list(self.strategy.lr_dict.values())[ 603 | -n_total_iterations] # == remaining iterations since we don't do warmup 604 | elif retrain_schedule in ['ALLR']: 605 | minLRThreshold = min(float(n_epochs_finetune) / self.config.n_epochs, 1.0) * maxLR 606 | # Use the norm drop 607 | relative_norm_drop = self.after_pruning_metrics['relative_norm_drop'] 608 | scaling = relative_norm_drop / sqrt(pruning_sparsity) 609 | 610 | discounted_LR = float(scaling) * maxLR 611 | 612 | after_warmup_lr = np.clip(discounted_LR, a_min=minLRThreshold, a_max=maxLR) 613 | 614 | elif retrain_schedule in ['SLR', 'CLR', 'LLR']: 615 | after_warmup_lr = maxLR 616 | else: 617 | raise NotImplementedError 618 | else: 619 | raise NotImplementedError 620 | 621 | # Set the optimizer lr 622 | for param_group in self.optimizer.param_groups: 623 | if n_warmup_iterations > 0: 624 | # If warmup, then we actually begin with 0 and increase to after_warmup_lr 625 | param_group['lr'] = 0.0 626 | else: 627 | param_group['lr'] = after_warmup_lr 628 | 629 | # Define warmup scheduler 630 | warmup_scheduler, milestone = None, None 631 | if n_warmup_iterations > 0: 632 | warmup_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR \ 633 | (self.optimizer, T_max=n_warmup_iterations, eta_min=after_warmup_lr) 634 | milestone = n_warmup_iterations + 1 635 | 636 | # Define scheduler after the warmup 637 | n_remaining_iterations = n_total_iterations - n_warmup_iterations 638 | scheduler = None 639 | if retrain_schedule in ['FT']: 640 | # Does essentially nothing but keeping the smallest learning rate 641 | scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer=self.optimizer, 642 | factor=1.0, 643 | total_iters=n_remaining_iterations) 644 | elif retrain_schedule == 'LRW': 645 | iterationsLR = list(self.strategy.lr_dict.values())[(-n_remaining_iterations):] 646 | iterationsLR.append(iterationsLR[-1]) # Double the last learning rate so we avoid the IndexError 647 | scheduler = FixedLR(optimizer=self.optimizer, lrList=iterationsLR) 648 | 649 | elif retrain_schedule in ['SLR']: 650 | iterationsLR = [lr if int(it) >= after_warmup_index else maxLR 651 | for it, lr in self.strategy.lr_dict.items()] 652 | 653 | interpolation_width = (len(self.strategy.lr_dict)) / n_remaining_iterations # In general not an integer 654 | reducedLRs = [iterationsLR[int(j * interpolation_width)] for j in range(n_remaining_iterations)] 655 | # Add a last LR to avoid IndexError 656 | reducedLRs = reducedLRs + [reducedLRs[-1]] 657 | 658 | lr_lambda = lambda it: reducedLRs[it] / float(maxLR) # Function returning the correct learning rate factor 659 | scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_lambda) 660 | 661 | elif retrain_schedule in ['CLR']: 662 | stopLR = minLR 663 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR \ 664 | (self.optimizer, T_max=n_remaining_iterations, eta_min=stopLR) 665 | 666 | elif retrain_schedule in ['LLR', 'ALLR']: 667 | scheduler = torch.optim.lr_scheduler.LinearLR(optimizer=self.optimizer, 668 | start_factor=1.0, end_factor=0., 669 | total_iters=n_remaining_iterations) 670 | 671 | # Reset base lrs to make this work 672 | scheduler.base_lrs = [after_warmup_lr for _ in self.optimizer.param_groups] 673 | 674 | # Define the Sequential Scheduler 675 | if warmup_scheduler is None: 676 | self.scheduler = scheduler 677 | else: 678 | self.scheduler = SequentialSchedulers(optimizer=self.optimizer, schedulers=[warmup_scheduler, scheduler], 679 | milestones=[milestone]) 680 | 681 | def fine_tuning(self, pruning_sparsity, n_epochs_finetune, phase=1): 682 | if n_epochs_finetune == 0: 683 | return 684 | if self.config.ensemble_by == 'retrain_length': 685 | n_epochs_finetune = self.config.split_val 686 | sys.stdout.write(f"We split by the retrain length. Value {n_epochs_finetune}.\n") 687 | n_phases = self.config.n_phases or 1 688 | 689 | # Reset the GradScaler for AutoCast 690 | self.ampGradScaler = torch.cuda.amp.GradScaler(enabled=(self.config.use_amp is True)) 691 | 692 | # Update the retrain schedule individually for every phase/cycle 693 | self.define_retrain_schedule(n_epochs_finetune=n_epochs_finetune, 694 | pruning_sparsity=pruning_sparsity) 695 | 696 | self.strategy.set_to_finetuning_phase() 697 | for epoch in range(1, n_epochs_finetune + 1, 1): 698 | self.reset_averaged_metrics() 699 | sys.stdout.write( 700 | f"\nFinetuning: phase {phase}/{n_phases} | epoch {epoch}/{n_epochs_finetune}\n") 701 | # Train 702 | t = time.time() 703 | self.train_epoch(data='train') 704 | self.evaluate_model(data='val') 705 | 706 | self.strategy.at_epoch_end(epoch=epoch) 707 | self.log(runTime=time.time() - t, finetuning=True, 708 | final_logging=(epoch == n_epochs_finetune and phase == n_phases)) 709 | 710 | def train_epoch(self, data='train', is_training=True): 711 | assert not (data in ['test', 'val', 'ood'] and is_training), "Can't train on test/val/ood set." 712 | loaderDict = {'train': self.trainLoader, 713 | 'val': self.valLoader, 714 | 'test': self.testLoader, 715 | 'ood': self.oodLoader} 716 | loader = loaderDict[data] 717 | if loader is None and data == 'ood': 718 | sys.stdout.write(f"No OOD data available. Skipping.\n") 719 | return 720 | 721 | sys.stdout.write(f"Training:\n") if is_training else sys.stdout.write( 722 | f"Evaluation of {data} data:\n") 723 | 724 | with torch.set_grad_enabled(is_training): 725 | with tqdm(loader, leave=True) as pbar: 726 | for x_input, y_target, indices in pbar: 727 | # Move to CUDA if possible 728 | x_input = x_input.to(self.device, non_blocking=True) 729 | y_target = y_target.to(self.device, non_blocking=True) 730 | self.optimizer.zero_grad() # Zero the gradient buffers 731 | 732 | itStartTime = time.time() 733 | 734 | with autocast(enabled=(self.config.use_amp is True)): 735 | output = self.model.train(mode=(data == 'train'))(x_input) 736 | loss = self.loss_criterion(output, y_target) 737 | 738 | if is_training: 739 | self.ampGradScaler.scale(loss).backward() # Scaling + Backpropagation 740 | self.ampGradScaler.step(self.optimizer) # Optimization step 741 | self.ampGradScaler.update() 742 | 743 | self.strategy.after_training_iteration(it=self.trainIterationCtr, 744 | lr=float(self.optimizer.param_groups[0]['lr'])) 745 | self.scheduler.step() 746 | self.trainIterationCtr += 1 747 | 748 | itEndTime = time.time() 749 | n_img_in_iteration = len(y_target) 750 | ips = n_img_in_iteration / (itEndTime - itStartTime) # Images processed per second 751 | 752 | self.metrics[data]['loss'](value=loss, weight=len(y_target)) 753 | self.metrics[data]['accuracy'](output, y_target) 754 | self.metrics[data]['ips_throughput'](ips) 755 | if data in ['val', 'test']: 756 | self.metrics[data]['ece'](output, y_target) 757 | self.metrics[data]['mce'](output, y_target) 758 | self.metrics[data]['worst_class_accuracy'](output, y_target) 759 | 760 | def train(self): 761 | self.ampGradScaler = torch.cuda.amp.GradScaler(enabled=(self.config.use_amp is True)) 762 | for epoch in range(self.config.n_epochs + 1): 763 | self.reset_averaged_metrics() 764 | sys.stdout.write(f"\n\nEpoch {epoch}/{self.config.n_epochs}\n") 765 | t = time.time() 766 | if epoch > 0: 767 | # Train 768 | self.train_epoch(data='train') 769 | self.evaluate_model(data='val') 770 | 771 | if epoch == self.config.n_epochs: 772 | # Do one complete evaluation on the test data set 773 | self.evaluate_model(data='test') 774 | 775 | self.strategy.at_epoch_end(epoch=epoch) 776 | 777 | self.log(runTime=time.time() - t) 778 | 779 | self.trained_test_accuracy = self.metrics['test']['accuracy'].compute() 780 | self.trained_train_loss = self.metrics['train']['loss'].compute() 781 | 782 | def recalibrate_bn(self): 783 | # Reset BN statistics 784 | recalibration_fraction = self.config.bn_recalibration_frac 785 | if self.config.bn_recalibration_frac is None or not (0 <= self.config.bn_recalibration_frac <= 1): 786 | recalibration_fraction = 1. 787 | sys.stdout.write( 788 | f"bn_recalibration_frac not specified or invalid ({self.config.bn_recalibration_frac}). Recalibrating BN-statistics on 100% of the training data (unshuffled).\n") 789 | 790 | reset_ctr = 0 791 | for m in self.model.modules(): 792 | if isinstance(m, torch.nn.BatchNorm2d): 793 | m.reset_running_stats() 794 | reset_ctr += 1 795 | sys.stdout.write( 796 | f"\nReset of {reset_ctr} BN-layers successful. Recalibrating BN-statistics on {int(recalibration_fraction * 100)}% of the training data (unshuffled).\n") 797 | n_batches = len(self.trainLoader_unshuffled) 798 | max_n_batches = int(recalibration_fraction * n_batches) 799 | if max_n_batches == 0: return 800 | it = 0 801 | with tqdm(self.trainLoader_unshuffled, leave=True) as pbar: 802 | for x_input, y_target, indices in pbar: 803 | # Move to CUDA if possible 804 | x_input = x_input.to(self.device, non_blocking=True) 805 | 806 | with autocast(enabled=(self.config.use_amp is True)): 807 | self.model.train()(x_input) 808 | it += 1 809 | if it >= max_n_batches: 810 | break 811 | 812 | # Free the cuda cache since it might be that the entire trainLoader is allocated 813 | # sys.stdout.write("Emptying cuda cache.\n") 814 | torch.cuda.empty_cache() 815 | --------------------------------------------------------------------------------