├── __init__.py ├── utils ├── __init__.py ├── bn_type.py ├── eval_utils.py ├── net_utils.py ├── schedulers.py ├── logging.py ├── conv_type.py └── builder.py ├── .gitattributes ├── data ├── __init__.py ├── utils.py └── imagenet.py ├── models ├── __init__.py ├── mobilenetv1.py └── resnet.py ├── requirements.txt ├── configs ├── reparam │ ├── resnet50-dense.yaml │ ├── resnet50-prune.yaml │ ├── mobilenetv1-prune.yaml │ └── mobilenetv1-dense.yaml └── parser.py ├── README.md ├── .gitignore ├── trainer.py ├── args.py └── main.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from data.imagenet import ImageNet 2 | from data.imagenet import TinyImageNet -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.resnet import ResNet18, ResNet50 2 | from models.mobilenetv1 import MobileNetV1 3 | 4 | __all__ = [ 5 | "ResNet18", 6 | "ResNet50", 7 | "MobileNetV1" 8 | ] -------------------------------------------------------------------------------- /utils/bn_type.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | LearnedBatchNorm = nn.BatchNorm2d 4 | 5 | 6 | class NonAffineBatchNorm(nn.BatchNorm2d): 7 | def __init__(self, dim): 8 | super(NonAffineBatchNorm, self).__init__(dim, affine=False) 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cudatoolkit=10.2.89 2 | cudnn=8.2.1.32 3 | numpy=1.21.4 4 | python=3.7.11 5 | pytorch=1.10.0 6 | tensorboard=2.7.0 7 | torchvision=0.11.1 8 | absl-py=0.15.0 9 | grpcio=1.42.0 10 | markdown=3.3.6 11 | pillow=9.0.1 12 | protobuf=3.19.1 13 | pyyaml=6.0 14 | six=1.16.0 15 | tqdm=4.62.3 16 | werkzeug=2.0.3 17 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.dataset import Dataset 3 | 4 | 5 | def one_batch_dataset(dataset, batch_size): 6 | print("==> Grabbing a single batch") 7 | 8 | perm = torch.randperm(len(dataset)) 9 | 10 | one_batch = [dataset[idx.item()] for idx in perm[:batch_size]] 11 | 12 | class _OneBatchWrapper(Dataset): 13 | def __init__(self): 14 | self.batch = one_batch 15 | 16 | def __getitem__(self, index): 17 | return self.batch[index] 18 | 19 | def __len__(self): 20 | return len(self.batch) 21 | 22 | return _OneBatchWrapper() 23 | -------------------------------------------------------------------------------- /utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def accuracy(output, target, topk=(1,)): 5 | """Computes the accuracy over the k top predictions for the specified values of k""" 6 | with torch.no_grad(): 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res 19 | -------------------------------------------------------------------------------- /configs/reparam/resnet50-dense.yaml: -------------------------------------------------------------------------------- 1 | # Architecture 2 | arch: ResNet50 3 | 4 | # ===== Dataset ===== # 5 | set: ImageNet 6 | name: r50-dense 7 | 8 | # ===== Learning Rate Policy ======== # 9 | optimizer: sgd 10 | lr: 0.256 11 | lr_policy: cosine_lr 12 | warmup_length: 5 13 | 14 | # ===== Network training config ===== # 15 | epochs: 100 16 | weight_decay: 0.000030517578125 17 | momentum: 0.875 18 | batch_size: 256 19 | label_smoothing: 0.1 20 | 21 | # ===== Sparsity =========== # 22 | conv_type: STRConv 23 | bn_type: LearnedBatchNorm 24 | init: kaiming_normal 25 | mode: fan_in 26 | nonlinearity: relu 27 | sparse_function: identity 28 | 29 | # ===== Hardware setup ===== # 30 | workers: 48 -------------------------------------------------------------------------------- /configs/reparam/resnet50-prune.yaml: -------------------------------------------------------------------------------- 1 | # Architecture 2 | arch: ResNet50 3 | 4 | # ===== Dataset ===== # 5 | set: ImageNet 6 | name: r50-prune 7 | 8 | # ===== Learning Rate Policy ======== # 9 | optimizer: sgd 10 | lr: 0.256 11 | lr_policy: cosine_lr 12 | warmup_length: 5 13 | 14 | # ===== Network training config ===== # 15 | epochs: 100 16 | weight_decay: 0.000030517578125 17 | momentum: 0.875 18 | batch_size: 256 19 | label_smoothing: 0.1 20 | 21 | # ===== Sparsity =========== # 22 | conv_type: STRConv 23 | bn_type: LearnedBatchNorm 24 | init: kaiming_normal 25 | mode: fan_in 26 | nonlinearity: relu 27 | sparse_function: stmod 28 | 29 | # ===== Hardware setup ===== # 30 | workers: 48 -------------------------------------------------------------------------------- /configs/reparam/mobilenetv1-prune.yaml: -------------------------------------------------------------------------------- 1 | # Architecture 2 | arch: MobileNetV1 3 | 4 | # ===== Dataset ===== # 5 | set: ImageNet 6 | name: mbnet-prune 7 | 8 | # ===== Learning Rate Policy ======== # 9 | optimizer: sgd 10 | lr: 0.256 11 | lr_policy: cosine_lr 12 | warmup_length: 5 13 | 14 | # ===== Network training config ===== # 15 | epochs: 100 16 | weight_decay: 0.00003051757813 17 | momentum: 0.875 18 | batch_size: 256 19 | label_smoothing: 0.1 20 | 21 | # ===== Sparsity =========== # 22 | conv_type: STRConv 23 | bn_type: LearnedBatchNorm 24 | init: kaiming_normal 25 | mode: fan_in 26 | nonlinearity: relu 27 | sparse_function: stmod 28 | 29 | # ===== Hardware setup ===== # 30 | workers: 48 -------------------------------------------------------------------------------- /configs/reparam/mobilenetv1-dense.yaml: -------------------------------------------------------------------------------- 1 | # Architecture 2 | arch: MobileNetV1 3 | 4 | # ===== Dataset ===== # 5 | set: ImageNet 6 | name: mbnet-dense 7 | 8 | # ===== Learning Rate Policy ======== # 9 | optimizer: sgd 10 | lr: 0.256 11 | lr_policy: cosine_lr 12 | warmup_length: 5 13 | 14 | # ===== Network training config ===== # 15 | epochs: 100 16 | weight_decay: 0.00003051757813 17 | momentum: 0.875 18 | batch_size: 256 19 | label_smoothing: 0.1 20 | 21 | # ===== Sparsity =========== # 22 | conv_type: STRConv 23 | bn_type: LearnedBatchNorm 24 | init: kaiming_normal 25 | mode: fan_in 26 | nonlinearity: relu 27 | sparse_function: identity 28 | 29 | # ===== Hardware setup ===== # 30 | workers: 48 -------------------------------------------------------------------------------- /configs/parser.py: -------------------------------------------------------------------------------- 1 | USABLE_TYPES = set([float, int]) 2 | 3 | 4 | def trim_preceding_hyphens(st): 5 | i = 0 6 | while st[i] == "-": 7 | i += 1 8 | 9 | return st[i:] 10 | 11 | 12 | def arg_to_varname(st: str): 13 | st = trim_preceding_hyphens(st) 14 | st = st.replace("-", "_") 15 | 16 | return st.split("=")[0] 17 | 18 | 19 | def argv_to_vars(argv): 20 | var_names = [] 21 | for arg in argv: 22 | if arg.startswith("-") and arg_to_varname(arg) != "config": 23 | var_names.append(arg_to_varname(arg)) 24 | 25 | return var_names 26 | 27 | 28 | def produce_override_string(args, override_args): 29 | lines = [] 30 | for v in override_args: 31 | if v != "multigpu": 32 | v_arg = getattr(args, v) 33 | if type(v_arg) in USABLE_TYPES: 34 | lines.append(v + ": " + str(v_arg)) 35 | else: 36 | lines.append(v + ": " + f'"{str(v_arg)}"') 37 | else: 38 | lines.append("multigpu: " + str(args.multigpu)) 39 | 40 | return "\n# ===== Overrided ===== #\n" + "\n".join(lines) 41 | -------------------------------------------------------------------------------- /models/mobilenetv1.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from utils.builder import get_builder 3 | 4 | 5 | class MobileNetV1(nn.Module): 6 | def __init__(self): 7 | super(MobileNetV1, self).__init__() 8 | builder = get_builder() 9 | 10 | def conv_bn(inp, oup, stride): 11 | return nn.Sequential( 12 | builder.conv2d(inp, oup, 3, stride, 1, bias=False), 13 | nn.BatchNorm2d(oup), 14 | nn.ReLU(inplace=True) 15 | ) 16 | 17 | def conv_dw(inp, oup, stride): 18 | return nn.Sequential( 19 | builder.conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 20 | nn.BatchNorm2d(inp), 21 | nn.ReLU(inplace=True), 22 | 23 | builder.conv2d(inp, oup, 1, 1, 0, bias=False), 24 | nn.BatchNorm2d(oup), 25 | nn.ReLU(inplace=True), 26 | ) 27 | 28 | self.model = nn.Sequential( 29 | conv_bn(3, 32, 2), 30 | conv_dw(32, 64, 1), 31 | conv_dw(64, 128, 2), 32 | conv_dw(128, 128, 1), 33 | conv_dw(128, 256, 2), 34 | conv_dw(256, 256, 1), 35 | conv_dw(256, 512, 2), 36 | conv_dw(512, 512, 1), 37 | conv_dw(512, 512, 1), 38 | conv_dw(512, 512, 1), 39 | conv_dw(512, 512, 1), 40 | conv_dw(512, 512, 1), 41 | conv_dw(512, 1024, 2), 42 | conv_dw(1024, 1024, 1), 43 | nn.AvgPool2d(7), 44 | ) 45 | self.fc = builder.conv1x1(1024, 1000) 46 | 47 | def forward(self, x): 48 | x = self.model(x) 49 | x = self.fc(x) 50 | x = x.view(-1, 1000) 51 | return x 52 | 53 | 54 | -------------------------------------------------------------------------------- /utils/net_utils.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import os 3 | import pathlib 4 | import shutil 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def save_checkpoint(state, is_best, filename="checkpoint.pth", save=False): 11 | filename = pathlib.Path(filename) 12 | 13 | if not filename.parent.exists(): 14 | os.makedirs(filename.parent) 15 | 16 | torch.save(state, filename) 17 | 18 | if is_best: 19 | shutil.copyfile(filename, str(filename.parent / "model_best.pth")) 20 | 21 | if not save: 22 | os.remove(filename) 23 | 24 | 25 | def get_lr(optimizer): 26 | return optimizer.param_groups[0]["lr"] 27 | 28 | 29 | def accumulate(model, f): 30 | acc = 0.0 31 | 32 | for child in model.children(): 33 | acc += accumulate(child, f) 34 | 35 | acc += f(model) 36 | 37 | return acc 38 | 39 | 40 | class LabelSmoothing(nn.Module): 41 | """ 42 | NLL loss with label smoothing. 43 | """ 44 | 45 | def __init__(self, smoothing=0.0): 46 | """ 47 | Constructor for the LabelSmoothing module. 48 | 49 | :param smoothing: label smoothing factor 50 | """ 51 | super(LabelSmoothing, self).__init__() 52 | self.confidence = 1.0 - smoothing 53 | self.smoothing = smoothing 54 | 55 | def forward(self, x, target): 56 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 57 | 58 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 59 | nll_loss = nll_loss.squeeze(1) 60 | smooth_loss = -logprobs.mean(dim=-1) 61 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 62 | return loss.mean() 63 | 64 | 65 | class MaskL1RegLoss(nn.Module): 66 | def __init__(self, temperature=1.0): 67 | super().__init__() 68 | self.temperature = temperature 69 | 70 | def forward(self, model): 71 | l1_accum = accumulate(model, self.l1_of_mask) 72 | 73 | return l1_accum 74 | 75 | def l1_of_mask(self, m): 76 | if hasattr(m, "mask"): 77 | return (self.temperature * m.mask).sigmoid().sum() 78 | else: 79 | return 0.0 80 | 81 | -------------------------------------------------------------------------------- /utils/schedulers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | __all__ = ["step_lr", "cosine_lr", "constant_lr", "efficientnet_lr", "get_policy"] 4 | 5 | 6 | def get_policy(name): 7 | if name is None: 8 | return constant_lr 9 | 10 | out_dict = { 11 | "constant_lr": constant_lr, 12 | "cosine_lr": cosine_lr, 13 | "efficientnet_lr": efficientnet_lr, 14 | "step_lr": step_lr, 15 | "multistep_lr": multistep_lr, 16 | } 17 | 18 | return out_dict[name] 19 | 20 | 21 | def assign_learning_rate(optimizer, new_lr): 22 | for param_group in optimizer.param_groups: 23 | param_group["lr"] = new_lr 24 | 25 | 26 | def constant_lr(optimizer, args, **kwargs): 27 | def _lr_adjuster(epoch, iteration): 28 | if epoch < args.warmup_length: 29 | lr = _warmup_lr(args.lr, args.warmup_length, epoch) 30 | else: 31 | lr = args.lr 32 | 33 | assign_learning_rate(optimizer, lr) 34 | 35 | return lr 36 | 37 | return _lr_adjuster 38 | 39 | 40 | def cosine_lr(optimizer, args, **kwargs): 41 | def _lr_adjuster(epoch, iteration): 42 | if epoch < args.warmup_length: 43 | lr = _warmup_lr(args.lr, args.warmup_length, epoch) 44 | else: 45 | e = epoch - args.warmup_length 46 | es = args.epochs - args.warmup_length 47 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * (args.lr - args.lr_min) + args.lr_min 48 | 49 | assign_learning_rate(optimizer, lr) 50 | 51 | return lr 52 | 53 | return _lr_adjuster 54 | 55 | 56 | def efficientnet_lr(optimizer, args, **kwargs): 57 | def _lr_adjuster(epoch, iteration): 58 | if epoch < args.warmup_length: 59 | lr = _warmup_lr(args.lr, args.warmup_length, epoch) 60 | else: 61 | lr = args.lr * (0.97 ** (epoch / 2.4)) 62 | 63 | assign_learning_rate(optimizer, lr) 64 | 65 | return lr 66 | 67 | return _lr_adjuster 68 | 69 | 70 | def step_lr(optimizer, args, **kwargs): 71 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 72 | 73 | def _lr_adjuster(epoch, iteration): 74 | lr = args.lr * (args.lr_gamma ** (epoch // args.lr_adjust)) 75 | 76 | assign_learning_rate(optimizer, lr) 77 | 78 | return lr 79 | 80 | return _lr_adjuster 81 | 82 | ############# 指定步长scheduler 83 | def multistep_lr(optimizer, args, **kwargs): 84 | lr = args.lr 85 | def _lr_adjuster(epoch, iteration): 86 | nonlocal lr 87 | if epoch < args.warmup_length: 88 | lr = _warmup_lr(args.lr, args.warmup_length, epoch) 89 | elif epoch in args.lr_milestones: 90 | lr = args.lr_gamma * lr 91 | 92 | assign_learning_rate(optimizer, lr) 93 | 94 | return lr 95 | 96 | return _lr_adjuster 97 | ############################## 98 | 99 | 100 | 101 | def _warmup_lr(base_lr, warmup_length, epoch): 102 | return base_lr * (epoch + 1) / warmup_length 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Unified Framework for Soft Threshold Pruning 2 | This directory contains the code reproducing this paper. The code is modified based on the open-source code of [STR](https://github.com/RAIVNLab/STR). 3 | 4 | ## Dependency 5 | 6 | The major dependencies of this code are list as below. The detailed ones are listed in `requirements.txt` 7 | 8 | ``` 9 | # Name Version 10 | cudatoolkit 10.2.89 11 | cudnn 8.2.1.32 12 | numpy 1.21.4 13 | python 3.7.11 14 | pytorch 1.10.0 15 | tensorboard 2.7.0 16 | torchvision 0.11.1 17 | pyyaml 6.0 18 | ``` 19 | 20 | ## Environment 21 | 22 | The running of code requires NVIDIA GPU and has been tested on *CUDA 10.2* and *Ubuntu 16.04*. The hardware platform used in our experiments is shown below. 23 | 24 | - GPU: Tesla V100 25 | - CPU: Intel(R) Xeon(R) Platinum 8168 CPU @ 2.70GHz 26 | 27 | Each trial requires 8 GPUs. 28 | 29 | ## Usage 30 | 31 | **Note:** You may need to specify different names for each experiment using `--name`, or it would be grueling to find the result of an exact trial. The setting of final threshold is in **Appendix I** of the paper. 32 | 33 | #### Dense training on ResNet-50: 34 | 35 | ```shell 36 | python main.py --multigpu 0,1,2,3,4,5,6,7 --config configs/reparam/resnet50-dense.yaml --print-freq 4096 --data 37 | ``` 38 | 39 | #### Dense training on MobileNet-V1: 40 | 41 | ```shell 42 | python main.py --multigpu 0,1,2,3,4,5,6,7 --config configs/reparam/mobilenetv1-dense.yaml --print-freq 4096 --data 43 | ``` 44 | 45 | #### S-LATS on ResNet-50: 46 | 47 | ```shell 48 | python main.py --multigpu 0,1,2,3,4,5,6,7 --config configs/reparam/resnet50-prune.yaml --gradual sinp --flat-width --print-freq 4096 --data --name 49 | ``` 50 | 51 | #### S-LATS on ResNet-50 (1024 batch size): 52 | 53 | ```shell 54 | python main.py --multigpu 0,1,2,3,4,5,6,7 --config configs/reparam/resnet50-prune.yaml --gradual sinp --flat-width --batch-size 1024 --lr 0.512 --print-freq 4096 --data --name 55 | ``` 56 | 57 | #### PGH on ResNet-50: 58 | 59 | ```shell 60 | python main.py --multigpu 0,1,2,3,4,5,6,7 --config configs/reparam/resnet50-prune.yaml --gradual sinppgh --flat-width --print-freq 4096 --data --name 61 | ``` 62 | 63 | #### LATS on ResNet-50: 64 | 65 | ```shell 66 | python main.py --multigpu 0,1,2,3,4,5,6,7 --config configs/reparam/resnet50-prune.yaml --gradual sinp --flat-width --print-freq 4096 --data --low-freq --name 67 | ``` 68 | 69 | #### S-LATS on MobileNet-V1: 70 | 71 | ```shell 72 | python main.py --multigpu 0,1,2,3,4,5,6,7 --config configs/reparam/mobilenetv1-prune.yaml --gradual sinp --flat-width --print-freq 4096 --data --name 73 | ``` 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | # PyCharm 148 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 149 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 150 | # and can be added to the global gitignore or merged into this file. For a more nuclear 151 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 152 | #.idea/ 153 | -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import tqdm 3 | 4 | from torch.utils.tensorboard import SummaryWriter 5 | 6 | 7 | class ProgressMeter(object): 8 | def __init__(self, num_batches, meters, prefix=""): 9 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 10 | self.meters = meters 11 | self.prefix = prefix 12 | 13 | def display(self, batch, tqdm_writer=True): 14 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 15 | entries += [str(meter) for meter in self.meters] 16 | if not tqdm_writer: 17 | print("\t".join(entries)) 18 | else: 19 | tqdm.tqdm.write("\t".join(entries)) 20 | 21 | def write_to_tensorboard( 22 | self, writer: SummaryWriter, prefix="train", global_step=None 23 | ): 24 | for meter in self.meters: 25 | avg = meter.avg 26 | val = meter.val 27 | if meter.write_val: 28 | writer.add_scalar( 29 | f"{prefix}/{meter.name}_val", val, global_step=global_step 30 | ) 31 | 32 | if meter.write_avg: 33 | writer.add_scalar( 34 | f"{prefix}/{meter.name}_avg", avg, global_step=global_step 35 | ) 36 | 37 | def _get_batch_fmtstr(self, num_batches): 38 | num_digits = len(str(num_batches // 1)) 39 | fmt = "{:" + str(num_digits) + "d}" 40 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 41 | 42 | 43 | class Meter(object): 44 | @abc.abstractmethod 45 | def __init__(self, name, fmt=":f"): 46 | pass 47 | 48 | @abc.abstractmethod 49 | def reset(self): 50 | pass 51 | 52 | @abc.abstractmethod 53 | def update(self, val, n=1): 54 | pass 55 | 56 | @abc.abstractmethod 57 | def __str__(self): 58 | pass 59 | 60 | 61 | class AverageMeter(Meter): 62 | """ Computes and stores the average and current value """ 63 | 64 | def __init__(self, name, fmt=":f", write_val=True, write_avg=True): 65 | self.name = name 66 | self.fmt = fmt 67 | self.reset() 68 | 69 | self.write_val = write_val 70 | self.write_avg = write_avg 71 | 72 | def reset(self): 73 | self.val = 0 74 | self.avg = 0 75 | self.sum = 0 76 | self.count = 0 77 | 78 | def update(self, val, n=1): 79 | self.val = val 80 | self.sum += val * n 81 | self.count += n 82 | self.avg = self.sum / self.count 83 | 84 | def __str__(self): 85 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 86 | return fmtstr.format(**self.__dict__) 87 | 88 | 89 | class VarianceMeter(Meter): 90 | def __init__(self, name, fmt=":f", write_val=False): 91 | self.name = name 92 | self._ex_sq = AverageMeter(name="_subvariance_1", fmt=":.02f") 93 | self._sq_ex = AverageMeter(name="_subvariance_2", fmt=":.02f") 94 | self.fmt = fmt 95 | self.reset() 96 | self.write_val = False 97 | self.write_avg = True 98 | 99 | @property 100 | def val(self): 101 | return self._ex_sq.val - self._sq_ex.val ** 2 102 | 103 | @property 104 | def avg(self): 105 | return self._ex_sq.avg - self._sq_ex.avg ** 2 106 | 107 | def reset(self): 108 | self._ex_sq.reset() 109 | self._sq_ex.reset() 110 | 111 | def update(self, val, n=1): 112 | self._ex_sq.update(val ** 2, n=n) 113 | self._sq_ex.update(val, n=n) 114 | 115 | def __str__(self): 116 | return ("{name} (var {avg" + self.fmt + "})").format( 117 | name=self.name, avg=self.avg 118 | ) 119 | -------------------------------------------------------------------------------- /data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torchvision import datasets, transforms 5 | 6 | import torch.multiprocessing 7 | import h5py 8 | import os 9 | import numpy as np 10 | torch.multiprocessing.set_sharing_strategy("file_system") 11 | 12 | class ImageNet: 13 | def __init__(self, args): 14 | super(ImageNet, self).__init__() 15 | 16 | data_root = args.data 17 | 18 | use_cuda = torch.cuda.is_available() 19 | 20 | # Data loading code 21 | kwargs = {"num_workers": args.workers, "pin_memory": True} if use_cuda else {} 22 | 23 | # Data loading code 24 | traindir = os.path.join(data_root, "train") 25 | valdir = os.path.join(data_root, "val") 26 | 27 | normalize = transforms.Normalize( 28 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 29 | ) 30 | 31 | train_dataset = datasets.ImageFolder( 32 | traindir, 33 | transforms.Compose( 34 | [ 35 | transforms.RandomResizedCrop(224), 36 | transforms.RandomHorizontalFlip(), 37 | transforms.ToTensor(), 38 | normalize, 39 | ] 40 | ), 41 | ) 42 | 43 | self.train_loader = torch.utils.data.DataLoader( 44 | train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs 45 | ) 46 | 47 | self.val_loader = torch.utils.data.DataLoader( 48 | datasets.ImageFolder( 49 | valdir, 50 | transforms.Compose( 51 | [ 52 | transforms.Resize(256), 53 | transforms.CenterCrop(224), 54 | transforms.ToTensor(), 55 | normalize, 56 | ] 57 | ), 58 | ), 59 | batch_size=args.batch_size, 60 | shuffle=False, 61 | **kwargs 62 | ) 63 | 64 | class TinyImageNet: 65 | def __init__(self, args): 66 | super(TinyImageNet, self).__init__() 67 | 68 | data_root = os.path.join(args.data, "tiny_imagenet") 69 | 70 | use_cuda = torch.cuda.is_available() 71 | kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {} 72 | 73 | normalize = transforms.Normalize( 74 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 75 | ) 76 | 77 | train_transforms = transforms.Compose([ 78 | transforms.ToPILImage(), 79 | transforms.RandomHorizontalFlip(), 80 | transforms.ToTensor(), 81 | normalize, 82 | ]) 83 | 84 | test_transforms = transforms.Compose([ 85 | transforms.ToTensor(), 86 | normalize, 87 | ]) 88 | 89 | train_dataset = H5DatasetOld(data_root + '/train.h5', transform=train_transforms) 90 | test_dataset = H5DatasetOld(data_root + '/val.h5', transform=test_transforms) 91 | self.train_loader = torch.utils.data.DataLoader( 92 | train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs 93 | ) 94 | self.val_loader = torch.utils.data.DataLoader( 95 | test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs 96 | ) 97 | 98 | class H5Dataset(torch.utils.data.Dataset): 99 | def __init__(self, h5_file, transform=None): 100 | self.transform = transform 101 | self.dataFile = None 102 | self.h5_file = h5_file 103 | 104 | def __len__(self): 105 | datasetNames = list(self.dataFile.keys()) 106 | return len(self.dataFile[datasetNames[0]]) 107 | 108 | 109 | def __getitem__(self, idx): 110 | if self.dataFile is None: 111 | self.dataFile = h5py.File(self.h5_file, 'r') 112 | data = self.dataFile[list(self.dataFile.keys())[0]][idx] 113 | label = self.dataFile[list(self.dataFile.keys())[1]][idx] 114 | if self.transform: 115 | data = self.transform(data) 116 | return (data, label) 117 | 118 | class H5DatasetOld(torch.utils.data.Dataset): 119 | def __init__(self, h5_file, transform=None): 120 | self.transform = transform 121 | self.dataFile = h5py.File(h5_file, 'r') 122 | # self.h5_file = h5_file 123 | 124 | def __len__(self): 125 | datasetNames = list(self.dataFile.keys()) 126 | return len(self.dataFile[datasetNames[0]]) 127 | 128 | 129 | def __getitem__(self, idx): 130 | # if self.dataFile is None: 131 | # self.dataFile = h5py.File(self.h5_file, 'r') 132 | data = self.dataFile[list(self.dataFile.keys())[0]][idx] 133 | label = self.dataFile[list(self.dataFile.keys())[1]][idx] 134 | if self.transform: 135 | data = self.transform(data) 136 | return (data, label) -------------------------------------------------------------------------------- /utils/conv_type.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.autograd as autograd 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from args import args as parser_args 7 | 8 | DenseConv = nn.Conv2d 9 | 10 | def sparseFunction(x, s, activation=torch.relu, f=torch.sigmoid): 11 | return torch.sign(x)*activation(torch.abs(x)-f(s)) 12 | 13 | def initialize_sInit(): 14 | 15 | if parser_args.sInit_type == "constant": 16 | return parser_args.sInit_value*torch.ones([1, 1]) 17 | 18 | class PseudoRelu(torch.autograd.Function): 19 | @staticmethod 20 | def forward(ctx, x): 21 | return torch.relu(x) 22 | 23 | @staticmethod 24 | def backward(ctx, grad_output): 25 | grad_x = None 26 | if ctx.needs_input_grad[0]: 27 | grad_x = grad_output 28 | return grad_x, None 29 | 30 | pseudoRelu = PseudoRelu.apply 31 | 32 | def softThresholdinv(x, s): 33 | return torch.sign(x) * (torch.abs(x) + s) 34 | 35 | def softThresholdmod(x, s): 36 | return torch.sign(x) * pseudoRelu(torch.abs(x)-s) 37 | 38 | class STRConv(nn.Conv2d): 39 | def __init__(self, *args, **kwargs): 40 | super().__init__(*args, **kwargs) 41 | 42 | #self.activation = pseudoRelu 43 | with torch.no_grad(): 44 | if parser_args.sparse_function == 'identity': 45 | self.mapping = lambda x: x 46 | elif parser_args.sparse_function == 'stmod': 47 | if parser_args.gradual is None: 48 | self.mapping = lambda x: softThresholdmod(x, parser_args.flat_width) 49 | else: 50 | self.mapping = lambda x: x 51 | 52 | if parser_args.sparse_function == 'stmod' and parser_args.gradual is None: 53 | self.weight.data = softThresholdinv(self.weight.data, parser_args.flat_width) 54 | 55 | def forward(self, x): 56 | # In case STR is not training for the hyperparameters given in the paper, change sparseWeight to self.sparseWeight if it is a problem of backprop. 57 | # However, that should not be the case according to graph computation. 58 | 59 | sparseWeight = self.mapping(self.weight) 60 | x = F.conv2d( 61 | x, sparseWeight, self.bias, self.stride, self.padding, self.dilation, self.groups 62 | ) 63 | return x 64 | 65 | def getSparsity(self): #, f=torch.sigmoid): 66 | #sparseWeight = sparseFunction(self.weight, self.sparseThreshold, self.activation, self.f) 67 | sparseWeight = self.mapping(self.weight) 68 | temp = sparseWeight.detach().cpu() 69 | return (temp == 0).sum(), temp.numel()#, f(self.sparseThreshold).item() 70 | 71 | @torch.no_grad() 72 | def getSparseWeight(self): 73 | return self.mapping(self.weight) 74 | #return sparseFunction(self.weight, self.sparseThreshold, self.activation, self.f) 75 | 76 | @torch.no_grad() 77 | def setFlatWidth(self, width): 78 | if parser_args.sparse_function == 'stmod': 79 | self.mapping = lambda x: softThresholdmod(x, width) 80 | 81 | class ChooseEdges(autograd.Function): 82 | @staticmethod 83 | def forward(ctx, weight, prune_rate): 84 | output = weight.clone() 85 | _, idx = weight.flatten().abs().sort() 86 | p = int(prune_rate * weight.numel()) 87 | # flat_oup and output access the same memory. 88 | flat_oup = output.flatten() 89 | flat_oup[idx[:p]] = 0 90 | return output 91 | 92 | @staticmethod 93 | def backward(ctx, grad_output): 94 | return grad_output, None 95 | 96 | class DNWConv(nn.Conv2d): 97 | def __init__(self, *args, **kwargs): 98 | super().__init__(*args, **kwargs) 99 | 100 | def set_prune_rate(self, prune_rate): 101 | self.prune_rate = prune_rate 102 | print(f"=> Setting prune rate to {prune_rate}") 103 | 104 | def forward(self, x): 105 | w = ChooseEdges.apply(self.weight, self.prune_rate) 106 | 107 | x = F.conv2d( 108 | x, w, self.bias, self.stride, self.padding, self.dilation, self.groups 109 | ) 110 | 111 | return x 112 | 113 | def GMPChooseEdges(weight, prune_rate): 114 | output = weight.clone() 115 | _, idx = weight.flatten().abs().sort() 116 | p = int(prune_rate * weight.numel()) 117 | # flat_oup and output access the same memory. 118 | flat_oup = output.flatten() 119 | flat_oup[idx[:p]] = 0 120 | return output 121 | 122 | class GMPConv(nn.Conv2d): 123 | def __init__(self, *args, **kwargs): 124 | super().__init__(*args, **kwargs) 125 | 126 | def set_prune_rate(self, prune_rate): 127 | self.prune_rate = prune_rate 128 | self.curr_prune_rate = 0.0 129 | print(f"=> Setting prune rate to {prune_rate}") 130 | 131 | def set_curr_prune_rate(self, curr_prune_rate): 132 | self.curr_prune_rate = curr_prune_rate 133 | 134 | def forward(self, x): 135 | w = GMPChooseEdges(self.weight, self.curr_prune_rate) 136 | x = F.conv2d( 137 | x, w, self.bias, self.stride, self.padding, self.dilation, self.groups 138 | ) 139 | 140 | return x 141 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from utils.builder import get_builder 4 | from args import args 5 | 6 | # BasicBlock {{{ 7 | class BasicBlock(nn.Module): 8 | M = 2 9 | expansion = 1 10 | 11 | def __init__(self, builder, inplanes, planes, stride=1, downsample=None): 12 | super(BasicBlock, self).__init__() 13 | self.conv1 = builder.conv3x3(inplanes, planes, stride) 14 | self.bn1 = builder.batchnorm(planes) 15 | self.relu = builder.activation() 16 | self.conv2 = builder.conv3x3(planes, planes) 17 | self.bn2 = builder.batchnorm(planes, last_bn=True) 18 | self.downsample = downsample 19 | self.stride = stride 20 | 21 | def forward(self, x): 22 | residual = x 23 | 24 | out = self.conv1(x) 25 | if self.bn1 is not None: 26 | out = self.bn1(out) 27 | 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | 32 | if self.bn2 is not None: 33 | out = self.bn2(out) 34 | 35 | if self.downsample is not None: 36 | residual = self.downsample(x) 37 | 38 | out += residual 39 | out = self.relu(out) 40 | 41 | return out 42 | 43 | 44 | # BasicBlock }}} 45 | 46 | # Bottleneck {{{ 47 | class Bottleneck(nn.Module): 48 | M = 3 49 | expansion = 4 50 | 51 | def __init__(self, builder, inplanes, planes, stride=1, downsample=None): 52 | super(Bottleneck, self).__init__() 53 | self.conv1 = builder.conv1x1(inplanes, planes) 54 | self.bn1 = builder.batchnorm(planes) 55 | self.conv2 = builder.conv3x3(planes, planes, stride=stride) 56 | self.bn2 = builder.batchnorm(planes) 57 | self.conv3 = builder.conv1x1(planes, planes * self.expansion) 58 | self.bn3 = builder.batchnorm(planes * self.expansion, last_bn=True) 59 | self.relu = builder.activation() 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | residual = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv3(out) 75 | out = self.bn3(out) 76 | 77 | if self.downsample is not None: 78 | residual = self.downsample(x) 79 | 80 | out += residual 81 | 82 | out = self.relu(out) 83 | 84 | return out 85 | 86 | 87 | # Bottleneck }}} 88 | 89 | # ResNet {{{ 90 | class ResNet(nn.Module): 91 | def __init__(self, builder, block, layers, num_classes=1000): 92 | self.inplanes = 64 93 | super(ResNet, self).__init__() 94 | if args.first_layer_dense: 95 | print("FIRST LAYER DENSE!!!!") 96 | self.conv1 = nn.Conv2d( 97 | 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False 98 | ) 99 | else: 100 | self.conv1 = builder.conv7x7(3, 64, stride=2, first_layer=True) 101 | 102 | self.bn1 = builder.batchnorm(64) 103 | self.relu = builder.activation() 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 105 | self.layer1 = self._make_layer(builder, block, 64, layers[0]) 106 | self.layer2 = self._make_layer(builder, block, 128, layers[1], stride=2) 107 | self.layer3 = self._make_layer(builder, block, 256, layers[2], stride=2) 108 | self.layer4 = self._make_layer(builder, block, 512, layers[3], stride=2) 109 | self.avgpool = nn.AdaptiveAvgPool2d(1) 110 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 111 | if args.last_layer_dense: 112 | self.fc = nn.Conv2d(512 * block.expansion, num_classes, 1) 113 | else: 114 | self.fc = builder.conv1x1(512 * block.expansion, num_classes) 115 | 116 | def _make_layer(self, builder, block, planes, blocks, stride=1): 117 | downsample = None 118 | if stride != 1 or self.inplanes != planes * block.expansion: 119 | dconv = builder.conv1x1( 120 | self.inplanes, planes * block.expansion, stride=stride 121 | ) 122 | dbn = builder.batchnorm(planes * block.expansion) 123 | if dbn is not None: 124 | downsample = nn.Sequential(dconv, dbn) 125 | else: 126 | downsample = dconv 127 | 128 | layers = [] 129 | layers.append(block(builder, self.inplanes, planes, stride, downsample)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(builder, self.inplanes, planes)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | if self.bn1 is not None: 139 | x = self.bn1(x) 140 | x = self.relu(x) 141 | x = self.maxpool(x) 142 | 143 | x = self.layer1(x) 144 | x = self.layer2(x) 145 | x = self.layer3(x) 146 | x = self.layer4(x) 147 | 148 | x = self.avgpool(x) 149 | x = self.fc(x) 150 | x = x.view(x.size(0), -1) 151 | 152 | return x 153 | 154 | 155 | # ResNet }}} 156 | def ResNet18(pretrained=False): 157 | # TODO: pretrained 158 | return ResNet(get_builder(), BasicBlock, [2, 2, 2, 2], 1000) 159 | 160 | 161 | def ResNet50(pretrained=False): 162 | # TODO: pretrained 163 | return ResNet(get_builder(), Bottleneck, [3, 4, 6, 3], 1000) 164 | -------------------------------------------------------------------------------- /utils/builder.py: -------------------------------------------------------------------------------- 1 | from args import args 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | import utils.conv_type 8 | import utils.bn_type 9 | 10 | 11 | class Builder(object): 12 | def __init__(self, conv_layer, bn_layer, first_layer=None): 13 | self.conv_layer = conv_layer 14 | self.bn_layer = bn_layer 15 | self.first_layer = first_layer or conv_layer 16 | 17 | def conv(self, kernel_size, in_planes, out_planes, stride=1, first_layer=False): 18 | conv_layer = self.first_layer if first_layer else self.conv_layer 19 | 20 | if first_layer: 21 | print(f"==> Building first layer with {args.first_layer_type}") 22 | 23 | if kernel_size == 3: 24 | conv = conv_layer( 25 | in_planes, 26 | out_planes, 27 | kernel_size=3, 28 | stride=stride, 29 | padding=1, 30 | bias=False, 31 | ) 32 | elif kernel_size == 1: 33 | conv = conv_layer( 34 | in_planes, out_planes, kernel_size=1, stride=stride, bias=False 35 | ) 36 | elif kernel_size == 5: 37 | conv = conv_layer( 38 | in_planes, 39 | out_planes, 40 | kernel_size=5, 41 | stride=stride, 42 | padding=2, 43 | bias=False, 44 | ) 45 | elif kernel_size == 7: 46 | conv = conv_layer( 47 | in_planes, 48 | out_planes, 49 | kernel_size=7, 50 | stride=stride, 51 | padding=3, 52 | bias=False, 53 | ) 54 | else: 55 | return None 56 | 57 | self._init_conv(conv) 58 | 59 | return conv 60 | 61 | def conv2d( 62 | self, 63 | in_channels, 64 | out_channels, 65 | kernel_size, 66 | stride=1, 67 | padding=0, 68 | dilation=1, 69 | groups=1, 70 | bias=True, 71 | padding_mode="zeros", 72 | ): 73 | return self.conv_layer( 74 | in_channels, 75 | out_channels, 76 | kernel_size, 77 | stride, 78 | padding, 79 | dilation, 80 | groups, 81 | bias, 82 | padding_mode, 83 | ) 84 | 85 | def conv3x3(self, in_planes, out_planes, stride=1, first_layer=False): 86 | """3x3 convolution with padding""" 87 | c = self.conv(3, in_planes, out_planes, stride=stride, first_layer=first_layer) 88 | return c 89 | 90 | def conv1x1(self, in_planes, out_planes, stride=1, first_layer=False): 91 | """1x1 convolution with padding""" 92 | c = self.conv(1, in_planes, out_planes, stride=stride, first_layer=first_layer) 93 | return c 94 | 95 | def conv7x7(self, in_planes, out_planes, stride=1, first_layer=False): 96 | """7x7 convolution with padding""" 97 | c = self.conv(7, in_planes, out_planes, stride=stride, first_layer=first_layer) 98 | return c 99 | 100 | def conv5x5(self, in_planes, out_planes, stride=1, first_layer=False): 101 | """5x5 convolution with padding""" 102 | c = self.conv(5, in_planes, out_planes, stride=stride, first_layer=first_layer) 103 | return c 104 | 105 | def batchnorm(self, planes, last_bn=False, first_layer=False): 106 | return self.bn_layer(planes) 107 | 108 | def activation(self): 109 | if args.nonlinearity == "relu": 110 | return (lambda: nn.ReLU(inplace=True))() 111 | else: 112 | raise ValueError(f"{args.nonlinearity} is not an initialization option!") 113 | 114 | def _init_conv(self, conv): 115 | if args.init == "signed_constant": 116 | 117 | fan = nn.init._calculate_correct_fan(conv.weight, args.mode) 118 | if args.scale_fan: 119 | fan = fan * (1 - args.prune_rate) 120 | gain = nn.init.calculate_gain(args.nonlinearity) 121 | std = gain / math.sqrt(fan) 122 | conv.weight.data = conv.weight.data.sign() * std 123 | 124 | elif args.init == "unsigned_constant": 125 | 126 | fan = nn.init._calculate_correct_fan(conv.weight, args.mode) 127 | if args.scale_fan: 128 | fan = fan * (1 - args.prune_rate) 129 | 130 | gain = nn.init.calculate_gain(args.nonlinearity) 131 | std = gain / math.sqrt(fan) 132 | conv.weight.data = torch.ones_like(conv.weight.data) * std 133 | 134 | elif args.init == "kaiming_normal": 135 | 136 | if args.scale_fan: 137 | fan = nn.init._calculate_correct_fan(conv.weight, args.mode) 138 | fan = fan * (1 - args.prune_rate) 139 | gain = nn.init.calculate_gain(args.nonlinearity) 140 | std = gain / math.sqrt(fan) 141 | with torch.no_grad(): 142 | conv.weight.data.normal_(0, std) 143 | else: 144 | nn.init.kaiming_normal_( 145 | conv.weight, mode=args.mode, nonlinearity=args.nonlinearity 146 | ) 147 | 148 | elif args.init == "standard": 149 | nn.init.kaiming_uniform_(conv.weight, a=math.sqrt(5)) 150 | else: 151 | raise ValueError(f"{args.init} is not an initialization option!") 152 | 153 | 154 | def get_builder(): 155 | 156 | print("==> Conv Type: {}".format(args.conv_type)) 157 | print("==> BN Type: {}".format(args.bn_type)) 158 | 159 | conv_layer = getattr(utils.conv_type, args.conv_type) 160 | bn_layer = getattr(utils.bn_type, args.bn_type) 161 | 162 | if args.first_layer_type is not None: 163 | first_layer = getattr(utils.conv_type, args.first_layer_type) 164 | print(f"==> First Layer Type {args.first_layer_type}") 165 | else: 166 | first_layer = None 167 | 168 | builder = Builder(conv_layer=conv_layer, bn_layer=bn_layer, first_layer=first_layer) 169 | 170 | return builder 171 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import tqdm 4 | 5 | from torch.cuda import amp 6 | from utils.eval_utils import accuracy 7 | from utils.logging import AverageMeter, ProgressMeter 8 | import math 9 | 10 | def sinpInc(n, N): 11 | # S-LATS 12 | return math.sin(math.pi * float(n) / N) / math.pi + float(n) / N 13 | 14 | def sinppghInc(n, N, beta): 15 | # PGH 16 | x = float(n) / N 17 | lbeta = math.log(beta) 18 | beta_powerx = beta ** x 19 | return ((math.pi ** 2) * (beta_powerx - 1) + (beta_powerx - 2) * (lbeta ** 2) + beta_powerx * lbeta * (lbeta * math.cos(math.pi * x) + math.pi * math.sin(math.pi * x))) / ((math.pi ** 2) * (beta - 1) - 2 * (lbeta ** 2)) 20 | 21 | def sinpLowfBaseInc(n, N): 22 | # LATS 23 | return 0.5 * (1.0 + 2 * n + (math.sin(math.pi * (n - 0.5) / N) / math.sin(math.pi * 0.5 / N))) / (N + 1.0) 24 | 25 | __all__ = ["train", "validate"] 26 | 27 | 28 | def train(train_loader, model, criterion, optimizer, epoch, args, writer, scaler=None): 29 | batch_time = AverageMeter("Time", ":6.3f") 30 | data_time = AverageMeter("Data", ":6.3f") 31 | losses = AverageMeter("Loss", ":.3f") 32 | top1 = AverageMeter("Acc@1", ":6.2f") 33 | top5 = AverageMeter("Acc@5", ":6.2f") 34 | progress = ProgressMeter( 35 | len(train_loader), 36 | [batch_time, data_time, losses, top1, top5], 37 | prefix=f"Epoch: [{epoch}]", 38 | ) 39 | 40 | # switch to train mode 41 | model.train() 42 | 43 | batch_size = train_loader.batch_size 44 | num_batches = len(train_loader) 45 | 46 | if epoch >= args.pruning_start_epoch: 47 | step = epoch * num_batches 48 | total_step = num_batches * args.epochs 49 | begin_step = args.pruning_start_epoch * num_batches 50 | 51 | if args.low_freq and args.gradual == 'sinp': 52 | base_threshold = sinpLowfBaseInc(epoch - args.pruning_start_epoch, args.epochs - args.pruning_start_epoch) 53 | flat_step = (1 + math.cos((epoch - args.pruning_start_epoch) * math.pi / (args.epochs - args.pruning_start_epoch))) / num_batches / (1 + args.epochs - args.pruning_start_epoch) 54 | b_step = 0 55 | 56 | 57 | end = time.time() 58 | for i, (images, target) in enumerate(train_loader): 59 | # measure data loading time 60 | data_time.update(time.time() - end) 61 | 62 | if args.gpu is not None: 63 | images = images.cuda(args.gpu, non_blocking=True) 64 | 65 | target = target.cuda(args.gpu, non_blocking=True).long() 66 | 67 | # compute output 68 | if scaler is not None: 69 | with amp.autocast(): 70 | output = model(images) 71 | loss = criterion(output, target.view(-1)) 72 | else: 73 | output = model(images) 74 | loss = criterion(output, target.view(-1)) 75 | 76 | # measure accuracy and record loss 77 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 78 | losses.update(loss.item(), images.size(0)) 79 | top1.update(acc1.item(), images.size(0)) 80 | top5.update(acc5.item(), images.size(0)) 81 | 82 | # compute gradient and do SGD step 83 | optimizer.zero_grad() 84 | if scaler is not None: 85 | scaler.scale(loss).backward() 86 | scaler.step(optimizer) 87 | scaler.update() 88 | else: 89 | loss.backward() 90 | optimizer.step() 91 | 92 | # update threshold 93 | if epoch >= args.pruning_start_epoch: 94 | step = step + 1 95 | if args.gradual is not None: 96 | if args.low_freq: 97 | b_step = b_step + 1 98 | flat_width = (base_threshold + b_step * flat_step) * args.flat_width 99 | # writer.add_scalar("threshold", flat_width, step) 100 | for module in model.modules(): 101 | if hasattr(module, 'setFlatWidth'): 102 | module.setFlatWidth(flat_width) 103 | else: 104 | if args.gradual == 'sinp': 105 | flat_width = sinpInc(step - begin_step, total_step - begin_step) 106 | elif args.gradual == 'sinppgh': 107 | flat_width = sinppghInc(step - begin_step, total_step - begin_step, args.beta) 108 | else: 109 | raise NotImplementedError 110 | 111 | normal_flat_width = flat_width * args.flat_width 112 | # writer.add_scalar("threshold", flat_width, step) 113 | for module in model.modules(): 114 | if hasattr(module, 'setFlatWidth'): 115 | module.setFlatWidth(normal_flat_width) 116 | 117 | 118 | # measure elapsed time 119 | batch_time.update(time.time() - end) 120 | end = time.time() 121 | 122 | if i % args.print_freq == 0: 123 | t = (num_batches * epoch + i) * batch_size 124 | progress.display(i) 125 | progress.write_to_tensorboard(writer, prefix="train", global_step=t) 126 | 127 | return top1.avg, top5.avg 128 | 129 | 130 | def validate(val_loader, model, criterion, args, writer, epoch): 131 | batch_time = AverageMeter("Time", ":6.3f", write_val=False) 132 | losses = AverageMeter("Loss", ":.3f", write_val=False) 133 | top1 = AverageMeter("Acc@1", ":6.2f", write_val=False) 134 | top5 = AverageMeter("Acc@5", ":6.2f", write_val=False) 135 | progress = ProgressMeter( 136 | len(val_loader), [batch_time, losses, top1, top5], prefix="Test: " 137 | ) 138 | 139 | # switch to evaluate mode 140 | model.eval() 141 | 142 | with torch.no_grad(): 143 | end = time.time() 144 | 145 | for i, (images, target) in enumerate(val_loader): 146 | if args.gpu is not None: 147 | images = images.cuda(args.gpu, non_blocking=True) 148 | 149 | target = target.cuda(args.gpu, non_blocking=True).long() 150 | 151 | # compute output 152 | output = model(images) 153 | 154 | loss = criterion(output, target.view(-1)) 155 | 156 | # measure accuracy and record loss 157 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 158 | losses.update(loss.item(), images.size(0)) 159 | top1.update(acc1.item(), images.size(0)) 160 | top5.update(acc5.item(), images.size(0)) 161 | 162 | # measure elapsed time 163 | batch_time.update(time.time() - end) 164 | end = time.time() 165 | 166 | if i % args.print_freq == 0: 167 | progress.display(i) 168 | 169 | progress.display(len(val_loader)) 170 | 171 | if writer is not None: 172 | progress.write_to_tensorboard(writer, prefix="test", global_step=epoch) 173 | 174 | return top1.avg, top5.avg 175 | 176 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import yaml 4 | 5 | from configs import parser as _parser 6 | 7 | args = None 8 | 9 | 10 | def parse_arguments(): 11 | parser = argparse.ArgumentParser(description="PyTorch ImageNet Training for STR, DNW and GMP") 12 | 13 | # General Config 14 | parser.add_argument( 15 | "--data", help="path to dataset base directory", default="" 16 | ) 17 | parser.add_argument("--optimizer", help="Which optimizer to use", default="sgd") 18 | parser.add_argument("--set", help="name of dataset", type=str, default="ImageNet") 19 | parser.add_argument( 20 | "-a", "--arch", metavar="ARCH", default="ResNet50", help="model architecture" 21 | ) 22 | parser.add_argument( 23 | "--config", help="Config file to use (see configs dir)", default=None 24 | ) 25 | parser.add_argument( 26 | "--log-dir", help="Where to save the runs. If None use ./runs", default=None 27 | ) 28 | parser.add_argument( 29 | "-j", 30 | "--workers", 31 | default=48, 32 | type=int, 33 | metavar="N", 34 | help="number of data loading workers (default: 20)", 35 | ) 36 | parser.add_argument( 37 | "--epochs", 38 | default=100, 39 | type=int, 40 | metavar="N", 41 | help="number of total epochs to run", 42 | ) 43 | parser.add_argument( 44 | "--start-epoch", 45 | default=None, 46 | type=int, 47 | metavar="N", 48 | help="manual epoch number (useful on restarts)", 49 | ) 50 | parser.add_argument( 51 | "-b", 52 | "--batch-size", 53 | default=256, 54 | type=int, 55 | metavar="N", 56 | help="mini-batch size (default: 256), this is the total " 57 | "batch size of all GPUs on the current node when " 58 | "using Data Parallel or Distributed Data Parallel", 59 | ) 60 | parser.add_argument( 61 | "--lr", 62 | "--learning-rate", 63 | default=0.256, 64 | type=float, 65 | metavar="LR", 66 | help="initial learning rate", 67 | dest="lr", 68 | ) 69 | parser.add_argument( 70 | "--warmup_length", default=5, type=int, help="Number of warmup iterations" 71 | ) 72 | parser.add_argument( 73 | "--init_prune_epoch", default=0, type=int, help="Init epoch for pruning in GMP" 74 | ) 75 | parser.add_argument( 76 | "--final_prune_epoch", default=-100, type=int, help="Final epoch for pruning in GMP" 77 | ) 78 | parser.add_argument( 79 | "--momentum", default=0.875, type=float, metavar="M", help="momentum" 80 | ) 81 | parser.add_argument( 82 | "--wd", 83 | "--weight-decay", 84 | default=1e-4, 85 | type=float, 86 | metavar="W", 87 | help="weight decay (default: 1e-4)", 88 | dest="weight_decay", 89 | ) 90 | parser.add_argument( 91 | "-p", 92 | "--print-freq", 93 | default=10, 94 | type=int, 95 | metavar="N", 96 | help="print frequency (default: 10)", 97 | ) 98 | parser.add_argument( 99 | "--num-classes", 100 | default=10, 101 | type=int, 102 | ) 103 | parser.add_argument( 104 | "--resume", 105 | default="", 106 | type=str, 107 | metavar="PATH", 108 | help="path to latest checkpoint (default: none)", 109 | ) 110 | parser.add_argument( 111 | "-e", 112 | "--evaluate", 113 | dest="evaluate", 114 | action="store_true", 115 | help="evaluate model on validation set", 116 | ) 117 | parser.add_argument( 118 | "--pretrained", 119 | type=str, 120 | default=None, 121 | ) 122 | parser.add_argument( 123 | "--seed", default=None, type=int, help="seed for initializing training. " 124 | ) 125 | parser.add_argument( 126 | "--multigpu", 127 | default=None, 128 | type=lambda x: [int(a) for a in x.split(",")], 129 | help="Which GPUs to use for multigpu training", 130 | ) 131 | 132 | # Learning Rate Policy Specific 133 | parser.add_argument( 134 | "--lr-policy", default="cosine_lr", help="Policy for the learning rate." 135 | ) 136 | parser.add_argument( 137 | "--multistep-lr-adjust", default=30, type=int, help="Interval to drop lr" 138 | ) 139 | parser.add_argument( 140 | "--lr-gamma", default=0.1, type=int, help="Multistep multiplier" 141 | ) 142 | parser.add_argument( 143 | "--name", default=None, type=str, help="Experiment name to append to filepath" 144 | ) 145 | parser.add_argument( 146 | "--save_every", default=-1, type=int, help="Save every ___ epochs" 147 | ) 148 | parser.add_argument( 149 | "--prune-rate", 150 | default=0.0, 151 | help="Amount of pruning to do during sparse training", 152 | type=float, 153 | ) 154 | parser.add_argument( 155 | "--width-mult", 156 | default=1.0, 157 | help="How much to vary the width of the network.", 158 | type=float, 159 | ) 160 | parser.add_argument( 161 | "--nesterov", 162 | default=False, 163 | action="store_true", 164 | help="Whether or not to use nesterov for SGD", 165 | ) 166 | parser.add_argument( 167 | "--random-mask", 168 | action="store_true", 169 | help="Whether or not to use a random mask when fine tuning for lottery experiments", 170 | ) 171 | parser.add_argument( 172 | "--one-batch", 173 | action="store_true", 174 | help="One batch train set for debugging purposes (test overfitting)", 175 | ) 176 | parser.add_argument( 177 | "--conv-type", type=str, default="STRConv", help="What kind of sparsity to use" 178 | ) 179 | parser.add_argument( 180 | "--freeze-weights", 181 | action="store_true", 182 | help="Whether or not to train only mask (this freezes weights)", 183 | ) 184 | parser.add_argument("--mode", default="fan_in", help="Weight initialization mode") 185 | parser.add_argument( 186 | "--nonlinearity", default="relu", help="Nonlinearity used by initialization" 187 | ) 188 | parser.add_argument("--bn-type", default="LearnedBatchNorm", help="BatchNorm type") 189 | parser.add_argument( 190 | "--init", default="kaiming_normal", help="Weight initialization modifications" 191 | ) 192 | parser.add_argument( 193 | "--no-bn-decay", action="store_true", default=False, help="No batchnorm decay" 194 | ) 195 | parser.add_argument( 196 | "--dense-conv-model", action="store_true", default=False, help="Store a model variant of the given pretrained model that is compatible to CNNs with DenseConv (nn.Conv2d)" 197 | ) 198 | parser.add_argument( 199 | "--st-decay", type=float, default=None, help="decay for sparse thresh. If none then use normal weight decay." 200 | ) 201 | parser.add_argument( 202 | "--scale-fan", action="store_true", default=False, help="scale fan" 203 | ) 204 | parser.add_argument( 205 | "--first-layer-dense", action="store_true", help="First layer dense or sparse" 206 | ) 207 | parser.add_argument( 208 | "--last-layer-dense", action="store_true", help="Last layer dense or sparse" 209 | ) 210 | parser.add_argument( 211 | "--label-smoothing", 212 | type=float, 213 | help="Label smoothing to use, default 0.0", 214 | default=0.1, 215 | ) 216 | parser.add_argument( 217 | "--first-layer-type", type=str, default=None, help="Conv type of first layer" 218 | ) 219 | 220 | parser.add_argument( 221 | "--sInit-type", 222 | type=str, 223 | help="type of sInit", 224 | default="constant", 225 | ) 226 | 227 | parser.add_argument( 228 | "--sInit-value", 229 | type=float, 230 | help="initial value for sInit", 231 | default=100, 232 | ) 233 | 234 | # parser.add_argument( 235 | # "--sparse-function", type=str, default='sigmoid', help="choice of g(s)" 236 | # ) 237 | 238 | parser.add_argument( 239 | "--sparse-function", type=str, choices=['identity', 'stmod'], default='identity', help="choice of reparameterization function") 240 | 241 | parser.add_argument( 242 | "--lr-milestones", type=int, nargs='+', default=[30, 70, 90], help="list of epoch indices if use stepLR. Must be increasing." 243 | ) 244 | 245 | parser.add_argument( 246 | "--flat-width", type=float, default=1.0, help="final threshold.") 247 | 248 | parser.add_argument("--gradual", type=str, choices=['sinp', 'sinppgh'], default=None, help="threshold scheduler") 249 | 250 | parser.add_argument('--pruning-start-epoch', default=0, type=int, 251 | help='the initial epoch to start pruning, only works when graudal is on') 252 | 253 | parser.add_argument("--low-freq", action="store_true", help="whether to use original LATS") 254 | 255 | parser.add_argument("--amp", action="store_true", help='use AMP training') 256 | 257 | parser.add_argument("--beta", default=0, type=float, help="parameter when using sinppgh scheduler") 258 | 259 | parser.add_argument("--lr-min", default=0.0, type=float, help="minimum learning rate of consine annealing") 260 | 261 | parser.add_argument( 262 | "--use-budget", action="store_true", help="use the budget from the pretrained model." 263 | ) 264 | parser.add_argument( 265 | "--ignore-pretrained-weights", action="store_true", help="ignore the weights of a pretrained model." 266 | ) 267 | 268 | args = parser.parse_args() 269 | 270 | get_config(args) 271 | 272 | return args 273 | 274 | 275 | def get_config(args): 276 | # get commands from command line 277 | override_args = _parser.argv_to_vars(sys.argv) 278 | 279 | # load yaml file 280 | yaml_txt = open(args.config).read() 281 | 282 | # override args 283 | loaded_yaml = yaml.load(yaml_txt, Loader=yaml.FullLoader) 284 | for v in override_args: 285 | loaded_yaml[v] = getattr(args, v) 286 | 287 | print(f"=> Reading YAML config from {args.config}") 288 | args.__dict__.update(loaded_yaml) 289 | 290 | 291 | def run_args(): 292 | global args 293 | if args is None: 294 | args = parse_arguments() 295 | 296 | 297 | run_args() 298 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import random 4 | import shutil 5 | import time 6 | import json 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim 13 | import torch.utils.data 14 | import torch.utils.data.distributed 15 | from torch.utils.tensorboard import SummaryWriter 16 | from torch.cuda import amp 17 | from utils.logging import AverageMeter, ProgressMeter 18 | from utils.net_utils import save_checkpoint, get_lr 19 | from utils.schedulers import get_policy 20 | from utils.conv_type import STRConv 21 | from utils.conv_type import sparseFunction 22 | 23 | from args import args 24 | from trainer import train, validate 25 | 26 | import data 27 | import models 28 | 29 | 30 | def main(): 31 | print(args) 32 | 33 | if args.seed is not None: 34 | random.seed(args.seed) 35 | torch.manual_seed(args.seed) 36 | torch.cuda.manual_seed(args.seed) 37 | torch.cuda.manual_seed_all(args.seed) 38 | 39 | # Simply call main_worker function 40 | main_worker(args) 41 | 42 | 43 | def main_worker(args): 44 | args.gpu = None 45 | 46 | if args.gpu is not None: 47 | print("Use GPU: {} for training".format(args.gpu)) 48 | 49 | # create model and optimizer 50 | model = get_model(args) 51 | if args.gradual == 'grad': 52 | for module in model.modules(): 53 | if hasattr(module, 'setFlatWidth'): 54 | module.register_buffer('threshold', torch.tensor([0.])) 55 | model = set_gpu(args, model) 56 | 57 | # Set up directories 58 | run_base_dir, ckpt_base_dir, log_base_dir = get_directories(args) 59 | 60 | # Loading pretrained model 61 | if args.pretrained: 62 | pretrained(args, model) 63 | 64 | # Saving a DenseConv (nn.Conv2d) compatible model 65 | if args.dense_conv_model: 66 | print(f"==> DenseConv compatible model, saving at {ckpt_base_dir / 'model_best.pth'}") 67 | save_checkpoint( 68 | { 69 | "epoch": 0, 70 | "arch": args.arch, 71 | "state_dict": model.state_dict(), 72 | }, 73 | True, 74 | filename=ckpt_base_dir / f"epoch_pretrained.state", 75 | save=True, 76 | ) 77 | return 78 | 79 | optimizer = get_optimizer(args, model) 80 | data = get_dataset(args) 81 | lr_policy = get_policy(args.lr_policy)(optimizer, args) 82 | 83 | if args.label_smoothing is None: 84 | criterion = nn.CrossEntropyLoss().cuda() 85 | else: 86 | # criterion = LabelSmoothing(smoothing=args.label_smoothing) 87 | criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing).cuda() 88 | 89 | # optionally resume from a checkpoint 90 | best_acc1 = 0.0 91 | best_acc5 = 0.0 92 | best_train_acc1 = 0.0 93 | best_train_acc5 = 0.0 94 | 95 | if args.resume: 96 | best_acc1 = resume(args, model, optimizer) 97 | 98 | # Evaulation of a model 99 | if args.evaluate: 100 | acc1, acc5 = validate( 101 | data.val_loader, model, criterion, args, writer=None, epoch=args.start_epoch 102 | ) 103 | return 104 | 105 | writer = SummaryWriter(log_dir=log_base_dir) 106 | epoch_time = AverageMeter("epoch_time", ":.4f", write_avg=False) 107 | validation_time = AverageMeter("validation_time", ":.4f", write_avg=False) 108 | train_time = AverageMeter("train_time", ":.4f", write_avg=False) 109 | progress_overall = ProgressMeter( 110 | 1, [epoch_time, validation_time, train_time], prefix="Overall Timing" 111 | ) 112 | if args.amp: 113 | scaler = amp.GradScaler() 114 | else: 115 | scaler = None 116 | 117 | prev_masks = dict() 118 | 119 | end_epoch = time.time() 120 | args.start_epoch = args.start_epoch or 0 121 | acc1 = None 122 | 123 | # Save the initial state 124 | save_checkpoint( 125 | { 126 | "epoch": 0, 127 | "arch": args.arch, 128 | "state_dict": model.state_dict(), 129 | "best_acc1": best_acc1, 130 | "best_acc5": best_acc5, 131 | "best_train_acc1": best_train_acc1, 132 | "best_train_acc5": best_train_acc5, 133 | "optimizer": optimizer.state_dict(), 134 | "curr_acc1": acc1 if acc1 else "Not evaluated", 135 | }, 136 | False, 137 | filename=ckpt_base_dir / f"initial.state", 138 | save=False, 139 | ) 140 | 141 | # Start training 142 | for epoch in range(args.start_epoch, args.epochs): 143 | lr_policy(epoch, iteration=None) 144 | cur_lr = get_lr(optimizer) 145 | 146 | # Gradual pruning in GMP experiments 147 | if args.conv_type == "GMPConv" and epoch >= args.init_prune_epoch and epoch <= args.final_prune_epoch: 148 | total_prune_epochs = args.final_prune_epoch - args.init_prune_epoch + 1 149 | for n, m in model.named_modules(): 150 | if hasattr(m, 'set_curr_prune_rate'): 151 | prune_decay = (1 - ((args.curr_prune_epoch - args.init_prune_epoch)/total_prune_epochs))**3 152 | curr_prune_rate = m.prune_rate - (m.prune_rate*prune_decay) 153 | m.set_curr_prune_rate(curr_prune_rate) 154 | 155 | # train for one epoch 156 | start_train = time.time() 157 | train_acc1, train_acc5 = train( 158 | data.train_loader, model, criterion, optimizer, epoch, args, writer=writer, scaler=scaler 159 | ) 160 | train_time.update((time.time() - start_train) / 60) 161 | 162 | # evaluate on validation set 163 | start_validation = time.time() 164 | acc1, acc5 = validate(data.val_loader, model, criterion, args, writer, epoch) 165 | validation_time.update((time.time() - start_validation) / 60) 166 | 167 | # remember best acc@1 and save checkpoint 168 | is_best = acc1 > best_acc1 169 | best_acc1 = max(acc1, best_acc1) 170 | best_acc5 = max(acc5, best_acc5) 171 | best_train_acc1 = max(train_acc1, best_train_acc1) 172 | best_train_acc5 = max(train_acc5, best_train_acc5) 173 | 174 | save = ((epoch % args.save_every) == 0) and args.save_every > 0 175 | if is_best or save or epoch == args.epochs - 1: 176 | if is_best: 177 | print(f"==> New best, saving at {ckpt_base_dir / 'model_best.pth'}") 178 | 179 | save_checkpoint( 180 | { 181 | "epoch": epoch + 1, 182 | "arch": args.arch, 183 | "state_dict": model.state_dict(), 184 | "best_acc1": best_acc1, 185 | "best_acc5": best_acc5, 186 | "best_train_acc1": best_train_acc1, 187 | "best_train_acc5": best_train_acc5, 188 | "optimizer": optimizer.state_dict(), 189 | "curr_acc1": acc1, 190 | "curr_acc5": acc5, 191 | }, 192 | is_best, 193 | filename=ckpt_base_dir / f"epoch_{epoch}.state", 194 | save=save, 195 | ) 196 | 197 | epoch_time.update((time.time() - end_epoch) / 60) 198 | progress_overall.display(epoch) 199 | progress_overall.write_to_tensorboard( 200 | writer, prefix="diagnostics", global_step=epoch 201 | ) 202 | 203 | # Storing sparsity and threshold statistics for STRConv models 204 | with torch.no_grad(): 205 | if args.conv_type == "STRConv": 206 | total_zerocnt = 0 207 | total_numel = 0 208 | for n, m in model.named_modules(): 209 | if isinstance(m, STRConv): 210 | if n in prev_masks: 211 | curr_mask = (m.getSparseWeight() == 0) 212 | prev_mask = prev_masks[n] 213 | regrowth_ratio = torch.logical_and(prev_mask, torch.logical_not(curr_mask)).sum() 214 | prune_ratio = torch.logical_and(torch.logical_not(prev_mask), curr_mask).sum() 215 | writer.add_scalar("regrowth/{}".format(n), regrowth_ratio, epoch) 216 | writer.add_scalar("prune/{}".format(n), prune_ratio, epoch) 217 | prev_masks[n] = (m.getSparseWeight() == 0) 218 | 219 | if epoch == 0 or (epoch + 1) % 20 == 0: 220 | writer.add_histogram("w/{}".format(n), m.getSparseWeight(), epoch) 221 | writer.add_histogram("theta/{}".format(n), m.weight, epoch) 222 | 223 | #sparsity, total_params, thresh = m.getSparsity() 224 | zerocnt, numel = m.getSparsity() 225 | print(f'{n}: {zerocnt / numel * 100:.2f}%') 226 | writer.add_scalar(f'sparsity/{n}', zerocnt / numel, epoch) 227 | 228 | #writer.add_scalar("thresh/{}".format(n), thresh, epoch) 229 | total_zerocnt += zerocnt 230 | total_numel += numel 231 | 232 | if args.first_layer_dense and n == 'module.conv1': 233 | print(f'{n}: 0.00%') 234 | writer.add_scalar(f'sparsity/{n}', 0.0, epoch) 235 | total_numel += m.weight.data.numel() 236 | 237 | if args.last_layer_dense and n == 'module.fc': 238 | print(f'{n}: 0.00%') 239 | writer.add_scalar(f'sparsity/{n}', 0.0, epoch) 240 | total_numel += m.weight.data.numel() 241 | 242 | total_sparsity = total_zerocnt / total_numel 243 | print(f'total: {total_sparsity * 100:.2f}%') 244 | writer.add_scalar("sparsity/total", total_sparsity, epoch) 245 | 246 | writer.add_scalar("test/lr", cur_lr, epoch) 247 | end_epoch = time.time() 248 | 249 | # write_result_to_csv( 250 | # best_acc1=best_acc1, 251 | # best_acc5=best_acc5, 252 | # best_train_acc1=best_train_acc1, 253 | # best_train_acc5=best_train_acc5, 254 | # prune_rate=args.prune_rate, 255 | # curr_acc1=acc1, 256 | # curr_acc5=acc5, 257 | # base_config=args.config, 258 | # name=args.name, 259 | # ) 260 | # if args.conv_type == "STRConv": 261 | # json_data = {} 262 | # #json_thres = {} 263 | # sum_sparse = 0.0 264 | # count = 0.0 265 | # for n, m in model.named_modules(): 266 | # if isinstance(m, STRConv): 267 | # zerocnt, numel = m.getSparsity() 268 | # json_data[n] = (zerocnt / numel * 100).item() 269 | # sum_sparse += zerocnt 270 | # count += numel 271 | # #json_thres[n] = sparsity[2] 272 | # json_data["total"] = (100 - (100 * sum_sparse / count)).item() 273 | # if not os.path.exists("runs/layerwise_sparsity"): 274 | # os.mkdir("runs/layerwise_sparsity") 275 | # if not os.path.exists("runs/layerwise_threshold"): 276 | # os.mkdir("runs/layerwise_threshold") 277 | # with open("runs/layerwise_sparsity/{}.json".format(args.name), "w") as f: 278 | # json.dump(json_data, f) 279 | #with open("runs/layerwise_threshold/{}.json".format(args.name), "w") as f: 280 | # json.dump(json_thres, f) 281 | 282 | 283 | def set_gpu(args, model): 284 | if args.gpu is not None: 285 | torch.cuda.set_device(args.gpu) 286 | model = model.cuda(args.gpu) 287 | else: 288 | # DataParallel will divide and allocate batch_size to all available GPUs 289 | print(f"=> Parallelizing on {args.multigpu} gpus") 290 | torch.cuda.set_device(args.multigpu[0]) 291 | args.gpu = args.multigpu[0] 292 | model = torch.nn.DataParallel(model, device_ids=args.multigpu).cuda( 293 | args.multigpu[0] 294 | ) 295 | 296 | cudnn.benchmark = True 297 | 298 | return model 299 | 300 | 301 | def resume(args, model, optimizer): 302 | if os.path.isfile(args.resume): 303 | print(f"=> Loading checkpoint '{args.resume}'") 304 | 305 | checkpoint = torch.load(args.resume) 306 | if args.start_epoch is None: 307 | print(f"=> Setting new start epoch at {checkpoint['epoch']}") 308 | args.start_epoch = checkpoint["epoch"] 309 | 310 | best_acc1 = checkpoint["best_acc1"] 311 | 312 | model.load_state_dict(checkpoint["state_dict"]) 313 | 314 | optimizer.load_state_dict(checkpoint["optimizer"]) 315 | 316 | print(f"=> Loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})") 317 | 318 | return best_acc1 319 | else: 320 | print(f"=> No checkpoint found at '{args.resume}'") 321 | 322 | 323 | def pretrained(args, model): 324 | if os.path.isfile(args.pretrained): 325 | print("=> loading pretrained weights from '{}'".format(args.pretrained)) 326 | pretrained = torch.load( 327 | args.pretrained, 328 | map_location=torch.device("cuda:{}".format(args.multigpu[0])), 329 | )["state_dict"] 330 | 331 | model_state_dict = model.state_dict() 332 | 333 | if not args.ignore_pretrained_weights: 334 | 335 | pretrained_final = { 336 | k: v 337 | for k, v in pretrained.items() 338 | if (k in model_state_dict and v.size() == model_state_dict[k].size()) 339 | } 340 | 341 | if args.conv_type != "STRConv": 342 | for k, v in pretrained.items(): 343 | if 'sparseThreshold' in k: 344 | wkey = k.split('sparse')[0] + 'weight' 345 | weight = pretrained[wkey] 346 | pretrained_final[wkey] = sparseFunction(weight, v) 347 | 348 | model_state_dict.update(pretrained_final) 349 | model.load_state_dict(model_state_dict) 350 | 351 | # Using the budgets of STR models for other models like DNW and GMP 352 | if args.use_budget: 353 | budget = {} 354 | for k, v in pretrained.items(): 355 | if 'sparseThreshold' in k: 356 | wkey = k.split('sparse')[0] + 'weight' 357 | weight = pretrained[wkey] 358 | sparse_weight = sparseFunction(weight, v) 359 | budget[wkey] = (sparse_weight.abs() > 0).float().mean().item() 360 | 361 | for n, m in model.named_modules(): 362 | if hasattr(m, 'set_prune_rate'): 363 | pr = 1 - budget[n + '.weight'] 364 | m.set_prune_rate(pr) 365 | print('set prune rate', n, pr) 366 | 367 | 368 | else: 369 | print("=> no pretrained weights found at '{}'".format(args.pretrained)) 370 | 371 | 372 | def get_dataset(args): 373 | print(f"=> Getting {args.set} dataset") 374 | dataset = getattr(data, args.set)(args) 375 | 376 | return dataset 377 | 378 | 379 | def get_model(args): 380 | if args.first_layer_dense: 381 | args.first_layer_type = "DenseConv" 382 | 383 | print("=> Creating model '{}'".format(args.arch)) 384 | model = models.__dict__[args.arch]() 385 | 386 | print(f"=> Num model params {sum(p.numel() for p in model.parameters())}") 387 | 388 | # applying sparsity to the network 389 | if args.conv_type != "DenseConv": 390 | 391 | print(f"==> Setting prune rate of network to {args.prune_rate}") 392 | 393 | def _sparsity(m): 394 | if hasattr(m, "set_prune_rate"): 395 | m.set_prune_rate(args.prune_rate) 396 | 397 | model.apply(_sparsity) 398 | 399 | # freezing the weights if we are only doing mask training 400 | if args.freeze_weights: 401 | print(f"=> Freezing model weights") 402 | 403 | def _freeze(m): 404 | if hasattr(m, "mask"): 405 | m.weight.requires_grad = False 406 | if hasattr(m, "bias") and m.bias is not None: 407 | m.bias.requires_grad = False 408 | 409 | model.apply(_freeze) 410 | 411 | return model 412 | 413 | 414 | def get_optimizer(args, model): 415 | for n, v in model.named_parameters(): 416 | if v.requires_grad: 417 | pass #print(" gradient to", n) 418 | 419 | if not v.requires_grad: 420 | pass #print(" no gradient to", n) 421 | 422 | if args.optimizer == "sgd": 423 | parameters = list(model.named_parameters()) 424 | # sparse_thresh = [v for n, v in parameters if ("sparseThreshold" in n) and v.requires_grad] 425 | bn_params = [v for n, v in parameters if ("bn" in n) and v.requires_grad] 426 | rest_params = [v for n, v in parameters if ("bn" not in n) and ('sparseThreshold' not in n) and v.requires_grad] 427 | optimizer = torch.optim.SGD( 428 | [ 429 | { 430 | "params": bn_params, 431 | "weight_decay": 0 if args.no_bn_decay else args.weight_decay, 432 | }, 433 | # { 434 | # "params": sparse_thresh, 435 | # "weight_decay": args.st_decay if args.st_decay is not None else args.weight_decay, 436 | # }, 437 | {"params": rest_params, "weight_decay": args.weight_decay}, 438 | ], 439 | args.lr, 440 | momentum=args.momentum, 441 | weight_decay=args.weight_decay, 442 | nesterov=args.nesterov, 443 | ) 444 | elif args.optimizer == "adam": 445 | optimizer = torch.optim.Adam( 446 | filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr 447 | ) 448 | 449 | return optimizer 450 | 451 | 452 | def _run_dir_exists(run_base_dir): 453 | log_base_dir = run_base_dir / "logs" 454 | ckpt_base_dir = run_base_dir / "checkpoints" 455 | 456 | return log_base_dir.exists() or ckpt_base_dir.exists() 457 | 458 | 459 | def get_directories(args): 460 | if args.config is None or args.name is None: 461 | raise ValueError("Must have name and config") 462 | 463 | config = pathlib.Path(args.config).stem 464 | if args.log_dir is None: 465 | run_base_dir = pathlib.Path( 466 | f"runs/{config}/{args.name}/prune_rate={args.prune_rate}" 467 | ) 468 | else: 469 | run_base_dir = pathlib.Path( 470 | f"{args.log_dir}/{config}/{args.name}/prune_rate={args.prune_rate}" 471 | ) 472 | if args.width_mult != 1.0: 473 | run_base_dir = run_base_dir / "width_mult={}".format(str(args.width_mult)) 474 | 475 | if _run_dir_exists(run_base_dir): 476 | rep_count = 0 477 | while _run_dir_exists(run_base_dir / str(rep_count)): 478 | rep_count += 1 479 | 480 | run_base_dir = run_base_dir / str(rep_count) 481 | 482 | log_base_dir = run_base_dir / "logs" 483 | ckpt_base_dir = run_base_dir / "checkpoints" 484 | 485 | if not run_base_dir.exists(): 486 | os.makedirs(run_base_dir) 487 | 488 | (run_base_dir / "settings.txt").write_text(str(args)) 489 | 490 | return run_base_dir, ckpt_base_dir, log_base_dir 491 | 492 | 493 | def write_result_to_csv(**kwargs): 494 | results = pathlib.Path("runs") / "results.csv" 495 | 496 | if not results.exists(): 497 | results.write_text( 498 | "Date Finished, " 499 | "Base Config, " 500 | "Name, " 501 | "Prune Rate, " 502 | "Current Val Top 1, " 503 | "Current Val Top 5, " 504 | "Best Val Top 1, " 505 | "Best Val Top 5, " 506 | "Best Train Top 1, " 507 | "Best Train Top 5\n" 508 | ) 509 | 510 | now = time.strftime("%m-%d-%y_%H:%M:%S") 511 | 512 | with open(results, "a+") as f: 513 | f.write( 514 | ( 515 | "{now}, " 516 | "{base_config}, " 517 | "{name}, " 518 | "{prune_rate}, " 519 | "{curr_acc1:.02f}, " 520 | "{curr_acc5:.02f}, " 521 | "{best_acc1:.02f}, " 522 | "{best_acc5:.02f}, " 523 | "{best_train_acc1:.02f}, " 524 | "{best_train_acc5:.02f}\n" 525 | ).format(now=now, **kwargs) 526 | ) 527 | 528 | 529 | if __name__ == "__main__": 530 | main() 531 | --------------------------------------------------------------------------------