├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── cifar_imagenet ├── .gitignore ├── LICENSE ├── README.md ├── cifar.py ├── fourstep.sh ├── imagenet.py ├── models │ ├── __init__.py │ ├── cifar │ │ ├── __init__.py │ │ ├── alexnet.py │ │ ├── densenet.py │ │ ├── preresnet.py │ │ ├── resnet.py │ │ ├── resnext.py │ │ ├── vgg.py │ │ └── wrn.py │ └── imagenet │ │ ├── __init__.py │ │ └── resnext.py ├── recipes.md └── utils │ ├── __init__.py │ ├── eval.py │ ├── images │ ├── cifar.png │ └── imagenet.png │ ├── logger.py │ ├── misc.py │ ├── radam.py │ └── visualize.py ├── img └── variance.png ├── language-model ├── README.md ├── eval_1bw.py ├── model_word_ada │ ├── LM.py │ ├── adaptive.py │ ├── basic.py │ ├── bnlstm.py │ ├── dataset.py │ ├── ddnet.py │ ├── densenet.py │ ├── ldnet.py │ ├── radam.py │ ├── resnet.py │ └── utils.py ├── pre_word_ada │ ├── encode_data2folder.py │ └── gene_map.py ├── recipes.md └── train_1bw.py ├── nmt ├── README.md ├── average_checkpoints.py ├── eval.sh ├── my_module │ ├── __init__.py │ ├── adam2.py │ ├── linear_schedule.py │ ├── novograd.py │ ├── poly_schedule.py │ └── radam.py └── recipes.md ├── radam ├── __init__.py └── radam.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.gitignore.io/api/python,pycharm,jupyternotebooks,visualstudiocode 3 | # Edit at https://www.gitignore.io/?templates=python,pycharm,jupyternotebooks,visualstudiocode 4 | 5 | ### JupyterNotebooks ### 6 | # gitignore template for Jupyter Notebooks 7 | # website: http://jupyter.org/ 8 | 9 | .ipynb_checkpoints 10 | */.ipynb_checkpoints/* 11 | 12 | # IPython 13 | profile_default/ 14 | ipython_config.py 15 | 16 | # Remove previous ipynb_checkpoints 17 | # git rm -r .ipynb_checkpoints/ 18 | 19 | ### PyCharm ### 20 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 21 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 22 | 23 | # User-specific stuff 24 | .idea/**/workspace.xml 25 | .idea/**/tasks.xml 26 | .idea/**/usage.statistics.xml 27 | .idea/**/dictionaries 28 | .idea/**/shelf 29 | 30 | # Generated files 31 | .idea/**/contentModel.xml 32 | 33 | # Sensitive or high-churn files 34 | .idea/**/dataSources/ 35 | .idea/**/dataSources.ids 36 | .idea/**/dataSources.local.xml 37 | .idea/**/sqlDataSources.xml 38 | .idea/**/dynamic.xml 39 | .idea/**/uiDesigner.xml 40 | .idea/**/dbnavigator.xml 41 | 42 | # Gradle 43 | .idea/**/gradle.xml 44 | .idea/**/libraries 45 | 46 | # Gradle and Maven with auto-import 47 | # When using Gradle or Maven with auto-import, you should exclude module files, 48 | # since they will be recreated, and may cause churn. Uncomment if using 49 | # auto-import. 50 | # .idea/modules.xml 51 | # .idea/*.iml 52 | # .idea/modules 53 | # *.iml 54 | # *.ipr 55 | 56 | # CMake 57 | cmake-build-*/ 58 | 59 | # Mongo Explorer plugin 60 | .idea/**/mongoSettings.xml 61 | 62 | # File-based project format 63 | *.iws 64 | 65 | # IntelliJ 66 | out/ 67 | 68 | # mpeltonen/sbt-idea plugin 69 | .idea_modules/ 70 | 71 | # JIRA plugin 72 | atlassian-ide-plugin.xml 73 | 74 | # Cursive Clojure plugin 75 | .idea/replstate.xml 76 | 77 | # Crashlytics plugin (for Android Studio and IntelliJ) 78 | com_crashlytics_export_strings.xml 79 | crashlytics.properties 80 | crashlytics-build.properties 81 | fabric.properties 82 | 83 | # Editor-based Rest Client 84 | .idea/httpRequests 85 | 86 | # Android studio 3.1+ serialized cache file 87 | .idea/caches/build_file_checksums.ser 88 | 89 | ### PyCharm Patch ### 90 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 91 | 92 | # *.iml 93 | # modules.xml 94 | # .idea/misc.xml 95 | # *.ipr 96 | 97 | # Sonarlint plugin 98 | .idea/**/sonarlint/ 99 | 100 | # SonarQube Plugin 101 | .idea/**/sonarIssues.xml 102 | 103 | # Markdown Navigator plugin 104 | .idea/**/markdown-navigator.xml 105 | .idea/**/markdown-navigator/ 106 | 107 | ### Python ### 108 | # Byte-compiled / optimized / DLL files 109 | __pycache__/ 110 | *.py[cod] 111 | *$py.class 112 | 113 | # C extensions 114 | *.so 115 | 116 | # Distribution / packaging 117 | .Python 118 | build/ 119 | develop-eggs/ 120 | dist/ 121 | downloads/ 122 | eggs/ 123 | .eggs/ 124 | lib/ 125 | lib64/ 126 | parts/ 127 | sdist/ 128 | var/ 129 | wheels/ 130 | pip-wheel-metadata/ 131 | share/python-wheels/ 132 | *.egg-info/ 133 | .installed.cfg 134 | *.egg 135 | MANIFEST 136 | 137 | # PyInstaller 138 | # Usually these files are written by a python script from a template 139 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 140 | *.manifest 141 | *.spec 142 | 143 | # Installer logs 144 | pip-log.txt 145 | pip-delete-this-directory.txt 146 | 147 | # Unit test / coverage reports 148 | htmlcov/ 149 | .tox/ 150 | .nox/ 151 | .coverage 152 | .coverage.* 153 | .cache 154 | nosetests.xml 155 | coverage.xml 156 | *.cover 157 | .hypothesis/ 158 | .pytest_cache/ 159 | 160 | # Translations 161 | *.mo 162 | *.pot 163 | 164 | # Scrapy stuff: 165 | .scrapy 166 | 167 | # Sphinx documentation 168 | docs/_build/ 169 | 170 | # PyBuilder 171 | target/ 172 | 173 | # pyenv 174 | .python-version 175 | 176 | # pipenv 177 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 178 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 179 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 180 | # install all needed dependencies. 181 | #Pipfile.lock 182 | 183 | # celery beat schedule file 184 | celerybeat-schedule 185 | 186 | # SageMath parsed files 187 | *.sage.py 188 | 189 | # Spyder project settings 190 | .spyderproject 191 | .spyproject 192 | 193 | # Rope project settings 194 | .ropeproject 195 | 196 | # Mr Developer 197 | .mr.developer.cfg 198 | .project 199 | .pydevproject 200 | 201 | # mkdocs documentation 202 | /site 203 | 204 | # mypy 205 | .mypy_cache/ 206 | .dmypy.json 207 | dmypy.json 208 | 209 | # Pyre type checker 210 | .pyre/ 211 | 212 | ### VisualStudioCode ### 213 | .vscode/* 214 | !.vscode/settings.json 215 | !.vscode/tasks.json 216 | !.vscode/launch.json 217 | !.vscode/extensions.json 218 | 219 | ### VisualStudioCode Patch ### 220 | # Ignore all local history of files 221 | .history 222 | 223 | # End of https://www.gitignore.io/api/python,pycharm,jupyternotebooks,visualstudiocode 224 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | install: pip install flake8 3 | script: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2019] [Liyuan Liu] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 2 | [![Travis-CI](https://travis-ci.org/LiyuanLucasLiu/RAdam.svg?branch=master)](https://travis-ci.org/LiyuanLucasLiu/RAdam) 3 | 4 |

RAdam

5 |
On the Variance of the Adaptive Learning Rate and Beyond
6 | 7 | We are in an early-release beta. Expect some adventures and rough edges. 8 | 9 | ## Table of Contents 10 | 11 | - [Introduction](#introduction) 12 | - [Motivation](#motivation) 13 | - [Questions and Discussions](#questions-and-discussions) 14 | - [Quick Start Guide](#quick-start-guide) 15 | - [Related Posts and Repos](#related-posts-and-repos) 16 | - [Citation](#citation) 17 | 18 | ## Introduction 19 |
If warmup is the answer, what is the question?
20 | 21 | The learning rate warmup for Adam is a must-have trick for stable training in certain situations (or eps tuning). But the underlying mechanism is largely unknown. In our study, we suggest one fundamental cause is __the large variance of the adaptive learning rates__, and provide both theoretical and empirical support evidence. 22 | 23 | In addition to explaining __why we should use warmup__, we also propose __RAdam__, a theoretically sound variant of Adam. 24 | 25 | ## Motivation 26 | 27 | As shown in Figure 1, we assume that gradients follow a normal distribution (mean: \mu, variance: 1). The variance of the adaptive learning rate is simulated and plotted in Figure 1 (blue curve). We observe that the adaptive learning rate has a large variance in the early stage of training. 28 | 29 |

30 | 31 | When using a Transformer for NMT, a warmup stage is usually required to avoid convergence problems (e.g., Adam-vanilla converges around 500 PPL in Figure 2, while Adam-warmup successfully converges under 10 PPL). 32 | In further explorations, we notice that, if we use additional 2000 samples to estimate the adaptive learning rate, the convergence problems are avoided (Adam-2k); or, if we increase the value of eps, the convergence problems are also relieved (Adam-eps). 33 | 34 | Therefore, we conjecture that the large variance in the early stage causes the convergence problem, and further propose Rectified Adam by analytically reducing the large variance. More details can be found in our [paper](https://arxiv.org/abs/1908.03265). 35 | 36 | ## Questions and Discussions 37 | 38 | ### Do I need to tune learning rate? 39 | 40 | Yes, the robustness of RAdam is not infinity. In our experiments, it works for a broader range of learning rates, but not all learning rates. 41 | 42 | ### Notes on Transformer (more discussions can be found in our [Transformer Clinic](https://github.com/LiyuanLucasLiu/Transformer-Clinic) project) 43 | 44 | __Choice of the Original Transformer.__ We choose the original Transformer as our main study object because, without warmup, it suffers from the most serious convergence problems in our experiments. With such serious problems, our controlled experiments can better verify our hypothesis (i.e., we demonstrate that Adam-2k / Adam-eps can avoid spurious local optima by minimal changes). 45 | 46 | __Sensitivity.__ We observe that the Transformer is sensitive to the architecture configuration, despite its efficiency and effectiveness. For example, by changing the position of the layer norm, the model may / may not require the warmup to get a good performance. Intuitively, since the gradient of the attention layer could be more sparse and the adaptive learning rates for smaller gradients have a larger variance, they are more sensitive. Nevertheless, we believe this problem deserves more in-depth analysis and is beyond the scope of our study. 47 | 48 | ### Why does warmup have a bigger impact on some models than others? 49 | 50 | Although the adaptive learning rate has a larger variance in the early stage, the exact magnitude is subject to the model design. Thus, the convergent problem could be more serious for some models/tasks than others. In our experiments, we observe that RAdam achieves consistent improvements over the vanilla Adam. It verifies the variance issue widely exists (since we can get better performance by fixing it). 51 | 52 | ### What if the gradient is not zero-meaned? 53 | 54 | As in Figure 1 (above), even if the gradient is not zero-meaned, the original adaptive learning rate still has a larger variance in the beginning, thus applying the rectification can help to stabilize the training. 55 | 56 | Another related concern is that, when the mean of the gradient is significantly larger than its variance, the magnitude of the "problematic" variance may not be very large (i.e., in Figure 1, when \mu equals to 10, the adaptive learning rate variance is relatively small and may not cause problems). We think it provides a possible explaination on why warmup have a bigger impact on some models than others. Still, we suggest that, in real-world applications, neural networks usually have some parts of parameters meet our assumption well (i.e., their gradient variance is larger than their gradient mean), and needs the rectification to stabilize the training. 57 | 58 | ### Why does SGD need warmup? 59 | 60 | To the best of our knowledge, the warmup heuristic is originally designed for large minibatch SGD [0], based on the intuition that the network changes rapidly in the early stage. However, we find that it __does not__ explain why Adam requires warmup. Notice that, Adam-2k uses the same large learning rate but with a better estimation of the adaptive learning rate can also avoid the convergence problems. 61 | 62 | The reason why sometimes warmup also helps SGD still lacks of theoretical support. FYI, when optimizing a simple 2-layer CNN with gradient descent, the thoery of [1] could be used to show the benifits of warmup. Specifically, the lr must be $O(cos \phi)$, where $\phi$ is the angle between the current weight and the ground true weight and $cos \phi$ could be very small due to high dimensional space and random initialization. And thus lr must be very small at the beginning to guarentee the convergence. $cos \phi$ however can be improved in the later stage, and thus the learning rate is allowed to be larger. Their theory somehow can justify why warmup is needed by gradient descend and neural networks. But it is still far-fetched for the real scenario. 63 | 64 | >

[0] Goyal et al, Accurate, Large Minibatch SGD: Training Imagenet in 1 Hour, 2017

65 | > [1] Du et al, Gradient Descent Learns One-hidden-layer CNN: Don’t be Afraid of Spurious Local Minima, 2017 66 | 67 | ## Quick Start Guide 68 | 69 | 1. Directly replace the vanilla Adam with RAdam without changing any settings. 70 | 2. Further tune hyper-parameters (including the learning rate) for a better performance. 71 | 72 | Note that in our paper, our major contribution is __to identify why we need the warmup for Adam__. Although some researchers successfully improve their model performance (__[user comments](#user-comments)__), considering the difficulty of training NNs, directly plugging in RAdam __may not__ result in an immediate performance boost. Based on our experience, replacing __the vanilla Adam__ with RAdam usually results in a better performance; however, if __warmup has already been employed and tuned__ in the baseline method, it is necessary to also tune hyper-parameters for RAdam. 73 | 74 | ## Related Posts and Repos 75 | 76 | ### Unofficial Re-Implementations 77 | RAdam is very easy to implement, we provide PyTorch implementations here, while third party ones can be found at: 78 | 79 | [Keras Implementation](https://github.com/CyberZHG/keras-radam) 80 | 81 | [Keras Implementation](https://github.com/titu1994/keras_rectified_adam) 82 | 83 | [Julia implementation in Flux.jl](https://fluxml.ai/Flux.jl/stable/training/optimisers/#Flux.Optimise.RADAM) 84 | 85 | ### Unofficial Introduction & Mentions 86 | 87 | We provide a simple introduction in [Motivation](#motivation), and more details can be found in our [paper](https://arxiv.org/abs/1908.03265). There are some unofficial introductions available (with better writings), and they are listed here for reference only (contents/claims in our paper are more accurate): 88 | 89 | [Medium Post](https://medium.com/@lessw/new-state-of-the-art-ai-optimizer-rectified-adam-radam-5d854730807b) 90 | > [related Twitter Post](https://twitter.com/jeremyphoward/status/1162118545095852032?ref_src=twsrc%5Etfw) 91 | 92 | [CSDN Post (in Chinese)](https://blog.csdn.net/u014248127/article/details/99696029) 93 | 94 | ### User Comments 95 | 96 | We are happy to see that our algorithms are found to be useful by some users : -) 97 | 98 |

"...I tested it on ImageNette and quickly got new high accuracy scores for the 5 and 20 epoch 128px leaderboard scores, so I know it works... https://forums.fast.ai/t/meet-radam-imo-the-new-state-of-the-art-ai-optimizer/52656

— Less Wright August 15, 2019
99 | 100 |

Thought "sounds interesting, I'll give it a try" - top 5 are vanilla Adam, bottom 4 (I only have access to 4 GPUs) are RAdam... so far looking pretty promising! pic.twitter.com/irvJSeoVfx

— Hamish Dickson (@_mishy) August 16, 2019
101 | 102 |

RAdam works great for me! It’s good to several % accuracy for free, but the biggest thing I like is the training stability. RAdam is way more stable! https://medium.com/@mgrankin/radam-works-great-for-me-344d37183943

— Grankin Mikhail August 17, 2019
103 | 104 |

"... Also, I achieved higher accuracy results using the newly proposed RAdam optimization function.... 105 | https://towardsdatascience.com/optimism-is-on-the-menu-a-recession-is-not-d87cce265b10

— Sameer Ahuja August 24, 2019
106 | 107 |

"... Out-of-box RAdam implementation performs better than Adam and finetuned SGD... https://twitter.com/ukrdailo/status/1166265186920980480

— Alex Dailo August 27, 2019
108 | 109 | ## Citation 110 | Please cite the following paper if you found our model useful. Thanks! 111 | 112 | >Liyuan Liu , Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, and Jiawei Han (2020). On the Variance of the Adaptive Learning Rate and Beyond. the Eighth International Conference on Learning Representations. 113 | 114 | ``` 115 | @inproceedings{liu2019radam, 116 | author = {Liu, Liyuan and Jiang, Haoming and He, Pengcheng and Chen, Weizhu and Liu, Xiaodong and Gao, Jianfeng and Han, Jiawei}, 117 | booktitle = {Proceedings of the Eighth International Conference on Learning Representations (ICLR 2020)}, 118 | month = {April}, 119 | title = {On the Variance of the Adaptive Learning Rate and Beyond}, 120 | year = {2020} 121 | } 122 | 123 | ``` 124 | -------------------------------------------------------------------------------- /cifar_imagenet/.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | checkpoint/* 3 | -------------------------------------------------------------------------------- /cifar_imagenet/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Wei Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cifar_imagenet/README.md: -------------------------------------------------------------------------------- 1 | # CIFAR & IMAGENET 2 | 3 | This folder is modified based on the pytorch classification project [original repo](https://github.com/bearpaw/pytorch-classification). For more details about this code base, please refer to the original repo. 4 | A training [recipe](/cifar_imagenet/recipes.md) is provided for image classification experiments. 5 | 6 | -------------------------------------------------------------------------------- /cifar_imagenet/fourstep.sh: -------------------------------------------------------------------------------- 1 | 2 | ROOT="cps" 3 | for SEED in 1111 2222 3333 4444 5555 4 | do 5 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer radam4s --beta1 0.9 --beta2 0.999 --checkpoint $ROOT/cifar10/resnet-20-adam4s-01-$SEED --gpu-id 0 --lr 0.1 --model_name adam4s --manualSeed $SEED 6 | 7 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer radam4s --beta1 0.9 --beta2 0.999 --checkpoint $ROOT/cifar10/resnet-20-adam4s-ua-01-$SEED --gpu-id 0 --lr 0.1 --model_name adam4s_ua --update_all --manualSeed $SEED 8 | 9 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer radam4s --beta1 0.9 --beta2 0.999 --checkpoint $ROOT/cifar10/resnet-20-adam4s-ua-af-01-$SEED --gpu-id 0 --lr 0.1 --model_name adam4s_ua_af --update_all --additional_four --manualSeed $SEED 10 | done 11 | -------------------------------------------------------------------------------- /cifar_imagenet/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiyuanLucasLiu/RAdam/d9fd30a337894c4003768561d45e8730dbd41333/cifar_imagenet/models/__init__.py -------------------------------------------------------------------------------- /cifar_imagenet/models/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | """The models subpackage contains definitions for the following model for CIFAR10/CIFAR100 4 | architectures: 5 | 6 | - `AlexNet`_ 7 | - `VGG`_ 8 | - `ResNet`_ 9 | - `SqueezeNet`_ 10 | - `DenseNet`_ 11 | 12 | You can construct a model with random weights by calling its constructor: 13 | 14 | .. code:: python 15 | 16 | import torchvision.models as models 17 | resnet18 = models.resnet18() 18 | alexnet = models.alexnet() 19 | squeezenet = models.squeezenet1_0() 20 | densenet = models.densenet_161() 21 | 22 | We provide pre-trained models for the ResNet variants and AlexNet, using the 23 | PyTorch :mod:`torch.utils.model_zoo`. These can constructed by passing 24 | ``pretrained=True``: 25 | 26 | .. code:: python 27 | 28 | import torchvision.models as models 29 | resnet18 = models.resnet18(pretrained=True) 30 | alexnet = models.alexnet(pretrained=True) 31 | 32 | ImageNet 1-crop error rates (224x224) 33 | 34 | ======================== ============= ============= 35 | Network Top-1 error Top-5 error 36 | ======================== ============= ============= 37 | ResNet-18 30.24 10.92 38 | ResNet-34 26.70 8.58 39 | ResNet-50 23.85 7.13 40 | ResNet-101 22.63 6.44 41 | ResNet-152 21.69 5.94 42 | Inception v3 22.55 6.44 43 | AlexNet 43.45 20.91 44 | VGG-11 30.98 11.37 45 | VGG-13 30.07 10.75 46 | VGG-16 28.41 9.62 47 | VGG-19 27.62 9.12 48 | SqueezeNet 1.0 41.90 19.58 49 | SqueezeNet 1.1 41.81 19.38 50 | Densenet-121 25.35 7.83 51 | Densenet-169 24.00 7.00 52 | Densenet-201 22.80 6.43 53 | Densenet-161 22.35 6.20 54 | ======================== ============= ============= 55 | 56 | 57 | .. _AlexNet: https://arxiv.org/abs/1404.5997 58 | .. _VGG: https://arxiv.org/abs/1409.1556 59 | .. _ResNet: https://arxiv.org/abs/1512.03385 60 | .. _SqueezeNet: https://arxiv.org/abs/1602.07360 61 | .. _DenseNet: https://arxiv.org/abs/1608.06993 62 | """ 63 | 64 | from .alexnet import * 65 | from .vgg import * 66 | from .resnet import * 67 | from .resnext import * 68 | from .wrn import * 69 | from .preresnet import * 70 | from .densenet import * 71 | -------------------------------------------------------------------------------- /cifar_imagenet/models/cifar/alexnet.py: -------------------------------------------------------------------------------- 1 | '''AlexNet for CIFAR10. FC layers are removed. Paddings are adjusted. 2 | Without BN, the start learning rate should be 0.01 3 | (c) YANG, Wei 4 | ''' 5 | import torch.nn as nn 6 | 7 | 8 | __all__ = ['alexnet'] 9 | 10 | 11 | class AlexNet(nn.Module): 12 | 13 | def __init__(self, num_classes=10): 14 | super(AlexNet, self).__init__() 15 | self.features = nn.Sequential( 16 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=5), 17 | nn.ReLU(inplace=True), 18 | nn.MaxPool2d(kernel_size=2, stride=2), 19 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 20 | nn.ReLU(inplace=True), 21 | nn.MaxPool2d(kernel_size=2, stride=2), 22 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 27 | nn.ReLU(inplace=True), 28 | nn.MaxPool2d(kernel_size=2, stride=2), 29 | ) 30 | self.classifier = nn.Linear(256, num_classes) 31 | 32 | def forward(self, x): 33 | x = self.features(x) 34 | x = x.view(x.size(0), -1) 35 | x = self.classifier(x) 36 | return x 37 | 38 | 39 | def alexnet(**kwargs): 40 | r"""AlexNet model architecture from the 41 | `"One weird trick..." `_ paper. 42 | """ 43 | model = AlexNet(**kwargs) 44 | return model 45 | -------------------------------------------------------------------------------- /cifar_imagenet/models/cifar/densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | __all__ = ['densenet'] 8 | 9 | 10 | from torch.autograd import Variable 11 | 12 | class Bottleneck(nn.Module): 13 | def __init__(self, inplanes, expansion=4, growthRate=12, dropRate=0): 14 | super(Bottleneck, self).__init__() 15 | planes = expansion * growthRate 16 | self.bn1 = nn.BatchNorm2d(inplanes) 17 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, growthRate, kernel_size=3, 20 | padding=1, bias=False) 21 | self.relu = nn.ReLU(inplace=True) 22 | self.dropRate = dropRate 23 | 24 | def forward(self, x): 25 | out = self.bn1(x) 26 | out = self.relu(out) 27 | out = self.conv1(out) 28 | out = self.bn2(out) 29 | out = self.relu(out) 30 | out = self.conv2(out) 31 | if self.dropRate > 0: 32 | out = F.dropout(out, p=self.dropRate, training=self.training) 33 | 34 | out = torch.cat((x, out), 1) 35 | 36 | return out 37 | 38 | 39 | class BasicBlock(nn.Module): 40 | def __init__(self, inplanes, expansion=1, growthRate=12, dropRate=0): 41 | super(BasicBlock, self).__init__() 42 | planes = expansion * growthRate 43 | self.bn1 = nn.BatchNorm2d(inplanes) 44 | self.conv1 = nn.Conv2d(inplanes, growthRate, kernel_size=3, 45 | padding=1, bias=False) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.dropRate = dropRate 48 | 49 | def forward(self, x): 50 | out = self.bn1(x) 51 | out = self.relu(out) 52 | out = self.conv1(out) 53 | if self.dropRate > 0: 54 | out = F.dropout(out, p=self.dropRate, training=self.training) 55 | 56 | out = torch.cat((x, out), 1) 57 | 58 | return out 59 | 60 | 61 | class Transition(nn.Module): 62 | def __init__(self, inplanes, outplanes): 63 | super(Transition, self).__init__() 64 | self.bn1 = nn.BatchNorm2d(inplanes) 65 | self.conv1 = nn.Conv2d(inplanes, outplanes, kernel_size=1, 66 | bias=False) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | def forward(self, x): 70 | out = self.bn1(x) 71 | out = self.relu(out) 72 | out = self.conv1(out) 73 | out = F.avg_pool2d(out, 2) 74 | return out 75 | 76 | 77 | class DenseNet(nn.Module): 78 | 79 | def __init__(self, depth=22, block=Bottleneck, 80 | dropRate=0, num_classes=10, growthRate=12, compressionRate=2): 81 | super(DenseNet, self).__init__() 82 | 83 | assert (depth - 4) % 3 == 0, 'depth should be 3n+4' 84 | n = (depth - 4) / 3 if block == BasicBlock else (depth - 4) // 6 85 | 86 | self.growthRate = growthRate 87 | self.dropRate = dropRate 88 | 89 | # self.inplanes is a global variable used across multiple 90 | # helper functions 91 | self.inplanes = growthRate * 2 92 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, padding=1, 93 | bias=False) 94 | self.dense1 = self._make_denseblock(block, n) 95 | self.trans1 = self._make_transition(compressionRate) 96 | self.dense2 = self._make_denseblock(block, n) 97 | self.trans2 = self._make_transition(compressionRate) 98 | self.dense3 = self._make_denseblock(block, n) 99 | self.bn = nn.BatchNorm2d(self.inplanes) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.avgpool = nn.AvgPool2d(8) 102 | self.fc = nn.Linear(self.inplanes, num_classes) 103 | 104 | # Weight initialization 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 108 | m.weight.data.normal_(0, math.sqrt(2. / n)) 109 | elif isinstance(m, nn.BatchNorm2d): 110 | m.weight.data.fill_(1) 111 | m.bias.data.zero_() 112 | 113 | def _make_denseblock(self, block, blocks): 114 | layers = [] 115 | for i in range(blocks): 116 | # Currently we fix the expansion ratio as the default value 117 | layers.append(block(self.inplanes, growthRate=self.growthRate, dropRate=self.dropRate)) 118 | self.inplanes += self.growthRate 119 | 120 | return nn.Sequential(*layers) 121 | 122 | def _make_transition(self, compressionRate): 123 | inplanes = self.inplanes 124 | outplanes = int(math.floor(self.inplanes // compressionRate)) 125 | self.inplanes = outplanes 126 | return Transition(inplanes, outplanes) 127 | 128 | 129 | def forward(self, x): 130 | x = self.conv1(x) 131 | 132 | x = self.trans1(self.dense1(x)) 133 | x = self.trans2(self.dense2(x)) 134 | x = self.dense3(x) 135 | x = self.bn(x) 136 | x = self.relu(x) 137 | 138 | x = self.avgpool(x) 139 | x = x.view(x.size(0), -1) 140 | x = self.fc(x) 141 | 142 | return x 143 | 144 | 145 | def densenet(**kwargs): 146 | """ 147 | Constructs a ResNet model. 148 | """ 149 | return DenseNet(**kwargs) -------------------------------------------------------------------------------- /cifar_imagenet/models/cifar/preresnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch.nn as nn 11 | import math 12 | 13 | 14 | __all__ = ['preresnet'] 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | "3x3 convolution with padding" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(BasicBlock, self).__init__() 27 | self.bn1 = nn.BatchNorm2d(inplanes) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.bn1(x) 39 | out = self.relu(out) 40 | out = self.conv1(out) 41 | 42 | out = self.bn2(out) 43 | out = self.relu(out) 44 | out = self.conv2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None): 58 | super(Bottleneck, self).__init__() 59 | self.bn1 = nn.BatchNorm2d(inplanes) 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn2 = nn.BatchNorm2d(planes) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=False) 64 | self.bn3 = nn.BatchNorm2d(planes) 65 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.bn1(x) 74 | out = self.relu(out) 75 | out = self.conv1(out) 76 | 77 | out = self.bn2(out) 78 | out = self.relu(out) 79 | out = self.conv2(out) 80 | 81 | out = self.bn3(out) 82 | out = self.relu(out) 83 | out = self.conv3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | 90 | return out 91 | 92 | 93 | class PreResNet(nn.Module): 94 | 95 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'): 96 | super(PreResNet, self).__init__() 97 | # Model type specifies number of layers for CIFAR-10 model 98 | if block_name.lower() == 'basicblock': 99 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 100 | n = (depth - 2) // 6 101 | block = BasicBlock 102 | elif block_name.lower() == 'bottleneck': 103 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 104 | n = (depth - 2) // 9 105 | block = Bottleneck 106 | else: 107 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 108 | 109 | self.inplanes = 16 110 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 111 | bias=False) 112 | self.layer1 = self._make_layer(block, 16, n) 113 | self.layer2 = self._make_layer(block, 32, n, stride=2) 114 | self.layer3 = self._make_layer(block, 64, n, stride=2) 115 | self.bn = nn.BatchNorm2d(64 * block.expansion) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.avgpool = nn.AvgPool2d(8) 118 | self.fc = nn.Linear(64 * block.expansion, num_classes) 119 | 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 123 | m.weight.data.normal_(0, math.sqrt(2. / n)) 124 | elif isinstance(m, nn.BatchNorm2d): 125 | m.weight.data.fill_(1) 126 | m.bias.data.zero_() 127 | 128 | def _make_layer(self, block, planes, blocks, stride=1): 129 | downsample = None 130 | if stride != 1 or self.inplanes != planes * block.expansion: 131 | downsample = nn.Sequential( 132 | nn.Conv2d(self.inplanes, planes * block.expansion, 133 | kernel_size=1, stride=stride, bias=False), 134 | ) 135 | 136 | layers = [] 137 | layers.append(block(self.inplanes, planes, stride, downsample)) 138 | self.inplanes = planes * block.expansion 139 | for i in range(1, blocks): 140 | layers.append(block(self.inplanes, planes)) 141 | 142 | return nn.Sequential(*layers) 143 | 144 | def forward(self, x): 145 | x = self.conv1(x) 146 | 147 | x = self.layer1(x) # 32x32 148 | x = self.layer2(x) # 16x16 149 | x = self.layer3(x) # 8x8 150 | x = self.bn(x) 151 | x = self.relu(x) 152 | 153 | x = self.avgpool(x) 154 | x = x.view(x.size(0), -1) 155 | x = self.fc(x) 156 | 157 | return x 158 | 159 | 160 | def preresnet(**kwargs): 161 | """ 162 | Constructs a ResNet model. 163 | """ 164 | return PreResNet(**kwargs) 165 | -------------------------------------------------------------------------------- /cifar_imagenet/models/cifar/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch.nn as nn 11 | import math 12 | 13 | 14 | __all__ = ['resnet'] 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | "3x3 convolution with padding" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(inplanes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = nn.BatchNorm2d(planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None): 58 | super(Bottleneck, self).__init__() 59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 60 | self.bn1 = nn.BatchNorm2d(planes) 61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 62 | padding=1, bias=False) 63 | self.bn2 = nn.BatchNorm2d(planes) 64 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 65 | self.bn3 = nn.BatchNorm2d(planes * 4) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.downsample = downsample 68 | self.stride = stride 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.conv1(x) 74 | out = self.bn1(out) 75 | out = self.relu(out) 76 | 77 | out = self.conv2(out) 78 | out = self.bn2(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv3(out) 82 | out = self.bn3(out) 83 | 84 | if self.downsample is not None: 85 | residual = self.downsample(x) 86 | 87 | out += residual 88 | out = self.relu(out) 89 | 90 | return out 91 | 92 | 93 | class ResNet(nn.Module): 94 | 95 | def __init__(self, depth, num_classes=1000, block_name='BasicBlock'): 96 | super(ResNet, self).__init__() 97 | # Model type specifies number of layers for CIFAR-10 model 98 | if block_name.lower() == 'basicblock': 99 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 100 | n = (depth - 2) // 6 101 | block = BasicBlock 102 | elif block_name.lower() == 'bottleneck': 103 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 104 | n = (depth - 2) // 9 105 | block = Bottleneck 106 | else: 107 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 108 | 109 | 110 | self.inplanes = 16 111 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, 112 | bias=False) 113 | self.bn1 = nn.BatchNorm2d(16) 114 | self.relu = nn.ReLU(inplace=True) 115 | self.layer1 = self._make_layer(block, 16, n) 116 | self.layer2 = self._make_layer(block, 32, n, stride=2) 117 | self.layer3 = self._make_layer(block, 64, n, stride=2) 118 | self.avgpool = nn.AvgPool2d(8) 119 | self.fc = nn.Linear(64 * block.expansion, num_classes) 120 | 121 | for m in self.modules(): 122 | if isinstance(m, nn.Conv2d): 123 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 124 | m.weight.data.normal_(0, math.sqrt(2. / n)) 125 | elif isinstance(m, nn.BatchNorm2d): 126 | m.weight.data.fill_(1) 127 | m.bias.data.zero_() 128 | 129 | def _make_layer(self, block, planes, blocks, stride=1): 130 | downsample = None 131 | if stride != 1 or self.inplanes != planes * block.expansion: 132 | downsample = nn.Sequential( 133 | nn.Conv2d(self.inplanes, planes * block.expansion, 134 | kernel_size=1, stride=stride, bias=False), 135 | nn.BatchNorm2d(planes * block.expansion), 136 | ) 137 | 138 | layers = [] 139 | layers.append(block(self.inplanes, planes, stride, downsample)) 140 | self.inplanes = planes * block.expansion 141 | for i in range(1, blocks): 142 | layers.append(block(self.inplanes, planes)) 143 | 144 | return nn.Sequential(*layers) 145 | 146 | def forward(self, x): 147 | x = self.conv1(x) 148 | x = self.bn1(x) 149 | x = self.relu(x) # 32x32 150 | 151 | x = self.layer1(x) # 32x32 152 | x = self.layer2(x) # 16x16 153 | x = self.layer3(x) # 8x8 154 | 155 | x = self.avgpool(x) 156 | x = x.view(x.size(0), -1) 157 | x = self.fc(x) 158 | 159 | return x 160 | 161 | 162 | def resnet(**kwargs): 163 | """ 164 | Constructs a ResNet model. 165 | """ 166 | return ResNet(**kwargs) 167 | -------------------------------------------------------------------------------- /cifar_imagenet/models/cifar/resnext.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | """ 3 | Creates a ResNeXt Model as defined in: 4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016). 5 | Aggregated residual transformations for deep neural networks. 6 | arXiv preprint arXiv:1611.05431. 7 | import from https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py 8 | """ 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn import init 12 | 13 | __all__ = ['resnext'] 14 | 15 | class ResNeXtBottleneck(nn.Module): 16 | """ 17 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) 18 | """ 19 | def __init__(self, in_channels, out_channels, stride, cardinality, widen_factor): 20 | """ Constructor 21 | Args: 22 | in_channels: input channel dimensionality 23 | out_channels: output channel dimensionality 24 | stride: conv stride. Replaces pooling layer. 25 | cardinality: num of convolution groups. 26 | widen_factor: factor to reduce the input dimensionality before convolution. 27 | """ 28 | super(ResNeXtBottleneck, self).__init__() 29 | D = cardinality * out_channels // widen_factor 30 | self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False) 31 | self.bn_reduce = nn.BatchNorm2d(D) 32 | self.conv_conv = nn.Conv2d(D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 33 | self.bn = nn.BatchNorm2d(D) 34 | self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 35 | self.bn_expand = nn.BatchNorm2d(out_channels) 36 | 37 | self.shortcut = nn.Sequential() 38 | if in_channels != out_channels: 39 | self.shortcut.add_module('shortcut_conv', nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False)) 40 | self.shortcut.add_module('shortcut_bn', nn.BatchNorm2d(out_channels)) 41 | 42 | def forward(self, x): 43 | bottleneck = self.conv_reduce.forward(x) 44 | bottleneck = F.relu(self.bn_reduce.forward(bottleneck), inplace=True) 45 | bottleneck = self.conv_conv.forward(bottleneck) 46 | bottleneck = F.relu(self.bn.forward(bottleneck), inplace=True) 47 | bottleneck = self.conv_expand.forward(bottleneck) 48 | bottleneck = self.bn_expand.forward(bottleneck) 49 | residual = self.shortcut.forward(x) 50 | return F.relu(residual + bottleneck, inplace=True) 51 | 52 | 53 | class CifarResNeXt(nn.Module): 54 | """ 55 | ResNext optimized for the Cifar dataset, as specified in 56 | https://arxiv.org/pdf/1611.05431.pdf 57 | """ 58 | def __init__(self, cardinality, depth, num_classes, widen_factor=4, dropRate=0): 59 | """ Constructor 60 | Args: 61 | cardinality: number of convolution groups. 62 | depth: number of layers. 63 | num_classes: number of classes 64 | widen_factor: factor to adjust the channel dimensionality 65 | """ 66 | super(CifarResNeXt, self).__init__() 67 | self.cardinality = cardinality 68 | self.depth = depth 69 | self.block_depth = (self.depth - 2) // 9 70 | self.widen_factor = widen_factor 71 | self.num_classes = num_classes 72 | self.output_size = 64 73 | self.stages = [64, 64 * self.widen_factor, 128 * self.widen_factor, 256 * self.widen_factor] 74 | 75 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 76 | self.bn_1 = nn.BatchNorm2d(64) 77 | self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1) 78 | self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2) 79 | self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2) 80 | self.classifier = nn.Linear(1024, num_classes) 81 | init.kaiming_normal(self.classifier.weight) 82 | 83 | for key in self.state_dict(): 84 | if key.split('.')[-1] == 'weight': 85 | if 'conv' in key: 86 | init.kaiming_normal(self.state_dict()[key], mode='fan_out') 87 | if 'bn' in key: 88 | self.state_dict()[key][...] = 1 89 | elif key.split('.')[-1] == 'bias': 90 | self.state_dict()[key][...] = 0 91 | 92 | def block(self, name, in_channels, out_channels, pool_stride=2): 93 | """ Stack n bottleneck modules where n is inferred from the depth of the network. 94 | Args: 95 | name: string name of the current block. 96 | in_channels: number of input channels 97 | out_channels: number of output channels 98 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. 99 | Returns: a Module consisting of n sequential bottlenecks. 100 | """ 101 | block = nn.Sequential() 102 | for bottleneck in range(self.block_depth): 103 | name_ = '%s_bottleneck_%d' % (name, bottleneck) 104 | if bottleneck == 0: 105 | block.add_module(name_, ResNeXtBottleneck(in_channels, out_channels, pool_stride, self.cardinality, 106 | self.widen_factor)) 107 | else: 108 | block.add_module(name_, 109 | ResNeXtBottleneck(out_channels, out_channels, 1, self.cardinality, self.widen_factor)) 110 | return block 111 | 112 | def forward(self, x): 113 | x = self.conv_1_3x3.forward(x) 114 | x = F.relu(self.bn_1.forward(x), inplace=True) 115 | x = self.stage_1.forward(x) 116 | x = self.stage_2.forward(x) 117 | x = self.stage_3.forward(x) 118 | x = F.avg_pool2d(x, 8, 1) 119 | x = x.view(-1, 1024) 120 | return self.classifier(x) 121 | 122 | def resnext(**kwargs): 123 | """Constructs a ResNeXt. 124 | """ 125 | model = CifarResNeXt(**kwargs) 126 | return model -------------------------------------------------------------------------------- /cifar_imagenet/models/cifar/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG for CIFAR10. FC layers are removed. 2 | (c) YANG, Wei 3 | ''' 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | import math 7 | 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | model_urls = { 16 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 17 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 18 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 19 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 20 | } 21 | 22 | 23 | class VGG(nn.Module): 24 | 25 | def __init__(self, features, num_classes=1000): 26 | super(VGG, self).__init__() 27 | self.features = features 28 | self.classifier = nn.Linear(512, num_classes) 29 | self._initialize_weights() 30 | 31 | def forward(self, x): 32 | x = self.features(x) 33 | x = x.view(x.size(0), -1) 34 | x = self.classifier(x) 35 | return x 36 | 37 | def _initialize_weights(self): 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 41 | m.weight.data.normal_(0, math.sqrt(2. / n)) 42 | if m.bias is not None: 43 | m.bias.data.zero_() 44 | elif isinstance(m, nn.BatchNorm2d): 45 | m.weight.data.fill_(1) 46 | m.bias.data.zero_() 47 | elif isinstance(m, nn.Linear): 48 | n = m.weight.size(1) 49 | m.weight.data.normal_(0, 0.01) 50 | m.bias.data.zero_() 51 | 52 | 53 | def make_layers(cfg, batch_norm=False): 54 | layers = [] 55 | in_channels = 3 56 | for v in cfg: 57 | if v == 'M': 58 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 59 | else: 60 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 61 | if batch_norm: 62 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 63 | else: 64 | layers += [conv2d, nn.ReLU(inplace=True)] 65 | in_channels = v 66 | return nn.Sequential(*layers) 67 | 68 | 69 | cfg = { 70 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 71 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 72 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 73 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 74 | } 75 | 76 | 77 | def vgg11(**kwargs): 78 | """VGG 11-layer model (configuration "A") 79 | 80 | Args: 81 | pretrained (bool): If True, returns a model pre-trained on ImageNet 82 | """ 83 | model = VGG(make_layers(cfg['A']), **kwargs) 84 | return model 85 | 86 | 87 | def vgg11_bn(**kwargs): 88 | """VGG 11-layer model (configuration "A") with batch normalization""" 89 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 90 | return model 91 | 92 | 93 | def vgg13(**kwargs): 94 | """VGG 13-layer model (configuration "B") 95 | 96 | Args: 97 | pretrained (bool): If True, returns a model pre-trained on ImageNet 98 | """ 99 | model = VGG(make_layers(cfg['B']), **kwargs) 100 | return model 101 | 102 | 103 | def vgg13_bn(**kwargs): 104 | """VGG 13-layer model (configuration "B") with batch normalization""" 105 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 106 | return model 107 | 108 | 109 | def vgg16(**kwargs): 110 | """VGG 16-layer model (configuration "D") 111 | 112 | Args: 113 | pretrained (bool): If True, returns a model pre-trained on ImageNet 114 | """ 115 | model = VGG(make_layers(cfg['D']), **kwargs) 116 | return model 117 | 118 | 119 | def vgg16_bn(**kwargs): 120 | """VGG 16-layer model (configuration "D") with batch normalization""" 121 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 122 | return model 123 | 124 | 125 | def vgg19(**kwargs): 126 | """VGG 19-layer model (configuration "E") 127 | 128 | Args: 129 | pretrained (bool): If True, returns a model pre-trained on ImageNet 130 | """ 131 | model = VGG(make_layers(cfg['E']), **kwargs) 132 | return model 133 | 134 | 135 | def vgg19_bn(**kwargs): 136 | """VGG 19-layer model (configuration 'E') with batch normalization""" 137 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 138 | return model 139 | -------------------------------------------------------------------------------- /cifar_imagenet/models/cifar/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | __all__ = ['wrn'] 7 | 8 | class BasicBlock(nn.Module): 9 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 10 | super(BasicBlock, self).__init__() 11 | self.bn1 = nn.BatchNorm2d(in_planes) 12 | self.relu1 = nn.ReLU(inplace=True) 13 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 14 | padding=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(out_planes) 16 | self.relu2 = nn.ReLU(inplace=True) 17 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 18 | padding=1, bias=False) 19 | self.droprate = dropRate 20 | self.equalInOut = (in_planes == out_planes) 21 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 22 | padding=0, bias=False) or None 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | class NetworkBlock(nn.Module): 35 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 36 | super(NetworkBlock, self).__init__() 37 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 38 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 39 | layers = [] 40 | for i in range(nb_layers): 41 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 42 | return nn.Sequential(*layers) 43 | def forward(self, x): 44 | return self.layer(x) 45 | 46 | class WideResNet(nn.Module): 47 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 48 | super(WideResNet, self).__init__() 49 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 50 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 51 | n = (depth - 4) // 6 52 | block = BasicBlock 53 | # 1st conv before any network block 54 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 55 | padding=1, bias=False) 56 | # 1st block 57 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 58 | # 2nd block 59 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 60 | # 3rd block 61 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 62 | # global average pooling and classifier 63 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.fc = nn.Linear(nChannels[3], num_classes) 66 | self.nChannels = nChannels[3] 67 | 68 | for m in self.modules(): 69 | if isinstance(m, nn.Conv2d): 70 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 71 | m.weight.data.normal_(0, math.sqrt(2. / n)) 72 | elif isinstance(m, nn.BatchNorm2d): 73 | m.weight.data.fill_(1) 74 | m.bias.data.zero_() 75 | elif isinstance(m, nn.Linear): 76 | m.bias.data.zero_() 77 | 78 | def forward(self, x): 79 | out = self.conv1(x) 80 | out = self.block1(out) 81 | out = self.block2(out) 82 | out = self.block3(out) 83 | out = self.relu(self.bn1(out)) 84 | out = F.avg_pool2d(out, 8) 85 | out = out.view(-1, self.nChannels) 86 | return self.fc(out) 87 | 88 | def wrn(**kwargs): 89 | """ 90 | Constructs a Wide Residual Networks. 91 | """ 92 | model = WideResNet(**kwargs) 93 | return model 94 | -------------------------------------------------------------------------------- /cifar_imagenet/models/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnext import * 4 | -------------------------------------------------------------------------------- /cifar_imagenet/models/imagenet/resnext.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | """ 3 | Creates a ResNeXt Model as defined in: 4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016). 5 | Aggregated residual transformations for deep neural networks. 6 | arXiv preprint arXiv:1611.05431. 7 | import from https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua 8 | """ 9 | import math 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn import init 13 | import torch 14 | 15 | __all__ = ['resnext50', 'resnext101', 'resnext152'] 16 | 17 | class Bottleneck(nn.Module): 18 | """ 19 | RexNeXt bottleneck type C 20 | """ 21 | expansion = 4 22 | 23 | def __init__(self, inplanes, planes, baseWidth, cardinality, stride=1, downsample=None): 24 | """ Constructor 25 | Args: 26 | inplanes: input channel dimensionality 27 | planes: output channel dimensionality 28 | baseWidth: base width. 29 | cardinality: num of convolution groups. 30 | stride: conv stride. Replaces pooling layer. 31 | """ 32 | super(Bottleneck, self).__init__() 33 | 34 | D = int(math.floor(planes * (baseWidth / 64))) 35 | C = cardinality 36 | 37 | self.conv1 = nn.Conv2d(inplanes, D*C, kernel_size=1, stride=1, padding=0, bias=False) 38 | self.bn1 = nn.BatchNorm2d(D*C) 39 | self.conv2 = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False) 40 | self.bn2 = nn.BatchNorm2d(D*C) 41 | self.conv3 = nn.Conv2d(D*C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False) 42 | self.bn3 = nn.BatchNorm2d(planes * 4) 43 | self.relu = nn.ReLU(inplace=True) 44 | 45 | self.downsample = downsample 46 | 47 | def forward(self, x): 48 | residual = x 49 | 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv3(out) 59 | out = self.bn3(out) 60 | 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | 64 | out += residual 65 | out = self.relu(out) 66 | 67 | return out 68 | 69 | 70 | class ResNeXt(nn.Module): 71 | """ 72 | ResNext optimized for the ImageNet dataset, as specified in 73 | https://arxiv.org/pdf/1611.05431.pdf 74 | """ 75 | def __init__(self, baseWidth, cardinality, layers, num_classes): 76 | """ Constructor 77 | Args: 78 | baseWidth: baseWidth for ResNeXt. 79 | cardinality: number of convolution groups. 80 | layers: config of layers, e.g., [3, 4, 6, 3] 81 | num_classes: number of classes 82 | """ 83 | super(ResNeXt, self).__init__() 84 | block = Bottleneck 85 | 86 | self.cardinality = cardinality 87 | self.baseWidth = baseWidth 88 | self.num_classes = num_classes 89 | self.inplanes = 64 90 | self.output_size = 64 91 | 92 | self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False) 93 | self.bn1 = nn.BatchNorm2d(64) 94 | self.relu = nn.ReLU(inplace=True) 95 | self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 96 | self.layer1 = self._make_layer(block, 64, layers[0]) 97 | self.layer2 = self._make_layer(block, 128, layers[1], 2) 98 | self.layer3 = self._make_layer(block, 256, layers[2], 2) 99 | self.layer4 = self._make_layer(block, 512, layers[3], 2) 100 | self.avgpool = nn.AvgPool2d(7) 101 | self.fc = nn.Linear(512 * block.expansion, num_classes) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 106 | m.weight.data.normal_(0, math.sqrt(2. / n)) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | m.weight.data.fill_(1) 109 | m.bias.data.zero_() 110 | 111 | def _make_layer(self, block, planes, blocks, stride=1): 112 | """ Stack n bottleneck modules where n is inferred from the depth of the network. 113 | Args: 114 | block: block type used to construct ResNext 115 | planes: number of output channels (need to multiply by block.expansion) 116 | blocks: number of blocks to be built 117 | stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. 118 | Returns: a Module consisting of n sequential bottlenecks. 119 | """ 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality, stride, downsample)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes, self.baseWidth, self.cardinality)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool1(x) 141 | x = self.layer1(x) 142 | x = self.layer2(x) 143 | x = self.layer3(x) 144 | x = self.layer4(x) 145 | x = self.avgpool(x) 146 | x = x.view(x.size(0), -1) 147 | x = self.fc(x) 148 | 149 | return x 150 | 151 | 152 | def resnext50(baseWidth, cardinality): 153 | """ 154 | Construct ResNeXt-50. 155 | """ 156 | model = ResNeXt(baseWidth, cardinality, [3, 4, 6, 3], 1000) 157 | return model 158 | 159 | 160 | def resnext101(baseWidth, cardinality): 161 | """ 162 | Construct ResNeXt-101. 163 | """ 164 | model = ResNeXt(baseWidth, cardinality, [3, 4, 23, 3], 1000) 165 | return model 166 | 167 | 168 | def resnext152(baseWidth, cardinality): 169 | """ 170 | Construct ResNeXt-152. 171 | """ 172 | model = ResNeXt(baseWidth, cardinality, [3, 8, 36, 3], 1000) 173 | return model 174 | -------------------------------------------------------------------------------- /cifar_imagenet/recipes.md: -------------------------------------------------------------------------------- 1 | # SGD 2 | 3 | ``` 4 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-sgd-01 --gpu-id 0 --model_name sgd_01 5 | 6 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-sgd-003 --gpu-id 0 --model_name sgd_003 --lr 0.03 7 | 8 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-sgd-001 --gpu-id 0 --model_name sgd_001 --lr 0.01 9 | 10 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-sgd-0003 --gpu-id 0 --model_name sgd_0003 --lr 0.003 11 | ``` 12 | 13 | # Vanilla Adam 14 | 15 | ``` 16 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint /cps/gadam/checkpoints/cifar10/resnet-20-adam-01 --gpu-id 0 --model_name adam_01 17 | 18 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint /cps/gadam/checkpoints/cifar10/resnet-20-adam-003 --gpu-id 0 --model_name adam_003 --lr 0.03 19 | 20 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint /cps/gadam/checkpoints/cifar10/resnet-20-adam-001 --gpu-id 0 --model_name adam_001 --lr 0.01 21 | 22 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint /cps/gadam/checkpoints/cifar10/resnet-20-adam-0003 --gpu-id 0 --model_name adam_0003 --lr 0.003 23 | ``` 24 | 25 | # RAdam experiments 26 | 27 | ``` 28 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer radam --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-radam-01 --gpu-id 0 --model_name radam_01 29 | 30 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer radam --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-radam-003 --gpu-id 0 --model_name radam_003 --lr 0.03 31 | 32 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer radam --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-radam-001 --gpu-id 0 --model_name radam_001 --lr 0.01 33 | 34 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer radam --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-radam-0003 --gpu-id 0 --model_name radam_0003 --lr 0.003 35 | ``` 36 | 37 | # Adam with 100 warmup 38 | 39 | ``` 40 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-01 --gpu-id 0 --warmup 100 --model_name adam_100_01 41 | 42 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-003 --gpu-id 0 --warmup 100 --model_name adam_100_003 --lr 0.03 43 | 44 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-001 --gpu-id 0 --warmup 100 --model_name adam_100_001 --lr 0.01 45 | 46 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-0003 --gpu-id 0 --warmup 100 --model_name adam_100_0003 --lr 0.003 47 | ``` 48 | 49 | # Adam with 200 warmup 50 | 51 | ``` 52 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-01 --gpu-id 0 --warmup 200 --model_name adam_200_01 53 | 54 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-003 --gpu-id 0 --warmup 200 --model_name adam_200_003 --lr 0.03 55 | 56 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-001 --gpu-id 0 --warmup 200 --model_name adam_200_001 --lr 0.01 57 | 58 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-0003 --gpu-id 0 --warmup 200 --model_name adam_200_0003 --lr 0.003 59 | ``` 60 | 61 | # Adam with 500 warmup 62 | 63 | ``` 64 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-01 --gpu-id 0 --warmup 500 --model_name adam_500_01 65 | 66 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-003 --gpu-id 0 --warmup 500 --model_name adam_500_003 --lr 0.03 67 | 68 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-001 --gpu-id 0 --warmup 500 --model_name adam_500_001 --lr 0.01 69 | 70 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-0003 --gpu-id 0 --warmup 500 --model_name adam_500_0003 --lr 0.003 71 | ``` 72 | 73 | # Adam with 1000 warmup 74 | 75 | ``` 76 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-01 --gpu-id 0 --warmup 1000 --model_name adam_1000_01 77 | 78 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-003 --gpu-id 0 --warmup 1000 --model_name adam_1000_003 --lr 0.03 79 | 80 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-001 --gpu-id 0 --warmup 1000 --model_name adam_1000_001 --lr 0.01 81 | 82 | python cifar.py -a resnet --depth 20 --epochs 164 --schedule 81 122 --gamma 0.1 --wd 1e-4 --optimizer adamw --beta1 0.9 --beta2 0.999 --checkpoint ./cps/gadam/checkpoints/cifar10/resnet-20-adam-0003 --gpu-id 0 --warmup 1000 --model_name adam_1000_0003 --lr 0.003 83 | ``` 84 | 85 | # ImageNet 86 | 87 | ``` 88 | python imagenet.py -j 16 -a resnet18 --data /data/ILSVRC2012/ --epochs 90 --schedule 31 61 --gamma 0.1 -c ./cps/imagenet/resnet18_radam_0003 --model_name radam_0003 --optimizer radam --lr 0.003 --beta1 0.9 --beta2 0.999 89 | 90 | python imagenet.py -j 16 -a resnet18 --data /data/ILSVRC2012/ --epochs 90 --schedule 31 61 --gamma 0.1 -c ./cps/imagenet/resnet18_radam_0005 --model_name radam_0005 --optimizer radam --lr 0.005 --beta1 0.9 --beta2 0.999 91 | ``` 92 | -------------------------------------------------------------------------------- /cifar_imagenet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Useful utils 2 | """ 3 | from .misc import * 4 | from .logger import * 5 | from .visualize import * 6 | from .eval import * 7 | 8 | # progress bar 9 | import os, sys 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 11 | from progress.bar import Bar as Bar -------------------------------------------------------------------------------- /cifar_imagenet/utils/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | __all__ = ['accuracy'] 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the precision@k for the specified values of k""" 7 | maxk = max(topk) 8 | batch_size = target.size(0) 9 | 10 | _, pred = output.topk(maxk, 1, True, True) 11 | pred = pred.t() 12 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 13 | 14 | res = [] 15 | for k in topk: 16 | correct_k = correct[:k].view(-1).float().sum(0) 17 | res.append(correct_k.mul_(100.0 / batch_size)) 18 | return res -------------------------------------------------------------------------------- /cifar_imagenet/utils/images/cifar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiyuanLucasLiu/RAdam/d9fd30a337894c4003768561d45e8730dbd41333/cifar_imagenet/utils/images/cifar.png -------------------------------------------------------------------------------- /cifar_imagenet/utils/images/imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiyuanLucasLiu/RAdam/d9fd30a337894c4003768561d45e8730dbd41333/cifar_imagenet/utils/images/imagenet.png -------------------------------------------------------------------------------- /cifar_imagenet/utils/logger.py: -------------------------------------------------------------------------------- 1 | # A simple torch style logger 2 | # (C) Wei YANG 2017 3 | from __future__ import absolute_import 4 | import matplotlib as mpl 5 | mpl.use('Agg') 6 | 7 | import matplotlib.pyplot as plt 8 | import os 9 | import sys 10 | import numpy as np 11 | 12 | __all__ = ['Logger', 'LoggerMonitor', 'savefig'] 13 | 14 | def savefig(fname, dpi=None): 15 | dpi = 150 if dpi == None else dpi 16 | plt.savefig(fname, dpi=dpi) 17 | 18 | def plot_overlap(logger, names=None): 19 | names = logger.names if names == None else names 20 | numbers = logger.numbers 21 | for _, name in enumerate(names): 22 | x = np.arange(len(numbers[name])) 23 | plt.plot(x, np.asarray(numbers[name])) 24 | return [logger.title + '(' + name + ')' for name in names] 25 | 26 | class Logger(object): 27 | '''Save training process to log file with simple plot function.''' 28 | def __init__(self, fpath, title=None, resume=False): 29 | self.file = None 30 | self.resume = resume 31 | self.title = '' if title == None else title 32 | if fpath is not None: 33 | if resume: 34 | self.file = open(fpath, 'r') 35 | name = self.file.readline() 36 | self.names = name.rstrip().split('\t') 37 | self.numbers = {} 38 | for _, name in enumerate(self.names): 39 | self.numbers[name] = [] 40 | 41 | for numbers in self.file: 42 | numbers = numbers.rstrip().split('\t') 43 | for i in range(0, len(numbers)): 44 | self.numbers[self.names[i]].append(numbers[i]) 45 | self.file.close() 46 | self.file = open(fpath, 'a') 47 | else: 48 | self.file = open(fpath, 'w') 49 | 50 | def set_names(self, names): 51 | if self.resume: 52 | pass 53 | # initialize numbers as empty list 54 | self.numbers = {} 55 | self.names = names 56 | for _, name in enumerate(self.names): 57 | self.file.write(name) 58 | self.file.write('\t') 59 | self.numbers[name] = [] 60 | self.file.write('\n') 61 | self.file.flush() 62 | 63 | 64 | def append(self, numbers): 65 | assert len(self.names) == len(numbers), 'Numbers do not match names' 66 | for index, num in enumerate(numbers): 67 | self.file.write("{0:.6f}".format(num)) 68 | self.file.write('\t') 69 | self.numbers[self.names[index]].append(num) 70 | self.file.write('\n') 71 | self.file.flush() 72 | 73 | def plot(self, names=None): 74 | names = self.names if names == None else names 75 | numbers = self.numbers 76 | for _, name in enumerate(names): 77 | x = np.arange(len(numbers[name])) 78 | plt.plot(x, np.asarray(numbers[name])) 79 | plt.legend([self.title + '(' + name + ')' for name in names]) 80 | plt.grid(True) 81 | 82 | def close(self): 83 | if self.file is not None: 84 | self.file.close() 85 | 86 | class LoggerMonitor(object): 87 | '''Load and visualize multiple logs.''' 88 | def __init__ (self, paths): 89 | '''paths is a distionary with {name:filepath} pair''' 90 | self.loggers = [] 91 | for title, path in paths.items(): 92 | logger = Logger(path, title=title, resume=True) 93 | self.loggers.append(logger) 94 | 95 | def plot(self, names=None): 96 | plt.figure() 97 | plt.subplot(121) 98 | legend_text = [] 99 | for logger in self.loggers: 100 | legend_text += plot_overlap(logger, names) 101 | plt.legend(legend_text, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) 102 | plt.grid(True) 103 | 104 | if __name__ == '__main__': 105 | # # Example 106 | # logger = Logger('test.txt') 107 | # logger.set_names(['Train loss', 'Valid loss','Test loss']) 108 | 109 | # length = 100 110 | # t = np.arange(length) 111 | # train_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 112 | # valid_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 113 | # test_loss = np.exp(-t / 10.0) + np.random.rand(length) * 0.1 114 | 115 | # for i in range(0, length): 116 | # logger.append([train_loss[i], valid_loss[i], test_loss[i]]) 117 | # logger.plot() 118 | 119 | # Example: logger monitor 120 | paths = { 121 | 'resadvnet20':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet20/log.txt', 122 | 'resadvnet32':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet32/log.txt', 123 | 'resadvnet44':'/home/wyang/code/pytorch-classification/checkpoint/cifar10/resadvnet44/log.txt', 124 | } 125 | 126 | field = ['Valid Acc.'] 127 | 128 | monitor = LoggerMonitor(paths) 129 | monitor.plot(names=field) 130 | savefig('test.eps') 131 | -------------------------------------------------------------------------------- /cifar_imagenet/utils/misc.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import errno 7 | import os 8 | import sys 9 | import time 10 | import math 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.init as init 15 | from torch.autograd import Variable 16 | 17 | __all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'AverageMeter'] 18 | 19 | 20 | def get_mean_and_std(dataset): 21 | '''Compute the mean and std value of dataset.''' 22 | dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 23 | 24 | mean = torch.zeros(3) 25 | std = torch.zeros(3) 26 | print('==> Computing mean and std..') 27 | for inputs, targets in dataloader: 28 | for i in range(3): 29 | mean[i] += inputs[:,i,:,:].mean() 30 | std[i] += inputs[:,i,:,:].std() 31 | mean.div_(len(dataset)) 32 | std.div_(len(dataset)) 33 | return mean, std 34 | 35 | def init_params(net): 36 | '''Init layer parameters.''' 37 | for m in net.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | init.kaiming_normal(m.weight, mode='fan_out') 40 | if m.bias: 41 | init.constant(m.bias, 0) 42 | elif isinstance(m, nn.BatchNorm2d): 43 | init.constant(m.weight, 1) 44 | init.constant(m.bias, 0) 45 | elif isinstance(m, nn.Linear): 46 | init.normal(m.weight, std=1e-3) 47 | if m.bias: 48 | init.constant(m.bias, 0) 49 | 50 | def mkdir_p(path): 51 | '''make dir if not exist''' 52 | try: 53 | os.makedirs(path) 54 | except OSError as exc: # Python >2.5 55 | if exc.errno == errno.EEXIST and os.path.isdir(path): 56 | pass 57 | else: 58 | raise 59 | 60 | class AverageMeter(object): 61 | """Computes and stores the average and current value 62 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 63 | """ 64 | def __init__(self): 65 | self.reset() 66 | 67 | def reset(self): 68 | self.val = 0 69 | self.avg = 0 70 | self.sum = 0 71 | self.count = 0 72 | 73 | def update(self, val, n=1): 74 | self.val = val 75 | self.sum += val * n 76 | self.count += n 77 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /cifar_imagenet/utils/radam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | class RAdam(Optimizer): 6 | 7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 8 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 9 | self.buffer = [[None, None, None] for ind in range(10)] 10 | super(RAdam, self).__init__(params, defaults) 11 | 12 | def __setstate__(self, state): 13 | super(RAdam, self).__setstate__(state) 14 | 15 | def step(self, closure=None): 16 | 17 | loss = None 18 | if closure is not None: 19 | loss = closure() 20 | 21 | for group in self.param_groups: 22 | 23 | for p in group['params']: 24 | if p.grad is None: 25 | continue 26 | grad = p.grad.data.float() 27 | if grad.is_sparse: 28 | raise RuntimeError('RAdam does not support sparse gradients') 29 | 30 | p_data_fp32 = p.data.float() 31 | 32 | state = self.state[p] 33 | 34 | if len(state) == 0: 35 | state['step'] = 0 36 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 37 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 38 | else: 39 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 40 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 41 | 42 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 43 | beta1, beta2 = group['betas'] 44 | 45 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 46 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 47 | 48 | state['step'] += 1 49 | buffered = self.buffer[int(state['step'] % 10)] 50 | if state['step'] == buffered[0]: 51 | N_sma, step_size = buffered[1], buffered[2] 52 | else: 53 | buffered[0] = state['step'] 54 | beta2_t = beta2 ** state['step'] 55 | N_sma_max = 2 / (1 - beta2) - 1 56 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 57 | buffered[1] = N_sma 58 | 59 | # more conservative since it's an approximated value 60 | if N_sma >= 5: 61 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 62 | else: 63 | step_size = group['lr'] / (1 - beta1 ** state['step']) 64 | buffered[2] = step_size 65 | 66 | if group['weight_decay'] != 0: 67 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 68 | 69 | # more conservative since it's an approximated value 70 | if N_sma >= 5: 71 | denom = exp_avg_sq.sqrt().add_(group['eps']) 72 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 73 | else: 74 | p_data_fp32.add_(-step_size, exp_avg) 75 | 76 | p.data.copy_(p_data_fp32) 77 | 78 | return loss 79 | 80 | class RAdam_4step(Optimizer): 81 | 82 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, update_all=False, additional_four=False): 83 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 84 | self.update_all = update_all # whether update the first 4 steps 85 | self.additional_four = additional_four # whether use additional 4 steps for SGD 86 | self.buffer = [[None, None] for ind in range(10)] 87 | super(RAdam_4step, self).__init__(params, defaults) 88 | 89 | def __setstate__(self, state): 90 | super(RAdam_4step, self).__setstate__(state) 91 | 92 | def step(self, closure=None): 93 | 94 | loss = None 95 | if closure is not None: 96 | loss = closure() 97 | 98 | for group in self.param_groups: 99 | 100 | for p in group['params']: 101 | if p.grad is None: 102 | continue 103 | grad = p.grad.data.float() 104 | if grad.is_sparse: 105 | raise RuntimeError('RAdam_4step does not support sparse gradients') 106 | 107 | p_data_fp32 = p.data.float() 108 | 109 | state = self.state[p] 110 | 111 | if len(state) == 0: 112 | state['step'] = -4 if self.additional_four else 0 #since this exp requires exactly 4 step, it is hard coded 113 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 114 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 115 | else: 116 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 117 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 118 | 119 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 120 | beta1, beta2 = group['betas'] 121 | 122 | state['step'] += 1 123 | 124 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 125 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 126 | 127 | if state['step'] > 0: 128 | 129 | state_step = state['step'] + 4 if self.additional_four else state['step'] #since this exp requires exactly 4 step, it is hard coded 130 | 131 | buffered = self.buffer[int(state_step % 10)] 132 | if state_step == buffered[0]: 133 | step_size = buffered[1] 134 | else: 135 | buffered[0] = state_step 136 | beta2_t = beta2 ** state['step'] 137 | 138 | if state['step'] > 4: #since this exp requires exactly 4 step, it is hard coded 139 | N_sma_max = 2 / (1 - beta2) - 1 140 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 141 | step_size = group['lr'] * math.sqrt((N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state_step) 142 | elif self.update_all: 143 | step_size = group['lr'] / (1 - beta1 ** state_step) 144 | else: 145 | step_size = 0 146 | buffered[1] = step_size 147 | 148 | if state['step'] > 4: #since this exp requires exactly 4 step, it is hard coded 149 | if group['weight_decay'] != 0: 150 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 151 | denom = (exp_avg_sq.sqrt() / math.sqrt(1 - beta2 ** state_step)).add_(group['eps']) 152 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 153 | p.data.copy_(p_data_fp32) 154 | elif self.update_all: 155 | if group['weight_decay'] != 0: 156 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 157 | denom = (exp_avg_sq.sqrt() / math.sqrt(1 - beta2 ** state_step)) 158 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 159 | p.data.copy_(p_data_fp32) 160 | else: 161 | state_step = state['step'] + 4 if self.additional_four else state['step'] #since this exp requires exactly 4 step, it is hard coded 162 | 163 | if group['weight_decay'] != 0: 164 | p_data_fp32.add_(-group['weight_decay'] * 0.1, p_data_fp32) 165 | 166 | step_size = 0.1 / (1 - beta1 ** state_step) 167 | p_data_fp32.add_(-step_size, exp_avg) 168 | p.data.copy_(p_data_fp32) 169 | 170 | return loss 171 | 172 | class AdamW(Optimizer): 173 | 174 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 175 | weight_decay=0, use_variance=True, warmup = 4000): 176 | defaults = dict(lr=lr, betas=betas, eps=eps, 177 | weight_decay=weight_decay, use_variance=True, warmup = warmup) 178 | print('======== Warmup: {} ========='.format(warmup)) 179 | super(AdamW, self).__init__(params, defaults) 180 | 181 | def __setstate__(self, state): 182 | super(AdamW, self).__setstate__(state) 183 | 184 | def step(self, closure=None): 185 | global iter_idx 186 | iter_idx += 1 187 | grad_list = list() 188 | mom_list = list() 189 | mom_2rd_list = list() 190 | 191 | loss = None 192 | if closure is not None: 193 | loss = closure() 194 | 195 | for group in self.param_groups: 196 | 197 | for p in group['params']: 198 | if p.grad is None: 199 | continue 200 | grad = p.grad.data.float() 201 | if grad.is_sparse: 202 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 203 | 204 | p_data_fp32 = p.data.float() 205 | 206 | state = self.state[p] 207 | 208 | if len(state) == 0: 209 | state['step'] = 0 210 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 211 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 212 | else: 213 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 214 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 215 | 216 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 217 | beta1, beta2 = group['betas'] 218 | 219 | state['step'] += 1 220 | 221 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 222 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 223 | 224 | denom = exp_avg_sq.sqrt().add_(group['eps']) 225 | bias_correction1 = 1 - beta1 ** state['step'] 226 | bias_correction2 = 1 - beta2 ** state['step'] 227 | 228 | if group['warmup'] > state['step']: 229 | scheduled_lr = 1e-6 + state['step'] * (group['lr'] - 1e-6) / group['warmup'] 230 | else: 231 | scheduled_lr = group['lr'] 232 | 233 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 234 | if group['weight_decay'] != 0: 235 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 236 | 237 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 238 | 239 | p.data.copy_(p_data_fp32) 240 | 241 | return loss 242 | -------------------------------------------------------------------------------- /cifar_imagenet/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from .misc import * 8 | 9 | __all__ = ['make_image', 'show_batch', 'show_mask', 'show_mask_single'] 10 | 11 | # functions to show an image 12 | def make_image(img, mean=(0,0,0), std=(1,1,1)): 13 | for i in range(0, 3): 14 | img[i] = img[i] * std[i] + mean[i] # unnormalize 15 | npimg = img.numpy() 16 | return np.transpose(npimg, (1, 2, 0)) 17 | 18 | def gauss(x,a,b,c): 19 | return torch.exp(-torch.pow(torch.add(x,-b),2).div(2*c*c)).mul(a) 20 | 21 | def colorize(x): 22 | ''' Converts a one-channel grayscale image to a color heatmap image ''' 23 | if x.dim() == 2: 24 | torch.unsqueeze(x, 0, out=x) 25 | if x.dim() == 3: 26 | cl = torch.zeros([3, x.size(1), x.size(2)]) 27 | cl[0] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 28 | cl[1] = gauss(x,1,.5,.3) 29 | cl[2] = gauss(x,1,.2,.3) 30 | cl[cl.gt(1)] = 1 31 | elif x.dim() == 4: 32 | cl = torch.zeros([x.size(0), 3, x.size(2), x.size(3)]) 33 | cl[:,0,:,:] = gauss(x,.5,.6,.2) + gauss(x,1,.8,.3) 34 | cl[:,1,:,:] = gauss(x,1,.5,.3) 35 | cl[:,2,:,:] = gauss(x,1,.2,.3) 36 | return cl 37 | 38 | def show_batch(images, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 39 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 40 | plt.imshow(images) 41 | plt.show() 42 | 43 | 44 | def show_mask_single(images, mask, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 45 | im_size = images.size(2) 46 | 47 | # save for adding mask 48 | im_data = images.clone() 49 | for i in range(0, 3): 50 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 51 | 52 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 53 | plt.subplot(2, 1, 1) 54 | plt.imshow(images) 55 | plt.axis('off') 56 | 57 | # for b in range(mask.size(0)): 58 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 59 | mask_size = mask.size(2) 60 | # print('Max %f Min %f' % (mask.max(), mask.min())) 61 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 62 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 63 | # for c in range(3): 64 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 65 | 66 | # print(mask.size()) 67 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 68 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 69 | plt.subplot(2, 1, 2) 70 | plt.imshow(mask) 71 | plt.axis('off') 72 | 73 | def show_mask(images, masklist, Mean=(2, 2, 2), Std=(0.5,0.5,0.5)): 74 | im_size = images.size(2) 75 | 76 | # save for adding mask 77 | im_data = images.clone() 78 | for i in range(0, 3): 79 | im_data[:,i,:,:] = im_data[:,i,:,:] * Std[i] + Mean[i] # unnormalize 80 | 81 | images = make_image(torchvision.utils.make_grid(images), Mean, Std) 82 | plt.subplot(1+len(masklist), 1, 1) 83 | plt.imshow(images) 84 | plt.axis('off') 85 | 86 | for i in range(len(masklist)): 87 | mask = masklist[i].data.cpu() 88 | # for b in range(mask.size(0)): 89 | # mask[b] = (mask[b] - mask[b].min())/(mask[b].max() - mask[b].min()) 90 | mask_size = mask.size(2) 91 | # print('Max %f Min %f' % (mask.max(), mask.min())) 92 | mask = (upsampling(mask, scale_factor=im_size/mask_size)) 93 | # mask = colorize(upsampling(mask, scale_factor=im_size/mask_size)) 94 | # for c in range(3): 95 | # mask[:,c,:,:] = (mask[:,c,:,:] - Mean[c])/Std[c] 96 | 97 | # print(mask.size()) 98 | mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask.expand_as(im_data))) 99 | # mask = make_image(torchvision.utils.make_grid(0.3*im_data+0.7*mask), Mean, Std) 100 | plt.subplot(1+len(masklist), 1, i+2) 101 | plt.imshow(mask) 102 | plt.axis('off') 103 | 104 | 105 | 106 | # x = torch.zeros(1, 3, 3) 107 | # out = colorize(x) 108 | # out_im = make_image(out) 109 | # plt.imshow(out_im) 110 | # plt.show() -------------------------------------------------------------------------------- /img/variance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiyuanLucasLiu/RAdam/d9fd30a337894c4003768561d45e8730dbd41333/img/variance.png -------------------------------------------------------------------------------- /language-model/README.md: -------------------------------------------------------------------------------- 1 | # One Billion Word 2 | 3 | This folder is modified based on the LD-Net project ([original repo](https://github.com/LiyuanLucasLiu/LD-Net)). For more details about this code base, please refer to the original repo. 4 | A training [recipe](/language-model/recipes.md) is provided for language modeling experiments. 5 | 6 | -------------------------------------------------------------------------------- /language-model/eval_1bw.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import datetime 3 | import time 4 | import torch 5 | import torch.autograd as autograd 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import codecs 9 | import pickle 10 | import math 11 | 12 | from model_word_ada.LM import LM 13 | from model_word_ada.basic import BasicRNN 14 | from model_word_ada.ddnet import DDRNN 15 | from model_word_ada.densenet import DenseRNN 16 | from model_word_ada.ldnet import LDRNN 17 | from model_word_ada.dataset import LargeDataset, EvalDataset 18 | from model_word_ada.adaptive import AdaptiveSoftmax 19 | import model_word_ada.utils as utils 20 | 21 | # from tensorboardX import SummaryWriter 22 | 23 | import argparse 24 | import json 25 | import os 26 | import sys 27 | import itertools 28 | import functools 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--dataset_folder', default='/data/billionwords/one_billion/') 33 | parser.add_argument('--load_checkpoint', default='./checkpoint/basic_1.model') 34 | parser.add_argument('--sequence_length', type=int, default=600) 35 | parser.add_argument('--hid_dim', type=int, default=2048) 36 | parser.add_argument('--word_dim', type=int, default=300) 37 | parser.add_argument('--label_dim', type=int, default=-1) 38 | parser.add_argument('--layer_num', type=int, default=2) 39 | parser.add_argument('--droprate', type=float, default=0.1) 40 | parser.add_argument('--add_relu', action='store_true') 41 | parser.add_argument('--gpu', type=int, default=1) 42 | parser.add_argument('--rnn_layer', choices=['Basic', 'DDNet', 'DenseNet', 'LDNet'], default='Basic') 43 | parser.add_argument('--rnn_unit', choices=['gru', 'lstm', 'rnn'], default='lstm') 44 | parser.add_argument('--cut_off', nargs='+', default=[4000,40000,200000]) 45 | parser.add_argument('--limit', type=int, default=76800) 46 | 47 | args = parser.parse_args() 48 | 49 | if args.gpu >= 0: 50 | torch.cuda.set_device(args.gpu) 51 | 52 | print('loading dataset') 53 | dataset = pickle.load(open(args.dataset_folder + 'test.pk', 'rb')) 54 | w_map, test_data = dataset['w_map'], dataset['test_data'] 55 | 56 | cut_off = args.cut_off + [len(w_map) + 1] 57 | 58 | test_loader = EvalDataset(test_data, args.sequence_length) 59 | 60 | print('building model') 61 | 62 | rnn_map = {'Basic': BasicRNN, 'DDNet': DDRNN, 'DenseNet': DenseRNN, 'LDNet': functools.partial(LDRNN, layer_drop = 0)} 63 | rnn_layer = rnn_map[args.rnn_layer](args.layer_num, args.rnn_unit, args.word_dim, args.hid_dim, args.droprate) 64 | 65 | if args.label_dim > 0: 66 | soft_max = AdaptiveSoftmax(args.label_dim, cut_off) 67 | else: 68 | soft_max = AdaptiveSoftmax(rnn_layer.output_dim, cut_off) 69 | 70 | lm_model = LM(rnn_layer, soft_max, len(w_map), args.word_dim, args.droprate, label_dim = args.label_dim, add_relu = args.add_relu) 71 | lm_model.cuda() 72 | 73 | if os.path.isfile(args.load_checkpoint): 74 | print("loading checkpoint: '{}'".format(args.load_checkpoint)) 75 | 76 | checkpoint_file = torch.load(args.load_checkpoint, map_location=lambda storage, loc: storage) 77 | lm_model.load_state_dict(checkpoint_file['lm_model']) 78 | else: 79 | print("no checkpoint found at: '{}'".format(args.load_checkpoint)) 80 | 81 | test_lm = nn.NLLLoss() 82 | 83 | test_lm.cuda() 84 | lm_model.cuda() 85 | 86 | print('evaluating') 87 | lm_model.eval() 88 | 89 | iterator = test_loader.get_tqdm() 90 | 91 | lm_model.init_hidden() 92 | total_loss = 0 93 | total_len = 0 94 | for word_t, label_t in iterator: 95 | label_t = label_t.view(-1) 96 | tmp_len = label_t.size(0) 97 | output = lm_model.log_prob(word_t) 98 | total_loss += tmp_len * utils.to_scalar(test_lm(autograd.Variable(output), label_t)) 99 | total_len += tmp_len 100 | 101 | if args.limit > 0 and total_len > args.limit: 102 | break 103 | 104 | print(str(total_loss / total_len)) 105 | ppl = math.exp(total_loss / total_len) 106 | print('PPL: ' + str(ppl)) 107 | -------------------------------------------------------------------------------- /language-model/model_word_ada/LM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import model_word_ada.utils as utils 5 | 6 | class LM(nn.Module): 7 | 8 | def __init__(self, rnn, soft_max, w_num, w_dim, droprate, label_dim = -1, add_relu=False): 9 | super(LM, self).__init__() 10 | 11 | self.rnn = rnn 12 | self.soft_max = soft_max 13 | 14 | if soft_max: 15 | self.forward = self.softmax_forward 16 | else: 17 | self.forward = self.embed_forward 18 | 19 | self.w_num = w_num 20 | self.w_dim = w_dim 21 | self.word_embed = nn.Embedding(w_num, w_dim) 22 | 23 | self.rnn_output = self.rnn.output_dim 24 | 25 | self.add_proj = label_dim > 0 26 | if self.add_proj: 27 | self.project = nn.Linear(self.rnn_output, label_dim) 28 | if add_relu: 29 | self.relu = nn.ReLU() 30 | else: 31 | self.relu = lambda x: x 32 | 33 | self.drop = nn.Dropout(p=droprate) 34 | 35 | def load_embed(self, origin_lm): 36 | self.word_embed = origin_lm.word_embed 37 | self.soft_max = origin_lm.soft_max 38 | 39 | def rand_ini(self): 40 | 41 | self.rnn.rand_ini() 42 | # utils.init_linear(self.project) 43 | self.soft_max.rand_ini() 44 | # if not self.tied_weight: 45 | utils.init_embedding(self.word_embed.weight) 46 | 47 | if self.add_proj: 48 | utils.init_linear(self.project) 49 | 50 | def init_hidden(self): 51 | self.rnn.init_hidden() 52 | 53 | def softmax_forward(self, w_in, target): 54 | 55 | w_emb = self.word_embed(w_in) 56 | 57 | w_emb = self.drop(w_emb) 58 | 59 | out = self.rnn(w_emb).contiguous().view(-1, self.rnn_output) 60 | 61 | if self.add_proj: 62 | out = self.drop(self.relu(self.project(out))) 63 | # out = self.drop(self.project(out)) 64 | 65 | out = self.soft_max(out, target) 66 | 67 | return out 68 | 69 | def embed_forward(self, w_in, target): 70 | 71 | w_emb = self.word_embed(w_in) 72 | 73 | w_emb = self.drop(w_emb) 74 | 75 | out = self.rnn(w_emb).contiguous().view(-1, self.rnn_output) 76 | 77 | if self.add_proj: 78 | out = self.drop(self.relu(self.project(out))) 79 | # out = self.drop(self.project(out)) 80 | 81 | out = self.soft_max(out, target) 82 | 83 | return out 84 | 85 | def log_prob(self, w_in): 86 | 87 | w_emb = self.word_embed(w_in) 88 | 89 | out = self.rnn(w_emb).contiguous().view(-1, self.rnn_output) 90 | 91 | if self.add_proj: 92 | out = self.relu(self.project(out)) 93 | 94 | out = self.soft_max.log_prob(out) 95 | 96 | return out -------------------------------------------------------------------------------- /language-model/model_word_ada/adaptive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | 5 | from math import sqrt 6 | 7 | class AdaptiveSoftmax(nn.Module): 8 | def __init__(self, input_size, cutoff): 9 | super().__init__() 10 | 11 | self.input_size = input_size 12 | self.cutoff = cutoff 13 | self.output_size = cutoff[0] + len(cutoff) - 1 14 | 15 | self.head = nn.Linear(input_size, self.output_size) 16 | self.tail = nn.ModuleList() 17 | 18 | self.cross_entropy = nn.CrossEntropyLoss(size_average=False) 19 | 20 | for i in range(len(self.cutoff) - 1): 21 | # seq = nn.Sequential( 22 | # nn.Linear(input_size, input_size // 4 ** i, False), 23 | # nn.Linear(input_size // 4 ** i, cutoff[i + 1] - cutoff[i], False) 24 | # ) 25 | 26 | seq = nn.Sequential( 27 | nn.Linear(input_size, input_size // 4 ** (i + 1), False), 28 | nn.Linear(input_size // 4 ** (i + 1), cutoff[i + 1] - cutoff[i], False) 29 | ) 30 | 31 | self.tail.append(seq) 32 | 33 | def rand_ini(self): 34 | 35 | nn.init.xavier_normal(self.head.weight) 36 | 37 | for tail in self.tail: 38 | nn.init.xavier_normal(tail[0].weight) 39 | nn.init.xavier_normal(tail[1].weight) 40 | 41 | def log_prob(self, w_in): 42 | lsm = nn.LogSoftmax(dim=1).cuda() 43 | 44 | head_out = self.head(w_in) 45 | 46 | batch_size = head_out.size(0) 47 | prob = torch.zeros(batch_size, self.cutoff[-1]).cuda() 48 | 49 | lsm_head = lsm(head_out) 50 | prob.narrow(1, 0, self.output_size).add_(lsm_head.narrow(1, 0, self.output_size).data) 51 | 52 | for i in range(len(self.tail)): 53 | pos = self.cutoff[i] 54 | i_size = self.cutoff[i + 1] - pos 55 | buffer = lsm_head.narrow(1, self.cutoff[0] + i, 1) 56 | buffer = buffer.expand(batch_size, i_size) 57 | lsm_tail = lsm(self.tail[i](w_in)) 58 | prob.narrow(1, pos, i_size).copy_(buffer.data).add_(lsm_tail.data) 59 | 60 | return prob 61 | 62 | def forward(self, w_in, target): 63 | 64 | batch_size = w_in.size(0) 65 | output = 0.0 66 | 67 | first_target = target.clone() 68 | 69 | for i in range(len(self.cutoff) - 1): 70 | 71 | mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1])) 72 | 73 | if mask.sum() > 0: 74 | 75 | first_target[mask] = self.cutoff[0] + i 76 | 77 | second_target = Variable(target[mask].add(-self.cutoff[i])) 78 | second_input = w_in.index_select(0, Variable(mask.nonzero().squeeze())) 79 | 80 | second_output = self.tail[i](second_input) 81 | 82 | output += self.cross_entropy(second_output, second_target) 83 | 84 | output += self.cross_entropy(self.head(w_in), Variable(first_target)) 85 | output /= batch_size 86 | return output -------------------------------------------------------------------------------- /language-model/model_word_ada/basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import model_word_ada.utils as utils 5 | from model_word_ada.bnlstm import BNLSTM 6 | 7 | class BasicUnit(nn.Module): 8 | def __init__(self, unit, input_dim, hid_dim, droprate): 9 | super(BasicUnit, self).__init__() 10 | 11 | rnnunit_map = {'rnn': nn.RNN, 'lstm': nn.LSTM, 'gru': nn.GRU, 'bnlstm': BNLSTM} 12 | 13 | self.batch_norm = (unit == 'bnlstm') 14 | 15 | self.layer = rnnunit_map[unit](input_dim, hid_dim, 1) 16 | self.droprate = droprate 17 | 18 | self.output_dim = hid_dim 19 | 20 | self.init_hidden() 21 | 22 | def init_hidden(self): 23 | 24 | self.hidden_state = None 25 | 26 | def rand_ini(self): 27 | 28 | if not self.batch_norm: 29 | utils.init_lstm(self.layer) 30 | 31 | def forward(self, x): 32 | # set_trace() 33 | out, new_hidden = self.layer(x, self.hidden_state) 34 | 35 | self.hidden_state = utils.repackage_hidden(new_hidden) 36 | 37 | if self.droprate > 0: 38 | out = F.dropout(out, p=self.droprate, training=self.training) 39 | 40 | return out 41 | 42 | class BasicRNN(nn.Module): 43 | def __init__(self, layer_num, unit, emb_dim, hid_dim, droprate): 44 | super(BasicRNN, self).__init__() 45 | 46 | layer_list = [BasicUnit(unit, emb_dim, hid_dim, droprate)] + [BasicUnit(unit, hid_dim, hid_dim, droprate) for i in range(layer_num - 1)] 47 | self.layer = nn.Sequential(*layer_list) 48 | self.output_dim = layer_list[-1].output_dim 49 | 50 | self.init_hidden() 51 | 52 | def init_hidden(self): 53 | 54 | for tup in self.layer.children(): 55 | tup.init_hidden() 56 | 57 | def rand_ini(self): 58 | 59 | for tup in self.layer.children(): 60 | tup.rand_ini() 61 | 62 | def forward(self, x): 63 | return self.layer(x) -------------------------------------------------------------------------------- /language-model/model_word_ada/bnlstm.py: -------------------------------------------------------------------------------- 1 | """Implementation of batch-normalized LSTM. 2 | from: https://github.com/jihunchoi/recurrent-batch-normalization-pytorch 3 | """ 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Variable 7 | from torch.nn import functional, init 8 | 9 | 10 | class SeparatedBatchNorm1d(nn.Module): 11 | 12 | """ 13 | A batch normalization module which keeps its running mean 14 | and variance separately per timestep. 15 | """ 16 | 17 | def __init__(self, num_features, max_length, eps=1e-5, momentum=0.1, 18 | affine=True): 19 | """ 20 | Most parts are copied from 21 | torch.nn.modules.batchnorm._BatchNorm. 22 | """ 23 | 24 | super(SeparatedBatchNorm1d, self).__init__() 25 | self.num_features = num_features 26 | self.max_length = max_length 27 | self.affine = affine 28 | self.eps = eps 29 | self.momentum = momentum 30 | if self.affine: 31 | self.weight = nn.Parameter(torch.FloatTensor(num_features)) 32 | self.bias = nn.Parameter(torch.FloatTensor(num_features)) 33 | else: 34 | self.register_parameter('weight', None) 35 | self.register_parameter('bias', None) 36 | for i in range(max_length): 37 | self.register_buffer( 38 | 'running_mean_{}'.format(i), torch.zeros(num_features)) 39 | self.register_buffer( 40 | 'running_var_{}'.format(i), torch.ones(num_features)) 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | for i in range(self.max_length): 45 | running_mean_i = getattr(self, 'running_mean_{}'.format(i)) 46 | running_var_i = getattr(self, 'running_var_{}'.format(i)) 47 | running_mean_i.zero_() 48 | running_var_i.fill_(1) 49 | if self.affine: 50 | self.weight.data.uniform_() 51 | self.bias.data.zero_() 52 | 53 | def _check_input_dim(self, input_): 54 | if input_.size(1) != self.running_mean_0.nelement(): 55 | raise ValueError('got {}-feature tensor, expected {}' 56 | .format(input_.size(1), self.num_features)) 57 | 58 | def forward(self, input_, time): 59 | self._check_input_dim(input_) 60 | if time >= self.max_length: 61 | time = self.max_length - 1 62 | running_mean = getattr(self, 'running_mean_{}'.format(time)) 63 | running_var = getattr(self, 'running_var_{}'.format(time)) 64 | return functional.batch_norm( 65 | input=input_, running_mean=running_mean, running_var=running_var, 66 | weight=self.weight, bias=self.bias, training=self.training, 67 | momentum=self.momentum, eps=self.eps) 68 | 69 | def __repr__(self): 70 | return ('{name}({num_features}, eps={eps}, momentum={momentum},' 71 | ' max_length={max_length}, affine={affine})' 72 | .format(name=self.__class__.__name__, **self.__dict__)) 73 | 74 | class BNLSTMCell(nn.Module): 75 | 76 | """A BN-LSTM cell.""" 77 | 78 | def __init__(self, input_size, hidden_size, max_length=784, use_bias=True): 79 | 80 | super(BNLSTMCell, self).__init__() 81 | self.input_size = input_size 82 | self.hidden_size = hidden_size 83 | self.max_length = max_length 84 | self.use_bias = use_bias 85 | self.weight_ih = nn.Parameter( 86 | torch.FloatTensor(input_size, 4 * hidden_size)) 87 | self.weight_hh = nn.Parameter( 88 | torch.FloatTensor(hidden_size, 4 * hidden_size)) 89 | if use_bias: 90 | self.bias = nn.Parameter(torch.FloatTensor(4 * hidden_size)) 91 | else: 92 | self.register_parameter('bias', None) 93 | # BN parameters 94 | self.bn_ih = SeparatedBatchNorm1d( 95 | num_features=4 * hidden_size, max_length=max_length) 96 | self.bn_hh = SeparatedBatchNorm1d( 97 | num_features=4 * hidden_size, max_length=max_length) 98 | self.bn_c = SeparatedBatchNorm1d( 99 | num_features=hidden_size, max_length=max_length) 100 | self.reset_parameters() 101 | 102 | def reset_parameters(self): 103 | """ 104 | Initialize parameters following the way proposed in the paper. 105 | """ 106 | 107 | # The input-to-hidden weight matrix is initialized orthogonally. 108 | init.orthogonal(self.weight_ih.data) 109 | # The hidden-to-hidden weight matrix is initialized as an identity 110 | # matrix. 111 | weight_hh_data = torch.eye(self.hidden_size) 112 | weight_hh_data = weight_hh_data.repeat(1, 4) 113 | self.weight_hh.data.set_(weight_hh_data) 114 | # The bias is just set to zero vectors. 115 | init.constant(self.bias.data, val=0) 116 | # Initialization of BN parameters. 117 | self.bn_ih.reset_parameters() 118 | self.bn_hh.reset_parameters() 119 | self.bn_c.reset_parameters() 120 | self.bn_ih.bias.data.fill_(0) 121 | self.bn_hh.bias.data.fill_(0) 122 | self.bn_ih.weight.data.fill_(0.1) 123 | self.bn_hh.weight.data.fill_(0.1) 124 | self.bn_c.weight.data.fill_(0.1) 125 | 126 | def forward(self, input_, hx, time): 127 | """ 128 | Args: 129 | input_: A (batch, input_size) tensor containing input 130 | features. 131 | hx: A tuple (h_0, c_0), which contains the initial hidden 132 | and cell state, where the size of both states is 133 | (batch, hidden_size). 134 | time: The current timestep value, which is used to 135 | get appropriate running statistics. 136 | 137 | Returns: 138 | h_1, c_1: Tensors containing the next hidden and cell state. 139 | """ 140 | 141 | h_0, c_0 = hx 142 | batch_size = h_0.size(0) 143 | bias_batch = (self.bias.unsqueeze(0) 144 | .expand(batch_size, *self.bias.size())) 145 | wh = torch.mm(h_0, self.weight_hh) 146 | wi = torch.mm(input_, self.weight_ih) 147 | bn_wh = self.bn_hh(wh, time=time) 148 | bn_wi = self.bn_ih(wi, time=time) 149 | f, i, o, g = torch.split(bn_wh + bn_wi + bias_batch, 150 | split_size=self.hidden_size, dim=1) 151 | c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g) 152 | h_1 = torch.sigmoid(o) * torch.tanh(self.bn_c(c_1, time=time)) 153 | return h_1, c_1 154 | 155 | 156 | class BNLSTM(nn.Module): 157 | 158 | """A module that runs multiple steps of LSTM.""" 159 | 160 | def __init__(self, input_size, hidden_size, num_layers=1, 161 | use_bias=True, batch_first=False, dropout=0, **kwargs): 162 | super(BNLSTM, self).__init__() 163 | self.input_size = input_size 164 | self.hidden_size = hidden_size 165 | self.num_layers = num_layers 166 | self.use_bias = use_bias 167 | self.batch_first = batch_first 168 | self.dropout = dropout 169 | 170 | for layer in range(num_layers): 171 | layer_input_size = input_size if layer == 0 else hidden_size 172 | cell = BNLSTMCell(input_size=layer_input_size, 173 | hidden_size=hidden_size, 174 | **kwargs) 175 | setattr(self, 'cell_{}'.format(layer), cell) 176 | self.dropout_layer = nn.Dropout(dropout) 177 | self.reset_parameters() 178 | 179 | def get_cell(self, layer): 180 | return getattr(self, 'cell_{}'.format(layer)) 181 | 182 | def reset_parameters(self): 183 | for layer in range(self.num_layers): 184 | cell = self.get_cell(layer) 185 | cell.reset_parameters() 186 | 187 | @staticmethod 188 | def _forward_rnn(cell, input_, hx): 189 | max_time = input_.size(0) 190 | output = [] 191 | for time in range(max_time): 192 | h_next, c_next = cell(input_=input_[time], hx=hx, time=time) 193 | # # mask = (time < length).float().unsqueeze(1).expand_as(h_next) 194 | # h_next = h_next*mask + hx[0]*(1 - mask) 195 | # c_next = c_next*mask + hx[1]*(1 - mask) 196 | hx_next = (h_next, c_next) 197 | output.append(h_next) 198 | hx = hx_next 199 | output = torch.stack(output, 0) 200 | return output, hx 201 | 202 | def forward(self, input_, length=None, hx=None): 203 | if self.batch_first: 204 | input_ = input_.transpose(0, 1) 205 | max_time, batch_size, _ = input_.size() 206 | # if length is None: 207 | # length = Variable(torch.LongTensor([max_time] * batch_size)) 208 | # if input_.is_cuda: 209 | # device = input_.get_device() 210 | # length = length.cuda(device) 211 | if hx is None: 212 | hx = Variable(input_.data.new(batch_size, self.hidden_size).zero_()) 213 | hx = (hx, hx) 214 | h_n = [] 215 | c_n = [] 216 | layer_output = None 217 | for layer in range(self.num_layers): 218 | cell = self.get_cell(layer) 219 | layer_output, (layer_h_n, layer_c_n) = BNLSTM._forward_rnn( 220 | cell=cell, input_=input_, hx=hx) 221 | input_ = self.dropout_layer(layer_output) 222 | h_n.append(layer_h_n) 223 | c_n.append(layer_c_n) 224 | output = layer_output 225 | h_n = torch.stack(h_n, 0) 226 | c_n = torch.stack(c_n, 0) 227 | return output, (h_n, c_n) -------------------------------------------------------------------------------- /language-model/model_word_ada/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.autograd as autograd 5 | 6 | import sys 7 | import pickle 8 | import random 9 | from tqdm import tqdm 10 | 11 | from torch.utils.data import Dataset 12 | 13 | class EvalDataset(object): 14 | 15 | def __init__(self, dataset, sequence_length): 16 | super(EvalDataset, self).__init__() 17 | self.dataset = dataset 18 | 19 | self.sequence_length = sequence_length 20 | 21 | self.construct_index() 22 | 23 | def get_tqdm(self): 24 | return tqdm(self, mininterval=2, total=self.index_length, leave=False, file=sys.stdout, ncols=80) 25 | 26 | def construct_index(self): 27 | token_per_batch = self.sequence_length 28 | tot_num = len(self.dataset) - 1 29 | res_num = tot_num - tot_num % token_per_batch 30 | 31 | self.x = list(torch.unbind(torch.LongTensor(self.dataset[0:res_num]).view(-1, self.sequence_length), 0)) 32 | self.y = list(torch.unbind(torch.LongTensor(self.dataset[1:res_num+1]).view(-1, self.sequence_length), 0)) 33 | 34 | self.x.append(torch.LongTensor(self.dataset[res_num:tot_num])) 35 | self.y.append(torch.LongTensor(self.dataset[res_num+1:tot_num+1])) 36 | 37 | self.index_length = len(self.x) 38 | self.cur_idx = 0 39 | 40 | def __iter__(self): 41 | return self 42 | 43 | def __next__(self): 44 | if self.cur_idx == self.index_length: 45 | self.cur_idx = 0 46 | raise StopIteration 47 | 48 | word_t = autograd.Variable(self.x[self.cur_idx]).cuda().view(-1, 1) 49 | label_t = autograd.Variable(self.y[self.cur_idx]).cuda().view(-1, 1) 50 | 51 | self.cur_idx += 1 52 | 53 | return word_t, label_t 54 | 55 | # class SmallDataset(object): 56 | 57 | # def __init__(self, dataset, batch_size, sequence_length): 58 | # super(SmallDataset, self).__init__() 59 | # self.dataset = dataset 60 | 61 | # self.batch_size = batch_size 62 | # self.sequence_length = sequence_length 63 | 64 | # self.construct_index() 65 | 66 | # def get_tqdm(self): 67 | # return tqdm(self, mininterval=2, total=self.index_length, leave=False, file=sys.stdout) 68 | 69 | # def construct_index(self): 70 | # token_per_batch = self.batch_size * self.sequence_length 71 | # res_num = len(self.dataset) - 1 72 | # res_num = res_num - res_num % token_per_batch 73 | 74 | # self.x = torch.LongTensor(self.dataset[0:res_num]).view(self.batch_size, -1, self.sequence_length).transpose_(0, 1).contiguous() 75 | # self.y = torch.LongTensor(self.dataset[1:res_num+1]).view(self.batch_size, -1, self.sequence_length).transpose_(0, 1).contiguous() 76 | 77 | # self.index_length = self.x.size(0) 78 | # self.cur_idx = 0 79 | 80 | # def __iter__(self): 81 | # return self 82 | 83 | # def __next__(self): 84 | # if self.cur_idx == self.index_length: 85 | # self.cur_idx = 0 86 | # raise StopIteration 87 | 88 | # word_t = autograd.Variable(self.x[self.cur_idx]).cuda() 89 | # label_t = autograd.Variable(self.y[self.cur_idx]).cuda() 90 | 91 | # self.cur_idx += 1 92 | 93 | # return word_t, label_t 94 | 95 | class LargeDataset(object): 96 | 97 | def __init__(self, root, range_idx, batch_size, sequence_length): 98 | super(LargeDataset, self).__init__() 99 | self.root = root 100 | self.range_idx = range_idx 101 | self.shuffle_list = list(range(0, range_idx)) 102 | self.shuffle() 103 | 104 | self.batch_size = batch_size 105 | self.sequence_length = sequence_length 106 | self.token_per_batch = self.batch_size * self.sequence_length 107 | 108 | self.total_batch_num = -1 109 | 110 | def shuffle(self): 111 | random.shuffle(self.shuffle_list) 112 | 113 | def get_tqdm(self): 114 | self.batch_count = 0 115 | 116 | if self.total_batch_num <= 0: 117 | return tqdm(self, mininterval=2, leave=False, file=sys.stdout).__iter__() 118 | else: 119 | return tqdm(self, mininterval=2, total=self.total_batch_num, leave=False, file=sys.stdout, ncols=80).__iter__() 120 | 121 | def __iter__(self): 122 | self.cur_idx = 0 123 | self.file_idx = 0 124 | self.index_length = 0 125 | return self 126 | 127 | def __next__(self): 128 | if self.cur_idx >= self.index_length: 129 | self.open_next() 130 | 131 | word_t = autograd.Variable(self.x[self.cur_idx]).cuda() 132 | # label_t = autograd.Variable(self.y[self.cur_idx]).cuda() 133 | label_t = self.y[self.cur_idx].cuda() 134 | 135 | self.cur_idx += 1 136 | 137 | return word_t, label_t 138 | 139 | def open_next(self): 140 | if self.file_idx >= self.range_idx: 141 | self.total_batch_num = self.batch_count 142 | self.shuffle() 143 | raise StopIteration 144 | 145 | self.dataset = pickle.load(open(self.root + 'train_' + str( self.shuffle_list[self.file_idx])+'.pk', 'rb')) 146 | 147 | res_num = len(self.dataset) - 1 148 | res_num = res_num - res_num % self.token_per_batch 149 | 150 | self.x = torch.LongTensor(self.dataset[0:res_num]).view(self.batch_size, -1, self.sequence_length).transpose_(0, 1).transpose_(1, 2).contiguous() 151 | self.y = torch.LongTensor(self.dataset[1:res_num+1]).view(self.batch_size, -1, self.sequence_length).transpose_(0, 1).transpose_(1, 2).contiguous() 152 | 153 | self.index_length = self.x.size(0) 154 | self.cur_idx = 0 155 | 156 | self.batch_count += self.index_length 157 | self.file_idx += 1 -------------------------------------------------------------------------------- /language-model/model_word_ada/ddnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import model_word_ada.utils as utils 5 | from model_word_ada.bnlstm import BNLSTM 6 | 7 | class BasicUnit(nn.Module): 8 | def __init__(self, unit, unit_number, emb_dim, hid_dim, droprate): 9 | super(BasicUnit, self).__init__() 10 | 11 | rnnunit_map = {'rnn': nn.RNN, 'lstm': nn.LSTM, 'gru': nn.GRU, 'bnlstm': BNLSTM} 12 | 13 | self.batch_norm = (unit == 'bnlstm') 14 | 15 | self.unit_number = unit_number 16 | # self.unit_weight = nn.Parameter(torch.FloatTensor([1] * unit_number)) 17 | 18 | self.unit_list = nn.ModuleList() 19 | self.unit_list.append(rnnunit_map[unit](emb_dim, hid_dim, 1)) 20 | if unit_number > 1: 21 | self.unit_list.extend([rnnunit_map[unit](hid_dim, hid_dim, 1) for ind in range(unit_number - 1)]) 22 | 23 | self.droprate = droprate 24 | 25 | self.output_dim = emb_dim + hid_dim * unit_number 26 | 27 | self.init_hidden() 28 | 29 | def init_hidden(self): 30 | 31 | self.hidden_list = [None for i in range(self.unit_number)] 32 | 33 | def rand_ini(self): 34 | 35 | if not self.batch_norm: 36 | for cur_lstm in self.unit_list: 37 | utils.init_lstm(cur_lstm) 38 | 39 | def forward(self, x): 40 | 41 | out = 0 42 | # n_w = F.softmax(self.unit_weight, dim=0) 43 | for ind in range(self.unit_number): 44 | nout, new_hidden = self.unit_list[ind](x[ind], self.hidden_list[ind]) 45 | self.hidden_list[ind] = utils.repackage_hidden(new_hidden) 46 | out = out + nout 47 | # out = out + n_w[ind] * self.unit_number * nout 48 | 49 | if self.droprate > 0: 50 | out = F.dropout(out, p=self.droprate, training=self.training) 51 | 52 | x.append(out) 53 | 54 | return x 55 | 56 | class DDRNN(nn.Module): 57 | def __init__(self, layer_num, unit, emb_dim, hid_dim, droprate): 58 | super(DDRNN, self).__init__() 59 | 60 | layer_list = [BasicUnit(unit, i + 1, emb_dim, hid_dim, droprate) for i in range(layer_num)] 61 | self.layer = nn.Sequential(*layer_list) 62 | self.output_dim = layer_list[-1].output_dim 63 | 64 | self.init_hidden() 65 | 66 | def init_hidden(self): 67 | 68 | for tup in self.layer.children(): 69 | tup.init_hidden() 70 | 71 | def rand_ini(self): 72 | 73 | for tup in self.layer.children(): 74 | tup.rand_ini() 75 | 76 | def forward(self, x): 77 | out = self.layer([x]) 78 | return torch.cat(out, 2) -------------------------------------------------------------------------------- /language-model/model_word_ada/densenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. module:: densenet 3 | :synopsis: vanilla dense RNN 4 | 5 | .. moduleauthor:: Liyuan Liu 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import model_word_ada.utils as utils 11 | 12 | class BasicUnit(nn.Module): 13 | def __init__(self, unit, input_dim, increase_rate, droprate): 14 | super(BasicUnit, self).__init__() 15 | 16 | rnnunit_map = {'rnn': nn.RNN, 'lstm': nn.LSTM, 'gru': nn.GRU} 17 | 18 | self.unit = unit 19 | 20 | self.layer = rnnunit_map[unit](input_dim, increase_rate, 1) 21 | 22 | if 'lstm' == self.unit: 23 | utils.init_lstm(self.layer) 24 | 25 | self.droprate = droprate 26 | 27 | self.input_dim = input_dim 28 | self.increase_rate = increase_rate 29 | self.output_dim = input_dim + increase_rate 30 | 31 | self.init_hidden() 32 | 33 | def init_hidden(self): 34 | 35 | self.hidden_state = None 36 | 37 | def rand_ini(self): 38 | return 39 | 40 | def forward(self, x): 41 | 42 | if self.droprate > 0: 43 | new_x = F.dropout(x, p=self.droprate, training=self.training) 44 | else: 45 | new_x = x 46 | 47 | out, new_hidden = self.layer(new_x, self.hidden_state) 48 | 49 | self.hidden_state = utils.repackage_hidden(new_hidden) 50 | 51 | out = out.contiguous() 52 | 53 | return torch.cat([x, out], 2) 54 | 55 | class DenseRNN(nn.Module): 56 | def __init__(self, layer_num, unit, emb_dim, hid_dim, droprate): 57 | super(DenseRNN, self).__init__() 58 | 59 | self.layer_list = [BasicUnit(unit, emb_dim + i * hid_dim, hid_dim, droprate) for i in range(layer_num)] 60 | self.layer = nn.Sequential(*self.layer_list) 61 | self.output_dim = self.layer_list[-1].output_dim 62 | 63 | self.init_hidden() 64 | 65 | def init_hidden(self): 66 | 67 | for tup in self.layer_list: 68 | tup.init_hidden() 69 | 70 | def rand_ini(self): 71 | 72 | for tup in self.layer_list: 73 | tup.rand_ini() 74 | 75 | def forward(self, x): 76 | return self.layer(x) -------------------------------------------------------------------------------- /language-model/model_word_ada/ldnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. module:: densenet 3 | :synopsis: vanilla dense RNN 4 | 5 | .. moduleauthor:: Liyuan Liu 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import model_word_ada.utils as utils 11 | import random 12 | 13 | class BasicUnit(nn.Module): 14 | def __init__(self, unit, input_dim, increase_rate, droprate, layer_drop = 0): 15 | super(BasicUnit, self).__init__() 16 | 17 | rnnunit_map = {'rnn': nn.RNN, 'lstm': nn.LSTM, 'gru': nn.GRU} 18 | 19 | self.unit = unit 20 | 21 | self.layer = rnnunit_map[unit](input_dim, increase_rate, 1) 22 | 23 | if 'lstm' == self.unit: 24 | utils.init_lstm(self.layer) 25 | 26 | self.layer_drop = layer_drop 27 | 28 | self.droprate = droprate 29 | 30 | self.input_dim = input_dim 31 | self.increase_rate = increase_rate 32 | self.output_dim = input_dim + increase_rate 33 | 34 | self.init_hidden() 35 | 36 | def init_hidden(self): 37 | 38 | self.hidden_state = None 39 | 40 | def rand_ini(self): 41 | return 42 | 43 | def forward(self, x, p_out): 44 | 45 | if self.droprate > 0: 46 | new_x = F.dropout(x, p=self.droprate, training=self.training) 47 | else: 48 | new_x = x 49 | 50 | out, new_hidden = self.layer(new_x, self.hidden_state) 51 | 52 | self.hidden_state = utils.repackage_hidden(new_hidden) 53 | 54 | out = out.contiguous() 55 | 56 | if self.training and random.uniform(0, 1) < self.layer_drop: 57 | deep_out = torch.autograd.Variable( torch.zeros(x.size(0), x.size(1), self.increase_rate) ).cuda() 58 | else: 59 | deep_out = out 60 | 61 | o_out = torch.cat([p_out, out], 2) 62 | d_out = torch.cat([x, deep_out], 2) 63 | return d_out, o_out 64 | 65 | class LDRNN(nn.Module): 66 | def __init__(self, layer_num, unit, emb_dim, hid_dim, droprate, layer_drop): 67 | super(LDRNN, self).__init__() 68 | 69 | self.layer_list = [BasicUnit(unit, emb_dim + i * hid_dim, hid_dim, droprate, layer_drop) for i in range(layer_num)] 70 | 71 | self.layer_num = layer_num 72 | self.layer = nn.ModuleList(self.layer_list) 73 | self.output_dim = self.layer_list[-1].output_dim 74 | 75 | self.init_hidden() 76 | 77 | def init_hidden(self): 78 | 79 | for tup in self.layer_list: 80 | tup.init_hidden() 81 | 82 | def rand_ini(self): 83 | 84 | for tup in self.layer_list: 85 | tup.rand_ini() 86 | 87 | def forward(self, x): 88 | output = x 89 | for ind in range(self.layer_num): 90 | x, output = self.layer_list[ind](x, output) 91 | return output -------------------------------------------------------------------------------- /language-model/model_word_ada/radam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | # from tensorboardX import SummaryWriter 6 | # writer = SummaryWriter(logdir='/cps/gadam/n_cifa/') 7 | # iter_idx = 0 8 | 9 | # from ipdb import set_trace 10 | import torch.optim 11 | 12 | class RAdam(Optimizer): 13 | 14 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 15 | weight_decay=0): 16 | defaults = dict(lr=lr, betas=betas, eps=eps, 17 | weight_decay=weight_decay) 18 | 19 | super(RAdam, self).__init__(params, defaults) 20 | 21 | def __setstate__(self, state): 22 | super(RAdam, self).__setstate__(state) 23 | 24 | def step(self, closure=None): 25 | loss = None 26 | beta2_t = None 27 | ratio = None 28 | N_sma_max = None 29 | N_sma = None 30 | 31 | if closure is not None: 32 | loss = closure() 33 | 34 | for group in self.param_groups: 35 | 36 | for p in group['params']: 37 | if p.grad is None: 38 | continue 39 | grad = p.grad.data.float() 40 | if grad.is_sparse: 41 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 42 | 43 | p_data_fp32 = p.data.float() 44 | 45 | state = self.state[p] 46 | 47 | if len(state) == 0: 48 | state['step'] = 0 49 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 50 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 51 | else: 52 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 53 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 54 | 55 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 56 | beta1, beta2 = group['betas'] 57 | 58 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 59 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 60 | 61 | state['step'] += 1 62 | if beta2_t is None: 63 | beta2_t = beta2 ** state['step'] 64 | N_sma_max = 2 / (1 - beta2) - 1 65 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 66 | beta1_t = 1 - beta1 ** state['step'] 67 | if N_sma >= 5: 68 | ratio = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / beta1_t 69 | 70 | if group['weight_decay'] != 0: 71 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 72 | 73 | # more conservative since it's an approximated value 74 | if N_sma >= 5: 75 | step_size = group['lr'] * ratio 76 | denom = exp_avg_sq.sqrt().add_(group['eps']) 77 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 78 | else: 79 | step_size = group['lr'] / beta1_t 80 | p_data_fp32.add_(-step_size, exp_avg) 81 | 82 | p.data.copy_(p_data_fp32) 83 | 84 | return loss 85 | 86 | 87 | class AdamW(Optimizer): 88 | 89 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 90 | weight_decay=0): 91 | defaults = dict(lr=lr, betas=betas, eps=eps, 92 | weight_decay=weight_decay) 93 | 94 | super(AdamW, self).__init__(params, defaults) 95 | 96 | def __setstate__(self, state): 97 | super(AdamW, self).__setstate__(state) 98 | 99 | def step(self, closure=None): 100 | global iter_idx 101 | iter_idx += 1 102 | grad_list = list() 103 | mom_list = list() 104 | mom_2rd_list = list() 105 | 106 | loss = None 107 | if closure is not None: 108 | loss = closure() 109 | 110 | for group in self.param_groups: 111 | 112 | for p in group['params']: 113 | if p.grad is None: 114 | continue 115 | grad = p.grad.data.float() 116 | if grad.is_sparse: 117 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 118 | 119 | p_data_fp32 = p.data.float() 120 | 121 | state = self.state[p] 122 | 123 | if len(state) == 0: 124 | state['step'] = 0 125 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 126 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 127 | else: 128 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 129 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 130 | 131 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 132 | beta1, beta2 = group['betas'] 133 | 134 | state['step'] += 1 135 | 136 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 137 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 138 | 139 | denom = exp_avg_sq.sqrt().add_(group['eps']) 140 | bias_correction1 = 1 - beta1 ** state['step'] 141 | bias_correction2 = 1 - beta2 ** state['step'] 142 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 143 | 144 | if group['weight_decay'] != 0: 145 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 146 | 147 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 148 | 149 | p.data.copy_(p_data_fp32) 150 | 151 | return loss 152 | -------------------------------------------------------------------------------- /language-model/model_word_ada/resnet.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | # import torch.nn as nn 3 | # import torch.nn.functional as F 4 | # import model.utils as utils 5 | 6 | # class BasicUnit(nn.Module): 7 | # def __init__(self, unit, input_dim, hid_dim, droprate): 8 | # super(BasicUnit, self).__init__() 9 | 10 | # rnnunit_map = {'rnn': nn.RNN, 'lstm': nn.LSTM, 'gru': nn.GRU} 11 | # self.unit_number = unit_number 12 | 13 | # self.layer = rnnunit_map[unit](input_dim, hid_dim, 1) 14 | 15 | # self.droprate = droprate 16 | 17 | # self.output_dim = input_dim + hid_dim 18 | 19 | # self.init_hidden() 20 | 21 | # def init_hidden(self): 22 | 23 | # self.hidden_list = [None for i in range(unit_number)] 24 | 25 | # def rand_ini(self): 26 | 27 | # for cur_lstm in self.unit_list: 28 | # utils.init_lstm(cur_lstm) 29 | 30 | # def forward(self, x): 31 | 32 | # out, _ = self.layer(x) 33 | 34 | # if self.droprate > 0: 35 | # out = F.dropout(out, p=self.droprate, training=self.training) 36 | 37 | # return toch.cat([x, out], 2) 38 | 39 | # class DenseRNN(nn.Module): 40 | # def __init__(self, layer_num, unit, emb_dim, hid_dim, droprate): 41 | # super(DenseRNN, self).__init__() 42 | 43 | # self.layer = nn.Sequential([BasicUnit(unit, emb_dim + i * hid_dim, hid_dim, droprate) for i in range(layer_num) ]) 44 | 45 | # self.output_dim = self.layer[-1].output_dim 46 | 47 | # self.init_hidden() 48 | 49 | # def init_hidden(self): 50 | # self.layer.apply(lambda t: t.init_hidden()) 51 | 52 | # def rand_ini(self): 53 | # self.layer.apply(lambda t: t.rand_ini()) 54 | 55 | # def forward(self, x): 56 | # return self.layer(x) -------------------------------------------------------------------------------- /language-model/model_word_ada/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import json 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init 8 | 9 | from torch.autograd import Variable 10 | 11 | def repackage_hidden(h): 12 | """Wraps hidden states in new Variables, to detach them from their history.""" 13 | if type(h) == torch.Tensor: 14 | return h.data 15 | else: 16 | return tuple(repackage_hidden(v) for v in h) 17 | 18 | def to_scalar(var): 19 | """change the first element of a tensor to scalar 20 | """ 21 | return var.view(-1).data.tolist()[0] 22 | 23 | def init_embedding(input_embedding): 24 | """ 25 | Initialize embedding 26 | """ 27 | bias = np.sqrt(3.0 / input_embedding.size(1)) 28 | nn.init.uniform(input_embedding, -bias, bias) 29 | 30 | def init_linear(input_linear): 31 | """ 32 | Initialize linear transformation 33 | """ 34 | bias = np.sqrt(6.0 / (input_linear.weight.size(0) + input_linear.weight.size(1))) 35 | nn.init.uniform(input_linear.weight, -bias, bias) 36 | if input_linear.bias is not None: 37 | input_linear.bias.data.zero_() 38 | 39 | def adjust_learning_rate(optimizer, lr): 40 | """ 41 | shrink learning rate for pytorch 42 | """ 43 | for param_group in optimizer.param_groups: 44 | param_group['lr'] = lr 45 | 46 | def init_lstm(input_lstm): 47 | """ 48 | Initialize lstm 49 | """ 50 | for ind in range(0, input_lstm.num_layers): 51 | weight = eval('input_lstm.weight_ih_l'+str(ind)) 52 | bias = np.sqrt(6.0 / (weight.size(0)/4 + weight.size(1))) 53 | nn.init.uniform(weight, -bias, bias) 54 | weight = eval('input_lstm.weight_hh_l'+str(ind)) 55 | bias = np.sqrt(6.0 / (weight.size(0)/4 + weight.size(1))) 56 | nn.init.uniform(weight, -bias, bias) 57 | 58 | if input_lstm.bias: 59 | for ind in range(0, input_lstm.num_layers): 60 | weight = eval('input_lstm.bias_ih_l'+str(ind)) 61 | weight.data.zero_() 62 | weight.data[input_lstm.hidden_size: 2 * input_lstm.hidden_size] = 1 63 | weight = eval('input_lstm.bias_hh_l'+str(ind)) 64 | weight.data.zero_() 65 | weight.data[input_lstm.hidden_size: 2 * input_lstm.hidden_size] = 1 66 | -------------------------------------------------------------------------------- /language-model/pre_word_ada/encode_data2folder.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import os 4 | import codecs 5 | import random 6 | import numpy as np 7 | 8 | from tqdm import tqdm 9 | 10 | import itertools 11 | import functools 12 | 13 | def encode_dataset(input_folder, w_map, reverse): 14 | 15 | w_eof = w_map['\n'] 16 | w_unk = w_map[''] 17 | 18 | list_dirs = os.walk(input_folder) 19 | 20 | lines = list() 21 | 22 | for root, dirs, files in list_dirs: 23 | for file in tqdm(files): 24 | with codecs.open(os.path.join(root, file), 'r', 'utf-8') as fin: 25 | lines = lines + list(filter(lambda t: t and not t.isspace(), fin.readlines())) 26 | 27 | dataset = list() 28 | for line in lines: 29 | dataset += list(map(lambda t: w_map.get(t, w_unk), line.split())) + [w_eof] 30 | 31 | if reverse: 32 | dataset = dataset[::-1] 33 | 34 | return dataset 35 | 36 | def encode_dataset2file(input_folder, output_folder, w_map, reverse): 37 | 38 | w_eof = w_map['\n'] 39 | w_unk = w_map[''] 40 | 41 | list_dirs = os.walk(input_folder) 42 | 43 | range_ind = 0 44 | 45 | for root, dirs, files in list_dirs: 46 | for file in tqdm(files): 47 | with codecs.open(os.path.join(root, file), 'r', 'utf-8') as fin: 48 | lines = list(filter(lambda t: t and not t.isspace(), fin.readlines())) 49 | 50 | dataset = list() 51 | for line in lines: 52 | dataset += list(map(lambda t: w_map.get(t, w_unk), line.split())) + [w_eof] 53 | 54 | if reverse: 55 | dataset = dataset[::-1] 56 | 57 | with open(output_folder+'train_'+ str(range_ind) + '.pk', 'wb') as f: 58 | pickle.dump(dataset, f) 59 | 60 | range_ind += 1 61 | 62 | return range_ind 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--train_folder', default="/data/billionwords/1-billion-word-language-modeling-benchmark/training-monolingual.tokenized.shuffled") 67 | parser.add_argument('--test_folder', default="/data/billionwords/1-billion-word-language-modeling-benchmark/heldout-monolingual.tokenized.shuffled") 68 | parser.add_argument('--input_map', default="/data/billionwords/1b_map.pk") 69 | parser.add_argument('--output_folder', default="/data/billionwords/one_billion/") 70 | parser.add_argument('--threshold', type=int, default=3) 71 | parser.add_argument('--unk', default='') 72 | parser.add_argument('--reverse', action='store_true') 73 | args = parser.parse_args() 74 | 75 | with open(args.input_map, 'rb') as f: 76 | w_count = pickle.load(f) 77 | 78 | unk_count = sum([v for k, v in w_count.items() if v <= args.threshold]) 79 | w_list = [(k, v) for k, v in w_count.items() if v > args.threshold] 80 | w_list.append(('', unk_count)) 81 | w_list.sort(key=lambda t: t[1], reverse=True) 82 | w_map = {kv[0]:v for v, kv in enumerate(w_list)} 83 | 84 | range_ind = encode_dataset2file(args.train_folder, args.output_folder, w_map, args.reverse) 85 | 86 | test_dataset = encode_dataset(args.test_folder, w_map, args.reverse) 87 | 88 | with open(args.output_folder+'test.pk', 'wb') as f: 89 | pickle.dump({'w_map': w_map, 'test_data':test_dataset, 'range' : range_ind}, f) 90 | -------------------------------------------------------------------------------- /language-model/pre_word_ada/gene_map.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import os 4 | import random 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | import itertools 9 | import functools 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--input_folder', default="/data/billionwords/1-billion-word-language-modeling-benchmark/training-monolingual.tokenized.shuffled") 14 | parser.add_argument('--output_map', default="/data/billionwords/1b_map.pk") 15 | args = parser.parse_args() 16 | 17 | w_count = {'\n':0} 18 | 19 | list_dirs = os.walk(args.input_folder) 20 | 21 | for root, dirs, files in list_dirs: 22 | for file in tqdm(files): 23 | with open(os.path.join(root, file)) as fin: 24 | for line in fin: 25 | if not line or line.isspace(): 26 | continue 27 | line = line.split() 28 | for tup in line: 29 | w_count[tup] = w_count.get(tup, 0) + 1 30 | w_count['\n'] += 1 31 | 32 | with open(args.output_map, 'wb') as f: 33 | pickle.dump(w_count, f) -------------------------------------------------------------------------------- /language-model/recipes.md: -------------------------------------------------------------------------------- 1 | # Pre-process 2 | 3 | ``` 4 | python pre_word_ada/gene_map.py --input_folder /data/billionwords/1-billion-word-language-modeling-benchmark/training-monolingual.tokenized.shuffled --output_map /data/billionwords/1b_map.pk 5 | 6 | python pre_word_ada/encode_data2folder.py --train_folder /data/billionwords/1-billion-word-language-modeling-benchmark/training-monolingual.tokenized.shuffled --test_folder /data/billionwords/1-billion-word-language-modeling-benchmark/heldout-monolingual.tokenized.shuffled --input_map /data/billionwords/1b_map.pk --output_folder /data/billionwords/one_billion/ 7 | ``` 8 | 9 | # Training 10 | 11 | ## Adam 12 | ``` 13 | python train_1bw.py --dataset_folder /data/billionwords/one_billion/ --lr 0.001 --checkpath ./cps/gadam/ --model_name adam --update Adam 14 | ``` 15 | 16 | ## RAdam 17 | ``` 18 | python train_1bw.py --dataset_folder /data/billionwords/one_billion/ --lr 0.001 --checkpath ./cps/gadam/ --model_name radam --update RAdam 19 | ``` 20 | -------------------------------------------------------------------------------- /language-model/train_1bw.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import datetime 3 | import time 4 | import torch 5 | import torch.autograd as autograd 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import codecs 9 | import pickle 10 | import math 11 | 12 | from model_word_ada.LM import LM 13 | from model_word_ada.basic import BasicRNN 14 | from model_word_ada.ddnet import DDRNN 15 | from model_word_ada.radam import RAdam 16 | from model_word_ada.ldnet import LDRNN 17 | from model_word_ada.densenet import DenseRNN 18 | from model_word_ada.dataset import LargeDataset, EvalDataset 19 | from model_word_ada.adaptive import AdaptiveSoftmax 20 | import model_word_ada.utils as utils 21 | 22 | # from tensorboardX import SummaryWriter 23 | # writer = SummaryWriter(logdir='./cps/gadam/log_1bw_full/') 24 | 25 | import argparse 26 | import json 27 | import os 28 | import sys 29 | import itertools 30 | import functools 31 | 32 | def evaluate(data_loader, lm_model, criterion, limited = 76800): 33 | print('evaluating') 34 | lm_model.eval() 35 | 36 | iterator = data_loader.get_tqdm() 37 | 38 | lm_model.init_hidden() 39 | total_loss = 0 40 | total_len = 0 41 | for word_t, label_t in iterator: 42 | label_t = label_t.view(-1) 43 | tmp_len = label_t.size(0) 44 | output = lm_model.log_prob(word_t) 45 | total_loss += tmp_len * utils.to_scalar(criterion(autograd.Variable(output), label_t)) 46 | total_len += tmp_len 47 | 48 | if limited >=0 and total_len > limited: 49 | break 50 | 51 | ppl = math.exp(total_loss / total_len) 52 | print('PPL: ' + str(ppl)) 53 | 54 | return ppl 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument('--dataset_folder', default='/data/billionwords/one_billion/') 59 | parser.add_argument('--load_checkpoint', default='') 60 | parser.add_argument('--batch_size', type=int, default=128) 61 | parser.add_argument('--sequence_length', type=int, default=20) 62 | parser.add_argument('--hid_dim', type=int, default=2048) 63 | parser.add_argument('--word_dim', type=int, default=300) 64 | parser.add_argument('--label_dim', type=int, default=-1) 65 | parser.add_argument('--layer_num', type=int, default=2) 66 | parser.add_argument('--droprate', type=float, default=0.1) 67 | parser.add_argument('--add_relu', action='store_true') 68 | parser.add_argument('--layer_drop', type=float, default=0.5) 69 | parser.add_argument('--gpu', type=int, default=1) 70 | parser.add_argument('--epoch', type=int, default=14) 71 | parser.add_argument('--clip', type=float, default=5) 72 | parser.add_argument('--update', choices=['Adam', 'Adagrad', 'Adadelta', 'RAdam', 'SGD'], default='Adam') 73 | parser.add_argument('--rnn_layer', choices=['Basic', 'DDNet', 'DenseNet', 'LDNet'], default='Basic') 74 | parser.add_argument('--rnn_unit', choices=['gru', 'lstm', 'rnn', 'bnlstm'], default='lstm') 75 | parser.add_argument('--lr', type=float, default=-1) 76 | parser.add_argument('--lr_decay', type=lambda t: [int(tup) for tup in t.split(',')], default=[8]) 77 | parser.add_argument('--cut_off', nargs='+', default=[4000,40000,200000]) 78 | parser.add_argument('--interval', type=int, default=100) 79 | parser.add_argument('--check_interval', type=int, default=4000) 80 | parser.add_argument('--checkpath', default='./cps/gadam/') 81 | parser.add_argument('--model_name', default='adam') 82 | args = parser.parse_args() 83 | 84 | if args.gpu >= 0: 85 | torch.cuda.set_device(args.gpu) 86 | 87 | print('loading dataset') 88 | dataset = pickle.load(open(args.dataset_folder + 'test.pk', 'rb')) 89 | w_map, test_data, range_idx = dataset['w_map'], dataset['test_data'], dataset['range'] 90 | 91 | cut_off = args.cut_off + [len(w_map) + 1] 92 | 93 | train_loader = LargeDataset(args.dataset_folder, range_idx, args.batch_size, args.sequence_length) 94 | test_loader = EvalDataset(test_data, args.batch_size) 95 | 96 | print('building model') 97 | 98 | rnn_map = {'Basic': BasicRNN, 'DDNet': DDRNN, 'DenseNet': DenseRNN, 'LDNet': functools.partial(LDRNN, layer_drop = args.layer_drop)} 99 | rnn_layer = rnn_map[args.rnn_layer](args.layer_num, args.rnn_unit, args.word_dim, args.hid_dim, args.droprate) 100 | 101 | if args.label_dim > 0: 102 | soft_max = AdaptiveSoftmax(args.label_dim, cut_off) 103 | else: 104 | soft_max = AdaptiveSoftmax(rnn_layer.output_dim, cut_off) 105 | 106 | lm_model = LM(rnn_layer, soft_max, len(w_map), args.word_dim, args.droprate, label_dim = args.label_dim, add_relu=args.add_relu) 107 | lm_model.rand_ini() 108 | # lm_model.cuda() 109 | 110 | optim_map = {'Adam' : optim.Adam, 'Adagrad': optim.Adagrad, 'Adadelta': optim.Adadelta, 'RAdam': RAdam, 'SGD': functools.partial(optim.SGD, momentum=0.9)} 111 | if args.lr > 0: 112 | optimizer=optim_map[args.update](lm_model.parameters(), lr=args.lr) 113 | else: 114 | optimizer=optim_map[args.update](lm_model.parameters()) 115 | 116 | if args.load_checkpoint: 117 | if os.path.isfile(args.load_checkpoint): 118 | print("loading checkpoint: '{}'".format(args.load_checkpoint)) 119 | checkpoint_file = torch.load(args.load_checkpoint, map_location=lambda storage, loc: storage) 120 | lm_model.load_state_dict(checkpoint_file['lm_model'], False) 121 | optimizer.load_state_dict(checkpoint_file['opt'], False) 122 | else: 123 | print("no checkpoint found at: '{}'".format(args.load_checkpoint)) 124 | 125 | test_lm = nn.NLLLoss() 126 | 127 | test_lm.cuda() 128 | lm_model.cuda() 129 | 130 | batch_index = 0 131 | epoch_loss = 0 132 | full_epoch_loss = 0 133 | best_train_ppl = float('inf') 134 | cur_lr = args.lr 135 | 136 | try: 137 | for indexs in range(args.epoch): 138 | 139 | print('#' * 89) 140 | print('Start: {}'.format(indexs)) 141 | 142 | iterator = train_loader.get_tqdm() 143 | full_epoch_loss = 0 144 | 145 | lm_model.train() 146 | 147 | for word_t, label_t in iterator: 148 | 149 | if 1 == train_loader.cur_idx: 150 | lm_model.init_hidden() 151 | 152 | label_t = label_t.view(-1) 153 | 154 | lm_model.zero_grad() 155 | loss = lm_model(word_t, label_t) 156 | 157 | loss.backward() 158 | torch.nn.utils.clip_grad_norm(lm_model.parameters(), args.clip) 159 | optimizer.step() 160 | 161 | batch_index += 1 162 | 163 | if 0 == batch_index % args.interval: 164 | s_loss = utils.to_scalar(loss) 165 | # writer.add_scalars('loss_tracking/train_loss', {args.model_name:s_loss}, batch_index) 166 | 167 | epoch_loss += utils.to_scalar(loss) 168 | full_epoch_loss += utils.to_scalar(loss) 169 | if 0 == batch_index % args.check_interval: 170 | epoch_ppl = math.exp(epoch_loss / args.check_interval) 171 | # writer.add_scalars('loss_tracking/train_ppl', {args.model_name: epoch_ppl}, batch_index) 172 | print('epoch_ppl: {} lr: {} @ batch_index: {}'.format(epoch_ppl, cur_lr, batch_index)) 173 | epoch_loss = 0 174 | 175 | if indexs in args.lr_decay and cur_lr > 0: 176 | cur_lr *= 0.1 177 | print('adjust_learning_rate...') 178 | utils.adjust_learning_rate(optimizer, cur_lr) 179 | 180 | test_ppl = evaluate(test_loader, lm_model, test_lm, -1) 181 | 182 | # writer.add_scalars('loss_tracking/test_ppl', {args.model_name: test_ppl}, indexs) 183 | print('test_ppl: {} @ index: {}'.format(test_ppl, indexs)) 184 | 185 | torch.save({'lm_model': lm_model.state_dict(), 'opt':optimizer.state_dict()}, args.checkpath+args.model_name+'.model') 186 | 187 | except KeyboardInterrupt: 188 | 189 | print('Exiting from training early') 190 | test_ppl = evaluate(test_loader, lm_model, test_lm, -1) 191 | # writer.add_scalars('loss_tracking/test_ppl', {args.model_name: test_ppl}, args.epoch) 192 | 193 | torch.save({'lm_model': lm_model.state_dict(), 'opt':optimizer.state_dict()}, args.checkpath+args.model_name+'.model') 194 | 195 | # writer.close() 196 | -------------------------------------------------------------------------------- /nmt/README.md: -------------------------------------------------------------------------------- 1 | # NMT 2 | 3 | This folder is based on the fairseq package ([repo](https://github.com/pytorch/fairseq)). For more details about this code base, please refer to the original repo. 4 | A training [recipe](/nmt/recipes.md) is provided for nmt experiments. 5 | 6 | -------------------------------------------------------------------------------- /nmt/average_checkpoints.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the LICENSE file in 6 | # the root directory of this source tree. An additional grant of patent rights 7 | # can be found in the PATENTS file in the same directory. 8 | 9 | import argparse 10 | import collections 11 | import torch 12 | import os 13 | import re 14 | 15 | 16 | def average_checkpoints(inputs): 17 | """Loads checkpoints from inputs and returns a model with averaged weights. 18 | 19 | Args: 20 | inputs: An iterable of string paths of checkpoints to load from. 21 | 22 | Returns: 23 | A dict of string keys mapping to various values. The 'model' key 24 | from the returned dict should correspond to an OrderedDict mapping 25 | string parameter names to torch Tensors. 26 | """ 27 | params_dict = collections.OrderedDict() 28 | params_keys = None 29 | new_state = None 30 | num_models = len(inputs) 31 | 32 | for f in inputs: 33 | state = torch.load( 34 | f, 35 | map_location=( 36 | lambda s, _: torch.serialization.default_restore_location(s, 'cpu') 37 | ), 38 | ) 39 | # Copies over the settings from the first checkpoint 40 | if new_state is None: 41 | new_state = state 42 | 43 | model_params = state['model'] 44 | 45 | model_params_keys = list(model_params.keys()) 46 | if params_keys is None: 47 | params_keys = model_params_keys 48 | elif params_keys != model_params_keys: 49 | raise KeyError( 50 | 'For checkpoint {}, expected list of params: {}, ' 51 | 'but found: {}'.format(f, params_keys, model_params_keys) 52 | ) 53 | 54 | for k in params_keys: 55 | p = model_params[k] 56 | if isinstance(p, torch.HalfTensor): 57 | p = p.float() 58 | if k not in params_dict: 59 | params_dict[k] = p.clone() 60 | # NOTE: clone() is needed in case of p is a shared parameter 61 | else: 62 | params_dict[k] += p 63 | 64 | averaged_params = collections.OrderedDict() 65 | for k, v in params_dict.items(): 66 | averaged_params[k] = v 67 | averaged_params[k].div_(num_models) 68 | new_state['model'] = averaged_params 69 | return new_state 70 | 71 | 72 | def last_n_checkpoints(paths, n, update_based, upper_bound=None): 73 | assert len(paths) == 1 74 | path = paths[0] 75 | if update_based: 76 | pt_regexp = re.compile(r'checkpoint_\d+_(\d+)\.pt') 77 | else: 78 | pt_regexp = re.compile(r'checkpoint(\d+)\.pt') 79 | files = os.listdir(path) 80 | 81 | entries = [] 82 | for f in files: 83 | m = pt_regexp.fullmatch(f) 84 | if m is not None: 85 | sort_key = int(m.group(1)) 86 | if upper_bound is None or sort_key <= upper_bound: 87 | entries.append((sort_key, m.group(0))) 88 | if len(entries) < n: 89 | raise Exception('Found {} checkpoint files but need at least {}', len(entries), n) 90 | return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)[:n]] 91 | 92 | 93 | def main(): 94 | parser = argparse.ArgumentParser( 95 | description='Tool to average the params of input checkpoints to ' 96 | 'produce a new checkpoint', 97 | ) 98 | # fmt: off 99 | parser.add_argument('--inputs', required=True, nargs='+', 100 | help='Input checkpoint file paths.') 101 | parser.add_argument('--output', required=True, metavar='FILE', 102 | help='Write the new checkpoint containing the averaged weights to this path.') 103 | num_group = parser.add_mutually_exclusive_group() 104 | num_group.add_argument('--num-epoch-checkpoints', type=int, 105 | help='if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, ' 106 | 'and average last this many of them.') 107 | num_group.add_argument('--num-update-checkpoints', type=int, 108 | help='if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, ' 109 | 'and average last this many of them.') 110 | parser.add_argument('--checkpoint-upper-bound', type=int, 111 | help='when using --num-epoch-checkpoints, this will set an upper bound on which checkpoint to use, ' 112 | 'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.') 113 | # fmt: on 114 | args = parser.parse_args() 115 | print(args) 116 | 117 | num = None 118 | is_update_based = False 119 | if args.num_update_checkpoints is not None: 120 | num = args.num_update_checkpoints 121 | is_update_based = True 122 | elif args.num_epoch_checkpoints is not None: 123 | num = args.num_epoch_checkpoints 124 | 125 | assert args.checkpoint_upper_bound is None or args.num_epoch_checkpoints is not None, \ 126 | '--checkpoint-upper-bound requires --num-epoch-checkpoints' 127 | assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \ 128 | 'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints' 129 | 130 | if num is not None: 131 | args.inputs = last_n_checkpoints( 132 | args.inputs, num, is_update_based, upper_bound=args.checkpoint_upper_bound, 133 | ) 134 | print('averaging checkpoints: ', args.inputs) 135 | 136 | new_state = average_checkpoints(args.inputs) 137 | torch.save(new_state, args.output) 138 | print('Finished writing averaged checkpoint to {}.'.format(args.output)) 139 | 140 | 141 | if __name__ == '__main__': 142 | main() 143 | -------------------------------------------------------------------------------- /nmt/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "Model path" $SAVEDIR 3 | GPUDEV=${2:-0} 4 | SAVEDIR=${1} 5 | MODELDIR=$SAVEDIR/model_ed.pt 6 | if [ -f $MODELDIR ]; then 7 | echo $MODELDIR "already exists" 8 | else 9 | echo "Start averaging model" 10 | python average_checkpoints.py --inputs $SAVEDIR --num-epoch-checkpoints 10 --output $MODELDIR | grep 'Finish' 11 | echo "End averaging model" 12 | fi 13 | 14 | CUDA_VISIBLE_DEVICES=$GPUDEV fairseq-generate data-bin/iwslt14.tokenized.de-en \ 15 | --path $MODELDIR \ 16 | --batch-size 128 --beam 5 --remove-bpe \ 17 | --user-dir ./my_module 2>&1 | grep BLEU4 18 | 19 | # CUDA_VISIBLE_DEVICES=$GPUDEV fairseq-generate data-bin/iwslt14.tokenized.en-de \ 20 | # --path $MODELDIR \ 21 | # --batch-size 128 --beam 5 --remove-bpe \ 22 | # --user-dir ./my_module 2>&1 | grep BLEU4 23 | -------------------------------------------------------------------------------- /nmt/my_module/__init__.py: -------------------------------------------------------------------------------- 1 | from .radam import * 2 | from .adam2 import * 3 | from .linear_schedule import * 4 | from .poly_schedule import * 5 | from .novograd import * 6 | -------------------------------------------------------------------------------- /nmt/my_module/adam2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | import types 10 | 11 | import torch 12 | import torch.optim 13 | 14 | from fairseq.optim import FairseqOptimizer, register_optimizer 15 | 16 | from tensorboardX import SummaryWriter 17 | # writer = SummaryWriter(logdir='./log/ada/') 18 | # # writer = SummaryWriter(logdir='./log/wmt/') 19 | 20 | iter_idx = 0 21 | 22 | @register_optimizer('adam2') 23 | class FairseqAdam2(FairseqOptimizer): 24 | 25 | def __init__(self, args, params): 26 | super().__init__(args, params) 27 | 28 | self._optimizer = Adam2(params, **self.optimizer_config) 29 | self._optimizer.name = args.tb_tag + '_' + self._optimizer.name 30 | 31 | @staticmethod 32 | def add_args(parser): 33 | """Add optimizer-specific arguments to the parser.""" 34 | # fmt: off 35 | parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B', 36 | help='betas for Adam optimizer') 37 | parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D', 38 | help='epsilon for Adam optimizer') 39 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 40 | help='weight decay') 41 | parser.add_argument('--tb-tag', default="", type=str, 42 | help='tb tag') 43 | parser.add_argument('--amsgrad', action='store_true') 44 | parser.add_argument('--adam-freeze', default=5000, type=float) 45 | parser.add_argument('--adam-no-correction1', action='store_true') 46 | # fmt: on 47 | 48 | @property 49 | def optimizer_config(self): 50 | """ 51 | Return a kwarg dictionary that will be used to override optimizer 52 | args stored in checkpoints. This allows us to load a checkpoint and 53 | resume training using a different set of optimizer args, e.g., with a 54 | different learning rate. 55 | """ 56 | return { 57 | 'lr': self.args.lr[0], 58 | 'betas': eval(self.args.adam_betas), 59 | 'eps': self.args.adam_eps, 60 | 'weight_decay': self.args.weight_decay, 61 | 'amsgrad': self.args.amsgrad, 62 | 'adam_freeze': self.args.adam_freeze, 63 | 'adam_no_correction1': self.args.adam_no_correction1, 64 | } 65 | 66 | 67 | class Adam2(torch.optim.Optimizer): 68 | 69 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 70 | weight_decay=0, amsgrad=False, adam_freeze=5000, adam_no_correction1=False): 71 | defaults = dict(lr=lr, betas=betas, eps=eps, 72 | weight_decay=weight_decay, amsgrad=amsgrad, adam_freeze=adam_freeze, adam_no_correction1=adam_no_correction1) 73 | self.name = '{}_{}_{}'.format(lr, betas[0], betas[1]) 74 | super(Adam2, self).__init__(params, defaults) 75 | 76 | @property 77 | def supports_memory_efficient_fp16(self): 78 | return True 79 | 80 | def step(self, closure=None): 81 | """Performs a single optimization step. 82 | 83 | Arguments: 84 | closure (callable, optional): A closure that reevaluates the model 85 | and returns the loss. 86 | """ 87 | global iter_idx 88 | iter_idx += 1 89 | grad_list = list() 90 | mom_list = list() 91 | mom_2rd_list = list() 92 | 93 | loss = None 94 | if closure is not None: 95 | loss = closure() 96 | 97 | for group in self.param_groups: 98 | 99 | # if 'adam_1k' in self.name: 100 | # writer_iter = iter_idx - group['adam_freeze'] 101 | # else: 102 | # writer_iter = iter_idx 103 | 104 | for p in group['params']: 105 | if p.grad is None: 106 | continue 107 | grad = p.grad.data.float() 108 | if grad.is_sparse: 109 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 110 | amsgrad = group['amsgrad'] 111 | 112 | p_data_fp32 = p.data.float() 113 | 114 | state = self.state[p] 115 | 116 | if len(state) == 0: 117 | state['step'] = 0 118 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 119 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 120 | if amsgrad: 121 | state['max_exp_avg_sq'] = torch.zeros_like(p_data_fp32) 122 | else: 123 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 124 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 125 | if amsgrad: 126 | state['max_exp_avg_sq'] = state['max_exp_avg_sq'].type_as(p_data_fp32) 127 | 128 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 129 | if amsgrad: 130 | max_exp_avg_sq = state['max_exp_avg_sq'] 131 | beta1, beta2 = group['betas'] 132 | 133 | state['step'] += 1 134 | exp_avg_sq.mul_(beta2).addcmul_(1-beta2, grad, grad) 135 | 136 | denom = exp_avg_sq.sqrt().add_(group['eps']) 137 | 138 | if group['adam_no_correction1']: 139 | bias_correction1 = 1 140 | else: 141 | bias_correction1 = (1 - beta1 ** state['step']) 142 | 143 | bias_correction2 = (1 - beta2 ** state['step'])**0.5 144 | step_size = group['lr'] * bias_correction2 / bias_correction1 145 | 146 | 147 | if 'adam_1k' not in self.name or state['step'] > group['adam_freeze']: 148 | if group['weight_decay'] != 0: 149 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 150 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 151 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 152 | p.data.copy_(p_data_fp32) 153 | 154 | # if writer_iter > 0 and writer_iter % 300 == 0 or writer_iter in [1, 5, 10, 25, 50, 75, 100, 150, 200]: 155 | # grad_list.extend( grad.abs().add_(1e-9).log().view(-1).tolist() ) 156 | # mom_list.extend( exp_avg.abs().add_(1e-9).log().view(-1).tolist() ) 157 | # mom_2rd_list.extend( exp_avg_sq.abs().add_(1e-9).log().view(-1).tolist() ) 158 | 159 | # if writer_iter > 0 and writer_iter % 300 == 0 or writer_iter in [1, 5, 10, 25, 50, 75, 100, 150, 200]: 160 | # writer.add_histogram('grad/{}'.format(self.name), grad_list, writer_iter) 161 | # writer.add_histogram('mom/{}'.format(self.name), mom_list, writer_iter) 162 | # writer.add_histogram('mom_sq/{}'.format(self.name), mom_2rd_list, writer_iter) 163 | 164 | return loss 165 | -------------------------------------------------------------------------------- /nmt/my_module/linear_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler 9 | 10 | 11 | @register_lr_scheduler('linear') 12 | class LinearSchedule(FairseqLRScheduler): 13 | """Decay the LR based on the inverse square root of the update number. 14 | 15 | We also support a warmup phase where we linearly increase the learning rate 16 | from some initial learning rate (``--warmup-init-lr``) until the configured 17 | learning rate (``--lr``). Thereafter we decay proportional to the number of 18 | updates, with a decay factor set to align with the configured learning rate. 19 | 20 | During warmup:: 21 | 22 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) 23 | lr = lrs[update_num] 24 | 25 | After warmup:: 26 | 27 | decay_factor = args.lr * sqrt(args.warmup_updates) 28 | lr = decay_factor / sqrt(update_num) 29 | """ 30 | 31 | def __init__(self, args, optimizer): 32 | super().__init__(args, optimizer) 33 | if len(args.lr) > 1: 34 | raise ValueError( 35 | 'Cannot use a fixed learning rate schedule with inverse_sqrt.' 36 | ' Consider --lr-scheduler=fixed instead.' 37 | ) 38 | warmup_end_lr = args.lr[0] 39 | if args.warmup_init_lr < 0: 40 | args.warmup_init_lr = warmup_end_lr 41 | 42 | # linearly warmup for the first args.warmup_updates 43 | self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates 44 | 45 | # then, decay prop. to the inverse square root of the update number 46 | # self.warmup_end_lr = warmup_end_lr * args.warmup_updates**0.5 47 | self.min_lr = args.min_lr 48 | self.warmup_end_lr = warmup_end_lr - self.min_lr 49 | 50 | # initial learning rate 51 | self.lr = args.warmup_init_lr 52 | self.optimizer.set_lr(self.lr) 53 | 54 | self.max_update = args.max_update - args.warmup_updates 55 | 56 | @staticmethod 57 | def add_args(parser): 58 | """Add arguments to the parser for this LR scheduler.""" 59 | # fmt: off 60 | parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N', 61 | help='warmup the learning rate linearly for the first N updates') 62 | parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', 63 | help='initial learning rate during warmup phase; default is args.lr') 64 | # fmt: on 65 | 66 | def step(self, epoch, val_loss=None): 67 | """Update the learning rate at the end of the given epoch.""" 68 | super().step(epoch, val_loss) 69 | # we don't change the learning rate at epoch boundaries 70 | return self.optimizer.get_lr() 71 | 72 | def step_update(self, num_updates): 73 | """Update the learning rate after each update.""" 74 | if num_updates < self.args.warmup_updates: 75 | self.lr = self.args.warmup_init_lr + num_updates*self.lr_step 76 | else: 77 | self.lr = self.warmup_end_lr * (1 - (num_updates - self.args.warmup_updates) / self.max_update) + self.min_lr 78 | self.optimizer.set_lr(self.lr) 79 | return self.lr 80 | -------------------------------------------------------------------------------- /nmt/my_module/novograd.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | import types 10 | 11 | import torch 12 | import torch.optim 13 | 14 | from fairseq.optim import FairseqOptimizer, register_optimizer 15 | 16 | 17 | iter_idx = 0 18 | 19 | @register_optimizer('novograd') 20 | class FairseqNovograd(FairseqOptimizer): 21 | 22 | def __init__(self, args, params): 23 | super().__init__(args, params) 24 | 25 | self._optimizer = Novograd(params, **self.optimizer_config) 26 | self._optimizer.name = args.tb_tag + '_' + self._optimizer.name 27 | 28 | @staticmethod 29 | def add_args(parser): 30 | """Add optimizer-specific arguments to the parser.""" 31 | # fmt: off 32 | parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B', 33 | help='betas for Adam optimizer') 34 | parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D', 35 | help='epsilon for Adam optimizer') 36 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 37 | help='weight decay') 38 | parser.add_argument('--tb-tag', default="", type=str, 39 | help='tb tag') 40 | parser.add_argument('--amsgrad', action='store_true') 41 | parser.add_argument('--adam-freeze', default=5000, type=float) 42 | parser.add_argument('--adam-no-correction1', action='store_true') 43 | # fmt: on 44 | 45 | @property 46 | def optimizer_config(self): 47 | """ 48 | Return a kwarg dictionary that will be used to override optimizer 49 | args stored in checkpoints. This allows us to load a checkpoint and 50 | resume training using a different set of optimizer args, e.g., with a 51 | different learning rate. 52 | """ 53 | return { 54 | 'lr': self.args.lr[0], 55 | 'betas': eval(self.args.adam_betas), 56 | 'eps': self.args.adam_eps, 57 | 'weight_decay': self.args.weight_decay, 58 | 'amsgrad': self.args.amsgrad, 59 | 'adam_freeze': self.args.adam_freeze, 60 | 'adam_no_correction1': self.args.adam_no_correction1, 61 | } 62 | 63 | 64 | class Novograd(torch.optim.Optimizer): 65 | 66 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 67 | weight_decay=0, amsgrad=False, adam_freeze=5000, adam_no_correction1=False): 68 | defaults = dict(lr=lr, betas=betas, eps=eps, 69 | weight_decay=weight_decay, amsgrad=amsgrad, adam_freeze=adam_freeze, adam_no_correction1=adam_no_correction1) 70 | self.name = '{}_{}_{}'.format(lr, betas[0], betas[1]) 71 | super(Novograd, self).__init__(params, defaults) 72 | 73 | @property 74 | def supports_memory_efficient_fp16(self): 75 | return True 76 | 77 | def step(self, closure=None): 78 | """Performs a single optimization step. 79 | 80 | Arguments: 81 | closure (callable, optional): A closure that reevaluates the model 82 | and returns the loss. 83 | """ 84 | global iter_idx 85 | iter_idx += 1 86 | grad_list = list() 87 | mom_list = list() 88 | mom_2rd_list = list() 89 | 90 | loss = None 91 | if closure is not None: 92 | loss = closure() 93 | 94 | for group in self.param_groups: 95 | 96 | 97 | for p in group['params']: 98 | if p.grad is None: 99 | continue 100 | grad = p.grad.data.float() 101 | if grad.is_sparse: 102 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 103 | amsgrad = group['amsgrad'] 104 | 105 | p_data_fp32 = p.data.float() 106 | 107 | state = self.state[p] 108 | 109 | if len(state) == 0: 110 | state['step'] = 0 111 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 112 | state['exp_avg_sq'] = 0 113 | if amsgrad: 114 | state['max_exp_avg_sq'] = torch.zeros_like(p_data_fp32) 115 | else: 116 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 117 | if amsgrad: 118 | state['max_exp_avg_sq'] = state['max_exp_avg_sq'].type_as(p_data_fp32) 119 | 120 | exp_avg = state['exp_avg'] 121 | if amsgrad: 122 | max_exp_avg_sq = state['max_exp_avg_sq'] 123 | beta1, beta2 = group['betas'] 124 | 125 | state['step'] += 1 126 | state['exp_avg_sq'] = state['exp_avg_sq']*beta2 + (1-beta2)*grad.norm().item()**2 127 | 128 | denom = state['exp_avg_sq']**0.5 + group['eps'] 129 | 130 | step_size = group['lr'] 131 | 132 | 133 | exp_avg.mul_(beta1).add_( grad/denom ) 134 | if group['weight_decay'] != 0: 135 | exp_avg.add_( group['weight_decay'], p_data_fp32) 136 | p_data_fp32.add_(-step_size, exp_avg) 137 | p.data.copy_(p_data_fp32) 138 | 139 | return loss 140 | -------------------------------------------------------------------------------- /nmt/my_module/poly_schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler 9 | 10 | 11 | @register_lr_scheduler('poly') 12 | class PolySchedule(FairseqLRScheduler): 13 | """Decay the LR based on the inverse square root of the update number. 14 | 15 | We also support a warmup phase where we linearly increase the learning rate 16 | from some initial learning rate (``--warmup-init-lr``) until the configured 17 | learning rate (``--lr``). Thereafter we decay proportional to the number of 18 | updates, with a decay factor set to align with the configured learning rate. 19 | 20 | During warmup:: 21 | 22 | lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) 23 | lr = lrs[update_num] 24 | 25 | After warmup:: 26 | 27 | decay_factor = args.lr * sqrt(args.warmup_updates) 28 | lr = decay_factor / sqrt(update_num) 29 | """ 30 | 31 | def __init__(self, args, optimizer): 32 | super().__init__(args, optimizer) 33 | if len(args.lr) > 1: 34 | raise ValueError( 35 | 'Cannot use a fixed learning rate schedule with inverse_sqrt.' 36 | ' Consider --lr-scheduler=fixed instead.' 37 | ) 38 | 39 | # then, decay prop. to the inverse square root of the update number 40 | # self.warmup_end_lr = warmup_end_lr * args.warmup_updates**0.5 41 | self.min_lr = args.min_lr 42 | 43 | # initial learning rate 44 | self.lr = args.lr[0] 45 | self.optimizer.set_lr(self.lr) 46 | 47 | self.max_update = args.max_update 48 | 49 | @staticmethod 50 | def add_args(parser): 51 | """Add arguments to the parser for this LR scheduler.""" 52 | # fmt: off 53 | parser.add_argument('--poly-pow', default=2, type=float, metavar='N', 54 | help='ploy power') 55 | 56 | def step(self, epoch, val_loss=None): 57 | """Update the learning rate at the end of the given epoch.""" 58 | super().step(epoch, val_loss) 59 | # we don't change the learning rate at epoch boundaries 60 | return self.optimizer.get_lr() 61 | 62 | def step_update(self, num_updates): 63 | """Update the learning rate after each update.""" 64 | self.lr = (self.args.lr[0] - self.min_lr)* (1 - num_updates / self.max_update)**self.args.poly_pow + self.min_lr 65 | self.optimizer.set_lr(self.lr) 66 | return self.lr 67 | -------------------------------------------------------------------------------- /nmt/my_module/radam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import math 9 | import types 10 | 11 | import torch 12 | import torch.optim 13 | # from ipdb import set_trace 14 | from fairseq.optim import FairseqOptimizer, register_optimizer 15 | 16 | # from tensorboardX import SummaryWriter 17 | # # writer = SummaryWriter(logdir='./log/wmt/') 18 | # writer = SummaryWriter(logdir='./log/ada/') 19 | # iter_idx = 0 20 | 21 | @register_optimizer('radam') 22 | class FairseqRAdam(FairseqOptimizer): 23 | 24 | def __init__(self, args, params): 25 | super().__init__(args, params) 26 | 27 | self._optimizer = RAdam(params, **self.optimizer_config) 28 | self._optimizer.name = args.tb_tag + '_' + self._optimizer.name 29 | 30 | @staticmethod 31 | def add_args(parser): 32 | """Add optimizer-specific arguments to the parser.""" 33 | # fmt: off 34 | parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B', 35 | help='betas for Adam optimizer') 36 | parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D', 37 | help='epsilon for Adam optimizer') 38 | parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', 39 | help='weight decay') 40 | parser.add_argument('--tb-tag', default="", type=str, 41 | help='tb tag') 42 | # fmt: on 43 | 44 | @property 45 | def optimizer_config(self): 46 | """ 47 | Return a kwarg dictionary that will be used to override optimizer 48 | args stored in checkpoints. This allows us to load a checkpoint and 49 | resume training using a different set of optimizer args, e.g., with a 50 | different learning rate. 51 | """ 52 | return { 53 | 'lr': self.args.lr[0], 54 | 'betas': eval(self.args.adam_betas), 55 | 'eps': self.args.adam_eps, 56 | 'weight_decay': self.args.weight_decay, 57 | } 58 | 59 | class RAdam(torch.optim.Optimizer): 60 | 61 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 62 | weight_decay=0, amsgrad=False): 63 | defaults = dict(lr=lr, betas=betas, eps=eps, 64 | weight_decay=weight_decay, amsgrad=amsgrad) 65 | 66 | self.name = '{}_{}_{}'.format(lr, betas[0], betas[1]) 67 | super(RAdam, self).__init__(params, defaults) 68 | 69 | @property 70 | def supports_memory_efficient_fp16(self): 71 | return True 72 | 73 | def step(self, closure=None): 74 | """Performs a single optimization step. 75 | 76 | Arguments: 77 | closure (callable, optional): A closure that reevaluates the model 78 | and returns the loss. 79 | """ 80 | global iter_idx 81 | iter_idx += 1 82 | grad_list = list() 83 | mom_list = list() 84 | mom_2rd_list = list() 85 | assert 'adam_1k' not in self.name 86 | writer_iter = iter_idx 87 | 88 | loss = None 89 | if closure is not None: 90 | loss = closure() 91 | 92 | for group in self.param_groups: 93 | 94 | for p in group['params']: 95 | if p.grad is None: 96 | continue 97 | grad = p.grad.data.float() 98 | if grad.is_sparse: 99 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 100 | amsgrad = group['amsgrad'] 101 | 102 | p_data_fp32 = p.data.float() 103 | 104 | state = self.state[p] 105 | 106 | if len(state) == 0: 107 | state['step'] = 0 108 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 109 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 110 | else: 111 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 112 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 113 | 114 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 115 | beta1, beta2 = group['betas'] 116 | 117 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 118 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 119 | 120 | state['step'] += 1 121 | 122 | beta2_t = beta2 ** state['step'] 123 | N_sma_max = 2 / (1 - beta2) - 1 124 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 125 | 126 | if group['weight_decay'] != 0: 127 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 128 | 129 | # more conservative since it's an approximated value 130 | if N_sma >= 5: 131 | step_size = group['lr'] * math.sqrt((1 - beta2_t ) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) * (N_sma_max) / N_sma / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 132 | denom = exp_avg_sq.sqrt().add_(group['eps']) 133 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 134 | else: 135 | step_size = group['lr'] / (1 - beta1 ** state['step']) 136 | p_data_fp32.add_(-step_size, exp_avg) 137 | 138 | p.data.copy_(p_data_fp32) 139 | 140 | # if writer_iter > 0 and writer_iter % 300 == 0 or writer_iter in [1, 5, 10, 25, 50, 75, 100, 150, 200]: 141 | # grad_list.extend( grad.abs().add_(1e-9).log().view(-1).tolist() ) 142 | # mom_list.extend( exp_avg.abs().add_(1e-9).log().view(-1).tolist() ) 143 | # mom_2rd_list.extend( exp_avg_sq.abs().add_(1e-9).log().view(-1).tolist() ) 144 | 145 | # if writer_iter > 0 and writer_iter % 300 == 0 or writer_iter in [1, 5, 10, 25, 50, 75, 100, 150, 200]: 146 | # writer.add_histogram('grad/{}'.format(self.name), grad_list, writer_iter) 147 | # writer.add_histogram('mom/{}'.format(self.name), mom_list, writer_iter) 148 | # writer.add_histogram('mom_sq/{}'.format(self.name), mom_2rd_list, writer_iter) 149 | 150 | return loss 151 | -------------------------------------------------------------------------------- /nmt/recipes.md: -------------------------------------------------------------------------------- 1 | 2 | # Adam with warmup 3 | 4 | ``` 5 | CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en -a transformer_iwslt_de_en --optimizer adam2 --lr 0.0003 -s de -t en --label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 --warmup-init-lr 1e-8 --min-lr '1e-09' --lr-scheduler linear --weight-decay 0.0001 --criterion label_smoothed_cross_entropy --max-update 70000 --warmup-updates 4000 --adam-betas '(0.9, 0.999)' --save-dir /cps/gadam/nmt/adam_warmup_f_0 --tb-tag adam_warmup_f_0 --user-dir ./my_module --restore-file x.pt 6 | 7 | bash eval.sh /cps/gadam/nmt/adam_warmup_f_0 0 >> results_f_5.txt 8 | 9 | for SEED in 1111 2222 3333 4444 10 | do 11 | 12 | CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en -a transformer_iwslt_de_en --optimizer adam2 --lr 0.0003 -s de -t en --label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 --warmup-init-lr 1e-8 --min-lr '1e-09' --lr-scheduler linear --weight-decay 0.0001 --criterion label_smoothed_cross_entropy --max-update 70000 --warmup-updates 4000 --adam-betas '(0.9, 0.999)' --save-dir /cps/gadam/nmt/adam_warmup_f_$SEED --tb-tag adam_warmup_f_$SEED --user-dir ./my_module --restore-file x.pt --seed $SEED 13 | 14 | bash eval.sh /cps/gadam/nmt/adam_warmup_f_$SEED 0 >> results_f_5.txt 15 | done 16 | ``` 17 | 18 | # Adam-2k 19 | 20 | ``` 21 | CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en -a transformer_iwslt_de_en --optimizer adam2 --lr 0.0003086 -s de -t en --label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 --min-lr '1e-09' --lr-scheduler linear --weight-decay 0.0001 --criterion label_smoothed_cross_entropy --max-update 72000 --warmup-updates 1 --adam-betas '(0.9, 0.999)' --save-dir /cps/gadam/nmt/adam_1k --tb-tag adam_1k --user-dir ./my_module --fp16 --restore-file x.pt --adam-freeze 2000 22 | ``` 23 | 24 | # Adam-eps 25 | 26 | ``` 27 | CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en -a transformer_iwslt_de_en --optimizer adam2 --lr 0.0003 -s de -t en --label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 --min-lr '1e-09' --lr-scheduler linear --weight-decay 0.0001 --criterion label_smoothed_cross_entropy --max-update 70000 --warmup-updates 1 --adam-betas '(0.9, 0.999)' --save-dir /cps/gadam/nmt/adam_eps --tb-tag adam_eps --user-dir ./my_module --fp16 --adam-eps 1e-4 --restore-file x.pt 28 | 29 | ``` 30 | 31 | # RAdam 32 | 33 | ``` 34 | CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en -a transformer_iwslt_de_en --optimizer radam --lr 0.0003 -s de -t en --label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 --min-lr '1e-09' --lr-scheduler linear --weight-decay 0.0001 --criterion label_smoothed_cross_entropy --max-update 70000 --warmup-updates 1 --adam-betas '(0.9, 0.999)' --save-dir /cps/gadam/nmt/radam_0 --tb-tag radam_0 --user-dir ./my_module --fp16 --restore-file x.pt 35 | 36 | bash eval.sh /cps/gadam/nmt/radam_0 0 >> results_f_5.txt 37 | 38 | for SEED in 1111 2222 3333 4444 39 | do 40 | CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en -a transformer_iwslt_de_en --optimizer radam --lr 0.0003 -s de -t en --label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 --min-lr '1e-09' --lr-scheduler linear --weight-decay 0.0001 --criterion label_smoothed_cross_entropy --max-update 70000 --warmup-updates 1 --adam-betas '(0.9, 0.9995)' --save-dir /cps/gadam/nmt/radam_$SEED --tb-tag radam_$SEED --user-dir ./my_module --fp16 --restore-file x.pt --seed $SEED 41 | 42 | bash eval.sh /cps/gadam/nmt/radam_$SEED 0 >> results_f_5.txt 43 | done 44 | ``` 45 | 46 | # Novograd 47 | We also implemented [novograd](https://arxiv.org/pdf/1905.11286.pdf), which claims no warmup is requried. 48 | We tried the following settings with with lr=0.03, 0.0003, 0.00003, 0.00001, none of these works without warmup. 49 | 50 | ``` 51 | CUDA_VISIBLE_DEVICES=0 fairseq-train ./data-bin/iwslt14.tokenized.de-en -a transformer_iwslt_de_en --optimizer novograd --lr 0.0003 -s en -t de --label-smoothing 0.1 --dropout 0.3 --max-tokens 4000 --min-lr '1e-09' --lr-scheduler poly --weight-decay 5e-5 --criterion label_smoothed_cross_entropy --max-update 70000 --adam-betas '(0.9, 0.999)' --save-dir /ckp/nmt/novograd --tb-tag novograd --user-dir ./my_module --fp16 --restore-file x.pt 52 | ``` 53 | -------------------------------------------------------------------------------- /radam/__init__.py: -------------------------------------------------------------------------------- 1 | from .radam import RAdam, PlainRAdam, AdamW 2 | -------------------------------------------------------------------------------- /radam/radam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | class RAdam(Optimizer): 6 | 7 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False): 8 | if not 0.0 <= lr: 9 | raise ValueError("Invalid learning rate: {}".format(lr)) 10 | if not 0.0 <= eps: 11 | raise ValueError("Invalid epsilon value: {}".format(eps)) 12 | if not 0.0 <= betas[0] < 1.0: 13 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 14 | if not 0.0 <= betas[1] < 1.0: 15 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 16 | 17 | self.degenerated_to_sgd = degenerated_to_sgd 18 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 19 | for param in params: 20 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 21 | param['buffer'] = [[None, None, None] for _ in range(10)] 22 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) 23 | super(RAdam, self).__init__(params, defaults) 24 | 25 | def __setstate__(self, state): 26 | super(RAdam, self).__setstate__(state) 27 | 28 | def step(self, closure=None): 29 | 30 | loss = None 31 | if closure is not None: 32 | loss = closure() 33 | 34 | for group in self.param_groups: 35 | 36 | for p in group['params']: 37 | if p.grad is None: 38 | continue 39 | grad = p.grad.data.float() 40 | if grad.is_sparse: 41 | raise RuntimeError('RAdam does not support sparse gradients') 42 | 43 | p_data_fp32 = p.data.float() 44 | 45 | state = self.state[p] 46 | 47 | if len(state) == 0: 48 | state['step'] = 0 49 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 50 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 51 | else: 52 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 53 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 54 | 55 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 56 | beta1, beta2 = group['betas'] 57 | 58 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 59 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 60 | 61 | state['step'] += 1 62 | buffered = group['buffer'][int(state['step'] % 10)] 63 | if state['step'] == buffered[0]: 64 | N_sma, step_size = buffered[1], buffered[2] 65 | else: 66 | buffered[0] = state['step'] 67 | beta2_t = beta2 ** state['step'] 68 | N_sma_max = 2 / (1 - beta2) - 1 69 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 70 | buffered[1] = N_sma 71 | 72 | # more conservative since it's an approximated value 73 | if N_sma >= 5: 74 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 75 | elif self.degenerated_to_sgd: 76 | step_size = 1.0 / (1 - beta1 ** state['step']) 77 | else: 78 | step_size = -1 79 | buffered[2] = step_size 80 | 81 | # more conservative since it's an approximated value 82 | if N_sma >= 5: 83 | if group['weight_decay'] != 0: 84 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 85 | denom = exp_avg_sq.sqrt().add_(group['eps']) 86 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 87 | p.data.copy_(p_data_fp32) 88 | elif step_size > 0: 89 | if group['weight_decay'] != 0: 90 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 91 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 92 | p.data.copy_(p_data_fp32) 93 | 94 | return loss 95 | 96 | class PlainRAdam(Optimizer): 97 | 98 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=False): 99 | if not 0.0 <= lr: 100 | raise ValueError("Invalid learning rate: {}".format(lr)) 101 | if not 0.0 <= eps: 102 | raise ValueError("Invalid epsilon value: {}".format(eps)) 103 | if not 0.0 <= betas[0] < 1.0: 104 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 105 | if not 0.0 <= betas[1] < 1.0: 106 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 107 | 108 | self.degenerated_to_sgd = degenerated_to_sgd 109 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 110 | 111 | super(PlainRAdam, self).__init__(params, defaults) 112 | 113 | def __setstate__(self, state): 114 | super(PlainRAdam, self).__setstate__(state) 115 | 116 | def step(self, closure=None): 117 | 118 | loss = None 119 | if closure is not None: 120 | loss = closure() 121 | 122 | for group in self.param_groups: 123 | 124 | for p in group['params']: 125 | if p.grad is None: 126 | continue 127 | grad = p.grad.data.float() 128 | if grad.is_sparse: 129 | raise RuntimeError('RAdam does not support sparse gradients') 130 | 131 | p_data_fp32 = p.data.float() 132 | 133 | state = self.state[p] 134 | 135 | if len(state) == 0: 136 | state['step'] = 0 137 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 138 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 139 | else: 140 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 141 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 142 | 143 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 144 | beta1, beta2 = group['betas'] 145 | 146 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 147 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 148 | 149 | state['step'] += 1 150 | beta2_t = beta2 ** state['step'] 151 | N_sma_max = 2 / (1 - beta2) - 1 152 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 153 | 154 | 155 | # more conservative since it's an approximated value 156 | if N_sma >= 5: 157 | if group['weight_decay'] != 0: 158 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 159 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 160 | denom = exp_avg_sq.sqrt().add_(group['eps']) 161 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 162 | p.data.copy_(p_data_fp32) 163 | elif self.degenerated_to_sgd: 164 | if group['weight_decay'] != 0: 165 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 166 | step_size = group['lr'] / (1 - beta1 ** state['step']) 167 | p_data_fp32.add_(-step_size, exp_avg) 168 | p.data.copy_(p_data_fp32) 169 | 170 | return loss 171 | 172 | 173 | class AdamW(Optimizer): 174 | 175 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 176 | if not 0.0 <= lr: 177 | raise ValueError("Invalid learning rate: {}".format(lr)) 178 | if not 0.0 <= eps: 179 | raise ValueError("Invalid epsilon value: {}".format(eps)) 180 | if not 0.0 <= betas[0] < 1.0: 181 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 182 | if not 0.0 <= betas[1] < 1.0: 183 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 184 | 185 | defaults = dict(lr=lr, betas=betas, eps=eps, 186 | weight_decay=weight_decay, warmup = warmup) 187 | super(AdamW, self).__init__(params, defaults) 188 | 189 | def __setstate__(self, state): 190 | super(AdamW, self).__setstate__(state) 191 | 192 | def step(self, closure=None): 193 | loss = None 194 | if closure is not None: 195 | loss = closure() 196 | 197 | for group in self.param_groups: 198 | 199 | for p in group['params']: 200 | if p.grad is None: 201 | continue 202 | grad = p.grad.data.float() 203 | if grad.is_sparse: 204 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 205 | 206 | p_data_fp32 = p.data.float() 207 | 208 | state = self.state[p] 209 | 210 | if len(state) == 0: 211 | state['step'] = 0 212 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 213 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 214 | else: 215 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 216 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 217 | 218 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 219 | beta1, beta2 = group['betas'] 220 | 221 | state['step'] += 1 222 | 223 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 224 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 225 | 226 | denom = exp_avg_sq.sqrt().add_(group['eps']) 227 | bias_correction1 = 1 - beta1 ** state['step'] 228 | bias_correction2 = 1 - beta2 ** state['step'] 229 | 230 | if group['warmup'] > state['step']: 231 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 232 | else: 233 | scheduled_lr = group['lr'] 234 | 235 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 236 | 237 | if group['weight_decay'] != 0: 238 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 239 | 240 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 241 | 242 | p.data.copy_(p_data_fp32) 243 | 244 | return loss 245 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import re 3 | 4 | try: 5 | import torch 6 | has_dev_pytorch = "dev" in torch.__version__ 7 | except ImportError: 8 | has_dev_pytorch = False 9 | 10 | # Base equirements 11 | install_requires = [ 12 | "torch", 13 | ] 14 | 15 | if has_dev_pytorch: # Remove the PyTorch requirement 16 | install_requires = [ 17 | install_require for install_require in install_requires 18 | if "torch" != re.split(r"(=|<|>)", install_require)[0] 19 | ] 20 | 21 | setup( 22 | name='RAdam', 23 | version='0.0.1', 24 | url='https://github.com/LiyuanLucasLiu/RAdam.git', 25 | author='Liyuan Liu', 26 | author_email='llychinalz@gmail.com', 27 | description='Implementation of the RAdam optimization algorithm described in On the Variance of the Adaptive Learning Rate and Beyond (https://arxiv.org/abs/1908.03265)', 28 | packages=find_packages(), 29 | install_requires=install_requires, 30 | ) 31 | --------------------------------------------------------------------------------