├── .gitignore
├── LICENSE
├── README.md
├── adabound
├── __init__.py
└── adabound.py
├── demos
├── README.md
└── cifar10
│ ├── .gitignore
│ ├── README.md
│ ├── curve
│ └── pretrained
│ │ ├── densenet-adabound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001
│ │ ├── densenet-adagrad-lr0.01
│ │ ├── densenet-adam-lr0.001-betas0.9-0.999
│ │ ├── densenet-amsbound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001
│ │ ├── densenet-amsgrad-lr0.001-betas0.9-0.999
│ │ ├── densenet-sgd-lr0.1-momentum0.9
│ │ ├── resnet-adabound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001
│ │ ├── resnet-adagrad-lr0.01
│ │ ├── resnet-adam-lr0.001-betas0.99-0.999
│ │ ├── resnet-amsbound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001
│ │ ├── resnet-amsgrad-lr0.001-betas0.99-0.999
│ │ └── resnet-sgd-lr0.1-momentum0.9
│ ├── main.py
│ ├── models
│ ├── __init__.py
│ ├── densenet.py
│ └── resnet.py
│ ├── requirements.txt
│ └── visualization.ipynb
├── release.sh
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | ### JetBrains template
3 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
4 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
5 |
6 | # We just ignore all .idea folder
7 | .idea
8 |
9 | # User-specific stuff:
10 | .idea/**/workspace.xml
11 | .idea/**/tasks.xml
12 | .idea/dictionaries
13 |
14 | # Sensitive or high-churn files:
15 | .idea/**/dataSources/
16 | .idea/**/dataSources.ids
17 | .idea/**/dataSources.local.xml
18 | .idea/**/sqlDataSources.xml
19 | .idea/**/dynamic.xml
20 | .idea/**/uiDesigner.xml
21 |
22 | # Gradle:
23 | .idea/**/gradle.xml
24 | .idea/**/libraries
25 |
26 | # CMake
27 | cmake-build-debug/
28 | cmake-build-release/
29 |
30 | # Mongo Explorer plugin:
31 | .idea/**/mongoSettings.xml
32 |
33 | ## File-based project format:
34 | *.iws
35 |
36 | ## Plugin-specific files:
37 |
38 | # IntelliJ
39 | out/
40 |
41 | # mpeltonen/sbt-idea plugin
42 | .idea_modules/
43 |
44 | # JIRA plugin
45 | atlassian-ide-plugin.xml
46 |
47 | # Cursive Clojure plugin
48 | .idea/replstate.xml
49 |
50 | # Crashlytics plugin (for Android Studio and IntelliJ)
51 | com_crashlytics_export_strings.xml
52 | crashlytics.properties
53 | crashlytics-build.properties
54 | fabric.properties
55 | ### macOS template
56 | # General
57 | .DS_Store
58 | .AppleDouble
59 | .LSOverride
60 |
61 | # Icon must end with two \r
62 | Icon
63 |
64 | # Thumbnails
65 | ._*
66 |
67 | # Files that might appear in the root of a volume
68 | .DocumentRevisions-V100
69 | .fseventsd
70 | .Spotlight-V100
71 | .TemporaryItems
72 | .Trashes
73 | .VolumeIcon.icns
74 | .com.apple.timemachine.donotpresent
75 |
76 | # Directories potentially created on remote AFP share
77 | .AppleDB
78 | .AppleDesktop
79 | Network Trash Folder
80 | Temporary Items
81 | .apdisk
82 | ### Python template
83 | # Byte-compiled / optimized / DLL files
84 | __pycache__/
85 | *.py[cod]
86 | *$py.class
87 |
88 | # C extensions
89 | *.so
90 |
91 | # Distribution / packaging
92 | .Python
93 | build/
94 | develop-eggs/
95 | dist/
96 | downloads/
97 | eggs/
98 | .eggs/
99 | lib/
100 | lib64/
101 | parts/
102 | sdist/
103 | var/
104 | wheels/
105 | *.egg-info/
106 | .installed.cfg
107 | *.egg
108 | MANIFEST
109 |
110 | # PyInstaller
111 | # Usually these files are written by a python script from a template
112 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
113 | *.manifest
114 | *.spec
115 |
116 | # Installer logs
117 | pip-log.txt
118 | pip-delete-this-directory.txt
119 |
120 | # Unit test / coverage reports
121 | htmlcov/
122 | .tox/
123 | .coverage
124 | .coverage.*
125 | .cache
126 | nosetests.xml
127 | coverage.xml
128 | *.cover
129 | .hypothesis/
130 | .pytest_cache/
131 |
132 | # Translations
133 | *.mo
134 | *.pot
135 |
136 | # Django stuff:
137 | *.log
138 | local_settings.py
139 | db.sqlite3
140 |
141 | # Flask stuff:
142 | instance/
143 | .webassets-cache
144 |
145 | # Scrapy stuff:
146 | .scrapy
147 |
148 | # Sphinx documentation
149 | docs/_build/
150 |
151 | # PyBuilder
152 | target/
153 |
154 | # Jupyter Notebook
155 | .ipynb_checkpoints
156 |
157 | # pyenv
158 | .python-version
159 |
160 | # celery beat schedule file
161 | celerybeat-schedule
162 |
163 | # SageMath parsed files
164 | *.sage.py
165 |
166 | # Environments
167 | .env
168 | .venv
169 | env/
170 | venv/
171 | ENV/
172 | env.bak/
173 | venv.bak/
174 |
175 | # Spyder project settings
176 | .spyderproject
177 | .spyproject
178 |
179 | # Rope project settings
180 | .ropeproject
181 |
182 | # mkdocs documentation
183 | /site
184 |
185 | # mypy
186 | .mypy_cache/
187 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AdaBound
2 | [](https://pypi.org/project/adabound/)
3 | [](https://pypi.org/project/adabound/)
4 | [](https://pypi.org/project/adabound/)
5 | [](./LICENSE)
6 |
7 | An optimizer that trains as fast as Adam and as good as SGD, for developing state-of-the-art
8 | deep learning models on a wide variety of popular tasks in the field of CV, NLP, and etc.
9 |
10 | Based on Luo et al. (2019).
11 | [Adaptive Gradient Methods with Dynamic Bound of Learning Rate](https://openreview.net/forum?id=Bkg3g2R9FX).
12 | In *Proc. of ICLR 2019*.
13 |
14 |
15 |
16 |
17 |
18 | ## Quick Links
19 |
20 | - [Website](https://www.luolc.com/publications/adabound/)
21 | - [Demos](./demos)
22 |
23 | ## Installation
24 |
25 | AdaBound requires Python 3.6.0 or later.
26 | We currently provide PyTorch version and AdaBound for TensorFlow is coming soon.
27 |
28 | ### Installing via pip
29 |
30 | The preferred way to install AdaBound is via `pip` with a virtual environment.
31 | Just run
32 | ```bash
33 | pip install adabound
34 | ```
35 | in your Python environment and you are ready to go!
36 |
37 | ### Using source code
38 |
39 | As AdaBound is a Python class with only 100+ lines, an alternative way is directly downloading
40 | [adabound.py](./adabound/adabound.py) and copying it to your project.
41 |
42 | ## Usage
43 |
44 | You can use AdaBound just like any other PyTorch optimizers.
45 |
46 | ```python3
47 | optimizer = adabound.AdaBound(model.parameters(), lr=1e-3, final_lr=0.1)
48 | ```
49 |
50 | As described in the paper, AdaBound is an optimizer that behaves like Adam at the beginning of
51 | training, and gradually transforms to SGD at the end.
52 | The `final_lr` parameter indicates AdaBound would transforms to an SGD with this learning rate.
53 | In common cases, a default final learning rate of `0.1` can achieve relatively good and stable
54 | results on unseen data.
55 | It is not very sensitive to its hyperparameters.
56 | See Appendix G of the paper for more details.
57 |
58 | Despite of its robust performance, we still have to state that, **there is no silver bullet**.
59 | It does not mean that you will be free from tuning hyperparameters once using AdaBound.
60 | The performance of a model depends on so many things including the task, the model structure,
61 | the distribution of data, and etc.
62 | **You still need to decide what hyperparameters to use based on your specific situation,
63 | but you may probably use much less time than before!**
64 |
65 | ## Demos
66 |
67 | Thanks to the awesome work by the GitHub team and the Jupyter team, the Jupyter notebook (`.ipynb`)
68 | files can render directly on GitHub.
69 | We provide several notebooks (like [this one](./demos/cifar10/visualization.ipynb)) for better
70 | visualization.
71 | We hope to illustrate the robust performance of AdaBound through these examples.
72 |
73 | For the full list of demos, please refer to [this page](./demos).
74 |
75 | ## Citing
76 | If you use AdaBound in your research, please cite [Adaptive Gradient Methods with Dynamic Bound of Learning Rate](https://openreview.net/forum?id=Bkg3g2R9FX).
77 | ```text
78 | @inproceedings{Luo2019AdaBound,
79 | author = {Luo, Liangchen and Xiong, Yuanhao and Liu, Yan and Sun, Xu},
80 | title = {Adaptive Gradient Methods with Dynamic Bound of Learning Rate},
81 | booktitle = {Proceedings of the 7th International Conference on Learning Representations},
82 | month = {May},
83 | year = {2019},
84 | address = {New Orleans, Louisiana}
85 | }
86 | ```
87 |
88 | ## Contributors
89 |
90 | [@kayuksel](https://github.com/kayuksel)
91 |
92 | ## License
93 | [Apache 2.0](./LICENSE)
94 |
--------------------------------------------------------------------------------
/adabound/__init__.py:
--------------------------------------------------------------------------------
1 | from .adabound import AdaBound
2 |
--------------------------------------------------------------------------------
/adabound/adabound.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim import Optimizer
4 |
5 |
6 | class AdaBound(Optimizer):
7 | """Implements AdaBound algorithm.
8 | It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_.
9 | Arguments:
10 | params (iterable): iterable of parameters to optimize or dicts defining
11 | parameter groups
12 | lr (float, optional): Adam learning rate (default: 1e-3)
13 | betas (Tuple[float, float], optional): coefficients used for computing
14 | running averages of gradient and its square (default: (0.9, 0.999))
15 | final_lr (float, optional): final (SGD) learning rate (default: 0.1)
16 | gamma (float, optional): convergence speed of the bound functions (default: 1e-3)
17 | eps (float, optional): term added to the denominator to improve
18 | numerical stability (default: 1e-8)
19 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
20 | amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm
21 | .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate:
22 | https://openreview.net/forum?id=Bkg3g2R9FX
23 | """
24 |
25 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3,
26 | eps=1e-8, weight_decay=0, amsbound=False):
27 | if not 0.0 <= lr:
28 | raise ValueError("Invalid learning rate: {}".format(lr))
29 | if not 0.0 <= eps:
30 | raise ValueError("Invalid epsilon value: {}".format(eps))
31 | if not 0.0 <= betas[0] < 1.0:
32 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
33 | if not 0.0 <= betas[1] < 1.0:
34 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
35 | if not 0.0 <= final_lr:
36 | raise ValueError("Invalid final learning rate: {}".format(final_lr))
37 | if not 0.0 <= gamma < 1.0:
38 | raise ValueError("Invalid gamma parameter: {}".format(gamma))
39 | defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps,
40 | weight_decay=weight_decay, amsbound=amsbound)
41 | super(AdaBound, self).__init__(params, defaults)
42 |
43 | self.base_lrs = list(map(lambda group: group['lr'], self.param_groups))
44 |
45 | def __setstate__(self, state):
46 | super(AdaBound, self).__setstate__(state)
47 | for group in self.param_groups:
48 | group.setdefault('amsbound', False)
49 |
50 | def step(self, closure=None):
51 | """Performs a single optimization step.
52 | Arguments:
53 | closure (callable, optional): A closure that reevaluates the model
54 | and returns the loss.
55 | """
56 | loss = None
57 | if closure is not None:
58 | loss = closure()
59 |
60 | for group, base_lr in zip(self.param_groups, self.base_lrs):
61 | for p in group['params']:
62 | if p.grad is None:
63 | continue
64 | grad = p.grad.data
65 | if grad.is_sparse:
66 | raise RuntimeError(
67 | 'Adam does not support sparse gradients, please consider SparseAdam instead')
68 | amsbound = group['amsbound']
69 |
70 | state = self.state[p]
71 |
72 | # State initialization
73 | if len(state) == 0:
74 | state['step'] = 0
75 | # Exponential moving average of gradient values
76 | state['exp_avg'] = torch.zeros_like(p.data)
77 | # Exponential moving average of squared gradient values
78 | state['exp_avg_sq'] = torch.zeros_like(p.data)
79 | if amsbound:
80 | # Maintains max of all exp. moving avg. of sq. grad. values
81 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
82 |
83 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
84 | if amsbound:
85 | max_exp_avg_sq = state['max_exp_avg_sq']
86 | beta1, beta2 = group['betas']
87 |
88 | state['step'] += 1
89 |
90 | if group['weight_decay'] != 0:
91 | grad = grad.add(group['weight_decay'], p.data)
92 |
93 | # Decay the first and second moment running average coefficient
94 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
95 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
96 | if amsbound:
97 | # Maintains the maximum of all 2nd moment running avg. till now
98 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
99 | # Use the max. for normalizing running avg. of gradient
100 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
101 | else:
102 | denom = exp_avg_sq.sqrt().add_(group['eps'])
103 |
104 | bias_correction1 = 1 - beta1 ** state['step']
105 | bias_correction2 = 1 - beta2 ** state['step']
106 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
107 |
108 | # Applies bounds on actual learning rate
109 | # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay
110 | final_lr = group['final_lr'] * group['lr'] / base_lr
111 | lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1))
112 | upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step']))
113 | step_size = torch.full_like(denom, step_size)
114 | step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
115 |
116 | p.data.add_(-step_size)
117 |
118 | return loss
119 |
120 | class AdaBoundW(Optimizer):
121 | """Implements AdaBound algorithm with Decoupled Weight Decay (arxiv.org/abs/1711.05101)
122 | It has been proposed in `Adaptive Gradient Methods with Dynamic Bound of Learning Rate`_.
123 | Arguments:
124 | params (iterable): iterable of parameters to optimize or dicts defining
125 | parameter groups
126 | lr (float, optional): Adam learning rate (default: 1e-3)
127 | betas (Tuple[float, float], optional): coefficients used for computing
128 | running averages of gradient and its square (default: (0.9, 0.999))
129 | final_lr (float, optional): final (SGD) learning rate (default: 0.1)
130 | gamma (float, optional): convergence speed of the bound functions (default: 1e-3)
131 | eps (float, optional): term added to the denominator to improve
132 | numerical stability (default: 1e-8)
133 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
134 | amsbound (boolean, optional): whether to use the AMSBound variant of this algorithm
135 | .. Adaptive Gradient Methods with Dynamic Bound of Learning Rate:
136 | https://openreview.net/forum?id=Bkg3g2R9FX
137 | """
138 |
139 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), final_lr=0.1, gamma=1e-3,
140 | eps=1e-8, weight_decay=0, amsbound=False):
141 | if not 0.0 <= lr:
142 | raise ValueError("Invalid learning rate: {}".format(lr))
143 | if not 0.0 <= eps:
144 | raise ValueError("Invalid epsilon value: {}".format(eps))
145 | if not 0.0 <= betas[0] < 1.0:
146 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
147 | if not 0.0 <= betas[1] < 1.0:
148 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
149 | if not 0.0 <= final_lr:
150 | raise ValueError("Invalid final learning rate: {}".format(final_lr))
151 | if not 0.0 <= gamma < 1.0:
152 | raise ValueError("Invalid gamma parameter: {}".format(gamma))
153 | defaults = dict(lr=lr, betas=betas, final_lr=final_lr, gamma=gamma, eps=eps,
154 | weight_decay=weight_decay, amsbound=amsbound)
155 | super(AdaBoundW, self).__init__(params, defaults)
156 |
157 | self.base_lrs = list(map(lambda group: group['lr'], self.param_groups))
158 |
159 | def __setstate__(self, state):
160 | super(AdaBoundW, self).__setstate__(state)
161 | for group in self.param_groups:
162 | group.setdefault('amsbound', False)
163 |
164 | def step(self, closure=None):
165 | """Performs a single optimization step.
166 | Arguments:
167 | closure (callable, optional): A closure that reevaluates the model
168 | and returns the loss.
169 | """
170 | loss = None
171 | if closure is not None:
172 | loss = closure()
173 |
174 | for group, base_lr in zip(self.param_groups, self.base_lrs):
175 | for p in group['params']:
176 | if p.grad is None:
177 | continue
178 | grad = p.grad.data
179 | if grad.is_sparse:
180 | raise RuntimeError(
181 | 'Adam does not support sparse gradients, please consider SparseAdam instead')
182 | amsbound = group['amsbound']
183 |
184 | state = self.state[p]
185 |
186 | # State initialization
187 | if len(state) == 0:
188 | state['step'] = 0
189 | # Exponential moving average of gradient values
190 | state['exp_avg'] = torch.zeros_like(p.data)
191 | # Exponential moving average of squared gradient values
192 | state['exp_avg_sq'] = torch.zeros_like(p.data)
193 | if amsbound:
194 | # Maintains max of all exp. moving avg. of sq. grad. values
195 | state['max_exp_avg_sq'] = torch.zeros_like(p.data)
196 |
197 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
198 | if amsbound:
199 | max_exp_avg_sq = state['max_exp_avg_sq']
200 | beta1, beta2 = group['betas']
201 |
202 | state['step'] += 1
203 |
204 | # Decay the first and second moment running average coefficient
205 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
206 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
207 | if amsbound:
208 | # Maintains the maximum of all 2nd moment running avg. till now
209 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
210 | # Use the max. for normalizing running avg. of gradient
211 | denom = max_exp_avg_sq.sqrt().add_(group['eps'])
212 | else:
213 | denom = exp_avg_sq.sqrt().add_(group['eps'])
214 |
215 | bias_correction1 = 1 - beta1 ** state['step']
216 | bias_correction2 = 1 - beta2 ** state['step']
217 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
218 |
219 | # Applies bounds on actual learning rate
220 | # lr_scheduler cannot affect final_lr, this is a workaround to apply lr decay
221 | final_lr = group['final_lr'] * group['lr'] / base_lr
222 | lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1))
223 | upper_bound = final_lr * (1 + 1 / (group['gamma'] * state['step']))
224 | step_size = torch.full_like(denom, step_size)
225 | step_size.div_(denom).clamp_(lower_bound, upper_bound).mul_(exp_avg)
226 |
227 | if group['weight_decay'] != 0:
228 | decayed_weights = torch.mul(p.data, group['weight_decay'])
229 | p.data.add_(-step_size)
230 | p.data.sub_(decayed_weights)
231 | else:
232 | p.data.add_(-step_size)
233 |
234 | return loss
235 |
--------------------------------------------------------------------------------
/demos/README.md:
--------------------------------------------------------------------------------
1 | # Demos
2 |
3 | Here we provide some demos of using AdaBound on several benchmark tasks.
4 | The purpose of these demos is to give an example of how to use it your research, and also
5 | illustrate the robust performance of AdaBound.
6 |
7 | In short, AdaBound can be regarded as an optimizer that dynamically transforms from Adam to SGD as
8 | the training step becomes larger.
9 | In this way, it can **combines the benefits of adaptive methods, viz. fast initial process, and the
10 | good final generalization properties of SGD**.
11 |
12 | In most examples, you can observe that AdaBound has a much faster training speed than SGD
13 | in the early stage, and the learning curve is much smoother than that of SGD.
14 | As for the final performance on unseen data, AdaBound can achieve better or similar performance
15 | compared with SGD, and has a considerable improvement over the adaptive methods.
16 |
17 | ## Demo List
18 | - CIFAR-10 \[[notebook](./cifar10/visualization.ipynb)\] \[[code](./cifar10)\]
19 |
20 | ## Future Work
21 |
22 | We will keep updating the demos in the near future to include more popular benchmarks.
23 | Feel free to leave an issue or send an email to the first author ([Liangchen Luo](mailto:luolc.witty@gmail.com))
24 | if you want to see a specific task which has not been included yet. :D
25 |
--------------------------------------------------------------------------------
/demos/cifar10/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 | checkpoint
3 | curve/*
4 | !curve/pretrained
5 |
6 | main.sh
7 |
--------------------------------------------------------------------------------
/demos/cifar10/README.md:
--------------------------------------------------------------------------------
1 | # Examples on CIFAR-10
2 |
3 | In this example, we test AdaBound/AMSBound on the standard CIFAR-10 image classification dataset,
4 | comparing with several baseline methods including: SGD, AdaGrad, Adam, and AMSGrad.
5 | The implementation is highly based on [this project](https://github.com/kuangliu/pytorch-cifar).
6 |
7 | Tested with PyTorch 0.4.1.
8 |
9 | ## Visualization
10 |
11 | We provide a notebook to make it easier to visualize the performance of AdaBound.
12 | You can directly click [visualization.ipynb](./visualization.ipynb) and view the result on GitHub,
13 | or clone the project and run on your local.
14 |
15 | ## Settings
16 |
17 | We have already provided the results produced by AdaBound/AMSBound with default settings and
18 | baseline optimizers with their best hyperparameters.
19 | The way of searching the best settings for baseline optimizers is described in the experiment
20 | section of the paper.
21 | The best hyperparameters are listed as follows to ease your reproduction:
22 |
23 | **ResNet-34:**
24 |
25 | | optimizer | lr | momentum | beta1 | beta2 | final lr | gamma |
26 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
27 | | SGD | 0.1 | 0.9 | | | | |
28 | | AdaGrad | 0.01 | | | | | |
29 | | Adam | 0.001 | | 0.99 | 0.999 | | |
30 | | AMSGrad | 0.001 | | 0.99 | 0.999 | | |
31 | | AdaBound (def.) | 0.001 | | 0.9 | 0.999 | 0.1 | 0.001 |
32 | | AMSBound (def.) | 0.001 | | 0.9 | 0.999 | 0.1 | 0.001 |
33 |
34 | **DenseNet-121:**
35 |
36 | | optimizer | lr | momentum | beta1 | beta2 | final lr | gamma |
37 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
38 | | SGD | 0.1 | 0.9 | | | | |
39 | | AdaGrad | 0.01 | | | | | |
40 | | Adam | 0.001 | | 0.9 | 0.999 | | |
41 | | AMSGrad | 0.001 | | 0.9 | 0.999 | | |
42 | | AdaBound (def.) | 0.001 | | 0.9 | 0.999 | 0.1 | 0.001 |
43 | | AMSBound (def.) | 0.001 | | 0.9 | 0.999 | 0.1 | 0.001 |
44 |
45 | We apply a weight decay of `5e-4` to all the optimizers.
46 |
47 | ## Running by Yourself
48 |
49 | You may also run the experiment and visualize the result by yourself.
50 | The following is an example to train ResNet-34 using AdaBound with a learning rate of 0.001 and
51 | a final learning rate of 0.1.
52 |
53 | ```bash
54 | python main.py --model=resnet --optim=adabound --lr=0.001 --final_lr=0.1
55 | ```
56 |
57 | The checkpoints will be saved in the `checkpoint` folder and the data points of the learning curve
58 | will be save in the `curve` folder.
59 |
--------------------------------------------------------------------------------
/demos/cifar10/curve/pretrained/densenet-adabound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/densenet-adabound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001
--------------------------------------------------------------------------------
/demos/cifar10/curve/pretrained/densenet-adagrad-lr0.01:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/densenet-adagrad-lr0.01
--------------------------------------------------------------------------------
/demos/cifar10/curve/pretrained/densenet-adam-lr0.001-betas0.9-0.999:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/densenet-adam-lr0.001-betas0.9-0.999
--------------------------------------------------------------------------------
/demos/cifar10/curve/pretrained/densenet-amsbound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/densenet-amsbound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001
--------------------------------------------------------------------------------
/demos/cifar10/curve/pretrained/densenet-amsgrad-lr0.001-betas0.9-0.999:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/densenet-amsgrad-lr0.001-betas0.9-0.999
--------------------------------------------------------------------------------
/demos/cifar10/curve/pretrained/densenet-sgd-lr0.1-momentum0.9:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/densenet-sgd-lr0.1-momentum0.9
--------------------------------------------------------------------------------
/demos/cifar10/curve/pretrained/resnet-adabound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/resnet-adabound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001
--------------------------------------------------------------------------------
/demos/cifar10/curve/pretrained/resnet-adagrad-lr0.01:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/resnet-adagrad-lr0.01
--------------------------------------------------------------------------------
/demos/cifar10/curve/pretrained/resnet-adam-lr0.001-betas0.99-0.999:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/resnet-adam-lr0.001-betas0.99-0.999
--------------------------------------------------------------------------------
/demos/cifar10/curve/pretrained/resnet-amsbound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/resnet-amsbound-lr0.001-betas0.9-0.999-final_lr0.1-gamma0.001
--------------------------------------------------------------------------------
/demos/cifar10/curve/pretrained/resnet-amsgrad-lr0.001-betas0.99-0.999:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/resnet-amsgrad-lr0.001-betas0.99-0.999
--------------------------------------------------------------------------------
/demos/cifar10/curve/pretrained/resnet-sgd-lr0.1-momentum0.9:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Luolc/AdaBound/2e928c3007a2fc44af0e4c97e343e1fed6986e44/demos/cifar10/curve/pretrained/resnet-sgd-lr0.1-momentum0.9
--------------------------------------------------------------------------------
/demos/cifar10/main.py:
--------------------------------------------------------------------------------
1 | """Train CIFAR10 with PyTorch."""
2 | from __future__ import print_function
3 |
4 | import torch.optim as optim
5 | import torch.backends.cudnn as cudnn
6 | import torchvision
7 | import torchvision.transforms as transforms
8 |
9 | import os
10 | import argparse
11 |
12 | from models import *
13 | from adabound import AdaBound
14 |
15 |
16 | def get_parser():
17 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
18 | parser.add_argument('--model', default='resnet', type=str, help='model',
19 | choices=['resnet', 'densenet'])
20 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer',
21 | choices=['sgd', 'adagrad', 'adam', 'amsgrad', 'adabound', 'amsbound'])
22 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
23 | parser.add_argument('--final_lr', default=0.1, type=float,
24 | help='final learning rate of AdaBound')
25 | parser.add_argument('--gamma', default=1e-3, type=float,
26 | help='convergence speed term of AdaBound')
27 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum term')
28 | parser.add_argument('--beta1', default=0.9, type=float, help='Adam coefficients beta_1')
29 | parser.add_argument('--beta2', default=0.999, type=float, help='Adam coefficients beta_2')
30 | parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
31 | parser.add_argument('--weight_decay', default=5e-4, type=float,
32 | help='weight decay for optimizers')
33 | return parser
34 |
35 |
36 | def build_dataset():
37 | print('==> Preparing data..')
38 | transform_train = transforms.Compose([
39 | transforms.RandomCrop(32, padding=4),
40 | transforms.RandomHorizontalFlip(),
41 | transforms.ToTensor(),
42 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
43 | ])
44 |
45 | transform_test = transforms.Compose([
46 | transforms.ToTensor(),
47 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
48 | ])
49 |
50 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,
51 | transform=transform_train)
52 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True,
53 | num_workers=2)
54 |
55 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,
56 | transform=transform_test)
57 | test_loader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
58 |
59 | # classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
60 |
61 | return train_loader, test_loader
62 |
63 |
64 | def get_ckpt_name(model='resnet', optimizer='sgd', lr=0.1, final_lr=0.1, momentum=0.9,
65 | beta1=0.9, beta2=0.999, gamma=1e-3):
66 | name = {
67 | 'sgd': 'lr{}-momentum{}'.format(lr, momentum),
68 | 'adagrad': 'lr{}'.format(lr),
69 | 'adam': 'lr{}-betas{}-{}'.format(lr, beta1, beta2),
70 | 'amsgrad': 'lr{}-betas{}-{}'.format(lr, beta1, beta2),
71 | 'adabound': 'lr{}-betas{}-{}-final_lr{}-gamma{}'.format(lr, beta1, beta2, final_lr, gamma),
72 | 'amsbound': 'lr{}-betas{}-{}-final_lr{}-gamma{}'.format(lr, beta1, beta2, final_lr, gamma),
73 | }[optimizer]
74 | return '{}-{}-{}'.format(model, optimizer, name)
75 |
76 |
77 | def load_checkpoint(ckpt_name):
78 | print('==> Resuming from checkpoint..')
79 | path = os.path.join('checkpoint', ckpt_name)
80 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
81 | assert os.path.exists(path), 'Error: checkpoint {} not found'.format(ckpt_name)
82 | return torch.load(ckpt_name)
83 |
84 |
85 | def build_model(args, device, ckpt=None):
86 | print('==> Building model..')
87 | net = {
88 | 'resnet': ResNet34,
89 | 'densenet': DenseNet121,
90 | }[args.model]()
91 | net = net.to(device)
92 | if device == 'cuda':
93 | net = torch.nn.DataParallel(net)
94 | cudnn.benchmark = True
95 |
96 | if ckpt:
97 | net.load_state_dict(ckpt['net'])
98 |
99 | return net
100 |
101 |
102 | def create_optimizer(args, model_params):
103 | if args.optim == 'sgd':
104 | return optim.SGD(model_params, args.lr, momentum=args.momentum,
105 | weight_decay=args.weight_decay)
106 | elif args.optim == 'adagrad':
107 | return optim.Adagrad(model_params, args.lr, weight_decay=args.weight_decay)
108 | elif args.optim == 'adam':
109 | return optim.Adam(model_params, args.lr, betas=(args.beta1, args.beta2),
110 | weight_decay=args.weight_decay)
111 | elif args.optim == 'amsgrad':
112 | return optim.Adam(model_params, args.lr, betas=(args.beta1, args.beta2),
113 | weight_decay=args.weight_decay, amsgrad=True)
114 | elif args.optim == 'adabound':
115 | return AdaBound(model_params, args.lr, betas=(args.beta1, args.beta2),
116 | final_lr=args.final_lr, gamma=args.gamma,
117 | weight_decay=args.weight_decay)
118 | else:
119 | assert args.optim == 'amsbound'
120 | return AdaBound(model_params, args.lr, betas=(args.beta1, args.beta2),
121 | final_lr=args.final_lr, gamma=args.gamma,
122 | weight_decay=args.weight_decay, amsbound=True)
123 |
124 |
125 | def train(net, epoch, device, data_loader, optimizer, criterion):
126 | print('\nEpoch: %d' % epoch)
127 | net.train()
128 | train_loss = 0
129 | correct = 0
130 | total = 0
131 | for batch_idx, (inputs, targets) in enumerate(data_loader):
132 | inputs, targets = inputs.to(device), targets.to(device)
133 | optimizer.zero_grad()
134 | outputs = net(inputs)
135 | loss = criterion(outputs, targets)
136 | loss.backward()
137 | optimizer.step()
138 |
139 | train_loss += loss.item()
140 | _, predicted = outputs.max(1)
141 | total += targets.size(0)
142 | correct += predicted.eq(targets).sum().item()
143 |
144 | accuracy = 100. * correct / total
145 | print('train acc %.3f' % accuracy)
146 |
147 | return accuracy
148 |
149 |
150 | def test(net, device, data_loader, criterion):
151 | net.eval()
152 | test_loss = 0
153 | correct = 0
154 | total = 0
155 | with torch.no_grad():
156 | for batch_idx, (inputs, targets) in enumerate(data_loader):
157 | inputs, targets = inputs.to(device), targets.to(device)
158 | outputs = net(inputs)
159 | loss = criterion(outputs, targets)
160 |
161 | test_loss += loss.item()
162 | _, predicted = outputs.max(1)
163 | total += targets.size(0)
164 | correct += predicted.eq(targets).sum().item()
165 |
166 | accuracy = 100. * correct / total
167 | print(' test acc %.3f' % accuracy)
168 |
169 | return accuracy
170 |
171 |
172 | def main():
173 | parser = get_parser()
174 | args = parser.parse_args()
175 |
176 | train_loader, test_loader = build_dataset()
177 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
178 |
179 | ckpt_name = get_ckpt_name(model=args.model, optimizer=args.optim, lr=args.lr,
180 | final_lr=args.final_lr, momentum=args.momentum,
181 | beta1=args.beta1, beta2=args.beta2, gamma=args.gamma)
182 | if args.resume:
183 | ckpt = load_checkpoint(ckpt_name)
184 | best_acc = ckpt['acc']
185 | start_epoch = ckpt['epoch']
186 | else:
187 | ckpt = None
188 | best_acc = 0
189 | start_epoch = -1
190 |
191 | net = build_model(args, device, ckpt=ckpt)
192 | criterion = nn.CrossEntropyLoss()
193 | optimizer = create_optimizer(args, net.parameters())
194 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.1,
195 | last_epoch=start_epoch)
196 |
197 | train_accuracies = []
198 | test_accuracies = []
199 |
200 | for epoch in range(start_epoch + 1, 200):
201 | scheduler.step()
202 | train_acc = train(net, epoch, device, train_loader, optimizer, criterion)
203 | test_acc = test(net, device, test_loader, criterion)
204 |
205 | # Save checkpoint.
206 | if test_acc > best_acc:
207 | print('Saving..')
208 | state = {
209 | 'net': net.state_dict(),
210 | 'acc': test_acc,
211 | 'epoch': epoch,
212 | }
213 | if not os.path.isdir('checkpoint'):
214 | os.mkdir('checkpoint')
215 | torch.save(state, os.path.join('checkpoint', ckpt_name))
216 | best_acc = test_acc
217 |
218 | train_accuracies.append(train_acc)
219 | test_accuracies.append(test_acc)
220 | if not os.path.isdir('curve'):
221 | os.mkdir('curve')
222 | torch.save({'train_acc': train_accuracies, 'test_acc': test_accuracies},
223 | os.path.join('curve', ckpt_name))
224 |
225 |
226 | if __name__ == '__main__':
227 | main()
228 |
--------------------------------------------------------------------------------
/demos/cifar10/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import *
2 | from .densenet import *
3 |
--------------------------------------------------------------------------------
/demos/cifar10/models/densenet.py:
--------------------------------------------------------------------------------
1 | """
2 | .. Densely Connected Convolutional Networks:
3 | https://arxiv.org/abs/1608.06993
4 | """
5 | import math
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class Bottleneck(nn.Module):
13 | def __init__(self, in_planes, growth_rate):
14 | super(Bottleneck, self).__init__()
15 | self.bn1 = nn.BatchNorm2d(in_planes)
16 | self.conv1 = nn.Conv2d(in_planes, 4 * growth_rate, kernel_size=1, bias=False)
17 | self.bn2 = nn.BatchNorm2d(4 * growth_rate)
18 | self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)
19 |
20 | def forward(self, x):
21 | out = self.conv1(F.relu(self.bn1(x)))
22 | out = self.conv2(F.relu(self.bn2(out)))
23 | out = torch.cat([out, x], 1)
24 | return out
25 |
26 |
27 | class Transition(nn.Module):
28 | def __init__(self, in_planes, out_planes):
29 | super(Transition, self).__init__()
30 | self.bn = nn.BatchNorm2d(in_planes)
31 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)
32 |
33 | def forward(self, x):
34 | out = self.conv(F.relu(self.bn(x)))
35 | out = F.avg_pool2d(out, 2)
36 | return out
37 |
38 |
39 | class DenseNet(nn.Module):
40 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10):
41 | super(DenseNet, self).__init__()
42 | self.growth_rate = growth_rate
43 |
44 | num_planes = 2 * growth_rate
45 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)
46 |
47 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0])
48 | num_planes += nblocks[0] * growth_rate
49 | out_planes = int(math.floor(num_planes * reduction))
50 | self.trans1 = Transition(num_planes, out_planes)
51 | num_planes = out_planes
52 |
53 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1])
54 | num_planes += nblocks[1] * growth_rate
55 | out_planes = int(math.floor(num_planes * reduction))
56 | self.trans2 = Transition(num_planes, out_planes)
57 | num_planes = out_planes
58 |
59 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2])
60 | num_planes += nblocks[2] * growth_rate
61 | out_planes = int(math.floor(num_planes * reduction))
62 | self.trans3 = Transition(num_planes, out_planes)
63 | num_planes = out_planes
64 |
65 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3])
66 | num_planes += nblocks[3] * growth_rate
67 |
68 | self.bn = nn.BatchNorm2d(num_planes)
69 | self.linear = nn.Linear(num_planes, num_classes)
70 |
71 | def _make_dense_layers(self, block, in_planes, nblock):
72 | layers = []
73 | for i in range(nblock):
74 | layers.append(block(in_planes, self.growth_rate))
75 | in_planes += self.growth_rate
76 | return nn.Sequential(*layers)
77 |
78 | def forward(self, x):
79 | out = self.conv1(x)
80 | out = self.trans1(self.dense1(out))
81 | out = self.trans2(self.dense2(out))
82 | out = self.trans3(self.dense3(out))
83 | out = self.dense4(out)
84 | out = F.avg_pool2d(F.relu(self.bn(out)), 4)
85 | out = out.view(out.size(0), -1)
86 | out = self.linear(out)
87 | return out
88 |
89 |
90 | def DenseNet121():
91 | return DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=32)
92 |
93 |
94 | def DenseNet169():
95 | return DenseNet(Bottleneck, [6, 12, 32, 32], growth_rate=32)
96 |
97 |
98 | def DenseNet201():
99 | return DenseNet(Bottleneck, [6, 12, 48, 32], growth_rate=32)
100 |
101 |
102 | def DenseNet161():
103 | return DenseNet(Bottleneck, [6, 12, 36, 24], growth_rate=48)
104 |
105 |
106 | def densenet_cifar():
107 | return DenseNet(Bottleneck, [6, 12, 24, 16], growth_rate=12)
108 |
109 |
110 | def test():
111 | net = densenet_cifar()
112 | x = torch.randn(1, 3, 32, 32)
113 | y = net(x)
114 | print(y)
115 |
116 | # test()
117 |
--------------------------------------------------------------------------------
/demos/cifar10/models/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | .. Deep Residual Learning for Image Recognition:
3 | https://arxiv.org/abs/1512.03385
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class BasicBlock(nn.Module):
11 | expansion = 1
12 |
13 | def __init__(self, in_planes, planes, stride=1):
14 | super(BasicBlock, self).__init__()
15 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1,
16 | bias=False)
17 | self.bn1 = nn.BatchNorm2d(planes)
18 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
19 | self.bn2 = nn.BatchNorm2d(planes)
20 |
21 | self.shortcut = nn.Sequential()
22 | if stride != 1 or in_planes != self.expansion * planes:
23 | self.shortcut = nn.Sequential(
24 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride,
25 | bias=False),
26 | nn.BatchNorm2d(self.expansion * planes)
27 | )
28 |
29 | def forward(self, x):
30 | out = F.relu(self.bn1(self.conv1(x)))
31 | out = self.bn2(self.conv2(out))
32 | out += self.shortcut(x)
33 | out = F.relu(out)
34 | return out
35 |
36 |
37 | class Bottleneck(nn.Module):
38 | expansion = 4
39 |
40 | def __init__(self, in_planes, planes, stride=1):
41 | super(Bottleneck, self).__init__()
42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
43 | self.bn1 = nn.BatchNorm2d(planes)
44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
45 | self.bn2 = nn.BatchNorm2d(planes)
46 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
47 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
48 |
49 | self.shortcut = nn.Sequential()
50 | if stride != 1 or in_planes != self.expansion * planes:
51 | self.shortcut = nn.Sequential(
52 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride,
53 | bias=False),
54 | nn.BatchNorm2d(self.expansion * planes)
55 | )
56 |
57 | def forward(self, x):
58 | out = F.relu(self.bn1(self.conv1(x)))
59 | out = F.relu(self.bn2(self.conv2(out)))
60 | out = self.bn3(self.conv3(out))
61 | out += self.shortcut(x)
62 | out = F.relu(out)
63 | return out
64 |
65 |
66 | class ResNet(nn.Module):
67 | def __init__(self, block, num_blocks, num_classes=10):
68 | super(ResNet, self).__init__()
69 | self.in_planes = 64
70 |
71 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
72 | self.bn1 = nn.BatchNorm2d(64)
73 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
74 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
75 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
76 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
77 | self.linear = nn.Linear(512 * block.expansion, num_classes)
78 |
79 | def _make_layer(self, block, planes, num_blocks, stride):
80 | strides = [stride] + [1] * (num_blocks - 1)
81 | layers = []
82 | for stride in strides:
83 | layers.append(block(self.in_planes, planes, stride))
84 | self.in_planes = planes * block.expansion
85 | return nn.Sequential(*layers)
86 |
87 | def forward(self, x):
88 | out = F.relu(self.bn1(self.conv1(x)))
89 | out = self.layer1(out)
90 | out = self.layer2(out)
91 | out = self.layer3(out)
92 | out = self.layer4(out)
93 | out = F.avg_pool2d(out, 4)
94 | out = out.view(out.size(0), -1)
95 | out = self.linear(out)
96 | return out
97 |
98 |
99 | def ResNet18():
100 | return ResNet(BasicBlock, [2, 2, 2, 2])
101 |
102 |
103 | def ResNet34():
104 | return ResNet(BasicBlock, [3, 4, 6, 3])
105 |
106 |
107 | def ResNet50():
108 | return ResNet(Bottleneck, [3, 4, 6, 3])
109 |
110 |
111 | def ResNet101():
112 | return ResNet(Bottleneck, [3, 4, 23, 3])
113 |
114 |
115 | def ResNet152():
116 | return ResNet(Bottleneck, [3, 8, 36, 3])
117 |
118 |
119 | def test():
120 | net = ResNet18()
121 | y = net(torch.randn(1, 3, 32, 32))
122 | print(y.size())
123 |
124 | # test()
125 |
--------------------------------------------------------------------------------
/demos/cifar10/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=0.4.1
2 | torchvision>=0.2.1
3 |
--------------------------------------------------------------------------------
/release.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python3 setup.py sdist bdist_wheel
3 | twine upload dist/*
4 |
5 | rm -rf build
6 | rm -rf dist
7 | rm -rf *.egg-info
8 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | __VERSION__ = '0.0.5'
4 |
5 | setup(name='adabound',
6 | version=__VERSION__,
7 | description='AdaBound optimization algorithm, build on PyTorch.',
8 | long_description=open("README.md").read(),
9 | long_description_content_type="text/markdown",
10 | keywords=['machine learning', 'deep learning'],
11 | classifiers=[
12 | 'Intended Audience :: Science/Research',
13 | 'Development Status :: 3 - Alpha',
14 | 'License :: OSI Approved :: Apache Software License',
15 | 'Programming Language :: Python :: 3.6',
16 | 'Programming Language :: Python :: 3.7',
17 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
18 | ],
19 | url='https://github.com/Luolc/AdaBound',
20 | author='Liangchen Luo',
21 | author_email='luolc.witty@gmail.com',
22 | license='Apache',
23 | packages=['adabound'],
24 | install_requires=[
25 | 'torch>=0.4.0',
26 | ],
27 | zip_safe=False,
28 | python_requires='>=3.6.0')
29 |
--------------------------------------------------------------------------------