├── LICENSE ├── README.md ├── env.yml ├── fractrain_cifar ├── compute_flops.py ├── data.py ├── model_info_mbv2.npy ├── models.py ├── modules │ ├── __init__.py │ ├── bwn.py │ ├── quantize.py │ └── rnlu.py ├── train_base.py ├── train_dfq.py ├── train_frac.py ├── train_pfq.py └── util_swa.py ├── fractrain_imagenet ├── data.py ├── models.py ├── modules │ ├── __init__.py │ ├── bwn.py │ ├── quantize.py │ └── rnlu.py ├── train_base.py ├── train_dfq.py ├── train_frac.py └── train_pfq.py └── img ├── DFQ.png ├── PFQ.png ├── dfq_result.png ├── fractrain_result_cifar.png ├── fractrain_result_imagenet.png └── pfq_result.png /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 RICE-EIC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the Software), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, andor sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED AS IS, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FracTrain: Fractionally Squeezing Bit Savings Both Temporally and Spatially for Efficient DNN Training 2 | ***Yonggan Fu***, Haoran You, Yang Zhao, Yue Wang, Chaojian Li, Kailash Gopalakrishnan, Zhangyang Wang, Yingyan Lin 3 | 4 | Accepted at NeurIPS 2020 [[Paper Link]](https://papers.nips.cc/paper/2020/file/8dc5983b8c4ef1d8fcd5f325f9a65511-Paper.pdf). 5 | 6 | ## Overview 7 | As reducing precision is one of the most effective knobs for boosting DNN training time/energy efficiency, there has been a growing interest in low-precision DNN 8 | training. In this paper, we propose the ***FracTrain*** framework which explores from an orthogonal direction: how to fractionally squeeze out more training cost savings from the most redundant bit level, progressively along the training trajectory and dynamically per input. 9 | 10 | 11 | ## Method 12 | We propose ***FracTrain*** that integrates *(i)* progressive fractional quantization (PFQ) which gradually increases the precision of activations, weights, and gradients that will not reach the precision of SOTA static quantized DNN training until the final training stage, and *(ii)* dynamic fractional quantization (DFQ) which assigns precisions to both the activations and gradients of each layer in an input-adaptive manner, for only “fractionally” updating layer parameters. 13 | 14 | ### Progressive Fractional Quantization (PFQ) 15 | 16 |

17 | 18 |

19 | 20 | PFQ progressively increases the precision for both forward and backward in DNN training controlled by an automated indicator, which gradually switches the focus from solution space exploration enabled by low precision to accurate update for better convergence enabled by high precision. 21 | 22 | ### Dynamic Fractional Quantization (DFQ) 23 | 24 |

25 | 26 |

27 | 28 | DFQ dynamically selects the precision for each layer’s activation and gradient during training in an input-adaptive manner controlled by a lightweight LSTM-based gate function. 29 | 30 | ### The FracTrain framework 31 | 32 | ***FracTrain*** integrates PFQ and DFQ in an easy but effective way that it applies PFQ’s indicator to schedule DFQ's compression ratio in different training stages. 33 | 34 | 35 | ## Evaluation 36 | We evaluate ***FracTrain*** on six models & four datasets (i.e., ResNet-38/74/MobileNetV2 on CIFAR-10/100, ResNet-18/34 on ImageNet, and Transformer on WikiText-103). Here are some representative experiments and please refer to [our paper](https://papers.nips.cc/paper/2020/file/8dc5983b8c4ef1d8fcd5f325f9a65511-Paper.pdf) for more results. 37 | 38 | 39 | ### Evaluating PFQ 40 | 41 |

42 | 43 |

44 | 45 | 46 | ### Evaluating DFQ 47 | 48 | 49 |

50 | 51 |

52 | 53 | 54 | ### Evaluating FracTrain integrating PFQ and DFQ 55 | 56 | - ***FracTrain*** on CIFAR-10/100 57 |

58 | 59 |

60 | 61 | 62 | - ***FracTrain*** on ImageNet 63 |

64 | 65 |

66 | 67 | 68 | ## Code Usage 69 | Our code is inspired by [SkipNet](https://github.com/ucbdrive/skipnet) and we follow its training setting for CIFAR-10/100 and ImageNet. 70 | 71 | ### Prerequisites 72 | See `env.yml` for the complete conda environment. Create a new conda environment: 73 | ``` 74 | conda env create -f env.yml 75 | conda activate pytorch 76 | ``` 77 | 78 | ### Overview 79 | `fractrain_cifar` and `fractrain_imagenet` are the codes customized for CIFAR-10/100 and ImageNet, respectively, with a similar code structure. In particular, `train_pfq.py`, `train_dfq.py`, and `train_frac.py` are the training scripts for training the target network with PFQ, DFQ, and ***FracTrain***, respectively. 80 | Here we use `fractrain_cifar` to demo the commands for training. 81 | 82 | ### Training with PFQ 83 | In addition to the commonly considered args, e.g., the target network, dataset, and data path via `--arch`, `--dataset`, and `--datadir`, respectively, you also need to: (1) Specify the precision schedule for both forward and backward with `--num_bits_schedule` and `--num_grad_bits_schedule`, respectively; (2) Specify the epsilon and alpha in Alg.1 of [our paper](https://papers.nips.cc/paper/2020/file/8dc5983b8c4ef1d8fcd5f325f9a65511-Paper.pdf) with `--initial_threshold` and `--decay`, respectively, as well as the number of turning points via `--num_turning_point`. 84 | 85 | - Example: Training ResNet-74 on CIFAR-100 with PFQ 86 | ``` 87 | cd fractrain_cifar 88 | python train_pfq.py --save_folder ./logs/ --arch cifar100_resnet_74 --workers 4 --dataset cifar100 --datadir path-to-cifar100 --num_bits_schedule 3 4 6 8 --num_grad_bits_schedule 6 6 8 8 --num_turning_point 3 --initial_threshold 0.05 --decay 0.3 89 | ``` 90 | 91 | ### Training with DFQ 92 | Specify the weight precision with `--weight_bits` and the target *cp* as defined in Sec. 3.2 of [our paper](https://papers.nips.cc/paper/2020/file/8dc5983b8c4ef1d8fcd5f325f9a65511-Paper.pdf) with `--target_ratio`. 93 | 94 | - Example: Training ResNet-74 on CIFAR-100 with DFQ 95 | ``` 96 | cd fractrain_cifar 97 | python train_dp.py --save_folder ./logs --arch cifar100_rnn_gate_74 --workers 4 --dataset cifar100 --datadir path-to-cifar100 --weight_bits 4 --target_ratio 3 98 | ``` 99 | 100 | ### Training with FracTrain 101 | Specify the step (increment) of *cp* between two consective training stages with `--target_ratio_step` and other args follow the ones in PFQ and DFQ. 102 | 103 | - Example: Training ResNet-74 on CIFAR-100 with ***FracTrain*** 104 | ``` 105 | cd fractrain_cifar 106 | python train_frac.py --save_folder ./logs --arch cifar100_rnn_gate_74 --workers 4 --dataset cifar100 --datadir path-to-cifar100 --weight_bits 4 --target_ratio 1.5 --target_ratio_step 0.5 --initial_threshold 0.05 --decay 0.3 107 | ``` 108 | 109 | Similarly, for training on ImageNet: 110 | 111 | - Example: Training ResNet-18 on ImageNet with ***FracTrain*** 112 | 113 | ``` 114 | cd fractrain_imagenet 115 | python train_frac.py --save_folder ./logs --arch resnet18_rnn --workers 32 --dataset imagenet --datadir path-to-imagenet --weight_bits 6 --target_ratio 3 --target_ratio_step 0.5 --initial_threshold 0.05 --decay 0.3 116 | ``` 117 | 118 | 119 | ## Citation 120 | ``` 121 | @article{fu2020fractrain, 122 | title={FracTrain: Fractionally Squeezing Bit Savings Both Temporally and Spatially for Efficient DNN Training}, 123 | author={Fu, Yonggan and You, Haoran and Zhao, Yang and Wang, Yue and Li, Chaojian and Gopalakrishnan, Kailash and Wang, Zhangyang and Lin, Yingyan}, 124 | journal={Advances in Neural Information Processing Systems}, 125 | volume={33}, 126 | year={2020} 127 | } 128 | ``` 129 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: pytorch 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _pytorch_select=0.2=gpu_0 10 | - backcall=0.1.0=py36_0 11 | - blas=1.0=mkl 12 | - ca-certificates=2020.10.14=0 13 | - certifi=2020.6.20=py36_0 14 | - cffi=1.12.3=py36h2e261b9_0 15 | - cudatoolkit=10.2.89=hfd86e86_1 16 | - cudnn=7.6.5=cuda10.2_0 17 | - decorator=4.4.1=py_0 18 | - freetype=2.8=hab7d2ae_1 19 | - intel-openmp=2019.4=243 20 | - ipykernel=5.1.3=py36h39e3cac_0 21 | - ipython=7.9.0=py36h39e3cac_0 22 | - ipython_genutils=0.2.0=py36_0 23 | - jedi=0.15.1=py36_0 24 | - jpeg=9b=h024ee3a_2 25 | - jupyter_client=5.3.4=py36_0 26 | - jupyter_core=4.6.1=py36_0 27 | - libedit=3.1.20170329=hf8c457e_1001 28 | - libevent=2.0.22=hb7f436b_1002 29 | - libffi=3.2.1=hd88cf55_4 30 | - libgcc-ng=8.2.0=hdf63c60_1 31 | - libgfortran-ng=7.3.0=hdf63c60_0 32 | - libpng=1.6.37=hbc83047_0 33 | - libsodium=1.0.16=h1bed415_0 34 | - libstdcxx-ng=8.2.0=hdf63c60_1 35 | - libtiff=4.0.10=h2733197_2 36 | - mkl=2019.4=243 37 | - mkl-service=2.3.0=py36he904b0f_0 38 | - mkl_fft=1.0.12=py36ha843d7b_0 39 | - mkl_random=1.0.2=py36hd81dba3_0 40 | - mpi=1.0=mpich 41 | - mpi4py=3.0.3=py36h028fd6f_0 42 | - mpich=3.3.2=hc856adb_0 43 | - ncurses=6.1=hf484d3e_1002 44 | - ninja=1.9.0=py36hfd86e86_0 45 | - numpy=1.16.4=py36h7e9f1db_0 46 | - numpy-base=1.16.4=py36hde5b4d6_0 47 | - olefile=0.46=py36_0 48 | - openssl=1.0.2u=h7b6447c_0 49 | - pandas=0.24.2=py36he6710b0_0 50 | - parso=0.5.1=py_0 51 | - pexpect=4.7.0=py36_0 52 | - pickleshare=0.7.5=py36_0 53 | - pip=19.1.1=py36_0 54 | - prompt_toolkit=2.0.10=py_0 55 | - ptyprocess=0.6.0=py36_0 56 | - pycparser=2.19=py36_0 57 | - pygments=2.4.2=py_0 58 | - python=3.6.5=hc3d631a_2 59 | - python-dateutil=2.8.0=py36_0 60 | - pytorch=1.6.0=py3.6_cuda10.2.89_cudnn7.6.5_0 61 | - pytz=2019.1=py_0 62 | - pyzmq=18.1.0=py36he6710b0_0 63 | - readline=7.0=ha6073c6_4 64 | - setuptools=41.0.1=py36_0 65 | - sqlite=3.23.1=he433501_0 66 | - tk=8.6.8=hbc83047_0 67 | - tmux=2.9=h45300e9_0 68 | - torchvision=0.7.0=py36_cu102 69 | - tornado=6.0.3=py36h7b6447c_0 70 | - traitlets=4.3.3=py36_0 71 | - wcwidth=0.1.7=py36_0 72 | - wheel=0.33.4=py36_0 73 | - xz=5.2.4=h14c3975_4 74 | - zeromq=4.3.1=he6710b0_3 75 | - zlib=1.2.11=h7b6447c_3 76 | - zstd=1.3.7=h0b5b093_0 77 | - pip: 78 | - absl-py==0.9.0 79 | - cached-property==1.5.2 80 | - cachetools==4.0.0 81 | - chardet==3.0.4 82 | - click==7.1.2 83 | - configparser==5.0.1 84 | - cycler==0.10.0 85 | - cython==0.29.15 86 | - dnspython==2.0.0 87 | - docker-pycreds==0.4.0 88 | - dominate==2.4.0 89 | - easydict==1.9 90 | - efficientnet-pytorch==0.5.1 91 | - gitdb==4.0.5 92 | - gitpython==3.1.11 93 | - google-auth==1.10.0 94 | - google-auth-oauthlib==0.4.1 95 | - grpcio==1.26.0 96 | - h5py==3.1.0 97 | - idna==2.8 98 | - image-tools==1.0.0 99 | - joblib==0.13.2 100 | - jsonpatch==1.25 101 | - jsonpointer==2.0 102 | - kiwisolver==1.1.0 103 | - kornia==0.2.0 104 | - lmdb==0.98 105 | - markdown==3.1.1 106 | - matplotlib==3.1.0 107 | - mpmath==1.1.0 108 | - oauthlib==3.1.0 109 | - opencv-python==4.1.2.30 110 | - paho-mqtt==1.5.1 111 | - pillow==6.1.0 112 | - promise==2.3 113 | - protobuf==3.14.0 114 | - psutil==5.7.3 115 | - pyasn1==0.4.8 116 | - pyasn1-modules==0.2.8 117 | - pycocotools==2.0.0 118 | - pyparsing==2.4.0 119 | - python-etcd==0.4.5 120 | - python-graphviz==0.13.2 121 | - pyyaml==5.3 122 | - requests==2.22.0 123 | - requests-oauthlib==1.3.0 124 | - rsa==4.0 125 | - scikit-learn==0.22.2.post1 126 | - scipy==1.1.0 127 | - sentry-sdk==0.19.5 128 | - setproctitle==1.2.1 129 | - shortuuid==1.0.1 130 | - six==1.15.0 131 | - sklearn==0.0 132 | - smmap==3.0.4 133 | - subprocess32==3.5.4 134 | - sympy==1.5.1 135 | - tabulate==0.8.3 136 | - tensorboard==2.1.0 137 | - tensorboardx==2.0 138 | - thop==0.0.31-1912272122 139 | - torchcontrib==0.0.2 140 | - torchfile==0.1.0 141 | - tqdm==4.41.1 142 | - urllib3==1.25.7 143 | - visdom==0.1.8.9 144 | - wandb==0.10.12 145 | - watchdog==1.0.1 146 | - websocket-client==0.57.0 147 | - werkzeug==0.16.0 148 | - yml==0.0.1 149 | prefix: /home/yf22/anaconda3/envs/pytorch 150 | 151 | -------------------------------------------------------------------------------- /fractrain_cifar/compute_flops.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/simochen/model-tools. 2 | import numpy as np 3 | import os 4 | 5 | import torch 6 | import torchvision 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | 10 | from models_base import cifar10_resnet_38, cifar10_resnet_74, cifar100_resnet_38, cifar100_resnet_74 11 | 12 | def print_model_param_nums(model=None): 13 | if model == None: 14 | model = torchvision.models.alexnet() 15 | total = sum([param.nelement() if param.requires_grad else 0 for param in model.parameters()]) 16 | print(' + Number of params: %.4fM' % (total / 1e6)) 17 | 18 | def count_model_param_flops(model=None, input_res=32, multiply_adds=True): 19 | 20 | prods = {} 21 | def save_hook(name): 22 | def hook_per(self, input, output): 23 | prods[name] = np.prod(input[0].shape) 24 | return hook_per 25 | 26 | list_1=[] 27 | def simple_hook(self, input, output): 28 | list_1.append(np.prod(input[0].shape)) 29 | list_2={} 30 | def simple_hook2(self, input, output): 31 | list_2['names'] = np.prod(input[0].shape) 32 | 33 | 34 | list_conv=[] 35 | def conv_hook(self, input, output): 36 | batch_size, input_channels, input_height, input_width = input[0].size() 37 | output_channels, output_height, output_width = output[0].size() 38 | 39 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 40 | bias_ops = 1 if self.bias is not None else 0 41 | 42 | params = output_channels * (kernel_ops + bias_ops) 43 | # flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size 44 | 45 | num_weight_params = (self.weight.data != 0).float().sum() 46 | flops = (num_weight_params * (2 if multiply_adds else 1) + bias_ops * output_channels) * output_height * output_width * batch_size 47 | 48 | list_conv.append(flops) 49 | 50 | list_linear=[] 51 | def linear_hook(self, input, output): 52 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 53 | 54 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 55 | bias_ops = self.bias.nelement() 56 | 57 | flops = batch_size * (weight_ops + bias_ops) 58 | list_linear.append(flops) 59 | 60 | list_bn=[] 61 | def bn_hook(self, input, output): 62 | list_bn.append(input[0].nelement() * 2) 63 | 64 | list_relu=[] 65 | def relu_hook(self, input, output): 66 | list_relu.append(input[0].nelement()) 67 | 68 | list_pooling=[] 69 | def pooling_hook(self, input, output): 70 | batch_size, input_channels, input_height, input_width = input[0].size() 71 | output_channels, output_height, output_width = output[0].size() 72 | 73 | kernel_ops = self.kernel_size * self.kernel_size 74 | bias_ops = 0 75 | params = 0 76 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size 77 | 78 | list_pooling.append(flops) 79 | 80 | list_upsample=[] 81 | 82 | # For bilinear upsample 83 | def upsample_hook(self, input, output): 84 | batch_size, input_channels, input_height, input_width = input[0].size() 85 | output_channels, output_height, output_width = output[0].size() 86 | 87 | flops = output_height * output_width * output_channels * batch_size * 12 88 | list_upsample.append(flops) 89 | 90 | def foo(net): 91 | childrens = list(net.children()) 92 | if not childrens: 93 | if isinstance(net, torch.nn.Conv2d): 94 | net.register_forward_hook(conv_hook) 95 | if isinstance(net, torch.nn.Linear): 96 | net.register_forward_hook(linear_hook) 97 | if isinstance(net, torch.nn.BatchNorm2d): 98 | net.register_forward_hook(bn_hook) 99 | if isinstance(net, torch.nn.ReLU): 100 | net.register_forward_hook(relu_hook) 101 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d): 102 | net.register_forward_hook(pooling_hook) 103 | if isinstance(net, torch.nn.Upsample): 104 | net.register_forward_hook(upsample_hook) 105 | return 106 | for c in childrens: 107 | foo(c) 108 | 109 | if model == None: 110 | model = torchvision.models.alexnet() 111 | foo(model) 112 | input = Variable(torch.rand(3,input_res,input_res).unsqueeze(0), requires_grad = True) 113 | out = model(input) 114 | 115 | 116 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample)) 117 | 118 | print(' + Number of FLOPs: %.2fG' % (total_flops / 1e9)) 119 | 120 | return total_flops 121 | 122 | 123 | if __name__ == '__main__': 124 | from models import * 125 | model = cifar10_mobilenet_v2() 126 | 127 | blocks = [] 128 | 129 | for module in model.modules(): 130 | if isinstance(module, Block): 131 | blocks.append(module) 132 | 133 | conv_list = [] 134 | def conv_hook(self, input, output): 135 | batch_size, input_channels, input_height, input_width = input[0].size() 136 | output_channels, output_height, output_width = output[0].size() 137 | 138 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 139 | bias_ops = 1 if self.bias is not None else 0 140 | 141 | params = output_channels * (kernel_ops + bias_ops) 142 | # flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size 143 | 144 | num_weight_params = (self.weight.data != 0).float().sum() 145 | flops = (num_weight_params * 2 + bias_ops * output_channels) * output_height * output_width * batch_size 146 | 147 | conv_list.append(flops) 148 | 149 | dws_list = [] 150 | def dws_hook(self, input, output): 151 | batch_size, input_channels, input_height, input_width = input[0].size() 152 | output_channels, output_height, output_width = output[0].size() 153 | 154 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) 155 | bias_ops = 1 if self.bias is not None else 0 156 | 157 | params = output_channels * (kernel_ops + bias_ops) 158 | # flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size 159 | 160 | num_weight_params = (self.weight.data != 0).float().sum() 161 | flops = (num_weight_params * 2 + bias_ops * output_channels) * output_height * output_width * batch_size 162 | 163 | dws_list.append(flops) 164 | 165 | block_flops_list = [] 166 | dws_list = [] 167 | hook_list = [] 168 | for block in blocks: 169 | for module in block.modules(): 170 | if isinstance(module, torch.nn.Conv2d): 171 | if module.groups == 1: 172 | hook = module.register_forward_hook(conv_hook) 173 | hook_list.append(hook) 174 | else: 175 | hook = module.register_forward_hook(dws_hook) 176 | hook_list.append(hook) 177 | 178 | input = Variable(torch.rand(3, 32, 32).unsqueeze(0), requires_grad = True) 179 | out = model(input, 0, 0) 180 | 181 | block_flops_list.append(sum(conv_list)) 182 | 183 | for hook in hook_list: 184 | hook.remove() 185 | 186 | conv_list = [] 187 | 188 | print(block_flops_list) 189 | print(dws_list) 190 | 191 | model_info = {'conv':block_flops_list, 'dws':dws_list} 192 | 193 | np.save('model_info_mbv2.npy', model_info) 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | -------------------------------------------------------------------------------- /fractrain_cifar/data.py: -------------------------------------------------------------------------------- 1 | """prepare CIFAR and SVHN 2 | """ 3 | 4 | from __future__ import print_function 5 | 6 | import torch 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import numpy as np 10 | 11 | 12 | crop_size = 32 13 | padding = 4 14 | 15 | 16 | def prepare_train_data(dataset='cifar10', datadir='/home/yf22/dataset', batch_size=128, 17 | shuffle=True, num_workers=4): 18 | 19 | if 'cifar' in dataset: 20 | transform_train = transforms.Compose([ 21 | transforms.RandomCrop(crop_size, padding=padding), 22 | transforms.RandomHorizontalFlip(), 23 | transforms.ToTensor(), 24 | transforms.Normalize((0.4914, 0.4822, 0.4465), 25 | (0.2023, 0.1994, 0.2010)), 26 | ]) 27 | 28 | trainset = torchvision.datasets.__dict__[dataset.upper()]( 29 | root=datadir, train=True, download=True, transform=transform_train) 30 | train_loader = torch.utils.data.DataLoader(trainset, 31 | batch_size=batch_size, 32 | shuffle=shuffle, 33 | num_workers=num_workers) 34 | elif 'svhn' in dataset: 35 | transform_train =transforms.Compose([ 36 | transforms.ToTensor(), 37 | transforms.Normalize((0.4377, 0.4438, 0.4728), 38 | (0.1980, 0.2010, 0.1970)), 39 | ]) 40 | trainset = torchvision.datasets.__dict__[dataset.upper()]( 41 | root=datadir, 42 | split='train', 43 | download=True, 44 | transform=transform_train 45 | ) 46 | 47 | transform_extra = transforms.Compose([ 48 | transforms.ToTensor(), 49 | transforms.Normalize((0.4300, 0.4284, 0.4427), 50 | (0.1963, 0.1979, 0.1995)) 51 | 52 | ]) 53 | 54 | extraset = torchvision.datasets.__dict__[dataset.upper()]( 55 | root=datadir, 56 | split='extra', 57 | download=True, 58 | transform = transform_extra 59 | ) 60 | 61 | total_data = torch.utils.data.ConcatDataset([trainset, extraset]) 62 | 63 | train_loader = torch.utils.data.DataLoader(total_data, 64 | batch_size=batch_size, 65 | shuffle=shuffle, 66 | num_workers=num_workers) 67 | else: 68 | train_loader = None 69 | return train_loader 70 | 71 | 72 | def prepare_test_data(dataset='cifar10', datadir='/home/yf22/dataset', batch_size=128, 73 | shuffle=False, num_workers=4): 74 | 75 | if 'cifar' in dataset: 76 | transform_test = transforms.Compose([ 77 | transforms.ToTensor(), 78 | transforms.Normalize((0.4914, 0.4822, 0.4465), 79 | (0.2023, 0.1994, 0.2010)), 80 | ]) 81 | 82 | testset = torchvision.datasets.__dict__[dataset.upper()](root=datadir, 83 | train=False, 84 | download=True, 85 | transform=transform_test) 86 | test_loader = torch.utils.data.DataLoader(testset, 87 | batch_size=batch_size, 88 | shuffle=shuffle, 89 | num_workers=num_workers) 90 | elif 'svhn' in dataset: 91 | transform_test = transforms.Compose([ 92 | transforms.ToTensor(), 93 | transforms.Normalize((0.4524, 0.4525, 0.4690), 94 | (0.2194, 0.2266, 0.2285)), 95 | ]) 96 | testset = torchvision.datasets.__dict__[dataset.upper()]( 97 | root=datadir, 98 | split='test', 99 | download=True, 100 | transform=transform_test) 101 | np.place(testset.labels, testset.labels == 10, 0) 102 | test_loader = torch.utils.data.DataLoader(testset, 103 | batch_size=batch_size, 104 | shuffle=shuffle, 105 | num_workers=num_workers) 106 | else: 107 | test_loader = None 108 | return test_loader 109 | -------------------------------------------------------------------------------- /fractrain_cifar/model_info_mbv2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/fractrain_cifar/model_info_mbv2.npy -------------------------------------------------------------------------------- /fractrain_cifar/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/fractrain_cifar/modules/__init__.py -------------------------------------------------------------------------------- /fractrain_cifar/modules/bwn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bounded weight norm 3 | Weight Normalization from https://arxiv.org/abs/1602.07868 4 | taken and adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/weight_norm.py 5 | """ 6 | import torch 7 | from torch.nn.parameter import Parameter 8 | from torch.autograd import Variable, Function 9 | import torch.nn as nn 10 | 11 | 12 | def gather_params(self, memo=None, param_func=lambda s: s._parameters.values()): 13 | if memo is None: 14 | memo = set() 15 | for p in param_func(self): 16 | if p is not None and p not in memo: 17 | memo.add(p) 18 | yield p 19 | for m in self.children(): 20 | for p in gather_params(m, memo, param_func): 21 | yield p 22 | 23 | nn.Module.gather_params = gather_params 24 | 25 | 26 | def _norm(x, dim, p=2): 27 | """Computes the norm over all dimensions except dim""" 28 | if p == float('inf'): # infinity norm 29 | func = lambda x, dim: x.abs().max(dim=dim)[0] 30 | else: 31 | func = lambda x, dim: torch.norm(x, dim=dim, p=p) 32 | if dim is None: 33 | return x.norm(p=p) 34 | elif dim == 0: 35 | output_size = (x.size(0),) + (1,) * (x.dim() - 1) 36 | return func(x.contiguous().view(x.size(0), -1), 1).view(*output_size) 37 | elif dim == x.dim() - 1: 38 | output_size = (1,) * (x.dim() - 1) + (x.size(-1),) 39 | return func(x.contiguous().view(-1, x.size(-1)), 0).view(*output_size) 40 | else: 41 | return _norm(x.transpose(0, dim), 0).transpose(0, dim) 42 | 43 | 44 | def _mean(p, dim): 45 | """Computes the mean over all dimensions except dim""" 46 | if dim is None: 47 | return p.mean() 48 | elif dim == 0: 49 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 50 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size) 51 | elif dim == p.dim() - 1: 52 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 53 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size) 54 | else: 55 | return _mean(p.transpose(0, dim), 0).transpose(0, dim) 56 | 57 | 58 | class BoundedWeighNorm(object): 59 | 60 | def __init__(self, name, dim, p): 61 | self.name = name 62 | self.dim = dim 63 | self.p = p 64 | 65 | def compute_weight(self, module): 66 | v = getattr(module, self.name + '_v') 67 | pre_norm = getattr(module, self.name + '_prenorm') 68 | return v * (pre_norm / _norm(v, self.dim, p=self.p)) 69 | 70 | @staticmethod 71 | def apply(module, name, dim, p): 72 | fn = BoundedWeighNorm(name, dim, p) 73 | 74 | weight = getattr(module, name) 75 | 76 | # remove w from parameter list 77 | del module._parameters[name] 78 | 79 | prenorm = _norm(weight, dim, p=p).mean() 80 | module.register_buffer(name + '_prenorm', prenorm.detach()) 81 | pre_norm = getattr(module, name + '_prenorm') 82 | print(pre_norm) 83 | module.register_parameter(name + '_v', Parameter(weight.data)) 84 | setattr(module, name, fn.compute_weight(module)) 85 | 86 | # recompute weight before every forward() 87 | module.register_forward_pre_hook(fn) 88 | 89 | def gather_normed_params(self, memo=None, param_func=lambda s: fn.compute_weight(s)): 90 | return gather_params(self, memo, param_func) 91 | module.gather_params = gather_normed_params 92 | return fn 93 | 94 | def remove(self, module): 95 | weight = self.compute_weight(module) 96 | delattr(module, self.name) 97 | del module._parameters[self.name + '_prenorm'] 98 | del module._parameters[self.name + '_v'] 99 | module.register_parameter(self.name, Parameter(weight.data)) 100 | 101 | def __call__(self, module, inputs): 102 | setattr(module, self.name, self.compute_weight(module)) 103 | 104 | 105 | def weight_norm(module, name='weight', dim=0, p=2): 106 | r"""Applies weight normalization to a parameter in the given module. 107 | 108 | .. math:: 109 | \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} 110 | 111 | Weight normalization is a reparameterization that decouples the magnitude 112 | of a weight tensor from its direction. This replaces the parameter specified 113 | by `name` (e.g. "weight") with two parameters: one specifying the magnitude 114 | (e.g. "weight_g") and one specifying the direction (e.g. "weight_v"). 115 | Weight normalization is implemented via a hook that recomputes the weight 116 | tensor from the magnitude and direction before every :meth:`~Module.forward` 117 | call. 118 | 119 | By default, with `dim=0`, the norm is computed independently per output 120 | channel/plane. To compute a norm over the entire weight tensor, use 121 | `dim=None`. 122 | 123 | See https://arxiv.org/abs/1602.07868 124 | 125 | Args: 126 | module (nn.Module): containing module 127 | name (str, optional): name of weight parameter 128 | dim (int, optional): dimension over which to compute the norm 129 | 130 | Returns: 131 | The original module with the weight norm hook 132 | 133 | Example:: 134 | 135 | >>> m = weight_norm(nn.Linear(20, 40), name='weight') 136 | Linear (20 -> 40) 137 | >>> m.weight_g.size() 138 | torch.Size([40, 1]) 139 | >>> m.weight_v.size() 140 | torch.Size([40, 20]) 141 | 142 | """ 143 | BoundedWeighNorm.apply(module, name, dim, p) 144 | return module 145 | 146 | 147 | def remove_weight_norm(module, name='weight'): 148 | r"""Removes the weight normalization reparameterization from a module. 149 | 150 | Args: 151 | module (nn.Module): containing module 152 | name (str, optional): name of weight parameter 153 | 154 | Example: 155 | >>> m = weight_norm(nn.Linear(20, 40)) 156 | >>> remove_weight_norm(m) 157 | """ 158 | for k, hook in module._forward_pre_hooks.items(): 159 | if isinstance(hook, BoundedWeighNorm) and hook.name == name: 160 | hook.remove(module) 161 | del module._forward_pre_hooks[k] 162 | return module 163 | 164 | raise ValueError("weight_norm of '{}' not found in {}" 165 | .format(name, module)) 166 | -------------------------------------------------------------------------------- /fractrain_cifar/modules/quantize.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd.function import InplaceFunction, Function 7 | 8 | QParams = namedtuple('QParams', ['range', 'zero_point', 'num_bits']) 9 | 10 | _DEFAULT_FLATTEN = (1, -1) 11 | _DEFAULT_FLATTEN_GRAD = (0, -1) 12 | 13 | 14 | def _deflatten_as(x, x_full): 15 | shape = list(x.shape) + [1] * (x_full.dim() - x.dim()) 16 | return x.view(*shape) 17 | 18 | 19 | def calculate_qparams(x, num_bits, flatten_dims=_DEFAULT_FLATTEN, reduce_dim=0, reduce_type='mean', keepdim=False, true_zero=False): 20 | with torch.no_grad(): 21 | x_flat = x.flatten(*flatten_dims) 22 | if x_flat.dim() == 1: 23 | min_values = _deflatten_as(x_flat.min(), x) 24 | max_values = _deflatten_as(x_flat.max(), x) 25 | else: 26 | min_values = _deflatten_as(x_flat.min(-1)[0], x) 27 | max_values = _deflatten_as(x_flat.max(-1)[0], x) 28 | 29 | if reduce_dim is not None: 30 | if reduce_type == 'mean': 31 | min_values = min_values.mean(reduce_dim, keepdim=keepdim) 32 | max_values = max_values.mean(reduce_dim, keepdim=keepdim) 33 | else: 34 | min_values = min_values.min(reduce_dim, keepdim=keepdim)[0] 35 | max_values = max_values.max(reduce_dim, keepdim=keepdim)[0] 36 | 37 | range_values = max_values - min_values 38 | return QParams(range=range_values, zero_point=min_values, 39 | num_bits=num_bits) 40 | 41 | 42 | class UniformQuantize(InplaceFunction): 43 | 44 | @staticmethod 45 | def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN, 46 | reduce_dim=0, dequantize=True, signed=False, stochastic=True, inplace=False): 47 | 48 | ctx.inplace = inplace 49 | 50 | if ctx.inplace: 51 | ctx.mark_dirty(input) 52 | output = input 53 | else: 54 | output = input.clone() 55 | 56 | if qparams is None: 57 | assert num_bits is not None, "either provide qparams of num_bits to quantize" 58 | qparams = calculate_qparams( 59 | input, num_bits=num_bits, flatten_dims=flatten_dims, reduce_dim=reduce_dim) 60 | 61 | zero_point = qparams.zero_point 62 | num_bits = qparams.num_bits 63 | qmin = -(2.**(num_bits - 1)) if signed else 0. 64 | qmax = qmin + 2.**num_bits - 1. 65 | scale = qparams.range / (qmax - qmin) 66 | 67 | min_scale = torch.tensor(1e-8).expand_as(scale).cuda() 68 | scale = torch.max(scale, min_scale) 69 | 70 | with torch.no_grad(): 71 | output.add_(qmin * scale - zero_point).div_(scale) 72 | if stochastic: 73 | noise = output.new(output.shape).uniform_(-0.5, 0.5) 74 | output.add_(noise) 75 | # quantize 76 | output.clamp_(qmin, qmax).round_() 77 | 78 | if dequantize: 79 | output.mul_(scale).add_( 80 | zero_point - qmin * scale) # dequantize 81 | return output 82 | 83 | @staticmethod 84 | def backward(ctx, grad_output): 85 | # straight-through estimator 86 | grad_input = grad_output 87 | return grad_input, None, None, None, None, None, None, None, None 88 | 89 | 90 | class UniformQuantizeGrad(InplaceFunction): 91 | 92 | @staticmethod 93 | def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD, 94 | reduce_dim=0, dequantize=True, signed=False, stochastic=True): 95 | ctx.num_bits = num_bits 96 | ctx.qparams = qparams 97 | ctx.flatten_dims = flatten_dims 98 | ctx.stochastic = stochastic 99 | ctx.signed = signed 100 | ctx.dequantize = dequantize 101 | ctx.reduce_dim = reduce_dim 102 | ctx.inplace = False 103 | return input 104 | 105 | @staticmethod 106 | def backward(ctx, grad_output): 107 | qparams = ctx.qparams 108 | with torch.no_grad(): 109 | if qparams is None: 110 | assert ctx.num_bits is not None, "either provide qparams of num_bits to quantize" 111 | qparams = calculate_qparams( 112 | grad_output, num_bits=ctx.num_bits, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim, reduce_type='extreme') 113 | 114 | grad_input = quantize(grad_output, num_bits=None, 115 | qparams=qparams, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim, 116 | dequantize=True, signed=ctx.signed, stochastic=ctx.stochastic, inplace=False) 117 | return grad_input, None, None, None, None, None, None, None 118 | 119 | 120 | def conv2d_biprec(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, num_bits_grad=None): 121 | out1 = F.conv2d(input.detach(), weight, bias, 122 | stride, padding, dilation, groups) 123 | out2 = F.conv2d(input, weight.detach(), bias.detach() if bias is not None else None, 124 | stride, padding, dilation, groups) 125 | out2 = quantize_grad(out2, num_bits=num_bits_grad, flatten_dims=(1, -1)) 126 | return out1 + out2 - out1.detach() 127 | 128 | 129 | def linear_biprec(input, weight, bias=None, num_bits_grad=None): 130 | out1 = F.linear(input.detach(), weight, bias) 131 | out2 = F.linear(input, weight.detach(), bias.detach() 132 | if bias is not None else None) 133 | out2 = quantize_grad(out2, num_bits=num_bits_grad) 134 | return out1 + out2 - out1.detach() 135 | 136 | 137 | def quantize(x, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN, reduce_dim=0, dequantize=True, signed=False, stochastic=False, inplace=False): 138 | if qparams: 139 | if qparams.num_bits: 140 | return UniformQuantize().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic, inplace) 141 | elif num_bits: 142 | return UniformQuantize().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic, inplace) 143 | 144 | return x 145 | 146 | 147 | def quantize_grad(x, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD, reduce_dim=0, dequantize=True, signed=False, stochastic=True): 148 | if qparams: 149 | if qparams.num_bits: 150 | return UniformQuantizeGrad().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic) 151 | elif num_bits: 152 | return UniformQuantizeGrad().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic) 153 | 154 | return x 155 | 156 | 157 | class QuantMeasure(nn.Module): 158 | """docstring for QuantMeasure.""" 159 | 160 | def __init__(self, shape_measure=(1,), flatten_dims=_DEFAULT_FLATTEN, 161 | inplace=False, dequantize=True, stochastic=False, momentum=0.9, measure=False): 162 | super(QuantMeasure, self).__init__() 163 | self.register_buffer('running_zero_point', torch.zeros(*shape_measure)) 164 | self.register_buffer('running_range', torch.zeros(*shape_measure)) 165 | self.measure = measure 166 | if self.measure: 167 | self.register_buffer('num_measured', torch.zeros(1)) 168 | self.flatten_dims = flatten_dims 169 | self.momentum = momentum 170 | self.dequantize = dequantize 171 | self.stochastic = stochastic 172 | self.inplace = inplace 173 | 174 | def forward(self, input, num_bits, qparams=None): 175 | 176 | if self.training or self.measure: 177 | if qparams is None: 178 | qparams = calculate_qparams( 179 | input, num_bits=num_bits, flatten_dims=self.flatten_dims, reduce_dim=0, reduce_type='extreme') 180 | with torch.no_grad(): 181 | if self.measure: 182 | momentum = self.num_measured / (self.num_measured + 1) 183 | self.num_measured += 1 184 | else: 185 | momentum = self.momentum 186 | self.running_zero_point.mul_(momentum).add_( 187 | qparams.zero_point * (1 - momentum)) 188 | self.running_range.mul_(momentum).add_( 189 | qparams.range * (1 - momentum)) 190 | else: 191 | qparams = QParams(range=self.running_range, 192 | zero_point=self.running_zero_point, num_bits=num_bits) 193 | if self.measure: 194 | return input 195 | else: 196 | q_input = quantize(input, qparams=qparams, dequantize=self.dequantize, 197 | stochastic=self.stochastic, inplace=self.inplace) 198 | return q_input 199 | 200 | 201 | class QConv2d(nn.Conv2d): 202 | """docstring for QConv2d.""" 203 | 204 | def __init__(self, in_channels, out_channels, kernel_size, 205 | stride=1, padding=0, dilation=1, groups=1, bias=True, momentum=0.1, quant_act_forward=0, quant_act_backward=0, quant_grad_act_error=0, quant_grad_act_gc=0, weight_bits=0, fix_prec=False): 206 | super(QConv2d, self).__init__(in_channels, out_channels, kernel_size, 207 | stride, padding, dilation, groups, bias) 208 | 209 | self.quantize_input_fw = QuantMeasure(shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1), momentum=momentum) 210 | self.quantize_input_bw = QuantMeasure(shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1), momentum=momentum) 211 | self.quant_act_forward = quant_act_forward 212 | self.quant_act_backward = quant_act_backward 213 | self.quant_grad_act_error = quant_grad_act_error 214 | self.quant_grad_act_gc = quant_grad_act_gc 215 | self.weight_bits = weight_bits 216 | self.fix_prec = fix_prec 217 | self.stride = stride 218 | 219 | 220 | def forward(self, input, num_bits, num_grad_bits): 221 | if num_bits == 0: 222 | output = F.conv2d(input, self.weight, self.bias, self.stride,self.padding, self.dilation, self.groups) 223 | return output 224 | 225 | if self.bias is not None: 226 | qbias = quantize( 227 | self.bias, num_bits=self.num_bits_weight + self.num_bits, 228 | flatten_dims=(0, -1)) 229 | else: 230 | qbias = None 231 | 232 | if self.fix_prec: 233 | if self.quant_act_forward or self.quant_act_backward or self.quant_grad_act_error or self.quant_grad_act_gc or self.weight_bits: 234 | weight_qparams = calculate_qparams(self.weight, num_bits=self.weight_bits, flatten_dims=(1, -1), reduce_dim=None) 235 | qweight = quantize(self.weight, qparams=weight_qparams) 236 | 237 | qinput_fw = self.quantize_input_fw(input, self.quant_act_forward) 238 | qinput_bw = self.quantize_input_bw(input, self.quant_act_backward) 239 | 240 | error_bits = self.quant_grad_act_error 241 | gc_bits = self.quant_grad_act_gc 242 | output = self.conv2d_quant_act(qinput_fw, qinput_bw, qweight, qbias, self.stride, self.padding, self.dilation, self.groups, error_bits, gc_bits) 243 | 244 | else: 245 | qinput = self.quantize_input_fw(input, num_bits) 246 | weight_qparams = calculate_qparams(self.weight, num_bits=num_bits, flatten_dims=(1, -1), reduce_dim=None) 247 | qweight = quantize(self.weight, qparams=weight_qparams) 248 | output = F.conv2d(qinput, qweight, qbias, self.stride, self.padding, self.dilation, self.groups) 249 | output = quantize_grad(output, num_bits=num_grad_bits, flatten_dims=(1, -1)) 250 | 251 | return output 252 | 253 | weight_qparams = calculate_qparams(self.weight, num_bits=self.weight_bits, flatten_dims=(1, -1), reduce_dim=None) 254 | qweight = quantize(self.weight, qparams=weight_qparams) 255 | 256 | qinput = self.quantize_input_fw(input, num_bits) 257 | output = F.conv2d(qinput, qweight, qbias, self.stride, self.padding, self.dilation, self.groups) 258 | output = quantize_grad(output, num_bits=num_grad_bits, flatten_dims=(1, -1)) 259 | 260 | # if self.quant_act_forward == -1: 261 | # qinput_fw = self.quantize_input_fw(input, num_bits) 262 | # else: 263 | # qinput_fw = self.quantize_input_fw(input, self.quant_act_forward) 264 | 265 | # if self.quant_act_backward == -1: 266 | # qinput_bw = self.quantize_input_bw(input, num_bits) 267 | # else: 268 | # qinput_bw = self.quantize_input_bw(input, self.quant_act_backward) 269 | 270 | # if self.quant_grad_act_error == -1: 271 | # error_bits = num_grad_bits 272 | # else: 273 | # error_bits = self.quant_grad_act_error 274 | 275 | # if self.quant_grad_act_gc == -1: 276 | # gc_bits = num_grad_bits 277 | # else: 278 | # gc_bits = self.quant_grad_act_gc 279 | 280 | # output = self.conv2d_quant_act(qinput_fw, qinput_bw, qweight, qbias, self.stride, self.padding, self.dilation, self.groups, error_bits, gc_bits) 281 | 282 | return output 283 | 284 | 285 | def conv2d_quant_act(self, input_fw, input_bw, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, error_bits=0, gc_bits=0): 286 | out1 = F.conv2d(input_fw, weight.detach(), bias.detach() if bias is not None else None, 287 | stride, padding, dilation, groups) 288 | out2 = F.conv2d(input_bw.detach(), weight, bias, 289 | stride, padding, dilation, groups) 290 | out1 = quantize_grad(out1, num_bits=error_bits) 291 | out2 = quantize_grad(out2, num_bits=gc_bits) 292 | return out1 + out2 - out2.detach() 293 | 294 | 295 | class QLinear(nn.Linear): 296 | """docstring for QConv2d.""" 297 | 298 | def __init__(self, in_features, out_features, bias=True, num_bits=8, num_bits_weight=8, num_bits_grad=8, biprecision=True): 299 | super(QLinear, self).__init__(in_features, out_features, bias) 300 | self.num_bits = num_bits 301 | self.num_bits_weight = num_bits_weight or num_bits 302 | self.num_bits_grad = num_bits_grad 303 | self.biprecision = biprecision 304 | self.quantize_input = QuantMeasure(self.num_bits) 305 | 306 | def forward(self, input): 307 | qinput = self.quantize_input(input) 308 | weight_qparams = calculate_qparams( 309 | self.weight, num_bits=self.num_bits_weight, flatten_dims=(1, -1), reduce_dim=None) 310 | qweight = quantize(self.weight, qparams=weight_qparams) 311 | if self.bias is not None: 312 | qbias = quantize( 313 | self.bias, num_bits=self.num_bits_weight + self.num_bits, 314 | flatten_dims=(0, -1)) 315 | else: 316 | qbias = None 317 | 318 | if not self.biprecision or self.num_bits_grad is None: 319 | output = F.linear(qinput, qweight, qbias) 320 | if self.num_bits_grad is not None: 321 | output = quantize_grad( 322 | output, num_bits=self.num_bits_grad) 323 | else: 324 | output = linear_biprec(qinput, qweight, qbias, self.num_bits_grad) 325 | return output 326 | 327 | 328 | class RangeBN(nn.Module): 329 | # this is normalized RangeBN 330 | 331 | def __init__(self, num_features, dim=1, momentum=0.1, affine=True, num_chunks=16, eps=1e-5, num_bits=8, num_bits_grad=8): 332 | super(RangeBN, self).__init__() 333 | self.register_buffer('running_mean', torch.zeros(num_features)) 334 | self.register_buffer('running_var', torch.zeros(num_features)) 335 | 336 | self.momentum = momentum 337 | self.dim = dim 338 | if affine: 339 | self.bias = nn.Parameter(torch.Tensor(num_features)) 340 | self.weight = nn.Parameter(torch.Tensor(num_features)) 341 | self.num_bits = num_bits 342 | self.num_bits_grad = num_bits_grad 343 | self.quantize_input = QuantMeasure(inplace=True, shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1)) 344 | self.eps = eps 345 | self.num_chunks = num_chunks 346 | self.reset_params() 347 | 348 | def reset_params(self): 349 | if self.weight is not None: 350 | self.weight.data.uniform_() 351 | if self.bias is not None: 352 | self.bias.data.zero_() 353 | 354 | def forward(self, x, num_bits, num_grad_bits): 355 | x = self.quantize_input(x, num_bits) 356 | if x.dim() == 2: # 1d 357 | x = x.unsqueeze(-1,).unsqueeze(-1) 358 | 359 | if self.training: 360 | B, C, H, W = x.shape 361 | y = x.transpose(0, 1).contiguous() # C x B x H x W 362 | y = y.view(C, self.num_chunks, (B * H * W) // self.num_chunks) 363 | mean_max = y.max(-1)[0].mean(-1) # C 364 | mean_min = y.min(-1)[0].mean(-1) # C 365 | mean = y.view(C, -1).mean(-1) # C 366 | scale_fix = (0.5 * 0.35) * (1 + (math.pi * math.log(4)) ** 367 | 0.5) / ((2 * math.log(y.size(-1))) ** 0.5) 368 | 369 | scale = (mean_max - mean_min) * scale_fix 370 | with torch.no_grad(): 371 | self.running_mean.mul_(self.momentum).add_( 372 | mean * (1 - self.momentum)) 373 | 374 | self.running_var.mul_(self.momentum).add_( 375 | scale * (1 - self.momentum)) 376 | else: 377 | mean = self.running_mean 378 | scale = self.running_var 379 | # scale = quantize(scale, num_bits=self.num_bits, min_value=float( 380 | # scale.min()), max_value=float(scale.max())) 381 | out = (x - mean.view(1, -1, 1, 1)) / \ 382 | (scale.view(1, -1, 1, 1) + self.eps) 383 | 384 | if self.weight is not None: 385 | qweight = self.weight 386 | # qweight = quantize(self.weight, num_bits=self.num_bits, 387 | # min_value=float(self.weight.min()), 388 | # max_value=float(self.weight.max())) 389 | out = out * qweight.view(1, -1, 1, 1) 390 | 391 | if self.bias is not None: 392 | qbias = self.bias 393 | # qbias = quantize(self.bias, num_bits=self.num_bits) 394 | out = out + qbias.view(1, -1, 1, 1) 395 | if num_grad_bits: 396 | out = quantize_grad( 397 | out, num_bits=num_grad_bits, flatten_dims=(1, -1)) 398 | 399 | if out.size(3) == 1 and out.size(2) == 1: 400 | out = out.squeeze(-1).squeeze(-1) 401 | return out 402 | 403 | 404 | class RangeBN1d(RangeBN): 405 | # this is normalized RangeBN 406 | 407 | def __init__(self, num_features, dim=1, momentum=0.1, affine=True, num_chunks=16, eps=1e-5, num_bits=8, num_bits_grad=8): 408 | super(RangeBN1d, self).__init__(num_features, dim, momentum, 409 | affine, num_chunks, eps, num_bits, num_bits_grad) 410 | self.quantize_input = QuantMeasure( 411 | self.num_bits, inplace=True, shape_measure=(1, 1), flatten_dims=(1, -1)) 412 | 413 | if __name__ == '__main__': 414 | x = torch.rand(2, 3) 415 | x_q = quantize(x, flatten_dims=(-1), num_bits=8, dequantize=True) 416 | print(x) 417 | print(x_q) -------------------------------------------------------------------------------- /fractrain_cifar/modules/rnlu.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.autograd.function import InplaceFunction 4 | from torch.autograd import Variable 5 | import torch.nn as nn 6 | import math 7 | 8 | 9 | class BiReLUFunction(InplaceFunction): 10 | 11 | @classmethod 12 | def forward(cls, ctx, input, inplace=False): 13 | if input.size(1) % 2 != 0: 14 | raise RuntimeError("dimension 1 of input must be multiple of 2, " 15 | "but got {}".format(input.size(1))) 16 | ctx.inplace = inplace 17 | 18 | if ctx.inplace: 19 | ctx.mark_dirty(input) 20 | output = input 21 | else: 22 | output = input.clone() 23 | 24 | pos, neg = output.chunk(2, dim=1) 25 | pos.clamp_(min=0) 26 | neg.clamp_(max=0) 27 | # scale = (pos - neg).view(pos.size(0), -1).mean(1).div_(2) 28 | # output. 29 | ctx.save_for_backward(output) 30 | return output 31 | 32 | @staticmethod 33 | def backward(ctx, grad_output): 34 | output, = ctx.saved_variables 35 | grad_input = grad_output.masked_fill(output.eq(0), 0) 36 | return grad_input, None 37 | 38 | 39 | def birelu(x, inplace=False): 40 | return BiReLUFunction().apply(x, inplace) 41 | 42 | 43 | class BiReLU(nn.Module): 44 | """docstring for BiReLU.""" 45 | 46 | def __init__(self, inplace=False): 47 | super(BiReLU, self).__init__() 48 | self.inplace = inplace 49 | 50 | def forward(self, inputs): 51 | return birelu(inputs, inplace=self.inplace) 52 | 53 | 54 | def binorm(x, shift=0, scale_fix=(2 / math.pi) ** 0.5): 55 | pos, neg = (x + shift).split(2, dim=1) 56 | scale = (pos - neg).view(pos.size(0), -1).mean(1).div_(2) * scale_fix 57 | return x / scale 58 | 59 | 60 | def _mean(p, dim): 61 | """Computes the mean over all dimensions except dim""" 62 | if dim is None: 63 | return p.mean() 64 | elif dim == 0: 65 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 66 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size) 67 | elif dim == p.dim() - 1: 68 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 69 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size) 70 | else: 71 | return _mean(p.transpose(0, dim), 0).transpose(0, dim) 72 | 73 | 74 | def rnlu(x, inplace=False, shift=0, scale_fix=(math.pi / 2) ** 0.5): 75 | x = birelu(x, inplace=inplace) 76 | pos, neg = (x + shift).chunk(2, dim=1) 77 | # scale = torch.cat((_mean(pos, 1), -_mean(neg, 1)), 1) * scale_fix + 1e-5 78 | scale = (pos - neg).view(pos.size(0), -1).mean(1) * scale_fix + 1e-8 79 | return x / scale.view(scale.size(0), *([1] * (x.dim() - 1))) 80 | 81 | 82 | class RnLU(nn.Module): 83 | """docstring for RnLU.""" 84 | 85 | def __init__(self, inplace=False): 86 | super(RnLU, self).__init__() 87 | self.inplace = inplace 88 | 89 | def forward(self, x): 90 | return rnlu(x, inplace=self.inplace) 91 | 92 | # output. 93 | if __name__ == "__main__": 94 | x = Variable(torch.randn(2, 16, 5, 5).cuda(), requires_grad=True) 95 | output = rnlu(x) 96 | 97 | output.sum().backward() 98 | -------------------------------------------------------------------------------- /fractrain_cifar/train_base.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | import os 10 | import shutil 11 | import argparse 12 | import time 13 | import logging 14 | 15 | import models 16 | from data import * 17 | 18 | import util_swa 19 | 20 | 21 | model_names = sorted(name for name in models.__dict__ 22 | if name.islower() and not name.startswith('__') 23 | and callable(models.__dict__[name]) 24 | ) 25 | 26 | 27 | def parse_args(): 28 | # hyper-parameters are from ResNet paper 29 | parser = argparse.ArgumentParser( 30 | description='Quantization Aware Training on CIFAR') 31 | parser.add_argument('--dir', help='annotate the working directory') 32 | parser.add_argument('--cmd', choices=['train', 'test'], default='train') 33 | parser.add_argument('--arch', metavar='ARCH', default='cifar10_resnet_38', 34 | choices=model_names, 35 | help='model architecture: ' + 36 | ' | '.join(model_names) + 37 | ' (default: cifar10_resnet_38)') 38 | parser.add_argument('--dataset', '-d', type=str, default='cifar10', 39 | choices=['cifar10', 'cifar100'], 40 | help='dataset choice') 41 | parser.add_argument('--datadir', default='/home/yf22/dataset', type=str, 42 | help='path to dataset') 43 | parser.add_argument('--workers', default=4, type=int, metavar='N', 44 | help='number of data loading workers (default: 4 )') 45 | parser.add_argument('--iters', default=64000, type=int, 46 | help='number of total iterations (default: 64,000)') 47 | parser.add_argument('--start_iter', default=0, type=int, 48 | help='manual iter number (useful on restarts)') 49 | parser.add_argument('--batch_size', default=128, type=int, 50 | help='mini-batch size (default: 128)') 51 | parser.add_argument('--lr_schedule', default='piecewise', type=str, 52 | help='learning rate schedule') 53 | parser.add_argument('--lr', default=0.1, type=float, 54 | help='initial learning rate') 55 | parser.add_argument('--momentum', default=0.9, type=float, 56 | help='momentum') 57 | parser.add_argument('--weight_decay', default=1e-4, type=float, 58 | help='weight decay (default: 1e-4)') 59 | parser.add_argument('--print_freq', default=10, type=int, 60 | help='print frequency (default: 10)') 61 | parser.add_argument('--resume', default='', type=str, 62 | help='path to latest checkpoint (default: None)') 63 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 64 | help='use pretrained model') 65 | parser.add_argument('--step_ratio', default=0.1, type=float, 66 | help='ratio for learning rate deduction') 67 | parser.add_argument('--warm_up', action='store_true', 68 | help='for n = 18, the model needs to warm up for 400 ' 69 | 'iterations') 70 | parser.add_argument('--save_folder', default='save_checkpoints', 71 | type=str, 72 | help='folder to save the checkpoints') 73 | parser.add_argument('--eval_every', default=390, type=int, 74 | help='evaluate model every (default: 1000) iterations') 75 | parser.add_argument('--num_bits',default=0,type=int, 76 | help='num bits for weight and activation') 77 | parser.add_argument('--num_grad_bits',default=0,type=int, 78 | help='num bits for gradient') 79 | parser.add_argument('--schedule', default=None, type=int, nargs='*', 80 | help='precision schedule') 81 | parser.add_argument('--num_bits_schedule',default=None,type=int,nargs='*', 82 | help='schedule for weight/act precision') 83 | parser.add_argument('--num_grad_bits_schedule',default=None,type=int,nargs='*', 84 | help='schedule for grad precision') 85 | parser.add_argument('--act_fw', default=0, type=int, 86 | help='precision of activation during forward, -1 means dynamic, 0 means no quantize') 87 | parser.add_argument('--act_bw', default=0, type=int, 88 | help='precision of activation during backward, -1 means dynamic, 0 means no quantize') 89 | parser.add_argument('--grad_act_error', default=0, type=int, 90 | help='precision of activation gradient during error backward, -1 means dynamic, 0 means no quantize') 91 | parser.add_argument('--grad_act_gc', default=0, type=int, 92 | help='precision of activation gradient during weight gradient computation, -1 means dynamic, 0 means no quantize') 93 | parser.add_argument('--weight_bits', default=0, type=int, 94 | help='precision of weight') 95 | parser.add_argument('--momentum_act', default=0.9, type=float, 96 | help='momentum for act min/max') 97 | parser.add_argument('--swa_start', type=float, default=None, help='SWA start step number') 98 | parser.add_argument('--swa_freq', type=float, default=1170, 99 | help='SWA model collection frequency') 100 | args = parser.parse_args() 101 | return args 102 | 103 | 104 | def main(): 105 | args = parse_args() 106 | save_path = args.save_path = os.path.join(args.save_folder, args.arch) 107 | if not os.path.exists(save_path): 108 | os.makedirs(save_path) 109 | 110 | models.ACT_FW = args.act_fw 111 | models.ACT_BW = args.act_bw 112 | models.GRAD_ACT_ERROR = args.grad_act_error 113 | models.GRAD_ACT_GC = args.grad_act_gc 114 | models.WEIGHT_BITS = args.weight_bits 115 | models.MOMENTUM = args.momentum_act 116 | 117 | args.num_bits = args.num_bits if not (args.act_fw + args.act_bw + args.grad_act_error + args.grad_act_gc + args.weight_bits) else -1 118 | 119 | # config logging file 120 | args.logger_file = os.path.join(save_path, 'log_{}.txt'.format(args.cmd)) 121 | handlers = [logging.FileHandler(args.logger_file, mode='w'), 122 | logging.StreamHandler()] 123 | logging.basicConfig(level=logging.INFO, 124 | datefmt='%m-%d-%y %H:%M', 125 | format='%(asctime)s:%(message)s', 126 | handlers=handlers) 127 | 128 | if args.cmd == 'train': 129 | logging.info('start training {}'.format(args.arch)) 130 | run_training(args) 131 | 132 | elif args.cmd == 'test': 133 | logging.info('start evaluating {} with checkpoints from {}'.format( 134 | args.arch, args.resume)) 135 | test_model(args) 136 | 137 | 138 | def run_training(args): 139 | # create model 140 | model = models.__dict__[args.arch](args.pretrained) 141 | model = torch.nn.DataParallel(model).cuda() 142 | 143 | if args.swa_start is not None: 144 | print('SWA training') 145 | swa_model = torch.nn.DataParallel(models.__dict__[args.arch](args.pretrained)).cuda() 146 | swa_n = 0 147 | 148 | else: 149 | print('SGD training') 150 | 151 | best_prec1 = 0 152 | best_iter = 0 153 | 154 | best_swa_prec = 0 155 | best_swa_iter = 0 156 | 157 | # best_full_prec = 0 158 | 159 | if args.resume: 160 | if os.path.isfile(args.resume): 161 | logging.info('=> loading checkpoint `{}`'.format(args.resume)) 162 | checkpoint = torch.load(args.resume) 163 | args.start_iter = checkpoint['iter'] 164 | best_prec1 = checkpoint['best_prec1'] 165 | model.load_state_dict(checkpoint['state_dict']) 166 | 167 | if args.swa_start is not None: 168 | swa_state_dict = checkpoint['swa_state_dict'] 169 | if swa_state_dict is not None: 170 | swa_model.load_state_dict(swa_state_dict) 171 | swa_n_ckpt = checkpoint['swa_n'] 172 | if swa_n_ckpt is not None: 173 | swa_n = swa_n_ckpt 174 | best_swa_prec_ckpt = checkpoint['best_swa_prec'] 175 | if best_swa_prec_ckpt is not None: 176 | best_swa_prec = best_swa_prec_ckpt 177 | 178 | logging.info('=> loaded checkpoint `{}` (iter: {})'.format( 179 | args.resume, checkpoint['iter'] 180 | )) 181 | else: 182 | logging.info('=> no checkpoint found at `{}`'.format(args.resume)) 183 | 184 | cudnn.benchmark = False 185 | 186 | train_loader = prepare_train_data(dataset=args.dataset, 187 | datadir=args.datadir, 188 | batch_size=args.batch_size, 189 | shuffle=True, 190 | num_workers=args.workers) 191 | test_loader = prepare_test_data(dataset=args.dataset, 192 | datadir=args.datadir, 193 | batch_size=args.batch_size, 194 | shuffle=False, 195 | num_workers=args.workers) 196 | if args.swa_start is not None: 197 | swa_loader = prepare_train_data(dataset=args.dataset, 198 | datadir=args.datadir, 199 | batch_size=args.batch_size, 200 | shuffle=False, 201 | num_workers=args.workers) 202 | 203 | # define loss function (criterion) and optimizer 204 | criterion = nn.CrossEntropyLoss().cuda() 205 | 206 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 207 | momentum=args.momentum, 208 | weight_decay=args.weight_decay) 209 | 210 | # optimizer = torch.optim.Adam(model.parameters(), args.lr, 211 | # weight_decay=args.weight_decay) 212 | 213 | batch_time = AverageMeter() 214 | data_time = AverageMeter() 215 | losses = AverageMeter() 216 | top1 = AverageMeter() 217 | cr = AverageMeter() 218 | 219 | end = time.time() 220 | 221 | i = args.start_iter 222 | while i < args.iters: 223 | for input, target in train_loader: 224 | # measuring data loading time 225 | data_time.update(time.time() - end) 226 | 227 | model.train() 228 | adjust_learning_rate(args, optimizer, i) 229 | adjust_precision(args, i) 230 | 231 | i += 1 232 | 233 | fw_cost = args.num_bits*args.num_bits/32/32 234 | eb_cost = args.num_bits*args.num_grad_bits/32/32 235 | gc_cost = eb_cost 236 | cr.update((fw_cost+eb_cost+gc_cost)/3) 237 | 238 | target = target.squeeze().long().cuda() 239 | input_var = Variable(input).cuda() 240 | target_var = Variable(target).cuda() 241 | 242 | # compute output 243 | output = model(input_var, args.num_bits, args.num_grad_bits) 244 | loss = criterion(output, target_var) 245 | 246 | # measure accuracy and record loss 247 | prec1, = accuracy(output.data, target, topk=(1,)) 248 | losses.update(loss.item(), input.size(0)) 249 | top1.update(prec1.item(), input.size(0)) 250 | 251 | # compute gradient and do SGD step 252 | optimizer.zero_grad() 253 | loss.backward() 254 | optimizer.step() 255 | 256 | # measure elapsed time 257 | batch_time.update(time.time() - end) 258 | end = time.time() 259 | 260 | # print log 261 | if i % args.print_freq == 0: 262 | logging.info("Iter: [{0}/{1}]\t" 263 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 264 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" 265 | "Loss {loss.val:.3f} ({loss.avg:.3f})\t" 266 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t".format( 267 | i, 268 | args.iters, 269 | batch_time=batch_time, 270 | data_time=data_time, 271 | loss=losses, 272 | top1=top1) 273 | ) 274 | 275 | 276 | if args.swa_start is not None and i >= args.swa_start and i % args.swa_freq == 0: 277 | util_swa.moving_average(swa_model, model, 1.0 / (swa_n + 1)) 278 | swa_n += 1 279 | util_swa.bn_update(swa_loader, swa_model, args.num_bits, args.num_grad_bits) 280 | prec1 = validate(args, test_loader, swa_model, criterion, i, swa=True) 281 | 282 | if prec1 > best_swa_prec: 283 | best_swa_prec = prec1 284 | best_swa_iter = i 285 | 286 | print("Current Best SWA Prec@1: ", best_swa_prec) 287 | print("Current Best SWA Iteration: ", best_swa_iter) 288 | 289 | if (i % args.eval_every == 0 and i > 0) or (i == args.iters): 290 | with torch.no_grad(): 291 | prec1 = validate(args, test_loader, model, criterion, i) 292 | # prec_full = validate_full_prec(args, test_loader, model, criterion, i) 293 | 294 | is_best = prec1 > best_prec1 295 | if is_best: 296 | best_prec1 = prec1 297 | best_iter = i 298 | # best_full_prec = max(prec_full, best_full_prec) 299 | 300 | print("Current Best Prec@1: ", best_prec1) 301 | print("Current Best Iteration: ", best_iter) 302 | # print("Current Best Full Prec@1: ", best_full_prec) 303 | 304 | checkpoint_path = os.path.join(args.save_path, 'checkpoint_{:05d}_{:.2f}.pth.tar'.format(i, prec1)) 305 | save_checkpoint({ 306 | 'iter': i, 307 | 'arch': args.arch, 308 | 'state_dict': model.state_dict(), 309 | 'best_prec1': best_prec1, 310 | 'swa_state_dict' : swa_model.state_dict() if args.swa_start is not None else None, 311 | 'swa_n' : swa_n if args.swa_start is not None else None, 312 | 'best_swa_prec' : best_swa_prec if args.swa_start is not None else None, 313 | }, 314 | is_best, filename=checkpoint_path) 315 | shutil.copyfile(checkpoint_path, os.path.join(args.save_path, 316 | 'checkpoint_latest' 317 | '.pth.tar')) 318 | 319 | if i == args.iters: 320 | break 321 | 322 | 323 | def validate(args, test_loader, model, criterion, step, swa=False): 324 | batch_time = AverageMeter() 325 | losses = AverageMeter() 326 | top1 = AverageMeter() 327 | 328 | # switch to evaluation mode 329 | model.eval() 330 | end = time.time() 331 | for i, (input, target) in enumerate(test_loader): 332 | target = target.squeeze().long().cuda() 333 | input_var = Variable(input, volatile=True).cuda() 334 | target_var = Variable(target, volatile=True).cuda() 335 | 336 | # compute output 337 | output = model(input_var, args.num_bits, args.num_grad_bits) 338 | loss = criterion(output, target_var) 339 | 340 | # measure accuracy and record loss 341 | prec1, = accuracy(output.data, target, topk=(1,)) 342 | top1.update(prec1.item(), input.size(0)) 343 | losses.update(loss.item(), input.size(0)) 344 | batch_time.update(time.time() - end) 345 | end = time.time() 346 | 347 | if (i % args.print_freq == 0) or (i == len(test_loader) - 1): 348 | logging.info( 349 | 'Test: [{}/{}]\t' 350 | 'Time: {batch_time.val:.4f}({batch_time.avg:.4f})\t' 351 | 'Loss: {loss.val:.3f}({loss.avg:.3f})\t' 352 | 'Prec@1: {top1.val:.3f}({top1.avg:.3f})\t'.format( 353 | i, len(test_loader), batch_time=batch_time, 354 | loss=losses, top1=top1 355 | ) 356 | ) 357 | 358 | if not swa: 359 | logging.info('Step {} * Prec@1 {top1.avg:.3f}'.format(step, top1=top1)) 360 | else: 361 | logging.info('Step {} * SWA Prec@1 {top1.avg:.3f}'.format(step, top1=top1)) 362 | 363 | return top1.avg 364 | 365 | 366 | def validate_full_prec(args, test_loader, model, criterion, step): 367 | batch_time = AverageMeter() 368 | losses = AverageMeter() 369 | top1 = AverageMeter() 370 | 371 | # switch to evaluation mode 372 | model.eval() 373 | end = time.time() 374 | for i, (input, target) in enumerate(test_loader): 375 | target = target.squeeze().long().cuda() 376 | input_var = Variable(input, volatile=True).cuda() 377 | target_var = Variable(target, volatile=True).cuda() 378 | 379 | # compute output 380 | output = model(input_var, 0, 0) 381 | loss = criterion(output, target_var) 382 | 383 | # measure accuracy and record loss 384 | prec1, = accuracy(output.data, target, topk=(1,)) 385 | top1.update(prec1.item(), input.size(0)) 386 | losses.update(loss.item(), input.size(0)) 387 | batch_time.update(time.time() - end) 388 | end = time.time() 389 | 390 | 391 | logging.info('Step {} * Full Prec@1 {top1.avg:.3f}'.format(step, top1=top1)) 392 | return top1.avg 393 | 394 | 395 | def test_model(args): 396 | # create model 397 | model = models.__dict__[args.arch](args.pretrained) 398 | model = torch.nn.DataParallel(model).cuda() 399 | 400 | if args.resume: 401 | if os.path.isfile(args.resume): 402 | logging.info('=> loading checkpoint `{}`'.format(args.resume)) 403 | checkpoint = torch.load(args.resume) 404 | args.start_iter = checkpoint['iter'] 405 | best_prec1 = checkpoint['best_prec1'] 406 | model.load_state_dict(checkpoint['state_dict']) 407 | logging.info('=> loaded checkpoint `{}` (iter: {})'.format( 408 | args.resume, checkpoint['iter'] 409 | )) 410 | else: 411 | logging.info('=> no checkpoint found at `{}`'.format(args.resume)) 412 | 413 | cudnn.benchmark = False 414 | test_loader = prepare_test_data(dataset=args.dataset, 415 | batch_size=args.batch_size, 416 | shuffle=False, 417 | num_workers=args.workers) 418 | criterion = nn.CrossEntropyLoss().cuda() 419 | 420 | # validate(args, test_loader, model, criterion) 421 | 422 | with torch.no_grad(): 423 | prec1 = validate(args, test_loader, model, criterion, args.start_iter) 424 | prec_full = validate_full_prec(args, test_loader, model, criterion, args.start_iter) 425 | 426 | 427 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 428 | torch.save(state, filename) 429 | if is_best: 430 | save_path = os.path.dirname(filename) 431 | shutil.copyfile(filename, os.path.join(save_path, 432 | 'model_best.pth.tar')) 433 | 434 | 435 | class AverageMeter(object): 436 | """Computes and stores the average and current value""" 437 | 438 | def __init__(self): 439 | self.reset() 440 | 441 | def reset(self): 442 | self.val = 0 443 | self.avg = 0 444 | self.sum = 0 445 | self.count = 0 446 | 447 | def update(self, val, n=1): 448 | self.val = val 449 | self.sum += val * n 450 | self.count += n 451 | self.avg = self.sum / self.count 452 | 453 | 454 | schedule_cnt = 0 455 | def adjust_precision(args, _iter): 456 | if args.schedule: 457 | global schedule_cnt 458 | 459 | assert len(args.num_bits_schedule) == len(args.schedule) + 1 460 | assert len(args.num_grad_bits_schedule) == len(args.schedule) + 1 461 | 462 | if schedule_cnt == 0: 463 | args.num_bits = args.num_bits_schedule[0] 464 | args.num_grad_bits = args.num_grad_bits_schedule[0] 465 | schedule_cnt += 1 466 | 467 | for step in args.schedule: 468 | if _iter == step: 469 | args.num_bits = args.num_bits_schedule[schedule_cnt] 470 | args.num_grad_bits = args.num_grad_bits_schedule[schedule_cnt] 471 | schedule_cnt += 1 472 | 473 | if _iter % args.eval_every == 0: 474 | logging.info('Iter [{}] num_bits = {} num_grad_bits = {}'.format(_iter, args.num_bits, args.num_grad_bits)) 475 | 476 | 477 | def adjust_learning_rate(args, optimizer, _iter): 478 | if args.lr_schedule == 'piecewise': 479 | if args.warm_up and (_iter < 400): 480 | lr = 0.01 481 | elif 32000 <= _iter < 48000: 482 | lr = args.lr * (args.step_ratio ** 1) 483 | elif _iter >= 48000: 484 | lr = args.lr * (args.step_ratio ** 2) 485 | else: 486 | lr = args.lr 487 | 488 | elif args.lr_schedule == 'linear': 489 | t = _iter / args.iters 490 | lr_ratio = 0.01 491 | if args.warm_up and (_iter < 400): 492 | lr = 0.01 493 | elif t < 0.5: 494 | lr = args.lr 495 | elif t < 0.9: 496 | lr = args.lr * (1 - (1-lr_ratio)*(t-0.5)/0.4) 497 | else: 498 | lr = args.lr * lr_ratio 499 | 500 | elif args.lr_schedule == 'anneal_cosine': 501 | lr_min = args.lr * (args.step_ratio ** 2) 502 | lr_max = args.lr 503 | lr = lr_min + 1/2 * (lr_max - lr_min) * (1 + np.cos(_iter/args.iters * 3.141592653)) 504 | 505 | if _iter % args.eval_every == 0: 506 | logging.info('Iter [{}] learning rate = {}'.format(_iter, lr)) 507 | 508 | for param_group in optimizer.param_groups: 509 | param_group['lr'] = lr 510 | 511 | 512 | def accuracy(output, target, topk=(1,)): 513 | """Computes the precision@k for the specified values of k""" 514 | maxk = max(topk) 515 | batch_size = target.size(0) 516 | 517 | _, pred = output.topk(maxk, 1, True, True) 518 | pred = pred.t() 519 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 520 | 521 | res = [] 522 | for k in topk: 523 | correct_k = correct[:k].view(-1).float().sum(0) 524 | res.append(correct_k.mul_(100.0 / batch_size)) 525 | return res 526 | 527 | 528 | if __name__ == '__main__': 529 | main() 530 | -------------------------------------------------------------------------------- /fractrain_cifar/train_pfq.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | import os 10 | import shutil 11 | import argparse 12 | import time 13 | import logging 14 | 15 | import models 16 | from data import * 17 | 18 | import util_swa 19 | 20 | 21 | model_names = sorted(name for name in models.__dict__ 22 | if name.islower() and not name.startswith('__') 23 | and callable(models.__dict__[name]) 24 | ) 25 | 26 | 27 | def parse_args(): 28 | # hyper-parameters are from ResNet paper 29 | parser = argparse.ArgumentParser( 30 | description='PFQ on CIFAR') 31 | parser.add_argument('--dir', help='annotate the working directory') 32 | parser.add_argument('--cmd', choices=['train', 'test'], default='train') 33 | parser.add_argument('--arch', metavar='ARCH', default='cifar10_resnet_38', 34 | choices=model_names, 35 | help='model architecture: ' + 36 | ' | '.join(model_names) + 37 | ' (default: cifar10_resnet_38)') 38 | parser.add_argument('--dataset', '-d', type=str, default='cifar10', 39 | choices=['cifar10', 'cifar100'], 40 | help='dataset choice') 41 | parser.add_argument('--datadir', default='/home/yf22/dataset', type=str, 42 | help='path to dataset') 43 | parser.add_argument('--workers', default=4, type=int, metavar='N', 44 | help='number of data loading workers (default: 4 )') 45 | parser.add_argument('--iters', default=64000, type=int, 46 | help='number of total iterations (default: 64,000)') 47 | parser.add_argument('--start_iter', default=0, type=int, 48 | help='manual iter number (useful on restarts)') 49 | parser.add_argument('--batch_size', default=128, type=int, 50 | help='mini-batch size (default: 128)') 51 | parser.add_argument('--lr_schedule', default='piecewise', type=str, 52 | help='learning rate schedule') 53 | parser.add_argument('--lr', default=0.1, type=float, 54 | help='initial learning rate') 55 | parser.add_argument('--momentum', default=0.9, type=float, 56 | help='momentum') 57 | parser.add_argument('--weight_decay', default=1e-4, type=float, 58 | help='weight decay (default: 1e-4)') 59 | parser.add_argument('--print_freq', default=10, type=int, 60 | help='print frequency (default: 10)') 61 | parser.add_argument('--resume', default='', type=str, 62 | help='path to latest checkpoint (default: None)') 63 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 64 | help='use pretrained model') 65 | parser.add_argument('--step_ratio', default=0.1, type=float, 66 | help='ratio for learning rate deduction') 67 | parser.add_argument('--warm_up', action='store_true', 68 | help='for n = 18, the model needs to warm up for 400 ' 69 | 'iterations') 70 | parser.add_argument('--save_folder', default='save_checkpoints', 71 | type=str, 72 | help='folder to save the checkpoints') 73 | parser.add_argument('--eval_every', default=400, type=int, 74 | help='evaluate model every (default: 1000) iterations') 75 | parser.add_argument('--num_bits',default=0,type=int, 76 | help='num bits for weight and activation') 77 | parser.add_argument('--num_grad_bits',default=0,type=int, 78 | help='num bits for gradient') 79 | parser.add_argument('--schedule', default=None, type=int, nargs='*', 80 | help='precision schedule') 81 | parser.add_argument('--num_bits_schedule',default=None,type=int,nargs='*', 82 | help='schedule for weight/act precision') 83 | parser.add_argument('--num_grad_bits_schedule',default=None,type=int,nargs='*', 84 | help='schedule for grad precision') 85 | parser.add_argument('--act_fw', default=0, type=int, 86 | help='precision of activation during forward, -1 means dynamic, 0 means no quantize') 87 | parser.add_argument('--act_bw', default=0, type=int, 88 | help='precision of activation during backward, -1 means dynamic, 0 means no quantize') 89 | parser.add_argument('--grad_act_error', default=0, type=int, 90 | help='precision of activation gradient during error backward, -1 means dynamic, 0 means no quantize') 91 | parser.add_argument('--grad_act_gc', default=0, type=int, 92 | help='precision of activation gradient during weight gradient computation, -1 means dynamic, 0 means no quantize') 93 | parser.add_argument('--weight_bits', default=0, type=int, 94 | help='precision of weight') 95 | parser.add_argument('--momentum_act', default=0.9, type=float, 96 | help='momentum for act min/max') 97 | parser.add_argument('--swa_start', type=float, default=None, help='SWA start step number') 98 | parser.add_argument('--swa_freq', type=float, default=1170, 99 | help='SWA model collection frequency') 100 | 101 | parser.add_argument('--num_turning_point', type=int, default=3) 102 | parser.add_argument('--initial_threshold', type=float, default=0.15) 103 | parser.add_argument('--decay', type=float, default=0.4) 104 | args = parser.parse_args() 105 | return args 106 | 107 | # indicator 108 | class loss_diff_indicator(): 109 | def __init__(self, threshold, decay, epoch_keep=5): 110 | self.threshold = threshold 111 | self.decay = decay 112 | self.epoch_keep = epoch_keep 113 | self.loss = [] 114 | self.scale_loss = 1 115 | self.loss_diff = [1 for i in range(1, self.epoch_keep)] 116 | 117 | def reset(self): 118 | self.loss = [] 119 | self.loss_diff = [1 for i in range(1, self.epoch_keep)] 120 | 121 | def adaptive_threshold(self, turning_point_count): 122 | decay_1 = self.decay 123 | decay_2 = self.decay 124 | if turning_point_count == 1: 125 | self.threshold *= decay_1 126 | if turning_point_count == 2: 127 | self.threshold *= decay_2 128 | print('threshold decay to {}'.format(self.threshold)) 129 | 130 | def get_loss(self, current_epoch_loss): 131 | if len(self.loss) < self.epoch_keep: 132 | self.loss.append(current_epoch_loss) 133 | else: 134 | self.loss.pop(0) 135 | self.loss.append(current_epoch_loss) 136 | 137 | def cal_loss_diff(self): 138 | if len(self.loss) == self.epoch_keep: 139 | for i in range(len(self.loss)-1): 140 | loss_now = self.loss[-1] 141 | loss_pre = self.loss[i] 142 | self.loss_diff[i] = np.abs(loss_pre - loss_now) / self.scale_loss 143 | return True 144 | else: 145 | return False 146 | 147 | def turning_point_emerge(self): 148 | flag = self.cal_loss_diff() 149 | if flag == True: 150 | print(self.loss_diff) 151 | for i in range(len(self.loss_diff)): 152 | if self.loss_diff[i] > self.threshold: 153 | return False 154 | return True 155 | else: 156 | return False 157 | 158 | def main(): 159 | args = parse_args() 160 | global save_path 161 | save_path = args.save_path = os.path.join(args.save_folder, args.arch) 162 | if not os.path.exists(save_path): 163 | os.makedirs(save_path) 164 | 165 | models.ACT_FW = args.act_fw 166 | models.ACT_BW = args.act_bw 167 | models.GRAD_ACT_ERROR = args.grad_act_error 168 | models.GRAD_ACT_GC = args.grad_act_gc 169 | models.WEIGHT_BITS = args.weight_bits 170 | models.MOMENTUM = args.momentum_act 171 | 172 | args.num_bits = args.num_bits if not (args.act_fw + args.act_bw + args.grad_act_error + args.grad_act_gc + args.weight_bits) else -1 173 | 174 | # config logging file 175 | args.logger_file = os.path.join(save_path, 'log_{}.txt'.format(args.cmd)) 176 | if os.path.exists(args.logger_file): 177 | os.remove(args.logger_file) 178 | handlers = [logging.FileHandler(args.logger_file, mode='w'), 179 | logging.StreamHandler()] 180 | logging.basicConfig(level=logging.INFO, 181 | datefmt='%m-%d-%y %H:%M', 182 | format='%(asctime)s:%(message)s', 183 | handlers=handlers) 184 | 185 | global history_score 186 | history_score = np.zeros((args.iters // args.eval_every, 3)) 187 | 188 | # initialize indicator 189 | # initial_threshold=0.15 190 | global scale_loss 191 | scale_loss = 0 192 | global my_loss_diff_indicator 193 | my_loss_diff_indicator = loss_diff_indicator(threshold=args.initial_threshold, 194 | decay=args.decay) 195 | 196 | global turning_point_count 197 | turning_point_count = 0 198 | 199 | if args.cmd == 'train': 200 | logging.info('start training {}'.format(args.arch)) 201 | run_training(args) 202 | 203 | elif args.cmd == 'test': 204 | logging.info('start evaluating {} with checkpoints from {}'.format( 205 | args.arch, args.resume)) 206 | test_model(args) 207 | 208 | 209 | 210 | def run_training(args): 211 | # create model 212 | training_loss = 0 213 | training_acc = 0 214 | 215 | model = models.__dict__[args.arch](args.pretrained) 216 | model = torch.nn.DataParallel(model).cuda() 217 | 218 | if args.swa_start is not None: 219 | print('SWA training') 220 | swa_model = torch.nn.DataParallel(models.__dict__[args.arch](args.pretrained)).cuda() 221 | swa_n = 0 222 | 223 | else: 224 | print('SGD training') 225 | 226 | best_prec1 = 0 227 | best_iter = 0 228 | 229 | best_swa_prec = 0 230 | best_swa_iter = 0 231 | 232 | # best_full_prec = 0 233 | 234 | if args.resume: 235 | if os.path.isfile(args.resume): 236 | logging.info('=> loading checkpoint `{}`'.format(args.resume)) 237 | checkpoint = torch.load(args.resume) 238 | args.start_iter = checkpoint['iter'] 239 | best_prec1 = checkpoint['best_prec1'] 240 | model.load_state_dict(checkpoint['state_dict']) 241 | 242 | if args.swa_start is not None: 243 | swa_state_dict = checkpoint['swa_state_dict'] 244 | if swa_state_dict is not None: 245 | swa_model.load_state_dict(swa_state_dict) 246 | swa_n_ckpt = checkpoint['swa_n'] 247 | if swa_n_ckpt is not None: 248 | swa_n = swa_n_ckpt 249 | best_swa_prec_ckpt = checkpoint['best_swa_prec'] 250 | if best_swa_prec_ckpt is not None: 251 | best_swa_prec = best_swa_prec_ckpt 252 | 253 | logging.info('=> loaded checkpoint `{}` (iter: {})'.format( 254 | args.resume, checkpoint['iter'] 255 | )) 256 | else: 257 | logging.info('=> no checkpoint found at `{}`'.format(args.resume)) 258 | 259 | cudnn.benchmark = False 260 | 261 | train_loader = prepare_train_data(dataset=args.dataset, 262 | datadir=args.datadir, 263 | batch_size=args.batch_size, 264 | shuffle=True, 265 | num_workers=args.workers) 266 | test_loader = prepare_test_data(dataset=args.dataset, 267 | datadir=args.datadir, 268 | batch_size=args.batch_size, 269 | shuffle=False, 270 | num_workers=args.workers) 271 | if args.swa_start is not None: 272 | swa_loader = prepare_train_data(dataset=args.dataset, 273 | datadir=args.datadir, 274 | batch_size=args.batch_size, 275 | shuffle=False, 276 | num_workers=args.workers) 277 | 278 | # define loss function (criterion) and optimizer 279 | criterion = nn.CrossEntropyLoss().cuda() 280 | 281 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 282 | momentum=args.momentum, 283 | weight_decay=args.weight_decay) 284 | 285 | # optimizer = torch.optim.Adam(model.parameters(), args.lr, 286 | # weight_decay=args.weight_decay) 287 | 288 | batch_time = AverageMeter() 289 | data_time = AverageMeter() 290 | losses = AverageMeter() 291 | top1 = AverageMeter() 292 | cr = AverageMeter() 293 | 294 | end = time.time() 295 | 296 | global scale_loss 297 | global turning_point_count 298 | global my_loss_diff_indicator 299 | 300 | i = args.start_iter 301 | while i < args.iters: 302 | for input, target in train_loader: 303 | # measuring data loading time 304 | data_time.update(time.time() - end) 305 | 306 | model.train() 307 | adjust_learning_rate(args, optimizer, i) 308 | # adjust_precision(args, i) 309 | adaptive_adjust_precision(args, turning_point_count) 310 | 311 | i += 1 312 | 313 | fw_cost = args.num_bits*args.num_bits/32/32 314 | eb_cost = args.num_bits*args.num_grad_bits/32/32 315 | gc_cost = eb_cost 316 | cr.update((fw_cost+eb_cost+gc_cost)/3) 317 | 318 | target = target.squeeze().long().cuda() 319 | input_var = Variable(input).cuda() 320 | target_var = Variable(target).cuda() 321 | 322 | # compute output 323 | output = model(input_var, args.num_bits, args.num_grad_bits) 324 | loss = criterion(output, target_var) 325 | training_loss += loss.item() 326 | 327 | # measure accuracy and record loss 328 | prec1, = accuracy(output.data, target, topk=(1,)) 329 | losses.update(loss.item(), input.size(0)) 330 | top1.update(prec1.item(), input.size(0)) 331 | training_acc += prec1.item() 332 | 333 | # compute gradient and do SGD step 334 | optimizer.zero_grad() 335 | loss.backward() 336 | optimizer.step() 337 | 338 | # measure elapsed time 339 | batch_time.update(time.time() - end) 340 | end = time.time() 341 | 342 | # print log 343 | if i % args.print_freq == 0: 344 | logging.info("Iter: [{0}/{1}]\t" 345 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 346 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" 347 | "Loss {loss.val:.3f} ({loss.avg:.3f})\t" 348 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t".format( 349 | i, 350 | args.iters, 351 | batch_time=batch_time, 352 | data_time=data_time, 353 | loss=losses, 354 | top1=top1) 355 | ) 356 | 357 | 358 | if args.swa_start is not None and i >= args.swa_start and i % args.swa_freq == 0: 359 | util_swa.moving_average(swa_model, model, 1.0 / (swa_n + 1)) 360 | swa_n += 1 361 | util_swa.bn_update(swa_loader, swa_model, args.num_bits, args.num_grad_bits) 362 | prec1 = validate(args, test_loader, swa_model, criterion, i, swa=True) 363 | 364 | if prec1 > best_swa_prec: 365 | best_swa_prec = prec1 366 | best_swa_iter = i 367 | 368 | # print("Current Best SWA Prec@1: ", best_swa_prec) 369 | # print("Current Best SWA Iteration: ", best_swa_iter) 370 | 371 | if (i % args.eval_every == 0 and i > 0) or (i == args.iters): 372 | # record training loss and test accuracy 373 | global history_score 374 | epoch = i // args.eval_every 375 | epoch_loss = training_loss / len(train_loader) 376 | with torch.no_grad(): 377 | prec1 = validate(args, test_loader, model, criterion, i) 378 | # prec_full = validate_full_prec(args, test_loader, model, criterion, i) 379 | history_score[epoch-1][0] = epoch_loss 380 | history_score[epoch-1][1] = np.round(training_acc / len(train_loader), 2) 381 | history_score[epoch-1][2] = prec1 382 | training_loss = 0 383 | training_acc = 0 384 | 385 | np.savetxt(os.path.join(save_path, 'record.txt'), history_score, fmt = '%10.5f', delimiter=',') 386 | 387 | # apply indicator 388 | # if epoch == 1: 389 | # logging.info('initial loss value: {}'.format(epoch_loss)) 390 | # my_loss_diff_indicator.scale_loss = epoch_loss 391 | if epoch <= 10: 392 | scale_loss += epoch_loss 393 | logging.info('scale_loss at epoch {}: {}'.format(epoch, scale_loss / epoch)) 394 | my_loss_diff_indicator.scale_loss = scale_loss / epoch 395 | if turning_point_count < args.num_turning_point: 396 | my_loss_diff_indicator.get_loss(epoch_loss) 397 | flag = my_loss_diff_indicator.turning_point_emerge() 398 | if flag == True: 399 | turning_point_count += 1 400 | logging.info('find {}-th turning point at {}-th epoch'.format(turning_point_count, epoch)) 401 | # print('find {}-th turning point at {}-th epoch'.format(turning_point_count, epoch)) 402 | my_loss_diff_indicator.adaptive_threshold(turning_point_count=turning_point_count) 403 | my_loss_diff_indicator.reset() 404 | 405 | logging.info('Epoch [{}] num_bits = {} num_grad_bits = {}'.format(epoch, args.num_bits, args.num_grad_bits)) 406 | 407 | # print statistics 408 | is_best = prec1 > best_prec1 409 | if is_best: 410 | best_prec1 = prec1 411 | best_iter = i 412 | # best_full_prec = max(prec_full, best_full_prec) 413 | 414 | logging.info("Current Best Prec@1: {}".format(best_prec1)) 415 | logging.info("Current Best Iteration: {}".format(best_iter)) 416 | logging.info("Current Best SWA Prec@1: {}".format(best_swa_prec)) 417 | logging.info("Current Best SWA Iteration: {}".format(best_swa_iter)) 418 | # print("Current Best Full Prec@1: ", best_full_prec) 419 | 420 | # checkpoint_path = os.path.join(args.save_path, 'checkpoint_{:05d}_{:.2f}.pth.tar'.format(i, prec1)) 421 | checkpoint_path = os.path.join(args.save_path, 'ckpt.pth.tar') 422 | save_checkpoint({ 423 | 'iter': i, 424 | 'arch': args.arch, 425 | 'state_dict': model.state_dict(), 426 | 'best_prec1': best_prec1, 427 | 'swa_state_dict' : swa_model.state_dict() if args.swa_start is not None else None, 428 | 'swa_n' : swa_n if args.swa_start is not None else None, 429 | 'best_swa_prec' : best_swa_prec if args.swa_start is not None else None, 430 | }, 431 | is_best, filename=checkpoint_path) 432 | # shutil.copyfile(checkpoint_path, os.path.join(args.save_path, 433 | # 'checkpoint_latest' 434 | # '.pth.tar')) 435 | 436 | if i == args.iters: 437 | print("Best accuracy: "+str(best_prec1)) 438 | history_score[-1][0] = best_prec1 439 | np.savetxt(os.path.join(save_path, 'record.txt'), history_score, fmt = '%10.5f', delimiter=',') 440 | break 441 | 442 | 443 | def validate(args, test_loader, model, criterion, step, swa=False): 444 | batch_time = AverageMeter() 445 | losses = AverageMeter() 446 | top1 = AverageMeter() 447 | 448 | # switch to evaluation mode 449 | model.eval() 450 | end = time.time() 451 | for i, (input, target) in enumerate(test_loader): 452 | target = target.squeeze().long().cuda() 453 | input_var = Variable(input, volatile=True).cuda() 454 | target_var = Variable(target, volatile=True).cuda() 455 | 456 | # compute output 457 | output = model(input_var, args.num_bits, args.num_grad_bits) 458 | loss = criterion(output, target_var) 459 | 460 | # measure accuracy and record loss 461 | prec1, = accuracy(output.data, target, topk=(1,)) 462 | top1.update(prec1.item(), input.size(0)) 463 | losses.update(loss.item(), input.size(0)) 464 | batch_time.update(time.time() - end) 465 | end = time.time() 466 | 467 | if (i % args.print_freq == 0) or (i == len(test_loader) - 1): 468 | logging.info( 469 | 'Test: [{}/{}]\t' 470 | 'Time: {batch_time.val:.4f}({batch_time.avg:.4f})\t' 471 | 'Loss: {loss.val:.3f}({loss.avg:.3f})\t' 472 | 'Prec@1: {top1.val:.3f}({top1.avg:.3f})\t'.format( 473 | i, len(test_loader), batch_time=batch_time, 474 | loss=losses, top1=top1 475 | ) 476 | ) 477 | 478 | if not swa: 479 | logging.info('Step {} * Prec@1 {top1.avg:.3f}'.format(step, top1=top1)) 480 | else: 481 | logging.info('Step {} * SWA Prec@1 {top1.avg:.3f}'.format(step, top1=top1)) 482 | 483 | return top1.avg 484 | 485 | 486 | def validate_full_prec(args, test_loader, model, criterion, step): 487 | batch_time = AverageMeter() 488 | losses = AverageMeter() 489 | top1 = AverageMeter() 490 | 491 | # switch to evaluation mode 492 | model.eval() 493 | end = time.time() 494 | for i, (input, target) in enumerate(test_loader): 495 | target = target.squeeze().long().cuda() 496 | input_var = Variable(input, volatile=True).cuda() 497 | target_var = Variable(target, volatile=True).cuda() 498 | 499 | # compute output 500 | output = model(input_var, 0, 0) 501 | loss = criterion(output, target_var) 502 | 503 | # measure accuracy and record loss 504 | prec1, = accuracy(output.data, target, topk=(1,)) 505 | top1.update(prec1.item(), input.size(0)) 506 | losses.update(loss.item(), input.size(0)) 507 | batch_time.update(time.time() - end) 508 | end = time.time() 509 | 510 | 511 | logging.info('Step {} * Full Prec@1 {top1.avg:.3f}'.format(step, top1=top1)) 512 | return top1.avg 513 | 514 | 515 | def test_model(args): 516 | # create model 517 | model = models.__dict__[args.arch](args.pretrained) 518 | model = torch.nn.DataParallel(model).cuda() 519 | 520 | if args.resume: 521 | if os.path.isfile(args.resume): 522 | logging.info('=> loading checkpoint `{}`'.format(args.resume)) 523 | checkpoint = torch.load(args.resume) 524 | args.start_iter = checkpoint['iter'] 525 | best_prec1 = checkpoint['best_prec1'] 526 | model.load_state_dict(checkpoint['state_dict']) 527 | logging.info('=> loaded checkpoint `{}` (iter: {})'.format( 528 | args.resume, checkpoint['iter'] 529 | )) 530 | else: 531 | logging.info('=> no checkpoint found at `{}`'.format(args.resume)) 532 | 533 | cudnn.benchmark = False 534 | test_loader = prepare_test_data(dataset=args.dataset, 535 | batch_size=args.batch_size, 536 | shuffle=False, 537 | num_workers=args.workers) 538 | criterion = nn.CrossEntropyLoss().cuda() 539 | 540 | # validate(args, test_loader, model, criterion) 541 | 542 | with torch.no_grad(): 543 | prec1 = validate(args, test_loader, model, criterion, args.start_iter) 544 | prec_full = validate_full_prec(args, test_loader, model, criterion, args.start_iter) 545 | 546 | 547 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 548 | torch.save(state, filename) 549 | if is_best: 550 | save_path = os.path.dirname(filename) 551 | shutil.copyfile(filename, os.path.join(save_path, 552 | 'model_best.pth.tar')) 553 | 554 | 555 | class AverageMeter(object): 556 | """Computes and stores the average and current value""" 557 | 558 | def __init__(self): 559 | self.reset() 560 | 561 | def reset(self): 562 | self.val = 0 563 | self.avg = 0 564 | self.sum = 0 565 | self.count = 0 566 | 567 | def update(self, val, n=1): 568 | self.val = val 569 | self.sum += val * n 570 | self.count += n 571 | self.avg = self.sum / self.count 572 | 573 | 574 | schedule_cnt = 0 575 | def adjust_precision(args, _iter): 576 | if args.schedule: 577 | global schedule_cnt 578 | 579 | assert len(args.num_bits_schedule) == len(args.schedule) + 1 580 | assert len(args.num_grad_bits_schedule) == len(args.schedule) + 1 581 | 582 | if schedule_cnt == 0: 583 | args.num_bits = args.num_bits_schedule[0] 584 | args.num_grad_bits = args.num_grad_bits_schedule[0] 585 | schedule_cnt += 1 586 | 587 | for step in args.schedule: 588 | if _iter == step: 589 | args.num_bits = args.num_bits_schedule[schedule_cnt] 590 | args.num_grad_bits = args.num_grad_bits_schedule[schedule_cnt] 591 | schedule_cnt += 1 592 | 593 | if _iter % args.eval_every == 0: 594 | logging.info('Iter [{}] num_bits = {} num_grad_bits = {}'.format(_iter, args.num_bits, args.num_grad_bits)) 595 | 596 | def adaptive_adjust_precision(args, turning_point_count): 597 | args.num_bits = args.num_bits_schedule[turning_point_count] 598 | args.num_grad_bits = args.num_grad_bits_schedule[turning_point_count] 599 | 600 | 601 | def adjust_learning_rate(args, optimizer, _iter): 602 | if args.lr_schedule == 'piecewise': 603 | if args.warm_up and (_iter < 400): 604 | lr = 0.01 605 | elif 32000 <= _iter < 48000: 606 | lr = args.lr * (args.step_ratio ** 1) 607 | elif _iter >= 48000: 608 | lr = args.lr * (args.step_ratio ** 2) 609 | else: 610 | lr = args.lr 611 | 612 | elif args.lr_schedule == 'linear': 613 | t = _iter / args.iters 614 | lr_ratio = 0.01 615 | if args.warm_up and (_iter < 400): 616 | lr = 0.01 617 | elif t < 0.5: 618 | lr = args.lr 619 | elif t < 0.9: 620 | lr = args.lr * (1 - (1-lr_ratio)*(t-0.5)/0.4) 621 | else: 622 | lr = args.lr * lr_ratio 623 | 624 | elif args.lr_schedule == 'anneal_cosine': 625 | lr_min = args.lr * (args.step_ratio ** 2) 626 | lr_max = args.lr 627 | lr = lr_min + 1/2 * (lr_max - lr_min) * (1 + np.cos(_iter/args.iters * 3.141592653)) 628 | 629 | if _iter % args.eval_every == 0: 630 | logging.info('Iter [{}] learning rate = {}'.format(_iter, lr)) 631 | 632 | for param_group in optimizer.param_groups: 633 | param_group['lr'] = lr 634 | 635 | 636 | def accuracy(output, target, topk=(1,)): 637 | """Computes the precision@k for the specified values of k""" 638 | maxk = max(topk) 639 | batch_size = target.size(0) 640 | 641 | _, pred = output.topk(maxk, 1, True, True) 642 | pred = pred.t() 643 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 644 | 645 | res = [] 646 | for k in topk: 647 | correct_k = correct[:k].view(-1).float().sum(0) 648 | res.append(correct_k.mul_(100.0 / batch_size)) 649 | return res 650 | 651 | 652 | if __name__ == '__main__': 653 | main() 654 | -------------------------------------------------------------------------------- /fractrain_cifar/util_swa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | def moving_average(net1, net2, alpha=1): 5 | for param1, param2 in zip(net1.parameters(), net2.parameters()): 6 | param1.data *= (1.0 - alpha) 7 | param1.data += param2.data * alpha 8 | 9 | 10 | def _check_bn(module, flag): 11 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 12 | flag[0] = True 13 | 14 | 15 | def check_bn(model): 16 | flag = [False] 17 | model.apply(lambda module: _check_bn(module, flag)) 18 | return flag[0] 19 | 20 | 21 | def reset_bn(module): 22 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 23 | module.running_mean = torch.zeros_like(module.running_mean) 24 | module.running_var = torch.ones_like(module.running_var) 25 | 26 | 27 | def _get_momenta(module, momenta): 28 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 29 | momenta[module] = module.momentum 30 | 31 | 32 | def _set_momenta(module, momenta): 33 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm): 34 | module.momentum = momenta[module] 35 | 36 | 37 | def bn_update(loader, model, num_bits, num_grad_bits): 38 | """ 39 | BatchNorm buffers update (if any). 40 | Performs 1 epochs to estimate buffers average using train dataset. 41 | :param loader: train dataset loader for buffers average estimation. 42 | :param model: model being update 43 | :return: None 44 | """ 45 | if not check_bn(model): 46 | return 47 | model.train() 48 | momenta = {} 49 | model.apply(reset_bn) 50 | model.apply(lambda module: _get_momenta(module, momenta)) 51 | n = 0 52 | 53 | print("SWA Update BN...") 54 | for input, _ in loader: 55 | input = input.cuda(async=True) 56 | input_var = torch.autograd.Variable(input) 57 | b = input_var.data.size(0) 58 | 59 | momentum = b / (n + b) 60 | for module in momenta.keys(): 61 | module.momentum = momentum 62 | 63 | model(input_var, num_bits, num_grad_bits) 64 | n += b 65 | 66 | model.apply(lambda module: _set_momenta(module, momenta)) -------------------------------------------------------------------------------- /fractrain_imagenet/data.py: -------------------------------------------------------------------------------- 1 | """prepare CIFAR and SVHN 2 | """ 3 | 4 | from __future__ import print_function 5 | 6 | import torch 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import numpy as np 10 | 11 | 12 | crop_size = 32 13 | padding = 4 14 | 15 | 16 | def prepare_train_data(dataset='cifar10', datadir='/home/yf22/dataset', batch_size=128, 17 | shuffle=True, num_workers=4): 18 | 19 | if 'cifar' in dataset: 20 | transform_train = transforms.Compose([ 21 | transforms.RandomCrop(crop_size, padding=padding), 22 | transforms.RandomHorizontalFlip(), 23 | transforms.ToTensor(), 24 | transforms.Normalize((0.4914, 0.4822, 0.4465), 25 | (0.2023, 0.1994, 0.2010)), 26 | ]) 27 | 28 | trainset = torchvision.datasets.__dict__[dataset.upper()]( 29 | root=datadir, train=True, download=True, transform=transform_train) 30 | train_loader = torch.utils.data.DataLoader(trainset, 31 | batch_size=batch_size, 32 | shuffle=shuffle, 33 | num_workers=num_workers) 34 | 35 | if 'imagenet' in dataset: 36 | train_dataset = torchvision.datasets.ImageFolder( 37 | datadir, 38 | transforms.Compose([ 39 | transforms.RandomResizedCrop(224), 40 | transforms.RandomHorizontalFlip(), 41 | transforms.ToTensor(), 42 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 43 | std=[0.229, 0.224, 0.225]) 44 | ])) 45 | train_loader = torch.utils.data.DataLoader( 46 | train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) 47 | 48 | elif 'svhn' in dataset: 49 | transform_train =transforms.Compose([ 50 | transforms.ToTensor(), 51 | transforms.Normalize((0.4377, 0.4438, 0.4728), 52 | (0.1980, 0.2010, 0.1970)), 53 | ]) 54 | trainset = torchvision.datasets.__dict__[dataset.upper()]( 55 | root=datadir, 56 | split='train', 57 | download=True, 58 | transform=transform_train 59 | ) 60 | 61 | transform_extra = transforms.Compose([ 62 | transforms.ToTensor(), 63 | transforms.Normalize((0.4300, 0.4284, 0.4427), 64 | (0.1963, 0.1979, 0.1995)) 65 | ]) 66 | 67 | extraset = torchvision.datasets.__dict__[dataset.upper()]( 68 | root=datadir, 69 | split='extra', 70 | download=True, 71 | transform = transform_extra 72 | ) 73 | 74 | total_data = torch.utils.data.ConcatDataset([trainset, extraset]) 75 | 76 | train_loader = torch.utils.data.DataLoader(total_data, 77 | batch_size=batch_size, 78 | shuffle=shuffle, 79 | num_workers=num_workers) 80 | else: 81 | train_loader = None 82 | return train_loader 83 | 84 | 85 | def prepare_test_data(dataset='cifar10', datadir='/home/yf22/dataset', batch_size=128, 86 | shuffle=False, num_workers=4): 87 | 88 | if 'cifar' in dataset: 89 | transform_test = transforms.Compose([ 90 | transforms.ToTensor(), 91 | transforms.Normalize((0.4914, 0.4822, 0.4465), 92 | (0.2023, 0.1994, 0.2010)), 93 | ]) 94 | 95 | testset = torchvision.datasets.__dict__[dataset.upper()](root=datadir, 96 | train=False, 97 | download=True, 98 | transform=transform_test) 99 | test_loader = torch.utils.data.DataLoader(testset, 100 | batch_size=batch_size, 101 | shuffle=shuffle, 102 | num_workers=num_workers) 103 | 104 | if 'imagenet' in dataset: 105 | test_loader = torch.utils.data.DataLoader( 106 | torchvision.datasets.ImageFolder(datadir, transforms.Compose([ 107 | transforms.Resize(256), 108 | transforms.CenterCrop(224), 109 | transforms.ToTensor(), 110 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 111 | std=[0.229, 0.224, 0.225]) 112 | ])), 113 | batch_size=batch_size, shuffle=False, num_workers=num_workers) 114 | 115 | elif 'svhn' in dataset: 116 | transform_test = transforms.Compose([ 117 | transforms.ToTensor(), 118 | transforms.Normalize((0.4524, 0.4525, 0.4690), 119 | (0.2194, 0.2266, 0.2285)), 120 | ]) 121 | testset = torchvision.datasets.__dict__[dataset.upper()]( 122 | root=datadir, 123 | split='test', 124 | download=True, 125 | transform=transform_test) 126 | np.place(testset.labels, testset.labels == 10, 0) 127 | test_loader = torch.utils.data.DataLoader(testset, 128 | batch_size=batch_size, 129 | shuffle=shuffle, 130 | num_workers=num_workers) 131 | else: 132 | test_loader = None 133 | return test_loader 134 | -------------------------------------------------------------------------------- /fractrain_imagenet/models.py: -------------------------------------------------------------------------------- 1 | """ This file contains the model definitions for both original ResNet (6n+2 2 | layers) and SkipNets. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import math 8 | from torch.autograd import Variable 9 | import torch.autograd as autograd 10 | from modules.quantize import quantize, quantize_grad, QConv2d, QLinear, RangeBN 11 | import torch.nn.functional as F 12 | 13 | 14 | ACT_FW = 0 15 | ACT_BW = 0 16 | GRAD_ACT_ERROR = 0 17 | GRAD_ACT_GC = 0 18 | WEIGHT_BITS = 0 19 | MOMENTUM = 0.9 20 | 21 | def Conv3x3(in_planes, out_planes, stride=1): 22 | "3x3 convolution with padding" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=1, bias=False) 25 | 26 | 27 | def conv3x3(in_planes, out_planes, stride=1, pool_size = None, fix_prec=False): 28 | "3x3 convolution with padding" 29 | return QConv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=1, bias=False, momentum=MOMENTUM, quant_act_forward=ACT_FW, quant_act_backward=ACT_BW, 31 | quant_grad_act_error=GRAD_ACT_ERROR, quant_grad_act_gc=GRAD_ACT_GC, weight_bits=WEIGHT_BITS, fix_prec=fix_prec) 32 | 33 | def conv1x1(in_planes, out_planes, stride=1, pool_size = None, padding=0, fix_prec=False): 34 | return QConv2d(in_planes, out_planes, kernel_size=1, stride=stride, 35 | padding=padding, bias=False, momentum=MOMENTUM, quant_act_forward=ACT_FW, quant_act_backward=ACT_BW, 36 | quant_grad_act_error=GRAD_ACT_ERROR, quant_grad_act_gc=GRAD_ACT_GC, weight_bits=WEIGHT_BITS, fix_prec=fix_prec) 37 | 38 | def make_bn(planes): 39 | return nn.BatchNorm2d(planes) 40 | # return RangeBN(planes) 41 | 42 | 43 | class BasicBlock(nn.Module): 44 | expansion = 1 45 | 46 | def __init__(self, inplanes, planes, stride=1, downsample=None, fix_prec=False): 47 | super(BasicBlock, self).__init__() 48 | self.conv1 = conv3x3(inplanes, planes, stride, fix_prec=fix_prec) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.conv2 = conv3x3(planes, planes, fix_prec=fix_prec) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x, num_bits, num_grad_bits, mask_list): 57 | residual = x 58 | 59 | out = self.conv1(x, num_bits, num_grad_bits, mask_list) 60 | out = self.bn1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out, num_bits, num_grad_bits, mask_list) 64 | out = self.bn2(out) 65 | 66 | if self.downsample is not None: 67 | residual = self.downsample(x) 68 | 69 | out += residual 70 | out = self.relu(out) 71 | 72 | return out 73 | 74 | 75 | class Bottleneck(nn.Module): 76 | expansion = 4 77 | 78 | def __init__(self, inplanes, planes, stride=1, downsample=None, fix_prec=False): 79 | super(Bottleneck, self).__init__() 80 | self.conv1 = conv1x1(inplanes, planes, fix_prec=fix_prec) 81 | self.bn1 = nn.BatchNorm2d(planes) 82 | self.conv2 = conv3x3(planes, planes, stride=stride, fix_prec=fix_prec) 83 | self.bn2 = nn.BatchNorm2d(planes) 84 | self.conv3 = conv1x1(planes, planes * 4, fix_prec=fix_prec) 85 | self.bn3 = nn.BatchNorm2d(planes * 4) 86 | self.relu = nn.ReLU(inplace=True) 87 | self.downsample = downsample 88 | self.stride = stride 89 | 90 | def forward(self, x, num_bits, num_grad_bits, mask_list): 91 | residual = x 92 | 93 | out = self.conv1(x, num_bits, num_grad_bits, mask_list) 94 | out = self.bn1(out) 95 | out = self.relu(out) 96 | 97 | out = self.conv2(out, num_bits, num_grad_bits, mask_list) 98 | out = self.bn2(out) 99 | out = self.relu(out) 100 | 101 | out = self.conv3(out, num_bits, num_grad_bits, mask_list) 102 | out = self.bn3(out) 103 | 104 | if self.downsample is not None: 105 | residual = self.downsample(x) 106 | 107 | out += residual 108 | out = self.relu(out) 109 | 110 | return out 111 | 112 | 113 | # class ResNet(nn.Module): 114 | # def __init__(self, block, layers, num_classes=1000): 115 | # self.inplanes = 64 116 | # super(ResNet, self).__init__() 117 | 118 | # self.num_layers = layers 119 | 120 | # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 121 | # bias=False) 122 | # self.bn1 = nn.BatchNorm2d(64) 123 | # self.relu = nn.ReLU(inplace=True) 124 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 125 | 126 | # self.layer1 = self._make_group(block, 64, layers[0], group_id=1) 127 | # self.layer2 = self._make_group(block, 128, layers[1], group_id=2) 128 | # self.layer3 = self._make_group(block, 256, layers[2], group_id=3) 129 | # self.layer4 = self._make_group(block, 512, layers[3], group_id=4) 130 | 131 | # self.avgpool = nn.AvgPool2d(7) 132 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 133 | 134 | # for m in self.modules(): 135 | # if isinstance(m, nn.Conv2d): 136 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 137 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 138 | # if isinstance(m, QConv2d): 139 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 140 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 141 | # elif isinstance(m, nn.BatchNorm2d): 142 | # m.weight.data.fill_(1) 143 | # m.bias.data.zero_() 144 | # elif isinstance(m, nn.Linear): 145 | # n = m.weight.size(0) * m.weight.size(1) 146 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 147 | 148 | 149 | # def _make_group(self, block, planes, layers, group_id): 150 | # """ Create the whole group""" 151 | # for i in range(layers): 152 | # if group_id > 1 and i == 0: 153 | # stride = 2 154 | # else: 155 | # stride = 1 156 | 157 | # layer = self._make_layer(block, planes, stride=stride) 158 | 159 | # setattr(self, 'group{}_layer{}'.format(group_id, i), layer) 160 | 161 | 162 | # def _make_layer(self, block, planes, stride=1): 163 | # downsample = None 164 | # if stride != 1 or self.inplanes != planes * block.expansion: 165 | # downsample = nn.Sequential( 166 | # nn.Conv2d(self.inplanes, planes * block.expansion, 167 | # kernel_size=1, stride=stride, bias=False), 168 | # nn.BatchNorm2d(planes * block.expansion), 169 | # ) 170 | 171 | # layer = block(self.inplanes, planes, stride, downsample, fix_prec=True) 172 | # self.inplanes = planes * block.expansion 173 | 174 | # return layer 175 | 176 | # def forward(self, x, num_bits, num_grad_bits): 177 | # x = self.conv1(x) 178 | # x = self.bn1(x) 179 | # x = self.relu(x) 180 | # x = self.maxpool(x) 181 | 182 | # for g in range(len(self.num_layers)): 183 | # for i in range(self.num_layers[g]): 184 | # x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x, num_bits, num_grad_bits) 185 | 186 | # x = self.avgpool(x) 187 | # x = x.view(x.size(0), -1) 188 | # x = self.fc(x) 189 | # return x 190 | 191 | 192 | # def resnet18(pretrained=False, **kwargs): 193 | # model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 194 | # return model 195 | 196 | 197 | # def resnet34(pretrained=False, **kwargs): 198 | # model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 199 | # return model 200 | 201 | 202 | # def resnet50(pretrained=False, **kwargs): 203 | # model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 204 | # return model 205 | 206 | 207 | 208 | class RNNGate(nn.Module): 209 | """Recurrent Gate definition. 210 | Input is already passed through average pooling and embedding.""" 211 | def __init__(self, input_dim, hidden_dim, proj_dim, rnn_type='lstm'): 212 | super(RNNGate, self).__init__() 213 | self.rnn_type = rnn_type 214 | self.input_dim = input_dim 215 | self.hidden_dim = hidden_dim 216 | self.proj_dim = proj_dim 217 | 218 | if self.rnn_type == 'lstm': 219 | self.rnn_one = nn.LSTM(input_dim, hidden_dim) 220 | # self.rnn_two = nn.LSTM(hidden_dim, hidden_dim) 221 | else: 222 | self.rnn = None 223 | self.hidden_one = None 224 | # self.hidden_two = None 225 | 226 | # reduce dim 227 | self.proj = nn.Linear(hidden_dim, proj_dim) 228 | # self.proj_two = nn.Linear(hidden_dim, 4) 229 | self.prob = nn.Sigmoid() 230 | self.prob_layer = nn.Softmax() 231 | 232 | def init_hidden(self, batch_size): 233 | # The axes semantics are (num_layers, minibatch_size, hidden_dim) 234 | return (autograd.Variable(torch.zeros(1, batch_size, 235 | self.hidden_dim).cuda()), 236 | autograd.Variable(torch.zeros(1, batch_size, 237 | self.hidden_dim).cuda())) 238 | 239 | def repackage_hidden(self): 240 | self.hidden_one = repackage_hidden(self.hidden_one) 241 | # self.hidden_two = repackage_hidden(self.hidden_two) 242 | def forward(self, x): 243 | # Take the convolution output of each step 244 | batch_size = x.size(0) 245 | # self.rnn_one.flatten_parameters() 246 | # self.rnn_two.flatten_parameters() 247 | 248 | out_one, self.hidden_one = self.rnn_one(x.view(1, batch_size, -1), self.hidden_one) 249 | 250 | # out_one = F.dropout(out_one, p = 0.1, training=True) 251 | 252 | # out_two, self.hidden_two = self.rnn_two(out_one.view(1, batch_size, -1), self.hidden_two) 253 | 254 | x_one = self.proj(out_one.squeeze()) 255 | # x_two = self.proj_two(out_two.squeeze()) 256 | 257 | # proj = self.proj(out.squeeze()) 258 | prob = self.prob_layer(x_one) 259 | # prob_two = self.prob_layer(x_two) 260 | 261 | # x_one = (prob > 0.5).float().detach() - \ 262 | # prob.detach() + prob 263 | 264 | # x_two = prob_two.detach().cpu().numpy() 265 | 266 | x_one = prob.detach().cpu().numpy() 267 | 268 | hard = (x_one == x_one.max(axis=1)[:,None]).astype(int) 269 | hard = torch.from_numpy(hard) 270 | hard = hard.cuda() 271 | 272 | # x_two = hard.float().detach() - \ 273 | # prob_two.detach() + prob_two 274 | 275 | x_one = hard.float().detach() - \ 276 | prob.detach() + prob 277 | 278 | # print(x_one) 279 | 280 | x_one = x_one.view(x_one.size(0),x_one.size(1), 1, 1, 1) 281 | 282 | # x_two = x_two.view(x_two.size(0), x_two.size(1), 1, 1, 1) 283 | 284 | return x_one # , x_two 285 | 286 | 287 | 288 | class ResNet_RNN(nn.Module): 289 | def __init__(self, block, layers, num_classes=1000, embed_dim=40, hidden_dim=20, proj_dim=7): 290 | self.inplanes = 64 291 | super(ResNet_RNN, self).__init__() 292 | 293 | self.num_layers = layers 294 | self.embed_dim = embed_dim 295 | self.hidden_dim = hidden_dim 296 | 297 | self.control = RNNGate(embed_dim, hidden_dim, proj_dim, rnn_type='lstm') 298 | 299 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 300 | bias=False) 301 | self.bn1 = nn.BatchNorm2d(64) 302 | self.relu = nn.ReLU(inplace=True) 303 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 304 | self.gate_layer1 = nn.Sequential(nn.AvgPool2d(56), 305 | nn.Conv2d(in_channels=64, out_channels=self.embed_dim, kernel_size=1, stride=1)) 306 | 307 | self.layer1 = self._make_group(block, 64, layers[0], group_id=1, pool_size=56) 308 | self.layer2 = self._make_group(block, 128, layers[1], group_id=2, pool_size=28) 309 | self.layer3 = self._make_group(block, 256, layers[2], group_id=3, pool_size=14) 310 | self.layer4 = self._make_group(block, 512, layers[3], group_id=4, pool_size=7) 311 | 312 | self.avgpool = nn.AvgPool2d(7) 313 | self.fc = nn.Linear(512 * block.expansion, num_classes) 314 | 315 | for m in self.modules(): 316 | if isinstance(m, nn.Conv2d): 317 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 318 | m.weight.data.normal_(0, math.sqrt(2. / n)) 319 | if isinstance(m, QConv2d): 320 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 321 | m.weight.data.normal_(0, math.sqrt(2. / n)) 322 | elif isinstance(m, nn.BatchNorm2d): 323 | m.weight.data.fill_(1) 324 | m.bias.data.zero_() 325 | elif isinstance(m, nn.Linear): 326 | n = m.weight.size(0) * m.weight.size(1) 327 | m.weight.data.normal_(0, math.sqrt(2. / n)) 328 | 329 | 330 | def _make_group(self, block, planes, layers, group_id, pool_size): 331 | """ Create the whole group""" 332 | for i in range(layers): 333 | if group_id > 1 and i == 0: 334 | stride = 2 335 | else: 336 | stride = 1 337 | 338 | layer, gate_layer = self._make_layer(block, planes, stride=stride, pool_size=pool_size) 339 | 340 | setattr(self, 'group{}_layer{}'.format(group_id, i), layer) 341 | setattr(self, 'group{}_gate{}'.format(group_id, i), gate_layer) 342 | 343 | 344 | def _make_layer(self, block, planes, pool_size, stride=1): 345 | downsample = None 346 | if stride != 1 or self.inplanes != planes * block.expansion: 347 | downsample = nn.Sequential( 348 | nn.Conv2d(self.inplanes, planes * block.expansion, 349 | kernel_size=1, stride=stride, bias=False), 350 | nn.BatchNorm2d(planes * block.expansion), 351 | ) 352 | 353 | layer = block(self.inplanes, planes, stride, downsample) 354 | self.inplanes = planes * block.expansion 355 | 356 | gate_layer = nn.Sequential( 357 | nn.AvgPool2d(pool_size), 358 | nn.Conv2d(in_channels=planes * block.expansion, 359 | out_channels=self.embed_dim, 360 | kernel_size=1, 361 | stride=1)) 362 | 363 | return layer, gate_layer 364 | 365 | def forward(self, x, bits, grad_bits): 366 | x = self.conv1(x) 367 | x = self.bn1(x) 368 | x = self.relu(x) 369 | x = self.maxpool(x) 370 | 371 | batch_size = x.size(0) 372 | self.control.hidden_one = self.control.init_hidden(batch_size) 373 | 374 | masks = [] 375 | 376 | gate_feature = self.gate_layer1(x) 377 | mask = self.control(gate_feature) 378 | 379 | for g in range(len(self.num_layers)): 380 | for i in range(self.num_layers[g]): 381 | 382 | mask_list = [] 383 | 384 | for j in range(len(bits)): 385 | mask_list.append(mask[:,j,:,:,:]) 386 | 387 | x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x, bits, grad_bits, mask_list) 388 | 389 | mask_list = [mask.squeeze() for mask in mask_list] 390 | 391 | masks.append(mask_list) 392 | 393 | gate_feature = getattr(self, 'group{}_gate{}'.format(g+1, i))(x) 394 | mask = self.control(gate_feature) 395 | # mask_grad = self.control_grad(gate_feature) 396 | 397 | x = self.avgpool(x) 398 | x = x.view(x.size(0), -1) 399 | x = self.fc(x) 400 | 401 | return x, masks 402 | 403 | 404 | def resnet18_rnn(pretrained=False, **kwargs): 405 | model = ResNet_RNN(BasicBlock, [2, 2, 2, 2], **kwargs) 406 | return model 407 | 408 | 409 | def resnet34_rnn(pretrained=False, **kwargs): 410 | model = ResNet_RNN(BasicBlock, [3, 4, 6, 3], **kwargs) 411 | return model 412 | 413 | 414 | def resnet50_rnn(pretrained=False, **kwargs): 415 | model = ResNet_RNN(Bottleneck, [3, 4, 6, 3], **kwargs) 416 | return model 417 | 418 | 419 | 420 | if __name__ == '__main__': 421 | model = resnet18() 422 | from thop import profile 423 | flops, params = profile(model, inputs=(torch.randn(1, 3, 256, 256),)) 424 | print('flops:', flops, 'params:', params) 425 | 426 | 427 | 428 | 429 | -------------------------------------------------------------------------------- /fractrain_imagenet/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/fractrain_imagenet/modules/__init__.py -------------------------------------------------------------------------------- /fractrain_imagenet/modules/bwn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Bounded weight norm 3 | Weight Normalization from https://arxiv.org/abs/1602.07868 4 | taken and adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/weight_norm.py 5 | """ 6 | import torch 7 | from torch.nn.parameter import Parameter 8 | from torch.autograd import Variable, Function 9 | import torch.nn as nn 10 | 11 | 12 | def gather_params(self, memo=None, param_func=lambda s: s._parameters.values()): 13 | if memo is None: 14 | memo = set() 15 | for p in param_func(self): 16 | if p is not None and p not in memo: 17 | memo.add(p) 18 | yield p 19 | for m in self.children(): 20 | for p in gather_params(m, memo, param_func): 21 | yield p 22 | 23 | nn.Module.gather_params = gather_params 24 | 25 | 26 | def _norm(x, dim, p=2): 27 | """Computes the norm over all dimensions except dim""" 28 | if p == float('inf'): # infinity norm 29 | func = lambda x, dim: x.abs().max(dim=dim)[0] 30 | else: 31 | func = lambda x, dim: torch.norm(x, dim=dim, p=p) 32 | if dim is None: 33 | return x.norm(p=p) 34 | elif dim == 0: 35 | output_size = (x.size(0),) + (1,) * (x.dim() - 1) 36 | return func(x.contiguous().view(x.size(0), -1), 1).view(*output_size) 37 | elif dim == x.dim() - 1: 38 | output_size = (1,) * (x.dim() - 1) + (x.size(-1),) 39 | return func(x.contiguous().view(-1, x.size(-1)), 0).view(*output_size) 40 | else: 41 | return _norm(x.transpose(0, dim), 0).transpose(0, dim) 42 | 43 | 44 | def _mean(p, dim): 45 | """Computes the mean over all dimensions except dim""" 46 | if dim is None: 47 | return p.mean() 48 | elif dim == 0: 49 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 50 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size) 51 | elif dim == p.dim() - 1: 52 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 53 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size) 54 | else: 55 | return _mean(p.transpose(0, dim), 0).transpose(0, dim) 56 | 57 | 58 | class BoundedWeighNorm(object): 59 | 60 | def __init__(self, name, dim, p): 61 | self.name = name 62 | self.dim = dim 63 | self.p = p 64 | 65 | def compute_weight(self, module): 66 | v = getattr(module, self.name + '_v') 67 | pre_norm = getattr(module, self.name + '_prenorm') 68 | return v * (pre_norm / _norm(v, self.dim, p=self.p)) 69 | 70 | @staticmethod 71 | def apply(module, name, dim, p): 72 | fn = BoundedWeighNorm(name, dim, p) 73 | 74 | weight = getattr(module, name) 75 | 76 | # remove w from parameter list 77 | del module._parameters[name] 78 | 79 | prenorm = _norm(weight, dim, p=p).mean() 80 | module.register_buffer(name + '_prenorm', prenorm.detach()) 81 | pre_norm = getattr(module, name + '_prenorm') 82 | print(pre_norm) 83 | module.register_parameter(name + '_v', Parameter(weight.data)) 84 | setattr(module, name, fn.compute_weight(module)) 85 | 86 | # recompute weight before every forward() 87 | module.register_forward_pre_hook(fn) 88 | 89 | def gather_normed_params(self, memo=None, param_func=lambda s: fn.compute_weight(s)): 90 | return gather_params(self, memo, param_func) 91 | module.gather_params = gather_normed_params 92 | return fn 93 | 94 | def remove(self, module): 95 | weight = self.compute_weight(module) 96 | delattr(module, self.name) 97 | del module._parameters[self.name + '_prenorm'] 98 | del module._parameters[self.name + '_v'] 99 | module.register_parameter(self.name, Parameter(weight.data)) 100 | 101 | def __call__(self, module, inputs): 102 | setattr(module, self.name, self.compute_weight(module)) 103 | 104 | 105 | def weight_norm(module, name='weight', dim=0, p=2): 106 | r"""Applies weight normalization to a parameter in the given module. 107 | 108 | .. math:: 109 | \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} 110 | 111 | Weight normalization is a reparameterization that decouples the magnitude 112 | of a weight tensor from its direction. This replaces the parameter specified 113 | by `name` (e.g. "weight") with two parameters: one specifying the magnitude 114 | (e.g. "weight_g") and one specifying the direction (e.g. "weight_v"). 115 | Weight normalization is implemented via a hook that recomputes the weight 116 | tensor from the magnitude and direction before every :meth:`~Module.forward` 117 | call. 118 | 119 | By default, with `dim=0`, the norm is computed independently per output 120 | channel/plane. To compute a norm over the entire weight tensor, use 121 | `dim=None`. 122 | 123 | See https://arxiv.org/abs/1602.07868 124 | 125 | Args: 126 | module (nn.Module): containing module 127 | name (str, optional): name of weight parameter 128 | dim (int, optional): dimension over which to compute the norm 129 | 130 | Returns: 131 | The original module with the weight norm hook 132 | 133 | Example:: 134 | 135 | >>> m = weight_norm(nn.Linear(20, 40), name='weight') 136 | Linear (20 -> 40) 137 | >>> m.weight_g.size() 138 | torch.Size([40, 1]) 139 | >>> m.weight_v.size() 140 | torch.Size([40, 20]) 141 | 142 | """ 143 | BoundedWeighNorm.apply(module, name, dim, p) 144 | return module 145 | 146 | 147 | def remove_weight_norm(module, name='weight'): 148 | r"""Removes the weight normalization reparameterization from a module. 149 | 150 | Args: 151 | module (nn.Module): containing module 152 | name (str, optional): name of weight parameter 153 | 154 | Example: 155 | >>> m = weight_norm(nn.Linear(20, 40)) 156 | >>> remove_weight_norm(m) 157 | """ 158 | for k, hook in module._forward_pre_hooks.items(): 159 | if isinstance(hook, BoundedWeighNorm) and hook.name == name: 160 | hook.remove(module) 161 | del module._forward_pre_hooks[k] 162 | return module 163 | 164 | raise ValueError("weight_norm of '{}' not found in {}" 165 | .format(name, module)) 166 | -------------------------------------------------------------------------------- /fractrain_imagenet/modules/quantize.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd.function import InplaceFunction, Function 7 | 8 | QParams = namedtuple('QParams', ['range', 'zero_point', 'num_bits']) 9 | 10 | _DEFAULT_FLATTEN = (1, -1) 11 | _DEFAULT_FLATTEN_GRAD = (0, -1) 12 | 13 | 14 | def _deflatten_as(x, x_full): 15 | shape = list(x.shape) + [1] * (x_full.dim() - x.dim()) 16 | return x.view(*shape) 17 | 18 | 19 | def calculate_qparams(x, num_bits, flatten_dims=_DEFAULT_FLATTEN, reduce_dim=0, reduce_type='mean', keepdim=False, true_zero=False): 20 | with torch.no_grad(): 21 | x_flat = x.flatten(*flatten_dims) 22 | if x_flat.dim() == 1: 23 | min_values = _deflatten_as(x_flat.min(), x) 24 | max_values = _deflatten_as(x_flat.max(), x) 25 | else: 26 | min_values = _deflatten_as(x_flat.min(-1)[0], x) 27 | max_values = _deflatten_as(x_flat.max(-1)[0], x) 28 | if reduce_dim is not None: 29 | if reduce_type == 'mean': 30 | min_values = min_values.mean(reduce_dim, keepdim=keepdim) 31 | max_values = max_values.mean(reduce_dim, keepdim=keepdim) 32 | else: 33 | min_values = min_values.min(reduce_dim, keepdim=keepdim)[0] 34 | max_values = max_values.max(reduce_dim, keepdim=keepdim)[0] 35 | # TODO: re-add true zero computation 36 | range_values = max_values - min_values 37 | return QParams(range=range_values, zero_point=min_values, 38 | num_bits=num_bits) 39 | 40 | 41 | class UniformQuantize(InplaceFunction): 42 | 43 | @staticmethod 44 | def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN, 45 | reduce_dim=0, dequantize=True, signed=False, stochastic=False, inplace=False): 46 | 47 | ctx.inplace = inplace 48 | 49 | if ctx.inplace: 50 | ctx.mark_dirty(input) 51 | output = input 52 | else: 53 | output = input.clone() 54 | 55 | if qparams is None: 56 | assert num_bits is not None, "either provide qparams of num_bits to quantize" 57 | qparams = calculate_qparams( 58 | input, num_bits=num_bits, flatten_dims=flatten_dims, reduce_dim=reduce_dim) 59 | 60 | zero_point = qparams.zero_point 61 | num_bits = qparams.num_bits 62 | qmin = -(2.**(num_bits - 1)) if signed else 0. 63 | qmax = qmin + 2.**num_bits - 1. 64 | scale = qparams.range / (qmax - qmin) 65 | 66 | min_scale = torch.tensor(1e-8).expand_as(scale).cuda() 67 | scale = torch.max(scale, min_scale) 68 | 69 | with torch.no_grad(): 70 | output.add_(qmin * scale - zero_point).div_(scale) 71 | if stochastic: 72 | noise = output.new(output.shape).uniform_(-0.5, 0.5) 73 | output.add_(noise) 74 | # quantize 75 | output.clamp_(qmin, qmax).round_() 76 | 77 | if dequantize: 78 | output.mul_(scale).add_( 79 | zero_point - qmin * scale) # dequantize 80 | return output 81 | 82 | @staticmethod 83 | def backward(ctx, grad_output): 84 | # straight-through estimator 85 | grad_input = grad_output 86 | return grad_input, None, None, None, None, None, None, None, None 87 | 88 | 89 | class UniformQuantizeGrad(InplaceFunction): 90 | 91 | @staticmethod 92 | def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD, 93 | reduce_dim=0, dequantize=True, signed=False, stochastic=True): 94 | ctx.num_bits = num_bits 95 | ctx.qparams = qparams 96 | ctx.flatten_dims = flatten_dims 97 | ctx.stochastic = stochastic 98 | ctx.signed = signed 99 | ctx.dequantize = dequantize 100 | ctx.reduce_dim = reduce_dim 101 | ctx.inplace = False 102 | return input 103 | 104 | @staticmethod 105 | def backward(ctx, grad_output): 106 | qparams = ctx.qparams 107 | with torch.no_grad(): 108 | if qparams is None: 109 | assert ctx.num_bits is not None, "either provide qparams of num_bits to quantize" 110 | qparams = calculate_qparams( 111 | grad_output, num_bits=ctx.num_bits, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim, reduce_type='extreme') 112 | 113 | grad_input = quantize(grad_output, num_bits=None, 114 | qparams=qparams, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim, 115 | dequantize=True, signed=ctx.signed, stochastic=ctx.stochastic, inplace=False) 116 | return grad_input, None, None, None, None, None, None, None 117 | 118 | 119 | def conv2d_biprec(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, num_bits_grad=None): 120 | out1 = F.conv2d(input.detach(), weight, bias, 121 | stride, padding, dilation, groups) 122 | out2 = F.conv2d(input, weight.detach(), bias.detach() if bias is not None else None, 123 | stride, padding, dilation, groups) 124 | out2 = quantize_grad(out2, num_bits=num_bits_grad, flatten_dims=(1, -1)) 125 | return out1 + out2 - out1.detach() 126 | 127 | 128 | def linear_biprec(input, weight, bias=None, num_bits_grad=None): 129 | out1 = F.linear(input.detach(), weight, bias) 130 | out2 = F.linear(input, weight.detach(), bias.detach() 131 | if bias is not None else None) 132 | out2 = quantize_grad(out2, num_bits=num_bits_grad) 133 | return out1 + out2 - out1.detach() 134 | 135 | 136 | def quantize(x, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN, reduce_dim=0, dequantize=True, signed=False, stochastic=False, inplace=False): 137 | if qparams: 138 | if qparams.num_bits: 139 | return UniformQuantize().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic, inplace) 140 | elif num_bits: 141 | return UniformQuantize().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic, inplace) 142 | 143 | return x 144 | 145 | 146 | def quantize_grad(x, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD, reduce_dim=0, dequantize=True, signed=False, stochastic=True): 147 | if qparams: 148 | if qparams.num_bits: 149 | return UniformQuantizeGrad().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic) 150 | elif num_bits: 151 | return UniformQuantizeGrad().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic) 152 | 153 | return x 154 | 155 | 156 | class QuantMeasure(nn.Module): 157 | """docstring for QuantMeasure.""" 158 | 159 | def __init__(self, shape_measure=(1,), flatten_dims=_DEFAULT_FLATTEN, 160 | inplace=False, dequantize=True, stochastic=False, momentum=0.9, measure=False): 161 | super(QuantMeasure, self).__init__() 162 | self.register_buffer('running_zero_point', torch.zeros(*shape_measure)) 163 | self.register_buffer('running_range', torch.zeros(*shape_measure)) 164 | self.measure = measure 165 | if self.measure: 166 | self.register_buffer('num_measured', torch.zeros(1)) 167 | self.flatten_dims = flatten_dims 168 | self.momentum = momentum 169 | self.dequantize = dequantize 170 | self.stochastic = stochastic 171 | self.inplace = inplace 172 | 173 | def forward(self, input, num_bits, qparams=None): 174 | 175 | if self.training or self.measure: 176 | if qparams is None: 177 | qparams = calculate_qparams( 178 | input, num_bits=num_bits, flatten_dims=self.flatten_dims, reduce_dim=0, reduce_type='extreme') 179 | with torch.no_grad(): 180 | if self.measure: 181 | momentum = self.num_measured / (self.num_measured + 1) 182 | self.num_measured += 1 183 | else: 184 | momentum = self.momentum 185 | self.running_zero_point.mul_(momentum).add_( 186 | qparams.zero_point * (1 - momentum)) 187 | self.running_range.mul_(momentum).add_( 188 | qparams.range * (1 - momentum)) 189 | else: 190 | qparams = QParams(range=self.running_range, 191 | zero_point=self.running_zero_point, num_bits=num_bits) 192 | if self.measure: 193 | return input 194 | else: 195 | q_input = quantize(input, qparams=qparams, dequantize=self.dequantize, 196 | stochastic=self.stochastic, inplace=self.inplace) 197 | return q_input 198 | 199 | 200 | class QConv2d(nn.Conv2d): 201 | """docstring for QConv2d.""" 202 | 203 | def __init__(self, in_channels, out_channels, kernel_size, 204 | stride=1, padding=0, dilation=1, groups=1, bias=True, momentum=0.1, quant_act_forward=0, quant_act_backward=0, quant_grad_act_error=0, quant_grad_act_gc=0, weight_bits=0, fix_prec=False): 205 | super(QConv2d, self).__init__(in_channels, out_channels, kernel_size, 206 | stride, padding, dilation, groups, bias) 207 | 208 | self.quantize_input_fw = QuantMeasure(shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1), momentum=momentum) 209 | self.quantize_input_bw = QuantMeasure(shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1), momentum=momentum) 210 | self.quant_act_forward = quant_act_forward 211 | self.quant_act_backward = quant_act_backward 212 | self.quant_grad_act_error = quant_grad_act_error 213 | self.quant_grad_act_gc = quant_grad_act_gc 214 | self.weight_bits = weight_bits 215 | self.fix_prec = fix_prec 216 | self.stride = stride 217 | 218 | 219 | def forward(self, input, num_bits, num_grad_bits, mask_list): 220 | 221 | input_candidates = [self.quantize_input_fw(input, num_bits=bit) for bit in num_bits] 222 | x = sum([mask_list[k].expand_as(input) * input_candidates[k] for k in range(len(num_bits))]) 223 | 224 | weight_qparams = calculate_qparams(self.weight, num_bits=self.weight_bits, flatten_dims=(1, -1), reduce_dim=None) 225 | qweight = quantize(self.weight, qparams=weight_qparams) 226 | qbias = None 227 | 228 | # qinput = self.quantize_input_fw(input, num_bits) 229 | output = F.conv2d(x, qweight, qbias, self.stride, self.padding, self.dilation, self.groups) 230 | 231 | output_candidates = [quantize_grad(output, num_bits=bit) for bit in num_grad_bits] 232 | x = sum([mask_list[k].expand_as(output).detach() * output_candidates[k] for k in range(len(num_grad_bits))]) 233 | 234 | return x 235 | 236 | 237 | def conv2d_quant_act(self, input_fw, input_bw, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, error_bits=0, gc_bits=0): 238 | out1 = F.conv2d(input_fw, weight.detach(), bias.detach() if bias is not None else None, 239 | stride, padding, dilation, groups) 240 | out2 = F.conv2d(input_bw.detach(), weight, bias, 241 | stride, padding, dilation, groups) 242 | out1 = quantize_grad(out1, num_bits=error_bits) 243 | out2 = quantize_grad(out2, num_bits=gc_bits) 244 | return out1 + out2 - out2.detach() 245 | 246 | 247 | class QLinear(nn.Linear): 248 | """docstring for QConv2d.""" 249 | 250 | def __init__(self, in_features, out_features, bias=True, num_bits=8, num_bits_weight=8, num_bits_grad=8, biprecision=True): 251 | super(QLinear, self).__init__(in_features, out_features, bias) 252 | self.num_bits = num_bits 253 | self.num_bits_weight = num_bits_weight or num_bits 254 | self.num_bits_grad = num_bits_grad 255 | self.biprecision = biprecision 256 | self.quantize_input = QuantMeasure(self.num_bits) 257 | 258 | def forward(self, input): 259 | qinput = self.quantize_input(input) 260 | weight_qparams = calculate_qparams( 261 | self.weight, num_bits=self.num_bits_weight, flatten_dims=(1, -1), reduce_dim=None) 262 | qweight = quantize(self.weight, qparams=weight_qparams) 263 | if self.bias is not None: 264 | qbias = quantize( 265 | self.bias, num_bits=self.num_bits_weight + self.num_bits, 266 | flatten_dims=(0, -1)) 267 | else: 268 | qbias = None 269 | 270 | if not self.biprecision or self.num_bits_grad is None: 271 | output = F.linear(qinput, qweight, qbias) 272 | if self.num_bits_grad is not None: 273 | output = quantize_grad( 274 | output, num_bits=self.num_bits_grad) 275 | else: 276 | output = linear_biprec(qinput, qweight, qbias, self.num_bits_grad) 277 | return output 278 | 279 | 280 | class RangeBN(nn.Module): 281 | # this is normalized RangeBN 282 | 283 | def __init__(self, num_features, dim=1, momentum=0.1, affine=True, num_chunks=16, eps=1e-5, num_bits=8, num_bits_grad=8): 284 | super(RangeBN, self).__init__() 285 | self.register_buffer('running_mean', torch.zeros(num_features)) 286 | self.register_buffer('running_var', torch.zeros(num_features)) 287 | 288 | self.momentum = momentum 289 | self.dim = dim 290 | if affine: 291 | self.bias = nn.Parameter(torch.Tensor(num_features)) 292 | self.weight = nn.Parameter(torch.Tensor(num_features)) 293 | self.num_bits = num_bits 294 | self.num_bits_grad = num_bits_grad 295 | self.quantize_input = QuantMeasure(inplace=True, shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1)) 296 | self.eps = eps 297 | self.num_chunks = num_chunks 298 | self.reset_params() 299 | 300 | def reset_params(self): 301 | if self.weight is not None: 302 | self.weight.data.uniform_() 303 | if self.bias is not None: 304 | self.bias.data.zero_() 305 | 306 | def forward(self, x, num_bits, num_grad_bits): 307 | x = self.quantize_input(x, num_bits) 308 | if x.dim() == 2: # 1d 309 | x = x.unsqueeze(-1,).unsqueeze(-1) 310 | 311 | if self.training: 312 | B, C, H, W = x.shape 313 | y = x.transpose(0, 1).contiguous() # C x B x H x W 314 | y = y.view(C, self.num_chunks, (B * H * W) // self.num_chunks) 315 | mean_max = y.max(-1)[0].mean(-1) # C 316 | mean_min = y.min(-1)[0].mean(-1) # C 317 | mean = y.view(C, -1).mean(-1) # C 318 | scale_fix = (0.5 * 0.35) * (1 + (math.pi * math.log(4)) ** 319 | 0.5) / ((2 * math.log(y.size(-1))) ** 0.5) 320 | 321 | scale = (mean_max - mean_min) * scale_fix 322 | with torch.no_grad(): 323 | self.running_mean.mul_(self.momentum).add_( 324 | mean * (1 - self.momentum)) 325 | 326 | self.running_var.mul_(self.momentum).add_( 327 | scale * (1 - self.momentum)) 328 | else: 329 | mean = self.running_mean 330 | scale = self.running_var 331 | # scale = quantize(scale, num_bits=self.num_bits, min_value=float( 332 | # scale.min()), max_value=float(scale.max())) 333 | out = (x - mean.view(1, -1, 1, 1)) / \ 334 | (scale.view(1, -1, 1, 1) + self.eps) 335 | 336 | if self.weight is not None: 337 | qweight = self.weight 338 | # qweight = quantize(self.weight, num_bits=self.num_bits, 339 | # min_value=float(self.weight.min()), 340 | # max_value=float(self.weight.max())) 341 | out = out * qweight.view(1, -1, 1, 1) 342 | 343 | if self.bias is not None: 344 | qbias = self.bias 345 | # qbias = quantize(self.bias, num_bits=self.num_bits) 346 | out = out + qbias.view(1, -1, 1, 1) 347 | if num_grad_bits: 348 | out = quantize_grad( 349 | out, num_bits=num_grad_bits, flatten_dims=(1, -1)) 350 | 351 | if out.size(3) == 1 and out.size(2) == 1: 352 | out = out.squeeze(-1).squeeze(-1) 353 | return out 354 | 355 | 356 | class RangeBN1d(RangeBN): 357 | # this is normalized RangeBN 358 | 359 | def __init__(self, num_features, dim=1, momentum=0.1, affine=True, num_chunks=16, eps=1e-5, num_bits=8, num_bits_grad=8): 360 | super(RangeBN1d, self).__init__(num_features, dim, momentum, 361 | affine, num_chunks, eps, num_bits, num_bits_grad) 362 | self.quantize_input = QuantMeasure( 363 | self.num_bits, inplace=True, shape_measure=(1, 1), flatten_dims=(1, -1)) 364 | 365 | if __name__ == '__main__': 366 | x = torch.rand(2, 3) 367 | x_q = quantize(x, flatten_dims=(-1), num_bits=8, dequantize=True) 368 | print(x) 369 | print(x_q) -------------------------------------------------------------------------------- /fractrain_imagenet/modules/rnlu.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.autograd.function import InplaceFunction 4 | from torch.autograd import Variable 5 | import torch.nn as nn 6 | import math 7 | 8 | 9 | class BiReLUFunction(InplaceFunction): 10 | 11 | @classmethod 12 | def forward(cls, ctx, input, inplace=False): 13 | if input.size(1) % 2 != 0: 14 | raise RuntimeError("dimension 1 of input must be multiple of 2, " 15 | "but got {}".format(input.size(1))) 16 | ctx.inplace = inplace 17 | 18 | if ctx.inplace: 19 | ctx.mark_dirty(input) 20 | output = input 21 | else: 22 | output = input.clone() 23 | 24 | pos, neg = output.chunk(2, dim=1) 25 | pos.clamp_(min=0) 26 | neg.clamp_(max=0) 27 | # scale = (pos - neg).view(pos.size(0), -1).mean(1).div_(2) 28 | # output. 29 | ctx.save_for_backward(output) 30 | return output 31 | 32 | @staticmethod 33 | def backward(ctx, grad_output): 34 | output, = ctx.saved_variables 35 | grad_input = grad_output.masked_fill(output.eq(0), 0) 36 | return grad_input, None 37 | 38 | 39 | def birelu(x, inplace=False): 40 | return BiReLUFunction().apply(x, inplace) 41 | 42 | 43 | class BiReLU(nn.Module): 44 | """docstring for BiReLU.""" 45 | 46 | def __init__(self, inplace=False): 47 | super(BiReLU, self).__init__() 48 | self.inplace = inplace 49 | 50 | def forward(self, inputs): 51 | return birelu(inputs, inplace=self.inplace) 52 | 53 | 54 | def binorm(x, shift=0, scale_fix=(2 / math.pi) ** 0.5): 55 | pos, neg = (x + shift).split(2, dim=1) 56 | scale = (pos - neg).view(pos.size(0), -1).mean(1).div_(2) * scale_fix 57 | return x / scale 58 | 59 | 60 | def _mean(p, dim): 61 | """Computes the mean over all dimensions except dim""" 62 | if dim is None: 63 | return p.mean() 64 | elif dim == 0: 65 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 66 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size) 67 | elif dim == p.dim() - 1: 68 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 69 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size) 70 | else: 71 | return _mean(p.transpose(0, dim), 0).transpose(0, dim) 72 | 73 | 74 | def rnlu(x, inplace=False, shift=0, scale_fix=(math.pi / 2) ** 0.5): 75 | x = birelu(x, inplace=inplace) 76 | pos, neg = (x + shift).chunk(2, dim=1) 77 | # scale = torch.cat((_mean(pos, 1), -_mean(neg, 1)), 1) * scale_fix + 1e-5 78 | scale = (pos - neg).view(pos.size(0), -1).mean(1) * scale_fix + 1e-8 79 | return x / scale.view(scale.size(0), *([1] * (x.dim() - 1))) 80 | 81 | 82 | class RnLU(nn.Module): 83 | """docstring for RnLU.""" 84 | 85 | def __init__(self, inplace=False): 86 | super(RnLU, self).__init__() 87 | self.inplace = inplace 88 | 89 | def forward(self, x): 90 | return rnlu(x, inplace=self.inplace) 91 | 92 | # output. 93 | if __name__ == "__main__": 94 | x = Variable(torch.randn(2, 16, 5, 5).cuda(), requires_grad=True) 95 | output = rnlu(x) 96 | 97 | output.sum().backward() 98 | -------------------------------------------------------------------------------- /fractrain_imagenet/train_base.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | import os 10 | import shutil 11 | import argparse 12 | import time 13 | import logging 14 | 15 | import models 16 | from data import * 17 | 18 | 19 | model_names = sorted(name for name in models.__dict__ 20 | if name.islower() and not name.startswith('__') 21 | and callable(models.__dict__[name]) 22 | ) 23 | 24 | 25 | def parse_args(): 26 | # hyper-parameters are from ResNet paper 27 | parser = argparse.ArgumentParser( 28 | description='Quantization Aware Training on ImageNet') 29 | parser.add_argument('--dir', help='annotate the working directory') 30 | parser.add_argument('--cmd', choices=['train', 'test'], default='train') 31 | parser.add_argument('--arch', metavar='ARCH', default='resnet50', 32 | choices=model_names, 33 | help='model architecture: ' + 34 | ' | '.join(model_names) + 35 | ' (default: cifar10_resnet_38)') 36 | parser.add_argument('--dataset', '-d', type=str, default='imagenet', 37 | choices=['cifar10', 'cifar100','imagenet'], 38 | help='dataset choice') 39 | parser.add_argument('--datadir', default='/home/yf22/dataset', type=str, 40 | help='path to dataset') 41 | parser.add_argument('--workers', default=16, type=int, metavar='N', 42 | help='number of data loading workers (default: 4 )') 43 | parser.add_argument('--epoch', default=90, type=int, 44 | help='number of epochs (default: 90)') 45 | parser.add_argument('--start_epoch', default=0, type=int, 46 | help='manual iter number (useful on restarts)') 47 | parser.add_argument('--batch_size', default=256, type=int, 48 | help='mini-batch size (default: 128)') 49 | parser.add_argument('--lr_schedule', default='piecewise', type=str, 50 | help='learning rate schedule') 51 | parser.add_argument('--lr', default=0.1, type=float, 52 | help='initial learning rate') 53 | parser.add_argument('--momentum', default=0.9, type=float, 54 | help='momentum') 55 | parser.add_argument('--weight_decay', default=1e-4, type=float, 56 | help='weight decay (default: 1e-4)') 57 | parser.add_argument('--print_freq', default=10, type=int, 58 | help='print frequency (default: 10)') 59 | parser.add_argument('--resume', default='', type=str, 60 | help='path to latest checkpoint (default: None)') 61 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 62 | help='use pretrained model') 63 | parser.add_argument('--step_ratio', default=0.1, type=float, 64 | help='ratio for learning rate deduction') 65 | parser.add_argument('--warm_up', action='store_true', 66 | help='for n = 18, the model needs to warm up for 400 ' 67 | 'iterations') 68 | parser.add_argument('--save_folder', default='save_checkpoints', 69 | type=str, 70 | help='folder to save the checkpoints') 71 | parser.add_argument('--eval_every', default=390, type=int, 72 | help='evaluate model every (default: 1000) iterations') 73 | parser.add_argument('--num_bits',default=0,type=int, 74 | help='num bits for weight and activation') 75 | parser.add_argument('--num_grad_bits',default=0,type=int, 76 | help='num bits for gradient') 77 | parser.add_argument('--schedule', default=None, type=int, nargs='*', 78 | help='precision schedule') 79 | parser.add_argument('--num_bits_schedule',default=None,type=int,nargs='*', 80 | help='schedule for weight/act precision') 81 | parser.add_argument('--num_grad_bits_schedule',default=None,type=int,nargs='*', 82 | help='schedule for grad precision') 83 | parser.add_argument('--act_fw', default=0, type=int, 84 | help='precision of activation during forward, -1 means dynamic, 0 means no quantize') 85 | parser.add_argument('--act_bw', default=0, type=int, 86 | help='precision of activation during backward, -1 means dynamic, 0 means no quantize') 87 | parser.add_argument('--grad_act_error', default=0, type=int, 88 | help='precision of activation gradient during error backward, -1 means dynamic, 0 means no quantize') 89 | parser.add_argument('--grad_act_gc', default=0, type=int, 90 | help='precision of activation gradient during weight gradient computation, -1 means dynamic, 0 means no quantize') 91 | parser.add_argument('--weight_bits', default=0, type=int, 92 | help='precision of weight') 93 | parser.add_argument('--momentum_act', default=0.9, type=float, 94 | help='momentum for act min/max') 95 | args = parser.parse_args() 96 | return args 97 | 98 | 99 | def main(): 100 | args = parse_args() 101 | save_path = args.save_path = os.path.join(args.save_folder, args.arch) 102 | if not os.path.exists(save_path): 103 | os.makedirs(save_path) 104 | 105 | models.ACT_FW = args.act_fw 106 | models.ACT_BW = args.act_bw 107 | models.GRAD_ACT_ERROR = args.grad_act_error 108 | models.GRAD_ACT_GC = args.grad_act_gc 109 | models.WEIGHT_BITS = args.weight_bits 110 | models.MOMENTUM = args.momentum_act 111 | 112 | args.num_bits = args.num_bits if not (args.act_fw + args.act_bw + args.grad_act_error + args.grad_act_gc + args.weight_bits) else -1 113 | 114 | # config logging file 115 | args.logger_file = os.path.join(save_path, 'log_{}.txt'.format(args.cmd)) 116 | handlers = [logging.FileHandler(args.logger_file, mode='w'), 117 | logging.StreamHandler()] 118 | logging.basicConfig(level=logging.INFO, 119 | datefmt='%m-%d-%y %H:%M', 120 | format='%(asctime)s:%(message)s', 121 | handlers=handlers) 122 | 123 | if args.cmd == 'train': 124 | logging.info('start training {}'.format(args.arch)) 125 | run_training(args) 126 | 127 | elif args.cmd == 'test': 128 | logging.info('start evaluating {} with checkpoints from {}'.format( 129 | args.arch, args.resume)) 130 | test_model(args) 131 | 132 | 133 | def run_training(args): 134 | # create model 135 | model = models.__dict__[args.arch](args.pretrained) 136 | model = torch.nn.DataParallel(model).cuda() 137 | 138 | best_prec1 = 0 139 | best_full_prec = 0 140 | 141 | # optionally resume from a checkpoint 142 | if args.resume: 143 | if os.path.isfile(args.resume): 144 | logging.info('=> loading checkpoint `{}`'.format(args.resume)) 145 | checkpoint = torch.load(args.resume) 146 | args.start_epoch = checkpoint['epoch'] 147 | best_prec1 = checkpoint['best_prec1'] 148 | model.load_state_dict(checkpoint['state_dict']) 149 | logging.info('=> loaded checkpoint `{}` (epoch: {})'.format( 150 | args.resume, checkpoint['epoch'] 151 | )) 152 | else: 153 | logging.info('=> no checkpoint found at `{}`'.format(args.resume)) 154 | 155 | cudnn.benchmark = False 156 | 157 | train_loader = prepare_train_data(dataset=args.dataset, 158 | datadir=args.datadir+'/train', 159 | batch_size=args.batch_size, 160 | shuffle=True, 161 | num_workers=args.workers) 162 | test_loader = prepare_test_data(dataset=args.dataset, 163 | datadir=args.datadir+'/val', 164 | batch_size=args.batch_size, 165 | shuffle=False, 166 | num_workers=args.workers) 167 | 168 | # define loss function (criterion) and optimizer 169 | criterion = nn.CrossEntropyLoss().cuda() 170 | 171 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 172 | momentum=args.momentum, 173 | weight_decay=args.weight_decay) 174 | 175 | # optimizer = torch.optim.Adam(model.parameters(), args.lr, 176 | # weight_decay=args.weight_decay) 177 | 178 | batch_time = AverageMeter() 179 | data_time = AverageMeter() 180 | losses = AverageMeter() 181 | top1 = AverageMeter() 182 | cr = AverageMeter() 183 | 184 | end = time.time() 185 | 186 | for _epoch in range(args.start_epoch, args.epoch): 187 | lr = adjust_learning_rate(args, optimizer, _epoch) 188 | adjust_precision(args, _epoch) 189 | 190 | print('Learning Rate:', lr) 191 | print('num bits:', args.num_bits, 'num grad bits:', args.num_grad_bits) 192 | 193 | for i, (input, target) in enumerate(train_loader): 194 | # measuring data loading time 195 | data_time.update(time.time() - end) 196 | 197 | model.train() 198 | 199 | fw_cost = args.num_bits*args.num_bits/32/32 200 | eb_cost = args.num_bits*args.num_grad_bits/32/32 201 | gc_cost = eb_cost 202 | cr.update((fw_cost+eb_cost+gc_cost)/3) 203 | 204 | target = target.squeeze().long().cuda() 205 | input_var = Variable(input).cuda() 206 | target_var = Variable(target).cuda() 207 | 208 | # compute output 209 | output = model(input_var, args.num_bits, args.num_grad_bits) 210 | loss = criterion(output, target_var) 211 | 212 | # measure accuracy and record loss 213 | prec1, = accuracy(output.data, target, topk=(1,)) 214 | losses.update(loss.item(), input.size(0)) 215 | top1.update(prec1.item(), input.size(0)) 216 | 217 | # compute gradient and do SGD step 218 | optimizer.zero_grad() 219 | loss.backward() 220 | optimizer.step() 221 | 222 | # measure elapsed time 223 | batch_time.update(time.time() - end) 224 | end = time.time() 225 | 226 | # print log 227 | if i % args.print_freq == 0: 228 | logging.info("Iter: [{0}][{1}/{2}]\t" 229 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 230 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" 231 | "Loss {loss.val:.3f} ({loss.avg:.3f})\t" 232 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t".format( 233 | _epoch, 234 | i, 235 | len(train_loader), 236 | batch_time=batch_time, 237 | data_time=data_time, 238 | loss=losses, 239 | top1=top1) 240 | ) 241 | 242 | with torch.no_grad(): 243 | prec1 = validate(args, test_loader, model, criterion, _epoch) 244 | # prec_full = validate_full_prec(args, test_loader, model, criterion, i) 245 | 246 | is_best = prec1 > best_prec1 247 | best_prec1 = max(prec1, best_prec1) 248 | #best_full_prec = max(prec_full, best_full_prec) 249 | 250 | print("Current Best Prec@1: ", best_prec1) 251 | #print("Current Best Full Prec@1: ", best_full_prec) 252 | 253 | checkpoint_path = os.path.join(args.save_path, 'checkpoint_{:05d}_{:.2f}.pth.tar'.format(_epoch, prec1)) 254 | save_checkpoint({ 255 | 'epoch': _epoch, 256 | 'arch': args.arch, 257 | 'state_dict': model.state_dict(), 258 | 'best_prec1': best_prec1, 259 | }, 260 | is_best, filename=checkpoint_path) 261 | shutil.copyfile(checkpoint_path, os.path.join(args.save_path, 262 | 'checkpoint_latest' 263 | '.pth.tar')) 264 | 265 | 266 | 267 | def validate(args, test_loader, model, criterion, _epoch): 268 | batch_time = AverageMeter() 269 | losses = AverageMeter() 270 | top1 = AverageMeter() 271 | 272 | # switch to evaluation mode 273 | model.eval() 274 | end = time.time() 275 | for i, (input, target) in enumerate(test_loader): 276 | target = target.squeeze().long().cuda() 277 | input_var = Variable(input, volatile=True).cuda() 278 | target_var = Variable(target, volatile=True).cuda() 279 | 280 | # compute output 281 | output = model(input_var, args.num_bits, args.num_grad_bits) 282 | loss = criterion(output, target_var) 283 | 284 | # measure accuracy and record loss 285 | prec1, = accuracy(output.data, target, topk=(1,)) 286 | top1.update(prec1.item(), input.size(0)) 287 | losses.update(loss.item(), input.size(0)) 288 | batch_time.update(time.time() - end) 289 | end = time.time() 290 | 291 | if (i % args.print_freq == 0) or (i == len(test_loader) - 1): 292 | logging.info( 293 | 'Test: [{}/{}]\t' 294 | 'Time: {batch_time.val:.4f}({batch_time.avg:.4f})\t' 295 | 'Loss: {loss.val:.3f}({loss.avg:.3f})\t' 296 | 'Prec@1: {top1.val:.3f}({top1.avg:.3f})\t'.format( 297 | i, len(test_loader), batch_time=batch_time, 298 | loss=losses, top1=top1 299 | ) 300 | ) 301 | 302 | logging.info('Epoch {} * Prec@1 {top1.avg:.3f}'.format(_epoch, top1=top1)) 303 | return top1.avg 304 | 305 | 306 | def validate_full_prec(args, test_loader, model, criterion, _epoch): 307 | batch_time = AverageMeter() 308 | losses = AverageMeter() 309 | top1 = AverageMeter() 310 | 311 | # switch to evaluation mode 312 | model.eval() 313 | end = time.time() 314 | for i, (input, target) in enumerate(test_loader): 315 | target = target.squeeze().long().cuda() 316 | input_var = Variable(input, volatile=True).cuda() 317 | target_var = Variable(target, volatile=True).cuda() 318 | 319 | # compute output 320 | output = model(input_var, 0, 0) 321 | loss = criterion(output, target_var) 322 | 323 | # measure accuracy and record loss 324 | prec1, = accuracy(output.data, target, topk=(1,)) 325 | top1.update(prec1.item(), input.size(0)) 326 | losses.update(loss.item(), input.size(0)) 327 | batch_time.update(time.time() - end) 328 | end = time.time() 329 | 330 | 331 | logging.info('Epoch {} * Full Prec@1 {top1.avg:.3f}'.format(_epoch, top1=top1)) 332 | return top1.avg 333 | 334 | 335 | def test_model(args): 336 | # create model 337 | model = models.__dict__[args.arch](args.pretrained) 338 | model = torch.nn.DataParallel(model).cuda() 339 | 340 | if args.resume: 341 | if os.path.isfile(args.resume): 342 | logging.info('=> loading checkpoint `{}`'.format(args.resume)) 343 | checkpoint = torch.load(args.resume) 344 | args.start_epoch = checkpoint['epoch'] 345 | best_prec1 = checkpoint['best_prec1'] 346 | model.load_state_dict(checkpoint['state_dict']) 347 | logging.info('=> loaded checkpoint `{}` (epoch: {})'.format( 348 | args.resume, checkpoint['epoch'] 349 | )) 350 | else: 351 | logging.info('=> no checkpoint found at `{}`'.format(args.resume)) 352 | 353 | cudnn.benchmark = False 354 | test_loader = prepare_test_data(dataset=args.dataset, 355 | batch_size=args.batch_size, 356 | shuffle=False, 357 | num_workers=args.workers) 358 | criterion = nn.CrossEntropyLoss().cuda() 359 | 360 | # validate(args, test_loader, model, criterion) 361 | 362 | with torch.no_grad(): 363 | prec1 = validate(args, test_loader, model, criterion, args.start_iter) 364 | # prec_full = validate_full_prec(args, test_loader, model, criterion, args.start_iter) 365 | 366 | 367 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 368 | torch.save(state, filename) 369 | if is_best: 370 | save_path = os.path.dirname(filename) 371 | shutil.copyfile(filename, os.path.join(save_path, 372 | 'model_best.pth.tar')) 373 | 374 | 375 | class AverageMeter(object): 376 | """Computes and stores the average and current value""" 377 | 378 | def __init__(self): 379 | self.reset() 380 | 381 | def reset(self): 382 | self.val = 0 383 | self.avg = 0 384 | self.sum = 0 385 | self.count = 0 386 | 387 | def update(self, val, n=1): 388 | self.val = val 389 | self.sum += val * n 390 | self.count += n 391 | self.avg = self.sum / self.count 392 | 393 | 394 | schedule_cnt = 0 395 | def adjust_precision(args, _epoch): 396 | if args.schedule: 397 | global schedule_cnt 398 | 399 | assert len(args.num_bits_schedule) == len(args.schedule) + 1 400 | assert len(args.num_grad_bits_schedule) == len(args.schedule) + 1 401 | 402 | if schedule_cnt == 0: 403 | args.num_bits = args.num_bits_schedule[0] 404 | args.num_grad_bits = args.num_grad_bits_schedule[0] 405 | schedule_cnt += 1 406 | 407 | for step in args.schedule: 408 | if _epoch == step: 409 | args.num_bits = args.num_bits_schedule[schedule_cnt] 410 | args.num_grad_bits = args.num_grad_bits_schedule[schedule_cnt] 411 | schedule_cnt += 1 412 | 413 | 414 | def adjust_learning_rate(args, optimizer, _epoch): 415 | lr = args.lr * (0.1 ** (_epoch // 30)) 416 | 417 | for param_group in optimizer.param_groups: 418 | param_group['lr'] = lr 419 | 420 | return lr 421 | 422 | 423 | def accuracy(output, target, topk=(1,)): 424 | """Computes the precision@k for the specified values of k""" 425 | maxk = max(topk) 426 | batch_size = target.size(0) 427 | 428 | _, pred = output.topk(maxk, 1, True, True) 429 | pred = pred.t() 430 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 431 | 432 | res = [] 433 | for k in topk: 434 | correct_k = correct[:k].view(-1).float().sum(0) 435 | res.append(correct_k.mul_(100.0 / batch_size)) 436 | return res 437 | 438 | 439 | if __name__ == '__main__': 440 | main() 441 | -------------------------------------------------------------------------------- /fractrain_imagenet/train_dfq.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | from functools import reduce 10 | 11 | import os 12 | import shutil 13 | import argparse 14 | import time 15 | import logging 16 | 17 | import models 18 | from data import * 19 | 20 | 21 | model_names = sorted(name for name in models.__dict__ 22 | if name.islower() and not name.startswith('__') 23 | and callable(models.__dict__[name]) 24 | ) 25 | 26 | 27 | def parse_args(): 28 | # hyper-parameters are from ResNet paper 29 | parser = argparse.ArgumentParser( 30 | description='DFQ on ImageNet') 31 | parser.add_argument('--dir', help='annotate the working directory') 32 | parser.add_argument('--cmd', choices=['train', 'test'], default='train') 33 | parser.add_argument('--arch', metavar='ARCH', default='resnet50', 34 | choices=model_names, 35 | help='model architecture: ' + 36 | ' | '.join(model_names) + 37 | ' (default: cifar10_resnet_38)') 38 | parser.add_argument('--dataset', '-d', type=str, default='imagenet', 39 | choices=['cifar10', 'cifar100','imagenet'], 40 | help='dataset choice') 41 | parser.add_argument('--datadir', default='/home/yf22/dataset', type=str, 42 | help='path to dataset') 43 | parser.add_argument('--workers', default=16, type=int, metavar='N', 44 | help='number of data loading workers (default: 4 )') 45 | parser.add_argument('--epoch', default=90, type=int, 46 | help='number of epochs (default: 90)') 47 | parser.add_argument('--start_epoch', default=0, type=int, 48 | help='manual iter number (useful on restarts)') 49 | parser.add_argument('--batch_size', default=256, type=int, 50 | help='mini-batch size (default: 128)') 51 | parser.add_argument('--lr_schedule', default='piecewise', type=str, 52 | help='learning rate schedule') 53 | parser.add_argument('--lr', default=0.1, type=float, 54 | help='initial learning rate') 55 | parser.add_argument('--momentum', default=0.9, type=float, 56 | help='momentum') 57 | parser.add_argument('--weight_decay', default=1e-4, type=float, 58 | help='weight decay (default: 1e-4)') 59 | parser.add_argument('--print_freq', default=10, type=int, 60 | help='print frequency (default: 10)') 61 | parser.add_argument('--resume', default='', type=str, 62 | help='path to latest checkpoint (default: None)') 63 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 64 | help='use pretrained model') 65 | parser.add_argument('--step_ratio', default=0.1, type=float, 66 | help='ratio for learning rate deduction') 67 | parser.add_argument('--warm_up', action='store_true', 68 | help='for n = 18, the model needs to warm up for 400 ' 69 | 'iterations') 70 | parser.add_argument('--save_folder', default='save_checkpoints', 71 | type=str, 72 | help='folder to save the checkpoints') 73 | parser.add_argument('--eval_every', default=390, type=int, 74 | help='evaluate model every (default: 1000) iterations') 75 | parser.add_argument('--num_bits',default=0,type=int, 76 | help='num bits for weight and activation') 77 | parser.add_argument('--num_grad_bits',default=0,type=int, 78 | help='num bits for gradient') 79 | parser.add_argument('--schedule', default=None, type=int, nargs='*', 80 | help='precision schedule') 81 | parser.add_argument('--num_bits_schedule',default=None,type=int,nargs='*', 82 | help='schedule for weight/act precision') 83 | parser.add_argument('--num_grad_bits_schedule',default=None,type=int,nargs='*', 84 | help='schedule for grad precision') 85 | parser.add_argument('--act_fw', default=0, type=int, 86 | help='precision of activation during forward, -1 means dynamic, 0 means no quantize') 87 | parser.add_argument('--act_bw', default=0, type=int, 88 | help='precision of activation during backward, -1 means dynamic, 0 means no quantize') 89 | parser.add_argument('--grad_act_error', default=0, type=int, 90 | help='precision of activation gradient during error backward, -1 means dynamic, 0 means no quantize') 91 | parser.add_argument('--grad_act_gc', default=0, type=int, 92 | help='precision of activation gradient during weight gradient computation, -1 means dynamic, 0 means no quantize') 93 | parser.add_argument('--weight_bits', default=0, type=int, 94 | help='precision of weight') 95 | parser.add_argument('--target_ratio',default=4,type=float, 96 | help='target compression ratio') 97 | parser.add_argument('--target_ratio_schedule',default=None,type=float,nargs='*', 98 | help='schedule for target compression ratio') 99 | parser.add_argument('--momentum_act', default=0.9, type=float, 100 | help='momentum for act min/max') 101 | parser.add_argument('--relax', default=0, type=float, 102 | help='relax parameter for target ratio') 103 | parser.add_argument('--beta', default=1e-3, type=float, 104 | help='coefficient') 105 | parser.add_argument('--computation_cost', default=True, type=bool, 106 | help='using computation cost as regularization term') 107 | args = parser.parse_args() 108 | return args 109 | 110 | 111 | def main(): 112 | args = parse_args() 113 | save_path = args.save_path = os.path.join(args.save_folder, args.arch) 114 | if not os.path.exists(save_path): 115 | os.makedirs(save_path) 116 | 117 | models.ACT_FW = args.act_fw 118 | models.ACT_BW = args.act_bw 119 | models.GRAD_ACT_ERROR = args.grad_act_error 120 | models.GRAD_ACT_GC = args.grad_act_gc 121 | models.WEIGHT_BITS = args.weight_bits 122 | models.MOMENTUM = args.momentum_act 123 | 124 | args.num_bits = args.num_bits if not (args.act_fw + args.act_bw + args.grad_act_error + args.grad_act_gc + args.weight_bits) else -1 125 | 126 | # config logging file 127 | args.logger_file = os.path.join(save_path, 'log_{}.txt'.format(args.cmd)) 128 | handlers = [logging.FileHandler(args.logger_file, mode='w'), 129 | logging.StreamHandler()] 130 | logging.basicConfig(level=logging.INFO, 131 | datefmt='%m-%d-%y %H:%M', 132 | format='%(asctime)s:%(message)s', 133 | handlers=handlers) 134 | 135 | if args.cmd == 'train': 136 | logging.info('start training {}'.format(args.arch)) 137 | run_training(args) 138 | 139 | elif args.cmd == 'test': 140 | logging.info('start evaluating {} with checkpoints from {}'.format( 141 | args.arch, args.resume)) 142 | test_model(args) 143 | 144 | bits = [3, 4, 6, 8] 145 | grad_bits = [6, 8, 12, 16] 146 | 147 | def run_training(args): 148 | 149 | cost_fw = [] 150 | for bit in bits: 151 | if bit == 0: 152 | cost_fw.append(1) 153 | else: 154 | cost_fw.append(bit/32) 155 | cost_fw = np.array(cost_fw) * args.weight_bits/32 156 | 157 | cost_eb = [] 158 | for bit in grad_bits: 159 | if bit == 0: 160 | cost_eb.append(1) 161 | else: 162 | cost_eb.append(bit/32) 163 | cost_eb = np.array(cost_eb) * args.weight_bits/32 164 | 165 | cost_gc = [] 166 | for i in range(len(bits)): 167 | if bits[i] == 0: 168 | cost_gc.append(1) 169 | else: 170 | cost_gc.append(bits[i]*grad_bits[i]/32/32) 171 | cost_gc = np.array(cost_gc) 172 | 173 | model = models.__dict__[args.arch](args.pretrained, proj_dim=len(bits)) 174 | model = torch.nn.DataParallel(model).cuda() 175 | 176 | best_prec1 = 0 177 | best_epoch = 0 178 | best_full_prec = 0 179 | 180 | # optionally resume from a checkpoint 181 | if args.resume: 182 | if os.path.isfile(args.resume): 183 | logging.info('=> loading checkpoint `{}`'.format(args.resume)) 184 | checkpoint = torch.load(args.resume) 185 | args.start_epoch = checkpoint['epoch'] 186 | best_prec1 = checkpoint['best_prec1'] 187 | model.load_state_dict(checkpoint['state_dict']) 188 | logging.info('=> loaded checkpoint `{}` (epoch: {})'.format( 189 | args.resume, checkpoint['epoch'] 190 | )) 191 | else: 192 | logging.info('=> no checkpoint found at `{}`'.format(args.resume)) 193 | 194 | cudnn.benchmark = False 195 | 196 | train_loader = prepare_train_data(dataset=args.dataset, 197 | datadir=args.datadir+'/train', 198 | batch_size=args.batch_size, 199 | shuffle=True, 200 | num_workers=args.workers) 201 | test_loader = prepare_test_data(dataset=args.dataset, 202 | datadir=args.datadir+'/val', 203 | batch_size=args.batch_size, 204 | shuffle=False, 205 | num_workers=args.workers) 206 | 207 | # define loss function (criterion) and optimizer 208 | criterion = nn.CrossEntropyLoss().cuda() 209 | 210 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 211 | momentum=args.momentum, 212 | weight_decay=args.weight_decay) 213 | 214 | batch_time = AverageMeter() 215 | data_time = AverageMeter() 216 | losses = AverageMeter() 217 | top1 = AverageMeter() 218 | cp_record = AverageMeter() 219 | cp_record_fw = AverageMeter() 220 | cp_record_eb = AverageMeter() 221 | cp_record_gc = AverageMeter() 222 | 223 | network_depth = sum(model.module.num_layers) 224 | 225 | layerwise_decision_statistics = [] 226 | 227 | for k in range(network_depth): 228 | layerwise_decision_statistics.append([]) 229 | for j in range(len(cost_fw)): 230 | ratio = AverageMeter() 231 | layerwise_decision_statistics[k].append(ratio) 232 | 233 | end = time.time() 234 | 235 | for _epoch in range(args.start_epoch, args.epoch): 236 | lr = adjust_learning_rate(args, optimizer, _epoch) 237 | adjust_target_ratio(args, _epoch) 238 | 239 | print('Learning Rate:', lr) 240 | print('Target Ratio:', args.target_ratio) 241 | 242 | for i, (input, target) in enumerate(train_loader): 243 | # measuring data loading time 244 | data_time.update(time.time() - end) 245 | 246 | model.train() 247 | 248 | target = target.squeeze().long().cuda() 249 | input_var = Variable(input).cuda() 250 | target_var = Variable(target).cuda() 251 | 252 | output, masks = model(input_var, bits, grad_bits) 253 | 254 | computation_cost_fw = 0 255 | computation_cost_eb = 0 256 | computation_cost_gc = 0 257 | computation_all = 0 258 | 259 | for layer in range(network_depth): 260 | 261 | full_layer = reduce((lambda x, y: x * y), masks[layer][0].shape) 262 | 263 | computation_all += full_layer 264 | 265 | for k in range(len(cost_fw)): 266 | 267 | dynamic_choice = masks[layer][k].sum() 268 | 269 | ratio = dynamic_choice / full_layer 270 | 271 | layerwise_decision_statistics[layer][k].update(ratio.data, 1) 272 | 273 | computation_cost_fw += masks[layer][k].sum() * cost_fw[k] 274 | computation_cost_eb += masks[layer][k].sum() * cost_eb[k] 275 | computation_cost_gc += masks[layer][k].sum() * cost_gc[k] 276 | 277 | computation_cost = computation_cost_fw + computation_cost_eb + computation_cost_gc 278 | 279 | cp_ratio_fw = (float(computation_cost_fw) / float(computation_all)) * 100 280 | cp_ratio_eb = (float(computation_cost_eb) / float(computation_all)) * 100 281 | cp_ratio_gc = (float(computation_cost_gc) / float(computation_all)) * 100 282 | 283 | cp_ratio = (float(computation_cost) / float(computation_all*3)) * 100 284 | 285 | computation_cost *= args.beta 286 | 287 | if cp_ratio < args.target_ratio - args.relax: 288 | reg = -1 289 | elif cp_ratio >= args.target_ratio + args.relax: 290 | reg = 1 291 | elif cp_ratio >=args.target_ratio: 292 | reg = 0.1 293 | else: 294 | reg = -0.1 295 | 296 | loss_cls = criterion(output, target_var) 297 | 298 | if args.computation_cost: 299 | loss = loss_cls + computation_cost * reg 300 | else: 301 | loss = loss_cls 302 | 303 | # measure accuracy and record loss 304 | prec1, = accuracy(output.data, target, topk=(1,)) 305 | losses.update(loss.item(), input.size(0)) 306 | top1.update(prec1.item(), input.size(0)) 307 | 308 | cp_record.update(cp_ratio,1) 309 | cp_record_fw.update(cp_ratio_fw,1) 310 | cp_record_eb.update(cp_ratio_eb,1) 311 | cp_record_gc.update(cp_ratio_gc,1) 312 | 313 | # compute gradient and do SGD step 314 | optimizer.zero_grad() 315 | loss.backward() 316 | optimizer.step() 317 | 318 | # measure elapsed time 319 | batch_time.update(time.time() - end) 320 | end = time.time() 321 | 322 | # print log 323 | if i % args.print_freq == 0: 324 | logging.info("Iter: [{0}][{1}/{2}]\t" 325 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 326 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" 327 | "Loss {loss.val:.3f} ({loss.avg:.3f})\t" 328 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t" 329 | "Computation_Percentage: {cp_record.val:.3f}({cp_record.avg:.3f})\t" 330 | "Computation_Percentage_FW: {cp_record_fw.val:.3f}({cp_record_fw.avg:.3f})\t" 331 | "Computation_Percentage_EB: {cp_record_eb.val:.3f}({cp_record_eb.avg:.3f})\t" 332 | "Computation_Percentage_GC: {cp_record_gc.val:.3f}({cp_record_gc.avg:.3f})\t".format( 333 | _epoch, 334 | i, 335 | len(train_loader), 336 | batch_time=batch_time, 337 | data_time=data_time, 338 | loss=losses, 339 | top1=top1, 340 | cp_record=cp_record, 341 | cp_record_fw=cp_record_fw, 342 | cp_record_eb=cp_record_eb, 343 | cp_record_gc=cp_record_gc) 344 | ) 345 | 346 | with torch.no_grad(): 347 | prec1 = validate(args, test_loader, model, criterion, _epoch) 348 | # prec_full = validate_full_prec(args, test_loader, model, criterion, i) 349 | 350 | is_best = prec1 > best_prec1 351 | if is_best: 352 | best_prec1 = prec1 353 | best_epoch = _epoch 354 | #best_full_prec = max(prec_full, best_full_prec) 355 | 356 | print("Current Best Prec@1: ", best_prec1, "Best Epoch:", best_epoch) 357 | #print("Current Best Full Prec@1: ", best_full_prec) 358 | 359 | checkpoint_path = os.path.join(args.save_path, 'checkpoint_{:05d}_{:.2f}.pth.tar'.format(_epoch, prec1)) 360 | save_checkpoint({ 361 | 'epoch': _epoch, 362 | 'arch': args.arch, 363 | 'state_dict': model.state_dict(), 364 | 'best_prec1': best_prec1, 365 | }, 366 | is_best, filename=checkpoint_path) 367 | shutil.copyfile(checkpoint_path, os.path.join(args.save_path, 368 | 'checkpoint_latest' 369 | '.pth.tar')) 370 | 371 | 372 | 373 | def validate(args, test_loader, model, criterion, _epoch): 374 | 375 | cost_fw = [] 376 | for bit in bits: 377 | if bit == 0: 378 | cost_fw.append(1) 379 | else: 380 | cost_fw.append(bit/32) 381 | cost_fw = np.array(cost_fw) * args.weight_bits/32 382 | 383 | cost_eb = [] 384 | for bit in grad_bits: 385 | if bit == 0: 386 | cost_eb.append(1) 387 | else: 388 | cost_eb.append(bit/32) 389 | cost_eb = np.array(cost_eb) * args.weight_bits/32 390 | 391 | cost_gc = [] 392 | for i in range(len(bits)): 393 | if bits[i] == 0: 394 | cost_gc.append(1) 395 | else: 396 | cost_gc.append(bits[i]*grad_bits[i]/32/32) 397 | cost_gc = np.array(cost_gc) 398 | 399 | batch_time = AverageMeter() 400 | data_time = AverageMeter() 401 | losses = AverageMeter() 402 | top1 = AverageMeter() 403 | cp_record = AverageMeter() 404 | cp_record_fw = AverageMeter() 405 | cp_record_eb = AverageMeter() 406 | cp_record_gc = AverageMeter() 407 | 408 | network_depth = sum(model.module.num_layers) 409 | 410 | layerwise_decision_statistics = [] 411 | 412 | for k in range(network_depth): 413 | layerwise_decision_statistics.append([]) 414 | for j in range(len(cost_fw)): 415 | ratio = AverageMeter() 416 | layerwise_decision_statistics[k].append(ratio) 417 | 418 | model.eval() 419 | end = time.time() 420 | for i, (input, target) in enumerate(test_loader): 421 | target = target.squeeze().long().cuda() 422 | input_var = Variable(input, volatile=True).cuda() 423 | target_var = Variable(target, volatile=True).cuda() 424 | 425 | output, masks = model(input_var, bits, grad_bits) 426 | 427 | computation_cost_fw = 0 428 | computation_cost_eb = 0 429 | computation_cost_gc = 0 430 | computation_all = 0 431 | 432 | for layer in range(network_depth): 433 | 434 | full_layer = reduce((lambda x, y: x * y), masks[layer][0].shape) 435 | 436 | computation_all += full_layer 437 | 438 | for k in range(len(cost_fw)): 439 | 440 | dynamic_choice = masks[layer][k].sum() 441 | 442 | ratio = dynamic_choice / full_layer 443 | 444 | layerwise_decision_statistics[layer][k].update(ratio.data, 1) 445 | 446 | computation_cost_fw += masks[layer][k].sum() * cost_fw[k] 447 | computation_cost_eb += masks[layer][k].sum() * cost_eb[k] 448 | computation_cost_gc += masks[layer][k].sum() * cost_gc[k] 449 | 450 | computation_cost = computation_cost_fw + computation_cost_eb + computation_cost_gc 451 | 452 | cp_ratio_fw = (float(computation_cost_fw) / float(computation_all)) * 100 453 | cp_ratio_eb = (float(computation_cost_eb) / float(computation_all)) * 100 454 | cp_ratio_gc = (float(computation_cost_gc) / float(computation_all)) * 100 455 | 456 | cp_ratio = (float(computation_cost) / float(computation_all*3)) * 100 457 | 458 | loss = criterion(output, target_var) 459 | 460 | # measure accuracy and record loss 461 | prec1, = accuracy(output.data, target, topk=(1,)) 462 | losses.update(loss.item(), input.size(0)) 463 | top1.update(prec1.item(), input.size(0)) 464 | 465 | cp_record.update(cp_ratio,1) 466 | cp_record_fw.update(cp_ratio_fw,1) 467 | cp_record_eb.update(cp_ratio_eb,1) 468 | cp_record_gc.update(cp_ratio_gc,1) 469 | 470 | batch_time.update(time.time() - end) 471 | end = time.time() 472 | 473 | if i % args.print_freq == 0 or (i == (len(test_loader) - 1)): 474 | logging.info("Iter: [{0}/{1}]\t" 475 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 476 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" 477 | "Loss {loss.val:.3f} ({loss.avg:.3f})\t" 478 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t" 479 | "Computation_Percentage: {cp_record.val:.3f}({cp_record.avg:.3f})\t" 480 | "Computation_Percentage_FW: {cp_record_fw.val:.3f}({cp_record_fw.avg:.3f})\t" 481 | "Computation_Percentage_EB: {cp_record_eb.val:.3f}({cp_record_eb.avg:.3f})\t" 482 | "Computation_Percentage_GC: {cp_record_gc.val:.3f}({cp_record_gc.avg:.3f})\t".format( 483 | i, 484 | len(test_loader), 485 | batch_time=batch_time, 486 | data_time=data_time, 487 | loss=losses, 488 | top1=top1, 489 | cp_record=cp_record, 490 | cp_record_fw=cp_record_fw, 491 | cp_record_eb=cp_record_eb, 492 | cp_record_gc=cp_record_gc) 493 | ) 494 | 495 | logging.info('Epoch {} * Prec@1 {top1.avg:.3f}'.format(_epoch, top1=top1)) 496 | 497 | for layer in range(network_depth): 498 | print('layer{}_decision'.format(layer + 1)) 499 | for g in range(len(cost_fw)): 500 | print('{}_ratio{}'.format(g,layerwise_decision_statistics[layer][g].avg)) 501 | 502 | return top1.avg 503 | 504 | 505 | def validate_full_prec(args, test_loader, model, criterion, _epoch): 506 | batch_time = AverageMeter() 507 | losses = AverageMeter() 508 | top1 = AverageMeter() 509 | 510 | bits_full = np.zeros(len(bits)) 511 | grad_bits_full = np.zeros(len(grad_bits)) 512 | # switch to evaluation mode 513 | model.eval() 514 | end = time.time() 515 | for i, (input, target) in enumerate(test_loader): 516 | target = target.squeeze().long().cuda() 517 | input_var = Variable(input, volatile=True).cuda() 518 | target_var = Variable(target, volatile=True).cuda() 519 | 520 | # compute output 521 | output, _ = model(input_var, bits_full, grad_bits_full) 522 | loss = criterion(output, target_var) 523 | 524 | # measure accuracy and record loss 525 | prec1, = accuracy(output.data, target, topk=(1,)) 526 | top1.update(prec1.item(), input.size(0)) 527 | losses.update(loss.item(), input.size(0)) 528 | batch_time.update(time.time() - end) 529 | end = time.time() 530 | 531 | if args.gate_type == 'rnn': 532 | model.module.control.repackage_hidden() 533 | 534 | logging.info('Epoch {} * Full Prec@1 {top1.avg:.3f}'.format(_epoch, top1=top1)) 535 | return top1.avg 536 | 537 | 538 | def test_model(args): 539 | # create model 540 | model = models.__dict__[args.arch](args.pretrained) 541 | model = torch.nn.DataParallel(model).cuda() 542 | 543 | if args.resume: 544 | if os.path.isfile(args.resume): 545 | logging.info('=> loading checkpoint `{}`'.format(args.resume)) 546 | checkpoint = torch.load(args.resume) 547 | args.start_epoch = checkpoint['epoch'] 548 | best_prec1 = checkpoint['best_prec1'] 549 | model.load_state_dict(checkpoint['state_dict']) 550 | logging.info('=> loaded checkpoint `{}` (epoch: {})'.format( 551 | args.resume, checkpoint['epoch'] 552 | )) 553 | else: 554 | logging.info('=> no checkpoint found at `{}`'.format(args.resume)) 555 | 556 | cudnn.benchmark = False 557 | test_loader = prepare_test_data(dataset=args.dataset, 558 | batch_size=args.batch_size, 559 | shuffle=False, 560 | num_workers=args.workers) 561 | criterion = nn.CrossEntropyLoss().cuda() 562 | 563 | # validate(args, test_loader, model, criterion) 564 | 565 | with torch.no_grad(): 566 | prec1 = validate(args, test_loader, model, criterion, args.start_iter) 567 | # prec_full = validate_full_prec(args, test_loader, model, criterion, args.start_iter) 568 | 569 | 570 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 571 | torch.save(state, filename) 572 | if is_best: 573 | save_path = os.path.dirname(filename) 574 | shutil.copyfile(filename, os.path.join(save_path, 575 | 'model_best.pth.tar')) 576 | 577 | 578 | class AverageMeter(object): 579 | """Computes and stores the average and current value""" 580 | 581 | def __init__(self): 582 | self.reset() 583 | 584 | def reset(self): 585 | self.val = 0 586 | self.avg = 0 587 | self.sum = 0 588 | self.count = 0 589 | 590 | def update(self, val, n=1): 591 | self.val = val 592 | self.sum += val * n 593 | self.count += n 594 | self.avg = self.sum / self.count 595 | 596 | 597 | schedule_cnt = 0 598 | def adjust_target_ratio(args, _epoch): 599 | if args.schedule: 600 | global schedule_cnt 601 | 602 | assert len(args.target_ratio_schedule) == len(args.schedule) + 1 603 | 604 | if schedule_cnt == 0: 605 | args.target_ratio = args.target_ratio_schedule[0] 606 | schedule_cnt += 1 607 | 608 | for step in args.schedule: 609 | if _epoch == step: 610 | args.target_ratio = args.target_ratio_schedule[schedule_cnt] 611 | schedule_cnt += 1 612 | 613 | 614 | def adjust_learning_rate(args, optimizer, _epoch): 615 | lr = args.lr * (0.1 ** (_epoch // 30)) 616 | 617 | for param_group in optimizer.param_groups: 618 | param_group['lr'] = lr 619 | 620 | return lr 621 | 622 | 623 | def accuracy(output, target, topk=(1,)): 624 | """Computes the precision@k for the specified values of k""" 625 | maxk = max(topk) 626 | batch_size = target.size(0) 627 | 628 | _, pred = output.topk(maxk, 1, True, True) 629 | pred = pred.t() 630 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 631 | 632 | res = [] 633 | for k in topk: 634 | correct_k = correct[:k].view(-1).float().sum(0) 635 | res.append(correct_k.mul_(100.0 / batch_size)) 636 | return res 637 | 638 | 639 | if __name__ == '__main__': 640 | main() 641 | -------------------------------------------------------------------------------- /fractrain_imagenet/train_pfq.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | from torch.autograd import Variable 7 | import numpy as np 8 | 9 | import os 10 | import shutil 11 | import argparse 12 | import time 13 | import logging 14 | 15 | import models 16 | from data import * 17 | 18 | 19 | model_names = sorted(name for name in models.__dict__ 20 | if name.islower() and not name.startswith('__') 21 | and callable(models.__dict__[name]) 22 | ) 23 | 24 | 25 | def parse_args(): 26 | # hyper-parameters are from ResNet paper 27 | parser = argparse.ArgumentParser( 28 | description='PFQ on ImageNet') 29 | parser.add_argument('--dir', help='annotate the working directory') 30 | parser.add_argument('--cmd', choices=['train', 'test'], default='train') 31 | parser.add_argument('--arch', metavar='ARCH', default='resnet50', 32 | choices=model_names, 33 | help='model architecture: ' + 34 | ' | '.join(model_names) + 35 | ' (default: cifar10_resnet_38)') 36 | parser.add_argument('--dataset', '-d', type=str, default='imagenet', 37 | choices=['cifar10', 'cifar100','imagenet'], 38 | help='dataset choice') 39 | parser.add_argument('--datadir', default='/home/yf22/dataset', type=str, 40 | help='path to dataset') 41 | parser.add_argument('--workers', default=16, type=int, metavar='N', 42 | help='number of data loading workers (default: 4 )') 43 | parser.add_argument('--epoch', default=90, type=int, 44 | help='number of epochs (default: 90)') 45 | parser.add_argument('--start_epoch', default=0, type=int, 46 | help='manual iter number (useful on restarts)') 47 | parser.add_argument('--batch_size', default=256, type=int, 48 | help='mini-batch size (default: 128)') 49 | parser.add_argument('--lr_schedule', default='piecewise', type=str, 50 | help='learning rate schedule') 51 | parser.add_argument('--lr', default=0.1, type=float, 52 | help='initial learning rate') 53 | parser.add_argument('--momentum', default=0.9, type=float, 54 | help='momentum') 55 | parser.add_argument('--weight_decay', default=1e-4, type=float, 56 | help='weight decay (default: 1e-4)') 57 | parser.add_argument('--print_freq', default=10, type=int, 58 | help='print frequency (default: 10)') 59 | parser.add_argument('--resume', default='', type=str, 60 | help='path to latest checkpoint (default: None)') 61 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 62 | help='use pretrained model') 63 | parser.add_argument('--step_ratio', default=0.1, type=float, 64 | help='ratio for learning rate deduction') 65 | parser.add_argument('--warm_up', action='store_true', 66 | help='for n = 18, the model needs to warm up for 400 ' 67 | 'iterations') 68 | parser.add_argument('--save_folder', default='save_checkpoints', 69 | type=str, 70 | help='folder to save the checkpoints') 71 | parser.add_argument('--eval_every', default=390, type=int, 72 | help='evaluate model every (default: 1000) iterations') 73 | parser.add_argument('--num_bits',default=0,type=int, 74 | help='num bits for weight and activation') 75 | parser.add_argument('--num_grad_bits',default=0,type=int, 76 | help='num bits for gradient') 77 | parser.add_argument('--schedule', default=None, type=int, nargs='*', 78 | help='precision schedule') 79 | parser.add_argument('--num_bits_schedule',default=None,type=int,nargs='*', 80 | help='schedule for weight/act precision') 81 | parser.add_argument('--num_grad_bits_schedule',default=None,type=int,nargs='*', 82 | help='schedule for grad precision') 83 | parser.add_argument('--act_fw', default=0, type=int, 84 | help='precision of activation during forward, -1 means dynamic, 0 means no quantize') 85 | parser.add_argument('--act_bw', default=0, type=int, 86 | help='precision of activation during backward, -1 means dynamic, 0 means no quantize') 87 | parser.add_argument('--grad_act_error', default=0, type=int, 88 | help='precision of activation gradient during error backward, -1 means dynamic, 0 means no quantize') 89 | parser.add_argument('--grad_act_gc', default=0, type=int, 90 | help='precision of activation gradient during weight gradient computation, -1 means dynamic, 0 means no quantize') 91 | parser.add_argument('--weight_bits', default=0, type=int, 92 | help='precision of weight') 93 | parser.add_argument('--momentum_act', default=0.9, type=float, 94 | help='momentum for act min/max') 95 | 96 | parser.add_argument('--num_turning_point', type=int, default=3) 97 | parser.add_argument('--initial_threshold', type=float, default=0.05) 98 | parser.add_argument('--decay', type=float, default=0.3) 99 | args = parser.parse_args() 100 | return args 101 | 102 | # indicator 103 | class loss_diff_indicator(): 104 | def __init__(self, threshold, decay, epoch_keep=5): 105 | self.threshold = threshold 106 | self.decay = decay 107 | self.epoch_keep = epoch_keep 108 | self.loss = [] 109 | self.scale_loss = 1 110 | self.loss_diff = [1 for i in range(1, self.epoch_keep)] 111 | 112 | def reset(self): 113 | self.loss = [] 114 | self.loss_diff = [1 for i in range(1, self.epoch_keep)] 115 | 116 | def adaptive_threshold(self, turning_point_count): 117 | decay_1 = self.decay 118 | decay_2 = self.decay 119 | if turning_point_count == 1: 120 | self.threshold *= decay_1 121 | if turning_point_count == 2: 122 | self.threshold *= decay_2 123 | print('threshold decay to {}'.format(self.threshold)) 124 | 125 | def get_loss(self, current_epoch_loss): 126 | if len(self.loss) < self.epoch_keep: 127 | self.loss.append(current_epoch_loss) 128 | else: 129 | self.loss.pop(0) 130 | self.loss.append(current_epoch_loss) 131 | 132 | def cal_loss_diff(self): 133 | if len(self.loss) == self.epoch_keep: 134 | for i in range(len(self.loss)-1): 135 | loss_now = self.loss[-1] 136 | loss_pre = self.loss[i] 137 | self.loss_diff[i] = np.abs(loss_pre - loss_now) / self.scale_loss 138 | return True 139 | else: 140 | return False 141 | 142 | def turning_point_emerge(self): 143 | flag = self.cal_loss_diff() 144 | if flag == True: 145 | print(self.loss_diff) 146 | for i in range(len(self.loss_diff)): 147 | if self.loss_diff[i] > self.threshold: 148 | return False 149 | return True 150 | else: 151 | return False 152 | 153 | def main(): 154 | args = parse_args() 155 | global save_path 156 | save_path = args.save_path = os.path.join(args.save_folder, args.arch) 157 | if not os.path.exists(save_path): 158 | os.makedirs(save_path) 159 | 160 | models.ACT_FW = args.act_fw 161 | models.ACT_BW = args.act_bw 162 | models.GRAD_ACT_ERROR = args.grad_act_error 163 | models.GRAD_ACT_GC = args.grad_act_gc 164 | models.WEIGHT_BITS = args.weight_bits 165 | models.MOMENTUM = args.momentum_act 166 | 167 | args.num_bits = args.num_bits if not (args.act_fw + args.act_bw + args.grad_act_error + args.grad_act_gc + args.weight_bits) else -1 168 | 169 | # config logging file 170 | args.logger_file = os.path.join(save_path, 'log_{}.txt'.format(args.cmd)) 171 | if os.path.exists(args.logger_file): 172 | os.remove(args.logger_file) 173 | handlers = [logging.FileHandler(args.logger_file, mode='w'), 174 | logging.StreamHandler()] 175 | logging.basicConfig(level=logging.INFO, 176 | datefmt='%m-%d-%y %H:%M', 177 | format='%(asctime)s:%(message)s', 178 | handlers=handlers) 179 | 180 | global history_score 181 | history_score = np.zeros((args.epoch, 3)) 182 | 183 | # initialize indicator 184 | # initial_threshold=0.15 185 | global scale_loss 186 | scale_loss = 0 187 | global my_loss_diff_indicator 188 | my_loss_diff_indicator = loss_diff_indicator(threshold=args.initial_threshold, 189 | decay=args.decay) 190 | 191 | global turning_point_count 192 | turning_point_count = 0 193 | 194 | if args.cmd == 'train': 195 | logging.info('start training {}'.format(args.arch)) 196 | run_training(args) 197 | 198 | elif args.cmd == 'test': 199 | logging.info('start evaluating {} with checkpoints from {}'.format( 200 | args.arch, args.resume)) 201 | test_model(args) 202 | 203 | 204 | def run_training(args): 205 | # create model 206 | training_loss = 0 207 | training_acc = 0 208 | 209 | model = models.__dict__[args.arch](args.pretrained) 210 | model = torch.nn.DataParallel(model).cuda() 211 | 212 | best_prec1 = 0 213 | best_epoch = 0 214 | 215 | # optionally resume from a checkpoint 216 | if args.resume: 217 | if os.path.isfile(args.resume): 218 | logging.info('=> loading checkpoint `{}`'.format(args.resume)) 219 | checkpoint = torch.load(args.resume) 220 | args.start_epoch = checkpoint['epoch'] 221 | best_prec1 = checkpoint['best_prec1'] 222 | model.load_state_dict(checkpoint['state_dict']) 223 | 224 | logging.info('=> loaded checkpoint `{}` (epoch: {})'.format( 225 | args.resume, checkpoint['epoch'] 226 | )) 227 | else: 228 | logging.info('=> no checkpoint found at `{}`'.format(args.resume)) 229 | 230 | cudnn.benchmark = False 231 | 232 | train_loader = prepare_train_data(dataset=args.dataset, 233 | datadir=args.datadir+'/train', 234 | batch_size=args.batch_size, 235 | shuffle=True, 236 | num_workers=args.workers) 237 | test_loader = prepare_test_data(dataset=args.dataset, 238 | datadir=args.datadir+'/val', 239 | batch_size=args.batch_size, 240 | shuffle=False, 241 | num_workers=args.workers) 242 | 243 | # define loss function (criterion) and optimizer 244 | criterion = nn.CrossEntropyLoss().cuda() 245 | 246 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 247 | momentum=args.momentum, 248 | weight_decay=args.weight_decay) 249 | 250 | # optimizer = torch.optim.Adam(model.parameters(), args.lr, 251 | # weight_decay=args.weight_decay) 252 | 253 | batch_time = AverageMeter() 254 | data_time = AverageMeter() 255 | losses = AverageMeter() 256 | top1 = AverageMeter() 257 | cr = AverageMeter() 258 | 259 | end = time.time() 260 | 261 | global scale_loss 262 | global history_score 263 | global turning_point_count 264 | global my_loss_diff_indicator 265 | 266 | for _epoch in range(args.start_epoch, args.epoch): 267 | lr = adjust_learning_rate(args, optimizer, _epoch) 268 | # adjust_precision(args, _epoch) 269 | adaptive_adjust_precision(args, turning_point_count) 270 | 271 | print('Learning Rate:', lr) 272 | print('num bits:', args.num_bits, 'num grad bits:', args.num_grad_bits) 273 | 274 | for i, (input, target) in enumerate(train_loader): 275 | # measuring data loading time 276 | data_time.update(time.time() - end) 277 | 278 | model.train() 279 | 280 | fw_cost = args.num_bits*args.num_bits/32/32 281 | eb_cost = args.num_bits*args.num_grad_bits/32/32 282 | gc_cost = eb_cost 283 | cr.update((fw_cost+eb_cost+gc_cost)/3) 284 | 285 | target = target.squeeze().long().cuda() 286 | input_var = Variable(input).cuda() 287 | target_var = Variable(target).cuda() 288 | 289 | # compute output 290 | output = model(input_var, args.num_bits, args.num_grad_bits) 291 | loss = criterion(output, target_var) 292 | training_loss += loss.item() 293 | 294 | # measure accuracy and record loss 295 | prec1, = accuracy(output.data, target, topk=(1,)) 296 | losses.update(loss.item(), input.size(0)) 297 | top1.update(prec1.item(), input.size(0)) 298 | training_acc += prec1.item() 299 | 300 | # compute gradient and do SGD step 301 | optimizer.zero_grad() 302 | loss.backward() 303 | optimizer.step() 304 | 305 | # measure elapsed time 306 | batch_time.update(time.time() - end) 307 | end = time.time() 308 | 309 | # print log 310 | if i % args.print_freq == 0: 311 | logging.info("Iter: [{0}][{1}/{2}]\t" 312 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" 313 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" 314 | "Loss {loss.val:.3f} ({loss.avg:.3f})\t" 315 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t".format( 316 | _epoch, 317 | i, 318 | len(train_loader), 319 | batch_time=batch_time, 320 | data_time=data_time, 321 | loss=losses, 322 | top1=top1) 323 | ) 324 | 325 | epoch = _epoch + 1 326 | epoch_loss = training_loss / len(train_loader) 327 | with torch.no_grad(): 328 | prec1 = validate(args, test_loader, model, criterion, _epoch) 329 | # prec_full = validate_full_prec(args, test_loader, model, criterion, i) 330 | history_score[epoch-1][0] = epoch_loss 331 | history_score[epoch-1][1] = np.round(training_acc / len(train_loader), 2) 332 | history_score[epoch-1][2] = prec1 333 | training_loss = 0 334 | training_acc = 0 335 | 336 | np.savetxt(os.path.join(save_path, 'record.txt'), history_score, fmt = '%10.5f', delimiter=',') 337 | 338 | # apply indicator 339 | # if epoch == 1: 340 | # logging.info('initial loss value: {}'.format(epoch_loss)) 341 | # my_loss_diff_indicator.scale_loss = epoch_loss 342 | if epoch <= 10: 343 | scale_loss += epoch_loss 344 | logging.info('scale_loss at epoch {}: {}'.format(epoch, scale_loss / epoch)) 345 | my_loss_diff_indicator.scale_loss = scale_loss / epoch 346 | if turning_point_count < args.num_turning_point: 347 | my_loss_diff_indicator.get_loss(epoch_loss) 348 | flag = my_loss_diff_indicator.turning_point_emerge() 349 | if flag == True: 350 | turning_point_count += 1 351 | logging.info('find {}-th turning point at {}-th epoch'.format(turning_point_count, epoch)) 352 | # print('find {}-th turning point at {}-th epoch'.format(turning_point_count, epoch)) 353 | my_loss_diff_indicator.adaptive_threshold(turning_point_count=turning_point_count) 354 | my_loss_diff_indicator.reset() 355 | 356 | logging.info('Epoch [{}] num_bits = {} num_grad_bits = {}'.format(epoch, args.num_bits, args.num_grad_bits)) 357 | 358 | 359 | is_best = prec1 > best_prec1 360 | if is_best: 361 | best_prec1 = max(prec1, best_prec1) 362 | best_epoch = epoch 363 | #best_full_prec = max(prec_full, best_full_prec) 364 | 365 | print("Current Best Prec@1: ", best_prec1) 366 | logging.info("Current Best Epoch: {}".format(best_epoch)) 367 | #print("Current Best Full Prec@1: ", best_full_prec) 368 | 369 | checkpoint_path = os.path.join(args.save_path, 'checkpoint_{:05d}_{:.2f}.pth.tar'.format(_epoch, prec1)) 370 | save_checkpoint({ 371 | 'epoch': _epoch, 372 | 'arch': args.arch, 373 | 'state_dict': model.state_dict(), 374 | 'best_prec1': best_prec1, 375 | }, 376 | is_best, filename=checkpoint_path) 377 | shutil.copyfile(checkpoint_path, os.path.join(args.save_path, 378 | 'checkpoint_latest' 379 | '.pth.tar')) 380 | 381 | 382 | def validate(args, test_loader, model, criterion, _epoch): 383 | batch_time = AverageMeter() 384 | losses = AverageMeter() 385 | top1 = AverageMeter() 386 | 387 | # switch to evaluation mode 388 | model.eval() 389 | end = time.time() 390 | for i, (input, target) in enumerate(test_loader): 391 | target = target.squeeze().long().cuda() 392 | input_var = Variable(input, volatile=True).cuda() 393 | target_var = Variable(target, volatile=True).cuda() 394 | 395 | # compute output 396 | output = model(input_var, args.num_bits, args.num_grad_bits) 397 | loss = criterion(output, target_var) 398 | 399 | # measure accuracy and record loss 400 | prec1, = accuracy(output.data, target, topk=(1,)) 401 | top1.update(prec1.item(), input.size(0)) 402 | losses.update(loss.item(), input.size(0)) 403 | batch_time.update(time.time() - end) 404 | end = time.time() 405 | 406 | if (i % args.print_freq == 0) or (i == len(test_loader) - 1): 407 | logging.info( 408 | 'Test: [{}/{}]\t' 409 | 'Time: {batch_time.val:.4f}({batch_time.avg:.4f})\t' 410 | 'Loss: {loss.val:.3f}({loss.avg:.3f})\t' 411 | 'Prec@1: {top1.val:.3f}({top1.avg:.3f})\t'.format( 412 | i, len(test_loader), batch_time=batch_time, 413 | loss=losses, top1=top1 414 | ) 415 | ) 416 | 417 | logging.info('Epoch {} * Prec@1 {top1.avg:.3f}'.format(_epoch, top1=top1)) 418 | return top1.avg 419 | 420 | 421 | def validate_full_prec(args, test_loader, model, criterion, _epoch): 422 | batch_time = AverageMeter() 423 | losses = AverageMeter() 424 | top1 = AverageMeter() 425 | 426 | # switch to evaluation mode 427 | model.eval() 428 | end = time.time() 429 | for i, (input, target) in enumerate(test_loader): 430 | target = target.squeeze().long().cuda() 431 | input_var = Variable(input, volatile=True).cuda() 432 | target_var = Variable(target, volatile=True).cuda() 433 | 434 | # compute output 435 | output = model(input_var, 0, 0) 436 | loss = criterion(output, target_var) 437 | 438 | # measure accuracy and record loss 439 | prec1, = accuracy(output.data, target, topk=(1,)) 440 | top1.update(prec1.item(), input.size(0)) 441 | losses.update(loss.item(), input.size(0)) 442 | batch_time.update(time.time() - end) 443 | end = time.time() 444 | 445 | 446 | logging.info('Epoch {} * Full Prec@1 {top1.avg:.3f}'.format(_epoch, top1=top1)) 447 | return top1.avg 448 | 449 | 450 | def test_model(args): 451 | # create model 452 | model = models.__dict__[args.arch](args.pretrained) 453 | model = torch.nn.DataParallel(model).cuda() 454 | 455 | if args.resume: 456 | if os.path.isfile(args.resume): 457 | logging.info('=> loading checkpoint `{}`'.format(args.resume)) 458 | checkpoint = torch.load(args.resume) 459 | args.start_epoch = checkpoint['epoch'] 460 | best_prec1 = checkpoint['best_prec1'] 461 | model.load_state_dict(checkpoint['state_dict']) 462 | logging.info('=> loaded checkpoint `{}` (epoch: {})'.format( 463 | args.resume, checkpoint['epoch'] 464 | )) 465 | else: 466 | logging.info('=> no checkpoint found at `{}`'.format(args.resume)) 467 | 468 | cudnn.benchmark = False 469 | test_loader = prepare_test_data(dataset=args.dataset, 470 | batch_size=args.batch_size, 471 | shuffle=False, 472 | num_workers=args.workers) 473 | criterion = nn.CrossEntropyLoss().cuda() 474 | 475 | # validate(args, test_loader, model, criterion) 476 | 477 | with torch.no_grad(): 478 | prec1 = validate(args, test_loader, model, criterion, args.start_iter) 479 | # prec_full = validate_full_prec(args, test_loader, model, criterion, args.start_iter) 480 | 481 | 482 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 483 | torch.save(state, filename) 484 | if is_best: 485 | save_path = os.path.dirname(filename) 486 | shutil.copyfile(filename, os.path.join(save_path, 487 | 'model_best.pth.tar')) 488 | 489 | 490 | class AverageMeter(object): 491 | """Computes and stores the average and current value""" 492 | 493 | def __init__(self): 494 | self.reset() 495 | 496 | def reset(self): 497 | self.val = 0 498 | self.avg = 0 499 | self.sum = 0 500 | self.count = 0 501 | 502 | def update(self, val, n=1): 503 | self.val = val 504 | self.sum += val * n 505 | self.count += n 506 | self.avg = self.sum / self.count 507 | 508 | 509 | schedule_cnt = 0 510 | def adjust_precision(args, _epoch): 511 | if args.schedule: 512 | global schedule_cnt 513 | 514 | assert len(args.num_bits_schedule) == len(args.schedule) + 1 515 | assert len(args.num_grad_bits_schedule) == len(args.schedule) + 1 516 | 517 | if schedule_cnt == 0: 518 | args.num_bits = args.num_bits_schedule[0] 519 | args.num_grad_bits = args.num_grad_bits_schedule[0] 520 | schedule_cnt += 1 521 | 522 | for step in args.schedule: 523 | if _epoch == step: 524 | args.num_bits = args.num_bits_schedule[schedule_cnt] 525 | args.num_grad_bits = args.num_grad_bits_schedule[schedule_cnt] 526 | schedule_cnt += 1 527 | 528 | def adaptive_adjust_precision(args, turning_point_count): 529 | args.num_bits = args.num_bits_schedule[turning_point_count] 530 | args.num_grad_bits = args.num_grad_bits_schedule[turning_point_count] 531 | 532 | def adjust_learning_rate(args, optimizer, _epoch): 533 | lr = args.lr * (0.1 ** (_epoch // 30)) 534 | 535 | for param_group in optimizer.param_groups: 536 | param_group['lr'] = lr 537 | 538 | return lr 539 | 540 | 541 | def accuracy(output, target, topk=(1,)): 542 | """Computes the precision@k for the specified values of k""" 543 | maxk = max(topk) 544 | batch_size = target.size(0) 545 | 546 | _, pred = output.topk(maxk, 1, True, True) 547 | pred = pred.t() 548 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 549 | 550 | res = [] 551 | for k in topk: 552 | correct_k = correct[:k].view(-1).float().sum(0) 553 | res.append(correct_k.mul_(100.0 / batch_size)) 554 | return res 555 | 556 | 557 | if __name__ == '__main__': 558 | main() 559 | -------------------------------------------------------------------------------- /img/DFQ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/img/DFQ.png -------------------------------------------------------------------------------- /img/PFQ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/img/PFQ.png -------------------------------------------------------------------------------- /img/dfq_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/img/dfq_result.png -------------------------------------------------------------------------------- /img/fractrain_result_cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/img/fractrain_result_cifar.png -------------------------------------------------------------------------------- /img/fractrain_result_imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/img/fractrain_result_imagenet.png -------------------------------------------------------------------------------- /img/pfq_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/img/pfq_result.png --------------------------------------------------------------------------------