├── .gitignore ├── LICENSE ├── README.md ├── adabound ├── __init__.py └── adabound.py ├── demos ├── README.md └── cifar10 │ ├── .gitignore │ ├── README.md │ ├── curve │ └── pretrained │ │ ├── densenet-adabound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001 │ │ ├── densenet-adagrad-lr0.01 │ │ ├── densenet-adam-lr0.001-betas0.9-0.999 │ │ ├── densenet-amsbound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001 │ │ ├── densenet-amsgrad-lr0.001-betas0.9-0.999 │ │ ├── densenet-sgd-lr0.1-momentum0.9 │ │ ├── resnet-adabound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001 │ │ ├── resnet-adagrad-lr0.01 │ │ ├── resnet-adam-lr0.001-betas0.99-0.999 │ │ ├── resnet-amsbound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001 │ │ ├── resnet-amsgrad-lr0.001-betas0.99-0.999 │ │ └── resnet-sgd-lr0.1-momentum0.9 │ ├── main.py │ ├── models │ ├── __init__.py │ ├── densenet.py │ └── resnet.py │ ├── requirements.txt │ └── visualization.ipynb ├── release.sh └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### JetBrains template 3 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 4 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 5 | 6 | # We just ignore all .idea folder 7 | .idea 8 | 9 | # User-specific stuff: 10 | .idea/**/workspace.xml 11 | .idea/**/tasks.xml 12 | .idea/dictionaries 13 | 14 | # Sensitive or high-churn files: 15 | .idea/**/dataSources/ 16 | .idea/**/dataSources.ids 17 | .idea/**/dataSources.local.xml 18 | .idea/**/sqlDataSources.xml 19 | .idea/**/dynamic.xml 20 | .idea/**/uiDesigner.xml 21 | 22 | # Gradle: 23 | .idea/**/gradle.xml 24 | .idea/**/libraries 25 | 26 | # CMake 27 | cmake-build-debug/ 28 | cmake-build-release/ 29 | 30 | # Mongo Explorer plugin: 31 | .idea/**/mongoSettings.xml 32 | 33 | ## File-based project format: 34 | *.iws 35 | 36 | ## Plugin-specific files: 37 | 38 | # IntelliJ 39 | out/ 40 | 41 | # mpeltonen/sbt-idea plugin 42 | .idea_modules/ 43 | 44 | # JIRA plugin 45 | atlassian-ide-plugin.xml 46 | 47 | # Cursive Clojure plugin 48 | .idea/replstate.xml 49 | 50 | # Crashlytics plugin (for Android Studio and IntelliJ) 51 | com_crashlytics_export_strings.xml 52 | crashlytics.properties 53 | crashlytics-build.properties 54 | fabric.properties 55 | ### macOS template 56 | # General 57 | .DS_Store 58 | .AppleDouble 59 | .LSOverride 60 | 61 | # Icon must end with two \r 62 | Icon 63 | 64 | # Thumbnails 65 | ._* 66 | 67 | # Files that might appear in the root of a volume 68 | .DocumentRevisions-V100 69 | .fseventsd 70 | .Spotlight-V100 71 | .TemporaryItems 72 | .Trashes 73 | .VolumeIcon.icns 74 | .com.apple.timemachine.donotpresent 75 | 76 | # Directories potentially created on remote AFP share 77 | .AppleDB 78 | .AppleDesktop 79 | Network Trash Folder 80 | Temporary Items 81 | .apdisk 82 | ### Python template 83 | # Byte-compiled / optimized / DLL files 84 | __pycache__/ 85 | *.py[cod] 86 | *$py.class 87 | 88 | # C extensions 89 | *.so 90 | 91 | # Distribution / packaging 92 | .Python 93 | build/ 94 | develop-eggs/ 95 | dist/ 96 | downloads/ 97 | eggs/ 98 | .eggs/ 99 | lib/ 100 | lib64/ 101 | parts/ 102 | sdist/ 103 | var/ 104 | wheels/ 105 | *.egg-info/ 106 | .installed.cfg 107 | *.egg 108 | MANIFEST 109 | 110 | # PyInstaller 111 | # Usually these files are written by a python script from a template 112 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 113 | *.manifest 114 | *.spec 115 | 116 | # Installer logs 117 | pip-log.txt 118 | pip-delete-this-directory.txt 119 | 120 | # Unit test / coverage reports 121 | htmlcov/ 122 | .tox/ 123 | .coverage 124 | .coverage.* 125 | .cache 126 | nosetests.xml 127 | coverage.xml 128 | *.cover 129 | .hypothesis/ 130 | .pytest_cache/ 131 | 132 | # Translations 133 | *.mo 134 | *.pot 135 | 136 | # Django stuff: 137 | *.log 138 | local_settings.py 139 | db.sqlite3 140 | 141 | # Flask stuff: 142 | instance/ 143 | .webassets-cache 144 | 145 | # Scrapy stuff: 146 | .scrapy 147 | 148 | # Sphinx documentation 149 | docs/_build/ 150 | 151 | # PyBuilder 152 | target/ 153 | 154 | # Jupyter Notebook 155 | .ipynb_checkpoints 156 | 157 | # pyenv 158 | .python-version 159 | 160 | # celery beat schedule file 161 | celerybeat-schedule 162 | 163 | # SageMath parsed files 164 | *.sage.py 165 | 166 | # Environments 167 | .env 168 | .venv 169 | env/ 170 | venv/ 171 | ENV/ 172 | env.bak/ 173 | venv.bak/ 174 | 175 | # Spyder project settings 176 | .spyderproject 177 | .spyproject 178 | 179 | # Rope project settings 180 | .ropeproject 181 | 182 | # mkdocs documentation 183 | /site 184 | 185 | # mypy 186 | .mypy_cache/ 187 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaBound 2 | [![PyPI - Version](https://img.shields.io/pypi/v/adabound.svg?style=flat)](https://pypi.org/project/adabound/) 3 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/adabound.svg)](https://pypi.org/project/adabound/) 4 | [![PyPI - Wheel](https://img.shields.io/pypi/wheel/adabound.svg?style=flat)](https://pypi.org/project/adabound/) 5 | [![GitHub - LICENSE](https://img.shields.io/github/license/Luolc/AdaBound.svg?style=flat)](./LICENSE) 6 | 7 | An optimizer that trains as fast as Adam and as good as SGD, for developing state-of-the-art 8 | deep learning models on a wide variety of popular tasks in the field of CV, NLP, and etc. 9 | 10 | Based on Luo et al. (2019). 11 | [Adaptive Gradient Methods with Dynamic Bound of Learning Rate](https://openreview.net/forum?id=Bkg3g2R9FX). 12 | In *Proc. of ICLR 2019*. 13 | 14 |

15 | 16 |

17 | 18 | ## Quick Links 19 | 20 | - [Website](https://www.luolc.com/publications/adabound/) 21 | - [Demos](./demos) 22 | 23 | ## Installation 24 | 25 | AdaBound requires Python 3.6.0 or later. 26 | We currently provide PyTorch version and AdaBound for TensorFlow is coming soon. 27 | 28 | ### Installing via pip 29 | 30 | The preferred way to install AdaBound is via `pip` with a virtual environment. 31 | Just run 32 | ```bash 33 | pip install adabound 34 | ``` 35 | in your Python environment and you are ready to go! 36 | 37 | ### Using source code 38 | 39 | As AdaBound is a Python class with only 100+ lines, an alternative way is directly downloading 40 | [adabound.py](./adabound/adabound.py) and copying it to your project. 41 | 42 | ## Usage 43 | 44 | You can use AdaBound just like any other PyTorch optimizers. 45 | 46 | ```python3 47 | optimizer = adabound.AdaBound(model.parameters(), lr=1e-3, final_lr=0.1) 48 | ``` 49 | 50 | As described in the paper, AdaBound is an optimizer that behaves like Adam at the beginning of 51 | training, and gradually transforms to SGD at the end. 52 | The `final_lr` parameter indicates AdaBound would transforms to an SGD with this learning rate. 53 | In common cases, a default final learning rate of `0.1` can achieve relatively good and stable 54 | results on unseen data. 55 | It is not very sensitive to its hyperparameters. 56 | See Appendix G of the paper for more details. 57 | 58 | Despite of its robust performance, we still have to state that, **there is no silver bullet**. 59 | It does not mean that you will be free from tuning hyperparameters once using AdaBound. 60 | The performance of a model depends on so many things including the task, the model structure, 61 | the distribution of data, and etc. 62 | **You still need to decide what hyperparameters to use based on your specific situation, 63 | but you may probably use much less time than before!** 64 | 65 | ## Demos 66 | 67 | Thanks to the awesome work by the GitHub team and the Jupyter team, the Jupyter notebook (`.ipynb`) 68 | files can render directly on GitHub. 69 | We provide several notebooks (like [this one](./demos/cifar10/visualization.ipynb)) for better 70 | visualization. 71 | We hope to illustrate the robust performance of AdaBound through these examples. 72 | 73 | For the full list of demos, please refer to [this page](./demos). 74 | 75 | ## Citing 76 | If you use AdaBound in your research, please cite [Adaptive Gradient Methods with Dynamic Bound of Learning Rate](https://openreview.net/forum?id=Bkg3g2R9FX). 77 | ```text 78 | @inproceedings{Luo2019AdaBound, 79 | author = {Luo, Liangchen and Xiong, Yuanhao and Liu, Yan and Sun, Xu}, 80 | title = {Adaptive Gradient Methods with Dynamic Bound of Learning Rate}, 81 | booktitle = {Proceedings of the 7th International Conference on Learning Representations}, 82 | month = {May}, 83 | year = {2019}, 84 | address = {New Orleans, Louisiana} 85 | } 86 | ``` 87 | 88 | ## Contributors 89 | 90 | [@kayuksel](https://github.com/kayuksel) 91 | 92 | ## License 93 | [Apache 2.0](./LICENSE) 94 | -------------------------------------------------------------------------------- /adabound/__init__.py: -------------------------------------------------------------------------------- 1 | from .adabound import AdaBound 2 | -------------------------------------------------------------------------------- /adabound/adabound.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim import Optimizer 4 | 5 | 6 | class AdaBound(Optimizer): 7 | """Implements AdaBound algorithm. 8 | It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_. 9 | Arguments: 10 | params (iterable): iterable of parameters to optimize or dicts defining 11 | parameter groups 12 | lr (float, optional): Adam learning rate (default: 1e-3) 13 | betas (Tuple[float, float], optional): coefficients used for computing 14 | running averages of gradient and its square (default: (0.9, 0.999)) 15 | final_lr (float, optional): final (SGD) learning rate (default: 0.1) 16 | gamma (float, optional): convergence speed of the bound functions (default: 1e-3) 17 | eps (float, optional): term added to the denominator to improve 18 | numerical stability (default: 1e-8) 19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 20 | amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm 21 | .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate: 22 | https://openreview.net/forum?id=Bkg3g2R9FX 23 | """ 24 | 25 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3, 26 | eps=1e-8, weight_decay=0, amsbound=False): 27 | if not 0.0 <= lr: 28 | raise ValueError("Invalid learning rate: {}".format(lr)) 29 | if not 0.0 <= eps: 30 | raise ValueError("Invalid epsilon value: {}".format(eps)) 31 | if not 0.0 <= betas[0] < 1.0: 32 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 33 | if not 0.0 <= betas[1] < 1.0: 34 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 35 | if not 0.0 <= final_lr: 36 | raise ValueError("Invalid final learning rate: {}".format(final_lr)) 37 | if not 0.0 <= gamma < 1.0: 38 | raise ValueError("Invalid gamma parameter: {}".format(gamma)) 39 | defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps, 40 | weight_decay=weight_decay, amsbound=amsbound) 41 | super(AdaBound, self).__init__(params, defaults) 42 | 43 | self.base_lrs = list(map(lambda group: group['lr'], self.param_groups)) 44 | 45 | def __setstate__(self, state): 46 | super(AdaBound, self).__setstate__(state) 47 | for group in self.param_groups: 48 | group.setdefault('amsbound', False) 49 | 50 | def step(self, closure=None): 51 | """Performs a single optimization step. 52 | Arguments: 53 | closure (callable, optional): A closure that reevaluates the model 54 | and returns the loss. 55 | """ 56 | loss = None 57 | if closure is not None: 58 | loss = closure() 59 | 60 | for group, base_lr in zip(self.param_groups, self.base_lrs): 61 | for p in group['params']: 62 | if p.grad is None: 63 | continue 64 | grad = p.grad.data 65 | if grad.is_sparse: 66 | raise RuntimeError( 67 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 68 | amsbound = group['amsbound'] 69 | 70 | state = self.state[p] 71 | 72 | # State initialization 73 | if len(state) == 0: 74 | state['step'] = 0 75 | # Exponential moving average of gradient values 76 | state['exp_avg'] = torch.zeros_like(p.data) 77 | # Exponential moving average of squared gradient values 78 | state['exp_avg_sq'] = torch.zeros_like(p.data) 79 | if amsbound: 80 | # Maintains max of all exp. moving avg. of sq. grad. values 81 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 82 | 83 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 84 | if amsbound: 85 | max_exp_avg_sq = state['max_exp_avg_sq'] 86 | beta1, beta2 = group['betas'] 87 | 88 | state['step'] += 1 89 | 90 | if group['weight_decay'] != 0: 91 | grad = grad.add(group['weight_decay'], p.data) 92 | 93 | # Decay the first and second moment running average coefficient 94 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 95 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 96 | if amsbound: 97 | # Maintains the maximum of all 2nd moment running avg. till now 98 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 99 | # Use the max. for normalizing running avg. of gradient 100 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 101 | else: 102 | denom = exp_avg_sq.sqrt().add_(group['eps']) 103 | 104 | bias_correction1 = 1 - beta1 ** state['step'] 105 | bias_correction2 = 1 - beta2 ** state['step'] 106 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 107 | 108 | # Applies bounds on actual learning rate 109 | # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay 110 | final_lr = group['final_lr'] * group['lr'] / base_lr 111 | lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1)) 112 | upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step'])) 113 | step_size = torch.full_like(denom, step_size) 114 | step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg) 115 | 116 | p.data.add_(-step_size) 117 | 118 | return loss 119 | 120 | class AdaBoundW(Optimizer): 121 | """Implements AdaBound algorithm with Decoupled Weight Decay (arxiv.org/abs/1711.05101) 122 | It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_. 123 | Arguments: 124 | params (iterable): iterable of parameters to optimize or dicts defining 125 | parameter groups 126 | lr (float, optional): Adam learning rate (default: 1e-3) 127 | betas (Tuple[float, float], optional): coefficients used for computing 128 | running averages of gradient and its square (default: (0.9, 0.999)) 129 | final_lr (float, optional): final (SGD) learning rate (default: 0.1) 130 | gamma (float, optional): convergence speed of the bound functions (default: 1e-3) 131 | eps (float, optional): term added to the denominator to improve 132 | numerical stability (default: 1e-8) 133 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 134 | amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm 135 | .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate: 136 | https://openreview.net/forum?id=Bkg3g2R9FX 137 | """ 138 | 139 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3, 140 | eps=1e-8, weight_decay=0, amsbound=False): 141 | if not 0.0 <= lr: 142 | raise ValueError("Invalid learning rate: {}".format(lr)) 143 | if not 0.0 <= eps: 144 | raise ValueError("Invalid epsilon value: {}".format(eps)) 145 | if not 0.0 <= betas[0] < 1.0: 146 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 147 | if not 0.0 <= betas[1] < 1.0: 148 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 149 | if not 0.0 <= final_lr: 150 | raise ValueError("Invalid final learning rate: {}".format(final_lr)) 151 | if not 0.0 <= gamma < 1.0: 152 | raise ValueError("Invalid gamma parameter: {}".format(gamma)) 153 | defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps, 154 | weight_decay=weight_decay, amsbound=amsbound) 155 | super(AdaBoundW, self).__init__(params, defaults) 156 | 157 | self.base_lrs = list(map(lambda group: group['lr'], self.param_groups)) 158 | 159 | def __setstate__(self, state): 160 | super(AdaBoundW, self).__setstate__(state) 161 | for group in self.param_groups: 162 | group.setdefault('amsbound', False) 163 | 164 | def step(self, closure=None): 165 | """Performs a single optimization step. 166 | Arguments: 167 | closure (callable, optional): A closure that reevaluates the model 168 | and returns the loss. 169 | """ 170 | loss = None 171 | if closure is not None: 172 | loss = closure() 173 | 174 | for group, base_lr in zip(self.param_groups, self.base_lrs): 175 | for p in group['params']: 176 | if p.grad is None: 177 | continue 178 | grad = p.grad.data 179 | if grad.is_sparse: 180 | raise RuntimeError( 181 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 182 | amsbound = group['amsbound'] 183 | 184 | state = self.state[p] 185 | 186 | # State initialization 187 | if len(state) == 0: 188 | state['step'] = 0 189 | # Exponential moving average of gradient values 190 | state['exp_avg'] = torch.zeros_like(p.data) 191 | # Exponential moving average of squared gradient values 192 | state['exp_avg_sq'] = torch.zeros_like(p.data) 193 | if amsbound: 194 | # Maintains max of all exp. moving avg. of sq. grad. values 195 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 196 | 197 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 198 | if amsbound: 199 | max_exp_avg_sq = state['max_exp_avg_sq'] 200 | beta1, beta2 = group['betas'] 201 | 202 | state['step'] += 1 203 | 204 | # Decay the first and second moment running average coefficient 205 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 206 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 207 | if amsbound: 208 | # Maintains the maximum of all 2nd moment running avg. till now 209 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 210 | # Use the max. for normalizing running avg. of gradient 211 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 212 | else: 213 | denom = exp_avg_sq.sqrt().add_(group['eps']) 214 | 215 | bias_correction1 = 1 - beta1 ** state['step'] 216 | bias_correction2 = 1 - beta2 ** state['step'] 217 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 218 | 219 | # Applies bounds on actual learning rate 220 | # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay 221 | final_lr = group['final_lr'] * group['lr'] / base_lr 222 | lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1)) 223 | upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step'])) 224 | step_size = torch.full_like(denom, step_size) 225 | step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg) 226 | 227 | if group['weight_decay'] != 0: 228 | decayed_weights = torch.mul(p.data, group['weight_decay']) 229 | p.data.add_(-step_size) 230 | p.data.sub_(decayed_weights) 231 | else: 232 | p.data.add_(-step_size) 233 | 234 | return loss 235 | -------------------------------------------------------------------------------- /demos/README.md: -------------------------------------------------------------------------------- 1 | # Demos 2 | 3 | Here we provide some demos of using AdaBound on several benchmark tasks. 4 | The purpose of these demos is to give an example of how to use it your research, and also 5 | illustrate the robust performance of AdaBound. 6 | 7 | In short, AdaBound can be regarded as an optimizer that dynamically transforms from Adam to SGD as 8 | the training step becomes larger. 9 | In this way, it can **combines the benefits of adaptive methods, viz. fast initial process, and the 10 | good final generalization properties of SGD**. 11 | 12 | In most examples, you can observe that AdaBound has a much faster training speed than SGD 13 | in the early stage, and the learning curve is much smoother than that of SGD. 14 | As for the final performance on unseen data, AdaBound can achieve better or similar performance 15 | compared with SGD, and has a considerable improvement over the adaptive methods. 16 | 17 | ## Demo List 18 | - CIFAR-10 \[[notebook](./cifar10/visualization.ipynb)\] \[[code](./cifar10)\] 19 | 20 | ## Future Work 21 | 22 | We will keep updating the demos in the near future to include more popular benchmarks. 23 | Feel free to leave an issue or send an email to the first author ([Liangchen Luo](mailto:luolc.witty@gmail.com)) 24 | if you want to see a specific task which has not been included yet. :D 25 | -------------------------------------------------------------------------------- /demos/cifar10/.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | checkpoint 3 | curve/* 4 | !curve/pretrained 5 | 6 | main.sh 7 | -------------------------------------------------------------------------------- /demos/cifar10/README.md: -------------------------------------------------------------------------------- 1 | # Examples on CIFAR-10 2 | 3 | In this example, we test AdaBound/AMSBound on the standard CIFAR-10 image classification dataset, 4 | comparing with several baseline methods including: SGD, AdaGrad, Adam, and AMSGrad. 5 | The implementation is highly based on [this project](https://github.com/kuangliu/pytorch-cifar). 6 | 7 | Tested with PyTorch 0.4.1. 8 | 9 | ## Visualization 10 | 11 | We provide a notebook to make it easier to visualize the performance of AdaBound. 12 | You can directly click [visualization.ipynb](./visualization.ipynb) and view the result on GitHub, 13 | or clone the project and run on your local. 14 | 15 | ## Settings 16 | 17 | We have already provided the results produced by AdaBound/AMSBound with default settings and 18 | baseline optimizers with their best hyperparameters. 19 | The way of searching the best settings for baseline optimizers is described in the experiment 20 | section of the paper. 21 | The best hyperparameters are listed as follows to ease your reproduction: 22 | 23 | **ResNet-34:** 24 | 25 | | optimizer | lr | momentum | beta1 | beta2 | final lr | gamma | 26 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 27 | | SGD | 0.1 | 0.9 | | | | | 28 | | AdaGrad | 0.01 | | | | | | 29 | | Adam | 0.001 | | 0.99 | 0.999 | | | 30 | | AMSGrad | 0.001 | | 0.99 | 0.999 | | | 31 | | AdaBound (def.) | 0.001 | | 0.9 | 0.999 | 0.1 | 0.001 | 32 | | AMSBound (def.) | 0.001 | | 0.9 | 0.999 | 0.1 | 0.001 | 33 | 34 | **DenseNet-121:** 35 | 36 | | optimizer | lr | momentum | beta1 | beta2 | final lr | gamma | 37 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 38 | | SGD | 0.1 | 0.9 | | | | | 39 | | AdaGrad | 0.01 | | | | | | 40 | | Adam | 0.001 | | 0.9 | 0.999 | | | 41 | | AMSGrad | 0.001 | | 0.9 | 0.999 | | | 42 | | AdaBound (def.) | 0.001 | | 0.9 | 0.999 | 0.1 | 0.001 | 43 | | AMSBound (def.) | 0.001 | | 0.9 | 0.999 | 0.1 | 0.001 | 44 | 45 | We apply a weight decay of `5e-4` to all the optimizers. 46 | 47 | ## Running by Yourself 48 | 49 | You may also run the experiment and visualize the result by yourself. 50 | The following is an example to train ResNet-34 using AdaBound with a learning rate of 0.001 and 51 | a final learning rate of 0.1. 52 | 53 | ```bash 54 | python main.py --model=resnet --optim=adabound --lr=0.001 --final_lr=0.1 55 | ``` 56 | 57 | The checkpoints will be saved in the `checkpoint` folder and the data points of the learning curve 58 | will be save in the `curve` folder. 59 | -------------------------------------------------------------------------------- /demos/cifar10/curve/pretrained/densenet-adabound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/densenet-adabound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001 -------------------------------------------------------------------------------- /demos/cifar10/curve/pretrained/densenet-adagrad-lr0.01: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/densenet-adagrad-lr0.01 -------------------------------------------------------------------------------- /demos/cifar10/curve/pretrained/densenet-adam-lr0.001-betas0.9-0.999: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/densenet-adam-lr0.001-betas0.9-0.999 -------------------------------------------------------------------------------- /demos/cifar10/curve/pretrained/densenet-amsbound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/densenet-amsbound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001 -------------------------------------------------------------------------------- /demos/cifar10/curve/pretrained/densenet-amsgrad-lr0.001-betas0.9-0.999: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/densenet-amsgrad-lr0.001-betas0.9-0.999 -------------------------------------------------------------------------------- /demos/cifar10/curve/pretrained/densenet-sgd-lr0.1-momentum0.9: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/densenet-sgd-lr0.1-momentum0.9 -------------------------------------------------------------------------------- /demos/cifar10/curve/pretrained/resnet-adabound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/resnet-adabound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001 -------------------------------------------------------------------------------- /demos/cifar10/curve/pretrained/resnet-adagrad-lr0.01: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/resnet-adagrad-lr0.01 -------------------------------------------------------------------------------- /demos/cifar10/curve/pretrained/resnet-adam-lr0.001-betas0.99-0.999: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/resnet-adam-lr0.001-betas0.99-0.999 -------------------------------------------------------------------------------- /demos/cifar10/curve/pretrained/resnet-amsbound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/resnet-amsbound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001 -------------------------------------------------------------------------------- /demos/cifar10/curve/pretrained/resnet-amsgrad-lr0.001-betas0.99-0.999: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/resnet-amsgrad-lr0.001-betas0.99-0.999 -------------------------------------------------------------------------------- /demos/cifar10/curve/pretrained/resnet-sgd-lr0.1-momentum0.9: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/resnet-sgd-lr0.1-momentum0.9 -------------------------------------------------------------------------------- /demos/cifar10/main.py: -------------------------------------------------------------------------------- 1 | """Train CIFAR10 with PyTorch.""" 2 | from __future__ import print_function 3 | 4 | import torch.optim as optim 5 | import torch.backends.cudnn as cudnn 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | 9 | import os 10 | import argparse 11 | 12 | from models import * 13 | from adabound import AdaBound 14 | 15 | 16 | def get_parser(): 17 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 18 | parser.add_argument('--model', default='resnet', type=str, help='model', 19 | choices=['resnet', 'densenet']) 20 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer', 21 | choices=['sgd', 'adagrad', 'adam', 'amsgrad', 'adabound', 'amsbound']) 22 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 23 | parser.add_argument('--final_lr', default=0.1, type=float, 24 | help='final learning rate of AdaBound') 25 | parser.add_argument('--gamma', default=1e-3, type=float, 26 | help='convergence speed term of AdaBound') 27 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum term') 28 | parser.add_argument('--beta1', default=0.9, type=float, help='Adam coefficients beta_1') 29 | parser.add_argument('--beta2', default=0.999, type=float, help='Adam coefficients beta_2') 30 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint') 31 | parser.add_argument('--weight_decay', default=5e-4, type=float, 32 | help='weight decay for optimizers') 33 | return parser 34 | 35 | 36 | def build_dataset(): 37 | print('==> Preparing data..') 38 | transform_train = transforms.Compose([ 39 | transforms.RandomCrop(32, padding=4), 40 | transforms.RandomHorizontalFlip(), 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 43 | ]) 44 | 45 | transform_test = transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 48 | ]) 49 | 50 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, 51 | transform=transform_train) 52 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, 53 | num_workers=2) 54 | 55 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, 56 | transform=transform_test) 57 | test_loader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) 58 | 59 | # classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 60 | 61 | return train_loader, test_loader 62 | 63 | 64 | def get_ckpt_name(model='resnet', optimizer='sgd', lr=0.1, final_lr=0.1, momentum=0.9, 65 | beta1=0.9, beta2=0.999, gamma=1e-3): 66 | name = { 67 | 'sgd': 'lr{}-momentum{}'.format(lr, momentum), 68 | 'adagrad': 'lr{}'.format(lr), 69 | 'adam': 'lr{}-betas{}-{}'.format(lr, beta1, beta2), 70 | 'amsgrad': 'lr{}-betas{}-{}'.format(lr, beta1, beta2), 71 | 'adabound': 'lr{}-betas{}-{}-final_lr{}-gamma{}'.format(lr, beta1, beta2, final_lr, gamma), 72 | 'amsbound': 'lr{}-betas{}-{}-final_lr{}-gamma{}'.format(lr, beta1, beta2, final_lr, gamma), 73 | }[optimizer] 74 | return '{}-{}-{}'.format(model, optimizer, name) 75 | 76 | 77 | def load_checkpoint(ckpt_name): 78 | print('==> Resuming from checkpoint..') 79 | path = os.path.join('checkpoint', ckpt_name) 80 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 81 | assert os.path.exists(path), 'Error: checkpoint {} not found'.format(ckpt_name) 82 | return torch.load(ckpt_name) 83 | 84 | 85 | def build_model(args, device, ckpt=None): 86 | print('==> Building model..') 87 | net = { 88 | 'resnet': ResNet34, 89 | 'densenet': DenseNet121, 90 | }[args.model]() 91 | net = net.to(device) 92 | if device == 'cuda': 93 | net = torch.nn.DataParallel(net) 94 | cudnn.benchmark = True 95 | 96 | if ckpt: 97 | net.load_state_dict(ckpt['net']) 98 | 99 | return net 100 | 101 | 102 | def create_optimizer(args, model_params): 103 | if args.optim == 'sgd': 104 | return optim.SGD(model_params, args.lr, momentum=args.momentum, 105 | weight_decay=args.weight_decay) 106 | elif args.optim == 'adagrad': 107 | return optim.Adagrad(model_params, args.lr, weight_decay=args.weight_decay) 108 | elif args.optim == 'adam': 109 | return optim.Adam(model_params, args.lr, betas=(args.beta1, args.beta2), 110 | weight_decay=args.weight_decay) 111 | elif args.optim == 'amsgrad': 112 | return optim.Adam(model_params, args.lr, betas=(args.beta1, args.beta2), 113 | weight_decay=args.weight_decay, amsgrad=True) 114 | elif args.optim == 'adabound': 115 | return AdaBound(model_params, args.lr, betas=(args.beta1, args.beta2), 116 | final_lr=args.final_lr, gamma=args.gamma, 117 | weight_decay=args.weight_decay) 118 | else: 119 | assert args.optim == 'amsbound' 120 | return AdaBound(model_params, args.lr, betas=(args.beta1, args.beta2), 121 | final_lr=args.final_lr, gamma=args.gamma, 122 | weight_decay=args.weight_decay, amsbound=True) 123 | 124 | 125 | def train(net, epoch, device, data_loader, optimizer, criterion): 126 | print('\nEpoch: %d' % epoch) 127 | net.train() 128 | train_loss = 0 129 | correct = 0 130 | total = 0 131 | for batch_idx, (inputs, targets) in enumerate(data_loader): 132 | inputs, targets = inputs.to(device), targets.to(device) 133 | optimizer.zero_grad() 134 | outputs = net(inputs) 135 | loss = criterion(outputs, targets) 136 | loss.backward() 137 | optimizer.step() 138 | 139 | train_loss += loss.item() 140 | _, predicted = outputs.max(1) 141 | total += targets.size(0) 142 | correct += predicted.eq(targets).sum().item() 143 | 144 | accuracy = 100. * correct / total 145 | print('train acc %.3f' % accuracy) 146 | 147 | return accuracy 148 | 149 | 150 | def test(net, device, data_loader, criterion): 151 | net.eval() 152 | test_loss = 0 153 | correct = 0 154 | total = 0 155 | with torch.no_grad(): 156 | for batch_idx, (inputs, targets) in enumerate(data_loader): 157 | inputs, targets = inputs.to(device), targets.to(device) 158 | outputs = net(inputs) 159 | loss = criterion(outputs, targets) 160 | 161 | test_loss += loss.item() 162 | _, predicted = outputs.max(1) 163 | total += targets.size(0) 164 | correct += predicted.eq(targets).sum().item() 165 | 166 | accuracy = 100. * correct / total 167 | print(' test acc %.3f' % accuracy) 168 | 169 | return accuracy 170 | 171 | 172 | def main(): 173 | parser = get_parser() 174 | args = parser.parse_args() 175 | 176 | train_loader, test_loader = build_dataset() 177 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 178 | 179 | ckpt_name = get_ckpt_name(model=args.model, optimizer=args.optim, lr=args.lr, 180 | final_lr=args.final_lr, momentum=args.momentum, 181 | beta1=args.beta1, beta2=args.beta2, gamma=args.gamma) 182 | if args.resume: 183 | ckpt = load_checkpoint(ckpt_name) 184 | best_acc = ckpt['acc'] 185 | start_epoch = ckpt['epoch'] 186 | else: 187 | ckpt = None 188 | best_acc = 0 189 | start_epoch = -1 190 | 191 | net = build_model(args, device, ckpt=ckpt) 192 | criterion = nn.CrossEntropyLoss() 193 | optimizer = create_optimizer(args, net.parameters()) 194 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.1, 195 | last_epoch=start_epoch) 196 | 197 | train_accuracies = [] 198 | test_accuracies = [] 199 | 200 | for epoch in range(start_epoch + 1, 200): 201 | scheduler.step() 202 | train_acc = train(net, epoch, device, train_loader, optimizer, criterion) 203 | test_acc = test(net, device, test_loader, criterion) 204 | 205 | # Save checkpoint. 206 | if test_acc > best_acc: 207 | print('Saving..') 208 | state = { 209 | 'net': net.state_dict(), 210 | 'acc': test_acc, 211 | 'epoch': epoch, 212 | } 213 | if not os.path.isdir('checkpoint'): 214 | os.mkdir('checkpoint') 215 | torch.save(state, os.path.join('checkpoint', ckpt_name)) 216 | best_acc = test_acc 217 | 218 | train_accuracies.append(train_acc) 219 | test_accuracies.append(test_acc) 220 | if not os.path.isdir('curve'): 221 | os.mkdir('curve') 222 | torch.save({'train_acc': train_accuracies, 'test_acc': test_accuracies}, 223 | os.path.join('curve', ckpt_name)) 224 | 225 | 226 | if __name__ == '__main__': 227 | main() 228 | -------------------------------------------------------------------------------- /demos/cifar10/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .densenet import * 3 | -------------------------------------------------------------------------------- /demos/cifar10/models/densenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. Densely Connected Convolutional Networks: 3 | https://arxiv.org/abs/1608.06993 4 | """ 5 | import math 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class Bottleneck(nn.Module): 13 | def __init__(self, in_planes, growth_rate): 14 | super(Bottleneck, self).__init__() 15 | self.bn1 = nn.BatchNorm2d(in_planes) 16 | self.conv1 = nn.Conv2d(in_planes, 4 * growth_rate, kernel_size=1, bias=False) 17 | self.bn2 = nn.BatchNorm2d(4 * growth_rate) 18 | self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 19 | 20 | def forward(self, x): 21 | out = self.conv1(F.relu(self.bn1(x))) 22 | out = self.conv2(F.relu(self.bn2(out))) 23 | out = torch.cat([out, x], 1) 24 | return out 25 | 26 | 27 | class Transition(nn.Module): 28 | def __init__(self, in_planes, out_planes): 29 | super(Transition, self).__init__() 30 | self.bn = nn.BatchNorm2d(in_planes) 31 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 32 | 33 | def forward(self, x): 34 | out = self.conv(F.relu(self.bn(x))) 35 | out = F.avg_pool2d(out, 2) 36 | return out 37 | 38 | 39 | class DenseNet(nn.Module): 40 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 41 | super(DenseNet, self).__init__() 42 | self.growth_rate = growth_rate 43 | 44 | num_planes = 2 * growth_rate 45 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 46 | 47 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 48 | num_planes += nblocks[0] * growth_rate 49 | out_planes = int(math.floor(num_planes * reduction)) 50 | self.trans1 = Transition(num_planes, out_planes) 51 | num_planes = out_planes 52 | 53 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 54 | num_planes += nblocks[1] * growth_rate 55 | out_planes = int(math.floor(num_planes * reduction)) 56 | self.trans2 = Transition(num_planes, out_planes) 57 | num_planes = out_planes 58 | 59 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 60 | num_planes += nblocks[2] * growth_rate 61 | out_planes = int(math.floor(num_planes * reduction)) 62 | self.trans3 = Transition(num_planes, out_planes) 63 | num_planes = out_planes 64 | 65 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 66 | num_planes += nblocks[3] * growth_rate 67 | 68 | self.bn = nn.BatchNorm2d(num_planes) 69 | self.linear = nn.Linear(num_planes, num_classes) 70 | 71 | def _make_dense_layers(self, block, in_planes, nblock): 72 | layers = [] 73 | for i in range(nblock): 74 | layers.append(block(in_planes, self.growth_rate)) 75 | in_planes += self.growth_rate 76 | return nn.Sequential(*layers) 77 | 78 | def forward(self, x): 79 | out = self.conv1(x) 80 | out = self.trans1(self.dense1(out)) 81 | out = self.trans2(self.dense2(out)) 82 | out = self.trans3(self.dense3(out)) 83 | out = self.dense4(out) 84 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 85 | out = out.view(out.size(0), -1) 86 | out = self.linear(out) 87 | return out 88 | 89 | 90 | def DenseNet121(): 91 | return DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=32) 92 | 93 | 94 | def DenseNet169(): 95 | return DenseNet(Bottleneck, [6, 12, 32, 32], growth_rate=32) 96 | 97 | 98 | def DenseNet201(): 99 | return DenseNet(Bottleneck, [6, 12, 48, 32], growth_rate=32) 100 | 101 | 102 | def DenseNet161(): 103 | return DenseNet(Bottleneck, [6, 12, 36, 24], growth_rate=48) 104 | 105 | 106 | def densenet_cifar(): 107 | return DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=12) 108 | 109 | 110 | def test(): 111 | net = densenet_cifar() 112 | x = torch.randn(1, 3, 32, 32) 113 | y = net(x) 114 | print(y) 115 | 116 | # test() 117 | -------------------------------------------------------------------------------- /demos/cifar10/models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. Deep Residual Learning for Image Recognition: 3 | https://arxiv.org/abs/1512.03385 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BasicBlock(nn.Module): 11 | expansion = 1 12 | 13 | def __init__(self, in_planes, planes, stride=1): 14 | super(BasicBlock, self).__init__() 15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, 16 | bias=False) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | 21 | self.shortcut = nn.Sequential() 22 | if stride != 1 or in_planes != self.expansion * planes: 23 | self.shortcut = nn.Sequential( 24 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, 25 | bias=False), 26 | nn.BatchNorm2d(self.expansion * planes) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | out += self.shortcut(x) 33 | out = F.relu(out) 34 | return out 35 | 36 | 37 | class Bottleneck(nn.Module): 38 | expansion = 4 39 | 40 | def __init__(self, in_planes, planes, stride=1): 41 | super(Bottleneck, self).__init__() 42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 43 | self.bn1 = nn.BatchNorm2d(planes) 44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride != 1 or in_planes != self.expansion * planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, 53 | bias=False), 54 | nn.BatchNorm2d(self.expansion * planes) 55 | ) 56 | 57 | def forward(self, x): 58 | out = F.relu(self.bn1(self.conv1(x))) 59 | out = F.relu(self.bn2(self.conv2(out))) 60 | out = self.bn3(self.conv3(out)) 61 | out += self.shortcut(x) 62 | out = F.relu(out) 63 | return out 64 | 65 | 66 | class ResNet(nn.Module): 67 | def __init__(self, block, num_blocks, num_classes=10): 68 | super(ResNet, self).__init__() 69 | self.in_planes = 64 70 | 71 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 72 | self.bn1 = nn.BatchNorm2d(64) 73 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 74 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 75 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 76 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 77 | self.linear = nn.Linear(512 * block.expansion, num_classes) 78 | 79 | def _make_layer(self, block, planes, num_blocks, stride): 80 | strides = [stride] + [1] * (num_blocks - 1) 81 | layers = [] 82 | for stride in strides: 83 | layers.append(block(self.in_planes, planes, stride)) 84 | self.in_planes = planes * block.expansion 85 | return nn.Sequential(*layers) 86 | 87 | def forward(self, x): 88 | out = F.relu(self.bn1(self.conv1(x))) 89 | out = self.layer1(out) 90 | out = self.layer2(out) 91 | out = self.layer3(out) 92 | out = self.layer4(out) 93 | out = F.avg_pool2d(out, 4) 94 | out = out.view(out.size(0), -1) 95 | out = self.linear(out) 96 | return out 97 | 98 | 99 | def ResNet18(): 100 | return ResNet(BasicBlock, [2, 2, 2, 2]) 101 | 102 | 103 | def ResNet34(): 104 | return ResNet(BasicBlock, [3, 4, 6, 3]) 105 | 106 | 107 | def ResNet50(): 108 | return ResNet(Bottleneck, [3, 4, 6, 3]) 109 | 110 | 111 | def ResNet101(): 112 | return ResNet(Bottleneck, [3, 4, 23, 3]) 113 | 114 | 115 | def ResNet152(): 116 | return ResNet(Bottleneck, [3, 8, 36, 3]) 117 | 118 | 119 | def test(): 120 | net = ResNet18() 121 | y = net(torch.randn(1, 3, 32, 32)) 122 | print(y.size()) 123 | 124 | # test() 125 | -------------------------------------------------------------------------------- /demos/cifar10/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=0.4.1 2 | torchvision>=0.2.1 3 | -------------------------------------------------------------------------------- /release.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python3 setup.py sdist bdist_wheel 3 | twine upload dist/* 4 | 5 | rm -rf build 6 | rm -rf dist 7 | rm -rf *.egg-info 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | __VERSION__ = '0.0.5' 4 | 5 | setup(name='adabound', 6 | version=__VERSION__, 7 | description='AdaBound optimization algorithm, build on PyTorch.', 8 | long_description=open("README.md").read(), 9 | long_description_content_type="text/markdown", 10 | keywords=['machine learning', 'deep learning'], 11 | classifiers=[ 12 | 'Intended Audience :: Science/Research', 13 | 'Development Status :: 3 - Alpha', 14 | 'License :: OSI Approved :: Apache Software License', 15 | 'Programming Language :: Python :: 3.6', 16 | 'Programming Language :: Python :: 3.7', 17 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 18 | ], 19 | url='https://github.com/Luolc/AdaBound', 20 | author='Liangchen Luo', 21 | author_email='luolc.witty@gmail.com', 22 | license='Apache', 23 | packages=['adabound'], 24 | install_requires=[ 25 | 'torch>=0.4.0', 26 | ], 27 | zip_safe=False, 28 | python_requires='>=3.6.0') 29 | --------------------------------------------------------------------------------