├── .gitignore ├── LICENSE ├── README.md ├── TernaryQuantization.pdf ├── data.py ├── logs ├── quantized_wp_wn_trainable.txt └── quantized_wp_wn_trainable_v2.txt ├── main_autoquantize.py ├── main_original.py ├── main_ternary.py ├── model.py ├── quantification.py └── weights ├── autoquantize.ckpt ├── original.ckpt └── quantized.ckpt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | mnist/** 106 | 107 | **/.DS_Store 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Vinay Sisodia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ternary quantization 2 | Training models with ternary quantized weights. PyTorch implementation of https://arxiv.org/abs/1612.01064 3 | 4 | ### Work in progress 5 | - [x] Train MNIST model in original format (float32) 6 | - [x] Train MNIST model with quantized weights 7 | - [x] Add training logs 8 | - [ ] Analyze quantized weights 9 | - [ ] Quantize weights keeping w_p and w_n fixed 10 | 11 | ### Repo Guide 12 | - A simple model (`model_full`) defined in `model.py` was trained on MNIST data using full precision weights. The trained weight is stored as `weights/original.ckpt`. 13 | - Code for training can be found under `main_original.py`. 14 | - A copy of the above model (loaded with trained weights) was created (`model_to_quantify`) and was trained using quantization. The trained weight is stored as `weights/quantized.ckpt`. 15 | - Code for training can be found under `main_ternary.py`. The logs can be found inside the file `logs/quantized_wp_wn_trainable.txt`. 16 | - I also tried updating the weights __by an equal amount__ in the direction of their gradients. In other words, I took the sign of every parameter's gradient and updated the parameter by a small value (`0.001`) like so: 17 | `param.grad.data = torch.sign(param.grad.data) * 0.001` 18 | - I got decent results but didn't dig deeper into it. The weights for this model are `weights/autoquantize.ckpt`. 19 | 20 | ### Notes: 21 | - Full precision model gives an accuracy of 98.8% 22 | - Quantized model gives an accuracy of as high as 98.52% 23 | - I slightly changed the way gradients are calculated. Using mean instead of sum in lines 15 an 16, `quantification.py` gave better results: 24 | ```python 25 | w_p_grad = (a * grad_data).mean() # not (a * grad_data).sum() 26 | w_n_grad = (b * grad_data).mean() # not (b * grad_data).sum() 27 | ``` 28 | -------------------------------------------------------------------------------- /TernaryQuantization.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinsis/ternary-quantization/2a8349c3773e89735d46bdcbeb44f5a62278b239/TernaryQuantization.pdf -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | from torchvision import transforms 5 | 6 | batch_size = 100 7 | mnist_folder = os.path.join(os.path.dirname(__file__), 'mnist') 8 | train_dataset = torchvision.datasets.MNIST(root=mnist_folder, 9 | train=True, 10 | transform=transforms.ToTensor(), 11 | download=True) 12 | 13 | test_dataset = torchvision.datasets.MNIST(root=mnist_folder, 14 | train=False, 15 | transform=transforms.ToTensor()) 16 | 17 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 18 | batch_size=batch_size, 19 | shuffle=True) 20 | 21 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 22 | batch_size=batch_size, 23 | shuffle=False) 24 | -------------------------------------------------------------------------------- /logs/quantized_wp_wn_trainable.txt: -------------------------------------------------------------------------------- 1 | === Epoch 0 === 2 | Iteration 10, loss: 0.08856464177370071 3 | Accuracy on 10000 images: 88.29 % 4 | Iteration 20, loss: 1.9434486627578735 5 | Accuracy on 10000 images: 91.96 % 6 | Iteration 30, loss: 0.0001876211172202602 7 | Accuracy on 10000 images: 94.77 % 8 | Iteration 40, loss: 1.0599431991577148 9 | Accuracy on 10000 images: 95.9 % 10 | Iteration 50, loss: 0.9612407088279724 11 | Accuracy on 10000 images: 96.45 % 12 | Iteration 60, loss: 0.37571418285369873 13 | Saving model now! 14 | Accuracy on 10000 images: 96.55 % 15 | Iteration 70, loss: 0.9669553637504578 16 | Saving model now! 17 | Accuracy on 10000 images: 96.64 % 18 | Iteration 80, loss: 1.6495921611785889 19 | Accuracy on 10000 images: 87.27 % 20 | Iteration 90, loss: 1.9698643684387207 21 | Accuracy on 10000 images: 94.48 % 22 | Iteration 100, loss: 0.23592466115951538 23 | Accuracy on 10000 images: 95.45 % 24 | Iteration 110, loss: 0.47355014085769653 25 | Accuracy on 10000 images: 95.66 % 26 | Iteration 120, loss: 1.3091685771942139 27 | Accuracy on 10000 images: 95.69 % 28 | Iteration 130, loss: 1.2354906797409058 29 | Accuracy on 10000 images: 92.67 % 30 | Iteration 140, loss: 0.6938989758491516 31 | Accuracy on 10000 images: 94.69 % 32 | Iteration 150, loss: 2.0067062377929688 33 | Accuracy on 10000 images: 95.67 % 34 | Iteration 160, loss: 1.4914294481277466 35 | Accuracy on 10000 images: 95.98 % 36 | Iteration 170, loss: 0.2864416837692261 37 | Accuracy on 10000 images: 96.15 % 38 | Iteration 180, loss: 1.453955888748169 39 | Accuracy on 10000 images: 96.24 % 40 | Iteration 190, loss: 1.3282577991485596 41 | Accuracy on 10000 images: 96.13 % 42 | Iteration 200, loss: 0.16358226537704468 43 | Accuracy on 10000 images: 96.04 % 44 | Iteration 210, loss: 1.0136550664901733 45 | Accuracy on 10000 images: 96.15 % 46 | Iteration 220, loss: 0.7128799557685852 47 | Accuracy on 10000 images: 96.16 % 48 | Iteration 230, loss: 0.751685619354248 49 | Accuracy on 10000 images: 96.22 % 50 | Iteration 240, loss: 0.6268892884254456 51 | Accuracy on 10000 images: 96.31 % 52 | Iteration 250, loss: 0.2759740352630615 53 | Accuracy on 10000 images: 96.29 % 54 | Iteration 260, loss: 0.3361165523529053 55 | Accuracy on 10000 images: 96.45 % 56 | Iteration 270, loss: 1.03572416305542 57 | Saving model now! 58 | Accuracy on 10000 images: 96.71 % 59 | Iteration 280, loss: 0.6482753157615662 60 | Saving model now! 61 | Accuracy on 10000 images: 96.7 % 62 | Iteration 290, loss: 0.20610004663467407 63 | Saving model now! 64 | Accuracy on 10000 images: 96.64 % 65 | Iteration 300, loss: 0.13605904579162598 66 | Saving model now! 67 | Accuracy on 10000 images: 96.69 % 68 | Iteration 310, loss: 0.27332818508148193 69 | Saving model now! 70 | Accuracy on 10000 images: 96.8 % 71 | Iteration 320, loss: 0.3908751606941223 72 | Saving model now! 73 | Accuracy on 10000 images: 96.82 % 74 | Iteration 330, loss: 0.6515353918075562 75 | Saving model now! 76 | Accuracy on 10000 images: 97.01 % 77 | Iteration 340, loss: 1.0389719009399414 78 | Saving model now! 79 | Accuracy on 10000 images: 97.03 % 80 | Iteration 350, loss: 6.0358048358466476e-05 81 | Saving model now! 82 | Accuracy on 10000 images: 96.91 % 83 | Iteration 360, loss: 0.32066208124160767 84 | Saving model now! 85 | Accuracy on 10000 images: 96.85 % 86 | Iteration 370, loss: 0.023409880697727203 87 | Saving model now! 88 | Accuracy on 10000 images: 96.83 % 89 | Iteration 380, loss: 0.04021806642413139 90 | Saving model now! 91 | Accuracy on 10000 images: 97.1 % 92 | Iteration 390, loss: 0.1939018815755844 93 | Saving model now! 94 | Accuracy on 10000 images: 97.11 % 95 | Iteration 400, loss: 0.7400006651878357 96 | Saving model now! 97 | Accuracy on 10000 images: 97.14 % 98 | Iteration 410, loss: 0.08159051835536957 99 | Saving model now! 100 | Accuracy on 10000 images: 97.29 % 101 | Iteration 420, loss: 0.47149860858917236 102 | Saving model now! 103 | Accuracy on 10000 images: 97.38 % 104 | Iteration 430, loss: 0.17372164130210876 105 | Saving model now! 106 | Accuracy on 10000 images: 97.49 % 107 | Iteration 440, loss: 0.4086321294307709 108 | Saving model now! 109 | Accuracy on 10000 images: 97.23 % 110 | Iteration 450, loss: 0.10978679358959198 111 | Saving model now! 112 | Accuracy on 10000 images: 97.1 % 113 | Iteration 460, loss: 0.04596172645688057 114 | Saving model now! 115 | Accuracy on 10000 images: 97.05 % 116 | Iteration 470, loss: 0.4255121648311615 117 | Saving model now! 118 | Accuracy on 10000 images: 97.41 % 119 | Iteration 480, loss: 0.7094628214836121 120 | Saving model now! 121 | Accuracy on 10000 images: 97.49 % 122 | Iteration 490, loss: 1.1689056158065796 123 | Saving model now! 124 | Accuracy on 10000 images: 97.45 % 125 | Iteration 500, loss: 0.06582614779472351 126 | Saving model now! 127 | Accuracy on 10000 images: 97.39 % 128 | Iteration 510, loss: 0.41349831223487854 129 | Saving model now! 130 | Accuracy on 10000 images: 97.56 % 131 | Iteration 520, loss: 0.5745459794998169 132 | Saving model now! 133 | Accuracy on 10000 images: 97.48 % 134 | Iteration 530, loss: 0.6338735818862915 135 | Saving model now! 136 | Accuracy on 10000 images: 97.4 % 137 | Iteration 540, loss: 0.1614023596048355 138 | Saving model now! 139 | Accuracy on 10000 images: 97.49 % 140 | Iteration 550, loss: 0.4620133340358734 141 | Saving model now! 142 | Accuracy on 10000 images: 97.51 % 143 | Iteration 560, loss: 0.17199371755123138 144 | Saving model now! 145 | Accuracy on 10000 images: 97.54 % 146 | Iteration 570, loss: 0.3327215313911438 147 | Saving model now! 148 | Accuracy on 10000 images: 97.63 % 149 | Iteration 580, loss: 0.5757023692131042 150 | Saving model now! 151 | Accuracy on 10000 images: 97.61 % 152 | Iteration 590, loss: 1.558800458908081 153 | Saving model now! 154 | Accuracy on 10000 images: 97.59 % 155 | Iteration 600, loss: 0.40290412306785583 156 | Saving model now! 157 | Accuracy on 10000 images: 97.6 % 158 | === Epoch 1 === 159 | Iteration 10, loss: 0.22842007875442505 160 | Saving model now! 161 | Accuracy on 10000 images: 97.64 % 162 | Iteration 20, loss: 0.23299770057201385 163 | Saving model now! 164 | Accuracy on 10000 images: 97.57 % 165 | Iteration 30, loss: 0.5452508330345154 166 | Saving model now! 167 | Accuracy on 10000 images: 97.71 % 168 | Iteration 40, loss: 0.15451961755752563 169 | Saving model now! 170 | Accuracy on 10000 images: 97.66 % 171 | Iteration 50, loss: 0.38366585969924927 172 | Saving model now! 173 | Accuracy on 10000 images: 97.66 % 174 | Iteration 60, loss: 0.7998814582824707 175 | Saving model now! 176 | Accuracy on 10000 images: 97.62 % 177 | Iteration 70, loss: 0.24839545786380768 178 | Saving model now! 179 | Accuracy on 10000 images: 97.65 % 180 | Iteration 80, loss: 0.4398971199989319 181 | Saving model now! 182 | Accuracy on 10000 images: 97.55 % 183 | Iteration 90, loss: 0.15301388502120972 184 | Saving model now! 185 | Accuracy on 10000 images: 97.55 % 186 | Iteration 100, loss: 0.14006121456623077 187 | Saving model now! 188 | Accuracy on 10000 images: 97.55 % 189 | Iteration 110, loss: 0.18368420004844666 190 | Saving model now! 191 | Accuracy on 10000 images: 97.6 % 192 | Iteration 120, loss: 0.06166192516684532 193 | Saving model now! 194 | Accuracy on 10000 images: 97.57 % 195 | Iteration 130, loss: 0.36177727580070496 196 | Saving model now! 197 | Accuracy on 10000 images: 97.56 % 198 | Iteration 140, loss: 0.2763940691947937 199 | Saving model now! 200 | Accuracy on 10000 images: 97.62 % 201 | Iteration 150, loss: 0.021834805607795715 202 | Saving model now! 203 | Accuracy on 10000 images: 97.53 % 204 | Iteration 160, loss: 0.5558332204818726 205 | Saving model now! 206 | Accuracy on 10000 images: 97.68 % 207 | Iteration 170, loss: 0.09090041369199753 208 | Saving model now! 209 | Accuracy on 10000 images: 97.81 % 210 | Iteration 180, loss: 0.7641462683677673 211 | Saving model now! 212 | Accuracy on 10000 images: 97.73 % 213 | Iteration 190, loss: 0.013445436023175716 214 | Saving model now! 215 | Accuracy on 10000 images: 97.8 % 216 | Iteration 200, loss: 0.00017448679136577994 217 | Saving model now! 218 | Accuracy on 10000 images: 97.82 % 219 | Iteration 210, loss: 0.09267725050449371 220 | Saving model now! 221 | Accuracy on 10000 images: 97.81 % 222 | Iteration 220, loss: 1.257249116897583 223 | Saving model now! 224 | Accuracy on 10000 images: 97.79 % 225 | Iteration 230, loss: 5.364418029785156e-06 226 | Saving model now! 227 | Accuracy on 10000 images: 97.75 % 228 | Iteration 240, loss: 0.033226992934942245 229 | Saving model now! 230 | Accuracy on 10000 images: 97.77 % 231 | Iteration 250, loss: 0.46451255679130554 232 | Saving model now! 233 | Accuracy on 10000 images: 97.7 % 234 | Iteration 260, loss: 0.31817829608917236 235 | Saving model now! 236 | Accuracy on 10000 images: 97.57 % 237 | Iteration 270, loss: 0.002036280697211623 238 | Saving model now! 239 | Accuracy on 10000 images: 97.58 % 240 | Iteration 280, loss: 0.19252490997314453 241 | Saving model now! 242 | Accuracy on 10000 images: 97.64 % 243 | Iteration 290, loss: 1.198459506034851 244 | Saving model now! 245 | Accuracy on 10000 images: 97.59 % 246 | Iteration 300, loss: 0.11013296991586685 247 | Saving model now! 248 | Accuracy on 10000 images: 97.62 % 249 | Iteration 310, loss: 1.0312072038650513 250 | Saving model now! 251 | Accuracy on 10000 images: 97.51 % 252 | Iteration 320, loss: 0.1421511471271515 253 | Saving model now! 254 | Accuracy on 10000 images: 97.49 % 255 | Iteration 330, loss: 0.31679967045783997 256 | Saving model now! 257 | Accuracy on 10000 images: 97.5 % 258 | Iteration 340, loss: 0.5681060552597046 259 | Saving model now! 260 | Accuracy on 10000 images: 97.53 % 261 | Iteration 350, loss: 8.335709571838379e-05 262 | Saving model now! 263 | Accuracy on 10000 images: 97.51 % 264 | Iteration 360, loss: 0.0031784249003976583 265 | Saving model now! 266 | Accuracy on 10000 images: 97.43 % 267 | Iteration 370, loss: 0.30954718589782715 268 | Saving model now! 269 | Accuracy on 10000 images: 97.54 % 270 | Iteration 380, loss: 1.0090899467468262 271 | Saving model now! 272 | Accuracy on 10000 images: 97.71 % 273 | Iteration 390, loss: 0.6651298403739929 274 | Saving model now! 275 | Accuracy on 10000 images: 97.64 % 276 | Iteration 400, loss: 0.0010495162568986416 277 | Saving model now! 278 | Accuracy on 10000 images: 97.67 % 279 | Iteration 410, loss: 0.05260579288005829 280 | Saving model now! 281 | Accuracy on 10000 images: 97.64 % 282 | Iteration 420, loss: 0.29221874475479126 283 | Saving model now! 284 | Accuracy on 10000 images: 97.69 % 285 | Iteration 430, loss: 0.25449228286743164 286 | Saving model now! 287 | Accuracy on 10000 images: 97.75 % 288 | Iteration 440, loss: 0.2852615416049957 289 | Saving model now! 290 | Accuracy on 10000 images: 97.75 % 291 | Iteration 450, loss: 0.19199492037296295 292 | Saving model now! 293 | Accuracy on 10000 images: 97.73 % 294 | Iteration 460, loss: 0.33736559748649597 295 | Saving model now! 296 | Accuracy on 10000 images: 97.69 % 297 | Iteration 470, loss: 0.16470222175121307 298 | Saving model now! 299 | Accuracy on 10000 images: 97.7 % 300 | Iteration 480, loss: 0.00966645497828722 301 | Saving model now! 302 | Accuracy on 10000 images: 97.68 % 303 | Iteration 490, loss: 0.04073469340801239 304 | Saving model now! 305 | Accuracy on 10000 images: 97.77 % 306 | Iteration 500, loss: 0.5707113742828369 307 | Saving model now! 308 | Accuracy on 10000 images: 97.8 % 309 | Iteration 510, loss: 0.3132103681564331 310 | Saving model now! 311 | Accuracy on 10000 images: 97.78 % 312 | Iteration 520, loss: 0.3855304718017578 313 | Saving model now! 314 | Accuracy on 10000 images: 97.72 % 315 | Iteration 530, loss: 0.24326926469802856 316 | Saving model now! 317 | Accuracy on 10000 images: 97.86 % 318 | Iteration 540, loss: 0.22886115312576294 319 | Saving model now! 320 | Accuracy on 10000 images: 97.78 % 321 | Iteration 550, loss: 0.428986132144928 322 | Saving model now! 323 | Accuracy on 10000 images: 97.66 % 324 | Iteration 560, loss: 0.5072702169418335 325 | Saving model now! 326 | Accuracy on 10000 images: 97.65 % 327 | Iteration 570, loss: 0.25695374608039856 328 | Saving model now! 329 | Accuracy on 10000 images: 97.75 % 330 | Iteration 580, loss: 0.15769313275814056 331 | Saving model now! 332 | Accuracy on 10000 images: 97.83 % 333 | Iteration 590, loss: 0.0010339879663661122 334 | Saving model now! 335 | Accuracy on 10000 images: 97.83 % 336 | Iteration 600, loss: 0.20295009016990662 337 | Saving model now! 338 | Accuracy on 10000 images: 97.87 % 339 | 340 | -------------------------------------------------------------------------------- /logs/quantized_wp_wn_trainable_v2.txt: -------------------------------------------------------------------------------- 1 | === Epoch 0 === 2 | Iteration 10, loss: 0.10921064019203186 3 | Accuracy on 10000 images: 91.61 % 4 | Iteration 20, loss: 0.2128966897726059 5 | Accuracy on 10000 images: 95.63 % 6 | Iteration 30, loss: 0.6946740746498108 7 | Accuracy on 10000 images: 97.07 % 8 | Iteration 40, loss: 0.0 9 | Accuracy on 10000 images: 98.04 % 10 | Iteration 50, loss: 0.4621486961841583 11 | Accuracy on 10000 images: 98.35 % 12 | Iteration 60, loss: 0.3064645230770111 13 | Accuracy on 10000 images: 98.19 % 14 | Iteration 70, loss: 0.13775737583637238 15 | Accuracy on 10000 images: 97.8 % 16 | Iteration 80, loss: 0.5893713235855103 17 | Accuracy on 10000 images: 97.97 % 18 | Iteration 90, loss: 0.12024127691984177 19 | Accuracy on 10000 images: 98.27 % 20 | Iteration 100, loss: 0.09390382468700409 21 | Saving model now! 22 | Accuracy on 10000 images: 98.52 % 23 | Iteration 110, loss: 0.0 24 | Saving model now! 25 | Accuracy on 10000 images: 98.52 % 26 | Iteration 120, loss: 0.32539936900138855 27 | Accuracy on 10000 images: 98.04 % 28 | Iteration 130, loss: 0.23605237901210785 29 | Saving model now! 30 | Accuracy on 10000 images: 98.41 % 31 | Iteration 140, loss: 0.09215492010116577 32 | Accuracy on 10000 images: 98.28 % 33 | Iteration 150, loss: 0.5238853693008423 34 | Accuracy on 10000 images: 97.96 % 35 | Iteration 160, loss: 0.10518604516983032 36 | Accuracy on 10000 images: 98.11 % 37 | Iteration 170, loss: 0.2523190677165985 38 | Accuracy on 10000 images: 97.79 % 39 | Iteration 180, loss: 0.00587494857609272 40 | Accuracy on 10000 images: 98.14 % 41 | Iteration 190, loss: 0.05473804473876953 42 | Accuracy on 10000 images: 98.12 % 43 | Iteration 200, loss: 0.04333222284913063 44 | Accuracy on 10000 images: 98.28 % 45 | Iteration 210, loss: 0.42413681745529175 46 | Accuracy on 10000 images: 98.02 % 47 | Iteration 220, loss: 0.030421152710914612 48 | Accuracy on 10000 images: 97.88 % 49 | Iteration 230, loss: 0.2965756356716156 50 | Accuracy on 10000 images: 96.5 % 51 | Iteration 240, loss: 0.0006092119147069752 52 | Accuracy on 10000 images: 98.23 % 53 | Iteration 250, loss: 0.03490591421723366 54 | Saving model now! 55 | Accuracy on 10000 images: 98.43 % 56 | -------------------------------------------------------------------------------- /main_autoquantize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from model import model_auto, device 6 | from data import train_loader, test_loader 7 | 8 | criterion = nn.CrossEntropyLoss() 9 | optimizer = torch.optim.SGD(model_auto.parameters(), lr=1.0) 10 | num_epochs = 2 11 | 12 | def quantize_params(model = model_auto): 13 | for n,p in model.named_parameters(): 14 | p.data = torch.sign(p.data) * 0.01 15 | 16 | # def update_weights(model = model_auto): 17 | # for n,p in model.named_parameters(): 18 | # p.data = p.grad.data * 0.1 19 | 20 | def train(model = model_auto): 21 | total_step = len(train_loader) 22 | for epoch in range(num_epochs): 23 | for i, (images, labels) in enumerate(train_loader): 24 | images = images.to(device) 25 | labels = labels.to(device) 26 | 27 | # Forward pass 28 | outputs = model(images) 29 | loss = criterion(outputs, labels) 30 | 31 | # Backward and optimize 32 | optimizer.zero_grad() 33 | loss.backward() 34 | for param in optimizer.param_groups[0]['params']: 35 | param.grad.data = torch.sign(param.grad.data) * 0.001 36 | optimizer.step() 37 | 38 | if (i+1) % 10 == 0: 39 | print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 40 | .format(epoch+1, num_epochs, i+1, total_step, loss.item())) 41 | if (i+1) % 100 == 0: 42 | test() 43 | test() 44 | 45 | def test(model = model_auto): 46 | model.eval() 47 | with torch.no_grad(): 48 | correct = 0 49 | total = 0 50 | for images, labels in test_loader: 51 | images = images.to(device) 52 | labels = labels.to(device) 53 | outputs = model(images) 54 | _, predicted = torch.max(outputs.data, 1) 55 | total += labels.size(0) 56 | correct += (predicted == labels).sum().item() 57 | 58 | print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 59 | 60 | def save_model(model = model_auto): 61 | dirname = os.path.dirname(__file__) 62 | dirname = os.path.join(dirname, 'weights') 63 | weightname = os.path.join(dirname, '{}.ckpt'.format(model.name)) 64 | torch.save(model.state_dict(), weightname) 65 | 66 | if __name__ == '__main__': 67 | quantize_params() 68 | train() 69 | # test() 70 | # save_model() 71 | -------------------------------------------------------------------------------- /main_original.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from model import model_full, device 6 | from data import train_loader, test_loader 7 | 8 | criterion = nn.CrossEntropyLoss() 9 | optimizer = torch.optim.Adam(model_full.parameters(), lr=0.001) 10 | num_epochs = 2 11 | 12 | def train(model = model_full): 13 | total_step = len(train_loader) 14 | for epoch in range(num_epochs): 15 | for i, (images, labels) in enumerate(train_loader): 16 | images = images.to(device) 17 | labels = labels.to(device) 18 | 19 | # Forward pass 20 | outputs = model(images) 21 | loss = criterion(outputs, labels) 22 | 23 | # Backward and optimize 24 | optimizer.zero_grad() 25 | loss.backward() 26 | optimizer.step() 27 | 28 | if (i+1) % 100 == 0: 29 | print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 30 | .format(epoch+1, num_epochs, i+1, total_step, loss.item())) 31 | 32 | def test(model = model_full): 33 | model.eval() 34 | with torch.no_grad(): 35 | correct = 0 36 | total = 0 37 | for images, labels in test_loader: 38 | images = images.to(device) 39 | labels = labels.to(device) 40 | outputs = model(images) 41 | _, predicted = torch.max(outputs.data, 1) 42 | total += labels.size(0) 43 | correct += (predicted == labels).sum().item() 44 | 45 | print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 46 | 47 | def save_model(model = model_full): 48 | dirname = os.path.dirname(__file__) 49 | dirname = os.path.join(dirname, 'weights') 50 | weightname = os.path.join(dirname, '{}.ckpt'.format(model.name)) 51 | torch.save(model.state_dict(), weightname) 52 | 53 | if __name__ == '__main__': 54 | train() 55 | test() 56 | save_model() 57 | -------------------------------------------------------------------------------- /main_ternary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from model import model_to_quantify, device 6 | from data import train_loader, test_loader 7 | from quantification import quantize, get_quantization_grads 8 | 9 | criterion = nn.CrossEntropyLoss() 10 | num_epochs = 2 11 | 12 | # load model with full precision trained weights 13 | dirname = os.path.dirname(__file__) 14 | dirname = os.path.join(dirname, 'weights') 15 | weightname = os.path.join(dirname, '{}.ckpt'.format('original')) 16 | model_to_quantify.load_state_dict(torch.load(weightname, map_location='cpu')) 17 | 18 | # create a list of parameters that need to be quantized 19 | ''' 20 | Model parameter names and parameter sizes: 21 | [('layer1.0.weight', torch.Size([16, 1, 5, 5])), 22 | ('layer1.0.bias', torch.Size([16])), 23 | ('layer1.1.weight', torch.Size([16])), 24 | ('layer1.1.bias', torch.Size([16])), 25 | ('layer2.0.weight', torch.Size([32, 16, 5, 5])), 26 | ('layer2.0.bias', torch.Size([32])), 27 | ('layer2.1.weight', torch.Size([32])), 28 | ('layer2.1.bias', torch.Size([32])), 29 | ('fc.weight', torch.Size([10, 1568])), 30 | ('fc.bias', torch.Size([10]))] 31 | 32 | layer1.1.* and layer2.1.* correspond to batch normalization layers. 33 | We do not quantize BN layers for now. 34 | ''' 35 | 36 | bn_weights = [ param for name,param in model_to_quantify.named_parameters() if '.1' in name] 37 | weights_to_be_quantized = [ param for name,param in model_to_quantify.named_parameters() if not '.1' in name] 38 | 39 | # store a full precision copy of parameters that need to be quantized 40 | full_precision_copies = [ param.data.clone().requires_grad_().to(device) for param in weights_to_be_quantized ] 41 | 42 | # for each parameter to be quantized, create a trainable tensor of scaling factors (w_p and w_n) 43 | # scaling_factors = torch.ones(len(weights_to_be_quantized), 2, requires_grad=True).to(device) 44 | scaling_factors = [torch.ones(2, requires_grad=True).to(device) for _ in range(len(weights_to_be_quantized))] 45 | 46 | # create optimizers for different parameter groups 47 | 48 | # optimizer for the networks parameters containing quantized and batch norm weights 49 | optimizer_main = torch.optim.Adam( 50 | [{'params': bn_weights}, {'params': weights_to_be_quantized}], 51 | lr=0.001 52 | ) 53 | # optimizers for full precision and scaling factors 54 | optimizer_full_precision_weights = torch.optim.Adam(full_precision_copies, lr=0.001) 55 | optimizer_scaling_factors = torch.optim.Adam(scaling_factors, lr=0.001) 56 | 57 | def train(): 58 | total_step = len(train_loader) 59 | for i, (images, labels) in enumerate(train_loader): 60 | # quantize weights from full precision weights 61 | for index, weight in enumerate(weights_to_be_quantized): 62 | w_p, w_n = scaling_factors[index] 63 | weight.data = quantize(full_precision_copies[index].data, w_p, w_n) 64 | # forward pass 65 | images = images.to(device) 66 | labels = labels.to(device) 67 | 68 | outputs = model_to_quantify(images) 69 | loss = criterion(outputs, labels) 70 | 71 | # backward pass - calculate gradients 72 | optimizer_main.zero_grad() 73 | optimizer_full_precision_weights.zero_grad() 74 | optimizer_scaling_factors.zero_grad() 75 | loss.backward() 76 | 77 | for index, weight in enumerate(weights_to_be_quantized): 78 | w_p, w_n = scaling_factors[index] 79 | full_precision_data = full_precision_copies[index].data 80 | full_precision_grad, w_p_grad, w_n_grad = get_quantization_grads(weight.grad.data, full_precision_data, w_p.item(), w_n.item()) 81 | full_precision_copies[index].grad = full_precision_grad.to(device) 82 | scaling_factors[index].grad = torch.FloatTensor([w_p_grad, w_n_grad]).to(device) 83 | weight.grad.data.zero_() 84 | 85 | if (i+1) % 10 == 0: 86 | print('Iteration {}, loss: {}'.format(i+1, loss.item())) 87 | test() 88 | 89 | optimizer_main.step() 90 | optimizer_full_precision_weights.step() 91 | optimizer_scaling_factors.step() 92 | 93 | 94 | def test(): 95 | model_to_quantify.eval() 96 | with torch.no_grad(): 97 | correct = 0 98 | total = 0 99 | for images, labels in test_loader: 100 | images = images.to(device) 101 | labels = labels.to(device) 102 | outputs = model_to_quantify(images) 103 | _, predicted = torch.max(outputs.data, 1) 104 | total += labels.size(0) 105 | correct += (predicted == labels).sum().item() 106 | accuracy = 100 * correct / total 107 | if accuracy >= 98.4: 108 | print('Saving model now!') 109 | save_model() 110 | print('\tAccuracy on 10000 images: {} %'.format(accuracy)) 111 | 112 | def save_model(model = model_to_quantify): 113 | dirname = os.path.dirname(__file__) 114 | dirname = os.path.join(dirname, 'weights') 115 | weightname = os.path.join(dirname, '{}.ckpt'.format(model.name)) 116 | torch.save(model.state_dict(), weightname) 117 | 118 | if __name__ == '__main__': 119 | assert full_precision_copies[0].requires_grad is True 120 | assert len(weights_to_be_quantized) == len(scaling_factors) 121 | assert len(weights_to_be_quantized) == len(full_precision_copies) 122 | for epoch in range(num_epochs): 123 | print('=== Epoch {} ==='.format(epoch)) 124 | train() 125 | print(scaling_factors) 126 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 5 | 6 | class ConvNet(nn.Module): 7 | def __init__(self, name, num_classes=10): 8 | super(ConvNet, self).__init__() 9 | self.name = name 10 | self.layer1 = nn.Sequential( 11 | nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), 12 | nn.BatchNorm2d(16), 13 | nn.ReLU(), 14 | nn.MaxPool2d(kernel_size=2, stride=2)) 15 | self.layer2 = nn.Sequential( 16 | nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), 17 | nn.BatchNorm2d(32), 18 | nn.ReLU(), 19 | nn.MaxPool2d(kernel_size=2, stride=2)) 20 | self.fc = nn.Linear(7*7*32, num_classes) 21 | 22 | def forward(self, x): 23 | out = self.layer1(x) 24 | out = self.layer2(out) 25 | out = out.reshape(out.size(0), -1) 26 | out = self.fc(out) 27 | return out 28 | 29 | model_full = ConvNet(name='original').to(device) 30 | model_to_quantify = ConvNet(name='quantized').to(device) 31 | 32 | class AutoQuantizedNet(nn.Module): 33 | def __init__(self, name, num_classes=10): 34 | super(AutoQuantizedNet, self).__init__() 35 | self.name = name 36 | self.relu = nn.ReLU() 37 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 38 | self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2) 39 | self.bn1 = nn.BatchNorm2d(16) 40 | self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2) 41 | self.bn2 = nn.BatchNorm2d(32) 42 | self.fc = nn.Linear(7*7*32, num_classes) 43 | 44 | def forward(self, x): 45 | out = self.relu(self.bn1(self.conv1(x))) 46 | out = self.maxpool(out) 47 | out = self.relu(self.bn2(self.conv2(out))) 48 | out = self.maxpool(out) 49 | out = out.reshape(out.size(0), -1) 50 | out = self.fc(out) 51 | return out 52 | 53 | model_auto = AutoQuantizedNet(name='autoquantize').to(device) 54 | -------------------------------------------------------------------------------- /quantification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | THRESHOLD = 0.15 4 | def quantize(tensor_data, w_p, w_n, threshold=THRESHOLD): 5 | delta = tensor_data.abs().max() * threshold 6 | return (tensor_data > delta).float() * w_p + (tensor_data < -delta).float() * -w_n 7 | 8 | def get_quantization_grads(grad_data, full_precision_data, w_p_data, w_n_data, threshold=THRESHOLD): 9 | delta = full_precision_data.abs().max() * threshold 10 | a = (full_precision_data > delta).float() 11 | b = (full_precision_data < -delta).float() 12 | c = torch.ones_like(full_precision_data) - a - b 13 | 14 | full_precision_grad = a * grad_data * w_p_data + b * grad_data * w_n_data + c * grad_data * 1 15 | w_p_grad = (a * grad_data).mean() 16 | w_n_grad = (b * grad_data).mean() 17 | return full_precision_grad, w_p_grad, w_n_grad 18 | -------------------------------------------------------------------------------- /weights/autoquantize.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinsis/ternary-quantization/2a8349c3773e89735d46bdcbeb44f5a62278b239/weights/autoquantize.ckpt -------------------------------------------------------------------------------- /weights/original.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinsis/ternary-quantization/2a8349c3773e89735d46bdcbeb44f5a62278b239/weights/original.ckpt -------------------------------------------------------------------------------- /weights/quantized.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinsis/ternary-quantization/2a8349c3773e89735d46bdcbeb44f5a62278b239/weights/quantized.ckpt --------------------------------------------------------------------------------