├── Readme.md ├── __init__.py ├── environment.yml ├── eval.py ├── fig ├── conv1.weight.png ├── fc.weight.png ├── layer1.0.conv1.weight.png ├── layer1.0.conv2.weight.png ├── layer1.1.conv1.weight.png ├── layer1.1.conv2.weight.png ├── layer1.2.conv1.weight.png ├── layer1.2.conv2.weight.png ├── layer1.3.conv1.weight.png ├── layer1.3.conv2.weight.png ├── layer1.4.conv1.weight.png ├── layer1.4.conv2.weight.png ├── layer2.0.conv1.weight.png ├── layer2.0.conv2.weight.png ├── layer2.1.conv1.weight.png ├── layer2.1.conv2.weight.png ├── layer2.2.conv1.weight.png ├── layer2.2.conv2.weight.png ├── layer2.3.conv1.weight.png ├── layer2.3.conv2.weight.png ├── layer2.4.conv1.weight.png ├── layer2.4.conv2.weight.png ├── layer3.0.conv1.weight.png ├── layer3.0.conv2.weight.png ├── layer3.1.conv1.weight.png ├── layer3.1.conv2.weight.png ├── layer3.2.conv1.weight.png ├── layer3.2.conv2.weight.png ├── layer3.3.conv1.weight.png ├── layer3.3.conv2.weight.png ├── layer3.4.conv1.weight.png └── layer3.4.conv2.weight.png ├── hyperparmeter_search.py ├── logger.py ├── models ├── ResNet150.pth ├── ResNet82.pth ├── W4_ResNet.pth ├── __init__.py └── cifar100_models.py ├── our_network.py ├── quant_module.py ├── quant_utils.py ├── train.py ├── utils.py └── visualize.py /Readme.md: -------------------------------------------------------------------------------- 1 | # Position-based Scaled Gradient for Model Quantization and Pruning (NeurIPS 2020) 2 | 3 | This repository is the official implementation of [Position-based Scaled Gradient for Model Quantization and Pruning](https://papers.nips.cc/paper/2020/hash/eb1e78328c46506b46a4ac4a1e378b91-Abstract.html). 4 | 5 | The source code is for reproducing the results of Figure 1 of the original paper and Table 1A of Appendix A. 6 | 7 | >The repository provides the codes for training & visualization and pre-trained models. 8 | 9 | ## Requirements 10 | 11 | To install requirements using [environment.yml](environment.yml) refer to the [documentation.](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-from-an-environment-yml-file) 12 | 13 | ``` 14 | torch=1.4.0 15 | tensorboard=2.1.1 16 | ``` 17 | 18 | > Manually installing the above packages will also work. Some visualization packages (matplotlib and seaborn) are also needed for running [visualize.py](visualize.py). 19 | 20 | ## Training 21 | 22 | [train.py](train.py) is the code for training **with PSGD** after the first learning decay. To train the model(s) in the paper, run this command: 23 | 24 | ```train 25 | python train.py --arch --seed --cu_num 26 | --lr --epoch --decay_epoch --adam 27 | --w_bit --lambda_s --first_last_quant --a_bit 28 | --load_pretrained 29 | 30 | #The results from the original paper can be reproducd by running : 31 | python train.py --arch ResNet32 --load_pretrained models/ResNet82.pth --lr 0.01 --epoch 150 --decay_epoch 123 --adam 0 \ 32 | --w_bit 4 --lambda_s 150 --first_last_quant 1 --seed 1 --cu_num 0 33 | ``` 34 | 35 | > The code for training vanilla pre-trained model is not included, but the checkpoint file is provided. 36 | > One can train from scratch or continue training from a vanilla model, which is the method used in the original paper. 37 | > 38 | >Please refer to the original paper and Appendix C for further training details. 39 | > 40 | > Modification to DATASET_PATH in the source file may be needed. 41 | 42 | ## Evaluation 43 | 44 | To evaluate the model with specified bit-widths, run: 45 | 46 | ```eval 47 | python eval.py --arch --model_path --cu_num 48 | --w_bit --a_bit --lambda_s --first_last_quant 49 | --act_quant --act_clipping --clipping_range 50 | 51 | #The results from the original paper can be reproduced by running : 52 | python eval.py --arch ResNet32 --model_path models/W4_ResNet.pth --w_bit 4 --a_bit 4 --lambda_s 150 --first_last_quant 1 \ 53 | --act_quant 0 --act_clipping 0 --cu_num 0 54 | ``` 55 | 56 | > When activation is also quantized, the best clipping range can be found on the training set. 57 | >For more details, please refer to Section 5 of [Data-Free Quantization Through Weight Equalization and Bias Correction](https://arxiv.org/abs/1906.04721). 58 | 59 | ## Pre-trained Models 60 | 61 | We provide two pre-trained models: 62 | 63 | - models/W4_ResNet.pth : PSGD-trained targeting 4-bit (150 epochs) 64 | - models/ResNet150.pth : SGD-trained (150 epochs) 65 | - models/ResNet82.pth : SGD-trained (82 epochs) used for training PSGD 66 | 67 | > For more implementation details, please refer to Appendix A and Appendix C. 68 | 69 | ## Results 70 | 71 | For further details on other bit-widths refer to Table 1A of Appendix A : 72 | 73 | ### Accuracy on CIFAR-100 74 | 75 | | Model | Full precision | W4A32 | 76 | | ------------- |---------------- | ------ | 77 | | ResNet-32 | 70.08% | 69.57% | 78 | 79 | 80 | ## Visualizations 81 | 82 | For visualizing the weight distributions, run [visualize.py](visualize.py). 83 | > Change the paths in the source code for visualizing other models than the provided ones. 84 | > Defualt: 'visualizations/' 85 | 86 | This will save figures of all convolutional layers of PSGD- and SGD- trained models of the provided pre-trained models in the 'visualizations' folder 87 | 88 | ![conv1.weight](fig/conv1.weight.png)
89 | 90 | ## Citation 91 | Please refer to the following citation if this work is useful for your research. 92 | 93 | ### Bibtex: 94 | 95 | ``` 96 | @misc{kim2020positionbased, 97 | title={Position-based Scaled Gradient for Model Quantization and Pruning}, 98 | author={Jangho Kim and KiYoon Yoo and Nojun Kwak}, 99 | year={2020}, 100 | eprint={2005.11035}, 101 | archivePrefix={arXiv}, 102 | primaryClass={cs.CV} 103 | } 104 | ``` -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/__init__.py -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: nips2020 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _pytorch_select=0.2=gpu_0 8 | - absl-py=0.9.0=py37hc8dfbb8_1 9 | - blas=1.0=mkl 10 | - blinker=1.4=py_1 11 | - brotlipy=0.7.0=py37h8f50634_1000 12 | - c-ares=1.15.0=h516909a_1001 13 | - ca-certificates=2020.1.1=0 14 | - cachetools=4.1.0=py_1 15 | - certifi=2020.4.5.1=py37_0 16 | - cffi=1.14.0=py37h2e261b9_0 17 | - chardet=3.0.4=py37hc8dfbb8_1006 18 | - click=7.1.2=pyh9f0ad1d_0 19 | - cryptography=2.9.2=py37hb09aad4_0 20 | - cudatoolkit=10.1.243=h6bb024c_0 21 | - cudnn=7.6.5=cuda10.1_0 22 | - cycler=0.10.0=py37_0 23 | - dbus=1.13.14=hb2f20db_0 24 | - expat=2.2.6=he6710b0_0 25 | - fontconfig=2.13.0=h9420a91_0 26 | - freetype=2.9.1=h8a8886c_1 27 | - glib=2.63.1=h5a9c865_0 28 | - google-auth=1.16.0=pyh9f0ad1d_0 29 | - google-auth-oauthlib=0.4.1=py_2 30 | - grpcio=1.27.2=py37hf8bcb03_0 31 | - gst-plugins-base=1.14.0=hbbd80ab_1 32 | - gstreamer=1.14.0=hb453b48_1 33 | - icu=58.2=he6710b0_3 34 | - idna=2.9=py_1 35 | - importlib-metadata=1.6.1=py37hc8dfbb8_0 36 | - intel-openmp=2020.1=217 37 | - jpeg=9b=h024ee3a_2 38 | - kiwisolver=1.2.0=py37hfd86e86_0 39 | - ld_impl_linux-64=2.33.1=h53a641e_7 40 | - libedit=3.1.20181209=hc058e9b_0 41 | - libffi=3.2.1=hd88cf55_4 42 | - libgcc-ng=9.1.0=hdf63c60_0 43 | - libgfortran-ng=7.3.0=hdf63c60_0 44 | - libpng=1.6.37=hbc83047_0 45 | - libprotobuf=3.12.3=h8b12597_0 46 | - libstdcxx-ng=9.1.0=hdf63c60_0 47 | - libtiff=4.1.0=h2733197_1 48 | - libuuid=1.0.3=h1bed415_2 49 | - libxcb=1.13=h1bed415_1 50 | - libxml2=2.9.9=hea5a465_1 51 | - lz4-c=1.9.2=he6710b0_0 52 | - markdown=3.2.2=py_0 53 | - matplotlib=3.1.3=py37_0 54 | - matplotlib-base=3.1.3=py37hef1b27d_0 55 | - mkl=2020.1=217 56 | - mkl-service=2.3.0=py37he904b0f_0 57 | - mkl_fft=1.0.15=py37ha843d7b_0 58 | - mkl_random=1.1.1=py37h0573a6f_0 59 | - ncurses=6.2=he6710b0_1 60 | - ninja=1.9.0=py37hfd86e86_0 61 | - numpy=1.18.1=py37h4f9e942_0 62 | - numpy-base=1.18.1=py37hde5b4d6_1 63 | - oauthlib=3.0.1=py_0 64 | - olefile=0.46=py37_0 65 | - openssl=1.1.1g=h7b6447c_0 66 | - pandas=1.0.3=py37h0573a6f_0 67 | - pcre=8.43=he6710b0_0 68 | - pillow=7.1.2=py37hb39fc2d_0 69 | - pip=20.0.2=py37_3 70 | - protobuf=3.12.3=py37h3340039_0 71 | - pyasn1=0.4.8=py_0 72 | - pyasn1-modules=0.2.7=py_0 73 | - pycparser=2.20=py_0 74 | - pyjwt=1.7.1=py_0 75 | - pyopenssl=19.1.0=py_1 76 | - pyparsing=2.4.7=py_0 77 | - pyqt=5.9.2=py37h05f1152_2 78 | - pysocks=1.7.1=py37hc8dfbb8_1 79 | - python=3.7.6=h0371630_2 80 | - python-dateutil=2.8.1=py_0 81 | - python_abi=3.7=1_cp37m 82 | - pytorch=1.4.0=cuda101py37h02f0884_0 83 | - pytz=2020.1=py_0 84 | - qt=5.9.7=h5867ecd_1 85 | - readline=7.0=h7b6447c_5 86 | - requests=2.23.0=pyh8c360ce_2 87 | - requests-oauthlib=1.2.0=py_0 88 | - rsa=4.0=py_0 89 | - scipy=1.4.1=py37h0b6359f_0 90 | - seaborn=0.10.1=py_0 91 | - setuptools=47.1.1=py37_0 92 | - sip=4.19.8=py37hf484d3e_0 93 | - six=1.15.0=py_0 94 | - sqlite=3.31.1=h62c20be_1 95 | - tensorboard=2.1.1=py_1 96 | - tk=8.6.8=hbc83047_0 97 | - torchvision=0.2.1=py37_0 98 | - tornado=6.0.4=py37h7b6447c_1 99 | - urllib3=1.25.9=py_0 100 | - werkzeug=1.0.1=pyh9f0ad1d_0 101 | - wheel=0.34.2=py37_0 102 | - xz=5.2.5=h7b6447c_0 103 | - zipp=3.1.0=py_0 104 | - zlib=1.2.11=h7b6447c_3 105 | - zstd=1.4.4=h0b5b093_3 106 | prefix: /home/miniconda3/envs/nips2020 107 | 108 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Change the dataset path 2 | DATASET_PATH ='~/data' 3 | 4 | import os 5 | import argparse 6 | import time 7 | from datetime import datetime 8 | import json 9 | import warnings 10 | warnings.filterwarnings("ignore") 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.autograd import Variable 15 | 16 | import utils 17 | import our_network 18 | import quant_utils 19 | 20 | parser = argparse.ArgumentParser(description='Test for CIFAR10/100') 21 | parser.add_argument('--arch', metavar='ARCH', default='ResNet32', choices=['ResNet32', 'Vgg16_bn']) 22 | parser.add_argument('--text', default='result.txt', type=str) 23 | parser.add_argument('--exp_name', default='cifar100', type=str) 24 | parser.add_argument('--log_time', default='1', type=str) 25 | parser.add_argument('--model_path', default='models/W4_ResNet.pth', type=str) 26 | 27 | parser.add_argument('--w_bit', default='4', type=int) 28 | parser.add_argument('--lambda_s', default='150', type=float) # For logging purposes 29 | parser.add_argument('--a_bit', default='8', type=float) 30 | parser.add_argument('--first_last_quant', default=1, type=int) 31 | parser.add_argument('--act_quant', default=0, type=int) 32 | parser.add_argument('--act_clipping', default=0, type=int) 33 | parser.add_argument('--clipping_range', default=6, type=int) 34 | parser.add_argument('--cu_num', default='0', type=str) 35 | 36 | args = parser.parse_args() 37 | print(args) 38 | os.environ['CUDA_VISIBLE_DEVICES'] = args.cu_num 39 | 40 | lambda_s = args.lambda_s 41 | w_bits = args.w_bit 42 | a_bits = args.a_bit 43 | act_quant = True if args.act_quant else False 44 | fl_quant = True if args.first_last_quant else False 45 | act_clipping = True if args.act_clipping else False 46 | clipping_range = args.clipping_range 47 | DEVICE = torch.device("cuda") 48 | EXPERIMENT_NAME = args.exp_name 49 | 50 | trainloader, valloader, testloader = utils.get_cifar100_dataloaders(128, 100) 51 | model = our_network.ResNet32(w_bits, a_bits, lambda_s, use_fp=True, activation_quant=False, quant_first_last=fl_quant) 52 | lp_net = our_network.ResNet32(w_bits, a_bits, lambda_s, use_fp=False, activation_quant=act_quant, quant_first_last=fl_quant) 53 | 54 | states = torch.load(args.model_path, map_location=DEVICE) 55 | utils.load_checkpoint(model, states) 56 | if 'state_dict' in states.keys(): 57 | epoch = states['epoch'] 58 | else: 59 | epoch = 0 60 | model.to(DEVICE) 61 | criterion_CE = nn.CrossEntropyLoss() 62 | 63 | def test(net): 64 | epoch_start_time = time.time() 65 | net.eval() 66 | test_loss = 0 67 | correct = 0 68 | total = 0 69 | criterion_CE = nn.CrossEntropyLoss() 70 | 71 | for batch_idx, (inputs, targets) in enumerate(testloader): 72 | inputs, targets = inputs.to(DEVICE), targets.to(DEVICE) 73 | inputs, targets = Variable(inputs), Variable(targets) 74 | outputs = net(inputs) 75 | 76 | loss = criterion_CE(outputs, targets) 77 | 78 | test_loss += loss.item() 79 | _, predicted = torch.max(outputs.data, 1) 80 | total += targets.size(0) 81 | correct += predicted.eq(targets.data).cpu().sum().float().item() 82 | b_idx = batch_idx 83 | 84 | print('Test \t Time Taken: %.2f sec' % (time.time() - epoch_start_time)) 85 | print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss / (b_idx + 1), 100. * correct / total, correct, total)) 86 | return test_loss / (b_idx + 1), correct / total 87 | 88 | def test_LP(): 89 | global my_list 90 | net = lp_net 91 | utils.load_checkpoint(net, states) 92 | if act_quant and act_clipping: 93 | _ = quant_utils.search_bn(net, clipping_range) 94 | net.to(DEVICE) 95 | print("Low Precision: ") 96 | test_loss, acc = test(net) 97 | return test_loss, acc 98 | 99 | def test_FP(): 100 | net = model 101 | net.to(DEVICE) 102 | print("Full Precision: ") 103 | test_loss, acc = test(net) 104 | return test_loss, acc 105 | 106 | if __name__ == '__main__': 107 | time_log = datetime.now().strftime('%m-%d %H:%M') 108 | if int(args.log_time) : 109 | folder_name = 'W{}A{}Scale{}_{}'.format(w_bits, a_bits, lambda_s, time_log) 110 | else: 111 | folder_name = 'W{}A{}Scale{}'.format(w_bits, a_bits, lambda_s) 112 | 113 | path = os.path.join(EXPERIMENT_NAME, folder_name) 114 | if not os.path.exists('results/' + path): 115 | os.makedirs('results/' + path) 116 | # Save argparse arguments as logging 117 | with open('results/{}/commandline_args.txt'.format(path), 'w') as f: 118 | json.dump(args.__dict__, f, indent=2) 119 | 120 | f = open(os.path.join("results/"+ path, args.text), "a") 121 | test_loss_LP, accuracy_LP = test_LP() 122 | test_loss_FP, accuracy_FP = test_FP() 123 | 124 | f.write('{} \t EPOCH {epoch} \t' 125 | 'FP Acc {:.4f} \t LP Acc {:.4f} \n'.format( 126 | time_log, accuracy_FP, accuracy_LP, epoch=epoch)) 127 | f.close() 128 | 129 | -------------------------------------------------------------------------------- /fig/conv1.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/conv1.weight.png -------------------------------------------------------------------------------- /fig/fc.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/fc.weight.png -------------------------------------------------------------------------------- /fig/layer1.0.conv1.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer1.0.conv1.weight.png -------------------------------------------------------------------------------- /fig/layer1.0.conv2.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer1.0.conv2.weight.png -------------------------------------------------------------------------------- /fig/layer1.1.conv1.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer1.1.conv1.weight.png -------------------------------------------------------------------------------- /fig/layer1.1.conv2.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer1.1.conv2.weight.png -------------------------------------------------------------------------------- /fig/layer1.2.conv1.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer1.2.conv1.weight.png -------------------------------------------------------------------------------- /fig/layer1.2.conv2.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer1.2.conv2.weight.png -------------------------------------------------------------------------------- /fig/layer1.3.conv1.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer1.3.conv1.weight.png -------------------------------------------------------------------------------- /fig/layer1.3.conv2.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer1.3.conv2.weight.png -------------------------------------------------------------------------------- /fig/layer1.4.conv1.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer1.4.conv1.weight.png -------------------------------------------------------------------------------- /fig/layer1.4.conv2.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer1.4.conv2.weight.png -------------------------------------------------------------------------------- /fig/layer2.0.conv1.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer2.0.conv1.weight.png -------------------------------------------------------------------------------- /fig/layer2.0.conv2.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer2.0.conv2.weight.png -------------------------------------------------------------------------------- /fig/layer2.1.conv1.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer2.1.conv1.weight.png -------------------------------------------------------------------------------- /fig/layer2.1.conv2.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer2.1.conv2.weight.png -------------------------------------------------------------------------------- /fig/layer2.2.conv1.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer2.2.conv1.weight.png -------------------------------------------------------------------------------- /fig/layer2.2.conv2.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer2.2.conv2.weight.png -------------------------------------------------------------------------------- /fig/layer2.3.conv1.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer2.3.conv1.weight.png -------------------------------------------------------------------------------- /fig/layer2.3.conv2.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer2.3.conv2.weight.png -------------------------------------------------------------------------------- /fig/layer2.4.conv1.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer2.4.conv1.weight.png -------------------------------------------------------------------------------- /fig/layer2.4.conv2.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer2.4.conv2.weight.png -------------------------------------------------------------------------------- /fig/layer3.0.conv1.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer3.0.conv1.weight.png -------------------------------------------------------------------------------- /fig/layer3.0.conv2.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer3.0.conv2.weight.png -------------------------------------------------------------------------------- /fig/layer3.1.conv1.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer3.1.conv1.weight.png -------------------------------------------------------------------------------- /fig/layer3.1.conv2.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer3.1.conv2.weight.png -------------------------------------------------------------------------------- /fig/layer3.2.conv1.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer3.2.conv1.weight.png -------------------------------------------------------------------------------- /fig/layer3.2.conv2.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer3.2.conv2.weight.png -------------------------------------------------------------------------------- /fig/layer3.3.conv1.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer3.3.conv1.weight.png -------------------------------------------------------------------------------- /fig/layer3.3.conv2.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer3.3.conv2.weight.png -------------------------------------------------------------------------------- /fig/layer3.4.conv1.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer3.4.conv1.weight.png -------------------------------------------------------------------------------- /fig/layer3.4.conv2.weight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/fig/layer3.4.conv2.weight.png -------------------------------------------------------------------------------- /hyperparmeter_search.py: -------------------------------------------------------------------------------- 1 | # Change the dataset path 2 | DATASET_PATH ='~/data' 3 | 4 | import argparse 5 | import json 6 | import time 7 | from datetime import datetime 8 | import warnings 9 | import os 10 | warnings.filterwarnings("ignore") 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | 16 | from logger import SummaryLogger 17 | import utils 18 | import our_network 19 | 20 | 21 | parser = argparse.ArgumentParser(description='Quantization finetuning for CIFAR100') 22 | parser.add_argument('--arch', metavar='ARCH', default='ResNet32', choices=['ResNet32', 'Vgg16_bn']) 23 | parser.add_argument('--text', default='log.txt', type=str) 24 | parser.add_argument('--exp_name', default='cifar100/hyperparameter/', type=str) 25 | parser.add_argument('--log_time', default='1', type=str) 26 | parser.add_argument('--lr', default='0.01', type=float) # By default 1e-4 for Adam // 1e-2 for SGD when starting from EPOCH 82 27 | parser.add_argument('--resume_epoch', default='83', type=int) 28 | parser.add_argument('--epoch', default='150', type=int) 29 | parser.add_argument('--decay_epoch', default=[123], nargs="*", type=int) 30 | parser.add_argument('--w_decay', default='1e-4', type=float) 31 | parser.add_argument('--adam', default='0', type=float) 32 | parser.add_argument('--cu_num', default='0', type=str) 33 | parser.add_argument('--seed', default='1', type=str) 34 | 35 | parser.add_argument('--load_pretrained', default='models/ResNet82.pth', type=str) 36 | parser.add_argument('--save_model', default='ckpt.t7', type=str) 37 | 38 | parser.add_argument('--w_bit', default='2', type=int) 39 | parser.add_argument('--lambda_s', default='150', type=float) 40 | parser.add_argument('--a_bit', default='4', type=float) 41 | parser.add_argument('--first_last_quant', default=1, type=int) 42 | 43 | torch.backends.cudnn.deterministic = True 44 | torch.backends.cudnn.benchmark = False 45 | 46 | args = parser.parse_args() 47 | print(args) 48 | 49 | torch.manual_seed(int(args.seed)) 50 | os.environ['CUDA_VISIBLE_DEVICES'] = args.cu_num 51 | trainloader, valloader, testloader = utils.get_cifar100_dataloaders_disjoint(128, 100) 52 | 53 | #Quantization parameters 54 | base_lr = args.lr 55 | lambda_s = args.lambda_s 56 | w_bits = args.w_bit 57 | a_bits = args.a_bit 58 | fl_quant = True if args.first_last_quant else False 59 | 60 | #Other parameters 61 | DEVICE = torch.device("cuda") 62 | RESUME_EPOCH = args.resume_epoch 63 | DECAY_EPOCH = args.decay_epoch 64 | DECAY_EPOCH = [ep - RESUME_EPOCH for ep in DECAY_EPOCH] 65 | FINAL_EPOCH = args.epoch 66 | EXPERIMENT_NAME = args.exp_name 67 | W_DECAY = args.w_decay 68 | USE_ADAM = int(args.adam) 69 | if w_bits == 2: 70 | print("*" * 20) 71 | print("W_DECAY set to 0") 72 | print("*" * 20) 73 | W_DECAY = 0 74 | 75 | model = our_network.__dict__[args.arch](w_bits, a_bits, lambda_s, use_fp=True, activation_quant=False, quant_first_last=fl_quant) 76 | 77 | if len(args.load_pretrained) > 2 : 78 | path = args.load_pretrained 79 | state = torch.load(path) 80 | utils.load_checkpoint(model, state) 81 | 82 | model.to(DEVICE) 83 | 84 | if not USE_ADAM: 85 | optimizer = optim.SGD(model.parameters(), lr=base_lr, nesterov=False, momentum=0.9, weight_decay=W_DECAY) 86 | else: 87 | print("*" *20) 88 | print("Using Adam as optimizer...") 89 | print("*" *20) 90 | base_lr *= 1e-2 * lambda_s 91 | optimizer = optim.Adam(model.parameters(), lr=base_lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=W_DECAY) 92 | optimizer.load_state_dict(state['optimizer']) 93 | 94 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=DECAY_EPOCH, gamma=0.1) 95 | criterion_CE = nn.CrossEntropyLoss() 96 | 97 | def eval(net, test_flag=False): 98 | loader = valloader if not test_flag else testloader 99 | flag = 'Val.' if not test_flag else 'Test' 100 | 101 | epoch_start_time = time.time() 102 | net.eval() 103 | val_loss = 0 104 | correct = 0 105 | total = 0 106 | criterion_CE = nn.CrossEntropyLoss() 107 | 108 | for batch_idx, (inputs, targets) in enumerate(loader): 109 | inputs, targets = inputs.to(DEVICE), targets.to(DEVICE) 110 | outputs = net(inputs) 111 | 112 | loss = criterion_CE(outputs, targets) 113 | val_loss += loss.item() 114 | _, predicted = torch.max(outputs.data, 1) 115 | 116 | total += targets.size(0) 117 | correct += predicted.eq(targets.data).cpu().sum().float().item() 118 | b_idx = batch_idx 119 | 120 | print('%s \t Time Taken: %.2f sec' % (flag, time.time() - epoch_start_time)) 121 | print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (val_loss / (b_idx + 1), 100. * correct / total, correct, total)) 122 | return val_loss / (b_idx + 1), correct / total 123 | 124 | def train(model, epoch): 125 | epoch_start_time = time.time() 126 | print('\n EPOCH: %d' % epoch) 127 | model.train() 128 | 129 | train_loss = 0 130 | correct = 0 131 | total = 0 132 | 133 | global optimizer 134 | 135 | for batch_idx, (inputs, targets) in enumerate(trainloader): 136 | inputs, targets = inputs.to(DEVICE), targets.to(DEVICE) 137 | optimizer.zero_grad() 138 | outputs = model(inputs) 139 | 140 | loss = criterion_CE(outputs, targets) 141 | loss.backward() 142 | 143 | optimizer.step() 144 | train_loss += loss.item() 145 | 146 | _, predicted = torch.max(outputs.data, 1) 147 | total += targets.size(0) 148 | correct += predicted.eq(targets.data).cpu().sum().float().item() 149 | b_idx = batch_idx 150 | 151 | print('Train s1 \t Time Taken: %.2f sec' % (time.time() - epoch_start_time)) 152 | print('Loss: %.3f | Acc s1: %.3f%% (%d/%d)' % (train_loss / (b_idx + 1), 100. * correct / total, correct, total)) 153 | 154 | return train_loss / (b_idx + 1), correct / total 155 | 156 | def eval_LP(address, lambda_s,num_bits, test_flag=False): 157 | net = our_network.__dict__[args.arch](num_bits, a_bits, lambda_s, use_fp=False, activation_quant=False, quant_first_last=fl_quant) 158 | 159 | old_param = torch.load(address) 160 | net.load_state_dict(old_param) 161 | net.to(DEVICE) 162 | if test_flag: 163 | print("***Test***") 164 | print("Low Precision: ") 165 | val_loss, acc = eval(net, test_flag) 166 | return val_loss, acc 167 | 168 | def eval_FP(address, lambda_s,num_bits, test_flag=False): 169 | net = our_network.__dict__[args.arch](num_bits, a_bits, lambda_s, use_fp=True, activation_quant=False, quant_first_last=fl_quant) 170 | 171 | old_param = torch.load(address) 172 | net.load_state_dict(old_param) 173 | net.to(DEVICE) 174 | 175 | if test_flag: 176 | print("***Test***") 177 | print("Full Precision: ") 178 | val_loss, acc = eval(net, test_flag) 179 | return val_loss, acc 180 | 181 | if __name__ == '__main__': 182 | time_log = datetime.now().strftime('%m-%d %H:%M') 183 | if int(args.log_time) : 184 | folder_name = 'Bit{}_Scale{}_{}'.format(w_bits, lambda_s, time_log) 185 | else: 186 | folder_name = 'Bit{}_Scale{}'.format(w_bits, lambda_s) 187 | 188 | path = os.path.join(EXPERIMENT_NAME, folder_name) 189 | if not os.path.exists('ckpt/' + path): 190 | os.makedirs('ckpt/' + path) 191 | if not os.path.exists('logs/' + path): 192 | os.makedirs('logs/' + path) 193 | 194 | # Save argparse arguments as logging 195 | with open('logs/{}/commandline_args.txt'.format(path), 'w') as f: 196 | json.dump(args.__dict__, f, indent=2) 197 | # Instantiate logger 198 | logger = SummaryLogger(path) 199 | best_FP = 0 200 | best_LP = 0 201 | 202 | for epoch in range(RESUME_EPOCH, FINAL_EPOCH+1): 203 | f = open(os.path.join("logs/" + path, 'log.txt'), "a") 204 | ### Train ### 205 | train_loss, acc = train(model, epoch) 206 | scheduler.step() 207 | ### save for evaluating LP and FP ### 208 | torch.save(model.state_dict(), "ckpt/{}/temp.t7".format(path)) 209 | address = "ckpt/{}/temp.t7".format(path) 210 | ### Evaluate LP and FP models ### 211 | val_loss_LP, accuracy_LP = eval_LP(address,lambda_s,w_bits, test_flag=False) 212 | val_loss_FP, accuracy_FP = eval_FP(address,lambda_s,w_bits, test_flag=False) 213 | 214 | is_best = accuracy_FP > best_FP 215 | best_FP = max(accuracy_FP, best_FP) 216 | LP_is_best = accuracy_LP > best_LP 217 | best_LP = max(accuracy_LP, best_LP) 218 | 219 | utils.save_checkpoint({ 220 | 'epoch': epoch, 221 | 'state_dict': model.state_dict(), 222 | 'best_FP_acc': best_FP, 223 | 'best_LP_acc' : best_LP, 224 | 'optimizer' : optimizer.state_dict(), 225 | }, is_best, 'ckpt/' + path, filename='{}.pth'.format(epoch)) 226 | 227 | train_log = {'Loss': train_loss, 'Accuracy': acc} 228 | val_log = {'LP loss': val_loss_LP, 'LP accuracy': accuracy_LP, 229 | 'FP loss': val_loss_FP, 'FP accuracy': accuracy_FP} 230 | 231 | logger.add_scalar_group('Train', train_log, epoch) 232 | logger.add_scalar_group('Val', val_log, epoch) 233 | 234 | f.write('EPOCH {epoch} \t' 235 | 'Trainacc : {acc:.4f} \t Valacc_LP : {top1_LP:.4f}\t' 236 | 'Valacc_FP : {top1_FP:.4f} \t Bestacc_LP : {best_LP:.4f} \t' 237 | 'Bestacc_FP : {best_FP:.4f} \n'.format( 238 | epoch=epoch, acc=acc, top1_LP=accuracy_LP, top1_FP=accuracy_FP, best_LP=best_LP, best_FP=best_FP) 239 | ) 240 | f.close() 241 | 242 | 243 | print("*" * 20) 244 | print("Testing final model") 245 | print("*" * 20) 246 | test_loss_LP, test_accuracy_LP = eval_LP(address, lambda_s, w_bits, test_flag=True) 247 | test_loss_FP, test_accuracy_FP = eval_FP(address, lambda_s, w_bits, test_flag=True) 248 | f = open(os.path.join("logs/" + path, 'log.txt'), "a") 249 | f.write('Test FP : {:.4f} \t Test LP : {:.4f}'.format(test_accuracy_FP, test_accuracy_LP)) 250 | f.close() 251 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | from torch.utils.tensorboard import SummaryWriter 2 | 3 | 4 | class SummaryLogger(SummaryWriter): 5 | 6 | def __init__(self, path): 7 | super().__init__() 8 | file_path = 'logs/' + path 9 | self.logger = SummaryWriter(file_path) 10 | 11 | def add_scalar_group(self, main_tag, tag_scalar_dict, global_step): 12 | for sub_tag, scalar in tag_scalar_dict.items(): 13 | self.logger.add_scalar(main_tag+'/{}'.format(sub_tag), scalar, global_step) 14 | 15 | 16 | #################################################################### 17 | """ 18 | Convenience function for logging multiple scalars of the same group 19 | Example Below: """ 20 | # logger = SummaryLogger('logs') 21 | # 22 | # import numpy as np 23 | # scalar_dict = {'Loss1': np.random.random(), 24 | # 'Loss2': np.random.random(), 25 | # 'Loss3.': np.random.random()} 26 | # main_tag ='Train' 27 | # for i in range(100): 28 | # global_step = i 29 | # logger.add_scalar_group(main_tag, scalar_dict, global_step) 30 | # 31 | # logger.add_scalar('tag',5,5) 32 | ####################################################################### -------------------------------------------------------------------------------- /models/ResNet150.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/models/ResNet150.pth -------------------------------------------------------------------------------- /models/ResNet82.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/models/ResNet82.pth -------------------------------------------------------------------------------- /models/W4_ResNet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jangho-Kim/PSG-pytorch/878e8ac6245703eb59148d91d723f25a0169830f/models/W4_ResNet.pth -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .cifar100_models import * 2 | 3 | -------------------------------------------------------------------------------- /models/cifar100_models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch.nn.functional as F 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.relu1 = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | 25 | self.downsample = downsample 26 | self.stride = stride 27 | 28 | def forward(self, x): 29 | residual = x 30 | out = self.conv1(x) 31 | out = self.bn1(out) 32 | out = self.relu1(out) 33 | 34 | out = self.conv2(out) 35 | out = self.bn2(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu2(out) 42 | 43 | return out 44 | 45 | 46 | class Bottleneck(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, inplanes, planes, stride=1, downsample=None): 50 | super(Bottleneck, self).__init__() 51 | 52 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False) 57 | self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion) 58 | self.relu = nn.ReLU(inplace=True) 59 | 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | x = F.relu(x) 65 | residual = x 66 | 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | out = self.relu(out) 70 | 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv3(out) 76 | out = self.bn3(out) 77 | if self.downsample is not None: 78 | residual = self.downsample(x) 79 | 80 | out += residual 81 | out = self.relu(out) 82 | 83 | return out 84 | 85 | 86 | class ResNet(nn.Module): 87 | def __init__(self, depth=32, num_classes=100, bottleneck=False): 88 | super(ResNet, self).__init__() 89 | self.inplanes = 16 90 | print(bottleneck) 91 | if bottleneck == True: 92 | n = int((depth - 2) / 9) 93 | block = Bottleneck 94 | else: 95 | n = int((depth - 2) / 6) 96 | block = BasicBlock 97 | 98 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 99 | self.bn1 = nn.BatchNorm2d(self.inplanes) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.layer1 = self._make_layer(block, 16, n) 102 | self.layer2 = self._make_layer(block, 32, n, stride=2) 103 | self.layer3 = self._make_layer(block, 64, n, stride=2) 104 | self.avgpool = nn.AvgPool2d(8) 105 | self.fc = nn.Linear(64 * block.expansion, num_classes) 106 | 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 110 | m.weight.data.normal_(0, math.sqrt(2. / n)) 111 | elif isinstance(m, nn.BatchNorm2d): 112 | m.weight.data.fill_(1) 113 | m.bias.data.zero_() 114 | 115 | def _make_layer(self, block, planes, blocks, stride=1): 116 | downsample = None 117 | if stride != 1 or self.inplanes != planes * block.expansion: 118 | downsample = nn.Sequential( 119 | nn.Conv2d(self.inplanes, planes * block.expansion, 120 | kernel_size=1, stride=stride, bias=False), 121 | nn.BatchNorm2d(planes * block.expansion), 122 | ) 123 | 124 | layers = [] 125 | layers.append(block(self.inplanes, planes, stride, downsample)) 126 | self.inplanes = planes * block.expansion 127 | for i in range(1, blocks): 128 | layers.append(block(self.inplanes, planes)) 129 | 130 | return nn.Sequential(*layers) 131 | 132 | def forward(self, x): 133 | x = self.conv1(x) 134 | x = self.bn1(x) 135 | x = self.relu(x) 136 | 137 | x = self.layer1(x) 138 | x = self.layer2(x) 139 | x = self.layer3(x) 140 | x = F.relu(x) 141 | 142 | x = self.avgpool(x) 143 | x = x.view(x.size(0), -1) 144 | x = self.fc(x) 145 | return x 146 | 147 | 148 | cfg = { 149 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 150 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 151 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 152 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 153 | 512, 512, 512, 512, 'M'], 154 | } 155 | 156 | class VGG(nn.Module): 157 | ''' 158 | VGG model 159 | ''' 160 | def __init__(self, features): 161 | super(VGG, self).__init__() 162 | self.features = features 163 | self.last_layer = nn.Linear(4096,100) 164 | self.classifier = nn.Sequential( 165 | # nn.Dropout(), 166 | nn.Linear(512, 4096, bias=True), 167 | nn.ReLU(True), 168 | nn.Dropout(), 169 | nn.Linear(4096, 4096, bias=True), 170 | nn.ReLU(True), 171 | nn.Dropout(), 172 | self.last_layer, 173 | ) 174 | # Initialize weights 175 | for m in self.modules(): 176 | if isinstance(m, nn.Conv2d): 177 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 178 | m.weight.data.normal_(0, math.sqrt(2. / n)) 179 | # m.bias.data.zero_() 180 | 181 | 182 | def forward(self, x): 183 | x = self.features(x) 184 | x = x.view(x.size(0), -1) 185 | x = self.classifier(x) 186 | return x 187 | 188 | 189 | def make_layers(cfg, batch_norm=False): 190 | layers = [] 191 | in_channels = 3 192 | for v in cfg: 193 | if v == 'M': 194 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 195 | else: 196 | if batch_norm: 197 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False) 198 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 199 | else: 200 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 201 | layers += [conv2d, nn.ReLU(inplace=True)] 202 | in_channels = v 203 | return nn.Sequential(*layers) 204 | 205 | def vgg16_bn(): 206 | """VGG 16-layer model (configuration "D") with batch normalization""" 207 | return VGG(make_layers(cfg['D'], batch_norm=True)) 208 | -------------------------------------------------------------------------------- /our_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from quant_module import * 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | expansion = 1 9 | def __init__(self, lambda_s, w_bits, a_bits, in_planes, planes, stride=1, use_fp=True, activation_quant=True): 10 | super(BasicBlock, self).__init__() 11 | self.conv1 = Conv2d_minmax(lambda_s, w_bits, in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False, use_fp=use_fp) 12 | self.bn1 = nn.BatchNorm2d(planes) 13 | self.act1 = ReLU_quant(a_bits) if activation_quant else nn.ReLU() 14 | self.conv2 = Conv2d_minmax(lambda_s, w_bits, planes, planes, kernel_size=3, stride=1, padding=1, bias=False, use_fp=use_fp) 15 | self.bn2 = nn.BatchNorm2d(planes) 16 | self.act2 = ReLU_quant(a_bits) if activation_quant else nn.ReLU() 17 | 18 | self.downsample = nn.Sequential() 19 | if stride != 1 or in_planes != self.expansion*planes: 20 | self.downsample = nn.Sequential( 21 | Conv2d_minmax(lambda_s, w_bits, in_planes, self.expansion*planes, kernel_size=1, stride=stride, padding=0, bias=False, use_fp=use_fp), 22 | nn.BatchNorm2d(self.expansion*planes) 23 | ) 24 | 25 | def forward(self, x): 26 | out = self.act1(self.bn1(self.conv1(x))) 27 | out = self.bn2(self.conv2(out)) 28 | out += self.downsample(x) 29 | out = self.act2(out) 30 | return out 31 | 32 | 33 | class ResNet(nn.Module): 34 | def __init__(self, block, num_blocks, num_classes=100, lambda_s=1, w_bits=4, a_bits=4, use_fp=True, activation_quant=True, quant_first_last=False): 35 | super(ResNet, self).__init__() 36 | self.in_planes = 16 37 | self.w_bits = w_bits 38 | self.a_bits = a_bits 39 | self.lambda_s= lambda_s 40 | self.act_quant = activation_quant 41 | 42 | if quant_first_last: # Fixed to 8 bit for 3 bits or lower 43 | fl_bits = 8 if w_bits < 4 else w_bits 44 | self.conv1 = Conv2d_minmax(lambda_s, fl_bits, 3, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False, use_fp=use_fp) 45 | else: 46 | self.conv1 = nn.Conv2d(3, self.in_planes, kernel_size=3, stride=1, padding=1, bias=False) 47 | 48 | self.bn1 = nn.BatchNorm2d(self.in_planes) 49 | self.act = ReLU_quant(a_bits) if activation_quant else nn.ReLU() 50 | 51 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1, use_fp=use_fp) 52 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2, use_fp=use_fp) 53 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2, use_fp=use_fp) 54 | self.avgpool = nn.AvgPool2d(kernel_size=8, stride=1) 55 | 56 | if quant_first_last: 57 | self.fc = Linear_minmax(lambda_s, fl_bits, 64 * block.expansion, num_classes, use_fp=use_fp) 58 | else: 59 | self.fc = nn.Linear(64, num_classes) 60 | 61 | def _make_layer(self, block, planes, num_blocks, stride, use_fp): 62 | layers = [] 63 | layers.append(block(self.lambda_s, self.w_bits, self.a_bits, self.in_planes, planes, stride, use_fp, self.act_quant)) 64 | self.in_planes = planes * block.expansion 65 | for _ in range(1, num_blocks): 66 | layers.append(block(self.lambda_s, self.w_bits, self.a_bits, self.in_planes, planes, 1, use_fp, self.act_quant)) 67 | return nn.Sequential(*layers) 68 | 69 | def forward(self, x): 70 | out = self.act(self.bn1(self.conv1(x))) 71 | out = self.layer1(out) 72 | out = self.layer2(out) 73 | out = self.layer3(out) 74 | out = self.avgpool(out) 75 | out = out.view(out.size(0), -1) 76 | out = self.fc(out) 77 | return out 78 | 79 | 80 | def ResNet32(w_bits, a_bits, lambda_s, use_fp=True, activation_quant=False, quant_first_last=False): 81 | return ResNet(BasicBlock, [5,5,5], w_bits=w_bits, a_bits=a_bits, lambda_s=lambda_s, use_fp=use_fp, activation_quant=activation_quant, quant_first_last=quant_first_last) 82 | 83 | 84 | cfg = { 85 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 86 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 87 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 88 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 89 | 512, 512, 512, 512, 'M'], 90 | } 91 | 92 | class VGG(nn.Module): 93 | ''' 94 | VGG model 95 | ''' 96 | def __init__(self, features, q_cfg): 97 | super(VGG, self).__init__() 98 | self.features = features 99 | self.q_cfg = q_cfg 100 | self.quant_first_last = q_cfg['quant_fl'] 101 | self.use_fp = q_cfg['use_fp'] 102 | fl_bit = q_cfg['w_bits'] if q_cfg['w_bits'] > 3 else 8 103 | fl_lambda_s = q_cfg['lambda_s'] 104 | 105 | if self.quant_first_last: 106 | self.last_layer = Linear_minmax(fl_lambda_s, fl_bit, 4096, 100, use_fp=self.use_fp) 107 | else: 108 | self.last_layer = nn.Linear(4096, 100) 109 | 110 | self.act = ReLU_quant(self.q_cfg['a_bits']) if self.q_cfg['activation_quant'] else nn.ReLU() 111 | 112 | self.classifier = nn.Sequential( 113 | Linear_minmax(self.q_cfg['lambda_s'], self.q_cfg['w_bits'], 512, 4096, use_fp=self.use_fp, bias=True), 114 | self.act, # ReLUQuant 115 | nn.Dropout(), 116 | Linear_minmax(self.q_cfg['lambda_s'], self.q_cfg['w_bits'], 4096, 4096, use_fp=self.use_fp, bias=True), 117 | self.act, #ReLUQuant x 118 | nn.Dropout(), 119 | self.last_layer, 120 | ) 121 | 122 | def forward(self, x): 123 | x = self.features(x) 124 | x = x.view(x.size(0), -1) 125 | x = self.classifier(x) 126 | return x 127 | 128 | 129 | def make_layers(cfg, q_cfg, batch_norm=False): 130 | layers = [] 131 | in_channels = 3 132 | activation = ReLU_quant(q_cfg['a_bits']) if q_cfg['activation_quant'] else nn.ReLU() 133 | fl_bit = q_cfg['w_bits'] if q_cfg['w_bits'] > 3 else 8 134 | fl_lambda_s = q_cfg['lambda_s'] 135 | 136 | if q_cfg['quant_fl']: 137 | first_layer = Conv2d_minmax(fl_lambda_s, fl_bit, in_channels, cfg[0], kernel_size=3, padding=1, use_fp=q_cfg['use_fp'], bias=False) 138 | else: 139 | first_layer = nn.Conv2d(in_channels, cfg[0], kernel_size=3, padding=1, bias=False) 140 | 141 | if batch_norm: 142 | layers += [first_layer, nn.BatchNorm2d(cfg[0]), activation] 143 | else: 144 | layers += [first_layer, activation] 145 | in_channels = cfg[0] 146 | 147 | for v in cfg[1:]: 148 | if v == 'M': 149 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 150 | else: 151 | conv2d = Conv2d_minmax(q_cfg['lambda_s'], q_cfg['w_bits'], in_channels, v, kernel_size=3, padding=1, use_fp=q_cfg['use_fp'], bias=False) 152 | if batch_norm: 153 | layers += [conv2d, nn.BatchNorm2d(v), activation] 154 | else: 155 | layers += [conv2d, activation] 156 | in_channels = v 157 | return nn.Sequential(*layers) 158 | 159 | 160 | def Vgg16_bn(w_bits, a_bits, lambda_s, use_fp=True, activation_quant=True, quant_first_last=False): 161 | """VGG 16-layer model (configuration "D") with batch normalization""" 162 | q_cfg = {'w_bits': w_bits, "a_bits": a_bits, "lambda_s": lambda_s, "use_fp": use_fp,\ 163 | "activation_quant": activation_quant, "quant_fl":quant_first_last} 164 | return VGG(make_layers(cfg['D'], q_cfg, batch_norm=True), q_cfg) -------------------------------------------------------------------------------- /quant_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.autograd as autograd 5 | import numpy as np 6 | 7 | ### Quantization modules ### 8 | class post_training_weight(nn.Module): 9 | def __init__(self,lambda_s,n_bits, use_fp, activation, momentum=0.1): 10 | super(post_training_weight, self).__init__() 11 | self.n_bits = n_bits 12 | self.lambda_s = lambda_s 13 | # type of quantization 14 | self.activation = activation 15 | if activation: 16 | self.type_q = min_max_post_training_asymmetric 17 | else : 18 | self.type_q = min_max_post_training_symmetric if use_fp else min_max_post_training_symmetric_lp 19 | self.x_min = None 20 | self.x_max = None 21 | self.momentum = momentum 22 | self.fixed_range = False 23 | self.quantizer = None 24 | # use running mean 25 | self.use_running_mean = True 26 | 27 | def update_range(self,data): 28 | if self.fixed_range: 29 | self.quantizer = self.type_q(self.lambda_s,self.n_bits,self.x_min, self.x_max) 30 | return 31 | 32 | x_min = min(0., float(data.min())) 33 | x_max = max(0., float(data.max())) 34 | 35 | if self.use_running_mean: 36 | if self.x_min is not None: 37 | self.x_min = self.momentum * x_min + (1 - self.momentum) * (self.x_min or x_min) 38 | self.x_max = self.momentum * x_max + (1 - self.momentum) * (self.x_max or x_max) 39 | elif self.x_min is None: 40 | self.x_min = x_min 41 | self.x_max = x_max 42 | else: 43 | self.x_min = x_min 44 | self.x_max = x_max 45 | 46 | self.quantizer = self.type_q(self.lambda_s,self.n_bits, self.x_min, self.x_max) 47 | 48 | def change_range_mode(self,Boolean): 49 | self.fixed_range = Boolean 50 | 51 | def forward(self, x): 52 | self.update_range(x) 53 | return_value = self.quantizer.return_scale_value(x) 54 | return self.quantizer(x), return_value 55 | 56 | 57 | class min_max_post_training_asymmetric(autograd.Function): 58 | def __init__(self,beta, n_bits, x_min, x_max): 59 | super(min_max_post_training_asymmetric, self).__init__() 60 | self.beta = beta 61 | if n_bits == 0: 62 | return None 63 | else: 64 | lower = 0 65 | upper = 2 ** n_bits 66 | # np.arange upper -1 so the range will be 0~255 67 | self.constraint = np.arange(lower, upper) 68 | self.valmin = float(self.constraint.min()) 69 | self.valmax = float(self.constraint.max()) 70 | 71 | self.n_levels = 2 ** (n_bits) 72 | self.delta = float(x_max) / (self.n_levels - 1) 73 | 74 | def forward(self, *args, **kwargs): 75 | x = args[0] 76 | lambda_s = self.delta 77 | x_lambda_s = torch.div(x, lambda_s) 78 | x_clip = F.hardtanh(x_lambda_s, min_val=self.valmin , max_val=self.valmax) 79 | x_round = torch.round(x_clip) 80 | x_restore = torch.mul(x_round, lambda_s) 81 | return x_restore 82 | 83 | def backward(self, *grad_outputs): 84 | grad_top = grad_outputs[0] 85 | return grad_top 86 | 87 | def return_scale_value(self, x): 88 | lambda_s = self.delta 89 | x_lambda_s = torch.div(x, lambda_s) 90 | x_clip = F.hardtanh(x_lambda_s, min_val=self.valmin, max_val=self.valmax) 91 | x_round = torch.round(x_clip) 92 | x_restore = torch.mul(x_round, lambda_s) 93 | scale_value = torch.abs(x-x_restore) 94 | return scale_value 95 | 96 | ### Full precision modules (fowarding with FP weights) ### 97 | class min_max_post_training_symmetric(autograd.Function): 98 | def __init__(self, beta, n_bits, x_min, x_max): 99 | super(min_max_post_training_symmetric, self).__init__() 100 | self.iter = 0 101 | self.beta = beta 102 | if n_bits == 0: 103 | return None 104 | else: 105 | # Restricted Mode 106 | lower = -2 ** (n_bits - 1) + 1 107 | upper = 2 ** (n_bits - 1) 108 | 109 | self.constraint = np.arange(lower, upper) 110 | self.valmin = float(self.constraint.min()) 111 | self.valmax = float(self.constraint.max()) 112 | x_absmax = max(abs(x_min), x_max) 113 | ### Full range ### 114 | # self.n_levels = 2 ** (n_bits) - 1 115 | # self.delta = float(x_absmax) / (self.n_levels / 2) 116 | ################## 117 | self.n_levels = 2 ** (n_bits-1) # Restricted range 118 | self.delta = float(x_absmax) / (self.n_levels - 1) 119 | 120 | def forward(self, *args, **kwargs): 121 | x = args[0] 122 | lambda_s = self.delta 123 | x_lambda_s = torch.div(x, lambda_s) 124 | x_clip = F.hardtanh(x_lambda_s, min_val=self.valmin, max_val=self.valmax) 125 | x_round = torch.round(x_clip) 126 | x_restore = torch.mul(x_round, lambda_s) 127 | scale = torch.abs(x-x_restore) 128 | self.save_for_backward(scale) 129 | return x 130 | 131 | def backward(self, *grad_outputs): 132 | grad_top = grad_outputs[0] 133 | scale = self.saved_tensors[0] 134 | # return grad_top *scale*self.beta / scale.max() # Non-seperable Scaling 135 | return grad_top * scale * self.beta # Vanilla 136 | 137 | def return_scale_value(self, x): 138 | lambda_s = self.delta 139 | x_lambda_s = torch.div(x, lambda_s) 140 | x_clip = F.hardtanh(x_lambda_s, min_val=self.valmin, max_val=self.valmax) 141 | x_round = torch.round(x_clip) 142 | x_restore = torch.mul(x_round, lambda_s) 143 | scale_value = torch.abs(x-x_restore) 144 | return scale_value 145 | 146 | ### Low precision modules (fowarding with LP weights) ### 147 | class min_max_post_training_symmetric_lp(autograd.Function): 148 | def __init__(self, beta, n_bits, x_min, x_max): 149 | super(min_max_post_training_symmetric_lp, self).__init__() 150 | self.iter = 0 151 | self.beta = beta 152 | if n_bits == 0: 153 | return None 154 | else: 155 | lower = -2 ** (n_bits - 1) + 1 156 | upper = 2 ** (n_bits - 1) 157 | 158 | self.constraint = np.arange(lower, upper) 159 | self.valmin = float(self.constraint.min()) 160 | self.valmax = float(self.constraint.max()) 161 | x_absmax = max(abs(x_min), x_max) 162 | x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax 163 | self.n_levels = 2 ** (n_bits-1) 164 | self.delta = float(x_absmax) / (self.n_levels - 1) 165 | 166 | def forward(self, *args, **kwargs): 167 | x = args[0] 168 | 169 | lambda_s = self.delta 170 | x_lambda_s = torch.div(x, lambda_s) 171 | x_clip = F.hardtanh(x_lambda_s, min_val=self.valmin, max_val=self.valmax) 172 | x_round = torch.round(x_clip) 173 | x_restore = torch.mul(x_round, lambda_s) 174 | 175 | scale = torch.abs(x-x_restore) 176 | self.save_for_backward(scale) 177 | return x_restore 178 | 179 | def backward(self, *grad_outputs): 180 | grad_top = grad_outputs[0] 181 | return grad_top 182 | 183 | def return_scale_value(self, x): 184 | lambda_s = self.delta 185 | x_lambda_s = torch.div(x, lambda_s) 186 | x_clip = F.hardtanh(x_lambda_s, min_val=self.valmin, max_val=self.valmax) 187 | x_round = torch.round(x_clip) 188 | x_restore = torch.mul(x_round, lambda_s) 189 | scale_value = torch.abs(x-x_restore) 190 | return scale_value 191 | 192 | 193 | ### Network module ### 194 | class Conv2d(nn.Conv2d): 195 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, 196 | padding=1, dilation=1, groups=1, bias=False): 197 | super(Conv2d, self).__init__(in_channels, out_channels, kernel_size,stride, 198 | padding, dilation, groups, bias) 199 | self.wquantizer = None 200 | 201 | def forward(self, x): 202 | weight = self.weight if self.wquantizer is None else self.wquantizer(self.weight) 203 | return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 204 | 205 | 206 | 207 | class Conv2d_minmax(nn.Conv2d): 208 | def __init__(self,lambda_s,n_bits,in_channels, out_channels, kernel_size=3, stride=1, 209 | padding=1, dilation=1, groups=1, bias=False, use_fp=True): 210 | super(Conv2d_minmax, self).__init__(in_channels, out_channels, kernel_size,stride, 211 | padding, dilation, groups, bias) 212 | 213 | self.register_buffer('scale_value', torch.rand(out_channels, in_channels, kernel_size, kernel_size)) 214 | self.use_fp = use_fp 215 | # Activation quant. is not needed for Conv. module 216 | self.wquantizer = post_training_weight(lambda_s,n_bits, use_fp, activation=False) 217 | 218 | def forward(self, x): 219 | weight, return_value = self.weight if self.wquantizer is None else self.wquantizer(self.weight) 220 | if self.use_fp: 221 | self.scale_value = return_value 222 | return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 223 | 224 | class Linear_minmax(nn.Linear): 225 | def __init__(self,lambda_s,n_bits, in_features, out_features, use_fp, activation=False, bias=True): 226 | super(Linear_minmax, self).__init__(in_features, out_features, bias) 227 | self.wquantizer = post_training_weight(lambda_s,n_bits, use_fp, activation=False) 228 | self.bquantizer = post_training_weight(lambda_s,n_bits, use_fp, activation=False) 229 | self.register_buffer('wscale_value', torch.rand(out_features, in_features)) 230 | self.register_buffer('bscale_value', torch.rand(out_features)) 231 | self.use_bias = bias 232 | self.use_fp = use_fp 233 | 234 | def forward(self, x): 235 | weight, wreturn_value = self.weight if self.wquantizer is None else self.wquantizer(self.weight) 236 | if self.use_bias: 237 | bias, breturn_value = self.bias if self.bquantizer is None else self.bquantizer(self.bias) 238 | else : 239 | bias, breturn_value = None, None 240 | if self.use_fp: 241 | self.wscale_value = wreturn_value 242 | self.bscale_value = breturn_value 243 | return F.linear(x, weight, bias) 244 | 245 | class ReLU_quant(nn.ReLU): 246 | def __init__(self, a_bits): 247 | super(ReLU_quant, self).__init__() 248 | self.aquantizer = post_training_weight(1, a_bits, use_fp=False, activation=True) 249 | 250 | def forward(self, x): 251 | x_quant , _ = x if self.aquantizer is None else self.aquantizer(x) 252 | return x_quant # ReLU function is applied inside aquantizer 253 | 254 | 255 | -------------------------------------------------------------------------------- /quant_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from quant_module import * 4 | 5 | 6 | def save_bn_param(module, bn_module, n): 7 | quantizer = module.aquantizer 8 | quantizer.change_range_mode(True) 9 | shift = bn_module.bias 10 | scale = bn_module.weight 11 | x_min = 0 12 | x_max = shift + n * scale 13 | quantizer.x_min = x_min 14 | quantizer.x_max = x_max.max() 15 | 16 | def is_bn(m): 17 | return isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d) 18 | 19 | def search_bn(model, n): 20 | my_list = [] 21 | prev = prev_name = None 22 | for name, m in model.named_children(): 23 | if is_bn(prev) and is_ReLU(m): 24 | save_bn_param(m, prev, n) 25 | my_list.append((m, prev)) 26 | search_bn(m, n) 27 | prev = m 28 | prev_name = name 29 | return my_list 30 | 31 | def is_ReLU(m): 32 | return isinstance(m, ReLU_quant) 33 | 34 | def compute_mse(model): 35 | # Computes Mean Absolute Percentage Error (MAPE) & Sum of Squares Error (SSE) of each layer 36 | sse_dict = {} 37 | mape_dict = {} 38 | for n, m in model.named_modules(): 39 | if isinstance(m, Conv2d_minmax) or isinstance(m, Linear_minmax): 40 | lp_weight = m.wquantizer(m.weight)[0].detach() 41 | fp_weight = m.weight 42 | mae = torch.nn.functional.l1_loss(fp_weight, lp_weight, reduction='none') 43 | sse = torch.nn.functional.mse_loss(fp_weight, lp_weight, reduction='sum') 44 | mape = mae / (torch.abs(fp_weight) + 1e-12) 45 | mape_dict[n] = torch.mean(mape).item() 46 | sse_dict[n] = sse.item() 47 | return mape_dict, sse_dict 48 | 49 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Change the dataset path 2 | DATASET_PATH ='~/data' 3 | 4 | import argparse 5 | import json 6 | import time 7 | from datetime import datetime 8 | import warnings 9 | import os 10 | warnings.filterwarnings("ignore") 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | 16 | from logger import SummaryLogger 17 | import utils 18 | import our_network 19 | 20 | 21 | parser = argparse.ArgumentParser(description='Quantization finetuning for CIFAR100') 22 | parser.add_argument('--arch', metavar='ARCH', default='ResNet32', choices=['ResNet32', 'Vgg16_bn']) 23 | parser.add_argument('--text', default='log.txt', type=str) 24 | parser.add_argument('--exp_name', default='cifar100/4bits', type=str) 25 | parser.add_argument('--log_time', default='1', type=str) 26 | parser.add_argument('--lr', default='0.01', type=float) # By default 1e-4 for Adam // 1e-2 for SGD when starting from EPOCH 82 27 | parser.add_argument('--resume_epoch', default='83', type=int) 28 | parser.add_argument('--epoch', default='150', type=int) 29 | parser.add_argument('--decay_epoch', default=[123], nargs="*", type=int) 30 | parser.add_argument('--w_decay', default='1e-4', type=float) 31 | parser.add_argument('--adam', default='0', type=float) 32 | parser.add_argument('--cu_num', default='0', type=str) 33 | parser.add_argument('--seed', default='1', type=str) 34 | 35 | parser.add_argument('--load_pretrained', default='models/ResNet82.pth', type=str) 36 | parser.add_argument('--save_model', default='ckpt.t7', type=str) 37 | 38 | parser.add_argument('--w_bit', default='4', type=int) 39 | parser.add_argument('--lambda_s', default='150', type=float) 40 | parser.add_argument('--a_bit', default='4', type=float) 41 | parser.add_argument('--first_last_quant', default=1, type=int) 42 | 43 | torch.backends.cudnn.deterministic = True 44 | torch.backends.cudnn.benchmark = False 45 | 46 | args = parser.parse_args() 47 | print(args) 48 | 49 | torch.manual_seed(int(args.seed)) 50 | os.environ['CUDA_VISIBLE_DEVICES'] = args.cu_num 51 | trainloader, valloader, testloader = utils.get_cifar100_dataloaders(128, 100) 52 | 53 | #Quantization parameters 54 | base_lr = args.lr 55 | lambda_s = args.lambda_s 56 | w_bits = args.w_bit 57 | a_bits = args.a_bit 58 | fl_quant = True if args.first_last_quant else False 59 | 60 | #Other parameters 61 | DEVICE = torch.device("cuda") 62 | RESUME_EPOCH = args.resume_epoch 63 | DECAY_EPOCH = args.decay_epoch 64 | DECAY_EPOCH = [ep - RESUME_EPOCH for ep in DECAY_EPOCH] 65 | FINAL_EPOCH = args.epoch 66 | EXPERIMENT_NAME = args.exp_name 67 | W_DECAY = args.w_decay 68 | USE_ADAM = int(args.adam) 69 | if w_bits == 2: 70 | print("*" * 20) 71 | print("W_DECAY set to 0") 72 | print("*" * 20) 73 | W_DECAY = 0 74 | 75 | model = our_network.__dict__[args.arch](w_bits, a_bits, lambda_s, use_fp=True, activation_quant=False, quant_first_last=fl_quant) 76 | 77 | if len(args.load_pretrained) > 2 : 78 | path = args.load_pretrained 79 | state = torch.load(path) 80 | utils.load_checkpoint(model, state) 81 | 82 | model.to(DEVICE) 83 | 84 | if not USE_ADAM: 85 | optimizer = optim.SGD(model.parameters(), lr=base_lr, nesterov=False, momentum=0.9, weight_decay=W_DECAY) 86 | else: 87 | print("*" *20) 88 | print("Using Adam as optimizer...") 89 | print("*" *20) 90 | base_lr *= 1e-2 * lambda_s 91 | optimizer = optim.Adam(model.parameters(), lr=base_lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=W_DECAY) 92 | optimizer.load_state_dict(state['optimizer']) 93 | 94 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=DECAY_EPOCH, gamma=0.1) 95 | criterion_CE = nn.CrossEntropyLoss() 96 | 97 | def eval(net, test_flag=False): 98 | loader = valloader if not test_flag else testloader 99 | flag = 'Val.' if not test_flag else 'Test' 100 | 101 | epoch_start_time = time.time() 102 | net.eval() 103 | val_loss = 0 104 | correct = 0 105 | total = 0 106 | criterion_CE = nn.CrossEntropyLoss() 107 | 108 | for batch_idx, (inputs, targets) in enumerate(loader): 109 | inputs, targets = inputs.to(DEVICE), targets.to(DEVICE) 110 | outputs = net(inputs) 111 | 112 | loss = criterion_CE(outputs, targets) 113 | val_loss += loss.item() 114 | _, predicted = torch.max(outputs.data, 1) 115 | 116 | total += targets.size(0) 117 | correct += predicted.eq(targets.data).cpu().sum().float().item() 118 | b_idx = batch_idx 119 | 120 | print('%s \t Time Taken: %.2f sec' % (flag, time.time() - epoch_start_time)) 121 | print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (val_loss / (b_idx + 1), 100. * correct / total, correct, total)) 122 | return val_loss / (b_idx + 1), correct / total 123 | 124 | def train(model, epoch): 125 | epoch_start_time = time.time() 126 | print('\n EPOCH: %d' % epoch) 127 | model.train() 128 | 129 | train_loss = 0 130 | correct = 0 131 | total = 0 132 | 133 | global optimizer 134 | 135 | for batch_idx, (inputs, targets) in enumerate(trainloader): 136 | inputs, targets = inputs.to(DEVICE), targets.to(DEVICE) 137 | optimizer.zero_grad() 138 | outputs = model(inputs) 139 | 140 | loss = criterion_CE(outputs, targets) 141 | loss.backward() 142 | 143 | optimizer.step() 144 | train_loss += loss.item() 145 | 146 | _, predicted = torch.max(outputs.data, 1) 147 | total += targets.size(0) 148 | correct += predicted.eq(targets.data).cpu().sum().float().item() 149 | b_idx = batch_idx 150 | 151 | print('Train s1 \t Time Taken: %.2f sec' % (time.time() - epoch_start_time)) 152 | print('Loss: %.3f | Acc s1: %.3f%% (%d/%d)' % (train_loss / (b_idx + 1), 100. * correct / total, correct, total)) 153 | 154 | return train_loss / (b_idx + 1), correct / total 155 | 156 | def eval_LP(address, lambda_s,num_bits, test_flag=False): 157 | net = our_network.__dict__[args.arch](num_bits, a_bits, lambda_s, use_fp=False, activation_quant=False, quant_first_last=fl_quant) 158 | 159 | old_param = torch.load(address) 160 | net.load_state_dict(old_param) 161 | net.to(DEVICE) 162 | if test_flag: 163 | print("***Test***") 164 | print("Low Precision: ") 165 | val_loss, acc = eval(net, test_flag) 166 | return val_loss, acc 167 | 168 | def eval_FP(address, lambda_s,num_bits, test_flag=False): 169 | net = our_network.__dict__[args.arch](num_bits, a_bits, lambda_s, use_fp=True, activation_quant=False, quant_first_last=fl_quant) 170 | 171 | old_param = torch.load(address) 172 | net.load_state_dict(old_param) 173 | net.to(DEVICE) 174 | 175 | if test_flag: 176 | print("***Test***") 177 | print("Full Precision: ") 178 | val_loss, acc = eval(net, test_flag) 179 | return val_loss, acc 180 | 181 | if __name__ == '__main__': 182 | time_log = datetime.now().strftime('%m-%d %H:%M') 183 | if int(args.log_time) : 184 | folder_name = 'Bit{}_Scale{}_{}'.format(w_bits, lambda_s, time_log) 185 | else: 186 | folder_name = 'Bit{}_Scale{}'.format(w_bits, lambda_s) 187 | 188 | path = os.path.join(EXPERIMENT_NAME, folder_name) 189 | if not os.path.exists('ckpt/' + path): 190 | os.makedirs('ckpt/' + path) 191 | if not os.path.exists('logs/' + path): 192 | os.makedirs('logs/' + path) 193 | 194 | # Save argparse arguments as logging 195 | with open('logs/{}/commandline_args.txt'.format(path), 'w') as f: 196 | json.dump(args.__dict__, f, indent=2) 197 | # Instantiate logger 198 | logger = SummaryLogger(path) 199 | best_FP = 0 200 | best_LP = 0 201 | 202 | with open(os.path.join("logs/" + path, 'log.txt'), "a") as f: 203 | torch.save(model.state_dict(), "ckpt/{}/temp.t7".format(path)) 204 | address = "ckpt/{}/temp.t7".format(path) 205 | print("Performance of pre-trained model") 206 | _ , _ = eval_FP(address, lambda_s, w_bits, test_flag=True) 207 | _ , _ = eval_LP(address, lambda_s, w_bits, test_flag=True) 208 | 209 | for epoch in range(RESUME_EPOCH, FINAL_EPOCH+1): 210 | f = open(os.path.join("logs/" + path, 'log.txt'), "a") 211 | ### Train ### 212 | train_loss, acc = train(model, epoch) 213 | scheduler.step() 214 | ### save for evaluating LP and FP ### 215 | torch.save(model.state_dict(), "ckpt/{}/temp.t7".format(path)) 216 | address = "ckpt/{}/temp.t7".format(path) 217 | ### Evaluate LP and FP models ### 218 | val_loss_LP, accuracy_LP = eval_LP(address,lambda_s,w_bits, test_flag=True) 219 | val_loss_FP, accuracy_FP = eval_FP(address,lambda_s,w_bits, test_flag=True) 220 | 221 | is_best = accuracy_FP > best_FP 222 | best_FP = max(accuracy_FP, best_FP) 223 | LP_is_best = accuracy_LP > best_LP 224 | best_LP = max(accuracy_LP, best_LP) 225 | 226 | utils.save_checkpoint({ 227 | 'epoch': epoch, 228 | 'state_dict': model.state_dict(), 229 | 'best_FP_acc': best_FP, 230 | 'best_LP_acc' : best_LP, 231 | 'optimizer' : optimizer.state_dict(), 232 | }, is_best, 'ckpt/' + path, filename='{}.pth'.format(epoch)) 233 | 234 | train_log = {'Loss': train_loss, 'Accuracy': acc} 235 | val_log = {'LP loss': val_loss_LP, 'LP accuracy': accuracy_LP, 236 | 'FP loss': val_loss_FP, 'FP accuracy': accuracy_FP} 237 | 238 | logger.add_scalar_group('Train', train_log, epoch) 239 | logger.add_scalar_group('Val', val_log, epoch) 240 | 241 | f.write('EPOCH {epoch} \t' 242 | 'Trainacc : {acc:.4f} \t Valacc_LP : {top1_LP:.4f}\t' 243 | 'Valacc_FP : {top1_FP:.4f} \t Bestacc_LP : {best_LP:.4f} \t' 244 | 'Bestacc_FP : {best_FP:.4f} \n'.format( 245 | epoch=epoch, acc=acc, top1_LP=accuracy_LP, top1_FP=accuracy_FP, best_LP=best_LP, best_FP=best_FP) 246 | ) 247 | f.close() 248 | 249 | 250 | print("*" * 20) 251 | print("Testing final model") 252 | print("*" * 20) 253 | test_loss_LP, test_accuracy_LP = eval_LP(address, lambda_s, w_bits, test_flag=True) 254 | test_loss_FP, test_accuracy_FP = eval_FP(address, lambda_s, w_bits, test_flag=True) 255 | f = open(os.path.join("logs/" + path, 'log.txt'), "a") 256 | f.write('Test FP : {:.4f} \t Test LP : {:.4f}'.format(test_accuracy_FP, test_accuracy_LP)) 257 | f.close() 258 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import numpy as np 4 | 5 | import torch 6 | import torchvision 7 | from torchvision import transforms 8 | from torch.utils.data import DataLoader 9 | from torch.utils.data.sampler import SubsetRandomSampler 10 | 11 | 12 | def save_checkpoint(state, is_best, path, filename='checkpoint.pth.tar'): 13 | filename = os.path.join(path, filename) 14 | torch.save(state, filename) 15 | if is_best: 16 | shutil.copyfile(filename, os.path.join(path,'model_best.pth.tar')) 17 | 18 | def load_checkpoint(model, checkpoint): 19 | m_keys = list(model.state_dict().keys()) 20 | 21 | if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: 22 | c_keys = list(checkpoint['state_dict'].keys()) 23 | not_m_keys = [i for i in c_keys if i not in m_keys] 24 | not_c_keys = [i for i in m_keys if i not in c_keys] 25 | model.load_state_dict(checkpoint['state_dict'], strict=False) 26 | 27 | else: 28 | c_keys = list(checkpoint.keys()) 29 | not_m_keys = [i for i in c_keys if i not in m_keys] 30 | not_c_keys = [i for i in m_keys if i not in c_keys] 31 | model.load_state_dict(checkpoint, strict=False) 32 | 33 | print("--------------------------------------\n LOADING PRETRAINING \n") 34 | print("Not in Model: ") 35 | print(not_m_keys) 36 | print("Not in Checkpoint") 37 | print(not_c_keys) 38 | print('\n\n') 39 | 40 | def get_cifar100_dataloaders(train_batch_size, test_batch_size): 41 | transform_train = transforms.Compose([ 42 | transforms.Pad(4, padding_mode='reflect'), 43 | transforms.RandomHorizontalFlip(), 44 | transforms.RandomCrop(32), 45 | transforms.ToTensor(), 46 | transforms.Normalize(mean=[x / 255.0 for x in [129.3, 124.1, 112.4]], 47 | std=[x / 255.0 for x in [68.2, 65.4, 70.4]]) 48 | ]) 49 | 50 | transform_test = transforms.Compose([ 51 | transforms.ToTensor(), 52 | transforms.Normalize(mean=[x / 255.0 for x in [129.3, 124.1, 112.4]], 53 | std=[x / 255.0 for x in [68.2, 65.4, 70.4]])]) 54 | 55 | 56 | trainset = torchvision.datasets.CIFAR100(root='~/data', train=True, download=True, 57 | transform=transform_train) 58 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=4) 59 | 60 | testset = torchvision.datasets.CIFAR100(root='~/data', train=False, download=True, 61 | transform=transform_test) 62 | testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, num_workers=4) 63 | 64 | subset_idx = np.random.randint(0, len(trainset), size=10000) 65 | valloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=False, num_workers=4, sampler=SubsetRandomSampler(subset_idx)) 66 | 67 | return trainloader, valloader, testloader 68 | 69 | def get_cifar100_dataloaders_disjoint(train_batch_size, test_batch_size): 70 | np.random.seed(0) 71 | transform_train = transforms.Compose([ 72 | transforms.Pad(4, padding_mode='reflect'), 73 | transforms.RandomHorizontalFlip(), 74 | transforms.RandomCrop(32), 75 | transforms.ToTensor(), 76 | transforms.Normalize(mean=[x / 255.0 for x in [129.3, 124.1, 112.4]], 77 | std=[x / 255.0 for x in [68.2, 65.4, 70.4]]) 78 | ]) 79 | transform_test = transforms.Compose([ 80 | transforms.ToTensor(), 81 | transforms.Normalize(mean=[x / 255.0 for x in [129.3, 124.1, 112.4]], 82 | std=[x / 255.0 for x in [68.2, 65.4, 70.4]])]) 83 | 84 | 85 | trainset = torchvision.datasets.CIFAR100(root='~/data', train=True, download=True,transform=transform_train) 86 | 87 | total_idx = np.arange(0,len(trainset)) 88 | np.random.shuffle(total_idx) 89 | subset_idx = total_idx[:10000] 90 | _subset_idx = total_idx[~np.in1d(total_idx, subset_idx)] 91 | valloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=False, num_workers=4, sampler=SubsetRandomSampler(subset_idx)) 92 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size, shuffle=False, num_workers=4, sampler=SubsetRandomSampler(_subset_idx)) 93 | 94 | testset = torchvision.datasets.CIFAR100(root='~/data', train=False, download=True, transform=transform_test) 95 | testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size, shuffle=False, num_workers=4) 96 | 97 | return trainloader, valloader, testloader 98 | 99 | 100 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib as mpl 3 | import torch 4 | import os 5 | 6 | plt.style.use('seaborn-ticks') 7 | mpl.use('TkAgg') 8 | 9 | paths = {'PSGD': 'models/W4_ResNet.pth' ,'SGD': 'models/ResNet150.pth'} 10 | states = [torch.load(paths['SGD'])['state_dict'], torch.load(paths['PSGD'])['state_dict']] 11 | 12 | weight_key = [] 13 | for i in states[1].keys(): 14 | if 'weight' in i and ('bn' not in i and 'downsample' not in i) : 15 | weight_key.append(i) 16 | 17 | if not os.path.exists('visualizations/'): 18 | os.makedirs('visualizations/') 19 | 20 | for layer in weight_key: 21 | fp_tensor = (states[0][layer].flatten().detach().cpu().numpy(), states[1][layer].flatten().detach().cpu().numpy()) 22 | 23 | fig = plt.figure() 24 | ax1 = fig.add_subplot(211) 25 | ax2 = fig.add_subplot(212) 26 | hist1 = ax1.hist(fp_tensor[0], bins=100, color='b', label='SGD') 27 | hist2 = ax2.hist(fp_tensor[1], bins=200, color='r', label='PSGD') 28 | ax1.legend() 29 | ax2.legend() 30 | plt.savefig('visualizations/{}.png'.format(layer)) 31 | plt.close() --------------------------------------------------------------------------------