├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── cifar_imagenet
├── .gitignore
├── LICENSE
├── README.md
├── cifar.py
├── fourstep.sh
├── imagenet.py
├── models
│ ├── __init__.py
│ ├── cifar
│ │ ├── __init__.py
│ │ ├── alexnet.py
│ │ ├── densenet.py
│ │ ├── preresnet.py
│ │ ├── resnet.py
│ │ ├── resnext.py
│ │ ├── vgg.py
│ │ └── wrn.py
│ └── imagenet
│ │ ├── __init__.py
│ │ └── resnext.py
├── recipes.md
└── utils
│ ├── __init__.py
│ ├── eval.py
│ ├── images
│ ├── cifar.png
│ └── imagenet.png
│ ├── logger.py
│ ├── misc.py
│ ├── radam.py
│ └── visualize.py
├── img
└── variance.png
├── language-model
├── README.md
├── eval_1bw.py
├── model_word_ada
│ ├── LM.py
│ ├── adaptive.py
│ ├── basic.py
│ ├── bnlstm.py
│ ├── dataset.py
│ ├── ddnet.py
│ ├── densenet.py
│ ├── ldnet.py
│ ├── radam.py
│ ├── resnet.py
│ └── utils.py
├── pre_word_ada
│ ├── encode_data2folder.py
│ └── gene_map.py
├── recipes.md
└── train_1bw.py
├── nmt
├── README.md
├── average_checkpoints.py
├── eval.sh
├── my_module
│ ├── __init__.py
│ ├── adam2.py
│ ├── linear_schedule.py
│ ├── novograd.py
│ ├── poly_schedule.py
│ └── radam.py
└── recipes.md
├── radam
├── __init__.py
└── radam.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.gitignore.io/api/python,pycharm,jupyternotebooks,visualstudiocode
3 | # Edit at https://www.gitignore.io/?templates=python,pycharm,jupyternotebooks,visualstudiocode
4 |
5 | ### JupyterNotebooks ###
6 | # gitignore template for Jupyter Notebooks
7 | # website: http://jupyter.org/
8 |
9 | .ipynb_checkpoints
10 | */.ipynb_checkpoints/*
11 |
12 | # IPython
13 | profile_default/
14 | ipython_config.py
15 |
16 | # Remove previous ipynb_checkpoints
17 | # git rm -r .ipynb_checkpoints/
18 |
19 | ### PyCharm ###
20 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
21 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
22 |
23 | # User-specific stuff
24 | .idea/**/workspace.xml
25 | .idea/**/tasks.xml
26 | .idea/**/usage.statistics.xml
27 | .idea/**/dictionaries
28 | .idea/**/shelf
29 |
30 | # Generated files
31 | .idea/**/contentModel.xml
32 |
33 | # Sensitive or high-churn files
34 | .idea/**/dataSources/
35 | .idea/**/dataSources.ids
36 | .idea/**/dataSources.local.xml
37 | .idea/**/sqlDataSources.xml
38 | .idea/**/dynamic.xml
39 | .idea/**/uiDesigner.xml
40 | .idea/**/dbnavigator.xml
41 |
42 | # Gradle
43 | .idea/**/gradle.xml
44 | .idea/**/libraries
45 |
46 | # Gradle and Maven with auto-import
47 | # When using Gradle or Maven with auto-import, you should exclude module files,
48 | # since they will be recreated, and may cause churn. Uncomment if using
49 | # auto-import.
50 | # .idea/modules.xml
51 | # .idea/*.iml
52 | # .idea/modules
53 | # *.iml
54 | # *.ipr
55 |
56 | # CMake
57 | cmake-build-*/
58 |
59 | # Mongo Explorer plugin
60 | .idea/**/mongoSettings.xml
61 |
62 | # File-based project format
63 | *.iws
64 |
65 | # IntelliJ
66 | out/
67 |
68 | # mpeltonen/sbt-idea plugin
69 | .idea_modules/
70 |
71 | # JIRA plugin
72 | atlassian-ide-plugin.xml
73 |
74 | # Cursive Clojure plugin
75 | .idea/replstate.xml
76 |
77 | # Crashlytics plugin (for Android Studio and IntelliJ)
78 | com_crashlytics_export_strings.xml
79 | crashlytics.properties
80 | crashlytics-build.properties
81 | fabric.properties
82 |
83 | # Editor-based Rest Client
84 | .idea/httpRequests
85 |
86 | # Android studio 3.1+ serialized cache file
87 | .idea/caches/build_file_checksums.ser
88 |
89 | ### PyCharm Patch ###
90 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
91 |
92 | # *.iml
93 | # modules.xml
94 | # .idea/misc.xml
95 | # *.ipr
96 |
97 | # Sonarlint plugin
98 | .idea/**/sonarlint/
99 |
100 | # SonarQube Plugin
101 | .idea/**/sonarIssues.xml
102 |
103 | # Markdown Navigator plugin
104 | .idea/**/markdown-navigator.xml
105 | .idea/**/markdown-navigator/
106 |
107 | ### Python ###
108 | # Byte-compiled / optimized / DLL files
109 | __pycache__/
110 | *.py[cod]
111 | *$py.class
112 |
113 | # C extensions
114 | *.so
115 |
116 | # Distribution / packaging
117 | .Python
118 | build/
119 | develop-eggs/
120 | dist/
121 | downloads/
122 | eggs/
123 | .eggs/
124 | lib/
125 | lib64/
126 | parts/
127 | sdist/
128 | var/
129 | wheels/
130 | pip-wheel-metadata/
131 | share/python-wheels/
132 | *.egg-info/
133 | .installed.cfg
134 | *.egg
135 | MANIFEST
136 |
137 | # PyInstaller
138 | # Usually these files are written by a python script from a template
139 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
140 | *.manifest
141 | *.spec
142 |
143 | # Installer logs
144 | pip-log.txt
145 | pip-delete-this-directory.txt
146 |
147 | # Unit test / coverage reports
148 | htmlcov/
149 | .tox/
150 | .nox/
151 | .coverage
152 | .coverage.*
153 | .cache
154 | nosetests.xml
155 | coverage.xml
156 | *.cover
157 | .hypothesis/
158 | .pytest_cache/
159 |
160 | # Translations
161 | *.mo
162 | *.pot
163 |
164 | # Scrapy stuff:
165 | .scrapy
166 |
167 | # Sphinx documentation
168 | docs/_build/
169 |
170 | # PyBuilder
171 | target/
172 |
173 | # pyenv
174 | .python-version
175 |
176 | # pipenv
177 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
178 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
179 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
180 | # install all needed dependencies.
181 | #Pipfile.lock
182 |
183 | # celery beat schedule file
184 | celerybeat-schedule
185 |
186 | # SageMath parsed files
187 | *.sage.py
188 |
189 | # Spyder project settings
190 | .spyderproject
191 | .spyproject
192 |
193 | # Rope project settings
194 | .ropeproject
195 |
196 | # Mr Developer
197 | .mr.developer.cfg
198 | .project
199 | .pydevproject
200 |
201 | # mkdocs documentation
202 | /site
203 |
204 | # mypy
205 | .mypy_cache/
206 | .dmypy.json
207 | dmypy.json
208 |
209 | # Pyre type checker
210 | .pyre/
211 |
212 | ### VisualStudioCode ###
213 | .vscode/*
214 | !.vscode/settings.json
215 | !.vscode/tasks.json
216 | !.vscode/launch.json
217 | !.vscode/extensions.json
218 |
219 | ### VisualStudioCode Patch ###
220 | # Ignore all local history of files
221 | .history
222 |
223 | # End of https://www.gitignore.io/api/python,pycharm,jupyternotebooks,visualstudiocode
224 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | install: pip install flake8
3 | script: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [2019] [Liyuan Liu]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://opensource.org/licenses/Apache-2.0)
2 | [](https://travis-ci.org/LiyuanLucasLiu/RAdam)
3 |
4 |
RAdam
5 | On the Variance of the Adaptive Learning Rate and Beyond
6 |
7 | We are in an early-release beta. Expect some adventures and rough edges.
8 |
9 | ## Table of Contents
10 |
11 | - [Introduction](#introduction)
12 | - [Motivation](#motivation)
13 | - [Questions and Discussions](#questions-and-discussions)
14 | - [Quick Start Guide](#quick-start-guide)
15 | - [Related Posts and Repos](#related-posts-and-repos)
16 | - [Citation](#citation)
17 |
18 | ## Introduction
19 | If warmup is the answer, what is the question?
20 |
21 | The learning rate warmup for Adam is a must-have trick for stable training in certain situations (or eps tuning). But the underlying mechanism is largely unknown. In our study, we suggest one fundamental cause is __the large variance of the adaptive learning rates__, and provide both theoretical and empirical support evidence.
22 |
23 | In addition to explaining __why we should use warmup__, we also propose __RAdam__, a theoretically sound variant of Adam.
24 |
25 | ## Motivation
26 |
27 | As shown in Figure 1, we assume that gradients follow a normal distribution (mean: \mu, variance: 1). The variance of the adaptive learning rate is simulated and plotted in Figure 1 (blue curve). We observe that the adaptive learning rate has a large variance in the early stage of training.
28 |
29 | 
30 |
31 | When using a Transformer for NMT, a warmup stage is usually required to avoid convergence problems (e.g., Adam-vanilla converges around 500 PPL in Figure 2, while Adam-warmup successfully converges under 10 PPL).
32 | In further explorations, we notice that, if we use additional 2000 samples to estimate the adaptive learning rate, the convergence problems are avoided (Adam-2k); or, if we increase the value of eps, the convergence problems are also relieved (Adam-eps).
33 |
34 | Therefore, we conjecture that the large variance in the early stage causes the convergence problem, and further propose Rectified Adam by analytically reducing the large variance. More details can be found in our [paper](https://arxiv.org/abs/1908.03265).
35 |
36 | ## Questions and Discussions
37 |
38 | ### Do I need to tune learning rate?
39 |
40 | Yes, the robustness of RAdam is not infinity. In our experiments, it works for a broader range of learning rates, but not all learning rates.
41 |
42 | ### Notes on Transformer (more discussions can be found in our [Transformer Clinic](https://github.com/LiyuanLucasLiu/Transformer-Clinic) project)
43 |
44 | __Choice of the Original Transformer.__ We choose the original Transformer as our main study object because, without warmup, it suffers from the most serious convergence problems in our experiments. With such serious problems, our controlled experiments can better verify our hypothesis (i.e., we demonstrate that Adam-2k / Adam-eps can avoid spurious local optima by minimal changes).
45 |
46 | __Sensitivity.__ We observe that the Transformer is sensitive to the architecture configuration, despite its efficiency and effectiveness. For example, by changing the position of the layer norm, the model may / may not require the warmup to get a good performance. Intuitively, since the gradient of the attention layer could be more sparse and the adaptive learning rates for smaller gradients have a larger variance, they are more sensitive. Nevertheless, we believe this problem deserves more in-depth analysis and is beyond the scope of our study.
47 |
48 | ### Why does warmup have a bigger impact on some models than others?
49 |
50 | Although the adaptive learning rate has a larger variance in the early stage, the exact magnitude is subject to the model design. Thus, the convergent problem could be more serious for some models/tasks than others. In our experiments, we observe that RAdam achieves consistent improvements over the vanilla Adam. It verifies the variance issue widely exists (since we can get better performance by fixing it).
51 |
52 | ### What if the gradient is not zero-meaned?
53 |
54 | As in Figure 1 (above), even if the gradient is not zero-meaned, the original adaptive learning rate still has a larger variance in the beginning, thus applying the rectification can help to stabilize the training.
55 |
56 | Another related concern is that, when the mean of the gradient is significantly larger than its variance, the magnitude of the "problematic" variance may not be very large (i.e., in Figure 1, when \mu equals to 10, the adaptive learning rate variance is relatively small and may not cause problems). We think it provides a possible explaination on why warmup have a bigger impact on some models than others. Still, we suggest that, in real-world applications, neural networks usually have some parts of parameters meet our assumption well (i.e., their gradient variance is larger than their gradient mean), and needs the rectification to stabilize the training.
57 |
58 | ### Why does SGD need warmup?
59 |
60 | To the best of our knowledge, the warmup heuristic is originally designed for large minibatch SGD [0], based on the intuition that the network changes rapidly in the early stage. However, we find that it __does not__ explain why Adam requires warmup. Notice that, Adam-2k uses the same large learning rate but with a better estimation of the adaptive learning rate can also avoid the convergence problems.
61 |
62 | The reason why sometimes warmup also helps SGD still lacks of theoretical support. FYI, when optimizing a simple 2-layer CNN with gradient descent, the thoery of [1] could be used to show the benifits of warmup. Specifically, the lr must be $O(cos \phi)$, where $\phi$ is the angle between the current weight and the ground true weight and $cos \phi$ could be very small due to high dimensional space and random initialization. And thus lr must be very small at the beginning to guarentee the convergence. $cos \phi$ however can be improved in the later stage, and thus the learning rate is allowed to be larger. Their theory somehow can justify why warmup is needed by gradient descend and neural networks. But it is still far-fetched for the real scenario.
63 |
64 | > [0] Goyal et al, Accurate, Large Minibatch SGD: Training Imagenet in 1 Hour, 2017
65 | > [1] Du et al, Gradient Descent Learns One-hidden-layer CNN: Don’t be Afraid of Spurious Local Minima, 2017
66 |
67 | ## Quick Start Guide
68 |
69 | 1. Directly replace the vanilla Adam with RAdam without changing any settings.
70 | 2. Further tune hyper-parameters (including the learning rate) for a better performance.
71 |
72 | Note that in our paper, our major contribution is __to identify why we need the warmup for Adam__. Although some researchers successfully improve their model performance (__[user comments](#user-comments)__), considering the difficulty of training NNs, directly plugging in RAdam __may not__ result in an immediate performance boost. Based on our experience, replacing __the vanilla Adam__ with RAdam usually results in a better performance; however, if __warmup has already been employed and tuned__ in the baseline method, it is necessary to also tune hyper-parameters for RAdam.
73 |
74 | ## Related Posts and Repos
75 |
76 | ### Unofficial Re-Implementations
77 | RAdam is very easy to implement, we provide PyTorch implementations here, while third party ones can be found at:
78 |
79 | [Keras Implementation](https://github.com/CyberZHG/keras-radam)
80 |
81 | [Keras Implementation](https://github.com/titu1994/keras_rectified_adam)
82 |
83 | [Julia implementation in Flux.jl](https://fluxml.ai/Flux.jl/stable/training/optimisers/#Flux.Optimise.RADAM)
84 |
85 | ### Unofficial Introduction & Mentions
86 |
87 | We provide a simple introduction in [Motivation](#motivation), and more details can be found in our [paper](https://arxiv.org/abs/1908.03265). There are some unofficial introductions available (with better writings), and they are listed here for reference only (contents/claims in our paper are more accurate):
88 |
89 | [Medium Post](https://medium.com/@lessw/new-state-of-the-art-ai-optimizer-rectified-adam-radam-5d854730807b)
90 | > [related Twitter Post](https://twitter.com/jeremyphoward/status/1162118545095852032?ref_src=twsrc%5Etfw)
91 |
92 | [CSDN Post (in Chinese)](https://blog.csdn.net/u014248127/article/details/99696029)
93 |
94 | ### User Comments
95 |
96 | We are happy to see that our algorithms are found to be useful by some users : -)
97 |
98 | "...I tested it on ImageNette and quickly got new high accuracy scores for the 5 and 20 epoch 128px leaderboard scores, so I know it works... https://forums.fast.ai/t/meet-radam-imo-the-new-state-of-the-art-ai-optimizer/52656
— Less Wright August 15, 2019
99 |
100 | Thought "sounds interesting, I'll give it a try" - top 5 are vanilla Adam, bottom 4 (I only have access to 4 GPUs) are RAdam... so far looking pretty promising! pic.twitter.com/irvJSeoVfx
— Hamish Dickson (@_mishy) August 16, 2019
101 |
102 | RAdam works great for me! It’s good to several % accuracy for free, but the biggest thing I like is the training stability. RAdam is way more stable! https://medium.com/@mgrankin/radam-works-great-for-me-344d37183943
— Grankin Mikhail August 17, 2019
103 |
104 | "... Also, I achieved higher accuracy results using the newly proposed RAdam optimization function....
105 | https://towardsdatascience.com/optimism-is-on-the-menu-a-recession-is-not-d87cce265b10
— Sameer Ahuja August 24, 2019
106 |
107 | "... Out-of-box RAdam implementation performs better than Adam and finetuned SGD... https://twitter.com/ukrdailo/status/1166265186920980480
— Alex Dailo August 27, 2019
108 |
109 | ## Citation
110 | Please cite the following paper if you found our model useful. Thanks!
111 |
112 | >Liyuan Liu , Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, and Jiawei Han (2020). On the Variance of the Adaptive Learning Rate and Beyond. the Eighth International Conference on Learning Representations.
113 |
114 | ```
115 | @inproceedings{liu2019radam,
116 | author = {Liu, Liyuan and Jiang, Haoming and He, Pengcheng and Chen, Weizhu and Liu, Xiaodong and Gao, Jianfeng and Han, Jiawei},
117 | booktitle = {Proceedings of the Eighth International Conference on Learning Representations (ICLR 2020)},
118 | month = {April},
119 | title = {On the Variance of the Adaptive Learning Rate and Beyond},
120 | year = {2020}
121 | }
122 |
123 | ```
124 |
--------------------------------------------------------------------------------
/cifar_imagenet/.gitignore:
--------------------------------------------------------------------------------
1 | data/*
2 | checkpoint/*
3 |
--------------------------------------------------------------------------------
/cifar_imagenet/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Wei Yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/cifar_imagenet/README.md:
--------------------------------------------------------------------------------
1 | # CIFAR & IMAGENET
2 |
3 | This folder is modified based on the pytorch classification project [original repo](https://github.com/bearpaw/pytorch-classification). For more details about this code base, please refer to the original repo.
4 | A training [recipe](/cifar_imagenet/recipes.md) is provided for image classification experiments.
5 |
6 |
--------------------------------------------------------------------------------
/cifar_imagenet/fourstep.sh:
--------------------------------------------------------------------------------
1 |
2 | ROOT="cps"
3 | for SEED in 1111 2222 3333 4444 5555
4 | do
5 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer radam4s --beta1 0.9 --beta2 0.999 --checkpoint $ROOT/cifar10/resnet-20-adam4s-01-$SEED --gpu-id 0 --lr 0.1 --model_name adam4s --manualSeed $SEED
6 |
7 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer radam4s --beta1 0.9 --beta2 0.999 --checkpoint $ROOT/cifar10/resnet-20-adam4s-ua-01-$SEED --gpu-id 0 --lr 0.1 --model_name adam4s_ua --update_all --manualSeed $SEED
8 |
9 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer radam4s --beta1 0.9 --beta2 0.999 --checkpoint $ROOT/cifar10/resnet-20-adam4s-ua-af-01-$SEED --gpu-id 0 --lr 0.1 --model_name adam4s_ua_af --update_all --additional_four --manualSeed $SEED
10 | done
11 |
--------------------------------------------------------------------------------
/cifar_imagenet/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiyuanLucasLiu/RAdam/d9fd30a337894c4003768561d45e8730dbd41333/cifar_imagenet/models/__init__.py
--------------------------------------------------------------------------------
/cifar_imagenet/models/cifar/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | """The models subpackage contains definitions for the following model for CIFAR10/CIFAR100
4 | architectures:
5 |
6 | - `AlexNet`_
7 | - `VGG`_
8 | - `ResNet`_
9 | - `SqueezeNet`_
10 | - `DenseNet`_
11 |
12 | You can construct a model with random weights by calling its constructor:
13 |
14 | .. code:: python
15 |
16 | import torchvision.models as models
17 | resnet18 = models.resnet18()
18 | alexnet = models.alexnet()
19 | squeezenet = models.squeezenet1_0()
20 | densenet = models.densenet_161()
21 |
22 | We provide pre-trained models for the ResNet variants and AlexNet, using the
23 | PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing
24 | ``pretrained=True``:
25 |
26 | .. code:: python
27 |
28 | import torchvision.models as models
29 | resnet18 = models.resnet18(pretrained=True)
30 | alexnet = models.alexnet(pretrained=True)
31 |
32 | ImageNet 1-crop error rates (224x224)
33 |
34 | ======================== ============= =============
35 | Network Top-1 error Top-5 error
36 | ======================== ============= =============
37 | ResNet-18 30.24 10.92
38 | ResNet-34 26.70 8.58
39 | ResNet-50 23.85 7.13
40 | ResNet-101 22.63 6.44
41 | ResNet-152 21.69 5.94
42 | Inception v3 22.55 6.44
43 | AlexNet 43.45 20.91
44 | VGG-11 30.98 11.37
45 | VGG-13 30.07 10.75
46 | VGG-16 28.41 9.62
47 | VGG-19 27.62 9.12
48 | SqueezeNet 1.0 41.90 19.58
49 | SqueezeNet 1.1 41.81 19.38
50 | Densenet-121 25.35 7.83
51 | Densenet-169 24.00 7.00
52 | Densenet-201 22.80 6.43
53 | Densenet-161 22.35 6.20
54 | ======================== ============= =============
55 |
56 |
57 | .. _AlexNet: https://arxiv.org/abs/1404.5997
58 | .. _VGG: https://arxiv.org/abs/1409.1556
59 | .. _ResNet: https://arxiv.org/abs/1512.03385
60 | .. _SqueezeNet: https://arxiv.org/abs/1602.07360
61 | .. _DenseNet: https://arxiv.org/abs/1608.06993
62 | """
63 |
64 | from .alexnet import *
65 | from .vgg import *
66 | from .resnet import *
67 | from .resnext import *
68 | from .wrn import *
69 | from .preresnet import *
70 | from .densenet import *
71 |
--------------------------------------------------------------------------------
/cifar_imagenet/models/cifar/alexnet.py:
--------------------------------------------------------------------------------
1 | '''AlexNet for CIFAR10. FC layers are removed. Paddings are adjusted.
2 | Without BN, the start learning rate should be 0.01
3 | (c) YANG, Wei
4 | '''
5 | import torch.nn as nn
6 |
7 |
8 | __all__ = ['alexnet']
9 |
10 |
11 | class AlexNet(nn.Module):
12 |
13 | def __init__(self, num_classes=10):
14 | super(AlexNet, self).__init__()
15 | self.features = nn.Sequential(
16 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5),
17 | nn.ReLU(inplace=True),
18 | nn.MaxPool2d(kernel_size=2, stride=2),
19 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
20 | nn.ReLU(inplace=True),
21 | nn.MaxPool2d(kernel_size=2, stride=2),
22 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
23 | nn.ReLU(inplace=True),
24 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
25 | nn.ReLU(inplace=True),
26 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
27 | nn.ReLU(inplace=True),
28 | nn.MaxPool2d(kernel_size=2, stride=2),
29 | )
30 | self.classifier = nn.Linear(256, num_classes)
31 |
32 | def forward(self, x):
33 | x = self.features(x)
34 | x = x.view(x.size(0), -1)
35 | x = self.classifier(x)
36 | return x
37 |
38 |
39 | def alexnet(**kwargs):
40 | r"""AlexNet model architecture from the
41 | `"One weird trick..." `_ paper.
42 | """
43 | model = AlexNet(**kwargs)
44 | return model
45 |
--------------------------------------------------------------------------------
/cifar_imagenet/models/cifar/densenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 |
7 | __all__ = ['densenet']
8 |
9 |
10 | from torch.autograd import Variable
11 |
12 | class Bottleneck(nn.Module):
13 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0):
14 | super(Bottleneck, self).__init__()
15 | planes = expansion * growthRate
16 | self.bn1 = nn.BatchNorm2d(inplanes)
17 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
18 | self.bn2 = nn.BatchNorm2d(planes)
19 | self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3,
20 | padding=1, bias=False)
21 | self.relu = nn.ReLU(inplace=True)
22 | self.dropRate = dropRate
23 |
24 | def forward(self, x):
25 | out = self.bn1(x)
26 | out = self.relu(out)
27 | out = self.conv1(out)
28 | out = self.bn2(out)
29 | out = self.relu(out)
30 | out = self.conv2(out)
31 | if self.dropRate > 0:
32 | out = F.dropout(out, p=self.dropRate, training=self.training)
33 |
34 | out = torch.cat((x, out), 1)
35 |
36 | return out
37 |
38 |
39 | class BasicBlock(nn.Module):
40 | def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0):
41 | super(BasicBlock, self).__init__()
42 | planes = expansion * growthRate
43 | self.bn1 = nn.BatchNorm2d(inplanes)
44 | self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3,
45 | padding=1, bias=False)
46 | self.relu = nn.ReLU(inplace=True)
47 | self.dropRate = dropRate
48 |
49 | def forward(self, x):
50 | out = self.bn1(x)
51 | out = self.relu(out)
52 | out = self.conv1(out)
53 | if self.dropRate > 0:
54 | out = F.dropout(out, p=self.dropRate, training=self.training)
55 |
56 | out = torch.cat((x, out), 1)
57 |
58 | return out
59 |
60 |
61 | class Transition(nn.Module):
62 | def __init__(self, inplanes, outplanes):
63 | super(Transition, self).__init__()
64 | self.bn1 = nn.BatchNorm2d(inplanes)
65 | self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1,
66 | bias=False)
67 | self.relu = nn.ReLU(inplace=True)
68 |
69 | def forward(self, x):
70 | out = self.bn1(x)
71 | out = self.relu(out)
72 | out = self.conv1(out)
73 | out = F.avg_pool2d(out, 2)
74 | return out
75 |
76 |
77 | class DenseNet(nn.Module):
78 |
79 | def __init__(self, depth=22, block=Bottleneck,
80 | dropRate=0, num_classes=10, growthRate=12, compressionRate=2):
81 | super(DenseNet, self).__init__()
82 |
83 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4'
84 | n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6
85 |
86 | self.growthRate = growthRate
87 | self.dropRate = dropRate
88 |
89 | # self.inplanes is a global variable used across multiple
90 | # helper functions
91 | self.inplanes = growthRate * 2
92 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1,
93 | bias=False)
94 | self.dense1 = self._make_denseblock(block, n)
95 | self.trans1 = self._make_transition(compressionRate)
96 | self.dense2 = self._make_denseblock(block, n)
97 | self.trans2 = self._make_transition(compressionRate)
98 | self.dense3 = self._make_denseblock(block, n)
99 | self.bn = nn.BatchNorm2d(self.inplanes)
100 | self.relu = nn.ReLU(inplace=True)
101 | self.avgpool = nn.AvgPool2d(8)
102 | self.fc = nn.Linear(self.inplanes, num_classes)
103 |
104 | # Weight initialization
105 | for m in self.modules():
106 | if isinstance(m, nn.Conv2d):
107 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
108 | m.weight.data.normal_(0, math.sqrt(2. / n))
109 | elif isinstance(m, nn.BatchNorm2d):
110 | m.weight.data.fill_(1)
111 | m.bias.data.zero_()
112 |
113 | def _make_denseblock(self, block, blocks):
114 | layers = []
115 | for i in range(blocks):
116 | # Currently we fix the expansion ratio as the default value
117 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate))
118 | self.inplanes += self.growthRate
119 |
120 | return nn.Sequential(*layers)
121 |
122 | def _make_transition(self, compressionRate):
123 | inplanes = self.inplanes
124 | outplanes = int(math.floor(self.inplanes // compressionRate))
125 | self.inplanes = outplanes
126 | return Transition(inplanes, outplanes)
127 |
128 |
129 | def forward(self, x):
130 | x = self.conv1(x)
131 |
132 | x = self.trans1(self.dense1(x))
133 | x = self.trans2(self.dense2(x))
134 | x = self.dense3(x)
135 | x = self.bn(x)
136 | x = self.relu(x)
137 |
138 | x = self.avgpool(x)
139 | x = x.view(x.size(0), -1)
140 | x = self.fc(x)
141 |
142 | return x
143 |
144 |
145 | def densenet(**kwargs):
146 | """
147 | Constructs a ResNet model.
148 | """
149 | return DenseNet(**kwargs)
--------------------------------------------------------------------------------
/cifar_imagenet/models/cifar/preresnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | '''Resnet for cifar dataset.
4 | Ported form
5 | https://github.com/facebook/fb.resnet.torch
6 | and
7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
8 | (c) YANG, Wei
9 | '''
10 | import torch.nn as nn
11 | import math
12 |
13 |
14 | __all__ = ['preresnet']
15 |
16 | def conv3x3(in_planes, out_planes, stride=1):
17 | "3x3 convolution with padding"
18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
19 | padding=1, bias=False)
20 |
21 |
22 | class BasicBlock(nn.Module):
23 | expansion = 1
24 |
25 | def __init__(self, inplanes, planes, stride=1, downsample=None):
26 | super(BasicBlock, self).__init__()
27 | self.bn1 = nn.BatchNorm2d(inplanes)
28 | self.relu = nn.ReLU(inplace=True)
29 | self.conv1 = conv3x3(inplanes, planes, stride)
30 | self.bn2 = nn.BatchNorm2d(planes)
31 | self.conv2 = conv3x3(planes, planes)
32 | self.downsample = downsample
33 | self.stride = stride
34 |
35 | def forward(self, x):
36 | residual = x
37 |
38 | out = self.bn1(x)
39 | out = self.relu(out)
40 | out = self.conv1(out)
41 |
42 | out = self.bn2(out)
43 | out = self.relu(out)
44 | out = self.conv2(out)
45 |
46 | if self.downsample is not None:
47 | residual = self.downsample(x)
48 |
49 | out += residual
50 |
51 | return out
52 |
53 |
54 | class Bottleneck(nn.Module):
55 | expansion = 4
56 |
57 | def __init__(self, inplanes, planes, stride=1, downsample=None):
58 | super(Bottleneck, self).__init__()
59 | self.bn1 = nn.BatchNorm2d(inplanes)
60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
61 | self.bn2 = nn.BatchNorm2d(planes)
62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
63 | padding=1, bias=False)
64 | self.bn3 = nn.BatchNorm2d(planes)
65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
66 | self.relu = nn.ReLU(inplace=True)
67 | self.downsample = downsample
68 | self.stride = stride
69 |
70 | def forward(self, x):
71 | residual = x
72 |
73 | out = self.bn1(x)
74 | out = self.relu(out)
75 | out = self.conv1(out)
76 |
77 | out = self.bn2(out)
78 | out = self.relu(out)
79 | out = self.conv2(out)
80 |
81 | out = self.bn3(out)
82 | out = self.relu(out)
83 | out = self.conv3(out)
84 |
85 | if self.downsample is not None:
86 | residual = self.downsample(x)
87 |
88 | out += residual
89 |
90 | return out
91 |
92 |
93 | class PreResNet(nn.Module):
94 |
95 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'):
96 | super(PreResNet, self).__init__()
97 | # Model type specifies number of layers for CIFAR-10 model
98 | if block_name.lower() == 'basicblock':
99 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
100 | n = (depth - 2) // 6
101 | block = BasicBlock
102 | elif block_name.lower() == 'bottleneck':
103 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
104 | n = (depth - 2) // 9
105 | block = Bottleneck
106 | else:
107 | raise ValueError('block_name shoule be Basicblock or Bottleneck')
108 |
109 | self.inplanes = 16
110 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
111 | bias=False)
112 | self.layer1 = self._make_layer(block, 16, n)
113 | self.layer2 = self._make_layer(block, 32, n, stride=2)
114 | self.layer3 = self._make_layer(block, 64, n, stride=2)
115 | self.bn = nn.BatchNorm2d(64 * block.expansion)
116 | self.relu = nn.ReLU(inplace=True)
117 | self.avgpool = nn.AvgPool2d(8)
118 | self.fc = nn.Linear(64 * block.expansion, num_classes)
119 |
120 | for m in self.modules():
121 | if isinstance(m, nn.Conv2d):
122 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
123 | m.weight.data.normal_(0, math.sqrt(2. / n))
124 | elif isinstance(m, nn.BatchNorm2d):
125 | m.weight.data.fill_(1)
126 | m.bias.data.zero_()
127 |
128 | def _make_layer(self, block, planes, blocks, stride=1):
129 | downsample = None
130 | if stride != 1 or self.inplanes != planes * block.expansion:
131 | downsample = nn.Sequential(
132 | nn.Conv2d(self.inplanes, planes * block.expansion,
133 | kernel_size=1, stride=stride, bias=False),
134 | )
135 |
136 | layers = []
137 | layers.append(block(self.inplanes, planes, stride, downsample))
138 | self.inplanes = planes * block.expansion
139 | for i in range(1, blocks):
140 | layers.append(block(self.inplanes, planes))
141 |
142 | return nn.Sequential(*layers)
143 |
144 | def forward(self, x):
145 | x = self.conv1(x)
146 |
147 | x = self.layer1(x) # 32x32
148 | x = self.layer2(x) # 16x16
149 | x = self.layer3(x) # 8x8
150 | x = self.bn(x)
151 | x = self.relu(x)
152 |
153 | x = self.avgpool(x)
154 | x = x.view(x.size(0), -1)
155 | x = self.fc(x)
156 |
157 | return x
158 |
159 |
160 | def preresnet(**kwargs):
161 | """
162 | Constructs a ResNet model.
163 | """
164 | return PreResNet(**kwargs)
165 |
--------------------------------------------------------------------------------
/cifar_imagenet/models/cifar/resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | '''Resnet for cifar dataset.
4 | Ported form
5 | https://github.com/facebook/fb.resnet.torch
6 | and
7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
8 | (c) YANG, Wei
9 | '''
10 | import torch.nn as nn
11 | import math
12 |
13 |
14 | __all__ = ['resnet']
15 |
16 | def conv3x3(in_planes, out_planes, stride=1):
17 | "3x3 convolution with padding"
18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
19 | padding=1, bias=False)
20 |
21 |
22 | class BasicBlock(nn.Module):
23 | expansion = 1
24 |
25 | def __init__(self, inplanes, planes, stride=1, downsample=None):
26 | super(BasicBlock, self).__init__()
27 | self.conv1 = conv3x3(inplanes, planes, stride)
28 | self.bn1 = nn.BatchNorm2d(planes)
29 | self.relu = nn.ReLU(inplace=True)
30 | self.conv2 = conv3x3(planes, planes)
31 | self.bn2 = nn.BatchNorm2d(planes)
32 | self.downsample = downsample
33 | self.stride = stride
34 |
35 | def forward(self, x):
36 | residual = x
37 |
38 | out = self.conv1(x)
39 | out = self.bn1(out)
40 | out = self.relu(out)
41 |
42 | out = self.conv2(out)
43 | out = self.bn2(out)
44 |
45 | if self.downsample is not None:
46 | residual = self.downsample(x)
47 |
48 | out += residual
49 | out = self.relu(out)
50 |
51 | return out
52 |
53 |
54 | class Bottleneck(nn.Module):
55 | expansion = 4
56 |
57 | def __init__(self, inplanes, planes, stride=1, downsample=None):
58 | super(Bottleneck, self).__init__()
59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
60 | self.bn1 = nn.BatchNorm2d(planes)
61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
62 | padding=1, bias=False)
63 | self.bn2 = nn.BatchNorm2d(planes)
64 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
65 | self.bn3 = nn.BatchNorm2d(planes * 4)
66 | self.relu = nn.ReLU(inplace=True)
67 | self.downsample = downsample
68 | self.stride = stride
69 |
70 | def forward(self, x):
71 | residual = x
72 |
73 | out = self.conv1(x)
74 | out = self.bn1(out)
75 | out = self.relu(out)
76 |
77 | out = self.conv2(out)
78 | out = self.bn2(out)
79 | out = self.relu(out)
80 |
81 | out = self.conv3(out)
82 | out = self.bn3(out)
83 |
84 | if self.downsample is not None:
85 | residual = self.downsample(x)
86 |
87 | out += residual
88 | out = self.relu(out)
89 |
90 | return out
91 |
92 |
93 | class ResNet(nn.Module):
94 |
95 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'):
96 | super(ResNet, self).__init__()
97 | # Model type specifies number of layers for CIFAR-10 model
98 | if block_name.lower() == 'basicblock':
99 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
100 | n = (depth - 2) // 6
101 | block = BasicBlock
102 | elif block_name.lower() == 'bottleneck':
103 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
104 | n = (depth - 2) // 9
105 | block = Bottleneck
106 | else:
107 | raise ValueError('block_name shoule be Basicblock or Bottleneck')
108 |
109 |
110 | self.inplanes = 16
111 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
112 | bias=False)
113 | self.bn1 = nn.BatchNorm2d(16)
114 | self.relu = nn.ReLU(inplace=True)
115 | self.layer1 = self._make_layer(block, 16, n)
116 | self.layer2 = self._make_layer(block, 32, n, stride=2)
117 | self.layer3 = self._make_layer(block, 64, n, stride=2)
118 | self.avgpool = nn.AvgPool2d(8)
119 | self.fc = nn.Linear(64 * block.expansion, num_classes)
120 |
121 | for m in self.modules():
122 | if isinstance(m, nn.Conv2d):
123 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
124 | m.weight.data.normal_(0, math.sqrt(2. / n))
125 | elif isinstance(m, nn.BatchNorm2d):
126 | m.weight.data.fill_(1)
127 | m.bias.data.zero_()
128 |
129 | def _make_layer(self, block, planes, blocks, stride=1):
130 | downsample = None
131 | if stride != 1 or self.inplanes != planes * block.expansion:
132 | downsample = nn.Sequential(
133 | nn.Conv2d(self.inplanes, planes * block.expansion,
134 | kernel_size=1, stride=stride, bias=False),
135 | nn.BatchNorm2d(planes * block.expansion),
136 | )
137 |
138 | layers = []
139 | layers.append(block(self.inplanes, planes, stride, downsample))
140 | self.inplanes = planes * block.expansion
141 | for i in range(1, blocks):
142 | layers.append(block(self.inplanes, planes))
143 |
144 | return nn.Sequential(*layers)
145 |
146 | def forward(self, x):
147 | x = self.conv1(x)
148 | x = self.bn1(x)
149 | x = self.relu(x) # 32x32
150 |
151 | x = self.layer1(x) # 32x32
152 | x = self.layer2(x) # 16x16
153 | x = self.layer3(x) # 8x8
154 |
155 | x = self.avgpool(x)
156 | x = x.view(x.size(0), -1)
157 | x = self.fc(x)
158 |
159 | return x
160 |
161 |
162 | def resnet(**kwargs):
163 | """
164 | Constructs a ResNet model.
165 | """
166 | return ResNet(**kwargs)
167 |
--------------------------------------------------------------------------------
/cifar_imagenet/models/cifar/resnext.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | """
3 | Creates a ResNeXt Model as defined in:
4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016).
5 | Aggregated residual transformations for deep neural networks.
6 | arXiv preprint arXiv:1611.05431.
7 | import from https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py
8 | """
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.nn import init
12 |
13 | __all__ = ['resnext']
14 |
15 | class ResNeXtBottleneck(nn.Module):
16 | """
17 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua)
18 | """
19 | def __init__(self, in_channels, out_channels, stride, cardinality, widen_factor):
20 | """ Constructor
21 | Args:
22 | in_channels: input channel dimensionality
23 | out_channels: output channel dimensionality
24 | stride: conv stride. Replaces pooling layer.
25 | cardinality: num of convolution groups.
26 | widen_factor: factor to reduce the input dimensionality before convolution.
27 | """
28 | super(ResNeXtBottleneck, self).__init__()
29 | D = cardinality * out_channels // widen_factor
30 | self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False)
31 | self.bn_reduce = nn.BatchNorm2d(D)
32 | self.conv_conv = nn.Conv2d(D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
33 | self.bn = nn.BatchNorm2d(D)
34 | self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False)
35 | self.bn_expand = nn.BatchNorm2d(out_channels)
36 |
37 | self.shortcut = nn.Sequential()
38 | if in_channels != out_channels:
39 | self.shortcut.add_module('shortcut_conv', nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False))
40 | self.shortcut.add_module('shortcut_bn', nn.BatchNorm2d(out_channels))
41 |
42 | def forward(self, x):
43 | bottleneck = self.conv_reduce.forward(x)
44 | bottleneck = F.relu(self.bn_reduce.forward(bottleneck), inplace=True)
45 | bottleneck = self.conv_conv.forward(bottleneck)
46 | bottleneck = F.relu(self.bn.forward(bottleneck), inplace=True)
47 | bottleneck = self.conv_expand.forward(bottleneck)
48 | bottleneck = self.bn_expand.forward(bottleneck)
49 | residual = self.shortcut.forward(x)
50 | return F.relu(residual + bottleneck, inplace=True)
51 |
52 |
53 | class CifarResNeXt(nn.Module):
54 | """
55 | ResNext optimized for the Cifar dataset, as specified in
56 | https://arxiv.org/pdf/1611.05431.pdf
57 | """
58 | def __init__(self, cardinality, depth, num_classes, widen_factor=4, dropRate=0):
59 | """ Constructor
60 | Args:
61 | cardinality: number of convolution groups.
62 | depth: number of layers.
63 | num_classes: number of classes
64 | widen_factor: factor to adjust the channel dimensionality
65 | """
66 | super(CifarResNeXt, self).__init__()
67 | self.cardinality = cardinality
68 | self.depth = depth
69 | self.block_depth = (self.depth - 2) // 9
70 | self.widen_factor = widen_factor
71 | self.num_classes = num_classes
72 | self.output_size = 64
73 | self.stages = [64, 64 * self.widen_factor, 128 * self.widen_factor, 256 * self.widen_factor]
74 |
75 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
76 | self.bn_1 = nn.BatchNorm2d(64)
77 | self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1)
78 | self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2)
79 | self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2)
80 | self.classifier = nn.Linear(1024, num_classes)
81 | init.kaiming_normal(self.classifier.weight)
82 |
83 | for key in self.state_dict():
84 | if key.split('.')[-1] == 'weight':
85 | if 'conv' in key:
86 | init.kaiming_normal(self.state_dict()[key], mode='fan_out')
87 | if 'bn' in key:
88 | self.state_dict()[key][...] = 1
89 | elif key.split('.')[-1] == 'bias':
90 | self.state_dict()[key][...] = 0
91 |
92 | def block(self, name, in_channels, out_channels, pool_stride=2):
93 | """ Stack n bottleneck modules where n is inferred from the depth of the network.
94 | Args:
95 | name: string name of the current block.
96 | in_channels: number of input channels
97 | out_channels: number of output channels
98 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
99 | Returns: a Module consisting of n sequential bottlenecks.
100 | """
101 | block = nn.Sequential()
102 | for bottleneck in range(self.block_depth):
103 | name_ = '%s_bottleneck_%d' % (name, bottleneck)
104 | if bottleneck == 0:
105 | block.add_module(name_, ResNeXtBottleneck(in_channels, out_channels, pool_stride, self.cardinality,
106 | self.widen_factor))
107 | else:
108 | block.add_module(name_,
109 | ResNeXtBottleneck(out_channels, out_channels, 1, self.cardinality, self.widen_factor))
110 | return block
111 |
112 | def forward(self, x):
113 | x = self.conv_1_3x3.forward(x)
114 | x = F.relu(self.bn_1.forward(x), inplace=True)
115 | x = self.stage_1.forward(x)
116 | x = self.stage_2.forward(x)
117 | x = self.stage_3.forward(x)
118 | x = F.avg_pool2d(x, 8, 1)
119 | x = x.view(-1, 1024)
120 | return self.classifier(x)
121 |
122 | def resnext(**kwargs):
123 | """Constructs a ResNeXt.
124 | """
125 | model = CifarResNeXt(**kwargs)
126 | return model
--------------------------------------------------------------------------------
/cifar_imagenet/models/cifar/vgg.py:
--------------------------------------------------------------------------------
1 | '''VGG for CIFAR10. FC layers are removed.
2 | (c) YANG, Wei
3 | '''
4 | import torch.nn as nn
5 | import torch.utils.model_zoo as model_zoo
6 | import math
7 |
8 |
9 | __all__ = [
10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
11 | 'vgg19_bn', 'vgg19',
12 | ]
13 |
14 |
15 | model_urls = {
16 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
17 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
18 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
19 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
20 | }
21 |
22 |
23 | class VGG(nn.Module):
24 |
25 | def __init__(self, features, num_classes=1000):
26 | super(VGG, self).__init__()
27 | self.features = features
28 | self.classifier = nn.Linear(512, num_classes)
29 | self._initialize_weights()
30 |
31 | def forward(self, x):
32 | x = self.features(x)
33 | x = x.view(x.size(0), -1)
34 | x = self.classifier(x)
35 | return x
36 |
37 | def _initialize_weights(self):
38 | for m in self.modules():
39 | if isinstance(m, nn.Conv2d):
40 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
41 | m.weight.data.normal_(0, math.sqrt(2. / n))
42 | if m.bias is not None:
43 | m.bias.data.zero_()
44 | elif isinstance(m, nn.BatchNorm2d):
45 | m.weight.data.fill_(1)
46 | m.bias.data.zero_()
47 | elif isinstance(m, nn.Linear):
48 | n = m.weight.size(1)
49 | m.weight.data.normal_(0, 0.01)
50 | m.bias.data.zero_()
51 |
52 |
53 | def make_layers(cfg, batch_norm=False):
54 | layers = []
55 | in_channels = 3
56 | for v in cfg:
57 | if v == 'M':
58 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
59 | else:
60 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
61 | if batch_norm:
62 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
63 | else:
64 | layers += [conv2d, nn.ReLU(inplace=True)]
65 | in_channels = v
66 | return nn.Sequential(*layers)
67 |
68 |
69 | cfg = {
70 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
71 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
72 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
73 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
74 | }
75 |
76 |
77 | def vgg11(**kwargs):
78 | """VGG 11-layer model (configuration "A")
79 |
80 | Args:
81 | pretrained (bool): If True, returns a model pre-trained on ImageNet
82 | """
83 | model = VGG(make_layers(cfg['A']), **kwargs)
84 | return model
85 |
86 |
87 | def vgg11_bn(**kwargs):
88 | """VGG 11-layer model (configuration "A") with batch normalization"""
89 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
90 | return model
91 |
92 |
93 | def vgg13(**kwargs):
94 | """VGG 13-layer model (configuration "B")
95 |
96 | Args:
97 | pretrained (bool): If True, returns a model pre-trained on ImageNet
98 | """
99 | model = VGG(make_layers(cfg['B']), **kwargs)
100 | return model
101 |
102 |
103 | def vgg13_bn(**kwargs):
104 | """VGG 13-layer model (configuration "B") with batch normalization"""
105 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
106 | return model
107 |
108 |
109 | def vgg16(**kwargs):
110 | """VGG 16-layer model (configuration "D")
111 |
112 | Args:
113 | pretrained (bool): If True, returns a model pre-trained on ImageNet
114 | """
115 | model = VGG(make_layers(cfg['D']), **kwargs)
116 | return model
117 |
118 |
119 | def vgg16_bn(**kwargs):
120 | """VGG 16-layer model (configuration "D") with batch normalization"""
121 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
122 | return model
123 |
124 |
125 | def vgg19(**kwargs):
126 | """VGG 19-layer model (configuration "E")
127 |
128 | Args:
129 | pretrained (bool): If True, returns a model pre-trained on ImageNet
130 | """
131 | model = VGG(make_layers(cfg['E']), **kwargs)
132 | return model
133 |
134 |
135 | def vgg19_bn(**kwargs):
136 | """VGG 19-layer model (configuration 'E') with batch normalization"""
137 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
138 | return model
139 |
--------------------------------------------------------------------------------
/cifar_imagenet/models/cifar/wrn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | __all__ = ['wrn']
7 |
8 | class BasicBlock(nn.Module):
9 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
10 | super(BasicBlock, self).__init__()
11 | self.bn1 = nn.BatchNorm2d(in_planes)
12 | self.relu1 = nn.ReLU(inplace=True)
13 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
14 | padding=1, bias=False)
15 | self.bn2 = nn.BatchNorm2d(out_planes)
16 | self.relu2 = nn.ReLU(inplace=True)
17 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
18 | padding=1, bias=False)
19 | self.droprate = dropRate
20 | self.equalInOut = (in_planes == out_planes)
21 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
22 | padding=0, bias=False) or None
23 | def forward(self, x):
24 | if not self.equalInOut:
25 | x = self.relu1(self.bn1(x))
26 | else:
27 | out = self.relu1(self.bn1(x))
28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
29 | if self.droprate > 0:
30 | out = F.dropout(out, p=self.droprate, training=self.training)
31 | out = self.conv2(out)
32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
33 |
34 | class NetworkBlock(nn.Module):
35 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
36 | super(NetworkBlock, self).__init__()
37 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
38 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
39 | layers = []
40 | for i in range(nb_layers):
41 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
42 | return nn.Sequential(*layers)
43 | def forward(self, x):
44 | return self.layer(x)
45 |
46 | class WideResNet(nn.Module):
47 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
48 | super(WideResNet, self).__init__()
49 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
50 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
51 | n = (depth - 4) // 6
52 | block = BasicBlock
53 | # 1st conv before any network block
54 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
55 | padding=1, bias=False)
56 | # 1st block
57 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
58 | # 2nd block
59 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
60 | # 3rd block
61 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
62 | # global average pooling and classifier
63 | self.bn1 = nn.BatchNorm2d(nChannels[3])
64 | self.relu = nn.ReLU(inplace=True)
65 | self.fc = nn.Linear(nChannels[3], num_classes)
66 | self.nChannels = nChannels[3]
67 |
68 | for m in self.modules():
69 | if isinstance(m, nn.Conv2d):
70 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
71 | m.weight.data.normal_(0, math.sqrt(2. / n))
72 | elif isinstance(m, nn.BatchNorm2d):
73 | m.weight.data.fill_(1)
74 | m.bias.data.zero_()
75 | elif isinstance(m, nn.Linear):
76 | m.bias.data.zero_()
77 |
78 | def forward(self, x):
79 | out = self.conv1(x)
80 | out = self.block1(out)
81 | out = self.block2(out)
82 | out = self.block3(out)
83 | out = self.relu(self.bn1(out))
84 | out = F.avg_pool2d(out, 8)
85 | out = out.view(-1, self.nChannels)
86 | return self.fc(out)
87 |
88 | def wrn(**kwargs):
89 | """
90 | Constructs a Wide Residual Networks.
91 | """
92 | model = WideResNet(**kwargs)
93 | return model
94 |
--------------------------------------------------------------------------------
/cifar_imagenet/models/imagenet/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .resnext import *
4 |
--------------------------------------------------------------------------------
/cifar_imagenet/models/imagenet/resnext.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | """
3 | Creates a ResNeXt Model as defined in:
4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016).
5 | Aggregated residual transformations for deep neural networks.
6 | arXiv preprint arXiv:1611.05431.
7 | import from https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua
8 | """
9 | import math
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch.nn import init
13 | import torch
14 |
15 | __all__ = ['resnext50', 'resnext101', 'resnext152']
16 |
17 | class Bottleneck(nn.Module):
18 | """
19 | RexNeXt bottleneck type C
20 | """
21 | expansion = 4
22 |
23 | def __init__(self, inplanes, planes, baseWidth, cardinality, stride=1, downsample=None):
24 | """ Constructor
25 | Args:
26 | inplanes: input channel dimensionality
27 | planes: output channel dimensionality
28 | baseWidth: base width.
29 | cardinality: num of convolution groups.
30 | stride: conv stride. Replaces pooling layer.
31 | """
32 | super(Bottleneck, self).__init__()
33 |
34 | D = int(math.floor(planes * (baseWidth / 64)))
35 | C = cardinality
36 |
37 | self.conv1 = nn.Conv2d(inplanes, D*C, kernel_size=1, stride=1, padding=0, bias=False)
38 | self.bn1 = nn.BatchNorm2d(D*C)
39 | self.conv2 = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False)
40 | self.bn2 = nn.BatchNorm2d(D*C)
41 | self.conv3 = nn.Conv2d(D*C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False)
42 | self.bn3 = nn.BatchNorm2d(planes * 4)
43 | self.relu = nn.ReLU(inplace=True)
44 |
45 | self.downsample = downsample
46 |
47 | def forward(self, x):
48 | residual = x
49 |
50 | out = self.conv1(x)
51 | out = self.bn1(out)
52 | out = self.relu(out)
53 |
54 | out = self.conv2(out)
55 | out = self.bn2(out)
56 | out = self.relu(out)
57 |
58 | out = self.conv3(out)
59 | out = self.bn3(out)
60 |
61 | if self.downsample is not None:
62 | residual = self.downsample(x)
63 |
64 | out += residual
65 | out = self.relu(out)
66 |
67 | return out
68 |
69 |
70 | class ResNeXt(nn.Module):
71 | """
72 | ResNext optimized for the ImageNet dataset, as specified in
73 | https://arxiv.org/pdf/1611.05431.pdf
74 | """
75 | def __init__(self, baseWidth, cardinality, layers, num_classes):
76 | """ Constructor
77 | Args:
78 | baseWidth: baseWidth for ResNeXt.
79 | cardinality: number of convolution groups.
80 | layers: config of layers, e.g., [3, 4, 6, 3]
81 | num_classes: number of classes
82 | """
83 | super(ResNeXt, self).__init__()
84 | block = Bottleneck
85 |
86 | self.cardinality = cardinality
87 | self.baseWidth = baseWidth
88 | self.num_classes = num_classes
89 | self.inplanes = 64
90 | self.output_size = 64
91 |
92 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
93 | self.bn1 = nn.BatchNorm2d(64)
94 | self.relu = nn.ReLU(inplace=True)
95 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
96 | self.layer1 = self._make_layer(block, 64, layers[0])
97 | self.layer2 = self._make_layer(block, 128, layers[1], 2)
98 | self.layer3 = self._make_layer(block, 256, layers[2], 2)
99 | self.layer4 = self._make_layer(block, 512, layers[3], 2)
100 | self.avgpool = nn.AvgPool2d(7)
101 | self.fc = nn.Linear(512 * block.expansion, num_classes)
102 |
103 | for m in self.modules():
104 | if isinstance(m, nn.Conv2d):
105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
106 | m.weight.data.normal_(0, math.sqrt(2. / n))
107 | elif isinstance(m, nn.BatchNorm2d):
108 | m.weight.data.fill_(1)
109 | m.bias.data.zero_()
110 |
111 | def _make_layer(self, block, planes, blocks, stride=1):
112 | """ Stack n bottleneck modules where n is inferred from the depth of the network.
113 | Args:
114 | block: block type used to construct ResNext
115 | planes: number of output channels (need to multiply by block.expansion)
116 | blocks: number of blocks to be built
117 | stride: factor to reduce the spatial dimensionality in the first bottleneck of the block.
118 | Returns: a Module consisting of n sequential bottlenecks.
119 | """
120 | downsample = None
121 | if stride != 1 or self.inplanes != planes * block.expansion:
122 | downsample = nn.Sequential(
123 | nn.Conv2d(self.inplanes, planes * block.expansion,
124 | kernel_size=1, stride=stride, bias=False),
125 | nn.BatchNorm2d(planes * block.expansion),
126 | )
127 |
128 | layers = []
129 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, stride, downsample))
130 | self.inplanes = planes * block.expansion
131 | for i in range(1, blocks):
132 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality))
133 |
134 | return nn.Sequential(*layers)
135 |
136 | def forward(self, x):
137 | x = self.conv1(x)
138 | x = self.bn1(x)
139 | x = self.relu(x)
140 | x = self.maxpool1(x)
141 | x = self.layer1(x)
142 | x = self.layer2(x)
143 | x = self.layer3(x)
144 | x = self.layer4(x)
145 | x = self.avgpool(x)
146 | x = x.view(x.size(0), -1)
147 | x = self.fc(x)
148 |
149 | return x
150 |
151 |
152 | def resnext50(baseWidth, cardinality):
153 | """
154 | Construct ResNeXt-50.
155 | """
156 | model = ResNeXt(baseWidth, cardinality, [3, 4, 6, 3], 1000)
157 | return model
158 |
159 |
160 | def resnext101(baseWidth, cardinality):
161 | """
162 | Construct ResNeXt-101.
163 | """
164 | model = ResNeXt(baseWidth, cardinality, [3, 4, 23, 3], 1000)
165 | return model
166 |
167 |
168 | def resnext152(baseWidth, cardinality):
169 | """
170 | Construct ResNeXt-152.
171 | """
172 | model = ResNeXt(baseWidth, cardinality, [3, 8, 36, 3], 1000)
173 | return model
174 |
--------------------------------------------------------------------------------
/cifar_imagenet/recipes.md:
--------------------------------------------------------------------------------
1 | # SGD
2 |
3 | ```
4 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-sgd-01 --gpu-id 0 --model_name sgd_01
5 |
6 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-sgd-003 --gpu-id 0 --model_name sgd_003 --lr 0.03
7 |
8 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-sgd-001 --gpu-id 0 --model_name sgd_001 --lr 0.01
9 |
10 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-sgd-0003 --gpu-id 0 --model_name sgd_0003 --lr 0.003
11 | ```
12 |
13 | # Vanilla Adam
14 |
15 | ```
16 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint /cps/gadam/checkpoints/cifar10/resnet-20-adam-01 --gpu-id 0 --model_name adam_01
17 |
18 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint /cps/gadam/checkpoints/cifar10/resnet-20-adam-003 --gpu-id 0 --model_name adam_003 --lr 0.03
19 |
20 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint /cps/gadam/checkpoints/cifar10/resnet-20-adam-001 --gpu-id 0 --model_name adam_001 --lr 0.01
21 |
22 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint /cps/gadam/checkpoints/cifar10/resnet-20-adam-0003 --gpu-id 0 --model_name adam_0003 --lr 0.003
23 | ```
24 |
25 | # RAdam experiments
26 |
27 | ```
28 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer radam --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-radam-01 --gpu-id 0 --model_name radam_01
29 |
30 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer radam --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-radam-003 --gpu-id 0 --model_name radam_003 --lr 0.03
31 |
32 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer radam --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-radam-001 --gpu-id 0 --model_name radam_001 --lr 0.01
33 |
34 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer radam --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-radam-0003 --gpu-id 0 --model_name radam_0003 --lr 0.003
35 | ```
36 |
37 | # Adam with 100 warmup
38 |
39 | ```
40 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-01 --gpu-id 0 --warmup 100 --model_name adam_100_01
41 |
42 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-003 --gpu-id 0 --warmup 100 --model_name adam_100_003 --lr 0.03
43 |
44 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-001 --gpu-id 0 --warmup 100 --model_name adam_100_001 --lr 0.01
45 |
46 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-0003 --gpu-id 0 --warmup 100 --model_name adam_100_0003 --lr 0.003
47 | ```
48 |
49 | # Adam with 200 warmup
50 |
51 | ```
52 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-01 --gpu-id 0 --warmup 200 --model_name adam_200_01
53 |
54 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-003 --gpu-id 0 --warmup 200 --model_name adam_200_003 --lr 0.03
55 |
56 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-001 --gpu-id 0 --warmup 200 --model_name adam_200_001 --lr 0.01
57 |
58 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-0003 --gpu-id 0 --warmup 200 --model_name adam_200_0003 --lr 0.003
59 | ```
60 |
61 | # Adam with 500 warmup
62 |
63 | ```
64 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-01 --gpu-id 0 --warmup 500 --model_name adam_500_01
65 |
66 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-003 --gpu-id 0 --warmup 500 --model_name adam_500_003 --lr 0.03
67 |
68 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-001 --gpu-id 0 --warmup 500 --model_name adam_500_001 --lr 0.01
69 |
70 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-0003 --gpu-id 0 --warmup 500 --model_name adam_500_0003 --lr 0.003
71 | ```
72 |
73 | # Adam with 1000 warmup
74 |
75 | ```
76 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-01 --gpu-id 0 --warmup 1000 --model_name adam_1000_01
77 |
78 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-003 --gpu-id 0 --warmup 1000 --model_name adam_1000_003 --lr 0.03
79 |
80 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-001 --gpu-id 0 --warmup 1000 --model_name adam_1000_001 --lr 0.01
81 |
82 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-0003 --gpu-id 0 --warmup 1000 --model_name adam_1000_0003 --lr 0.003
83 | ```
84 |
85 | # ImageNet
86 |
87 | ```
88 | python imagenet.py -j 16 -a resnet18 --data /data/ILSVRC2012/ --epochs 90 --schedule 31 61 --gamma 0.1 -c ./cps/imagenet/resnet18_radam_0003 --model_name radam_0003 --optimizer radam --lr 0.003 --beta1 0.9 --beta2 0.999
89 |
90 | python imagenet.py -j 16 -a resnet18 --data /data/ILSVRC2012/ --epochs 90 --schedule 31 61 --gamma 0.1 -c ./cps/imagenet/resnet18_radam_0005 --model_name radam_0005 --optimizer radam --lr 0.005 --beta1 0.9 --beta2 0.999
91 | ```
92 |
--------------------------------------------------------------------------------
/cifar_imagenet/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Useful utils
2 | """
3 | from .misc import *
4 | from .logger import *
5 | from .visualize import *
6 | from .eval import *
7 |
8 | # progress bar
9 | import os, sys
10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress"))
11 | from progress.bar import Bar as Bar
--------------------------------------------------------------------------------
/cifar_imagenet/utils/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 |
3 | __all__ = ['accuracy']
4 |
5 | def accuracy(output, target, topk=(1,)):
6 | """Computes the precision@k for the specified values of k"""
7 | maxk = max(topk)
8 | batch_size = target.size(0)
9 |
10 | _, pred = output.topk(maxk, 1, True, True)
11 | pred = pred.t()
12 | correct = pred.eq(target.view(1, -1).expand_as(pred))
13 |
14 | res = []
15 | for k in topk:
16 | correct_k = correct[:k].view(-1).float().sum(0)
17 | res.append(correct_k.mul_(100.0 / batch_size))
18 | return res
--------------------------------------------------------------------------------
/cifar_imagenet/utils/images/cifar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiyuanLucasLiu/RAdam/d9fd30a337894c4003768561d45e8730dbd41333/cifar_imagenet/utils/images/cifar.png
--------------------------------------------------------------------------------
/cifar_imagenet/utils/images/imagenet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiyuanLucasLiu/RAdam/d9fd30a337894c4003768561d45e8730dbd41333/cifar_imagenet/utils/images/imagenet.png
--------------------------------------------------------------------------------
/cifar_imagenet/utils/logger.py:
--------------------------------------------------------------------------------
1 | # A simple torch style logger
2 | # (C) Wei YANG 2017
3 | from __future__ import absolute_import
4 | import matplotlib as mpl
5 | mpl.use('Agg')
6 |
7 | import matplotlib.pyplot as plt
8 | import os
9 | import sys
10 | import numpy as np
11 |
12 | __all__ = ['Logger', 'LoggerMonitor', 'savefig']
13 |
14 | def savefig(fname, dpi=None):
15 | dpi = 150 if dpi == None else dpi
16 | plt.savefig(fname, dpi=dpi)
17 |
18 | def plot_overlap(logger, names=None):
19 | names = logger.names if names == None else names
20 | numbers = logger.numbers
21 | for _, name in enumerate(names):
22 | x = np.arange(len(numbers[name]))
23 | plt.plot(x, np.asarray(numbers[name]))
24 | return [logger.title + '(' + name + ')' for name in names]
25 |
26 | class Logger(object):
27 | '''Save training process to log file with simple plot function.'''
28 | def __init__(self, fpath, title=None, resume=False):
29 | self.file = None
30 | self.resume = resume
31 | self.title = '' if title == None else title
32 | if fpath is not None:
33 | if resume:
34 | self.file = open(fpath, 'r')
35 | name = self.file.readline()
36 | self.names = name.rstrip().split('\t')
37 | self.numbers = {}
38 | for _, name in enumerate(self.names):
39 | self.numbers[name] = []
40 |
41 | for numbers in self.file:
42 | numbers = numbers.rstrip().split('\t')
43 | for i in range(0, len(numbers)):
44 | self.numbers[self.names[i]].append(numbers[i])
45 | self.file.close()
46 | self.file = open(fpath, 'a')
47 | else:
48 | self.file = open(fpath, 'w')
49 |
50 | def set_names(self, names):
51 | if self.resume:
52 | pass
53 | # initialize numbers as empty list
54 | self.numbers = {}
55 | self.names = names
56 | for _, name in enumerate(self.names):
57 | self.file.write(name)
58 | self.file.write('\t')
59 | self.numbers[name] = []
60 | self.file.write('\n')
61 | self.file.flush()
62 |
63 |
64 | def append(self, numbers):
65 | assert len(self.names) == len(numbers), 'Numbers do not match names'
66 | for index, num in enumerate(numbers):
67 | self.file.write("{0:.6f}".format(num))
68 | self.file.write('\t')
69 | self.numbers[self.names[index]].append(num)
70 | self.file.write('\n')
71 | self.file.flush()
72 |
73 | def plot(self, names=None):
74 | names = self.names if names == None else names
75 | numbers = self.numbers
76 | for _, name in enumerate(names):
77 | x = np.arange(len(numbers[name]))
78 | plt.plot(x, np.asarray(numbers[name]))
79 | plt.legend([self.title + '(' + name + ')' for name in names])
80 | plt.grid(True)
81 |
82 | def close(self):
83 | if self.file is not None:
84 | self.file.close()
85 |
86 | class LoggerMonitor(object):
87 | '''Load and visualize multiple logs.'''
88 | def __init__ (self, paths):
89 | '''paths is a distionary with {name:filepath} pair'''
90 | self.loggers = []
91 | for title, path in paths.items():
92 | logger = Logger(path, title=title, resume=True)
93 | self.loggers.append(logger)
94 |
95 | def plot(self, names=None):
96 | plt.figure()
97 | plt.subplot(121)
98 | legend_text = []
99 | for logger in self.loggers:
100 | legend_text += plot_overlap(logger, names)
101 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
102 | plt.grid(True)
103 |
104 | if __name__ == '__main__':
105 | # # Example
106 | # logger = Logger('test.txt')
107 | # logger.set_names(['Train loss', 'Valid loss','Test loss'])
108 |
109 | # length = 100
110 | # t = np.arange(length)
111 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
112 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
113 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1
114 |
115 | # for i in range(0, length):
116 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]])
117 | # logger.plot()
118 |
119 | # Example: logger monitor
120 | paths = {
121 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt',
122 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt',
123 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt',
124 | }
125 |
126 | field = ['Valid Acc.']
127 |
128 | monitor = LoggerMonitor(paths)
129 | monitor.plot(names=field)
130 | savefig('test.eps')
131 |
--------------------------------------------------------------------------------
/cifar_imagenet/utils/misc.py:
--------------------------------------------------------------------------------
1 | '''Some helper functions for PyTorch, including:
2 | - get_mean_and_std: calculate the mean and std value of dataset.
3 | - msr_init: net parameter initialization.
4 | - progress_bar: progress bar mimic xlua.progress.
5 | '''
6 | import errno
7 | import os
8 | import sys
9 | import time
10 | import math
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.init as init
15 | from torch.autograd import Variable
16 |
17 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter']
18 |
19 |
20 | def get_mean_and_std(dataset):
21 | '''Compute the mean and std value of dataset.'''
22 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
23 |
24 | mean = torch.zeros(3)
25 | std = torch.zeros(3)
26 | print('==> Computing mean and std..')
27 | for inputs, targets in dataloader:
28 | for i in range(3):
29 | mean[i] += inputs[:,i,:,:].mean()
30 | std[i] += inputs[:,i,:,:].std()
31 | mean.div_(len(dataset))
32 | std.div_(len(dataset))
33 | return mean, std
34 |
35 | def init_params(net):
36 | '''Init layer parameters.'''
37 | for m in net.modules():
38 | if isinstance(m, nn.Conv2d):
39 | init.kaiming_normal(m.weight, mode='fan_out')
40 | if m.bias:
41 | init.constant(m.bias, 0)
42 | elif isinstance(m, nn.BatchNorm2d):
43 | init.constant(m.weight, 1)
44 | init.constant(m.bias, 0)
45 | elif isinstance(m, nn.Linear):
46 | init.normal(m.weight, std=1e-3)
47 | if m.bias:
48 | init.constant(m.bias, 0)
49 |
50 | def mkdir_p(path):
51 | '''make dir if not exist'''
52 | try:
53 | os.makedirs(path)
54 | except OSError as exc: # Python >2.5
55 | if exc.errno == errno.EEXIST and os.path.isdir(path):
56 | pass
57 | else:
58 | raise
59 |
60 | class AverageMeter(object):
61 | """Computes and stores the average and current value
62 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
63 | """
64 | def __init__(self):
65 | self.reset()
66 |
67 | def reset(self):
68 | self.val = 0
69 | self.avg = 0
70 | self.sum = 0
71 | self.count = 0
72 |
73 | def update(self, val, n=1):
74 | self.val = val
75 | self.sum += val * n
76 | self.count += n
77 | self.avg = self.sum / self.count
--------------------------------------------------------------------------------
/cifar_imagenet/utils/radam.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer, required
4 |
5 | class RAdam(Optimizer):
6 |
7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
8 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
9 | self.buffer = [[None, None, None] for ind in range(10)]
10 | super(RAdam, self).__init__(params, defaults)
11 |
12 | def __setstate__(self, state):
13 | super(RAdam, self).__setstate__(state)
14 |
15 | def step(self, closure=None):
16 |
17 | loss = None
18 | if closure is not None:
19 | loss = closure()
20 |
21 | for group in self.param_groups:
22 |
23 | for p in group['params']:
24 | if p.grad is None:
25 | continue
26 | grad = p.grad.data.float()
27 | if grad.is_sparse:
28 | raise RuntimeError('RAdam does not support sparse gradients')
29 |
30 | p_data_fp32 = p.data.float()
31 |
32 | state = self.state[p]
33 |
34 | if len(state) == 0:
35 | state['step'] = 0
36 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
37 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
38 | else:
39 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
40 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
41 |
42 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
43 | beta1, beta2 = group['betas']
44 |
45 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
46 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
47 |
48 | state['step'] += 1
49 | buffered = self.buffer[int(state['step'] % 10)]
50 | if state['step'] == buffered[0]:
51 | N_sma, step_size = buffered[1], buffered[2]
52 | else:
53 | buffered[0] = state['step']
54 | beta2_t = beta2 ** state['step']
55 | N_sma_max = 2 / (1 - beta2) - 1
56 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
57 | buffered[1] = N_sma
58 |
59 | # more conservative since it's an approximated value
60 | if N_sma >= 5:
61 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
62 | else:
63 | step_size = group['lr'] / (1 - beta1 ** state['step'])
64 | buffered[2] = step_size
65 |
66 | if group['weight_decay'] != 0:
67 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
68 |
69 | # more conservative since it's an approximated value
70 | if N_sma >= 5:
71 | denom = exp_avg_sq.sqrt().add_(group['eps'])
72 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
73 | else:
74 | p_data_fp32.add_(-step_size, exp_avg)
75 |
76 | p.data.copy_(p_data_fp32)
77 |
78 | return loss
79 |
80 | class RAdam_4step(Optimizer):
81 |
82 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, update_all=False, additional_four=False):
83 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
84 | self.update_all = update_all # whether update the first 4 steps
85 | self.additional_four = additional_four # whether use additional 4 steps for SGD
86 | self.buffer = [[None, None] for ind in range(10)]
87 | super(RAdam_4step, self).__init__(params, defaults)
88 |
89 | def __setstate__(self, state):
90 | super(RAdam_4step, self).__setstate__(state)
91 |
92 | def step(self, closure=None):
93 |
94 | loss = None
95 | if closure is not None:
96 | loss = closure()
97 |
98 | for group in self.param_groups:
99 |
100 | for p in group['params']:
101 | if p.grad is None:
102 | continue
103 | grad = p.grad.data.float()
104 | if grad.is_sparse:
105 | raise RuntimeError('RAdam_4step does not support sparse gradients')
106 |
107 | p_data_fp32 = p.data.float()
108 |
109 | state = self.state[p]
110 |
111 | if len(state) == 0:
112 | state['step'] = -4 if self.additional_four else 0 #since this exp requires exactly 4 step, it is hard coded
113 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
114 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
115 | else:
116 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
117 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
118 |
119 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
120 | beta1, beta2 = group['betas']
121 |
122 | state['step'] += 1
123 |
124 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
125 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
126 |
127 | if state['step'] > 0:
128 |
129 | state_step = state['step'] + 4 if self.additional_four else state['step'] #since this exp requires exactly 4 step, it is hard coded
130 |
131 | buffered = self.buffer[int(state_step % 10)]
132 | if state_step == buffered[0]:
133 | step_size = buffered[1]
134 | else:
135 | buffered[0] = state_step
136 | beta2_t = beta2 ** state['step']
137 |
138 | if state['step'] > 4: #since this exp requires exactly 4 step, it is hard coded
139 | N_sma_max = 2 / (1 - beta2) - 1
140 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
141 | step_size = group['lr'] * math.sqrt((N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state_step)
142 | elif self.update_all:
143 | step_size = group['lr'] / (1 - beta1 ** state_step)
144 | else:
145 | step_size = 0
146 | buffered[1] = step_size
147 |
148 | if state['step'] > 4: #since this exp requires exactly 4 step, it is hard coded
149 | if group['weight_decay'] != 0:
150 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
151 | denom = (exp_avg_sq.sqrt() / math.sqrt(1 - beta2 ** state_step)).add_(group['eps'])
152 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
153 | p.data.copy_(p_data_fp32)
154 | elif self.update_all:
155 | if group['weight_decay'] != 0:
156 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
157 | denom = (exp_avg_sq.sqrt() / math.sqrt(1 - beta2 ** state_step))
158 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
159 | p.data.copy_(p_data_fp32)
160 | else:
161 | state_step = state['step'] + 4 if self.additional_four else state['step'] #since this exp requires exactly 4 step, it is hard coded
162 |
163 | if group['weight_decay'] != 0:
164 | p_data_fp32.add_(-group['weight_decay'] * 0.1, p_data_fp32)
165 |
166 | step_size = 0.1 / (1 - beta1 ** state_step)
167 | p_data_fp32.add_(-step_size, exp_avg)
168 | p.data.copy_(p_data_fp32)
169 |
170 | return loss
171 |
172 | class AdamW(Optimizer):
173 |
174 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
175 | weight_decay=0, use_variance=True, warmup = 4000):
176 | defaults = dict(lr=lr, betas=betas, eps=eps,
177 | weight_decay=weight_decay, use_variance=True, warmup = warmup)
178 | print('======== Warmup: {} ========='.format(warmup))
179 | super(AdamW, self).__init__(params, defaults)
180 |
181 | def __setstate__(self, state):
182 | super(AdamW, self).__setstate__(state)
183 |
184 | def step(self, closure=None):
185 | global iter_idx
186 | iter_idx += 1
187 | grad_list = list()
188 | mom_list = list()
189 | mom_2rd_list = list()
190 |
191 | loss = None
192 | if closure is not None:
193 | loss = closure()
194 |
195 | for group in self.param_groups:
196 |
197 | for p in group['params']:
198 | if p.grad is None:
199 | continue
200 | grad = p.grad.data.float()
201 | if grad.is_sparse:
202 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
203 |
204 | p_data_fp32 = p.data.float()
205 |
206 | state = self.state[p]
207 |
208 | if len(state) == 0:
209 | state['step'] = 0
210 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
211 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
212 | else:
213 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
214 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
215 |
216 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
217 | beta1, beta2 = group['betas']
218 |
219 | state['step'] += 1
220 |
221 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
222 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
223 |
224 | denom = exp_avg_sq.sqrt().add_(group['eps'])
225 | bias_correction1 = 1 - beta1 ** state['step']
226 | bias_correction2 = 1 - beta2 ** state['step']
227 |
228 | if group['warmup'] > state['step']:
229 | scheduled_lr = 1e-6 + state['step'] * (group['lr'] - 1e-6) / group['warmup']
230 | else:
231 | scheduled_lr = group['lr']
232 |
233 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1
234 | if group['weight_decay'] != 0:
235 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)
236 |
237 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
238 |
239 | p.data.copy_(p_data_fp32)
240 |
241 | return loss
242 |
--------------------------------------------------------------------------------
/cifar_imagenet/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import torch
3 | import torch.nn as nn
4 | import torchvision
5 | import torchvision.transforms as transforms
6 | import numpy as np
7 | from .misc import *
8 |
9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single']
10 |
11 | # functions to show an image
12 | def make_image(img, mean=(0,0,0), std=(1,1,1)):
13 | for i in range(0, 3):
14 | img[i] = img[i] * std[i] + mean[i] # unnormalize
15 | npimg = img.numpy()
16 | return np.transpose(npimg, (1, 2, 0))
17 |
18 | def gauss(x,a,b,c):
19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a)
20 |
21 | def colorize(x):
22 | ''' Converts a one-channel grayscale image to a color heatmap image '''
23 | if x.dim() == 2:
24 | torch.unsqueeze(x, 0, out=x)
25 | if x.dim() == 3:
26 | cl = torch.zeros([3, x.size(1), x.size(2)])
27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
28 | cl[1] = gauss(x,1,.5,.3)
29 | cl[2] = gauss(x,1,.2,.3)
30 | cl[cl.gt(1)] = 1
31 | elif x.dim() == 4:
32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)])
33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3)
34 | cl[:,1,:,:] = gauss(x,1,.5,.3)
35 | cl[:,2,:,:] = gauss(x,1,.2,.3)
36 | return cl
37 |
38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
40 | plt.imshow(images)
41 | plt.show()
42 |
43 |
44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
45 | im_size = images.size(2)
46 |
47 | # save for adding mask
48 | im_data = images.clone()
49 | for i in range(0, 3):
50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
51 |
52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
53 | plt.subplot(2, 1, 1)
54 | plt.imshow(images)
55 | plt.axis('off')
56 |
57 | # for b in range(mask.size(0)):
58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
59 | mask_size = mask.size(2)
60 | # print('Max %f Min %f' % (mask.max(), mask.min()))
61 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
63 | # for c in range(3):
64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
65 |
66 | # print(mask.size())
67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
69 | plt.subplot(2, 1, 2)
70 | plt.imshow(mask)
71 | plt.axis('off')
72 |
73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)):
74 | im_size = images.size(2)
75 |
76 | # save for adding mask
77 | im_data = images.clone()
78 | for i in range(0, 3):
79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize
80 |
81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std)
82 | plt.subplot(1+len(masklist), 1, 1)
83 | plt.imshow(images)
84 | plt.axis('off')
85 |
86 | for i in range(len(masklist)):
87 | mask = masklist[i].data.cpu()
88 | # for b in range(mask.size(0)):
89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min())
90 | mask_size = mask.size(2)
91 | # print('Max %f Min %f' % (mask.max(), mask.min()))
92 | mask = (upsampling(mask, scale_factor=im_size/mask_size))
93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size))
94 | # for c in range(3):
95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c]
96 |
97 | # print(mask.size())
98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data)))
99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std)
100 | plt.subplot(1+len(masklist), 1, i+2)
101 | plt.imshow(mask)
102 | plt.axis('off')
103 |
104 |
105 |
106 | # x = torch.zeros(1, 3, 3)
107 | # out = colorize(x)
108 | # out_im = make_image(out)
109 | # plt.imshow(out_im)
110 | # plt.show()
--------------------------------------------------------------------------------
/img/variance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiyuanLucasLiu/RAdam/d9fd30a337894c4003768561d45e8730dbd41333/img/variance.png
--------------------------------------------------------------------------------
/language-model/README.md:
--------------------------------------------------------------------------------
1 | # One Billion Word
2 |
3 | This folder is modified based on the LD-Net project ([original repo](https://github.com/LiyuanLucasLiu/LD-Net)). For more details about this code base, please refer to the original repo.
4 | A training [recipe](/language-model/recipes.md) is provided for language modeling experiments.
5 |
6 |
--------------------------------------------------------------------------------
/language-model/eval_1bw.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import datetime
3 | import time
4 | import torch
5 | import torch.autograd as autograd
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import codecs
9 | import pickle
10 | import math
11 |
12 | from model_word_ada.LM import LM
13 | from model_word_ada.basic import BasicRNN
14 | from model_word_ada.ddnet import DDRNN
15 | from model_word_ada.densenet import DenseRNN
16 | from model_word_ada.ldnet import LDRNN
17 | from model_word_ada.dataset import LargeDataset, EvalDataset
18 | from model_word_ada.adaptive import AdaptiveSoftmax
19 | import model_word_ada.utils as utils
20 |
21 | # from tensorboardX import SummaryWriter
22 |
23 | import argparse
24 | import json
25 | import os
26 | import sys
27 | import itertools
28 | import functools
29 |
30 | if __name__ == "__main__":
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('--dataset_folder', default='/data/billionwords/one_billion/')
33 | parser.add_argument('--load_checkpoint', default='./checkpoint/basic_1.model')
34 | parser.add_argument('--sequence_length', type=int, default=600)
35 | parser.add_argument('--hid_dim', type=int, default=2048)
36 | parser.add_argument('--word_dim', type=int, default=300)
37 | parser.add_argument('--label_dim', type=int, default=-1)
38 | parser.add_argument('--layer_num', type=int, default=2)
39 | parser.add_argument('--droprate', type=float, default=0.1)
40 | parser.add_argument('--add_relu', action='store_true')
41 | parser.add_argument('--gpu', type=int, default=1)
42 | parser.add_argument('--rnn_layer', choices=['Basic', 'DDNet', 'DenseNet', 'LDNet'], default='Basic')
43 | parser.add_argument('--rnn_unit', choices=['gru', 'lstm', 'rnn'], default='lstm')
44 | parser.add_argument('--cut_off', nargs='+', default=[4000,40000,200000])
45 | parser.add_argument('--limit', type=int, default=76800)
46 |
47 | args = parser.parse_args()
48 |
49 | if args.gpu >= 0:
50 | torch.cuda.set_device(args.gpu)
51 |
52 | print('loading dataset')
53 | dataset = pickle.load(open(args.dataset_folder + 'test.pk', 'rb'))
54 | w_map, test_data = dataset['w_map'], dataset['test_data']
55 |
56 | cut_off = args.cut_off + [len(w_map) + 1]
57 |
58 | test_loader = EvalDataset(test_data, args.sequence_length)
59 |
60 | print('building model')
61 |
62 | rnn_map = {'Basic': BasicRNN, 'DDNet': DDRNN, 'DenseNet': DenseRNN, 'LDNet': functools.partial(LDRNN, layer_drop = 0)}
63 | rnn_layer = rnn_map[args.rnn_layer](args.layer_num, args.rnn_unit, args.word_dim, args.hid_dim, args.droprate)
64 |
65 | if args.label_dim > 0:
66 | soft_max = AdaptiveSoftmax(args.label_dim, cut_off)
67 | else:
68 | soft_max = AdaptiveSoftmax(rnn_layer.output_dim, cut_off)
69 |
70 | lm_model = LM(rnn_layer, soft_max, len(w_map), args.word_dim, args.droprate, label_dim = args.label_dim, add_relu = args.add_relu)
71 | lm_model.cuda()
72 |
73 | if os.path.isfile(args.load_checkpoint):
74 | print("loading checkpoint: '{}'".format(args.load_checkpoint))
75 |
76 | checkpoint_file = torch.load(args.load_checkpoint, map_location=lambda storage, loc: storage)
77 | lm_model.load_state_dict(checkpoint_file['lm_model'])
78 | else:
79 | print("no checkpoint found at: '{}'".format(args.load_checkpoint))
80 |
81 | test_lm = nn.NLLLoss()
82 |
83 | test_lm.cuda()
84 | lm_model.cuda()
85 |
86 | print('evaluating')
87 | lm_model.eval()
88 |
89 | iterator = test_loader.get_tqdm()
90 |
91 | lm_model.init_hidden()
92 | total_loss = 0
93 | total_len = 0
94 | for word_t, label_t in iterator:
95 | label_t = label_t.view(-1)
96 | tmp_len = label_t.size(0)
97 | output = lm_model.log_prob(word_t)
98 | total_loss += tmp_len * utils.to_scalar(test_lm(autograd.Variable(output), label_t))
99 | total_len += tmp_len
100 |
101 | if args.limit > 0 and total_len > args.limit:
102 | break
103 |
104 | print(str(total_loss / total_len))
105 | ppl = math.exp(total_loss / total_len)
106 | print('PPL: ' + str(ppl))
107 |
--------------------------------------------------------------------------------
/language-model/model_word_ada/LM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import model_word_ada.utils as utils
5 |
6 | class LM(nn.Module):
7 |
8 | def __init__(self, rnn, soft_max, w_num, w_dim, droprate, label_dim = -1, add_relu=False):
9 | super(LM, self).__init__()
10 |
11 | self.rnn = rnn
12 | self.soft_max = soft_max
13 |
14 | if soft_max:
15 | self.forward = self.softmax_forward
16 | else:
17 | self.forward = self.embed_forward
18 |
19 | self.w_num = w_num
20 | self.w_dim = w_dim
21 | self.word_embed = nn.Embedding(w_num, w_dim)
22 |
23 | self.rnn_output = self.rnn.output_dim
24 |
25 | self.add_proj = label_dim > 0
26 | if self.add_proj:
27 | self.project = nn.Linear(self.rnn_output, label_dim)
28 | if add_relu:
29 | self.relu = nn.ReLU()
30 | else:
31 | self.relu = lambda x: x
32 |
33 | self.drop = nn.Dropout(p=droprate)
34 |
35 | def load_embed(self, origin_lm):
36 | self.word_embed = origin_lm.word_embed
37 | self.soft_max = origin_lm.soft_max
38 |
39 | def rand_ini(self):
40 |
41 | self.rnn.rand_ini()
42 | # utils.init_linear(self.project)
43 | self.soft_max.rand_ini()
44 | # if not self.tied_weight:
45 | utils.init_embedding(self.word_embed.weight)
46 |
47 | if self.add_proj:
48 | utils.init_linear(self.project)
49 |
50 | def init_hidden(self):
51 | self.rnn.init_hidden()
52 |
53 | def softmax_forward(self, w_in, target):
54 |
55 | w_emb = self.word_embed(w_in)
56 |
57 | w_emb = self.drop(w_emb)
58 |
59 | out = self.rnn(w_emb).contiguous().view(-1, self.rnn_output)
60 |
61 | if self.add_proj:
62 | out = self.drop(self.relu(self.project(out)))
63 | # out = self.drop(self.project(out))
64 |
65 | out = self.soft_max(out, target)
66 |
67 | return out
68 |
69 | def embed_forward(self, w_in, target):
70 |
71 | w_emb = self.word_embed(w_in)
72 |
73 | w_emb = self.drop(w_emb)
74 |
75 | out = self.rnn(w_emb).contiguous().view(-1, self.rnn_output)
76 |
77 | if self.add_proj:
78 | out = self.drop(self.relu(self.project(out)))
79 | # out = self.drop(self.project(out))
80 |
81 | out = self.soft_max(out, target)
82 |
83 | return out
84 |
85 | def log_prob(self, w_in):
86 |
87 | w_emb = self.word_embed(w_in)
88 |
89 | out = self.rnn(w_emb).contiguous().view(-1, self.rnn_output)
90 |
91 | if self.add_proj:
92 | out = self.relu(self.project(out))
93 |
94 | out = self.soft_max.log_prob(out)
95 |
96 | return out
--------------------------------------------------------------------------------
/language-model/model_word_ada/adaptive.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.autograd import Variable
4 |
5 | from math import sqrt
6 |
7 | class AdaptiveSoftmax(nn.Module):
8 | def __init__(self, input_size, cutoff):
9 | super().__init__()
10 |
11 | self.input_size = input_size
12 | self.cutoff = cutoff
13 | self.output_size = cutoff[0] + len(cutoff) - 1
14 |
15 | self.head = nn.Linear(input_size, self.output_size)
16 | self.tail = nn.ModuleList()
17 |
18 | self.cross_entropy = nn.CrossEntropyLoss(size_average=False)
19 |
20 | for i in range(len(self.cutoff) - 1):
21 | # seq = nn.Sequential(
22 | # nn.Linear(input_size, input_size // 4 ** i, False),
23 | # nn.Linear(input_size // 4 ** i, cutoff[i + 1] - cutoff[i], False)
24 | # )
25 |
26 | seq = nn.Sequential(
27 | nn.Linear(input_size, input_size // 4 ** (i + 1), False),
28 | nn.Linear(input_size // 4 ** (i + 1), cutoff[i + 1] - cutoff[i], False)
29 | )
30 |
31 | self.tail.append(seq)
32 |
33 | def rand_ini(self):
34 |
35 | nn.init.xavier_normal(self.head.weight)
36 |
37 | for tail in self.tail:
38 | nn.init.xavier_normal(tail[0].weight)
39 | nn.init.xavier_normal(tail[1].weight)
40 |
41 | def log_prob(self, w_in):
42 | lsm = nn.LogSoftmax(dim=1).cuda()
43 |
44 | head_out = self.head(w_in)
45 |
46 | batch_size = head_out.size(0)
47 | prob = torch.zeros(batch_size, self.cutoff[-1]).cuda()
48 |
49 | lsm_head = lsm(head_out)
50 | prob.narrow(1, 0, self.output_size).add_(lsm_head.narrow(1, 0, self.output_size).data)
51 |
52 | for i in range(len(self.tail)):
53 | pos = self.cutoff[i]
54 | i_size = self.cutoff[i + 1] - pos
55 | buffer = lsm_head.narrow(1, self.cutoff[0] + i, 1)
56 | buffer = buffer.expand(batch_size, i_size)
57 | lsm_tail = lsm(self.tail[i](w_in))
58 | prob.narrow(1, pos, i_size).copy_(buffer.data).add_(lsm_tail.data)
59 |
60 | return prob
61 |
62 | def forward(self, w_in, target):
63 |
64 | batch_size = w_in.size(0)
65 | output = 0.0
66 |
67 | first_target = target.clone()
68 |
69 | for i in range(len(self.cutoff) - 1):
70 |
71 | mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
72 |
73 | if mask.sum() > 0:
74 |
75 | first_target[mask] = self.cutoff[0] + i
76 |
77 | second_target = Variable(target[mask].add(-self.cutoff[i]))
78 | second_input = w_in.index_select(0, Variable(mask.nonzero().squeeze()))
79 |
80 | second_output = self.tail[i](second_input)
81 |
82 | output += self.cross_entropy(second_output, second_target)
83 |
84 | output += self.cross_entropy(self.head(w_in), Variable(first_target))
85 | output /= batch_size
86 | return output
--------------------------------------------------------------------------------
/language-model/model_word_ada/basic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import model_word_ada.utils as utils
5 | from model_word_ada.bnlstm import BNLSTM
6 |
7 | class BasicUnit(nn.Module):
8 | def __init__(self, unit, input_dim, hid_dim, droprate):
9 | super(BasicUnit, self).__init__()
10 |
11 | rnnunit_map = {'rnn': nn.RNN, 'lstm': nn.LSTM, 'gru': nn.GRU, 'bnlstm': BNLSTM}
12 |
13 | self.batch_norm = (unit == 'bnlstm')
14 |
15 | self.layer = rnnunit_map[unit](input_dim, hid_dim, 1)
16 | self.droprate = droprate
17 |
18 | self.output_dim = hid_dim
19 |
20 | self.init_hidden()
21 |
22 | def init_hidden(self):
23 |
24 | self.hidden_state = None
25 |
26 | def rand_ini(self):
27 |
28 | if not self.batch_norm:
29 | utils.init_lstm(self.layer)
30 |
31 | def forward(self, x):
32 | # set_trace()
33 | out, new_hidden = self.layer(x, self.hidden_state)
34 |
35 | self.hidden_state = utils.repackage_hidden(new_hidden)
36 |
37 | if self.droprate > 0:
38 | out = F.dropout(out, p=self.droprate, training=self.training)
39 |
40 | return out
41 |
42 | class BasicRNN(nn.Module):
43 | def __init__(self, layer_num, unit, emb_dim, hid_dim, droprate):
44 | super(BasicRNN, self).__init__()
45 |
46 | layer_list = [BasicUnit(unit, emb_dim, hid_dim, droprate)] + [BasicUnit(unit, hid_dim, hid_dim, droprate) for i in range(layer_num - 1)]
47 | self.layer = nn.Sequential(*layer_list)
48 | self.output_dim = layer_list[-1].output_dim
49 |
50 | self.init_hidden()
51 |
52 | def init_hidden(self):
53 |
54 | for tup in self.layer.children():
55 | tup.init_hidden()
56 |
57 | def rand_ini(self):
58 |
59 | for tup in self.layer.children():
60 | tup.rand_ini()
61 |
62 | def forward(self, x):
63 | return self.layer(x)
--------------------------------------------------------------------------------
/language-model/model_word_ada/bnlstm.py:
--------------------------------------------------------------------------------
1 | """Implementation of batch-normalized LSTM.
2 | from: https://github.com/jihunchoi/recurrent-batch-normalization-pytorch
3 | """
4 | import torch
5 | from torch import nn
6 | from torch.autograd import Variable
7 | from torch.nn import functional, init
8 |
9 |
10 | class SeparatedBatchNorm1d(nn.Module):
11 |
12 | """
13 | A batch normalization module which keeps its running mean
14 | and variance separately per timestep.
15 | """
16 |
17 | def __init__(self, num_features, max_length, eps=1e-5, momentum=0.1,
18 | affine=True):
19 | """
20 | Most parts are copied from
21 | torch.nn.modules.batchnorm._BatchNorm.
22 | """
23 |
24 | super(SeparatedBatchNorm1d, self).__init__()
25 | self.num_features = num_features
26 | self.max_length = max_length
27 | self.affine = affine
28 | self.eps = eps
29 | self.momentum = momentum
30 | if self.affine:
31 | self.weight = nn.Parameter(torch.FloatTensor(num_features))
32 | self.bias = nn.Parameter(torch.FloatTensor(num_features))
33 | else:
34 | self.register_parameter('weight', None)
35 | self.register_parameter('bias', None)
36 | for i in range(max_length):
37 | self.register_buffer(
38 | 'running_mean_{}'.format(i), torch.zeros(num_features))
39 | self.register_buffer(
40 | 'running_var_{}'.format(i), torch.ones(num_features))
41 | self.reset_parameters()
42 |
43 | def reset_parameters(self):
44 | for i in range(self.max_length):
45 | running_mean_i = getattr(self, 'running_mean_{}'.format(i))
46 | running_var_i = getattr(self, 'running_var_{}'.format(i))
47 | running_mean_i.zero_()
48 | running_var_i.fill_(1)
49 | if self.affine:
50 | self.weight.data.uniform_()
51 | self.bias.data.zero_()
52 |
53 | def _check_input_dim(self, input_):
54 | if input_.size(1) != self.running_mean_0.nelement():
55 | raise ValueError('got {}-feature tensor, expected {}'
56 | .format(input_.size(1), self.num_features))
57 |
58 | def forward(self, input_, time):
59 | self._check_input_dim(input_)
60 | if time >= self.max_length:
61 | time = self.max_length - 1
62 | running_mean = getattr(self, 'running_mean_{}'.format(time))
63 | running_var = getattr(self, 'running_var_{}'.format(time))
64 | return functional.batch_norm(
65 | input=input_, running_mean=running_mean, running_var=running_var,
66 | weight=self.weight, bias=self.bias, training=self.training,
67 | momentum=self.momentum, eps=self.eps)
68 |
69 | def __repr__(self):
70 | return ('{name}({num_features}, eps={eps}, momentum={momentum},'
71 | ' max_length={max_length}, affine={affine})'
72 | .format(name=self.__class__.__name__, **self.__dict__))
73 |
74 | class BNLSTMCell(nn.Module):
75 |
76 | """A BN-LSTM cell."""
77 |
78 | def __init__(self, input_size, hidden_size, max_length=784, use_bias=True):
79 |
80 | super(BNLSTMCell, self).__init__()
81 | self.input_size = input_size
82 | self.hidden_size = hidden_size
83 | self.max_length = max_length
84 | self.use_bias = use_bias
85 | self.weight_ih = nn.Parameter(
86 | torch.FloatTensor(input_size, 4 * hidden_size))
87 | self.weight_hh = nn.Parameter(
88 | torch.FloatTensor(hidden_size, 4 * hidden_size))
89 | if use_bias:
90 | self.bias = nn.Parameter(torch.FloatTensor(4 * hidden_size))
91 | else:
92 | self.register_parameter('bias', None)
93 | # BN parameters
94 | self.bn_ih = SeparatedBatchNorm1d(
95 | num_features=4 * hidden_size, max_length=max_length)
96 | self.bn_hh = SeparatedBatchNorm1d(
97 | num_features=4 * hidden_size, max_length=max_length)
98 | self.bn_c = SeparatedBatchNorm1d(
99 | num_features=hidden_size, max_length=max_length)
100 | self.reset_parameters()
101 |
102 | def reset_parameters(self):
103 | """
104 | Initialize parameters following the way proposed in the paper.
105 | """
106 |
107 | # The input-to-hidden weight matrix is initialized orthogonally.
108 | init.orthogonal(self.weight_ih.data)
109 | # The hidden-to-hidden weight matrix is initialized as an identity
110 | # matrix.
111 | weight_hh_data = torch.eye(self.hidden_size)
112 | weight_hh_data = weight_hh_data.repeat(1, 4)
113 | self.weight_hh.data.set_(weight_hh_data)
114 | # The bias is just set to zero vectors.
115 | init.constant(self.bias.data, val=0)
116 | # Initialization of BN parameters.
117 | self.bn_ih.reset_parameters()
118 | self.bn_hh.reset_parameters()
119 | self.bn_c.reset_parameters()
120 | self.bn_ih.bias.data.fill_(0)
121 | self.bn_hh.bias.data.fill_(0)
122 | self.bn_ih.weight.data.fill_(0.1)
123 | self.bn_hh.weight.data.fill_(0.1)
124 | self.bn_c.weight.data.fill_(0.1)
125 |
126 | def forward(self, input_, hx, time):
127 | """
128 | Args:
129 | input_: A (batch, input_size) tensor containing input
130 | features.
131 | hx: A tuple (h_0, c_0), which contains the initial hidden
132 | and cell state, where the size of both states is
133 | (batch, hidden_size).
134 | time: The current timestep value, which is used to
135 | get appropriate running statistics.
136 |
137 | Returns:
138 | h_1, c_1: Tensors containing the next hidden and cell state.
139 | """
140 |
141 | h_0, c_0 = hx
142 | batch_size = h_0.size(0)
143 | bias_batch = (self.bias.unsqueeze(0)
144 | .expand(batch_size, *self.bias.size()))
145 | wh = torch.mm(h_0, self.weight_hh)
146 | wi = torch.mm(input_, self.weight_ih)
147 | bn_wh = self.bn_hh(wh, time=time)
148 | bn_wi = self.bn_ih(wi, time=time)
149 | f, i, o, g = torch.split(bn_wh + bn_wi + bias_batch,
150 | split_size=self.hidden_size, dim=1)
151 | c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g)
152 | h_1 = torch.sigmoid(o) * torch.tanh(self.bn_c(c_1, time=time))
153 | return h_1, c_1
154 |
155 |
156 | class BNLSTM(nn.Module):
157 |
158 | """A module that runs multiple steps of LSTM."""
159 |
160 | def __init__(self, input_size, hidden_size, num_layers=1,
161 | use_bias=True, batch_first=False, dropout=0, **kwargs):
162 | super(BNLSTM, self).__init__()
163 | self.input_size = input_size
164 | self.hidden_size = hidden_size
165 | self.num_layers = num_layers
166 | self.use_bias = use_bias
167 | self.batch_first = batch_first
168 | self.dropout = dropout
169 |
170 | for layer in range(num_layers):
171 | layer_input_size = input_size if layer == 0 else hidden_size
172 | cell = BNLSTMCell(input_size=layer_input_size,
173 | hidden_size=hidden_size,
174 | **kwargs)
175 | setattr(self, 'cell_{}'.format(layer), cell)
176 | self.dropout_layer = nn.Dropout(dropout)
177 | self.reset_parameters()
178 |
179 | def get_cell(self, layer):
180 | return getattr(self, 'cell_{}'.format(layer))
181 |
182 | def reset_parameters(self):
183 | for layer in range(self.num_layers):
184 | cell = self.get_cell(layer)
185 | cell.reset_parameters()
186 |
187 | @staticmethod
188 | def _forward_rnn(cell, input_, hx):
189 | max_time = input_.size(0)
190 | output = []
191 | for time in range(max_time):
192 | h_next, c_next = cell(input_=input_[time], hx=hx, time=time)
193 | # # mask = (time < length).float().unsqueeze(1).expand_as(h_next)
194 | # h_next = h_next*mask + hx[0]*(1 - mask)
195 | # c_next = c_next*mask + hx[1]*(1 - mask)
196 | hx_next = (h_next, c_next)
197 | output.append(h_next)
198 | hx = hx_next
199 | output = torch.stack(output, 0)
200 | return output, hx
201 |
202 | def forward(self, input_, length=None, hx=None):
203 | if self.batch_first:
204 | input_ = input_.transpose(0, 1)
205 | max_time, batch_size, _ = input_.size()
206 | # if length is None:
207 | # length = Variable(torch.LongTensor([max_time] * batch_size))
208 | # if input_.is_cuda:
209 | # device = input_.get_device()
210 | # length = length.cuda(device)
211 | if hx is None:
212 | hx = Variable(input_.data.new(batch_size, self.hidden_size).zero_())
213 | hx = (hx, hx)
214 | h_n = []
215 | c_n = []
216 | layer_output = None
217 | for layer in range(self.num_layers):
218 | cell = self.get_cell(layer)
219 | layer_output, (layer_h_n, layer_c_n) = BNLSTM._forward_rnn(
220 | cell=cell, input_=input_, hx=hx)
221 | input_ = self.dropout_layer(layer_output)
222 | h_n.append(layer_h_n)
223 | c_n.append(layer_c_n)
224 | output = layer_output
225 | h_n = torch.stack(h_n, 0)
226 | c_n = torch.stack(c_n, 0)
227 | return output, (h_n, c_n)
--------------------------------------------------------------------------------
/language-model/model_word_ada/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.autograd as autograd
5 |
6 | import sys
7 | import pickle
8 | import random
9 | from tqdm import tqdm
10 |
11 | from torch.utils.data import Dataset
12 |
13 | class EvalDataset(object):
14 |
15 | def __init__(self, dataset, sequence_length):
16 | super(EvalDataset, self).__init__()
17 | self.dataset = dataset
18 |
19 | self.sequence_length = sequence_length
20 |
21 | self.construct_index()
22 |
23 | def get_tqdm(self):
24 | return tqdm(self, mininterval=2, total=self.index_length, leave=False, file=sys.stdout, ncols=80)
25 |
26 | def construct_index(self):
27 | token_per_batch = self.sequence_length
28 | tot_num = len(self.dataset) - 1
29 | res_num = tot_num - tot_num % token_per_batch
30 |
31 | self.x = list(torch.unbind(torch.LongTensor(self.dataset[0:res_num]).view(-1, self.sequence_length), 0))
32 | self.y = list(torch.unbind(torch.LongTensor(self.dataset[1:res_num+1]).view(-1, self.sequence_length), 0))
33 |
34 | self.x.append(torch.LongTensor(self.dataset[res_num:tot_num]))
35 | self.y.append(torch.LongTensor(self.dataset[res_num+1:tot_num+1]))
36 |
37 | self.index_length = len(self.x)
38 | self.cur_idx = 0
39 |
40 | def __iter__(self):
41 | return self
42 |
43 | def __next__(self):
44 | if self.cur_idx == self.index_length:
45 | self.cur_idx = 0
46 | raise StopIteration
47 |
48 | word_t = autograd.Variable(self.x[self.cur_idx]).cuda().view(-1, 1)
49 | label_t = autograd.Variable(self.y[self.cur_idx]).cuda().view(-1, 1)
50 |
51 | self.cur_idx += 1
52 |
53 | return word_t, label_t
54 |
55 | # class SmallDataset(object):
56 |
57 | # def __init__(self, dataset, batch_size, sequence_length):
58 | # super(SmallDataset, self).__init__()
59 | # self.dataset = dataset
60 |
61 | # self.batch_size = batch_size
62 | # self.sequence_length = sequence_length
63 |
64 | # self.construct_index()
65 |
66 | # def get_tqdm(self):
67 | # return tqdm(self, mininterval=2, total=self.index_length, leave=False, file=sys.stdout)
68 |
69 | # def construct_index(self):
70 | # token_per_batch = self.batch_size * self.sequence_length
71 | # res_num = len(self.dataset) - 1
72 | # res_num = res_num - res_num % token_per_batch
73 |
74 | # self.x = torch.LongTensor(self.dataset[0:res_num]).view(self.batch_size, -1, self.sequence_length).transpose_(0, 1).contiguous()
75 | # self.y = torch.LongTensor(self.dataset[1:res_num+1]).view(self.batch_size, -1, self.sequence_length).transpose_(0, 1).contiguous()
76 |
77 | # self.index_length = self.x.size(0)
78 | # self.cur_idx = 0
79 |
80 | # def __iter__(self):
81 | # return self
82 |
83 | # def __next__(self):
84 | # if self.cur_idx == self.index_length:
85 | # self.cur_idx = 0
86 | # raise StopIteration
87 |
88 | # word_t = autograd.Variable(self.x[self.cur_idx]).cuda()
89 | # label_t = autograd.Variable(self.y[self.cur_idx]).cuda()
90 |
91 | # self.cur_idx += 1
92 |
93 | # return word_t, label_t
94 |
95 | class LargeDataset(object):
96 |
97 | def __init__(self, root, range_idx, batch_size, sequence_length):
98 | super(LargeDataset, self).__init__()
99 | self.root = root
100 | self.range_idx = range_idx
101 | self.shuffle_list = list(range(0, range_idx))
102 | self.shuffle()
103 |
104 | self.batch_size = batch_size
105 | self.sequence_length = sequence_length
106 | self.token_per_batch = self.batch_size * self.sequence_length
107 |
108 | self.total_batch_num = -1
109 |
110 | def shuffle(self):
111 | random.shuffle(self.shuffle_list)
112 |
113 | def get_tqdm(self):
114 | self.batch_count = 0
115 |
116 | if self.total_batch_num <= 0:
117 | return tqdm(self, mininterval=2, leave=False, file=sys.stdout).__iter__()
118 | else:
119 | return tqdm(self, mininterval=2, total=self.total_batch_num, leave=False, file=sys.stdout, ncols=80).__iter__()
120 |
121 | def __iter__(self):
122 | self.cur_idx = 0
123 | self.file_idx = 0
124 | self.index_length = 0
125 | return self
126 |
127 | def __next__(self):
128 | if self.cur_idx >= self.index_length:
129 | self.open_next()
130 |
131 | word_t = autograd.Variable(self.x[self.cur_idx]).cuda()
132 | # label_t = autograd.Variable(self.y[self.cur_idx]).cuda()
133 | label_t = self.y[self.cur_idx].cuda()
134 |
135 | self.cur_idx += 1
136 |
137 | return word_t, label_t
138 |
139 | def open_next(self):
140 | if self.file_idx >= self.range_idx:
141 | self.total_batch_num = self.batch_count
142 | self.shuffle()
143 | raise StopIteration
144 |
145 | self.dataset = pickle.load(open(self.root + 'train_' + str( self.shuffle_list[self.file_idx])+'.pk', 'rb'))
146 |
147 | res_num = len(self.dataset) - 1
148 | res_num = res_num - res_num % self.token_per_batch
149 |
150 | self.x = torch.LongTensor(self.dataset[0:res_num]).view(self.batch_size, -1, self.sequence_length).transpose_(0, 1).transpose_(1, 2).contiguous()
151 | self.y = torch.LongTensor(self.dataset[1:res_num+1]).view(self.batch_size, -1, self.sequence_length).transpose_(0, 1).transpose_(1, 2).contiguous()
152 |
153 | self.index_length = self.x.size(0)
154 | self.cur_idx = 0
155 |
156 | self.batch_count += self.index_length
157 | self.file_idx += 1
--------------------------------------------------------------------------------
/language-model/model_word_ada/ddnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import model_word_ada.utils as utils
5 | from model_word_ada.bnlstm import BNLSTM
6 |
7 | class BasicUnit(nn.Module):
8 | def __init__(self, unit, unit_number, emb_dim, hid_dim, droprate):
9 | super(BasicUnit, self).__init__()
10 |
11 | rnnunit_map = {'rnn': nn.RNN, 'lstm': nn.LSTM, 'gru': nn.GRU, 'bnlstm': BNLSTM}
12 |
13 | self.batch_norm = (unit == 'bnlstm')
14 |
15 | self.unit_number = unit_number
16 | # self.unit_weight = nn.Parameter(torch.FloatTensor([1] * unit_number))
17 |
18 | self.unit_list = nn.ModuleList()
19 | self.unit_list.append(rnnunit_map[unit](emb_dim, hid_dim, 1))
20 | if unit_number > 1:
21 | self.unit_list.extend([rnnunit_map[unit](hid_dim, hid_dim, 1) for ind in range(unit_number - 1)])
22 |
23 | self.droprate = droprate
24 |
25 | self.output_dim = emb_dim + hid_dim * unit_number
26 |
27 | self.init_hidden()
28 |
29 | def init_hidden(self):
30 |
31 | self.hidden_list = [None for i in range(self.unit_number)]
32 |
33 | def rand_ini(self):
34 |
35 | if not self.batch_norm:
36 | for cur_lstm in self.unit_list:
37 | utils.init_lstm(cur_lstm)
38 |
39 | def forward(self, x):
40 |
41 | out = 0
42 | # n_w = F.softmax(self.unit_weight, dim=0)
43 | for ind in range(self.unit_number):
44 | nout, new_hidden = self.unit_list[ind](x[ind], self.hidden_list[ind])
45 | self.hidden_list[ind] = utils.repackage_hidden(new_hidden)
46 | out = out + nout
47 | # out = out + n_w[ind] * self.unit_number * nout
48 |
49 | if self.droprate > 0:
50 | out = F.dropout(out, p=self.droprate, training=self.training)
51 |
52 | x.append(out)
53 |
54 | return x
55 |
56 | class DDRNN(nn.Module):
57 | def __init__(self, layer_num, unit, emb_dim, hid_dim, droprate):
58 | super(DDRNN, self).__init__()
59 |
60 | layer_list = [BasicUnit(unit, i + 1, emb_dim, hid_dim, droprate) for i in range(layer_num)]
61 | self.layer = nn.Sequential(*layer_list)
62 | self.output_dim = layer_list[-1].output_dim
63 |
64 | self.init_hidden()
65 |
66 | def init_hidden(self):
67 |
68 | for tup in self.layer.children():
69 | tup.init_hidden()
70 |
71 | def rand_ini(self):
72 |
73 | for tup in self.layer.children():
74 | tup.rand_ini()
75 |
76 | def forward(self, x):
77 | out = self.layer([x])
78 | return torch.cat(out, 2)
--------------------------------------------------------------------------------
/language-model/model_word_ada/densenet.py:
--------------------------------------------------------------------------------
1 | """
2 | .. module:: densenet
3 | :synopsis: vanilla dense RNN
4 |
5 | .. moduleauthor:: Liyuan Liu
6 | """
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import model_word_ada.utils as utils
11 |
12 | class BasicUnit(nn.Module):
13 | def __init__(self, unit, input_dim, increase_rate, droprate):
14 | super(BasicUnit, self).__init__()
15 |
16 | rnnunit_map = {'rnn': nn.RNN, 'lstm': nn.LSTM, 'gru': nn.GRU}
17 |
18 | self.unit = unit
19 |
20 | self.layer = rnnunit_map[unit](input_dim, increase_rate, 1)
21 |
22 | if 'lstm' == self.unit:
23 | utils.init_lstm(self.layer)
24 |
25 | self.droprate = droprate
26 |
27 | self.input_dim = input_dim
28 | self.increase_rate = increase_rate
29 | self.output_dim = input_dim + increase_rate
30 |
31 | self.init_hidden()
32 |
33 | def init_hidden(self):
34 |
35 | self.hidden_state = None
36 |
37 | def rand_ini(self):
38 | return
39 |
40 | def forward(self, x):
41 |
42 | if self.droprate > 0:
43 | new_x = F.dropout(x, p=self.droprate, training=self.training)
44 | else:
45 | new_x = x
46 |
47 | out, new_hidden = self.layer(new_x, self.hidden_state)
48 |
49 | self.hidden_state = utils.repackage_hidden(new_hidden)
50 |
51 | out = out.contiguous()
52 |
53 | return torch.cat([x, out], 2)
54 |
55 | class DenseRNN(nn.Module):
56 | def __init__(self, layer_num, unit, emb_dim, hid_dim, droprate):
57 | super(DenseRNN, self).__init__()
58 |
59 | self.layer_list = [BasicUnit(unit, emb_dim + i * hid_dim, hid_dim, droprate) for i in range(layer_num)]
60 | self.layer = nn.Sequential(*self.layer_list)
61 | self.output_dim = self.layer_list[-1].output_dim
62 |
63 | self.init_hidden()
64 |
65 | def init_hidden(self):
66 |
67 | for tup in self.layer_list:
68 | tup.init_hidden()
69 |
70 | def rand_ini(self):
71 |
72 | for tup in self.layer_list:
73 | tup.rand_ini()
74 |
75 | def forward(self, x):
76 | return self.layer(x)
--------------------------------------------------------------------------------
/language-model/model_word_ada/ldnet.py:
--------------------------------------------------------------------------------
1 | """
2 | .. module:: densenet
3 | :synopsis: vanilla dense RNN
4 |
5 | .. moduleauthor:: Liyuan Liu
6 | """
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import model_word_ada.utils as utils
11 | import random
12 |
13 | class BasicUnit(nn.Module):
14 | def __init__(self, unit, input_dim, increase_rate, droprate, layer_drop = 0):
15 | super(BasicUnit, self).__init__()
16 |
17 | rnnunit_map = {'rnn': nn.RNN, 'lstm': nn.LSTM, 'gru': nn.GRU}
18 |
19 | self.unit = unit
20 |
21 | self.layer = rnnunit_map[unit](input_dim, increase_rate, 1)
22 |
23 | if 'lstm' == self.unit:
24 | utils.init_lstm(self.layer)
25 |
26 | self.layer_drop = layer_drop
27 |
28 | self.droprate = droprate
29 |
30 | self.input_dim = input_dim
31 | self.increase_rate = increase_rate
32 | self.output_dim = input_dim + increase_rate
33 |
34 | self.init_hidden()
35 |
36 | def init_hidden(self):
37 |
38 | self.hidden_state = None
39 |
40 | def rand_ini(self):
41 | return
42 |
43 | def forward(self, x, p_out):
44 |
45 | if self.droprate > 0:
46 | new_x = F.dropout(x, p=self.droprate, training=self.training)
47 | else:
48 | new_x = x
49 |
50 | out, new_hidden = self.layer(new_x, self.hidden_state)
51 |
52 | self.hidden_state = utils.repackage_hidden(new_hidden)
53 |
54 | out = out.contiguous()
55 |
56 | if self.training and random.uniform(0, 1) < self.layer_drop:
57 | deep_out = torch.autograd.Variable( torch.zeros(x.size(0), x.size(1), self.increase_rate) ).cuda()
58 | else:
59 | deep_out = out
60 |
61 | o_out = torch.cat([p_out, out], 2)
62 | d_out = torch.cat([x, deep_out], 2)
63 | return d_out, o_out
64 |
65 | class LDRNN(nn.Module):
66 | def __init__(self, layer_num, unit, emb_dim, hid_dim, droprate, layer_drop):
67 | super(LDRNN, self).__init__()
68 |
69 | self.layer_list = [BasicUnit(unit, emb_dim + i * hid_dim, hid_dim, droprate, layer_drop) for i in range(layer_num)]
70 |
71 | self.layer_num = layer_num
72 | self.layer = nn.ModuleList(self.layer_list)
73 | self.output_dim = self.layer_list[-1].output_dim
74 |
75 | self.init_hidden()
76 |
77 | def init_hidden(self):
78 |
79 | for tup in self.layer_list:
80 | tup.init_hidden()
81 |
82 | def rand_ini(self):
83 |
84 | for tup in self.layer_list:
85 | tup.rand_ini()
86 |
87 | def forward(self, x):
88 | output = x
89 | for ind in range(self.layer_num):
90 | x, output = self.layer_list[ind](x, output)
91 | return output
--------------------------------------------------------------------------------
/language-model/model_word_ada/radam.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer, required
4 |
5 | # from tensorboardX import SummaryWriter
6 | # writer = SummaryWriter(logdir='/cps/gadam/n_cifa/')
7 | # iter_idx = 0
8 |
9 | # from ipdb import set_trace
10 | import torch.optim
11 |
12 | class RAdam(Optimizer):
13 |
14 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
15 | weight_decay=0):
16 | defaults = dict(lr=lr, betas=betas, eps=eps,
17 | weight_decay=weight_decay)
18 |
19 | super(RAdam, self).__init__(params, defaults)
20 |
21 | def __setstate__(self, state):
22 | super(RAdam, self).__setstate__(state)
23 |
24 | def step(self, closure=None):
25 | loss = None
26 | beta2_t = None
27 | ratio = None
28 | N_sma_max = None
29 | N_sma = None
30 |
31 | if closure is not None:
32 | loss = closure()
33 |
34 | for group in self.param_groups:
35 |
36 | for p in group['params']:
37 | if p.grad is None:
38 | continue
39 | grad = p.grad.data.float()
40 | if grad.is_sparse:
41 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
42 |
43 | p_data_fp32 = p.data.float()
44 |
45 | state = self.state[p]
46 |
47 | if len(state) == 0:
48 | state['step'] = 0
49 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
50 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
51 | else:
52 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
53 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
54 |
55 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
56 | beta1, beta2 = group['betas']
57 |
58 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
59 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
60 |
61 | state['step'] += 1
62 | if beta2_t is None:
63 | beta2_t = beta2 ** state['step']
64 | N_sma_max = 2 / (1 - beta2) - 1
65 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
66 | beta1_t = 1 - beta1 ** state['step']
67 | if N_sma >= 5:
68 | ratio = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / beta1_t
69 |
70 | if group['weight_decay'] != 0:
71 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
72 |
73 | # more conservative since it's an approximated value
74 | if N_sma >= 5:
75 | step_size = group['lr'] * ratio
76 | denom = exp_avg_sq.sqrt().add_(group['eps'])
77 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
78 | else:
79 | step_size = group['lr'] / beta1_t
80 | p_data_fp32.add_(-step_size, exp_avg)
81 |
82 | p.data.copy_(p_data_fp32)
83 |
84 | return loss
85 |
86 |
87 | class AdamW(Optimizer):
88 |
89 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
90 | weight_decay=0):
91 | defaults = dict(lr=lr, betas=betas, eps=eps,
92 | weight_decay=weight_decay)
93 |
94 | super(AdamW, self).__init__(params, defaults)
95 |
96 | def __setstate__(self, state):
97 | super(AdamW, self).__setstate__(state)
98 |
99 | def step(self, closure=None):
100 | global iter_idx
101 | iter_idx += 1
102 | grad_list = list()
103 | mom_list = list()
104 | mom_2rd_list = list()
105 |
106 | loss = None
107 | if closure is not None:
108 | loss = closure()
109 |
110 | for group in self.param_groups:
111 |
112 | for p in group['params']:
113 | if p.grad is None:
114 | continue
115 | grad = p.grad.data.float()
116 | if grad.is_sparse:
117 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
118 |
119 | p_data_fp32 = p.data.float()
120 |
121 | state = self.state[p]
122 |
123 | if len(state) == 0:
124 | state['step'] = 0
125 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
126 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
127 | else:
128 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
129 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
130 |
131 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
132 | beta1, beta2 = group['betas']
133 |
134 | state['step'] += 1
135 |
136 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
137 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
138 |
139 | denom = exp_avg_sq.sqrt().add_(group['eps'])
140 | bias_correction1 = 1 - beta1 ** state['step']
141 | bias_correction2 = 1 - beta2 ** state['step']
142 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
143 |
144 | if group['weight_decay'] != 0:
145 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
146 |
147 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
148 |
149 | p.data.copy_(p_data_fp32)
150 |
151 | return loss
152 |
--------------------------------------------------------------------------------
/language-model/model_word_ada/resnet.py:
--------------------------------------------------------------------------------
1 | # import torch
2 | # import torch.nn as nn
3 | # import torch.nn.functional as F
4 | # import model.utils as utils
5 |
6 | # class BasicUnit(nn.Module):
7 | # def __init__(self, unit, input_dim, hid_dim, droprate):
8 | # super(BasicUnit, self).__init__()
9 |
10 | # rnnunit_map = {'rnn': nn.RNN, 'lstm': nn.LSTM, 'gru': nn.GRU}
11 | # self.unit_number = unit_number
12 |
13 | # self.layer = rnnunit_map[unit](input_dim, hid_dim, 1)
14 |
15 | # self.droprate = droprate
16 |
17 | # self.output_dim = input_dim + hid_dim
18 |
19 | # self.init_hidden()
20 |
21 | # def init_hidden(self):
22 |
23 | # self.hidden_list = [None for i in range(unit_number)]
24 |
25 | # def rand_ini(self):
26 |
27 | # for cur_lstm in self.unit_list:
28 | # utils.init_lstm(cur_lstm)
29 |
30 | # def forward(self, x):
31 |
32 | # out, _ = self.layer(x)
33 |
34 | # if self.droprate > 0:
35 | # out = F.dropout(out, p=self.droprate, training=self.training)
36 |
37 | # return toch.cat([x, out], 2)
38 |
39 | # class DenseRNN(nn.Module):
40 | # def __init__(self, layer_num, unit, emb_dim, hid_dim, droprate):
41 | # super(DenseRNN, self).__init__()
42 |
43 | # self.layer = nn.Sequential([BasicUnit(unit, emb_dim + i * hid_dim, hid_dim, droprate) for i in range(layer_num) ])
44 |
45 | # self.output_dim = self.layer[-1].output_dim
46 |
47 | # self.init_hidden()
48 |
49 | # def init_hidden(self):
50 | # self.layer.apply(lambda t: t.init_hidden())
51 |
52 | # def rand_ini(self):
53 | # self.layer.apply(lambda t: t.rand_ini())
54 |
55 | # def forward(self, x):
56 | # return self.layer(x)
--------------------------------------------------------------------------------
/language-model/model_word_ada/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import json
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.init
8 |
9 | from torch.autograd import Variable
10 |
11 | def repackage_hidden(h):
12 | """Wraps hidden states in new Variables, to detach them from their history."""
13 | if type(h) == torch.Tensor:
14 | return h.data
15 | else:
16 | return tuple(repackage_hidden(v) for v in h)
17 |
18 | def to_scalar(var):
19 | """change the first element of a tensor to scalar
20 | """
21 | return var.view(-1).data.tolist()[0]
22 |
23 | def init_embedding(input_embedding):
24 | """
25 | Initialize embedding
26 | """
27 | bias = np.sqrt(3.0 / input_embedding.size(1))
28 | nn.init.uniform(input_embedding, -bias, bias)
29 |
30 | def init_linear(input_linear):
31 | """
32 | Initialize linear transformation
33 | """
34 | bias = np.sqrt(6.0 / (input_linear.weight.size(0) + input_linear.weight.size(1)))
35 | nn.init.uniform(input_linear.weight, -bias, bias)
36 | if input_linear.bias is not None:
37 | input_linear.bias.data.zero_()
38 |
39 | def adjust_learning_rate(optimizer, lr):
40 | """
41 | shrink learning rate for pytorch
42 | """
43 | for param_group in optimizer.param_groups:
44 | param_group['lr'] = lr
45 |
46 | def init_lstm(input_lstm):
47 | """
48 | Initialize lstm
49 | """
50 | for ind in range(0, input_lstm.num_layers):
51 | weight = eval('input_lstm.weight_ih_l'+str(ind))
52 | bias = np.sqrt(6.0 / (weight.size(0)/4 + weight.size(1)))
53 | nn.init.uniform(weight, -bias, bias)
54 | weight = eval('input_lstm.weight_hh_l'+str(ind))
55 | bias = np.sqrt(6.0 / (weight.size(0)/4 + weight.size(1)))
56 | nn.init.uniform(weight, -bias, bias)
57 |
58 | if input_lstm.bias:
59 | for ind in range(0, input_lstm.num_layers):
60 | weight = eval('input_lstm.bias_ih_l'+str(ind))
61 | weight.data.zero_()
62 | weight.data[input_lstm.hidden_size: 2 * input_lstm.hidden_size] = 1
63 | weight = eval('input_lstm.bias_hh_l'+str(ind))
64 | weight.data.zero_()
65 | weight.data[input_lstm.hidden_size: 2 * input_lstm.hidden_size] = 1
66 |
--------------------------------------------------------------------------------
/language-model/pre_word_ada/encode_data2folder.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import argparse
3 | import os
4 | import codecs
5 | import random
6 | import numpy as np
7 |
8 | from tqdm import tqdm
9 |
10 | import itertools
11 | import functools
12 |
13 | def encode_dataset(input_folder, w_map, reverse):
14 |
15 | w_eof = w_map['\n']
16 | w_unk = w_map['']
17 |
18 | list_dirs = os.walk(input_folder)
19 |
20 | lines = list()
21 |
22 | for root, dirs, files in list_dirs:
23 | for file in tqdm(files):
24 | with codecs.open(os.path.join(root, file), 'r', 'utf-8') as fin:
25 | lines = lines + list(filter(lambda t: t and not t.isspace(), fin.readlines()))
26 |
27 | dataset = list()
28 | for line in lines:
29 | dataset += list(map(lambda t: w_map.get(t, w_unk), line.split())) + [w_eof]
30 |
31 | if reverse:
32 | dataset = dataset[::-1]
33 |
34 | return dataset
35 |
36 | def encode_dataset2file(input_folder, output_folder, w_map, reverse):
37 |
38 | w_eof = w_map['\n']
39 | w_unk = w_map['']
40 |
41 | list_dirs = os.walk(input_folder)
42 |
43 | range_ind = 0
44 |
45 | for root, dirs, files in list_dirs:
46 | for file in tqdm(files):
47 | with codecs.open(os.path.join(root, file), 'r', 'utf-8') as fin:
48 | lines = list(filter(lambda t: t and not t.isspace(), fin.readlines()))
49 |
50 | dataset = list()
51 | for line in lines:
52 | dataset += list(map(lambda t: w_map.get(t, w_unk), line.split())) + [w_eof]
53 |
54 | if reverse:
55 | dataset = dataset[::-1]
56 |
57 | with open(output_folder+'train_'+ str(range_ind) + '.pk', 'wb') as f:
58 | pickle.dump(dataset, f)
59 |
60 | range_ind += 1
61 |
62 | return range_ind
63 |
64 | if __name__ == "__main__":
65 | parser = argparse.ArgumentParser()
66 | parser.add_argument('--train_folder', default="/data/billionwords/1-billion-word-language-modeling-benchmark/training-monolingual.tokenized.shuffled")
67 | parser.add_argument('--test_folder', default="/data/billionwords/1-billion-word-language-modeling-benchmark/heldout-monolingual.tokenized.shuffled")
68 | parser.add_argument('--input_map', default="/data/billionwords/1b_map.pk")
69 | parser.add_argument('--output_folder', default="/data/billionwords/one_billion/")
70 | parser.add_argument('--threshold', type=int, default=3)
71 | parser.add_argument('--unk', default='')
72 | parser.add_argument('--reverse', action='store_true')
73 | args = parser.parse_args()
74 |
75 | with open(args.input_map, 'rb') as f:
76 | w_count = pickle.load(f)
77 |
78 | unk_count = sum([v for k, v in w_count.items() if v <= args.threshold])
79 | w_list = [(k, v) for k, v in w_count.items() if v > args.threshold]
80 | w_list.append(('', unk_count))
81 | w_list.sort(key=lambda t: t[1], reverse=True)
82 | w_map = {kv[0]:v for v, kv in enumerate(w_list)}
83 |
84 | range_ind = encode_dataset2file(args.train_folder, args.output_folder, w_map, args.reverse)
85 |
86 | test_dataset = encode_dataset(args.test_folder, w_map, args.reverse)
87 |
88 | with open(args.output_folder+'test.pk', 'wb') as f:
89 | pickle.dump({'w_map': w_map, 'test_data':test_dataset, 'range' : range_ind}, f)
90 |
--------------------------------------------------------------------------------
/language-model/pre_word_ada/gene_map.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import argparse
3 | import os
4 | import random
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | import itertools
9 | import functools
10 |
11 | if __name__ == "__main__":
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--input_folder', default="/data/billionwords/1-billion-word-language-modeling-benchmark/training-monolingual.tokenized.shuffled")
14 | parser.add_argument('--output_map', default="/data/billionwords/1b_map.pk")
15 | args = parser.parse_args()
16 |
17 | w_count = {'\n':0}
18 |
19 | list_dirs = os.walk(args.input_folder)
20 |
21 | for root, dirs, files in list_dirs:
22 | for file in tqdm(files):
23 | with open(os.path.join(root, file)) as fin:
24 | for line in fin:
25 | if not line or line.isspace():
26 | continue
27 | line = line.split()
28 | for tup in line:
29 | w_count[tup] = w_count.get(tup, 0) + 1
30 | w_count['\n'] += 1
31 |
32 | with open(args.output_map, 'wb') as f:
33 | pickle.dump(w_count, f)
--------------------------------------------------------------------------------
/language-model/recipes.md:
--------------------------------------------------------------------------------
1 | # Pre-process
2 |
3 | ```
4 | python pre_word_ada/gene_map.py --input_folder /data/billionwords/1-billion-word-language-modeling-benchmark/training-monolingual.tokenized.shuffled --output_map /data/billionwords/1b_map.pk
5 |
6 | python pre_word_ada/encode_data2folder.py --train_folder /data/billionwords/1-billion-word-language-modeling-benchmark/training-monolingual.tokenized.shuffled --test_folder /data/billionwords/1-billion-word-language-modeling-benchmark/heldout-monolingual.tokenized.shuffled --input_map /data/billionwords/1b_map.pk --output_folder /data/billionwords/one_billion/
7 | ```
8 |
9 | # Training
10 |
11 | ## Adam
12 | ```
13 | python train_1bw.py --dataset_folder /data/billionwords/one_billion/ --lr 0.001 --checkpath ./cps/gadam/ --model_name adam --update Adam
14 | ```
15 |
16 | ## RAdam
17 | ```
18 | python train_1bw.py --dataset_folder /data/billionwords/one_billion/ --lr 0.001 --checkpath ./cps/gadam/ --model_name radam --update RAdam
19 | ```
20 |
--------------------------------------------------------------------------------
/language-model/train_1bw.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import datetime
3 | import time
4 | import torch
5 | import torch.autograd as autograd
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | import codecs
9 | import pickle
10 | import math
11 |
12 | from model_word_ada.LM import LM
13 | from model_word_ada.basic import BasicRNN
14 | from model_word_ada.ddnet import DDRNN
15 | from model_word_ada.radam import RAdam
16 | from model_word_ada.ldnet import LDRNN
17 | from model_word_ada.densenet import DenseRNN
18 | from model_word_ada.dataset import LargeDataset, EvalDataset
19 | from model_word_ada.adaptive import AdaptiveSoftmax
20 | import model_word_ada.utils as utils
21 |
22 | # from tensorboardX import SummaryWriter
23 | # writer = SummaryWriter(logdir='./cps/gadam/log_1bw_full/')
24 |
25 | import argparse
26 | import json
27 | import os
28 | import sys
29 | import itertools
30 | import functools
31 |
32 | def evaluate(data_loader, lm_model, criterion, limited = 76800):
33 | print('evaluating')
34 | lm_model.eval()
35 |
36 | iterator = data_loader.get_tqdm()
37 |
38 | lm_model.init_hidden()
39 | total_loss = 0
40 | total_len = 0
41 | for word_t, label_t in iterator:
42 | label_t = label_t.view(-1)
43 | tmp_len = label_t.size(0)
44 | output = lm_model.log_prob(word_t)
45 | total_loss += tmp_len * utils.to_scalar(criterion(autograd.Variable(output), label_t))
46 | total_len += tmp_len
47 |
48 | if limited >=0 and total_len > limited:
49 | break
50 |
51 | ppl = math.exp(total_loss / total_len)
52 | print('PPL: ' + str(ppl))
53 |
54 | return ppl
55 |
56 | if __name__ == "__main__":
57 | parser = argparse.ArgumentParser()
58 | parser.add_argument('--dataset_folder', default='/data/billionwords/one_billion/')
59 | parser.add_argument('--load_checkpoint', default='')
60 | parser.add_argument('--batch_size', type=int, default=128)
61 | parser.add_argument('--sequence_length', type=int, default=20)
62 | parser.add_argument('--hid_dim', type=int, default=2048)
63 | parser.add_argument('--word_dim', type=int, default=300)
64 | parser.add_argument('--label_dim', type=int, default=-1)
65 | parser.add_argument('--layer_num', type=int, default=2)
66 | parser.add_argument('--droprate', type=float, default=0.1)
67 | parser.add_argument('--add_relu', action='store_true')
68 | parser.add_argument('--layer_drop', type=float, default=0.5)
69 | parser.add_argument('--gpu', type=int, default=1)
70 | parser.add_argument('--epoch', type=int, default=14)
71 | parser.add_argument('--clip', type=float, default=5)
72 | parser.add_argument('--update', choices=['Adam', 'Adagrad', 'Adadelta', 'RAdam', 'SGD'], default='Adam')
73 | parser.add_argument('--rnn_layer', choices=['Basic', 'DDNet', 'DenseNet', 'LDNet'], default='Basic')
74 | parser.add_argument('--rnn_unit', choices=['gru', 'lstm', 'rnn', 'bnlstm'], default='lstm')
75 | parser.add_argument('--lr', type=float, default=-1)
76 | parser.add_argument('--lr_decay', type=lambda t: [int(tup) for tup in t.split(',')], default=[8])
77 | parser.add_argument('--cut_off', nargs='+', default=[4000,40000,200000])
78 | parser.add_argument('--interval', type=int, default=100)
79 | parser.add_argument('--check_interval', type=int, default=4000)
80 | parser.add_argument('--checkpath', default='./cps/gadam/')
81 | parser.add_argument('--model_name', default='adam')
82 | args = parser.parse_args()
83 |
84 | if args.gpu >= 0:
85 | torch.cuda.set_device(args.gpu)
86 |
87 | print('loading dataset')
88 | dataset = pickle.load(open(args.dataset_folder + 'test.pk', 'rb'))
89 | w_map, test_data, range_idx = dataset['w_map'], dataset['test_data'], dataset['range']
90 |
91 | cut_off = args.cut_off + [len(w_map) + 1]
92 |
93 | train_loader = LargeDataset(args.dataset_folder, range_idx, args.batch_size, args.sequence_length)
94 | test_loader = EvalDataset(test_data, args.batch_size)
95 |
96 | print('building model')
97 |
98 | rnn_map = {'Basic': BasicRNN, 'DDNet': DDRNN, 'DenseNet': DenseRNN, 'LDNet': functools.partial(LDRNN, layer_drop = args.layer_drop)}
99 | rnn_layer = rnn_map[args.rnn_layer](args.layer_num, args.rnn_unit, args.word_dim, args.hid_dim, args.droprate)
100 |
101 | if args.label_dim > 0:
102 | soft_max = AdaptiveSoftmax(args.label_dim, cut_off)
103 | else:
104 | soft_max = AdaptiveSoftmax(rnn_layer.output_dim, cut_off)
105 |
106 | lm_model = LM(rnn_layer, soft_max, len(w_map), args.word_dim, args.droprate, label_dim = args.label_dim, add_relu=args.add_relu)
107 | lm_model.rand_ini()
108 | # lm_model.cuda()
109 |
110 | optim_map = {'Adam' : optim.Adam, 'Adagrad': optim.Adagrad, 'Adadelta': optim.Adadelta, 'RAdam': RAdam, 'SGD': functools.partial(optim.SGD, momentum=0.9)}
111 | if args.lr > 0:
112 | optimizer=optim_map[args.update](lm_model.parameters(), lr=args.lr)
113 | else:
114 | optimizer=optim_map[args.update](lm_model.parameters())
115 |
116 | if args.load_checkpoint:
117 | if os.path.isfile(args.load_checkpoint):
118 | print("loading checkpoint: '{}'".format(args.load_checkpoint))
119 | checkpoint_file = torch.load(args.load_checkpoint, map_location=lambda storage, loc: storage)
120 | lm_model.load_state_dict(checkpoint_file['lm_model'], False)
121 | optimizer.load_state_dict(checkpoint_file['opt'], False)
122 | else:
123 | print("no checkpoint found at: '{}'".format(args.load_checkpoint))
124 |
125 | test_lm = nn.NLLLoss()
126 |
127 | test_lm.cuda()
128 | lm_model.cuda()
129 |
130 | batch_index = 0
131 | epoch_loss = 0
132 | full_epoch_loss = 0
133 | best_train_ppl = float('inf')
134 | cur_lr = args.lr
135 |
136 | try:
137 | for indexs in range(args.epoch):
138 |
139 | print('#' * 89)
140 | print('Start: {}'.format(indexs))
141 |
142 | iterator = train_loader.get_tqdm()
143 | full_epoch_loss = 0
144 |
145 | lm_model.train()
146 |
147 | for word_t, label_t in iterator:
148 |
149 | if 1 == train_loader.cur_idx:
150 | lm_model.init_hidden()
151 |
152 | label_t = label_t.view(-1)
153 |
154 | lm_model.zero_grad()
155 | loss = lm_model(word_t, label_t)
156 |
157 | loss.backward()
158 | torch.nn.utils.clip_grad_norm(lm_model.parameters(), args.clip)
159 | optimizer.step()
160 |
161 | batch_index += 1
162 |
163 | if 0 == batch_index % args.interval:
164 | s_loss = utils.to_scalar(loss)
165 | # writer.add_scalars('loss_tracking/train_loss', {args.model_name:s_loss}, batch_index)
166 |
167 | epoch_loss += utils.to_scalar(loss)
168 | full_epoch_loss += utils.to_scalar(loss)
169 | if 0 == batch_index % args.check_interval:
170 | epoch_ppl = math.exp(epoch_loss / args.check_interval)
171 | # writer.add_scalars('loss_tracking/train_ppl', {args.model_name: epoch_ppl}, batch_index)
172 | print('epoch_ppl: {} lr: {} @ batch_index: {}'.format(epoch_ppl, cur_lr, batch_index))
173 | epoch_loss = 0
174 |
175 | if indexs in args.lr_decay and cur_lr > 0:
176 | cur_lr *= 0.1
177 | print('adjust_learning_rate...')
178 | utils.adjust_learning_rate(optimizer, cur_lr)
179 |
180 | test_ppl = evaluate(test_loader, lm_model, test_lm, -1)
181 |
182 | # writer.add_scalars('loss_tracking/test_ppl', {args.model_name: test_ppl}, indexs)
183 | print('test_ppl: {} @ index: {}'.format(test_ppl, indexs))
184 |
185 | torch.save({'lm_model': lm_model.state_dict(), 'opt':optimizer.state_dict()}, args.checkpath+args.model_name+'.model')
186 |
187 | except KeyboardInterrupt:
188 |
189 | print('Exiting from training early')
190 | test_ppl = evaluate(test_loader, lm_model, test_lm, -1)
191 | # writer.add_scalars('loss_tracking/test_ppl', {args.model_name: test_ppl}, args.epoch)
192 |
193 | torch.save({'lm_model': lm_model.state_dict(), 'opt':optimizer.state_dict()}, args.checkpath+args.model_name+'.model')
194 |
195 | # writer.close()
196 |
--------------------------------------------------------------------------------
/nmt/README.md:
--------------------------------------------------------------------------------
1 | # NMT
2 |
3 | This folder is based on the fairseq package ([repo](https://github.com/pytorch/fairseq)). For more details about this code base, please refer to the original repo.
4 | A training [recipe](/nmt/recipes.md) is provided for nmt experiments.
5 |
6 |
--------------------------------------------------------------------------------
/nmt/average_checkpoints.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) 2017-present, Facebook, Inc.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the LICENSE file in
6 | # the root directory of this source tree. An additional grant of patent rights
7 | # can be found in the PATENTS file in the same directory.
8 |
9 | import argparse
10 | import collections
11 | import torch
12 | import os
13 | import re
14 |
15 |
16 | def average_checkpoints(inputs):
17 | """Loads checkpoints from inputs and returns a model with averaged weights.
18 |
19 | Args:
20 | inputs: An iterable of string paths of checkpoints to load from.
21 |
22 | Returns:
23 | A dict of string keys mapping to various values. The 'model' key
24 | from the returned dict should correspond to an OrderedDict mapping
25 | string parameter names to torch Tensors.
26 | """
27 | params_dict = collections.OrderedDict()
28 | params_keys = None
29 | new_state = None
30 | num_models = len(inputs)
31 |
32 | for f in inputs:
33 | state = torch.load(
34 | f,
35 | map_location=(
36 | lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
37 | ),
38 | )
39 | # Copies over the settings from the first checkpoint
40 | if new_state is None:
41 | new_state = state
42 |
43 | model_params = state['model']
44 |
45 | model_params_keys = list(model_params.keys())
46 | if params_keys is None:
47 | params_keys = model_params_keys
48 | elif params_keys != model_params_keys:
49 | raise KeyError(
50 | 'For checkpoint {}, expected list of params: {}, '
51 | 'but found: {}'.format(f, params_keys, model_params_keys)
52 | )
53 |
54 | for k in params_keys:
55 | p = model_params[k]
56 | if isinstance(p, torch.HalfTensor):
57 | p = p.float()
58 | if k not in params_dict:
59 | params_dict[k] = p.clone()
60 | # NOTE: clone() is needed in case of p is a shared parameter
61 | else:
62 | params_dict[k] += p
63 |
64 | averaged_params = collections.OrderedDict()
65 | for k, v in params_dict.items():
66 | averaged_params[k] = v
67 | averaged_params[k].div_(num_models)
68 | new_state['model'] = averaged_params
69 | return new_state
70 |
71 |
72 | def last_n_checkpoints(paths, n, update_based, upper_bound=None):
73 | assert len(paths) == 1
74 | path = paths[0]
75 | if update_based:
76 | pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt')
77 | else:
78 | pt_regexp = re.compile(r'checkpoint(\d+)\.pt')
79 | files = os.listdir(path)
80 |
81 | entries = []
82 | for f in files:
83 | m = pt_regexp.fullmatch(f)
84 | if m is not None:
85 | sort_key = int(m.group(1))
86 | if upper_bound is None or sort_key <= upper_bound:
87 | entries.append((sort_key, m.group(0)))
88 | if len(entries) < n:
89 | raise Exception('Found {} checkpoint files but need at least {}', len(entries), n)
90 | return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]]
91 |
92 |
93 | def main():
94 | parser = argparse.ArgumentParser(
95 | description='Tool to average the params of input checkpoints to '
96 | 'produce a new checkpoint',
97 | )
98 | # fmt: off
99 | parser.add_argument('--inputs', required=True, nargs='+',
100 | help='Input checkpoint file paths.')
101 | parser.add_argument('--output', required=True, metavar='FILE',
102 | help='Write the new checkpoint containing the averaged weights to this path.')
103 | num_group = parser.add_mutually_exclusive_group()
104 | num_group.add_argument('--num-epoch-checkpoints', type=int,
105 | help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
106 | 'and average last this many of them.')
107 | num_group.add_argument('--num-update-checkpoints', type=int,
108 | help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
109 | 'and average last this many of them.')
110 | parser.add_argument('--checkpoint-upper-bound', type=int,
111 | help='when using --num-epoch-checkpoints, this will set an upper bound on which checkpoint to use, '
112 | 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.')
113 | # fmt: on
114 | args = parser.parse_args()
115 | print(args)
116 |
117 | num = None
118 | is_update_based = False
119 | if args.num_update_checkpoints is not None:
120 | num = args.num_update_checkpoints
121 | is_update_based = True
122 | elif args.num_epoch_checkpoints is not None:
123 | num = args.num_epoch_checkpoints
124 |
125 | assert args.checkpoint_upper_bound is None or args.num_epoch_checkpoints is not None, \
126 | '--checkpoint-upper-bound requires --num-epoch-checkpoints'
127 | assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \
128 | 'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints'
129 |
130 | if num is not None:
131 | args.inputs = last_n_checkpoints(
132 | args.inputs, num, is_update_based, upper_bound=args.checkpoint_upper_bound,
133 | )
134 | print('averaging checkpoints: ', args.inputs)
135 |
136 | new_state = average_checkpoints(args.inputs)
137 | torch.save(new_state, args.output)
138 | print('Finished writing averaged checkpoint to {}.'.format(args.output))
139 |
140 |
141 | if __name__ == '__main__':
142 | main()
143 |
--------------------------------------------------------------------------------
/nmt/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | echo "Model path" $SAVEDIR
3 | GPUDEV=${2:-0}
4 | SAVEDIR=${1}
5 | MODELDIR=$SAVEDIR/model_ed.pt
6 | if [ -f $MODELDIR ]; then
7 | echo $MODELDIR "already exists"
8 | else
9 | echo "Start averaging model"
10 | python average_checkpoints.py --inputs $SAVEDIR --num-epoch-checkpoints 10 --output $MODELDIR | grep 'Finish'
11 | echo "End averaging model"
12 | fi
13 |
14 | CUDA_VISIBLE_DEVICES=$GPUDEV fairseq-generate data-bin/iwslt14.tokenized.de-en \
15 | --path $MODELDIR \
16 | --batch-size 128 --beam 5 --remove-bpe \
17 | --user-dir ./my_module 2>&1 | grep BLEU4
18 |
19 | # CUDA_VISIBLE_DEVICES=$GPUDEV fairseq-generate data-bin/iwslt14.tokenized.en-de \
20 | # --path $MODELDIR \
21 | # --batch-size 128 --beam 5 --remove-bpe \
22 | # --user-dir ./my_module 2>&1 | grep BLEU4
23 |
--------------------------------------------------------------------------------
/nmt/my_module/__init__.py:
--------------------------------------------------------------------------------
1 | from .radam import *
2 | from .adam2 import *
3 | from .linear_schedule import *
4 | from .poly_schedule import *
5 | from .novograd import *
6 |
--------------------------------------------------------------------------------
/nmt/my_module/adam2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the LICENSE file in
5 | # the root directory of this source tree. An additional grant of patent rights
6 | # can be found in the PATENTS file in the same directory.
7 |
8 | import math
9 | import types
10 |
11 | import torch
12 | import torch.optim
13 |
14 | from fairseq.optim import FairseqOptimizer, register_optimizer
15 |
16 | from tensorboardX import SummaryWriter
17 | # writer = SummaryWriter(logdir='./log/ada/')
18 | # # writer = SummaryWriter(logdir='./log/wmt/')
19 |
20 | iter_idx = 0
21 |
22 | @register_optimizer('adam2')
23 | class FairseqAdam2(FairseqOptimizer):
24 |
25 | def __init__(self, args, params):
26 | super().__init__(args, params)
27 |
28 | self._optimizer = Adam2(params, **self.optimizer_config)
29 | self._optimizer.name = args.tb_tag + '_' + self._optimizer.name
30 |
31 | @staticmethod
32 | def add_args(parser):
33 | """Add optimizer-specific arguments to the parser."""
34 | # fmt: off
35 | parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B',
36 | help='betas for Adam optimizer')
37 | parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D',
38 | help='epsilon for Adam optimizer')
39 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
40 | help='weight decay')
41 | parser.add_argument('--tb-tag', default="", type=str,
42 | help='tb tag')
43 | parser.add_argument('--amsgrad', action='store_true')
44 | parser.add_argument('--adam-freeze', default=5000, type=float)
45 | parser.add_argument('--adam-no-correction1', action='store_true')
46 | # fmt: on
47 |
48 | @property
49 | def optimizer_config(self):
50 | """
51 | Return a kwarg dictionary that will be used to override optimizer
52 | args stored in checkpoints. This allows us to load a checkpoint and
53 | resume training using a different set of optimizer args, e.g., with a
54 | different learning rate.
55 | """
56 | return {
57 | 'lr': self.args.lr[0],
58 | 'betas': eval(self.args.adam_betas),
59 | 'eps': self.args.adam_eps,
60 | 'weight_decay': self.args.weight_decay,
61 | 'amsgrad': self.args.amsgrad,
62 | 'adam_freeze': self.args.adam_freeze,
63 | 'adam_no_correction1': self.args.adam_no_correction1,
64 | }
65 |
66 |
67 | class Adam2(torch.optim.Optimizer):
68 |
69 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
70 | weight_decay=0, amsgrad=False, adam_freeze=5000, adam_no_correction1=False):
71 | defaults = dict(lr=lr, betas=betas, eps=eps,
72 | weight_decay=weight_decay, amsgrad=amsgrad, adam_freeze=adam_freeze, adam_no_correction1=adam_no_correction1)
73 | self.name = '{}_{}_{}'.format(lr, betas[0], betas[1])
74 | super(Adam2, self).__init__(params, defaults)
75 |
76 | @property
77 | def supports_memory_efficient_fp16(self):
78 | return True
79 |
80 | def step(self, closure=None):
81 | """Performs a single optimization step.
82 |
83 | Arguments:
84 | closure (callable, optional): A closure that reevaluates the model
85 | and returns the loss.
86 | """
87 | global iter_idx
88 | iter_idx += 1
89 | grad_list = list()
90 | mom_list = list()
91 | mom_2rd_list = list()
92 |
93 | loss = None
94 | if closure is not None:
95 | loss = closure()
96 |
97 | for group in self.param_groups:
98 |
99 | # if 'adam_1k' in self.name:
100 | # writer_iter = iter_idx - group['adam_freeze']
101 | # else:
102 | # writer_iter = iter_idx
103 |
104 | for p in group['params']:
105 | if p.grad is None:
106 | continue
107 | grad = p.grad.data.float()
108 | if grad.is_sparse:
109 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
110 | amsgrad = group['amsgrad']
111 |
112 | p_data_fp32 = p.data.float()
113 |
114 | state = self.state[p]
115 |
116 | if len(state) == 0:
117 | state['step'] = 0
118 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
119 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
120 | if amsgrad:
121 | state['max_exp_avg_sq'] = torch.zeros_like(p_data_fp32)
122 | else:
123 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
124 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
125 | if amsgrad:
126 | state['max_exp_avg_sq'] = state['max_exp_avg_sq'].type_as(p_data_fp32)
127 |
128 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
129 | if amsgrad:
130 | max_exp_avg_sq = state['max_exp_avg_sq']
131 | beta1, beta2 = group['betas']
132 |
133 | state['step'] += 1
134 | exp_avg_sq.mul_(beta2).addcmul_(1-beta2, grad, grad)
135 |
136 | denom = exp_avg_sq.sqrt().add_(group['eps'])
137 |
138 | if group['adam_no_correction1']:
139 | bias_correction1 = 1
140 | else:
141 | bias_correction1 = (1 - beta1 ** state['step'])
142 |
143 | bias_correction2 = (1 - beta2 ** state['step'])**0.5
144 | step_size = group['lr'] * bias_correction2 / bias_correction1
145 |
146 |
147 | if 'adam_1k' not in self.name or state['step'] > group['adam_freeze']:
148 | if group['weight_decay'] != 0:
149 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
150 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
151 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
152 | p.data.copy_(p_data_fp32)
153 |
154 | # if writer_iter > 0 and writer_iter % 300 == 0 or writer_iter in [1, 5, 10, 25, 50, 75, 100, 150, 200]:
155 | # grad_list.extend( grad.abs().add_(1e-9).log().view(-1).tolist() )
156 | # mom_list.extend( exp_avg.abs().add_(1e-9).log().view(-1).tolist() )
157 | # mom_2rd_list.extend( exp_avg_sq.abs().add_(1e-9).log().view(-1).tolist() )
158 |
159 | # if writer_iter > 0 and writer_iter % 300 == 0 or writer_iter in [1, 5, 10, 25, 50, 75, 100, 150, 200]:
160 | # writer.add_histogram('grad/{}'.format(self.name), grad_list, writer_iter)
161 | # writer.add_histogram('mom/{}'.format(self.name), mom_list, writer_iter)
162 | # writer.add_histogram('mom_sq/{}'.format(self.name), mom_2rd_list, writer_iter)
163 |
164 | return loss
165 |
--------------------------------------------------------------------------------
/nmt/my_module/linear_schedule.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the LICENSE file in
5 | # the root directory of this source tree. An additional grant of patent rights
6 | # can be found in the PATENTS file in the same directory.
7 |
8 | from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
9 |
10 |
11 | @register_lr_scheduler('linear')
12 | class LinearSchedule(FairseqLRScheduler):
13 | """Decay the LR based on the inverse square root of the update number.
14 |
15 | We also support a warmup phase where we linearly increase the learning rate
16 | from some initial learning rate (``--warmup-init-lr``) until the configured
17 | learning rate (``--lr``). Thereafter we decay proportional to the number of
18 | updates, with a decay factor set to align with the configured learning rate.
19 |
20 | During warmup::
21 |
22 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
23 | lr = lrs[update_num]
24 |
25 | After warmup::
26 |
27 | decay_factor = args.lr * sqrt(args.warmup_updates)
28 | lr = decay_factor / sqrt(update_num)
29 | """
30 |
31 | def __init__(self, args, optimizer):
32 | super().__init__(args, optimizer)
33 | if len(args.lr) > 1:
34 | raise ValueError(
35 | 'Cannot use a fixed learning rate schedule with inverse_sqrt.'
36 | ' Consider --lr-scheduler=fixed instead.'
37 | )
38 | warmup_end_lr = args.lr[0]
39 | if args.warmup_init_lr < 0:
40 | args.warmup_init_lr = warmup_end_lr
41 |
42 | # linearly warmup for the first args.warmup_updates
43 | self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates
44 |
45 | # then, decay prop. to the inverse square root of the update number
46 | # self.warmup_end_lr = warmup_end_lr * args.warmup_updates**0.5
47 | self.min_lr = args.min_lr
48 | self.warmup_end_lr = warmup_end_lr - self.min_lr
49 |
50 | # initial learning rate
51 | self.lr = args.warmup_init_lr
52 | self.optimizer.set_lr(self.lr)
53 |
54 | self.max_update = args.max_update - args.warmup_updates
55 |
56 | @staticmethod
57 | def add_args(parser):
58 | """Add arguments to the parser for this LR scheduler."""
59 | # fmt: off
60 | parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N',
61 | help='warmup the learning rate linearly for the first N updates')
62 | parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
63 | help='initial learning rate during warmup phase; default is args.lr')
64 | # fmt: on
65 |
66 | def step(self, epoch, val_loss=None):
67 | """Update the learning rate at the end of the given epoch."""
68 | super().step(epoch, val_loss)
69 | # we don't change the learning rate at epoch boundaries
70 | return self.optimizer.get_lr()
71 |
72 | def step_update(self, num_updates):
73 | """Update the learning rate after each update."""
74 | if num_updates < self.args.warmup_updates:
75 | self.lr = self.args.warmup_init_lr + num_updates*self.lr_step
76 | else:
77 | self.lr = self.warmup_end_lr * (1 - (num_updates - self.args.warmup_updates) / self.max_update) + self.min_lr
78 | self.optimizer.set_lr(self.lr)
79 | return self.lr
80 |
--------------------------------------------------------------------------------
/nmt/my_module/novograd.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the LICENSE file in
5 | # the root directory of this source tree. An additional grant of patent rights
6 | # can be found in the PATENTS file in the same directory.
7 |
8 | import math
9 | import types
10 |
11 | import torch
12 | import torch.optim
13 |
14 | from fairseq.optim import FairseqOptimizer, register_optimizer
15 |
16 |
17 | iter_idx = 0
18 |
19 | @register_optimizer('novograd')
20 | class FairseqNovograd(FairseqOptimizer):
21 |
22 | def __init__(self, args, params):
23 | super().__init__(args, params)
24 |
25 | self._optimizer = Novograd(params, **self.optimizer_config)
26 | self._optimizer.name = args.tb_tag + '_' + self._optimizer.name
27 |
28 | @staticmethod
29 | def add_args(parser):
30 | """Add optimizer-specific arguments to the parser."""
31 | # fmt: off
32 | parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B',
33 | help='betas for Adam optimizer')
34 | parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D',
35 | help='epsilon for Adam optimizer')
36 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
37 | help='weight decay')
38 | parser.add_argument('--tb-tag', default="", type=str,
39 | help='tb tag')
40 | parser.add_argument('--amsgrad', action='store_true')
41 | parser.add_argument('--adam-freeze', default=5000, type=float)
42 | parser.add_argument('--adam-no-correction1', action='store_true')
43 | # fmt: on
44 |
45 | @property
46 | def optimizer_config(self):
47 | """
48 | Return a kwarg dictionary that will be used to override optimizer
49 | args stored in checkpoints. This allows us to load a checkpoint and
50 | resume training using a different set of optimizer args, e.g., with a
51 | different learning rate.
52 | """
53 | return {
54 | 'lr': self.args.lr[0],
55 | 'betas': eval(self.args.adam_betas),
56 | 'eps': self.args.adam_eps,
57 | 'weight_decay': self.args.weight_decay,
58 | 'amsgrad': self.args.amsgrad,
59 | 'adam_freeze': self.args.adam_freeze,
60 | 'adam_no_correction1': self.args.adam_no_correction1,
61 | }
62 |
63 |
64 | class Novograd(torch.optim.Optimizer):
65 |
66 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
67 | weight_decay=0, amsgrad=False, adam_freeze=5000, adam_no_correction1=False):
68 | defaults = dict(lr=lr, betas=betas, eps=eps,
69 | weight_decay=weight_decay, amsgrad=amsgrad, adam_freeze=adam_freeze, adam_no_correction1=adam_no_correction1)
70 | self.name = '{}_{}_{}'.format(lr, betas[0], betas[1])
71 | super(Novograd, self).__init__(params, defaults)
72 |
73 | @property
74 | def supports_memory_efficient_fp16(self):
75 | return True
76 |
77 | def step(self, closure=None):
78 | """Performs a single optimization step.
79 |
80 | Arguments:
81 | closure (callable, optional): A closure that reevaluates the model
82 | and returns the loss.
83 | """
84 | global iter_idx
85 | iter_idx += 1
86 | grad_list = list()
87 | mom_list = list()
88 | mom_2rd_list = list()
89 |
90 | loss = None
91 | if closure is not None:
92 | loss = closure()
93 |
94 | for group in self.param_groups:
95 |
96 |
97 | for p in group['params']:
98 | if p.grad is None:
99 | continue
100 | grad = p.grad.data.float()
101 | if grad.is_sparse:
102 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
103 | amsgrad = group['amsgrad']
104 |
105 | p_data_fp32 = p.data.float()
106 |
107 | state = self.state[p]
108 |
109 | if len(state) == 0:
110 | state['step'] = 0
111 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
112 | state['exp_avg_sq'] = 0
113 | if amsgrad:
114 | state['max_exp_avg_sq'] = torch.zeros_like(p_data_fp32)
115 | else:
116 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
117 | if amsgrad:
118 | state['max_exp_avg_sq'] = state['max_exp_avg_sq'].type_as(p_data_fp32)
119 |
120 | exp_avg = state['exp_avg']
121 | if amsgrad:
122 | max_exp_avg_sq = state['max_exp_avg_sq']
123 | beta1, beta2 = group['betas']
124 |
125 | state['step'] += 1
126 | state['exp_avg_sq'] = state['exp_avg_sq']*beta2 + (1-beta2)*grad.norm().item()**2
127 |
128 | denom = state['exp_avg_sq']**0.5 + group['eps']
129 |
130 | step_size = group['lr']
131 |
132 |
133 | exp_avg.mul_(beta1).add_( grad/denom )
134 | if group['weight_decay'] != 0:
135 | exp_avg.add_( group['weight_decay'], p_data_fp32)
136 | p_data_fp32.add_(-step_size, exp_avg)
137 | p.data.copy_(p_data_fp32)
138 |
139 | return loss
140 |
--------------------------------------------------------------------------------
/nmt/my_module/poly_schedule.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the LICENSE file in
5 | # the root directory of this source tree. An additional grant of patent rights
6 | # can be found in the PATENTS file in the same directory.
7 |
8 | from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler
9 |
10 |
11 | @register_lr_scheduler('poly')
12 | class PolySchedule(FairseqLRScheduler):
13 | """Decay the LR based on the inverse square root of the update number.
14 |
15 | We also support a warmup phase where we linearly increase the learning rate
16 | from some initial learning rate (``--warmup-init-lr``) until the configured
17 | learning rate (``--lr``). Thereafter we decay proportional to the number of
18 | updates, with a decay factor set to align with the configured learning rate.
19 |
20 | During warmup::
21 |
22 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
23 | lr = lrs[update_num]
24 |
25 | After warmup::
26 |
27 | decay_factor = args.lr * sqrt(args.warmup_updates)
28 | lr = decay_factor / sqrt(update_num)
29 | """
30 |
31 | def __init__(self, args, optimizer):
32 | super().__init__(args, optimizer)
33 | if len(args.lr) > 1:
34 | raise ValueError(
35 | 'Cannot use a fixed learning rate schedule with inverse_sqrt.'
36 | ' Consider --lr-scheduler=fixed instead.'
37 | )
38 |
39 | # then, decay prop. to the inverse square root of the update number
40 | # self.warmup_end_lr = warmup_end_lr * args.warmup_updates**0.5
41 | self.min_lr = args.min_lr
42 |
43 | # initial learning rate
44 | self.lr = args.lr[0]
45 | self.optimizer.set_lr(self.lr)
46 |
47 | self.max_update = args.max_update
48 |
49 | @staticmethod
50 | def add_args(parser):
51 | """Add arguments to the parser for this LR scheduler."""
52 | # fmt: off
53 | parser.add_argument('--poly-pow', default=2, type=float, metavar='N',
54 | help='ploy power')
55 |
56 | def step(self, epoch, val_loss=None):
57 | """Update the learning rate at the end of the given epoch."""
58 | super().step(epoch, val_loss)
59 | # we don't change the learning rate at epoch boundaries
60 | return self.optimizer.get_lr()
61 |
62 | def step_update(self, num_updates):
63 | """Update the learning rate after each update."""
64 | self.lr = (self.args.lr[0] - self.min_lr)* (1 - num_updates / self.max_update)**self.args.poly_pow + self.min_lr
65 | self.optimizer.set_lr(self.lr)
66 | return self.lr
67 |
--------------------------------------------------------------------------------
/nmt/my_module/radam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the LICENSE file in
5 | # the root directory of this source tree. An additional grant of patent rights
6 | # can be found in the PATENTS file in the same directory.
7 |
8 | import math
9 | import types
10 |
11 | import torch
12 | import torch.optim
13 | # from ipdb import set_trace
14 | from fairseq.optim import FairseqOptimizer, register_optimizer
15 |
16 | # from tensorboardX import SummaryWriter
17 | # # writer = SummaryWriter(logdir='./log/wmt/')
18 | # writer = SummaryWriter(logdir='./log/ada/')
19 | # iter_idx = 0
20 |
21 | @register_optimizer('radam')
22 | class FairseqRAdam(FairseqOptimizer):
23 |
24 | def __init__(self, args, params):
25 | super().__init__(args, params)
26 |
27 | self._optimizer = RAdam(params, **self.optimizer_config)
28 | self._optimizer.name = args.tb_tag + '_' + self._optimizer.name
29 |
30 | @staticmethod
31 | def add_args(parser):
32 | """Add optimizer-specific arguments to the parser."""
33 | # fmt: off
34 | parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B',
35 | help='betas for Adam optimizer')
36 | parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D',
37 | help='epsilon for Adam optimizer')
38 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
39 | help='weight decay')
40 | parser.add_argument('--tb-tag', default="", type=str,
41 | help='tb tag')
42 | # fmt: on
43 |
44 | @property
45 | def optimizer_config(self):
46 | """
47 | Return a kwarg dictionary that will be used to override optimizer
48 | args stored in checkpoints. This allows us to load a checkpoint and
49 | resume training using a different set of optimizer args, e.g., with a
50 | different learning rate.
51 | """
52 | return {
53 | 'lr': self.args.lr[0],
54 | 'betas': eval(self.args.adam_betas),
55 | 'eps': self.args.adam_eps,
56 | 'weight_decay': self.args.weight_decay,
57 | }
58 |
59 | class RAdam(torch.optim.Optimizer):
60 |
61 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
62 | weight_decay=0, amsgrad=False):
63 | defaults = dict(lr=lr, betas=betas, eps=eps,
64 | weight_decay=weight_decay, amsgrad=amsgrad)
65 |
66 | self.name = '{}_{}_{}'.format(lr, betas[0], betas[1])
67 | super(RAdam, self).__init__(params, defaults)
68 |
69 | @property
70 | def supports_memory_efficient_fp16(self):
71 | return True
72 |
73 | def step(self, closure=None):
74 | """Performs a single optimization step.
75 |
76 | Arguments:
77 | closure (callable, optional): A closure that reevaluates the model
78 | and returns the loss.
79 | """
80 | global iter_idx
81 | iter_idx += 1
82 | grad_list = list()
83 | mom_list = list()
84 | mom_2rd_list = list()
85 | assert 'adam_1k' not in self.name
86 | writer_iter = iter_idx
87 |
88 | loss = None
89 | if closure is not None:
90 | loss = closure()
91 |
92 | for group in self.param_groups:
93 |
94 | for p in group['params']:
95 | if p.grad is None:
96 | continue
97 | grad = p.grad.data.float()
98 | if grad.is_sparse:
99 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
100 | amsgrad = group['amsgrad']
101 |
102 | p_data_fp32 = p.data.float()
103 |
104 | state = self.state[p]
105 |
106 | if len(state) == 0:
107 | state['step'] = 0
108 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
109 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
110 | else:
111 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
112 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
113 |
114 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
115 | beta1, beta2 = group['betas']
116 |
117 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
118 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
119 |
120 | state['step'] += 1
121 |
122 | beta2_t = beta2 ** state['step']
123 | N_sma_max = 2 / (1 - beta2) - 1
124 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
125 |
126 | if group['weight_decay'] != 0:
127 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
128 |
129 | # more conservative since it's an approximated value
130 | if N_sma >= 5:
131 | step_size = group['lr'] * math.sqrt((1 - beta2_t ) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) * (N_sma_max) / N_sma / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
132 | denom = exp_avg_sq.sqrt().add_(group['eps'])
133 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
134 | else:
135 | step_size = group['lr'] / (1 - beta1 ** state['step'])
136 | p_data_fp32.add_(-step_size, exp_avg)
137 |
138 | p.data.copy_(p_data_fp32)
139 |
140 | # if writer_iter > 0 and writer_iter % 300 == 0 or writer_iter in [1, 5, 10, 25, 50, 75, 100, 150, 200]:
141 | # grad_list.extend( grad.abs().add_(1e-9).log().view(-1).tolist() )
142 | # mom_list.extend( exp_avg.abs().add_(1e-9).log().view(-1).tolist() )
143 | # mom_2rd_list.extend( exp_avg_sq.abs().add_(1e-9).log().view(-1).tolist() )
144 |
145 | # if writer_iter > 0 and writer_iter % 300 == 0 or writer_iter in [1, 5, 10, 25, 50, 75, 100, 150, 200]:
146 | # writer.add_histogram('grad/{}'.format(self.name), grad_list, writer_iter)
147 | # writer.add_histogram('mom/{}'.format(self.name), mom_list, writer_iter)
148 | # writer.add_histogram('mom_sq/{}'.format(self.name), mom_2rd_list, writer_iter)
149 |
150 | return loss
151 |
--------------------------------------------------------------------------------
/nmt/recipes.md:
--------------------------------------------------------------------------------
1 |
2 | # Adam with warmup
3 |
4 | ```
5 | CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en -a transformer_iwslt_de_en --optimizer adam2 --lr 0.0003 -s de -t en --label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 --warmup-init-lr 1e-8 --min-lr '1e-09' --lr-scheduler linear --weight-decay 0.0001 --criterion label_smoothed_cross_entropy --max-update 70000 --warmup-updates 4000 --adam-betas '(0.9, 0.999)' --save-dir /cps/gadam/nmt/adam_warmup_f_0 --tb-tag adam_warmup_f_0 --user-dir ./my_module --restore-file x.pt
6 |
7 | bash eval.sh /cps/gadam/nmt/adam_warmup_f_0 0 >> results_f_5.txt
8 |
9 | for SEED in 1111 2222 3333 4444
10 | do
11 |
12 | CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en -a transformer_iwslt_de_en --optimizer adam2 --lr 0.0003 -s de -t en --label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 --warmup-init-lr 1e-8 --min-lr '1e-09' --lr-scheduler linear --weight-decay 0.0001 --criterion label_smoothed_cross_entropy --max-update 70000 --warmup-updates 4000 --adam-betas '(0.9, 0.999)' --save-dir /cps/gadam/nmt/adam_warmup_f_$SEED --tb-tag adam_warmup_f_$SEED --user-dir ./my_module --restore-file x.pt --seed $SEED
13 |
14 | bash eval.sh /cps/gadam/nmt/adam_warmup_f_$SEED 0 >> results_f_5.txt
15 | done
16 | ```
17 |
18 | # Adam-2k
19 |
20 | ```
21 | CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en -a transformer_iwslt_de_en --optimizer adam2 --lr 0.0003086 -s de -t en --label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 --min-lr '1e-09' --lr-scheduler linear --weight-decay 0.0001 --criterion label_smoothed_cross_entropy --max-update 72000 --warmup-updates 1 --adam-betas '(0.9, 0.999)' --save-dir /cps/gadam/nmt/adam_1k --tb-tag adam_1k --user-dir ./my_module --fp16 --restore-file x.pt --adam-freeze 2000
22 | ```
23 |
24 | # Adam-eps
25 |
26 | ```
27 | CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en -a transformer_iwslt_de_en --optimizer adam2 --lr 0.0003 -s de -t en --label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 --min-lr '1e-09' --lr-scheduler linear --weight-decay 0.0001 --criterion label_smoothed_cross_entropy --max-update 70000 --warmup-updates 1 --adam-betas '(0.9, 0.999)' --save-dir /cps/gadam/nmt/adam_eps --tb-tag adam_eps --user-dir ./my_module --fp16 --adam-eps 1e-4 --restore-file x.pt
28 |
29 | ```
30 |
31 | # RAdam
32 |
33 | ```
34 | CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en -a transformer_iwslt_de_en --optimizer radam --lr 0.0003 -s de -t en --label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 --min-lr '1e-09' --lr-scheduler linear --weight-decay 0.0001 --criterion label_smoothed_cross_entropy --max-update 70000 --warmup-updates 1 --adam-betas '(0.9, 0.999)' --save-dir /cps/gadam/nmt/radam_0 --tb-tag radam_0 --user-dir ./my_module --fp16 --restore-file x.pt
35 |
36 | bash eval.sh /cps/gadam/nmt/radam_0 0 >> results_f_5.txt
37 |
38 | for SEED in 1111 2222 3333 4444
39 | do
40 | CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en -a transformer_iwslt_de_en --optimizer radam --lr 0.0003 -s de -t en --label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 --min-lr '1e-09' --lr-scheduler linear --weight-decay 0.0001 --criterion label_smoothed_cross_entropy --max-update 70000 --warmup-updates 1 --adam-betas '(0.9, 0.9995)' --save-dir /cps/gadam/nmt/radam_$SEED --tb-tag radam_$SEED --user-dir ./my_module --fp16 --restore-file x.pt --seed $SEED
41 |
42 | bash eval.sh /cps/gadam/nmt/radam_$SEED 0 >> results_f_5.txt
43 | done
44 | ```
45 |
46 | # Novograd
47 | We also implemented [novograd](https://arxiv.org/pdf/1905.11286.pdf), which claims no warmup is requried.
48 | We tried the following settings with with lr=0.03, 0.0003, 0.00003, 0.00001, none of these works without warmup.
49 |
50 | ```
51 | CUDA_VISIBLE_DEVICES=0 fairseq-train ./data-bin/iwslt14.tokenized.de-en -a transformer_iwslt_de_en --optimizer novograd --lr 0.0003 -s en -t de --label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 --min-lr '1e-09' --lr-scheduler poly --weight-decay 5e-5 --criterion label_smoothed_cross_entropy --max-update 70000 --adam-betas '(0.9, 0.999)' --save-dir /ckp/nmt/novograd --tb-tag novograd --user-dir ./my_module --fp16 --restore-file x.pt
52 | ```
53 |
--------------------------------------------------------------------------------
/radam/__init__.py:
--------------------------------------------------------------------------------
1 | from .radam import RAdam, PlainRAdam, AdamW
2 |
--------------------------------------------------------------------------------
/radam/radam.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer, required
4 |
5 | class RAdam(Optimizer):
6 |
7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False):
8 | if not 0.0 <= lr:
9 | raise ValueError("Invalid learning rate: {}".format(lr))
10 | if not 0.0 <= eps:
11 | raise ValueError("Invalid epsilon value: {}".format(eps))
12 | if not 0.0 <= betas[0] < 1.0:
13 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
14 | if not 0.0 <= betas[1] < 1.0:
15 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
16 |
17 | self.degenerated_to_sgd = degenerated_to_sgd
18 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
19 | for param in params:
20 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
21 | param['buffer'] = [[None, None, None] for _ in range(10)]
22 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])
23 | super(RAdam, self).__init__(params, defaults)
24 |
25 | def __setstate__(self, state):
26 | super(RAdam, self).__setstate__(state)
27 |
28 | def step(self, closure=None):
29 |
30 | loss = None
31 | if closure is not None:
32 | loss = closure()
33 |
34 | for group in self.param_groups:
35 |
36 | for p in group['params']:
37 | if p.grad is None:
38 | continue
39 | grad = p.grad.data.float()
40 | if grad.is_sparse:
41 | raise RuntimeError('RAdam does not support sparse gradients')
42 |
43 | p_data_fp32 = p.data.float()
44 |
45 | state = self.state[p]
46 |
47 | if len(state) == 0:
48 | state['step'] = 0
49 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
50 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
51 | else:
52 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
53 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
54 |
55 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
56 | beta1, beta2 = group['betas']
57 |
58 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
59 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
60 |
61 | state['step'] += 1
62 | buffered = group['buffer'][int(state['step'] % 10)]
63 | if state['step'] == buffered[0]:
64 | N_sma, step_size = buffered[1], buffered[2]
65 | else:
66 | buffered[0] = state['step']
67 | beta2_t = beta2 ** state['step']
68 | N_sma_max = 2 / (1 - beta2) - 1
69 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
70 | buffered[1] = N_sma
71 |
72 | # more conservative since it's an approximated value
73 | if N_sma >= 5:
74 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
75 | elif self.degenerated_to_sgd:
76 | step_size = 1.0 / (1 - beta1 ** state['step'])
77 | else:
78 | step_size = -1
79 | buffered[2] = step_size
80 |
81 | # more conservative since it's an approximated value
82 | if N_sma >= 5:
83 | if group['weight_decay'] != 0:
84 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
85 | denom = exp_avg_sq.sqrt().add_(group['eps'])
86 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
87 | p.data.copy_(p_data_fp32)
88 | elif step_size > 0:
89 | if group['weight_decay'] != 0:
90 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
91 | p_data_fp32.add_(-step_size * group['lr'], exp_avg)
92 | p.data.copy_(p_data_fp32)
93 |
94 | return loss
95 |
96 | class PlainRAdam(Optimizer):
97 |
98 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False):
99 | if not 0.0 <= lr:
100 | raise ValueError("Invalid learning rate: {}".format(lr))
101 | if not 0.0 <= eps:
102 | raise ValueError("Invalid epsilon value: {}".format(eps))
103 | if not 0.0 <= betas[0] < 1.0:
104 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
105 | if not 0.0 <= betas[1] < 1.0:
106 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
107 |
108 | self.degenerated_to_sgd = degenerated_to_sgd
109 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
110 |
111 | super(PlainRAdam, self).__init__(params, defaults)
112 |
113 | def __setstate__(self, state):
114 | super(PlainRAdam, self).__setstate__(state)
115 |
116 | def step(self, closure=None):
117 |
118 | loss = None
119 | if closure is not None:
120 | loss = closure()
121 |
122 | for group in self.param_groups:
123 |
124 | for p in group['params']:
125 | if p.grad is None:
126 | continue
127 | grad = p.grad.data.float()
128 | if grad.is_sparse:
129 | raise RuntimeError('RAdam does not support sparse gradients')
130 |
131 | p_data_fp32 = p.data.float()
132 |
133 | state = self.state[p]
134 |
135 | if len(state) == 0:
136 | state['step'] = 0
137 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
138 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
139 | else:
140 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
141 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
142 |
143 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
144 | beta1, beta2 = group['betas']
145 |
146 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
147 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
148 |
149 | state['step'] += 1
150 | beta2_t = beta2 ** state['step']
151 | N_sma_max = 2 / (1 - beta2) - 1
152 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
153 |
154 |
155 | # more conservative since it's an approximated value
156 | if N_sma >= 5:
157 | if group['weight_decay'] != 0:
158 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
159 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
160 | denom = exp_avg_sq.sqrt().add_(group['eps'])
161 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
162 | p.data.copy_(p_data_fp32)
163 | elif self.degenerated_to_sgd:
164 | if group['weight_decay'] != 0:
165 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
166 | step_size = group['lr'] / (1 - beta1 ** state['step'])
167 | p_data_fp32.add_(-step_size, exp_avg)
168 | p.data.copy_(p_data_fp32)
169 |
170 | return loss
171 |
172 |
173 | class AdamW(Optimizer):
174 |
175 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0):
176 | if not 0.0 <= lr:
177 | raise ValueError("Invalid learning rate: {}".format(lr))
178 | if not 0.0 <= eps:
179 | raise ValueError("Invalid epsilon value: {}".format(eps))
180 | if not 0.0 <= betas[0] < 1.0:
181 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
182 | if not 0.0 <= betas[1] < 1.0:
183 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
184 |
185 | defaults = dict(lr=lr, betas=betas, eps=eps,
186 | weight_decay=weight_decay, warmup = warmup)
187 | super(AdamW, self).__init__(params, defaults)
188 |
189 | def __setstate__(self, state):
190 | super(AdamW, self).__setstate__(state)
191 |
192 | def step(self, closure=None):
193 | loss = None
194 | if closure is not None:
195 | loss = closure()
196 |
197 | for group in self.param_groups:
198 |
199 | for p in group['params']:
200 | if p.grad is None:
201 | continue
202 | grad = p.grad.data.float()
203 | if grad.is_sparse:
204 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
205 |
206 | p_data_fp32 = p.data.float()
207 |
208 | state = self.state[p]
209 |
210 | if len(state) == 0:
211 | state['step'] = 0
212 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
213 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
214 | else:
215 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
216 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
217 |
218 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
219 | beta1, beta2 = group['betas']
220 |
221 | state['step'] += 1
222 |
223 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
224 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
225 |
226 | denom = exp_avg_sq.sqrt().add_(group['eps'])
227 | bias_correction1 = 1 - beta1 ** state['step']
228 | bias_correction2 = 1 - beta2 ** state['step']
229 |
230 | if group['warmup'] > state['step']:
231 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup']
232 | else:
233 | scheduled_lr = group['lr']
234 |
235 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1
236 |
237 | if group['weight_decay'] != 0:
238 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)
239 |
240 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
241 |
242 | p.data.copy_(p_data_fp32)
243 |
244 | return loss
245 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | import re
3 |
4 | try:
5 | import torch
6 | has_dev_pytorch = "dev" in torch.__version__
7 | except ImportError:
8 | has_dev_pytorch = False
9 |
10 | # Base equirements
11 | install_requires = [
12 | "torch",
13 | ]
14 |
15 | if has_dev_pytorch: # Remove the PyTorch requirement
16 | install_requires = [
17 | install_require for install_require in install_requires
18 | if "torch" != re.split(r"(=|<|>)", install_require)[0]
19 | ]
20 |
21 | setup(
22 | name='RAdam',
23 | version='0.0.1',
24 | url='https://github.com/LiyuanLucasLiu/RAdam.git',
25 | author='Liyuan Liu',
26 | author_email='llychinalz@gmail.com',
27 | description='Implementation of the RAdam optimization algorithm described in On the Variance of the Adaptive Learning Rate and Beyond (https://arxiv.org/abs/1908.03265)',
28 | packages=find_packages(),
29 | install_requires=install_requires,
30 | )
31 |
--------------------------------------------------------------------------------