├── .gitignore ├── LICENSE.md ├── README.md ├── poetry.lock ├── pyproject.toml ├── tests ├── __init__.py └── test_torch_optim_sparse.py └── torch_optim_sparse ├── __init__.py ├── sparser_adam.py ├── sparser_adamw.py ├── sparser_sgd.py └── sparser_sgdw.py /.gitignore: -------------------------------------------------------------------------------- 1 | # added by gitignore-cli 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # pytype static type analyzer 137 | .pytype/ 138 | 139 | # Cython debug symbols 140 | cython_debug/ 141 | 142 | .python-version 143 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 karl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torch-sparse-optim 2 | 3 | This library implements "sparser" versions of PyTorch optimizers, which only apply momentum and weight decay updates to parameters where the gradients are non-zero. 4 | 5 | It contains four optimizers: 6 | - SparserSGD 7 | - SparserAdam 8 | - SparserSGDW 9 | - SparserAdamW 10 | 11 | The latter two follow the approaches outlined in ["Decoupled Weight Decay Regularization"](https://arxiv.org/abs/1711.05101) by Loshchilov & Hunter from ICLR 2019. 12 | 13 | Except for SGDW, they're all straightforward ports of the existing optimizers from PyTorch, modified only to convert momentum and weight decay to sparse updates. The SGDW optimizer additionally applies a small change to where/how weight decay is applied, as outlined in the paper above. 14 | -------------------------------------------------------------------------------- /poetry.lock: -------------------------------------------------------------------------------- 1 | [[package]] 2 | category = "dev" 3 | description = "An abstract syntax tree for Python with inference support." 4 | name = "astroid" 5 | optional = false 6 | python-versions = ">=3.5" 7 | version = "2.4.2" 8 | 9 | [package.dependencies] 10 | lazy-object-proxy = ">=1.4.0,<1.5.0" 11 | six = ">=1.12,<2.0" 12 | wrapt = ">=1.11,<2.0" 13 | 14 | [package.dependencies.typed-ast] 15 | python = "<3.8" 16 | version = ">=1.4.0,<1.5" 17 | 18 | [[package]] 19 | category = "dev" 20 | description = "Atomic file writes." 21 | marker = "sys_platform == \"win32\"" 22 | name = "atomicwrites" 23 | optional = false 24 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 25 | version = "1.4.0" 26 | 27 | [[package]] 28 | category = "dev" 29 | description = "Classes Without Boilerplate" 30 | name = "attrs" 31 | optional = false 32 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 33 | version = "20.2.0" 34 | 35 | [package.extras] 36 | dev = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "sphinx", "sphinx-rtd-theme", "pre-commit"] 37 | docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"] 38 | tests = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"] 39 | tests_no_zope = ["coverage (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six"] 40 | 41 | [[package]] 42 | category = "dev" 43 | description = "Cross-platform colored terminal text." 44 | marker = "sys_platform == \"win32\"" 45 | name = "colorama" 46 | optional = false 47 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 48 | version = "0.4.3" 49 | 50 | [[package]] 51 | category = "main" 52 | description = "Clean single-source support for Python 3 and 2" 53 | name = "future" 54 | optional = false 55 | python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" 56 | version = "0.18.2" 57 | 58 | [[package]] 59 | category = "dev" 60 | description = "Read metadata from Python packages" 61 | marker = "python_version < \"3.8\"" 62 | name = "importlib-metadata" 63 | optional = false 64 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" 65 | version = "1.7.0" 66 | 67 | [package.dependencies] 68 | zipp = ">=0.5" 69 | 70 | [package.extras] 71 | docs = ["sphinx", "rst.linker"] 72 | testing = ["packaging", "pep517", "importlib-resources (>=1.3)"] 73 | 74 | [[package]] 75 | category = "dev" 76 | description = "A Python utility / library to sort Python imports." 77 | name = "isort" 78 | optional = false 79 | python-versions = ">=3.6,<4.0" 80 | version = "5.5.3" 81 | 82 | [package.extras] 83 | colors = ["colorama (>=0.4.3,<0.5.0)"] 84 | pipfile_deprecated_finder = ["pipreqs", "requirementslib"] 85 | requirements_deprecated_finder = ["pipreqs", "pip-api"] 86 | 87 | [[package]] 88 | category = "dev" 89 | description = "A fast and thorough lazy object proxy." 90 | name = "lazy-object-proxy" 91 | optional = false 92 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 93 | version = "1.4.3" 94 | 95 | [[package]] 96 | category = "dev" 97 | description = "McCabe checker, plugin for flake8" 98 | name = "mccabe" 99 | optional = false 100 | python-versions = "*" 101 | version = "0.6.1" 102 | 103 | [[package]] 104 | category = "dev" 105 | description = "More routines for operating on iterables, beyond itertools" 106 | name = "more-itertools" 107 | optional = false 108 | python-versions = ">=3.5" 109 | version = "8.5.0" 110 | 111 | [[package]] 112 | category = "main" 113 | description = "NumPy is the fundamental package for array computing with Python." 114 | name = "numpy" 115 | optional = false 116 | python-versions = ">=3.6" 117 | version = "1.19.2" 118 | 119 | [[package]] 120 | category = "dev" 121 | description = "Core utilities for Python packages" 122 | name = "packaging" 123 | optional = false 124 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 125 | version = "20.4" 126 | 127 | [package.dependencies] 128 | pyparsing = ">=2.0.2" 129 | six = "*" 130 | 131 | [[package]] 132 | category = "dev" 133 | description = "plugin and hook calling mechanisms for python" 134 | name = "pluggy" 135 | optional = false 136 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 137 | version = "0.13.1" 138 | 139 | [package.dependencies] 140 | [package.dependencies.importlib-metadata] 141 | python = "<3.8" 142 | version = ">=0.12" 143 | 144 | [package.extras] 145 | dev = ["pre-commit", "tox"] 146 | 147 | [[package]] 148 | category = "dev" 149 | description = "library with cross-python path, ini-parsing, io, code, log facilities" 150 | name = "py" 151 | optional = false 152 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 153 | version = "1.9.0" 154 | 155 | [[package]] 156 | category = "dev" 157 | description = "python code static checker" 158 | name = "pylint" 159 | optional = false 160 | python-versions = ">=3.5.*" 161 | version = "2.6.0" 162 | 163 | [package.dependencies] 164 | astroid = ">=2.4.0,<=2.5" 165 | colorama = "*" 166 | isort = ">=4.2.5,<6" 167 | mccabe = ">=0.6,<0.7" 168 | toml = ">=0.7.1" 169 | 170 | [[package]] 171 | category = "dev" 172 | description = "Python parsing module" 173 | name = "pyparsing" 174 | optional = false 175 | python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" 176 | version = "2.4.7" 177 | 178 | [[package]] 179 | category = "dev" 180 | description = "pytest: simple powerful testing with Python" 181 | name = "pytest" 182 | optional = false 183 | python-versions = ">=3.5" 184 | version = "5.4.3" 185 | 186 | [package.dependencies] 187 | atomicwrites = ">=1.0" 188 | attrs = ">=17.4.0" 189 | colorama = "*" 190 | more-itertools = ">=4.0.0" 191 | packaging = "*" 192 | pluggy = ">=0.12,<1.0" 193 | py = ">=1.5.0" 194 | wcwidth = "*" 195 | 196 | [package.dependencies.importlib-metadata] 197 | python = "<3.8" 198 | version = ">=0.12" 199 | 200 | [package.extras] 201 | checkqa-mypy = ["mypy (v0.761)"] 202 | testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] 203 | 204 | [[package]] 205 | category = "dev" 206 | description = "Python 2 and 3 compatibility utilities" 207 | name = "six" 208 | optional = false 209 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" 210 | version = "1.15.0" 211 | 212 | [[package]] 213 | category = "dev" 214 | description = "Python Library for Tom's Obvious, Minimal Language" 215 | name = "toml" 216 | optional = false 217 | python-versions = "*" 218 | version = "0.10.1" 219 | 220 | [[package]] 221 | category = "main" 222 | description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" 223 | name = "torch" 224 | optional = false 225 | python-versions = ">=3.6.1" 226 | version = "1.6.0" 227 | 228 | [package.dependencies] 229 | future = "*" 230 | numpy = "*" 231 | 232 | [[package]] 233 | category = "dev" 234 | description = "a fork of Python 2 and 3 ast modules with type comment support" 235 | marker = "implementation_name == \"cpython\" and python_version < \"3.8\"" 236 | name = "typed-ast" 237 | optional = false 238 | python-versions = "*" 239 | version = "1.4.1" 240 | 241 | [[package]] 242 | category = "dev" 243 | description = "Measures the displayed width of unicode strings in a terminal" 244 | name = "wcwidth" 245 | optional = false 246 | python-versions = "*" 247 | version = "0.2.5" 248 | 249 | [[package]] 250 | category = "dev" 251 | description = "Module for decorators, wrappers and monkey patching." 252 | name = "wrapt" 253 | optional = false 254 | python-versions = "*" 255 | version = "1.12.1" 256 | 257 | [[package]] 258 | category = "dev" 259 | description = "Backport of pathlib-compatible object wrapper for zip files" 260 | marker = "python_version < \"3.8\"" 261 | name = "zipp" 262 | optional = false 263 | python-versions = ">=3.6" 264 | version = "3.1.0" 265 | 266 | [package.extras] 267 | docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] 268 | testing = ["jaraco.itertools", "func-timeout"] 269 | 270 | [metadata] 271 | content-hash = "00b34c3fa2f28e8ce03a95e4a03749bc9466705460266662275fe3ab5bfd4dd2" 272 | python-versions = "^3.7" 273 | 274 | [metadata.files] 275 | astroid = [ 276 | {file = "astroid-2.4.2-py3-none-any.whl", hash = "sha256:bc58d83eb610252fd8de6363e39d4f1d0619c894b0ed24603b881c02e64c7386"}, 277 | {file = "astroid-2.4.2.tar.gz", hash = "sha256:2f4078c2a41bf377eea06d71c9d2ba4eb8f6b1af2135bec27bbbb7d8f12bb703"}, 278 | ] 279 | atomicwrites = [ 280 | {file = "atomicwrites-1.4.0-py2.py3-none-any.whl", hash = "sha256:6d1784dea7c0c8d4a5172b6c620f40b6e4cbfdf96d783691f2e1302a7b88e197"}, 281 | {file = "atomicwrites-1.4.0.tar.gz", hash = "sha256:ae70396ad1a434f9c7046fd2dd196fc04b12f9e91ffb859164193be8b6168a7a"}, 282 | ] 283 | attrs = [ 284 | {file = "attrs-20.2.0-py2.py3-none-any.whl", hash = "sha256:fce7fc47dfc976152e82d53ff92fa0407700c21acd20886a13777a0d20e655dc"}, 285 | {file = "attrs-20.2.0.tar.gz", hash = "sha256:26b54ddbbb9ee1d34d5d3668dd37d6cf74990ab23c828c2888dccdceee395594"}, 286 | ] 287 | colorama = [ 288 | {file = "colorama-0.4.3-py2.py3-none-any.whl", hash = "sha256:7d73d2a99753107a36ac6b455ee49046802e59d9d076ef8e47b61499fa29afff"}, 289 | {file = "colorama-0.4.3.tar.gz", hash = "sha256:e96da0d330793e2cb9485e9ddfd918d456036c7149416295932478192f4436a1"}, 290 | ] 291 | future = [ 292 | {file = "future-0.18.2.tar.gz", hash = "sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d"}, 293 | ] 294 | importlib-metadata = [ 295 | {file = "importlib_metadata-1.7.0-py2.py3-none-any.whl", hash = "sha256:dc15b2969b4ce36305c51eebe62d418ac7791e9a157911d58bfb1f9ccd8e2070"}, 296 | {file = "importlib_metadata-1.7.0.tar.gz", hash = "sha256:90bb658cdbbf6d1735b6341ce708fc7024a3e14e99ffdc5783edea9f9b077f83"}, 297 | ] 298 | isort = [ 299 | {file = "isort-5.5.3-py3-none-any.whl", hash = "sha256:c16eaa7432a1c004c585d79b12ad080c6c421dd18fe27982ca11f95e6898e432"}, 300 | {file = "isort-5.5.3.tar.gz", hash = "sha256:6187a9f1ce8784cbc6d1b88790a43e6083a6302f03e9ae482acc0f232a98c843"}, 301 | ] 302 | lazy-object-proxy = [ 303 | {file = "lazy-object-proxy-1.4.3.tar.gz", hash = "sha256:f3900e8a5de27447acbf900b4750b0ddfd7ec1ea7fbaf11dfa911141bc522af0"}, 304 | {file = "lazy_object_proxy-1.4.3-cp27-cp27m-macosx_10_13_x86_64.whl", hash = "sha256:a2238e9d1bb71a56cd710611a1614d1194dc10a175c1e08d75e1a7bcc250d442"}, 305 | {file = "lazy_object_proxy-1.4.3-cp27-cp27m-win32.whl", hash = "sha256:efa1909120ce98bbb3777e8b6f92237f5d5c8ea6758efea36a473e1d38f7d3e4"}, 306 | {file = "lazy_object_proxy-1.4.3-cp27-cp27m-win_amd64.whl", hash = "sha256:4677f594e474c91da97f489fea5b7daa17b5517190899cf213697e48d3902f5a"}, 307 | {file = "lazy_object_proxy-1.4.3-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:0c4b206227a8097f05c4dbdd323c50edf81f15db3b8dc064d08c62d37e1a504d"}, 308 | {file = "lazy_object_proxy-1.4.3-cp34-cp34m-manylinux1_x86_64.whl", hash = "sha256:d945239a5639b3ff35b70a88c5f2f491913eb94871780ebfabb2568bd58afc5a"}, 309 | {file = "lazy_object_proxy-1.4.3-cp34-cp34m-win32.whl", hash = "sha256:9651375199045a358eb6741df3e02a651e0330be090b3bc79f6d0de31a80ec3e"}, 310 | {file = "lazy_object_proxy-1.4.3-cp34-cp34m-win_amd64.whl", hash = "sha256:eba7011090323c1dadf18b3b689845fd96a61ba0a1dfbd7f24b921398affc357"}, 311 | {file = "lazy_object_proxy-1.4.3-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:48dab84ebd4831077b150572aec802f303117c8cc5c871e182447281ebf3ac50"}, 312 | {file = "lazy_object_proxy-1.4.3-cp35-cp35m-win32.whl", hash = "sha256:ca0a928a3ddbc5725be2dd1cf895ec0a254798915fb3a36af0964a0a4149e3db"}, 313 | {file = "lazy_object_proxy-1.4.3-cp35-cp35m-win_amd64.whl", hash = "sha256:194d092e6f246b906e8f70884e620e459fc54db3259e60cf69a4d66c3fda3449"}, 314 | {file = "lazy_object_proxy-1.4.3-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:97bb5884f6f1cdce0099f86b907aa41c970c3c672ac8b9c8352789e103cf3156"}, 315 | {file = "lazy_object_proxy-1.4.3-cp36-cp36m-win32.whl", hash = "sha256:cb2c7c57005a6804ab66f106ceb8482da55f5314b7fcb06551db1edae4ad1531"}, 316 | {file = "lazy_object_proxy-1.4.3-cp36-cp36m-win_amd64.whl", hash = "sha256:8d859b89baf8ef7f8bc6b00aa20316483d67f0b1cbf422f5b4dc56701c8f2ffb"}, 317 | {file = "lazy_object_proxy-1.4.3-cp37-cp37m-macosx_10_13_x86_64.whl", hash = "sha256:1be7e4c9f96948003609aa6c974ae59830a6baecc5376c25c92d7d697e684c08"}, 318 | {file = "lazy_object_proxy-1.4.3-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:d74bb8693bf9cf75ac3b47a54d716bbb1a92648d5f781fc799347cfc95952383"}, 319 | {file = "lazy_object_proxy-1.4.3-cp37-cp37m-win32.whl", hash = "sha256:9b15f3f4c0f35727d3a0fba4b770b3c4ebbb1fa907dbcc046a1d2799f3edd142"}, 320 | {file = "lazy_object_proxy-1.4.3-cp37-cp37m-win_amd64.whl", hash = "sha256:9254f4358b9b541e3441b007a0ea0764b9d056afdeafc1a5569eee1cc6c1b9ea"}, 321 | {file = "lazy_object_proxy-1.4.3-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:a6ae12d08c0bf9909ce12385803a543bfe99b95fe01e752536a60af2b7797c62"}, 322 | {file = "lazy_object_proxy-1.4.3-cp38-cp38-win32.whl", hash = "sha256:5541cada25cd173702dbd99f8e22434105456314462326f06dba3e180f203dfd"}, 323 | {file = "lazy_object_proxy-1.4.3-cp38-cp38-win_amd64.whl", hash = "sha256:59f79fef100b09564bc2df42ea2d8d21a64fdcda64979c0fa3db7bdaabaf6239"}, 324 | ] 325 | mccabe = [ 326 | {file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"}, 327 | {file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"}, 328 | ] 329 | more-itertools = [ 330 | {file = "more-itertools-8.5.0.tar.gz", hash = "sha256:6f83822ae94818eae2612063a5101a7311e68ae8002005b5e05f03fd74a86a20"}, 331 | {file = "more_itertools-8.5.0-py3-none-any.whl", hash = "sha256:9b30f12df9393f0d28af9210ff8efe48d10c94f73e5daf886f10c4b0b0b4f03c"}, 332 | ] 333 | numpy = [ 334 | {file = "numpy-1.19.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:b594f76771bc7fc8a044c5ba303427ee67c17a09b36e1fa32bde82f5c419d17a"}, 335 | {file = "numpy-1.19.2-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:e6ddbdc5113628f15de7e4911c02aed74a4ccff531842c583e5032f6e5a179bd"}, 336 | {file = "numpy-1.19.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:3733640466733441295b0d6d3dcbf8e1ffa7e897d4d82903169529fd3386919a"}, 337 | {file = "numpy-1.19.2-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:4339741994c775396e1a274dba3609c69ab0f16056c1077f18979bec2a2c2e6e"}, 338 | {file = "numpy-1.19.2-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:7c6646314291d8f5ea900a7ea9c4261f834b5b62159ba2abe3836f4fa6705526"}, 339 | {file = "numpy-1.19.2-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:7118f0a9f2f617f921ec7d278d981244ba83c85eea197be7c5a4f84af80a9c3c"}, 340 | {file = "numpy-1.19.2-cp36-cp36m-win32.whl", hash = "sha256:9a3001248b9231ed73894c773142658bab914645261275f675d86c290c37f66d"}, 341 | {file = "numpy-1.19.2-cp36-cp36m-win_amd64.whl", hash = "sha256:967c92435f0b3ba37a4257c48b8715b76741410467e2bdb1097e8391fccfae15"}, 342 | {file = "numpy-1.19.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:d526fa58ae4aead839161535d59ea9565863bb0b0bdb3cc63214613fb16aced4"}, 343 | {file = "numpy-1.19.2-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:eb25c381d168daf351147713f49c626030dcff7a393d5caa62515d415a6071d8"}, 344 | {file = "numpy-1.19.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:62139af94728d22350a571b7c82795b9d59be77fc162414ada6c8b6a10ef5d02"}, 345 | {file = "numpy-1.19.2-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:0c66da1d202c52051625e55a249da35b31f65a81cb56e4c69af0dfb8fb0125bf"}, 346 | {file = "numpy-1.19.2-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:2117536e968abb7357d34d754e3733b0d7113d4c9f1d921f21a3d96dec5ff716"}, 347 | {file = "numpy-1.19.2-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:54045b198aebf41bf6bf4088012777c1d11703bf74461d70cd350c0af2182e45"}, 348 | {file = "numpy-1.19.2-cp37-cp37m-win32.whl", hash = "sha256:aba1d5daf1144b956bc87ffb87966791f5e9f3e1f6fab3d7f581db1f5b598f7a"}, 349 | {file = "numpy-1.19.2-cp37-cp37m-win_amd64.whl", hash = "sha256:addaa551b298052c16885fc70408d3848d4e2e7352de4e7a1e13e691abc734c1"}, 350 | {file = "numpy-1.19.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:58d66a6b3b55178a1f8a5fe98df26ace76260a70de694d99577ddeab7eaa9a9d"}, 351 | {file = "numpy-1.19.2-cp38-cp38-manylinux1_i686.whl", hash = "sha256:59f3d687faea7a4f7f93bd9665e5b102f32f3fa28514f15b126f099b7997203d"}, 352 | {file = "numpy-1.19.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cebd4f4e64cfe87f2039e4725781f6326a61f095bc77b3716502bed812b385a9"}, 353 | {file = "numpy-1.19.2-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:c35a01777f81e7333bcf276b605f39c872e28295441c265cd0c860f4b40148c1"}, 354 | {file = "numpy-1.19.2-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:d7ac33585e1f09e7345aa902c281bd777fdb792432d27fca857f39b70e5dd31c"}, 355 | {file = "numpy-1.19.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:04c7d4ebc5ff93d9822075ddb1751ff392a4375e5885299445fcebf877f179d5"}, 356 | {file = "numpy-1.19.2-cp38-cp38-win32.whl", hash = "sha256:51ee93e1fac3fe08ef54ff1c7f329db64d8a9c5557e6c8e908be9497ac76374b"}, 357 | {file = "numpy-1.19.2-cp38-cp38-win_amd64.whl", hash = "sha256:1669ec8e42f169ff715a904c9b2105b6640f3f2a4c4c2cb4920ae8b2785dac65"}, 358 | {file = "numpy-1.19.2-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:0bfd85053d1e9f60234f28f63d4a5147ada7f432943c113a11afcf3e65d9d4c8"}, 359 | {file = "numpy-1.19.2.zip", hash = "sha256:0d310730e1e793527065ad7dde736197b705d0e4c9999775f212b03c44a8484c"}, 360 | ] 361 | packaging = [ 362 | {file = "packaging-20.4-py2.py3-none-any.whl", hash = "sha256:998416ba6962ae7fbd6596850b80e17859a5753ba17c32284f67bfff33784181"}, 363 | {file = "packaging-20.4.tar.gz", hash = "sha256:4357f74f47b9c12db93624a82154e9b120fa8293699949152b22065d556079f8"}, 364 | ] 365 | pluggy = [ 366 | {file = "pluggy-0.13.1-py2.py3-none-any.whl", hash = "sha256:966c145cd83c96502c3c3868f50408687b38434af77734af1e9ca461a4081d2d"}, 367 | {file = "pluggy-0.13.1.tar.gz", hash = "sha256:15b2acde666561e1298d71b523007ed7364de07029219b604cf808bfa1c765b0"}, 368 | ] 369 | py = [ 370 | {file = "py-1.9.0-py2.py3-none-any.whl", hash = "sha256:366389d1db726cd2fcfc79732e75410e5fe4d31db13692115529d34069a043c2"}, 371 | {file = "py-1.9.0.tar.gz", hash = "sha256:9ca6883ce56b4e8da7e79ac18787889fa5206c79dcc67fb065376cd2fe03f342"}, 372 | ] 373 | pylint = [ 374 | {file = "pylint-2.6.0-py3-none-any.whl", hash = "sha256:bfe68f020f8a0fece830a22dd4d5dddb4ecc6137db04face4c3420a46a52239f"}, 375 | {file = "pylint-2.6.0.tar.gz", hash = "sha256:bb4a908c9dadbc3aac18860550e870f58e1a02c9f2c204fdf5693d73be061210"}, 376 | ] 377 | pyparsing = [ 378 | {file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"}, 379 | {file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"}, 380 | ] 381 | pytest = [ 382 | {file = "pytest-5.4.3-py3-none-any.whl", hash = "sha256:5c0db86b698e8f170ba4582a492248919255fcd4c79b1ee64ace34301fb589a1"}, 383 | {file = "pytest-5.4.3.tar.gz", hash = "sha256:7979331bfcba207414f5e1263b5a0f8f521d0f457318836a7355531ed1a4c7d8"}, 384 | ] 385 | six = [ 386 | {file = "six-1.15.0-py2.py3-none-any.whl", hash = "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"}, 387 | {file = "six-1.15.0.tar.gz", hash = "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259"}, 388 | ] 389 | toml = [ 390 | {file = "toml-0.10.1-py2.py3-none-any.whl", hash = "sha256:bda89d5935c2eac546d648028b9901107a595863cb36bae0c73ac804a9b4ce88"}, 391 | {file = "toml-0.10.1.tar.gz", hash = "sha256:926b612be1e5ce0634a2ca03470f95169cf16f939018233a670519cb4ac58b0f"}, 392 | ] 393 | torch = [ 394 | {file = "torch-1.6.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:7669f4d923b5758e28b521ea749c795ed67ff24b45ba20296bc8cff706d08df8"}, 395 | {file = "torch-1.6.0-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:728facb972a5952323c6d790c2c5922b2b35c44b0bc7bdfa02f8639727671a0c"}, 396 | {file = "torch-1.6.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:87d65c01d1b70bb46070824f28bfd93c86d3c5c56b90cbbe836a3f2491d91c76"}, 397 | {file = "torch-1.6.0-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:3838bd01af7dfb1f78573973f6842ce75b17e8e4f22be99c891dcb7c94bc13f5"}, 398 | {file = "torch-1.6.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:5357873e243bcfa804c32dc341f564e9a4c12addfc9baae4ee857fcc09a0a216"}, 399 | {file = "torch-1.6.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:4f9a4ad7947cef566afb0a323d99009fe8524f0b0f2ca1fb7ad5de0400381a5b"}, 400 | ] 401 | typed-ast = [ 402 | {file = "typed_ast-1.4.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:73d785a950fc82dd2a25897d525d003f6378d1cb23ab305578394694202a58c3"}, 403 | {file = "typed_ast-1.4.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:aaee9905aee35ba5905cfb3c62f3e83b3bec7b39413f0a7f19be4e547ea01ebb"}, 404 | {file = "typed_ast-1.4.1-cp35-cp35m-win32.whl", hash = "sha256:0c2c07682d61a629b68433afb159376e24e5b2fd4641d35424e462169c0a7919"}, 405 | {file = "typed_ast-1.4.1-cp35-cp35m-win_amd64.whl", hash = "sha256:4083861b0aa07990b619bd7ddc365eb7fa4b817e99cf5f8d9cf21a42780f6e01"}, 406 | {file = "typed_ast-1.4.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:269151951236b0f9a6f04015a9004084a5ab0d5f19b57de779f908621e7d8b75"}, 407 | {file = "typed_ast-1.4.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:24995c843eb0ad11a4527b026b4dde3da70e1f2d8806c99b7b4a7cf491612652"}, 408 | {file = "typed_ast-1.4.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:fe460b922ec15dd205595c9b5b99e2f056fd98ae8f9f56b888e7a17dc2b757e7"}, 409 | {file = "typed_ast-1.4.1-cp36-cp36m-win32.whl", hash = "sha256:4e3e5da80ccbebfff202a67bf900d081906c358ccc3d5e3c8aea42fdfdfd51c1"}, 410 | {file = "typed_ast-1.4.1-cp36-cp36m-win_amd64.whl", hash = "sha256:249862707802d40f7f29f6e1aad8d84b5aa9e44552d2cc17384b209f091276aa"}, 411 | {file = "typed_ast-1.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8ce678dbaf790dbdb3eba24056d5364fb45944f33553dd5869b7580cdbb83614"}, 412 | {file = "typed_ast-1.4.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:c9e348e02e4d2b4a8b2eedb48210430658df6951fa484e59de33ff773fbd4b41"}, 413 | {file = "typed_ast-1.4.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:bcd3b13b56ea479b3650b82cabd6b5343a625b0ced5429e4ccad28a8973f301b"}, 414 | {file = "typed_ast-1.4.1-cp37-cp37m-win32.whl", hash = "sha256:d5d33e9e7af3b34a40dc05f498939f0ebf187f07c385fd58d591c533ad8562fe"}, 415 | {file = "typed_ast-1.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:0666aa36131496aed8f7be0410ff974562ab7eeac11ef351def9ea6fa28f6355"}, 416 | {file = "typed_ast-1.4.1-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:d205b1b46085271b4e15f670058ce182bd1199e56b317bf2ec004b6a44f911f6"}, 417 | {file = "typed_ast-1.4.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:6daac9731f172c2a22ade6ed0c00197ee7cc1221aa84cfdf9c31defeb059a907"}, 418 | {file = "typed_ast-1.4.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:498b0f36cc7054c1fead3d7fc59d2150f4d5c6c56ba7fb150c013fbc683a8d2d"}, 419 | {file = "typed_ast-1.4.1-cp38-cp38-win32.whl", hash = "sha256:715ff2f2df46121071622063fc7543d9b1fd19ebfc4f5c8895af64a77a8c852c"}, 420 | {file = "typed_ast-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:fc0fea399acb12edbf8a628ba8d2312f583bdbdb3335635db062fa98cf71fca4"}, 421 | {file = "typed_ast-1.4.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:d43943ef777f9a1c42bf4e552ba23ac77a6351de620aa9acf64ad54933ad4d34"}, 422 | {file = "typed_ast-1.4.1.tar.gz", hash = "sha256:8c8aaad94455178e3187ab22c8b01a3837f8ee50e09cf31f1ba129eb293ec30b"}, 423 | ] 424 | wcwidth = [ 425 | {file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"}, 426 | {file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"}, 427 | ] 428 | wrapt = [ 429 | {file = "wrapt-1.12.1.tar.gz", hash = "sha256:b62ffa81fb85f4332a4f609cab4ac40709470da05643a082ec1eb88e6d9b97d7"}, 430 | ] 431 | zipp = [ 432 | {file = "zipp-3.1.0-py3-none-any.whl", hash = "sha256:aa36550ff0c0b7ef7fa639055d797116ee891440eac1a56f378e2d3179e0320b"}, 433 | {file = "zipp-3.1.0.tar.gz", hash = "sha256:c599e4d75c98f6798c509911d08a22e6c021d074469042177c8c86fb92eefd96"}, 434 | ] 435 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "torch-optim-sparse" 3 | version = "0.1.3" 4 | description = "PyTorch optimizers with sparse momentum and weight decay" 5 | authors = ["Karl Higley "] 6 | 7 | [tool.poetry.dependencies] 8 | python = "^3.7" 9 | torch = "^1.6.0" 10 | 11 | [tool.poetry.dev-dependencies] 12 | pytest = "^5.2" 13 | pylint = "^2.6.0" 14 | 15 | [build-system] 16 | requires = ["poetry>=0.12"] 17 | build-backend = "poetry.masonry.api" 18 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/karlhigley/torch-optim-sparse/b6dd32f60f27dbf7707ff1e050e9f193d538345f/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_torch_optim_sparse.py: -------------------------------------------------------------------------------- 1 | from torch_optim_sparse import __version__ 2 | 3 | 4 | def test_version(): 5 | assert __version__ == '0.1.0' 6 | -------------------------------------------------------------------------------- /torch_optim_sparse/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.1.3' 2 | 3 | from .sparser_adam import SparserAdam 4 | from .sparser_adamw import SparserAdamW 5 | from .sparser_sgd import SparserSGD 6 | from .sparser_sgdw import SparserSGDW 7 | 8 | 9 | def convert_lr(eff_lr, momentum=0.0, beta1=0.0, beta2=0.0, batch_size=1): 10 | """Calculates what learning rate to use for rough equivalence with plain SGD 11 | 12 | Useful for supplying one set of hyper-parameters to sweep across with multiple optimizers 13 | and getting them all to converge with hyper-parameters that are somewhere near the same order 14 | of magnitude. Accounts for the effects of optimizer batch size, momentum, and adaptive 15 | learning rates in Adam and SGD variants. 16 | 17 | All params except the effective learning rate are optional; only supply the params that are 18 | relevant to the optimizer you want to use. 19 | 20 | Args: 21 | eff_lr (float): The effective learning rate you want. 22 | momentum (float, optional): The SGD momentum coefficient. Defaults to 0.0, but 0.9 is typical. 23 | beta1 (float, optional): The Adam first moment coefficient. Defaults to 0.0, but 0.9 is typical. 24 | beta2 (float, optional): The Adam second moment coefficient. Defaults to 0.0, but 0.999 is typical. 25 | batch_size (int, optional): The number of examples in a mini-batch. Defaults to 1. 26 | 27 | Returns: 28 | lr (float): The adjusted learning rate to supply to the optimizer 29 | """ 30 | lr = eff_lr 31 | 32 | if beta1 != 1.0 or beta2 != 1.0: 33 | lr = lr * (1 - beta2) / (1 - beta1) 34 | 35 | if momentum != 0.0: 36 | lr = lr * (1 - momentum) 37 | 38 | if batch_size > 1: 39 | lr = lr * batch_size 40 | 41 | return lr 42 | -------------------------------------------------------------------------------- /torch_optim_sparse/sparser_adam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim import Optimizer 4 | 5 | 6 | class SparserAdam(Optimizer): 7 | 8 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0): 9 | params = list(params) 10 | 11 | if not 0.0 < lr: 12 | raise ValueError("Invalid learning rate: {}".format(lr)) 13 | if not 0.0 < eps: 14 | raise ValueError("Invalid epsilon value: {}".format(eps)) 15 | if not 0.0 <= betas[0] < 1.0: 16 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 17 | if not 0.0 <= betas[1] < 1.0: 18 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 19 | if not 0.0 <= weight_decay: 20 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 21 | 22 | sparse_params = [] 23 | for index, param in enumerate(params): 24 | if isinstance(param, dict): 25 | for d_index, d_param in enumerate(param.get("params", [])): 26 | if d_param.is_sparse: 27 | sparse_params.append([index, d_index]) 28 | elif param.is_sparse: 29 | sparse_params.append(index) 30 | if sparse_params: 31 | raise ValueError( 32 | f"Sparse params at indices {sparse_params}: SparserAdam requires dense parameter tensors" 33 | ) 34 | 35 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 36 | super(SparserAdam, self).__init__(params, defaults) 37 | 38 | @torch.no_grad() 39 | def step(self, closure=None): 40 | """Performs a single optimization step. 41 | 42 | Arguments: 43 | closure (callable, optional): A closure that reevaluates the model 44 | and returns the loss. 45 | """ 46 | loss = None 47 | if closure is not None: 48 | with torch.enable_grad(): 49 | loss = closure() 50 | 51 | for group in self.param_groups: 52 | for p in group['params']: 53 | if p.grad is None: 54 | continue 55 | grad = p.grad 56 | if not grad.is_sparse: 57 | raise RuntimeError('SparseAdam does not support dense gradients, please consider Adam instead') 58 | 59 | state = self.state[p] 60 | 61 | # State initialization 62 | if len(state) == 0: 63 | state['step'] = 0 64 | # Exponential moving average of gradient values 65 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 66 | # Exponential moving average of squared gradient values 67 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 68 | 69 | state['step'] += 1 70 | 71 | if group['weight_decay'] != 0: 72 | grad = grad.coalesce() 73 | grad.add_(p.sparse_mask(grad), alpha=group['weight_decay']) 74 | 75 | grad = grad.coalesce() # the update is non-linear so indices must be unique 76 | 77 | grad_indices = grad._indices() 78 | grad_values = grad._values() 79 | size = grad.size() 80 | 81 | def make_sparse(values): 82 | constructor = grad.new 83 | if grad_indices.dim() == 0 or values.dim() == 0: 84 | return constructor().resize_as_(grad) 85 | return constructor(grad_indices, values, size) 86 | 87 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 88 | beta1, beta2 = group['betas'] 89 | 90 | # Decay the first and second moment running average coefficient 91 | # old <- b * old + (1 - b) * new 92 | # <==> old += (1 - b) * (new - old) 93 | old_exp_avg_values = exp_avg.sparse_mask(grad)._values() 94 | exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1) 95 | exp_avg.add_(make_sparse(exp_avg_update_values)) 96 | old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values() 97 | exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2) 98 | exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values)) 99 | 100 | # Dense addition again is intended, avoiding another sparse_mask 101 | numer = exp_avg_update_values.add_(old_exp_avg_values) 102 | exp_avg_sq_update_values.add_(old_exp_avg_sq_values) 103 | denom = exp_avg_sq_update_values.sqrt_().add_(group['eps']) 104 | del exp_avg_update_values, exp_avg_sq_update_values 105 | 106 | bias_correction1 = 1 - beta1 ** state['step'] 107 | bias_correction2 = 1 - beta2 ** state['step'] 108 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 109 | 110 | p.add_(make_sparse(-step_size * numer.div_(denom))) 111 | 112 | return loss 113 | -------------------------------------------------------------------------------- /torch_optim_sparse/sparser_adamw.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim import Optimizer 4 | 5 | 6 | class SparserAdamW(Optimizer): 7 | 8 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0): 9 | params = list(params) 10 | 11 | if not 0.0 < lr: 12 | raise ValueError("Invalid learning rate: {}".format(lr)) 13 | if not 0.0 < eps: 14 | raise ValueError("Invalid epsilon value: {}".format(eps)) 15 | if not 0.0 <= betas[0] < 1.0: 16 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 17 | if not 0.0 <= betas[1] < 1.0: 18 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 19 | if not 0.0 <= weight_decay: 20 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 21 | 22 | sparse_params = [] 23 | for index, param in enumerate(params): 24 | if isinstance(param, dict): 25 | for d_index, d_param in enumerate(param.get("params", [])): 26 | if d_param.is_sparse: 27 | sparse_params.append([index, d_index]) 28 | elif param.is_sparse: 29 | sparse_params.append(index) 30 | if sparse_params: 31 | raise ValueError( 32 | f"Sparse params at indices {sparse_params}: SparserAdamW requires dense parameter tensors" 33 | ) 34 | 35 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 36 | super(SparserAdamW, self).__init__(params, defaults) 37 | 38 | @torch.no_grad() 39 | def step(self, closure=None): 40 | """Performs a single optimization step. 41 | 42 | Arguments: 43 | closure (callable, optional): A closure that reevaluates the model 44 | and returns the loss. 45 | """ 46 | loss = None 47 | if closure is not None: 48 | with torch.enable_grad(): 49 | loss = closure() 50 | 51 | for group in self.param_groups: 52 | for p in group['params']: 53 | if p.grad is None: 54 | continue 55 | 56 | # Perform optimization step 57 | grad = p.grad 58 | if not grad.is_sparse: 59 | raise RuntimeError('SparseAdam does not support dense gradients, please consider Adam instead') 60 | 61 | # State initialization 62 | state = self.state[p] 63 | 64 | if len(state) == 0: 65 | state['step'] = 0 66 | # Exponential moving average of gradient values 67 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 68 | # Exponential moving average of squared gradient values 69 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 70 | 71 | state['step'] += 1 72 | 73 | grad = grad.coalesce() # the update is non-linear so indices must be unique 74 | grad_indices = grad._indices() 75 | grad_values = grad._values() 76 | size = grad.size() 77 | 78 | # Perform weight decay step 79 | p[grad_indices].mul_(1 - group['lr'] * group['weight_decay']) 80 | # p.sub_(p.sparse_mask(grad), alpha=(group['lr'] * group['weight_decay'])) 81 | 82 | def make_sparse(values): 83 | constructor = grad.new 84 | if grad_indices.dim() == 0 or values.dim() == 0: 85 | return constructor().resize_as_(grad) 86 | return constructor(grad_indices, values, size) 87 | 88 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 89 | beta1, beta2 = group['betas'] 90 | 91 | # Decay the first and second moment running average coefficient 92 | # old <- b * old + (1 - b) * new 93 | # <==> old += (1 - b) * (new - old) 94 | old_exp_avg_values = exp_avg.sparse_mask(grad)._values() 95 | exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1) 96 | exp_avg.add_(make_sparse(exp_avg_update_values)) 97 | old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values() 98 | exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2) 99 | exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values)) 100 | 101 | # Dense addition again is intended, avoiding another sparse_mask 102 | numer = exp_avg_update_values.add_(old_exp_avg_values) 103 | exp_avg_sq_update_values.add_(old_exp_avg_sq_values) 104 | denom = exp_avg_sq_update_values.sqrt_().add_(group['eps']) 105 | del exp_avg_update_values, exp_avg_sq_update_values 106 | 107 | bias_correction1 = 1 - beta1 ** state['step'] 108 | bias_correction2 = 1 - beta2 ** state['step'] 109 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 110 | 111 | p.add_(make_sparse(-step_size * numer.div_(denom))) 112 | 113 | return loss 114 | -------------------------------------------------------------------------------- /torch_optim_sparse/sparser_sgd.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim import SGD 4 | 5 | 6 | class SparserSGD(SGD): 7 | 8 | @torch.no_grad() 9 | def step(self, closure=None): 10 | """Performs a single optimization step. 11 | 12 | Arguments: 13 | closure (callable, optional): A closure that reevaluates the model 14 | and returns the loss. 15 | """ 16 | loss = None 17 | if closure is not None: 18 | with torch.enable_grad(): 19 | loss = closure() 20 | 21 | for group in self.param_groups: 22 | lr = group['lr'] 23 | weight_decay = group['weight_decay'] 24 | momentum = group['momentum'] 25 | dampening = group['dampening'] 26 | nesterov = group['nesterov'] 27 | 28 | for p in group['params']: 29 | if p.grad is None: 30 | continue 31 | grad = p.grad 32 | 33 | grad = grad.coalesce() 34 | grad_inds = grad._indices() 35 | grad_values = grad._values() 36 | size = grad.size() 37 | 38 | def make_sparse(values): 39 | constructor = grad.new 40 | if grad_inds.dim() == 0 or values.dim() == 0: 41 | return constructor().resize_as_(grad) 42 | return constructor(grad_inds, values.reshape(grad_values.shape), size) 43 | 44 | if weight_decay != 0: 45 | param_values = p.data[grad_inds].squeeze() 46 | grad_values.add_(param_values, alpha=weight_decay) 47 | 48 | if momentum != 0: 49 | param_state = self.state[p] 50 | 51 | if 'momentum_buffer' not in param_state: 52 | buf = param_state['momentum_buffer'] = torch.clone(grad).detach().to_dense() 53 | else: 54 | buf = param_state['momentum_buffer'] 55 | # Only update momentum_buffer where sparse gradient is non-zero 56 | buf[grad_inds].mul_(momentum) 57 | buf.add_(grad, alpha=(1-dampening)) 58 | 59 | mom_values = buf[grad_inds].squeeze() 60 | 61 | if nesterov: 62 | mom_values = grad_values.add(mom_values, alpha=momentum) 63 | 64 | p.data.add_(make_sparse(mom_values), alpha=-lr) 65 | else: 66 | p.add_(grad, alpha=-lr) 67 | 68 | return loss 69 | -------------------------------------------------------------------------------- /torch_optim_sparse/sparser_sgdw.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim import SGD 4 | 5 | 6 | class SparserSGDW(SGD): 7 | 8 | @torch.no_grad() 9 | def step(self, closure=None): 10 | """Performs a single optimization step. 11 | 12 | Arguments: 13 | closure (callable, optional): A closure that reevaluates the model 14 | and returns the loss. 15 | """ 16 | loss = None 17 | if closure is not None: 18 | with torch.enable_grad(): 19 | loss = closure() 20 | 21 | for group in self.param_groups: 22 | lr = group['lr'] 23 | weight_decay = group['weight_decay'] 24 | momentum = group['momentum'] 25 | dampening = group['dampening'] 26 | nesterov = group['nesterov'] 27 | 28 | for p in group['params']: 29 | if p.grad is None: 30 | continue 31 | grad = p.grad 32 | 33 | grad = grad.coalesce() 34 | grad_inds = grad._indices() 35 | grad_values = grad._values() 36 | size = grad.size() 37 | 38 | def make_sparse(values): 39 | constructor = grad.new 40 | if grad_inds.dim() == 0 or values.dim() == 0: 41 | return constructor().resize_as_(grad) 42 | return constructor(grad_inds, values.reshape(grad_values.shape), size) 43 | 44 | if momentum != 0: 45 | param_state = self.state[p] 46 | 47 | if "momentum_buffer" not in param_state: 48 | buf = param_state["momentum_buffer"] = torch.clone(grad).detach().to_dense() 49 | else: 50 | buf = param_state["momentum_buffer"] 51 | # Only update momentum_buffer where sparse gradient is non-zero 52 | buf[grad_inds].mul_(momentum) 53 | buf.add_(grad, alpha=(1-dampening)) 54 | 55 | mom_values = buf[grad_inds].squeeze() 56 | 57 | if nesterov: 58 | mom_values = grad_values.add(mom_values, alpha=momentum) 59 | 60 | p.data.add_(make_sparse(mom_values), alpha=-lr) 61 | else: 62 | p.add_(grad, alpha=-lr) 63 | 64 | if weight_decay != 0: 65 | p.add_(p.sparse_mask(grad), alpha=-lr*weight_decay) 66 | 67 | return loss 68 | --------------------------------------------------------------------------------