├── .gitignore ├── LICENSE ├── README.md ├── assets └── default_schedule.png ├── psgd_jax ├── __init__.py ├── affine.py ├── kron.py ├── low_rank_approximation.py ├── psgd_test.py ├── utils.py └── xmat.py └── pyproject.toml /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | .idea/ 163 | 164 | # wandb 165 | wandb 166 | 167 | .DS_Store 168 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution 4.0 International Public License 58 | 59 | By exercising the Licensed Rights (defined below), You accept and agree 60 | to be bound by the terms and conditions of this Creative Commons 61 | Attribution 4.0 International Public License ("Public License"). To the 62 | extent this Public License may be interpreted as a contract, You are 63 | granted the Licensed Rights in consideration of Your acceptance of 64 | these terms and conditions, and the Licensor grants You such rights in 65 | consideration of benefits the Licensor receives from making the 66 | Licensed Material available under these terms and conditions. 67 | 68 | 69 | Section 1 -- Definitions. 70 | 71 | a. Adapted Material means material subject to Copyright and Similar 72 | Rights that is derived from or based upon the Licensed Material 73 | and in which the Licensed Material is translated, altered, 74 | arranged, transformed, or otherwise modified in a manner requiring 75 | permission under the Copyright and Similar Rights held by the 76 | Licensor. For purposes of this Public License, where the Licensed 77 | Material is a musical work, performance, or sound recording, 78 | Adapted Material is always produced where the Licensed Material is 79 | synched in timed relation with a moving image. 80 | 81 | b. Adapter's License means the license You apply to Your Copyright 82 | and Similar Rights in Your contributions to Adapted Material in 83 | accordance with the terms and conditions of this Public License. 84 | 85 | c. Copyright and Similar Rights means copyright and/or similar rights 86 | closely related to copyright including, without limitation, 87 | performance, broadcast, sound recording, and Sui Generis Database 88 | Rights, without regard to how the rights are labeled or 89 | categorized. For purposes of this Public License, the rights 90 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 91 | Rights. 92 | 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. Share means to provide material to the public by any means or 116 | process that requires permission under the Licensed Rights, such 117 | as reproduction, public display, public performance, distribution, 118 | dissemination, communication, or importation, and to make material 119 | available to the public including in ways that members of the 120 | public may access the material from a place and at a time 121 | individually chosen by them. 122 | 123 | j. Sui Generis Database Rights means rights other than copyright 124 | resulting from Directive 96/9/EC of the European Parliament and of 125 | the Council of 11 March 1996 on the legal protection of databases, 126 | as amended and/or succeeded, as well as other essentially 127 | equivalent rights anywhere in the world. 128 | 129 | k. You means the individual or entity exercising the Licensed Rights 130 | under this Public License. Your has a corresponding meaning. 131 | 132 | 133 | Section 2 -- Scope. 134 | 135 | a. License grant. 136 | 137 | 1. Subject to the terms and conditions of this Public License, 138 | the Licensor hereby grants You a worldwide, royalty-free, 139 | non-sublicensable, non-exclusive, irrevocable license to 140 | exercise the Licensed Rights in the Licensed Material to: 141 | 142 | a. reproduce and Share the Licensed Material, in whole or 143 | in part; and 144 | 145 | b. produce, reproduce, and Share Adapted Material. 146 | 147 | 2. Exceptions and Limitations. For the avoidance of doubt, where 148 | Exceptions and Limitations apply to Your use, this Public 149 | License does not apply, and You do not need to comply with 150 | its terms and conditions. 151 | 152 | 3. Term. The term of this Public License is specified in Section 153 | 6(a). 154 | 155 | 4. Media and formats; technical modifications allowed. The 156 | Licensor authorizes You to exercise the Licensed Rights in 157 | all media and formats whether now known or hereafter created, 158 | and to make technical modifications necessary to do so. The 159 | Licensor waives and/or agrees not to assert any right or 160 | authority to forbid You from making technical modifications 161 | necessary to exercise the Licensed Rights, including 162 | technical modifications necessary to circumvent Effective 163 | Technological Measures. For purposes of this Public License, 164 | simply making modifications authorized by this Section 2(a) 165 | (4) never produces Adapted Material. 166 | 167 | 5. Downstream recipients. 168 | 169 | a. Offer from the Licensor -- Licensed Material. Every 170 | recipient of the Licensed Material automatically 171 | receives an offer from the Licensor to exercise the 172 | Licensed Rights under the terms and conditions of this 173 | Public License. 174 | 175 | b. No downstream restrictions. You may not offer or impose 176 | any additional or different terms or conditions on, or 177 | apply any Effective Technological Measures to, the 178 | Licensed Material if doing so restricts exercise of the 179 | Licensed Rights by any recipient of the Licensed 180 | Material. 181 | 182 | 6. No endorsement. Nothing in this Public License constitutes or 183 | may be construed as permission to assert or imply that You 184 | are, or that Your use of the Licensed Material is, connected 185 | with, or sponsored, endorsed, or granted official status by, 186 | the Licensor or others designated to receive attribution as 187 | provided in Section 3(a)(1)(A)(i). 188 | 189 | b. Other rights. 190 | 191 | 1. Moral rights, such as the right of integrity, are not 192 | licensed under this Public License, nor are publicity, 193 | privacy, and/or other similar personality rights; however, to 194 | the extent possible, the Licensor waives and/or agrees not to 195 | assert any such rights held by the Licensor to the limited 196 | extent necessary to allow You to exercise the Licensed 197 | Rights, but not otherwise. 198 | 199 | 2. Patent and trademark rights are not licensed under this 200 | Public License. 201 | 202 | 3. To the extent possible, the Licensor waives any right to 203 | collect royalties from You for the exercise of the Licensed 204 | Rights, whether directly or through a collecting society 205 | under any voluntary or waivable statutory or compulsory 206 | licensing scheme. In all other cases the Licensor expressly 207 | reserves any right to collect such royalties. 208 | 209 | 210 | Section 3 -- License Conditions. 211 | 212 | Your exercise of the Licensed Rights is expressly made subject to the 213 | following conditions. 214 | 215 | a. Attribution. 216 | 217 | 1. If You Share the Licensed Material (including in modified 218 | form), You must: 219 | 220 | a. retain the following if it is supplied by the Licensor 221 | with the Licensed Material: 222 | 223 | i. identification of the creator(s) of the Licensed 224 | Material and any others designated to receive 225 | attribution, in any reasonable manner requested by 226 | the Licensor (including by pseudonym if 227 | designated); 228 | 229 | ii. a copyright notice; 230 | 231 | iii. a notice that refers to this Public License; 232 | 233 | iv. a notice that refers to the disclaimer of 234 | warranties; 235 | 236 | v. a URI or hyperlink to the Licensed Material to the 237 | extent reasonably practicable; 238 | 239 | b. indicate if You modified the Licensed Material and 240 | retain an indication of any previous modifications; and 241 | 242 | c. indicate the Licensed Material is licensed under this 243 | Public License, and include the text of, or the URI or 244 | hyperlink to, this Public License. 245 | 246 | 2. You may satisfy the conditions in Section 3(a)(1) in any 247 | reasonable manner based on the medium, means, and context in 248 | which You Share the Licensed Material. For example, it may be 249 | reasonable to satisfy the conditions by providing a URI or 250 | hyperlink to a resource that includes the required 251 | information. 252 | 253 | 3. If requested by the Licensor, You must remove any of the 254 | information required by Section 3(a)(1)(A) to the extent 255 | reasonably practicable. 256 | 257 | 4. If You Share Adapted Material You produce, the Adapter's 258 | License You apply must not prevent recipients of the Adapted 259 | Material from complying with this Public License. 260 | 261 | 262 | Section 4 -- Sui Generis Database Rights. 263 | 264 | Where the Licensed Rights include Sui Generis Database Rights that 265 | apply to Your use of the Licensed Material: 266 | 267 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 268 | to extract, reuse, reproduce, and Share all or a substantial 269 | portion of the contents of the database; 270 | 271 | b. if You include all or a substantial portion of the database 272 | contents in a database in which You have Sui Generis Database 273 | Rights, then the database in which You have Sui Generis Database 274 | Rights (but not its individual contents) is Adapted Material; and 275 | 276 | c. You must comply with the conditions in Section 3(a) if You Share 277 | all or a substantial portion of the contents of the database. 278 | 279 | For the avoidance of doubt, this Section 4 supplements and does not 280 | replace Your obligations under this Public License where the Licensed 281 | Rights include other Copyright and Similar Rights. 282 | 283 | 284 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 285 | 286 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 287 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 288 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 289 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 290 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 291 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 292 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 293 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 294 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 295 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 296 | 297 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 298 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 299 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 300 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 301 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 302 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 303 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 304 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 305 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 306 | 307 | c. The disclaimer of warranties and limitation of liability provided 308 | above shall be interpreted in a manner that, to the extent 309 | possible, most closely approximates an absolute disclaimer and 310 | waiver of all liability. 311 | 312 | 313 | Section 6 -- Term and Termination. 314 | 315 | a. This Public License applies for the term of the Copyright and 316 | Similar Rights licensed here. However, if You fail to comply with 317 | this Public License, then Your rights under this Public License 318 | terminate automatically. 319 | 320 | b. Where Your right to use the Licensed Material has terminated under 321 | Section 6(a), it reinstates: 322 | 323 | 1. automatically as of the date the violation is cured, provided 324 | it is cured within 30 days of Your discovery of the 325 | violation; or 326 | 327 | 2. upon express reinstatement by the Licensor. 328 | 329 | For the avoidance of doubt, this Section 6(b) does not affect any 330 | right the Licensor may have to seek remedies for Your violations 331 | of this Public License. 332 | 333 | c. For the avoidance of doubt, the Licensor may also offer the 334 | Licensed Material under separate terms or conditions or stop 335 | distributing the Licensed Material at any time; however, doing so 336 | will not terminate this Public License. 337 | 338 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 339 | License. 340 | 341 | 342 | Section 7 -- Other Terms and Conditions. 343 | 344 | a. The Licensor shall not be bound by any additional or different 345 | terms or conditions communicated by You unless expressly agreed. 346 | 347 | b. Any arrangements, understandings, or agreements regarding the 348 | Licensed Material not stated herein are separate from and 349 | independent of the terms and conditions of this Public License. 350 | 351 | 352 | Section 8 -- Interpretation. 353 | 354 | a. For the avoidance of doubt, this Public License does not, and 355 | shall not be interpreted to, reduce, limit, restrict, or impose 356 | conditions on any use of the Licensed Material that could lawfully 357 | be made without permission under this Public License. 358 | 359 | b. To the extent possible, if any provision of this Public License is 360 | deemed unenforceable, it shall be automatically reformed to the 361 | minimum extent necessary to make it enforceable. If the provision 362 | cannot be reformed, it shall be severed from this Public License 363 | without affecting the enforceability of the remaining terms and 364 | conditions. 365 | 366 | c. No term or condition of this Public License will be waived and no 367 | failure to comply consented to unless expressly agreed to by the 368 | Licensor. 369 | 370 | d. Nothing in this Public License constitutes or may be interpreted 371 | as a limitation upon, or waiver of, any privileges and immunities 372 | that apply to the Licensor or You, including from the legal 373 | processes of any jurisdiction or authority. 374 | 375 | 376 | ======================================================================= 377 | 378 | Creative Commons is not a party to its public 379 | licenses. Notwithstanding, Creative Commons may elect to apply one of 380 | its public licenses to material it publishes and in those instances 381 | will be considered the “Licensor.” The text of the Creative Commons 382 | public licenses is dedicated to the public domain under the CC0 Public 383 | Domain Dedication. Except for the limited purpose of indicating that 384 | material is shared under a Creative Commons public license or as 385 | otherwise permitted by the Creative Commons policies published at 386 | creativecommons.org/policies, Creative Commons does not authorize the 387 | use of the trademark "Creative Commons" or any other trademark or logo 388 | of Creative Commons without its prior written consent including, 389 | without limitation, in connection with any unauthorized modifications 390 | to any of its public licenses or any other arrangements, 391 | understandings, or agreements concerning use of licensed material. For 392 | the avoidance of doubt, this paragraph does not form part of the 393 | public licenses. 394 | 395 | Creative Commons may be contacted at creativecommons.org. 396 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PSGD (Preconditioned Stochastic Gradient Descent) 2 | 3 | For original PSGD repo, see [psgd_torch](https://github.com/lixilinx/psgd_torch). 4 | 5 | For PyTorch Kron version, see [kron_torch](https://github.com/evanatyourservice/kron_torch). 6 | 7 | Implementations of [PSGD optimizers](https://github.com/lixilinx/psgd_torch) in JAX (optax-style). 8 | PSGD is a second-order optimizer originally created by Xi-Lin Li that uses either a hessian-based 9 | or whitening-based (gg^T) preconditioner and lie groups to improve training convergence, 10 | generalization, and efficiency. I highly suggest taking a look at Xi-Lin's PSGD repo's readme linked 11 | to above for interesting details on how PSGD works and experiments using PSGD. There are also 12 | paper resources listed near the bottom of this readme. 13 | 14 | ### `kron`: 15 | 16 | The most versatile and easy-to-use PSGD optimizer is `kron`, which uses a Kronecker-factored 17 | preconditioner. It has less hyperparameters that need tuning than adam, and can generally act as a 18 | drop-in replacement. 19 | 20 | ## Installation 21 | 22 | ```bash 23 | pip install psgd-jax 24 | ``` 25 | 26 | ## Basic Usage (Kron) 27 | 28 | Kron schedules the preconditioner update probability by default to start at 1.0 and anneal to 0.03 29 | at the beginning of training, so training will be slightly slower at the start but will speed up 30 | by around 4k steps. 31 | 32 | For basic usage, use `kron` optimizer like any other optax optimizer: 33 | 34 | ```python 35 | from psgd_jax.kron import kron 36 | 37 | optimizer = kron() 38 | opt_state = optimizer.init(params) 39 | 40 | updates, opt_state = optimizer.update(grads, opt_state) 41 | params = optax.apply_updates(params, updates) 42 | ``` 43 | 44 | **Basic hyperparameters:** 45 | 46 | TLDR: Learning rate and weight decay act similarly to adam's, start with adam-like settings and go 47 | from there. Maybe use slightly lower learning rate (like /2). There is no b2 or epsilon. 48 | 49 | These next 3 settings control whether a dimension's preconditioner is diagonal or triangular. 50 | For example, for a layer with shape (256, 128), triagular preconditioners would be shapes (256, 256) 51 | and (128, 128), and diagonal preconditioners would be shapes (256,) and (128,). Depending on how 52 | these settings are chosen, `kron` can balance between memory/speed and effectiveness. Defaults lead 53 | to most precoditioners being triangular except for 1-dimensional layers and very large dimensions. 54 | 55 | `max_size_triangular`: Any dimension with size above this value will have a diagonal preconditioner. 56 | 57 | `min_ndim_triangular`: Any tensor with less than this number of dims will have all diagonal 58 | preconditioners. Default is 2, so single-dim layers like bias and scale will use diagonal 59 | preconditioners. 60 | 61 | `memory_save_mode`: Can be None, 'one_diag', or 'all_diag'. None is default and lets all 62 | preconditioners be triangular. 'one_diag' sets the largest or last dim per layer as diagonal 63 | using `np.argsort(shape)[::-1][0]`. 'all_diag' sets all preconditioners to be diagonal. 64 | 65 | `preconditioner_update_probability`: Preconditioner update probability uses a schedule by default 66 | that works well for most cases. It anneals from 1 to 0.03 at the beginning of training, so training 67 | will be slightly slower at the start but will speed up by around 4k steps. PSGD generally benefits 68 | from more preconditioner updates at the start of training, but once the preconditioner is learned 69 | it's okay to do them less often. An easy way to adjust update frequency is to define your own schedule 70 | using the `precond_update_prob_schedule` function in kron.py (just changing the `min_prob` value 71 | is easiest) and pass this into kron through the `preconditioner_update_probability` hyperparameter. 72 | 73 | This is the default schedule defined in the `precond_update_prob_schedule` function at the top of kron.py: 74 | 75 | Default Schedule 76 | 77 | 78 | **Sharding:** 79 | 80 | Kron contains einsums, and in general the first axis of the preconditioner matrices is the 81 | contracting axis. 82 | 83 | If using only FSDP, I usually shard the last axis of each preconditioner matrix and call it good. 84 | 85 | However, if using tensor parallelism in addition to FSDP, you may think more carefully about how 86 | the preconditioners are sharded in train_state. For example, with grads of shape (256, 128) and kron 87 | preconditioners of shapes (256, 256) and (128, 128), if the grads are sharded as (fsdp, tensor), 88 | then you may want to shard the (256, 256) preconditioner as (fsdp, tensor) and the (128, 128) 89 | preconditioner as (tensor, fsdp) so the grads and its preconditioners have similar contracting axes. 90 | 91 | 92 | **Scanned layers:** 93 | 94 | If you are scanning layers in your network, you can also have kron scan over these layers while 95 | updating and applying the preconditioner. Simply pass in a pytree through `scanned_layers` with 96 | the same structure as your params with bool values indicating which layers are scanned. PSGD will 97 | vmap over the first dims of those layers. If you need a more advanced scanning setup, please open 98 | an issue. 99 | 100 | For very large models, the preconditioner update may use too much memory all at once when scanning, 101 | in which case you can set `lax_map_scanned_layers` to `True` and set `lax_map_batch_size` to a 102 | reasonable batch size for your setup (`lax.map` scans over batches of vmap, see JAX docs). If 103 | your net is 32 layers and you're hitting OOM during the optimizer step, you can break the model into 104 | 2 or 4 and set `lax_map_batch_size` to 16 or 8 respectively. 105 | 106 | 107 | ## Advanced Usage (XMat, LRA, Affine) 108 | 109 | Other forms of PSGD include XMat, LRA, and Affine. PSGD defaults to a gradient 110 | whitening type preconditioner (gg^T). In this case, you can use PSGD like any other 111 | optax optimizer: 112 | 113 | ```python 114 | import jax 115 | import jax.numpy as jnp 116 | import optax 117 | from psgd_jax.xmat import xmat # or low_rank_approximation, affine 118 | 119 | 120 | def loss_fn(params, x): 121 | return jnp.sum((params - x) ** 2) 122 | 123 | 124 | params = jnp.array([1.0, 2.0, 3.0]) 125 | x = jnp.array([0.0, 0.0, 0.0]) 126 | 127 | # make optimizer and init state 128 | opt = xmat( 129 | learning_rate=1.0, 130 | b1=0.0, 131 | preconditioner_update_probability=1.0, # preconditioner update frequency 132 | ) 133 | opt_state = opt.init(params) 134 | 135 | 136 | def step(params, x, opt_state): 137 | loss_val, grad = jax.value_and_grad(loss_fn)(params, x) 138 | updates, opt_state = opt.update(grad, opt_state) 139 | params = optax.apply_updates(params, updates) 140 | return params, opt_state, loss_val 141 | 142 | 143 | while True: 144 | params, opt_state, loss_val = step(params, x, opt_state) 145 | print(loss_val) 146 | if loss_val < 1e-4: 147 | print("yay") 148 | break 149 | 150 | # Expected output: 151 | # 14.0 152 | # 5.1563816 153 | # 1.7376599 154 | # 0.6118454 155 | # 0.18457186 156 | # 0.056664664 157 | # 0.014270116 158 | # 0.0027846962 159 | # 0.00018843572 160 | # 4.3836744e-06 161 | # yay 162 | ``` 163 | 164 | However, PSGD can also be used with a hessian vector product. If values are provided for PSGD's extra 165 | update function arguments `Hvp`, `vector`, and `update_preconditioner`, PSGD automatically 166 | uses hessian-based preconditioning. `Hvp` is the hessian vector product, `vector` is the random 167 | vector used to calculate the hessian vector product, and `update_preconditioner` is a boolean 168 | that tells PSGD whether we're updating the preconditioner this step (passed in real hvp and 169 | vector) or not (passed in dummy hvp and vector). 170 | 171 | The `hessian_helper` function can help with this and generally replace `jax.value_and_grad`: 172 | 173 | ```python 174 | import jax 175 | import jax.numpy as jnp 176 | import optax 177 | from psgd_jax.xmat import xmat # or low_rank_approximation, affine 178 | from psgd_jax import hessian_helper 179 | 180 | 181 | def loss_fn(params, x): 182 | return jnp.sum((params - x) ** 2) 183 | 184 | 185 | params = jnp.array([1.0, 2.0, 3.0]) 186 | x = jnp.array([0.0, 0.0, 0.0]) 187 | 188 | # make optimizer and init state 189 | # no need to set 'preconditioner_update_probability' here, it's handled by hessian_helper 190 | opt = xmat( 191 | learning_rate=1.0, 192 | b1=0.0, 193 | ) 194 | opt_state = opt.init(params) 195 | 196 | 197 | def step(key, params, x, opt_state): 198 | # replace jax.value_and_grad with the hessian_helper: 199 | key, subkey = jax.random.split(key) 200 | loss_fn_out, grad, hvp, vector, update_precond = hessian_helper( 201 | subkey, 202 | loss_fn, 203 | params, 204 | loss_fn_extra_args=(x,), 205 | has_aux=False, 206 | preconditioner_update_probability=1.0, # update frequency handled in hessian_helper 207 | ) 208 | loss_val = loss_fn_out 209 | 210 | # Pass hvp, random vector, and whether we're updating the preconditioner 211 | # this step into the update function. PSGD will automatically switch to 212 | # hessian-based preconditioning when these are provided. 213 | updates, opt_state = opt.update( 214 | grad, 215 | opt_state, 216 | Hvp=hvp, 217 | vector=vector, 218 | update_preconditioner=update_precond 219 | ) 220 | 221 | params = optax.apply_updates(params, updates) 222 | return key, params, opt_state, loss_val 223 | 224 | 225 | key = jax.random.PRNGKey(0) 226 | while True: 227 | key, params, opt_state, loss_val = step(key, params, x, opt_state) 228 | print(loss_val) 229 | if loss_val < 1e-4: 230 | print("yay") 231 | break 232 | 233 | # Expected output: 234 | # 14.0 235 | # 7.460699e-14 236 | # yay 237 | ``` 238 | 239 | If `preconditioner_update_probability` is lowered, time is saved by calculating the hessian less 240 | often, but convergence could be slower. 241 | 242 | ## PSGD variants 243 | 244 | `psgd_jax.kron` - `psgd_jax.xmat` - `psgd_jax.low_rank_approximation` - `psgd_jax.affine` 245 | 246 | There are four variants of PSGD: Kron, which uses Kronecker-factored preconditioners for tensors 247 | of any number of dimensions, XMat, which uses an x-shaped global preconditioner, LRA, which uses 248 | a low-rank approximation global preconditioner, and Affine, which uses kronecker-factored 249 | preconditioners for matrices. 250 | 251 | **Kron:** 252 | 253 | Kron uses Kronecker-factored preconditioners for tensors of any number of dimensions. It's very 254 | versatile, has less hyperparameters that need tuning than adam, and can generally act as a drop-in 255 | replacement for adam. 256 | 257 | **XMat:** 258 | 259 | XMat is very simple to use, uses global hessian information for its preconditioner, and has 260 | memory use of only n_params * 3 (including momentum which is optional, set b1 to 0 to disable). 261 | 262 | **LRA:** 263 | 264 | Low rank approximation uses a low rank hessian for its preconditioner and can give very strong 265 | results. It has memory use of n_params * (2 * rank + 1) (n_params * (2 * rank) without momentum). 266 | 267 | **Affine:** 268 | 269 | Affine does not use global hessian information, but can be powerful nonetheless and possibly use 270 | less memory than xmat, LRA, or adam. `max_size_triangular` and `max_skew_triangular` determine whether 271 | a dimension's preconditioner is triangular or diagonal. Affine and Kron are nearly identical for matrices. 272 | 273 | 274 | ## Resources 275 | 276 | PSGD papers and resources listed from Xi-Lin's repo 277 | 278 | 1) Xi-Lin Li. Preconditioned stochastic gradient descent, [arXiv:1512.04202](https://arxiv.org/abs/1512.04202), 2015. (General ideas of PSGD, preconditioner fitting losses and Kronecker product preconditioners.) 279 | 2) Xi-Lin Li. Preconditioner on matrix Lie group for SGD, [arXiv:1809.10232](https://arxiv.org/abs/1809.10232), 2018. (Focus on preconditioners with the affine Lie group.) 280 | 3) Xi-Lin Li. Black box Lie group preconditioners for SGD, [arXiv:2211.04422](https://arxiv.org/abs/2211.04422), 2022. (Mainly about the LRA preconditioner. See [these supplementary materials](https://drive.google.com/file/d/1CTNx1q67_py87jn-0OI-vSLcsM1K7VsM/view) for detailed math derivations.) 281 | 4) Xi-Lin Li. Stochastic Hessian fittings on Lie groups, [arXiv:2402.11858](https://arxiv.org/abs/2402.11858), 2024. (Some theoretical works on the efficiency of PSGD. The Hessian fitting problem is shown to be strongly convex on set ${\rm GL}(n, \mathbb{R})/R_{\rm polar}$.) 282 | 5) Omead Pooladzandi, Xi-Lin Li. Curvature-informed SGD via general purpose Lie-group preconditioners, [arXiv:2402.04553](https://arxiv.org/abs/2402.04553), 2024. (Plenty of benchmark results and analyses for PSGD vs. other optimizers.) 283 | 284 | 285 | ## License 286 | 287 | [![CC BY 4.0][cc-by-image]][cc-by] 288 | 289 | This work is licensed under a [Creative Commons Attribution 4.0 International License][cc-by]. 290 | 291 | 2024 Evan Walters, Omead Pooladzandi, Xi-Lin Li 292 | 293 | 294 | [cc-by]: http://creativecommons.org/licenses/by/4.0/ 295 | [cc-by-image]: https://licensebuttons.net/l/by/4.0/88x31.png 296 | [cc-by-shield]: https://img.shields.io/badge/License-CC%20BY%204.0-lightgrey.svg 297 | -------------------------------------------------------------------------------- /assets/default_schedule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evanatyourservice/psgd_jax/be9552446381ce1ad9fcf9632c5c1045a9fa44c3/assets/default_schedule.png -------------------------------------------------------------------------------- /psgd_jax/__init__.py: -------------------------------------------------------------------------------- 1 | from psgd_jax.utils import hessian_helper 2 | from psgd_jax.kron import precond_update_prob_schedule -------------------------------------------------------------------------------- /psgd_jax/affine.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Union, Callable, NamedTuple, List 2 | 3 | import jax 4 | from jax import numpy as jnp 5 | from jax.random import PRNGKey 6 | 7 | from optax import tree_utils as otu 8 | from optax._src import base, transform, clipping 9 | from optax._src.numerics import safe_int32_increment 10 | from optax._src.utils import canonicalize_dtype 11 | from optax._src.combine import chain 12 | 13 | from psgd_jax.utils import add_eps, apply_momentum 14 | 15 | 16 | class PSGDAffineState(NamedTuple): 17 | count: jax.Array 18 | key: PRNGKey 19 | mu: Optional[base.Updates] 20 | Qs: List[List[jax.Array]] 21 | 22 | 23 | def scale_by_affine( 24 | preconditioner_update_probability: float = 1.0, 25 | b1: float = 0.9, 26 | nesterov: bool = False, 27 | max_size_triangular: int = 4096, 28 | max_skew_triangular: int = 128, 29 | precond_lr: Union[float, Callable[[int], float]] = 0.1, 30 | precond_init_scale: Optional[float] = None, 31 | update_global_norm_clip: Optional[float] = None, 32 | step_normalizer_order: str = "2nd", 33 | seed: Optional[PRNGKey] = None, 34 | mu_dtype: Optional[Union[str, jnp.dtype]] = None, 35 | precision: str = "tensorfloat32", 36 | ) -> base.GradientTransformationExtraArgs: 37 | """ 38 | Implements Affine PSGD from https://github.com/lixilinx/psgd_torch. 39 | 40 | Args: 41 | preconditioner_update_probability: float, probability of updating the 42 | preconditioner. 43 | b1: float, momentum parameter. 44 | nesterov: bool, whether to use Nesterov momentum. 45 | max_size_triangular: int, max size for affine preconditioner to be 46 | triangular. 47 | max_skew_triangular: int, max skew for affine preconditioner to be 48 | triangular. 49 | precond_lr: float or callable, learning rate for the preconditioner. 50 | precond_init_scale: optional float, initial scale for the preconditioner. 51 | update_global_norm_clip: optional float, clip updates by global norm. 52 | step_normalizer_order: str, '1st' or '2nd'. 53 | seed: Optional PRNGKey, random seed. 54 | mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator. 55 | Defaults to the same dtype as the parameters. 56 | precision: str, precision for matmul, 'bfloat16', 'tensorfloat32', 'float32'. 57 | 58 | Returns: 59 | optax.GradientTransformationExtraArgs 60 | """ 61 | mu_dtype = canonicalize_dtype(mu_dtype) 62 | 63 | def init_fn(params): 64 | key = seed if seed is not None else jax.random.PRNGKey(36) 65 | 66 | # momentum 67 | mu = None 68 | if b1 > 0: 69 | print("PSGD: Using momentum.") 70 | mu = otu.tree_zeros_like(params, mu_dtype) 71 | 72 | # preconditioners 73 | affine_reshapers = [_shape_as_matrix(x) for x in jax.tree.leaves(params)] 74 | Qs = [ 75 | _initQ(s[2], max_size_triangular, max_skew_triangular, jnp.float32) 76 | for s in affine_reshapers 77 | ] 78 | 79 | # initial state 80 | return PSGDAffineState(count=jnp.zeros([], jnp.int32), key=key, mu=mu, Qs=Qs) 81 | 82 | def update_fn( 83 | updates: base.Updates, 84 | state: PSGDAffineState, 85 | params: base.Params = None, 86 | Hvp: Optional[base.Updates] = None, 87 | vector: Optional[base.Updates] = None, 88 | update_preconditioner: Optional[bool] = None, 89 | ): 90 | del params 91 | # use hessian preconditioning if hessian provided 92 | # otherwise use gg^T whitening type preconditioning 93 | hessian_based_preconditioning = Hvp is not None 94 | if hessian_based_preconditioning and ( 95 | vector is None or update_preconditioner is None 96 | ): 97 | raise ValueError( 98 | "If using Hessian-based preconditioning, must also pass in random vector and " 99 | "update_preconditioner to PSGD's update function. See README for more info." 100 | ) 101 | 102 | count_inc = safe_int32_increment(state.count) 103 | key = state.key 104 | affine_reshapers = [_shape_as_matrix(x) for x in jax.tree.leaves(updates)] 105 | 106 | precond_lr_in = precond_lr 107 | if isinstance(precond_lr, Callable): 108 | precond_lr_in = precond_lr(count_inc) 109 | 110 | def _update_precond(key: PRNGKey, state: PSGDAffineState, Hvs, vs): 111 | Hvs = [r[0](x) for x, r in zip(jax.tree.leaves(Hvs), affine_reshapers)] 112 | 113 | if hessian_based_preconditioning: 114 | vs = [r[0](x) for x, r in zip(jax.tree.leaves(vs), affine_reshapers)] 115 | 116 | # init Qs 117 | def init_q(v, h): 118 | if precond_init_scale is not None: 119 | return precond_init_scale 120 | else: 121 | return (jnp.sum(v * v.conj()) / jnp.sum(h * h.conj())) ** 0.25 122 | 123 | Qs = jax.lax.cond( 124 | state.count == 0, 125 | lambda: [ 126 | [init_q(v, h) ** 0.5 * q for q in Qlr] 127 | for v, h, Qlr in zip(vs, Hvs, state.Qs) 128 | ], 129 | lambda: state.Qs, 130 | ) 131 | 132 | # update preconditioner 133 | key, subkey = jax.random.split(key) 134 | keys = jax.random.split(subkey, len(Qs)) 135 | Qs = [ 136 | _update_precond_affine_math_( 137 | k, 138 | Qlr[0], 139 | Qlr[1], 140 | v, 141 | h, 142 | precond_lr_in, 143 | step_normalizer_order, 144 | precision, 145 | ) 146 | for (k, Qlr, v, h) in zip( 147 | keys, Qs, jax.tree.leaves(vs), jax.tree.leaves(Hvs) 148 | ) 149 | ] 150 | else: 151 | # init Qs 152 | def init_q(g): 153 | if precond_init_scale is not None: 154 | return precond_init_scale 155 | else: 156 | return (g.size / jnp.sum(g * g.conj())) ** 0.25 157 | 158 | Qs = jax.lax.cond( 159 | state.count == 0, 160 | lambda: [ 161 | [init_q(g) ** 0.5 * q for q in Qlr] 162 | for g, Qlr in zip(Hvs, state.Qs) 163 | ], 164 | lambda: state.Qs, 165 | ) 166 | 167 | # update preconditioner 168 | key, subkey = jax.random.split(key) 169 | keys = jax.random.split(subkey, len(Qs)) 170 | Qs = [ 171 | _update_precond_affine_dropv_math( 172 | k, 173 | Qlr[0], 174 | Qlr[1], 175 | h, 176 | precond_lr_in, 177 | step_normalizer_order, 178 | precision, 179 | ) 180 | for (k, Qlr, h) in zip(keys, Qs, jax.tree.leaves(Hvs)) 181 | ] 182 | 183 | return key, Qs 184 | 185 | def _dont_update_precond(key, state, Hvs, vs): 186 | return key, state.Qs 187 | 188 | if not hessian_based_preconditioning: 189 | # update cond not passed in, create here 190 | key, subkey = jax.random.split(key) 191 | update_preconditioner = jnp.logical_or( 192 | jax.random.uniform(subkey) < preconditioner_update_probability, 193 | state.count < 2, 194 | ) 195 | # use grads as Hvp 196 | Hvp = updates 197 | 198 | key, Qs = jax.lax.cond( 199 | update_preconditioner, 200 | _update_precond, 201 | _dont_update_precond, 202 | key, 203 | state, 204 | Hvp, 205 | vector, 206 | ) 207 | 208 | # momentum 209 | mu = None 210 | if state.mu is not None: 211 | updates, mu = apply_momentum(updates, state.mu, count_inc, b1, nesterov) 212 | 213 | # preconditioning 214 | flat_updates = [ 215 | r[0](u) for u, r in zip(jax.tree.leaves(updates), affine_reshapers) 216 | ] 217 | flat_updates = [ 218 | _precond_grad_affine_math(Qlr[0], Qlr[1], g) 219 | for (Qlr, g) in zip(Qs, flat_updates) 220 | ] 221 | flat_updates = [r[1](u) for u, r in zip(flat_updates, affine_reshapers)] 222 | updates = jax.tree_unflatten(jax.tree.structure(updates), flat_updates) 223 | 224 | # clipping 225 | if update_global_norm_clip is not None: 226 | updates, _ = clipping.clip_by_global_norm(update_global_norm_clip).update( 227 | updates, base.EmptyState 228 | ) 229 | 230 | mu = otu.tree_cast(mu, mu_dtype) 231 | state = PSGDAffineState(count=count_inc, key=key, mu=mu, Qs=Qs) 232 | return updates, state 233 | 234 | return base.GradientTransformationExtraArgs(init_fn, update_fn) 235 | 236 | 237 | def affine( 238 | learning_rate: Union[float, Callable[[int], float]] = 0.01, 239 | preconditioner_update_probability: float = 1.0, 240 | b1: float = 0.9, 241 | nesterov: bool = False, 242 | weight_decay: float = 0.0, 243 | mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, 244 | max_size_triangular: int = 4096, 245 | max_skew_triangular: int = 128, 246 | precond_lr: Union[float, Callable[[int], float]] = 0.1, 247 | precond_init_scale: Optional[float] = None, 248 | update_global_norm_clip: Optional[float] = None, 249 | step_normalizer_order: str = "2nd", 250 | seed: Optional[PRNGKey] = None, 251 | mu_dtype: Optional[Union[str, jnp.dtype]] = None, 252 | precision: str = "tensorfloat32", 253 | ) -> base.GradientTransformationExtraArgs: 254 | """ 255 | Implements Affine PSGD from https://github.com/lixilinx/psgd_torch. 256 | 257 | Args: 258 | learning_rate: float or callable, learning rate. 259 | preconditioner_update_probability: float, probability of updating the 260 | preconditioner. 261 | b1: float, momentum parameter. 262 | nesterov: bool, whether to use Nesterov momentum. 263 | weight_decay: float, weight decay. 264 | mask: optional Any or callable, mask to apply to parameters. 265 | max_size_triangular: int, max size for affine preconditioner to be 266 | triangular. 267 | max_skew_triangular: int, max skew for affine preconditioner to be 268 | triangular. 269 | precond_lr: float or callable, learning rate for the preconditioner. 270 | precond_init_scale: optional float, initial scale for the preconditioner. 271 | update_global_norm_clip: optional float, clip updates by global norm. 272 | step_normalizer_order: str, '1st' or '2nd'. 273 | seed: Optional PRNGKey, random seed. 274 | mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator. 275 | Defaults to the same dtype as the parameters. 276 | precision: str, precision for matmul, 'bfloat16', 'tensorfloat32', 'float32'. 277 | 278 | Returns: 279 | optax.GradientTransformationExtraArgs 280 | """ 281 | opt = [ 282 | scale_by_affine( 283 | preconditioner_update_probability=preconditioner_update_probability, 284 | b1=b1, 285 | nesterov=nesterov, 286 | max_size_triangular=max_size_triangular, 287 | max_skew_triangular=max_skew_triangular, 288 | precond_lr=precond_lr, 289 | precond_init_scale=precond_init_scale, 290 | update_global_norm_clip=update_global_norm_clip, 291 | step_normalizer_order=step_normalizer_order, 292 | seed=seed, 293 | mu_dtype=mu_dtype, 294 | precision=precision, 295 | ) 296 | ] 297 | if weight_decay > 0: 298 | opt.append(transform.add_decayed_weights(weight_decay, mask=mask)) 299 | opt.append(transform.scale_by_learning_rate(learning_rate)) 300 | return chain(*opt) 301 | 302 | 303 | def _norm_lower_bound(A: jax.Array): 304 | """ 305 | Returns a cheap lower bound for the spectral norm of A. 306 | Numerical results on random matrices with a wide range of distributions and sizes suggest, 307 | norm(A) <= sqrt(2) * norm_lower_bound(A) 308 | Looks to be a very tight lower bound. 309 | """ 310 | max_abs = jnp.max(jnp.abs(A)) 311 | 312 | def calc(A): 313 | A = A / max_abs 314 | 315 | aa = jnp.real(A * A.conj()) 316 | 317 | aa_sum0 = jnp.sum(aa, axis=0) 318 | aa_sum1 = jnp.sum(aa, axis=1) 319 | i = jnp.argmax(aa_sum0, 0) 320 | j = jnp.argmax(aa_sum1, 0) 321 | value0 = jax.lax.dynamic_index_in_dim(aa_sum0, i, 0, keepdims=False) 322 | value1 = jax.lax.dynamic_index_in_dim(aa_sum1, j, 0, keepdims=False) 323 | 324 | def gt_branch(): 325 | x = jax.lax.dynamic_index_in_dim(A, i, 1, keepdims=False) 326 | x = x.conj() @ A 327 | return max_abs * jnp.linalg.norm((x / jnp.linalg.norm(x)) @ A.conj().T) 328 | 329 | def le_branch(): 330 | x = jax.lax.dynamic_index_in_dim(A, j, 0, keepdims=False) 331 | x = A @ x.conj() 332 | return max_abs * jnp.linalg.norm(A.conj().T @ (x / jnp.linalg.norm(x))) 333 | 334 | return jax.lax.cond(value0 > value1, gt_branch, le_branch) 335 | 336 | def pass_calc(A): 337 | return max_abs 338 | 339 | return jax.lax.cond(max_abs > 0, calc, pass_calc, A) 340 | 341 | 342 | def _shape_as_matrix(x: jax.Array) -> tuple: 343 | """Reshapes tensor x to a matrix with conditions to improve efficiency. 344 | 345 | From original pytorch version. 346 | 347 | Args: 348 | x: jax.Array, tensor to be reshaped. 349 | 350 | Returns: 351 | tuple where first element is function that convert x to matrix, second 352 | element is function that converts matrix back to x, and third element 353 | is the shape of x as a matrix. 354 | """ 355 | 356 | def prod(arr): 357 | # prod = lambda arr: 1 if len(arr)==0 else arr[0]*prod(arr[1:]) 358 | result = 1 359 | for a in arr: 360 | result *= a 361 | return result 362 | 363 | def permutations(p0): 364 | # generate all the permutations of the original one p0 365 | if len(p0) == 1: 366 | yield p0 367 | else: 368 | for i in range(len(p0)): 369 | for q in permutations(p0[:i] + p0[i + 1 :]): 370 | yield p0[i], *q 371 | 372 | # here begins the processing 373 | if x.ndim == 2: # t already is a matrix, do nothing 374 | return lambda u: u, lambda v: v, x.shape 375 | elif x.ndim < 2: # scalar or vector, simple reshape to matrix 376 | mtx_shape = (1, x.size) 377 | return ( 378 | lambda u, shape=mtx_shape: u.reshape(shape), 379 | lambda v, shape=x.shape: v.reshape(shape), 380 | mtx_shape, 381 | ) 382 | else: # higher order tensor, a little complicated 383 | p0, s0 = tuple(range(x.ndim)), x.shape # original permutation and shape 384 | min_precond_size, opt_p, opt_s, opt_i = float("inf"), None, None, None 385 | for p in permutations(p0): 386 | s = tuple(s0[j] for j in p) 387 | for i in range(1, len(p)): 388 | if (new_size := prod(s[:i]) ** 2 + prod(s[i:]) ** 2) < min_precond_size: 389 | min_precond_size = new_size 390 | opt_p, opt_s, opt_i = p, s, i 391 | 392 | if opt_p == p0: # no permutation is needed, just reshaping 393 | mtx_shape = (prod(s0[:opt_i]), prod(s0[opt_i:])) 394 | return ( 395 | lambda u, shape=mtx_shape: u.reshape(shape), 396 | lambda v, shape=s0: v.reshape(shape), 397 | mtx_shape, 398 | ) 399 | else: # need both permutation and reshaping 400 | mtx_shape = (prod(opt_s[:opt_i]), prod(opt_s[opt_i:])) 401 | q = tuple( 402 | pair[1] for pair in sorted([(k, i) for (i, k) in enumerate(opt_p)]) 403 | ) 404 | return ( 405 | lambda u, permute=opt_p, shape=mtx_shape: u.transpose(permute).reshape( 406 | shape 407 | ), 408 | lambda v, permute=q, shape=opt_s: v.reshape(shape).transpose(permute), 409 | mtx_shape, 410 | ) 411 | 412 | 413 | def _initQ(shape, max_size, max_skew, dtype=jnp.float32): 414 | """ 415 | It initializes Q = kron(Q2, Q1) for param p to scale * I, 416 | where Q1 and Q2 can reduce to diagonal matrices to save memory if 417 | max_size or max_skew are set to small numbers. 418 | """ 419 | assert len(shape) == 2, "preconditioned param shape must be 2D" 420 | s1, s2 = shape 421 | if s1 < 2 or s1 > max_size or s1 > max_skew * s2: 422 | Q1 = jnp.ones(s1, dtype=dtype) 423 | else: 424 | Q1 = jnp.eye(s1, dtype=dtype) 425 | 426 | if s2 < 2 or s2 > max_size or s2 > max_skew * s1: 427 | Q2 = jnp.ones(s2, dtype=dtype) 428 | else: 429 | Q2 = jnp.eye(s2, dtype=dtype) 430 | 431 | return [Q1, Q2] 432 | 433 | 434 | def _solve_triangular(a, b, upper, left=True): 435 | """jax.lax.linalg.triangular_solve rewritten to match PyTorch convention.""" 436 | return jax.lax.linalg.triangular_solve(a, b, left_side=left, lower=not upper) 437 | 438 | 439 | def _update_precond_affine_math_( 440 | key, Ql, Qr, dX, dG, precond_lr, step_normalizer, precision 441 | ): 442 | with jax.default_matmul_precision(precision): 443 | if Ql.ndim == 2: 444 | if Qr.ndim == 2: # Ql.dim()=2 and Qr.dim()=2: 445 | A = jnp.linalg.multi_dot([Ql, dG, Qr.conj().T]) 446 | Bh = _solve_triangular( 447 | Ql.conj().T, 448 | _solve_triangular(Qr, dX, upper=True, left=False), 449 | upper=False, 450 | ) 451 | 452 | AhA, BhB = A.conj().T @ A, Bh @ Bh.conj().T 453 | AAh, BBh = A @ A.conj().T, Bh.conj().T @ Bh 454 | grad1 = jnp.triu(AAh - BhB) 455 | grad2 = jnp.triu(AhA - BBh) 456 | 457 | if step_normalizer == "2nd": 458 | step1 = precond_lr / add_eps(_norm_lower_bound(AAh + BhB)) 459 | step2 = precond_lr / add_eps(_norm_lower_bound(AhA + BBh)) 460 | else: 461 | step1 = precond_lr / add_eps(_norm_lower_bound(grad1)) 462 | step2 = precond_lr / add_eps(_norm_lower_bound(grad2)) 463 | 464 | Ql -= step1 * grad1 @ Ql 465 | Qr -= step2 * grad2 @ Qr 466 | else: # Ql.dim()=2 and Qr.dim()=1: 467 | A = Ql @ (dG * Qr.conj()) 468 | Bh = _solve_triangular(Ql.conj().T, dX / Qr, upper=False) 469 | 470 | AAh, BhB = A @ A.conj().T, Bh @ Bh.conj().T 471 | AAc, BBc = jnp.sum(A * A.conj(), axis=0), jnp.sum( 472 | Bh * Bh.conj(), axis=0 473 | ) 474 | grad1 = jnp.triu(AAh - BhB) 475 | grad2 = AAc - BBc 476 | 477 | if step_normalizer == "2nd": 478 | step1 = precond_lr / add_eps(_norm_lower_bound(AAh + BhB)) 479 | step2 = precond_lr / add_eps(jnp.max(jnp.real(AAc + BBc))) 480 | else: 481 | step1 = precond_lr / add_eps(_norm_lower_bound(grad1)) 482 | step2 = precond_lr / add_eps(jnp.max(jnp.abs(grad2))) 483 | 484 | Ql -= step1 * grad1 @ Ql 485 | Qr -= step2 * grad2 * Qr 486 | else: 487 | if Qr.ndim == 2: # Ql.dim()=1 and Qr.dim()=2: 488 | A = (Ql[:, None] * dG) @ Qr.conj().T 489 | Bh = _solve_triangular(Qr, dX, upper=True, left=False) / ( 490 | Ql.conj()[:, None] 491 | ) 492 | 493 | AAc, BBc = jnp.sum(A * A.conj(), axis=1), jnp.sum( 494 | Bh * Bh.conj(), axis=1 495 | ) 496 | AhA, BBh = A.conj().T @ A, Bh.conj().T @ Bh 497 | grad1 = AAc - BBc 498 | grad2 = jnp.triu(AhA - BBh) 499 | 500 | if step_normalizer == "2nd": 501 | step1 = precond_lr / add_eps(jnp.max(jnp.real(AAc + BBc))) 502 | step2 = precond_lr / add_eps(_norm_lower_bound(AhA + BBh)) 503 | else: 504 | step1 = precond_lr / add_eps(jnp.max(jnp.abs(grad1))) 505 | step2 = precond_lr / add_eps(_norm_lower_bound(grad2)) 506 | 507 | Ql -= step1 * grad1 * Ql 508 | Qr -= step2 * grad2 @ Qr 509 | else: # Ql.dim()=1 and Qr.dim()=1: 510 | A = Ql[:, None] * dG * Qr.conj() 511 | Bh = dX / Qr / Ql.conj()[:, None] 512 | 513 | AAc1, BBc1 = jnp.sum(A * A.conj(), axis=1), jnp.sum( 514 | Bh * Bh.conj(), axis=1 515 | ) 516 | AAc2, BBc2 = jnp.sum(A * A.conj(), axis=0), jnp.sum( 517 | Bh * Bh.conj(), axis=0 518 | ) 519 | grad1 = AAc1 - BBc1 520 | grad2 = AAc2 - BBc2 521 | 522 | if step_normalizer == "2nd": 523 | step1 = precond_lr / add_eps(jnp.max(jnp.real(AAc1 + BBc1))) 524 | step2 = precond_lr / add_eps(jnp.max(jnp.real(AAc2 + BBc2))) 525 | else: 526 | step1 = precond_lr / add_eps(jnp.max(jnp.abs(grad1))) 527 | step2 = precond_lr / add_eps(jnp.max(jnp.abs(grad2))) 528 | 529 | Ql -= step1 * grad1 * Ql 530 | Qr -= step2 * grad2 * Qr 531 | 532 | def _balance(Ql, Qr): 533 | max_l = jnp.max(jnp.abs(Ql)) 534 | max_r = jnp.max(jnp.abs(Qr)) 535 | 536 | rho = jnp.sqrt(max_l / max_r) 537 | Ql /= rho 538 | Qr *= rho 539 | return Ql, Qr 540 | 541 | key, subkey = jax.random.split(key) 542 | Ql, Qr = jax.lax.cond( 543 | jax.random.uniform(subkey) < 0.01, _balance, lambda ql, qr: (ql, qr), Ql, Qr 544 | ) 545 | 546 | return [Ql, Qr] 547 | 548 | 549 | def _update_precond_affine_dropv_math( 550 | key, Ql, Qr, dG, precond_lr, step_normalizer, precision 551 | ): 552 | with jax.default_matmul_precision(precision): 553 | 554 | def balance(key, Ql, Qr): 555 | def _balance(Ql, Qr): 556 | max_l = jnp.max(jnp.abs(Ql)) 557 | max_r = jnp.max(jnp.abs(Qr)) 558 | 559 | rho = jnp.sqrt(max_l / max_r) 560 | Ql /= rho 561 | Qr *= rho 562 | return Ql, Qr 563 | 564 | Ql, Qr = jax.lax.cond( 565 | jax.random.uniform(key) < 0.01, 566 | _balance, 567 | lambda ql, qr: (ql, qr), 568 | Ql, 569 | Qr, 570 | ) 571 | return Ql, Qr 572 | 573 | if Ql.ndim == 1 and Qr.ndim == 1: 574 | # drop v when both dims use diagonal preconditioners 575 | A = Ql[:, None] * dG * Qr.conj() 576 | invQQl, invQQr = 1 / (Ql * Ql.conj()), 1 / (Qr * Qr.conj()) 577 | 578 | AAc1, BBc1 = jnp.sum(A * A.conj(), axis=1), jnp.sum(invQQr) * invQQl 579 | AAc2, BBc2 = jnp.sum(A * A.conj(), axis=0), jnp.sum(invQQl) * invQQr 580 | grad1 = AAc1 - BBc1 581 | grad2 = AAc2 - BBc2 582 | 583 | if step_normalizer == "2nd": 584 | step1 = precond_lr / add_eps(jnp.max(jnp.real(AAc1 + BBc1))) 585 | step2 = precond_lr / add_eps(jnp.max(jnp.real(AAc2 + BBc2))) 586 | else: 587 | step1 = precond_lr / add_eps(jnp.max(jnp.abs(grad1))) 588 | step2 = precond_lr / add_eps(jnp.max(jnp.abs(grad2))) 589 | 590 | Ql = Ql - step1 * grad1 * Ql 591 | Qr = Qr - step2 * grad2 * Qr 592 | 593 | key, subkey = jax.random.split(key) 594 | Ql, Qr = balance(subkey, Ql, Qr) 595 | 596 | elif Ql.ndim == 1 and Qr.ndim == 2 and Ql.shape[0] >= Qr.shape[0]: 597 | # drop v when left is diagonal, right is dense, and gradient is a tall matrix 598 | A = (Ql[:, None] * dG) @ Qr.conj().T 599 | invQQl = 1 / (Ql * Ql.conj()) 600 | invQr = _solve_triangular( 601 | Qr, jnp.eye(Qr.shape[0], dtype=Qr.dtype), upper=True 602 | ) 603 | invQQr = invQr.conj().T @ invQr 604 | 605 | AAc, BBc = jnp.sum(A * A.conj(), axis=1), jnp.trace(invQQr) * invQQl 606 | AhA, BBh = A.conj().T @ A, jnp.sum(invQQl) * invQQr 607 | grad1 = AAc - BBc 608 | grad2 = jnp.triu(AhA - BBh) 609 | 610 | if step_normalizer == "2nd": 611 | step1 = precond_lr / add_eps(jnp.max(jnp.real(AAc + BBc))) 612 | step2 = precond_lr / add_eps(_norm_lower_bound(AhA + BBh)) 613 | else: 614 | step1 = precond_lr / add_eps(jnp.max(jnp.abs(grad1))) 615 | step2 = precond_lr / add_eps(_norm_lower_bound(grad2)) 616 | 617 | Ql -= step1 * grad1 * Ql 618 | Qr -= step2 * grad2 @ Qr 619 | 620 | key, subkey = jax.random.split(key) 621 | Ql, Qr = balance(subkey, Ql, Qr) 622 | 623 | elif Qr.ndim == 1 and Ql.ndim == 2 and Qr.shape[0] >= Ql.shape[0]: 624 | # drop v when right is diagonal, left is dense, and gradient is a short matrix 625 | A = Ql @ (dG * Qr.conj()) 626 | invQl = _solve_triangular( 627 | Ql, jnp.eye(Ql.shape[0], dtype=Ql.dtype), upper=True 628 | ) 629 | invQQl = invQl.conj().T @ invQl 630 | invQQr = 1 / (Qr * Qr.conj()) 631 | 632 | AAh, BhB = A @ A.conj().T, jnp.sum(invQQr) * invQQl 633 | AAc, BBc = jnp.sum(A * A.conj(), axis=0), jnp.trace(invQQl) * invQQr 634 | grad1 = jnp.triu(AAh - BhB) 635 | grad2 = AAc - BBc 636 | 637 | if step_normalizer == "2nd": 638 | step1 = precond_lr / add_eps(_norm_lower_bound(AAh + BhB)) 639 | step2 = precond_lr / add_eps(jnp.max(jnp.real(AAc + BBc))) 640 | else: 641 | step1 = precond_lr / add_eps(_norm_lower_bound(grad1)) 642 | step2 = precond_lr / add_eps(jnp.max(jnp.abs(grad2))) 643 | 644 | Ql -= step1 * grad1 @ Ql 645 | Qr -= step2 * grad2 * Qr 646 | 647 | key, subkey = jax.random.split(key) 648 | Ql, Qr = balance(subkey, Ql, Qr) 649 | 650 | else: 651 | # keeping v as an auxiliary variable could save computations (tradeoff of performance, similar to Hutchinson’s trick) when 652 | # 1) gradient is a tall matrix, but left side is a dense preconditioner, right side is diagonal 653 | # 2) gradient is a short matrix, but left side is a diagonal preconditioner, right side is dense 654 | # 3) both sides use dense preconditioner, but gradient is skewed (no saving for square shape gradient) 655 | key, subkey = jax.random.split(key) 656 | v = otu.tree_random_like(subkey, dG, jax.random.normal) 657 | key, subkey = jax.random.split(key) 658 | return _update_precond_affine_math_( 659 | subkey, Ql, Qr, v, dG, precond_lr, step_normalizer, precision 660 | ) 661 | 662 | return [Ql, Qr] 663 | 664 | 665 | def _precond_grad_affine_math(Ql, Qr, grad): 666 | if Ql.ndim == 2: 667 | if Qr.ndim == 2: # Ql.ndim=2 and Qr.ndim=2: 668 | return jnp.linalg.multi_dot([Ql.conj().T, Ql, grad, Qr.conj().T, Qr]) 669 | else: # Ql.ndim=2 and Qr.ndim=1: 670 | return jnp.linalg.multi_dot([Ql.conj().T, Ql, grad * (Qr * Qr.conj())]) 671 | else: 672 | if Qr.ndim == 2: # Ql.ndim=1 and Qr.ndim=2: 673 | return jnp.linalg.multi_dot( 674 | [(Ql * Ql.conj())[:, None] * grad, Qr.conj().T, Qr] 675 | ) 676 | else: # Ql.ndim=1 and Qr.ndim=1: 677 | return (Ql * Ql.conj())[:, None] * grad * (Qr * Qr.conj()) 678 | -------------------------------------------------------------------------------- /psgd_jax/kron.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional, Union, Callable 2 | from functools import partial 3 | import string 4 | import numpy as np 5 | 6 | import chex 7 | import jax 8 | from jax import vmap 9 | import jax.numpy as jnp 10 | import flax.linen as nn 11 | from optax import tree_utils as otu 12 | from optax._src import base, transform, clipping, numerics 13 | from optax._src.numerics import safe_int32_increment 14 | from optax._src.utils import canonicalize_dtype 15 | from optax._src.combine import chain 16 | 17 | 18 | def precond_update_prob_schedule( 19 | max_prob=1.0, min_prob=0.03, decay=0.001, flat_start=500 20 | ): 21 | """Anneal preconditioner update probability during beginning of training. 22 | 23 | PSGD benefits from more preconditioner updates at the beginning of training, 24 | but once the preconditioner is learned the update probability can drop low. 25 | 26 | This schedule is an exponential anneal with a flat start. Default settings keep 27 | update probability at 1.0 for 250 steps then exponentially anneal down to 28 | `min_prob` by 4000 steps. Default settings work well for most models and 29 | training regimes. 30 | """ 31 | 32 | def _schedule(n): 33 | """Exponential anneal with flat start.""" 34 | return jnp.minimum( 35 | jnp.maximum(max_prob * jnp.exp(-decay * (n - flat_start)), min_prob), 36 | max_prob, 37 | ) 38 | 39 | return _schedule 40 | 41 | 42 | def scale_by_kron( 43 | b1: float = 0.9, 44 | preconditioner_update_probability: Union[ 45 | float, Callable[[int], float] 46 | ] = precond_update_prob_schedule(), 47 | max_size_triangular: int = 8192, 48 | min_ndim_triangular: int = 2, 49 | memory_save_mode: Optional[str] = None, 50 | momentum_into_precond_update: bool = True, 51 | preconditioner_lr: float = 0.1, 52 | preconditioner_init_scale: float = 1.0, 53 | mu_dtype: Optional[Union[str, jnp.dtype]] = None, 54 | precond_dtype: Optional[Union[str, jnp.dtype]] = None, 55 | precond_update_precision: Optional[str] = "tensorfloat32", 56 | precond_grads_precision: Optional[str] = None, 57 | scanned_layers: Optional[base.Params] = None, 58 | lax_map_scanned_layers: bool = False, 59 | lax_map_batch_size: int = 8, 60 | ) -> base.GradientTransformationExtraArgs: 61 | """ 62 | Implements PSGD Kron from https://github.com/lixilinx/psgd_torch. 63 | 64 | Args: 65 | b1: float, momentum parameter. 66 | preconditioner_update_probability: float, probability of updating the 67 | preconditioner. Default anneals from 1.0 to 0.03 by 4000 steps. 68 | max_size_triangular: int, max size for dim's preconditioner to be triangular. 69 | min_ndim_triangular: int, minimum number of dimensions a layer needs to have 70 | triangular preconditioners. 71 | memory_save_mode: optional str, None, 'one_diag', or 'all_diag', None is default 72 | to set all preconditioners to be triangular, 'one_diag' sets the largest 73 | or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners 74 | to be diagonal. 75 | momentum_into_precond_update: bool, whether to send momentum into preconditioner 76 | update instead of raw gradients. 77 | preconditioner_lr: float, learning rate for preconditioner. 78 | preconditioner_init_scale: float, scale for preconditioner initialization. 79 | mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator. 80 | Defaults to the same dtype as the parameters. 81 | precond_dtype: optional str or jnp.dtype, dtype of the preconditioner. 82 | precond_update_precision: str, precision for matmul during preconditioner update, 83 | 'bfloat16', 'tensorfloat32', 'float32'. 84 | precond_grads_precision: str, precision for matmul during preconditioning grads, 85 | 'bfloat16', 'tensorfloat32', 'float32'. 86 | scanned_layers: optional base.Params, tree of bool same structure as params 87 | indicating scanned layers. PSGD will vmap over the first dim. 88 | lax_map_scanned_layers: bool, whether to use lax.map for scanned layers 89 | instead of vmap. Useful to save memory with large models. 90 | lax_map_batch_size: int, batch size for lax.map, see JAX docs for more info. 91 | 92 | Returns: 93 | optax.GradientTransformationExtraArgs 94 | """ 95 | mu_dtype = canonicalize_dtype(mu_dtype) 96 | precond_dtype = canonicalize_dtype(precond_dtype) 97 | 98 | def map_fn(do_map, fn, *args): 99 | """Maybe map a fn along first axis.""" 100 | if do_map: 101 | if lax_map_scanned_layers: 102 | return jax.lax.map( 103 | lambda xs: fn(*xs), 104 | xs=args, 105 | batch_size=lax_map_batch_size if lax_map_batch_size > 1 else None, 106 | ) 107 | else: 108 | return vmap(fn)(*args) 109 | else: 110 | return fn(*args) 111 | 112 | def init_fn(params): 113 | params = jax.tree.map( 114 | lambda x: x.unbox() if isinstance(x, nn.Partitioned) else x, 115 | params, 116 | is_leaf=lambda v: isinstance(v, (chex.Array, nn.Partitioned)), 117 | ) 118 | 119 | scanned_layers_ = scanned_layers 120 | if scanned_layers is None: 121 | scanned_layers_ = jax.tree.map(lambda _: False, params) 122 | 123 | # momentum 124 | mu = None 125 | if b1 > 0: 126 | mu = jax.tree.map(lambda x: jnp.zeros_like(x, dtype=mu_dtype), params) 127 | 128 | # preconditioners 129 | Qs = [ 130 | _init_Q_exprs( 131 | t[0] if s else t, 132 | preconditioner_init_scale, 133 | max_size_triangular, 134 | min_ndim_triangular, 135 | memory_save_mode, 136 | precond_dtype, 137 | )[0] 138 | for t, s in zip(jax.tree.leaves(params), jax.tree.leaves(scanned_layers_)) 139 | ] 140 | # broadcast for scanned layers 141 | Qs = [ 142 | ( 143 | jax.tree.map( 144 | lambda d: jnp.repeat(jnp.expand_dims(d, 0), t.shape[0], axis=0), q 145 | ) 146 | if s 147 | else q 148 | ) 149 | for q, t, s in zip( 150 | Qs, jax.tree.leaves(params), jax.tree.leaves(scanned_layers_) 151 | ) 152 | ] 153 | Qs = jax.tree.structure(params).unflatten(Qs) 154 | 155 | # Calculate sizes for nu (preconditioner) and mu (momentum) 156 | Qs_n_elements = sum([q.size for q in jax.tree.leaves(Qs)]) 157 | Qs_size_MB = sum( 158 | [q.size * q.dtype.itemsize / (2**20) for q in jax.tree.leaves(Qs)] 159 | ) 160 | if jax.process_index() == 0: 161 | print( 162 | f"PSGD Preconditioners size: {Qs_n_elements} elements, " 163 | f"{Qs_size_MB:.2f} MB" 164 | ) 165 | if mu is not None: 166 | mu_n_elements = sum([p.size for p in jax.tree.leaves(mu)]) 167 | mu_size_MB = sum( 168 | [p.size * p.dtype.itemsize / (2**20) for p in jax.tree.leaves(mu)] 169 | ) 170 | if jax.process_index() == 0: 171 | print( 172 | f"PSGD Momentum size: {mu_n_elements} elements, {mu_size_MB:.2f} MB" 173 | ) 174 | 175 | # initial state 176 | return dict( 177 | count=jnp.zeros([], jnp.int32), 178 | mu=mu, 179 | Qs_preconditioners=Qs, 180 | update_counter=jnp.zeros([], jnp.int32), 181 | ) 182 | 183 | def update_fn(updates: base.Updates, state: dict, params: base.Params = None): 184 | del params 185 | count_inc = safe_int32_increment(state["count"]) 186 | key = jax.random.fold_in(jax.random.PRNGKey(5318008), state["count"]) 187 | 188 | # account for flax.linen.Partitioned grads and params 189 | boxed_updates, grads_structure = jax.tree.flatten( 190 | updates, is_leaf=lambda v: isinstance(v, (chex.Array, nn.Partitioned)) 191 | ) 192 | flax_partitioned = False 193 | if isinstance(boxed_updates[0], nn.Partitioned): 194 | flax_partitioned = True 195 | updates = [u.unbox() for u in boxed_updates] 196 | updates = grads_structure.unflatten(updates) 197 | 198 | scanned_layers_ = scanned_layers 199 | if scanned_layers is None: 200 | scanned_layers_ = jax.tree.map(lambda _: False, updates) 201 | 202 | update_prob_in = preconditioner_update_probability 203 | if isinstance(preconditioner_update_probability, Callable): 204 | update_prob_in = preconditioner_update_probability(count_inc) 205 | 206 | # momentum 207 | mu = None 208 | momentum_updates = updates 209 | if state["mu"] is not None: 210 | mu = otu.tree_update_moment(updates, state["mu"], b1, 1) 211 | momentum_updates = otu.tree_bias_correction(mu, b1, count_inc) 212 | 213 | # flatten pytrees 214 | updates, grads_structure = jax.tree.flatten(updates) 215 | momentum_updates = grads_structure.flatten_up_to(momentum_updates) 216 | Qs = grads_structure.flatten_up_to(state["Qs_preconditioners"]) 217 | scanned_layers_ = grads_structure.flatten_up_to(scanned_layers_) 218 | 219 | # get einsum expressions 220 | expressions = [ 221 | _init_Q_exprs( 222 | t[0] if s else t, 223 | preconditioner_init_scale, 224 | max_size_triangular, 225 | min_ndim_triangular, 226 | memory_save_mode, 227 | precond_dtype, 228 | existing_Q=jax.tree.map(lambda d: d[0], Q) if s else Q, 229 | ) 230 | for t, s, Q in zip(updates, scanned_layers_, Qs) 231 | ] 232 | 233 | # maybe update preconditioner 234 | def update_preconditioner(key, Qs): 235 | with jax.default_matmul_precision(precond_update_precision): 236 | if momentum_into_precond_update: 237 | precond_updates_in = momentum_updates 238 | else: 239 | precond_updates_in = updates 240 | 241 | # balance preconditioners about every 100 updates 242 | def balance_Qs(Qs: List[List[jax.Array]]): 243 | def _balance_Q(Q: List[jax.Array]): 244 | norms = jnp.array( 245 | [jnp.max(jnp.abs(q)) for q in Q], dtype=jnp.float32 246 | ) 247 | gmean = jnp.prod(norms) ** (1 / len(norms)) 248 | to_mul = gmean / norms 249 | return [q * x.astype(q.dtype) for q, x in zip(Q, to_mul)] 250 | 251 | return [ 252 | map_fn(s, _balance_Q, Q) if len(Q) > 1 else Q 253 | for Q, s in zip(Qs, scanned_layers_) 254 | ] 255 | 256 | key, subkey = jax.random.split(key) 257 | do_balances = jax.random.uniform(subkey) < 0.01 258 | Qs = jax.lax.cond(do_balances, balance_Qs, lambda qs: qs, Qs) 259 | 260 | # create random vectors 261 | key, subkey = jax.random.split(key) 262 | Vs_keys = jax.random.split(subkey, len(precond_updates_in)) 263 | Vs = [ 264 | jax.random.normal(k, shape=g.shape, dtype=g.dtype) 265 | for k, g in zip(Vs_keys, precond_updates_in) 266 | ] 267 | 268 | # damp based on machine precision (f32 probably enough) 269 | damp_eps = jnp.sqrt(jnp.finfo(jnp.float32).eps) 270 | precond_updates_in = jax.tree.map( 271 | lambda g, v: g + damp_eps.astype(g.dtype) * jnp.mean(jnp.abs(g)) * v, 272 | precond_updates_in, 273 | Vs, 274 | ) 275 | 276 | # form conjB 277 | conjBs = [ 278 | map_fn(s, _conjB, Q, g, v) 279 | for s, Q, g, v in zip(scanned_layers_, Qs, precond_updates_in, Vs) 280 | ] 281 | 282 | # update Qs 283 | new_Qs = [ 284 | map_fn( 285 | s, 286 | partial( 287 | _update_precond, exprs=exprs, precond_lr=preconditioner_lr 288 | ), 289 | Q, 290 | g, 291 | conjb, 292 | ) 293 | for s, exprs, Q, g, conjb in zip( 294 | scanned_layers_, expressions, Qs, precond_updates_in, conjBs 295 | ) 296 | ] 297 | 298 | new_Qs = otu.tree_cast(new_Qs, precond_dtype) 299 | return new_Qs 300 | 301 | # update preconditioner deterministically 302 | update_counter_inc = safe_int32_increment(state["update_counter"]) 303 | do_update = update_counter_inc >= 1 / update_prob_in 304 | update_counter_inc = jnp.where(do_update, 0, update_counter_inc) 305 | key, subkey = jax.random.split(key) 306 | Qs = jax.lax.cond(do_update, update_preconditioner, lambda _, qs: qs, subkey, Qs) 307 | 308 | # precondition gradients 309 | with jax.default_matmul_precision(precond_grads_precision): 310 | precond_gs = [ 311 | map_fn(s, partial(_precond_grad, exprs=exprs), Q, g) 312 | for s, exprs, Q, g in zip( 313 | scanned_layers_, expressions, Qs, momentum_updates 314 | ) 315 | ] 316 | 317 | # RMS of pre_grad should be 1.0, so let's cap at 1.1 318 | def _clip_fn(u): 319 | clip_denom = jnp.maximum( 320 | 1.0, 321 | jnp.sqrt(jnp.mean(numerics.abs_sq(u))) / 1.1) 322 | return u / clip_denom 323 | 324 | precond_gs = jax.tree.map(_clip_fn, precond_gs) 325 | 326 | # box preconditioned grads 327 | if flax_partitioned: 328 | precond_gs = [ 329 | u.replace_boxed(pg) for u, pg in zip(boxed_updates, precond_gs) 330 | ] 331 | 332 | # unflatten pytrees 333 | updates = grads_structure.unflatten(precond_gs) 334 | Qs = grads_structure.unflatten(Qs) 335 | 336 | # dtypes and new state 337 | mu = otu.tree_cast(mu, mu_dtype) 338 | Qs = otu.tree_cast(Qs, precond_dtype) 339 | state = dict( 340 | count=count_inc, 341 | mu=mu, 342 | Qs_preconditioners=Qs, 343 | update_counter=update_counter_inc, 344 | ) 345 | 346 | return updates, state 347 | 348 | return base.GradientTransformationExtraArgs(init_fn, update_fn) 349 | 350 | 351 | def kron( 352 | learning_rate: Union[float, Callable[[int], float]] = 0.001, 353 | b1: float = 0.9, 354 | weight_decay: float = 0.0, 355 | weight_decay_mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, 356 | preconditioner_update_probability: Union[ 357 | float, Callable[[int], float] 358 | ] = precond_update_prob_schedule(), 359 | max_size_triangular: int = 8192, 360 | min_ndim_triangular: int = 2, 361 | memory_save_mode: Optional[str] = None, 362 | momentum_into_precond_update: bool = True, 363 | preconditioner_lr: float = 0.1, 364 | preconditioner_init_scale: float = 1.0, 365 | mu_dtype: Optional[Union[str, jnp.dtype]] = None, 366 | precond_dtype: Optional[Union[str, jnp.dtype]] = None, 367 | precond_update_precision: Optional[str] = "tensorfloat32", 368 | precond_grads_precision: Optional[str] = None, 369 | scanned_layers: Optional[base.Params] = None, 370 | lax_map_scanned_layers: bool = False, 371 | lax_map_batch_size: int = 8, 372 | ) -> base.GradientTransformationExtraArgs: 373 | """ 374 | Implements PSGD Kron from https://github.com/lixilinx/psgd_torch. 375 | 376 | Args: 377 | learning_rate: float or callable, learning rate. 378 | b1: float, momentum parameter. 379 | weight_decay: float, weight decay. 380 | weight_decay_mask: optional Any or callable, pytree of bool same structure 381 | as params with weight decay applied to True elements. 382 | preconditioner_update_probability: float, probability of updating the 383 | preconditioner. Default anneals from 1.0 to 0.03 by 4000 steps. 384 | max_size_triangular: int, max size for dim's preconditioner to be triangular. 385 | min_ndim_triangular: int, minimum number of dimensions a layer needs to have 386 | triangular preconditioners. 387 | memory_save_mode: optional str, None, 'one_diag', or 'all_diag', None is default 388 | to set all preconditioners to be triangular. 'one_diag' sets only the largest 389 | or last dim in a layer to be diagonal, and 'all_diag' sets all preconditioners 390 | to be diagonal. 391 | momentum_into_precond_update: bool, whether to send momentum into preconditioner 392 | update instead of raw gradients. 393 | preconditioner_lr: float, learning rate for preconditioner. 394 | preconditioner_init_scale: float, scale for preconditioner initialization. 395 | mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator. 396 | Defaults to the same dtype as the parameters. 397 | precond_dtype: optional str or jnp.dtype, dtype of the preconditioner. 398 | precond_update_precision: str, precision for matmul during preconditioner update, 399 | 'bfloat16', 'tensorfloat32', 'float32'. 400 | precond_grads_precision: str, precision for matmul during preconditioning grads, 401 | 'bfloat16', 'tensorfloat32', 'float32'. 402 | scanned_layers: optional base.Params, tree of bool same structure as params 403 | indicating scanned layers. PSGD will vmap over the first dim. 404 | lax_map_scanned_layers: bool, whether to use lax.map for scanned layers 405 | instead of vmap. Useful to save memory with large models. 406 | lax_map_batch_size: int, batch size for lax.map, see JAX docs for more info. 407 | 408 | Returns: 409 | optax.GradientTransformationExtraArgs 410 | """ 411 | optimizer = [ 412 | scale_by_kron( 413 | b1=b1, 414 | preconditioner_update_probability=preconditioner_update_probability, 415 | max_size_triangular=max_size_triangular, 416 | min_ndim_triangular=min_ndim_triangular, 417 | memory_save_mode=memory_save_mode, 418 | momentum_into_precond_update=momentum_into_precond_update, 419 | preconditioner_lr=preconditioner_lr, 420 | preconditioner_init_scale=preconditioner_init_scale, 421 | mu_dtype=mu_dtype, 422 | precond_dtype=precond_dtype, 423 | precond_update_precision=precond_update_precision, 424 | precond_grads_precision=precond_grads_precision, 425 | scanned_layers=scanned_layers, 426 | lax_map_scanned_layers=lax_map_scanned_layers, 427 | lax_map_batch_size=lax_map_batch_size, 428 | ) 429 | ] 430 | if weight_decay > 0.0: 431 | optimizer.append(transform.add_decayed_weights(weight_decay, weight_decay_mask)) 432 | optimizer.append(transform.scale_by_learning_rate(learning_rate)) 433 | return chain(*optimizer) 434 | 435 | 436 | def _add_tiny(x): 437 | return x + jnp.finfo(x.dtype).tiny 438 | 439 | 440 | def _norm_lower_bound(A: jax.Array): 441 | """Returns a cheap lower bound for the spectral norm of A. 442 | 443 | Numerical results on random matrices with a wide range of distributions and 444 | sizes suggest, norm(A) <= sqrt(2) * norm_lower_bound(A). Looks to be a very 445 | tight lower bound. 446 | """ 447 | max_abs = jnp.max(jnp.abs(A)) 448 | 449 | def calc(A): 450 | A = A / max_abs 451 | A_conj = A.conj() 452 | 453 | aa = jnp.real(A * A_conj) 454 | 455 | aa_sum0 = jnp.sum(aa, axis=0) 456 | aa_sum1 = jnp.sum(aa, axis=1) 457 | i = jnp.argmax(aa_sum0, 0) 458 | j = jnp.argmax(aa_sum1, 0) 459 | value0 = jax.lax.dynamic_index_in_dim(aa_sum0, i, 0, keepdims=False) 460 | value1 = jax.lax.dynamic_index_in_dim(aa_sum1, j, 0, keepdims=False) 461 | 462 | def gt_branch(): 463 | x = jax.lax.dynamic_index_in_dim(A, i, 1, keepdims=False) 464 | x = x.conj() @ A 465 | return max_abs * jnp.linalg.norm((x / jnp.linalg.norm(x)) @ A_conj.T) 466 | 467 | def le_branch(): 468 | x = jax.lax.dynamic_index_in_dim(A, j, 0, keepdims=False) 469 | x = A @ x.conj() 470 | return max_abs * jnp.linalg.norm(A_conj.T @ (x / jnp.linalg.norm(x))) 471 | 472 | return jax.lax.cond(value0 > value1, gt_branch, le_branch) 473 | 474 | def no_calc(_): 475 | return max_abs 476 | 477 | return jax.lax.cond(max_abs > 0, calc, no_calc, A) 478 | 479 | 480 | def _init_Q_exprs( 481 | t, scale, max_size, min_ndim_triangular, memory_save_mode, dtype, existing_Q=None 482 | ): 483 | """For a scalar or tensor `t`, we initialize its preconditioner `Q` and 484 | reusable contraction expressions for updating `Q` and preconditioning gradient. 485 | """ 486 | letters = string.ascii_lowercase + string.ascii_uppercase 487 | 488 | shape = t.shape 489 | if len(shape) == 0: # scalar 490 | Q = ( 491 | [scale * jnp.ones_like(t, dtype=dtype)] 492 | if existing_Q is None 493 | else existing_Q 494 | ) 495 | exprA = ",->" 496 | exprGs = [",->"] 497 | exprP = ",,->" 498 | else: # tensor 499 | if len(shape) > 13: 500 | raise ValueError( 501 | f"Got tensor with dim {len(t.shape)}; Einstein runs out of letters!" 502 | ) 503 | 504 | scale = scale ** (1 / len(shape)) 505 | 506 | if memory_save_mode is None: 507 | dim_diag = [False for _ in shape] 508 | elif memory_save_mode == "one_diag": 509 | rev_sorted_dims = np.argsort(shape)[::-1] 510 | dim_diag = [False for _ in shape] 511 | dim_diag[rev_sorted_dims[0]] = True 512 | elif memory_save_mode == "all_diag": 513 | dim_diag = [True for _ in shape] 514 | else: 515 | raise ValueError( 516 | f"Invalid memory_save_mode: {memory_save_mode}, must be one of " 517 | "[None, 'one_diag', 'all_diag']" 518 | ) 519 | 520 | Q = [] if existing_Q is None else existing_Q 521 | piece1A, piece2A, piece3A = ([], "", "") 522 | exprGs = [] 523 | piece1P, piece2P, piece3P, piece4P = ([], [], "", "") 524 | for i, (size, dim_d) in enumerate(zip(shape, dim_diag)): 525 | if ( 526 | size == 1 527 | or size > max_size 528 | or len(shape) < min_ndim_triangular 529 | or dim_d 530 | ): 531 | # use diagonal matrix as preconditioner for this dim 532 | if existing_Q is None: 533 | Q.append(scale * jnp.ones(size, dtype=dtype)) 534 | 535 | piece1A.append(letters[i]) 536 | piece2A = piece2A + letters[i] 537 | piece3A = piece3A + letters[i] 538 | 539 | piece1 = "".join( 540 | [ 541 | (letters[i + 13] if j == i else letters[j]) 542 | for j in range(len(shape)) 543 | ] 544 | ) 545 | exprGs.append(piece1 + "," + piece1 + "->" + letters[i + 13]) 546 | 547 | piece1P.append(letters[i + 13]) 548 | piece2P.append(letters[i + 13]) 549 | piece3P = piece3P + letters[i + 13] 550 | piece4P = piece4P + letters[i + 13] 551 | else: 552 | # use triangular matrix as preconditioner for this dim 553 | if existing_Q is None: 554 | Q.append(scale * jnp.eye(size, dtype=dtype)) 555 | 556 | piece1A.append(letters[i] + letters[i + 13]) 557 | piece2A = piece2A + letters[i + 13] 558 | piece3A = piece3A + letters[i] 559 | 560 | piece1 = "".join( 561 | [ 562 | (letters[i + 13] if j == i else letters[j]) 563 | for j in range(len(shape)) 564 | ] 565 | ) 566 | piece2 = "".join( 567 | [ 568 | (letters[i + 26] if j == i else letters[j]) 569 | for j in range(len(shape)) 570 | ] 571 | ) 572 | exprGs.append( 573 | piece1 + "," + piece2 + "->" + letters[i + 13] + letters[i + 26] 574 | ) 575 | 576 | a, b, c = (letters[i], letters[i + 13], letters[i + 26]) 577 | piece1P.append(a + b) 578 | piece2P.append(a + c) 579 | piece3P = piece3P + c 580 | piece4P = piece4P + b 581 | 582 | exprA = ",".join(piece1A) + "," + piece2A + "->" + piece3A 583 | exprP = ( 584 | ",".join(piece1P) + "," + ",".join(piece2P) + "," + piece3P + "->" + piece4P 585 | ) 586 | 587 | exprGs = tuple(exprGs) 588 | if existing_Q is not None: 589 | return exprA, exprGs, exprP 590 | return [Q, (exprA, exprGs, exprP)] 591 | 592 | 593 | def _solve_triangular_right(X, A): 594 | """Compute X @ inv(A). 595 | 596 | A triangular solve has roughly the same complexity as a matmul. 597 | """ 598 | X_ndim = X.ndim 599 | if X_ndim < 2: 600 | X = X[None, :] 601 | 602 | dtype_in = jnp.promote_types(A.dtype, X.dtype) 603 | A, X = A.astype(dtype_in), X.astype(dtype_in) 604 | leading_dims = 0 605 | if X.ndim > 2: 606 | leading_dims = X.ndim - 2 607 | solve_fn = partial(jax.lax.linalg.triangular_solve, left_side=False, lower=False) 608 | for _ in range(leading_dims): 609 | solve_fn = vmap(solve_fn, in_axes=(None, 0)) 610 | solution = solve_fn(A, X) 611 | 612 | if X_ndim < 2: 613 | return solution[0] 614 | return solution 615 | 616 | 617 | def _conjB(Q, G, V): 618 | """Compute conjB.""" 619 | order = G.ndim 620 | p = list(range(order)) 621 | conjB = jnp.transpose(V.conj(), p[1:] + p[:1]) 622 | for i, q in enumerate(Q): 623 | conjB = conjB / q if q.ndim < 2 else _solve_triangular_right(conjB, q) 624 | if i < order - 1: 625 | conjB = jnp.swapaxes(conjB, i, order - 1) 626 | return conjB 627 | 628 | 629 | def _update_precond(Q, G, conjB, exprs, precond_lr): 630 | """Compute A and update Q.""" 631 | exprA, exprGs, _ = exprs 632 | 633 | A = jnp.einsum(exprA, *Q, G) 634 | 635 | A_conj = A.conj() 636 | conjB_conj = conjB.conj() 637 | 638 | def _update_single_q(i, q): 639 | term1 = jnp.einsum(exprGs[i], A, A_conj) 640 | term2 = jnp.einsum(exprGs[i], conjB_conj, conjB) 641 | 642 | tmp = term1 - term2 643 | tmp *= precond_lr 644 | if q.ndim < 2: 645 | tmp *= q 646 | tmp /= _add_tiny(jnp.max(jnp.abs(term1 + term2))) 647 | q -= tmp 648 | else: 649 | tmp = jnp.triu(tmp) 650 | tmp /= _add_tiny(_norm_lower_bound(term1 + term2)) 651 | tmp @= q 652 | q -= tmp 653 | return q 654 | 655 | return [_update_single_q(i, q) for i, q in enumerate(Q)] 656 | 657 | 658 | def _precond_grad(Q, G, exprs): 659 | """Precondition gradient G with preconditioner Q.""" 660 | exprP = exprs[-1] 661 | return jnp.einsum(exprP, *[q.conj() for q in Q], *Q, G) 662 | -------------------------------------------------------------------------------- /psgd_jax/low_rank_approximation.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Union, Callable, NamedTuple 2 | 3 | import jax 4 | from jax import numpy as jnp 5 | from jax.random import PRNGKey 6 | 7 | from optax import tree_utils as otu 8 | from optax._src import base, transform, clipping 9 | from optax._src.numerics import safe_int32_increment 10 | from optax._src.utils import canonicalize_dtype 11 | from optax._src.combine import chain 12 | 13 | from psgd_jax.utils import add_eps, apply_momentum 14 | 15 | 16 | class PSGDLRAState(NamedTuple): 17 | count: jax.Array 18 | key: PRNGKey 19 | mu: Optional[base.Updates] 20 | U: jax.Array 21 | V: jax.Array 22 | d: jax.Array 23 | 24 | 25 | def scale_by_lra( 26 | preconditioner_update_probability: float = 1.0, 27 | b1: float = 0.9, 28 | nesterov: bool = False, 29 | uvd_rank_of_approximation: int = 10, 30 | precond_lr: Union[float, Callable[[int], float]] = 0.1, 31 | precond_init_scale: Optional[float] = None, 32 | update_global_norm_clip: Optional[float] = None, 33 | step_normalizer_order: str = "2nd", 34 | seed: Optional[PRNGKey] = None, 35 | mu_dtype: Optional[Union[str, jnp.dtype]] = None, 36 | precision: str = "tensorfloat32", 37 | ) -> base.GradientTransformationExtraArgs: 38 | """ 39 | Implements UVd PSGD from https://github.com/lixilinx/psgd_torch. 40 | 41 | Args: 42 | preconditioner_update_probability: float, probability of updating the 43 | preconditioner. 44 | b1: float, momentum parameter. 45 | nesterov: bool, whether to use Nesterov momentum. 46 | uvd_rank_of_approximation: int, rank of approximation for uvd preconditioner. 47 | precond_lr: float or callable, learning rate for the preconditioner. 48 | precond_init_scale: optional float, initial scale for the preconditioner. 49 | update_global_norm_clip: optional float, clip updates by global norm. 50 | step_normalizer_order: str, '1st' or '2nd'. 51 | seed: Optional PRNGKey, random seed. 52 | mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator. 53 | Defaults to the same dtype as the parameters. 54 | precision: str, precision for matmul, 'bfloat16', 'tensorfloat32', 'float32'. 55 | 56 | Returns: 57 | optax.GradientTransformationExtraArgs 58 | """ 59 | mu_dtype = canonicalize_dtype(mu_dtype) 60 | 61 | def init_fn(params): 62 | key = seed if seed is not None else jax.random.PRNGKey(36) 63 | 64 | # momentum 65 | mu = None 66 | if b1 > 0: 67 | print("PSGD: Using momentum.") 68 | mu = otu.tree_zeros_like(params, mu_dtype) 69 | 70 | # preconditioners 71 | n_params = sum([x.size for x in jax.tree.leaves(params)]) 72 | key, subkey = jax.random.split(key) 73 | U = jax.random.normal( 74 | subkey, (n_params, uvd_rank_of_approximation), dtype=jnp.float32 75 | ) 76 | U /= (n_params * (uvd_rank_of_approximation + 10)) ** 0.5 77 | 78 | key, subkey = jax.random.split(key) 79 | V = jax.random.normal( 80 | subkey, (n_params, uvd_rank_of_approximation), dtype=jnp.float32 81 | ) 82 | V /= (n_params * (uvd_rank_of_approximation + 10)) ** 0.5 83 | 84 | d = jnp.ones((n_params, 1), jnp.float32) 85 | 86 | # initial state 87 | return PSGDLRAState( 88 | count=jnp.zeros([], jnp.int32), key=key, mu=mu, U=U, V=V, d=d 89 | ) 90 | 91 | def update_fn( 92 | updates: base.Updates, 93 | state: PSGDLRAState, 94 | params: base.Params = None, 95 | Hvp: Optional[base.Updates] = None, 96 | vector: Optional[base.Updates] = None, 97 | update_preconditioner: Optional[bool] = None, 98 | ): 99 | del params 100 | # use hessian preconditioning if hessian provided 101 | # otherwise use gg^T whitening type preconditioning 102 | hessian_based_preconditioning = Hvp is not None 103 | if hessian_based_preconditioning and ( 104 | vector is None or update_preconditioner is None 105 | ): 106 | raise ValueError( 107 | "If using Hessian-based preconditioning, must also pass in random vector and " 108 | "update_preconditioner to PSGD's update function. See README for more info." 109 | ) 110 | 111 | count_inc = safe_int32_increment(state.count) 112 | key = state.key 113 | 114 | precond_lr_in = precond_lr 115 | if isinstance(precond_lr, Callable): 116 | precond_lr_in = precond_lr(count_inc) 117 | 118 | def _update_precond(key: PRNGKey, state: PSGDLRAState, Hvs, vs): 119 | v = jnp.concatenate( 120 | [jnp.reshape(x, (-1, 1)) for x in jax.tree.leaves(vs)], 0 121 | ) 122 | h = jnp.concatenate( 123 | [jnp.reshape(x, (-1, 1)) for x in jax.tree.leaves(Hvs)], 0 124 | ) 125 | 126 | # init d 127 | if precond_init_scale is not None: 128 | init_scale = precond_init_scale 129 | else: 130 | if hessian_based_preconditioning: 131 | init_scale = (jnp.sum(v * v) / jnp.sum(h * h)) ** 0.25 132 | else: 133 | init_scale = (len(h) / jnp.sum(jnp.square(h))) ** 0.25 134 | d = jax.lax.cond( 135 | state.count == 0, lambda: state.d * init_scale, lambda: state.d 136 | ) 137 | 138 | # update preconditioner 139 | key, subkey = jax.random.split(key) 140 | U, V, d = _update_precond_UVd_math( 141 | subkey, 142 | state.U, 143 | state.V, 144 | d, 145 | v, 146 | h, 147 | precond_lr_in, 148 | step_normalizer_order, 149 | precision, 150 | ) 151 | 152 | return key, U, V, d 153 | 154 | def _dont_update_precond(key, state, Hvs, vs): 155 | return key, state.U, state.V, state.d 156 | 157 | if not hessian_based_preconditioning: 158 | # update cond and vector not passed in, create here 159 | key, subkey = jax.random.split(key) 160 | update_preconditioner = jnp.logical_or( 161 | jax.random.uniform(subkey) < preconditioner_update_probability, 162 | state.count < 2, 163 | ) 164 | key, subkey = jax.random.split(key) 165 | vector = otu.tree_random_like(subkey, updates, jax.random.normal) 166 | # use grads as Hvp 167 | Hvp = updates 168 | 169 | key, U, V, d = jax.lax.cond( 170 | update_preconditioner, 171 | _update_precond, 172 | _dont_update_precond, 173 | key, 174 | state, 175 | Hvp, 176 | vector, 177 | ) 178 | 179 | # momentum 180 | mu = None 181 | if state.mu is not None: 182 | updates, mu = apply_momentum(updates, state.mu, count_inc, b1, nesterov) 183 | 184 | # preconditioning 185 | flat_updates = jnp.concatenate( 186 | [jnp.reshape(x, (-1, 1)) for x in jax.tree.leaves(updates)], 0 187 | ) 188 | flat_updates = _precond_grad_UVd_math(U, V, d, flat_updates) 189 | with jax.ensure_compile_time_eval(): 190 | params_struct = jax.tree.structure(updates) 191 | param_sizes = [x.size for x in jax.tree.leaves(updates)] 192 | param_cumsizes = [x.item() for x in jnp.cumsum(jnp.array(param_sizes))] 193 | param_shapes = [x.shape for x in jax.tree.leaves(updates)] 194 | flat_updates = [ 195 | jnp.reshape(flat_updates[idx - size : idx], s) 196 | for idx, size, s in zip(param_cumsizes, param_sizes, param_shapes) 197 | ] 198 | updates = jax.tree.unflatten(params_struct, flat_updates) 199 | 200 | # clipping 201 | if update_global_norm_clip is not None: 202 | updates, _ = clipping.clip_by_global_norm(update_global_norm_clip).update( 203 | updates, base.EmptyState 204 | ) 205 | 206 | mu = otu.tree_cast(mu, mu_dtype) 207 | state = PSGDLRAState(count=count_inc, key=key, mu=mu, U=U, V=V, d=d) 208 | return updates, state 209 | 210 | return base.GradientTransformationExtraArgs(init_fn, update_fn) 211 | 212 | 213 | def low_rank_approximation( 214 | learning_rate: Union[float, Callable[[int], float]] = 0.01, 215 | preconditioner_update_probability: float = 1.0, 216 | b1: float = 0.9, 217 | nesterov: bool = False, 218 | weight_decay: float = 0.0, 219 | mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, 220 | uvd_rank_of_approximation: int = 10, 221 | precond_lr: Union[float, Callable[[int], float]] = 0.1, 222 | precond_init_scale: Optional[float] = None, 223 | update_global_norm_clip: Optional[float] = None, 224 | step_normalizer_order: str = "2nd", 225 | seed: Optional[PRNGKey] = None, 226 | mu_dtype: Optional[Union[str, jnp.dtype]] = None, 227 | precision: str = "tensorfloat32", 228 | ) -> base.GradientTransformationExtraArgs: 229 | """ 230 | Implements UVd PSGD from https://github.com/lixilinx/psgd_torch. 231 | 232 | Args: 233 | learning_rate: float or callable, learning rate for the optimizer. 234 | preconditioner_update_probability: float, probability of updating the 235 | preconditioner. 236 | b1: float, momentum parameter. 237 | nesterov: bool, whether to use Nesterov momentum. 238 | weight_decay: float, weight decay. 239 | mask: optional mask for weight decay. 240 | uvd_rank_of_approximation: int, rank of approximation for uvd preconditioner. 241 | precond_lr: float or callable, learning rate for the preconditioner. 242 | precond_init_scale: optional float, initial scale for the preconditioner. 243 | update_global_norm_clip: optional float, clip updates by global norm. 244 | step_normalizer_order: str, '1st' or '2nd'. 245 | seed: Optional PRNGKey, random seed. 246 | mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator. 247 | Defaults to the same dtype as the parameters. 248 | precision: str, precision for matmul, 'bfloat16', 'tensorfloat32', 'float32'. 249 | 250 | Returns: 251 | optax.GradientTransformationExtraArgs 252 | """ 253 | opt = [ 254 | scale_by_lra( 255 | preconditioner_update_probability=preconditioner_update_probability, 256 | b1=b1, 257 | nesterov=nesterov, 258 | uvd_rank_of_approximation=uvd_rank_of_approximation, 259 | precond_lr=precond_lr, 260 | precond_init_scale=precond_init_scale, 261 | update_global_norm_clip=update_global_norm_clip, 262 | step_normalizer_order=step_normalizer_order, 263 | seed=seed, 264 | mu_dtype=mu_dtype, 265 | precision=precision, 266 | ) 267 | ] 268 | if weight_decay > 0: 269 | opt.append(transform.add_decayed_weights(weight_decay, mask=mask)) 270 | opt.append(transform.scale_by_learning_rate(learning_rate)) 271 | return chain(*opt) 272 | 273 | 274 | def _IpUVtmatvec(U, V, x): 275 | """Returns (I + U*V')*x. All variables are either matrices or column vectors.""" 276 | return x + jnp.matmul(U, jnp.matmul(V.T, x)) 277 | 278 | 279 | def _update_precond_UVd_math( 280 | key, U, V, d, v, h, precond_lr, step_normalizer, precision 281 | ): 282 | """ 283 | Update preconditioner Q = (I + U*V')*diag(d) with (vector, Hessian-vector product) = (v, h). 284 | State variables U, V and d are updated inplace. 285 | 286 | U, V, d, v, and h are either matrices or column vectors. 287 | """ 288 | with jax.default_matmul_precision(precision): 289 | # balance the numerical dynamic ranges of U and V; optional 290 | def _balance(U, V): 291 | normU = jnp.linalg.norm(U) 292 | normV = jnp.linalg.norm(V) 293 | rho = jnp.sqrt(normU / normV) 294 | U = U / rho 295 | V = V * rho 296 | return U, V 297 | 298 | key, subkey = jax.random.split(key) 299 | U, V = jax.lax.cond( 300 | jax.random.uniform(subkey) < 0.01, _balance, lambda u, v: (u, v), U, V 301 | ) 302 | 303 | Qh = _IpUVtmatvec(U, V, d * h) 304 | Ph = d * _IpUVtmatvec(V, U, Qh) 305 | 306 | VtU = V.T @ U 307 | I = jnp.eye(VtU.shape[0], dtype=VtU.dtype) 308 | IpVtU = I + VtU 309 | invQtv = v / d 310 | 311 | # cast to float32 for accuracy, no slowdown as 'a' is only (rank, rank) 312 | orig_dtype = U.dtype 313 | IpVtU = IpVtU.astype(jnp.float32) 314 | U_solve = jnp.linalg.solve(IpVtU.T, (U.T @ invQtv).astype(jnp.float32)) 315 | invQtv = invQtv - V @ U_solve.astype(orig_dtype) 316 | V_solve = jnp.linalg.solve(IpVtU, (V.T @ invQtv).astype(jnp.float32)) 317 | invPv = invQtv - U @ V_solve.astype(orig_dtype) 318 | IpVtU = IpVtU.astype(orig_dtype) 319 | invPv = invPv / d 320 | 321 | nablaD = Ph * h - v * invPv 322 | if step_normalizer == "2nd": 323 | mu = precond_lr * jnp.min( 324 | jax.lax.rsqrt(add_eps(Ph * Ph + v * v)) 325 | * jax.lax.rsqrt(add_eps(h * h + invPv * invPv)) 326 | ) # two seperate rsqrt's to avoid underflow 327 | else: 328 | mu = precond_lr / add_eps(jnp.max(jnp.abs(nablaD))) 329 | d -= mu * d * nablaD 330 | 331 | # update either U or V, not both at the same time 332 | a, b = Qh, invQtv 333 | 334 | def _update_U(U, V): 335 | atV = a.T @ V 336 | btV = b.T @ V 337 | atVVt = atV @ V.T 338 | btVVt = btV @ V.T 339 | if step_normalizer == "2nd": 340 | mu = precond_lr / add_eps( 341 | jnp.linalg.norm(a) * jnp.linalg.norm(atVVt) 342 | + jnp.linalg.norm(b) * jnp.linalg.norm(btVVt) 343 | ) 344 | else: # '1st' 345 | norm = jnp.sqrt( 346 | jnp.abs( 347 | (a.T @ a) * (atVVt @ atVVt.T) 348 | + (b.T @ b) * (btVVt @ btVVt.T) 349 | - 2 * (a.T @ b) * (atVVt @ btVVt.T) 350 | ) 351 | ) 352 | mu = precond_lr / add_eps(norm) 353 | 354 | U -= mu * (a @ (atV @ IpVtU) - b @ (btV @ IpVtU)) 355 | 356 | return U, V 357 | 358 | def _update_V(U, V): 359 | atU = a.T @ U 360 | btU = b.T @ U 361 | UUta = U @ atU.T 362 | UUtb = U @ btU.T 363 | if step_normalizer == "2nd": 364 | mu = precond_lr / add_eps( 365 | jnp.linalg.norm(a) * jnp.linalg.norm(UUta) 366 | + jnp.linalg.norm(b) * jnp.linalg.norm(UUtb) 367 | ) 368 | else: # '1st' 369 | norm = jnp.sqrt( 370 | jnp.abs( 371 | (UUta.T @ UUta) * (a.T @ a) 372 | + (UUtb.T @ UUtb) * (b.T @ b) 373 | - 2 * (UUta.T @ UUtb) * (a.T @ b) 374 | ) 375 | ) 376 | mu = precond_lr / add_eps(norm) 377 | 378 | V -= mu * ((a + V @ atU.T) @ atU - (b + V @ btU.T) @ btU) 379 | 380 | return U, V 381 | 382 | U, V = jax.lax.cond(jax.random.uniform(key) < 0.5, _update_U, _update_V, U, V) 383 | 384 | return U, V, d 385 | 386 | 387 | def _precond_grad_UVd_math(U, V, d, g): 388 | """ 389 | Preconditioning gradient g with Q = (I + U*V')*diag(d). 390 | 391 | All variables here are either matrices or column vectors. 392 | """ 393 | g = _IpUVtmatvec(U, V, d * g) 394 | g = d * _IpUVtmatvec(V, U, g) 395 | return g 396 | -------------------------------------------------------------------------------- /psgd_jax/psgd_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | from typing import Union, Optional 4 | import numpy as np 5 | from matplotlib import pyplot as plt 6 | from pprint import pprint 7 | 8 | import jax 9 | from jax import numpy as jnp, jit, sharding 10 | from jax.random import uniform 11 | from jax.experimental import mesh_utils 12 | import optax 13 | 14 | from psgd_jax import hessian_helper 15 | from psgd_jax.xmat import xmat 16 | from psgd_jax.low_rank_approximation import low_rank_approximation 17 | from psgd_jax.affine import affine 18 | from psgd_jax.kron import kron 19 | 20 | 21 | os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" 22 | 23 | 24 | def _plot_rosenbrock(test_iter, plot_title, losses, save_dir=None): 25 | """plot rosenbrock test from batch of results. 26 | 27 | Adapted from https://github.com/jettify/pytorch-optimizer""" 28 | 29 | def rosenbrock(p): 30 | x, y = p 31 | return (1 - x) ** 2 + 100 * (y - x**2) ** 2 32 | 33 | x = np.linspace(-2, 2, 250) 34 | y = np.linspace(-1, 3, 250) 35 | minimum = (1.0, 1.0) 36 | 37 | X, Y = np.meshgrid(x, y) 38 | Z = rosenbrock([X, Y]) 39 | 40 | # plot 4 subplots 41 | fig = plt.figure(figsize=(16, 10)) 42 | 43 | # plot losses in top left 44 | ax = fig.add_subplot(2, 2, 1) 45 | ax.plot(losses) 46 | ax.set_title(f"Losses (final loss = {losses[-1]})") 47 | ax.set_yscale("log") 48 | ax.set_ylim([min(losses) * 0.5, max(losses) * 2]) 49 | 50 | # plot three examples 51 | for i, sample in enumerate(test_iter): 52 | iter_x, iter_y = sample[0, :], sample[1, :] 53 | ax = fig.add_subplot(2, 2, i + 2) 54 | ax.contour(X, Y, Z, 90, cmap="jet") 55 | ax.plot(iter_x, iter_y, color="r", marker="x", markersize=4) 56 | ax.set_title(f"{plot_title}, {len(iter_x) - 1} steps") 57 | ax.set_xlim([-2, 2]) 58 | ax.set_ylim([-1, 3]) 59 | ax.plot(*minimum, "gD") 60 | if i == 2: 61 | break 62 | 63 | if save_dir is not None: 64 | plt.savefig(os.path.join(save_dir, f"{plot_title}.png")) 65 | plt.show() 66 | 67 | 68 | @jit 69 | def _loss_fn_rosenbrock(xs): 70 | # rosenbrock function 71 | l = lambda x, y: (1 - x) ** 2 + 1 * (y - x**2) ** 2 72 | flat_xs = jax.tree.leaves(xs)[:-1] 73 | return sum([l(x[0], x[1]) for x in flat_xs]) / len(flat_xs) 74 | 75 | 76 | @jit 77 | def _make_params(key): 78 | # params in [-2, 2] and [-1, 3] 79 | n_sets = 16 80 | keys = jax.random.split(key, n_sets * 2) 81 | keys = jnp.reshape(keys, (n_sets, 2, 2)) 82 | params = { 83 | f"{i:02}": jnp.array( 84 | [ 85 | uniform(k[0], [], jnp.float32, -2, -1), 86 | uniform(k[1], [], jnp.float32, -1, 3), 87 | ] 88 | ) 89 | for i, k in enumerate(keys) 90 | } 91 | params["00"] = jnp.array([-2, 2], dtype=jnp.float32) 92 | params["scalar"] = jnp.array(0.0, dtype=jnp.float32) 93 | return params 94 | 95 | 96 | def _run_test( 97 | optimizer: Union[ 98 | optax.GradientTransformation, optax.GradientTransformationExtraArgs 99 | ], 100 | opt_state: optax.OptState, 101 | params: dict, 102 | steps: int, 103 | psgd_use_hessian: Optional[bool] = False, 104 | psgd_update_probability: float = 1.0, 105 | ): 106 | def loop_body(i, state): 107 | params, opt_state, key, losses, recorded_params = state 108 | 109 | key, subkey = jax.random.split(key) 110 | if psgd_use_hessian: 111 | # use helper to compute hvp and pass into PSGD 112 | loss_out, grads, hvp, vector, update_precond = hessian_helper( 113 | subkey, 114 | i, 115 | _loss_fn_rosenbrock, 116 | params, 117 | loss_fn_extra_args=(), 118 | has_aux=False, 119 | preconditioner_update_probability=psgd_update_probability, 120 | ) 121 | updates, opt_state = optimizer.update( 122 | grads, 123 | opt_state, 124 | params, 125 | Hvp=hvp, 126 | vector=vector, 127 | update_preconditioner=update_precond, 128 | ) 129 | else: 130 | loss_out, updates = jax.value_and_grad(_loss_fn_rosenbrock)(params) 131 | updates, opt_state = optimizer.update(updates, opt_state, params) 132 | 133 | params = optax.apply_updates(params, updates) 134 | losses = losses.at[i].set(loss_out) 135 | recorded_params = recorded_params.at[:, :, i + 1].set( 136 | jnp.stack(jax.tree.leaves(params)[:-1]) 137 | ) 138 | return params, opt_state, key, losses, recorded_params 139 | 140 | losses = jnp.zeros([steps]) 141 | recorded_params = jnp.zeros([len(jax.tree.leaves(params)[:-1]), 2, steps + 1]) 142 | recorded_params = recorded_params.at[:, :, 0].set( 143 | jnp.stack(jax.tree.leaves(params)[:-1]) 144 | ) 145 | init_state = (params, opt_state, jax.random.PRNGKey(0), losses, recorded_params) 146 | params, opt_state, _, losses, recorded_params = jax.lax.fori_loop( 147 | 0, steps, loop_body, init_state 148 | ) 149 | 150 | return params, opt_state, losses, recorded_params 151 | 152 | 153 | def main(): 154 | print("Testing PSGD variants on Rosenbrock function") 155 | 156 | for use_hessian in [False, True]: 157 | for precond_type in ["kron", "xmat", "low_rank_approximation", "affine"]: 158 | if use_hessian and precond_type == "kron": 159 | # kron just uses whitening (gg^T) 160 | continue 161 | steps = 500 162 | psgd_update_probability = 1.0 163 | learning_rate = optax.linear_schedule(0.1, 0.0, steps) 164 | kwargs = { 165 | "learning_rate": learning_rate, 166 | "preconditioner_update_probability": psgd_update_probability, 167 | "b1": 0.9, 168 | "precond_lr": 0.1, 169 | "update_global_norm_clip": np.sqrt(32.0), 170 | } 171 | if precond_type == "xmat": 172 | optimizer = partial(xmat, **kwargs) 173 | elif precond_type == "low_rank_approximation": 174 | optimizer = partial(low_rank_approximation, **kwargs) 175 | elif precond_type == "affine": 176 | optimizer = partial(affine, **kwargs) 177 | elif precond_type == "kron": 178 | del kwargs["precond_lr"] 179 | del kwargs["update_global_norm_clip"] 180 | optimizer = partial( 181 | kron, 182 | memory_save_mode=None, 183 | momentum_into_precond_update=False, 184 | **kwargs, 185 | ) 186 | else: 187 | optimizer = None 188 | 189 | plot_title = f"{precond_type} PSGD {'Hvp' if use_hessian else 'gg^T'}" 190 | print(plot_title) 191 | 192 | seed = np.random.randint(0, 2**30) 193 | 194 | params = _make_params(jax.random.PRNGKey(seed)) 195 | 196 | optimizer = optimizer() 197 | opt_state = optimizer.init(params) 198 | pprint(opt_state) 199 | 200 | P = sharding.PartitionSpec 201 | devices = mesh_utils.create_device_mesh((2,)) 202 | mesh = sharding.Mesh(devices, ("m",)) 203 | 204 | def create_spec(x): 205 | if x.size > 1: 206 | shards = ("m" if x.shape[0] % 2 == 0 else None,) + (None,) * ( 207 | x.ndim - 1 208 | ) 209 | return sharding.NamedSharding(mesh, P(*shards)) 210 | else: 211 | return sharding.NamedSharding(mesh, P()) 212 | 213 | params_sharding = jax.tree.map(create_spec, params) 214 | opt_state_sharding = jax.tree.map(create_spec, opt_state) 215 | 216 | params = jax.device_put(params, params_sharding) 217 | opt_state = jax.device_put(opt_state, opt_state_sharding) 218 | 219 | initial_loss = _loss_fn_rosenbrock(params) 220 | print(f"Initial loss = {initial_loss}") 221 | 222 | run_test_fn = jit( 223 | _run_test, 224 | static_argnums=(0, 3, 4, 5), 225 | out_shardings=( 226 | params_sharding, 227 | opt_state_sharding, 228 | sharding.NamedSharding(mesh, P()), 229 | sharding.NamedSharding(mesh, P()), 230 | ), 231 | ) 232 | 233 | params, opt_state, losses, recorded_params = run_test_fn( 234 | optimizer, 235 | opt_state, 236 | params, 237 | steps, 238 | use_hessian, 239 | psgd_update_probability, 240 | ) 241 | 242 | final_loss = _loss_fn_rosenbrock(params) 243 | print(f"Final loss = {final_loss}") 244 | 245 | print("Output sharding:") 246 | print(jax.tree.map(lambda x: x.sharding, params)) 247 | print(jax.tree.map(lambda x: x.sharding, opt_state)) 248 | 249 | _plot_rosenbrock(recorded_params, plot_title, losses) 250 | 251 | 252 | if __name__ == "__main__": 253 | main() 254 | -------------------------------------------------------------------------------- /psgd_jax/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple 2 | 3 | import jax 4 | from jax import numpy as jnp 5 | from jax.random import PRNGKey 6 | from optax import tree_utils as otu 7 | from optax._src import base 8 | from optax._src.numerics import safe_int32_increment 9 | 10 | 11 | def hessian_helper( 12 | key: PRNGKey, 13 | train_step: int, 14 | loss_fn: Callable, 15 | params: base.Params, 16 | loss_fn_extra_args: Tuple = (), 17 | has_aux: bool = False, 18 | preconditioner_update_probability: float = 1.0, 19 | ): 20 | """Helper function for computing hessian vector product for PSGD. 21 | 22 | This helps handle the calculation of a hessian vector product if wanting to use exact 23 | hvp instead of the default gradient whitening style preconditioner. It returns the 24 | loss fn output, gradients, hvp, random vector, and a bool of whether we're updating the 25 | preconditioner this step. The hvp, vector, and update cond are then passed into PSGD's 26 | update fn. This fn is not needed if wanting to use the default gradient whitening style 27 | preconditioner. 28 | 29 | Args: 30 | key: PRNGKey, random key. 31 | train_step: int, current train step needed to init preconditioner on first step. 32 | loss_fn: callable, loss function. 33 | params: flax.Params, model parameters. 34 | loss_fn_extra_args: tuple, extra arguments for loss function to be used as 35 | `loss_fn(params, *loss_fn_extra_args)`. 36 | has_aux: bool, whether loss function has aux output. 37 | preconditioner_update_probability: float, probability of updating the preconditioner. 38 | 39 | Returns: 40 | loss_out: jnp.ndarray, output of loss function. 41 | grads: flax.Params, gradients. 42 | hvp: flax.Params, hessian vector product. 43 | vector: flax.Params, random vector. 44 | update_preconditioner: bool, whether we're updating preconditioner this step. 45 | """ 46 | obj_fn = lambda params: loss_fn(params, *loss_fn_extra_args) 47 | key1, key2 = jax.random.split(key) 48 | 49 | def grad_fn(params): 50 | loss_out, grad = jax.value_and_grad(obj_fn, has_aux=has_aux)(params) 51 | return grad, loss_out 52 | 53 | def hvp_fn(params): 54 | vector = otu.tree_random_like(key1, params, jax.random.normal) 55 | grad, hvp, loss_out = jax.jvp(grad_fn, (params,), (vector,), has_aux=True) 56 | return grad, loss_out, hvp, vector 57 | 58 | # TODO (evanatyourservice): finite difference hvp option 59 | 60 | def g_fn(params): 61 | grad, loss_out = grad_fn(params) 62 | dummy_hvp = jax.tree.map(jnp.zeros_like, params) 63 | dummy_vector = jax.tree.map(jnp.zeros_like, params) 64 | return grad, loss_out, dummy_hvp, dummy_vector 65 | 66 | update_precond = jnp.logical_or( 67 | jax.random.uniform(key2) < preconditioner_update_probability, train_step < 2 68 | ) 69 | 70 | grad, loss_out, hvp, vector = jax.lax.cond(update_precond, hvp_fn, g_fn, params) 71 | return loss_out, grad, hvp, vector, update_precond 72 | 73 | 74 | def apply_momentum( 75 | updates: base.Updates, momentum: base.Updates, step, b1, nesterov 76 | ) -> Tuple[base.Updates, base.Updates]: 77 | # ema 78 | mu = otu.tree_update_moment(updates, momentum, b1, 1) 79 | if nesterov: 80 | # nesterov momentum for ema with bias correction 81 | # https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ 82 | updates = jax.tree.map( 83 | lambda m, g: b1 * m + (1 - b1) * g, 84 | otu.tree_bias_correction(mu, b1, safe_int32_increment(step)), 85 | otu.tree_bias_correction(updates, b1, step), 86 | ) 87 | else: 88 | # bias correction only 89 | updates = otu.tree_bias_correction(mu, b1, step) 90 | 91 | return updates, mu 92 | 93 | 94 | def add_eps(x): 95 | return jnp.clip(x, 1e-25, None) 96 | -------------------------------------------------------------------------------- /psgd_jax/xmat.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional, Union, Callable, NamedTuple 2 | 3 | import jax 4 | from jax import numpy as jnp 5 | from jax.random import PRNGKey 6 | 7 | from optax import tree_utils as otu 8 | from optax._src import base, transform, clipping 9 | from optax._src.numerics import safe_int32_increment 10 | from optax._src.utils import canonicalize_dtype 11 | from optax._src.combine import chain 12 | 13 | from psgd_jax.utils import add_eps, apply_momentum 14 | 15 | 16 | class PSGDXMatState(NamedTuple): 17 | count: jax.Array 18 | key: PRNGKey 19 | mu: Optional[base.Updates] 20 | a: jax.Array 21 | b: jax.Array 22 | 23 | 24 | def scale_by_xmat( 25 | preconditioner_update_probability: float = 1.0, 26 | b1: float = 0.9, 27 | nesterov: bool = False, 28 | precond_lr: Union[float, Callable[[int], float]] = 0.1, 29 | precond_init_scale: Optional[float] = None, 30 | update_global_norm_clip: Optional[float] = None, 31 | step_normalizer_order: str = "2nd", 32 | seed: Optional[PRNGKey] = None, 33 | mu_dtype: Optional[Union[str, jnp.dtype]] = None, 34 | precision: str = "tensorfloat32", 35 | ) -> base.GradientTransformationExtraArgs: 36 | """ 37 | Implements XMat PSGD from https://github.com/lixilinx/psgd_torch. 38 | 39 | Args: 40 | preconditioner_update_probability: float, probability of updating the 41 | preconditioner. 42 | b1: float, momentum parameter. 43 | nesterov: bool, whether to use Nesterov momentum. 44 | precond_lr: float or callable, learning rate for the preconditioner. 45 | precond_init_scale: optional float, initial scale for the preconditioner. 46 | update_global_norm_clip: optional float, clip updates by global norm. 47 | step_normalizer_order: str, '1st' or '2nd'. 48 | seed: Optional PRNGKey, random seed. 49 | mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator. 50 | Defaults to the same dtype as the parameters. 51 | precision: str, precision for matmul, 'bfloat16', 'tensorfloat32', 'float32'. 52 | 53 | Returns: 54 | optax.GradientTransformationExtraArgs 55 | """ 56 | mu_dtype = canonicalize_dtype(mu_dtype) 57 | 58 | def init_fn(params): 59 | key = seed if seed is not None else jax.random.PRNGKey(36) 60 | 61 | # momentum 62 | mu = None 63 | if b1 > 0: 64 | print("PSGD: Using momentum.") 65 | mu = otu.tree_zeros_like(params, mu_dtype) 66 | 67 | # preconditioner 68 | n_params = sum([x.size for x in jax.tree.leaves(params)]) 69 | a = jnp.ones((n_params,), jnp.float32) 70 | b = jnp.zeros((n_params,), jnp.float32) 71 | 72 | # initial state 73 | return PSGDXMatState(count=jnp.zeros([], jnp.int32), key=key, mu=mu, a=a, b=b) 74 | 75 | def update_fn( 76 | updates: base.Updates, 77 | state: PSGDXMatState, 78 | params: base.Params = None, 79 | Hvp: Optional[base.Updates] = None, 80 | vector: Optional[base.Updates] = None, 81 | update_preconditioner: Optional[bool] = None, 82 | ): 83 | del params 84 | # use hessian preconditioning if hessian provided 85 | # otherwise use gg^T whitening type preconditioning 86 | hessian_based_preconditioning = Hvp is not None 87 | if hessian_based_preconditioning and ( 88 | vector is None or update_preconditioner is None 89 | ): 90 | raise ValueError( 91 | "If using Hessian-based preconditioning, must also pass in random vector and " 92 | "update_preconditioner to PSGD's update function. See README for more info." 93 | ) 94 | 95 | count_inc = safe_int32_increment(state.count) 96 | key = state.key 97 | 98 | precond_lr_in = precond_lr 99 | if isinstance(precond_lr, Callable): 100 | precond_lr_in = precond_lr(count_inc) 101 | 102 | def _update_precond(key: PRNGKey, state: PSGDXMatState, Hvs, vs): 103 | v = jnp.concatenate([jnp.reshape(x, (-1,)) for x in jax.tree.leaves(vs)], 0) 104 | h = jnp.concatenate( 105 | [jnp.reshape(x, (-1,)) for x in jax.tree.leaves(Hvs)], 0 106 | ) 107 | 108 | # init a 109 | if precond_init_scale is not None: 110 | init_scale = precond_init_scale 111 | else: 112 | if hessian_based_preconditioning: 113 | init_scale = (jnp.sum(v * v) / jnp.sum(h * h)) ** 0.25 114 | else: 115 | init_scale = (len(h) / jnp.sum(jnp.square(h))) ** 0.25 116 | a = jax.lax.cond( 117 | state.count == 0, lambda: state.a * init_scale, lambda: state.a 118 | ) 119 | 120 | # update preconditioner 121 | a, b = _update_precond_Xmat_math_( 122 | a, state.b, v, h, precond_lr_in, step_normalizer_order, precision 123 | ) 124 | 125 | return key, a, b 126 | 127 | def _dont_update_precond(key, state, Hvs, vs): 128 | return key, state.a, state.b 129 | 130 | if not hessian_based_preconditioning: 131 | # update cond and vector not passed in, create here 132 | key, subkey = jax.random.split(key) 133 | update_preconditioner = jnp.logical_or( 134 | jax.random.uniform(subkey) < preconditioner_update_probability, 135 | state.count < 2, 136 | ) 137 | key, subkey = jax.random.split(key) 138 | vector = otu.tree_random_like(subkey, updates, jax.random.normal) 139 | # use grads as Hvp 140 | Hvp = updates 141 | 142 | key, a, b = jax.lax.cond( 143 | update_preconditioner, 144 | _update_precond, 145 | _dont_update_precond, 146 | key, 147 | state, 148 | Hvp, 149 | vector, 150 | ) 151 | 152 | # momentum 153 | mu = None 154 | if state.mu is not None: 155 | updates, mu = apply_momentum(updates, state.mu, count_inc, b1, nesterov) 156 | 157 | # preconditioning 158 | flat_updates = jnp.concatenate( 159 | [jnp.reshape(x, (-1,)) for x in jax.tree.leaves(updates)], 0 160 | ) 161 | flat_updates = _precond_grad_Xmat_math(a, b, flat_updates) 162 | with jax.ensure_compile_time_eval(): 163 | params_struct = jax.tree.structure(updates) 164 | param_sizes = [x.size for x in jax.tree.leaves(updates)] 165 | param_cumsizes = [x.item() for x in jnp.cumsum(jnp.array(param_sizes))] 166 | param_shapes = [x.shape for x in jax.tree.leaves(updates)] 167 | flat_updates = [ 168 | jnp.reshape(flat_updates[idx - size : idx], s) 169 | for idx, size, s in zip(param_cumsizes, param_sizes, param_shapes) 170 | ] 171 | updates = jax.tree.unflatten(params_struct, flat_updates) 172 | 173 | # clipping 174 | if update_global_norm_clip: 175 | updates, _ = clipping.clip_by_global_norm(update_global_norm_clip).update( 176 | updates, base.EmptyState 177 | ) 178 | 179 | mu = otu.tree_cast(mu, mu_dtype) 180 | state = PSGDXMatState(count=count_inc, key=key, mu=mu, a=a, b=b) 181 | return updates, state 182 | 183 | return base.GradientTransformationExtraArgs(init_fn, update_fn) 184 | 185 | 186 | def xmat( 187 | learning_rate: Union[float, Callable[[int], float]] = 0.01, 188 | preconditioner_update_probability: float = 1.0, 189 | b1: float = 0.9, 190 | nesterov: bool = False, 191 | weight_decay: float = 0.0, 192 | mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, 193 | precond_lr: Union[float, Callable[[int], float]] = 0.1, 194 | precond_init_scale: Optional[float] = None, 195 | update_global_norm_clip: Optional[float] = None, 196 | step_normalizer_order: str = "2nd", 197 | seed: Optional[PRNGKey] = None, 198 | mu_dtype: Optional[Union[str, jnp.dtype]] = None, 199 | precision: str = "tensorfloat32", 200 | ) -> base.GradientTransformationExtraArgs: 201 | """ 202 | Implements XMat PSGD from https://github.com/lixilinx/psgd_torch. 203 | 204 | Args: 205 | learning_rate: float or callable, learning rate for the optimizer. 206 | preconditioner_update_probability: float, probability of updating the 207 | preconditioner. 208 | b1: float, momentum parameter. 209 | nesterov: bool, whether to use Nesterov momentum. 210 | weight_decay: float, weight decay. 211 | mask: optional mask for weight decay. 212 | precond_lr: float or callable, learning rate for the preconditioner. 213 | precond_init_scale: optional float, initial scale for the preconditioner. 214 | update_global_norm_clip: optional float, clip updates by global norm. 215 | step_normalizer_order: str, '1st' or '2nd'. 216 | seed: Optional PRNGKey, random seed. 217 | mu_dtype: optional str or jnp.dtype, dtype of the momentum accumulator. 218 | Defaults to the same dtype as the parameters. 219 | precision: str, precision for matmul, 'bfloat16', 'tensorfloat32', 'float32'. 220 | 221 | Returns: 222 | optax.GradientTransformationExtraArgs 223 | """ 224 | opt = [ 225 | scale_by_xmat( 226 | preconditioner_update_probability=preconditioner_update_probability, 227 | b1=b1, 228 | nesterov=nesterov, 229 | precond_lr=precond_lr, 230 | precond_init_scale=precond_init_scale, 231 | update_global_norm_clip=update_global_norm_clip, 232 | step_normalizer_order=step_normalizer_order, 233 | seed=seed, 234 | mu_dtype=mu_dtype, 235 | precision=precision, 236 | ) 237 | ] 238 | if weight_decay > 0: 239 | opt.append(transform.add_decayed_weights(weight_decay, mask=mask)) 240 | opt.append(transform.scale_by_learning_rate(learning_rate)) 241 | return chain(*opt) 242 | 243 | 244 | def _update_precond_Xmat_math_(a, b, v, h, precond_lr, step_normalizer, precision): 245 | """ 246 | Update preconditioner Q = diag(a) + adiag(b) with (vector, Hessian-vector product) = (v, h). 247 | """ 248 | with jax.default_matmul_precision(precision): 249 | Qh = a * h + b * jnp.flip(h, 0) 250 | aflip, bflip = jnp.flip(a, 0), jnp.flip(b, 0) 251 | invQtv = (aflip * v - bflip * jnp.flip(v, 0)) / (a * aflip - b * bflip) 252 | 253 | u, v = Qh * Qh, invQtv * invQtv 254 | nablaA = u - v 255 | nablaB = Qh * jnp.flip(Qh, 0) - invQtv * jnp.flip(invQtv, 0) 256 | q, r = jnp.divmod(len(nablaB), 2) 257 | nablaB = jnp.where(r == 1, nablaB.at[q].set(0), nablaB) 258 | 259 | if step_normalizer == "2nd": 260 | mu = precond_lr / add_eps(jnp.max(u + v)) 261 | else: 262 | mu = precond_lr / add_eps( 263 | jnp.maximum(jnp.max(jnp.abs(nablaA)), jnp.max(jnp.abs(nablaB))) 264 | ) 265 | 266 | a -= mu * (nablaA * a + nablaB * bflip) 267 | b -= mu * (nablaA * b + nablaB * aflip) 268 | 269 | return a, b 270 | 271 | 272 | def _precond_grad_Xmat_math(a, b, g): 273 | """ 274 | Preconditioning gradient g with Q = diag(a) + adiag(b). 275 | 276 | All variables here are either matrices or column vectors. 277 | """ 278 | ab = a * b 279 | return (a * a + jnp.flip(b * b, 0)) * g + (ab + jnp.flip(ab, 0)) * jnp.flip(g, 0) 280 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["flit_core >=3.2,<4"] 3 | build-backend = "flit_core.buildapi" 4 | 5 | [project] 6 | name = "psgd-jax" 7 | version = "0.2.9" 8 | description = "An implementation of PSGD optimizer in JAX." 9 | readme = { file = "README.md", content-type = "text/markdown" } 10 | license = { file = "LICENSE" } 11 | requires-python = ">=3.9" 12 | authors = [ 13 | { name = "Evan Walters" }, 14 | { name = "Omead Pooladzandi" }, 15 | { name = "Xi-Lin Li" }, 16 | ] 17 | keywords = [ 18 | "python", 19 | "machine learning", 20 | "optimization", 21 | "jax", 22 | ] 23 | classifiers = [ 24 | "Environment :: Console", 25 | "Programming Language :: Python", 26 | "Intended Audience :: Developers", 27 | "Operating System :: OS Independent", 28 | "Programming Language :: Python :: 3", 29 | "Intended Audience :: Science/Research", 30 | "Development Status :: 4 - Beta", 31 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 32 | "Topic :: Software Development :: Libraries :: Python Modules", 33 | ] 34 | dependencies = [ 35 | "numpy", 36 | "matplotlib", 37 | "jax", 38 | "optax", 39 | ] 40 | 41 | [project.urls] 42 | homepage = "https://github.com/evanatyourservice/psgd_jax" 43 | repository = "https://github.com/evanatyourservice/psgd_jax" --------------------------------------------------------------------------------