├── LICENSE
├── README.md
├── env.yml
├── fractrain_cifar
├── compute_flops.py
├── data.py
├── model_info_mbv2.npy
├── models.py
├── modules
│ ├── __init__.py
│ ├── bwn.py
│ ├── quantize.py
│ └── rnlu.py
├── train_base.py
├── train_dfq.py
├── train_frac.py
├── train_pfq.py
└── util_swa.py
├── fractrain_imagenet
├── data.py
├── models.py
├── modules
│ ├── __init__.py
│ ├── bwn.py
│ ├── quantize.py
│ └── rnlu.py
├── train_base.py
├── train_dfq.py
├── train_frac.py
└── train_pfq.py
└── img
├── DFQ.png
├── PFQ.png
├── dfq_result.png
├── fractrain_result_cifar.png
├── fractrain_result_imagenet.png
└── pfq_result.png
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 RICE-EIC
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the Software), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, andor sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED AS IS, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FracTrain: Fractionally Squeezing Bit Savings Both Temporally and Spatially for Efficient DNN Training
2 | ***Yonggan Fu***, Haoran You, Yang Zhao, Yue Wang, Chaojian Li, Kailash Gopalakrishnan, Zhangyang Wang, Yingyan Lin
3 |
4 | Accepted at NeurIPS 2020 [[Paper Link]](https://papers.nips.cc/paper/2020/file/8dc5983b8c4ef1d8fcd5f325f9a65511-Paper.pdf).
5 |
6 | ## Overview
7 | As reducing precision is one of the most effective knobs for boosting DNN training time/energy efficiency, there has been a growing interest in low-precision DNN
8 | training. In this paper, we propose the ***FracTrain*** framework which explores from an orthogonal direction: how to fractionally squeeze out more training cost savings from the most redundant bit level, progressively along the training trajectory and dynamically per input.
9 |
10 |
11 | ## Method
12 | We propose ***FracTrain*** that integrates *(i)* progressive fractional quantization (PFQ) which gradually increases the precision of activations, weights, and gradients that will not reach the precision of SOTA static quantized DNN training until the final training stage, and *(ii)* dynamic fractional quantization (DFQ) which assigns precisions to both the activations and gradients of each layer in an input-adaptive manner, for only “fractionally” updating layer parameters.
13 |
14 | ### Progressive Fractional Quantization (PFQ)
15 |
16 |
17 |
18 |
19 |
20 | PFQ progressively increases the precision for both forward and backward in DNN training controlled by an automated indicator, which gradually switches the focus from solution space exploration enabled by low precision to accurate update for better convergence enabled by high precision.
21 |
22 | ### Dynamic Fractional Quantization (DFQ)
23 |
24 |
25 |
26 |
27 |
28 | DFQ dynamically selects the precision for each layer’s activation and gradient during training in an input-adaptive manner controlled by a lightweight LSTM-based gate function.
29 |
30 | ### The FracTrain framework
31 |
32 | ***FracTrain*** integrates PFQ and DFQ in an easy but effective way that it applies PFQ’s indicator to schedule DFQ's compression ratio in different training stages.
33 |
34 |
35 | ## Evaluation
36 | We evaluate ***FracTrain*** on six models & four datasets (i.e., ResNet-38/74/MobileNetV2 on CIFAR-10/100, ResNet-18/34 on ImageNet, and Transformer on WikiText-103). Here are some representative experiments and please refer to [our paper](https://papers.nips.cc/paper/2020/file/8dc5983b8c4ef1d8fcd5f325f9a65511-Paper.pdf) for more results.
37 |
38 |
39 | ### Evaluating PFQ
40 |
41 |
42 |
43 |
44 |
45 |
46 | ### Evaluating DFQ
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | ### Evaluating FracTrain integrating PFQ and DFQ
55 |
56 | - ***FracTrain*** on CIFAR-10/100
57 |
58 |
59 |
60 |
61 |
62 | - ***FracTrain*** on ImageNet
63 |
64 |
65 |
66 |
67 |
68 | ## Code Usage
69 | Our code is inspired by [SkipNet](https://github.com/ucbdrive/skipnet) and we follow its training setting for CIFAR-10/100 and ImageNet.
70 |
71 | ### Prerequisites
72 | See `env.yml` for the complete conda environment. Create a new conda environment:
73 | ```
74 | conda env create -f env.yml
75 | conda activate pytorch
76 | ```
77 |
78 | ### Overview
79 | `fractrain_cifar` and `fractrain_imagenet` are the codes customized for CIFAR-10/100 and ImageNet, respectively, with a similar code structure. In particular, `train_pfq.py`, `train_dfq.py`, and `train_frac.py` are the training scripts for training the target network with PFQ, DFQ, and ***FracTrain***, respectively.
80 | Here we use `fractrain_cifar` to demo the commands for training.
81 |
82 | ### Training with PFQ
83 | In addition to the commonly considered args, e.g., the target network, dataset, and data path via `--arch`, `--dataset`, and `--datadir`, respectively, you also need to: (1) Specify the precision schedule for both forward and backward with `--num_bits_schedule` and `--num_grad_bits_schedule`, respectively; (2) Specify the epsilon and alpha in Alg.1 of [our paper](https://papers.nips.cc/paper/2020/file/8dc5983b8c4ef1d8fcd5f325f9a65511-Paper.pdf) with `--initial_threshold` and `--decay`, respectively, as well as the number of turning points via `--num_turning_point`.
84 |
85 | - Example: Training ResNet-74 on CIFAR-100 with PFQ
86 | ```
87 | cd fractrain_cifar
88 | python train_pfq.py --save_folder ./logs/ --arch cifar100_resnet_74 --workers 4 --dataset cifar100 --datadir path-to-cifar100 --num_bits_schedule 3 4 6 8 --num_grad_bits_schedule 6 6 8 8 --num_turning_point 3 --initial_threshold 0.05 --decay 0.3
89 | ```
90 |
91 | ### Training with DFQ
92 | Specify the weight precision with `--weight_bits` and the target *cp* as defined in Sec. 3.2 of [our paper](https://papers.nips.cc/paper/2020/file/8dc5983b8c4ef1d8fcd5f325f9a65511-Paper.pdf) with `--target_ratio`.
93 |
94 | - Example: Training ResNet-74 on CIFAR-100 with DFQ
95 | ```
96 | cd fractrain_cifar
97 | python train_dp.py --save_folder ./logs --arch cifar100_rnn_gate_74 --workers 4 --dataset cifar100 --datadir path-to-cifar100 --weight_bits 4 --target_ratio 3
98 | ```
99 |
100 | ### Training with FracTrain
101 | Specify the step (increment) of *cp* between two consective training stages with `--target_ratio_step` and other args follow the ones in PFQ and DFQ.
102 |
103 | - Example: Training ResNet-74 on CIFAR-100 with ***FracTrain***
104 | ```
105 | cd fractrain_cifar
106 | python train_frac.py --save_folder ./logs --arch cifar100_rnn_gate_74 --workers 4 --dataset cifar100 --datadir path-to-cifar100 --weight_bits 4 --target_ratio 1.5 --target_ratio_step 0.5 --initial_threshold 0.05 --decay 0.3
107 | ```
108 |
109 | Similarly, for training on ImageNet:
110 |
111 | - Example: Training ResNet-18 on ImageNet with ***FracTrain***
112 |
113 | ```
114 | cd fractrain_imagenet
115 | python train_frac.py --save_folder ./logs --arch resnet18_rnn --workers 32 --dataset imagenet --datadir path-to-imagenet --weight_bits 6 --target_ratio 3 --target_ratio_step 0.5 --initial_threshold 0.05 --decay 0.3
116 | ```
117 |
118 |
119 | ## Citation
120 | ```
121 | @article{fu2020fractrain,
122 | title={FracTrain: Fractionally Squeezing Bit Savings Both Temporally and Spatially for Efficient DNN Training},
123 | author={Fu, Yonggan and You, Haoran and Zhao, Yang and Wang, Yue and Li, Chaojian and Gopalakrishnan, Kailash and Wang, Zhangyang and Lin, Yingyan},
124 | journal={Advances in Neural Information Processing Systems},
125 | volume={33},
126 | year={2020}
127 | }
128 | ```
129 |
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: pytorch
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - anaconda
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - _pytorch_select=0.2=gpu_0
10 | - backcall=0.1.0=py36_0
11 | - blas=1.0=mkl
12 | - ca-certificates=2020.10.14=0
13 | - certifi=2020.6.20=py36_0
14 | - cffi=1.12.3=py36h2e261b9_0
15 | - cudatoolkit=10.2.89=hfd86e86_1
16 | - cudnn=7.6.5=cuda10.2_0
17 | - decorator=4.4.1=py_0
18 | - freetype=2.8=hab7d2ae_1
19 | - intel-openmp=2019.4=243
20 | - ipykernel=5.1.3=py36h39e3cac_0
21 | - ipython=7.9.0=py36h39e3cac_0
22 | - ipython_genutils=0.2.0=py36_0
23 | - jedi=0.15.1=py36_0
24 | - jpeg=9b=h024ee3a_2
25 | - jupyter_client=5.3.4=py36_0
26 | - jupyter_core=4.6.1=py36_0
27 | - libedit=3.1.20170329=hf8c457e_1001
28 | - libevent=2.0.22=hb7f436b_1002
29 | - libffi=3.2.1=hd88cf55_4
30 | - libgcc-ng=8.2.0=hdf63c60_1
31 | - libgfortran-ng=7.3.0=hdf63c60_0
32 | - libpng=1.6.37=hbc83047_0
33 | - libsodium=1.0.16=h1bed415_0
34 | - libstdcxx-ng=8.2.0=hdf63c60_1
35 | - libtiff=4.0.10=h2733197_2
36 | - mkl=2019.4=243
37 | - mkl-service=2.3.0=py36he904b0f_0
38 | - mkl_fft=1.0.12=py36ha843d7b_0
39 | - mkl_random=1.0.2=py36hd81dba3_0
40 | - mpi=1.0=mpich
41 | - mpi4py=3.0.3=py36h028fd6f_0
42 | - mpich=3.3.2=hc856adb_0
43 | - ncurses=6.1=hf484d3e_1002
44 | - ninja=1.9.0=py36hfd86e86_0
45 | - numpy=1.16.4=py36h7e9f1db_0
46 | - numpy-base=1.16.4=py36hde5b4d6_0
47 | - olefile=0.46=py36_0
48 | - openssl=1.0.2u=h7b6447c_0
49 | - pandas=0.24.2=py36he6710b0_0
50 | - parso=0.5.1=py_0
51 | - pexpect=4.7.0=py36_0
52 | - pickleshare=0.7.5=py36_0
53 | - pip=19.1.1=py36_0
54 | - prompt_toolkit=2.0.10=py_0
55 | - ptyprocess=0.6.0=py36_0
56 | - pycparser=2.19=py36_0
57 | - pygments=2.4.2=py_0
58 | - python=3.6.5=hc3d631a_2
59 | - python-dateutil=2.8.0=py36_0
60 | - pytorch=1.6.0=py3.6_cuda10.2.89_cudnn7.6.5_0
61 | - pytz=2019.1=py_0
62 | - pyzmq=18.1.0=py36he6710b0_0
63 | - readline=7.0=ha6073c6_4
64 | - setuptools=41.0.1=py36_0
65 | - sqlite=3.23.1=he433501_0
66 | - tk=8.6.8=hbc83047_0
67 | - tmux=2.9=h45300e9_0
68 | - torchvision=0.7.0=py36_cu102
69 | - tornado=6.0.3=py36h7b6447c_0
70 | - traitlets=4.3.3=py36_0
71 | - wcwidth=0.1.7=py36_0
72 | - wheel=0.33.4=py36_0
73 | - xz=5.2.4=h14c3975_4
74 | - zeromq=4.3.1=he6710b0_3
75 | - zlib=1.2.11=h7b6447c_3
76 | - zstd=1.3.7=h0b5b093_0
77 | - pip:
78 | - absl-py==0.9.0
79 | - cached-property==1.5.2
80 | - cachetools==4.0.0
81 | - chardet==3.0.4
82 | - click==7.1.2
83 | - configparser==5.0.1
84 | - cycler==0.10.0
85 | - cython==0.29.15
86 | - dnspython==2.0.0
87 | - docker-pycreds==0.4.0
88 | - dominate==2.4.0
89 | - easydict==1.9
90 | - efficientnet-pytorch==0.5.1
91 | - gitdb==4.0.5
92 | - gitpython==3.1.11
93 | - google-auth==1.10.0
94 | - google-auth-oauthlib==0.4.1
95 | - grpcio==1.26.0
96 | - h5py==3.1.0
97 | - idna==2.8
98 | - image-tools==1.0.0
99 | - joblib==0.13.2
100 | - jsonpatch==1.25
101 | - jsonpointer==2.0
102 | - kiwisolver==1.1.0
103 | - kornia==0.2.0
104 | - lmdb==0.98
105 | - markdown==3.1.1
106 | - matplotlib==3.1.0
107 | - mpmath==1.1.0
108 | - oauthlib==3.1.0
109 | - opencv-python==4.1.2.30
110 | - paho-mqtt==1.5.1
111 | - pillow==6.1.0
112 | - promise==2.3
113 | - protobuf==3.14.0
114 | - psutil==5.7.3
115 | - pyasn1==0.4.8
116 | - pyasn1-modules==0.2.8
117 | - pycocotools==2.0.0
118 | - pyparsing==2.4.0
119 | - python-etcd==0.4.5
120 | - python-graphviz==0.13.2
121 | - pyyaml==5.3
122 | - requests==2.22.0
123 | - requests-oauthlib==1.3.0
124 | - rsa==4.0
125 | - scikit-learn==0.22.2.post1
126 | - scipy==1.1.0
127 | - sentry-sdk==0.19.5
128 | - setproctitle==1.2.1
129 | - shortuuid==1.0.1
130 | - six==1.15.0
131 | - sklearn==0.0
132 | - smmap==3.0.4
133 | - subprocess32==3.5.4
134 | - sympy==1.5.1
135 | - tabulate==0.8.3
136 | - tensorboard==2.1.0
137 | - tensorboardx==2.0
138 | - thop==0.0.31-1912272122
139 | - torchcontrib==0.0.2
140 | - torchfile==0.1.0
141 | - tqdm==4.41.1
142 | - urllib3==1.25.7
143 | - visdom==0.1.8.9
144 | - wandb==0.10.12
145 | - watchdog==1.0.1
146 | - websocket-client==0.57.0
147 | - werkzeug==0.16.0
148 | - yml==0.0.1
149 | prefix: /home/yf22/anaconda3/envs/pytorch
150 |
151 |
--------------------------------------------------------------------------------
/fractrain_cifar/compute_flops.py:
--------------------------------------------------------------------------------
1 | # Code from https://github.com/simochen/model-tools.
2 | import numpy as np
3 | import os
4 |
5 | import torch
6 | import torchvision
7 | import torch.nn as nn
8 | from torch.autograd import Variable
9 |
10 | from models_base import cifar10_resnet_38, cifar10_resnet_74, cifar100_resnet_38, cifar100_resnet_74
11 |
12 | def print_model_param_nums(model=None):
13 | if model == None:
14 | model = torchvision.models.alexnet()
15 | total = sum([param.nelement() if param.requires_grad else 0 for param in model.parameters()])
16 | print(' + Number of params: %.4fM' % (total / 1e6))
17 |
18 | def count_model_param_flops(model=None, input_res=32, multiply_adds=True):
19 |
20 | prods = {}
21 | def save_hook(name):
22 | def hook_per(self, input, output):
23 | prods[name] = np.prod(input[0].shape)
24 | return hook_per
25 |
26 | list_1=[]
27 | def simple_hook(self, input, output):
28 | list_1.append(np.prod(input[0].shape))
29 | list_2={}
30 | def simple_hook2(self, input, output):
31 | list_2['names'] = np.prod(input[0].shape)
32 |
33 |
34 | list_conv=[]
35 | def conv_hook(self, input, output):
36 | batch_size, input_channels, input_height, input_width = input[0].size()
37 | output_channels, output_height, output_width = output[0].size()
38 |
39 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups)
40 | bias_ops = 1 if self.bias is not None else 0
41 |
42 | params = output_channels * (kernel_ops + bias_ops)
43 | # flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size
44 |
45 | num_weight_params = (self.weight.data != 0).float().sum()
46 | flops = (num_weight_params * (2 if multiply_adds else 1) + bias_ops * output_channels) * output_height * output_width * batch_size
47 |
48 | list_conv.append(flops)
49 |
50 | list_linear=[]
51 | def linear_hook(self, input, output):
52 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1
53 |
54 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1)
55 | bias_ops = self.bias.nelement()
56 |
57 | flops = batch_size * (weight_ops + bias_ops)
58 | list_linear.append(flops)
59 |
60 | list_bn=[]
61 | def bn_hook(self, input, output):
62 | list_bn.append(input[0].nelement() * 2)
63 |
64 | list_relu=[]
65 | def relu_hook(self, input, output):
66 | list_relu.append(input[0].nelement())
67 |
68 | list_pooling=[]
69 | def pooling_hook(self, input, output):
70 | batch_size, input_channels, input_height, input_width = input[0].size()
71 | output_channels, output_height, output_width = output[0].size()
72 |
73 | kernel_ops = self.kernel_size * self.kernel_size
74 | bias_ops = 0
75 | params = 0
76 | flops = (kernel_ops + bias_ops) * output_channels * output_height * output_width * batch_size
77 |
78 | list_pooling.append(flops)
79 |
80 | list_upsample=[]
81 |
82 | # For bilinear upsample
83 | def upsample_hook(self, input, output):
84 | batch_size, input_channels, input_height, input_width = input[0].size()
85 | output_channels, output_height, output_width = output[0].size()
86 |
87 | flops = output_height * output_width * output_channels * batch_size * 12
88 | list_upsample.append(flops)
89 |
90 | def foo(net):
91 | childrens = list(net.children())
92 | if not childrens:
93 | if isinstance(net, torch.nn.Conv2d):
94 | net.register_forward_hook(conv_hook)
95 | if isinstance(net, torch.nn.Linear):
96 | net.register_forward_hook(linear_hook)
97 | if isinstance(net, torch.nn.BatchNorm2d):
98 | net.register_forward_hook(bn_hook)
99 | if isinstance(net, torch.nn.ReLU):
100 | net.register_forward_hook(relu_hook)
101 | if isinstance(net, torch.nn.MaxPool2d) or isinstance(net, torch.nn.AvgPool2d):
102 | net.register_forward_hook(pooling_hook)
103 | if isinstance(net, torch.nn.Upsample):
104 | net.register_forward_hook(upsample_hook)
105 | return
106 | for c in childrens:
107 | foo(c)
108 |
109 | if model == None:
110 | model = torchvision.models.alexnet()
111 | foo(model)
112 | input = Variable(torch.rand(3,input_res,input_res).unsqueeze(0), requires_grad = True)
113 | out = model(input)
114 |
115 |
116 | total_flops = (sum(list_conv) + sum(list_linear) + sum(list_bn) + sum(list_relu) + sum(list_pooling) + sum(list_upsample))
117 |
118 | print(' + Number of FLOPs: %.2fG' % (total_flops / 1e9))
119 |
120 | return total_flops
121 |
122 |
123 | if __name__ == '__main__':
124 | from models import *
125 | model = cifar10_mobilenet_v2()
126 |
127 | blocks = []
128 |
129 | for module in model.modules():
130 | if isinstance(module, Block):
131 | blocks.append(module)
132 |
133 | conv_list = []
134 | def conv_hook(self, input, output):
135 | batch_size, input_channels, input_height, input_width = input[0].size()
136 | output_channels, output_height, output_width = output[0].size()
137 |
138 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups)
139 | bias_ops = 1 if self.bias is not None else 0
140 |
141 | params = output_channels * (kernel_ops + bias_ops)
142 | # flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size
143 |
144 | num_weight_params = (self.weight.data != 0).float().sum()
145 | flops = (num_weight_params * 2 + bias_ops * output_channels) * output_height * output_width * batch_size
146 |
147 | conv_list.append(flops)
148 |
149 | dws_list = []
150 | def dws_hook(self, input, output):
151 | batch_size, input_channels, input_height, input_width = input[0].size()
152 | output_channels, output_height, output_width = output[0].size()
153 |
154 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups)
155 | bias_ops = 1 if self.bias is not None else 0
156 |
157 | params = output_channels * (kernel_ops + bias_ops)
158 | # flops = (kernel_ops * (2 if multiply_adds else 1) + bias_ops) * output_channels * output_height * output_width * batch_size
159 |
160 | num_weight_params = (self.weight.data != 0).float().sum()
161 | flops = (num_weight_params * 2 + bias_ops * output_channels) * output_height * output_width * batch_size
162 |
163 | dws_list.append(flops)
164 |
165 | block_flops_list = []
166 | dws_list = []
167 | hook_list = []
168 | for block in blocks:
169 | for module in block.modules():
170 | if isinstance(module, torch.nn.Conv2d):
171 | if module.groups == 1:
172 | hook = module.register_forward_hook(conv_hook)
173 | hook_list.append(hook)
174 | else:
175 | hook = module.register_forward_hook(dws_hook)
176 | hook_list.append(hook)
177 |
178 | input = Variable(torch.rand(3, 32, 32).unsqueeze(0), requires_grad = True)
179 | out = model(input, 0, 0)
180 |
181 | block_flops_list.append(sum(conv_list))
182 |
183 | for hook in hook_list:
184 | hook.remove()
185 |
186 | conv_list = []
187 |
188 | print(block_flops_list)
189 | print(dws_list)
190 |
191 | model_info = {'conv':block_flops_list, 'dws':dws_list}
192 |
193 | np.save('model_info_mbv2.npy', model_info)
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
--------------------------------------------------------------------------------
/fractrain_cifar/data.py:
--------------------------------------------------------------------------------
1 | """prepare CIFAR and SVHN
2 | """
3 |
4 | from __future__ import print_function
5 |
6 | import torch
7 | import torchvision
8 | import torchvision.transforms as transforms
9 | import numpy as np
10 |
11 |
12 | crop_size = 32
13 | padding = 4
14 |
15 |
16 | def prepare_train_data(dataset='cifar10', datadir='/home/yf22/dataset', batch_size=128,
17 | shuffle=True, num_workers=4):
18 |
19 | if 'cifar' in dataset:
20 | transform_train = transforms.Compose([
21 | transforms.RandomCrop(crop_size, padding=padding),
22 | transforms.RandomHorizontalFlip(),
23 | transforms.ToTensor(),
24 | transforms.Normalize((0.4914, 0.4822, 0.4465),
25 | (0.2023, 0.1994, 0.2010)),
26 | ])
27 |
28 | trainset = torchvision.datasets.__dict__[dataset.upper()](
29 | root=datadir, train=True, download=True, transform=transform_train)
30 | train_loader = torch.utils.data.DataLoader(trainset,
31 | batch_size=batch_size,
32 | shuffle=shuffle,
33 | num_workers=num_workers)
34 | elif 'svhn' in dataset:
35 | transform_train =transforms.Compose([
36 | transforms.ToTensor(),
37 | transforms.Normalize((0.4377, 0.4438, 0.4728),
38 | (0.1980, 0.2010, 0.1970)),
39 | ])
40 | trainset = torchvision.datasets.__dict__[dataset.upper()](
41 | root=datadir,
42 | split='train',
43 | download=True,
44 | transform=transform_train
45 | )
46 |
47 | transform_extra = transforms.Compose([
48 | transforms.ToTensor(),
49 | transforms.Normalize((0.4300, 0.4284, 0.4427),
50 | (0.1963, 0.1979, 0.1995))
51 |
52 | ])
53 |
54 | extraset = torchvision.datasets.__dict__[dataset.upper()](
55 | root=datadir,
56 | split='extra',
57 | download=True,
58 | transform = transform_extra
59 | )
60 |
61 | total_data = torch.utils.data.ConcatDataset([trainset, extraset])
62 |
63 | train_loader = torch.utils.data.DataLoader(total_data,
64 | batch_size=batch_size,
65 | shuffle=shuffle,
66 | num_workers=num_workers)
67 | else:
68 | train_loader = None
69 | return train_loader
70 |
71 |
72 | def prepare_test_data(dataset='cifar10', datadir='/home/yf22/dataset', batch_size=128,
73 | shuffle=False, num_workers=4):
74 |
75 | if 'cifar' in dataset:
76 | transform_test = transforms.Compose([
77 | transforms.ToTensor(),
78 | transforms.Normalize((0.4914, 0.4822, 0.4465),
79 | (0.2023, 0.1994, 0.2010)),
80 | ])
81 |
82 | testset = torchvision.datasets.__dict__[dataset.upper()](root=datadir,
83 | train=False,
84 | download=True,
85 | transform=transform_test)
86 | test_loader = torch.utils.data.DataLoader(testset,
87 | batch_size=batch_size,
88 | shuffle=shuffle,
89 | num_workers=num_workers)
90 | elif 'svhn' in dataset:
91 | transform_test = transforms.Compose([
92 | transforms.ToTensor(),
93 | transforms.Normalize((0.4524, 0.4525, 0.4690),
94 | (0.2194, 0.2266, 0.2285)),
95 | ])
96 | testset = torchvision.datasets.__dict__[dataset.upper()](
97 | root=datadir,
98 | split='test',
99 | download=True,
100 | transform=transform_test)
101 | np.place(testset.labels, testset.labels == 10, 0)
102 | test_loader = torch.utils.data.DataLoader(testset,
103 | batch_size=batch_size,
104 | shuffle=shuffle,
105 | num_workers=num_workers)
106 | else:
107 | test_loader = None
108 | return test_loader
109 |
--------------------------------------------------------------------------------
/fractrain_cifar/model_info_mbv2.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/fractrain_cifar/model_info_mbv2.npy
--------------------------------------------------------------------------------
/fractrain_cifar/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/fractrain_cifar/modules/__init__.py
--------------------------------------------------------------------------------
/fractrain_cifar/modules/bwn.py:
--------------------------------------------------------------------------------
1 | """
2 | Bounded weight norm
3 | Weight Normalization from https://arxiv.org/abs/1602.07868
4 | taken and adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/weight_norm.py
5 | """
6 | import torch
7 | from torch.nn.parameter import Parameter
8 | from torch.autograd import Variable, Function
9 | import torch.nn as nn
10 |
11 |
12 | def gather_params(self, memo=None, param_func=lambda s: s._parameters.values()):
13 | if memo is None:
14 | memo = set()
15 | for p in param_func(self):
16 | if p is not None and p not in memo:
17 | memo.add(p)
18 | yield p
19 | for m in self.children():
20 | for p in gather_params(m, memo, param_func):
21 | yield p
22 |
23 | nn.Module.gather_params = gather_params
24 |
25 |
26 | def _norm(x, dim, p=2):
27 | """Computes the norm over all dimensions except dim"""
28 | if p == float('inf'): # infinity norm
29 | func = lambda x, dim: x.abs().max(dim=dim)[0]
30 | else:
31 | func = lambda x, dim: torch.norm(x, dim=dim, p=p)
32 | if dim is None:
33 | return x.norm(p=p)
34 | elif dim == 0:
35 | output_size = (x.size(0),) + (1,) * (x.dim() - 1)
36 | return func(x.contiguous().view(x.size(0), -1), 1).view(*output_size)
37 | elif dim == x.dim() - 1:
38 | output_size = (1,) * (x.dim() - 1) + (x.size(-1),)
39 | return func(x.contiguous().view(-1, x.size(-1)), 0).view(*output_size)
40 | else:
41 | return _norm(x.transpose(0, dim), 0).transpose(0, dim)
42 |
43 |
44 | def _mean(p, dim):
45 | """Computes the mean over all dimensions except dim"""
46 | if dim is None:
47 | return p.mean()
48 | elif dim == 0:
49 | output_size = (p.size(0),) + (1,) * (p.dim() - 1)
50 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size)
51 | elif dim == p.dim() - 1:
52 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),)
53 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size)
54 | else:
55 | return _mean(p.transpose(0, dim), 0).transpose(0, dim)
56 |
57 |
58 | class BoundedWeighNorm(object):
59 |
60 | def __init__(self, name, dim, p):
61 | self.name = name
62 | self.dim = dim
63 | self.p = p
64 |
65 | def compute_weight(self, module):
66 | v = getattr(module, self.name + '_v')
67 | pre_norm = getattr(module, self.name + '_prenorm')
68 | return v * (pre_norm / _norm(v, self.dim, p=self.p))
69 |
70 | @staticmethod
71 | def apply(module, name, dim, p):
72 | fn = BoundedWeighNorm(name, dim, p)
73 |
74 | weight = getattr(module, name)
75 |
76 | # remove w from parameter list
77 | del module._parameters[name]
78 |
79 | prenorm = _norm(weight, dim, p=p).mean()
80 | module.register_buffer(name + '_prenorm', prenorm.detach())
81 | pre_norm = getattr(module, name + '_prenorm')
82 | print(pre_norm)
83 | module.register_parameter(name + '_v', Parameter(weight.data))
84 | setattr(module, name, fn.compute_weight(module))
85 |
86 | # recompute weight before every forward()
87 | module.register_forward_pre_hook(fn)
88 |
89 | def gather_normed_params(self, memo=None, param_func=lambda s: fn.compute_weight(s)):
90 | return gather_params(self, memo, param_func)
91 | module.gather_params = gather_normed_params
92 | return fn
93 |
94 | def remove(self, module):
95 | weight = self.compute_weight(module)
96 | delattr(module, self.name)
97 | del module._parameters[self.name + '_prenorm']
98 | del module._parameters[self.name + '_v']
99 | module.register_parameter(self.name, Parameter(weight.data))
100 |
101 | def __call__(self, module, inputs):
102 | setattr(module, self.name, self.compute_weight(module))
103 |
104 |
105 | def weight_norm(module, name='weight', dim=0, p=2):
106 | r"""Applies weight normalization to a parameter in the given module.
107 |
108 | .. math::
109 | \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
110 |
111 | Weight normalization is a reparameterization that decouples the magnitude
112 | of a weight tensor from its direction. This replaces the parameter specified
113 | by `name` (e.g. "weight") with two parameters: one specifying the magnitude
114 | (e.g. "weight_g") and one specifying the direction (e.g. "weight_v").
115 | Weight normalization is implemented via a hook that recomputes the weight
116 | tensor from the magnitude and direction before every :meth:`~Module.forward`
117 | call.
118 |
119 | By default, with `dim=0`, the norm is computed independently per output
120 | channel/plane. To compute a norm over the entire weight tensor, use
121 | `dim=None`.
122 |
123 | See https://arxiv.org/abs/1602.07868
124 |
125 | Args:
126 | module (nn.Module): containing module
127 | name (str, optional): name of weight parameter
128 | dim (int, optional): dimension over which to compute the norm
129 |
130 | Returns:
131 | The original module with the weight norm hook
132 |
133 | Example::
134 |
135 | >>> m = weight_norm(nn.Linear(20, 40), name='weight')
136 | Linear (20 -> 40)
137 | >>> m.weight_g.size()
138 | torch.Size([40, 1])
139 | >>> m.weight_v.size()
140 | torch.Size([40, 20])
141 |
142 | """
143 | BoundedWeighNorm.apply(module, name, dim, p)
144 | return module
145 |
146 |
147 | def remove_weight_norm(module, name='weight'):
148 | r"""Removes the weight normalization reparameterization from a module.
149 |
150 | Args:
151 | module (nn.Module): containing module
152 | name (str, optional): name of weight parameter
153 |
154 | Example:
155 | >>> m = weight_norm(nn.Linear(20, 40))
156 | >>> remove_weight_norm(m)
157 | """
158 | for k, hook in module._forward_pre_hooks.items():
159 | if isinstance(hook, BoundedWeighNorm) and hook.name == name:
160 | hook.remove(module)
161 | del module._forward_pre_hooks[k]
162 | return module
163 |
164 | raise ValueError("weight_norm of '{}' not found in {}"
165 | .format(name, module))
166 |
--------------------------------------------------------------------------------
/fractrain_cifar/modules/quantize.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.autograd.function import InplaceFunction, Function
7 |
8 | QParams = namedtuple('QParams', ['range', 'zero_point', 'num_bits'])
9 |
10 | _DEFAULT_FLATTEN = (1, -1)
11 | _DEFAULT_FLATTEN_GRAD = (0, -1)
12 |
13 |
14 | def _deflatten_as(x, x_full):
15 | shape = list(x.shape) + [1] * (x_full.dim() - x.dim())
16 | return x.view(*shape)
17 |
18 |
19 | def calculate_qparams(x, num_bits, flatten_dims=_DEFAULT_FLATTEN, reduce_dim=0, reduce_type='mean', keepdim=False, true_zero=False):
20 | with torch.no_grad():
21 | x_flat = x.flatten(*flatten_dims)
22 | if x_flat.dim() == 1:
23 | min_values = _deflatten_as(x_flat.min(), x)
24 | max_values = _deflatten_as(x_flat.max(), x)
25 | else:
26 | min_values = _deflatten_as(x_flat.min(-1)[0], x)
27 | max_values = _deflatten_as(x_flat.max(-1)[0], x)
28 |
29 | if reduce_dim is not None:
30 | if reduce_type == 'mean':
31 | min_values = min_values.mean(reduce_dim, keepdim=keepdim)
32 | max_values = max_values.mean(reduce_dim, keepdim=keepdim)
33 | else:
34 | min_values = min_values.min(reduce_dim, keepdim=keepdim)[0]
35 | max_values = max_values.max(reduce_dim, keepdim=keepdim)[0]
36 |
37 | range_values = max_values - min_values
38 | return QParams(range=range_values, zero_point=min_values,
39 | num_bits=num_bits)
40 |
41 |
42 | class UniformQuantize(InplaceFunction):
43 |
44 | @staticmethod
45 | def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN,
46 | reduce_dim=0, dequantize=True, signed=False, stochastic=True, inplace=False):
47 |
48 | ctx.inplace = inplace
49 |
50 | if ctx.inplace:
51 | ctx.mark_dirty(input)
52 | output = input
53 | else:
54 | output = input.clone()
55 |
56 | if qparams is None:
57 | assert num_bits is not None, "either provide qparams of num_bits to quantize"
58 | qparams = calculate_qparams(
59 | input, num_bits=num_bits, flatten_dims=flatten_dims, reduce_dim=reduce_dim)
60 |
61 | zero_point = qparams.zero_point
62 | num_bits = qparams.num_bits
63 | qmin = -(2.**(num_bits - 1)) if signed else 0.
64 | qmax = qmin + 2.**num_bits - 1.
65 | scale = qparams.range / (qmax - qmin)
66 |
67 | min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
68 | scale = torch.max(scale, min_scale)
69 |
70 | with torch.no_grad():
71 | output.add_(qmin * scale - zero_point).div_(scale)
72 | if stochastic:
73 | noise = output.new(output.shape).uniform_(-0.5, 0.5)
74 | output.add_(noise)
75 | # quantize
76 | output.clamp_(qmin, qmax).round_()
77 |
78 | if dequantize:
79 | output.mul_(scale).add_(
80 | zero_point - qmin * scale) # dequantize
81 | return output
82 |
83 | @staticmethod
84 | def backward(ctx, grad_output):
85 | # straight-through estimator
86 | grad_input = grad_output
87 | return grad_input, None, None, None, None, None, None, None, None
88 |
89 |
90 | class UniformQuantizeGrad(InplaceFunction):
91 |
92 | @staticmethod
93 | def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD,
94 | reduce_dim=0, dequantize=True, signed=False, stochastic=True):
95 | ctx.num_bits = num_bits
96 | ctx.qparams = qparams
97 | ctx.flatten_dims = flatten_dims
98 | ctx.stochastic = stochastic
99 | ctx.signed = signed
100 | ctx.dequantize = dequantize
101 | ctx.reduce_dim = reduce_dim
102 | ctx.inplace = False
103 | return input
104 |
105 | @staticmethod
106 | def backward(ctx, grad_output):
107 | qparams = ctx.qparams
108 | with torch.no_grad():
109 | if qparams is None:
110 | assert ctx.num_bits is not None, "either provide qparams of num_bits to quantize"
111 | qparams = calculate_qparams(
112 | grad_output, num_bits=ctx.num_bits, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim, reduce_type='extreme')
113 |
114 | grad_input = quantize(grad_output, num_bits=None,
115 | qparams=qparams, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim,
116 | dequantize=True, signed=ctx.signed, stochastic=ctx.stochastic, inplace=False)
117 | return grad_input, None, None, None, None, None, None, None
118 |
119 |
120 | def conv2d_biprec(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, num_bits_grad=None):
121 | out1 = F.conv2d(input.detach(), weight, bias,
122 | stride, padding, dilation, groups)
123 | out2 = F.conv2d(input, weight.detach(), bias.detach() if bias is not None else None,
124 | stride, padding, dilation, groups)
125 | out2 = quantize_grad(out2, num_bits=num_bits_grad, flatten_dims=(1, -1))
126 | return out1 + out2 - out1.detach()
127 |
128 |
129 | def linear_biprec(input, weight, bias=None, num_bits_grad=None):
130 | out1 = F.linear(input.detach(), weight, bias)
131 | out2 = F.linear(input, weight.detach(), bias.detach()
132 | if bias is not None else None)
133 | out2 = quantize_grad(out2, num_bits=num_bits_grad)
134 | return out1 + out2 - out1.detach()
135 |
136 |
137 | def quantize(x, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN, reduce_dim=0, dequantize=True, signed=False, stochastic=False, inplace=False):
138 | if qparams:
139 | if qparams.num_bits:
140 | return UniformQuantize().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic, inplace)
141 | elif num_bits:
142 | return UniformQuantize().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic, inplace)
143 |
144 | return x
145 |
146 |
147 | def quantize_grad(x, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD, reduce_dim=0, dequantize=True, signed=False, stochastic=True):
148 | if qparams:
149 | if qparams.num_bits:
150 | return UniformQuantizeGrad().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic)
151 | elif num_bits:
152 | return UniformQuantizeGrad().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic)
153 |
154 | return x
155 |
156 |
157 | class QuantMeasure(nn.Module):
158 | """docstring for QuantMeasure."""
159 |
160 | def __init__(self, shape_measure=(1,), flatten_dims=_DEFAULT_FLATTEN,
161 | inplace=False, dequantize=True, stochastic=False, momentum=0.9, measure=False):
162 | super(QuantMeasure, self).__init__()
163 | self.register_buffer('running_zero_point', torch.zeros(*shape_measure))
164 | self.register_buffer('running_range', torch.zeros(*shape_measure))
165 | self.measure = measure
166 | if self.measure:
167 | self.register_buffer('num_measured', torch.zeros(1))
168 | self.flatten_dims = flatten_dims
169 | self.momentum = momentum
170 | self.dequantize = dequantize
171 | self.stochastic = stochastic
172 | self.inplace = inplace
173 |
174 | def forward(self, input, num_bits, qparams=None):
175 |
176 | if self.training or self.measure:
177 | if qparams is None:
178 | qparams = calculate_qparams(
179 | input, num_bits=num_bits, flatten_dims=self.flatten_dims, reduce_dim=0, reduce_type='extreme')
180 | with torch.no_grad():
181 | if self.measure:
182 | momentum = self.num_measured / (self.num_measured + 1)
183 | self.num_measured += 1
184 | else:
185 | momentum = self.momentum
186 | self.running_zero_point.mul_(momentum).add_(
187 | qparams.zero_point * (1 - momentum))
188 | self.running_range.mul_(momentum).add_(
189 | qparams.range * (1 - momentum))
190 | else:
191 | qparams = QParams(range=self.running_range,
192 | zero_point=self.running_zero_point, num_bits=num_bits)
193 | if self.measure:
194 | return input
195 | else:
196 | q_input = quantize(input, qparams=qparams, dequantize=self.dequantize,
197 | stochastic=self.stochastic, inplace=self.inplace)
198 | return q_input
199 |
200 |
201 | class QConv2d(nn.Conv2d):
202 | """docstring for QConv2d."""
203 |
204 | def __init__(self, in_channels, out_channels, kernel_size,
205 | stride=1, padding=0, dilation=1, groups=1, bias=True, momentum=0.1, quant_act_forward=0, quant_act_backward=0, quant_grad_act_error=0, quant_grad_act_gc=0, weight_bits=0, fix_prec=False):
206 | super(QConv2d, self).__init__(in_channels, out_channels, kernel_size,
207 | stride, padding, dilation, groups, bias)
208 |
209 | self.quantize_input_fw = QuantMeasure(shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1), momentum=momentum)
210 | self.quantize_input_bw = QuantMeasure(shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1), momentum=momentum)
211 | self.quant_act_forward = quant_act_forward
212 | self.quant_act_backward = quant_act_backward
213 | self.quant_grad_act_error = quant_grad_act_error
214 | self.quant_grad_act_gc = quant_grad_act_gc
215 | self.weight_bits = weight_bits
216 | self.fix_prec = fix_prec
217 | self.stride = stride
218 |
219 |
220 | def forward(self, input, num_bits, num_grad_bits):
221 | if num_bits == 0:
222 | output = F.conv2d(input, self.weight, self.bias, self.stride,self.padding, self.dilation, self.groups)
223 | return output
224 |
225 | if self.bias is not None:
226 | qbias = quantize(
227 | self.bias, num_bits=self.num_bits_weight + self.num_bits,
228 | flatten_dims=(0, -1))
229 | else:
230 | qbias = None
231 |
232 | if self.fix_prec:
233 | if self.quant_act_forward or self.quant_act_backward or self.quant_grad_act_error or self.quant_grad_act_gc or self.weight_bits:
234 | weight_qparams = calculate_qparams(self.weight, num_bits=self.weight_bits, flatten_dims=(1, -1), reduce_dim=None)
235 | qweight = quantize(self.weight, qparams=weight_qparams)
236 |
237 | qinput_fw = self.quantize_input_fw(input, self.quant_act_forward)
238 | qinput_bw = self.quantize_input_bw(input, self.quant_act_backward)
239 |
240 | error_bits = self.quant_grad_act_error
241 | gc_bits = self.quant_grad_act_gc
242 | output = self.conv2d_quant_act(qinput_fw, qinput_bw, qweight, qbias, self.stride, self.padding, self.dilation, self.groups, error_bits, gc_bits)
243 |
244 | else:
245 | qinput = self.quantize_input_fw(input, num_bits)
246 | weight_qparams = calculate_qparams(self.weight, num_bits=num_bits, flatten_dims=(1, -1), reduce_dim=None)
247 | qweight = quantize(self.weight, qparams=weight_qparams)
248 | output = F.conv2d(qinput, qweight, qbias, self.stride, self.padding, self.dilation, self.groups)
249 | output = quantize_grad(output, num_bits=num_grad_bits, flatten_dims=(1, -1))
250 |
251 | return output
252 |
253 | weight_qparams = calculate_qparams(self.weight, num_bits=self.weight_bits, flatten_dims=(1, -1), reduce_dim=None)
254 | qweight = quantize(self.weight, qparams=weight_qparams)
255 |
256 | qinput = self.quantize_input_fw(input, num_bits)
257 | output = F.conv2d(qinput, qweight, qbias, self.stride, self.padding, self.dilation, self.groups)
258 | output = quantize_grad(output, num_bits=num_grad_bits, flatten_dims=(1, -1))
259 |
260 | # if self.quant_act_forward == -1:
261 | # qinput_fw = self.quantize_input_fw(input, num_bits)
262 | # else:
263 | # qinput_fw = self.quantize_input_fw(input, self.quant_act_forward)
264 |
265 | # if self.quant_act_backward == -1:
266 | # qinput_bw = self.quantize_input_bw(input, num_bits)
267 | # else:
268 | # qinput_bw = self.quantize_input_bw(input, self.quant_act_backward)
269 |
270 | # if self.quant_grad_act_error == -1:
271 | # error_bits = num_grad_bits
272 | # else:
273 | # error_bits = self.quant_grad_act_error
274 |
275 | # if self.quant_grad_act_gc == -1:
276 | # gc_bits = num_grad_bits
277 | # else:
278 | # gc_bits = self.quant_grad_act_gc
279 |
280 | # output = self.conv2d_quant_act(qinput_fw, qinput_bw, qweight, qbias, self.stride, self.padding, self.dilation, self.groups, error_bits, gc_bits)
281 |
282 | return output
283 |
284 |
285 | def conv2d_quant_act(self, input_fw, input_bw, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, error_bits=0, gc_bits=0):
286 | out1 = F.conv2d(input_fw, weight.detach(), bias.detach() if bias is not None else None,
287 | stride, padding, dilation, groups)
288 | out2 = F.conv2d(input_bw.detach(), weight, bias,
289 | stride, padding, dilation, groups)
290 | out1 = quantize_grad(out1, num_bits=error_bits)
291 | out2 = quantize_grad(out2, num_bits=gc_bits)
292 | return out1 + out2 - out2.detach()
293 |
294 |
295 | class QLinear(nn.Linear):
296 | """docstring for QConv2d."""
297 |
298 | def __init__(self, in_features, out_features, bias=True, num_bits=8, num_bits_weight=8, num_bits_grad=8, biprecision=True):
299 | super(QLinear, self).__init__(in_features, out_features, bias)
300 | self.num_bits = num_bits
301 | self.num_bits_weight = num_bits_weight or num_bits
302 | self.num_bits_grad = num_bits_grad
303 | self.biprecision = biprecision
304 | self.quantize_input = QuantMeasure(self.num_bits)
305 |
306 | def forward(self, input):
307 | qinput = self.quantize_input(input)
308 | weight_qparams = calculate_qparams(
309 | self.weight, num_bits=self.num_bits_weight, flatten_dims=(1, -1), reduce_dim=None)
310 | qweight = quantize(self.weight, qparams=weight_qparams)
311 | if self.bias is not None:
312 | qbias = quantize(
313 | self.bias, num_bits=self.num_bits_weight + self.num_bits,
314 | flatten_dims=(0, -1))
315 | else:
316 | qbias = None
317 |
318 | if not self.biprecision or self.num_bits_grad is None:
319 | output = F.linear(qinput, qweight, qbias)
320 | if self.num_bits_grad is not None:
321 | output = quantize_grad(
322 | output, num_bits=self.num_bits_grad)
323 | else:
324 | output = linear_biprec(qinput, qweight, qbias, self.num_bits_grad)
325 | return output
326 |
327 |
328 | class RangeBN(nn.Module):
329 | # this is normalized RangeBN
330 |
331 | def __init__(self, num_features, dim=1, momentum=0.1, affine=True, num_chunks=16, eps=1e-5, num_bits=8, num_bits_grad=8):
332 | super(RangeBN, self).__init__()
333 | self.register_buffer('running_mean', torch.zeros(num_features))
334 | self.register_buffer('running_var', torch.zeros(num_features))
335 |
336 | self.momentum = momentum
337 | self.dim = dim
338 | if affine:
339 | self.bias = nn.Parameter(torch.Tensor(num_features))
340 | self.weight = nn.Parameter(torch.Tensor(num_features))
341 | self.num_bits = num_bits
342 | self.num_bits_grad = num_bits_grad
343 | self.quantize_input = QuantMeasure(inplace=True, shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1))
344 | self.eps = eps
345 | self.num_chunks = num_chunks
346 | self.reset_params()
347 |
348 | def reset_params(self):
349 | if self.weight is not None:
350 | self.weight.data.uniform_()
351 | if self.bias is not None:
352 | self.bias.data.zero_()
353 |
354 | def forward(self, x, num_bits, num_grad_bits):
355 | x = self.quantize_input(x, num_bits)
356 | if x.dim() == 2: # 1d
357 | x = x.unsqueeze(-1,).unsqueeze(-1)
358 |
359 | if self.training:
360 | B, C, H, W = x.shape
361 | y = x.transpose(0, 1).contiguous() # C x B x H x W
362 | y = y.view(C, self.num_chunks, (B * H * W) // self.num_chunks)
363 | mean_max = y.max(-1)[0].mean(-1) # C
364 | mean_min = y.min(-1)[0].mean(-1) # C
365 | mean = y.view(C, -1).mean(-1) # C
366 | scale_fix = (0.5 * 0.35) * (1 + (math.pi * math.log(4)) **
367 | 0.5) / ((2 * math.log(y.size(-1))) ** 0.5)
368 |
369 | scale = (mean_max - mean_min) * scale_fix
370 | with torch.no_grad():
371 | self.running_mean.mul_(self.momentum).add_(
372 | mean * (1 - self.momentum))
373 |
374 | self.running_var.mul_(self.momentum).add_(
375 | scale * (1 - self.momentum))
376 | else:
377 | mean = self.running_mean
378 | scale = self.running_var
379 | # scale = quantize(scale, num_bits=self.num_bits, min_value=float(
380 | # scale.min()), max_value=float(scale.max()))
381 | out = (x - mean.view(1, -1, 1, 1)) / \
382 | (scale.view(1, -1, 1, 1) + self.eps)
383 |
384 | if self.weight is not None:
385 | qweight = self.weight
386 | # qweight = quantize(self.weight, num_bits=self.num_bits,
387 | # min_value=float(self.weight.min()),
388 | # max_value=float(self.weight.max()))
389 | out = out * qweight.view(1, -1, 1, 1)
390 |
391 | if self.bias is not None:
392 | qbias = self.bias
393 | # qbias = quantize(self.bias, num_bits=self.num_bits)
394 | out = out + qbias.view(1, -1, 1, 1)
395 | if num_grad_bits:
396 | out = quantize_grad(
397 | out, num_bits=num_grad_bits, flatten_dims=(1, -1))
398 |
399 | if out.size(3) == 1 and out.size(2) == 1:
400 | out = out.squeeze(-1).squeeze(-1)
401 | return out
402 |
403 |
404 | class RangeBN1d(RangeBN):
405 | # this is normalized RangeBN
406 |
407 | def __init__(self, num_features, dim=1, momentum=0.1, affine=True, num_chunks=16, eps=1e-5, num_bits=8, num_bits_grad=8):
408 | super(RangeBN1d, self).__init__(num_features, dim, momentum,
409 | affine, num_chunks, eps, num_bits, num_bits_grad)
410 | self.quantize_input = QuantMeasure(
411 | self.num_bits, inplace=True, shape_measure=(1, 1), flatten_dims=(1, -1))
412 |
413 | if __name__ == '__main__':
414 | x = torch.rand(2, 3)
415 | x_q = quantize(x, flatten_dims=(-1), num_bits=8, dequantize=True)
416 | print(x)
417 | print(x_q)
--------------------------------------------------------------------------------
/fractrain_cifar/modules/rnlu.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch.autograd.function import InplaceFunction
4 | from torch.autograd import Variable
5 | import torch.nn as nn
6 | import math
7 |
8 |
9 | class BiReLUFunction(InplaceFunction):
10 |
11 | @classmethod
12 | def forward(cls, ctx, input, inplace=False):
13 | if input.size(1) % 2 != 0:
14 | raise RuntimeError("dimension 1 of input must be multiple of 2, "
15 | "but got {}".format(input.size(1)))
16 | ctx.inplace = inplace
17 |
18 | if ctx.inplace:
19 | ctx.mark_dirty(input)
20 | output = input
21 | else:
22 | output = input.clone()
23 |
24 | pos, neg = output.chunk(2, dim=1)
25 | pos.clamp_(min=0)
26 | neg.clamp_(max=0)
27 | # scale = (pos - neg).view(pos.size(0), -1).mean(1).div_(2)
28 | # output.
29 | ctx.save_for_backward(output)
30 | return output
31 |
32 | @staticmethod
33 | def backward(ctx, grad_output):
34 | output, = ctx.saved_variables
35 | grad_input = grad_output.masked_fill(output.eq(0), 0)
36 | return grad_input, None
37 |
38 |
39 | def birelu(x, inplace=False):
40 | return BiReLUFunction().apply(x, inplace)
41 |
42 |
43 | class BiReLU(nn.Module):
44 | """docstring for BiReLU."""
45 |
46 | def __init__(self, inplace=False):
47 | super(BiReLU, self).__init__()
48 | self.inplace = inplace
49 |
50 | def forward(self, inputs):
51 | return birelu(inputs, inplace=self.inplace)
52 |
53 |
54 | def binorm(x, shift=0, scale_fix=(2 / math.pi) ** 0.5):
55 | pos, neg = (x + shift).split(2, dim=1)
56 | scale = (pos - neg).view(pos.size(0), -1).mean(1).div_(2) * scale_fix
57 | return x / scale
58 |
59 |
60 | def _mean(p, dim):
61 | """Computes the mean over all dimensions except dim"""
62 | if dim is None:
63 | return p.mean()
64 | elif dim == 0:
65 | output_size = (p.size(0),) + (1,) * (p.dim() - 1)
66 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size)
67 | elif dim == p.dim() - 1:
68 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),)
69 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size)
70 | else:
71 | return _mean(p.transpose(0, dim), 0).transpose(0, dim)
72 |
73 |
74 | def rnlu(x, inplace=False, shift=0, scale_fix=(math.pi / 2) ** 0.5):
75 | x = birelu(x, inplace=inplace)
76 | pos, neg = (x + shift).chunk(2, dim=1)
77 | # scale = torch.cat((_mean(pos, 1), -_mean(neg, 1)), 1) * scale_fix + 1e-5
78 | scale = (pos - neg).view(pos.size(0), -1).mean(1) * scale_fix + 1e-8
79 | return x / scale.view(scale.size(0), *([1] * (x.dim() - 1)))
80 |
81 |
82 | class RnLU(nn.Module):
83 | """docstring for RnLU."""
84 |
85 | def __init__(self, inplace=False):
86 | super(RnLU, self).__init__()
87 | self.inplace = inplace
88 |
89 | def forward(self, x):
90 | return rnlu(x, inplace=self.inplace)
91 |
92 | # output.
93 | if __name__ == "__main__":
94 | x = Variable(torch.randn(2, 16, 5, 5).cuda(), requires_grad=True)
95 | output = rnlu(x)
96 |
97 | output.sum().backward()
98 |
--------------------------------------------------------------------------------
/fractrain_cifar/train_base.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.backends.cudnn as cudnn
6 | from torch.autograd import Variable
7 | import numpy as np
8 |
9 | import os
10 | import shutil
11 | import argparse
12 | import time
13 | import logging
14 |
15 | import models
16 | from data import *
17 |
18 | import util_swa
19 |
20 |
21 | model_names = sorted(name for name in models.__dict__
22 | if name.islower() and not name.startswith('__')
23 | and callable(models.__dict__[name])
24 | )
25 |
26 |
27 | def parse_args():
28 | # hyper-parameters are from ResNet paper
29 | parser = argparse.ArgumentParser(
30 | description='Quantization Aware Training on CIFAR')
31 | parser.add_argument('--dir', help='annotate the working directory')
32 | parser.add_argument('--cmd', choices=['train', 'test'], default='train')
33 | parser.add_argument('--arch', metavar='ARCH', default='cifar10_resnet_38',
34 | choices=model_names,
35 | help='model architecture: ' +
36 | ' | '.join(model_names) +
37 | ' (default: cifar10_resnet_38)')
38 | parser.add_argument('--dataset', '-d', type=str, default='cifar10',
39 | choices=['cifar10', 'cifar100'],
40 | help='dataset choice')
41 | parser.add_argument('--datadir', default='/home/yf22/dataset', type=str,
42 | help='path to dataset')
43 | parser.add_argument('--workers', default=4, type=int, metavar='N',
44 | help='number of data loading workers (default: 4 )')
45 | parser.add_argument('--iters', default=64000, type=int,
46 | help='number of total iterations (default: 64,000)')
47 | parser.add_argument('--start_iter', default=0, type=int,
48 | help='manual iter number (useful on restarts)')
49 | parser.add_argument('--batch_size', default=128, type=int,
50 | help='mini-batch size (default: 128)')
51 | parser.add_argument('--lr_schedule', default='piecewise', type=str,
52 | help='learning rate schedule')
53 | parser.add_argument('--lr', default=0.1, type=float,
54 | help='initial learning rate')
55 | parser.add_argument('--momentum', default=0.9, type=float,
56 | help='momentum')
57 | parser.add_argument('--weight_decay', default=1e-4, type=float,
58 | help='weight decay (default: 1e-4)')
59 | parser.add_argument('--print_freq', default=10, type=int,
60 | help='print frequency (default: 10)')
61 | parser.add_argument('--resume', default='', type=str,
62 | help='path to latest checkpoint (default: None)')
63 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
64 | help='use pretrained model')
65 | parser.add_argument('--step_ratio', default=0.1, type=float,
66 | help='ratio for learning rate deduction')
67 | parser.add_argument('--warm_up', action='store_true',
68 | help='for n = 18, the model needs to warm up for 400 '
69 | 'iterations')
70 | parser.add_argument('--save_folder', default='save_checkpoints',
71 | type=str,
72 | help='folder to save the checkpoints')
73 | parser.add_argument('--eval_every', default=390, type=int,
74 | help='evaluate model every (default: 1000) iterations')
75 | parser.add_argument('--num_bits',default=0,type=int,
76 | help='num bits for weight and activation')
77 | parser.add_argument('--num_grad_bits',default=0,type=int,
78 | help='num bits for gradient')
79 | parser.add_argument('--schedule', default=None, type=int, nargs='*',
80 | help='precision schedule')
81 | parser.add_argument('--num_bits_schedule',default=None,type=int,nargs='*',
82 | help='schedule for weight/act precision')
83 | parser.add_argument('--num_grad_bits_schedule',default=None,type=int,nargs='*',
84 | help='schedule for grad precision')
85 | parser.add_argument('--act_fw', default=0, type=int,
86 | help='precision of activation during forward, -1 means dynamic, 0 means no quantize')
87 | parser.add_argument('--act_bw', default=0, type=int,
88 | help='precision of activation during backward, -1 means dynamic, 0 means no quantize')
89 | parser.add_argument('--grad_act_error', default=0, type=int,
90 | help='precision of activation gradient during error backward, -1 means dynamic, 0 means no quantize')
91 | parser.add_argument('--grad_act_gc', default=0, type=int,
92 | help='precision of activation gradient during weight gradient computation, -1 means dynamic, 0 means no quantize')
93 | parser.add_argument('--weight_bits', default=0, type=int,
94 | help='precision of weight')
95 | parser.add_argument('--momentum_act', default=0.9, type=float,
96 | help='momentum for act min/max')
97 | parser.add_argument('--swa_start', type=float, default=None, help='SWA start step number')
98 | parser.add_argument('--swa_freq', type=float, default=1170,
99 | help='SWA model collection frequency')
100 | args = parser.parse_args()
101 | return args
102 |
103 |
104 | def main():
105 | args = parse_args()
106 | save_path = args.save_path = os.path.join(args.save_folder, args.arch)
107 | if not os.path.exists(save_path):
108 | os.makedirs(save_path)
109 |
110 | models.ACT_FW = args.act_fw
111 | models.ACT_BW = args.act_bw
112 | models.GRAD_ACT_ERROR = args.grad_act_error
113 | models.GRAD_ACT_GC = args.grad_act_gc
114 | models.WEIGHT_BITS = args.weight_bits
115 | models.MOMENTUM = args.momentum_act
116 |
117 | args.num_bits = args.num_bits if not (args.act_fw + args.act_bw + args.grad_act_error + args.grad_act_gc + args.weight_bits) else -1
118 |
119 | # config logging file
120 | args.logger_file = os.path.join(save_path, 'log_{}.txt'.format(args.cmd))
121 | handlers = [logging.FileHandler(args.logger_file, mode='w'),
122 | logging.StreamHandler()]
123 | logging.basicConfig(level=logging.INFO,
124 | datefmt='%m-%d-%y %H:%M',
125 | format='%(asctime)s:%(message)s',
126 | handlers=handlers)
127 |
128 | if args.cmd == 'train':
129 | logging.info('start training {}'.format(args.arch))
130 | run_training(args)
131 |
132 | elif args.cmd == 'test':
133 | logging.info('start evaluating {} with checkpoints from {}'.format(
134 | args.arch, args.resume))
135 | test_model(args)
136 |
137 |
138 | def run_training(args):
139 | # create model
140 | model = models.__dict__[args.arch](args.pretrained)
141 | model = torch.nn.DataParallel(model).cuda()
142 |
143 | if args.swa_start is not None:
144 | print('SWA training')
145 | swa_model = torch.nn.DataParallel(models.__dict__[args.arch](args.pretrained)).cuda()
146 | swa_n = 0
147 |
148 | else:
149 | print('SGD training')
150 |
151 | best_prec1 = 0
152 | best_iter = 0
153 |
154 | best_swa_prec = 0
155 | best_swa_iter = 0
156 |
157 | # best_full_prec = 0
158 |
159 | if args.resume:
160 | if os.path.isfile(args.resume):
161 | logging.info('=> loading checkpoint `{}`'.format(args.resume))
162 | checkpoint = torch.load(args.resume)
163 | args.start_iter = checkpoint['iter']
164 | best_prec1 = checkpoint['best_prec1']
165 | model.load_state_dict(checkpoint['state_dict'])
166 |
167 | if args.swa_start is not None:
168 | swa_state_dict = checkpoint['swa_state_dict']
169 | if swa_state_dict is not None:
170 | swa_model.load_state_dict(swa_state_dict)
171 | swa_n_ckpt = checkpoint['swa_n']
172 | if swa_n_ckpt is not None:
173 | swa_n = swa_n_ckpt
174 | best_swa_prec_ckpt = checkpoint['best_swa_prec']
175 | if best_swa_prec_ckpt is not None:
176 | best_swa_prec = best_swa_prec_ckpt
177 |
178 | logging.info('=> loaded checkpoint `{}` (iter: {})'.format(
179 | args.resume, checkpoint['iter']
180 | ))
181 | else:
182 | logging.info('=> no checkpoint found at `{}`'.format(args.resume))
183 |
184 | cudnn.benchmark = False
185 |
186 | train_loader = prepare_train_data(dataset=args.dataset,
187 | datadir=args.datadir,
188 | batch_size=args.batch_size,
189 | shuffle=True,
190 | num_workers=args.workers)
191 | test_loader = prepare_test_data(dataset=args.dataset,
192 | datadir=args.datadir,
193 | batch_size=args.batch_size,
194 | shuffle=False,
195 | num_workers=args.workers)
196 | if args.swa_start is not None:
197 | swa_loader = prepare_train_data(dataset=args.dataset,
198 | datadir=args.datadir,
199 | batch_size=args.batch_size,
200 | shuffle=False,
201 | num_workers=args.workers)
202 |
203 | # define loss function (criterion) and optimizer
204 | criterion = nn.CrossEntropyLoss().cuda()
205 |
206 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
207 | momentum=args.momentum,
208 | weight_decay=args.weight_decay)
209 |
210 | # optimizer = torch.optim.Adam(model.parameters(), args.lr,
211 | # weight_decay=args.weight_decay)
212 |
213 | batch_time = AverageMeter()
214 | data_time = AverageMeter()
215 | losses = AverageMeter()
216 | top1 = AverageMeter()
217 | cr = AverageMeter()
218 |
219 | end = time.time()
220 |
221 | i = args.start_iter
222 | while i < args.iters:
223 | for input, target in train_loader:
224 | # measuring data loading time
225 | data_time.update(time.time() - end)
226 |
227 | model.train()
228 | adjust_learning_rate(args, optimizer, i)
229 | adjust_precision(args, i)
230 |
231 | i += 1
232 |
233 | fw_cost = args.num_bits*args.num_bits/32/32
234 | eb_cost = args.num_bits*args.num_grad_bits/32/32
235 | gc_cost = eb_cost
236 | cr.update((fw_cost+eb_cost+gc_cost)/3)
237 |
238 | target = target.squeeze().long().cuda()
239 | input_var = Variable(input).cuda()
240 | target_var = Variable(target).cuda()
241 |
242 | # compute output
243 | output = model(input_var, args.num_bits, args.num_grad_bits)
244 | loss = criterion(output, target_var)
245 |
246 | # measure accuracy and record loss
247 | prec1, = accuracy(output.data, target, topk=(1,))
248 | losses.update(loss.item(), input.size(0))
249 | top1.update(prec1.item(), input.size(0))
250 |
251 | # compute gradient and do SGD step
252 | optimizer.zero_grad()
253 | loss.backward()
254 | optimizer.step()
255 |
256 | # measure elapsed time
257 | batch_time.update(time.time() - end)
258 | end = time.time()
259 |
260 | # print log
261 | if i % args.print_freq == 0:
262 | logging.info("Iter: [{0}/{1}]\t"
263 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
264 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
265 | "Loss {loss.val:.3f} ({loss.avg:.3f})\t"
266 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t".format(
267 | i,
268 | args.iters,
269 | batch_time=batch_time,
270 | data_time=data_time,
271 | loss=losses,
272 | top1=top1)
273 | )
274 |
275 |
276 | if args.swa_start is not None and i >= args.swa_start and i % args.swa_freq == 0:
277 | util_swa.moving_average(swa_model, model, 1.0 / (swa_n + 1))
278 | swa_n += 1
279 | util_swa.bn_update(swa_loader, swa_model, args.num_bits, args.num_grad_bits)
280 | prec1 = validate(args, test_loader, swa_model, criterion, i, swa=True)
281 |
282 | if prec1 > best_swa_prec:
283 | best_swa_prec = prec1
284 | best_swa_iter = i
285 |
286 | print("Current Best SWA Prec@1: ", best_swa_prec)
287 | print("Current Best SWA Iteration: ", best_swa_iter)
288 |
289 | if (i % args.eval_every == 0 and i > 0) or (i == args.iters):
290 | with torch.no_grad():
291 | prec1 = validate(args, test_loader, model, criterion, i)
292 | # prec_full = validate_full_prec(args, test_loader, model, criterion, i)
293 |
294 | is_best = prec1 > best_prec1
295 | if is_best:
296 | best_prec1 = prec1
297 | best_iter = i
298 | # best_full_prec = max(prec_full, best_full_prec)
299 |
300 | print("Current Best Prec@1: ", best_prec1)
301 | print("Current Best Iteration: ", best_iter)
302 | # print("Current Best Full Prec@1: ", best_full_prec)
303 |
304 | checkpoint_path = os.path.join(args.save_path, 'checkpoint_{:05d}_{:.2f}.pth.tar'.format(i, prec1))
305 | save_checkpoint({
306 | 'iter': i,
307 | 'arch': args.arch,
308 | 'state_dict': model.state_dict(),
309 | 'best_prec1': best_prec1,
310 | 'swa_state_dict' : swa_model.state_dict() if args.swa_start is not None else None,
311 | 'swa_n' : swa_n if args.swa_start is not None else None,
312 | 'best_swa_prec' : best_swa_prec if args.swa_start is not None else None,
313 | },
314 | is_best, filename=checkpoint_path)
315 | shutil.copyfile(checkpoint_path, os.path.join(args.save_path,
316 | 'checkpoint_latest'
317 | '.pth.tar'))
318 |
319 | if i == args.iters:
320 | break
321 |
322 |
323 | def validate(args, test_loader, model, criterion, step, swa=False):
324 | batch_time = AverageMeter()
325 | losses = AverageMeter()
326 | top1 = AverageMeter()
327 |
328 | # switch to evaluation mode
329 | model.eval()
330 | end = time.time()
331 | for i, (input, target) in enumerate(test_loader):
332 | target = target.squeeze().long().cuda()
333 | input_var = Variable(input, volatile=True).cuda()
334 | target_var = Variable(target, volatile=True).cuda()
335 |
336 | # compute output
337 | output = model(input_var, args.num_bits, args.num_grad_bits)
338 | loss = criterion(output, target_var)
339 |
340 | # measure accuracy and record loss
341 | prec1, = accuracy(output.data, target, topk=(1,))
342 | top1.update(prec1.item(), input.size(0))
343 | losses.update(loss.item(), input.size(0))
344 | batch_time.update(time.time() - end)
345 | end = time.time()
346 |
347 | if (i % args.print_freq == 0) or (i == len(test_loader) - 1):
348 | logging.info(
349 | 'Test: [{}/{}]\t'
350 | 'Time: {batch_time.val:.4f}({batch_time.avg:.4f})\t'
351 | 'Loss: {loss.val:.3f}({loss.avg:.3f})\t'
352 | 'Prec@1: {top1.val:.3f}({top1.avg:.3f})\t'.format(
353 | i, len(test_loader), batch_time=batch_time,
354 | loss=losses, top1=top1
355 | )
356 | )
357 |
358 | if not swa:
359 | logging.info('Step {} * Prec@1 {top1.avg:.3f}'.format(step, top1=top1))
360 | else:
361 | logging.info('Step {} * SWA Prec@1 {top1.avg:.3f}'.format(step, top1=top1))
362 |
363 | return top1.avg
364 |
365 |
366 | def validate_full_prec(args, test_loader, model, criterion, step):
367 | batch_time = AverageMeter()
368 | losses = AverageMeter()
369 | top1 = AverageMeter()
370 |
371 | # switch to evaluation mode
372 | model.eval()
373 | end = time.time()
374 | for i, (input, target) in enumerate(test_loader):
375 | target = target.squeeze().long().cuda()
376 | input_var = Variable(input, volatile=True).cuda()
377 | target_var = Variable(target, volatile=True).cuda()
378 |
379 | # compute output
380 | output = model(input_var, 0, 0)
381 | loss = criterion(output, target_var)
382 |
383 | # measure accuracy and record loss
384 | prec1, = accuracy(output.data, target, topk=(1,))
385 | top1.update(prec1.item(), input.size(0))
386 | losses.update(loss.item(), input.size(0))
387 | batch_time.update(time.time() - end)
388 | end = time.time()
389 |
390 |
391 | logging.info('Step {} * Full Prec@1 {top1.avg:.3f}'.format(step, top1=top1))
392 | return top1.avg
393 |
394 |
395 | def test_model(args):
396 | # create model
397 | model = models.__dict__[args.arch](args.pretrained)
398 | model = torch.nn.DataParallel(model).cuda()
399 |
400 | if args.resume:
401 | if os.path.isfile(args.resume):
402 | logging.info('=> loading checkpoint `{}`'.format(args.resume))
403 | checkpoint = torch.load(args.resume)
404 | args.start_iter = checkpoint['iter']
405 | best_prec1 = checkpoint['best_prec1']
406 | model.load_state_dict(checkpoint['state_dict'])
407 | logging.info('=> loaded checkpoint `{}` (iter: {})'.format(
408 | args.resume, checkpoint['iter']
409 | ))
410 | else:
411 | logging.info('=> no checkpoint found at `{}`'.format(args.resume))
412 |
413 | cudnn.benchmark = False
414 | test_loader = prepare_test_data(dataset=args.dataset,
415 | batch_size=args.batch_size,
416 | shuffle=False,
417 | num_workers=args.workers)
418 | criterion = nn.CrossEntropyLoss().cuda()
419 |
420 | # validate(args, test_loader, model, criterion)
421 |
422 | with torch.no_grad():
423 | prec1 = validate(args, test_loader, model, criterion, args.start_iter)
424 | prec_full = validate_full_prec(args, test_loader, model, criterion, args.start_iter)
425 |
426 |
427 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
428 | torch.save(state, filename)
429 | if is_best:
430 | save_path = os.path.dirname(filename)
431 | shutil.copyfile(filename, os.path.join(save_path,
432 | 'model_best.pth.tar'))
433 |
434 |
435 | class AverageMeter(object):
436 | """Computes and stores the average and current value"""
437 |
438 | def __init__(self):
439 | self.reset()
440 |
441 | def reset(self):
442 | self.val = 0
443 | self.avg = 0
444 | self.sum = 0
445 | self.count = 0
446 |
447 | def update(self, val, n=1):
448 | self.val = val
449 | self.sum += val * n
450 | self.count += n
451 | self.avg = self.sum / self.count
452 |
453 |
454 | schedule_cnt = 0
455 | def adjust_precision(args, _iter):
456 | if args.schedule:
457 | global schedule_cnt
458 |
459 | assert len(args.num_bits_schedule) == len(args.schedule) + 1
460 | assert len(args.num_grad_bits_schedule) == len(args.schedule) + 1
461 |
462 | if schedule_cnt == 0:
463 | args.num_bits = args.num_bits_schedule[0]
464 | args.num_grad_bits = args.num_grad_bits_schedule[0]
465 | schedule_cnt += 1
466 |
467 | for step in args.schedule:
468 | if _iter == step:
469 | args.num_bits = args.num_bits_schedule[schedule_cnt]
470 | args.num_grad_bits = args.num_grad_bits_schedule[schedule_cnt]
471 | schedule_cnt += 1
472 |
473 | if _iter % args.eval_every == 0:
474 | logging.info('Iter [{}] num_bits = {} num_grad_bits = {}'.format(_iter, args.num_bits, args.num_grad_bits))
475 |
476 |
477 | def adjust_learning_rate(args, optimizer, _iter):
478 | if args.lr_schedule == 'piecewise':
479 | if args.warm_up and (_iter < 400):
480 | lr = 0.01
481 | elif 32000 <= _iter < 48000:
482 | lr = args.lr * (args.step_ratio ** 1)
483 | elif _iter >= 48000:
484 | lr = args.lr * (args.step_ratio ** 2)
485 | else:
486 | lr = args.lr
487 |
488 | elif args.lr_schedule == 'linear':
489 | t = _iter / args.iters
490 | lr_ratio = 0.01
491 | if args.warm_up and (_iter < 400):
492 | lr = 0.01
493 | elif t < 0.5:
494 | lr = args.lr
495 | elif t < 0.9:
496 | lr = args.lr * (1 - (1-lr_ratio)*(t-0.5)/0.4)
497 | else:
498 | lr = args.lr * lr_ratio
499 |
500 | elif args.lr_schedule == 'anneal_cosine':
501 | lr_min = args.lr * (args.step_ratio ** 2)
502 | lr_max = args.lr
503 | lr = lr_min + 1/2 * (lr_max - lr_min) * (1 + np.cos(_iter/args.iters * 3.141592653))
504 |
505 | if _iter % args.eval_every == 0:
506 | logging.info('Iter [{}] learning rate = {}'.format(_iter, lr))
507 |
508 | for param_group in optimizer.param_groups:
509 | param_group['lr'] = lr
510 |
511 |
512 | def accuracy(output, target, topk=(1,)):
513 | """Computes the precision@k for the specified values of k"""
514 | maxk = max(topk)
515 | batch_size = target.size(0)
516 |
517 | _, pred = output.topk(maxk, 1, True, True)
518 | pred = pred.t()
519 | correct = pred.eq(target.view(1, -1).expand_as(pred))
520 |
521 | res = []
522 | for k in topk:
523 | correct_k = correct[:k].view(-1).float().sum(0)
524 | res.append(correct_k.mul_(100.0 / batch_size))
525 | return res
526 |
527 |
528 | if __name__ == '__main__':
529 | main()
530 |
--------------------------------------------------------------------------------
/fractrain_cifar/train_pfq.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.backends.cudnn as cudnn
6 | from torch.autograd import Variable
7 | import numpy as np
8 |
9 | import os
10 | import shutil
11 | import argparse
12 | import time
13 | import logging
14 |
15 | import models
16 | from data import *
17 |
18 | import util_swa
19 |
20 |
21 | model_names = sorted(name for name in models.__dict__
22 | if name.islower() and not name.startswith('__')
23 | and callable(models.__dict__[name])
24 | )
25 |
26 |
27 | def parse_args():
28 | # hyper-parameters are from ResNet paper
29 | parser = argparse.ArgumentParser(
30 | description='PFQ on CIFAR')
31 | parser.add_argument('--dir', help='annotate the working directory')
32 | parser.add_argument('--cmd', choices=['train', 'test'], default='train')
33 | parser.add_argument('--arch', metavar='ARCH', default='cifar10_resnet_38',
34 | choices=model_names,
35 | help='model architecture: ' +
36 | ' | '.join(model_names) +
37 | ' (default: cifar10_resnet_38)')
38 | parser.add_argument('--dataset', '-d', type=str, default='cifar10',
39 | choices=['cifar10', 'cifar100'],
40 | help='dataset choice')
41 | parser.add_argument('--datadir', default='/home/yf22/dataset', type=str,
42 | help='path to dataset')
43 | parser.add_argument('--workers', default=4, type=int, metavar='N',
44 | help='number of data loading workers (default: 4 )')
45 | parser.add_argument('--iters', default=64000, type=int,
46 | help='number of total iterations (default: 64,000)')
47 | parser.add_argument('--start_iter', default=0, type=int,
48 | help='manual iter number (useful on restarts)')
49 | parser.add_argument('--batch_size', default=128, type=int,
50 | help='mini-batch size (default: 128)')
51 | parser.add_argument('--lr_schedule', default='piecewise', type=str,
52 | help='learning rate schedule')
53 | parser.add_argument('--lr', default=0.1, type=float,
54 | help='initial learning rate')
55 | parser.add_argument('--momentum', default=0.9, type=float,
56 | help='momentum')
57 | parser.add_argument('--weight_decay', default=1e-4, type=float,
58 | help='weight decay (default: 1e-4)')
59 | parser.add_argument('--print_freq', default=10, type=int,
60 | help='print frequency (default: 10)')
61 | parser.add_argument('--resume', default='', type=str,
62 | help='path to latest checkpoint (default: None)')
63 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
64 | help='use pretrained model')
65 | parser.add_argument('--step_ratio', default=0.1, type=float,
66 | help='ratio for learning rate deduction')
67 | parser.add_argument('--warm_up', action='store_true',
68 | help='for n = 18, the model needs to warm up for 400 '
69 | 'iterations')
70 | parser.add_argument('--save_folder', default='save_checkpoints',
71 | type=str,
72 | help='folder to save the checkpoints')
73 | parser.add_argument('--eval_every', default=400, type=int,
74 | help='evaluate model every (default: 1000) iterations')
75 | parser.add_argument('--num_bits',default=0,type=int,
76 | help='num bits for weight and activation')
77 | parser.add_argument('--num_grad_bits',default=0,type=int,
78 | help='num bits for gradient')
79 | parser.add_argument('--schedule', default=None, type=int, nargs='*',
80 | help='precision schedule')
81 | parser.add_argument('--num_bits_schedule',default=None,type=int,nargs='*',
82 | help='schedule for weight/act precision')
83 | parser.add_argument('--num_grad_bits_schedule',default=None,type=int,nargs='*',
84 | help='schedule for grad precision')
85 | parser.add_argument('--act_fw', default=0, type=int,
86 | help='precision of activation during forward, -1 means dynamic, 0 means no quantize')
87 | parser.add_argument('--act_bw', default=0, type=int,
88 | help='precision of activation during backward, -1 means dynamic, 0 means no quantize')
89 | parser.add_argument('--grad_act_error', default=0, type=int,
90 | help='precision of activation gradient during error backward, -1 means dynamic, 0 means no quantize')
91 | parser.add_argument('--grad_act_gc', default=0, type=int,
92 | help='precision of activation gradient during weight gradient computation, -1 means dynamic, 0 means no quantize')
93 | parser.add_argument('--weight_bits', default=0, type=int,
94 | help='precision of weight')
95 | parser.add_argument('--momentum_act', default=0.9, type=float,
96 | help='momentum for act min/max')
97 | parser.add_argument('--swa_start', type=float, default=None, help='SWA start step number')
98 | parser.add_argument('--swa_freq', type=float, default=1170,
99 | help='SWA model collection frequency')
100 |
101 | parser.add_argument('--num_turning_point', type=int, default=3)
102 | parser.add_argument('--initial_threshold', type=float, default=0.15)
103 | parser.add_argument('--decay', type=float, default=0.4)
104 | args = parser.parse_args()
105 | return args
106 |
107 | # indicator
108 | class loss_diff_indicator():
109 | def __init__(self, threshold, decay, epoch_keep=5):
110 | self.threshold = threshold
111 | self.decay = decay
112 | self.epoch_keep = epoch_keep
113 | self.loss = []
114 | self.scale_loss = 1
115 | self.loss_diff = [1 for i in range(1, self.epoch_keep)]
116 |
117 | def reset(self):
118 | self.loss = []
119 | self.loss_diff = [1 for i in range(1, self.epoch_keep)]
120 |
121 | def adaptive_threshold(self, turning_point_count):
122 | decay_1 = self.decay
123 | decay_2 = self.decay
124 | if turning_point_count == 1:
125 | self.threshold *= decay_1
126 | if turning_point_count == 2:
127 | self.threshold *= decay_2
128 | print('threshold decay to {}'.format(self.threshold))
129 |
130 | def get_loss(self, current_epoch_loss):
131 | if len(self.loss) < self.epoch_keep:
132 | self.loss.append(current_epoch_loss)
133 | else:
134 | self.loss.pop(0)
135 | self.loss.append(current_epoch_loss)
136 |
137 | def cal_loss_diff(self):
138 | if len(self.loss) == self.epoch_keep:
139 | for i in range(len(self.loss)-1):
140 | loss_now = self.loss[-1]
141 | loss_pre = self.loss[i]
142 | self.loss_diff[i] = np.abs(loss_pre - loss_now) / self.scale_loss
143 | return True
144 | else:
145 | return False
146 |
147 | def turning_point_emerge(self):
148 | flag = self.cal_loss_diff()
149 | if flag == True:
150 | print(self.loss_diff)
151 | for i in range(len(self.loss_diff)):
152 | if self.loss_diff[i] > self.threshold:
153 | return False
154 | return True
155 | else:
156 | return False
157 |
158 | def main():
159 | args = parse_args()
160 | global save_path
161 | save_path = args.save_path = os.path.join(args.save_folder, args.arch)
162 | if not os.path.exists(save_path):
163 | os.makedirs(save_path)
164 |
165 | models.ACT_FW = args.act_fw
166 | models.ACT_BW = args.act_bw
167 | models.GRAD_ACT_ERROR = args.grad_act_error
168 | models.GRAD_ACT_GC = args.grad_act_gc
169 | models.WEIGHT_BITS = args.weight_bits
170 | models.MOMENTUM = args.momentum_act
171 |
172 | args.num_bits = args.num_bits if not (args.act_fw + args.act_bw + args.grad_act_error + args.grad_act_gc + args.weight_bits) else -1
173 |
174 | # config logging file
175 | args.logger_file = os.path.join(save_path, 'log_{}.txt'.format(args.cmd))
176 | if os.path.exists(args.logger_file):
177 | os.remove(args.logger_file)
178 | handlers = [logging.FileHandler(args.logger_file, mode='w'),
179 | logging.StreamHandler()]
180 | logging.basicConfig(level=logging.INFO,
181 | datefmt='%m-%d-%y %H:%M',
182 | format='%(asctime)s:%(message)s',
183 | handlers=handlers)
184 |
185 | global history_score
186 | history_score = np.zeros((args.iters // args.eval_every, 3))
187 |
188 | # initialize indicator
189 | # initial_threshold=0.15
190 | global scale_loss
191 | scale_loss = 0
192 | global my_loss_diff_indicator
193 | my_loss_diff_indicator = loss_diff_indicator(threshold=args.initial_threshold,
194 | decay=args.decay)
195 |
196 | global turning_point_count
197 | turning_point_count = 0
198 |
199 | if args.cmd == 'train':
200 | logging.info('start training {}'.format(args.arch))
201 | run_training(args)
202 |
203 | elif args.cmd == 'test':
204 | logging.info('start evaluating {} with checkpoints from {}'.format(
205 | args.arch, args.resume))
206 | test_model(args)
207 |
208 |
209 |
210 | def run_training(args):
211 | # create model
212 | training_loss = 0
213 | training_acc = 0
214 |
215 | model = models.__dict__[args.arch](args.pretrained)
216 | model = torch.nn.DataParallel(model).cuda()
217 |
218 | if args.swa_start is not None:
219 | print('SWA training')
220 | swa_model = torch.nn.DataParallel(models.__dict__[args.arch](args.pretrained)).cuda()
221 | swa_n = 0
222 |
223 | else:
224 | print('SGD training')
225 |
226 | best_prec1 = 0
227 | best_iter = 0
228 |
229 | best_swa_prec = 0
230 | best_swa_iter = 0
231 |
232 | # best_full_prec = 0
233 |
234 | if args.resume:
235 | if os.path.isfile(args.resume):
236 | logging.info('=> loading checkpoint `{}`'.format(args.resume))
237 | checkpoint = torch.load(args.resume)
238 | args.start_iter = checkpoint['iter']
239 | best_prec1 = checkpoint['best_prec1']
240 | model.load_state_dict(checkpoint['state_dict'])
241 |
242 | if args.swa_start is not None:
243 | swa_state_dict = checkpoint['swa_state_dict']
244 | if swa_state_dict is not None:
245 | swa_model.load_state_dict(swa_state_dict)
246 | swa_n_ckpt = checkpoint['swa_n']
247 | if swa_n_ckpt is not None:
248 | swa_n = swa_n_ckpt
249 | best_swa_prec_ckpt = checkpoint['best_swa_prec']
250 | if best_swa_prec_ckpt is not None:
251 | best_swa_prec = best_swa_prec_ckpt
252 |
253 | logging.info('=> loaded checkpoint `{}` (iter: {})'.format(
254 | args.resume, checkpoint['iter']
255 | ))
256 | else:
257 | logging.info('=> no checkpoint found at `{}`'.format(args.resume))
258 |
259 | cudnn.benchmark = False
260 |
261 | train_loader = prepare_train_data(dataset=args.dataset,
262 | datadir=args.datadir,
263 | batch_size=args.batch_size,
264 | shuffle=True,
265 | num_workers=args.workers)
266 | test_loader = prepare_test_data(dataset=args.dataset,
267 | datadir=args.datadir,
268 | batch_size=args.batch_size,
269 | shuffle=False,
270 | num_workers=args.workers)
271 | if args.swa_start is not None:
272 | swa_loader = prepare_train_data(dataset=args.dataset,
273 | datadir=args.datadir,
274 | batch_size=args.batch_size,
275 | shuffle=False,
276 | num_workers=args.workers)
277 |
278 | # define loss function (criterion) and optimizer
279 | criterion = nn.CrossEntropyLoss().cuda()
280 |
281 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
282 | momentum=args.momentum,
283 | weight_decay=args.weight_decay)
284 |
285 | # optimizer = torch.optim.Adam(model.parameters(), args.lr,
286 | # weight_decay=args.weight_decay)
287 |
288 | batch_time = AverageMeter()
289 | data_time = AverageMeter()
290 | losses = AverageMeter()
291 | top1 = AverageMeter()
292 | cr = AverageMeter()
293 |
294 | end = time.time()
295 |
296 | global scale_loss
297 | global turning_point_count
298 | global my_loss_diff_indicator
299 |
300 | i = args.start_iter
301 | while i < args.iters:
302 | for input, target in train_loader:
303 | # measuring data loading time
304 | data_time.update(time.time() - end)
305 |
306 | model.train()
307 | adjust_learning_rate(args, optimizer, i)
308 | # adjust_precision(args, i)
309 | adaptive_adjust_precision(args, turning_point_count)
310 |
311 | i += 1
312 |
313 | fw_cost = args.num_bits*args.num_bits/32/32
314 | eb_cost = args.num_bits*args.num_grad_bits/32/32
315 | gc_cost = eb_cost
316 | cr.update((fw_cost+eb_cost+gc_cost)/3)
317 |
318 | target = target.squeeze().long().cuda()
319 | input_var = Variable(input).cuda()
320 | target_var = Variable(target).cuda()
321 |
322 | # compute output
323 | output = model(input_var, args.num_bits, args.num_grad_bits)
324 | loss = criterion(output, target_var)
325 | training_loss += loss.item()
326 |
327 | # measure accuracy and record loss
328 | prec1, = accuracy(output.data, target, topk=(1,))
329 | losses.update(loss.item(), input.size(0))
330 | top1.update(prec1.item(), input.size(0))
331 | training_acc += prec1.item()
332 |
333 | # compute gradient and do SGD step
334 | optimizer.zero_grad()
335 | loss.backward()
336 | optimizer.step()
337 |
338 | # measure elapsed time
339 | batch_time.update(time.time() - end)
340 | end = time.time()
341 |
342 | # print log
343 | if i % args.print_freq == 0:
344 | logging.info("Iter: [{0}/{1}]\t"
345 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
346 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
347 | "Loss {loss.val:.3f} ({loss.avg:.3f})\t"
348 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t".format(
349 | i,
350 | args.iters,
351 | batch_time=batch_time,
352 | data_time=data_time,
353 | loss=losses,
354 | top1=top1)
355 | )
356 |
357 |
358 | if args.swa_start is not None and i >= args.swa_start and i % args.swa_freq == 0:
359 | util_swa.moving_average(swa_model, model, 1.0 / (swa_n + 1))
360 | swa_n += 1
361 | util_swa.bn_update(swa_loader, swa_model, args.num_bits, args.num_grad_bits)
362 | prec1 = validate(args, test_loader, swa_model, criterion, i, swa=True)
363 |
364 | if prec1 > best_swa_prec:
365 | best_swa_prec = prec1
366 | best_swa_iter = i
367 |
368 | # print("Current Best SWA Prec@1: ", best_swa_prec)
369 | # print("Current Best SWA Iteration: ", best_swa_iter)
370 |
371 | if (i % args.eval_every == 0 and i > 0) or (i == args.iters):
372 | # record training loss and test accuracy
373 | global history_score
374 | epoch = i // args.eval_every
375 | epoch_loss = training_loss / len(train_loader)
376 | with torch.no_grad():
377 | prec1 = validate(args, test_loader, model, criterion, i)
378 | # prec_full = validate_full_prec(args, test_loader, model, criterion, i)
379 | history_score[epoch-1][0] = epoch_loss
380 | history_score[epoch-1][1] = np.round(training_acc / len(train_loader), 2)
381 | history_score[epoch-1][2] = prec1
382 | training_loss = 0
383 | training_acc = 0
384 |
385 | np.savetxt(os.path.join(save_path, 'record.txt'), history_score, fmt = '%10.5f', delimiter=',')
386 |
387 | # apply indicator
388 | # if epoch == 1:
389 | # logging.info('initial loss value: {}'.format(epoch_loss))
390 | # my_loss_diff_indicator.scale_loss = epoch_loss
391 | if epoch <= 10:
392 | scale_loss += epoch_loss
393 | logging.info('scale_loss at epoch {}: {}'.format(epoch, scale_loss / epoch))
394 | my_loss_diff_indicator.scale_loss = scale_loss / epoch
395 | if turning_point_count < args.num_turning_point:
396 | my_loss_diff_indicator.get_loss(epoch_loss)
397 | flag = my_loss_diff_indicator.turning_point_emerge()
398 | if flag == True:
399 | turning_point_count += 1
400 | logging.info('find {}-th turning point at {}-th epoch'.format(turning_point_count, epoch))
401 | # print('find {}-th turning point at {}-th epoch'.format(turning_point_count, epoch))
402 | my_loss_diff_indicator.adaptive_threshold(turning_point_count=turning_point_count)
403 | my_loss_diff_indicator.reset()
404 |
405 | logging.info('Epoch [{}] num_bits = {} num_grad_bits = {}'.format(epoch, args.num_bits, args.num_grad_bits))
406 |
407 | # print statistics
408 | is_best = prec1 > best_prec1
409 | if is_best:
410 | best_prec1 = prec1
411 | best_iter = i
412 | # best_full_prec = max(prec_full, best_full_prec)
413 |
414 | logging.info("Current Best Prec@1: {}".format(best_prec1))
415 | logging.info("Current Best Iteration: {}".format(best_iter))
416 | logging.info("Current Best SWA Prec@1: {}".format(best_swa_prec))
417 | logging.info("Current Best SWA Iteration: {}".format(best_swa_iter))
418 | # print("Current Best Full Prec@1: ", best_full_prec)
419 |
420 | # checkpoint_path = os.path.join(args.save_path, 'checkpoint_{:05d}_{:.2f}.pth.tar'.format(i, prec1))
421 | checkpoint_path = os.path.join(args.save_path, 'ckpt.pth.tar')
422 | save_checkpoint({
423 | 'iter': i,
424 | 'arch': args.arch,
425 | 'state_dict': model.state_dict(),
426 | 'best_prec1': best_prec1,
427 | 'swa_state_dict' : swa_model.state_dict() if args.swa_start is not None else None,
428 | 'swa_n' : swa_n if args.swa_start is not None else None,
429 | 'best_swa_prec' : best_swa_prec if args.swa_start is not None else None,
430 | },
431 | is_best, filename=checkpoint_path)
432 | # shutil.copyfile(checkpoint_path, os.path.join(args.save_path,
433 | # 'checkpoint_latest'
434 | # '.pth.tar'))
435 |
436 | if i == args.iters:
437 | print("Best accuracy: "+str(best_prec1))
438 | history_score[-1][0] = best_prec1
439 | np.savetxt(os.path.join(save_path, 'record.txt'), history_score, fmt = '%10.5f', delimiter=',')
440 | break
441 |
442 |
443 | def validate(args, test_loader, model, criterion, step, swa=False):
444 | batch_time = AverageMeter()
445 | losses = AverageMeter()
446 | top1 = AverageMeter()
447 |
448 | # switch to evaluation mode
449 | model.eval()
450 | end = time.time()
451 | for i, (input, target) in enumerate(test_loader):
452 | target = target.squeeze().long().cuda()
453 | input_var = Variable(input, volatile=True).cuda()
454 | target_var = Variable(target, volatile=True).cuda()
455 |
456 | # compute output
457 | output = model(input_var, args.num_bits, args.num_grad_bits)
458 | loss = criterion(output, target_var)
459 |
460 | # measure accuracy and record loss
461 | prec1, = accuracy(output.data, target, topk=(1,))
462 | top1.update(prec1.item(), input.size(0))
463 | losses.update(loss.item(), input.size(0))
464 | batch_time.update(time.time() - end)
465 | end = time.time()
466 |
467 | if (i % args.print_freq == 0) or (i == len(test_loader) - 1):
468 | logging.info(
469 | 'Test: [{}/{}]\t'
470 | 'Time: {batch_time.val:.4f}({batch_time.avg:.4f})\t'
471 | 'Loss: {loss.val:.3f}({loss.avg:.3f})\t'
472 | 'Prec@1: {top1.val:.3f}({top1.avg:.3f})\t'.format(
473 | i, len(test_loader), batch_time=batch_time,
474 | loss=losses, top1=top1
475 | )
476 | )
477 |
478 | if not swa:
479 | logging.info('Step {} * Prec@1 {top1.avg:.3f}'.format(step, top1=top1))
480 | else:
481 | logging.info('Step {} * SWA Prec@1 {top1.avg:.3f}'.format(step, top1=top1))
482 |
483 | return top1.avg
484 |
485 |
486 | def validate_full_prec(args, test_loader, model, criterion, step):
487 | batch_time = AverageMeter()
488 | losses = AverageMeter()
489 | top1 = AverageMeter()
490 |
491 | # switch to evaluation mode
492 | model.eval()
493 | end = time.time()
494 | for i, (input, target) in enumerate(test_loader):
495 | target = target.squeeze().long().cuda()
496 | input_var = Variable(input, volatile=True).cuda()
497 | target_var = Variable(target, volatile=True).cuda()
498 |
499 | # compute output
500 | output = model(input_var, 0, 0)
501 | loss = criterion(output, target_var)
502 |
503 | # measure accuracy and record loss
504 | prec1, = accuracy(output.data, target, topk=(1,))
505 | top1.update(prec1.item(), input.size(0))
506 | losses.update(loss.item(), input.size(0))
507 | batch_time.update(time.time() - end)
508 | end = time.time()
509 |
510 |
511 | logging.info('Step {} * Full Prec@1 {top1.avg:.3f}'.format(step, top1=top1))
512 | return top1.avg
513 |
514 |
515 | def test_model(args):
516 | # create model
517 | model = models.__dict__[args.arch](args.pretrained)
518 | model = torch.nn.DataParallel(model).cuda()
519 |
520 | if args.resume:
521 | if os.path.isfile(args.resume):
522 | logging.info('=> loading checkpoint `{}`'.format(args.resume))
523 | checkpoint = torch.load(args.resume)
524 | args.start_iter = checkpoint['iter']
525 | best_prec1 = checkpoint['best_prec1']
526 | model.load_state_dict(checkpoint['state_dict'])
527 | logging.info('=> loaded checkpoint `{}` (iter: {})'.format(
528 | args.resume, checkpoint['iter']
529 | ))
530 | else:
531 | logging.info('=> no checkpoint found at `{}`'.format(args.resume))
532 |
533 | cudnn.benchmark = False
534 | test_loader = prepare_test_data(dataset=args.dataset,
535 | batch_size=args.batch_size,
536 | shuffle=False,
537 | num_workers=args.workers)
538 | criterion = nn.CrossEntropyLoss().cuda()
539 |
540 | # validate(args, test_loader, model, criterion)
541 |
542 | with torch.no_grad():
543 | prec1 = validate(args, test_loader, model, criterion, args.start_iter)
544 | prec_full = validate_full_prec(args, test_loader, model, criterion, args.start_iter)
545 |
546 |
547 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
548 | torch.save(state, filename)
549 | if is_best:
550 | save_path = os.path.dirname(filename)
551 | shutil.copyfile(filename, os.path.join(save_path,
552 | 'model_best.pth.tar'))
553 |
554 |
555 | class AverageMeter(object):
556 | """Computes and stores the average and current value"""
557 |
558 | def __init__(self):
559 | self.reset()
560 |
561 | def reset(self):
562 | self.val = 0
563 | self.avg = 0
564 | self.sum = 0
565 | self.count = 0
566 |
567 | def update(self, val, n=1):
568 | self.val = val
569 | self.sum += val * n
570 | self.count += n
571 | self.avg = self.sum / self.count
572 |
573 |
574 | schedule_cnt = 0
575 | def adjust_precision(args, _iter):
576 | if args.schedule:
577 | global schedule_cnt
578 |
579 | assert len(args.num_bits_schedule) == len(args.schedule) + 1
580 | assert len(args.num_grad_bits_schedule) == len(args.schedule) + 1
581 |
582 | if schedule_cnt == 0:
583 | args.num_bits = args.num_bits_schedule[0]
584 | args.num_grad_bits = args.num_grad_bits_schedule[0]
585 | schedule_cnt += 1
586 |
587 | for step in args.schedule:
588 | if _iter == step:
589 | args.num_bits = args.num_bits_schedule[schedule_cnt]
590 | args.num_grad_bits = args.num_grad_bits_schedule[schedule_cnt]
591 | schedule_cnt += 1
592 |
593 | if _iter % args.eval_every == 0:
594 | logging.info('Iter [{}] num_bits = {} num_grad_bits = {}'.format(_iter, args.num_bits, args.num_grad_bits))
595 |
596 | def adaptive_adjust_precision(args, turning_point_count):
597 | args.num_bits = args.num_bits_schedule[turning_point_count]
598 | args.num_grad_bits = args.num_grad_bits_schedule[turning_point_count]
599 |
600 |
601 | def adjust_learning_rate(args, optimizer, _iter):
602 | if args.lr_schedule == 'piecewise':
603 | if args.warm_up and (_iter < 400):
604 | lr = 0.01
605 | elif 32000 <= _iter < 48000:
606 | lr = args.lr * (args.step_ratio ** 1)
607 | elif _iter >= 48000:
608 | lr = args.lr * (args.step_ratio ** 2)
609 | else:
610 | lr = args.lr
611 |
612 | elif args.lr_schedule == 'linear':
613 | t = _iter / args.iters
614 | lr_ratio = 0.01
615 | if args.warm_up and (_iter < 400):
616 | lr = 0.01
617 | elif t < 0.5:
618 | lr = args.lr
619 | elif t < 0.9:
620 | lr = args.lr * (1 - (1-lr_ratio)*(t-0.5)/0.4)
621 | else:
622 | lr = args.lr * lr_ratio
623 |
624 | elif args.lr_schedule == 'anneal_cosine':
625 | lr_min = args.lr * (args.step_ratio ** 2)
626 | lr_max = args.lr
627 | lr = lr_min + 1/2 * (lr_max - lr_min) * (1 + np.cos(_iter/args.iters * 3.141592653))
628 |
629 | if _iter % args.eval_every == 0:
630 | logging.info('Iter [{}] learning rate = {}'.format(_iter, lr))
631 |
632 | for param_group in optimizer.param_groups:
633 | param_group['lr'] = lr
634 |
635 |
636 | def accuracy(output, target, topk=(1,)):
637 | """Computes the precision@k for the specified values of k"""
638 | maxk = max(topk)
639 | batch_size = target.size(0)
640 |
641 | _, pred = output.topk(maxk, 1, True, True)
642 | pred = pred.t()
643 | correct = pred.eq(target.view(1, -1).expand_as(pred))
644 |
645 | res = []
646 | for k in topk:
647 | correct_k = correct[:k].view(-1).float().sum(0)
648 | res.append(correct_k.mul_(100.0 / batch_size))
649 | return res
650 |
651 |
652 | if __name__ == '__main__':
653 | main()
654 |
--------------------------------------------------------------------------------
/fractrain_cifar/util_swa.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 |
4 | def moving_average(net1, net2, alpha=1):
5 | for param1, param2 in zip(net1.parameters(), net2.parameters()):
6 | param1.data *= (1.0 - alpha)
7 | param1.data += param2.data * alpha
8 |
9 |
10 | def _check_bn(module, flag):
11 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
12 | flag[0] = True
13 |
14 |
15 | def check_bn(model):
16 | flag = [False]
17 | model.apply(lambda module: _check_bn(module, flag))
18 | return flag[0]
19 |
20 |
21 | def reset_bn(module):
22 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
23 | module.running_mean = torch.zeros_like(module.running_mean)
24 | module.running_var = torch.ones_like(module.running_var)
25 |
26 |
27 | def _get_momenta(module, momenta):
28 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
29 | momenta[module] = module.momentum
30 |
31 |
32 | def _set_momenta(module, momenta):
33 | if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
34 | module.momentum = momenta[module]
35 |
36 |
37 | def bn_update(loader, model, num_bits, num_grad_bits):
38 | """
39 | BatchNorm buffers update (if any).
40 | Performs 1 epochs to estimate buffers average using train dataset.
41 | :param loader: train dataset loader for buffers average estimation.
42 | :param model: model being update
43 | :return: None
44 | """
45 | if not check_bn(model):
46 | return
47 | model.train()
48 | momenta = {}
49 | model.apply(reset_bn)
50 | model.apply(lambda module: _get_momenta(module, momenta))
51 | n = 0
52 |
53 | print("SWA Update BN...")
54 | for input, _ in loader:
55 | input = input.cuda(async=True)
56 | input_var = torch.autograd.Variable(input)
57 | b = input_var.data.size(0)
58 |
59 | momentum = b / (n + b)
60 | for module in momenta.keys():
61 | module.momentum = momentum
62 |
63 | model(input_var, num_bits, num_grad_bits)
64 | n += b
65 |
66 | model.apply(lambda module: _set_momenta(module, momenta))
--------------------------------------------------------------------------------
/fractrain_imagenet/data.py:
--------------------------------------------------------------------------------
1 | """prepare CIFAR and SVHN
2 | """
3 |
4 | from __future__ import print_function
5 |
6 | import torch
7 | import torchvision
8 | import torchvision.transforms as transforms
9 | import numpy as np
10 |
11 |
12 | crop_size = 32
13 | padding = 4
14 |
15 |
16 | def prepare_train_data(dataset='cifar10', datadir='/home/yf22/dataset', batch_size=128,
17 | shuffle=True, num_workers=4):
18 |
19 | if 'cifar' in dataset:
20 | transform_train = transforms.Compose([
21 | transforms.RandomCrop(crop_size, padding=padding),
22 | transforms.RandomHorizontalFlip(),
23 | transforms.ToTensor(),
24 | transforms.Normalize((0.4914, 0.4822, 0.4465),
25 | (0.2023, 0.1994, 0.2010)),
26 | ])
27 |
28 | trainset = torchvision.datasets.__dict__[dataset.upper()](
29 | root=datadir, train=True, download=True, transform=transform_train)
30 | train_loader = torch.utils.data.DataLoader(trainset,
31 | batch_size=batch_size,
32 | shuffle=shuffle,
33 | num_workers=num_workers)
34 |
35 | if 'imagenet' in dataset:
36 | train_dataset = torchvision.datasets.ImageFolder(
37 | datadir,
38 | transforms.Compose([
39 | transforms.RandomResizedCrop(224),
40 | transforms.RandomHorizontalFlip(),
41 | transforms.ToTensor(),
42 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
43 | std=[0.229, 0.224, 0.225])
44 | ]))
45 | train_loader = torch.utils.data.DataLoader(
46 | train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
47 |
48 | elif 'svhn' in dataset:
49 | transform_train =transforms.Compose([
50 | transforms.ToTensor(),
51 | transforms.Normalize((0.4377, 0.4438, 0.4728),
52 | (0.1980, 0.2010, 0.1970)),
53 | ])
54 | trainset = torchvision.datasets.__dict__[dataset.upper()](
55 | root=datadir,
56 | split='train',
57 | download=True,
58 | transform=transform_train
59 | )
60 |
61 | transform_extra = transforms.Compose([
62 | transforms.ToTensor(),
63 | transforms.Normalize((0.4300, 0.4284, 0.4427),
64 | (0.1963, 0.1979, 0.1995))
65 | ])
66 |
67 | extraset = torchvision.datasets.__dict__[dataset.upper()](
68 | root=datadir,
69 | split='extra',
70 | download=True,
71 | transform = transform_extra
72 | )
73 |
74 | total_data = torch.utils.data.ConcatDataset([trainset, extraset])
75 |
76 | train_loader = torch.utils.data.DataLoader(total_data,
77 | batch_size=batch_size,
78 | shuffle=shuffle,
79 | num_workers=num_workers)
80 | else:
81 | train_loader = None
82 | return train_loader
83 |
84 |
85 | def prepare_test_data(dataset='cifar10', datadir='/home/yf22/dataset', batch_size=128,
86 | shuffle=False, num_workers=4):
87 |
88 | if 'cifar' in dataset:
89 | transform_test = transforms.Compose([
90 | transforms.ToTensor(),
91 | transforms.Normalize((0.4914, 0.4822, 0.4465),
92 | (0.2023, 0.1994, 0.2010)),
93 | ])
94 |
95 | testset = torchvision.datasets.__dict__[dataset.upper()](root=datadir,
96 | train=False,
97 | download=True,
98 | transform=transform_test)
99 | test_loader = torch.utils.data.DataLoader(testset,
100 | batch_size=batch_size,
101 | shuffle=shuffle,
102 | num_workers=num_workers)
103 |
104 | if 'imagenet' in dataset:
105 | test_loader = torch.utils.data.DataLoader(
106 | torchvision.datasets.ImageFolder(datadir, transforms.Compose([
107 | transforms.Resize(256),
108 | transforms.CenterCrop(224),
109 | transforms.ToTensor(),
110 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
111 | std=[0.229, 0.224, 0.225])
112 | ])),
113 | batch_size=batch_size, shuffle=False, num_workers=num_workers)
114 |
115 | elif 'svhn' in dataset:
116 | transform_test = transforms.Compose([
117 | transforms.ToTensor(),
118 | transforms.Normalize((0.4524, 0.4525, 0.4690),
119 | (0.2194, 0.2266, 0.2285)),
120 | ])
121 | testset = torchvision.datasets.__dict__[dataset.upper()](
122 | root=datadir,
123 | split='test',
124 | download=True,
125 | transform=transform_test)
126 | np.place(testset.labels, testset.labels == 10, 0)
127 | test_loader = torch.utils.data.DataLoader(testset,
128 | batch_size=batch_size,
129 | shuffle=shuffle,
130 | num_workers=num_workers)
131 | else:
132 | test_loader = None
133 | return test_loader
134 |
--------------------------------------------------------------------------------
/fractrain_imagenet/models.py:
--------------------------------------------------------------------------------
1 | """ This file contains the model definitions for both original ResNet (6n+2
2 | layers) and SkipNets.
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 | import math
8 | from torch.autograd import Variable
9 | import torch.autograd as autograd
10 | from modules.quantize import quantize, quantize_grad, QConv2d, QLinear, RangeBN
11 | import torch.nn.functional as F
12 |
13 |
14 | ACT_FW = 0
15 | ACT_BW = 0
16 | GRAD_ACT_ERROR = 0
17 | GRAD_ACT_GC = 0
18 | WEIGHT_BITS = 0
19 | MOMENTUM = 0.9
20 |
21 | def Conv3x3(in_planes, out_planes, stride=1):
22 | "3x3 convolution with padding"
23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
24 | padding=1, bias=False)
25 |
26 |
27 | def conv3x3(in_planes, out_planes, stride=1, pool_size = None, fix_prec=False):
28 | "3x3 convolution with padding"
29 | return QConv2d(in_planes, out_planes, kernel_size=3, stride=stride,
30 | padding=1, bias=False, momentum=MOMENTUM, quant_act_forward=ACT_FW, quant_act_backward=ACT_BW,
31 | quant_grad_act_error=GRAD_ACT_ERROR, quant_grad_act_gc=GRAD_ACT_GC, weight_bits=WEIGHT_BITS, fix_prec=fix_prec)
32 |
33 | def conv1x1(in_planes, out_planes, stride=1, pool_size = None, padding=0, fix_prec=False):
34 | return QConv2d(in_planes, out_planes, kernel_size=1, stride=stride,
35 | padding=padding, bias=False, momentum=MOMENTUM, quant_act_forward=ACT_FW, quant_act_backward=ACT_BW,
36 | quant_grad_act_error=GRAD_ACT_ERROR, quant_grad_act_gc=GRAD_ACT_GC, weight_bits=WEIGHT_BITS, fix_prec=fix_prec)
37 |
38 | def make_bn(planes):
39 | return nn.BatchNorm2d(planes)
40 | # return RangeBN(planes)
41 |
42 |
43 | class BasicBlock(nn.Module):
44 | expansion = 1
45 |
46 | def __init__(self, inplanes, planes, stride=1, downsample=None, fix_prec=False):
47 | super(BasicBlock, self).__init__()
48 | self.conv1 = conv3x3(inplanes, planes, stride, fix_prec=fix_prec)
49 | self.bn1 = nn.BatchNorm2d(planes)
50 | self.relu = nn.ReLU(inplace=True)
51 | self.conv2 = conv3x3(planes, planes, fix_prec=fix_prec)
52 | self.bn2 = nn.BatchNorm2d(planes)
53 | self.downsample = downsample
54 | self.stride = stride
55 |
56 | def forward(self, x, num_bits, num_grad_bits, mask_list):
57 | residual = x
58 |
59 | out = self.conv1(x, num_bits, num_grad_bits, mask_list)
60 | out = self.bn1(out)
61 | out = self.relu(out)
62 |
63 | out = self.conv2(out, num_bits, num_grad_bits, mask_list)
64 | out = self.bn2(out)
65 |
66 | if self.downsample is not None:
67 | residual = self.downsample(x)
68 |
69 | out += residual
70 | out = self.relu(out)
71 |
72 | return out
73 |
74 |
75 | class Bottleneck(nn.Module):
76 | expansion = 4
77 |
78 | def __init__(self, inplanes, planes, stride=1, downsample=None, fix_prec=False):
79 | super(Bottleneck, self).__init__()
80 | self.conv1 = conv1x1(inplanes, planes, fix_prec=fix_prec)
81 | self.bn1 = nn.BatchNorm2d(planes)
82 | self.conv2 = conv3x3(planes, planes, stride=stride, fix_prec=fix_prec)
83 | self.bn2 = nn.BatchNorm2d(planes)
84 | self.conv3 = conv1x1(planes, planes * 4, fix_prec=fix_prec)
85 | self.bn3 = nn.BatchNorm2d(planes * 4)
86 | self.relu = nn.ReLU(inplace=True)
87 | self.downsample = downsample
88 | self.stride = stride
89 |
90 | def forward(self, x, num_bits, num_grad_bits, mask_list):
91 | residual = x
92 |
93 | out = self.conv1(x, num_bits, num_grad_bits, mask_list)
94 | out = self.bn1(out)
95 | out = self.relu(out)
96 |
97 | out = self.conv2(out, num_bits, num_grad_bits, mask_list)
98 | out = self.bn2(out)
99 | out = self.relu(out)
100 |
101 | out = self.conv3(out, num_bits, num_grad_bits, mask_list)
102 | out = self.bn3(out)
103 |
104 | if self.downsample is not None:
105 | residual = self.downsample(x)
106 |
107 | out += residual
108 | out = self.relu(out)
109 |
110 | return out
111 |
112 |
113 | # class ResNet(nn.Module):
114 | # def __init__(self, block, layers, num_classes=1000):
115 | # self.inplanes = 64
116 | # super(ResNet, self).__init__()
117 |
118 | # self.num_layers = layers
119 |
120 | # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
121 | # bias=False)
122 | # self.bn1 = nn.BatchNorm2d(64)
123 | # self.relu = nn.ReLU(inplace=True)
124 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
125 |
126 | # self.layer1 = self._make_group(block, 64, layers[0], group_id=1)
127 | # self.layer2 = self._make_group(block, 128, layers[1], group_id=2)
128 | # self.layer3 = self._make_group(block, 256, layers[2], group_id=3)
129 | # self.layer4 = self._make_group(block, 512, layers[3], group_id=4)
130 |
131 | # self.avgpool = nn.AvgPool2d(7)
132 | # self.fc = nn.Linear(512 * block.expansion, num_classes)
133 |
134 | # for m in self.modules():
135 | # if isinstance(m, nn.Conv2d):
136 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
137 | # m.weight.data.normal_(0, math.sqrt(2. / n))
138 | # if isinstance(m, QConv2d):
139 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
140 | # m.weight.data.normal_(0, math.sqrt(2. / n))
141 | # elif isinstance(m, nn.BatchNorm2d):
142 | # m.weight.data.fill_(1)
143 | # m.bias.data.zero_()
144 | # elif isinstance(m, nn.Linear):
145 | # n = m.weight.size(0) * m.weight.size(1)
146 | # m.weight.data.normal_(0, math.sqrt(2. / n))
147 |
148 |
149 | # def _make_group(self, block, planes, layers, group_id):
150 | # """ Create the whole group"""
151 | # for i in range(layers):
152 | # if group_id > 1 and i == 0:
153 | # stride = 2
154 | # else:
155 | # stride = 1
156 |
157 | # layer = self._make_layer(block, planes, stride=stride)
158 |
159 | # setattr(self, 'group{}_layer{}'.format(group_id, i), layer)
160 |
161 |
162 | # def _make_layer(self, block, planes, stride=1):
163 | # downsample = None
164 | # if stride != 1 or self.inplanes != planes * block.expansion:
165 | # downsample = nn.Sequential(
166 | # nn.Conv2d(self.inplanes, planes * block.expansion,
167 | # kernel_size=1, stride=stride, bias=False),
168 | # nn.BatchNorm2d(planes * block.expansion),
169 | # )
170 |
171 | # layer = block(self.inplanes, planes, stride, downsample, fix_prec=True)
172 | # self.inplanes = planes * block.expansion
173 |
174 | # return layer
175 |
176 | # def forward(self, x, num_bits, num_grad_bits):
177 | # x = self.conv1(x)
178 | # x = self.bn1(x)
179 | # x = self.relu(x)
180 | # x = self.maxpool(x)
181 |
182 | # for g in range(len(self.num_layers)):
183 | # for i in range(self.num_layers[g]):
184 | # x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x, num_bits, num_grad_bits)
185 |
186 | # x = self.avgpool(x)
187 | # x = x.view(x.size(0), -1)
188 | # x = self.fc(x)
189 | # return x
190 |
191 |
192 | # def resnet18(pretrained=False, **kwargs):
193 | # model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
194 | # return model
195 |
196 |
197 | # def resnet34(pretrained=False, **kwargs):
198 | # model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
199 | # return model
200 |
201 |
202 | # def resnet50(pretrained=False, **kwargs):
203 | # model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
204 | # return model
205 |
206 |
207 |
208 | class RNNGate(nn.Module):
209 | """Recurrent Gate definition.
210 | Input is already passed through average pooling and embedding."""
211 | def __init__(self, input_dim, hidden_dim, proj_dim, rnn_type='lstm'):
212 | super(RNNGate, self).__init__()
213 | self.rnn_type = rnn_type
214 | self.input_dim = input_dim
215 | self.hidden_dim = hidden_dim
216 | self.proj_dim = proj_dim
217 |
218 | if self.rnn_type == 'lstm':
219 | self.rnn_one = nn.LSTM(input_dim, hidden_dim)
220 | # self.rnn_two = nn.LSTM(hidden_dim, hidden_dim)
221 | else:
222 | self.rnn = None
223 | self.hidden_one = None
224 | # self.hidden_two = None
225 |
226 | # reduce dim
227 | self.proj = nn.Linear(hidden_dim, proj_dim)
228 | # self.proj_two = nn.Linear(hidden_dim, 4)
229 | self.prob = nn.Sigmoid()
230 | self.prob_layer = nn.Softmax()
231 |
232 | def init_hidden(self, batch_size):
233 | # The axes semantics are (num_layers, minibatch_size, hidden_dim)
234 | return (autograd.Variable(torch.zeros(1, batch_size,
235 | self.hidden_dim).cuda()),
236 | autograd.Variable(torch.zeros(1, batch_size,
237 | self.hidden_dim).cuda()))
238 |
239 | def repackage_hidden(self):
240 | self.hidden_one = repackage_hidden(self.hidden_one)
241 | # self.hidden_two = repackage_hidden(self.hidden_two)
242 | def forward(self, x):
243 | # Take the convolution output of each step
244 | batch_size = x.size(0)
245 | # self.rnn_one.flatten_parameters()
246 | # self.rnn_two.flatten_parameters()
247 |
248 | out_one, self.hidden_one = self.rnn_one(x.view(1, batch_size, -1), self.hidden_one)
249 |
250 | # out_one = F.dropout(out_one, p = 0.1, training=True)
251 |
252 | # out_two, self.hidden_two = self.rnn_two(out_one.view(1, batch_size, -1), self.hidden_two)
253 |
254 | x_one = self.proj(out_one.squeeze())
255 | # x_two = self.proj_two(out_two.squeeze())
256 |
257 | # proj = self.proj(out.squeeze())
258 | prob = self.prob_layer(x_one)
259 | # prob_two = self.prob_layer(x_two)
260 |
261 | # x_one = (prob > 0.5).float().detach() - \
262 | # prob.detach() + prob
263 |
264 | # x_two = prob_two.detach().cpu().numpy()
265 |
266 | x_one = prob.detach().cpu().numpy()
267 |
268 | hard = (x_one == x_one.max(axis=1)[:,None]).astype(int)
269 | hard = torch.from_numpy(hard)
270 | hard = hard.cuda()
271 |
272 | # x_two = hard.float().detach() - \
273 | # prob_two.detach() + prob_two
274 |
275 | x_one = hard.float().detach() - \
276 | prob.detach() + prob
277 |
278 | # print(x_one)
279 |
280 | x_one = x_one.view(x_one.size(0),x_one.size(1), 1, 1, 1)
281 |
282 | # x_two = x_two.view(x_two.size(0), x_two.size(1), 1, 1, 1)
283 |
284 | return x_one # , x_two
285 |
286 |
287 |
288 | class ResNet_RNN(nn.Module):
289 | def __init__(self, block, layers, num_classes=1000, embed_dim=40, hidden_dim=20, proj_dim=7):
290 | self.inplanes = 64
291 | super(ResNet_RNN, self).__init__()
292 |
293 | self.num_layers = layers
294 | self.embed_dim = embed_dim
295 | self.hidden_dim = hidden_dim
296 |
297 | self.control = RNNGate(embed_dim, hidden_dim, proj_dim, rnn_type='lstm')
298 |
299 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
300 | bias=False)
301 | self.bn1 = nn.BatchNorm2d(64)
302 | self.relu = nn.ReLU(inplace=True)
303 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
304 | self.gate_layer1 = nn.Sequential(nn.AvgPool2d(56),
305 | nn.Conv2d(in_channels=64, out_channels=self.embed_dim, kernel_size=1, stride=1))
306 |
307 | self.layer1 = self._make_group(block, 64, layers[0], group_id=1, pool_size=56)
308 | self.layer2 = self._make_group(block, 128, layers[1], group_id=2, pool_size=28)
309 | self.layer3 = self._make_group(block, 256, layers[2], group_id=3, pool_size=14)
310 | self.layer4 = self._make_group(block, 512, layers[3], group_id=4, pool_size=7)
311 |
312 | self.avgpool = nn.AvgPool2d(7)
313 | self.fc = nn.Linear(512 * block.expansion, num_classes)
314 |
315 | for m in self.modules():
316 | if isinstance(m, nn.Conv2d):
317 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
318 | m.weight.data.normal_(0, math.sqrt(2. / n))
319 | if isinstance(m, QConv2d):
320 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
321 | m.weight.data.normal_(0, math.sqrt(2. / n))
322 | elif isinstance(m, nn.BatchNorm2d):
323 | m.weight.data.fill_(1)
324 | m.bias.data.zero_()
325 | elif isinstance(m, nn.Linear):
326 | n = m.weight.size(0) * m.weight.size(1)
327 | m.weight.data.normal_(0, math.sqrt(2. / n))
328 |
329 |
330 | def _make_group(self, block, planes, layers, group_id, pool_size):
331 | """ Create the whole group"""
332 | for i in range(layers):
333 | if group_id > 1 and i == 0:
334 | stride = 2
335 | else:
336 | stride = 1
337 |
338 | layer, gate_layer = self._make_layer(block, planes, stride=stride, pool_size=pool_size)
339 |
340 | setattr(self, 'group{}_layer{}'.format(group_id, i), layer)
341 | setattr(self, 'group{}_gate{}'.format(group_id, i), gate_layer)
342 |
343 |
344 | def _make_layer(self, block, planes, pool_size, stride=1):
345 | downsample = None
346 | if stride != 1 or self.inplanes != planes * block.expansion:
347 | downsample = nn.Sequential(
348 | nn.Conv2d(self.inplanes, planes * block.expansion,
349 | kernel_size=1, stride=stride, bias=False),
350 | nn.BatchNorm2d(planes * block.expansion),
351 | )
352 |
353 | layer = block(self.inplanes, planes, stride, downsample)
354 | self.inplanes = planes * block.expansion
355 |
356 | gate_layer = nn.Sequential(
357 | nn.AvgPool2d(pool_size),
358 | nn.Conv2d(in_channels=planes * block.expansion,
359 | out_channels=self.embed_dim,
360 | kernel_size=1,
361 | stride=1))
362 |
363 | return layer, gate_layer
364 |
365 | def forward(self, x, bits, grad_bits):
366 | x = self.conv1(x)
367 | x = self.bn1(x)
368 | x = self.relu(x)
369 | x = self.maxpool(x)
370 |
371 | batch_size = x.size(0)
372 | self.control.hidden_one = self.control.init_hidden(batch_size)
373 |
374 | masks = []
375 |
376 | gate_feature = self.gate_layer1(x)
377 | mask = self.control(gate_feature)
378 |
379 | for g in range(len(self.num_layers)):
380 | for i in range(self.num_layers[g]):
381 |
382 | mask_list = []
383 |
384 | for j in range(len(bits)):
385 | mask_list.append(mask[:,j,:,:,:])
386 |
387 | x = getattr(self, 'group{}_layer{}'.format(g+1, i))(x, bits, grad_bits, mask_list)
388 |
389 | mask_list = [mask.squeeze() for mask in mask_list]
390 |
391 | masks.append(mask_list)
392 |
393 | gate_feature = getattr(self, 'group{}_gate{}'.format(g+1, i))(x)
394 | mask = self.control(gate_feature)
395 | # mask_grad = self.control_grad(gate_feature)
396 |
397 | x = self.avgpool(x)
398 | x = x.view(x.size(0), -1)
399 | x = self.fc(x)
400 |
401 | return x, masks
402 |
403 |
404 | def resnet18_rnn(pretrained=False, **kwargs):
405 | model = ResNet_RNN(BasicBlock, [2, 2, 2, 2], **kwargs)
406 | return model
407 |
408 |
409 | def resnet34_rnn(pretrained=False, **kwargs):
410 | model = ResNet_RNN(BasicBlock, [3, 4, 6, 3], **kwargs)
411 | return model
412 |
413 |
414 | def resnet50_rnn(pretrained=False, **kwargs):
415 | model = ResNet_RNN(Bottleneck, [3, 4, 6, 3], **kwargs)
416 | return model
417 |
418 |
419 |
420 | if __name__ == '__main__':
421 | model = resnet18()
422 | from thop import profile
423 | flops, params = profile(model, inputs=(torch.randn(1, 3, 256, 256),))
424 | print('flops:', flops, 'params:', params)
425 |
426 |
427 |
428 |
429 |
--------------------------------------------------------------------------------
/fractrain_imagenet/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/fractrain_imagenet/modules/__init__.py
--------------------------------------------------------------------------------
/fractrain_imagenet/modules/bwn.py:
--------------------------------------------------------------------------------
1 | """
2 | Bounded weight norm
3 | Weight Normalization from https://arxiv.org/abs/1602.07868
4 | taken and adapted from https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/weight_norm.py
5 | """
6 | import torch
7 | from torch.nn.parameter import Parameter
8 | from torch.autograd import Variable, Function
9 | import torch.nn as nn
10 |
11 |
12 | def gather_params(self, memo=None, param_func=lambda s: s._parameters.values()):
13 | if memo is None:
14 | memo = set()
15 | for p in param_func(self):
16 | if p is not None and p not in memo:
17 | memo.add(p)
18 | yield p
19 | for m in self.children():
20 | for p in gather_params(m, memo, param_func):
21 | yield p
22 |
23 | nn.Module.gather_params = gather_params
24 |
25 |
26 | def _norm(x, dim, p=2):
27 | """Computes the norm over all dimensions except dim"""
28 | if p == float('inf'): # infinity norm
29 | func = lambda x, dim: x.abs().max(dim=dim)[0]
30 | else:
31 | func = lambda x, dim: torch.norm(x, dim=dim, p=p)
32 | if dim is None:
33 | return x.norm(p=p)
34 | elif dim == 0:
35 | output_size = (x.size(0),) + (1,) * (x.dim() - 1)
36 | return func(x.contiguous().view(x.size(0), -1), 1).view(*output_size)
37 | elif dim == x.dim() - 1:
38 | output_size = (1,) * (x.dim() - 1) + (x.size(-1),)
39 | return func(x.contiguous().view(-1, x.size(-1)), 0).view(*output_size)
40 | else:
41 | return _norm(x.transpose(0, dim), 0).transpose(0, dim)
42 |
43 |
44 | def _mean(p, dim):
45 | """Computes the mean over all dimensions except dim"""
46 | if dim is None:
47 | return p.mean()
48 | elif dim == 0:
49 | output_size = (p.size(0),) + (1,) * (p.dim() - 1)
50 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size)
51 | elif dim == p.dim() - 1:
52 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),)
53 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size)
54 | else:
55 | return _mean(p.transpose(0, dim), 0).transpose(0, dim)
56 |
57 |
58 | class BoundedWeighNorm(object):
59 |
60 | def __init__(self, name, dim, p):
61 | self.name = name
62 | self.dim = dim
63 | self.p = p
64 |
65 | def compute_weight(self, module):
66 | v = getattr(module, self.name + '_v')
67 | pre_norm = getattr(module, self.name + '_prenorm')
68 | return v * (pre_norm / _norm(v, self.dim, p=self.p))
69 |
70 | @staticmethod
71 | def apply(module, name, dim, p):
72 | fn = BoundedWeighNorm(name, dim, p)
73 |
74 | weight = getattr(module, name)
75 |
76 | # remove w from parameter list
77 | del module._parameters[name]
78 |
79 | prenorm = _norm(weight, dim, p=p).mean()
80 | module.register_buffer(name + '_prenorm', prenorm.detach())
81 | pre_norm = getattr(module, name + '_prenorm')
82 | print(pre_norm)
83 | module.register_parameter(name + '_v', Parameter(weight.data))
84 | setattr(module, name, fn.compute_weight(module))
85 |
86 | # recompute weight before every forward()
87 | module.register_forward_pre_hook(fn)
88 |
89 | def gather_normed_params(self, memo=None, param_func=lambda s: fn.compute_weight(s)):
90 | return gather_params(self, memo, param_func)
91 | module.gather_params = gather_normed_params
92 | return fn
93 |
94 | def remove(self, module):
95 | weight = self.compute_weight(module)
96 | delattr(module, self.name)
97 | del module._parameters[self.name + '_prenorm']
98 | del module._parameters[self.name + '_v']
99 | module.register_parameter(self.name, Parameter(weight.data))
100 |
101 | def __call__(self, module, inputs):
102 | setattr(module, self.name, self.compute_weight(module))
103 |
104 |
105 | def weight_norm(module, name='weight', dim=0, p=2):
106 | r"""Applies weight normalization to a parameter in the given module.
107 |
108 | .. math::
109 | \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
110 |
111 | Weight normalization is a reparameterization that decouples the magnitude
112 | of a weight tensor from its direction. This replaces the parameter specified
113 | by `name` (e.g. "weight") with two parameters: one specifying the magnitude
114 | (e.g. "weight_g") and one specifying the direction (e.g. "weight_v").
115 | Weight normalization is implemented via a hook that recomputes the weight
116 | tensor from the magnitude and direction before every :meth:`~Module.forward`
117 | call.
118 |
119 | By default, with `dim=0`, the norm is computed independently per output
120 | channel/plane. To compute a norm over the entire weight tensor, use
121 | `dim=None`.
122 |
123 | See https://arxiv.org/abs/1602.07868
124 |
125 | Args:
126 | module (nn.Module): containing module
127 | name (str, optional): name of weight parameter
128 | dim (int, optional): dimension over which to compute the norm
129 |
130 | Returns:
131 | The original module with the weight norm hook
132 |
133 | Example::
134 |
135 | >>> m = weight_norm(nn.Linear(20, 40), name='weight')
136 | Linear (20 -> 40)
137 | >>> m.weight_g.size()
138 | torch.Size([40, 1])
139 | >>> m.weight_v.size()
140 | torch.Size([40, 20])
141 |
142 | """
143 | BoundedWeighNorm.apply(module, name, dim, p)
144 | return module
145 |
146 |
147 | def remove_weight_norm(module, name='weight'):
148 | r"""Removes the weight normalization reparameterization from a module.
149 |
150 | Args:
151 | module (nn.Module): containing module
152 | name (str, optional): name of weight parameter
153 |
154 | Example:
155 | >>> m = weight_norm(nn.Linear(20, 40))
156 | >>> remove_weight_norm(m)
157 | """
158 | for k, hook in module._forward_pre_hooks.items():
159 | if isinstance(hook, BoundedWeighNorm) and hook.name == name:
160 | hook.remove(module)
161 | del module._forward_pre_hooks[k]
162 | return module
163 |
164 | raise ValueError("weight_norm of '{}' not found in {}"
165 | .format(name, module))
166 |
--------------------------------------------------------------------------------
/fractrain_imagenet/modules/quantize.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.autograd.function import InplaceFunction, Function
7 |
8 | QParams = namedtuple('QParams', ['range', 'zero_point', 'num_bits'])
9 |
10 | _DEFAULT_FLATTEN = (1, -1)
11 | _DEFAULT_FLATTEN_GRAD = (0, -1)
12 |
13 |
14 | def _deflatten_as(x, x_full):
15 | shape = list(x.shape) + [1] * (x_full.dim() - x.dim())
16 | return x.view(*shape)
17 |
18 |
19 | def calculate_qparams(x, num_bits, flatten_dims=_DEFAULT_FLATTEN, reduce_dim=0, reduce_type='mean', keepdim=False, true_zero=False):
20 | with torch.no_grad():
21 | x_flat = x.flatten(*flatten_dims)
22 | if x_flat.dim() == 1:
23 | min_values = _deflatten_as(x_flat.min(), x)
24 | max_values = _deflatten_as(x_flat.max(), x)
25 | else:
26 | min_values = _deflatten_as(x_flat.min(-1)[0], x)
27 | max_values = _deflatten_as(x_flat.max(-1)[0], x)
28 | if reduce_dim is not None:
29 | if reduce_type == 'mean':
30 | min_values = min_values.mean(reduce_dim, keepdim=keepdim)
31 | max_values = max_values.mean(reduce_dim, keepdim=keepdim)
32 | else:
33 | min_values = min_values.min(reduce_dim, keepdim=keepdim)[0]
34 | max_values = max_values.max(reduce_dim, keepdim=keepdim)[0]
35 | # TODO: re-add true zero computation
36 | range_values = max_values - min_values
37 | return QParams(range=range_values, zero_point=min_values,
38 | num_bits=num_bits)
39 |
40 |
41 | class UniformQuantize(InplaceFunction):
42 |
43 | @staticmethod
44 | def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN,
45 | reduce_dim=0, dequantize=True, signed=False, stochastic=False, inplace=False):
46 |
47 | ctx.inplace = inplace
48 |
49 | if ctx.inplace:
50 | ctx.mark_dirty(input)
51 | output = input
52 | else:
53 | output = input.clone()
54 |
55 | if qparams is None:
56 | assert num_bits is not None, "either provide qparams of num_bits to quantize"
57 | qparams = calculate_qparams(
58 | input, num_bits=num_bits, flatten_dims=flatten_dims, reduce_dim=reduce_dim)
59 |
60 | zero_point = qparams.zero_point
61 | num_bits = qparams.num_bits
62 | qmin = -(2.**(num_bits - 1)) if signed else 0.
63 | qmax = qmin + 2.**num_bits - 1.
64 | scale = qparams.range / (qmax - qmin)
65 |
66 | min_scale = torch.tensor(1e-8).expand_as(scale).cuda()
67 | scale = torch.max(scale, min_scale)
68 |
69 | with torch.no_grad():
70 | output.add_(qmin * scale - zero_point).div_(scale)
71 | if stochastic:
72 | noise = output.new(output.shape).uniform_(-0.5, 0.5)
73 | output.add_(noise)
74 | # quantize
75 | output.clamp_(qmin, qmax).round_()
76 |
77 | if dequantize:
78 | output.mul_(scale).add_(
79 | zero_point - qmin * scale) # dequantize
80 | return output
81 |
82 | @staticmethod
83 | def backward(ctx, grad_output):
84 | # straight-through estimator
85 | grad_input = grad_output
86 | return grad_input, None, None, None, None, None, None, None, None
87 |
88 |
89 | class UniformQuantizeGrad(InplaceFunction):
90 |
91 | @staticmethod
92 | def forward(ctx, input, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD,
93 | reduce_dim=0, dequantize=True, signed=False, stochastic=True):
94 | ctx.num_bits = num_bits
95 | ctx.qparams = qparams
96 | ctx.flatten_dims = flatten_dims
97 | ctx.stochastic = stochastic
98 | ctx.signed = signed
99 | ctx.dequantize = dequantize
100 | ctx.reduce_dim = reduce_dim
101 | ctx.inplace = False
102 | return input
103 |
104 | @staticmethod
105 | def backward(ctx, grad_output):
106 | qparams = ctx.qparams
107 | with torch.no_grad():
108 | if qparams is None:
109 | assert ctx.num_bits is not None, "either provide qparams of num_bits to quantize"
110 | qparams = calculate_qparams(
111 | grad_output, num_bits=ctx.num_bits, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim, reduce_type='extreme')
112 |
113 | grad_input = quantize(grad_output, num_bits=None,
114 | qparams=qparams, flatten_dims=ctx.flatten_dims, reduce_dim=ctx.reduce_dim,
115 | dequantize=True, signed=ctx.signed, stochastic=ctx.stochastic, inplace=False)
116 | return grad_input, None, None, None, None, None, None, None
117 |
118 |
119 | def conv2d_biprec(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, num_bits_grad=None):
120 | out1 = F.conv2d(input.detach(), weight, bias,
121 | stride, padding, dilation, groups)
122 | out2 = F.conv2d(input, weight.detach(), bias.detach() if bias is not None else None,
123 | stride, padding, dilation, groups)
124 | out2 = quantize_grad(out2, num_bits=num_bits_grad, flatten_dims=(1, -1))
125 | return out1 + out2 - out1.detach()
126 |
127 |
128 | def linear_biprec(input, weight, bias=None, num_bits_grad=None):
129 | out1 = F.linear(input.detach(), weight, bias)
130 | out2 = F.linear(input, weight.detach(), bias.detach()
131 | if bias is not None else None)
132 | out2 = quantize_grad(out2, num_bits=num_bits_grad)
133 | return out1 + out2 - out1.detach()
134 |
135 |
136 | def quantize(x, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN, reduce_dim=0, dequantize=True, signed=False, stochastic=False, inplace=False):
137 | if qparams:
138 | if qparams.num_bits:
139 | return UniformQuantize().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic, inplace)
140 | elif num_bits:
141 | return UniformQuantize().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic, inplace)
142 |
143 | return x
144 |
145 |
146 | def quantize_grad(x, num_bits=None, qparams=None, flatten_dims=_DEFAULT_FLATTEN_GRAD, reduce_dim=0, dequantize=True, signed=False, stochastic=True):
147 | if qparams:
148 | if qparams.num_bits:
149 | return UniformQuantizeGrad().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic)
150 | elif num_bits:
151 | return UniformQuantizeGrad().apply(x, num_bits, qparams, flatten_dims, reduce_dim, dequantize, signed, stochastic)
152 |
153 | return x
154 |
155 |
156 | class QuantMeasure(nn.Module):
157 | """docstring for QuantMeasure."""
158 |
159 | def __init__(self, shape_measure=(1,), flatten_dims=_DEFAULT_FLATTEN,
160 | inplace=False, dequantize=True, stochastic=False, momentum=0.9, measure=False):
161 | super(QuantMeasure, self).__init__()
162 | self.register_buffer('running_zero_point', torch.zeros(*shape_measure))
163 | self.register_buffer('running_range', torch.zeros(*shape_measure))
164 | self.measure = measure
165 | if self.measure:
166 | self.register_buffer('num_measured', torch.zeros(1))
167 | self.flatten_dims = flatten_dims
168 | self.momentum = momentum
169 | self.dequantize = dequantize
170 | self.stochastic = stochastic
171 | self.inplace = inplace
172 |
173 | def forward(self, input, num_bits, qparams=None):
174 |
175 | if self.training or self.measure:
176 | if qparams is None:
177 | qparams = calculate_qparams(
178 | input, num_bits=num_bits, flatten_dims=self.flatten_dims, reduce_dim=0, reduce_type='extreme')
179 | with torch.no_grad():
180 | if self.measure:
181 | momentum = self.num_measured / (self.num_measured + 1)
182 | self.num_measured += 1
183 | else:
184 | momentum = self.momentum
185 | self.running_zero_point.mul_(momentum).add_(
186 | qparams.zero_point * (1 - momentum))
187 | self.running_range.mul_(momentum).add_(
188 | qparams.range * (1 - momentum))
189 | else:
190 | qparams = QParams(range=self.running_range,
191 | zero_point=self.running_zero_point, num_bits=num_bits)
192 | if self.measure:
193 | return input
194 | else:
195 | q_input = quantize(input, qparams=qparams, dequantize=self.dequantize,
196 | stochastic=self.stochastic, inplace=self.inplace)
197 | return q_input
198 |
199 |
200 | class QConv2d(nn.Conv2d):
201 | """docstring for QConv2d."""
202 |
203 | def __init__(self, in_channels, out_channels, kernel_size,
204 | stride=1, padding=0, dilation=1, groups=1, bias=True, momentum=0.1, quant_act_forward=0, quant_act_backward=0, quant_grad_act_error=0, quant_grad_act_gc=0, weight_bits=0, fix_prec=False):
205 | super(QConv2d, self).__init__(in_channels, out_channels, kernel_size,
206 | stride, padding, dilation, groups, bias)
207 |
208 | self.quantize_input_fw = QuantMeasure(shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1), momentum=momentum)
209 | self.quantize_input_bw = QuantMeasure(shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1), momentum=momentum)
210 | self.quant_act_forward = quant_act_forward
211 | self.quant_act_backward = quant_act_backward
212 | self.quant_grad_act_error = quant_grad_act_error
213 | self.quant_grad_act_gc = quant_grad_act_gc
214 | self.weight_bits = weight_bits
215 | self.fix_prec = fix_prec
216 | self.stride = stride
217 |
218 |
219 | def forward(self, input, num_bits, num_grad_bits, mask_list):
220 |
221 | input_candidates = [self.quantize_input_fw(input, num_bits=bit) for bit in num_bits]
222 | x = sum([mask_list[k].expand_as(input) * input_candidates[k] for k in range(len(num_bits))])
223 |
224 | weight_qparams = calculate_qparams(self.weight, num_bits=self.weight_bits, flatten_dims=(1, -1), reduce_dim=None)
225 | qweight = quantize(self.weight, qparams=weight_qparams)
226 | qbias = None
227 |
228 | # qinput = self.quantize_input_fw(input, num_bits)
229 | output = F.conv2d(x, qweight, qbias, self.stride, self.padding, self.dilation, self.groups)
230 |
231 | output_candidates = [quantize_grad(output, num_bits=bit) for bit in num_grad_bits]
232 | x = sum([mask_list[k].expand_as(output).detach() * output_candidates[k] for k in range(len(num_grad_bits))])
233 |
234 | return x
235 |
236 |
237 | def conv2d_quant_act(self, input_fw, input_bw, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, error_bits=0, gc_bits=0):
238 | out1 = F.conv2d(input_fw, weight.detach(), bias.detach() if bias is not None else None,
239 | stride, padding, dilation, groups)
240 | out2 = F.conv2d(input_bw.detach(), weight, bias,
241 | stride, padding, dilation, groups)
242 | out1 = quantize_grad(out1, num_bits=error_bits)
243 | out2 = quantize_grad(out2, num_bits=gc_bits)
244 | return out1 + out2 - out2.detach()
245 |
246 |
247 | class QLinear(nn.Linear):
248 | """docstring for QConv2d."""
249 |
250 | def __init__(self, in_features, out_features, bias=True, num_bits=8, num_bits_weight=8, num_bits_grad=8, biprecision=True):
251 | super(QLinear, self).__init__(in_features, out_features, bias)
252 | self.num_bits = num_bits
253 | self.num_bits_weight = num_bits_weight or num_bits
254 | self.num_bits_grad = num_bits_grad
255 | self.biprecision = biprecision
256 | self.quantize_input = QuantMeasure(self.num_bits)
257 |
258 | def forward(self, input):
259 | qinput = self.quantize_input(input)
260 | weight_qparams = calculate_qparams(
261 | self.weight, num_bits=self.num_bits_weight, flatten_dims=(1, -1), reduce_dim=None)
262 | qweight = quantize(self.weight, qparams=weight_qparams)
263 | if self.bias is not None:
264 | qbias = quantize(
265 | self.bias, num_bits=self.num_bits_weight + self.num_bits,
266 | flatten_dims=(0, -1))
267 | else:
268 | qbias = None
269 |
270 | if not self.biprecision or self.num_bits_grad is None:
271 | output = F.linear(qinput, qweight, qbias)
272 | if self.num_bits_grad is not None:
273 | output = quantize_grad(
274 | output, num_bits=self.num_bits_grad)
275 | else:
276 | output = linear_biprec(qinput, qweight, qbias, self.num_bits_grad)
277 | return output
278 |
279 |
280 | class RangeBN(nn.Module):
281 | # this is normalized RangeBN
282 |
283 | def __init__(self, num_features, dim=1, momentum=0.1, affine=True, num_chunks=16, eps=1e-5, num_bits=8, num_bits_grad=8):
284 | super(RangeBN, self).__init__()
285 | self.register_buffer('running_mean', torch.zeros(num_features))
286 | self.register_buffer('running_var', torch.zeros(num_features))
287 |
288 | self.momentum = momentum
289 | self.dim = dim
290 | if affine:
291 | self.bias = nn.Parameter(torch.Tensor(num_features))
292 | self.weight = nn.Parameter(torch.Tensor(num_features))
293 | self.num_bits = num_bits
294 | self.num_bits_grad = num_bits_grad
295 | self.quantize_input = QuantMeasure(inplace=True, shape_measure=(1, 1, 1, 1), flatten_dims=(1, -1))
296 | self.eps = eps
297 | self.num_chunks = num_chunks
298 | self.reset_params()
299 |
300 | def reset_params(self):
301 | if self.weight is not None:
302 | self.weight.data.uniform_()
303 | if self.bias is not None:
304 | self.bias.data.zero_()
305 |
306 | def forward(self, x, num_bits, num_grad_bits):
307 | x = self.quantize_input(x, num_bits)
308 | if x.dim() == 2: # 1d
309 | x = x.unsqueeze(-1,).unsqueeze(-1)
310 |
311 | if self.training:
312 | B, C, H, W = x.shape
313 | y = x.transpose(0, 1).contiguous() # C x B x H x W
314 | y = y.view(C, self.num_chunks, (B * H * W) // self.num_chunks)
315 | mean_max = y.max(-1)[0].mean(-1) # C
316 | mean_min = y.min(-1)[0].mean(-1) # C
317 | mean = y.view(C, -1).mean(-1) # C
318 | scale_fix = (0.5 * 0.35) * (1 + (math.pi * math.log(4)) **
319 | 0.5) / ((2 * math.log(y.size(-1))) ** 0.5)
320 |
321 | scale = (mean_max - mean_min) * scale_fix
322 | with torch.no_grad():
323 | self.running_mean.mul_(self.momentum).add_(
324 | mean * (1 - self.momentum))
325 |
326 | self.running_var.mul_(self.momentum).add_(
327 | scale * (1 - self.momentum))
328 | else:
329 | mean = self.running_mean
330 | scale = self.running_var
331 | # scale = quantize(scale, num_bits=self.num_bits, min_value=float(
332 | # scale.min()), max_value=float(scale.max()))
333 | out = (x - mean.view(1, -1, 1, 1)) / \
334 | (scale.view(1, -1, 1, 1) + self.eps)
335 |
336 | if self.weight is not None:
337 | qweight = self.weight
338 | # qweight = quantize(self.weight, num_bits=self.num_bits,
339 | # min_value=float(self.weight.min()),
340 | # max_value=float(self.weight.max()))
341 | out = out * qweight.view(1, -1, 1, 1)
342 |
343 | if self.bias is not None:
344 | qbias = self.bias
345 | # qbias = quantize(self.bias, num_bits=self.num_bits)
346 | out = out + qbias.view(1, -1, 1, 1)
347 | if num_grad_bits:
348 | out = quantize_grad(
349 | out, num_bits=num_grad_bits, flatten_dims=(1, -1))
350 |
351 | if out.size(3) == 1 and out.size(2) == 1:
352 | out = out.squeeze(-1).squeeze(-1)
353 | return out
354 |
355 |
356 | class RangeBN1d(RangeBN):
357 | # this is normalized RangeBN
358 |
359 | def __init__(self, num_features, dim=1, momentum=0.1, affine=True, num_chunks=16, eps=1e-5, num_bits=8, num_bits_grad=8):
360 | super(RangeBN1d, self).__init__(num_features, dim, momentum,
361 | affine, num_chunks, eps, num_bits, num_bits_grad)
362 | self.quantize_input = QuantMeasure(
363 | self.num_bits, inplace=True, shape_measure=(1, 1), flatten_dims=(1, -1))
364 |
365 | if __name__ == '__main__':
366 | x = torch.rand(2, 3)
367 | x_q = quantize(x, flatten_dims=(-1), num_bits=8, dequantize=True)
368 | print(x)
369 | print(x_q)
--------------------------------------------------------------------------------
/fractrain_imagenet/modules/rnlu.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch.autograd.function import InplaceFunction
4 | from torch.autograd import Variable
5 | import torch.nn as nn
6 | import math
7 |
8 |
9 | class BiReLUFunction(InplaceFunction):
10 |
11 | @classmethod
12 | def forward(cls, ctx, input, inplace=False):
13 | if input.size(1) % 2 != 0:
14 | raise RuntimeError("dimension 1 of input must be multiple of 2, "
15 | "but got {}".format(input.size(1)))
16 | ctx.inplace = inplace
17 |
18 | if ctx.inplace:
19 | ctx.mark_dirty(input)
20 | output = input
21 | else:
22 | output = input.clone()
23 |
24 | pos, neg = output.chunk(2, dim=1)
25 | pos.clamp_(min=0)
26 | neg.clamp_(max=0)
27 | # scale = (pos - neg).view(pos.size(0), -1).mean(1).div_(2)
28 | # output.
29 | ctx.save_for_backward(output)
30 | return output
31 |
32 | @staticmethod
33 | def backward(ctx, grad_output):
34 | output, = ctx.saved_variables
35 | grad_input = grad_output.masked_fill(output.eq(0), 0)
36 | return grad_input, None
37 |
38 |
39 | def birelu(x, inplace=False):
40 | return BiReLUFunction().apply(x, inplace)
41 |
42 |
43 | class BiReLU(nn.Module):
44 | """docstring for BiReLU."""
45 |
46 | def __init__(self, inplace=False):
47 | super(BiReLU, self).__init__()
48 | self.inplace = inplace
49 |
50 | def forward(self, inputs):
51 | return birelu(inputs, inplace=self.inplace)
52 |
53 |
54 | def binorm(x, shift=0, scale_fix=(2 / math.pi) ** 0.5):
55 | pos, neg = (x + shift).split(2, dim=1)
56 | scale = (pos - neg).view(pos.size(0), -1).mean(1).div_(2) * scale_fix
57 | return x / scale
58 |
59 |
60 | def _mean(p, dim):
61 | """Computes the mean over all dimensions except dim"""
62 | if dim is None:
63 | return p.mean()
64 | elif dim == 0:
65 | output_size = (p.size(0),) + (1,) * (p.dim() - 1)
66 | return p.contiguous().view(p.size(0), -1).mean(dim=1).view(*output_size)
67 | elif dim == p.dim() - 1:
68 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),)
69 | return p.contiguous().view(-1, p.size(-1)).mean(dim=0).view(*output_size)
70 | else:
71 | return _mean(p.transpose(0, dim), 0).transpose(0, dim)
72 |
73 |
74 | def rnlu(x, inplace=False, shift=0, scale_fix=(math.pi / 2) ** 0.5):
75 | x = birelu(x, inplace=inplace)
76 | pos, neg = (x + shift).chunk(2, dim=1)
77 | # scale = torch.cat((_mean(pos, 1), -_mean(neg, 1)), 1) * scale_fix + 1e-5
78 | scale = (pos - neg).view(pos.size(0), -1).mean(1) * scale_fix + 1e-8
79 | return x / scale.view(scale.size(0), *([1] * (x.dim() - 1)))
80 |
81 |
82 | class RnLU(nn.Module):
83 | """docstring for RnLU."""
84 |
85 | def __init__(self, inplace=False):
86 | super(RnLU, self).__init__()
87 | self.inplace = inplace
88 |
89 | def forward(self, x):
90 | return rnlu(x, inplace=self.inplace)
91 |
92 | # output.
93 | if __name__ == "__main__":
94 | x = Variable(torch.randn(2, 16, 5, 5).cuda(), requires_grad=True)
95 | output = rnlu(x)
96 |
97 | output.sum().backward()
98 |
--------------------------------------------------------------------------------
/fractrain_imagenet/train_base.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.backends.cudnn as cudnn
6 | from torch.autograd import Variable
7 | import numpy as np
8 |
9 | import os
10 | import shutil
11 | import argparse
12 | import time
13 | import logging
14 |
15 | import models
16 | from data import *
17 |
18 |
19 | model_names = sorted(name for name in models.__dict__
20 | if name.islower() and not name.startswith('__')
21 | and callable(models.__dict__[name])
22 | )
23 |
24 |
25 | def parse_args():
26 | # hyper-parameters are from ResNet paper
27 | parser = argparse.ArgumentParser(
28 | description='Quantization Aware Training on ImageNet')
29 | parser.add_argument('--dir', help='annotate the working directory')
30 | parser.add_argument('--cmd', choices=['train', 'test'], default='train')
31 | parser.add_argument('--arch', metavar='ARCH', default='resnet50',
32 | choices=model_names,
33 | help='model architecture: ' +
34 | ' | '.join(model_names) +
35 | ' (default: cifar10_resnet_38)')
36 | parser.add_argument('--dataset', '-d', type=str, default='imagenet',
37 | choices=['cifar10', 'cifar100','imagenet'],
38 | help='dataset choice')
39 | parser.add_argument('--datadir', default='/home/yf22/dataset', type=str,
40 | help='path to dataset')
41 | parser.add_argument('--workers', default=16, type=int, metavar='N',
42 | help='number of data loading workers (default: 4 )')
43 | parser.add_argument('--epoch', default=90, type=int,
44 | help='number of epochs (default: 90)')
45 | parser.add_argument('--start_epoch', default=0, type=int,
46 | help='manual iter number (useful on restarts)')
47 | parser.add_argument('--batch_size', default=256, type=int,
48 | help='mini-batch size (default: 128)')
49 | parser.add_argument('--lr_schedule', default='piecewise', type=str,
50 | help='learning rate schedule')
51 | parser.add_argument('--lr', default=0.1, type=float,
52 | help='initial learning rate')
53 | parser.add_argument('--momentum', default=0.9, type=float,
54 | help='momentum')
55 | parser.add_argument('--weight_decay', default=1e-4, type=float,
56 | help='weight decay (default: 1e-4)')
57 | parser.add_argument('--print_freq', default=10, type=int,
58 | help='print frequency (default: 10)')
59 | parser.add_argument('--resume', default='', type=str,
60 | help='path to latest checkpoint (default: None)')
61 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
62 | help='use pretrained model')
63 | parser.add_argument('--step_ratio', default=0.1, type=float,
64 | help='ratio for learning rate deduction')
65 | parser.add_argument('--warm_up', action='store_true',
66 | help='for n = 18, the model needs to warm up for 400 '
67 | 'iterations')
68 | parser.add_argument('--save_folder', default='save_checkpoints',
69 | type=str,
70 | help='folder to save the checkpoints')
71 | parser.add_argument('--eval_every', default=390, type=int,
72 | help='evaluate model every (default: 1000) iterations')
73 | parser.add_argument('--num_bits',default=0,type=int,
74 | help='num bits for weight and activation')
75 | parser.add_argument('--num_grad_bits',default=0,type=int,
76 | help='num bits for gradient')
77 | parser.add_argument('--schedule', default=None, type=int, nargs='*',
78 | help='precision schedule')
79 | parser.add_argument('--num_bits_schedule',default=None,type=int,nargs='*',
80 | help='schedule for weight/act precision')
81 | parser.add_argument('--num_grad_bits_schedule',default=None,type=int,nargs='*',
82 | help='schedule for grad precision')
83 | parser.add_argument('--act_fw', default=0, type=int,
84 | help='precision of activation during forward, -1 means dynamic, 0 means no quantize')
85 | parser.add_argument('--act_bw', default=0, type=int,
86 | help='precision of activation during backward, -1 means dynamic, 0 means no quantize')
87 | parser.add_argument('--grad_act_error', default=0, type=int,
88 | help='precision of activation gradient during error backward, -1 means dynamic, 0 means no quantize')
89 | parser.add_argument('--grad_act_gc', default=0, type=int,
90 | help='precision of activation gradient during weight gradient computation, -1 means dynamic, 0 means no quantize')
91 | parser.add_argument('--weight_bits', default=0, type=int,
92 | help='precision of weight')
93 | parser.add_argument('--momentum_act', default=0.9, type=float,
94 | help='momentum for act min/max')
95 | args = parser.parse_args()
96 | return args
97 |
98 |
99 | def main():
100 | args = parse_args()
101 | save_path = args.save_path = os.path.join(args.save_folder, args.arch)
102 | if not os.path.exists(save_path):
103 | os.makedirs(save_path)
104 |
105 | models.ACT_FW = args.act_fw
106 | models.ACT_BW = args.act_bw
107 | models.GRAD_ACT_ERROR = args.grad_act_error
108 | models.GRAD_ACT_GC = args.grad_act_gc
109 | models.WEIGHT_BITS = args.weight_bits
110 | models.MOMENTUM = args.momentum_act
111 |
112 | args.num_bits = args.num_bits if not (args.act_fw + args.act_bw + args.grad_act_error + args.grad_act_gc + args.weight_bits) else -1
113 |
114 | # config logging file
115 | args.logger_file = os.path.join(save_path, 'log_{}.txt'.format(args.cmd))
116 | handlers = [logging.FileHandler(args.logger_file, mode='w'),
117 | logging.StreamHandler()]
118 | logging.basicConfig(level=logging.INFO,
119 | datefmt='%m-%d-%y %H:%M',
120 | format='%(asctime)s:%(message)s',
121 | handlers=handlers)
122 |
123 | if args.cmd == 'train':
124 | logging.info('start training {}'.format(args.arch))
125 | run_training(args)
126 |
127 | elif args.cmd == 'test':
128 | logging.info('start evaluating {} with checkpoints from {}'.format(
129 | args.arch, args.resume))
130 | test_model(args)
131 |
132 |
133 | def run_training(args):
134 | # create model
135 | model = models.__dict__[args.arch](args.pretrained)
136 | model = torch.nn.DataParallel(model).cuda()
137 |
138 | best_prec1 = 0
139 | best_full_prec = 0
140 |
141 | # optionally resume from a checkpoint
142 | if args.resume:
143 | if os.path.isfile(args.resume):
144 | logging.info('=> loading checkpoint `{}`'.format(args.resume))
145 | checkpoint = torch.load(args.resume)
146 | args.start_epoch = checkpoint['epoch']
147 | best_prec1 = checkpoint['best_prec1']
148 | model.load_state_dict(checkpoint['state_dict'])
149 | logging.info('=> loaded checkpoint `{}` (epoch: {})'.format(
150 | args.resume, checkpoint['epoch']
151 | ))
152 | else:
153 | logging.info('=> no checkpoint found at `{}`'.format(args.resume))
154 |
155 | cudnn.benchmark = False
156 |
157 | train_loader = prepare_train_data(dataset=args.dataset,
158 | datadir=args.datadir+'/train',
159 | batch_size=args.batch_size,
160 | shuffle=True,
161 | num_workers=args.workers)
162 | test_loader = prepare_test_data(dataset=args.dataset,
163 | datadir=args.datadir+'/val',
164 | batch_size=args.batch_size,
165 | shuffle=False,
166 | num_workers=args.workers)
167 |
168 | # define loss function (criterion) and optimizer
169 | criterion = nn.CrossEntropyLoss().cuda()
170 |
171 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
172 | momentum=args.momentum,
173 | weight_decay=args.weight_decay)
174 |
175 | # optimizer = torch.optim.Adam(model.parameters(), args.lr,
176 | # weight_decay=args.weight_decay)
177 |
178 | batch_time = AverageMeter()
179 | data_time = AverageMeter()
180 | losses = AverageMeter()
181 | top1 = AverageMeter()
182 | cr = AverageMeter()
183 |
184 | end = time.time()
185 |
186 | for _epoch in range(args.start_epoch, args.epoch):
187 | lr = adjust_learning_rate(args, optimizer, _epoch)
188 | adjust_precision(args, _epoch)
189 |
190 | print('Learning Rate:', lr)
191 | print('num bits:', args.num_bits, 'num grad bits:', args.num_grad_bits)
192 |
193 | for i, (input, target) in enumerate(train_loader):
194 | # measuring data loading time
195 | data_time.update(time.time() - end)
196 |
197 | model.train()
198 |
199 | fw_cost = args.num_bits*args.num_bits/32/32
200 | eb_cost = args.num_bits*args.num_grad_bits/32/32
201 | gc_cost = eb_cost
202 | cr.update((fw_cost+eb_cost+gc_cost)/3)
203 |
204 | target = target.squeeze().long().cuda()
205 | input_var = Variable(input).cuda()
206 | target_var = Variable(target).cuda()
207 |
208 | # compute output
209 | output = model(input_var, args.num_bits, args.num_grad_bits)
210 | loss = criterion(output, target_var)
211 |
212 | # measure accuracy and record loss
213 | prec1, = accuracy(output.data, target, topk=(1,))
214 | losses.update(loss.item(), input.size(0))
215 | top1.update(prec1.item(), input.size(0))
216 |
217 | # compute gradient and do SGD step
218 | optimizer.zero_grad()
219 | loss.backward()
220 | optimizer.step()
221 |
222 | # measure elapsed time
223 | batch_time.update(time.time() - end)
224 | end = time.time()
225 |
226 | # print log
227 | if i % args.print_freq == 0:
228 | logging.info("Iter: [{0}][{1}/{2}]\t"
229 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
230 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
231 | "Loss {loss.val:.3f} ({loss.avg:.3f})\t"
232 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t".format(
233 | _epoch,
234 | i,
235 | len(train_loader),
236 | batch_time=batch_time,
237 | data_time=data_time,
238 | loss=losses,
239 | top1=top1)
240 | )
241 |
242 | with torch.no_grad():
243 | prec1 = validate(args, test_loader, model, criterion, _epoch)
244 | # prec_full = validate_full_prec(args, test_loader, model, criterion, i)
245 |
246 | is_best = prec1 > best_prec1
247 | best_prec1 = max(prec1, best_prec1)
248 | #best_full_prec = max(prec_full, best_full_prec)
249 |
250 | print("Current Best Prec@1: ", best_prec1)
251 | #print("Current Best Full Prec@1: ", best_full_prec)
252 |
253 | checkpoint_path = os.path.join(args.save_path, 'checkpoint_{:05d}_{:.2f}.pth.tar'.format(_epoch, prec1))
254 | save_checkpoint({
255 | 'epoch': _epoch,
256 | 'arch': args.arch,
257 | 'state_dict': model.state_dict(),
258 | 'best_prec1': best_prec1,
259 | },
260 | is_best, filename=checkpoint_path)
261 | shutil.copyfile(checkpoint_path, os.path.join(args.save_path,
262 | 'checkpoint_latest'
263 | '.pth.tar'))
264 |
265 |
266 |
267 | def validate(args, test_loader, model, criterion, _epoch):
268 | batch_time = AverageMeter()
269 | losses = AverageMeter()
270 | top1 = AverageMeter()
271 |
272 | # switch to evaluation mode
273 | model.eval()
274 | end = time.time()
275 | for i, (input, target) in enumerate(test_loader):
276 | target = target.squeeze().long().cuda()
277 | input_var = Variable(input, volatile=True).cuda()
278 | target_var = Variable(target, volatile=True).cuda()
279 |
280 | # compute output
281 | output = model(input_var, args.num_bits, args.num_grad_bits)
282 | loss = criterion(output, target_var)
283 |
284 | # measure accuracy and record loss
285 | prec1, = accuracy(output.data, target, topk=(1,))
286 | top1.update(prec1.item(), input.size(0))
287 | losses.update(loss.item(), input.size(0))
288 | batch_time.update(time.time() - end)
289 | end = time.time()
290 |
291 | if (i % args.print_freq == 0) or (i == len(test_loader) - 1):
292 | logging.info(
293 | 'Test: [{}/{}]\t'
294 | 'Time: {batch_time.val:.4f}({batch_time.avg:.4f})\t'
295 | 'Loss: {loss.val:.3f}({loss.avg:.3f})\t'
296 | 'Prec@1: {top1.val:.3f}({top1.avg:.3f})\t'.format(
297 | i, len(test_loader), batch_time=batch_time,
298 | loss=losses, top1=top1
299 | )
300 | )
301 |
302 | logging.info('Epoch {} * Prec@1 {top1.avg:.3f}'.format(_epoch, top1=top1))
303 | return top1.avg
304 |
305 |
306 | def validate_full_prec(args, test_loader, model, criterion, _epoch):
307 | batch_time = AverageMeter()
308 | losses = AverageMeter()
309 | top1 = AverageMeter()
310 |
311 | # switch to evaluation mode
312 | model.eval()
313 | end = time.time()
314 | for i, (input, target) in enumerate(test_loader):
315 | target = target.squeeze().long().cuda()
316 | input_var = Variable(input, volatile=True).cuda()
317 | target_var = Variable(target, volatile=True).cuda()
318 |
319 | # compute output
320 | output = model(input_var, 0, 0)
321 | loss = criterion(output, target_var)
322 |
323 | # measure accuracy and record loss
324 | prec1, = accuracy(output.data, target, topk=(1,))
325 | top1.update(prec1.item(), input.size(0))
326 | losses.update(loss.item(), input.size(0))
327 | batch_time.update(time.time() - end)
328 | end = time.time()
329 |
330 |
331 | logging.info('Epoch {} * Full Prec@1 {top1.avg:.3f}'.format(_epoch, top1=top1))
332 | return top1.avg
333 |
334 |
335 | def test_model(args):
336 | # create model
337 | model = models.__dict__[args.arch](args.pretrained)
338 | model = torch.nn.DataParallel(model).cuda()
339 |
340 | if args.resume:
341 | if os.path.isfile(args.resume):
342 | logging.info('=> loading checkpoint `{}`'.format(args.resume))
343 | checkpoint = torch.load(args.resume)
344 | args.start_epoch = checkpoint['epoch']
345 | best_prec1 = checkpoint['best_prec1']
346 | model.load_state_dict(checkpoint['state_dict'])
347 | logging.info('=> loaded checkpoint `{}` (epoch: {})'.format(
348 | args.resume, checkpoint['epoch']
349 | ))
350 | else:
351 | logging.info('=> no checkpoint found at `{}`'.format(args.resume))
352 |
353 | cudnn.benchmark = False
354 | test_loader = prepare_test_data(dataset=args.dataset,
355 | batch_size=args.batch_size,
356 | shuffle=False,
357 | num_workers=args.workers)
358 | criterion = nn.CrossEntropyLoss().cuda()
359 |
360 | # validate(args, test_loader, model, criterion)
361 |
362 | with torch.no_grad():
363 | prec1 = validate(args, test_loader, model, criterion, args.start_iter)
364 | # prec_full = validate_full_prec(args, test_loader, model, criterion, args.start_iter)
365 |
366 |
367 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
368 | torch.save(state, filename)
369 | if is_best:
370 | save_path = os.path.dirname(filename)
371 | shutil.copyfile(filename, os.path.join(save_path,
372 | 'model_best.pth.tar'))
373 |
374 |
375 | class AverageMeter(object):
376 | """Computes and stores the average and current value"""
377 |
378 | def __init__(self):
379 | self.reset()
380 |
381 | def reset(self):
382 | self.val = 0
383 | self.avg = 0
384 | self.sum = 0
385 | self.count = 0
386 |
387 | def update(self, val, n=1):
388 | self.val = val
389 | self.sum += val * n
390 | self.count += n
391 | self.avg = self.sum / self.count
392 |
393 |
394 | schedule_cnt = 0
395 | def adjust_precision(args, _epoch):
396 | if args.schedule:
397 | global schedule_cnt
398 |
399 | assert len(args.num_bits_schedule) == len(args.schedule) + 1
400 | assert len(args.num_grad_bits_schedule) == len(args.schedule) + 1
401 |
402 | if schedule_cnt == 0:
403 | args.num_bits = args.num_bits_schedule[0]
404 | args.num_grad_bits = args.num_grad_bits_schedule[0]
405 | schedule_cnt += 1
406 |
407 | for step in args.schedule:
408 | if _epoch == step:
409 | args.num_bits = args.num_bits_schedule[schedule_cnt]
410 | args.num_grad_bits = args.num_grad_bits_schedule[schedule_cnt]
411 | schedule_cnt += 1
412 |
413 |
414 | def adjust_learning_rate(args, optimizer, _epoch):
415 | lr = args.lr * (0.1 ** (_epoch // 30))
416 |
417 | for param_group in optimizer.param_groups:
418 | param_group['lr'] = lr
419 |
420 | return lr
421 |
422 |
423 | def accuracy(output, target, topk=(1,)):
424 | """Computes the precision@k for the specified values of k"""
425 | maxk = max(topk)
426 | batch_size = target.size(0)
427 |
428 | _, pred = output.topk(maxk, 1, True, True)
429 | pred = pred.t()
430 | correct = pred.eq(target.view(1, -1).expand_as(pred))
431 |
432 | res = []
433 | for k in topk:
434 | correct_k = correct[:k].view(-1).float().sum(0)
435 | res.append(correct_k.mul_(100.0 / batch_size))
436 | return res
437 |
438 |
439 | if __name__ == '__main__':
440 | main()
441 |
--------------------------------------------------------------------------------
/fractrain_imagenet/train_dfq.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.backends.cudnn as cudnn
6 | from torch.autograd import Variable
7 | import numpy as np
8 |
9 | from functools import reduce
10 |
11 | import os
12 | import shutil
13 | import argparse
14 | import time
15 | import logging
16 |
17 | import models
18 | from data import *
19 |
20 |
21 | model_names = sorted(name for name in models.__dict__
22 | if name.islower() and not name.startswith('__')
23 | and callable(models.__dict__[name])
24 | )
25 |
26 |
27 | def parse_args():
28 | # hyper-parameters are from ResNet paper
29 | parser = argparse.ArgumentParser(
30 | description='DFQ on ImageNet')
31 | parser.add_argument('--dir', help='annotate the working directory')
32 | parser.add_argument('--cmd', choices=['train', 'test'], default='train')
33 | parser.add_argument('--arch', metavar='ARCH', default='resnet50',
34 | choices=model_names,
35 | help='model architecture: ' +
36 | ' | '.join(model_names) +
37 | ' (default: cifar10_resnet_38)')
38 | parser.add_argument('--dataset', '-d', type=str, default='imagenet',
39 | choices=['cifar10', 'cifar100','imagenet'],
40 | help='dataset choice')
41 | parser.add_argument('--datadir', default='/home/yf22/dataset', type=str,
42 | help='path to dataset')
43 | parser.add_argument('--workers', default=16, type=int, metavar='N',
44 | help='number of data loading workers (default: 4 )')
45 | parser.add_argument('--epoch', default=90, type=int,
46 | help='number of epochs (default: 90)')
47 | parser.add_argument('--start_epoch', default=0, type=int,
48 | help='manual iter number (useful on restarts)')
49 | parser.add_argument('--batch_size', default=256, type=int,
50 | help='mini-batch size (default: 128)')
51 | parser.add_argument('--lr_schedule', default='piecewise', type=str,
52 | help='learning rate schedule')
53 | parser.add_argument('--lr', default=0.1, type=float,
54 | help='initial learning rate')
55 | parser.add_argument('--momentum', default=0.9, type=float,
56 | help='momentum')
57 | parser.add_argument('--weight_decay', default=1e-4, type=float,
58 | help='weight decay (default: 1e-4)')
59 | parser.add_argument('--print_freq', default=10, type=int,
60 | help='print frequency (default: 10)')
61 | parser.add_argument('--resume', default='', type=str,
62 | help='path to latest checkpoint (default: None)')
63 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
64 | help='use pretrained model')
65 | parser.add_argument('--step_ratio', default=0.1, type=float,
66 | help='ratio for learning rate deduction')
67 | parser.add_argument('--warm_up', action='store_true',
68 | help='for n = 18, the model needs to warm up for 400 '
69 | 'iterations')
70 | parser.add_argument('--save_folder', default='save_checkpoints',
71 | type=str,
72 | help='folder to save the checkpoints')
73 | parser.add_argument('--eval_every', default=390, type=int,
74 | help='evaluate model every (default: 1000) iterations')
75 | parser.add_argument('--num_bits',default=0,type=int,
76 | help='num bits for weight and activation')
77 | parser.add_argument('--num_grad_bits',default=0,type=int,
78 | help='num bits for gradient')
79 | parser.add_argument('--schedule', default=None, type=int, nargs='*',
80 | help='precision schedule')
81 | parser.add_argument('--num_bits_schedule',default=None,type=int,nargs='*',
82 | help='schedule for weight/act precision')
83 | parser.add_argument('--num_grad_bits_schedule',default=None,type=int,nargs='*',
84 | help='schedule for grad precision')
85 | parser.add_argument('--act_fw', default=0, type=int,
86 | help='precision of activation during forward, -1 means dynamic, 0 means no quantize')
87 | parser.add_argument('--act_bw', default=0, type=int,
88 | help='precision of activation during backward, -1 means dynamic, 0 means no quantize')
89 | parser.add_argument('--grad_act_error', default=0, type=int,
90 | help='precision of activation gradient during error backward, -1 means dynamic, 0 means no quantize')
91 | parser.add_argument('--grad_act_gc', default=0, type=int,
92 | help='precision of activation gradient during weight gradient computation, -1 means dynamic, 0 means no quantize')
93 | parser.add_argument('--weight_bits', default=0, type=int,
94 | help='precision of weight')
95 | parser.add_argument('--target_ratio',default=4,type=float,
96 | help='target compression ratio')
97 | parser.add_argument('--target_ratio_schedule',default=None,type=float,nargs='*',
98 | help='schedule for target compression ratio')
99 | parser.add_argument('--momentum_act', default=0.9, type=float,
100 | help='momentum for act min/max')
101 | parser.add_argument('--relax', default=0, type=float,
102 | help='relax parameter for target ratio')
103 | parser.add_argument('--beta', default=1e-3, type=float,
104 | help='coefficient')
105 | parser.add_argument('--computation_cost', default=True, type=bool,
106 | help='using computation cost as regularization term')
107 | args = parser.parse_args()
108 | return args
109 |
110 |
111 | def main():
112 | args = parse_args()
113 | save_path = args.save_path = os.path.join(args.save_folder, args.arch)
114 | if not os.path.exists(save_path):
115 | os.makedirs(save_path)
116 |
117 | models.ACT_FW = args.act_fw
118 | models.ACT_BW = args.act_bw
119 | models.GRAD_ACT_ERROR = args.grad_act_error
120 | models.GRAD_ACT_GC = args.grad_act_gc
121 | models.WEIGHT_BITS = args.weight_bits
122 | models.MOMENTUM = args.momentum_act
123 |
124 | args.num_bits = args.num_bits if not (args.act_fw + args.act_bw + args.grad_act_error + args.grad_act_gc + args.weight_bits) else -1
125 |
126 | # config logging file
127 | args.logger_file = os.path.join(save_path, 'log_{}.txt'.format(args.cmd))
128 | handlers = [logging.FileHandler(args.logger_file, mode='w'),
129 | logging.StreamHandler()]
130 | logging.basicConfig(level=logging.INFO,
131 | datefmt='%m-%d-%y %H:%M',
132 | format='%(asctime)s:%(message)s',
133 | handlers=handlers)
134 |
135 | if args.cmd == 'train':
136 | logging.info('start training {}'.format(args.arch))
137 | run_training(args)
138 |
139 | elif args.cmd == 'test':
140 | logging.info('start evaluating {} with checkpoints from {}'.format(
141 | args.arch, args.resume))
142 | test_model(args)
143 |
144 | bits = [3, 4, 6, 8]
145 | grad_bits = [6, 8, 12, 16]
146 |
147 | def run_training(args):
148 |
149 | cost_fw = []
150 | for bit in bits:
151 | if bit == 0:
152 | cost_fw.append(1)
153 | else:
154 | cost_fw.append(bit/32)
155 | cost_fw = np.array(cost_fw) * args.weight_bits/32
156 |
157 | cost_eb = []
158 | for bit in grad_bits:
159 | if bit == 0:
160 | cost_eb.append(1)
161 | else:
162 | cost_eb.append(bit/32)
163 | cost_eb = np.array(cost_eb) * args.weight_bits/32
164 |
165 | cost_gc = []
166 | for i in range(len(bits)):
167 | if bits[i] == 0:
168 | cost_gc.append(1)
169 | else:
170 | cost_gc.append(bits[i]*grad_bits[i]/32/32)
171 | cost_gc = np.array(cost_gc)
172 |
173 | model = models.__dict__[args.arch](args.pretrained, proj_dim=len(bits))
174 | model = torch.nn.DataParallel(model).cuda()
175 |
176 | best_prec1 = 0
177 | best_epoch = 0
178 | best_full_prec = 0
179 |
180 | # optionally resume from a checkpoint
181 | if args.resume:
182 | if os.path.isfile(args.resume):
183 | logging.info('=> loading checkpoint `{}`'.format(args.resume))
184 | checkpoint = torch.load(args.resume)
185 | args.start_epoch = checkpoint['epoch']
186 | best_prec1 = checkpoint['best_prec1']
187 | model.load_state_dict(checkpoint['state_dict'])
188 | logging.info('=> loaded checkpoint `{}` (epoch: {})'.format(
189 | args.resume, checkpoint['epoch']
190 | ))
191 | else:
192 | logging.info('=> no checkpoint found at `{}`'.format(args.resume))
193 |
194 | cudnn.benchmark = False
195 |
196 | train_loader = prepare_train_data(dataset=args.dataset,
197 | datadir=args.datadir+'/train',
198 | batch_size=args.batch_size,
199 | shuffle=True,
200 | num_workers=args.workers)
201 | test_loader = prepare_test_data(dataset=args.dataset,
202 | datadir=args.datadir+'/val',
203 | batch_size=args.batch_size,
204 | shuffle=False,
205 | num_workers=args.workers)
206 |
207 | # define loss function (criterion) and optimizer
208 | criterion = nn.CrossEntropyLoss().cuda()
209 |
210 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
211 | momentum=args.momentum,
212 | weight_decay=args.weight_decay)
213 |
214 | batch_time = AverageMeter()
215 | data_time = AverageMeter()
216 | losses = AverageMeter()
217 | top1 = AverageMeter()
218 | cp_record = AverageMeter()
219 | cp_record_fw = AverageMeter()
220 | cp_record_eb = AverageMeter()
221 | cp_record_gc = AverageMeter()
222 |
223 | network_depth = sum(model.module.num_layers)
224 |
225 | layerwise_decision_statistics = []
226 |
227 | for k in range(network_depth):
228 | layerwise_decision_statistics.append([])
229 | for j in range(len(cost_fw)):
230 | ratio = AverageMeter()
231 | layerwise_decision_statistics[k].append(ratio)
232 |
233 | end = time.time()
234 |
235 | for _epoch in range(args.start_epoch, args.epoch):
236 | lr = adjust_learning_rate(args, optimizer, _epoch)
237 | adjust_target_ratio(args, _epoch)
238 |
239 | print('Learning Rate:', lr)
240 | print('Target Ratio:', args.target_ratio)
241 |
242 | for i, (input, target) in enumerate(train_loader):
243 | # measuring data loading time
244 | data_time.update(time.time() - end)
245 |
246 | model.train()
247 |
248 | target = target.squeeze().long().cuda()
249 | input_var = Variable(input).cuda()
250 | target_var = Variable(target).cuda()
251 |
252 | output, masks = model(input_var, bits, grad_bits)
253 |
254 | computation_cost_fw = 0
255 | computation_cost_eb = 0
256 | computation_cost_gc = 0
257 | computation_all = 0
258 |
259 | for layer in range(network_depth):
260 |
261 | full_layer = reduce((lambda x, y: x * y), masks[layer][0].shape)
262 |
263 | computation_all += full_layer
264 |
265 | for k in range(len(cost_fw)):
266 |
267 | dynamic_choice = masks[layer][k].sum()
268 |
269 | ratio = dynamic_choice / full_layer
270 |
271 | layerwise_decision_statistics[layer][k].update(ratio.data, 1)
272 |
273 | computation_cost_fw += masks[layer][k].sum() * cost_fw[k]
274 | computation_cost_eb += masks[layer][k].sum() * cost_eb[k]
275 | computation_cost_gc += masks[layer][k].sum() * cost_gc[k]
276 |
277 | computation_cost = computation_cost_fw + computation_cost_eb + computation_cost_gc
278 |
279 | cp_ratio_fw = (float(computation_cost_fw) / float(computation_all)) * 100
280 | cp_ratio_eb = (float(computation_cost_eb) / float(computation_all)) * 100
281 | cp_ratio_gc = (float(computation_cost_gc) / float(computation_all)) * 100
282 |
283 | cp_ratio = (float(computation_cost) / float(computation_all*3)) * 100
284 |
285 | computation_cost *= args.beta
286 |
287 | if cp_ratio < args.target_ratio - args.relax:
288 | reg = -1
289 | elif cp_ratio >= args.target_ratio + args.relax:
290 | reg = 1
291 | elif cp_ratio >=args.target_ratio:
292 | reg = 0.1
293 | else:
294 | reg = -0.1
295 |
296 | loss_cls = criterion(output, target_var)
297 |
298 | if args.computation_cost:
299 | loss = loss_cls + computation_cost * reg
300 | else:
301 | loss = loss_cls
302 |
303 | # measure accuracy and record loss
304 | prec1, = accuracy(output.data, target, topk=(1,))
305 | losses.update(loss.item(), input.size(0))
306 | top1.update(prec1.item(), input.size(0))
307 |
308 | cp_record.update(cp_ratio,1)
309 | cp_record_fw.update(cp_ratio_fw,1)
310 | cp_record_eb.update(cp_ratio_eb,1)
311 | cp_record_gc.update(cp_ratio_gc,1)
312 |
313 | # compute gradient and do SGD step
314 | optimizer.zero_grad()
315 | loss.backward()
316 | optimizer.step()
317 |
318 | # measure elapsed time
319 | batch_time.update(time.time() - end)
320 | end = time.time()
321 |
322 | # print log
323 | if i % args.print_freq == 0:
324 | logging.info("Iter: [{0}][{1}/{2}]\t"
325 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
326 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
327 | "Loss {loss.val:.3f} ({loss.avg:.3f})\t"
328 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t"
329 | "Computation_Percentage: {cp_record.val:.3f}({cp_record.avg:.3f})\t"
330 | "Computation_Percentage_FW: {cp_record_fw.val:.3f}({cp_record_fw.avg:.3f})\t"
331 | "Computation_Percentage_EB: {cp_record_eb.val:.3f}({cp_record_eb.avg:.3f})\t"
332 | "Computation_Percentage_GC: {cp_record_gc.val:.3f}({cp_record_gc.avg:.3f})\t".format(
333 | _epoch,
334 | i,
335 | len(train_loader),
336 | batch_time=batch_time,
337 | data_time=data_time,
338 | loss=losses,
339 | top1=top1,
340 | cp_record=cp_record,
341 | cp_record_fw=cp_record_fw,
342 | cp_record_eb=cp_record_eb,
343 | cp_record_gc=cp_record_gc)
344 | )
345 |
346 | with torch.no_grad():
347 | prec1 = validate(args, test_loader, model, criterion, _epoch)
348 | # prec_full = validate_full_prec(args, test_loader, model, criterion, i)
349 |
350 | is_best = prec1 > best_prec1
351 | if is_best:
352 | best_prec1 = prec1
353 | best_epoch = _epoch
354 | #best_full_prec = max(prec_full, best_full_prec)
355 |
356 | print("Current Best Prec@1: ", best_prec1, "Best Epoch:", best_epoch)
357 | #print("Current Best Full Prec@1: ", best_full_prec)
358 |
359 | checkpoint_path = os.path.join(args.save_path, 'checkpoint_{:05d}_{:.2f}.pth.tar'.format(_epoch, prec1))
360 | save_checkpoint({
361 | 'epoch': _epoch,
362 | 'arch': args.arch,
363 | 'state_dict': model.state_dict(),
364 | 'best_prec1': best_prec1,
365 | },
366 | is_best, filename=checkpoint_path)
367 | shutil.copyfile(checkpoint_path, os.path.join(args.save_path,
368 | 'checkpoint_latest'
369 | '.pth.tar'))
370 |
371 |
372 |
373 | def validate(args, test_loader, model, criterion, _epoch):
374 |
375 | cost_fw = []
376 | for bit in bits:
377 | if bit == 0:
378 | cost_fw.append(1)
379 | else:
380 | cost_fw.append(bit/32)
381 | cost_fw = np.array(cost_fw) * args.weight_bits/32
382 |
383 | cost_eb = []
384 | for bit in grad_bits:
385 | if bit == 0:
386 | cost_eb.append(1)
387 | else:
388 | cost_eb.append(bit/32)
389 | cost_eb = np.array(cost_eb) * args.weight_bits/32
390 |
391 | cost_gc = []
392 | for i in range(len(bits)):
393 | if bits[i] == 0:
394 | cost_gc.append(1)
395 | else:
396 | cost_gc.append(bits[i]*grad_bits[i]/32/32)
397 | cost_gc = np.array(cost_gc)
398 |
399 | batch_time = AverageMeter()
400 | data_time = AverageMeter()
401 | losses = AverageMeter()
402 | top1 = AverageMeter()
403 | cp_record = AverageMeter()
404 | cp_record_fw = AverageMeter()
405 | cp_record_eb = AverageMeter()
406 | cp_record_gc = AverageMeter()
407 |
408 | network_depth = sum(model.module.num_layers)
409 |
410 | layerwise_decision_statistics = []
411 |
412 | for k in range(network_depth):
413 | layerwise_decision_statistics.append([])
414 | for j in range(len(cost_fw)):
415 | ratio = AverageMeter()
416 | layerwise_decision_statistics[k].append(ratio)
417 |
418 | model.eval()
419 | end = time.time()
420 | for i, (input, target) in enumerate(test_loader):
421 | target = target.squeeze().long().cuda()
422 | input_var = Variable(input, volatile=True).cuda()
423 | target_var = Variable(target, volatile=True).cuda()
424 |
425 | output, masks = model(input_var, bits, grad_bits)
426 |
427 | computation_cost_fw = 0
428 | computation_cost_eb = 0
429 | computation_cost_gc = 0
430 | computation_all = 0
431 |
432 | for layer in range(network_depth):
433 |
434 | full_layer = reduce((lambda x, y: x * y), masks[layer][0].shape)
435 |
436 | computation_all += full_layer
437 |
438 | for k in range(len(cost_fw)):
439 |
440 | dynamic_choice = masks[layer][k].sum()
441 |
442 | ratio = dynamic_choice / full_layer
443 |
444 | layerwise_decision_statistics[layer][k].update(ratio.data, 1)
445 |
446 | computation_cost_fw += masks[layer][k].sum() * cost_fw[k]
447 | computation_cost_eb += masks[layer][k].sum() * cost_eb[k]
448 | computation_cost_gc += masks[layer][k].sum() * cost_gc[k]
449 |
450 | computation_cost = computation_cost_fw + computation_cost_eb + computation_cost_gc
451 |
452 | cp_ratio_fw = (float(computation_cost_fw) / float(computation_all)) * 100
453 | cp_ratio_eb = (float(computation_cost_eb) / float(computation_all)) * 100
454 | cp_ratio_gc = (float(computation_cost_gc) / float(computation_all)) * 100
455 |
456 | cp_ratio = (float(computation_cost) / float(computation_all*3)) * 100
457 |
458 | loss = criterion(output, target_var)
459 |
460 | # measure accuracy and record loss
461 | prec1, = accuracy(output.data, target, topk=(1,))
462 | losses.update(loss.item(), input.size(0))
463 | top1.update(prec1.item(), input.size(0))
464 |
465 | cp_record.update(cp_ratio,1)
466 | cp_record_fw.update(cp_ratio_fw,1)
467 | cp_record_eb.update(cp_ratio_eb,1)
468 | cp_record_gc.update(cp_ratio_gc,1)
469 |
470 | batch_time.update(time.time() - end)
471 | end = time.time()
472 |
473 | if i % args.print_freq == 0 or (i == (len(test_loader) - 1)):
474 | logging.info("Iter: [{0}/{1}]\t"
475 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
476 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
477 | "Loss {loss.val:.3f} ({loss.avg:.3f})\t"
478 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t"
479 | "Computation_Percentage: {cp_record.val:.3f}({cp_record.avg:.3f})\t"
480 | "Computation_Percentage_FW: {cp_record_fw.val:.3f}({cp_record_fw.avg:.3f})\t"
481 | "Computation_Percentage_EB: {cp_record_eb.val:.3f}({cp_record_eb.avg:.3f})\t"
482 | "Computation_Percentage_GC: {cp_record_gc.val:.3f}({cp_record_gc.avg:.3f})\t".format(
483 | i,
484 | len(test_loader),
485 | batch_time=batch_time,
486 | data_time=data_time,
487 | loss=losses,
488 | top1=top1,
489 | cp_record=cp_record,
490 | cp_record_fw=cp_record_fw,
491 | cp_record_eb=cp_record_eb,
492 | cp_record_gc=cp_record_gc)
493 | )
494 |
495 | logging.info('Epoch {} * Prec@1 {top1.avg:.3f}'.format(_epoch, top1=top1))
496 |
497 | for layer in range(network_depth):
498 | print('layer{}_decision'.format(layer + 1))
499 | for g in range(len(cost_fw)):
500 | print('{}_ratio{}'.format(g,layerwise_decision_statistics[layer][g].avg))
501 |
502 | return top1.avg
503 |
504 |
505 | def validate_full_prec(args, test_loader, model, criterion, _epoch):
506 | batch_time = AverageMeter()
507 | losses = AverageMeter()
508 | top1 = AverageMeter()
509 |
510 | bits_full = np.zeros(len(bits))
511 | grad_bits_full = np.zeros(len(grad_bits))
512 | # switch to evaluation mode
513 | model.eval()
514 | end = time.time()
515 | for i, (input, target) in enumerate(test_loader):
516 | target = target.squeeze().long().cuda()
517 | input_var = Variable(input, volatile=True).cuda()
518 | target_var = Variable(target, volatile=True).cuda()
519 |
520 | # compute output
521 | output, _ = model(input_var, bits_full, grad_bits_full)
522 | loss = criterion(output, target_var)
523 |
524 | # measure accuracy and record loss
525 | prec1, = accuracy(output.data, target, topk=(1,))
526 | top1.update(prec1.item(), input.size(0))
527 | losses.update(loss.item(), input.size(0))
528 | batch_time.update(time.time() - end)
529 | end = time.time()
530 |
531 | if args.gate_type == 'rnn':
532 | model.module.control.repackage_hidden()
533 |
534 | logging.info('Epoch {} * Full Prec@1 {top1.avg:.3f}'.format(_epoch, top1=top1))
535 | return top1.avg
536 |
537 |
538 | def test_model(args):
539 | # create model
540 | model = models.__dict__[args.arch](args.pretrained)
541 | model = torch.nn.DataParallel(model).cuda()
542 |
543 | if args.resume:
544 | if os.path.isfile(args.resume):
545 | logging.info('=> loading checkpoint `{}`'.format(args.resume))
546 | checkpoint = torch.load(args.resume)
547 | args.start_epoch = checkpoint['epoch']
548 | best_prec1 = checkpoint['best_prec1']
549 | model.load_state_dict(checkpoint['state_dict'])
550 | logging.info('=> loaded checkpoint `{}` (epoch: {})'.format(
551 | args.resume, checkpoint['epoch']
552 | ))
553 | else:
554 | logging.info('=> no checkpoint found at `{}`'.format(args.resume))
555 |
556 | cudnn.benchmark = False
557 | test_loader = prepare_test_data(dataset=args.dataset,
558 | batch_size=args.batch_size,
559 | shuffle=False,
560 | num_workers=args.workers)
561 | criterion = nn.CrossEntropyLoss().cuda()
562 |
563 | # validate(args, test_loader, model, criterion)
564 |
565 | with torch.no_grad():
566 | prec1 = validate(args, test_loader, model, criterion, args.start_iter)
567 | # prec_full = validate_full_prec(args, test_loader, model, criterion, args.start_iter)
568 |
569 |
570 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
571 | torch.save(state, filename)
572 | if is_best:
573 | save_path = os.path.dirname(filename)
574 | shutil.copyfile(filename, os.path.join(save_path,
575 | 'model_best.pth.tar'))
576 |
577 |
578 | class AverageMeter(object):
579 | """Computes and stores the average and current value"""
580 |
581 | def __init__(self):
582 | self.reset()
583 |
584 | def reset(self):
585 | self.val = 0
586 | self.avg = 0
587 | self.sum = 0
588 | self.count = 0
589 |
590 | def update(self, val, n=1):
591 | self.val = val
592 | self.sum += val * n
593 | self.count += n
594 | self.avg = self.sum / self.count
595 |
596 |
597 | schedule_cnt = 0
598 | def adjust_target_ratio(args, _epoch):
599 | if args.schedule:
600 | global schedule_cnt
601 |
602 | assert len(args.target_ratio_schedule) == len(args.schedule) + 1
603 |
604 | if schedule_cnt == 0:
605 | args.target_ratio = args.target_ratio_schedule[0]
606 | schedule_cnt += 1
607 |
608 | for step in args.schedule:
609 | if _epoch == step:
610 | args.target_ratio = args.target_ratio_schedule[schedule_cnt]
611 | schedule_cnt += 1
612 |
613 |
614 | def adjust_learning_rate(args, optimizer, _epoch):
615 | lr = args.lr * (0.1 ** (_epoch // 30))
616 |
617 | for param_group in optimizer.param_groups:
618 | param_group['lr'] = lr
619 |
620 | return lr
621 |
622 |
623 | def accuracy(output, target, topk=(1,)):
624 | """Computes the precision@k for the specified values of k"""
625 | maxk = max(topk)
626 | batch_size = target.size(0)
627 |
628 | _, pred = output.topk(maxk, 1, True, True)
629 | pred = pred.t()
630 | correct = pred.eq(target.view(1, -1).expand_as(pred))
631 |
632 | res = []
633 | for k in topk:
634 | correct_k = correct[:k].view(-1).float().sum(0)
635 | res.append(correct_k.mul_(100.0 / batch_size))
636 | return res
637 |
638 |
639 | if __name__ == '__main__':
640 | main()
641 |
--------------------------------------------------------------------------------
/fractrain_imagenet/train_pfq.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.backends.cudnn as cudnn
6 | from torch.autograd import Variable
7 | import numpy as np
8 |
9 | import os
10 | import shutil
11 | import argparse
12 | import time
13 | import logging
14 |
15 | import models
16 | from data import *
17 |
18 |
19 | model_names = sorted(name for name in models.__dict__
20 | if name.islower() and not name.startswith('__')
21 | and callable(models.__dict__[name])
22 | )
23 |
24 |
25 | def parse_args():
26 | # hyper-parameters are from ResNet paper
27 | parser = argparse.ArgumentParser(
28 | description='PFQ on ImageNet')
29 | parser.add_argument('--dir', help='annotate the working directory')
30 | parser.add_argument('--cmd', choices=['train', 'test'], default='train')
31 | parser.add_argument('--arch', metavar='ARCH', default='resnet50',
32 | choices=model_names,
33 | help='model architecture: ' +
34 | ' | '.join(model_names) +
35 | ' (default: cifar10_resnet_38)')
36 | parser.add_argument('--dataset', '-d', type=str, default='imagenet',
37 | choices=['cifar10', 'cifar100','imagenet'],
38 | help='dataset choice')
39 | parser.add_argument('--datadir', default='/home/yf22/dataset', type=str,
40 | help='path to dataset')
41 | parser.add_argument('--workers', default=16, type=int, metavar='N',
42 | help='number of data loading workers (default: 4 )')
43 | parser.add_argument('--epoch', default=90, type=int,
44 | help='number of epochs (default: 90)')
45 | parser.add_argument('--start_epoch', default=0, type=int,
46 | help='manual iter number (useful on restarts)')
47 | parser.add_argument('--batch_size', default=256, type=int,
48 | help='mini-batch size (default: 128)')
49 | parser.add_argument('--lr_schedule', default='piecewise', type=str,
50 | help='learning rate schedule')
51 | parser.add_argument('--lr', default=0.1, type=float,
52 | help='initial learning rate')
53 | parser.add_argument('--momentum', default=0.9, type=float,
54 | help='momentum')
55 | parser.add_argument('--weight_decay', default=1e-4, type=float,
56 | help='weight decay (default: 1e-4)')
57 | parser.add_argument('--print_freq', default=10, type=int,
58 | help='print frequency (default: 10)')
59 | parser.add_argument('--resume', default='', type=str,
60 | help='path to latest checkpoint (default: None)')
61 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
62 | help='use pretrained model')
63 | parser.add_argument('--step_ratio', default=0.1, type=float,
64 | help='ratio for learning rate deduction')
65 | parser.add_argument('--warm_up', action='store_true',
66 | help='for n = 18, the model needs to warm up for 400 '
67 | 'iterations')
68 | parser.add_argument('--save_folder', default='save_checkpoints',
69 | type=str,
70 | help='folder to save the checkpoints')
71 | parser.add_argument('--eval_every', default=390, type=int,
72 | help='evaluate model every (default: 1000) iterations')
73 | parser.add_argument('--num_bits',default=0,type=int,
74 | help='num bits for weight and activation')
75 | parser.add_argument('--num_grad_bits',default=0,type=int,
76 | help='num bits for gradient')
77 | parser.add_argument('--schedule', default=None, type=int, nargs='*',
78 | help='precision schedule')
79 | parser.add_argument('--num_bits_schedule',default=None,type=int,nargs='*',
80 | help='schedule for weight/act precision')
81 | parser.add_argument('--num_grad_bits_schedule',default=None,type=int,nargs='*',
82 | help='schedule for grad precision')
83 | parser.add_argument('--act_fw', default=0, type=int,
84 | help='precision of activation during forward, -1 means dynamic, 0 means no quantize')
85 | parser.add_argument('--act_bw', default=0, type=int,
86 | help='precision of activation during backward, -1 means dynamic, 0 means no quantize')
87 | parser.add_argument('--grad_act_error', default=0, type=int,
88 | help='precision of activation gradient during error backward, -1 means dynamic, 0 means no quantize')
89 | parser.add_argument('--grad_act_gc', default=0, type=int,
90 | help='precision of activation gradient during weight gradient computation, -1 means dynamic, 0 means no quantize')
91 | parser.add_argument('--weight_bits', default=0, type=int,
92 | help='precision of weight')
93 | parser.add_argument('--momentum_act', default=0.9, type=float,
94 | help='momentum for act min/max')
95 |
96 | parser.add_argument('--num_turning_point', type=int, default=3)
97 | parser.add_argument('--initial_threshold', type=float, default=0.05)
98 | parser.add_argument('--decay', type=float, default=0.3)
99 | args = parser.parse_args()
100 | return args
101 |
102 | # indicator
103 | class loss_diff_indicator():
104 | def __init__(self, threshold, decay, epoch_keep=5):
105 | self.threshold = threshold
106 | self.decay = decay
107 | self.epoch_keep = epoch_keep
108 | self.loss = []
109 | self.scale_loss = 1
110 | self.loss_diff = [1 for i in range(1, self.epoch_keep)]
111 |
112 | def reset(self):
113 | self.loss = []
114 | self.loss_diff = [1 for i in range(1, self.epoch_keep)]
115 |
116 | def adaptive_threshold(self, turning_point_count):
117 | decay_1 = self.decay
118 | decay_2 = self.decay
119 | if turning_point_count == 1:
120 | self.threshold *= decay_1
121 | if turning_point_count == 2:
122 | self.threshold *= decay_2
123 | print('threshold decay to {}'.format(self.threshold))
124 |
125 | def get_loss(self, current_epoch_loss):
126 | if len(self.loss) < self.epoch_keep:
127 | self.loss.append(current_epoch_loss)
128 | else:
129 | self.loss.pop(0)
130 | self.loss.append(current_epoch_loss)
131 |
132 | def cal_loss_diff(self):
133 | if len(self.loss) == self.epoch_keep:
134 | for i in range(len(self.loss)-1):
135 | loss_now = self.loss[-1]
136 | loss_pre = self.loss[i]
137 | self.loss_diff[i] = np.abs(loss_pre - loss_now) / self.scale_loss
138 | return True
139 | else:
140 | return False
141 |
142 | def turning_point_emerge(self):
143 | flag = self.cal_loss_diff()
144 | if flag == True:
145 | print(self.loss_diff)
146 | for i in range(len(self.loss_diff)):
147 | if self.loss_diff[i] > self.threshold:
148 | return False
149 | return True
150 | else:
151 | return False
152 |
153 | def main():
154 | args = parse_args()
155 | global save_path
156 | save_path = args.save_path = os.path.join(args.save_folder, args.arch)
157 | if not os.path.exists(save_path):
158 | os.makedirs(save_path)
159 |
160 | models.ACT_FW = args.act_fw
161 | models.ACT_BW = args.act_bw
162 | models.GRAD_ACT_ERROR = args.grad_act_error
163 | models.GRAD_ACT_GC = args.grad_act_gc
164 | models.WEIGHT_BITS = args.weight_bits
165 | models.MOMENTUM = args.momentum_act
166 |
167 | args.num_bits = args.num_bits if not (args.act_fw + args.act_bw + args.grad_act_error + args.grad_act_gc + args.weight_bits) else -1
168 |
169 | # config logging file
170 | args.logger_file = os.path.join(save_path, 'log_{}.txt'.format(args.cmd))
171 | if os.path.exists(args.logger_file):
172 | os.remove(args.logger_file)
173 | handlers = [logging.FileHandler(args.logger_file, mode='w'),
174 | logging.StreamHandler()]
175 | logging.basicConfig(level=logging.INFO,
176 | datefmt='%m-%d-%y %H:%M',
177 | format='%(asctime)s:%(message)s',
178 | handlers=handlers)
179 |
180 | global history_score
181 | history_score = np.zeros((args.epoch, 3))
182 |
183 | # initialize indicator
184 | # initial_threshold=0.15
185 | global scale_loss
186 | scale_loss = 0
187 | global my_loss_diff_indicator
188 | my_loss_diff_indicator = loss_diff_indicator(threshold=args.initial_threshold,
189 | decay=args.decay)
190 |
191 | global turning_point_count
192 | turning_point_count = 0
193 |
194 | if args.cmd == 'train':
195 | logging.info('start training {}'.format(args.arch))
196 | run_training(args)
197 |
198 | elif args.cmd == 'test':
199 | logging.info('start evaluating {} with checkpoints from {}'.format(
200 | args.arch, args.resume))
201 | test_model(args)
202 |
203 |
204 | def run_training(args):
205 | # create model
206 | training_loss = 0
207 | training_acc = 0
208 |
209 | model = models.__dict__[args.arch](args.pretrained)
210 | model = torch.nn.DataParallel(model).cuda()
211 |
212 | best_prec1 = 0
213 | best_epoch = 0
214 |
215 | # optionally resume from a checkpoint
216 | if args.resume:
217 | if os.path.isfile(args.resume):
218 | logging.info('=> loading checkpoint `{}`'.format(args.resume))
219 | checkpoint = torch.load(args.resume)
220 | args.start_epoch = checkpoint['epoch']
221 | best_prec1 = checkpoint['best_prec1']
222 | model.load_state_dict(checkpoint['state_dict'])
223 |
224 | logging.info('=> loaded checkpoint `{}` (epoch: {})'.format(
225 | args.resume, checkpoint['epoch']
226 | ))
227 | else:
228 | logging.info('=> no checkpoint found at `{}`'.format(args.resume))
229 |
230 | cudnn.benchmark = False
231 |
232 | train_loader = prepare_train_data(dataset=args.dataset,
233 | datadir=args.datadir+'/train',
234 | batch_size=args.batch_size,
235 | shuffle=True,
236 | num_workers=args.workers)
237 | test_loader = prepare_test_data(dataset=args.dataset,
238 | datadir=args.datadir+'/val',
239 | batch_size=args.batch_size,
240 | shuffle=False,
241 | num_workers=args.workers)
242 |
243 | # define loss function (criterion) and optimizer
244 | criterion = nn.CrossEntropyLoss().cuda()
245 |
246 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
247 | momentum=args.momentum,
248 | weight_decay=args.weight_decay)
249 |
250 | # optimizer = torch.optim.Adam(model.parameters(), args.lr,
251 | # weight_decay=args.weight_decay)
252 |
253 | batch_time = AverageMeter()
254 | data_time = AverageMeter()
255 | losses = AverageMeter()
256 | top1 = AverageMeter()
257 | cr = AverageMeter()
258 |
259 | end = time.time()
260 |
261 | global scale_loss
262 | global history_score
263 | global turning_point_count
264 | global my_loss_diff_indicator
265 |
266 | for _epoch in range(args.start_epoch, args.epoch):
267 | lr = adjust_learning_rate(args, optimizer, _epoch)
268 | # adjust_precision(args, _epoch)
269 | adaptive_adjust_precision(args, turning_point_count)
270 |
271 | print('Learning Rate:', lr)
272 | print('num bits:', args.num_bits, 'num grad bits:', args.num_grad_bits)
273 |
274 | for i, (input, target) in enumerate(train_loader):
275 | # measuring data loading time
276 | data_time.update(time.time() - end)
277 |
278 | model.train()
279 |
280 | fw_cost = args.num_bits*args.num_bits/32/32
281 | eb_cost = args.num_bits*args.num_grad_bits/32/32
282 | gc_cost = eb_cost
283 | cr.update((fw_cost+eb_cost+gc_cost)/3)
284 |
285 | target = target.squeeze().long().cuda()
286 | input_var = Variable(input).cuda()
287 | target_var = Variable(target).cuda()
288 |
289 | # compute output
290 | output = model(input_var, args.num_bits, args.num_grad_bits)
291 | loss = criterion(output, target_var)
292 | training_loss += loss.item()
293 |
294 | # measure accuracy and record loss
295 | prec1, = accuracy(output.data, target, topk=(1,))
296 | losses.update(loss.item(), input.size(0))
297 | top1.update(prec1.item(), input.size(0))
298 | training_acc += prec1.item()
299 |
300 | # compute gradient and do SGD step
301 | optimizer.zero_grad()
302 | loss.backward()
303 | optimizer.step()
304 |
305 | # measure elapsed time
306 | batch_time.update(time.time() - end)
307 | end = time.time()
308 |
309 | # print log
310 | if i % args.print_freq == 0:
311 | logging.info("Iter: [{0}][{1}/{2}]\t"
312 | "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
313 | "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
314 | "Loss {loss.val:.3f} ({loss.avg:.3f})\t"
315 | "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t".format(
316 | _epoch,
317 | i,
318 | len(train_loader),
319 | batch_time=batch_time,
320 | data_time=data_time,
321 | loss=losses,
322 | top1=top1)
323 | )
324 |
325 | epoch = _epoch + 1
326 | epoch_loss = training_loss / len(train_loader)
327 | with torch.no_grad():
328 | prec1 = validate(args, test_loader, model, criterion, _epoch)
329 | # prec_full = validate_full_prec(args, test_loader, model, criterion, i)
330 | history_score[epoch-1][0] = epoch_loss
331 | history_score[epoch-1][1] = np.round(training_acc / len(train_loader), 2)
332 | history_score[epoch-1][2] = prec1
333 | training_loss = 0
334 | training_acc = 0
335 |
336 | np.savetxt(os.path.join(save_path, 'record.txt'), history_score, fmt = '%10.5f', delimiter=',')
337 |
338 | # apply indicator
339 | # if epoch == 1:
340 | # logging.info('initial loss value: {}'.format(epoch_loss))
341 | # my_loss_diff_indicator.scale_loss = epoch_loss
342 | if epoch <= 10:
343 | scale_loss += epoch_loss
344 | logging.info('scale_loss at epoch {}: {}'.format(epoch, scale_loss / epoch))
345 | my_loss_diff_indicator.scale_loss = scale_loss / epoch
346 | if turning_point_count < args.num_turning_point:
347 | my_loss_diff_indicator.get_loss(epoch_loss)
348 | flag = my_loss_diff_indicator.turning_point_emerge()
349 | if flag == True:
350 | turning_point_count += 1
351 | logging.info('find {}-th turning point at {}-th epoch'.format(turning_point_count, epoch))
352 | # print('find {}-th turning point at {}-th epoch'.format(turning_point_count, epoch))
353 | my_loss_diff_indicator.adaptive_threshold(turning_point_count=turning_point_count)
354 | my_loss_diff_indicator.reset()
355 |
356 | logging.info('Epoch [{}] num_bits = {} num_grad_bits = {}'.format(epoch, args.num_bits, args.num_grad_bits))
357 |
358 |
359 | is_best = prec1 > best_prec1
360 | if is_best:
361 | best_prec1 = max(prec1, best_prec1)
362 | best_epoch = epoch
363 | #best_full_prec = max(prec_full, best_full_prec)
364 |
365 | print("Current Best Prec@1: ", best_prec1)
366 | logging.info("Current Best Epoch: {}".format(best_epoch))
367 | #print("Current Best Full Prec@1: ", best_full_prec)
368 |
369 | checkpoint_path = os.path.join(args.save_path, 'checkpoint_{:05d}_{:.2f}.pth.tar'.format(_epoch, prec1))
370 | save_checkpoint({
371 | 'epoch': _epoch,
372 | 'arch': args.arch,
373 | 'state_dict': model.state_dict(),
374 | 'best_prec1': best_prec1,
375 | },
376 | is_best, filename=checkpoint_path)
377 | shutil.copyfile(checkpoint_path, os.path.join(args.save_path,
378 | 'checkpoint_latest'
379 | '.pth.tar'))
380 |
381 |
382 | def validate(args, test_loader, model, criterion, _epoch):
383 | batch_time = AverageMeter()
384 | losses = AverageMeter()
385 | top1 = AverageMeter()
386 |
387 | # switch to evaluation mode
388 | model.eval()
389 | end = time.time()
390 | for i, (input, target) in enumerate(test_loader):
391 | target = target.squeeze().long().cuda()
392 | input_var = Variable(input, volatile=True).cuda()
393 | target_var = Variable(target, volatile=True).cuda()
394 |
395 | # compute output
396 | output = model(input_var, args.num_bits, args.num_grad_bits)
397 | loss = criterion(output, target_var)
398 |
399 | # measure accuracy and record loss
400 | prec1, = accuracy(output.data, target, topk=(1,))
401 | top1.update(prec1.item(), input.size(0))
402 | losses.update(loss.item(), input.size(0))
403 | batch_time.update(time.time() - end)
404 | end = time.time()
405 |
406 | if (i % args.print_freq == 0) or (i == len(test_loader) - 1):
407 | logging.info(
408 | 'Test: [{}/{}]\t'
409 | 'Time: {batch_time.val:.4f}({batch_time.avg:.4f})\t'
410 | 'Loss: {loss.val:.3f}({loss.avg:.3f})\t'
411 | 'Prec@1: {top1.val:.3f}({top1.avg:.3f})\t'.format(
412 | i, len(test_loader), batch_time=batch_time,
413 | loss=losses, top1=top1
414 | )
415 | )
416 |
417 | logging.info('Epoch {} * Prec@1 {top1.avg:.3f}'.format(_epoch, top1=top1))
418 | return top1.avg
419 |
420 |
421 | def validate_full_prec(args, test_loader, model, criterion, _epoch):
422 | batch_time = AverageMeter()
423 | losses = AverageMeter()
424 | top1 = AverageMeter()
425 |
426 | # switch to evaluation mode
427 | model.eval()
428 | end = time.time()
429 | for i, (input, target) in enumerate(test_loader):
430 | target = target.squeeze().long().cuda()
431 | input_var = Variable(input, volatile=True).cuda()
432 | target_var = Variable(target, volatile=True).cuda()
433 |
434 | # compute output
435 | output = model(input_var, 0, 0)
436 | loss = criterion(output, target_var)
437 |
438 | # measure accuracy and record loss
439 | prec1, = accuracy(output.data, target, topk=(1,))
440 | top1.update(prec1.item(), input.size(0))
441 | losses.update(loss.item(), input.size(0))
442 | batch_time.update(time.time() - end)
443 | end = time.time()
444 |
445 |
446 | logging.info('Epoch {} * Full Prec@1 {top1.avg:.3f}'.format(_epoch, top1=top1))
447 | return top1.avg
448 |
449 |
450 | def test_model(args):
451 | # create model
452 | model = models.__dict__[args.arch](args.pretrained)
453 | model = torch.nn.DataParallel(model).cuda()
454 |
455 | if args.resume:
456 | if os.path.isfile(args.resume):
457 | logging.info('=> loading checkpoint `{}`'.format(args.resume))
458 | checkpoint = torch.load(args.resume)
459 | args.start_epoch = checkpoint['epoch']
460 | best_prec1 = checkpoint['best_prec1']
461 | model.load_state_dict(checkpoint['state_dict'])
462 | logging.info('=> loaded checkpoint `{}` (epoch: {})'.format(
463 | args.resume, checkpoint['epoch']
464 | ))
465 | else:
466 | logging.info('=> no checkpoint found at `{}`'.format(args.resume))
467 |
468 | cudnn.benchmark = False
469 | test_loader = prepare_test_data(dataset=args.dataset,
470 | batch_size=args.batch_size,
471 | shuffle=False,
472 | num_workers=args.workers)
473 | criterion = nn.CrossEntropyLoss().cuda()
474 |
475 | # validate(args, test_loader, model, criterion)
476 |
477 | with torch.no_grad():
478 | prec1 = validate(args, test_loader, model, criterion, args.start_iter)
479 | # prec_full = validate_full_prec(args, test_loader, model, criterion, args.start_iter)
480 |
481 |
482 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
483 | torch.save(state, filename)
484 | if is_best:
485 | save_path = os.path.dirname(filename)
486 | shutil.copyfile(filename, os.path.join(save_path,
487 | 'model_best.pth.tar'))
488 |
489 |
490 | class AverageMeter(object):
491 | """Computes and stores the average and current value"""
492 |
493 | def __init__(self):
494 | self.reset()
495 |
496 | def reset(self):
497 | self.val = 0
498 | self.avg = 0
499 | self.sum = 0
500 | self.count = 0
501 |
502 | def update(self, val, n=1):
503 | self.val = val
504 | self.sum += val * n
505 | self.count += n
506 | self.avg = self.sum / self.count
507 |
508 |
509 | schedule_cnt = 0
510 | def adjust_precision(args, _epoch):
511 | if args.schedule:
512 | global schedule_cnt
513 |
514 | assert len(args.num_bits_schedule) == len(args.schedule) + 1
515 | assert len(args.num_grad_bits_schedule) == len(args.schedule) + 1
516 |
517 | if schedule_cnt == 0:
518 | args.num_bits = args.num_bits_schedule[0]
519 | args.num_grad_bits = args.num_grad_bits_schedule[0]
520 | schedule_cnt += 1
521 |
522 | for step in args.schedule:
523 | if _epoch == step:
524 | args.num_bits = args.num_bits_schedule[schedule_cnt]
525 | args.num_grad_bits = args.num_grad_bits_schedule[schedule_cnt]
526 | schedule_cnt += 1
527 |
528 | def adaptive_adjust_precision(args, turning_point_count):
529 | args.num_bits = args.num_bits_schedule[turning_point_count]
530 | args.num_grad_bits = args.num_grad_bits_schedule[turning_point_count]
531 |
532 | def adjust_learning_rate(args, optimizer, _epoch):
533 | lr = args.lr * (0.1 ** (_epoch // 30))
534 |
535 | for param_group in optimizer.param_groups:
536 | param_group['lr'] = lr
537 |
538 | return lr
539 |
540 |
541 | def accuracy(output, target, topk=(1,)):
542 | """Computes the precision@k for the specified values of k"""
543 | maxk = max(topk)
544 | batch_size = target.size(0)
545 |
546 | _, pred = output.topk(maxk, 1, True, True)
547 | pred = pred.t()
548 | correct = pred.eq(target.view(1, -1).expand_as(pred))
549 |
550 | res = []
551 | for k in topk:
552 | correct_k = correct[:k].view(-1).float().sum(0)
553 | res.append(correct_k.mul_(100.0 / batch_size))
554 | return res
555 |
556 |
557 | if __name__ == '__main__':
558 | main()
559 |
--------------------------------------------------------------------------------
/img/DFQ.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/img/DFQ.png
--------------------------------------------------------------------------------
/img/PFQ.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/img/PFQ.png
--------------------------------------------------------------------------------
/img/dfq_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/img/dfq_result.png
--------------------------------------------------------------------------------
/img/fractrain_result_cifar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/img/fractrain_result_cifar.png
--------------------------------------------------------------------------------
/img/fractrain_result_imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/img/fractrain_result_imagenet.png
--------------------------------------------------------------------------------
/img/pfq_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GATECH-EIC/FracTrain/1113ec227e6ef12225db582de3ea9a551d00c51a/img/pfq_result.png
--------------------------------------------------------------------------------