├── .gitignore ├── LICENCE ├── README.md ├── argconfig.py ├── config ├── cifar10.yaml ├── cifar100.yaml └── imagenet.yaml ├── core ├── bnet.py ├── layers │ └── bconv2d.py └── model_converter.py ├── dataset.py ├── logger.py ├── main.py ├── model.py ├── run_cifar.sh ├── run_imagenet.sh ├── vis_cost.ipynb └── vis_hist.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *.egg-info/ 4 | .installed.cfg 5 | *.egg 6 | 7 | .* 8 | !/.gitignore 9 | 10 | */__pycache__/* 11 | 12 | log/ 13 | data/ 14 | 15 | # Ignore private files for debugging 16 | plot_stats.ipynb 17 | plot_util.py 18 | misc/ -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2022 Denso IT Laboratory, Inc. 2 | All Rights Reserved 3 | 4 | Denso IT Laboratory, Inc. retains sole and exclusive ownership of all 5 | intellectual property rights including copyrights and patents related to this 6 | Software. 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of the Software and accompanying documentation to use, copy, modify, merge, 10 | publish, or distribute the Software or software derived from it for 11 | non-commercial purposes, such as academic study, education and personal use, 12 | subject to the following conditions: 13 | 14 | 1. Redistributions of source code must retain the above copyright notice, 15 | this list of conditions and the following disclaimer. 16 | 17 | 2. Redistributions in binary form must reproduce the above copyright notice, 18 | this list of conditions and the following disclaimer in the documentation 19 | and/or other materials provided with the distribution. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 25 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 31 | POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bit-Pruning 2 | This is the official repo for ICLR 2023 Paper "Bit-Pruning: A Sparse Multiplication-Less Dot-Product" 3 | Yusuke Sekikawa and Shingo Yashima 4 | 5 | [paper](https://openreview.net/pdf?id=YUDiZcZTI8), [openreview](https://openreview.net/forum?id=YUDiZcZTI8) 6 | 7 | 8 | # Usage 9 | 10 | ## Set GPU(s) to use 11 | ``` 12 | export CUDA_VISIBLE_DEVICES=0 13 | export CUDA_VISIBLE_DEVICES=1 14 | ``` 15 | 16 | ## Run train and test 17 | ``` 18 | python main.py --config 'config/cifar10.yaml' 19 | python main.py --config 'config/cifar100.yaml' 20 | python main.py --config 'config/imagenet.yaml' 21 | ``` 22 | You can change any config by specifying the config name (e.g., optim.spr_w) followed by the value (e.g., `optim.spr_w 16 model.wgt_bit 8 ...`). 23 | use `run_xxx.sh` for batch execution. 24 | 25 | 26 | ## Plot bit-pruning loss 27 | Run `vis_cost.ipynb` to plot proximal weight and their loss landscape (Fig. 3, Fig. 11). 28 | 29 | ## Plot histrram of learned weight 30 | Run `vis_hist.ipynb` to plot the histgram of learned weight distribution (Fig. 6). 31 | 32 | # Installation 33 | ## Setup docker image from pytorch/pytorch 34 | ``` 35 | docker pull pytorch/pytorch 36 | docker run -dit --gpus all -v /username/Project/:/home/src -v /data1/dataset:/home/data --name username --shm-size=64gb pytorch/pytorch 37 | docker exec -e CUDA_VISIBLE_DEVICES='0' -u 0 -it username bash 38 | apt-get -y update && apt-get -y install libgl1 && apt-get -y install libglib2.0-0 39 | yes | pip install opencv-python 40 | yes | pip install opencv-contrib-python 41 | yes | pip install einops 42 | yes | pip install kornia 43 | yes | pip install lightning-bolts 44 | yes | pip install pytorch-lightning 45 | yes | pip install fvcore 46 | yes | pip install scipy 47 | conda install -c conda-forge easydict --yes 48 | conda install -c conda-forge ruamel.yaml --yes 49 | ``` 50 | 51 | 52 | If you find our code or paper useful, please cite the following: 53 | ``` 54 | @inproceedings{iclr2023bitprune, 55 | author = {Yusuke, Sekikawa and Shingo, Yashima}, 56 | title = {Bit-Pruning: A Sparse Multiplication-Less Dot-Product}, 57 | booktitle={Proceedings of the International Conference on Learning Representations}, 58 | year = {2023}, 59 | url={https://openreview.net/forum?id=YUDiZcZTI8} 60 | } 61 | ``` -------------------------------------------------------------------------------- /argconfig.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ruamel.yaml as yaml 3 | from fvcore.common.config import CfgNode as CN 4 | import os 5 | 6 | def default_args(): 7 | import argparse 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | "--config-file", 11 | default=None, 12 | type=str, 13 | help="/path/to/config-file" 14 | ) 15 | parser.add_argument( 16 | "opts", 17 | help=""" 18 | Modify config options at the end of the command. For Yacs configs, use 19 | space-separated "PATH.KEY VALUE" pairs. 20 | """.strip(), 21 | default=None, 22 | nargs=argparse.REMAINDER, 23 | ) 24 | return parser 25 | 26 | def load(): 27 | args = default_args().parse_args() 28 | print(args) 29 | print(args.config_file) 30 | if args.config_file is not None: 31 | cfg = CN((yaml.safe_load(open(args.config_file, 'r')))) 32 | else: 33 | os.error('Please specify config-file') 34 | cfg.merge_from_list(args.opts) 35 | 36 | return prep(cfg) 37 | 38 | def prep(cfg, print_cfg=True): 39 | cfg.cfg_name = '{}_wd{}_bs{}_epoch{}_lr{}_{}_lamda_{:0.1f}_bit{}_{}{}'.format( 40 | cfg.optim.loss_type, 41 | cfg.optim.enable_decay, 42 | cfg.optim.batch_size, 43 | cfg.optim.epochs, 44 | cfg.optim.lr_core, 45 | cfg.optim.lr_mask, 46 | cfg.optim.lamda_ini, 47 | cfg.model.wgt_bit, 48 | cfg.model.act_bit, 49 | cfg.misc.suffix 50 | ) 51 | cfg.file_name = "{:08d}".format(int(cfg.optim.spr_w)) 52 | # create dir for score 53 | cfg.root_dir = os.path.join(cfg.misc.log_dir, cfg.dataset.name, cfg.model.name, cfg.cfg_name, cfg.file_name) 54 | cfg.stats_dir = os.path.join(cfg.root_dir, 'score') 55 | cfg.model_dir = os.path.join(cfg.root_dir, 'model') 56 | 57 | os.makedirs(cfg.stats_dir, exist_ok=True) 58 | os.makedirs(cfg.model_dir, exist_ok=True) 59 | cfg.model_path_final = os.path.join(cfg.model_dir, 'final.pt') 60 | cfg.model_path_best = os.path.join(cfg.model_dir, 'best.pt') 61 | with open(os.path.join(cfg.root_dir, 'config.yaml'), 'w') as f: 62 | f.write(cfg.dump()) 63 | 64 | if print_cfg: 65 | print(cfg) 66 | return cfg -------------------------------------------------------------------------------- /config/cifar10.yaml: -------------------------------------------------------------------------------- 1 | misc: 2 | suffix: '' 3 | log_dir: 'log/' 4 | log_name: 'log_re' 5 | seed: 0 6 | 7 | dataset: 8 | name: 'CIFAR10' # cifar10, CIFAR10 9 | path: "data" 10 | input_size: 32 11 | color_jitter: 0.4, 12 | aa: 'rand-m9-mstd0.5-inc1' 13 | train_interpolation: 'bicubic' 14 | reprob: 0.25 15 | remode: 'pixel' 16 | recount: 1 17 | imagenet_default_mean_and_std: True 18 | pin_mem: True 19 | crop_pct: 0.0 20 | 21 | model: 22 | name: 'resnet18' 23 | wgt_bit: 8 24 | act_bit: 8 25 | pretrain_path: '' 26 | pretrained: 0 # 1: from ImageNet 2: from Task 27 | 28 | optim: 29 | batch_size: 512 30 | scheduler: 'OneCycleLR' # 'ExponentialLR', 'OneCycleLR' 31 | optimizer: 'AdamW' # 'SGD', 'AdamW' 32 | epochs: 200 33 | lr_core: 0.05 34 | lr_mask: 0.01 35 | gamma: 0.98 36 | momentum: 0.9 37 | weight_decay: 0.01 # 0 38 | enable_decay: 1 # 0 39 | loss_type: 'wgt_tgt' # 'wgt_naive', 'act_naive', 'wgt_bilinear', 'act_bilinear', 'wgt_tgt', 'act_tgt', 'wgt_prox' 40 | spr_w: 5 # binary weight 41 | lamda_ini: 1.0 42 | wgt_p_norm: 0.5 43 | smoothing: 0.0 44 | drop_path: 0.0 45 | use_correction: True 46 | 47 | hardware: 48 | num_cpu_workers: 16 49 | gpu_device: [0] -------------------------------------------------------------------------------- /config/cifar100.yaml: -------------------------------------------------------------------------------- 1 | misc: 2 | suffix: '' 3 | log_dir: 'log/' 4 | log_name: 'log_re' 5 | seed: 0 6 | 7 | dataset: 8 | name: 'CIFAR100' # cifar100, CIFAR100 9 | path: "data" 10 | input_size: 32 11 | color_jitter: 0.4, 12 | aa: 'rand-m9-mstd0.5-inc1' 13 | train_interpolation: 'bicubic' 14 | reprob: 0.25 15 | remode: 'pixel' 16 | recount: 1 17 | imagenet_default_mean_and_std: True 18 | pin_mem: True 19 | crop_pct: 0.0 20 | 21 | model: 22 | name: 'resnet18' 23 | wgt_bit: 8 24 | act_bit: 8 25 | pretrain_path: '' 26 | pretrained: 0 # 1: from ImageNet 2: from Task 27 | width: 1 28 | 29 | optim: 30 | batch_size: 512 31 | scheduler: 'OneCycleLR' # 'ExponentialLR', 'OneCycleLR' 32 | optimizer: 'AdamW' # 'SGD', 'AdamW' 33 | epochs: 200 34 | lr_core: 0.05 35 | lr_mask: 0.01 36 | gamma: 0.98 37 | momentum: 0.9 38 | weight_decay: 0.01 # 0 39 | enable_decay: 1 # 0 40 | loss_type: 'act_bilinear' # 'wgt_naive', 'act_naive', 'wgt_bilinear', 'act_bilinear', 'wgt_tgt', 'act_tgt', 'wgt_prox', 'wgt_lprox' 41 | spr_w: 5 42 | lamda_ini: 1.0 43 | wgt_p_norm: 0.5 44 | smoothing: 0.0 45 | drop_path: 0.0 46 | 47 | hardware: 48 | num_cpu_workers: 16 49 | gpu_device: [0] -------------------------------------------------------------------------------- /config/imagenet.yaml: -------------------------------------------------------------------------------- 1 | misc: 2 | suffix: '' 3 | log_dir: 'log/' 4 | log_name: 'log' 5 | seed: 0 6 | 7 | dataset: 8 | name: 'IMNET' 9 | path: "/home/data/ILSVRC/Data/CLS-LOC" 10 | input_size: 224 11 | color_jitter: 0.4, 12 | aa: 'rand-m9-mstd0.5-inc1' 13 | train_interpolation: 'bicubic' 14 | reprob: 0.25 15 | remode: 'pixel' 16 | recount: 1 17 | imagenet_default_mean_and_std: True 18 | pin_mem: True 19 | crop_pct: 0.0 20 | 21 | model: 22 | name: 'convnext_base' 23 | wgt_bit: 8 24 | act_bit: 8 25 | pretrain_path: '' 26 | pretrained: 1 # 1: from ImageNet 2: from Task 27 | 28 | optim: 29 | batch_size: 256 30 | scheduler: 'CosineDecay' # 'ExponentialLR', 'OneCycleLR', 'CosineDecay' 31 | optimizer: 'AdamW' # 'SGD', 'AdamW' 32 | epochs: 100 33 | lr_core: 5e-4 34 | lr_mask: 5e-4 35 | gamma: 0.98 36 | momentum: 0.9 37 | weight_decay: 1e-7 # 0 38 | enable_decay: 1 # 0 39 | loss_type: 'act_tgt' # 'fp32', 'wgt_naive', 'act_naive', 'wgt_bilinear', 'act_bilinear', 'wgt_tgt', 'act_tgt', 'wgt_prox', 'wgt_lprox' 40 | spr_w: 5 41 | lamda_ini: 1.0 42 | wgt_p_norm: 0.5 43 | smoothing: 0.1 44 | drop_path: 0.8 45 | 46 | hardware: 47 | num_cpu_workers: 32 48 | gpu_device: [0,1,2,3] -------------------------------------------------------------------------------- /core/bnet.py: -------------------------------------------------------------------------------- 1 | from numpy.core.numeric import Inf 2 | from core.layers.bconv2d import BConv2d 3 | from core.model_converter import convert_layers 4 | import torch.nn as nn 5 | import numpy as np 6 | import torch 7 | import random 8 | from easydict import EasyDict as edict 9 | 10 | class BNet(nn.Module): 11 | def __init__(self, model, cfg): 12 | super(BNet, self).__init__() 13 | self.cfg = cfg 14 | self.selected_out =[] 15 | self.fhooks = [] 16 | self.DSS_cand=[] 17 | self.selected_idx = [] 18 | 19 | model, _ = convert_layers(model, cfg) 20 | self.model = model 21 | 22 | model.eval() 23 | dummy_input = torch.randn(cfg.dataset.input_shape) 24 | self.prepare_hook() 25 | self.compute_dense_syn_cnt(dummy_input) 26 | 27 | 28 | def forward(self, x): 29 | x = self.model(x) 30 | return x 31 | 32 | # hooks 33 | def prepare_hook(self): 34 | DSS_cand = [] 35 | 36 | print('<<<<>>>>') 37 | idx = 0 38 | names = [] 39 | th_cand = [] 40 | for name, module in self.named_modules(): 41 | if isinstance(module, BConv2d): 42 | module.name = name 43 | module.idx = idx 44 | DSS_cand.append(module) 45 | names.append(name) 46 | idx+=1 47 | self.DSS_cand = DSS_cand 48 | self.th_cand = th_cand 49 | # print(DSS_cand) 50 | return 51 | 52 | def reset_hook(self): 53 | for fhook in self.fhooks: 54 | fhook.remove() 55 | self.fhooks.clear() 56 | self.selected_out.clear() 57 | 58 | for module in self.DSS_cand: 59 | if isinstance(module, BConv2d): 60 | module.selected_itr = [] 61 | module.hook_count = 0 62 | 63 | 64 | def register_idx_hook(self, idx=[]): 65 | self.max_hook = 3 66 | for fhook in self.fhooks: 67 | fhook.remove() 68 | self.fhooks.clear() 69 | self.selected_out.clear() 70 | self.selected_idx = idx 71 | for idx_ in self.selected_idx: 72 | self.fhooks.append(self.DSS_cand[idx_].register_forward_hook(self.forward_hook(idx_))) 73 | 74 | 75 | def register_rand_hook(self, n_sample=1): 76 | self.max_hook = Inf 77 | for fhook in self.fhooks: 78 | fhook.remove() 79 | self.fhooks.clear() 80 | self.selected_out.clear() 81 | self.selected_idx.clear() 82 | for idx, module in enumerate(self.DSS_cand): 83 | module.is_selected=False 84 | 85 | for itr in range(n_sample): 86 | idx_ = random.choices(range(len(self.DSS_cand)), weights=self.dense_syn_cnt)[0] 87 | if idx_ not in self.selected_idx: 88 | self.selected_idx.append(idx_) 89 | self.DSS_cand[idx_].is_selected=True 90 | # self.selected_idx = np.random.randint(low=0, high=len(self.DSS_cand), size=1)[0] 91 | # print('register_rand_hook %d %d '%(self.selected_idx, len(self.fhooks))) 92 | self.fhooks.append(self.DSS_cand[idx_].register_forward_hook(self.forward_hook(idx_))) 93 | 94 | def register_all_hook(self): 95 | self.max_hook = Inf 96 | for fhook in self.fhooks: 97 | fhook.remove() 98 | self.fhooks.clear() 99 | self.selected_out.clear() 100 | 101 | for idx, module in enumerate(self.DSS_cand): 102 | self.fhooks.append(module.register_forward_hook(self.forward_hook(idx))) 103 | 104 | def forward_hook(self, selected_idx): 105 | def hook(module, input, output): 106 | module.hook_count+=1 107 | if isinstance(module, BConv2d): 108 | module.compute_spr_loss(input[0]) 109 | module.layer_stats.selected_idx = selected_idx 110 | self.selected_out.append(module) 111 | return hook 112 | 113 | def set_lamda(self, lamda): 114 | for idx, module in enumerate(self.DSS_cand): 115 | module.lamda = (lamda+0.001) 116 | 117 | 118 | def compute_dense_syn_cnt(self, input): 119 | for fhook in self.fhooks: 120 | fhook.remove() 121 | self.fhooks.clear() 122 | self.selected_out.clear() 123 | 124 | # print('register_all_hook %d '%(len(self.fhooks))) 125 | for idx, module in enumerate(self.DSS_cand): 126 | self.fhooks.append(module.register_forward_hook(self.forward_hook(idx))) 127 | 128 | self(input) 129 | 130 | dense_syn_cnt = np.zeros(len(self.DSS_cand)) 131 | dense_act_cnt = np.zeros(len(self.DSS_cand)) 132 | for module in self.selected_out: 133 | stats = module.layer_stats 134 | dense_syn_cnt[stats.selected_idx]+=stats.dense_syn_cnt 135 | dense_act_cnt[stats.selected_idx]+=stats.dense_act_cnt 136 | 137 | 138 | self.dense_syn_cnt = dense_syn_cnt 139 | print('Dense SYN CNT (10^9): %f'%(dense_syn_cnt.sum())) 140 | print('Dense ACT CNT (10^9): %f'%(dense_act_cnt.sum())) 141 | 142 | def set_init_state(self, init_state): 143 | for module in self.DSS_cand: 144 | module.act_lsq.set_init_state(init_state) 145 | 146 | 147 | def aggregate_dss(self): 148 | # accumlate stats 149 | layer_stats_all = edict() 150 | layer_stats_all.op_loss = 0.0 151 | layer_stats_all.mac = 0.0 152 | layer_stats_all.sac = 0.0 153 | layer_stats_all.mac2 = 0.0 154 | layer_stats_all.sac2 = 0.0 155 | layer_stats_all.bit_cnt = 0.0 156 | layer_stats_all.wgt_cnt = 0.0 157 | layer_stats_all.act_cnt = 0.0 158 | 159 | layer_stats_all.dense_act_cnt = 0.0 160 | layer_stats_all.dense_syn_cnt = 0.0 161 | layer_stats_all.dense_wgt_cnt = 0.0 162 | 163 | for module in self.selected_out: 164 | stats = module.layer_stats 165 | layer_stats_all.op_loss +=stats.op_loss.to('cuda:0') 166 | layer_stats_all.mac +=stats.mac 167 | layer_stats_all.sac +=stats.sac 168 | layer_stats_all.mac2 +=stats.mac2 169 | layer_stats_all.sac2 +=stats.sac2 170 | layer_stats_all.bit_cnt +=stats.bit_cnt 171 | layer_stats_all.wgt_cnt +=stats.wgt_cnt 172 | layer_stats_all.act_cnt +=stats.act_cnt 173 | 174 | layer_stats_all.dense_act_cnt +=stats.dense_act_cnt 175 | layer_stats_all.dense_syn_cnt +=stats.dense_syn_cnt 176 | layer_stats_all.dense_wgt_cnt +=stats.dense_wgt_cnt 177 | 178 | # Normalize 179 | syn_scale = (np.sum(self.dense_syn_cnt)/layer_stats_all.dense_syn_cnt) 180 | layer_stats_all.op_loss /= layer_stats_all.dense_syn_cnt 181 | layer_stats_all.mac *= syn_scale 182 | layer_stats_all.sac *= syn_scale 183 | layer_stats_all.mac2 *= syn_scale 184 | layer_stats_all.sac2 *= syn_scale 185 | layer_stats_all.avg_bit = layer_stats_all.bit_cnt/layer_stats_all.dense_wgt_cnt 186 | layer_stats_all.avg_wgt = layer_stats_all.wgt_cnt/layer_stats_all.dense_wgt_cnt 187 | layer_stats_all.avg_act = layer_stats_all.act_cnt/layer_stats_all.dense_act_cnt 188 | return layer_stats_all -------------------------------------------------------------------------------- /core/layers/bconv2d.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from torchvision import transforms as transforms 7 | import torch.nn.init as init 8 | import math 9 | import os 10 | 11 | # Adopted form https://github.com/hustzxd/LSQuantization/blob/master/lsq.py 12 | def ste_round_(x, add_noise=False): 13 | eps = (x.round() - x).detach() 14 | # eps[eps.isnan()] = 0 15 | if add_noise==True and 0: 16 | pm = -eps.sgn() 17 | noise = pm*torch.randint_like(x,0,2) 18 | return x + eps + noise 19 | else: 20 | return x + eps 21 | 22 | class Round(torch.autograd.Function): 23 | @staticmethod 24 | def forward(ctx, x): 25 | return x.round() 26 | 27 | @staticmethod 28 | def backward(ctx, g): 29 | return g 30 | 31 | class Floor(torch.autograd.Function): 32 | @staticmethod 33 | def forward(ctx, x): 34 | return x.floor() 35 | 36 | @staticmethod 37 | def backward(ctx, g): 38 | return g 39 | 40 | 41 | # Code adopted from https://github.com/hustzxd/EfficientPyTorch/tree/1fcb533c7bfdafba4aba8272f1e0c34cbde91309 42 | def grad_scale(x, scale): 43 | y = x 44 | y_grad = x * scale 45 | return y.detach() - y_grad.detach() + y_grad 46 | 47 | def round_pass(x): 48 | y = x.round() 49 | y_grad = x 50 | return y.detach() - y_grad.detach() + y_grad 51 | 52 | class ActLSQ(nn.Module): 53 | def __init__(self, nbits=4): 54 | super(ActLSQ, self).__init__() 55 | self.nbits = nbits 56 | if self.nbits == 0: 57 | self.register_parameter('alpha', None) 58 | return 59 | requires_grad = nbits<=8 60 | self.register_buffer('init_state', torch.zeros(1)) 61 | self.alpha = nn.Parameter(torch.Tensor(1), requires_grad=requires_grad) 62 | self.itr = 0 63 | self.running_mean = 0 64 | 65 | def forward(self, x): 66 | if self.alpha is None: 67 | return x 68 | 69 | Qp = 2 ** self.nbits - 1 70 | if self.init_state == 0: 71 | self.running_mean+=x.abs().mean() 72 | self.itr+=1 73 | return x 74 | 75 | # return x 76 | g = 1.0 / math.sqrt(x.numel() * Qp) 77 | alpha = grad_scale(self.alpha, g) 78 | x = round_pass((x / alpha).clamp(None, Qp)) * alpha 79 | return x 80 | 81 | def set_init_state(self, init_state): 82 | Qp = 2 ** self.nbits - 1 83 | if self.alpha is None: 84 | return 85 | 86 | if init_state==0: 87 | self.init_state.fill_(0) 88 | self.running_mean=0 89 | elif init_state==1: 90 | self.init_state.fill_(1) 91 | self.alpha.data.copy_(2 * self.running_mean /self.itr / math.sqrt(Qp)) 92 | # print(self.alpha) 93 | else: 94 | pass 95 | 96 | def bit2min(bit): 97 | return -2**((bit-1)) 98 | 99 | def bit2max(bit): 100 | return 2**((bit-1)) 101 | 102 | def binary(x, bit): 103 | mask = 2**torch.arange(bit).to(x.device, x.dtype) 104 | return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte() 105 | 106 | def compute_wgt_vec(bit): 107 | return torch.linspace(bit2min(bit), bit2max(bit), 2**bit+1) 108 | 109 | def compute_bit_cost(bit): 110 | wgt_vec = compute_wgt_vec(bit) 111 | bit_cost = [] 112 | for v_ in wgt_vec: 113 | tmp = binary((v_.abs()).byte(), bit) 114 | bit_cost.append(tmp.sum()) 115 | 116 | bit_cost = torch.stack(bit_cost).float() 117 | 118 | # make dist matrix 119 | bit_cost_m = bit_cost.view([2**bit+1,1]).sub(bit_cost.view([1,2**bit+1])) 120 | return bit_cost, bit_cost_m 121 | 122 | def compute_wgt_cost(bit, pnorm): 123 | wgt_vec = compute_wgt_vec(bit) 124 | wgt_cost = torch.sgn(wgt_vec)*torch.pow(wgt_vec.abs(), pnorm) 125 | wgt_cost_m = -wgt_cost.view([-1,1]).sub(wgt_cost.view([1,-1])).abs() 126 | return wgt_cost, wgt_cost_m 127 | 128 | def compute_cost(bit, pnorm, lamda): 129 | bit_cost, bit_cost_m = compute_bit_cost(bit) 130 | wgt_cost, wgt_cost_m = compute_wgt_cost(bit, pnorm) 131 | 132 | cost = lamda*bit_cost_m + wgt_cost_m 133 | tgt_bin = cost.argmax(dim=1) 134 | 135 | # Copy target value's cost, and add the distance to the 136 | wgt_cost_new = torch.zeros_like(wgt_cost) 137 | wgt_cost_new = -wgt_cost_m[range(len(wgt_cost_new)), tgt_bin] 138 | cost = bit_cost[tgt_bin] + wgt_cost_new 139 | return cost 140 | 141 | def w2cost(w, cost, bit): 142 | w_ = w.add(2**(bit-1)).round() 143 | return cost[w_.long()].float() 144 | 145 | def abs_binalize(x): 146 | return x.abs().gt(0).to(x) 147 | 148 | ######################################################################## 149 | # Bit-Pruning Conv2d 150 | # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 151 | class BConv2d(nn.Module): 152 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=[0,0], dilation=1, groups=1, bias=True, cfg=None, is_input=False): 153 | super(BConv2d, self).__init__() 154 | self.kernel_size = kernel_size 155 | self.padding = padding 156 | self.dilation = dilation 157 | self.groups = groups 158 | self.in_channels = in_channels 159 | self.out_channels = out_channels 160 | self.padding = padding 161 | self.cfg = cfg 162 | 163 | if not isinstance(stride, (tuple, list)): 164 | stride = (stride, stride) 165 | 166 | self.stride = stride 167 | w_init_alg = 'kaiming_uniform' # 'kaiming_uniform', 'kaiming_normal' 168 | self.register_parameter("fweight", nn.Parameter(torch.zeros(out_channels, in_channels//groups, kernel_size[0], kernel_size[1]))) 169 | 170 | if w_init_alg=='kaiming_normal': 171 | init.kaiming_normal_(self.fweight) 172 | elif w_init_alg=='kaiming_uniform': 173 | init.kaiming_uniform_(self.fweight) 174 | 175 | 176 | self.mask_th = 0.0 177 | if bias: 178 | self.register_parameter("bias", nn.Parameter(torch.zeros(out_channels))) 179 | else: 180 | self.bias = None 181 | 182 | self.reg_type = 3 183 | self.hook_count = 0 184 | self.lamda = (self.cfg.optim.lamda_ini+0.001) 185 | 186 | self.thd_neg = bit2min(cfg.model.wgt_bit) 187 | self.thd_pos = bit2max(cfg.model.wgt_bit) 188 | 189 | bit_cost, bit_cost_m = compute_bit_cost(cfg.model.wgt_bit) 190 | wgt_cost, wgt_cost_m = compute_wgt_cost(cfg.model.wgt_bit, self.cfg.optim.wgt_p_norm) 191 | 192 | self.register_parameter("bit_cost", nn.Parameter(bit_cost, requires_grad=False)) 193 | self.register_parameter("bit_cost_m", nn.Parameter(bit_cost_m, requires_grad=False)) 194 | self.register_parameter("wgt_cost", nn.Parameter(wgt_cost, requires_grad=False)) 195 | self.register_parameter("wgt_cost_m", nn.Parameter(wgt_cost_m, requires_grad=False)) 196 | 197 | if self.cfg.optim.lamda_ini>0: 198 | cost = compute_cost(self.cfg.model.wgt_bit, self.cfg.optim.wgt_p_norm, self.lamda+0.001) 199 | else: 200 | cost = self.bit_cost 201 | self.register_parameter("cost", nn.Parameter(cost, requires_grad=False)) 202 | 203 | self.act_lsq = ActLSQ(0 if is_input else cfg.model.act_bit) 204 | 205 | self.is_selected = False 206 | self.layer_stats = edict() 207 | self.layer_stats.dense_act_cnt = 1 208 | self.layer_stats.dense_syn_cnt = 1 209 | self.layer_stats.dense_wgt_cnt = torch.numel(self.fweight) 210 | 211 | self.register_parameter("scale", nn.Parameter(torch.ones([out_channels, 1, 1, 1]))) 212 | self.set_scale() 213 | 214 | def set_scale(self, type='absmax'): 215 | with torch.no_grad(): 216 | v_max = self.fweight.abs().flatten(1).max(dim=1)[0] 217 | v_std = 3*self.fweight.flatten(1).std(dim=1) 218 | 219 | if type=='absmax': 220 | scale = v_max.div(self.thd_pos) 221 | else: 222 | scale = torch.minimum(v_max, v_std).div(self.thd_pos) 223 | 224 | self.scale.data.copy_(scale.view([self.out_channels, 1, 1, 1]).data) 225 | 226 | def count_bit(self, x, cost, bit): 227 | hist = torch.histogram(x.cpu(), bins=2**bit+1, range=(bit2min(bit)-0.5, bit2max(bit)+0.5), density=False) 228 | bit_cnt = hist[0].mul(cost.cpu().squeeze()).sum() 229 | return bit_cnt, self.layer_stats.dense_wgt_cnt - hist[0][2**(bit-1)] 230 | 231 | def calibrate_masked_weight(self): 232 | self.fweight.data = self.fweight.data + torch.sign(self.fweight.data)*self.mask_th 233 | 234 | def compute_out_shape(self,x): 235 | b, c, h, w = x.shape 236 | h_ = np.floor((h+2*self.padding[0]-1*(self.kernel_size[0]-1)-1)/self.stride[0]+1)-(self.dilation[0]-1)*2 237 | w_ = np.floor((w+2*self.padding[1]-1*(self.kernel_size[1]-1)-1)/self.stride[1]+1)-(self.dilation[0]-1)*2 238 | 239 | self.view4 = (-1, c*np.prod(self.kernel_size), int(h_), int(w_)) 240 | self.view5 = (-1, c, np.prod(self.kernel_size), int(h_), int(w_)) 241 | return self.view4, self.view5 242 | 243 | def get_dense_syn_cnt(self): 244 | return self.dense_syn_cnt 245 | 246 | def tgt_loss(self, x_unfold, w, mask_act=False, use_correction=False): 247 | tgt = self.compute_prox() 248 | 249 | if use_correction: 250 | with torch.no_grad(): 251 | d_cost = self.compute_dcost() 252 | d = F.softshrink((tgt-w), lambd=0.5).mul(d_cost).div(6).div(256) 253 | else: 254 | d = F.softshrink(tgt-w, lambd=0.5).div(256) 255 | 256 | if mask_act: 257 | L = self.act_spr_loss(x_unfold, d) 258 | else: 259 | L = self.wgt_spr_loss(d) 260 | return L 261 | 262 | 263 | def bilinear_loss(self, x_unfold, w, mask_act=False): 264 | bit = self.cfg.model.wgt_bit 265 | cost = self.cost 266 | 267 | pi = F.pad(torch.tensor(cost).to(w.device).float(), (1,1), mode='constant', value=cost[0]) 268 | 269 | pi = pi.div(pi.sum()) 270 | 271 | w = w.add(1.0 + 2**(bit-1)) 272 | 273 | x_floor = w.floor().long() 274 | v_floor = pi[x_floor] 275 | 276 | x_ceil = w.ceil().long() 277 | v_ceil = pi[x_ceil] 278 | 279 | d_floor = w - w.floor() 280 | d_ceil = (1-d_floor) 281 | 282 | v = v_floor*d_ceil + v_ceil*d_floor 283 | 284 | if mask_act: 285 | L = torch.einsum('oi, bimn->b', v.flatten(1), x_unfold).mean() 286 | else: 287 | H, W = x_unfold.shape[-2:] 288 | L = v.sum() 289 | return L 290 | 291 | def act_spr_loss(self, x, weight): 292 | if self.groups>1: 293 | x = x.view(self.view5) 294 | else: 295 | x = x.view(self.view4) 296 | # d = d.view([out_shape[0], -1, out_shape[-2]* out_shape[-1]]) 297 | weight = weight.flatten(1) 298 | if self.reg_type==3: 299 | # Square-Hoyer 300 | if self.groups>1: 301 | sum1 = torch.einsum('oi, boimn->b', weight.abs(), x) 302 | sum2 = torch.einsum('oi, boimn->b', weight.abs(), x**2) 303 | else: 304 | sum1 = torch.einsum('oi, bimn->b', weight.abs(), x) 305 | sum2 = torch.einsum('oi, bimn->b', weight.abs(), x**2) 306 | dss = (sum1**2)/(sum2+1e-06) 307 | elif self.reg_type==4: 308 | # Group-Square-Hoyer 309 | sum1 = torch.einsum('oi, bimn->bi', weight.abs(), x) 310 | sum2 = torch.einsum('oi, bimn->bi', weight.abs(), x**2) 311 | dss = ((sum1**2)/(sum2+1e-06)).sum(1) 312 | elif self.reg_type==1: 313 | # 1 norm 314 | dss = torch.einsum('oi, bimn->b', weight.abs(), x) 315 | return dss.mean() 316 | 317 | def calc_l1_and_zero_ratio(self, weights, scale): 318 | x = Round.apply(weights.abs() / 2 ** (scale - 8)) 319 | 320 | b1 = Floor.apply(x/64) 321 | b2 = Floor.apply((x-b1.detach()*64)/16) 322 | b3 = Floor.apply((x-b1.detach()*64-b2.detach()*16)/4) 323 | b4 = x-b1.detach()*64-b2.detach()*16-b3.detach()*4 324 | 325 | l1_norm = b1.abs().sum() + b2.abs().sum() + b3.abs().sum() + b4.abs().sum() 326 | return l1_norm 327 | 328 | def wgt_spr_loss(self, weight): 329 | # d = d.view([out_shape[0], -1, out_shape[-2]* out_shape[-1]]) 330 | if self.reg_type==3: 331 | # Square-Hoyer 332 | sum1 = weight.abs().flatten(0).sum(dim=0) 333 | sum2 = (weight**2).flatten(0).sum(dim=0) 334 | L = (sum1**2)/(sum2+1e-06) 335 | elif self.reg_type==4: 336 | # Group-Square-Hoyer 337 | sum1 = weight.abs().flatten(1).sum(dim=1) 338 | sum2 = (weight**2).flatten(1).sum(dim=1) 339 | L = (sum1**2)/(sum2+1e-06) 340 | L = L.sum() 341 | elif self.reg_type==1: 342 | # 1 norm 343 | L = weight.abs().sum() 344 | return L.mean() 345 | 346 | def compute_spr_loss(self, x): 347 | if not self.is_conv: 348 | x = x.permute([0,3,1,2]) 349 | 350 | x = self.act_lsq(x) 351 | # x = x.relu() 352 | 353 | self.compute_out_shape(x) 354 | weight_s = self.scale_w(self.fweight) 355 | weight_r = self.round_w(weight_s, prep=True) 356 | weight_q = self.qunatize_w(weight_r, prep=True) 357 | weight_b = abs_binalize(weight_r) # no-grad 358 | 359 | # Unfold 360 | if self.is_conv: 361 | I_unfold = F.unfold(x, kernel_size=self.kernel_size, padding=self.padding, stride=self.stride, dilation=self.dilation).view(self.view4) 362 | else: 363 | I_unfold = x 364 | 365 | bit_cnt, wgt_cnt = self.count_bit(weight_s, self.bit_cost, self.cfg.model.wgt_bit) 366 | 367 | space_scale = np.prod(self.view4[-2:])/self.groups 368 | 369 | op_loss = torch.tensor([0.0]).to(x) 370 | if self.cfg.optim.loss_type[:3]=='act': 371 | if self.cfg.optim.loss_type in ['act_bilinear']: 372 | op_loss = self.bilinear_loss(I_unfold, weight_s, mask_act=True) 373 | op_loss = op_loss.mul(32) # magic number to make the loss comparable to other losses 374 | elif self.cfg.optim.loss_type in ['act_tgt']: 375 | op_loss = self.tgt_loss(I_unfold, weight_s, mask_act=True, use_correction=self.cfg.optim.use_correction) 376 | op_loss = op_loss.mul(256) # magic number to make the loss comparable to other losses 377 | elif self.cfg.optim.loss_type in ['act_slice']: 378 | op_loss = self.calc_l1_and_zero_ratio(self.fweight, self.scale) 379 | op_loss = op_loss.mul(256) # magic number to make the loss comparable to other losses 380 | elif self.cfg.optim.loss_type in ['act_naive']: 381 | op_loss = self.act_spr_loss(I_unfold, weight_q) 382 | else: 383 | os.error('unsupported loss_type ' + self.cfg.optim.loss_type) 384 | elif self.cfg.optim.loss_type[:3]=='wgt': 385 | if self.cfg.optim.loss_type in ['wgt_bilinear']: 386 | op_loss = self.bilinear_loss(I_unfold, weight_s, mask_act=False) 387 | op_loss = op_loss.mul(space_scale).mul(32) # magic number to make the loss comparable to other losses 388 | elif self.cfg.optim.loss_type in ['wgt_tgt']: 389 | op_loss = self.tgt_loss(I_unfold, weight_s, mask_act=False).mul(space_scale) 390 | op_loss = op_loss.mul(space_scale).mul(4) # magic number to make the loss comparable to other losses 391 | elif self.cfg.optim.loss_type in ['wgt_slice']: 392 | op_loss = self.calc_l1_and_zero_ratio(self.fweight, self.scale) 393 | op_loss = op_loss.mul(256) # magic number to make the loss comparable to other losses 394 | elif self.cfg.optim.loss_type in ['wgt_naive']: 395 | op_loss = self.wgt_spr_loss(weight_q).mul(space_scale) 396 | else: 397 | os.error('unsupported loss_type ' + self.cfg.optim.loss_type) 398 | else: 399 | os.error('unsupported loss_type ' + self.cfg.optim.loss_type) 400 | 401 | with torch.no_grad(): 402 | w_cost = w2cost(weight_r, self.bit_cost.to(x.device), self.cfg.model.wgt_bit) 403 | if self.groups>1: 404 | mac = torch.einsum('oi, boimn->b', weight_b.flatten(1), (I_unfold.view(self.view5)!=0).to(x)).mean() 405 | sac = torch.einsum('oi, boimn->b', w_cost.flatten(1), (I_unfold.view(self.view5)!=0).to(x)).mean() 406 | else: 407 | mac = torch.einsum('oi, bimn->b', weight_b.flatten(1), (I_unfold!=0).to(x)).mean() 408 | sac = torch.einsum('oi, bimn->b', w_cost.flatten(1), (I_unfold!=0).to(x)).mean() 409 | mac2 = wgt_cnt.mul(space_scale) 410 | sac2 = bit_cnt.mul(space_scale) 411 | act_cnt = torch.count_nonzero(x)/x.shape[0] 412 | 413 | assert(~torch.isnan(op_loss)) 414 | self.layer_stats.op_loss = op_loss 415 | self.layer_stats.mac = mac.item() 416 | self.layer_stats.sac = sac.item() 417 | self.layer_stats.mac2 = mac2.item() # weight sparsity only 418 | self.layer_stats.sac2 = sac2.item() # weight sparsity only 419 | self.layer_stats.bit_cnt = bit_cnt.item() 420 | self.layer_stats.wgt_cnt = wgt_cnt.item() 421 | self.layer_stats.act_cnt = act_cnt.item() 422 | 423 | # https://pmelchior.net/blog/proximal-matrix-factorization-in-pytorch.html 424 | def solve_argmin(self, lamda=1.001): 425 | cost = lamda*self.bit_cost_m + self.wgt_cost_m 426 | tgt_bin = cost.argmax(dim=1) 427 | return tgt_bin 428 | 429 | def compute_prox(self): 430 | min_idx = self.solve_argmin(self.lamda) 431 | rw = self.round_w((self.fweight)).sub(self.thd_neg).long() 432 | return min_idx[rw].float().add(self.thd_neg) 433 | 434 | def compute_dcost(self): 435 | min_idx = self.solve_argmin(self.lamda) 436 | rw = self.round_w((self.fweight)).sub(self.thd_neg).long() 437 | d_cost = self.bit_cost[rw] - self.bit_cost[min_idx][rw] 438 | return d_cost.float() 439 | 440 | 441 | def project_prox(self, x): 442 | with torch.no_grad(): 443 | # Project to proximal point 444 | fweight_ = self.compute_prox().mul(self.scale) 445 | output_prox = F.conv2d(x, fweight_, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation) # 'zeros' 446 | output_org = F.conv2d(x, self.qunatize_w(self.fweight), bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation) # 'zeros' 447 | 448 | eps = (output_prox - output_org).detach() 449 | return output_org + eps 450 | 451 | # LSQ Qauntization 452 | def qunatize_w(self,w, prep=False): 453 | if prep: 454 | return w.mul(self.scale) 455 | else: 456 | return self.round_w(w).mul(self.scale) 457 | def round_w(self,w, prep=False): 458 | if prep: 459 | return ste_round_(w) 460 | else: 461 | return ste_round_(self.scale_w(w)) 462 | def scale_w(self,w): 463 | return torch.clamp(w.div(self.scale), self.thd_neg-0.5, self.thd_pos+0.5) 464 | 465 | 466 | def forward(self, x): 467 | x = self.act_lsq(x) 468 | self.layer_stats.dense_act_cnt=np.prod([*x.shape[1:]]) 469 | if self.is_conv: 470 | self.layer_stats.dense_syn_cnt=np.prod([*self.fweight.shape])*np.prod([*x.shape[2:]])/(np.prod(self.stride))/(self.groups) 471 | else: 472 | self.layer_stats.dense_syn_cnt=np.prod([*self.fweight.shape])*np.prod([*x.shape[1:3]])/(np.prod(self.stride)) 473 | 474 | # Experimental code (Project to proximal point in forward pass) 475 | if self.is_selected and self.cfg.optim.loss_type in ['wgt_prox']: 476 | return self.project_prox(x) 477 | 478 | if self.cfg.optim.loss_type[:4]=='fp32': 479 | fweight_ = self.fweight 480 | else: 481 | fweight_ = self.qunatize_w(self.fweight) 482 | 483 | # Main convolution 484 | if self.is_conv: 485 | output = F.conv2d(x, fweight_, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) 486 | else: 487 | output = F.linear(x, fweight_[:,:,0,0], bias=self.bias) 488 | 489 | return output -------------------------------------------------------------------------------- /core/model_converter.py: -------------------------------------------------------------------------------- 1 | from logging import error 2 | import torch.nn as nn 3 | from core.layers.bconv2d import BConv2d 4 | from timm.models.layers import trunc_normal_, DropPath 5 | 6 | # Ref 7 | # https://discuss.pytorch.org/t/how-can-i-replace-an-intermediate-layer-in-a-pre-trained-network/3586/5 8 | 9 | 10 | def convert_layers(model, cfg): 11 | conversion_count = 0 12 | for name, module in reversed(model._modules.items()): 13 | if module is not None and len(list(module.children())) > 0 and type(module) != nn.Conv2d: 14 | model._modules[name], num_converted = convert_layers(module, cfg) 15 | 16 | if type(module) == nn.Conv2d: 17 | in_channels, out_channels = module.in_channels, module.out_channels 18 | out_channels = out_channels//cfg.model.width 19 | if not hasattr(module, 'is_input'): 20 | in_channels = in_channels//cfg.model.width 21 | module_new = BConv2d(in_channels, out_channels, module.kernel_size, module.stride, 22 | module.padding, module.dilation, module.groups, 23 | module.bias is not None, cfg, is_input = hasattr(module, 'is_input')) 24 | 25 | 26 | module_new.is_conv = True 27 | if cfg.model.pretrained==1: 28 | # module_new.fweight = module.weight 29 | # module_new.bias = module.bias 30 | module_new.fweight.data.copy_(module.weight.data) 31 | module_new.bias.data.copy_(module.bias.data) 32 | # module_new.calibrate_masked_weight() 33 | module_new.set_scale(type='absmax') 34 | 35 | # module_new.loss_type = loss_type 36 | model._modules[name] = module_new 37 | 38 | conversion_count += 1 39 | 40 | if type(module) == nn.BatchNorm2d: 41 | # print(module) 42 | num_features = module.num_features 43 | num_features = num_features//cfg.model.width 44 | model._modules[name] = nn.BatchNorm2d(num_features) 45 | conversion_count += 1 46 | 47 | if type(module) == nn.Linear: 48 | if module.out_features==1000 or module.out_features==100 or module.out_features==10: 49 | in_features, out_features = module.in_features, module.out_features 50 | # print(out_features) 51 | in_features = in_features//cfg.model.width 52 | module_new = nn.Linear(in_features, out_features) 53 | else: 54 | module_new = BConv2d(module.in_features, module.out_features, [1,1], [1,1], 55 | [0,0], [1,1], 1, 56 | module.bias is not None, cfg, is_input = False) 57 | 58 | module_new.is_conv = False 59 | if cfg.model.pretrained==1: 60 | module_new.fweight.data.copy_(module.weight[:,:,None,None].data) 61 | module_new.bias.data.copy_(module.bias.data) 62 | module_new.set_scale(type='absmax') 63 | # module_new.calibrate_masked_weight() 64 | 65 | # module_new.name = module.name 66 | 67 | model._modules[name] = module_new 68 | 69 | conversion_count += 1 70 | 71 | if hasattr(module, "drop_path"): 72 | module.drop_path = DropPath(cfg.optim.drop_path) if cfg.optim.drop_path > 0. else nn.Identity() 73 | 74 | if type(module) == nn.GELU: 75 | model._modules[name] = nn.ReLU() 76 | 77 | 78 | 79 | return model, conversion_count 80 | 81 | if __name__ == '__main__': 82 | print('__main__') -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from torchvision import datasets, transforms 4 | from timm.data.constants import \ 5 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 6 | from timm.data import create_transform 7 | import torch.distributed as dist 8 | 9 | def is_dist_avail_and_initialized(): 10 | if not dist.is_available(): 11 | return False 12 | if not dist.is_initialized(): 13 | return False 14 | return True 15 | 16 | 17 | def get_world_size(): 18 | if not is_dist_avail_and_initialized(): 19 | return 1 20 | return dist.get_world_size() 21 | 22 | 23 | def get_rank(): 24 | if not is_dist_avail_and_initialized(): 25 | return 0 26 | return dist.get_rank() 27 | 28 | 29 | from pl_bolts.transforms.dataset_normalizations import ( 30 | cifar10_normalization, 31 | imagenet_normalization, 32 | ) 33 | 34 | train_transforms_cifar = transforms.Compose( 35 | [ 36 | transforms.RandomCrop(32, padding=4), 37 | transforms.RandomHorizontalFlip(), 38 | transforms.ToTensor(), 39 | cifar10_normalization(), 40 | ] 41 | ) 42 | 43 | test_transforms_cifar = transforms.Compose( 44 | [ 45 | transforms.ToTensor(), 46 | cifar10_normalization(), 47 | ] 48 | ) 49 | 50 | def build_dataset(is_train, cfg): 51 | if cfg.dataset.name == 'CIFAR100': 52 | transform = build_transform(is_train, cfg) 53 | dataset = datasets.CIFAR100(cfg.dataset.path, train=is_train, transform=transform, download=True) 54 | nb_classes = 100 55 | elif cfg.dataset.name == 'cifar100': 56 | if is_train: 57 | transform = train_transforms_cifar 58 | else: 59 | transform = test_transforms_cifar 60 | dataset = datasets.CIFAR100(cfg.dataset.path, train=is_train, transform=transform, download=True) 61 | nb_classes = 100 62 | elif cfg.dataset.name == 'CIFAR10': 63 | transform = build_transform(is_train, cfg) 64 | dataset = datasets.CIFAR10(cfg.dataset.path, train=is_train, transform=transform, download=True) 65 | nb_classes = 10 66 | elif cfg.dataset.name == 'cifar10': 67 | if is_train: 68 | transform = train_transforms_cifar 69 | else: 70 | transform = test_transforms_cifar 71 | dataset = datasets.CIFAR10(cfg.dataset.path, train=is_train, transform=transform, download=True) 72 | nb_classes = 10 73 | elif cfg.dataset.name == 'IMNET': 74 | print("reading from datapath", cfg.dataset.path) 75 | transform = build_transform(is_train, cfg) 76 | root = os.path.join(cfg.dataset.path, 'train' if is_train else 'val') 77 | dataset = datasets.ImageFolder(root, transform=transform) 78 | nb_classes = 1000 79 | elif cfg.dataset.name == "image_folder": 80 | root = cfg.dataset.path if is_train else cfg.dataset.eval_data_path 81 | dataset = datasets.ImageFolder(root, transform=transform) 82 | nb_classes = cfg.dataset.nb_classes 83 | assert len(dataset.class_to_idx) == nb_classes 84 | else: 85 | raise NotImplementedError() 86 | 87 | print("Transform = ") 88 | if isinstance(transform, tuple): 89 | for trans in transform: 90 | print(" - - - - - - - - - - ") 91 | for t in trans.transforms: 92 | print(t) 93 | else: 94 | for t in transform.transforms: 95 | print(t) 96 | print("---------------------------") 97 | 98 | print("Number of the class = %d" % nb_classes) 99 | 100 | return dataset, nb_classes 101 | 102 | # adopted from https://github.com/facebookresearch/ConvNeXt/blob/33440594b4221b713d493ce11f33b939c4afd696/datasets.py 103 | def build_transform(is_train, cfg): 104 | imagenet_default_mean_and_std = cfg.dataset.imagenet_default_mean_and_std 105 | resize_im = cfg.dataset.input_size > 32 106 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 107 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 108 | 109 | if is_train: 110 | # this should always dispatch to transforms_imagenet_train 111 | transform = create_transform( 112 | input_size=cfg.dataset.input_size, 113 | is_training=True, 114 | color_jitter=cfg.dataset.color_jitter, 115 | auto_augment=cfg.dataset.aa, 116 | interpolation=cfg.dataset.train_interpolation, 117 | re_prob=cfg.dataset.reprob, 118 | re_mode=cfg.dataset.remode, 119 | re_count=cfg.dataset.recount, 120 | mean=mean, 121 | std=std, 122 | ) 123 | if not resize_im: 124 | transform.transforms[0] = transforms.RandomCrop( 125 | cfg.dataset.input_size, padding=4) 126 | return transform 127 | 128 | t = [] 129 | if resize_im: 130 | # warping (no cropping) when evaluated at 384 or larger 131 | if cfg.dataset.input_size >= 384: 132 | t.append( 133 | transforms.Resize((cfg.dataset.input_size, cfg.dataset.input_size), 134 | interpolation=transforms.InterpolationMode.BICUBIC), 135 | ) 136 | print(f"Warping {cfg.dataset.input_size} size input images...") 137 | else: 138 | if cfg.dataset.crop_pct<=0: 139 | cfg.dataset.crop_pct = 224 / 256 140 | size = int(cfg.dataset.input_size / cfg.dataset.crop_pct) 141 | t.append( 142 | # to maintain same ratio w.r.t. 224 images 143 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), 144 | ) 145 | t.append(transforms.CenterCrop(cfg.dataset.input_size)) 146 | 147 | t.append(transforms.ToTensor()) 148 | t.append(transforms.Normalize(mean, std)) 149 | return transforms.Compose(t) 150 | 151 | 152 | def getLoader(cfg): 153 | dataset_train, nb_classes = build_dataset(True, cfg) 154 | dataset_val, nb_classes = build_dataset(False, cfg) 155 | cfg.dataset.input_shape = [1, 3, cfg.dataset.input_size, cfg.dataset.input_size] 156 | 157 | num_tasks = get_world_size() 158 | global_rank = get_rank() 159 | sampler_train = torch.utils.data.DistributedSampler( 160 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=cfg.misc.seed, 161 | ) 162 | data_loader_train = torch.utils.data.DataLoader( 163 | dataset_train, sampler=sampler_train, 164 | batch_size=cfg.optim.batch_size, 165 | num_workers=cfg.hardware.num_cpu_workers, 166 | pin_memory=cfg.dataset.pin_mem, 167 | drop_last=True, 168 | ) 169 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 170 | data_loader_val = torch.utils.data.DataLoader( 171 | dataset_val, sampler=sampler_val, 172 | batch_size=cfg.optim.batch_size, 173 | num_workers=cfg.hardware.num_cpu_workers, 174 | pin_memory=cfg.dataset.pin_mem, 175 | drop_last=False 176 | ) 177 | return data_loader_train, data_loader_val 178 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | import numpy as np 3 | import pickle 4 | import os 5 | import torch 6 | 7 | class AverageMeter(object): 8 | """Computes and stores the average and current value""" 9 | def __init__(self): 10 | self.initialized = False 11 | self.val = None 12 | self.avg = None 13 | self.sum = None 14 | self.count = None 15 | 16 | def initialize(self, val, weight): 17 | self.val = val 18 | self.avg = val 19 | self.sum = val * weight 20 | self.count = weight 21 | self.initialized = True 22 | 23 | def update(self, val, weight=1): 24 | if not self.initialized: 25 | self.initialize(val, weight) 26 | else: 27 | self.add(val, weight) 28 | 29 | def add(self, val, weight): 30 | self.val = val 31 | self.sum += val * weight 32 | self.count += weight 33 | self.avg = self.sum / self.count 34 | 35 | def value(self): 36 | return self.val 37 | 38 | def average(self): 39 | return self.avg 40 | 41 | class STATS(): 42 | def __init__(self, config): 43 | self.cfg = config 44 | log = edict() 45 | log.train = edict() 46 | log.val = edict() 47 | log.test = edict() 48 | 49 | 50 | def init_loss(e_obj): 51 | e_obj.ce_loss = np.zeros(config.optim.epochs) 52 | e_obj.acc = np.zeros(config.optim.epochs) 53 | e_obj.acc5 = np.zeros(config.optim.epochs) 54 | e_obj.mac = np.zeros(config.optim.epochs) 55 | e_obj.sac = np.zeros(config.optim.epochs) 56 | e_obj.mac2 = np.zeros(config.optim.epochs) 57 | e_obj.sac2 = np.zeros(config.optim.epochs) 58 | e_obj.op_loss = np.zeros(config.optim.epochs) 59 | e_obj.avg_wgt = np.zeros(config.optim.epochs) 60 | e_obj.avg_act = np.zeros(config.optim.epochs) 61 | e_obj.avg_bit = np.zeros(config.optim.epochs) 62 | e_obj.lr = np.zeros(config.optim.epochs) 63 | init_loss(log.train) 64 | init_loss(log.val) 65 | init_loss(log.test) 66 | 67 | log.best = edict() 68 | log.best.acc = 0.0 69 | log.best.ce_loss = 0.0 70 | log.best.itr = 0 71 | log.best.epoch = -1 72 | 73 | self.log = log 74 | self.epoch = 0 75 | 76 | def init_meter(self, phase): 77 | self.phase = phase 78 | if phase=='train': 79 | print("<<<<<<<<< start epoch{:3d} {} {}".format(self.epoch, self.cfg.dataset.name, self.cfg.cfg_name + ' ' +self.cfg.file_name)) 80 | 81 | self.ce_loss_meter = AverageMeter() 82 | self.acc_meter = AverageMeter() 83 | self.acc5_meter = AverageMeter() 84 | self.mac_meter = AverageMeter() 85 | self.sac_meter = AverageMeter() 86 | self.mac2_meter = AverageMeter() 87 | self.sac2_meter = AverageMeter() 88 | self.op_loss_meter = AverageMeter() 89 | self.avg_wgt_meter = AverageMeter() 90 | self.avg_act_meter = AverageMeter() 91 | self.avg_bit_meter = AverageMeter() 92 | self.ce_loss_meter = AverageMeter() 93 | self.lr_meter = AverageMeter() 94 | 95 | def update_meter(self, ce_loss, outputs, labels, lr, layer_stats): 96 | acc1, acc5 = self.__accuracy(outputs, labels, topk=(1, 5)) 97 | 98 | self.ce_loss_meter.update(ce_loss.item()) 99 | self.acc_meter.update(acc1.item()) 100 | self.acc5_meter.update(acc5.item()) 101 | self.op_loss_meter.update(layer_stats.op_loss.item()) 102 | self.mac_meter.update(layer_stats.mac) 103 | self.sac_meter.update(layer_stats.sac) 104 | self.mac2_meter.update(layer_stats.mac2) 105 | self.sac2_meter.update(layer_stats.sac2) 106 | self.avg_wgt_meter.update(layer_stats.avg_wgt) 107 | self.avg_act_meter.update(layer_stats.avg_act) 108 | self.avg_bit_meter.update(layer_stats.avg_bit) 109 | self.lr_meter.update(lr) 110 | 111 | 112 | def save(self, logname='log2'): 113 | self.log[self.phase].ce_loss[self.epoch] = self.ce_loss_meter.average() 114 | self.log[self.phase].acc[self.epoch] = self.acc_meter.average() 115 | self.log[self.phase].acc5[self.epoch] = self.acc5_meter.average() 116 | self.log[self.phase].mac[self.epoch] = self.mac_meter.average() 117 | self.log[self.phase].sac[self.epoch] = self.sac_meter.average() 118 | self.log[self.phase].mac2[self.epoch] = self.mac2_meter.average() 119 | self.log[self.phase].sac2[self.epoch] = self.sac2_meter.average() 120 | self.log[self.phase].op_loss[self.epoch] = self.op_loss_meter.average() 121 | self.log[self.phase].avg_wgt[self.epoch] = self.avg_wgt_meter.average() 122 | self.log[self.phase].avg_act[self.epoch] = self.avg_act_meter.average() 123 | self.log[self.phase].avg_bit[self.epoch] = self.avg_bit_meter.average() 124 | self.log[self.phase].lr[self.epoch] = self.lr_meter.average() 125 | 126 | 127 | if self.phase=='val' and self.log[self.phase].acc[self.epoch]>self.log.best.acc: 128 | self.log.best.acc = self.log[self.phase].acc[self.epoch] 129 | self.log.best.epoch = self.epoch 130 | print(" Best updated! epoch{:3d}, acc: {:0.3f}, ".format(self.log.best.epoch, self.log.best.acc)) 131 | 132 | with open(os.path.join(self.cfg.stats_dir, logname), 'wb') as handle: 133 | pickle.dump(self.log, handle, protocol=pickle.HIGHEST_PROTOCOL) 134 | 135 | # print("finished epoch{:3d} {}>>>>>>>>>>".format(self.epoch, self.phase,)) 136 | 137 | def disp(self): 138 | print_scale_op_loss = 1e-3 139 | print_scale_mac = 1e9 140 | 141 | print(" {}, lr:{:0.5f} epoch{:3d}, acc: {:0.3f}, acc5: {:0.3f}, ce: {:2.3f}, spr: {:2.3f}, mac: {:2.3f}, {:2.3f}, sac: {:2.3f},{:2.3f} avg_bit: {:0.4f}, avg_wgt: {:0.4f}, avg_act: {:0.4f}".format( 142 | self.phase, 143 | self.lr_meter.average(), 144 | self.epoch, 145 | self.acc_meter.average(), 146 | self.acc5_meter.average(), 147 | self.ce_loss_meter.average(), 148 | self.op_loss_meter.average()/print_scale_op_loss, 149 | self.mac_meter.average()/print_scale_mac, 150 | self.mac2_meter.average()/print_scale_mac, 151 | self.sac_meter.average()/print_scale_mac, 152 | self.sac2_meter.average()/print_scale_mac, 153 | self.avg_bit_meter.average(), 154 | self.avg_wgt_meter.average(), 155 | self.avg_act_meter.average()) 156 | ) 157 | 158 | 159 | @staticmethod 160 | def __accuracy(output, target, topk=(1,)): 161 | """Computes the accuracy over the k top predictions for the specified values of k.""" 162 | with torch.no_grad(): 163 | maxk = max(topk) 164 | batch_size = target.size(0) 165 | 166 | _, pred = output.topk(maxk, 1, True, True) 167 | pred = pred.t() 168 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 169 | 170 | res = [] 171 | for k in topk: 172 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 173 | res.append(correct_k.mul_(1.0 / batch_size)) 174 | return res -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torch.optim import lr_scheduler 5 | import tqdm 6 | import torch.nn as nn 7 | from logger import STATS 8 | from core.bnet import BNet 9 | from model import get_model 10 | import argconfig 11 | import dataset 12 | import timm 13 | 14 | def extract_batch(batch): 15 | x, y = batch 16 | x, y = x.to('cuda'), y.to('cuda') 17 | return x, y 18 | 19 | def main(): 20 | cfg = argconfig.load() 21 | train_loader, val_loader = dataset.getLoader(cfg) 22 | cfg.freeze() 23 | 24 | model = get_model(cfg) 25 | net = BNet(model, cfg) 26 | 27 | # Setup Optimizer 28 | wd0_para = [] 29 | wd1_para = [] 30 | for name, value in net.model.named_parameters(): 31 | if "scale" in name: 32 | wd0_para += [value] 33 | elif 'alpha' in name: 34 | wd0_para += [value] 35 | elif "fweight" in name: 36 | if cfg.optim.enable_decay: 37 | wd1_para += [value] 38 | else: 39 | wd0_para += [value] 40 | else: 41 | wd1_para += [value] 42 | 43 | if cfg.optim.optimizer=='SGD': 44 | optimizer = torch.optim.SGD(wd1_para,lr=cfg.optim.lr_core, momentum=cfg.optim.momentum, weight_decay=cfg.optim.weight_decay) 45 | elif cfg.optim.optimizer=='AdamW': 46 | optimizer = optim.AdamW(wd1_para, lr=cfg.optim.lr_core, weight_decay=cfg.optim.weight_decay) 47 | 48 | optimizer.add_param_group({"params": wd0_para, 'lr':cfg.optim.lr_mask,'weight_decay':0.0}) # BASE_LR: 0.00004, ADAMW 0.001 49 | 50 | if cfg.optim.scheduler=='OneCycleLR': 51 | steps_per_epoch = len(train_loader) 52 | scheduler = lr_scheduler.OneCycleLR(optimizer,[cfg.optim.lr_core, cfg.optim.lr_mask], epochs=cfg.optim.epochs, steps_per_epoch=steps_per_epoch) 53 | elif cfg.optim.scheduler=='ExponentialLR': 54 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=cfg.optim.gamma) 55 | elif cfg.optim.scheduler=='CosineAnnealingLR': 56 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.optim.T_max) 57 | elif cfg.optim.scheduler=='CosineDecay': 58 | scheduler = timm.scheduler.CosineLRScheduler(optimizer, t_initial=cfg.optim.epochs, warmup_lr_init=cfg.optim.lr_core/1e3, warmup_t=10) 59 | 60 | if cfg.model.pretrained==2 and cfg.optim.spr_w>0: 61 | pretrain_path = cfg.model.pretrain_path 62 | net.load_state_dict(torch.load(pretrain_path), strict=True) 63 | print('loading '+ pretrain_path) 64 | # Move to GPU 65 | if torch.cuda.is_available(): 66 | net = net.cuda() 67 | num_gpu = list(range(torch.cuda.device_count())) 68 | 69 | stats = STATS(cfg) 70 | if cfg.optim.smoothing>0: 71 | criterion = timm.loss.LabelSmoothingCrossEntropy(smoothing=cfg.optim.smoothing) 72 | else: 73 | criterion = nn.CrossEntropyLoss() 74 | 75 | 76 | # Dummy forward for initializing LSQ, We do NOT need this when using pre-trained model from same bit 77 | if (cfg.model.pretrained==2 and cfg.optim.spr_w>0) or cfg.model.act_bit<=0: 78 | net.set_init_state(-1) 79 | else: 80 | net.set_init_state(0) 81 | dummy(cfg, stats, criterion, net, optimizer, scheduler, train_loader) 82 | net.set_init_state(1) 83 | 84 | # Parallel GPUs 85 | net = torch.nn.DataParallel(net, device_ids=cfg.hardware.gpu_device) 86 | 87 | for epoch in range(cfg.optim.epochs): 88 | stats.epoch = epoch 89 | stats.init_meter('train') 90 | train(cfg, stats, criterion, net, optimizer, scheduler, train_loader) 91 | stats.save(cfg.misc.log_name) 92 | stats.disp() 93 | 94 | stats.init_meter('val') 95 | evaluate(cfg, stats, criterion, net, optimizer, scheduler, val_loader) 96 | stats.save() 97 | stats.disp() 98 | 99 | if stats.epoch% 10 == 0: 100 | torch.save(net.module.state_dict(), cfg.model_path_final) 101 | if stats.epoch==stats.log.best.epoch: 102 | torch.save(net.module.state_dict(), cfg.model_path_best) 103 | 104 | def dummy(cfg, stats, criterion, net, optimizer, scheduler, loader): 105 | net.train() 106 | net.reset_hook() 107 | with torch.no_grad(): 108 | for iter, batch in tqdm.tqdm(enumerate(loader)): 109 | x, _ = extract_batch(batch) 110 | _ = net(x) 111 | 112 | 113 | def train(cfg, stats, criterion, net, optimizer, scheduler, loader): 114 | net.train() 115 | 116 | lamda_c = cfg.optim.lamda_ini 117 | net.module.set_lamda(lamda_c) 118 | 119 | for iter, batch in tqdm.tqdm(enumerate(loader)): 120 | optimizer.zero_grad() 121 | x, labels = extract_batch(batch) 122 | 123 | net.module.register_rand_hook(3) 124 | 125 | outputs = net(x) 126 | ce_loss = criterion(outputs, labels) 127 | 128 | 129 | layer_stats = net.module.aggregate_dss() 130 | stats.update_meter(ce_loss, outputs, labels, optimizer.param_groups[0]['lr'], layer_stats) 131 | 132 | if cfg.optim.spr_w==0: 133 | loss = ce_loss 134 | else: 135 | loss = ce_loss + cfg.optim.spr_w*layer_stats.op_loss.to('cuda:0') 136 | loss.backward() 137 | 138 | if iter % 16 == 0: 139 | stats.disp() 140 | 141 | optimizer.step() 142 | if isinstance(scheduler, lr_scheduler.OneCycleLR): 143 | scheduler.step() 144 | 145 | 146 | if not isinstance(scheduler, lr_scheduler.OneCycleLR): 147 | if isinstance(scheduler, timm.scheduler.CosineLRScheduler): 148 | scheduler.step(stats.epoch) 149 | else: 150 | scheduler.step() 151 | 152 | 153 | def evaluate(cfg, stats, criterion, net, optimizer, scheduler, loader): 154 | net.eval() 155 | for iter, batch in tqdm.tqdm(enumerate(loader)): 156 | with torch.no_grad(): 157 | x, labels = extract_batch(batch) 158 | net.module.register_all_hook() 159 | outputs = net(x) 160 | ce_loss = criterion(outputs, labels) 161 | layer_stats = net.module.aggregate_dss() 162 | stats.update_meter(ce_loss, outputs, labels, optimizer.param_groups[0]['lr'], layer_stats) 163 | 164 | 165 | if __name__ == "__main__": 166 | main() -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch.nn as nn 3 | 4 | import torchvision.models as M 5 | 6 | def get_model(cfg): 7 | pretrained = cfg.model.pretrained==1 8 | if cfg.dataset.name in ['CIFAR10', 'cifar10']: 9 | return create_cifar_model(pretrained=pretrained, num_classes=10) 10 | elif cfg.dataset.name in ['CIFAR100', 'cifar100']: 11 | return create_cifar_model(pretrained=pretrained, num_classes=100) 12 | elif cfg.dataset.name =='IMNET': 13 | return create_imagenet_model(pretrained=pretrained) 14 | 15 | def create_cifar_model(pretrained, num_classes): 16 | if pretrained: 17 | model = torchvision.models.resnet18(pretrained=pretrained) 18 | model.fc = nn.Linear(512,num_classes) 19 | else: 20 | model = torchvision.models.resnet18(pretrained=False, num_classes=num_classes) 21 | model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) 22 | model.conv1.is_input = True 23 | model.maxpool = nn.Identity() 24 | return model 25 | 26 | def create_imagenet_model(pretrained): 27 | if pretrained==1: 28 | model=torchvision.models.convnext_base(weights = M.ConvNeXt_Base_Weights.DEFAULT) 29 | else: 30 | model=torchvision.models.convnext_base() 31 | model.features[0][0].is_input = True 32 | 33 | return model -------------------------------------------------------------------------------- /run_cifar.sh: -------------------------------------------------------------------------------- 1 | echo "----run_cifar----"; 2 | echo "loss_type: $1" 3 | echo "lamda_ini: $2" 4 | echo "wgt_bit: $3" 5 | echo "act_bit: $4" 6 | echo "act_bit: $5" 7 | echo "act_bit: $6" 8 | 9 | for i in 0 1 2 4 8 16 32 64 128 256 512 1024 2048 4096 8192 16384 32768 10 | do 11 | echo "spr_weight: $i" 12 | python main.py --config $5 optim.spr_w $i optim.loss_type $1 optim.lamda_ini $2 model.wgt_bit $3 model.act_bit $4 13 | done 14 | 15 | echo "----done---"; -------------------------------------------------------------------------------- /run_imagenet.sh: -------------------------------------------------------------------------------- 1 | echo "----run_imagenet----"; 2 | echo "loss_type: $1" 3 | echo "lamda_ini: $2" 4 | echo "wgt_bit: $3" 5 | echo "act_bit: $4" 6 | echo "act_bit: $5" 7 | 8 | for i in 0 1 2 4 8 32 64 128 9 | do 10 | echo "spr_weight: $i" 11 | python main.py --config $5 optim.spr_w $i optim.loss_type $1 optim.lamda_ini $2 model.wgt_bit $3 model.act_bit $4 12 | done 13 | 14 | echo "----done---"; 15 | -------------------------------------------------------------------------------- /vis_hist.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "%reload_ext autoreload\n", 11 | "%autoreload 2\n", 12 | "\n", 13 | "import numpy as np\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import os\n", 16 | "import matplotlib.cm as cm\n", 17 | "import torch\n", 18 | "import torch.nn.functional as F\n", 19 | "from core.bnet import BNet\n", 20 | "from core.layers.bconv2d import compute_bit_cost, bit2min, bit2max\n", 21 | "from model import get_model\n", 22 | "\n", 23 | "label_fontsize = 12\n", 24 | "\n", 25 | "def scale_w(w, scale, bit=8):\n", 26 | " thd_neg = bit2min(bit)\n", 27 | " thd_pos = bit2max(bit)\n", 28 | " w = torch.clamp(w.div(scale), thd_neg-0.5, thd_pos+0.5)\n", 29 | " return w\n", 30 | "\n", 31 | "def get_layer_ind(K=3):\n", 32 | " unsorted_max_indices = np.argpartition(-net.dense_syn_cnt, K)[:K]\n", 33 | " y = net.dense_syn_cnt[unsorted_max_indices]\n", 34 | " indices = np.argsort(-y)\n", 35 | " max_k_indices = unsorted_max_indices[indices]\n", 36 | "\n", 37 | " return max_k_indices\n", 38 | "\n", 39 | "print('Start')\n", 40 | "\n", 41 | "dssw = 1024\n", 42 | "act_bit = 32 # 8, 32\n", 43 | "wgt_bit = 8 # 8, 32\n", 44 | "lamda = 1.0 # 0.0, 1.0\n", 45 | "epoch = 0\n", 46 | "dataset_name = 'cifar10' # 'cifar10' , 'cifar100'\n", 47 | "optim_loss_type = 'act_tgt' # 'act_tgt' 'act_naive'\n", 48 | "model_name = 'resnet18' # \n", 49 | "\n", 50 | "exp_name = '{}/{}/{}_wd1_bs512_epoch200_lr0.05_0.01_lamda_{:0.1f}_bit8_{}'.format(dataset_name, model_name, optim_loss_type, lamda, act_bit)\n", 51 | "\n", 52 | "log_tpath = os.path.join('log', exp_name,'{:08d}','score','log') \n", 53 | "\n", 54 | "bit_cost, _ = compute_bit_cost(wgt_bit)\n", 55 | "\n", 56 | "cost_color = []\n", 57 | "\n", 58 | "for i in range(len(bit_cost)):\n", 59 | " if bit_cost[i]==0:\n", 60 | " cost_color.append((0,0,0))\n", 61 | " else:\n", 62 | " cost_color.append(cm.jet((bit_cost[i].item())/7.0))\n", 63 | "\n", 64 | "score_tpath = os.path.join('log', exp_name, '{:08d}').format(dssw)\n", 65 | "model_tpath = os.path.join('log', exp_name, '{:08d}'.format(dssw), 'model', 'final.pt').format(dssw)\n", 66 | "\n", 67 | "print('path OK ' + score_tpath)\n", 68 | "\n", 69 | "cfg = CN((yaml.safe_load(open('config/cifar10.yaml', 'r'))))\n", 70 | "model = get_model(cfg)\n", 71 | "cfg.dataset.input_shape = [1, 3, 32, 32]\n", 72 | "net = BNet(model, cfg)\n", 73 | "\n", 74 | "print('model loadted ' + exp_name)\n", 75 | "\n", 76 | "if epoch>0:\n", 77 | " net.load_state_dict(torch.load(model_tpath, map_location='cpu'), strict=False)\n", 78 | "\n", 79 | "\n", 80 | "os.makedirs(os.path.join(score_tpath.format(dssw), 'png') ,exist_ok=True)\n", 81 | "os.makedirs(os.path.join(score_tpath.format(dssw), 'pdf') ,exist_ok=True)\n", 82 | "\n", 83 | "max_k_indices = get_layer_ind(3)\n", 84 | "print(net.dense_syn_cnt/1e6)\n", 85 | "\n", 86 | "\n", 87 | "for idx, module in enumerate(net.DSS_cand):\n", 88 | " plt.figure(21, figsize=(6,3),dpi=200)\n", 89 | " w = module.fweight\n", 90 | " \n", 91 | " if idx in max_k_indices:\n", 92 | " print(idx)\n", 93 | " print(net.dense_syn_cnt[idx]/1e6)\n", 94 | " else:\n", 95 | " continue\n", 96 | "\n", 97 | " scale = module.scale\n", 98 | " w = scale_w(w, scale, wgt_bit)\n", 99 | " hist = torch.histogram(w.cpu(), bins=2**wgt_bit+1, range=(bit2min(wgt_bit)-0.5, bit2max(wgt_bit)+0.5), density=False)\n", 100 | " bit_cnt = hist[0].mul(bit_cost.cpu().squeeze()).sum()/torch.numel(w)\n", 101 | " fig = plt.bar(hist[1].ceil()[:257].detach().numpy(), hist[0].detach().numpy(), color=cost_color, linewidth=1)\n", 102 | " \n", 103 | " for c in range(8):\n", 104 | " if c==0:\n", 105 | " col = (0,0,0)\n", 106 | " else:\n", 107 | " col = cm.jet(c/7.0)\n", 108 | " # plt.text(-125+ (c)*16, 50, str(c))\n", 109 | " plt.text(10+ (c)*15, 1e5, str(c), bbox=dict(edgecolor=col, facecolor='none'))\n", 110 | " \n", 111 | " \n", 112 | " plt.ylim([0, 1e6])\n", 113 | " plt.yscale('log')\n", 114 | " plt.xlabel(\"Weight\", fontsize=label_fontsize)\n", 115 | " plt.ylabel(\"Frequency (log)\", fontsize=label_fontsize)\n", 116 | " plt.xlim([-128.5, 128.5])\n", 117 | " plt.yticks([1e1, 1e2, 1e3,1e4,1e5, 1e6])\n", 118 | " plt.savefig((os.path.join(score_tpath, 'png', 'layer{:03d}_{:03d}_.png'.format(idx, epoch))), bbox_inches='tight')\n", 119 | " plt.savefig((os.path.join(score_tpath, 'pdf', 'layer{:03d}_{:03d}_.pdf'.format(idx, epoch))), bbox_inches='tight')\n", 120 | " print(os.path.join(score_tpath, 'png', 'layer{:03d}_{:03d}_.png'.format(idx, epoch)))\n", 121 | " plt.show()" 122 | ] 123 | } 124 | ], 125 | "metadata": { 126 | "kernelspec": { 127 | "display_name": "Python 3.9.13 ('bitprune')", 128 | "language": "python", 129 | "name": "python3" 130 | }, 131 | "language_info": { 132 | "codemirror_mode": { 133 | "name": "ipython", 134 | "version": 3 135 | }, 136 | "file_extension": ".py", 137 | "mimetype": "text/x-python", 138 | "name": "python", 139 | "nbconvert_exporter": "python", 140 | "pygments_lexer": "ipython3", 141 | "version": "3.9.13 | packaged by conda-forge | (main, May 27 2022, 16:56:21) \n[GCC 10.3.0]" 142 | }, 143 | "orig_nbformat": 4, 144 | "vscode": { 145 | "interpreter": { 146 | "hash": "fa8be68a8b7298b717a779de91d0437c8c17312fb770321ab98c25f6837e8a76" 147 | } 148 | } 149 | }, 150 | "nbformat": 4, 151 | "nbformat_minor": 2 152 | } 153 | --------------------------------------------------------------------------------