├── LICENSE ├── README.md ├── anybit.py ├── config.py ├── data_pre.py ├── evaluators.py ├── main.py ├── models ├── __init__.py ├── alexnet.py ├── alexnet_all.py ├── quantization.py ├── resnet.py └── resnet18_all.py ├── quan-weight.sh ├── quan_all_main.py ├── quan_weight_main.py ├── tools └── cluster.py ├── train.sh └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Quantization Networks 2 | 3 | ### Overview 4 | This repository contains the training code of Quantization Networks introduced in our CVPR 2019 paper: [*Quantization Networks*](http://openaccess.thecvf.com/content_CVPR_2019/papers/Yang_Quantization_Networks_CVPR_2019_paper.pdf). 5 | 6 | In this work, we provide a **simple and uniform way** for weights and activations quantization by formulating it as a differentiable non-linear function. 7 | The quantization function is represented as a linear combination of several 8 | Sigmoid functions with learnable biases and scales that 9 | could be learned in a lossless and end-to-end manner via 10 | continuous relaxation of the steepness of Sigmoid functions. 11 | 12 | Extensive experiments on image classification and object 13 | detection tasks show that our quantization networks outperform state-of-the-art methods. 14 | 15 | ### Run environment 16 | 17 | + Python 3.5 18 | + Python bindings for OpenCV 19 | + Pytorch 0.3.0 20 | 21 | ### Usage 22 | 23 | Download the ImageNet dataset and decompress into the structure like 24 | 25 | dir/ 26 | train/ 27 | n01440764_10026.JPEG 28 | ... 29 | val/ 30 | ILSVRC2012_val_00000001.JPEG 31 | ... 32 | 33 | To train a weight quantization model of ResNet-18, simply run 34 | 35 | sh quan-weight.sh 36 | 37 | After the training, the result model will be stored in `./logs/quan-weight/resnet18-quan-w-1`. 38 | 39 | Other training processes can be found in the paper. 40 | 41 | ### License 42 | + Apache License 2.0 43 | 44 | 45 | ### Citation 46 | If you use our code or models in your research, please cite with: 47 | ``` 48 | @inproceedings{yang2019quantization, 49 | title={Quantization Networks}, 50 | author={Yang Jiwei, Shen Xu, Xing Jun, Tian Xinmei, Li Houqiang, Deng Bing, Huang Jianqiang and Hua Xian-sheng}, 51 | booktitle={CVPR}, 52 | year={2019} 53 | } 54 | ``` 55 | -------------------------------------------------------------------------------- /anybit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # anybit.py is used to quantize the weight of model. 4 | 5 | from __future__ import print_function, absolute_import 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | from torch.nn.parameter import Parameter 11 | import math 12 | import numpy 13 | import pdb 14 | 15 | def sigmoid_t(x, b=0, t=1): 16 | """ 17 | The sigmoid function with T for soft quantization function. 18 | Args: 19 | x: input 20 | b: the bias 21 | t: the temperature 22 | Returns: 23 | y = sigmoid(t(x-b)) 24 | """ 25 | temp = -1 * t * (x - b) 26 | temp = torch.clamp(temp, min=-10.0, max=10.0) 27 | return 1.0 / (1.0 + torch.exp(temp)) 28 | 29 | def step(x, bias): 30 | """ 31 | The step function for ideal quantization function in test stage. 32 | """ 33 | y = torch.zeros_like(x) 34 | mask = torch.gt(x - bias, 0.0) 35 | y[mask] = 1.0 36 | return y 37 | 38 | class QuaOp(object): 39 | """ 40 | Quantize weight. 41 | Args: 42 | model: the model to be quantified. 43 | QW_biases (list): the bias of quantization function. 44 | QW_biases is a list with m*n shape, m is the number of layers, 45 | n is the number of sigmoid_t. 46 | QW_values (list): the list of quantization values, 47 | such as [-1, 0, 1], [-2, -1, 0, 1, 2]. 48 | Returns: 49 | Quantized model. 50 | """ 51 | def __init__(self, model, QW_biases, QW_values=[]): 52 | # Count the number of Conv2d and Linear 53 | count_targets = 0 54 | for m in model.modules(): 55 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 56 | count_targets = count_targets + 1 57 | # Omit the first conv layer and the last linear layer 58 | start_range = 1 59 | end_range = count_targets - 2 60 | self.bin_range = numpy.linspace(start_range, 61 | end_range, end_range-start_range+1)\ 62 | .astype('int').tolist() 63 | self.num_of_params = len(self.bin_range) 64 | self.saved_params = [] 65 | self.target_params = [] 66 | self.target_modules = [] 67 | index = -1 68 | for m in model.modules(): 69 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 70 | index = index + 1 71 | if index in self.bin_range: 72 | tmp = m.weight.data.clone() 73 | self.saved_params.append(tmp) 74 | self.target_modules.append(m.weight) 75 | 76 | print('target_modules number: ', len(self.target_modules)) 77 | self.QW_biases = QW_biases 78 | self.QW_values = QW_values 79 | # the number of sigmoid_t 80 | self.n = len(self.QW_values) - 1 81 | self.threshold = self.QW_values[-1] * 5 / 4.0 82 | # the gap between two quantization values 83 | self.scales = [] 84 | offset = 0. 85 | for i in range(self.n): 86 | gap = self.QW_values[i + 1] - self.QW_values[i] 87 | self.scales.append(gap) 88 | offset += gap 89 | self.offset = offset / 2. 90 | 91 | def forward(self, x, T, quan_bias, train=True): 92 | if train: 93 | y = sigmoid_t(x, b=quan_bias[0], t=T)*self.scales[0] 94 | for j in range(1, self.n): 95 | y += sigmoid_t(x, b=quan_bias[j], t=T)*self.scales[j] 96 | else: 97 | y = step(x, bias=quan_bias[0])*self.scales[0] 98 | for j in range(1, self.n): 99 | y += step(x, bias=quan_bias[j])*self.scales[j] 100 | y = y - self.offset 101 | 102 | return y 103 | 104 | def backward(self, x, T, quan_bias): 105 | y_1 = sigmoid_t(x, b=quan_bias[0], t=T)*self.scales[0] 106 | y_grad = (y_1.mul(self.scales[0] - y_1)).div(self.scales[0]) 107 | for j in range(1, self.n): 108 | y_temp = sigmoid_t(x, b=quan_bias[j], t=T)*self.scales[j] 109 | y_grad += (y_temp.mul(self.scales[j] - y_temp)).div(self.scales[j]) 110 | 111 | return y_grad 112 | 113 | def quantization(self, T, alpha, beta, init, train_phase=True): 114 | """ 115 | The operation of network quantization. 116 | Args: 117 | T: the temperature, a single number. 118 | alpha: the scale factor of the output, a list. 119 | beta: the scale factor of the input, a list. 120 | init: a flag represents the first loading of the quantization function. 121 | train_phase: a flag represents the quantization 122 | operation in the training stage. 123 | """ 124 | self.save_params() 125 | self.quantizeConvParams(T, alpha, beta, init, train_phase=train_phase) 126 | 127 | def save_params(self): 128 | """ 129 | save the float parameters for backward 130 | """ 131 | for index in range(self.num_of_params): 132 | self.saved_params[index].copy_(self.target_modules[index].data) 133 | 134 | def restore_params(self): 135 | for index in range(self.num_of_params): 136 | self.target_modules[index].data.copy_(self.saved_params[index]) 137 | 138 | 139 | def quantizeConvParams(self, T, alpha, beta, init, train_phase): 140 | """ 141 | quantize the parameters in forward 142 | """ 143 | T = (T > 2000)*2000 + (T <= 2000)*T 144 | for index in range(self.num_of_params): 145 | if init: 146 | beta[index].data = torch.Tensor([self.threshold / self.target_modules[index].data.abs().max()]).cuda() 147 | alpha[index].data = torch.reciprocal(beta[index].data) 148 | # scale w 149 | x = self.target_modules[index].data.mul(beta[index].data) 150 | 151 | y = self.forward(x, T, self.QW_biases[index], train=train_phase) 152 | #scale w^hat 153 | self.target_modules[index].data = y.mul(alpha[index].data) 154 | 155 | 156 | def updateQuaGradWeight(self, T, alpha, beta, init): 157 | """ 158 | Calculate the gradients of all the parameters. 159 | The gradients of model parameters are saved in the [Variable].grad.data. 160 | Args: 161 | T: the temperature, a single number. 162 | alpha: the scale factor of the output, a list. 163 | beta: the scale factor of the input, a list. 164 | init: a flag represents the first loading of the quantization function. 165 | Returns: 166 | alpha_grad: the gradient of alpha. 167 | beta_grad: the gradient of beta. 168 | """ 169 | beta_grad = [0.0] * len(beta) 170 | alpha_grad = [0.0] * len(alpha) 171 | T = (T > 2000)*2000 + (T <= 2000)*T 172 | for index in range(self.num_of_params): 173 | if init: 174 | beta[index].data = torch.Tensor([self.threshold / self.target_modules[index].data.abs().max()]).cuda() 175 | alpha[index].data = torch.reciprocal(beta[index].data) 176 | x = self.target_modules[index].data.mul(beta[index].data) 177 | 178 | # set T = 1 when train binary model 179 | y_grad = self.backward(x, 1, self.QW_biases[index]).mul(T) 180 | # set T = T when train the other quantization model 181 | #y_grad = self.backward(x, T, self.QW_biases[index]).mul(T) 182 | 183 | 184 | beta_grad[index] = y_grad.mul(self.target_modules[index].data).mul(alpha[index].data).\ 185 | mul(self.target_modules[index].grad.data).sum() 186 | alpha_grad[index] = self.forward(x, T, self.QW_biases[index]).\ 187 | mul(self.target_modules[index].grad.data).sum() 188 | 189 | self.target_modules[index].grad.data = y_grad.mul(beta[index].data).mul(alpha[index].data).\ 190 | mul(self.target_modules[index].grad.data) 191 | return alpha_grad, beta_grad 192 | 193 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os.path 5 | 6 | # gets home dir cross platform 7 | HOME = os.path.expanduser("~") 8 | 9 | QW_values = { 10 | 'alexnet-w-1': [-1, 0, 1], 'alexnet-w-2': [-1, 0, 1], 'alexnet-w-3-pm2': [-2, -1, 0, 1, 2], 'alexnet-w-3-pm4': [-4, -2, -1, 0, 1, 2, 4], 11 | 'resnet18-w-1': [-1, 0, 1], 'resnet18-w-2': [-1, 0, 1],'resnet18-w-3':[-4,-2,-1,0,1,2,4] 12 | } 13 | 14 | QW_biases = { 15 | 'alexnet-w-1':[[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], 16 | 17 | 'resnet18-w-1':[[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], 18 | [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], 19 | [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], 20 | [0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], 21 | } 22 | 23 | QA_biases = { 24 | 'resnet18-a-1':[[0.05], [0.05], [0.05], [0.05], [0.05], [0.05], [0.05], [0.05], [0.05], [0.05], [0.05], [0.05], [0.05], [0.05], [0.05], 25 | [0.05], [0.05]] 26 | 27 | } 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /data_pre.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import absolute_import 5 | import os.path as osp 6 | from collections import defaultdict 7 | import numpy as np 8 | from PIL import Image 9 | import random 10 | 11 | import torch 12 | from torchvision.transforms import * 13 | 14 | #################################### data augmentation ################################ 15 | class Grayscale(object): 16 | 17 | def __call__(self, img): 18 | gs = img.clone() 19 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 20 | gs[1].copy_(gs[0]) 21 | gs[2].copy_(gs[0]) 22 | return gs 23 | 24 | class Saturation(object): 25 | 26 | def __init__(self, var): 27 | self.var = var 28 | 29 | def __call__(self, img): 30 | gs = Grayscale()(img) 31 | alpha = random.uniform(0, self.var) 32 | return img.lerp(gs, alpha) 33 | 34 | 35 | class Brightness(object): 36 | 37 | def __init__(self, var): 38 | self.var = var 39 | 40 | def __call__(self, img): 41 | gs = img.new().resize_as_(img).zero_() 42 | alpha = random.uniform(0, self.var) 43 | return img.lerp(gs, alpha) 44 | 45 | 46 | class Contrast(object): 47 | 48 | def __init__(self, var): 49 | self.var = var 50 | 51 | def __call__(self, img): 52 | gs = Grayscale()(img) 53 | gs.fill_(gs.mean()) 54 | alpha = random.uniform(0, self.var) 55 | return img.lerp(gs, alpha) 56 | 57 | 58 | class RandomOrder(object): 59 | """ Composes several transforms together in random order. 60 | """ 61 | 62 | def __init__(self, transforms): 63 | self.transforms = transforms 64 | 65 | def __call__(self, img): 66 | if self.transforms is None: 67 | return img 68 | order = torch.randperm(len(self.transforms)) 69 | for i in order: 70 | img = self.transforms[i](img) 71 | return img 72 | 73 | 74 | class ColorJitter(RandomOrder): 75 | 76 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 77 | self.transforms = [] 78 | if brightness != 0: 79 | self.transforms.append(Brightness(brightness)) 80 | if contrast != 0: 81 | self.transforms.append(Contrast(contrast)) 82 | if saturation != 0: 83 | self.transforms.append(Saturation(saturation)) 84 | 85 | class Lighting(object): 86 | """Lighting noise(AlexNet - style PCA - based noise)""" 87 | 88 | def __init__(self, alphastd, eigval, eigvec): 89 | self.alphastd = alphastd 90 | self.eigval = eigval 91 | self.eigvec = eigvec 92 | 93 | def __call__(self, img): 94 | if self.alphastd == 0: 95 | return img 96 | 97 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 98 | rgb = self.eigvec.type_as(img).clone()\ 99 | .mul(alpha.view(1, 3).expand(3, 3))\ 100 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 101 | .sum(1).squeeze() 102 | 103 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 104 | 105 | 106 | #################################### end data augmentation ########################## 107 | 108 | 109 | #################################### data preprocessor ################################ 110 | class Preprocessor(object): 111 | def __init__(self, dataset, root=None, transform=None): 112 | super(Preprocessor, self).__init__() 113 | self.dataset = dataset 114 | self.root = root 115 | self.transform = transform 116 | 117 | def __len__(self): 118 | return len(self.dataset) 119 | 120 | def __getitem__(self, indices): 121 | if isinstance(indices, (tuple, list)): 122 | return [self._get_single_item(index) for index in indices] 123 | return self._get_single_item(indices) 124 | 125 | def _get_single_item(self, index): 126 | fname, label = self.dataset[index] 127 | fpath = fname 128 | if self.root is not None: 129 | fpath = osp.join(self.root, fname) 130 | img = Image.open(fpath).convert('RGB') 131 | if self.transform is not None: 132 | img = self.transform(img) 133 | return img, fname, label 134 | 135 | #################################### end data preprocessor ################################ 136 | -------------------------------------------------------------------------------- /evaluators.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import print_function, absolute_import 5 | 6 | from utils import to_torch 7 | 8 | 9 | def accuracy(output, target, topk=(1,)): 10 | output, target = to_torch(output), to_torch(target) 11 | maxk = max(topk) 12 | batch_size = target.size(0) 13 | 14 | _, pred = output.topk(maxk, 1, True, True) 15 | pred = pred.t() 16 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 17 | 18 | ret = [] 19 | for k in topk: 20 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 21 | ret.append(correct_k.mul_(1. / batch_size)) 22 | return ret 23 | 24 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Train the baseline model. 5 | """ 6 | from __future__ import print_function, absolute_import 7 | import argparse 8 | import os.path as osp 9 | import os 10 | import numpy as np 11 | import sys 12 | import time 13 | import math 14 | import torch 15 | from torch import nn 16 | from torch.autograd import Variable 17 | from torch.backends import cudnn 18 | from torch.utils.data import DataLoader 19 | from torch.utils.data.sampler import RandomSampler 20 | from torch.nn.parameter import Parameter 21 | from torchvision import transforms as T 22 | 23 | from config import * 24 | import models 25 | from data_pre import Lighting, Preprocessor 26 | from utils import Logger, AverageMeter 27 | from utils import load_checkpoint, save_checkpoint 28 | from evaluators import accuracy 29 | import pdb 30 | 31 | 32 | def get_params(pretrained_model): 33 | pretrained_checkpoint = load_checkpoint(pretrained_model) 34 | for name, param in pretrained_checkpoint.items(): 35 | #for name, param in pretrained_checkpoint['state_dict'].items(): 36 | print('pretrained_model params name and size: ', name, param.size()) 37 | if isinstance(param, Parameter): 38 | # backwards compatibility for serialized parameters 39 | param = param.data 40 | try: 41 | np.save(name+'.npy', param.cpu().numpy()) 42 | print('############# new_model load params name: ',name) 43 | except: 44 | raise RuntimeError('While copying the parameter named {}, \ 45 | whose dimensions in the model are {} and \ 46 | whose dimensions in the checkpoint are {}.' 47 | .format(name, new_model_dict[name].size(), param.size())) 48 | 49 | 50 | def get_data(split_id, data_dir, img_size, scale_size, batch_size, 51 | workers, train_list, val_list): 52 | root = data_dir 53 | 54 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 55 | std=[0.229, 0.224, 0.225]) # RGB imagenet 56 | 57 | # with data augmentation 58 | train_transformer = T.Compose([ 59 | T.RandomResizedCrop(img_size), 60 | T.RandomHorizontalFlip(), 61 | T.ToTensor(), # [0, 255] to [0.0, 1.0] 62 | normalizer, # normalize each channel of the input 63 | ]) 64 | 65 | test_transformer = T.Compose([ 66 | T.Resize(scale_size), 67 | T.CenterCrop(img_size), 68 | T.ToTensor(), 69 | normalizer, 70 | ]) 71 | 72 | train_loader = DataLoader( 73 | Preprocessor(train_list, root=root, 74 | transform=train_transformer), 75 | batch_size=batch_size, num_workers=workers, 76 | sampler=RandomSampler(train_list), 77 | pin_memory=True, drop_last=False) 78 | 79 | val_loader = DataLoader( 80 | Preprocessor(val_list, root=root, 81 | transform=test_transformer), 82 | batch_size=batch_size, num_workers=workers, 83 | shuffle=False, pin_memory=True) 84 | 85 | return train_loader, val_loader 86 | 87 | def main(args): 88 | np.random.seed(args.seed) 89 | torch.manual_seed(args.seed) 90 | torch.cuda.manual_seed(args.seed) 91 | cudnn.benchmark = True 92 | data_dir = osp.join(args.data_dir, args.dataset) 93 | # Redirect print to both console and log file 94 | if not args.evaluate: 95 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 96 | else: 97 | sys.stdout = Logger(osp.join(args.logs_dir, 'evaluate-log.txt')) 98 | print('\n################## setting ###################') 99 | print(parser.parse_args()) 100 | print('################## setting ###################\n') 101 | # Create data loaders 102 | def readlist(fpath): 103 | lines=[] 104 | with open(fpath, 'r') as f: 105 | data = f.readlines() 106 | 107 | for line in data: 108 | name, label = line.split() 109 | lines.append((name, int(label))) 110 | return lines 111 | 112 | # Load data list 113 | if osp.exists(osp.join(data_dir, 'train.txt')): 114 | train_list = readlist(osp.join(data_dir, 'train.txt')) 115 | else: 116 | raise RuntimeError("The training list -- {} doesn't exist".format(train_list)) 117 | 118 | if osp.exists(osp.join(data_dir, 'val.txt')): 119 | val_list = readlist(osp.join(data_dir, 'val.txt')) 120 | else: 121 | raise RuntimeError("The val list -- {} doesn't exist".format(val_list)) 122 | 123 | 124 | if args.scale_size is None : 125 | args.scale_size = 256 126 | if args.img_size is None : 127 | args.img_size = 224 128 | 129 | train_loader, val_loader = \ 130 | get_data(args.split, data_dir, args.img_size, 131 | args.scale_size, args.batch_size, args.workers, 132 | train_list, val_list) 133 | # Create model 134 | #num_classes = 1000 # imagenet 1000 135 | model = models.create(args.arch, False, num_classes=1000) 136 | 137 | if args.adam: 138 | print('The optimizer is Adam !!!') 139 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, 140 | weight_decay=args.weight_decay) 141 | else: 142 | print('The optimizer is SGD !!!') 143 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, 144 | momentum=args.momentum, 145 | weight_decay=args.weight_decay) 146 | 147 | # Load model from checkpoint 148 | start_epoch = best_top1 = 0 149 | if args.pretrained: 150 | print('=> Start load params from pre-trained model...') 151 | checkpoint = load_checkpoint(args.pretrained) 152 | if 'alexnet' in args.arch or 'resnet' in args.arch: 153 | model.load_state_dict(checkpoint) 154 | #model.load_state_dict(checkpoint['state_dict']) 155 | #torch.save(model.state_dict(), osp.join('./pre-models', 'resnet18-relu6-703.pth')) 156 | else: 157 | raise RuntimeError('The arch is ERROR!!!') 158 | 159 | # get model parameters 160 | get_params(args.pretrained) 161 | pdb.set_trace() 162 | 163 | 164 | if args.resume: 165 | checkpoint = load_checkpoint(args.resume) 166 | model.load_state_dict(checkpoint['state_dict']) 167 | optimizer.load_state_dict(checkpoint['optimizer']) 168 | start_epoch = args.resume_epoch 169 | print("=> Finetune Start epoch {} " 170 | .format(start_epoch)) 171 | 172 | 173 | model = nn.DataParallel(model).cuda() 174 | 175 | # Criterion 176 | criterion = nn.CrossEntropyLoss().cuda() 177 | 178 | evaluator = Evaluator(model, criterion) 179 | if args.evaluate: 180 | print('Test model: \n') 181 | evaluator.evaluate(val_loader) 182 | return 183 | 184 | # Trainer 185 | trainer = Trainer(model, criterion) 186 | 187 | # Schedule learning rate 188 | def adjust_lr(epoch): 189 | step_size = args.step_size 190 | decay_step = args.decay_step 191 | lr = args.lr if epoch < step_size else \ 192 | args.lr * (0.1 ** ((epoch - step_size) // decay_step + 1)) 193 | for g in optimizer.param_groups: 194 | g['lr'] = lr * g.get('lr_mult', 1) 195 | 196 | # Start training 197 | trainer.show_info(with_arch=True, with_grad=False) 198 | for epoch in range(start_epoch, args.epochs): 199 | adjust_lr(epoch) 200 | 201 | trainer.train(epoch, train_loader, optimizer, print_info=args.print_info) 202 | if epoch < args.start_save: 203 | continue 204 | top1 = evaluator.evaluate(val_loader) 205 | 206 | is_best = top1 > best_top1 207 | best_top1 = max(top1, best_top1) 208 | save_checkpoint({ 209 | 'state_dict':model.module.state_dict(), 210 | 'optimizer': optimizer.state_dict()}, 211 | is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 212 | 213 | print('\n * Finished epoch {:3d} top1: {:5.2%} model_best: {:5.2%} \n'. 214 | format(epoch, top1, best_top1)) 215 | 216 | if (epoch+1) % 5 == 0: 217 | model_name = 'epoch_'+ str(epoch) + '.pth.tar' 218 | torch.save({'state_dict':model.module.state_dict(), 219 | 'optimizer': optimizer.state_dict()}, 220 | osp.join(args.logs_dir, model_name)) 221 | 222 | class Trainer(object): 223 | def __init__(self, model, criterion): 224 | super(Trainer, self).__init__() 225 | self.model = model 226 | self.criterion = criterion 227 | 228 | def train(self, epoch, data_loader, optimizer, print_freq=1, print_info=10): 229 | self.model.train() 230 | 231 | batch_time = AverageMeter() 232 | data_time = AverageMeter() 233 | losses = AverageMeter() 234 | top1 = AverageMeter() 235 | top5 = AverageMeter() 236 | 237 | end = time.time() 238 | for i, inputs in enumerate(data_loader): 239 | data_time.update(time.time() - end) 240 | 241 | inputs_var, targets_var = self._parse_data(inputs) 242 | 243 | loss, prec1, prec5 = self._forward(inputs_var, targets_var) 244 | losses.update(loss.data[0], targets_var.size(0)) 245 | top1.update(prec1, targets_var.size(0)) 246 | top5.update(prec5, targets_var.size(0)) 247 | 248 | optimizer.zero_grad() 249 | loss.backward() 250 | torch.nn.utils.clip_grad_norm(self.model.parameters(), 5.0) 251 | 252 | optimizer.step() 253 | 254 | batch_time.update(time.time() - end) 255 | end = time.time() 256 | 257 | if (i + 1) % print_freq == 0: 258 | print('Epoch: [{}][{}/{}]\t' 259 | 'Time {:.3f} ({:.3f})\t' 260 | 'Data {:.3f} ({:.3f})\t' 261 | 'Loss {:.3f} ({:.3f})\t' 262 | 'Prec@1 {:.2%} ({:.2%})\t' 263 | 'Prec@5 {:.2%} ({:.2%})\t' 264 | .format(epoch, i + 1, len(data_loader), 265 | batch_time.val, batch_time.avg, 266 | data_time.val, data_time.avg, 267 | losses.val, losses.avg, 268 | top1.val, top1.avg, 269 | top5.val, top5.avg)) 270 | if (epoch+1) % print_info == 0: 271 | self.show_info() 272 | 273 | def show_info(self, with_arch=False, with_grad=True): 274 | if with_arch: 275 | print('\n\n################# model modules ###################') 276 | for name, m in self.model.named_modules(): 277 | print('{}: {}'.format(name, m)) 278 | print('################# model modules ###################\n\n') 279 | 280 | if with_grad: 281 | print('################# model params diff ###################') 282 | for name, param in self.model.named_parameters(): 283 | mean_value = torch.abs(param.data).mean() 284 | mean_grad = torch.abs(param.grad).mean().data[0] + 1e-8 285 | print('{}: size{}, data_abd_avg: {}, dgrad_abd_avg: {}, data/grad: {}'.format(name, 286 | param.size(), mean_value, mean_grad, mean_value/mean_grad)) 287 | print('################# model params diff ###################\n\n') 288 | 289 | else: 290 | print('################# model params ###################') 291 | for name, param in self.model.named_parameters(): 292 | print('{}: size{}, abs_avg: {}'.format(name, 293 | param.size(), 294 | torch.abs(param.data.cpu()).mean())) 295 | print('################# model params ###################\n\n') 296 | 297 | def _parse_data(self, inputs): 298 | imgs, _, labels = inputs 299 | inputs_var = [Variable(imgs)] 300 | targets_var = Variable(labels.cuda()) 301 | return inputs_var, targets_var 302 | 303 | def _forward(self, inputs, targets): 304 | outputs = self.model(*inputs) 305 | if isinstance(self.criterion, torch.nn.CrossEntropyLoss): 306 | loss = self.criterion(outputs, targets) 307 | prec1, prec5= accuracy(outputs.data, targets.data, topk=(1,5)) 308 | prec1 = prec1[0] 309 | prec5 = prec5[0] 310 | else: 311 | raise ValueError("Unsupported loss:", self.criterion) 312 | return loss, prec1, prec5 313 | 314 | class Evaluator(object): 315 | def __init__(self, model, criterion): 316 | super(Evaluator, self).__init__() 317 | self.model = model 318 | self.criterion = criterion 319 | 320 | def evaluate(self, data_loader, print_freq=1): 321 | batch_time = AverageMeter() 322 | losses = AverageMeter() 323 | top1 = AverageMeter() 324 | top5 = AverageMeter() 325 | 326 | self.model.eval() 327 | 328 | end = time.time() 329 | 330 | for i, inputs in enumerate(data_loader): 331 | inputs_var, targets_var = self._parse_data(inputs) 332 | 333 | loss, prec1, prec5 = self._forward(inputs_var, targets_var) 334 | 335 | losses.update(loss.data[0], targets_var.size(0)) 336 | top1.update(prec1, targets_var.size(0)) 337 | top5.update(prec5, targets_var.size(0)) 338 | 339 | batch_time.update(time.time() - end) 340 | end = time.time() 341 | 342 | if i % print_freq == 0: 343 | print('Test: [{}/{}]\t' 344 | 'Time {:.3f} ({:.3f})\t' 345 | 'Loss {:.4f} ({:.4f})\t' 346 | 'Prec@1 {:.2%} ({:.2%})\t' 347 | 'Prec@5 {:.2%} ({:.2%})\t' 348 | .format(i + 1, len(data_loader), 349 | batch_time.val, batch_time.avg, 350 | losses.val, losses.avg, 351 | top1.val, top1.avg, 352 | top5.val, top5.avg)) 353 | 354 | print(' * Prec@1 {:.2%} Prec@5 {:.2%}'.format(top1.avg, top5.avg)) 355 | 356 | return top1.avg 357 | 358 | def _parse_data(self, inputs): 359 | imgs, _, labels = inputs 360 | inputs_var = [Variable(imgs, volatile=True)] 361 | targets_var = Variable(labels.cuda(), volatile=True) 362 | return inputs_var, targets_var 363 | 364 | def _forward(self, inputs, targets): 365 | outputs = self.model(*inputs) 366 | if isinstance(self.criterion, torch.nn.CrossEntropyLoss): 367 | loss = self.criterion(outputs, targets) 368 | prec1, prec5= accuracy(outputs.data, targets.data, topk=(1,5)) 369 | prec1 = prec1[0] 370 | prec5 = prec5[0] 371 | else: 372 | raise ValueError("Unsupported loss:", self.criterion) 373 | return loss, prec1, prec5 374 | 375 | 376 | if __name__ == '__main__': 377 | parser = argparse.ArgumentParser(description="Softmax loss classification") 378 | # data 379 | parser.add_argument('-d', '--dataset', type=str, default='imagenet') 380 | parser.add_argument('-b', '--batch-size', type=int, default=256) 381 | parser.add_argument('-j', '--workers', type=int, default=4) 382 | parser.add_argument('--split', type=int, default=0) 383 | parser.add_argument('--scale_size', type=int, default=256, 384 | help="val resize image size, default: 256 for ImageNet") 385 | parser.add_argument('--img_size', type=int, default=224, 386 | help="input image size, default: 224 for ImageNet") 387 | # model 388 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 389 | choices=models.names()) 390 | # optimizer 391 | parser.add_argument('--lr', type=float, default=0.001, 392 | help="learning rate of new parameters, for pretrained " 393 | "parameters it is 10 times smaller than this") 394 | parser.add_argument('--momentum', type=float, default=0.9) 395 | parser.add_argument('--weight-decay', type=float, default=1e-5) 396 | parser.add_argument('--step_size', type=int, default=25) 397 | parser.add_argument('--decay_step', type=int, default=25) 398 | 399 | # training configs pretrained_model 400 | parser.add_argument('--pretrained', type=str, default='', metavar='PATH') 401 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 402 | parser.add_argument('--resume_epoch', type=int,default=0) 403 | parser.add_argument('--evaluate', action='store_true', 404 | help="evaluation only") 405 | parser.add_argument('--adam', action='store_true', 406 | help="use Adam") 407 | parser.add_argument('--epochs', type=int, default=100) 408 | parser.add_argument('--start_save', type=int, default=0, 409 | help="start saving checkpoints after specific epoch") 410 | parser.add_argument('--seed', type=int, default=1) 411 | parser.add_argument('--print-freq', type=int, default=1) 412 | parser.add_argument('--print-info', type=int, default=10) 413 | # misc 414 | working_dir = osp.dirname(osp.abspath(__file__)) 415 | parser.add_argument('--data-dir', type=str, metavar='PATH', 416 | default=osp.join(working_dir, 'data')) 417 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 418 | default=osp.join(working_dir, 'logs')) 419 | main(parser.parse_args()) 420 | 421 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import absolute_import 5 | 6 | from .alexnet import * 7 | from .alexnet_all import * 8 | from .resnet import * 9 | from .resnet18_all import * 10 | 11 | __factory = { 12 | 'alexnet': alexnet, 13 | 'alexnet_q': alexnet_q, 14 | 'resnet18': resnet18, 15 | 'resnet18_q': resnet18_q, 16 | } 17 | 18 | 19 | def names(): 20 | return sorted(__factory.keys()) 21 | 22 | 23 | def create(name, *args, **kwargs): 24 | """ 25 | Create a model instance. 26 | 27 | Parameters 28 | ---------- 29 | name : str 30 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 31 | 'resnet50', 'resnet101', and 'resnet152'. 32 | pretrained : bool, optional 33 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 34 | model. Default: True 35 | num_classes : int, optional 36 | If positive, will append a Linear layer at the end as the classifier 37 | with this number of output units. Default: 0 38 | """ 39 | if name not in __factory: 40 | raise KeyError("Unknown model:", name) 41 | return __factory[name](*args, **kwargs) 42 | -------------------------------------------------------------------------------- /models/alexnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import print_function, absolute_import 5 | 6 | import os 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from torch.nn import init 12 | import torch.utils.model_zoo as model_zoo 13 | 14 | import pdb 15 | 16 | __all__ = ['AlexNet', 'alexnet'] 17 | 18 | 19 | class ContConv2d(nn.Module): 20 | def __init__(self, input_channels, output_channels, 21 | kernel_size=-1, stride=-1, padding=-1, groups=1, Linear=False): 22 | super(ContConv2d, self).__init__() 23 | self.kernel_size = kernel_size 24 | self.stride = stride 25 | self.padding = padding 26 | self.groups = groups 27 | 28 | self.Linear = Linear 29 | if not self.Linear: 30 | self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=self.kernel_size, 31 | stride=self.stride, padding=self.padding, groups=self.groups) 32 | self.bn = nn.BatchNorm2d(output_channels, eps=1e-3) 33 | else: 34 | self.linear = nn.Linear(input_channels, output_channels) 35 | self.bn = nn.BatchNorm1d(output_channels, eps=1e-3) 36 | self.relu = nn.ReLU(inplace=True) 37 | 38 | def forward(self, x): 39 | if not self.Linear: 40 | x = self.conv(x) 41 | else: 42 | x = self.linear(x) 43 | x = self.bn(x) 44 | x = self.relu(x) 45 | 46 | return x 47 | 48 | 49 | class AlexNet(nn.Module): 50 | def __init__(self, num_classes=1000): 51 | super(AlexNet, self).__init__() 52 | self.num_classes = num_classes 53 | self.features_0 = nn.Sequential( 54 | ContConv2d(3, 96, kernel_size=11, stride=4, padding=2), 55 | nn.MaxPool2d(kernel_size=3, stride=2), 56 | ) 57 | self.features_1 = nn.Sequential( 58 | ContConv2d(96, 256, kernel_size=5, stride=1, padding=2), 59 | nn.MaxPool2d(kernel_size=3, stride=2), 60 | ) 61 | self.features_2 = nn.Sequential( 62 | ContConv2d(256, 384, kernel_size=3, stride=1, padding=1), 63 | ContConv2d(384, 384, kernel_size=3, stride=1, padding=1), 64 | ContConv2d(384, 256, kernel_size=3, stride=1, padding=1), 65 | nn.MaxPool2d(kernel_size=3, stride=2), 66 | ) 67 | self.classifier = nn.Sequential( 68 | nn.Dropout(p=0.1), 69 | ContConv2d(256*6*6, 4096, Linear=True), 70 | nn.Dropout(p=0.1), 71 | ContConv2d(4096, 4096, Linear=True), 72 | nn.Linear(4096, self.num_classes), 73 | ) 74 | 75 | self.reset_params() 76 | 77 | def forward(self, x): 78 | x = self.features_0(x) 79 | x = self.features_1(x) 80 | x = self.features_2(x) 81 | x = x.view(x.size(0), -1) 82 | x = self.classifier(x) 83 | return x 84 | 85 | def reset_params(self): 86 | for m in self.modules(): 87 | if isinstance(m, nn.Conv2d): 88 | init.kaiming_normal(m.weight, mode='fan_in') 89 | if m.bias is not None: 90 | init.constant(m.bias, 0) 91 | elif isinstance(m, nn.BatchNorm2d): 92 | init.constant(m.weight, 1) 93 | init.constant(m.bias, 0) 94 | elif isinstance(m, nn.BatchNorm1d): 95 | init.constant(m.weight, 1) 96 | init.constant(m.bias, 0) 97 | elif isinstance(m, nn.Linear): 98 | init.kaiming_normal(m.weight, mode='fan_in') 99 | if m.bias is not None: 100 | init.constant(m.bias, 0) 101 | 102 | def alexnet(pretrained=False, **kwargs): 103 | model=AlexNet(**kwargs) 104 | if pretrained: 105 | model_path='model_list/alexnet.pth.tar' 106 | pretrained_model = torch.load(model_path) 107 | model.load_state_dict(pretrained_model['state_dict']) 108 | return model 109 | 110 | -------------------------------------------------------------------------------- /models/alexnet_all.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # alexnet_all.py is used to quantize the weight and activation of AlexNet. 4 | from __future__ import print_function, absolute_import 5 | 6 | import os 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from torch.nn import init 12 | import torch.utils.model_zoo as model_zoo 13 | from .quantization import * 14 | import pdb 15 | 16 | __all__ = ['AlexNet_Q', 'alexnet_q'] 17 | 18 | 19 | class ContConv2d(nn.Module): 20 | def __init__(self, input_channels, output_channels, ac_quan_values, ac_quan_bias, ac_init_beta, count, 21 | kernel_size=-1, stride=-1, padding=-1, groups=1, QA_flag=True, Linear=False): 22 | super(ContConv2d, self).__init__() 23 | self.kernel_size = kernel_size 24 | self.stride = stride 25 | self.padding = padding 26 | self.groups = groups 27 | self.QA_flag = QA_flag 28 | 29 | self.Linear = Linear 30 | if not self.Linear: 31 | self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=self.kernel_size, 32 | stride=self.stride, padding=self.padding, groups=self.groups) 33 | self.bn = nn.BatchNorm2d(output_channels, eps=1e-3) 34 | else: 35 | self.linear = nn.Linear(input_channels, output_channels) 36 | self.bn = nn.BatchNorm1d(output_channels, eps=1e-3) 37 | self.relu = nn.ReLU(inplace=True) 38 | 39 | if self.QA_flag: 40 | self.quan = Quantization(quant_values=ac_quan_values, quan_bias=ac_quan_bias[count], init_beta=ac_init_beta) 41 | 42 | self.ac_T = 1 43 | 44 | def set_activation_T(self, activation_T): 45 | self.ac_T = activation_T 46 | 47 | 48 | def forward(self, x): 49 | if not self.Linear: 50 | x = self.conv(x) 51 | else: 52 | x = self.linear(x) 53 | x = self.bn(x) 54 | x = self.relu(x) 55 | #quantization 56 | if self.QA_flag: 57 | x = self.quan(x, self.ac_T) 58 | 59 | return x 60 | 61 | 62 | class AlexNet_Q(nn.Module): 63 | def __init__(self, QA_flag=True, ac_quan_bias=None, ac_quan_values=None, ac_beta=None, num_classes=1000): 64 | self.ac_quan_values = ac_quan_values 65 | self.ac_quan_bias = ac_quan_bias 66 | self.ac_beta = ac_beta 67 | self.count = 0 68 | super(AlexNet_Q, self).__init__() 69 | self.num_classes = num_classes 70 | 71 | self.QA_flag = QA_flag 72 | 73 | self.features_0 = nn.Sequential( 74 | ContConv2d(3, 96, ac_quan_values=None, ac_quan_bias=None, ac_init_beta=None, count=0, 75 | kernel_size=11, stride=4, padding=2, QA_flag=False), 76 | nn.MaxPool2d(kernel_size=3, stride=2), 77 | ) 78 | if self.QA_flag: 79 | #print(self.count) 80 | self.quan0 = Quantization(quant_values=self.ac_quan_values, quan_bias=self.ac_quan_bias[self.count], 81 | init_beta=self.ac_beta[self.count]) 82 | self.count += 1 83 | 84 | self.features_1 = nn.Sequential( 85 | ContConv2d(96, 256, ac_quan_values=None, ac_quan_bias=None, ac_init_beta=None, count=0, 86 | kernel_size=5, stride=1, padding=2, QA_flag=False), 87 | nn.MaxPool2d(kernel_size=3, stride=2), 88 | ) 89 | if self.QA_flag: 90 | #print(self.count) 91 | self.quan1 = Quantization(quant_values=self.ac_quan_values, quan_bias=self.ac_quan_bias[self.count], 92 | init_beta=self.ac_beta[self.count]) 93 | self.count += 1 94 | 95 | self.features_2 = nn.Sequential( 96 | ContConv2d(256, 384, ac_quan_values=self.ac_quan_values, ac_quan_bias=self.ac_quan_bias, ac_init_beta=self.ac_beta, 97 | count=2, kernel_size=3, stride=1, padding=1, QA_flag=self.QA_flag), 98 | ContConv2d(384, 384, ac_quan_values=self.ac_quan_values, ac_quan_bias=self.ac_quan_bias, ac_init_beta=self.ac_beta, 99 | count=3, kernel_size=3, stride=1, padding=1, QA_flag=self.QA_flag), 100 | ContConv2d(384, 256, ac_quan_values=None, ac_quan_bias=None, ac_init_beta=None, count=0, 101 | kernel_size=3, stride=1, padding=1, QA_flag=False), 102 | nn.MaxPool2d(kernel_size=3, stride=2), 103 | ) 104 | if self.QA_flag: 105 | #print(self.count) 106 | self.quan2 = Quantization(quant_values=self.ac_quan_values, quan_bias=self.ac_quan_bias[4], 107 | init_beta=self.ac_beta[4]) 108 | 109 | self.classifier = nn.Sequential( 110 | nn.Dropout(p=0.1), 111 | ContConv2d(256*6*6, 4096, ac_quan_values=self.ac_quan_values, ac_quan_bias=self.ac_quan_bias, ac_init_beta=self.ac_beta, 112 | count=5, Linear=True, QA_flag=self.QA_flag), 113 | nn.Dropout(p=0.1), 114 | ContConv2d(4096, 4096, ac_quan_values=None, ac_quan_bias=None, ac_init_beta=None, count=0, Linear=True, QA_flag=False), 115 | nn.Linear(4096, self.num_classes), 116 | ) 117 | 118 | self.reset_params() 119 | 120 | def set_ac_T(self, input_ac_T): 121 | for m in self.features_0: 122 | if isinstance(m, ContConv2d): 123 | m.set_activation_T(input_ac_T) 124 | for m in self.features_1: 125 | if isinstance(m, ContConv2d): 126 | m.set_activation_T(input_ac_T) 127 | for m in self.features_2: 128 | if isinstance(m, ContConv2d): 129 | m.set_activation_T(input_ac_T) 130 | for m in self.classifier: 131 | if isinstance(m, ContConv2d): 132 | m.set_activation_T(input_ac_T) 133 | 134 | def forward(self, x, input_ac_T=1): 135 | if self.QA_flag: 136 | self.set_ac_T(input_ac_T) 137 | 138 | x = self.features_0(x) 139 | if self.QA_flag: 140 | x = self.quan0(x, input_ac_T) 141 | 142 | x = self.features_1(x) 143 | if self.QA_flag: 144 | x = self.quan1(x, input_ac_T) 145 | 146 | x = self.features_2(x) 147 | if self.QA_flag: 148 | x = self.quan2(x, input_ac_T) 149 | 150 | x = x.view(x.size(0), -1) 151 | x = self.classifier(x) 152 | return x 153 | 154 | def reset_params(self): 155 | for m in self.modules(): 156 | if isinstance(m, nn.Conv2d): 157 | init.kaiming_normal(m.weight, mode='fan_in') 158 | if m.bias is not None: 159 | init.constant(m.bias, 0) 160 | elif isinstance(m, nn.BatchNorm2d): 161 | init.constant(m.weight, 1) 162 | init.constant(m.bias, 0) 163 | elif isinstance(m, nn.BatchNorm1d): 164 | init.constant(m.weight, 1) 165 | init.constant(m.bias, 0) 166 | elif isinstance(m, nn.Linear): 167 | init.kaiming_normal(m.weight, mode='fan_in') 168 | if m.bias is not None: 169 | init.constant(m.bias, 0) 170 | 171 | def alexnet_q(pretrained=False, **kwargs): 172 | model=AlexNet_Q(**kwargs) 173 | if pretrained: 174 | model_path='model_list/alexnet.pth.tar' 175 | pretrained_model = torch.load(model_path) 176 | model.load_state_dict(pretrained_model['state_dict']) 177 | return model 178 | 179 | -------------------------------------------------------------------------------- /models/quantization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # quantization.py is used to quantize the activation of model. 4 | from __future__ import print_function, absolute_import 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.nn import init 9 | import torch.nn as nn 10 | import pickle 11 | from torch.nn.parameter import Parameter 12 | from torch.autograd import Variable 13 | import numpy as np 14 | import pdb 15 | 16 | class SigmoidT(torch.autograd.Function): 17 | """ sigmoid with temperature T for training 18 | we need the gradients for input and bias 19 | for customization of function, refer to https://pytorch.org/docs/stable/notes/extending.html 20 | """ 21 | 22 | @staticmethod 23 | def forward(self, input, scales, n, b, T): 24 | self.save_for_backward(input) 25 | self.T = T 26 | self.b = b 27 | self.scales = scales 28 | self.n = n 29 | 30 | buf = torch.clamp(self.T * (input - self.b[0]), min=-10.0, max=10.0) 31 | output = self.scales[0] / (1.0 + torch.exp(-buf)) 32 | for k in range(1, self.n): 33 | buf = torch.clamp(self.T * (input - self.b[k]), min=-10.0, max=10.0) 34 | output += self.scales[k] / (1.0 + torch.exp(-buf)) 35 | return output 36 | 37 | @staticmethod 38 | def backward(self, grad_output): 39 | # set T = 1 when train binary model in the backward. 40 | #self.T = 1 41 | input, = self.saved_tensors 42 | b_buf = torch.clamp(self.T * (input - self.b[0]), min=-10.0, max=10.0) 43 | b_output = self.scales[0] / (1.0 + torch.exp(-b_buf)) 44 | temp = b_output * (1 - b_output) * self.T 45 | for j in range(1, self.n): 46 | b_buf = torch.clamp(self.T * (input - self.b[j]), min=-10.0, max=10.0) 47 | b_output = self.scales[j] / (1.0 + torch.exp(-b_buf)) 48 | temp += b_output * (1 - b_output) * self.T 49 | grad_input = Variable(temp) * grad_output 50 | # corresponding to grad_input 51 | return grad_input, None, None, None, None 52 | 53 | sigmoidT = SigmoidT.apply 54 | 55 | def step(x, b): 56 | """ 57 | The step function for ideal quantization function in test stage. 58 | """ 59 | y = torch.zeros_like(x) 60 | mask = torch.gt(x - b, 0.0) 61 | y[mask] = 1.0 62 | return y 63 | 64 | 65 | class Quantization(nn.Module): 66 | """ Quantization Activation 67 | Args: 68 | quant_values: the target quantized values, like [-4, -2, -1, 0, 1 , 2, 4] 69 | quan_bias and init_beta: the data for initialization of quantization parameters (biases, beta) 70 | - for activations, format as `N x 1` for biases and `1x1` for (beta) 71 | we need to obtain the intialization values for biases and beta offline 72 | 73 | Shape: 74 | - Input: :math:`(N, C, H, W)` 75 | - Output: :math:`(N, C, H, W)` (same shape as input) 76 | 77 | Usage: 78 | - for activations, just pending this module to the activations when build the graph 79 | """ 80 | 81 | def __init__(self, quant_values=[-1, 0, 1], quan_bias=[0], init_beta=0.0): 82 | super(Quantization, self).__init__() 83 | """register_parameter: params w/ grad, and need to be learned 84 | register_buffer: params w/o grad, do not need to be learned 85 | example shown in: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py 86 | """ 87 | self.values = quant_values 88 | # number of sigmoids 89 | self.n = len(self.values) - 1 90 | self.alpha = Parameter(torch.Tensor([1])) 91 | self.beta = Parameter(torch.Tensor([1])) 92 | self.register_buffer('biases', torch.zeros(self.n)) 93 | self.register_buffer('scales', torch.zeros(self.n)) 94 | 95 | boundary = np.array(quan_bias) 96 | self.init_scale_and_offset() 97 | self.bias_inited = False 98 | self.alpha_beta_inited = False 99 | self.init_biases(boundary) 100 | self.init_alpha_and_beta(init_beta) 101 | 102 | def init_scale_and_offset(self): 103 | """ 104 | Initialize the scale and offset of quantization function. 105 | """ 106 | for i in range(self.n): 107 | gap = self.values[i + 1] - self.values[i] 108 | self.scales[i] = gap 109 | 110 | def init_biases(self, init_data): 111 | """ 112 | Initialize the bias of quantization function. 113 | init_data in numpy format. 114 | """ 115 | # activations initialization (obtained offline) 116 | assert init_data.size == self.n 117 | self.biases.copy_(torch.from_numpy(init_data)) 118 | self.bias_inited = True 119 | #print('baises inited!!!') 120 | 121 | def init_alpha_and_beta(self, init_beta): 122 | """ 123 | Initialize the alpha and beta of quantization function. 124 | init_data in numpy format. 125 | """ 126 | # activations initialization (obtained offline) 127 | self.beta.data = torch.Tensor([init_beta]).cuda() 128 | self.alpha.data = torch.reciprocal(self.beta.data) 129 | self.alpha_beta_inited = True 130 | 131 | def forward(self, input, T=1): 132 | assert self.bias_inited 133 | input = input.mul(self.beta) 134 | if self.training: 135 | assert self.alpha_beta_inited 136 | output = sigmoidT(input, self.scales, self.n, self.biases, T) 137 | else: 138 | output = step(input, b=self.biases[0])*self.scales[0] 139 | for i in range(1, self.n): 140 | output += step(input, b=self.biases[i])*self.scales[i] 141 | 142 | output = output.mul(self.alpha) 143 | return output 144 | 145 | 146 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import print_function, absolute_import 5 | 6 | import torch.nn as nn 7 | from torch.nn import init 8 | import math 9 | import torch.utils.model_zoo as model_zoo 10 | import pdb 11 | 12 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50'] 13 | 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1): 23 | "3x3 convolution with padding" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=1, bias=False) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None): 32 | super(BasicBlock, self).__init__() 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = nn.BatchNorm2d(planes) 35 | #self.relu = nn.ReLU6(inplace=True) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 69 | padding=1, bias=False) 70 | self.bn2 = nn.BatchNorm2d(planes) 71 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(planes * 4) 73 | #self.relu = nn.ReLU6(inplace=True) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class ResNet(nn.Module): 102 | 103 | def __init__(self, block, layers, num_classes=1000): 104 | self.inplanes = 64 105 | super(ResNet, self).__init__() 106 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 107 | bias=False) 108 | self.bn1 = nn.BatchNorm2d(64) 109 | #self.relu = nn.ReLU6(inplace=True) 110 | self.relu = nn.ReLU(inplace=True) 111 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 112 | self.layer1 = self._make_layer(block, 64, layers[0]) 113 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 114 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 115 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 116 | self.avgpool = nn.AvgPool2d(7, stride=1) 117 | self.fc = nn.Linear(512 * block.expansion, num_classes) 118 | 119 | 120 | self.set_params() 121 | # for m in self.modules(): 122 | # if isinstance(m, nn.Conv2d): 123 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 124 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 125 | # elif isinstance(m, nn.BatchNorm2d): 126 | # m.weight.data.fill_(1) 127 | # m.bias.data.zero_() 128 | 129 | def _make_layer(self, block, planes, blocks, stride=1): 130 | downsample = None 131 | if stride != 1 or self.inplanes != planes * block.expansion: 132 | downsample = nn.Sequential( 133 | nn.Conv2d(self.inplanes, planes * block.expansion, 134 | kernel_size=1, stride=stride, bias=False), 135 | nn.BatchNorm2d(planes * block.expansion), 136 | ) 137 | 138 | layers = [] 139 | layers.append(block(self.inplanes, planes, stride, downsample)) 140 | self.inplanes = planes * block.expansion 141 | for i in range(1, blocks): 142 | layers.append(block(self.inplanes, planes)) 143 | 144 | return nn.Sequential(*layers) 145 | 146 | def set_params(self): 147 | for m in self.modules(): 148 | if isinstance(m, nn.Conv2d): 149 | init.kaiming_normal(m.weight, mode='fan_in') 150 | if m.bias is not None: 151 | init.constant(m.bias, 0) 152 | elif isinstance(m, nn.BatchNorm2d): 153 | init.constant(m.weight, 1) 154 | init.constant(m.bias, 0) 155 | elif isinstance(m, nn.Linear): 156 | init.normal(m.weight, std=0.001) 157 | if m.bias is not None: 158 | init.constant(m.bias, 0) 159 | 160 | def forward(self, x): 161 | x = self.conv1(x) 162 | x = self.bn1(x) 163 | x = self.relu(x) 164 | x = self.maxpool(x) 165 | 166 | x = self.layer1(x) 167 | x = self.layer2(x) 168 | x = self.layer3(x) 169 | x = self.layer4(x) 170 | 171 | x = self.avgpool(x) 172 | x = x.view(x.size(0), -1) 173 | x = self.fc(x) 174 | 175 | return x 176 | 177 | 178 | def resnet18(pretrained=True, **kwargs): 179 | """Constructs a ResNet-18 model. 180 | 181 | Args: 182 | pretrained (bool): If True, returns a model pre-trained on ImageNet 183 | """ 184 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 185 | if pretrained: 186 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 187 | return model 188 | 189 | 190 | def resnet34(pretrained=True, **kwargs): 191 | """Constructs a ResNet-34 model. 192 | 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | """ 196 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 197 | if pretrained: 198 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 199 | return model 200 | 201 | 202 | def resnet50(pretrained=True, **kwargs): 203 | """Constructs a ResNet-50 model. 204 | 205 | Args: 206 | pretrained (bool): If True, returns a model pre-trained on ImageNet 207 | """ 208 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 209 | if pretrained: 210 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 211 | return model 212 | 213 | -------------------------------------------------------------------------------- /models/resnet18_all.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | #resnet18_all.py is used to quantize the weight and activation of ResNet-18. 4 | from __future__ import print_function, absolute_import 5 | 6 | import torch.nn as nn 7 | from torch.nn import init 8 | import math 9 | import torch.utils.model_zoo as model_zoo 10 | from .quantization import * 11 | import pdb 12 | 13 | __all__ = ['ResNet_Q', 'resnet18_q'] 14 | 15 | 16 | model_urls = { 17 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 18 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1): 23 | "3x3 convolution with padding" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=1, bias=False) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, ac_quan_values, ac_quan_bias, ac_init_beta, count, 32 | stride=1, downsample=None, QA_flag=True): 33 | super(BasicBlock, self).__init__() 34 | self.QA_flag = QA_flag 35 | self.conv1 = conv3x3(inplanes, planes, stride) 36 | self.bn1 = nn.BatchNorm2d(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | if self.QA_flag: 39 | self.quan1 = Quantization(quant_values=ac_quan_values, quan_bias=ac_quan_bias[count], init_beta=ac_init_beta[count]) 40 | 41 | self.conv2 = conv3x3(planes, planes) 42 | self.bn2 = nn.BatchNorm2d(planes) 43 | self.downsample = downsample 44 | if self.QA_flag: 45 | self.quan2 = Quantization(quant_values=ac_quan_values, quan_bias=ac_quan_bias[count+1], init_beta=ac_init_beta[count+1]) 46 | 47 | self.stride = stride 48 | self.ac_T = 1 49 | 50 | def set_activation_T(self, activation_T): 51 | self.ac_T = activation_T 52 | 53 | def forward(self, x): 54 | residual = x 55 | 56 | out = self.conv1(x) 57 | out = self.bn1(out) 58 | out = self.relu(out) 59 | # quantization 60 | 61 | if self.QA_flag: 62 | out = self.quan1(out, self.ac_T) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | 67 | if self.downsample is not None: 68 | residual = self.downsample(x) 69 | 70 | out += residual 71 | out = self.relu(out) 72 | if self.QA_flag: 73 | out = self.quan2(out, self.ac_T) 74 | 75 | return out 76 | 77 | class ResNet_Q(nn.Module): 78 | 79 | def __init__(self, block, layers, num_classes=1000, QA_flag=True, ac_quan_bias=None, ac_quan_values=None, ac_beta=None): 80 | self.inplanes = 64 81 | self.ac_quan_values = ac_quan_values 82 | self.ac_quan_bias = ac_quan_bias 83 | self.ac_beta = ac_beta 84 | self.count=0 85 | super(ResNet_Q, self).__init__() 86 | self.QA_flag = QA_flag 87 | 88 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 89 | bias=False) 90 | self.bn1 = nn.BatchNorm2d(64) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 93 | if self.QA_flag: 94 | #print(self.count) 95 | self.quan0 = Quantization(quant_values=self.ac_quan_values, quan_bias=self.ac_quan_bias[self.count], 96 | init_beta=self.ac_beta[self.count]) 97 | self.count += 1 98 | 99 | self.layer1 = self._make_layer(block, 64, layers[0], QA_flag=self.QA_flag) 100 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, QA_flag=self.QA_flag) 101 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, QA_flag=self.QA_flag) 102 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, QA_flag=self.QA_flag) 103 | self.avgpool = nn.AvgPool2d(7, stride=1) 104 | self.fc = nn.Linear(512 * block.expansion, num_classes) 105 | 106 | 107 | self.set_params() 108 | # for m in self.modules(): 109 | # if isinstance(m, nn.Conv2d): 110 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 111 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 112 | # elif isinstance(m, nn.BatchNorm2d): 113 | # m.weight.data.fill_(1) 114 | # m.bias.data.zero_() 115 | 116 | def _make_layer(self, block, planes, blocks, stride=1, QA_flag=True): 117 | downsample = None 118 | if stride != 1 or self.inplanes != planes * block.expansion: 119 | downsample = nn.Sequential( 120 | nn.Conv2d(self.inplanes, planes * block.expansion, 121 | kernel_size=1, stride=stride, bias=False), 122 | nn.BatchNorm2d(planes * block.expansion), 123 | ) 124 | 125 | layers = [] 126 | layers.append(block(self.inplanes, planes, self.ac_quan_values, self.ac_quan_bias, self.ac_beta, 127 | self.count, stride, downsample, QA_flag=QA_flag)) 128 | self.inplanes = planes * block.expansion 129 | self.count += 2 130 | for i in range(1, blocks): 131 | layers.append(block(self.inplanes, planes, self.ac_quan_values, \ 132 | self.ac_quan_bias, self.ac_beta, self.count, QA_flag=QA_flag)) 133 | self.count += 2 134 | 135 | return nn.Sequential(*layers) 136 | 137 | def set_resnet_ac_T(self, input_ac_T): 138 | for m in self.layer1: 139 | if isinstance(m, BasicBlock): 140 | m.set_activation_T(input_ac_T) 141 | for m in self.layer2: 142 | if isinstance(m, BasicBlock): 143 | m.set_activation_T(input_ac_T) 144 | for m in self.layer3: 145 | if isinstance(m, BasicBlock): 146 | m.set_activation_T(input_ac_T) 147 | for m in self.layer4: 148 | if isinstance(m, BasicBlock): 149 | m.set_activation_T(input_ac_T) 150 | 151 | 152 | def set_params(self): 153 | for m in self.modules(): 154 | if isinstance(m, nn.Conv2d): 155 | init.kaiming_normal(m.weight, mode='fan_in') 156 | if m.bias is not None: 157 | init.constant(m.bias, 0) 158 | elif isinstance(m, nn.BatchNorm2d): 159 | init.constant(m.weight, 1) 160 | init.constant(m.bias, 0) 161 | elif isinstance(m, nn.Linear): 162 | init.normal(m.weight, std=0.001) 163 | if m.bias is not None: 164 | init.constant(m.bias, 0) 165 | 166 | def forward(self, x, input_ac_T=0): 167 | if self.QA_flag: 168 | self.set_resnet_ac_T(input_ac_T) 169 | 170 | x = self.conv1(x) 171 | x = self.bn1(x) 172 | x = self.relu(x) 173 | x = self.maxpool(x) 174 | 175 | if self.QA_flag: 176 | l1 = self.quan0(x, input_ac_T) 177 | 178 | x = self.layer1(x) 179 | x = self.layer2(x) 180 | x = self.layer3(x) 181 | x = self.layer4(x) 182 | 183 | x = self.avgpool(x) 184 | x = x.view(x.size(0), -1) 185 | x = self.fc(x) 186 | 187 | return x 188 | 189 | 190 | def resnet18_q(pretrained=False, **kwargs): 191 | """Constructs a ResNet-18 model. 192 | 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | """ 196 | model = ResNet_Q(BasicBlock, [2, 2, 2, 2], **kwargs) 197 | if pretrained: 198 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 199 | return model 200 | 201 | 202 | -------------------------------------------------------------------------------- /quan-weight.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python quan_weight_main.py -a resnet18 -b 256 -d imagenet \ 2 | --img_size 224 -j 16 --weight-decay 1e-4 \ 3 | --lr 0.001 --temperature 20 \ 4 | --offline_biases resnet18-w-1 \ 5 | --step_size 25 --decay_step 5 --epochs 35 \ 6 | --start_save 0 --print-info 1 \ 7 | --pretrained ./pre-models/resnet18.pth \ 8 | --logs-dir logs/quan-weight/resnet18-quan-w-1 9 | -------------------------------------------------------------------------------- /quan_all_main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # quan_all_main.py is used to train the weight and activation quantized model. 4 | 5 | from __future__ import print_function, absolute_import 6 | import argparse 7 | import os.path as osp 8 | import os 9 | import numpy as np 10 | import sys 11 | import time 12 | import math 13 | import torch 14 | from torch import nn 15 | from torch.autograd import Variable 16 | from torch.backends import cudnn 17 | from torch.utils.data import DataLoader 18 | from torch.utils.data.sampler import RandomSampler 19 | from torch.nn.parameter import Parameter 20 | from torchvision import transforms as T 21 | 22 | from config import * 23 | import models 24 | from data_pre import ColorJitter, Lighting, Preprocessor 25 | from utils import Logger, AverageMeter 26 | from utils import load_checkpoint, save_checkpoint 27 | from utils import RandomResized 28 | from anybit import QuaOp 29 | from evaluators import accuracy 30 | import pdb 31 | 32 | # define global qua_op 33 | qua_op = None 34 | 35 | def get_data(split_id, data_dir, img_size, scale_size, batch_size, 36 | workers, train_list, val_list): 37 | root = data_dir 38 | 39 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 40 | std=[0.229, 0.224, 0.225]) # RGB imagenet 41 | # with data augmentation 42 | train_transformer = T.Compose([ 43 | T.Resize(scale_size), 44 | T.RandomCrop(img_size), 45 | T.RandomHorizontalFlip(), 46 | T.ToTensor(), # [0, 255] to [0.0, 1.0] 47 | normalizer, # normalize each channel of the input 48 | ]) 49 | 50 | test_transformer = T.Compose([ 51 | T.Resize(scale_size), 52 | T.CenterCrop(img_size), 53 | T.ToTensor(), 54 | normalizer, 55 | ]) 56 | 57 | train_loader = DataLoader( 58 | Preprocessor(train_list, root=root, 59 | transform=train_transformer), 60 | batch_size=batch_size, num_workers=workers, 61 | sampler=RandomSampler(train_list), 62 | pin_memory=True, drop_last=False) 63 | 64 | val_loader = DataLoader( 65 | Preprocessor(val_list, root=root, 66 | transform=test_transformer), 67 | batch_size=batch_size, num_workers=workers, 68 | shuffle=False, pin_memory=True) 69 | 70 | return train_loader, val_loader 71 | 72 | def load_params(new_model, pretrained_model): 73 | #new_model_dict = new_model.module.state_dict() 74 | new_model_dict = new_model.state_dict() 75 | pretrained_checkpoint = load_checkpoint(pretrained_model) 76 | #for name, param in pretrained_checkpoint.items(): 77 | for name, param in pretrained_checkpoint['state_dict'].items(): 78 | print('pretrained_model params name and size: ', name, param.size()) 79 | if name in new_model_dict: 80 | if isinstance(param, Parameter): 81 | # backwards compatibility for serialized parameters 82 | param = param.data 83 | try: 84 | new_model_dict[name].copy_(param) 85 | print('############# new_model load params name: ',name) 86 | except: 87 | raise RuntimeError('While copying the parameter named {}, \ 88 | whose dimensions in the model are {} and \ 89 | whose dimensions in the checkpoint are {}.' 90 | .format(name, new_model_dict[name].size(), param.size())) 91 | else: 92 | continue 93 | 94 | def load_alexnet_params(new_model, pretrained_model): 95 | #new_model_dict = new_model.module.state_dict() 96 | new_model_dict = new_model.state_dict() 97 | pretrained_checkpoint = load_checkpoint(pretrained_model) 98 | for name, param in pretrained_checkpoint['state_dict'].items(): 99 | print('pretrained_model params name and size: ', name, param.size()) 100 | if name in new_model_dict: 101 | if isinstance(param, Parameter): 102 | # backwards compatibility for serialized parameters 103 | param = param.data 104 | try: 105 | new_model_dict[name].copy_(param) 106 | print('############# new_model load params name: ',name) 107 | except: 108 | raise RuntimeError('While copying the parameter named {}, \ 109 | whose dimensions in the model are {} and \ 110 | whose dimensions in the checkpoint are {}.' 111 | .format(name, new_model_dict[name].size(), param.size())) 112 | elif 'features.0' in name: 113 | if isinstance(param, Parameter): 114 | # backwards compatibility for serialized parameters 115 | param = param.data 116 | try: 117 | if name == 'features.0.conv.weight': 118 | new_name = 'features_0.0.conv.weight' 119 | elif name == 'features.0.conv.bias': 120 | new_name = 'features_0.0.conv.bias' 121 | elif name == 'features.0.bn.weight': 122 | new_name = 'features_0.0.bn.weight' 123 | elif name == 'features.0.bn.bias': 124 | new_name = 'features_0.0.bn.bias' 125 | elif name == 'features.0.bn.running_mean': 126 | new_name = 'features_0.0.bn.running_mean' 127 | elif name == 'features.0.bn.running_var': 128 | new_name = 'features_0.0.bn.running_var' 129 | new_model_dict[new_name].copy_(param) 130 | print('############# new_model load params name: ', new_name) 131 | except: 132 | raise RuntimeError('While copying the parameter named {}, \ 133 | whose dimensions in the model are {} and \ 134 | whose dimensions in the checkpoint are {}.' 135 | .format(name, new_model_dict[name].size(), param.size())) 136 | elif 'features.2' in name: 137 | if isinstance(param, Parameter): 138 | # backwards compatibility for serialized parameters 139 | param = param.data 140 | try: 141 | if name == 'features.2.conv.weight': 142 | new_name = 'features_1.0.conv.weight' 143 | elif name == 'features.2.conv.bias': 144 | new_name = 'features_1.0.conv.bias' 145 | elif name == 'features.2.bn.weight': 146 | new_name = 'features_1.0.bn.weight' 147 | elif name == 'features.2.bn.bias': 148 | new_name = 'features_1.0.bn.bias' 149 | elif name == 'features.2.bn.running_mean': 150 | new_name = 'features_1.0.bn.running_mean' 151 | elif name == 'features.2.bn.running_var': 152 | new_name = 'features_1.0.bn.running_var' 153 | new_model_dict[new_name].copy_(param) 154 | print('############# new_model load params name: ', new_name) 155 | except: 156 | raise RuntimeError('While copying the parameter named {}, \ 157 | whose dimensions in the model are {} and \ 158 | whose dimensions in the checkpoint are {}.' 159 | .format(name, new_model_dict[name].size(), param.size())) 160 | elif 'features.4' in name: 161 | if isinstance(param, Parameter): 162 | # backwards compatibility for serialized parameters 163 | param = param.data 164 | try: 165 | if name == 'features.4.conv.weight': 166 | new_name = 'features_2.0.conv.weight' 167 | elif name == 'features.4.conv.bias': 168 | new_name = 'features_2.0.conv.bias' 169 | elif name == 'features.4.bn.weight': 170 | new_name = 'features_2.0.bn.weight' 171 | elif name == 'features.4.bn.bias': 172 | new_name = 'features_2.0.bn.bias' 173 | elif name == 'features.4.bn.running_mean': 174 | new_name = 'features_2.0.bn.running_mean' 175 | elif name == 'features.4.bn.running_var': 176 | new_name = 'features_2.0.bn.running_var' 177 | new_model_dict[new_name].copy_(param) 178 | print('############# new_model load params name: ', new_name) 179 | except: 180 | raise RuntimeError('While copying the parameter named {}, \ 181 | whose dimensions in the model are {} and \ 182 | whose dimensions in the checkpoint are {}.' 183 | .format(name, new_model_dict[name].size(), param.size())) 184 | elif 'features.5' in name: 185 | if isinstance(param, Parameter): 186 | # backwards compatibility for serialized parameters 187 | param = param.data 188 | try: 189 | if name == 'features.5.conv.weight': 190 | new_name = 'features_2.1.conv.weight' 191 | elif name == 'features.5.conv.bias': 192 | new_name = 'features_2.1.conv.bias' 193 | elif name == 'features.5.bn.weight': 194 | new_name = 'features_2.1.bn.weight' 195 | elif name == 'features.5.bn.bias': 196 | new_name = 'features_2.1.bn.bias' 197 | elif name == 'features.5.bn.running_mean': 198 | new_name = 'features_2.1.bn.running_mean' 199 | elif name == 'features.5.bn.running_var': 200 | new_name = 'features_2.1.bn.running_var' 201 | new_model_dict[new_name].copy_(param) 202 | print('############# new_model load params name: ', new_name) 203 | except: 204 | raise RuntimeError('While copying the parameter named {}, \ 205 | whose dimensions in the model are {} and \ 206 | whose dimensions in the checkpoint are {}.' 207 | .format(name, new_model_dict[name].size(), param.size())) 208 | elif 'features.6' in name: 209 | if isinstance(param, Parameter): 210 | # backwards compatibility for serialized parameters 211 | param = param.data 212 | try: 213 | if name == 'features.6.conv.weight': 214 | new_name = 'features_2.2.conv.weight' 215 | elif name == 'features.6.conv.bias': 216 | new_name = 'features_2.2.conv.bias' 217 | elif name == 'features.6.bn.weight': 218 | new_name = 'features_2.2.bn.weight' 219 | elif name == 'features.6.bn.bias': 220 | new_name = 'features_2.2.bn.bias' 221 | elif name == 'features.6.bn.running_mean': 222 | new_name = 'features_2.2.bn.running_mean' 223 | elif name == 'features.6.bn.running_var': 224 | new_name = 'features_2.2.bn.running_var' 225 | new_model_dict[new_name].copy_(param) 226 | print('############# new_model load params name: ', new_name) 227 | except: 228 | raise RuntimeError('While copying the parameter named {}, \ 229 | whose dimensions in the model are {} and \ 230 | whose dimensions in the checkpoint are {}.' 231 | .format(name, new_model_dict[name].size(), param.size())) 232 | else: 233 | continue 234 | 235 | def main(args): 236 | np.random.seed(args.seed) 237 | torch.manual_seed(args.seed) 238 | torch.cuda.manual_seed(args.seed) 239 | cudnn.benchmark = True 240 | data_dir = osp.join(args.data_dir, args.dataset) 241 | # Redirect print to both console and log file 242 | if not args.evaluate: 243 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 244 | else: 245 | sys.stdout = Logger(osp.join(args.logs_dir, 'evaluate-log.txt')) 246 | print('\n################## setting ###################') 247 | print(parser.parse_args()) 248 | print('################## setting ###################\n') 249 | # Create data loaders 250 | def readlist(fpath): 251 | lines=[] 252 | with open(fpath, 'r') as f: 253 | data = f.readlines() 254 | 255 | for line in data: 256 | name, label = line.split() 257 | lines.append((name, int(label))) 258 | return lines 259 | 260 | # Load data list 261 | if osp.exists(osp.join(data_dir, 'train.txt')): 262 | train_list = readlist(osp.join(data_dir, 'train.txt')) 263 | else: 264 | raise RuntimeError("The training list -- {} doesn't exist".format(train_list)) 265 | 266 | if osp.exists(osp.join(data_dir, 'val.txt')): 267 | val_list = readlist(osp.join(data_dir, 'val.txt')) 268 | else: 269 | raise RuntimeError("The val list -- {} doesn't exist".format(val_list)) 270 | 271 | 272 | if args.scale_size is None : 273 | args.scale_size = 256 274 | if args.img_size is None : 275 | args.img_size = 224 276 | 277 | train_loader, val_loader = \ 278 | get_data(args.split, data_dir, args.img_size, 279 | args.scale_size, args.batch_size, args.workers, 280 | train_list, val_list) 281 | 282 | max_quan_value = pow(2, args.ak) 283 | ac_quan_values = [i for i in range(max_quan_value)] 284 | print('ac_quan_values: ', ac_quan_values) 285 | # Create model 286 | #num_classes = 1000 # imagenet 1000 287 | model = models.create(args.arch, QA_flag=True, ac_quan_bias = QA_biases[args.qa_biases], 288 | ac_quan_values=ac_quan_values, ac_beta=QA_beta[args.qa_beta], num_classes=1000) 289 | 290 | # create alpha and belta 291 | count = 0 292 | for m in model.modules(): 293 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 294 | count = count + 1 295 | alpha = [] 296 | beta = [] 297 | for i in range(count-2): 298 | alpha.append(Variable(torch.FloatTensor([0.0]).cuda(), requires_grad=True)) 299 | beta.append(Variable(torch.FloatTensor([0.0]).cuda(), requires_grad=True)) 300 | 301 | # model Load from checkpoint 302 | start_epoch = best_top1 = 0 303 | if args.pretrained_model: 304 | print('=> Start load params from pre-trained model...') 305 | if 'resnet' in args.arch: 306 | load_params(model, args.pretrained_model) 307 | elif 'alexnet' in args.arch: 308 | load_alexnet_params(model, args.pretrained_model) 309 | alpha = load_checkpoint(args.pretrained_model)['alpha'] 310 | beta = load_checkpoint(args.pretrained_model)['beta'] 311 | 312 | if args.resume: 313 | checkpoint = load_checkpoint(args.resume) 314 | model.load_state_dict(checkpoint['state_dict']) 315 | optimizer.load_state_dict(checkpoint['optimizer']) 316 | optimizer_alpha.load_state_dict(checkpoint['optimizer_alpha']) 317 | optimizer_beta.load_state_dict(checkpoint['optimizer_beta']) 318 | alpha = checkpoint['alpha'] 319 | beta = checkpoint['beta'] 320 | start_epoch = args.resume_epoch 321 | print("=> Finetune Start epoch {} " 322 | .format(start_epoch)) 323 | 324 | 325 | model = nn.DataParallel(model).cuda() 326 | 327 | # Criterion 328 | criterion = nn.CrossEntropyLoss().cuda() 329 | 330 | qw_values = QW_values[args.qw_biases] 331 | 332 | global qua_op 333 | qua_op = QuaOp(model, QW_biases[args.qw_biases], QW_values=qw_values) 334 | 335 | evaluator = Evaluator(model, criterion, alpha, beta) 336 | if args.evaluate: 337 | print('Test model: \n') 338 | evaluator.evaluate(val_loader, W_T=1) 339 | return 340 | 341 | # Optimizer 342 | spec_param_list = ['quan'] 343 | if args.change_lr_mult: 344 | def _key_in_name(name): 345 | for k in spec_param_list: 346 | if k in name: 347 | return True 348 | return False 349 | base_params = [] 350 | base_params_names = [] 351 | spec_params = [] 352 | spec_params_names = [] 353 | 354 | for name, param in model.named_parameters(): 355 | if _key_in_name(name): 356 | spec_params.append(param) 357 | spec_params_names.append(name) 358 | else: 359 | base_params.append(param) 360 | base_params_names.append(name) 361 | print('############# base params ################') 362 | print(base_params_names) 363 | print('lr_mult: {}'.format(args.base_lr_mult)) 364 | print('############# base params ################') 365 | print('############# spec params ################') 366 | print(spec_params_names) 367 | print('lr_mult: {}'.format(args.spec_lr_mult)) 368 | print('############# spec params ################') 369 | param_groups = [ 370 | {'params': base_params, 'lr_mult': args.base_lr_mult}, 371 | {'params': spec_params, 'lr_mult': args.spec_lr_mult}] 372 | else: 373 | param_groups = model.parameters() 374 | 375 | if args.adam: 376 | print('The optimizer is Adam !!!') 377 | optimizer = torch.optim.Adam(param_groups, lr=args.lr, 378 | weight_decay=args.weight_decay) 379 | optimizer_alpha = torch.optim.Adam(alpha, lr=args.lr) 380 | optimizer_beta = torch.optim.Adam(beta, lr=args.lr) 381 | else: 382 | print('The optimizer is SGD !!!') 383 | optimizer = torch.optim.SGD(param_groups, lr=args.lr, 384 | momentum=args.momentum, 385 | weight_decay=args.weight_decay) 386 | optimizer_alpha = torch.optim.SGD(alpha, lr=args.lr, 387 | momentum=args.momentum) 388 | optimizer_beta = torch.optim.SGD(beta, lr=args.lr, 389 | momentum=args.momentum) 390 | 391 | # Trainer 392 | trainer = Trainer(model, criterion, alpha, beta) 393 | 394 | # Schedule learning rate 395 | def adjust_lr(epoch): 396 | step_size = args.step_size 397 | decay_step = args.decay_step 398 | lr = args.lr if epoch < step_size else \ 399 | args.lr * (0.1 ** ((epoch - step_size) // decay_step + 1)) 400 | for g in optimizer.param_groups: 401 | g['lr'] = lr * g.get('lr_mult', 1) 402 | for k in optimizer_alpha.param_groups: 403 | k['lr'] = lr * args.base_lr_mult 404 | for m in optimizer_beta.param_groups: 405 | m['lr'] = lr * args.base_lr_mult 406 | return lr 407 | # Start training 408 | trainer.show_info(with_arch=True, with_grad=False) 409 | for epoch in range(start_epoch, args.epochs): 410 | lr = adjust_lr(epoch) 411 | w_t = (epoch + 1) * args.temperature_W # linear 412 | 413 | ac_t = (epoch + 1) * args.temperature_A # linear 414 | print('lr={}, W_T={}, A_T={}'.format(lr, w_t, ac_t)) 415 | 416 | trainer.train(epoch, train_loader, optimizer, optimizer_alpha, 417 | optimizer_beta, W_T=w_t, ac_T=ac_t, print_info=args.print_info) 418 | if epoch < args.start_save: 419 | continue 420 | top1 = evaluator.evaluate(val_loader, W_T=w_t) 421 | 422 | is_best = top1 > best_top1 423 | best_top1 = max(top1, best_top1) 424 | save_checkpoint({ 425 | 'state_dict':model.module.state_dict(), 426 | 'optimizer': optimizer.state_dict(), 427 | 'optimizer_alpha': optimizer_alpha.state_dict(), 428 | 'optimizer_beta': optimizer_beta.state_dict(), 429 | 'alpha': alpha, 430 | 'beta': beta}, 431 | is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 432 | 433 | print('\n * Finished epoch {:3d} top1: {:5.2%} model_best: {:5.2%} \n'. 434 | format(epoch, top1, best_top1)) 435 | 436 | if (epoch+1) % 5 == 0: 437 | model_name = 'epoch_'+ str(epoch) + '.pth.tar' 438 | torch.save({'state_dict':model.module.state_dict(), 439 | 'optimizer': optimizer.state_dict(), 440 | 'optimizer_alpha': optimizer_alpha.state_dict(), 441 | 'optimizer_beta': optimizer_beta.state_dict(), 442 | 'alpha': alpha, 443 | 'beta': beta}, 444 | osp.join(args.logs_dir, model_name)) 445 | 446 | class Trainer(object): 447 | def __init__(self, model, criterion, alpha, beta): 448 | super(Trainer, self).__init__() 449 | self.model = model 450 | self.criterion = criterion 451 | self.alpha = alpha 452 | self.beta = beta 453 | self.init = False 454 | 455 | def train(self, epoch, data_loader, optimizer, optimizer_alpha, 456 | optimizer_beta, W_T=1, ac_T=1, print_freq=1, print_info=10): 457 | self.model.train() 458 | 459 | batch_time = AverageMeter() 460 | data_time = AverageMeter() 461 | losses = AverageMeter() 462 | top1 = AverageMeter() 463 | top5 = AverageMeter() 464 | 465 | end = time.time() 466 | for i, inputs in enumerate(data_loader): 467 | if epoch == 0 and i == 0: 468 | self.init = True 469 | else: 470 | self.init = False 471 | data_time.update(time.time() - end) 472 | 473 | inputs_var, targets_var = self._parse_data(inputs) 474 | 475 | qua_op.quantization(W_T, self.alpha, self.beta, init=self.init) 476 | 477 | loss, prec1, prec5 = self._forward(inputs_var, targets_var, ac_T) 478 | losses.update(loss.data[0], targets_var.size(0)) 479 | top1.update(prec1, targets_var.size(0)) 480 | top5.update(prec5, targets_var.size(0)) 481 | 482 | optimizer.zero_grad() 483 | optimizer_alpha.zero_grad() 484 | optimizer_beta.zero_grad() 485 | loss.backward() 486 | torch.nn.utils.clip_grad_norm(self.model.parameters(), 20.0) 487 | 488 | qua_op.restore_params() 489 | alpha_grad, beta_grad = qua_op.updateQuaGradWeight(W_T, self.alpha, self.beta, init=self.init) 490 | for index in range(len(self.alpha)): 491 | self.alpha[index].grad = Variable(torch.FloatTensor([alpha_grad[index]]).cuda()) 492 | self.beta[index].grad = Variable(torch.FloatTensor([beta_grad[index]]).cuda()) 493 | 494 | optimizer.step() 495 | optimizer_alpha.step() 496 | optimizer_beta.step() 497 | 498 | batch_time.update(time.time() - end) 499 | end = time.time() 500 | 501 | if (i + 1) % print_freq == 0: 502 | print('Epoch: [{}][{}/{}]\t' 503 | 'Time {:.3f} ({:.3f})\t' 504 | 'Data {:.3f} ({:.3f})\t' 505 | 'Loss {:.3f} ({:.3f})\t' 506 | 'Prec@1 {:.2%} ({:.2%})\t' 507 | 'Prec@5 {:.2%} ({:.2%})\t' 508 | .format(epoch, i + 1, len(data_loader), 509 | batch_time.val, batch_time.avg, 510 | data_time.val, data_time.avg, 511 | losses.val, losses.avg, 512 | top1.val, top1.avg, 513 | top5.val, top5.avg)) 514 | #if (epoch+1) % print_info == 0: 515 | # self.show_info() 516 | 517 | def show_info(self, with_arch=False, with_grad=True): 518 | if with_arch: 519 | print('\n\n################# model modules ###################') 520 | for name, m in self.model.named_modules(): 521 | print('{}: {}'.format(name, m)) 522 | print('################# model modules ###################\n\n') 523 | 524 | if with_grad: 525 | print('################# model params diff ###################') 526 | for name, param in self.model.named_parameters(): 527 | mean_value = torch.abs(param.data).mean() 528 | mean_grad = torch.abs(param.grad).mean().data[0] + 1e-8 529 | print('{}: size{}, data_abd_avg: {}, dgrad_abd_avg: {}, data/grad: {}'.format(name, 530 | param.size(), mean_value, mean_grad, mean_value/mean_grad)) 531 | print('################# model params diff ###################\n\n') 532 | 533 | else: 534 | print('################# model params ###################') 535 | for name, param in self.model.named_parameters(): 536 | print('{}: size{}, abs_avg: {}'.format(name, 537 | param.size(), 538 | torch.abs(param.data.cpu()).mean())) 539 | print('################# model params ###################\n\n') 540 | 541 | def _parse_data(self, inputs): 542 | imgs, _, labels = inputs 543 | inputs_var = Variable(imgs) 544 | targets_var = Variable(labels.cuda()) 545 | return inputs_var, targets_var 546 | 547 | def _forward(self, inputs, targets, ac_T): 548 | outputs = self.model(inputs, ac_T) 549 | if isinstance(self.criterion, torch.nn.CrossEntropyLoss): 550 | loss = self.criterion(outputs, targets) 551 | prec1, prec5= accuracy(outputs.data, targets.data, topk=(1,5)) 552 | prec1 = prec1[0] 553 | prec5 = prec5[0] 554 | else: 555 | raise ValueError("Unsupported loss:", self.criterion) 556 | return loss, prec1, prec5 557 | 558 | class Evaluator(object): 559 | def __init__(self, model, criterion, alpha, beta): 560 | super(Evaluator, self).__init__() 561 | self.model = model 562 | self.criterion = criterion 563 | self.alpha = alpha 564 | self.beta = beta 565 | 566 | def evaluate(self, data_loader, W_T=1, print_freq=1): 567 | batch_time = AverageMeter() 568 | losses = AverageMeter() 569 | top1 = AverageMeter() 570 | top5 = AverageMeter() 571 | 572 | self.model.eval() 573 | 574 | end = time.time() 575 | print('alpha: ', self.alpha) 576 | print('beta: ', self.beta) 577 | qua_op.quantization(W_T, self.alpha, self.beta, init=False, train_phase=False) 578 | 579 | for i, inputs in enumerate(data_loader): 580 | inputs_var, targets_var = self._parse_data(inputs) 581 | 582 | loss, prec1, prec5 = self._forward(inputs_var, targets_var) 583 | 584 | losses.update(loss.data[0], targets_var.size(0)) 585 | top1.update(prec1, targets_var.size(0)) 586 | top5.update(prec5, targets_var.size(0)) 587 | 588 | batch_time.update(time.time() - end) 589 | end = time.time() 590 | 591 | if i % print_freq == 0: 592 | print('Test: [{}/{}]\t' 593 | 'Time {:.3f} ({:.3f})\t' 594 | 'Loss {:.4f} ({:.4f})\t' 595 | 'Prec@1 {:.2%} ({:.2%})\t' 596 | 'Prec@5 {:.2%} ({:.2%})\t' 597 | .format(i + 1, len(data_loader), 598 | batch_time.val, batch_time.avg, 599 | losses.val, losses.avg, 600 | top1.val, top1.avg, 601 | top5.val, top5.avg)) 602 | 603 | qua_op.restore_params() 604 | 605 | print(' * Prec@1 {:.2%} Prec@5 {:.2%}'.format(top1.avg, top5.avg)) 606 | 607 | return top1.avg 608 | 609 | def _parse_data(self, inputs): 610 | imgs, _, labels = inputs 611 | inputs_var = Variable(imgs, volatile=True) 612 | targets_var = Variable(labels.cuda(), volatile=True) 613 | return inputs_var, targets_var 614 | 615 | def _forward(self, inputs, targets): 616 | outputs = self.model(inputs, input_ac_T=1) 617 | if isinstance(self.criterion, torch.nn.CrossEntropyLoss): 618 | loss = self.criterion(outputs, targets) 619 | prec1, prec5= accuracy(outputs.data, targets.data, topk=(1,5)) 620 | prec1 = prec1[0] 621 | prec5 = prec5[0] 622 | else: 623 | raise ValueError("Unsupported loss:", self.criterion) 624 | return loss, prec1, prec5 625 | 626 | 627 | if __name__ == '__main__': 628 | parser = argparse.ArgumentParser(description="Softmax loss classification") 629 | # data 630 | parser.add_argument('-d', '--dataset', type=str, default='imagenet') 631 | parser.add_argument('-b', '--batch-size', type=int, default=256) 632 | parser.add_argument('-j', '--workers', type=int, default=4) 633 | parser.add_argument('--split', type=int, default=0) 634 | parser.add_argument('--scale_size', type=int, default=256, 635 | help="val resize image size, default: 256 for ImageNet") 636 | parser.add_argument('--img_size', type=int, default=224, 637 | help="input image size, default: 224 for ImageNet") 638 | # model 639 | parser.add_argument('-a', '--arch', type=str, default='alexnet', 640 | choices=models.names()) 641 | # optimizer 642 | parser.add_argument('--lr', type=float, default=0.001, 643 | help="learning rate of new parameters, for pretrained " 644 | "parameters it is 10 times smaller than this") 645 | parser.add_argument('--momentum', type=float, default=0.9) 646 | parser.add_argument('--weight-decay', type=float, default=1e-5) 647 | parser.add_argument('--step_size', type=int, default=25) 648 | parser.add_argument('--decay_step', type=int, default=25) 649 | 650 | # adjust lr method 651 | parser.add_argument('--spec_lr_mult', type=float, default=1.0) 652 | parser.add_argument('--base_lr_mult', type=float, default=0.1) 653 | parser.add_argument('--change_lr_mult', type=bool, default=True) 654 | 655 | # training configs pretrained_model 656 | parser.add_argument('--ak', type=int, default=1, 657 | help="the bit number of activation quantization, default:1") 658 | parser.add_argument('--qa_biases', type=str, default='') 659 | parser.add_argument('--qa_beta', type=str, default='') 660 | parser.add_argument('--qw_biases', type=str, default='') 661 | parser.add_argument('--pretrained_model', type=str, default='', metavar='PATH') 662 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 663 | parser.add_argument('--resume_epoch', type=int,default=0) 664 | parser.add_argument('--evaluate', action='store_true', 665 | help="evaluation only") 666 | parser.add_argument('--adam', action='store_true', 667 | help="use Adam") 668 | parser.add_argument('--epochs', type=int, default=100) 669 | parser.add_argument('--start_save', type=int, default=0, 670 | help="start saving checkpoints after specific epoch") 671 | parser.add_argument('--seed', type=int, default=1) 672 | parser.add_argument('--print-freq', type=int, default=1) 673 | parser.add_argument('--print-info', type=int, default=10) 674 | parser.add_argument('--temperature_W', type=float, default=10) 675 | parser.add_argument('--temperature_A', type=float, default=10) 676 | # misc 677 | working_dir = osp.dirname(osp.abspath(__file__)) 678 | parser.add_argument('--data-dir', type=str, metavar='PATH', 679 | default=osp.join(working_dir, 'data')) 680 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 681 | default=osp.join(working_dir, 'logs')) 682 | main(parser.parse_args()) 683 | 684 | -------------------------------------------------------------------------------- /quan_weight_main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # quan_weight_main.py is used to train the weight quantized model. 4 | 5 | from __future__ import print_function, absolute_import 6 | import argparse 7 | import os.path as osp 8 | import os 9 | import numpy as np 10 | import sys 11 | import time 12 | import math 13 | import torch 14 | from torch import nn 15 | from torch.autograd import Variable 16 | from torch.backends import cudnn 17 | from torch.utils.data import DataLoader 18 | from torch.utils.data.sampler import RandomSampler 19 | from torch.nn.parameter import Parameter 20 | from torchvision import transforms as T 21 | 22 | from config import * 23 | import models 24 | from data_pre import Lighting, Preprocessor 25 | from utils import Logger, AverageMeter 26 | from utils import load_checkpoint, save_checkpoint 27 | from utils import RandomResized 28 | from anybit import QuaOp 29 | from evaluators import accuracy 30 | import pdb 31 | 32 | # define global qua_op 33 | qua_op = None 34 | 35 | def get_data(split_id, data_dir, img_size, scale_size, batch_size, 36 | workers, train_list, val_list): 37 | root = data_dir 38 | 39 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 40 | std=[0.229, 0.224, 0.225]) # RGB imagenet 41 | # with data augmentation 42 | train_transformer = T.Compose([ 43 | T.Resize(scale_size), 44 | T.RandomCrop(img_size), 45 | #T.RandomResizedCrop(img_size), 46 | T.RandomHorizontalFlip(), 47 | T.ToTensor(), # [0, 255] to [0.0, 1.0] 48 | normalizer, # normalize each channel of the input 49 | ]) 50 | 51 | test_transformer = T.Compose([ 52 | T.Resize(scale_size), 53 | T.CenterCrop(img_size), 54 | T.ToTensor(), 55 | normalizer, 56 | ]) 57 | 58 | train_loader = DataLoader( 59 | Preprocessor(train_list, root=root, 60 | transform=train_transformer), 61 | batch_size=batch_size, num_workers=workers, 62 | sampler=RandomSampler(train_list), 63 | pin_memory=True, drop_last=False) 64 | 65 | val_loader = DataLoader( 66 | Preprocessor(val_list, root=root, 67 | transform=test_transformer), 68 | batch_size=batch_size, num_workers=workers, 69 | shuffle=False, pin_memory=True) 70 | 71 | return train_loader, val_loader 72 | 73 | def main(args): 74 | np.random.seed(args.seed) 75 | torch.manual_seed(args.seed) 76 | torch.cuda.manual_seed(args.seed) 77 | cudnn.benchmark = True 78 | data_dir = osp.join(args.data_dir, args.dataset) 79 | # Redirect print to both console and log file 80 | if not args.evaluate: 81 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 82 | else: 83 | sys.stdout = Logger(osp.join(args.logs_dir, 'evaluate-log.txt')) 84 | print('\n################## setting ###################') 85 | print(parser.parse_args()) 86 | print('################## setting ###################\n') 87 | # Create data loaders 88 | def readlist(fpath): 89 | lines=[] 90 | with open(fpath, 'r') as f: 91 | data = f.readlines() 92 | 93 | for line in data: 94 | name, label = line.split() 95 | lines.append((name, int(label))) 96 | return lines 97 | 98 | # Load data list 99 | if osp.exists(osp.join(data_dir, 'train.txt')): 100 | train_list = readlist(osp.join(data_dir, 'train.txt')) 101 | else: 102 | raise RuntimeError("The training list -- {} doesn't exist".format(train_list)) 103 | 104 | if osp.exists(osp.join(data_dir, 'val.txt')): 105 | val_list = readlist(osp.join(data_dir, 'val.txt')) 106 | else: 107 | raise RuntimeError("The val list -- {} doesn't exist".format(val_list)) 108 | 109 | 110 | if args.scale_size is None : 111 | args.scale_size = 256 112 | if args.img_size is None : 113 | args.img_size = 224 114 | 115 | train_loader, val_loader = \ 116 | get_data(args.split, data_dir, args.img_size, 117 | args.scale_size, args.batch_size, args.workers, 118 | train_list, val_list) 119 | # Create model 120 | #num_classes = 1000 # imagenet 1000 121 | model = models.create(args.arch, False, num_classes=1000) 122 | # create alpha and belta 123 | count = 0 124 | for m in model.modules(): 125 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 126 | count = count + 1 127 | alpha = [] 128 | beta = [] 129 | for i in range(count-2): 130 | alpha.append(Variable(torch.FloatTensor([0.0]).cuda(), requires_grad=True)) 131 | beta.append(Variable(torch.FloatTensor([0.0]).cuda(), requires_grad=True)) 132 | 133 | if args.adam: 134 | print('The optimizer is Adam !!!') 135 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, 136 | weight_decay=args.weight_decay) 137 | optimizer_alpha = torch.optim.Adam(alpha, lr=args.lr) 138 | optimizer_beta = torch.optim.Adam(beta, lr=args.lr) 139 | else: 140 | print('The optimizer is SGD !!!') 141 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, 142 | momentum=args.momentum, 143 | weight_decay=args.weight_decay) 144 | optimizer_alpha = torch.optim.SGD(alpha, lr=args.lr, 145 | momentum=args.momentum) 146 | optimizer_beta = torch.optim.SGD(beta, lr=args.lr, 147 | momentum=args.momentum) 148 | # model Load from checkpoint 149 | start_epoch = best_top1 = 0 150 | if args.pretrained: 151 | print('=> Start load params from pre-trained model...') 152 | checkpoint = load_checkpoint(args.pretrained) 153 | if 'alexnet' in args.arch or 'resnet' in args.arch: 154 | model.load_state_dict(checkpoint) 155 | else: 156 | raise RuntimeError('The arch is ERROR!!!') 157 | if args.resume: 158 | checkpoint = load_checkpoint(args.resume) 159 | model.load_state_dict(checkpoint['state_dict']) 160 | optimizer.load_state_dict(checkpoint['optimizer']) 161 | optimizer_alpha.load_state_dict(checkpoint['optimizer_alpha']) 162 | optimizer_beta.load_state_dict(checkpoint['optimizer_beta']) 163 | alpha = checkpoint['alpha'] 164 | beta = checkpoint['beta'] 165 | start_epoch = args.resume_epoch 166 | print("=> Finetune Start epoch {} " 167 | .format(start_epoch)) 168 | 169 | 170 | model = nn.DataParallel(model).cuda() 171 | 172 | # Criterion 173 | criterion = nn.CrossEntropyLoss().cuda() 174 | 175 | qw_values = QW_values[args.offline_biases] 176 | print('qw_values: ', qw_values) 177 | global qua_op 178 | qua_op = QuaOp(model, QW_biases[args.offline_biases], QW_values=qw_values) 179 | 180 | evaluator = Evaluator(model, criterion, alpha, beta) 181 | if args.evaluate: 182 | print('Test model: \n') 183 | evaluator.evaluate(val_loader, T=1) 184 | return 185 | 186 | # Trainer 187 | trainer = Trainer(model, criterion, alpha, beta) 188 | 189 | # Schedule learning rate 190 | def adjust_lr(epoch): 191 | step_size = args.step_size 192 | decay_step = args.decay_step 193 | lr = args.lr if epoch < step_size else \ 194 | args.lr * (0.1 ** ((epoch - step_size) // decay_step + 1)) 195 | for g in optimizer.param_groups: 196 | g['lr'] = lr * g.get('lr_mult', 1) 197 | 198 | # Start training 199 | trainer.show_info(with_arch=True, with_grad=False) 200 | for epoch in range(start_epoch, args.epochs): 201 | adjust_lr(epoch) 202 | t = (epoch + 1) * args.temperature # linear 203 | print('W_T = ', t) 204 | 205 | trainer.train(epoch, train_loader, optimizer, optimizer_alpha, 206 | optimizer_beta, T=t, print_info=args.print_info) 207 | if epoch < args.start_save: 208 | continue 209 | top1 = evaluator.evaluate(val_loader, T=t) 210 | 211 | is_best = top1 > best_top1 212 | best_top1 = max(top1, best_top1) 213 | save_checkpoint({ 214 | 'state_dict':model.module.state_dict(), 215 | 'optimizer': optimizer.state_dict(), 216 | 'optimizer_alpha': optimizer_alpha.state_dict(), 217 | 'optimizer_beta': optimizer_beta.state_dict(), 218 | 'alpha': alpha, 219 | 'beta': beta}, 220 | is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 221 | 222 | print('\n * Finished epoch {:3d} top1: {:5.2%} model_best: {:5.2%} \n'. 223 | format(epoch, top1, best_top1)) 224 | 225 | if (epoch+1) % 5 == 0: 226 | model_name = 'epoch_'+ str(epoch) + '.pth.tar' 227 | torch.save({'state_dict':model.module.state_dict(), 228 | 'optimizer': optimizer.state_dict(), 229 | 'optimizer_alpha': optimizer_alpha.state_dict(), 230 | 'optimizer_beta': optimizer_beta.state_dict(), 231 | 'alpha': alpha, 232 | 'beta': beta}, 233 | osp.join(args.logs_dir, model_name)) 234 | 235 | class Trainer(object): 236 | def __init__(self, model, criterion, alpha, beta): 237 | super(Trainer, self).__init__() 238 | self.model = model 239 | self.criterion = criterion 240 | self.alpha = alpha 241 | self.beta = beta 242 | self.init = False 243 | 244 | def train(self, epoch, data_loader, optimizer, optimizer_alpha, 245 | optimizer_beta, T=1, print_freq=1, print_info=10): 246 | self.model.train() 247 | 248 | batch_time = AverageMeter() 249 | data_time = AverageMeter() 250 | losses = AverageMeter() 251 | top1 = AverageMeter() 252 | top5 = AverageMeter() 253 | 254 | end = time.time() 255 | for i, inputs in enumerate(data_loader): 256 | if epoch == 0 and i == 0: 257 | self.init = True 258 | else: 259 | self.init = False 260 | data_time.update(time.time() - end) 261 | 262 | inputs_var, targets_var = self._parse_data(inputs) 263 | 264 | qua_op.quantization(T, self.alpha, self.beta, init=self.init) 265 | 266 | loss, prec1, prec5 = self._forward(inputs_var, targets_var) 267 | losses.update(loss.data[0], targets_var.size(0)) 268 | top1.update(prec1, targets_var.size(0)) 269 | top5.update(prec5, targets_var.size(0)) 270 | 271 | optimizer.zero_grad() 272 | optimizer_alpha.zero_grad() 273 | optimizer_beta.zero_grad() 274 | loss.backward() 275 | torch.nn.utils.clip_grad_norm(self.model.parameters(), 5.0) 276 | 277 | qua_op.restore_params() 278 | alpha_grad, beta_grad = qua_op.updateQuaGradWeight(T, self.alpha, self.beta, init=self.init) 279 | for index in range(len(self.alpha)): 280 | self.alpha[index].grad = Variable(torch.FloatTensor([alpha_grad[index]]).cuda()) 281 | self.beta[index].grad = Variable(torch.FloatTensor([beta_grad[index]]).cuda()) 282 | 283 | optimizer.step() 284 | optimizer_alpha.step() 285 | optimizer_beta.step() 286 | 287 | batch_time.update(time.time() - end) 288 | end = time.time() 289 | 290 | if (i + 1) % print_freq == 0: 291 | print('Epoch: [{}][{}/{}]\t' 292 | 'Time {:.3f} ({:.3f})\t' 293 | 'Data {:.3f} ({:.3f})\t' 294 | 'Loss {:.3f} ({:.3f})\t' 295 | 'Prec@1 {:.2%} ({:.2%})\t' 296 | 'Prec@5 {:.2%} ({:.2%})\t' 297 | .format(epoch, i + 1, len(data_loader), 298 | batch_time.val, batch_time.avg, 299 | data_time.val, data_time.avg, 300 | losses.val, losses.avg, 301 | top1.val, top1.avg, 302 | top5.val, top5.avg)) 303 | if (epoch+1) % print_info == 0: 304 | self.show_info() 305 | 306 | def show_info(self, with_arch=False, with_grad=True): 307 | if with_arch: 308 | print('\n\n################# model modules ###################') 309 | for name, m in self.model.named_modules(): 310 | print('{}: {}'.format(name, m)) 311 | print('################# model modules ###################\n\n') 312 | 313 | if with_grad: 314 | print('################# model params diff ###################') 315 | for name, param in self.model.named_parameters(): 316 | mean_value = torch.abs(param.data).mean() 317 | mean_grad = torch.abs(param.grad).mean().data[0] + 1e-8 318 | print('{}: size{}, data_abd_avg: {}, dgrad_abd_avg: {}, data/grad: {}'.format(name, 319 | param.size(), mean_value, mean_grad, mean_value/mean_grad)) 320 | print('################# model params diff ###################\n\n') 321 | 322 | else: 323 | print('################# model params ###################') 324 | for name, param in self.model.named_parameters(): 325 | print('{}: size{}, abs_avg: {}'.format(name, 326 | param.size(), 327 | torch.abs(param.data.cpu()).mean())) 328 | print('################# model params ###################\n\n') 329 | 330 | def _parse_data(self, inputs): 331 | imgs, _, labels = inputs 332 | inputs_var = [Variable(imgs)] 333 | targets_var = Variable(labels.cuda()) 334 | return inputs_var, targets_var 335 | 336 | def _forward(self, inputs, targets): 337 | outputs = self.model(*inputs) 338 | if isinstance(self.criterion, torch.nn.CrossEntropyLoss): 339 | loss = self.criterion(outputs, targets) 340 | prec1, prec5= accuracy(outputs.data, targets.data, topk=(1,5)) 341 | prec1 = prec1[0] 342 | prec5 = prec5[0] 343 | else: 344 | raise ValueError("Unsupported loss:", self.criterion) 345 | return loss, prec1, prec5 346 | 347 | class Evaluator(object): 348 | def __init__(self, model, criterion, alpha, beta): 349 | super(Evaluator, self).__init__() 350 | self.model = model 351 | self.criterion = criterion 352 | self.alpha = alpha 353 | self.beta = beta 354 | 355 | def evaluate(self, data_loader, T=1, print_freq=1): 356 | batch_time = AverageMeter() 357 | losses = AverageMeter() 358 | top1 = AverageMeter() 359 | top5 = AverageMeter() 360 | 361 | self.model.eval() 362 | 363 | end = time.time() 364 | print('alpha: ', self.alpha) 365 | print('beta: ', self.beta) 366 | qua_op.quantization(T, self.alpha, self.beta, init=False, train_phase=False) 367 | 368 | for i, inputs in enumerate(data_loader): 369 | inputs_var, targets_var = self._parse_data(inputs) 370 | 371 | loss, prec1, prec5 = self._forward(inputs_var, targets_var) 372 | 373 | losses.update(loss.data[0], targets_var.size(0)) 374 | top1.update(prec1, targets_var.size(0)) 375 | top5.update(prec5, targets_var.size(0)) 376 | 377 | batch_time.update(time.time() - end) 378 | end = time.time() 379 | 380 | if i % print_freq == 0: 381 | print('Test: [{}/{}]\t' 382 | 'Time {:.3f} ({:.3f})\t' 383 | 'Loss {:.4f} ({:.4f})\t' 384 | 'Prec@1 {:.2%} ({:.2%})\t' 385 | 'Prec@5 {:.2%} ({:.2%})\t' 386 | .format(i + 1, len(data_loader), 387 | batch_time.val, batch_time.avg, 388 | losses.val, losses.avg, 389 | top1.val, top1.avg, 390 | top5.val, top5.avg)) 391 | 392 | qua_op.restore_params() 393 | 394 | print(' * Prec@1 {:.2%} Prec@5 {:.2%}'.format(top1.avg, top5.avg)) 395 | 396 | return top1.avg 397 | 398 | def _parse_data(self, inputs): 399 | imgs, _, labels = inputs 400 | inputs_var = [Variable(imgs, volatile=True)] 401 | targets_var = Variable(labels.cuda(), volatile=True) 402 | return inputs_var, targets_var 403 | 404 | def _forward(self, inputs, targets): 405 | outputs = self.model(*inputs) 406 | if isinstance(self.criterion, torch.nn.CrossEntropyLoss): 407 | loss = self.criterion(outputs, targets) 408 | prec1, prec5= accuracy(outputs.data, targets.data, topk=(1,5)) 409 | prec1 = prec1[0] 410 | prec5 = prec5[0] 411 | else: 412 | raise ValueError("Unsupported loss:", self.criterion) 413 | return loss, prec1, prec5 414 | 415 | 416 | if __name__ == '__main__': 417 | parser = argparse.ArgumentParser(description="Softmax loss classification") 418 | # data 419 | parser.add_argument('-d', '--dataset', type=str, default='imagenet') 420 | parser.add_argument('-b', '--batch-size', type=int, default=256) 421 | parser.add_argument('-j', '--workers', type=int, default=4) 422 | parser.add_argument('--split', type=int, default=0) 423 | parser.add_argument('--scale_size', type=int, default=256, 424 | help="val resize image size, default: 256 for ImageNet") 425 | parser.add_argument('--img_size', type=int, default=224, 426 | help="input image size, default: 224 for ImageNet") 427 | # model 428 | parser.add_argument('-a', '--arch', type=str, default='alexnet', 429 | choices=models.names()) 430 | # optimizer 431 | parser.add_argument('--lr', type=float, default=0.001, 432 | help="learning rate of new parameters, for pretrained " 433 | "parameters it is 10 times smaller than this") 434 | parser.add_argument('--momentum', type=float, default=0.9) 435 | parser.add_argument('--weight-decay', type=float, default=1e-5) 436 | parser.add_argument('--step_size', type=int, default=25) 437 | parser.add_argument('--decay_step', type=int, default=25) 438 | 439 | # training configs pretrained_model 440 | parser.add_argument('--pretrained', type=str, default='', metavar='PATH') 441 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 442 | parser.add_argument('--resume_epoch', type=int,default=0) 443 | parser.add_argument('--evaluate', action='store_true', 444 | help="evaluation only") 445 | parser.add_argument('--adam', action='store_true', 446 | help="use Adam") 447 | parser.add_argument('--epochs', type=int, default=100) 448 | parser.add_argument('--start_save', type=int, default=0, 449 | help="start saving checkpoints after specific epoch") 450 | parser.add_argument('--seed', type=int, default=1) 451 | parser.add_argument('--print-freq', type=int, default=1) 452 | parser.add_argument('--print-info', type=int, default=10) 453 | parser.add_argument('--temperature', type=int, default=10) 454 | parser.add_argument('--offline_biases', type=str, default='') 455 | # misc 456 | working_dir = osp.dirname(osp.abspath(__file__)) 457 | parser.add_argument('--data-dir', type=str, metavar='PATH', 458 | default=osp.join(working_dir, 'data')) 459 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 460 | default=osp.join(working_dir, 'logs')) 461 | main(parser.parse_args()) 462 | 463 | -------------------------------------------------------------------------------- /tools/cluster.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # cluster.py is used to get the bias(b_i) of quantization function. 4 | 5 | from __future__ import print_function, absolute_import 6 | import argparse 7 | import os.path as osp 8 | 9 | import os 10 | 11 | import sys 12 | import pdb 13 | 14 | import math 15 | import numpy as np 16 | import matplotlib 17 | import matplotlib.pyplot as plt 18 | 19 | from sklearn.cluster import KMeans 20 | 21 | 22 | def params_cluster(params, Q_values): 23 | # print("The max and min values of params: ", params.max(), params.min()) 24 | # print("The shape of params: ", params.shape) 25 | 26 | max_value = abs(params).max().tolist() 27 | # print("max_abs_value: ", max_value) 28 | 29 | quan_values = Q_values 30 | threshold = quan_values[-1]*5/4.0 31 | # print("scale threshold: ", threshold) 32 | pre_params = np.sort(params.reshape(-1, 1), axis = 0) 33 | pre_params = pre_params* (threshold/max_value) 34 | 35 | # cluster 36 | n_clusters = len(quan_values) 37 | estimator = KMeans(n_clusters=n_clusters) 38 | estimator.fit(pre_params) 39 | label_pred = estimator.labels_ 40 | centroids = estimator.cluster_centers_ 41 | 42 | #print("cluster_centers: ", centroids) 43 | #print("label_pred: ", label_pred) 44 | 45 | temp = label_pred[0] 46 | saved_index = [0]*(n_clusters - 1) 47 | j = 0 48 | for index, i in enumerate(label_pred): 49 | if i != temp: 50 | saved_index[j] = index 51 | j += 1 52 | temp = i 53 | 54 | # print("boundary_index: ", saved_index) 55 | 56 | # print(pre_params[saved_index[0]-1], pre_params[saved_index[0]]) 57 | # print(pre_params[saved_index[1]-1], pre_params[saved_index[1]]) 58 | 59 | boundary = [0]*(n_clusters - 1) 60 | for i in range(n_clusters - 1): 61 | temp = (pre_params[saved_index[i] - 1] + pre_params[saved_index[i]]) / 2 62 | boundary[i] = temp.tolist()[0] 63 | # print("boundary: ", boundary) 64 | return boundary 65 | 66 | def main(args): 67 | Q_values = [-4, -2, -1, 0, 1, 2, 4] 68 | #Q_values = [-2, -1, 0, 1, 2] 69 | #Q_values = [-1, 0, 1] 70 | 71 | 72 | all_file = sorted(os.listdir(args.root)) 73 | for filename in all_file: 74 | if '.npy' in filename: 75 | params_road = osp.join(args.root, filename) 76 | params = np.load(params_road) 77 | boundary = params_cluster(params, Q_values) 78 | print(filename, boundary) 79 | 80 | if __name__ == '__main__': 81 | parser = argparse.ArgumentParser(description="Parameter cluster") 82 | #file road 83 | parser.add_argument('-r', '--root', type=str, default=".") 84 | 85 | main(parser.parse_args()) 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1 python main.py -a resnet18 -b 256 -d imagenet \ 2 | --img_size 224 -j 16 --weight-decay 1e-4 --lr 0.1 \ 3 | --step_size 40 --decay_step 25 --epochs 110 \ 4 | --start_save 0 --print-info 1 \ 5 | --logs-dir logs/baseline/resnet18-relu6-low-aug 6 | #--resume logs/baseline/alexnet/epoch_79.pth.tar \ 7 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import print_function, absolute_import 5 | 6 | import os 7 | import os.path as osp 8 | import sys 9 | import errno 10 | from PIL import Image 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | import shutil 16 | import random 17 | 18 | from torchvision.transforms import * 19 | 20 | def to_numpy(tensor): 21 | if torch.is_tensor(tensor): 22 | return tensor.cpu().numpy() 23 | elif type(tensor).__module__ != 'numpy': 24 | raise ValueError("Cannot convert {} to numpy array" 25 | .format(type(tensor))) 26 | return tensor 27 | 28 | 29 | def to_torch(ndarray): 30 | if type(ndarray).__module__ == 'numpy': 31 | return torch.from_numpy(ndarray) 32 | elif not torch.is_tensor(ndarray): 33 | raise ValueError("Cannot convert {} to torch tensor" 34 | .format(type(ndarray))) 35 | return ndarray 36 | 37 | def mkdir_if_missing(dir_path): 38 | try: 39 | os.makedirs(dir_path) 40 | except OSError as e: 41 | if e.errno != errno.EEXIST: 42 | raise 43 | 44 | def save_checkpoint(state, is_best, fpath = 'checkpoint.pth.tar'): 45 | mkdir_if_missing(osp.dirname(fpath)) 46 | torch.save(state, fpath) 47 | if is_best: 48 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 49 | 50 | def load_checkpoint(fpath): 51 | if osp.isfile(fpath): 52 | checkpoint = torch.load(fpath) 53 | print("==> Loaded checkpoint '{}'".format(fpath)) 54 | return checkpoint 55 | else: 56 | raise ValueError("==> No checkpoint found at '{}'".format(fpath)) 57 | 58 | 59 | 60 | class AverageMeter(object): 61 | """Computes and stores the average and current value.""" 62 | def __init__(self): 63 | self.reset() 64 | 65 | def reset(self): 66 | self.val = 0.0 67 | self.avg = 0.0 68 | self.sum = 0.0 69 | self.count = 0 70 | 71 | def update(self, val, n=1): 72 | self.val = val 73 | self.sum += val*n 74 | self.count += n 75 | self.avg = self.sum / self.count 76 | 77 | class Logger(object): 78 | def __init__(self, fpath=None): 79 | self.console = sys.stdout 80 | self.file = None 81 | if fpath is not None: 82 | mkdir_if_missing(osp.dirname(fpath)) 83 | self.file = open(fpath, 'w') 84 | 85 | def __del__(self): 86 | self.close() 87 | 88 | def __enter__(self): 89 | pass 90 | 91 | def __exit__(self, *args): 92 | self.close() 93 | 94 | def write(self, msg): 95 | self.console.write(msg) 96 | if self.file is not None: 97 | self.file.write(msg) 98 | 99 | def flush(self): 100 | self.console.flush() 101 | if self.file is not None: 102 | self.file.flush() 103 | os.fsync(self.file.fileno()) 104 | 105 | def close(self): 106 | self.console.close() 107 | if self.file is not None: 108 | self.file.close() 109 | 110 | class RandomResized(object): 111 | def __init__(self, min_size, max_size): 112 | self.min_size = min_size 113 | self.max_size = max_size 114 | 115 | def __call__(self, img): 116 | scale_size = random.randint(self.min_size, self.max_size) 117 | scale = Resize(scale_size) 118 | return scale(img) 119 | 120 | --------------------------------------------------------------------------------