├── .gitignore ├── LICENSE ├── README.md ├── cg-cifar10.py ├── model ├── __init__.py ├── cg_cifar10_resnet.py ├── cg_cifar10_resnet_postact.py ├── pg_cifar10_resnet.py └── quantized_cifar10_resnet.py ├── pg-cifar10.py ├── scripts ├── train_cg.sh ├── train_cg_postact.sh └── train_pg_pact.sh └── utils ├── __init__.py ├── cg_utils.py ├── pg_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, cornell-zhang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dnn-gating 2 | 3 | ```dnn-gating``` is a collective repository of [precision-gating](https://arxiv.org/abs/2002.07136) and [channel-gating](https://arxiv.org/abs/1805.12549) reimplemented in Pytorch. 4 | 5 | ## Precision Gating (PG) 6 | 7 | ### Requirments of PG 8 | 9 | ``` 10 | python 3.6.8 11 | torch >= 1.3.0 12 | numpy 1.16.4 13 | matplotlib 3.1.0 14 | ``` 15 | 16 | ### Usage 17 | 18 | With this repo, you can: 19 | - Evaluate uniform quantization and [PACT](https://arxiv.org/abs/1805.06085). 20 | - Evaluate PG on ResNet CIFAR-10. 21 | - Apply PG to your own models and datasets. 22 | 23 | ##### Example 24 | The following example trains ResNet-20 on CIFAR-10 with activations quantized to 3 bits, 2 MSBs out of which for prediction. 25 | 26 | ```sh 27 | $ cd scripts 28 | $ source train_pg_pact.sh 29 | ``` 30 | 31 | ##### Specify the Flags 32 | 33 | Make sure to tune the training parameters in to achieve a good model prediction accuracy. 34 | ``` 35 | -w : bitwidth of weights (floating-point if set to 0) 36 | -a : bitwidth of activations (floating-point if set to 0) 37 | -pact : use parameterized clipping for activatons 38 | -pg : use PG 39 | -pb : prediction bitwidth (only valid if -pg is turned on, and the bitwidth of prediction must smaller than that of activations) 40 | -gtar : the gating target 41 | -sg : the penalty factor on the gating loss 42 | -spbp : use sparse back-prop 43 | ``` 44 | 45 | ## Channel Gating (CG) 46 | 47 | ### Requirments of CG 48 | ``` 49 | python 2.7.12 50 | torch 1.1.0 51 | numpy 1.16.4 52 | matplotlib 2.1.0 53 | ``` 54 | ### Usage 55 | 56 | With this repo, you can: 57 | - Evaluate CG on ResNet CIFAR-10 (both the original and modified post-activated ResNets). 58 | - The post-activated ResNet allows applying channel gating to all convolutional layers in a residual module. 59 | - Apply CG to your own models and datasets. 60 | 61 | ##### Example 62 | The following examples use one fourth and half of input channels in the base path for the original and post-activated ResNets, respectively. 63 | 64 | ```sh 65 | $ cd scripts 66 | $ source train_cg.sh 67 | $ source train_cg_postact.sh 68 | ``` 69 | 70 | ##### Specify the Flags 71 | 72 | The training parameters can be tuned to achieve different FLOP reduction and model accuracy. 73 | ``` 74 | -lr : initial learning rate 75 | -wd: weigth decaying factor 76 | -pt: use 1/pt fraction of channels for prediction 77 | -gi: the intital value of gating thresholds 78 | -gtar: the target value of gating thresholds 79 | -spbp : use sparse back-prop 80 | -group: use group conv in the base path 81 | -cg : use CG 82 | -postact: use post-activated ResNet 83 | ``` 84 | 85 | ## Apply PG/CG to Your Own Models & Datasets 86 | 87 | The following steps allows you to apply PG/CG to your own models. 88 | 1. Copy the model file to ```model/```. 89 | 2. Import ```utils/pg_utils.py``` /```utils/cg_utils.py``` in the model file, replace convolutional layers followed by activation functions with the ```PGConv2d```/ ```CGConv2d``` module. 90 | 3. Import ```model/your_model.py``` in the ```generate_model()``` function in ```pg-cifar10.py```/ ```cg-cifar10.py```. 91 | 92 | If you prepare your own training scripts, remember to add the **gating loss** to the model prediction loss before doing back-prop. 93 | 94 | ##### Note 95 | - The way of exporting sparsity in the update phase we are using is only valid while training on a single GPU. This is because Pytorch modifies each model replica on a GPU instead of a global model if ```DataParallel``` is activated. For multi-GPU training, we suggest users turn off the sparsity printing during training, save the trained model, and print the sparsity only when testing. 96 | 97 | 98 | 99 | ### Citation 100 | If you use CG or PG in your research, please cite our NeurIPS'19 and ICLR'20 papers. 101 | 102 | **Channel Gating Neural Networks** 103 | ``` 104 | 105 | @incollection{NIPS2019_8464, 106 | title = {Channel Gating Neural Networks}, 107 | author = {Hua, Weizhe and Zhou, Yuan and De Sa, Christopher M and Zhang, Zhiru and Suh, G. Edward}, 108 | booktitle = {Advances in Neural Information Processing Systems 32}, 109 | editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett}, 110 | pages = {1886--1896}, 111 | year = {2019}, 112 | publisher = {Curran Associates, Inc.}, 113 | url = {http://papers.nips.cc/paper/8464-channel-gating-neural-networks.pdf} 114 | } 115 | ``` 116 | **Precision Gating: Improving Neural Network Efficiency with Dynamic Dual-Precision Activations** 117 | ``` 118 | @inproceedings{ 119 | Zhang2020Precision, 120 | title={Precision Gating: Improving Neural Network Efficiency with Dynamic Dual-Precision Activations}, 121 | author={Yichi Zhang and Ritchie Zhao and Weizhe Hua and Nayun Xu and G. Edward Suh and Zhiru Zhang}, 122 | booktitle={International Conference on Learning Representations}, 123 | year={2020}, 124 | url={https://openreview.net/forum?id=SJgVU0EKwS} 125 | } 126 | ``` 127 | -------------------------------------------------------------------------------- /cg-cifar10.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import matplotlib.pyplot as plt 11 | import utils.utils as util 12 | 13 | import numpy as np 14 | import os, time, sys 15 | import argparse 16 | 17 | import utils.cg_utils as G 18 | torch.manual_seed(123123) 19 | 20 | #---------------------------- 21 | # Argument parser. 22 | #---------------------------- 23 | parser = argparse.ArgumentParser(description='PyTorch CIFAR-10 Training') 24 | parser.add_argument('--lr', '-lr', type=float, default=0.1, help='initial learning rate') 25 | parser.add_argument('--wt_decay', '-wd', type=float, default=1e-4, help='weight decaying') 26 | parser.add_argument('--save', '-s', action='store_true', help='save the model') 27 | parser.add_argument('--test', '-t', action='store_true', help='test only') 28 | parser.add_argument('--path', '-p', type=str, default=None, help='saved model path') 29 | parser.add_argument('--partitions', '-pt', type=int, default=4, help='number of partitions') 30 | parser.add_argument('--ginit', '-gi', type=float, default=0.0, help='initial value of the gating threshold') 31 | parser.add_argument('--alpha', '-a', type=float, default=2.0, help='slope of the gate backprop') 32 | parser.add_argument('--use_group', '-group', action='store_true', help='use group conv as the base path') 33 | parser.add_argument('--gtarget', '-gtar', type=float, default=0.0, help='gating target') 34 | parser.add_argument('--sparse_bp', '-spbp', action='store_true', help='sparse backprop of PGConv2d') 35 | parser.add_argument('--use_cg', '-cg', action='store_true', help='activate channel gating') 36 | parser.add_argument('--use_shuffle', '-shuffle', action='store_true', help='add channel shuffling') 37 | parser.add_argument('--use_postact', '-postact', action='store_true', help='use postact resnet') 38 | parser.add_argument('--which_gpus', '-gpu', type=str, default='0', help='which gpus to use') 39 | 40 | args = parser.parse_args() 41 | 42 | ######################### 43 | # parameters 44 | 45 | batch_size = 128 46 | num_epoch = 250 47 | _LAST_EPOCH = -1 #last_epoch arg is useful for restart 48 | _LEARNING_RATE = args.lr 49 | _WEIGHT_DECAY = args.wt_decay 50 | _ARCH = "resnet-20" 51 | this_file_path = os.path.dirname(os.path.abspath(__file__)) 52 | save_folder = os.path.join(this_file_path, 'save_CIFAR10_model') 53 | ######################### 54 | 55 | #---------------------------- 56 | # Load the CIFAR-10 dataset. 57 | #---------------------------- 58 | 59 | def load_cifar10(): 60 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 61 | std=[0.229, 0.224, 0.225]) 62 | transform_train = transforms.Compose([ 63 | transforms.RandomHorizontalFlip(), 64 | transforms.RandomCrop(32, 4), 65 | transforms.ToTensor(), 66 | normalize 67 | ]) 68 | transform_test = transforms.Compose([ 69 | transforms.ToTensor(), 70 | normalize 71 | ]) 72 | 73 | # pin_memory=True makes transfering data from host to GPU faster 74 | trainset = torchvision.datasets.CIFAR10(root='/tmp/cifar10_data', train=True, 75 | download=True, transform=transform_train) 76 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 77 | shuffle=True, num_workers=4, pin_memory=True) 78 | 79 | testset = torchvision.datasets.CIFAR10(root='/tmp/cifar10_data', train=False, 80 | download=True, transform=transform_test) 81 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 82 | shuffle=False, num_workers=4, pin_memory=True) 83 | 84 | classes = ('plane', 'car', 'bird', 'cat', 85 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 86 | 87 | return trainloader, testloader, classes 88 | 89 | #---------------------------- 90 | # Define the weight decay. 91 | #---------------------------- 92 | def add_weight_decay(model, weight_decay=1e-5, skip_name='decay_skip_name'): 93 | decay = [] 94 | no_decay = [] 95 | for name, param in model.named_parameters(): 96 | if not param.requires_grad: 97 | continue 98 | if skip_name in name: 99 | no_decay.append(param) 100 | elif len(param.shape) == 1: 101 | no_decay.append(param) 102 | else: 103 | decay.append(param) 104 | return [ 105 | {'params': no_decay, 'weight_decay': 0.0}, 106 | {'params': decay, 'weight_decay': weight_decay}] 107 | 108 | #---------------------------- 109 | # Define the model. 110 | #---------------------------- 111 | 112 | def generate_model(model_arch): 113 | if model_arch == 'resnet-20': 114 | if args.use_cg: 115 | if args.use_postact: 116 | import model.cg_cifar10_resnet_postact as m 117 | else: 118 | import model.cg_cifar10_resnet as m 119 | kwargs = {'partitions':args.partitions, 'ginit':args.ginit, \ 120 | 'use_group':args.use_group, 'sparse_bp':args.sparse_bp, \ 121 | 'shuffle':args.use_shuffle, 'alpha':args.alpha} 122 | return m.resnet20(**kwargs) 123 | else: 124 | if args.use_postact: 125 | import model.cifar10_resnet_postact as m 126 | else: 127 | import model.cifar10_resnet as m 128 | return m.resnet20() 129 | else: 130 | raise NotImplementedError("Model architecture is not supported.") 131 | 132 | 133 | 134 | #---------------------------- 135 | # Train the network. 136 | #---------------------------- 137 | 138 | def train_model(trainloader, testloader, net, device): 139 | # define the loss function 140 | criterion = (nn.CrossEntropyLoss().cuda() 141 | if torch.cuda.is_available() else nn.CrossEntropyLoss()) 142 | initial_lr = _LEARNING_RATE * batch_size / 128 143 | 144 | # add weight decaying 145 | params = add_weight_decay(net, _WEIGHT_DECAY, 'threshold') 146 | # initialize the optimizer 147 | optimizer = optim.SGD(params, 148 | lr=initial_lr, 149 | momentum=0.9) 150 | # multiply the lr by 0.1 at 100, 150, and 200 epochs 151 | div = num_epoch // 5 152 | 153 | lr_decay_milestones = [div*3, div*4] 154 | scheduler = optim.lr_scheduler.MultiStepLR( 155 | optimizer, 156 | milestones=lr_decay_milestones, 157 | gamma=0.1, 158 | last_epoch=_LAST_EPOCH) 159 | 160 | for epoch in range(num_epoch): # loop over the dataset multiple times 161 | 162 | # set printing functions 163 | batch_time = util.AverageMeter('Time/batch', ':.3f') 164 | losses = util.AverageMeter('Loss', ':6.2f') 165 | top1 = util.AverageMeter('Acc', ':6.2f') 166 | progress = util.ProgressMeter( 167 | len(trainloader), 168 | [losses, top1, batch_time], 169 | prefix="Epoch: [{}]".format(epoch+1) 170 | ) 171 | 172 | # switch the model to the training mode 173 | net.train() 174 | 175 | print('current learning rate = {}'.format(optimizer.param_groups[0]['lr'])) 176 | 177 | # each epoch 178 | end = time.time() 179 | for i, data in enumerate(trainloader, 0): 180 | # get the inputs; data is a list of [inputs, labels] 181 | inputs, labels = data[0].to(device), data[1].to(device) 182 | 183 | # zero the parameter gradients 184 | optimizer.zero_grad() 185 | 186 | # forward + backward + optimize 187 | outputs = net(inputs) 188 | loss = criterion(outputs, labels) 189 | for name, param in net.named_parameters(): 190 | if 'threshold' in name: 191 | loss += 0.0001 * torch.sum((param-args.gtarget) ** 2) 192 | loss.backward() 193 | optimizer.step() 194 | 195 | # measure accuracy and record loss 196 | _, batch_predicted = torch.max(outputs.data, 1) 197 | batch_accu = 100.0 * (batch_predicted == labels).sum().item() / labels.size(0) 198 | losses.update(loss.item(), labels.size(0)) 199 | top1.update(batch_accu, labels.size(0)) 200 | 201 | # measure elapsed time 202 | batch_time.update(time.time() - end) 203 | end = time.time() 204 | 205 | if i % 50 == 49: 206 | # print statistics every 100 mini-batches each epoch 207 | progress.display(i) # i = batch id in the epoch 208 | 209 | # update the learning rate 210 | scheduler.step() 211 | 212 | # print test accuracy every few epochs 213 | if epoch % 10 == 9: 214 | print('epoch {}'.format(epoch+1)) 215 | test_accu(testloader, net, device) 216 | 217 | # save the model if required 218 | if args.save: 219 | print("Saving the trained model.") 220 | util.save_models(net.state_dict(), save_folder, suffix=_ARCH) 221 | 222 | print('Finished Training') 223 | 224 | 225 | #---------------------------- 226 | # Test accuracy. 227 | #---------------------------- 228 | 229 | def test_accu(testloader, net, device): 230 | cnt_factor = 2 if args.use_postact else 1 231 | cnt_out = np.zeros(9 * cnt_factor) # this number is hardcoded for ResNet-20 232 | cnt_full = np.zeros(9 * cnt_factor) # this number is hardcoded for ResNet-20 233 | num_out = [] 234 | num_full = [] 235 | def _report_sparsity(m): 236 | classname = m.__class__.__name__ 237 | if isinstance(m, G.CGConv2d): 238 | num_out.append(m.num_out) 239 | num_full.append(m.num_full) 240 | 241 | correct = 0 242 | total = 0 243 | # switch the model to the evaluation mode 244 | net.eval() 245 | with torch.no_grad(): 246 | for data in testloader: 247 | images, labels = data[0].to(device), data[1].to(device) 248 | outputs = net(images) 249 | _, predicted = torch.max(outputs.data, 1) 250 | total += labels.size(0) 251 | correct += (predicted == labels).sum().item() 252 | 253 | """ calculate statistics per PG layer """ 254 | if args.use_cg: 255 | net.apply(_report_sparsity) 256 | cnt_out += np.array(num_out) 257 | cnt_full += np.array(num_full) 258 | num_out = [] 259 | num_full = [] 260 | 261 | print('Accuracy of the network on the 10000 test images: %.1f %%' % ( 262 | 100 * correct / total)) 263 | if args.use_cg: 264 | print('Sparsity of the update phase: %.1f %%' % (100-np.sum(cnt_full)*1.0/np.sum(cnt_out)*100)) 265 | 266 | 267 | #---------------------------- 268 | # Test accuracy per class 269 | #---------------------------- 270 | 271 | def per_class_test_accu(testloader, classes, net, device): 272 | class_correct = list(0. for i in range(10)) 273 | class_total = list(0. for i in range(10)) 274 | net.eval() 275 | with torch.no_grad(): 276 | for data in testloader: 277 | images, labels = data[0].to(device), data[1].to(device) 278 | outputs = net(images) 279 | _, predicted = torch.max(outputs, 1) 280 | c = (predicted == labels).squeeze() 281 | for i in range(4): 282 | label = labels[i] 283 | class_correct[label] += c[i].item() 284 | class_total[label] += 1 285 | 286 | 287 | for i in range(10): 288 | print('Accuracy of %5s : %.1f %%' % ( 289 | classes[i], 100 * class_correct[i] / class_total[i])) 290 | 291 | 292 | #---------------------------- 293 | # Main function. 294 | #---------------------------- 295 | 296 | def main(): 297 | os.environ["CUDA_VISIBLE_DEVICES"] = args.which_gpus 298 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 299 | print("Available GPUs: {}".format(torch.cuda.device_count())) 300 | 301 | print("Create {} model.".format(_ARCH)) 302 | net = generate_model(_ARCH) 303 | 304 | if args.path: 305 | print("@ Load trained model from {}.".format(args.path)) 306 | net.load_state_dict(torch.load(args.path)) 307 | if torch.cuda.device_count() > 1: 308 | print("Activate multi GPU support.") 309 | net = nn.DataParallel(net) 310 | net.to(device) 311 | 312 | print("Loading the data.") 313 | trainloader, testloader, classes = load_cifar10() 314 | if args.test: 315 | print("Mode: Test only.") 316 | test_accu(testloader, net, device) 317 | else: 318 | print("Start training.") 319 | train_model(trainloader, testloader, net, device) 320 | test_accu(testloader, net, device) 321 | per_class_test_accu(testloader, classes, net, device) 322 | 323 | 324 | if __name__ == "__main__": 325 | main() 326 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornell-zhang/dnn-gating/31666fadf35789b433c79eec8669a3a2df818bd4/model/__init__.py -------------------------------------------------------------------------------- /model/cg_cifar10_resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 3 | The implementation and structure of this file is hugely influenced by [2] 4 | which is implemented for ImageNet and doesn't have option A for identity. 5 | Moreover, most of the implementations on the web is copy-paste from 6 | torchvision's resnet and has wrong number of params. 7 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 8 | number of layers and parameters: 9 | name | layers | params 10 | ResNet20 | 20 | 0.27M 11 | ResNet32 | 32 | 0.46M 12 | ResNet44 | 44 | 0.66M 13 | ResNet56 | 56 | 0.85M 14 | ResNet110 | 110 | 1.7M 15 | ResNet1202| 1202 | 19.4m 16 | which this implementation indeed has. 17 | Reference: 18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 20 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 21 | If you use this implementation in you work, please don't forget to mention the 22 | author, Yerlan Idelbayev. 23 | ''' 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | import utils.cg_utils as G 28 | 29 | from torch.autograd import Variable 30 | 31 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 32 | 33 | def _weights_init(m): 34 | classname = m.__class__.__name__ 35 | #print(classname) 36 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 37 | nn.init.kaiming_normal_(m.weight) 38 | 39 | class LambdaLayer(nn.Module): 40 | def __init__(self, lambd): 41 | super(LambdaLayer, self).__init__() 42 | self.lambd = lambd 43 | 44 | def forward(self, x): 45 | return self.lambd(x) 46 | 47 | 48 | class BasicBlock(nn.Module): 49 | expansion = 1 50 | 51 | def __init__(self, in_planes, planes, stride=1, option='A', **kwargs): 52 | super(BasicBlock, self).__init__() 53 | self.conv1 = G.CGConv2d(in_planes, planes, kernel_size=3, 54 | stride=stride, padding=1, bias=False, 55 | p=kwargs['partitions'], th=kwargs['ginit'], alpha=kwargs['alpha'], 56 | use_group=kwargs['use_group'], shuffle=kwargs['shuffle'], sparse_bp=kwargs['sparse_bp']) 57 | self.bn1 = nn.BatchNorm2d(planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 59 | stride=1, padding=1, bias=False) 60 | self.bn2 = nn.BatchNorm2d(planes) 61 | self.relu = nn.ReLU() 62 | 63 | self.shortcut = nn.Sequential() 64 | if stride != 1 or in_planes != planes: 65 | if option == 'A': 66 | """ 67 | For CIFAR10 ResNet paper uses option A. 68 | """ 69 | self.shortcut = LambdaLayer(lambda x: 70 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 71 | elif option == 'B': 72 | self.shortcut = nn.Sequential( 73 | nn.Conv2d(in_planes, self.expansion * planes, 74 | kernel_size=1, stride=stride, bias=False), 75 | nn.BatchNorm2d(self.expansion * planes) 76 | ) 77 | 78 | def forward(self, x): 79 | out = self.relu(self.bn1(self.conv1(x))) 80 | out = self.bn2(self.conv2(out)) 81 | out += self.shortcut(x) 82 | out = self.relu(out) 83 | return out 84 | 85 | 86 | class ResNet(nn.Module): 87 | def __init__(self, block, num_blocks, num_classes=10, **kwargs): 88 | super(ResNet, self).__init__() 89 | self.in_planes = 16 90 | 91 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 92 | self.bn1 = nn.BatchNorm2d(16) 93 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1, **kwargs) 94 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2, **kwargs) 95 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2, **kwargs) 96 | self.linear = nn.Linear(64, num_classes) 97 | self.relu = nn.ReLU(inplace=True) 98 | 99 | self.apply(_weights_init) 100 | 101 | def _make_layer(self, block, planes, num_blocks, stride, **kwargs): 102 | strides = [stride] + [1]*(num_blocks-1) 103 | layers = [] 104 | for stride in strides: 105 | layers.append(block(self.in_planes, planes, stride, **kwargs)) 106 | self.in_planes = planes * block.expansion 107 | 108 | return nn.Sequential(*layers) 109 | 110 | def forward(self, x): 111 | out = self.relu(self.bn1(self.conv1(x))) 112 | out = self.layer1(out) 113 | out = self.layer2(out) 114 | out = self.layer3(out) 115 | out = F.avg_pool2d(out, out.size()[3]) 116 | out = out.view(out.size(0), -1) 117 | out = self.linear(out) 118 | return out 119 | 120 | 121 | def resnet20(num_classes=10, **kwargs): 122 | return ResNet(BasicBlock, [3, 3, 3], num_classes=num_classes, **kwargs) 123 | 124 | 125 | def resnet32(): 126 | return ResNet(BasicBlock, [5, 5, 5]) 127 | 128 | 129 | def resnet44(): 130 | return ResNet(BasicBlock, [7, 7, 7]) 131 | 132 | 133 | def resnet56(): 134 | return ResNet(BasicBlock, [9, 9, 9]) 135 | 136 | 137 | def resnet110(): 138 | return ResNet(BasicBlock, [18, 18, 18]) 139 | 140 | 141 | def resnet1202(): 142 | return ResNet(BasicBlock, [200, 200, 200]) 143 | 144 | ''' 145 | def test(net): 146 | import numpy as np 147 | total_params = 0 148 | 149 | for x in filter(lambda p: p.requires_grad, net.parameters()): 150 | total_params += np.prod(x.data.numpy().shape) 151 | print("Total number of params", total_params) 152 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 153 | 154 | 155 | if __name__ == "__main__": 156 | for net_name in __all__: 157 | if net_name.startswith('resnet'): 158 | print(net_name) 159 | test(globals()[net_name]()) 160 | print() 161 | ''' 162 | -------------------------------------------------------------------------------- /model/cg_cifar10_resnet_postact.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 3 | The implementation and structure of this file is hugely influenced by [2] 4 | which is implemented for ImageNet and doesn't have option A for identity. 5 | Moreover, most of the implementations on the web is copy-paste from 6 | torchvision's resnet and has wrong number of params. 7 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 8 | number of layers and parameters: 9 | name | layers | params 10 | ResNet20 | 20 | 0.27M 11 | ResNet32 | 32 | 0.46M 12 | ResNet44 | 44 | 0.66M 13 | ResNet56 | 56 | 0.85M 14 | ResNet110 | 110 | 1.7M 15 | ResNet1202| 1202 | 19.4m 16 | which this implementation indeed has. 17 | Reference: 18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 20 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 21 | If you use this implementation in you work, please don't forget to mention the 22 | author, Yerlan Idelbayev. 23 | ''' 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | import utils.cg_utils as G 28 | 29 | from torch.autograd import Variable 30 | 31 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 32 | 33 | def _weights_init(m): 34 | classname = m.__class__.__name__ 35 | #print(classname) 36 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 37 | nn.init.kaiming_normal_(m.weight) 38 | 39 | class LambdaLayer(nn.Module): 40 | def __init__(self, lambd): 41 | super(LambdaLayer, self).__init__() 42 | self.lambd = lambd 43 | 44 | def forward(self, x): 45 | return self.lambd(x) 46 | 47 | 48 | class BasicBlock(nn.Module): 49 | expansion = 1 50 | 51 | def __init__(self, in_planes, planes, stride=1, option='A', **kwargs): 52 | super(BasicBlock, self).__init__() 53 | #print("wbits:{}, abits:{}, trunc bits:{}".format(kwargs['wbits'], kwargs['abits'], kwargs['pred_bits'])) 54 | self.conv1 = G.CGConv2d(in_planes, planes, kernel_size=3, 55 | stride=stride, padding=1, bias=False, 56 | p=kwargs['partitions'], th=kwargs['ginit'], alpha=kwargs['alpha'], 57 | use_group=kwargs['use_group'], shuffle=kwargs['shuffle'], sparse_bp=kwargs['sparse_bp']) 58 | self.bn1 = nn.BatchNorm2d(planes) 59 | self.conv2 = G.CGConv2d(planes, planes, kernel_size=3, 60 | stride=1, padding=1, bias=False, 61 | p=kwargs['partitions'], th=kwargs['ginit'], alpha=kwargs['alpha'], 62 | use_group=kwargs['use_group'], shuffle=kwargs['shuffle'], sparse_bp=kwargs['sparse_bp']) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.relu = nn.ReLU() 65 | 66 | self.shortcut = nn.Sequential() 67 | if stride != 1 or in_planes != planes: 68 | if option == 'A': 69 | """ 70 | For CIFAR10 ResNet paper uses option A. 71 | """ 72 | self.shortcut = LambdaLayer(lambda x: 73 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 74 | elif option == 'B': 75 | self.shortcut = nn.Sequential( 76 | nn.Conv2d(in_planes, self.expansion * planes, 77 | kernel_size=1, stride=stride, bias=False), 78 | nn.BatchNorm2d(self.expansion * planes) 79 | ) 80 | 81 | def forward(self, x): 82 | out = self.relu(self.bn1(self.conv1(x))) 83 | out = self.relu(self.bn2(self.conv2(out))) 84 | out += self.shortcut(x) 85 | return out 86 | 87 | 88 | class ResNet(nn.Module): 89 | def __init__(self, block, num_blocks, num_classes=10, **kwargs): 90 | super(ResNet, self).__init__() 91 | self.in_planes = 16 92 | 93 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 94 | self.bn1 = nn.BatchNorm2d(16) 95 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1, **kwargs) 96 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2, **kwargs) 97 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2, **kwargs) 98 | self.linear = nn.Linear(64, num_classes) 99 | self.relu = nn.ReLU(inplace=True) 100 | 101 | self.apply(_weights_init) 102 | 103 | def _make_layer(self, block, planes, num_blocks, stride, **kwargs): 104 | strides = [stride] + [1]*(num_blocks-1) 105 | layers = [] 106 | for stride in strides: 107 | layers.append(block(self.in_planes, planes, stride, **kwargs)) 108 | self.in_planes = planes * block.expansion 109 | 110 | return nn.Sequential(*layers) 111 | 112 | def forward(self, x): 113 | out = self.relu(self.bn1(self.conv1(x))) 114 | out = self.layer1(out) 115 | out = self.layer2(out) 116 | out = self.layer3(out) 117 | out = F.avg_pool2d(out, out.size()[3]) 118 | out = out.view(out.size(0), -1) 119 | out = self.linear(out) 120 | return out 121 | 122 | 123 | def resnet20(num_classes=10, **kwargs): 124 | return ResNet(BasicBlock, [3, 3, 3], num_classes=num_classes, **kwargs) 125 | 126 | 127 | def resnet32(): 128 | return ResNet(BasicBlock, [5, 5, 5]) 129 | 130 | 131 | def resnet44(): 132 | return ResNet(BasicBlock, [7, 7, 7]) 133 | 134 | 135 | def resnet56(): 136 | return ResNet(BasicBlock, [9, 9, 9]) 137 | 138 | 139 | def resnet110(): 140 | return ResNet(BasicBlock, [18, 18, 18]) 141 | 142 | 143 | def resnet1202(): 144 | return ResNet(BasicBlock, [200, 200, 200]) 145 | 146 | ''' 147 | def test(net): 148 | import numpy as np 149 | total_params = 0 150 | 151 | for x in filter(lambda p: p.requires_grad, net.parameters()): 152 | total_params += np.prod(x.data.numpy().shape) 153 | print("Total number of params", total_params) 154 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 155 | 156 | 157 | if __name__ == "__main__": 158 | for net_name in __all__: 159 | if net_name.startswith('resnet'): 160 | print(net_name) 161 | test(globals()[net_name]()) 162 | print() 163 | ''' 164 | -------------------------------------------------------------------------------- /model/pg_cifar10_resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 3 | The implementation and structure of this file is hugely influenced by [2] 4 | which is implemented for ImageNet and doesn't have option A for identity. 5 | Moreover, most of the implementations on the web is copy-paste from 6 | torchvision's resnet and has wrong number of params. 7 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 8 | number of layers and parameters: 9 | name | layers | params 10 | ResNet20 | 20 | 0.27M 11 | ResNet32 | 32 | 0.46M 12 | ResNet44 | 44 | 0.66M 13 | ResNet56 | 56 | 0.85M 14 | ResNet110 | 110 | 1.7M 15 | ResNet1202| 1202 | 19.4m 16 | which this implementation indeed has. 17 | Reference: 18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 20 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 21 | If you use this implementation in you work, please don't forget to mention the 22 | author, Yerlan Idelbayev. 23 | ''' 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | import utils.pg_utils as q 28 | 29 | from torch.autograd import Variable 30 | 31 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 32 | 33 | def _weights_init(m): 34 | classname = m.__class__.__name__ 35 | #print(classname) 36 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 37 | nn.init.kaiming_normal_(m.weight) 38 | 39 | class LambdaLayer(nn.Module): 40 | def __init__(self, lambd): 41 | super(LambdaLayer, self).__init__() 42 | self.lambd = lambd 43 | 44 | def forward(self, x): 45 | return self.lambd(x) 46 | 47 | 48 | class BasicBlock(nn.Module): 49 | expansion = 1 50 | 51 | def __init__(self, in_planes, planes, stride=1, option='A', **kwargs): 52 | super(BasicBlock, self).__init__() 53 | #print("wbits:{}, abits:{}, trunc bits:{}".format(kwargs['wbits'], kwargs['abits'], kwargs['pred_bits'])) 54 | self.conv1 = q.PGConv2d(in_planes, planes, kernel_size=3, 55 | stride=stride, padding=1, bias=False, 56 | wbits=kwargs['wbits'], abits=kwargs['abits'], 57 | pred_bits=kwargs['pred_bits'], sparse_bp=kwargs['sparse_bp']) 58 | self.bn1 = nn.BatchNorm2d(planes) 59 | self.conv2 = q.QuantizedConv2d(planes, planes, kernel_size=3, 60 | stride=1, padding=1, bias=False, 61 | wbits=kwargs['wbits'], abits=kwargs['abits']) 62 | self.bn2 = nn.BatchNorm2d(planes) 63 | self.relu = q.PactReLU() if kwargs['pact'] else nn.ReLU() 64 | 65 | self.shortcut = nn.Sequential() 66 | if stride != 1 or in_planes != planes: 67 | if option == 'A': 68 | """ 69 | For CIFAR10 ResNet paper uses option A. 70 | """ 71 | self.shortcut = LambdaLayer(lambda x: 72 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 73 | elif option == 'B': 74 | self.shortcut = nn.Sequential( 75 | q.QuantizedConv2d(in_planes, self.expansion * planes, 76 | kernel_size=1, stride=stride, bias=False, 77 | wbits=kwargs['wbits'], abits=kwargs['abits']), 78 | nn.BatchNorm2d(self.expansion * planes) 79 | ) 80 | 81 | def forward(self, x): 82 | out = self.relu(self.bn1(self.conv1(x))) 83 | out = self.bn2(self.conv2(out)) 84 | out += self.shortcut(x) 85 | out = self.relu(out) 86 | return out 87 | 88 | 89 | class ResNet(nn.Module): 90 | def __init__(self, block, num_blocks, num_classes=10, **kwargs): 91 | super(ResNet, self).__init__() 92 | self.in_planes = 16 93 | 94 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 95 | self.bn1 = nn.BatchNorm2d(16) 96 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1, **kwargs) 97 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2, **kwargs) 98 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2, **kwargs) 99 | self.linear = nn.Linear(64, num_classes) 100 | self.relu = q.PactReLU() if kwargs['pact'] else nn.ReLU() 101 | 102 | self.apply(_weights_init) 103 | 104 | def _make_layer(self, block, planes, num_blocks, stride, **kwargs): 105 | strides = [stride] + [1]*(num_blocks-1) 106 | layers = [] 107 | for stride in strides: 108 | layers.append(block(self.in_planes, planes, stride, **kwargs)) 109 | self.in_planes = planes * block.expansion 110 | 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, x): 114 | out = self.relu(self.bn1(self.conv1(x))) 115 | out = self.layer1(out) 116 | out = self.layer2(out) 117 | out = self.layer3(out) 118 | out = F.avg_pool2d(out, out.size()[3]) 119 | out = out.view(out.size(0), -1) 120 | out = self.linear(out) 121 | return out 122 | 123 | 124 | def resnet20(num_classes=10, **kwargs): 125 | return ResNet(BasicBlock, [3, 3, 3], num_classes=num_classes, **kwargs) 126 | 127 | 128 | def resnet32(): 129 | return ResNet(BasicBlock, [5, 5, 5]) 130 | 131 | 132 | def resnet44(): 133 | return ResNet(BasicBlock, [7, 7, 7]) 134 | 135 | 136 | def resnet56(): 137 | return ResNet(BasicBlock, [9, 9, 9]) 138 | 139 | 140 | def resnet110(): 141 | return ResNet(BasicBlock, [18, 18, 18]) 142 | 143 | 144 | def resnet1202(): 145 | return ResNet(BasicBlock, [200, 200, 200]) 146 | 147 | ''' 148 | def test(net): 149 | import numpy as np 150 | total_params = 0 151 | 152 | for x in filter(lambda p: p.requires_grad, net.parameters()): 153 | total_params += np.prod(x.data.numpy().shape) 154 | print("Total number of params", total_params) 155 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 156 | 157 | 158 | if __name__ == "__main__": 159 | for net_name in __all__: 160 | if net_name.startswith('resnet'): 161 | print(net_name) 162 | test(globals()[net_name]()) 163 | print() 164 | ''' 165 | -------------------------------------------------------------------------------- /model/quantized_cifar10_resnet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 3 | The implementation and structure of this file is hugely influenced by [2] 4 | which is implemented for ImageNet and doesn't have option A for identity. 5 | Moreover, most of the implementations on the web is copy-paste from 6 | torchvision's resnet and has wrong number of params. 7 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 8 | number of layers and parameters: 9 | name | layers | params 10 | ResNet20 | 20 | 0.27M 11 | ResNet32 | 32 | 0.46M 12 | ResNet44 | 44 | 0.66M 13 | ResNet56 | 56 | 0.85M 14 | ResNet110 | 110 | 1.7M 15 | ResNet1202| 1202 | 19.4m 16 | which this implementation indeed has. 17 | Reference: 18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 20 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 21 | If you use this implementation in you work, please don't forget to mention the 22 | author, Yerlan Idelbayev. 23 | ''' 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | import utils.pg_utils as q 28 | 29 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] 30 | 31 | def _weights_init(m): 32 | classname = m.__class__.__name__ 33 | #print(classname) 34 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 35 | nn.init.kaiming_normal_(m.weight) 36 | 37 | class LambdaLayer(nn.Module): 38 | def __init__(self, lambd): 39 | super(LambdaLayer, self).__init__() 40 | self.lambd = lambd 41 | 42 | def forward(self, x): 43 | return self.lambd(x) 44 | 45 | 46 | class BasicBlock(nn.Module): 47 | expansion = 1 48 | 49 | def __init__(self, in_planes, planes, stride=1, option='A', **kwargs): 50 | super(BasicBlock, self).__init__() 51 | self.conv1 = q.QuantizedConv2d(in_planes, planes, kernel_size=3, 52 | stride=stride, padding=1, bias=False, 53 | wbits=kwargs['wbits'], abits=kwargs['abits']) 54 | self.bn1 = nn.BatchNorm2d(planes) 55 | self.conv2 = q.QuantizedConv2d(planes, planes, kernel_size=3, 56 | stride=1, padding=1, bias=False, 57 | wbits=kwargs['wbits'], abits=kwargs['abits']) 58 | self.bn2 = nn.BatchNorm2d(planes) 59 | self.relu = q.PactReLU() if kwargs['pact'] else nn.ReLU() 60 | 61 | self.shortcut = nn.Sequential() 62 | if stride != 1 or in_planes != planes: 63 | if option == 'A': 64 | """ 65 | For CIFAR10 ResNet paper uses option A. 66 | """ 67 | self.shortcut = LambdaLayer(lambda x: 68 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 69 | elif option == 'B': 70 | self.shortcut = nn.Sequential( 71 | q.QuantizedConv2d(in_planes, self.expansion * planes, 72 | kernel_size=1, stride=stride, bias=False, 73 | wbits=kwargs['wbits'], abits=kwargs['abits']), 74 | nn.BatchNorm2d(self.expansion * planes) 75 | ) 76 | 77 | def forward(self, x): 78 | out = self.relu(self.bn1(self.conv1(x))) 79 | out = self.bn2(self.conv2(out)) 80 | out += self.shortcut(x) 81 | out = self.relu(out) 82 | return out 83 | 84 | 85 | class ResNet(nn.Module): 86 | def __init__(self, block, num_blocks, num_classes=10, **kwargs): 87 | super(ResNet, self).__init__() 88 | self.in_planes = 16 89 | 90 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 91 | self.bn1 = nn.BatchNorm2d(16) 92 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1, **kwargs) 93 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2, **kwargs) 94 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2, **kwargs) 95 | self.linear = nn.Linear(64, num_classes) 96 | self.relu = q.PactReLU() if kwargs['pact'] else nn.ReLU() 97 | 98 | self.apply(_weights_init) 99 | 100 | def _make_layer(self, block, planes, num_blocks, stride, **kwargs): 101 | strides = [stride] + [1]*(num_blocks-1) 102 | layers = [] 103 | for stride in strides: 104 | layers.append(block(self.in_planes, planes, stride, **kwargs)) 105 | self.in_planes = planes * block.expansion 106 | 107 | return nn.Sequential(*layers) 108 | 109 | def forward(self, x): 110 | out = self.relu(self.bn1(self.conv1(x))) 111 | out = self.layer1(out) 112 | out = self.layer2(out) 113 | out = self.layer3(out) 114 | out = F.avg_pool2d(out, out.size()[3]) 115 | out = out.view(out.size(0), -1) 116 | out = self.linear(out) 117 | return out 118 | 119 | 120 | def resnet20(num_classes=10, **kwargs): 121 | return ResNet(BasicBlock, [3, 3, 3], num_classes=num_classes, **kwargs) 122 | 123 | 124 | def resnet32(): 125 | return ResNet(BasicBlock, [5, 5, 5]) 126 | 127 | 128 | def resnet44(): 129 | return ResNet(BasicBlock, [7, 7, 7]) 130 | 131 | 132 | def resnet56(): 133 | return ResNet(BasicBlock, [9, 9, 9]) 134 | 135 | 136 | def resnet110(): 137 | return ResNet(BasicBlock, [18, 18, 18]) 138 | 139 | 140 | def resnet1202(): 141 | return ResNet(BasicBlock, [200, 200, 200]) 142 | 143 | ''' 144 | def test(net): 145 | import numpy as np 146 | total_params = 0 147 | 148 | for x in filter(lambda p: p.requires_grad, net.parameters()): 149 | total_params += np.prod(x.data.numpy().shape) 150 | print("Total number of params", total_params) 151 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 152 | 153 | 154 | if __name__ == "__main__": 155 | for net_name in __all__: 156 | if net_name.startswith('resnet'): 157 | print(net_name) 158 | test(globals()[net_name]()) 159 | print() 160 | ''' 161 | -------------------------------------------------------------------------------- /pg-cifar10.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | import torch.nn as nn 9 | # import torch.nn.functional as F 10 | import torch.optim as optim 11 | import matplotlib.pyplot as plt 12 | import utils.utils as util 13 | 14 | import numpy as np 15 | import os, time, sys 16 | import argparse 17 | 18 | import utils.pg_utils as q 19 | 20 | #torch.manual_seed(123123) 21 | 22 | ######################### 23 | # parameters 24 | 25 | batch_size = 128 26 | num_epoch = 200 27 | _LAST_EPOCH = -1 #last_epoch arg is useful for restart 28 | _WEIGHT_DECAY = 1e-4 29 | _ARCH = "resnet-20" 30 | this_file_path = os.path.dirname(os.path.abspath(__file__)) 31 | save_folder = os.path.join(this_file_path, 'save_CIFAR10_model') 32 | ######################### 33 | 34 | 35 | #---------------------------- 36 | # Argument parser. 37 | #---------------------------- 38 | parser = argparse.ArgumentParser(description='PyTorch CIFAR-10 Training') 39 | parser.add_argument('--save', '-s', action='store_true', help='save the model') 40 | parser.add_argument('--test', '-t', action='store_true', help='test only') 41 | parser.add_argument('--path', '-p', type=str, default=None, help='saved model path') 42 | parser.add_argument('--which_gpus', '-gpu', type=str, default='0', help='which gpus to use') 43 | 44 | # quantization 45 | parser.add_argument('--wbits', '-w', type=int, default=0, help='bitwidth of weights') 46 | parser.add_argument('--abits', '-a', type=int, default=0, help='bitwidth of activations') 47 | parser.add_argument('--ispact', '-pact', action='store_true', help='activate PACT ReLU') 48 | 49 | # PG specific arguments 50 | parser.add_argument('--pbits', '-pb', type=int, default=4, help='bitwidth of predictions') 51 | parser.add_argument('--gtarget', '-gtar', type=float, default=0.0, help='gating target') 52 | parser.add_argument('--sparse_bp', '-spbp', action='store_true', help='sparse backprop of PGConv2d') 53 | parser.add_argument('--ispg', '-pg', action='store_true', help='activate precision gating') 54 | parser.add_argument('--sigma', '-sg', type=float, default=0.001, help='the penalty factor') 55 | 56 | args = parser.parse_args() 57 | 58 | 59 | #---------------------------- 60 | # Load the CIFAR-10 dataset. 61 | #---------------------------- 62 | 63 | def load_cifar10(): 64 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 65 | std=[0.229, 0.224, 0.225]) 66 | transform_train = transforms.Compose([ 67 | transforms.RandomHorizontalFlip(), 68 | transforms.RandomCrop(32, 4), 69 | transforms.ToTensor(), 70 | normalize 71 | ]) 72 | transform_test = transforms.Compose([ 73 | transforms.ToTensor(), 74 | normalize 75 | ]) 76 | 77 | # pin_memory=True makes transfering data from host to GPU faster 78 | trainset = torchvision.datasets.CIFAR10(root='/tmp/cifar10_data', train=True, 79 | download=True, transform=transform_train) 80 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 81 | shuffle=True, num_workers=4, pin_memory=True) 82 | 83 | testset = torchvision.datasets.CIFAR10(root='/tmp/cifar10_data', train=False, 84 | download=True, transform=transform_test) 85 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 86 | shuffle=False, num_workers=4, pin_memory=True) 87 | 88 | classes = ('plane', 'car', 'bird', 'cat', 89 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 90 | 91 | return trainloader, testloader, classes 92 | 93 | 94 | #---------------------------- 95 | # Define the model. 96 | #---------------------------- 97 | 98 | def generate_model(model_arch): 99 | if model_arch == 'resnet-20': 100 | if args.ispg: 101 | import model.pg_cifar10_resnet as m 102 | kwargs = {'wbits':args.wbits, 'abits':args.abits, \ 103 | 'pred_bits':args.pbits, 'sparse_bp':args.sparse_bp, \ 104 | 'pact':args.ispact} 105 | return m.resnet20(**kwargs) 106 | else: 107 | import model.quantized_cifar10_resnet as m 108 | kwargs = {'wbits':args.wbits, 'abits':args.abits, 'pact':args.ispact} 109 | return m.resnet20(**kwargs) 110 | else: 111 | raise NotImplementedError("Model architecture is not supported.") 112 | 113 | 114 | 115 | #---------------------------- 116 | # Train the network. 117 | #---------------------------- 118 | 119 | def train_model(trainloader, testloader, net, device): 120 | if torch.cuda.device_count() > 1: 121 | # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs 122 | print("Activate multi GPU support.") 123 | net = nn.DataParallel(net) 124 | net.to(device) 125 | # define the loss function 126 | criterion = (nn.CrossEntropyLoss().cuda() 127 | if torch.cuda.is_available() else nn.CrossEntropyLoss()) 128 | # Scale the lr linearly with the batch size. 129 | # Should be 0.1 when batch_size=128 130 | initial_lr = 0.1 * batch_size / 128 131 | # initialize the optimizer 132 | optimizer = optim.SGD(net.parameters(), 133 | lr=initial_lr, 134 | momentum=0.9, 135 | weight_decay=_WEIGHT_DECAY) 136 | # multiply the lr by 0.1 at 100, 150, and 200 epochs 137 | div = num_epoch // 4 138 | lr_decay_milestones = [div*2, div*3] 139 | scheduler = optim.lr_scheduler.MultiStepLR( 140 | optimizer, 141 | milestones=lr_decay_milestones, 142 | gamma=0.1, 143 | last_epoch=_LAST_EPOCH) 144 | 145 | for epoch in range(num_epoch): # loop over the dataset multiple times 146 | 147 | # set printing functions 148 | batch_time = util.AverageMeter('Time/batch', ':.3f') 149 | losses = util.AverageMeter('Loss', ':6.2f') 150 | top1 = util.AverageMeter('Acc', ':6.2f') 151 | progress = util.ProgressMeter( 152 | len(trainloader), 153 | [losses, top1, batch_time], 154 | prefix="Epoch: [{}]".format(epoch+1) 155 | ) 156 | 157 | # switch the model to the training mode 158 | net.train() 159 | 160 | print('current learning rate = {}'.format(optimizer.param_groups[0]['lr'])) 161 | 162 | # each epoch 163 | end = time.time() 164 | for i, data in enumerate(trainloader, 0): 165 | # get the inputs; data is a list of [inputs, labels] 166 | inputs, labels = data[0].to(device), data[1].to(device) 167 | 168 | # zero the parameter gradients 169 | optimizer.zero_grad() 170 | 171 | # forward + backward + optimize 172 | outputs = net(inputs) 173 | loss = criterion(outputs, labels) 174 | for name, param in net.named_parameters(): 175 | if 'threshold' in name: 176 | loss += args.sigma * torch.norm(param-args.gtarget) 177 | loss.backward() 178 | optimizer.step() 179 | 180 | # measure accuracy and record loss 181 | _, batch_predicted = torch.max(outputs.data, 1) 182 | batch_accu = 100.0 * (batch_predicted == labels).sum().item() / labels.size(0) 183 | losses.update(loss.item(), labels.size(0)) 184 | top1.update(batch_accu, labels.size(0)) 185 | 186 | # measure elapsed time 187 | batch_time.update(time.time() - end) 188 | end = time.time() 189 | 190 | if i % 50 == 49: 191 | # print statistics every 100 mini-batches each epoch 192 | progress.display(i) # i = batch id in the epoch 193 | 194 | # update the learning rate 195 | scheduler.step() 196 | 197 | # print test accuracy every few epochs 198 | if epoch % 10 == 9: 199 | print('epoch {}'.format(epoch+1)) 200 | test_accu(testloader, net, device) 201 | 202 | # save the model if required 203 | if args.save: 204 | print("Saving the trained model.") 205 | util.save_models(net.state_dict(), save_folder, suffix=_ARCH) 206 | 207 | print('Finished Training') 208 | 209 | 210 | #---------------------------- 211 | # Test accuracy. 212 | #---------------------------- 213 | 214 | def test_accu(testloader, net, device): 215 | net.to(device) 216 | cnt_out = np.zeros(9) # this 9 is hardcoded for ResNet-20 217 | cnt_high = np.zeros(9) # this 9 is hardcoded for ResNet-20 218 | num_out = [] 219 | num_high = [] 220 | def _report_sparsity(m): 221 | classname = m.__class__.__name__ 222 | if isinstance(m, q.PGConv2d): 223 | num_out.append(m.num_out) 224 | num_high.append(m.num_high) 225 | 226 | correct = 0 227 | total = 0 228 | # switch the model to the evaluation mode 229 | net.eval() 230 | with torch.no_grad(): 231 | for data in testloader: 232 | images, labels = data[0].to(device), data[1].to(device) 233 | outputs = net(images) 234 | _, predicted = torch.max(outputs.data, 1) 235 | total += labels.size(0) 236 | correct += (predicted == labels).sum().item() 237 | 238 | """ calculate statistics per PG layer """ 239 | if args.ispg: 240 | net.apply(_report_sparsity) 241 | cnt_out += np.array(num_out) 242 | cnt_high += np.array(num_high) 243 | num_out = [] 244 | num_high = [] 245 | 246 | print('Accuracy of the network on the 10000 test images: %.1f %%' % ( 247 | 100 * correct / total)) 248 | if args.ispg: 249 | print('Sparsity of the update phase: %.1f %%' % (100-np.sum(cnt_high)*1.0/np.sum(cnt_out)*100)) 250 | 251 | 252 | #---------------------------- 253 | # Test accuracy per class 254 | #---------------------------- 255 | 256 | def per_class_test_accu(testloader, classes, net, device): 257 | class_correct = list(0. for i in range(10)) 258 | class_total = list(0. for i in range(10)) 259 | net.eval() 260 | with torch.no_grad(): 261 | for data in testloader: 262 | images, labels = data[0].to(device), data[1].to(device) 263 | outputs = net(images) 264 | _, predicted = torch.max(outputs, 1) 265 | c = (predicted == labels).squeeze() 266 | for i in range(4): 267 | label = labels[i] 268 | class_correct[label] += c[i].item() 269 | class_total[label] += 1 270 | 271 | 272 | for i in range(10): 273 | print('Accuracy of %5s : %.1f %%' % ( 274 | classes[i], 100 * class_correct[i] / class_total[i])) 275 | 276 | 277 | #---------------------------- 278 | # Main function. 279 | #---------------------------- 280 | 281 | def main(): 282 | os.environ["CUDA_VISIBLE_DEVICES"] = args.which_gpus 283 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 284 | print("Available GPUs: {}".format(torch.cuda.device_count())) 285 | 286 | print("Create {} model.".format(_ARCH)) 287 | net = generate_model(_ARCH) 288 | #print(net) 289 | 290 | if args.path: 291 | print("@ Load trained model from {}.".format(args.path)) 292 | net.load_state_dict(torch.load(args.path)) 293 | 294 | print("Loading the data.") 295 | trainloader, testloader, classes = load_cifar10() 296 | if args.test: 297 | print("Mode: Test only.") 298 | test_accu(testloader, net, device) 299 | else: 300 | print("Start training.") 301 | train_model(trainloader, testloader, net, device) 302 | test_accu(testloader, net, device) 303 | per_class_test_accu(testloader, classes, net, device) 304 | 305 | 306 | if __name__ == "__main__": 307 | main() 308 | 309 | 310 | 311 | 312 | 313 | 314 | ############################# 315 | # Backup code. 316 | ############################# 317 | 318 | ''' 319 | #---------------------------- 320 | # Show images in the dataset. 321 | #---------------------------- 322 | 323 | def imshow(img): 324 | img = img / 2 + 0.5 # unnormalize 325 | npimg = img.numpy() 326 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 327 | plt.show() 328 | 329 | # get some random training images 330 | dataiter = iter(trainloader) 331 | images, labels = dataiter.next() 332 | 333 | # show images 334 | imshow(torchvision.utils.make_grid(images)) 335 | # print labels 336 | print(' '.join('%5s' % classes[labels[j]] for j in range(4))) 337 | ''' 338 | 339 | -------------------------------------------------------------------------------- /scripts/train_cg.sh: -------------------------------------------------------------------------------- 1 | python ../cg-cifar10.py -gpu 0 -gi 0.0 -gt 1.0 -pt 4 --alpha 2.0 -cg --sparse_bp --use_group 2 | -------------------------------------------------------------------------------- /scripts/train_cg_postact.sh: -------------------------------------------------------------------------------- 1 | python ../cg-cifar10.py -gpu 1 -lr 0.08 -wd 3e-4 -gi 0.0 -gt 0.6 -pt 2 -cg --use_postact --sparse_bp --use_group 2 | -------------------------------------------------------------------------------- /scripts/train_pg_pact.sh: -------------------------------------------------------------------------------- 1 | python ../pg-cifar10.py -pg -w 8 -a 3 -pb 2 -gtar 1 -spbp -pact -gpu 0 2 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cornell-zhang/dnn-gating/31666fadf35789b433c79eec8669a3a2df818bd4/utils/__init__.py -------------------------------------------------------------------------------- /utils/cg_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class SparseGreaterThan(torch.autograd.Function): 7 | """ 8 | We can implement our own custom autograd Functions by subclassing 9 | torch.autograd.Function and implementing the forward and backward passes 10 | which operate on Tensors. 11 | """ 12 | 13 | @staticmethod 14 | def forward(ctx, input): 15 | """ 16 | In the forward pass we receive a Tensor containing the input and return 17 | a Tensor containing the output. ctx is a context object that can be used 18 | to stash information for backward computation. You can cache arbitrary 19 | objects for use in the backward pass using the ctx.save_for_backward method. 20 | """ 21 | ctx.save_for_backward(input) 22 | return torch.Tensor.float(torch.gt(input, torch.zeros_like(input))) 23 | 24 | @staticmethod 25 | def backward(ctx, grad_output): 26 | """ 27 | In the backward pass we receive a Tensor containing the gradient of the loss 28 | with respect to the output, and we need to compute the gradient of the loss 29 | with respect to the input. 30 | 31 | The backward behavior of the floor function is defined as the identity function. 32 | """ 33 | input, = ctx.saved_tensors 34 | grad_input = grad_output.clone() 35 | grad_input[input (0.5/alpha)*torch.ones_like(input)] = 0 110 | return grad_input, None 111 | 112 | class CGConv2d(nn.Conv2d): 113 | """ 114 | A convolutional layer computed as out = final_sum * mask + partial_sum * (1 - mask) 115 | - final_sum Y = W * X 116 | - partial_sum Yp = Wp * Xp = (W * mask) * X 117 | - gate decision d = Yp > Delta 118 | 119 | **Note**: 120 | 1. CG predicts with the partial sum 121 | 2. no bias due to Batch Normalization 122 | """ 123 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 124 | padding=0, dilation=1, groups=1, bias=False, 125 | padding_mode='zeros', sparse_bp=False, use_group=False, 126 | shuffle=False, p=4, th=-6.0, alpha=2.0): 127 | super(CGConv2d, self).__init__(in_channels, out_channels, 128 | kernel_size, stride, 129 | padding, dilation, groups, 130 | bias, padding_mode) 131 | self.gt = SparseGreaterThan.apply if sparse_bp else GreaterThan.apply 132 | self.gtSTE = GreaterThanSTE.apply 133 | self.th = th 134 | self.alpha = alpha 135 | self.p = p 136 | self.bn = nn.BatchNorm2d(out_channels, affine=False) 137 | self.shuffle = shuffle 138 | 139 | """ 140 | initialize the mask for the weights 141 | """ 142 | in_chunk_size = int(in_channels/self.p) 143 | out_chunk_size = int(out_channels/self.p) 144 | 145 | mask = torch.zeros(out_channels, in_channels, kernel_size, kernel_size) 146 | if use_group: 147 | for idx in range(self.p): 148 | mask[idx*out_chunk_size:(idx+1)*out_chunk_size, idx*in_chunk_size:(idx+1)*in_chunk_size] = torch.ones(out_chunk_size, in_chunk_size, kernel_size, kernel_size) 149 | else: 150 | mask[:, 0:in_chunk_size] = torch.ones(out_channels, in_chunk_size, kernel_size, kernel_size) 151 | self.mask = nn.Parameter(mask, requires_grad=False) 152 | 153 | """ 154 | initialize the threshold with th 155 | """ 156 | self.threshold = nn.Parameter(self.th * torch.ones(1, out_channels, 1, 1)) 157 | 158 | """ number of output features """ 159 | self.num_out = 0 160 | """ n!umber of output features computed using all input channels """ 161 | self.num_full = 0 162 | 163 | def forward(self, input): 164 | """ 165 | 1. mask the weight tensor 166 | 2. compute Yp 167 | 3. generate gating decision d 168 | """ 169 | if self.shuffle: 170 | input = channel_shuffle(input, self.p) 171 | Yp = F.conv2d(input, self.weight * self.mask, self.bias, self.stride, self.padding, self.dilation, self.groups) 172 | """ Calculate the gating decison d """ 173 | d = self.gt(torch.sigmoid(self.alpha*(self.bn(Yp)-self.threshold)) - 0.5 * torch.ones_like(Yp)) 174 | """ update report """ 175 | self.num_out = d.numel() 176 | self.num_full = d[d>0].numel() 177 | """ perform full convolution """ 178 | Y = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 179 | """ combine outputs """ 180 | return Y * d + Yp * (torch.ones_like(d) - d) 181 | -------------------------------------------------------------------------------- /utils/pg_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | ########## 7 | ## PACT 8 | ########## 9 | 10 | 11 | class PactClip(torch.autograd.Function): 12 | """ 13 | We can implement our own custom autograd Functions by subclassing 14 | torch.autograd.Function and implementing the forward and backward passes 15 | which operate on Tensors. 16 | """ 17 | 18 | @staticmethod 19 | def forward(ctx, input, upper_bound): 20 | """ 21 | In the forward pass we receive a Tensor containing the input and return 22 | a Tensor containing the output. ctx is a context object that can be used 23 | to stash information for backward computation. You can cache arbitrary 24 | objects for use in the backward pass using the ctx.save_for_backward method. 25 | 26 | upper_bound if input > upper_bound 27 | y = input if 0 <= input <= upper_bound 28 | 0 if input < 0 29 | """ 30 | ctx.save_for_backward(input, upper_bound) 31 | return torch.clamp(input, 0, upper_bound.data) 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | """ 36 | In the backward pass we receive a Tensor containing the gradient of the loss 37 | with respect to the output, and we need to compute the gradient of the loss 38 | with respect to the input. 39 | 40 | The backward behavior of the floor function is defined as the identity function. 41 | """ 42 | input, upper_bound, = ctx.saved_tensors 43 | grad_input = grad_output.clone() 44 | grad_upper_bound = grad_output.clone() 45 | grad_input[input<0] = 0 46 | grad_input[input>upper_bound] = 0 47 | grad_upper_bound[input<=upper_bound] = 0 48 | return grad_input, torch.sum(grad_upper_bound) 49 | 50 | class PactReLU(nn.Module): 51 | def __init__(self, upper_bound=6.0): 52 | super(PactReLU, self).__init__() 53 | self.upper_bound = nn.Parameter(torch.tensor(upper_bound)) 54 | 55 | def forward(self, input): 56 | return PactClip.apply(input, self.upper_bound) 57 | 58 | 59 | ########## 60 | ## Mask 61 | ########## 62 | 63 | 64 | class SparseGreaterThan(torch.autograd.Function): 65 | """ 66 | We can implement our own custom autograd Functions by subclassing 67 | torch.autograd.Function and implementing the forward and backward passes 68 | which operate on Tensors. 69 | """ 70 | 71 | @staticmethod 72 | def forward(ctx, input, threshold): 73 | """ 74 | In the forward pass we receive a Tensor containing the input and return 75 | a Tensor containing the output. ctx is a context object that can be used 76 | to stash information for backward computation. You can cache arbitrary 77 | objects for use in the backward pass using the ctx.save_for_backward method. 78 | """ 79 | ctx.save_for_backward(input, torch.tensor(threshold)) 80 | return torch.Tensor.float(torch.gt(input, threshold)) 81 | 82 | @staticmethod 83 | def backward(ctx, grad_output): 84 | """ 85 | In the backward pass we receive a Tensor containing the gradient of the loss 86 | with respect to the output, and we need to compute the gradient of the loss 87 | with respect to the input. 88 | 89 | The backward behavior of the floor function is defined as the identity function. 90 | """ 91 | input, threshold, = ctx.saved_tensors 92 | grad_input = grad_output.clone() 93 | grad_input[input 1, "RoundToBits is only used with bitwidth larger than 1." 245 | self.bits = bits 246 | self.epsilon = 1e-7 247 | 248 | def forward(self, input): 249 | """ extract the sign of each element """ 250 | sign = torch.sign(input).detach() 251 | """ get the mantessa bits """ 252 | input = torch.abs(input) 253 | scaling = torch.max(input).detach() + self.epsilon 254 | input = Clamp.apply( input/scaling ,0.0, 1.0 ) 255 | """ round the mantessa bits to the required precision """ 256 | input = Round.apply(input * (2.0**self.bits-1.0)) / (2.0**self.bits-1.0) 257 | return input * scaling * sign 258 | 259 | class TorchTruncate(nn.Module): 260 | """ 261 | Quantize an input tensor to a b-bit fixed-point representation, and 262 | remain the bh most-significant bits. 263 | Args: 264 | input: Input tensor 265 | b: Number of bits in the fixed-point 266 | bh: Number of most-significant bits remained 267 | """ 268 | def __init__(self, b=8, bh=4): 269 | super(TorchTruncate, self).__init__() 270 | assert b > 0, "Cannot truncate floating-point numbers (b=0)." 271 | assert bh > 0, "Cannot output floating-point numbers (bh=0)." 272 | assert b > bh, "The number of MSBs are larger than the total bitwidth." 273 | self.b = b 274 | self.bh = bh 275 | self.epsilon = 1e-7 276 | 277 | def forward(self, input): 278 | """ extract the sign of each element """ 279 | sign = torch.sign(input).detach() 280 | """ get the mantessa bits """ 281 | input = torch.abs(input) 282 | scaling = torch.max(input).detach() + self.epsilon 283 | input = Clamp.apply( input/scaling ,0.0, 1.0 ) 284 | """ round the mantessa bits to the required precision """ 285 | input = Round.apply( input * (2.0**self.b-1.0) ) 286 | """ truncate the mantessa bits """ 287 | input = Floor.apply( input / (2**(self.b-self.bh) * 1.0) ) 288 | """ rescale """ 289 | input *= (2**(self.b-self.bh) * 1.0) 290 | input /= (2.0**self.b-1.0) 291 | return input * scaling * sign 292 | 293 | class TorchQuantize(nn.Module): 294 | """ 295 | Quantize an input tensor to the fixed-point representation. 296 | Args: 297 | input: Input tensor 298 | bits: Number of bits in the fixed-point 299 | """ 300 | def __init__(self, bits=0): 301 | super(TorchQuantize, self).__init__() 302 | if bits == 0: 303 | self.quantize = nn.Identity() 304 | elif bits == 1: 305 | self.quantize = TorchBinarize() 306 | else: 307 | self.quantize = TorchRoundToBits(bits) 308 | 309 | def forward(self, input): 310 | return self.quantize(input) 311 | 312 | 313 | ########## 314 | ## Layer 315 | ########## 316 | 317 | 318 | class QuantizedConv2d(nn.Conv2d): 319 | """ 320 | A convolutional layer with its weight tensor and input tensor quantized. 321 | """ 322 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 323 | padding=0, dilation=1, groups=1, bias=True, 324 | padding_mode='zeros', wbits=0, abits=0): 325 | super(QuantizedConv2d, self).__init__(in_channels, out_channels, 326 | kernel_size, stride, 327 | padding, dilation, groups, 328 | bias, padding_mode) 329 | self.quantize_w = TorchQuantize(wbits) 330 | self.quantize_a = TorchQuantize(abits) 331 | self.weight_rescale = \ 332 | np.sqrt(1.0/(kernel_size**2 * in_channels)) if (wbits == 1) else 1.0 333 | 334 | def forward(self, input): 335 | """ 336 | 1. Quantize the input tensor 337 | 2. Quantize the weight tensor 338 | 3. Rescale via McDonnell 2018 (https://arxiv.org/abs/1802.08530) 339 | 4. perform convolution 340 | """ 341 | return F.conv2d(self.quantize_a(input), 342 | self.quantize_w(self.weight) * self.weight_rescale, 343 | self.bias, self.stride, self.padding, 344 | self.dilation, self.groups) 345 | 346 | class QuantizedLinear(nn.Linear): 347 | """ 348 | A fully connected layer with its weight tensor and input tensor quantized. 349 | """ 350 | def __init__(self, in_features, out_features, bias=True, wbits=0, abits=0): 351 | super(QuantizedLinear, self).__init__(in_features, out_features, bias) 352 | self.quantize_w = TorchQuantize(wbits) 353 | self.quantize_a = TorchQuantize(abits) 354 | self.weight_rescale = np.sqrt(1.0/in_features) if (wbits == 1) else 1.0 355 | 356 | def forward(self, input): 357 | """ 358 | 1. Quantize the input tensor 359 | 2. Quantize the weight tensor 360 | 3. Rescale via McDonnell 2018 (https://arxiv.org/abs/1802.08530) 361 | 4. perform matrix multiplication 362 | """ 363 | return F.linear(self.quantize_a(input), 364 | self.quantize_w(self.weight) * self.weight_rescale, 365 | self.bias) 366 | 367 | class PGConv2d(nn.Conv2d): 368 | """ 369 | A convolutional layer computed as out = out_msb + mask . out_lsb 370 | - out_msb = I_msb * W 371 | - mask = (I_msb * W) > Delta 372 | - out_lsb = I_lsb * W 373 | out_msb calculates the prediction results. 374 | out_lsb is only calculated where a prediction result exceeds the threshold. 375 | 376 | **Note**: 377 | 1. PG predicts with . 378 | 2. bias must set to be False! 379 | """ 380 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 381 | padding=0, dilation=1, groups=1, bias=False, 382 | padding_mode='zeros', wbits=8, abits=8, pred_bits=4, 383 | sparse_bp=False, alpha=5): 384 | super(PGConv2d, self).__init__(in_channels, out_channels, 385 | kernel_size, stride, 386 | padding, dilation, groups, 387 | bias, padding_mode) 388 | self.quantize_w = TorchQuantize(wbits) 389 | self.quantize_a = TorchQuantize(abits) 390 | self.trunc_a = TorchTruncate(b=abits, bh=pred_bits) 391 | self.gt = SparseGreaterThan.apply if sparse_bp else GreaterThan.apply 392 | self.weight_rescale = \ 393 | np.sqrt(1.0/(kernel_size**2 * in_channels)) if (wbits == 1) else 1.0 394 | self.alpha = alpha 395 | 396 | """ 397 | zero initialization 398 | nan loss while using torch.Tensor to initialize the thresholds 399 | """ 400 | self.threshold = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) 401 | 402 | """ number of output features """ 403 | self.num_out = 0 404 | """ number of output features computed at high precision """ 405 | self.num_high = 0 406 | 407 | def forward(self, input): 408 | """ 409 | 1. Truncate the input tensor 410 | 2. Quantize the weight tensor 411 | 3. Rescale via McDonnell 2018 (https://arxiv.org/abs/1802.08530) 412 | 4. perform MSB convolution 413 | """ 414 | out_msb = F.conv2d(self.trunc_a(input), 415 | self.quantize_w(self.weight) * self.weight_rescale, 416 | self.bias, self.stride, self.padding, 417 | self.dilation, self.groups) 418 | """ Calculate the mask """ 419 | mask = self.gt(torch.sigmoid(self.alpha*(out_msb-self.threshold)), 0.5) 420 | """ update report """ 421 | self.num_out = mask.cpu().numel() 422 | self.num_high = mask[mask>0].cpu().numel() 423 | """ perform LSB convolution """ 424 | out_lsb = F.conv2d(self.quantize_a(input)-self.trunc_a(input), 425 | self.quantize_w(self.weight) * self.weight_rescale, 426 | self.bias, self.stride, self.padding, 427 | self.dilation, self.groups) 428 | """ combine outputs """ 429 | return out_msb + mask * out_lsb 430 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import torch 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | def __init__(self, name, fmt=':f'): 11 | self.name = name 12 | self.fmt = fmt 13 | self.reset() 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | 27 | def __str__(self): 28 | fmtstr = '{name} {avg' + self.fmt + '}' 29 | return fmtstr.format(**self.__dict__) 30 | 31 | 32 | class ProgressMeter(object): 33 | def __init__(self, num_batches, meters, prefix=""): 34 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 35 | self.meters = meters 36 | self.prefix = prefix 37 | 38 | def display(self, batch): 39 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 40 | entries += [str(meter) for meter in self.meters] 41 | print('\t'.join(entries)) 42 | 43 | def _get_batch_fmtstr(self, num_batches): 44 | num_digits = len(str(num_batches // 1)) 45 | fmt = '{:' + str(num_digits) + 'd}' 46 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 47 | 48 | def save_models(model, path, suffix=''): 49 | """Save model to given path 50 | Args: 51 | model: model to be saved 52 | path: path that the model would be saved 53 | epoch: the epoch the model finished training 54 | """ 55 | if not os.path.exists(path): 56 | os.makedirs(path) 57 | file_path = os.path.join(path, "model_{}.pt".format(suffix)) 58 | torch.save(model, file_path) #pwf file 59 | 60 | def poly_decay_lr(optimizer, global_steps, total_steps, base_lr, end_lr, power): 61 | """Sets the learning rate to be polynomially decaying""" 62 | lr = (base_lr - end_lr) * (1 - global_steps/total_steps) ** power + end_lr 63 | for param_group in optimizer.param_groups: 64 | param_group['lr'] = lr 65 | --------------------------------------------------------------------------------