├── .gitignore ├── LICENSE ├── README.md ├── assets └── teaser.png ├── configs ├── cifar10_generate_images.yaml └── im256_generate_images.yaml ├── dnnlib ├── __init__.py └── util.py ├── env.yml ├── generate_images.py ├── model_card.md ├── torch_utils ├── __init__.py ├── distributed.py ├── misc.py ├── persistence.py └── training_stats.py └── training ├── __init__.py ├── dit.py ├── encoders.py ├── preconds.py └── unets.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | 163 | # Project-related 164 | datasets 165 | 166 | *-runs 167 | outputs/ 168 | 169 | slurm* 170 | debug.sh 171 | *.out 172 | 173 | 174 | wandb/ 175 | 176 | multirun/ 177 | 178 | *.json -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial-ShareAlike 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 58 | Public License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 63 | ("Public License"). To the extent this Public License may be 64 | interpreted as a contract, You are granted the Licensed Rights in 65 | consideration of Your acceptance of these terms and conditions, and the 66 | Licensor grants You such rights in consideration of benefits the 67 | Licensor receives from making the Licensed Material available under 68 | these terms and conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. BY-NC-SA Compatible License means a license listed at 88 | creativecommons.org/compatiblelicenses, approved by Creative 89 | Commons as essentially the equivalent of this Public License. 90 | 91 | d. Copyright and Similar Rights means copyright and/or similar rights 92 | closely related to copyright including, without limitation, 93 | performance, broadcast, sound recording, and Sui Generis Database 94 | Rights, without regard to how the rights are labeled or 95 | categorized. For purposes of this Public License, the rights 96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 97 | Rights. 98 | 99 | e. Effective Technological Measures means those measures that, in the 100 | absence of proper authority, may not be circumvented under laws 101 | fulfilling obligations under Article 11 of the WIPO Copyright 102 | Treaty adopted on December 20, 1996, and/or similar international 103 | agreements. 104 | 105 | f. Exceptions and Limitations means fair use, fair dealing, and/or 106 | any other exception or limitation to Copyright and Similar Rights 107 | that applies to Your use of the Licensed Material. 108 | 109 | g. License Elements means the license attributes listed in the name 110 | of a Creative Commons Public License. The License Elements of this 111 | Public License are Attribution, NonCommercial, and ShareAlike. 112 | 113 | h. Licensed Material means the artistic or literary work, database, 114 | or other material to which the Licensor applied this Public 115 | License. 116 | 117 | i. Licensed Rights means the rights granted to You subject to the 118 | terms and conditions of this Public License, which are limited to 119 | all Copyright and Similar Rights that apply to Your use of the 120 | Licensed Material and that the Licensor has authority to license. 121 | 122 | j. Licensor means the individual(s) or entity(ies) granting rights 123 | under this Public License. 124 | 125 | k. NonCommercial means not primarily intended for or directed towards 126 | commercial advantage or monetary compensation. For purposes of 127 | this Public License, the exchange of the Licensed Material for 128 | other material subject to Copyright and Similar Rights by digital 129 | file-sharing or similar means is NonCommercial provided there is 130 | no payment of monetary compensation in connection with the 131 | exchange. 132 | 133 | l. Share means to provide material to the public by any means or 134 | process that requires permission under the Licensed Rights, such 135 | as reproduction, public display, public performance, distribution, 136 | dissemination, communication, or importation, and to make material 137 | available to the public including in ways that members of the 138 | public may access the material from a place and at a time 139 | individually chosen by them. 140 | 141 | m. Sui Generis Database Rights means rights other than copyright 142 | resulting from Directive 96/9/EC of the European Parliament and of 143 | the Council of 11 March 1996 on the legal protection of databases, 144 | as amended and/or succeeded, as well as other essentially 145 | equivalent rights anywhere in the world. 146 | 147 | n. You means the individual or entity exercising the Licensed Rights 148 | under this Public License. Your has a corresponding meaning. 149 | 150 | 151 | Section 2 -- Scope. 152 | 153 | a. License grant. 154 | 155 | 1. Subject to the terms and conditions of this Public License, 156 | the Licensor hereby grants You a worldwide, royalty-free, 157 | non-sublicensable, non-exclusive, irrevocable license to 158 | exercise the Licensed Rights in the Licensed Material to: 159 | 160 | a. reproduce and Share the Licensed Material, in whole or 161 | in part, for NonCommercial purposes only; and 162 | 163 | b. produce, reproduce, and Share Adapted Material for 164 | NonCommercial purposes only. 165 | 166 | 2. Exceptions and Limitations. For the avoidance of doubt, where 167 | Exceptions and Limitations apply to Your use, this Public 168 | License does not apply, and You do not need to comply with 169 | its terms and conditions. 170 | 171 | 3. Term. The term of this Public License is specified in Section 172 | 6(a). 173 | 174 | 4. Media and formats; technical modifications allowed. The 175 | Licensor authorizes You to exercise the Licensed Rights in 176 | all media and formats whether now known or hereafter created, 177 | and to make technical modifications necessary to do so. The 178 | Licensor waives and/or agrees not to assert any right or 179 | authority to forbid You from making technical modifications 180 | necessary to exercise the Licensed Rights, including 181 | technical modifications necessary to circumvent Effective 182 | Technological Measures. For purposes of this Public License, 183 | simply making modifications authorized by this Section 2(a) 184 | (4) never produces Adapted Material. 185 | 186 | 5. Downstream recipients. 187 | 188 | a. Offer from the Licensor -- Licensed Material. Every 189 | recipient of the Licensed Material automatically 190 | receives an offer from the Licensor to exercise the 191 | Licensed Rights under the terms and conditions of this 192 | Public License. 193 | 194 | b. Additional offer from the Licensor -- Adapted Material. 195 | Every recipient of Adapted Material from You 196 | automatically receives an offer from the Licensor to 197 | exercise the Licensed Rights in the Adapted Material 198 | under the conditions of the Adapter's License You apply. 199 | 200 | c. No downstream restrictions. You may not offer or impose 201 | any additional or different terms or conditions on, or 202 | apply any Effective Technological Measures to, the 203 | Licensed Material if doing so restricts exercise of the 204 | Licensed Rights by any recipient of the Licensed 205 | Material. 206 | 207 | 6. No endorsement. Nothing in this Public License constitutes or 208 | may be construed as permission to assert or imply that You 209 | are, or that Your use of the Licensed Material is, connected 210 | with, or sponsored, endorsed, or granted official status by, 211 | the Licensor or others designated to receive attribution as 212 | provided in Section 3(a)(1)(A)(i). 213 | 214 | b. Other rights. 215 | 216 | 1. Moral rights, such as the right of integrity, are not 217 | licensed under this Public License, nor are publicity, 218 | privacy, and/or other similar personality rights; however, to 219 | the extent possible, the Licensor waives and/or agrees not to 220 | assert any such rights held by the Licensor to the limited 221 | extent necessary to allow You to exercise the Licensed 222 | Rights, but not otherwise. 223 | 224 | 2. Patent and trademark rights are not licensed under this 225 | Public License. 226 | 227 | 3. To the extent possible, the Licensor waives any right to 228 | collect royalties from You for the exercise of the Licensed 229 | Rights, whether directly or through a collecting society 230 | under any voluntary or waivable statutory or compulsory 231 | licensing scheme. In all other cases the Licensor expressly 232 | reserves any right to collect such royalties, including when 233 | the Licensed Material is used other than for NonCommercial 234 | purposes. 235 | 236 | 237 | Section 3 -- License Conditions. 238 | 239 | Your exercise of the Licensed Rights is expressly made subject to the 240 | following conditions. 241 | 242 | a. Attribution. 243 | 244 | 1. If You Share the Licensed Material (including in modified 245 | form), You must: 246 | 247 | a. retain the following if it is supplied by the Licensor 248 | with the Licensed Material: 249 | 250 | i. identification of the creator(s) of the Licensed 251 | Material and any others designated to receive 252 | attribution, in any reasonable manner requested by 253 | the Licensor (including by pseudonym if 254 | designated); 255 | 256 | ii. a copyright notice; 257 | 258 | iii. a notice that refers to this Public License; 259 | 260 | iv. a notice that refers to the disclaimer of 261 | warranties; 262 | 263 | v. a URI or hyperlink to the Licensed Material to the 264 | extent reasonably practicable; 265 | 266 | b. indicate if You modified the Licensed Material and 267 | retain an indication of any previous modifications; and 268 | 269 | c. indicate the Licensed Material is licensed under this 270 | Public License, and include the text of, or the URI or 271 | hyperlink to, this Public License. 272 | 273 | 2. You may satisfy the conditions in Section 3(a)(1) in any 274 | reasonable manner based on the medium, means, and context in 275 | which You Share the Licensed Material. For example, it may be 276 | reasonable to satisfy the conditions by providing a URI or 277 | hyperlink to a resource that includes the required 278 | information. 279 | 3. If requested by the Licensor, You must remove any of the 280 | information required by Section 3(a)(1)(A) to the extent 281 | reasonably practicable. 282 | 283 | b. ShareAlike. 284 | 285 | In addition to the conditions in Section 3(a), if You Share 286 | Adapted Material You produce, the following conditions also apply. 287 | 288 | 1. The Adapter's License You apply must be a Creative Commons 289 | license with the same License Elements, this version or 290 | later, or a BY-NC-SA Compatible License. 291 | 292 | 2. You must include the text of, or the URI or hyperlink to, the 293 | Adapter's License You apply. You may satisfy this condition 294 | in any reasonable manner based on the medium, means, and 295 | context in which You Share Adapted Material. 296 | 297 | 3. You may not offer or impose any additional or different terms 298 | or conditions on, or apply any Effective Technological 299 | Measures to, Adapted Material that restrict exercise of the 300 | rights granted under the Adapter's License You apply. 301 | 302 | 303 | Section 4 -- Sui Generis Database Rights. 304 | 305 | Where the Licensed Rights include Sui Generis Database Rights that 306 | apply to Your use of the Licensed Material: 307 | 308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 309 | to extract, reuse, reproduce, and Share all or a substantial 310 | portion of the contents of the database for NonCommercial purposes 311 | only; 312 | 313 | b. if You include all or a substantial portion of the database 314 | contents in a database in which You have Sui Generis Database 315 | Rights, then the database in which You have Sui Generis Database 316 | Rights (but not its individual contents) is Adapted Material, 317 | including for purposes of Section 3(b); and 318 | 319 | c. You must comply with the conditions in Section 3(a) if You Share 320 | all or a substantial portion of the contents of the database. 321 | 322 | For the avoidance of doubt, this Section 4 supplements and does not 323 | replace Your obligations under this Public License where the Licensed 324 | Rights include other Copyright and Similar Rights. 325 | 326 | 327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 328 | 329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 339 | 340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 349 | 350 | c. The disclaimer of warranties and limitation of liability provided 351 | above shall be interpreted in a manner that, to the extent 352 | possible, most closely approximates an absolute disclaimer and 353 | waiver of all liability. 354 | 355 | 356 | Section 6 -- Term and Termination. 357 | 358 | a. This Public License applies for the term of the Copyright and 359 | Similar Rights licensed here. However, if You fail to comply with 360 | this Public License, then Your rights under this Public License 361 | terminate automatically. 362 | 363 | b. Where Your right to use the Licensed Material has terminated under 364 | Section 6(a), it reinstates: 365 | 366 | 1. automatically as of the date the violation is cured, provided 367 | it is cured within 30 days of Your discovery of the 368 | violation; or 369 | 370 | 2. upon express reinstatement by the Licensor. 371 | 372 | For the avoidance of doubt, this Section 6(b) does not affect any 373 | right the Licensor may have to seek remedies for Your violations 374 | of this Public License. 375 | 376 | c. For the avoidance of doubt, the Licensor may also offer the 377 | Licensed Material under separate terms or conditions or stop 378 | distributing the Licensed Material at any time; however, doing so 379 | will not terminate this Public License. 380 | 381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 382 | License. 383 | 384 | 385 | Section 7 -- Other Terms and Conditions. 386 | 387 | a. The Licensor shall not be bound by any additional or different 388 | terms or conditions communicated by You unless expressly agreed. 389 | 390 | b. Any arrangements, understandings, or agreements regarding the 391 | Licensed Material not stated herein are separate from and 392 | independent of the terms and conditions of this Public License. 393 | 394 | 395 | Section 8 -- Interpretation. 396 | 397 | a. For the avoidance of doubt, this Public License does not, and 398 | shall not be interpreted to, reduce, limit, restrict, or impose 399 | conditions on any use of the Licensed Material that could lawfully 400 | be made without permission under this Public License. 401 | 402 | b. To the extent possible, if any provision of this Public License is 403 | deemed unenforceable, it shall be automatically reformed to the 404 | minimum extent necessary to make it enforceable. If the provision 405 | cannot be reformed, it shall be severed from this Public License 406 | without affecting the enforceability of the remaining terms and 407 | conditions. 408 | 409 | c. No term or condition of this Public License will be waived and no 410 | failure to comply consented to unless expressly agreed to by the 411 | Licensor. 412 | 413 | d. Nothing in this Public License constitutes or may be interpreted 414 | as a limitation upon, or waiver of, any privileges and immunities 415 | that apply to the Licensor or You, including from the legal 416 | processes of any jurisdiction or authority. 417 | 418 | ======================================================================= 419 | 420 | Creative Commons is not a party to its public 421 | licenses. Notwithstanding, Creative Commons may elect to apply one of 422 | its public licenses to material it publishes and in those instances 423 | will be considered the "Licensor." The text of the Creative Commons 424 | public licenses is dedicated to the public domain under the CC0 Public 425 | Domain Dedication. Except for the limited purpose of indicating that 426 | material is shared under a Creative Commons public license or as 427 | otherwise permitted by the Creative Commons policies published at 428 | creativecommons.org/policies, Creative Commons does not authorize the 429 | use of the trademark "Creative Commons" or any other trademark or logo 430 | of Creative Commons without its prior written consent including, 431 | without limitation, in connection with any unauthorized modifications 432 | to any of its public licenses or any other arrangements, 433 | understandings, or agreements concerning use of licensed material. For 434 | the avoidance of doubt, this paragraph does not form part of the 435 | public licenses. 436 | 437 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Inductive Moment Matching 2 | 3 | 4 | Official Implementation of [Inductive Moment Matching](https://arxiv.org/abs/2503.07565) 5 | 6 |

7 | 8 |

9 | 10 |
11 | 12 | Linqi Zhou1, 13 | 14 | Stefano Ermon2 15 | 16 | 17 | Jiaming Song1 18 | 19 |
20 | 21 |
22 | 1Luma AI, 23 | 2Stanford University 24 |
25 |
26 | [Paper] 27 | [Blog] 28 |
29 |
30 | 31 | Also check out our accompanying [position paper](https://arxiv.org/abs/2503.07154) that explains the motivation and ways of designing new generative paradigms. 32 | 33 | # Dependencies 34 | 35 | To install all packages in this codebase along with their dependencies, run 36 | ```sh 37 | conda env create -f env.yml 38 | ``` 39 | 40 | # Pre-trained models 41 | 42 | We provide pretrained checkpoints through our [repo](https://huggingface.co/lumaai/imm) on Hugging Face: 43 | * IMM on CIFAR-10: [cifar10.pkl](https://huggingface.co/lumaai/imm/resolve/main/cifar10.pt). 44 | * IMM on ImageNet-256x256: 45 | 1. `t-s` is passed as second time embedding, trained with `a=2`: [imagenet256_ts_a2.pkl](https://huggingface.co/lumaai/imm/resolve/main/imagenet256_ts_a2.pkl). 46 | 2. `s` is passed as second time embedding directly, trained with `a=1`: [imagenet256_s_a1.pkl](https://huggingface.co/lumaai/imm/resolve/main/imagenet256_s_a1.pkl). 47 | 48 | # Sampling 49 | 50 | The checkpoints can be tested via 51 | ```sh 52 | python generate_images.py --config-name=CONFIG_NAME eval.resume=CKPT_PATH REPLACEMENT_ARGS 53 | ``` 54 | where `CONFIG_NAME` is `im256_generate_images.yaml` or `cifar10_generate_images.yaml` and `CKPT_PATH` is the path to your checkpoint. When loading `imagenet256_s_a1.pkl`, `REPLACEMENT_ARGS` needs to be `network.temb_type=identity`. Otherwise, `REPLACEMENT_ARGS` is empty. 55 | 56 | # Checklist 57 | 58 | - [x] Add model weights and model definitions. 59 | - [x] Add inference scripts. 60 | - [ ] Add evaluation scripts. 61 | - [ ] Add training scripts. 62 | 63 | # Acknowledgements 64 | 65 | Some of the utility functions are based on [EDM](https://github.com/NVlabs/edm), and thus parts of the code would apply under [this license](https://github.com/NVlabs/edm/blob/main/LICENSE.txt). 66 | 67 | # Citation 68 | 69 | ``` 70 | @article{zhou2025inductive, 71 | title={Inductive Moment Matching}, 72 | author={Zhou, Linqi and Ermon, Stefano and Song, Jiaming}, 73 | journal={arXiv preprint arXiv:2503.07565}, 74 | year={2025} 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lumalabs/imm/c9473e9395910d3dd03ffffb6b0a8b4fb760c3a1/assets/teaser.png -------------------------------------------------------------------------------- /configs/cifar10_generate_images.yaml: -------------------------------------------------------------------------------- 1 | 2 | label_dim: 0 # unconditional 3 | resolution: 32 #latent resolution 4 | channels: 3 5 | 6 | encoder: 7 | class_name: training.encoders.StandardRGBEncoder 8 | 9 | 10 | network: 11 | class_name: training.preconds.IMMPrecond 12 | #ddpmpp 13 | model_type: "SongUNet" 14 | embedding_type: "positional" 15 | encoder_type: "standard" 16 | decoder_type: "standard" 17 | channel_mult_noise: 1 18 | resample_filter: [1, 1] 19 | model_channels: 128 20 | channel_mult: [2, 2, 2] 21 | s_embed: true 22 | dropout: 0.2 23 | 24 | noise_schedule: fm 25 | 26 | f_type: simple_edm 27 | temb_type: identity 28 | time_scale: 1000 29 | 30 | 31 | eps: 0.006 32 | T: 0.994 33 | 34 | 35 | sampling: 36 | 1_step: 37 | name: pushforward_generator_fn 38 | mid_nt: null 39 | 40 | 2_steps: 41 | name: pushforward_generator_fn 42 | mid_nt: [1.4] 43 | 44 | eval: 45 | seed: 42 46 | batch_size: 256 47 | cudnn_benchmark: true 48 | resume: null 49 | 50 | 51 | hydra: 52 | output_subdir: null 53 | run: 54 | dir: . -------------------------------------------------------------------------------- /configs/im256_generate_images.yaml: -------------------------------------------------------------------------------- 1 | 2 | 3 | label_dim: 1000 4 | resolution: 32 #latent resolution 5 | channels: 4 6 | 7 | dataloader: 8 | pin_memory: true 9 | num_workers: 1 10 | prefetch_factor: 2 11 | 12 | encoder: 13 | class_name: training.encoders.StabilityVAEEncoder 14 | vae_name: stabilityai/sd-vae-ft-ema 15 | final_std: 0.5 16 | raw_mean: [ 0.86488, -0.27787343, 0.21616915, 0.3738409 ] 17 | raw_std: [4.85503674, 5.31922414, 3.93725398 , 3.9870003 ] 18 | use_fp16: true 19 | 20 | 21 | network: 22 | class_name: training.preconds.IMMPrecond 23 | 24 | model_type: "DiT_XL_2" 25 | s_embed: true 26 | 27 | noise_schedule: fm 28 | 29 | #sample function 30 | f_type: euler_fm 31 | temb_type: stride 32 | time_scale: 1000 33 | 34 | sigma_data: 0.5 35 | 36 | eps: 0. 37 | T: 0.994 38 | 39 | 40 | 41 | sampling: 42 | 43 | 1_steps_cfg1.5_pushforward_uniform: 44 | name: pushforward_generator_fn 45 | discretization: uniform 46 | num_steps: 1 47 | cfg_scale: 1.5 48 | 49 | 50 | 2_steps_cfg1.5_pushforward_uniform: 51 | name: pushforward_generator_fn 52 | discretization: uniform 53 | num_steps: 2 54 | cfg_scale: 1.5 55 | 56 | 4_steps_cfg1.5_pushforward_uniform: 57 | name: pushforward_generator_fn 58 | discretization: uniform 59 | num_steps: 4 60 | cfg_scale: 1.5 61 | 62 | 8_steps_cfg1.5_pushforward_uniform: 63 | name: pushforward_generator_fn 64 | discretization: uniform 65 | num_steps: 8 66 | cfg_scale: 1.5 67 | 68 | 69 | 70 | eval: 71 | seed: 42 72 | batch_size: 256 73 | cudnn_benchmark: true 74 | resume: null 75 | 76 | 77 | hydra: 78 | output_subdir: null 79 | run: 80 | dir: . -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | from .util import EasyDict, make_cache_dir_path 9 | -------------------------------------------------------------------------------- /dnnlib/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Miscellaneous utility classes and functions.""" 9 | 10 | import ctypes 11 | import fnmatch 12 | import importlib 13 | import inspect 14 | import numpy as np 15 | import os 16 | import shutil 17 | import sys 18 | import types 19 | import io 20 | import pickle 21 | import re 22 | import requests 23 | import html 24 | import hashlib 25 | import glob 26 | import tempfile 27 | import urllib 28 | import urllib.request 29 | import uuid 30 | 31 | from distutils.util import strtobool 32 | from typing import Any, List, Tuple, Union, Optional 33 | 34 | 35 | # Util classes 36 | # ------------------------------------------------------------------------------------------ 37 | 38 | 39 | class EasyDict(dict): 40 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 41 | 42 | def __getattr__(self, name: str) -> Any: 43 | try: 44 | return self[name] 45 | except KeyError: 46 | raise AttributeError(name) 47 | 48 | def __setattr__(self, name: str, value: Any) -> None: 49 | self[name] = value 50 | 51 | def __delattr__(self, name: str) -> None: 52 | del self[name] 53 | 54 | 55 | class Logger(object): 56 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" 57 | 58 | def __init__( 59 | self, 60 | file_name: Optional[str] = None, 61 | file_mode: str = "w", 62 | should_flush: bool = True, 63 | ): 64 | self.file = None 65 | 66 | if file_name is not None: 67 | self.file = open(file_name, file_mode) 68 | 69 | self.should_flush = should_flush 70 | self.stdout = sys.stdout 71 | self.stderr = sys.stderr 72 | 73 | sys.stdout = self 74 | sys.stderr = self 75 | 76 | def __enter__(self) -> "Logger": 77 | return self 78 | 79 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: 80 | self.close() 81 | 82 | def write(self, text: Union[str, bytes]) -> None: 83 | """Write text to stdout (and a file) and optionally flush.""" 84 | if isinstance(text, bytes): 85 | text = text.decode() 86 | if ( 87 | len(text) == 0 88 | ): # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash 89 | return 90 | 91 | if self.file is not None: 92 | self.file.write(text) 93 | 94 | self.stdout.write(text) 95 | 96 | if self.should_flush: 97 | self.flush() 98 | 99 | def flush(self) -> None: 100 | """Flush written text to both stdout and a file, if open.""" 101 | if self.file is not None: 102 | self.file.flush() 103 | 104 | self.stdout.flush() 105 | 106 | def close(self) -> None: 107 | """Flush, close possible files, and remove stdout/stderr mirroring.""" 108 | self.flush() 109 | 110 | # if using multiple loggers, prevent closing in wrong order 111 | if sys.stdout is self: 112 | sys.stdout = self.stdout 113 | if sys.stderr is self: 114 | sys.stderr = self.stderr 115 | 116 | if self.file is not None: 117 | self.file.close() 118 | self.file = None 119 | 120 | 121 | # Cache directories 122 | # ------------------------------------------------------------------------------------------ 123 | 124 | _dnnlib_cache_dir = None 125 | 126 | 127 | def set_cache_dir(path: str) -> None: 128 | global _dnnlib_cache_dir 129 | _dnnlib_cache_dir = path 130 | 131 | 132 | def make_cache_dir_path(*paths: str) -> str: 133 | if _dnnlib_cache_dir is not None: 134 | return os.path.join(_dnnlib_cache_dir, *paths) 135 | if "DNNLIB_CACHE_DIR" in os.environ: 136 | return os.path.join(os.environ["DNNLIB_CACHE_DIR"], *paths) 137 | if "HOME" in os.environ: 138 | return os.path.join(os.environ["HOME"], ".cache", "dnnlib", *paths) 139 | if "USERPROFILE" in os.environ: 140 | return os.path.join(os.environ["USERPROFILE"], ".cache", "dnnlib", *paths) 141 | return os.path.join(tempfile.gettempdir(), ".cache", "dnnlib", *paths) 142 | 143 | 144 | # Small util functions 145 | # ------------------------------------------------------------------------------------------ 146 | 147 | 148 | def format_time(seconds: Union[int, float]) -> str: 149 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 150 | s = int(np.rint(seconds)) 151 | 152 | if s < 60: 153 | return "{0}s".format(s) 154 | elif s < 60 * 60: 155 | return "{0}m {1:02}s".format(s // 60, s % 60) 156 | elif s < 24 * 60 * 60: 157 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) 158 | else: 159 | return "{0}d {1:02}h {2:02}m".format( 160 | s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60 161 | ) 162 | 163 | 164 | def format_time_brief(seconds: Union[int, float]) -> str: 165 | """Convert the seconds to human readable string with days, hours, minutes and seconds.""" 166 | s = int(np.rint(seconds)) 167 | 168 | if s < 60: 169 | return "{0}s".format(s) 170 | elif s < 60 * 60: 171 | return "{0}m {1:02}s".format(s // 60, s % 60) 172 | elif s < 24 * 60 * 60: 173 | return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) 174 | else: 175 | return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) 176 | 177 | 178 | def ask_yes_no(question: str) -> bool: 179 | """Ask the user the question until the user inputs a valid answer.""" 180 | while True: 181 | try: 182 | print("{0} [y/n]".format(question)) 183 | return strtobool(input().lower()) 184 | except ValueError: 185 | pass 186 | 187 | 188 | def tuple_product(t: Tuple) -> Any: 189 | """Calculate the product of the tuple elements.""" 190 | result = 1 191 | 192 | for v in t: 193 | result *= v 194 | 195 | return result 196 | 197 | 198 | _str_to_ctype = { 199 | "uint8": ctypes.c_ubyte, 200 | "uint16": ctypes.c_uint16, 201 | "uint32": ctypes.c_uint32, 202 | "uint64": ctypes.c_uint64, 203 | "int8": ctypes.c_byte, 204 | "int16": ctypes.c_int16, 205 | "int32": ctypes.c_int32, 206 | "int64": ctypes.c_int64, 207 | "float32": ctypes.c_float, 208 | "float64": ctypes.c_double, 209 | } 210 | 211 | 212 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: 213 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" 214 | type_str = None 215 | 216 | if isinstance(type_obj, str): 217 | type_str = type_obj 218 | elif hasattr(type_obj, "__name__"): 219 | type_str = type_obj.__name__ 220 | elif hasattr(type_obj, "name"): 221 | type_str = type_obj.name 222 | else: 223 | raise RuntimeError("Cannot infer type name from input") 224 | 225 | assert type_str in _str_to_ctype.keys() 226 | 227 | my_dtype = np.dtype(type_str) 228 | my_ctype = _str_to_ctype[type_str] 229 | 230 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype) 231 | 232 | return my_dtype, my_ctype 233 | 234 | 235 | def is_pickleable(obj: Any) -> bool: 236 | try: 237 | with io.BytesIO() as stream: 238 | pickle.dump(obj, stream) 239 | return True 240 | except: 241 | return False 242 | 243 | 244 | # Functionality to import modules/objects by name, and call functions by name 245 | # ------------------------------------------------------------------------------------------ 246 | 247 | 248 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: 249 | """Searches for the underlying module behind the name to some python object. 250 | Returns the module and the object name (original name with module part removed).""" 251 | 252 | # allow convenience shorthands, substitute them by full names 253 | obj_name = re.sub("^np.", "numpy.", obj_name) 254 | obj_name = re.sub("^tf.", "tensorflow.", obj_name) 255 | 256 | # list alternatives for (module_name, local_obj_name) 257 | parts = obj_name.split(".") 258 | name_pairs = [ 259 | (".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1) 260 | ] 261 | 262 | # try each alternative in turn 263 | for module_name, local_obj_name in name_pairs: 264 | try: 265 | module = importlib.import_module(module_name) # may raise ImportError 266 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 267 | return module, local_obj_name 268 | except: 269 | pass 270 | 271 | # maybe some of the modules themselves contain errors? 272 | for module_name, _local_obj_name in name_pairs: 273 | try: 274 | importlib.import_module(module_name) # may raise ImportError 275 | except ImportError: 276 | if not str(sys.exc_info()[1]).startswith( 277 | "No module named '" + module_name + "'" 278 | ): 279 | raise 280 | 281 | # maybe the requested attribute is missing? 282 | for module_name, local_obj_name in name_pairs: 283 | try: 284 | module = importlib.import_module(module_name) # may raise ImportError 285 | get_obj_from_module(module, local_obj_name) # may raise AttributeError 286 | except ImportError: 287 | pass 288 | 289 | # we are out of luck, but we have no idea why 290 | raise ImportError(obj_name) 291 | 292 | 293 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: 294 | """Traverses the object name and returns the last (rightmost) python object.""" 295 | if obj_name == "": 296 | return module 297 | obj = module 298 | for part in obj_name.split("."): 299 | obj = getattr(obj, part) 300 | return obj 301 | 302 | 303 | def get_obj_by_name(name: str) -> Any: 304 | """Finds the python object with the given name.""" 305 | module, obj_name = get_module_from_obj_name(name) 306 | return get_obj_from_module(module, obj_name) 307 | 308 | 309 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: 310 | """Finds the python object with the given name and calls it as a function.""" 311 | assert func_name is not None 312 | func_obj = get_obj_by_name(func_name) 313 | assert callable(func_obj) 314 | return func_obj(*args, **kwargs) 315 | 316 | 317 | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: 318 | """Finds the python class with the given name and constructs it with the given arguments.""" 319 | return call_func_by_name(*args, func_name=class_name, **kwargs) 320 | 321 | 322 | def get_module_dir_by_obj_name(obj_name: str) -> str: 323 | """Get the directory path of the module containing the given object name.""" 324 | module, _ = get_module_from_obj_name(obj_name) 325 | return os.path.dirname(inspect.getfile(module)) 326 | 327 | 328 | def is_top_level_function(obj: Any) -> bool: 329 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" 330 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ 331 | 332 | 333 | def get_top_level_function_name(obj: Any) -> str: 334 | """Return the fully-qualified name of a top-level function.""" 335 | assert is_top_level_function(obj) 336 | module = obj.__module__ 337 | if module == "__main__": 338 | module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] 339 | return module + "." + obj.__name__ 340 | 341 | 342 | # File system helpers 343 | # ------------------------------------------------------------------------------------------ 344 | 345 | 346 | def list_dir_recursively_with_ignore( 347 | dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False 348 | ) -> List[Tuple[str, str]]: 349 | """List all files recursively in a given directory while ignoring given file and directory names. 350 | Returns list of tuples containing both absolute and relative paths.""" 351 | assert os.path.isdir(dir_path) 352 | base_name = os.path.basename(os.path.normpath(dir_path)) 353 | 354 | if ignores is None: 355 | ignores = [] 356 | 357 | result = [] 358 | 359 | for root, dirs, files in os.walk(dir_path, topdown=True): 360 | for ignore_ in ignores: 361 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] 362 | 363 | # dirs need to be edited in-place 364 | for d in dirs_to_remove: 365 | dirs.remove(d) 366 | 367 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] 368 | 369 | absolute_paths = [os.path.join(root, f) for f in files] 370 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] 371 | 372 | if add_base_to_relative: 373 | relative_paths = [os.path.join(base_name, p) for p in relative_paths] 374 | 375 | assert len(absolute_paths) == len(relative_paths) 376 | result += zip(absolute_paths, relative_paths) 377 | 378 | return result 379 | 380 | 381 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: 382 | """Takes in a list of tuples of (src, dst) paths and copies files. 383 | Will create all necessary directories.""" 384 | for file in files: 385 | target_dir_name = os.path.dirname(file[1]) 386 | 387 | # will create all intermediate-level directories 388 | if not os.path.exists(target_dir_name): 389 | os.makedirs(target_dir_name) 390 | 391 | shutil.copyfile(file[0], file[1]) 392 | 393 | 394 | # URL helpers 395 | # ------------------------------------------------------------------------------------------ 396 | 397 | 398 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool: 399 | """Determine whether the given object is a valid URL string.""" 400 | if not isinstance(obj, str) or not "://" in obj: 401 | return False 402 | if allow_file_urls and obj.startswith("file://"): 403 | return True 404 | try: 405 | res = requests.compat.urlparse(obj) 406 | if not res.scheme or not res.netloc or not "." in res.netloc: 407 | return False 408 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) 409 | if not res.scheme or not res.netloc or not "." in res.netloc: 410 | return False 411 | except: 412 | return False 413 | return True 414 | 415 | 416 | def open_url( 417 | url: str, 418 | cache_dir: str = None, 419 | num_attempts: int = 10, 420 | verbose: bool = True, 421 | return_filename: bool = False, 422 | cache: bool = True, 423 | ) -> Any: 424 | """Download the given URL and return a binary-mode file object to access the data.""" 425 | assert num_attempts >= 1 426 | assert not (return_filename and (not cache)) 427 | 428 | # Doesn't look like an URL scheme so interpret it as a local filename. 429 | if not re.match("^[a-z]+://", url): 430 | return url if return_filename else open(url, "rb") 431 | 432 | # Handle file URLs. This code handles unusual file:// patterns that 433 | # arise on Windows: 434 | # 435 | # file:///c:/foo.txt 436 | # 437 | # which would translate to a local '/c:/foo.txt' filename that's 438 | # invalid. Drop the forward slash for such pathnames. 439 | # 440 | # If you touch this code path, you should test it on both Linux and 441 | # Windows. 442 | # 443 | # Some internet resources suggest using urllib.request.url2pathname() but 444 | # but that converts forward slashes to backslashes and this causes 445 | # its own set of problems. 446 | if url.startswith("file://"): 447 | filename = urllib.parse.urlparse(url).path 448 | if re.match(r"^/[a-zA-Z]:", filename): 449 | filename = filename[1:] 450 | return filename if return_filename else open(filename, "rb") 451 | 452 | assert is_url(url) 453 | 454 | # Lookup from cache. 455 | if cache_dir is None: 456 | cache_dir = make_cache_dir_path("downloads") 457 | 458 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 459 | if cache: 460 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) 461 | if len(cache_files) == 1: 462 | filename = cache_files[0] 463 | return filename if return_filename else open(filename, "rb") 464 | 465 | # Download. 466 | url_name = None 467 | url_data = None 468 | with requests.Session() as session: 469 | if verbose: 470 | print("Downloading %s ..." % url, end="", flush=True) 471 | for attempts_left in reversed(range(num_attempts)): 472 | try: 473 | with session.get(url) as res: 474 | res.raise_for_status() 475 | if len(res.content) == 0: 476 | raise IOError("No data received") 477 | 478 | if len(res.content) < 8192: 479 | content_str = res.content.decode("utf-8") 480 | if "download_warning" in res.headers.get("Set-Cookie", ""): 481 | links = [ 482 | html.unescape(link) 483 | for link in content_str.split('"') 484 | if "export=download" in link 485 | ] 486 | if len(links) == 1: 487 | url = requests.compat.urljoin(url, links[0]) 488 | raise IOError("Google Drive virus checker nag") 489 | if "Google Drive - Quota exceeded" in content_str: 490 | raise IOError( 491 | "Google Drive download quota exceeded -- please try again later" 492 | ) 493 | 494 | match = re.search( 495 | r'filename="([^"]*)"', 496 | res.headers.get("Content-Disposition", ""), 497 | ) 498 | url_name = match[1] if match else url 499 | url_data = res.content 500 | if verbose: 501 | print(" done") 502 | break 503 | except KeyboardInterrupt: 504 | raise 505 | except: 506 | if not attempts_left: 507 | if verbose: 508 | print(" failed") 509 | raise 510 | if verbose: 511 | print(".", end="", flush=True) 512 | 513 | # Save to cache. 514 | if cache: 515 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) 516 | safe_name = safe_name[: min(len(safe_name), 128)] 517 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) 518 | temp_file = os.path.join( 519 | cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name 520 | ) 521 | os.makedirs(cache_dir, exist_ok=True) 522 | with open(temp_file, "wb") as f: 523 | f.write(url_data) 524 | os.replace(temp_file, cache_file) # atomic 525 | if return_filename: 526 | return cache_file 527 | 528 | # Return data as file object. 529 | assert not return_filename 530 | return io.BytesIO(url_data) 531 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: imm 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python==3.9.18 7 | - pip 8 | - click 9 | - requests 10 | - pillow 11 | - numpy 12 | - scipy 13 | - psutil 14 | - tqdm 15 | - imageio 16 | - pytorch=2.5.1 17 | - pytorch-cuda=12.1 18 | - pip: 19 | - einops 20 | - matplotlib 21 | - seaborn 22 | - wandb 23 | - timm==1.0.8 24 | - imageio-ffmpeg 25 | - pyspng 26 | - omegaconf==2.3.0 27 | - hydra-core==1.3.2 28 | - diffusers==0.31.0 29 | 30 | -------------------------------------------------------------------------------- /generate_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | 5 | import pickle 6 | import functools 7 | import numpy as np 8 | 9 | import torch 10 | import dnnlib 11 | 12 | import torchvision.utils as vutils 13 | import warnings 14 | 15 | from omegaconf import OmegaConf 16 | from torch_utils import misc 17 | import hydra 18 | 19 | warnings.filterwarnings( 20 | "ignore", "Grad strides do not match bucket view strides" 21 | ) # False warning printed by PyTorch 1.12. 22 | 23 | 24 | 25 | 26 | # ---------------------------------------------------------------------------- 27 | 28 | def generator_fn(*args, name='pushforward_generator_fn', **kwargs): 29 | return globals()[name](*args, **kwargs) 30 | 31 | 32 | 33 | 34 | @torch.no_grad() 35 | def pushforward_generator_fn(net, latents, class_labels=None, discretization=None, mid_nt=None, num_steps=None, cfg_scale=None, ): 36 | # Time step discretization. 37 | if discretization == 'uniform': 38 | t_steps = torch.linspace(net.T, net.eps, num_steps+1, dtype=torch.float64, device=latents.device) 39 | elif discretization == 'edm': 40 | nt_min = net.get_log_nt(torch.as_tensor(net.eps, dtype=torch.float64)).exp().item() 41 | nt_max = net.get_log_nt(torch.as_tensor(net.T, dtype=torch.float64)).exp().item() 42 | rho = 7 43 | step_indices = torch.arange(num_steps+1, dtype=torch.float64, device=latents.device) 44 | nt_steps = (nt_max ** (1 / rho) + step_indices / (num_steps) * (nt_min ** (1 / rho) - nt_max ** (1 / rho))) ** rho 45 | t_steps = net.nt_to_t(nt_steps) 46 | else: 47 | if mid_nt is None: 48 | mid_nt = [] 49 | mid_t = [net.nt_to_t(torch.as_tensor(nt)).item() for nt in mid_nt] 50 | t_steps = torch.tensor( 51 | [net.T] + list(mid_t), dtype=torch.float64, device=latents.device 52 | ) 53 | # t_0 = T, t_N = 0 54 | t_steps = torch.cat([t_steps, torch.ones_like(t_steps[:1]) * net.eps]) 55 | 56 | # Sampling steps 57 | x = latents.to(torch.float64) 58 | 59 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): 60 | 61 | x = net.cfg_forward(x, t_cur, t_next, class_labels=class_labels, cfg_scale=cfg_scale ).to( 62 | torch.float64 63 | ) 64 | 65 | 66 | return x 67 | 68 | @torch.no_grad() 69 | def restart_generator_fn(net, latents, class_labels=None, discretization=None, mid_nt=None, num_steps=None, cfg_scale=None ): 70 | # Time step discretization. 71 | if discretization == 'uniform': 72 | t_steps = torch.linspace(net.T, net.eps, num_steps+1, dtype=torch.float64, device=latents.device)[:-1] 73 | elif discretization == 'edm': 74 | nt_min = net.get_log_nt(torch.as_tensor(net.eps, dtype=torch.float64)).exp().item() 75 | nt_max = net.get_log_nt(torch.as_tensor(net.T, dtype=torch.float64)).exp().item() 76 | rho = 7 77 | step_indices = torch.arange(num_steps+1, dtype=torch.float64, device=latents.device) 78 | nt_steps = (nt_max ** (1 / rho) + step_indices / (num_steps) * (nt_min ** (1 / rho) - nt_max ** (1 / rho))) ** rho 79 | t_steps = net.nt_to_t(nt_steps)[:-1] 80 | else: 81 | if mid_nt is None: 82 | mid_nt = [] 83 | mid_t = [net.nt_to_t(torch.as_tensor(nt)).item() for nt in mid_nt] 84 | t_steps = torch.tensor( 85 | [net.T] + list(mid_t), dtype=torch.float64, device=latents.device 86 | ) 87 | # Sampling steps 88 | x = latents.to(torch.float64) 89 | 90 | for i, t_cur in enumerate(t_steps): 91 | 92 | 93 | x = net.cfg_forward(x, t_cur, torch.ones_like(t_cur) * net.eps, class_labels=class_labels, cfg_scale=cfg_scale ).to( 94 | torch.float64 95 | ) 96 | 97 | if i < len(t_steps) - 1: 98 | x, _ = net.add_noise(x, t_steps[i+1]) 99 | 100 | return x 101 | 102 | 103 | 104 | # ---------------------------------------------------------------------------- 105 | 106 | @hydra.main(version_base=None, config_path="configs") 107 | def main(cfg): 108 | 109 | device = torch.device("cuda") 110 | config = OmegaConf.create(OmegaConf.to_yaml(cfg, resolve=True)) 111 | 112 | # Random seed. 113 | if config.eval.seed is None: 114 | 115 | seed = torch.randint(1 << 31, size=[], device=device) 116 | torch.distributed.broadcast(seed, src=0) 117 | config.eval.seed = int(seed) 118 | 119 | # Checkpoint to evaluate. 120 | resume_pkl = cfg.eval.resume 121 | cudnn_benchmark = config.eval.cudnn_benchmark 122 | seed = config.eval.seed 123 | encoder_kwargs = config.encoder 124 | 125 | batch_size = config.eval.batch_size 126 | sample_kwargs_dict = config.get('sampling', {}) 127 | # Initialize. 128 | np.random.seed(seed % (1 << 31)) 129 | torch.manual_seed(np.random.randint(1 << 31)) 130 | torch.backends.cudnn.benchmark = cudnn_benchmark 131 | torch.backends.cudnn.allow_tf32 = True 132 | torch.backends.cuda.matmul.allow_tf32 = True 133 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 134 | 135 | print('Setting up encoder...') 136 | encoder = dnnlib.util.construct_class_by_name(**encoder_kwargs) 137 | 138 | # Construct network. 139 | print("Constructing network...") 140 | 141 | interface_kwargs = dict( 142 | img_resolution=config.resolution, 143 | img_channels=config.channels, 144 | label_dim=config.label_dim, 145 | ) 146 | if config.get('network', None) is not None: 147 | network_kwargs = config.network 148 | net = dnnlib.util.construct_class_by_name( 149 | **network_kwargs, **interface_kwargs 150 | ) # subclass of torch.nn.Module 151 | net.eval().requires_grad_(False).to(device) 152 | 153 | # Resume training from previous snapshot. 154 | with dnnlib.util.open_url(resume_pkl, verbose=True) as f: 155 | data = pickle.load(f) 156 | 157 | if config.get('network', None) is not None: 158 | misc.copy_params_and_buffers( 159 | src_module=data['ema'], dst_module=net, require_all=True 160 | ) 161 | else: 162 | net = data['ema'].eval().requires_grad_(False).to(device) 163 | 164 | 165 | grid_z = net.get_init_noise( 166 | [batch_size, net.img_channels, net.img_resolution, net.img_resolution], 167 | device, 168 | ) 169 | if net.label_dim > 0: 170 | labels = torch.randint(0, net.label_dim, (batch_size,), device=device) 171 | grid_c = torch.nn.functional.one_hot(labels, num_classes=net.label_dim) 172 | else: 173 | grid_c = None 174 | 175 | # Few-step Evaluation. 176 | generator_fn_dict = {k: functools.partial(generator_fn, **sample_kwargs) for k, sample_kwargs in sample_kwargs_dict.items()} 177 | print("Sample images...") 178 | res = {} 179 | for key, gen_fn in generator_fn_dict.items(): 180 | images = gen_fn(net, grid_z, grid_c) 181 | images = encoder.decode(images.to(device) ).detach().cpu() 182 | 183 | vutils.save_image( 184 | images / 255., 185 | os.path.join(f"{key}_samples.png"), 186 | nrow=int(np.sqrt(images.shape[0])), 187 | normalize=False, 188 | ) 189 | 190 | res[key] = images 191 | 192 | print('done.') 193 | 194 | # ---------------------------------------------------------------------------- 195 | 196 | if __name__ == "__main__": 197 | main() 198 | 199 | # ---------------------------------------------------------------------------- 200 | -------------------------------------------------------------------------------- /model_card.md: -------------------------------------------------------------------------------- 1 | # Model Card 2 | 3 | These are Inductive Moment Matching (IMM) models described in the paper [Inductive Moment Matching](https://arxiv.org/abs/2503.07565). We include the following models in this release: 4 | 5 | We provide pretrained checkpoints through our [repo](https://huggingface.co/lumaai/imm) on Hugging Face: 6 | * IMM on CIFAR-10: [cifar10.pkl](https://huggingface.co/lumaai/imm/resolve/main/cifar10.pt). 7 | * IMM on ImageNet-256x256: 8 | 1. `t-s` is passed as second time embedding, trained with `a=2`: [imagenet256_ts_a2.pkl](https://huggingface.co/lumaai/imm/resolve/main/imagenet256_ts_a2.pkl). 9 | 2. `s` is passed as second time embedding directly, trained with `a=1`: [imagenet256_s_a1.pkl](https://huggingface.co/lumaai/imm/resolve/main/imagenet256_s_a1.pkl). 10 | 11 | 12 | ## Intended Use 13 | 14 | This model is provided exclusively for research purposes. Acceptable uses include: 15 | 16 | - Academic research on generative modeling techniques 17 | - Benchmarking against other generative models 18 | - Educational purposes to understand Inductive Moment Matching algorithms 19 | - Exploration of model capabilities in controlled research environments 20 | 21 | Prohibited Uses: 22 | 23 | - Any commercial applications or commercial product development 24 | - Integration into products or services offered to customers 25 | - Generation of content for commercial distribution 26 | - Any applications that could result in harm, including but not limited to: 27 | - Creating deceptive or misleading content 28 | - Generating harmful, offensive, or discriminatory outputs 29 | - Circumventing security systems 30 | - Creating deepfakes or other potentially harmful synthetic media 31 | - Any use case that could negatively impact individuals or society 32 | 33 | ## Limitations 34 | 35 | The IMM models have several limitations common to image generation models: 36 | 37 | - Limited Resolution: The models are trained on specific resolutions (CIFAR-10 and 256x256 for ImageNet), and generating images at significantly higher resolutions may result in quality degradation or artifacts. 38 | - Computational Resources: Training and inference require substantial computational resources, which may limit their practical applications in resource-constrained environments. 39 | - Training Data Limitations: The models are trained on specific datasets (CIFAR-10 and ImageNet), and may not generalize well to other domains or data distributions. 40 | - Generalization to Unseen Data: The models may not generalize well to unseen data or domains, which is a common limitation for generative models. 41 | 42 | 43 | -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from . import training_stats 4 | import torch.distributed as dist 5 | import datetime 6 | 7 | # ---------------------------------------------------------------------------- 8 | 9 | 10 | def init(): 11 | if "MASTER_ADDR" not in os.environ: 12 | os.environ["MASTER_ADDR"] = "localhost" 13 | if "MASTER_PORT" not in os.environ: 14 | os.environ["MASTER_PORT"] = "29500" 15 | if "RANK" not in os.environ: 16 | os.environ["RANK"] = "0" 17 | if "LOCAL_RANK" not in os.environ: 18 | os.environ["LOCAL_RANK"] = "0" 19 | if "WORLD_SIZE" not in os.environ: 20 | os.environ["WORLD_SIZE"] = "1" 21 | 22 | os.environ["NCCL_SOCKET_IFNAME"] = "enp" 23 | os.environ["FI_EFA_SET_CUDA_SYNC_MEMOPS"] = "0" 24 | 25 | os.environ["NCCL_BUFFSIZE"] = "8388608" 26 | os.environ["NCCL_P2P_NET_CHUNKSIZE"] = "524288" 27 | 28 | 29 | os.environ['TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC'] = '1200' 30 | os.environ['TORCH_NCCL_ENABLE_MONITORING'] = '0' 31 | 32 | backend = "gloo" if os.name == "nt" else "nccl" 33 | torch.distributed.init_process_group(backend=backend, init_method="env://", timeout=datetime.timedelta(minutes=120),) 34 | torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", "0"))) 35 | 36 | sync_device = torch.device("cuda") if get_world_size() > 1 else None 37 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) 38 | 39 | 40 | # ---------------------------------------------------------------------------- 41 | 42 | 43 | def get_rank(): 44 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 45 | 46 | 47 | # ---------------------------------------------------------------------------- 48 | 49 | 50 | def get_world_size(): 51 | return ( 52 | torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 53 | ) 54 | 55 | 56 | # ---------------------------------------------------------------------------- 57 | 58 | 59 | def should_stop(): 60 | return False 61 | 62 | 63 | # ---------------------------------------------------------------------------- 64 | 65 | 66 | def update_progress(cur, total): 67 | _ = cur, total 68 | 69 | 70 | # ---------------------------------------------------------------------------- 71 | 72 | 73 | def print0(*args, **kwargs): 74 | if get_rank() == 0: 75 | print(*args, **kwargs) 76 | 77 | 78 | # ---------------------------------------------------------------------------- 79 | 80 | 81 | 82 | broadcast = dist.broadcast 83 | new_group = dist.new_group 84 | barrier = dist.barrier 85 | all_gather = dist.all_gather 86 | send = dist.send 87 | recv = dist.recv -------------------------------------------------------------------------------- /torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | import re 2 | import contextlib 3 | import numpy as np 4 | import torch 5 | import warnings 6 | import dnnlib 7 | import functools 8 | from . import persistence 9 | 10 | # ---------------------------------------------------------------------------- 11 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 12 | # same constant is used multiple times. 13 | 14 | _constant_cache = dict() 15 | 16 | 17 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 18 | value = np.asarray(value) 19 | if shape is not None: 20 | shape = tuple(shape) 21 | if dtype is None: 22 | dtype = torch.get_default_dtype() 23 | if device is None: 24 | device = torch.device("cpu") 25 | if memory_format is None: 26 | memory_format = torch.contiguous_format 27 | 28 | key = ( 29 | value.shape, 30 | value.dtype, 31 | value.tobytes(), 32 | shape, 33 | dtype, 34 | device, 35 | memory_format, 36 | ) 37 | tensor = _constant_cache.get(key, None) 38 | if tensor is None: 39 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 40 | if shape is not None: 41 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 42 | tensor = tensor.contiguous(memory_format=memory_format) 43 | _constant_cache[key] = tensor 44 | return tensor 45 | 46 | #---------------------------------------------------------------------------- 47 | # Variant of constant() that inherits dtype and device from the given 48 | # reference tensor by default. 49 | 50 | def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None): 51 | if dtype is None: 52 | dtype = ref.dtype 53 | if device is None: 54 | device = ref.device 55 | return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format) 56 | 57 | #---------------------------------------------------------------------------- 58 | # Cached construction of temporary tensors in pinned CPU memory. 59 | 60 | @functools.lru_cache(None) 61 | def pinned_buf(shape, dtype): 62 | return torch.empty(shape, dtype=dtype).pin_memory() 63 | 64 | #---------------------------------------------------------------------------- 65 | # Symbolic assert. 66 | 67 | try: 68 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 69 | except AttributeError: 70 | symbolic_assert = torch.Assert # 1.7.0 71 | 72 | 73 | # ---------------------------------------------------------------------------- 74 | # Replace NaN/Inf with specified numerical values. 75 | 76 | try: 77 | nan_to_num = torch.nan_to_num # 1.8.0a0 78 | except AttributeError: 79 | 80 | def nan_to_num( 81 | input, nan=0.0, posinf=None, neginf=None, *, out=None 82 | ): # pylint: disable=redefined-builtin 83 | assert isinstance(input, torch.Tensor) 84 | if posinf is None: 85 | posinf = torch.finfo(input.dtype).max 86 | if neginf is None: 87 | neginf = torch.finfo(input.dtype).min 88 | assert nan == 0 89 | return torch.clamp( 90 | input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out 91 | ) 92 | 93 | 94 | # ---------------------------------------------------------------------------- 95 | # Symbolic assert. 96 | 97 | try: 98 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 99 | except AttributeError: 100 | symbolic_assert = torch.Assert # 1.7.0 101 | 102 | # ---------------------------------------------------------------------------- 103 | # Context manager to temporarily suppress known warnings in torch.jit.trace(). 104 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 105 | 106 | 107 | @contextlib.contextmanager 108 | def suppress_tracer_warnings(): 109 | flt = ("ignore", None, torch.jit.TracerWarning, None, 0) 110 | warnings.filters.insert(0, flt) 111 | yield 112 | warnings.filters.remove(flt) 113 | 114 | 115 | # ---------------------------------------------------------------------------- 116 | # Assert that the shape of a tensor matches the given list of integers. 117 | # None indicates that the size of a dimension is allowed to vary. 118 | # Performs symbolic assertion when used in torch.jit.trace(). 119 | 120 | 121 | def assert_shape(tensor, ref_shape): 122 | if tensor.ndim != len(ref_shape): 123 | raise AssertionError( 124 | f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}" 125 | ) 126 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 127 | if ref_size is None: 128 | pass 129 | elif isinstance(ref_size, torch.Tensor): 130 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 131 | symbolic_assert( 132 | torch.equal(torch.as_tensor(size), ref_size), 133 | f"Wrong size for dimension {idx}", 134 | ) 135 | elif isinstance(size, torch.Tensor): 136 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 137 | symbolic_assert( 138 | torch.equal(size, torch.as_tensor(ref_size)), 139 | f"Wrong size for dimension {idx}: expected {ref_size}", 140 | ) 141 | elif size != ref_size: 142 | raise AssertionError( 143 | f"Wrong size for dimension {idx}: got {size}, expected {ref_size}" 144 | ) 145 | 146 | 147 | # ---------------------------------------------------------------------------- 148 | # Function decorator that calls torch.autograd.profiler.record_function(). 149 | 150 | 151 | def profiled_function(fn): 152 | def decorator(*args, **kwargs): 153 | with torch.autograd.profiler.record_function(fn.__name__): 154 | return fn(*args, **kwargs) 155 | 156 | decorator.__name__ = fn.__name__ 157 | return decorator 158 | 159 | 160 | # ---------------------------------------------------------------------------- 161 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 162 | # indefinitely, shuffling items as it goes. 163 | 164 | 165 | class InfiniteSampler(torch.utils.data.Sampler): 166 | def __init__( 167 | self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5 168 | ): 169 | assert len(dataset) > 0 170 | assert num_replicas > 0 171 | assert 0 <= rank < num_replicas 172 | assert 0 <= window_size <= 1 173 | super().__init__(dataset) 174 | self.dataset = dataset 175 | self.rank = rank 176 | self.num_replicas = num_replicas 177 | self.shuffle = shuffle 178 | self.seed = seed 179 | self.window_size = window_size 180 | 181 | def __iter__(self): 182 | order = np.arange(len(self.dataset)) 183 | rnd = None 184 | window = 0 185 | if self.shuffle: 186 | rnd = np.random.RandomState(self.seed) 187 | rnd.shuffle(order) 188 | window = int(np.rint(order.size * self.window_size)) 189 | 190 | idx = 0 191 | while True: 192 | i = idx % order.size 193 | if idx % self.num_replicas == self.rank: 194 | yield order[i] 195 | if window >= 2: 196 | j = (i - rnd.randint(window)) % order.size 197 | order[i], order[j] = order[j], order[i] 198 | idx += 1 199 | 200 | 201 | # ---------------------------------------------------------------------------- 202 | # Utilities for operating with torch.nn.Module parameters and buffers. 203 | 204 | 205 | def params_and_buffers(module): 206 | assert isinstance(module, torch.nn.Module) 207 | return list(module.parameters()) + list(module.buffers()) 208 | 209 | 210 | def named_params_and_buffers(module): 211 | assert isinstance(module, torch.nn.Module) 212 | return list(module.named_parameters()) + list(module.named_buffers()) 213 | 214 | 215 | @torch.no_grad() 216 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 217 | assert isinstance(src_module, torch.nn.Module) 218 | assert isinstance(dst_module, torch.nn.Module) 219 | src_tensors = dict(named_params_and_buffers(src_module)) 220 | for name, tensor in named_params_and_buffers(dst_module): 221 | assert (name in src_tensors) or (not require_all) 222 | if name in src_tensors: 223 | tensor.copy_(src_tensors[name]) 224 | 225 | 226 | # ---------------------------------------------------------------------------- 227 | # Context manager for easily enabling/disabling DistributedDataParallel 228 | # synchronization. 229 | 230 | 231 | @contextlib.contextmanager 232 | def ddp_sync(module, sync): 233 | assert isinstance(module, torch.nn.Module) 234 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 235 | yield 236 | else: 237 | with module.no_sync(): 238 | yield 239 | 240 | 241 | # ---------------------------------------------------------------------------- 242 | # Check DistributedDataParallel consistency across processes. 243 | 244 | 245 | def check_ddp_consistency(module, ignore_regex=None): 246 | assert isinstance(module, torch.nn.Module) 247 | for name, tensor in named_params_and_buffers(module): 248 | fullname = type(module).__name__ + "." + name 249 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 250 | continue 251 | tensor = tensor.detach() 252 | if tensor.is_floating_point(): 253 | tensor = nan_to_num(tensor) 254 | other = tensor.clone() 255 | torch.distributed.broadcast(tensor=other, src=0) 256 | assert (tensor == other).all(), fullname 257 | 258 | 259 | # ---------------------------------------------------------------------------- 260 | # Print summary table of module hierarchy. 261 | 262 | 263 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 264 | assert isinstance(module, torch.nn.Module) 265 | assert not isinstance(module, torch.jit.ScriptModule) 266 | assert isinstance(inputs, (tuple, list)) 267 | 268 | # Register hooks. 269 | entries = [] 270 | nesting = [0] 271 | 272 | def pre_hook(_mod, _inputs): 273 | nesting[0] += 1 274 | 275 | def post_hook(mod, _inputs, outputs): 276 | nesting[0] -= 1 277 | if nesting[0] <= max_nesting: 278 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 279 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 280 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 281 | 282 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 283 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 284 | 285 | # Run module. 286 | outputs = module(*inputs) 287 | for hook in hooks: 288 | hook.remove() 289 | 290 | # Identify unique outputs, parameters, and buffers. 291 | tensors_seen = set() 292 | for e in entries: 293 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 294 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 295 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 296 | tensors_seen |= { 297 | id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs 298 | } 299 | 300 | # Filter out redundant entries. 301 | if skip_redundant: 302 | entries = [ 303 | e 304 | for e in entries 305 | if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs) 306 | ] 307 | 308 | # Construct table. 309 | rows = [ 310 | [type(module).__name__, "Parameters", "Buffers", "Output shape", "Datatype"] 311 | ] 312 | rows += [["---"] * len(rows[0])] 313 | param_total = 0 314 | buffer_total = 0 315 | submodule_names = {mod: name for name, mod in module.named_modules()} 316 | for e in entries: 317 | name = "" if e.mod is module else submodule_names[e.mod] 318 | param_size = sum(t.numel() for t in e.unique_params) 319 | buffer_size = sum(t.numel() for t in e.unique_buffers) 320 | output_shapes = [str(list(t.shape)) for t in e.outputs] 321 | output_dtypes = [str(t.dtype).split(".")[-1] for t in e.outputs] 322 | rows += [ 323 | [ 324 | name + (":0" if len(e.outputs) >= 2 else ""), 325 | str(param_size) if param_size else "-", 326 | str(buffer_size) if buffer_size else "-", 327 | (output_shapes + ["-"])[0], 328 | (output_dtypes + ["-"])[0], 329 | ] 330 | ] 331 | for idx in range(1, len(e.outputs)): 332 | rows += [ 333 | [name + f":{idx}", "-", "-", output_shapes[idx], output_dtypes[idx]] 334 | ] 335 | param_total += param_size 336 | buffer_total += buffer_size 337 | rows += [["---"] * len(rows[0])] 338 | rows += [["Total", str(param_total), str(buffer_total), "-", "-"]] 339 | 340 | # Print table. 341 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 342 | print() 343 | for row in rows: 344 | print( 345 | " ".join( 346 | cell + " " * (width - len(cell)) for cell, width in zip(row, widths) 347 | ) 348 | ) 349 | print() 350 | return outputs 351 | 352 | 353 | # ---------------------------------------------------------------------------- 354 | 355 | 356 | 357 | import abc 358 | @persistence.persistent_class 359 | class ActivationHook(abc.ABC): 360 | def __init__(self, modules_to_watch): 361 | self.modules_to_watch = modules_to_watch 362 | 363 | self._hook_result = {} 364 | 365 | @property 366 | def hook_result(self): 367 | return self._hook_result 368 | 369 | @abc.abstractmethod 370 | def __call__(self, module, input, output): 371 | pass 372 | 373 | def watch(self, models_dict): 374 | 375 | acc = [] 376 | #register name for easy access 377 | for k in models_dict: 378 | model = models_dict[k] 379 | for name, module in model.named_modules(): 380 | if self.modules_to_watch == 'all': 381 | module._hook_name = name 382 | else: 383 | for mw in self.modules_to_watch: 384 | if mw in name and name not in acc: 385 | module._hook_name = k + '.' + name 386 | acc.append(name) 387 | 388 | 389 | def clear(self): 390 | self._hook_result = {} 391 | 392 | 393 | @persistence.persistent_class 394 | class ActivationMagnitudeHook(ActivationHook): 395 | def __init__(self, modules_to_watch='all'): 396 | super().__init__(modules_to_watch) 397 | 398 | def __call__(self, module, input, output): 399 | if hasattr(module, '_hook_name'): 400 | # only track registered modules 401 | if isinstance(output, torch.Tensor): 402 | output_ = output.detach() 403 | 404 | self._hook_result['activations/' + module._hook_name + '_magnitude_div_10000'] = (output_/10000).flatten(1).norm(1).mean().item() #prevent overflow 405 | else: 406 | self._hook_result['activations/' + module._hook_name + '_magnitude_div_10000'] = 0 407 | 408 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | """Facilities for pickling Python code alongside other data. 2 | 3 | The pickled code is automatically imported into a separate Python module 4 | during unpickling. This way, any previously exported pickles will remain 5 | usable even if the original code is no longer available, or if the current 6 | version of the code is not consistent with what was originally pickled.""" 7 | 8 | import sys 9 | import pickle 10 | import io 11 | import inspect 12 | import copy 13 | import uuid 14 | import types 15 | import dnnlib 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | _version = 6 # internal version number 20 | _decorators = set() # {decorator_class, ...} 21 | _import_hooks = [] # [hook_function, ...] 22 | _module_to_src_dict = dict() # {module: src, ...} 23 | _src_to_module_dict = dict() # {src: module, ...} 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | def persistent_class(orig_class): 28 | r"""Class decorator that extends a given class to save its source code 29 | when pickled. 30 | 31 | Example: 32 | 33 | from torch_utils import persistence 34 | 35 | @persistence.persistent_class 36 | class MyNetwork(torch.nn.Module): 37 | def __init__(self, num_inputs, num_outputs): 38 | super().__init__() 39 | self.fc = MyLayer(num_inputs, num_outputs) 40 | ... 41 | 42 | @persistence.persistent_class 43 | class MyLayer(torch.nn.Module): 44 | ... 45 | 46 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 47 | source code alongside other internal state (e.g., parameters, buffers, 48 | and submodules). This way, any previously exported pickle will remain 49 | usable even if the class definitions have been modified or are no 50 | longer available. 51 | 52 | The decorator saves the source code of the entire Python module 53 | containing the decorated class. It does *not* save the source code of 54 | any imported modules. Thus, the imported modules must be available 55 | during unpickling, also including `torch_utils.persistence` itself. 56 | 57 | It is ok to call functions defined in the same module from the 58 | decorated class. However, if the decorated class depends on other 59 | classes defined in the same module, they must be decorated as well. 60 | This is illustrated in the above example in the case of `MyLayer`. 61 | 62 | It is also possible to employ the decorator just-in-time before 63 | calling the constructor. For example: 64 | 65 | cls = MyLayer 66 | if want_to_make_it_persistent: 67 | cls = persistence.persistent_class(cls) 68 | layer = cls(num_inputs, num_outputs) 69 | 70 | As an additional feature, the decorator also keeps track of the 71 | arguments that were used to construct each instance of the decorated 72 | class. The arguments can be queried via `obj.init_args` and 73 | `obj.init_kwargs`, and they are automatically pickled alongside other 74 | object state. This feature can be disabled on a per-instance basis 75 | by setting `self._record_init_args = False` in the constructor. 76 | 77 | A typical use case is to first unpickle a previous instance of a 78 | persistent class, and then upgrade it to use the latest version of 79 | the source code: 80 | 81 | with open('old_pickle.pkl', 'rb') as f: 82 | old_net = pickle.load(f) 83 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 84 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 85 | """ 86 | assert isinstance(orig_class, type) 87 | if is_persistent(orig_class): 88 | return orig_class 89 | 90 | assert orig_class.__module__ in sys.modules 91 | orig_module = sys.modules[orig_class.__module__] 92 | orig_module_src = _module_to_src(orig_module) 93 | 94 | class Decorator(orig_class): 95 | _orig_module_src = orig_module_src 96 | _orig_class_name = orig_class.__name__ 97 | 98 | def __init__(self, *args, **kwargs): 99 | super().__init__(*args, **kwargs) 100 | record_init_args = getattr(self, '_record_init_args', True) 101 | self._init_args = copy.deepcopy(args) if record_init_args else None 102 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None 103 | assert orig_class.__name__ in orig_module.__dict__ 104 | _check_pickleable(self.__reduce__()) 105 | 106 | @property 107 | def init_args(self): 108 | assert self._init_args is not None 109 | return copy.deepcopy(self._init_args) 110 | 111 | @property 112 | def init_kwargs(self): 113 | assert self._init_kwargs is not None 114 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 115 | 116 | def __reduce__(self): 117 | fields = list(super().__reduce__()) 118 | fields += [None] * max(3 - len(fields), 0) 119 | if fields[0] is not _reconstruct_persistent_obj: 120 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 121 | fields[0] = _reconstruct_persistent_obj # reconstruct func 122 | fields[1] = (meta,) # reconstruct args 123 | fields[2] = None # state dict 124 | return tuple(fields) 125 | 126 | Decorator.__name__ = orig_class.__name__ 127 | Decorator.__module__ = orig_class.__module__ 128 | _decorators.add(Decorator) 129 | return Decorator 130 | 131 | #---------------------------------------------------------------------------- 132 | 133 | def is_persistent(obj): 134 | r"""Test whether the given object or class is persistent, i.e., 135 | whether it will save its source code when pickled. 136 | """ 137 | try: 138 | if obj in _decorators: 139 | return True 140 | except TypeError: 141 | pass 142 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 143 | 144 | #---------------------------------------------------------------------------- 145 | 146 | def import_hook(hook): 147 | r"""Register an import hook that is called whenever a persistent object 148 | is being unpickled. A typical use case is to patch the pickled source 149 | code to avoid errors and inconsistencies when the API of some imported 150 | module has changed. 151 | 152 | The hook should have the following signature: 153 | 154 | hook(meta) -> modified meta 155 | 156 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 157 | 158 | type: Type of the persistent object, e.g. `'class'`. 159 | version: Internal version number of `torch_utils.persistence`. 160 | module_src Original source code of the Python module. 161 | class_name: Class name in the original Python module. 162 | state: Internal state of the object. 163 | 164 | Example: 165 | 166 | @persistence.import_hook 167 | def wreck_my_network(meta): 168 | if meta.class_name == 'MyNetwork': 169 | print('MyNetwork is being imported. I will wreck it!') 170 | meta.module_src = meta.module_src.replace("True", "False") 171 | return meta 172 | """ 173 | assert callable(hook) 174 | _import_hooks.append(hook) 175 | 176 | #---------------------------------------------------------------------------- 177 | 178 | def _reconstruct_persistent_obj(meta): 179 | r"""Hook that is called internally by the `pickle` module to unpickle 180 | a persistent object. 181 | """ 182 | meta = dnnlib.EasyDict(meta) 183 | 184 | meta.state = dnnlib.EasyDict(meta.state) if meta.state is not None else dnnlib.EasyDict() 185 | for hook in _import_hooks: 186 | meta = hook(meta) 187 | assert meta is not None 188 | 189 | assert meta.version == _version 190 | module = _src_to_module(meta.module_src) 191 | 192 | assert meta.type == 'class' 193 | orig_class = module.__dict__[meta.class_name] 194 | decorator_class = persistent_class(orig_class) 195 | obj = decorator_class.__new__(decorator_class) 196 | 197 | setstate = getattr(obj, '__setstate__', None) 198 | if callable(setstate): 199 | setstate(meta.state) # pylint: disable=not-callable 200 | else: 201 | obj.__dict__.update(meta.state) 202 | 203 | return obj 204 | 205 | #---------------------------------------------------------------------------- 206 | 207 | def _module_to_src(module): 208 | r"""Query the source code of a given Python module. 209 | """ 210 | src = _module_to_src_dict.get(module, None) 211 | if src is None: 212 | src = inspect.getsource(module) 213 | _module_to_src_dict[module] = src 214 | _src_to_module_dict[src] = module 215 | return src 216 | 217 | def _src_to_module(src): 218 | r"""Get or create a Python module for the given source code. 219 | """ 220 | module = _src_to_module_dict.get(src, None) 221 | if module is None: 222 | module_name = "_imported_module_" + uuid.uuid4().hex 223 | module = types.ModuleType(module_name) 224 | sys.modules[module_name] = module 225 | _module_to_src_dict[module] = src 226 | _src_to_module_dict[src] = module 227 | exec(src, module.__dict__) # pylint: disable=exec-used 228 | return module 229 | 230 | #---------------------------------------------------------------------------- 231 | 232 | def _check_pickleable(obj): 233 | r"""Check that the given object is pickleable, raising an exception if 234 | it is not. This function is expected to be considerably more efficient 235 | than actually pickling the object. 236 | """ 237 | def recurse(obj): 238 | if isinstance(obj, (list, tuple, set)): 239 | return [recurse(x) for x in obj] 240 | if isinstance(obj, dict): 241 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 242 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 243 | return None # Python primitive types are pickleable. 244 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 245 | return None # NumPy arrays and PyTorch tensors are pickleable. 246 | if is_persistent(obj): 247 | return None # Persistent objects are pickleable, by virtue of the constructor check. 248 | return obj 249 | with io.BytesIO() as f: 250 | pickle.dump(recurse(obj), f) 251 | 252 | #---------------------------------------------------------------------------- 253 | -------------------------------------------------------------------------------- /torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | """Facilities for reporting and collecting training statistics across 2 | multiple processes and devices. The interface is designed to minimize 3 | synchronization overhead as well as the amount of boilerplate in user 4 | code.""" 5 | 6 | import re 7 | import numpy as np 8 | import torch 9 | import torch.distributed 10 | import dnnlib 11 | 12 | from collections import defaultdict 13 | 14 | from . import misc 15 | 16 | # ---------------------------------------------------------------------------- 17 | 18 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 19 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 20 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 21 | _rank = 0 # Rank of the current process. 22 | _sync_device = ( 23 | None # Device to use for multiprocess communication. None = single-process. 24 | ) 25 | _sync_called = False # Has _sync() been called yet? 26 | _counters = ( 27 | dict() 28 | ) # Running counters on each device, updated by report(): name => device => torch.Tensor 29 | _cumulative = ( 30 | dict() 31 | ) # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 32 | 33 | # ---------------------------------------------------------------------------- 34 | 35 | 36 | def init_multiprocessing(rank, sync_device): 37 | r"""Initializes `torch_utils.training_stats` for collecting statistics 38 | across multiple processes. 39 | 40 | This function must be called after 41 | `torch.distributed.init_process_group()` and before `Collector.update()`. 42 | The call is not necessary if multi-process collection is not needed. 43 | 44 | Args: 45 | rank: Rank of the current process. 46 | sync_device: PyTorch device to use for inter-process 47 | communication, or None to disable multi-process 48 | collection. Typically `torch.device('cuda', rank)`. 49 | """ 50 | global _rank, _sync_device 51 | assert not _sync_called 52 | _rank = rank 53 | _sync_device = sync_device 54 | 55 | 56 | # ---------------------------------------------------------------------------- 57 | 58 | 59 | @misc.profiled_function 60 | def report(name, value, ts=None, max_t=1, num_bins=4): 61 | r"""Broadcasts the given set of scalars to all interested instances of 62 | `Collector`, across device and process boundaries. 63 | 64 | This function is expected to be extremely cheap and can be safely 65 | called from anywhere in the training loop, loss function, or inside a 66 | `torch.nn.Module`. 67 | 68 | Warning: The current implementation expects the set of unique names to 69 | be consistent across processes. Please make sure that `report()` is 70 | called at least once for each unique name by each process, and in the 71 | same order. If a given process has no scalars to broadcast, it can do 72 | `report(name, [])` (empty list). 73 | 74 | Args: 75 | name: Arbitrary string specifying the name of the statistic. 76 | Averages are accumulated separately for each unique name. 77 | value: Arbitrary set of scalars. Can be a list, tuple, 78 | NumPy array, PyTorch tensor, or Python scalar. 79 | 80 | Returns: 81 | The same `value` that was passed in. 82 | """ 83 | value_in = value 84 | quantiles = {f"{name}_q{quartile}": [] for quartile in range(num_bins)} 85 | if ts is not None: 86 | for sub_t, sub_loss in zip(ts.cpu().numpy(), value): 87 | if isinstance(sub_loss, torch.Tensor): 88 | sub_loss = sub_loss.detach().cpu().numpy() 89 | 90 | quartile = int(num_bins * min(sub_t, max_t-1e-3) / max_t) 91 | 92 | quantiles[f"{name}_q{quartile}"].append(sub_loss.item()) 93 | 94 | else: 95 | quantiles[name] = value 96 | 97 | for name, value in quantiles.items(): 98 | if name not in _counters: 99 | _counters[name] = dict() 100 | 101 | elems = torch.as_tensor(value) 102 | if elems.numel() == 0: 103 | elems = torch.zeros([1], dtype=_reduce_dtype) 104 | moments = torch.stack( 105 | [ 106 | torch.zeros_like(elems).sum(), 107 | elems.sum(), 108 | elems.square().sum(), 109 | ] 110 | ).to(_counter_dtype) 111 | 112 | continue 113 | else: 114 | 115 | elems = elems.detach().flatten().to(_reduce_dtype) 116 | moments = torch.stack( 117 | [ 118 | torch.ones_like(elems).sum(), 119 | elems.sum(), 120 | elems.square().sum(), 121 | ] 122 | ) 123 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 124 | moments = moments.to(_counter_dtype) 125 | 126 | device = moments.device 127 | if device not in _counters[name]: 128 | _counters[name][device] = torch.zeros_like(moments) 129 | _counters[name][device].add_(moments) 130 | return value_in 131 | 132 | 133 | # ---------------------------------------------------------------------------- 134 | 135 | 136 | def report0(name, value): 137 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 138 | but ignores any scalars provided by the other processes. 139 | See `report()` for further details. 140 | """ 141 | 142 | report(name, value if _rank == 0 else []) 143 | return value 144 | 145 | 146 | # ---------------------------------------------------------------------------- 147 | 148 | 149 | class Collector: 150 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 151 | computes their long-term averages (mean and standard deviation) over 152 | user-defined periods of time. 153 | 154 | The averages are first collected into internal counters that are not 155 | directly visible to the user. They are then copied to the user-visible 156 | state as a result of calling `update()` and can then be queried using 157 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 158 | internal counters for the next round, so that the user-visible state 159 | effectively reflects averages collected between the last two calls to 160 | `update()`. 161 | 162 | Args: 163 | regex: Regular expression defining which statistics to 164 | collect. The default is to collect everything. 165 | keep_previous: Whether to retain the previous averages if no 166 | scalars were collected on a given round 167 | (default: True). 168 | """ 169 | 170 | def __init__(self, regex=".*", keep_previous=True): 171 | self._regex = re.compile(regex) 172 | self._keep_previous = keep_previous 173 | self._cumulative = dict() 174 | self._moments = dict() 175 | self._moments.clear() 176 | 177 | def names(self): 178 | r"""Returns the names of all statistics broadcasted so far that 179 | match the regular expression specified at construction time. 180 | """ 181 | return [name for name in _counters if self._regex.fullmatch(name)] 182 | 183 | def update( 184 | self, 185 | disable_sync=False 186 | ): 187 | r"""Copies current values of the internal counters to the 188 | user-visible state and resets them for the next round. 189 | 190 | If `keep_previous=True` was specified at construction time, the 191 | operation is skipped for statistics that have received no scalars 192 | since the last update, retaining their previous averages. 193 | 194 | This method performs a number of GPU-to-CPU transfers and one 195 | `torch.distributed.all_reduce()`. It is intended to be called 196 | periodically in the main training loop, typically once every 197 | N training steps. 198 | """ 199 | if not self._keep_previous: 200 | self._moments.clear() 201 | 202 | for name, cumulative in _sync(self.names(), disable=disable_sync): 203 | if name not in self._cumulative: 204 | self._cumulative[name] = torch.zeros( 205 | [_num_moments], dtype=_counter_dtype 206 | ) 207 | delta = cumulative - self._cumulative[name] 208 | self._cumulative[name].copy_(cumulative) 209 | if float(delta[0]) != 0: 210 | self._moments[name] = delta 211 | 212 | def _get_delta(self, name): 213 | r"""Returns the raw moments that were accumulated for the given 214 | statistic between the last two calls to `update()`, or zero if 215 | no scalars were collected. 216 | """ 217 | assert self._regex.fullmatch(name) 218 | if name not in self._moments: 219 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 220 | return self._moments[name] 221 | 222 | def num(self, name): 223 | r"""Returns the number of scalars that were accumulated for the given 224 | statistic between the last two calls to `update()`, or zero if 225 | no scalars were collected. 226 | """ 227 | delta = self._get_delta(name) 228 | return int(delta[0]) 229 | 230 | def mean(self, name): 231 | r"""Returns the mean of the scalars that were accumulated for the 232 | given statistic between the last two calls to `update()`, or NaN if 233 | no scalars were collected. 234 | """ 235 | delta = self._get_delta(name) 236 | if int(delta[0]) == 0: 237 | return float("nan") 238 | return float(delta[1] / delta[0]) 239 | 240 | def std(self, name): 241 | r"""Returns the standard deviation of the scalars that were 242 | accumulated for the given statistic between the last two calls to 243 | `update()`, or NaN if no scalars were collected. 244 | """ 245 | delta = self._get_delta(name) 246 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 247 | return float("nan") 248 | if int(delta[0]) == 1: 249 | return float(0) 250 | mean = float(delta[1] / delta[0]) 251 | raw_var = float(delta[2] / delta[0]) 252 | return np.sqrt(max(raw_var - np.square(mean), 0)) 253 | 254 | def as_dict(self): 255 | r"""Returns the averages accumulated between the last two calls to 256 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 257 | 258 | dnnlib.EasyDict( 259 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 260 | ... 261 | ) 262 | """ 263 | stats = dnnlib.EasyDict() 264 | for name in self.names(): 265 | stats[name] = dnnlib.EasyDict( 266 | mean=self.mean(name), 267 | ) 268 | return stats 269 | 270 | def __getitem__(self, name): 271 | r"""Convenience getter. 272 | `collector[name]` is a synonym for `collector.mean(name)`. 273 | """ 274 | return self.mean(name) 275 | 276 | 277 | # ---------------------------------------------------------------------------- 278 | 279 | 280 | def _sync(names, disable=False): 281 | r"""Synchronize the global cumulative counters across devices and 282 | processes. Called internally by `Collector.update()`. 283 | """ 284 | if len(names) == 0: 285 | return [] 286 | 287 | global _sync_called 288 | _sync_called = True 289 | 290 | # Collect deltas within current rank. 291 | deltas = [] 292 | device = _sync_device if _sync_device is not None else torch.device("cpu") 293 | 294 | for name in names: 295 | 296 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 297 | for counter in _counters[name].values(): 298 | delta.add_(counter.to(device)) 299 | counter.copy_(torch.zeros_like(counter)) 300 | deltas.append(delta) 301 | deltas = torch.stack(deltas) 302 | 303 | # Sum deltas across ranks. 304 | if _sync_device is not None and not disable: 305 | torch.distributed.all_reduce(deltas) 306 | # Update cumulative values. 307 | deltas = deltas.cpu() 308 | for idx, name in enumerate(names): 309 | if name not in _cumulative: 310 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 311 | _cumulative[name].add_(deltas[idx]) 312 | 313 | # Return name-value pairs. 314 | return [(name, _cumulative[name]) for name in names] 315 | 316 | 317 | # ---------------------------------------------------------------------------- 318 | # Convenience. 319 | 320 | default_collector = Collector() 321 | 322 | # ---------------------------------------------------------------------------- 323 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /training/dit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # GLIDE: https://github.com/openai/glide-text2im 9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py 10 | # -------------------------------------------------------- 11 | 12 | import torch 13 | import torch.nn as nn 14 | import numpy as np 15 | import math 16 | import functools 17 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp 18 | 19 | from torch_utils import persistence 20 | from einops import repeat 21 | 22 | def modulate(x, shift, scale ): 23 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 24 | 25 | ################################################################################# 26 | # Embedding Layers for Timesteps and Class Labels # 27 | ################################################################################# 28 | 29 | 30 | @persistence.persistent_class 31 | class FourierEmbedding(torch.nn.Module): 32 | def __init__(self, num_channels, scale=16, **kwargs): 33 | super().__init__() 34 | print("FourierEmbedding scale:", scale) 35 | self.register_buffer("freqs", torch.randn(num_channels // 2) * scale) 36 | 37 | def forward(self, x): 38 | dtype = x.dtype 39 | x = x.to(torch.float64).ger((2 * np.pi * self.freqs.to(torch.float64)) ) 40 | x = torch.cat([x.cos(), x.sin()], dim=1).to(dtype) 41 | return x 42 | 43 | @persistence.persistent_class 44 | class TimestepEmbedder(nn.Module): 45 | """ 46 | Embeds scalar timesteps into vector representations. 47 | """ 48 | 49 | def __init__(self, hidden_size, frequency_embedding_size=256, embedding_type='positional', use_mlp=True, scale=1): 50 | super().__init__() 51 | self.use_mlp = use_mlp 52 | self.hidden_size = hidden_size 53 | if use_mlp: 54 | self.mlp = nn.Sequential( 55 | nn.Linear(frequency_embedding_size, hidden_size , bias=True), 56 | nn.SiLU(), 57 | nn.Linear(hidden_size , hidden_size, bias=True), 58 | ) 59 | self.frequency_embedding_size = frequency_embedding_size 60 | 61 | self.embedding_type = embedding_type 62 | 63 | if self.embedding_type == 'fourier': 64 | self.register_buffer("freqs", torch.randn(frequency_embedding_size // 2) * scale) 65 | 66 | 67 | @staticmethod 68 | def positional_timestep_embedding(t, dim, max_period=10000 ): 69 | """ 70 | Create sinusoidal timestep embeddings. 71 | :param t: a 1-D Tensor of N indices, one per batch element. 72 | These may be fractional. 73 | :param dim: the dimension of the output. 74 | :param max_period: controls the minimum frequency of the embeddings. 75 | :return: an (N, D) Tensor of positional embeddings. 76 | """ 77 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 78 | half = dim // 2 79 | freqs = torch.exp( 80 | -math.log(max_period) 81 | * torch.arange(start=0, end=half, dtype=torch.float64) 82 | / half 83 | ).to(device=t.device) 84 | args = t[:, None].to(torch.float64) * freqs[None] 85 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 86 | if dim % 2: 87 | embedding = torch.cat( 88 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 89 | ) 90 | return embedding 91 | 92 | def fourier_timestep_embedding(self, t, ): 93 | x = t.to(torch.float64).ger((2 * np.pi * self.freqs.to(torch.float64)) ) 94 | x = torch.cat([x.cos(), x.sin()], dim=1) 95 | 96 | return x 97 | 98 | def forward(self, t ): 99 | if self.embedding_type == 'positional': 100 | t_freq = self.positional_timestep_embedding(t, self.frequency_embedding_size) 101 | elif self.embedding_type == 'fourier': 102 | t_freq = self.fourier_timestep_embedding(t) 103 | 104 | if self.use_mlp: 105 | t_emb = self.mlp(t_freq.to(dtype=t.dtype) ) 106 | else: 107 | t_emb = t_freq.to(dtype=t.dtype) 108 | return t_emb 109 | 110 | 111 | @persistence.persistent_class 112 | class LabelEmbedder(nn.Module): 113 | """ 114 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 115 | """ 116 | 117 | def __init__(self, num_classes, hidden_size, dropout_prob): 118 | super().__init__() 119 | use_cfg_embedding = dropout_prob > 0 120 | self.embedding_table = nn.Embedding( 121 | num_classes + use_cfg_embedding, hidden_size 122 | ) 123 | self.num_classes = num_classes 124 | self.dropout_prob = dropout_prob 125 | 126 | def token_drop(self, labels, force_drop_ids=None): 127 | """ 128 | Drops labels to enable classifier-free guidance. 129 | """ 130 | if force_drop_ids is None: 131 | drop_ids = ( 132 | torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 133 | ) 134 | else: 135 | drop_ids = force_drop_ids == 1 136 | labels = torch.where(drop_ids, self.num_classes, labels) 137 | return labels 138 | 139 | def forward(self, labels, train, force_drop_ids=None): 140 | use_dropout = self.dropout_prob > 0 141 | if (train and use_dropout) or (force_drop_ids is not None): 142 | labels = self.token_drop(labels, force_drop_ids) 143 | embeddings = self.embedding_table(labels) 144 | return embeddings 145 | 146 | 147 | ################################################################################# 148 | # Core DiT Model # 149 | ################################################################################# 150 | 151 | 152 | 153 | 154 | @persistence.persistent_class 155 | class DiTBlock(nn.Module): 156 | """ 157 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 158 | """ 159 | 160 | def __init__(self, hidden_size, num_heads, temb_size, mlp_ratio=4.0, skip=False, dropout=0, **block_kwargs): 161 | super().__init__() 162 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6 ) 163 | self.attn = Attention( 164 | hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs 165 | ) 166 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False , eps=1e-6) 167 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 168 | approx_gelu = lambda: nn.GELU(approximate="tanh") 169 | self.mlp = Mlp( 170 | in_features=hidden_size, 171 | hidden_features=mlp_hidden_dim, 172 | act_layer=approx_gelu, 173 | drop=dropout, 174 | ) 175 | self.adaLN_modulation = nn.Sequential( 176 | nn.SiLU(), nn.Linear(temb_size, 6 * hidden_size, bias=True) 177 | ) 178 | self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) if skip else None 179 | 180 | def forward(self, x, c, ): 181 | 182 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 183 | (self.adaLN_modulation(c )).chunk(6, dim=1) 184 | ) 185 | 186 | x = x + gate_msa.unsqueeze(1) * self.attn( 187 | modulate(self.norm1(x), shift_msa, scale_msa ) 188 | ) 189 | x = x + gate_mlp.unsqueeze(1) * self.mlp( 190 | modulate(self.norm2(x), shift_mlp, scale_mlp ) 191 | ) 192 | 193 | return x 194 | 195 | 196 | @persistence.persistent_class 197 | class FinalLayer(nn.Module): 198 | """ 199 | The final layer of DiT. 200 | """ 201 | 202 | def __init__(self, hidden_size, patch_size, out_channels): 203 | super().__init__() 204 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 205 | self.linear = nn.Linear( 206 | hidden_size, patch_size * patch_size * out_channels, bias=True 207 | ) 208 | self.adaLN_modulation = nn.Sequential( 209 | nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) 210 | ) 211 | 212 | def forward(self, x, c): 213 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 214 | x = modulate(self.norm_final(x), shift, scale) 215 | x = self.linear(x) 216 | return x 217 | 218 | 219 | @persistence.persistent_class 220 | class DiT(nn.Module): 221 | """ 222 | Diffusion model with a Transformer backbone. 223 | """ 224 | 225 | def __init__( 226 | self, 227 | img_resolution, 228 | patch_size=2, 229 | in_channels=4, 230 | hidden_size=1152, 231 | depth=28, 232 | num_heads=16, 233 | mlp_ratio=4.0, 234 | class_dropout_prob=0., 235 | num_classes=1000, 236 | s_embed=True, 237 | qk_norm=False, 238 | skip=False, 239 | embedding_kwargs={}, 240 | temb_mult=1, 241 | dropout=0, 242 | **kwargs 243 | ): 244 | super().__init__() 245 | self.in_channels = in_channels 246 | self.out_channels = in_channels 247 | self.patch_size = patch_size 248 | self.num_heads = num_heads 249 | self.skip = skip 250 | temb_size = hidden_size * temb_mult 251 | 252 | self.s_embed = s_embed 253 | if s_embed: 254 | self.s_embedder = TimestepEmbedder(temb_size, **embedding_kwargs ) 255 | 256 | self.x_embedder = PatchEmbed( 257 | img_resolution, patch_size, in_channels, hidden_size, bias=True, 258 | ) 259 | self.t_embedder = TimestepEmbedder(temb_size, **embedding_kwargs ) 260 | self.y_embedder = LabelEmbedder(num_classes + 1, temb_size, class_dropout_prob) 261 | num_patches = self.x_embedder.num_patches 262 | # Will use fixed sin-cos embedding: 263 | self.pos_embed = nn.Parameter( 264 | torch.zeros(1, num_patches, hidden_size), requires_grad=False 265 | ) 266 | self.blocks = nn.ModuleList( 267 | [ 268 | DiTBlock(hidden_size, num_heads,temb_size, mlp_ratio=mlp_ratio, qk_norm=qk_norm, dropout=dropout, ) 269 | for _ in range(depth) 270 | ] 271 | ) 272 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) 273 | self.initialize_weights() 274 | 275 | def initialize_weights(self): 276 | # Initialize transformer layers: 277 | def _basic_init(module): 278 | if isinstance(module, nn.Linear): 279 | torch.nn.init.xavier_uniform_(module.weight) 280 | if module.bias is not None: 281 | nn.init.constant_(module.bias, 0) 282 | elif isinstance(module, nn.LayerNorm): 283 | if module.bias is not None: 284 | nn.init.constant_(module.bias, 0) 285 | if module.weight is not None: 286 | nn.init.constant_(module.weight, 1.0) 287 | 288 | 289 | self.apply(_basic_init) 290 | 291 | # Initialize (and freeze) pos_embed by sin-cos embedding: 292 | pos_embed = get_2d_sincos_pos_embed( 293 | self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5), 294 | ) 295 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 296 | 297 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 298 | w = self.x_embedder.proj.weight.data 299 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 300 | nn.init.constant_(self.x_embedder.proj.bias, 0) 301 | 302 | # Initialize label embedding table: 303 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 304 | 305 | # Initialize timestep embedding MLP: 306 | if self.t_embedder.use_mlp: 307 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 308 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 309 | 310 | if self.s_embed and self.s_embedder.use_mlp: 311 | # Initialize timestep embedding MLP: 312 | nn.init.normal_(self.s_embedder.mlp[0].weight, std=0.02) 313 | nn.init.normal_(self.s_embedder.mlp[2].weight, std=0.02) 314 | 315 | # Zero-out adaLN modulation layers in DiT blocks: 316 | for block in self.blocks: 317 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 318 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 319 | # Zero-out output layers: 320 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 321 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 322 | nn.init.constant_(self.final_layer.linear.weight, 0) 323 | nn.init.constant_(self.final_layer.linear.bias, 0) 324 | 325 | 326 | def unpatchify(self, x): 327 | """ 328 | x: (N, T, patch_size**2 * C) 329 | imgs: (N, H, W, C) 330 | """ 331 | c = self.out_channels 332 | p = self.x_embedder.patch_size[0] 333 | h = w = int(x.shape[1] ** 0.5) 334 | assert h * w == x.shape[1] 335 | 336 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 337 | x = torch.einsum("nhwpqc->nchpwq", x) 338 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 339 | return imgs 340 | 341 | def forward(self, x, 342 | noise_labels_t, 343 | noise_labels_s=None, 344 | class_labels=None, 345 | **kwargs): 346 | """ 347 | Forward pass of DiT. 348 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 349 | t: (N,) tensor of diffusion timesteps 350 | y: (N, num_classes) tensor of one-hot vector labels, all zeros denote unconditional 351 | """ 352 | is_uncond = 1 - class_labels.sum(dim=1, keepdims=True) # 1 if unconditional 0 otherwise 353 | y = torch.cat([ class_labels, is_uncond], dim=1) 354 | y = y.argmax(dim=1) # to (N,) tensor of class labels 355 | 356 | x = ( 357 | self.x_embedder(x) + self.pos_embed 358 | ) # (N, T, D), where T = H * W / patch_size ** 2 359 | if noise_labels_t.shape[0] == 1: 360 | noise_labels_t = repeat(noise_labels_t, '1 ... -> B ...', B=x.shape[0]) 361 | 362 | t = self.t_embedder(noise_labels_t ) # (N, D) 363 | if noise_labels_s is not None and self.s_embed: 364 | 365 | if noise_labels_s.shape[0] == 1: 366 | noise_labels_s = repeat(noise_labels_s, '1 ... -> B ...', B=x.shape[0]) 367 | 368 | s = self.s_embedder(noise_labels_s ) 369 | 370 | t = t + s 371 | 372 | 373 | y = self.y_embedder(y, self.training) # (N, D) 374 | c = t + y # (N, D) 375 | 376 | for block in self.blocks: 377 | x = block(x, c) # (N, T, D) 378 | 379 | 380 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 381 | x = self.unpatchify(x) # (N, out_channels, H, W) 382 | 383 | return x 384 | 385 | 386 | ################################################################################# 387 | # Sine/Cosine Positional Embedding Functions # 388 | ################################################################################# 389 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 390 | 391 | 392 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 393 | """ 394 | grid_size: int of the grid height and width 395 | return: 396 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 397 | """ 398 | grid_h = np.arange(grid_size, dtype=np.float64) 399 | grid_w = np.arange(grid_size, dtype=np.float64) 400 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 401 | grid = np.stack(grid, axis=0) 402 | 403 | grid = grid.reshape([2, 1, grid_size, grid_size]) 404 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 405 | if cls_token and extra_tokens > 0: 406 | pos_embed = np.concatenate( 407 | [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 408 | ) 409 | return pos_embed 410 | 411 | 412 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 413 | assert embed_dim % 2 == 0 414 | 415 | # use half of dimensions to encode grid_h 416 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 417 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 418 | 419 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 420 | return emb 421 | 422 | 423 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 424 | """ 425 | embed_dim: output dimension for each position 426 | pos: a list of positions to be encoded: size (M,) 427 | out: (M, D) 428 | """ 429 | assert embed_dim % 2 == 0 430 | omega = np.arange(embed_dim // 2, dtype=np.float64) 431 | omega /= embed_dim / 2.0 432 | omega = 1.0 / 10000**omega # (D/2,) 433 | 434 | pos = pos.reshape(-1) # (M,) 435 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 436 | 437 | emb_sin = np.sin(out) # (M, D/2) 438 | emb_cos = np.cos(out) # (M, D/2) 439 | 440 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 441 | return emb 442 | 443 | 444 | ################################################################################# 445 | # DiT Configs # 446 | ################################################################################# 447 | 448 | 449 | @persistence.persistent_class 450 | class DiT_XL_2(DiT): 451 | def __init__(self, **kwargs): 452 | super().__init__(patch_size=2, hidden_size=1152, depth=28, num_heads=16, **kwargs) 453 | 454 | @persistence.persistent_class 455 | class DiT_XL_4(DiT): 456 | def __init__(self, **kwargs): 457 | super().__init__(patch_size=4, hidden_size=1152, depth=28, num_heads=16, **kwargs) 458 | 459 | @persistence.persistent_class 460 | class DiT_XL_8(DiT): 461 | def __init__(self, **kwargs): 462 | super().__init__(patch_size=8, hidden_size=1152, depth=28, num_heads=16, **kwargs) 463 | 464 | @persistence.persistent_class 465 | class DiT_L_2(DiT): 466 | def __init__(self, **kwargs): 467 | super().__init__(patch_size=2, hidden_size=1024, depth=24, num_heads=16, **kwargs) 468 | 469 | @persistence.persistent_class 470 | class DiT_L_4(DiT): 471 | def __init__(self, **kwargs): 472 | super().__init__(patch_size=4, hidden_size=1024, depth=24, num_heads=16, **kwargs) 473 | 474 | @persistence.persistent_class 475 | class DiT_L_8(DiT): 476 | def __init__(self, **kwargs): 477 | super().__init__(patch_size=8, hidden_size=1024, depth=24, num_heads=16, **kwargs) 478 | 479 | @persistence.persistent_class 480 | class DiT_B_2(DiT): 481 | def __init__(self, **kwargs): 482 | super().__init__(patch_size=2, hidden_size=768, depth=12, num_heads=12, **kwargs) 483 | 484 | @persistence.persistent_class 485 | class DiT_B_4(DiT): 486 | def __init__(self, **kwargs): 487 | super().__init__(patch_size=4, hidden_size=768, depth=12, num_heads=12, **kwargs) 488 | 489 | @persistence.persistent_class 490 | class DiT_B_8(DiT): 491 | def __init__(self, **kwargs): 492 | super().__init__(patch_size=8, hidden_size=768, depth=12, num_heads=12, **kwargs) 493 | 494 | @persistence.persistent_class 495 | class DiT_S_2(DiT): 496 | def __init__(self, **kwargs): 497 | super().__init__(patch_size=2, hidden_size=384, depth=12, num_heads=6, **kwargs) 498 | 499 | @persistence.persistent_class 500 | class DiT_S_4(DiT): 501 | def __init__(self, **kwargs): 502 | super().__init__(patch_size=4, hidden_size=384, depth=12, num_heads=6, **kwargs) 503 | 504 | @persistence.persistent_class 505 | class DiT_S_8(DiT): 506 | def __init__(self, **kwargs): 507 | super().__init__(patch_size=8, hidden_size=384, depth=12, num_heads=6, **kwargs) 508 | 509 | -------------------------------------------------------------------------------- /training/encoders.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Converting between pixel and latent representations of image data.""" 9 | 10 | import os 11 | import warnings 12 | import numpy as np 13 | import torch 14 | from torch_utils import persistence 15 | from torch_utils import misc 16 | 17 | warnings.filterwarnings('ignore', 'torch.utils._pytree._register_pytree_node is deprecated.') 18 | warnings.filterwarnings('ignore', '`resume_download` is deprecated') 19 | 20 | #---------------------------------------------------------------------------- 21 | # Abstract base class for encoders/decoders that convert back and forth 22 | # between pixel and latent representations of image data. 23 | # 24 | # Logically, "raw pixels" are first encoded into "raw latents" that are 25 | # then further encoded into "final latents". Decoding, on the other hand, 26 | # goes directly from the final latents to raw pixels. The final latents are 27 | # used as inputs and outputs of the model, whereas the raw latents are 28 | # stored in the dataset. This separation provides added flexibility in terms 29 | # of performing just-in-time adjustments, such as data whitening, without 30 | # having to construct a new dataset. 31 | # 32 | # All image data is represented as PyTorch tensors in NCHW order. 33 | # Raw pixels are represented as 3-channel uint8. 34 | 35 | @persistence.persistent_class 36 | class Encoder: 37 | def __init__(self): 38 | pass 39 | 40 | def init(self, device): # force lazy init to happen now 41 | pass 42 | 43 | def __getstate__(self): 44 | return self.__dict__ 45 | 46 | def encode(self, x): # raw pixels => final latents 47 | return self.encode_latents(self.encode_pixels(x)) 48 | 49 | def encode_pixels(self, x): # raw pixels => raw latents 50 | raise NotImplementedError # to be overridden by subclass 51 | 52 | def encode_latents(self, x): # raw latents => final latents 53 | raise NotImplementedError # to be overridden by subclass 54 | 55 | def decode(self, x): # final latents => raw pixels 56 | raise NotImplementedError # to be overridden by subclass 57 | 58 | #---------------------------------------------------------------------------- 59 | # Standard RGB encoder that scales the pixel data into [-1, +1]. 60 | 61 | @persistence.persistent_class 62 | class IdentityEncoder(Encoder): 63 | def __init__(self): 64 | super().__init__() 65 | 66 | def encode_pixels(self, x): # raw pixels => raw latents 67 | return x 68 | 69 | def encode_latents(self, x): # raw latents => final latents 70 | return x 71 | def encode(self, x): 72 | return x 73 | def decode(self, x): # final latents => raw pixels 74 | return x 75 | #---------------------------------------------------------------------------- 76 | # Standard RGB encoder that scales the pixel data into [-1, +1]. 77 | 78 | @persistence.persistent_class 79 | class StandardRGBEncoder(Encoder): 80 | def __init__(self): 81 | super().__init__() 82 | 83 | def encode_pixels(self, x): # raw pixels => raw latents 84 | return x 85 | 86 | def encode_latents(self, x): # raw latents => final latents 87 | return x.to(torch.float32) / 127.5 - 1 88 | 89 | def encode(self, x): 90 | return self.encode_latents(self.encode_pixels(x)) 91 | 92 | 93 | def decode(self, x): # final latents => raw pixels 94 | 95 | return (x.to(torch.float32) * 127.5 + 128).clip(0, 255).to(torch.uint8) 96 | 97 | #---------------------------------------------------------------------------- 98 | # Pre-trained VAE encoder from Stability AI. 99 | 100 | @persistence.persistent_class 101 | class StabilityVAEEncoder(Encoder): 102 | def __init__(self, 103 | vae_name = 'stabilityai/sd-vae-ft-mse', # Name of the VAE to use. 104 | raw_mean = [1.56, -0.695, 0.483, 0.729], # Assumed mean of the raw latents. 105 | raw_std = [5.27, 5.91, 4.21, 4.31], # Assumed standard deviation of the raw latents. 106 | final_mean = 0, # Desired mean of the final latents. 107 | final_std = 0.5, # Desired standard deviation of the final latents. 108 | batch_size = 8, # Batch size to use when running the VAE. 109 | use_fp16 = False, # Data type to use for the latents. 110 | ): 111 | super().__init__() 112 | self.vae_name = vae_name 113 | self.scale = np.float32(final_std) / np.float32(raw_std) 114 | self.bias = np.float32(final_mean) - np.float32(raw_mean) * self.scale 115 | self.batch_size = int(batch_size) 116 | self._vae = None 117 | self.dtype = torch.float16 if use_fp16 else torch.float32 118 | 119 | def init(self, device): # force lazy init to happen now 120 | super().init(device) 121 | if self._vae is None: 122 | self._vae = load_stability_vae(self.vae_name, device=device, dtype=self.dtype) 123 | else: 124 | self._vae.to(device) 125 | 126 | def __getstate__(self): 127 | return dict(super().__getstate__(), _vae=None) # do not pickle the vae 128 | 129 | def _run_vae_encoder(self, x): 130 | dtype = x.dtype 131 | d = self._vae.encode(x.to(self.dtype))['latent_dist'] 132 | return torch.cat([d.mean, d.std], dim=1).to(dtype) 133 | 134 | def _run_vae_decoder(self, x): 135 | dtype = x.dtype 136 | return self._vae.decode(x.to(self.dtype))['sample'].to(dtype) 137 | 138 | def encode_pixels(self, x): # raw pixels => raw latents 139 | self.init(x.device) 140 | x = x.to(torch.float32) / 127.5 - 1 141 | x = torch.cat([self._run_vae_encoder(batch) for batch in x.split(self.batch_size)]) 142 | return x 143 | def encode_latents(self, x): # raw latents => final latents 144 | mean, std = x.to(torch.float32).chunk(2, dim=1) 145 | x = mean + torch.randn_like(mean) * std 146 | x = x * misc.const_like(x, self.scale).reshape(1, -1, 1, 1) 147 | x = x + misc.const_like(x, self.bias).reshape(1, -1, 1, 1) 148 | return x 149 | 150 | def encode(self, x): 151 | if x.shape[1] == 2 * 4: 152 | return self.encode_latents(x) 153 | elif x.shape[1] == 3: 154 | return self.encode_latents(self.encode_pixels(x)) 155 | else: 156 | raise ValueError(f'Invalid number of channels: {x.shape[1]}') 157 | 158 | def decode_latents_to_pixels(self, x): 159 | self.init(x.device) 160 | x = x.to(torch.float32) 161 | x = x - misc.const_like(x, self.bias).reshape(1, -1, 1, 1) 162 | x = x / misc.const_like(x, self.scale).reshape(1, -1, 1, 1) 163 | x = torch.cat([self._run_vae_decoder(batch) for batch in x.split(self.batch_size)]) 164 | x = (x * 0.5 + 0.5).clamp(0,1).mul(255).to(torch.uint8) 165 | return x 166 | 167 | def decode(self, x): # final latents => raw pixels 168 | if x.shape[1] == 2 * 4: 169 | mean, std = x.to(torch.float32).chunk(2, dim=1) 170 | x = mean + torch.randn_like(mean) * std 171 | return self.decode_latents_to_pixels(x) 172 | elif x.shape[1] == 4: 173 | return self.decode_latents_to_pixels(x) 174 | elif x.shape[1] == 3: 175 | return x 176 | else: 177 | raise ValueError(f'Invalid number of channels: {x.shape[1]}') 178 | #---------------------------------------------------------------------------- 179 | 180 | def load_stability_vae(vae_name='stabilityai/sd-vae-ft-mse', device=torch.device('cpu'), dtype=torch.float32): 181 | import dnnlib 182 | cache_dir = dnnlib.make_cache_dir_path('diffusers') 183 | os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1' 184 | os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' 185 | os.environ['HF_HOME'] = cache_dir 186 | 187 | import diffusers # pip install diffusers # pyright: ignore [reportMissingImports] 188 | try: 189 | # First try with local_files_only to avoid consulting tfhub metadata if the model is already in cache. 190 | vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, cache_dir=cache_dir, local_files_only=True, torch_dtype=dtype) 191 | except: 192 | # Could not load the model from cache; try without local_files_only. 193 | vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, cache_dir=cache_dir, torch_dtype=dtype) 194 | return torch.compile(vae.eval().requires_grad_(False).to(device), mode="max-autotune", fullgraph=True) 195 | 196 | #---------------------------------------------------------------------------- -------------------------------------------------------------------------------- /training/preconds.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Model architectures and preconditioning schemes used in the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import numpy as np 12 | import torch 13 | from torch_utils import persistence 14 | from training.unets import * 15 | from training.dit import * 16 | 17 | 18 | @persistence.persistent_class 19 | class IMMPrecond(torch.nn.Module): 20 | 21 | def __init__( 22 | self, 23 | img_resolution, # Image resolution. 24 | img_channels, # Number of color channels. 25 | label_dim=0, # Number of class labels, 0 = unconditional. 26 | mixed_precision=None, 27 | noise_schedule="fm", 28 | model_type="SongUNet", 29 | sigma_data=0.5, 30 | f_type="euler_fm", 31 | T=0.994, 32 | eps=0., 33 | temb_type='identity', 34 | time_scale=1000., 35 | **model_kwargs, # Keyword arguments for the underlying model. 36 | ): 37 | super().__init__() 38 | 39 | 40 | self.img_resolution = img_resolution 41 | self.img_channels = img_channels 42 | 43 | self.label_dim = label_dim 44 | self.use_mixed_precision = mixed_precision is not None 45 | if mixed_precision == 'bf16': 46 | self.mixed_precision = torch.bfloat16 47 | elif mixed_precision == 'fp16': 48 | self.mixed_precision = torch.float16 49 | elif mixed_precision is None: 50 | self.mixed_precision = torch.float32 51 | else: 52 | raise ValueError(f"Unknown mixed_precision: {mixed_precision}") 53 | 54 | 55 | self.noise_schedule = noise_schedule 56 | 57 | self.T = T 58 | self.eps = eps 59 | 60 | self.sigma_data = sigma_data 61 | 62 | self.f_type = f_type 63 | 64 | self.nt_low = self.get_log_nt(torch.tensor(self.eps, dtype=torch.float64)).exp().numpy().item() 65 | self.nt_high = self.get_log_nt(torch.tensor(self.T, dtype=torch.float64)).exp().numpy().item() 66 | 67 | self.model = globals()[model_type]( 68 | img_resolution=img_resolution, 69 | img_channels=img_channels, 70 | in_channels=img_channels, 71 | out_channels=img_channels, 72 | label_dim=label_dim, 73 | **model_kwargs, 74 | ) 75 | print('# Mparams:', sum(p.numel() for p in self.model.parameters()) / 1000000) 76 | 77 | self.time_scale = time_scale 78 | 79 | 80 | self.temb_type = temb_type 81 | 82 | if self.f_type == 'euler_fm': 83 | assert self.noise_schedule == 'fm' 84 | 85 | 86 | def get_logsnr(self, t): 87 | dtype = t.dtype 88 | t = t.to(torch.float64) 89 | if self.noise_schedule == "vp_cosine": 90 | logsnr = -2 * torch.log(torch.tan(t * torch.pi * 0.5)) 91 | 92 | elif self.noise_schedule == "fm": 93 | logsnr = 2 * ((1 - t).log() - t.log()) 94 | 95 | logsnr = logsnr.to(dtype) 96 | return logsnr 97 | 98 | def get_log_nt(self, t): 99 | logsnr_t = self.get_logsnr(t) 100 | 101 | return -0.5 * logsnr_t 102 | 103 | def get_alpha_sigma(self, t): 104 | if self.noise_schedule == 'fm': 105 | alpha_t = (1 - t) 106 | sigma_t = t 107 | elif self.noise_schedule == 'vp_cosine': 108 | alpha_t = torch.cos(t * torch.pi * 0.5) 109 | sigma_t = torch.sin(t * torch.pi * 0.5) 110 | 111 | return alpha_t, sigma_t 112 | 113 | def add_noise(self, y, t, noise=None): 114 | 115 | if noise is None: 116 | noise = torch.randn_like(y) * self.sigma_data 117 | 118 | alpha_t, sigma_t = self.get_alpha_sigma(t) 119 | 120 | return alpha_t * y + sigma_t * noise, noise 121 | 122 | def ddim(self, yt, y, t, s, noise=None): 123 | alpha_t, sigma_t = self.get_alpha_sigma(t) 124 | alpha_s, sigma_s = self.get_alpha_sigma(s) 125 | 126 | 127 | if noise is None: 128 | ys = (alpha_s - alpha_t * sigma_s / sigma_t) * y + sigma_s / sigma_t * yt 129 | else: 130 | ys = alpha_s * y + sigma_s * noise 131 | return ys 132 | 133 | 134 | 135 | def simple_edm_sample_function(self, yt, y, t, s ): 136 | alpha_t, sigma_t = self.get_alpha_sigma(t) 137 | alpha_s, sigma_s = self.get_alpha_sigma(s) 138 | 139 | c_skip = (alpha_t * alpha_s + sigma_t * sigma_s) / (alpha_t**2 + sigma_t**2) 140 | 141 | c_out = - (alpha_s * sigma_t - alpha_t * sigma_s) * (alpha_t**2 + sigma_t**2).rsqrt() * self.sigma_data 142 | 143 | return c_skip * yt + c_out * y 144 | 145 | def euler_fm_sample_function(self, yt, y, t, s ): 146 | assert self.noise_schedule == 'fm' 147 | 148 | 149 | return yt - (t - s) * self.sigma_data * y 150 | 151 | def nt_to_t(self, nt): 152 | dtype = nt.dtype 153 | nt = nt.to(torch.float64) 154 | if self.noise_schedule == "vp_cosine": 155 | t = torch.arctan(nt) / (torch.pi * 0.5) 156 | 157 | elif self.noise_schedule == "fm": 158 | t = nt / (1 + nt) 159 | 160 | t = torch.nan_to_num(t, nan=1) 161 | 162 | t = t.to(dtype) 163 | 164 | 165 | if ( 166 | self.noise_schedule.startswith("vp") 167 | and self.noise_schedule == "fm" 168 | and t.max() > 1 169 | ): 170 | raise ValueError(f"t out of range: {t.min().item()}, {t.max().item()}") 171 | return t 172 | 173 | def get_init_noise(self, shape, device): 174 | 175 | noise = torch.randn(shape, device=device) * self.sigma_data 176 | return noise 177 | 178 | def forward_model( 179 | self, 180 | model, 181 | x, 182 | t, 183 | s, 184 | class_labels=None, 185 | force_fp32=False, 186 | **model_kwargs, 187 | ): 188 | 189 | 190 | 191 | alpha_t, sigma_t = self.get_alpha_sigma(t) 192 | 193 | c_in = (alpha_t ** 2 + sigma_t**2 ).rsqrt() / self.sigma_data 194 | if self.temb_type == 'identity': 195 | 196 | c_noise_t = t * self.time_scale 197 | c_noise_s = s * self.time_scale 198 | 199 | elif self.temb_type == 'stride': 200 | 201 | c_noise_t = t * self.time_scale 202 | c_noise_s = (t - s) * self.time_scale 203 | 204 | with torch.amp.autocast('cuda', enabled=self.use_mixed_precision and not force_fp32, dtype= self.mixed_precision ): 205 | F_x = model( 206 | (c_in * x) , 207 | c_noise_t.flatten() , 208 | c_noise_s.flatten() , 209 | class_labels=class_labels, 210 | **model_kwargs, 211 | ) 212 | return F_x 213 | 214 | 215 | def forward( 216 | self, 217 | x, 218 | t, 219 | s=None, 220 | class_labels=None, 221 | force_fp32=False, 222 | **model_kwargs, 223 | ): 224 | dtype = t.dtype 225 | class_labels = ( 226 | None 227 | if self.label_dim == 0 228 | else ( 229 | torch.zeros([1, self.label_dim], device=x.device) 230 | if class_labels is None 231 | else class_labels.to(torch.float32).reshape(-1, self.label_dim) 232 | ) 233 | ) 234 | 235 | F_x = self.forward_model( 236 | self.model, 237 | x.to(torch.float32), 238 | t.to(torch.float32).reshape(-1, 1, 1, 1), 239 | s.to(torch.float32).reshape(-1, 1, 1, 1) if s is not None else None, 240 | class_labels, 241 | force_fp32, 242 | **model_kwargs, 243 | ) 244 | F_x = F_x.to(dtype) 245 | 246 | if self.f_type == "identity": 247 | F_x = self.ddim(x, F_x , t, s) 248 | elif self.f_type == "simple_edm": 249 | F_x = self.simple_edm_sample_function(x, F_x , t, s) 250 | elif self.f_type == "euler_fm": 251 | F_x = self.euler_fm_sample_function(x, F_x, t, s) 252 | else: 253 | raise NotImplementedError 254 | 255 | return F_x 256 | 257 | def cfg_forward( 258 | self, 259 | x, 260 | t, 261 | s=None, 262 | class_labels=None, 263 | force_fp32=False, 264 | cfg_scale=None, 265 | **model_kwargs, 266 | ): 267 | dtype = t.dtype 268 | class_labels = ( 269 | None 270 | if self.label_dim == 0 271 | else ( 272 | torch.zeros([1, self.label_dim], device=x.device) 273 | if class_labels is None 274 | else class_labels.to(torch.float32).reshape(-1, self.label_dim) 275 | ) 276 | ) 277 | if cfg_scale is not None: 278 | 279 | x_cfg = torch.cat([x, x], dim=0) 280 | class_labels = torch.cat([torch.zeros_like(class_labels), class_labels], dim=0) 281 | else: 282 | x_cfg = x 283 | F_x = self.forward_model( 284 | self.model, 285 | x_cfg.to(torch.float32), 286 | t.to(torch.float32).reshape(-1, 1, 1, 1) , 287 | s.to(torch.float32).reshape(-1, 1, 1, 1) if s is not None else None, 288 | class_labels=class_labels, 289 | force_fp32=force_fp32, 290 | **model_kwargs, 291 | ) 292 | F_x = F_x.to(dtype) 293 | 294 | if cfg_scale is not None: 295 | uncond_F = F_x[:len(x) ] 296 | cond_F = F_x[len(x) :] 297 | 298 | F_x = uncond_F + cfg_scale * (cond_F - uncond_F) 299 | 300 | if self.f_type == "identity": 301 | F_x = self.ddim(x, F_x, t, s) 302 | elif self.f_type == "simple_edm": 303 | F_x = self.simple_edm_sample_function(x, F_x , t, s) 304 | elif self.f_type == "euler_fm": 305 | F_x = self.euler_fm_sample_function(x, F_x , t, s) 306 | else: 307 | raise NotImplementedError 308 | 309 | return F_x -------------------------------------------------------------------------------- /training/unets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Model architectures and preconditioning schemes used in the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import numpy as np 12 | import torch 13 | from torch_utils import persistence 14 | from torch.nn.functional import silu 15 | 16 | # ---------------------------------------------------------------------------- 17 | # Unified routine for initializing weights and biases. 18 | 19 | 20 | def weight_init(shape, mode, fan_in, fan_out): 21 | if mode == "xavier_uniform": 22 | return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1) 23 | if mode == "xavier_normal": 24 | return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape) 25 | if mode == "kaiming_uniform": 26 | return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1) 27 | if mode == "kaiming_normal": 28 | return np.sqrt(1 / fan_in) * torch.randn(*shape) 29 | raise ValueError(f'Invalid init mode "{mode}"') 30 | 31 | 32 | # ---------------------------------------------------------------------------- 33 | # Fully-connected layer. 34 | 35 | 36 | @persistence.persistent_class 37 | class Linear(torch.nn.Module): 38 | def __init__( 39 | self, 40 | in_features, 41 | out_features, 42 | bias=True, 43 | init_mode="kaiming_normal", 44 | init_weight=1, 45 | init_bias=0, 46 | ): 47 | super().__init__() 48 | self.in_features = in_features 49 | self.out_features = out_features 50 | init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features) 51 | self.weight = torch.nn.Parameter( 52 | weight_init([out_features, in_features], **init_kwargs) * init_weight 53 | ) 54 | self.bias = ( 55 | torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias) 56 | if bias 57 | else None 58 | ) 59 | 60 | def forward(self, x): 61 | x = x @ self.weight.to(x.dtype).t() 62 | if self.bias is not None: 63 | x = x.add_(self.bias.to(x.dtype)) 64 | return x 65 | 66 | 67 | # ---------------------------------------------------------------------------- 68 | # Convolutional layer with optional up/downsampling. 69 | 70 | 71 | @persistence.persistent_class 72 | class Conv2d(torch.nn.Module): 73 | def __init__( 74 | self, 75 | in_channels, 76 | out_channels, 77 | kernel, 78 | bias=True, 79 | up=False, 80 | down=False, 81 | resample_filter=[1, 1], 82 | fused_resample=False, 83 | init_mode="kaiming_normal", 84 | init_weight=1, 85 | init_bias=0, 86 | ): 87 | assert not (up and down) 88 | super().__init__() 89 | self.in_channels = in_channels 90 | self.out_channels = out_channels 91 | self.up = up 92 | self.down = down 93 | self.fused_resample = fused_resample 94 | init_kwargs = dict( 95 | mode=init_mode, 96 | fan_in=in_channels * kernel * kernel, 97 | fan_out=out_channels * kernel * kernel, 98 | ) 99 | self.weight = ( 100 | torch.nn.Parameter( 101 | weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs) 102 | * init_weight 103 | ) 104 | if kernel 105 | else None 106 | ) 107 | self.bias = ( 108 | torch.nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias) 109 | if kernel and bias 110 | else None 111 | ) 112 | f = torch.as_tensor(resample_filter, dtype=torch.float32) 113 | f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square() 114 | self.register_buffer("resample_filter", f if up or down else None) 115 | 116 | def forward(self, x): 117 | w = self.weight.to(x.dtype) if self.weight is not None else None 118 | b = self.bias.to(x.dtype) if self.bias is not None else None 119 | f = ( 120 | self.resample_filter.to(x.dtype) 121 | if self.resample_filter is not None 122 | else None 123 | ) 124 | w_pad = w.shape[-1] // 2 if w is not None else 0 125 | f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0 126 | 127 | if self.fused_resample and self.up and w is not None: 128 | x = torch.nn.functional.conv_transpose2d( 129 | x, 130 | f.mul(4).tile([self.in_channels, 1, 1, 1]), 131 | groups=self.in_channels, 132 | stride=2, 133 | padding=max(f_pad - w_pad, 0), 134 | ) 135 | x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0)) 136 | elif self.fused_resample and self.down and w is not None: 137 | x = torch.nn.functional.conv2d(x, w, padding=w_pad + f_pad) 138 | x = torch.nn.functional.conv2d( 139 | x, 140 | f.tile([self.out_channels, 1, 1, 1]), 141 | groups=self.out_channels, 142 | stride=2, 143 | ) 144 | else: 145 | if self.up: 146 | x = torch.nn.functional.conv_transpose2d( 147 | x, 148 | f.mul(4).tile([self.in_channels, 1, 1, 1]), 149 | groups=self.in_channels, 150 | stride=2, 151 | padding=f_pad, 152 | ) 153 | if self.down: 154 | x = torch.nn.functional.conv2d( 155 | x, 156 | f.tile([self.in_channels, 1, 1, 1]), 157 | groups=self.in_channels, 158 | stride=2, 159 | padding=f_pad, 160 | ) 161 | if w is not None: 162 | x = torch.nn.functional.conv2d(x, w, padding=w_pad) 163 | if b is not None: 164 | x = x.add_(b.reshape(1, -1, 1, 1)) 165 | return x 166 | 167 | 168 | # ---------------------------------------------------------------------------- 169 | # Group normalization. 170 | 171 | 172 | @persistence.persistent_class 173 | class GroupNorm(torch.nn.Module): 174 | def __init__(self, num_channels, num_groups=32, min_channels_per_group=4, eps=1e-5): 175 | super().__init__() 176 | self.num_groups = min(num_groups, num_channels // min_channels_per_group) 177 | self.eps = eps 178 | self.weight = torch.nn.Parameter(torch.ones(num_channels)) 179 | self.bias = torch.nn.Parameter(torch.zeros(num_channels)) 180 | 181 | def forward(self, x, *args, **kwargs): 182 | x = torch.nn.functional.group_norm( 183 | x, 184 | num_groups=self.num_groups, 185 | weight=self.weight.to(x.dtype), 186 | bias=self.bias.to(x.dtype), 187 | eps=self.eps, 188 | ) 189 | return x 190 | 191 | 192 | 193 | # ---------------------------------------------------------------------------- 194 | # Attention weight computation, i.e., softmax(Q^T * K). 195 | # Performs all computation using FP32, but uses the original datatype for 196 | # inputs/outputs/gradients to conserve memory. 197 | 198 | 199 | class AttentionOp(torch.autograd.Function): 200 | @staticmethod 201 | def forward(q, k): 202 | w = ( 203 | torch.einsum( 204 | "ncq,nck->nqk", 205 | q.to(torch.float32), 206 | (k / np.sqrt(k.shape[1])).to(torch.float32), 207 | ) 208 | .softmax(dim=2) 209 | .to(q.dtype) 210 | ) 211 | return w 212 | 213 | @staticmethod 214 | def setup_context(ctx, inputs, outputs): 215 | q,k = inputs 216 | w = outputs 217 | ctx.save_for_backward(q, k, w) 218 | # ctx.w = w 219 | 220 | @staticmethod 221 | def backward(ctx, dw): 222 | q, k, w = ctx.saved_tensors 223 | db = torch._softmax_backward_data( 224 | grad_output=dw.to(torch.float32), 225 | output=w.to(torch.float32), 226 | dim=2, 227 | input_dtype=torch.float32, 228 | ) 229 | dq = torch.einsum("nck,nqk->ncq", k.to(torch.float32), db).to( 230 | q.dtype 231 | ) / np.sqrt(k.shape[1]) 232 | dk = torch.einsum("ncq,nqk->nck", q.to(torch.float32), db).to( 233 | k.dtype 234 | ) / np.sqrt(k.shape[1]) 235 | return dq, dk 236 | 237 | 238 | @persistence.persistent_class 239 | class Attention(torch.nn.Module): 240 | def forward(self, q, k): 241 | w = ( 242 | torch.einsum( 243 | "ncq,nck->nqk", 244 | q.to(torch.float32), 245 | (k / np.sqrt(k.shape[1])).to(torch.float32), 246 | ) 247 | .softmax(dim=2) 248 | .to(q.dtype) 249 | ) 250 | return w 251 | 252 | 253 | # ---------------------------------------------------------------------------- 254 | # Unified U-Net block with optional up/downsampling and self-attention. 255 | # Represents the union of all features employed by the DDPM++, NCSN++, and 256 | # ADM architectures. 257 | 258 | 259 | 260 | @persistence.persistent_class 261 | class UNetBlock(torch.nn.Module): 262 | def __init__(self, 263 | in_channels, out_channels, emb_channels, up=False, down=False, attention=False, 264 | num_heads=None, channels_per_head=64, dropout=0, skip_scale=1, eps=1e-5, 265 | resample_filter=[1,1], resample_proj=False, adaptive_scale=True, 266 | init=dict(), init_zero=dict(init_weight=0), init_attn=None, 267 | ): 268 | super().__init__() 269 | self.in_channels = in_channels 270 | self.out_channels = out_channels 271 | self.emb_channels = emb_channels 272 | self.num_heads = 0 if not attention else num_heads if num_heads is not None else out_channels // channels_per_head 273 | self.dropout = dropout 274 | self.skip_scale = skip_scale 275 | self.adaptive_scale = adaptive_scale 276 | 277 | self.norm0 = GroupNorm(num_channels=in_channels, eps=eps) 278 | self.conv0 = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=3, up=up, down=down, resample_filter=resample_filter, **init) 279 | self.affine = Linear(in_features=emb_channels, out_features=out_channels*(2 if adaptive_scale else 1), **init) 280 | self.norm1 = GroupNorm(num_channels=out_channels, eps=eps) 281 | self.conv1 = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero) 282 | 283 | self.skip = None 284 | if out_channels != in_channels or up or down: 285 | kernel = 1 if resample_proj or out_channels!= in_channels else 0 286 | self.skip = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=kernel, up=up, down=down, resample_filter=resample_filter, **init) 287 | 288 | if self.num_heads: 289 | self.norm2 = GroupNorm(num_channels=out_channels, eps=eps) 290 | self.qkv = Conv2d(in_channels=out_channels, out_channels=out_channels*3, kernel=1, **(init_attn if init_attn is not None else init)) 291 | self.proj = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=1, **init_zero) 292 | 293 | def forward(self, x, emb): 294 | orig = x 295 | x = self.conv0(silu(self.norm0(x))) 296 | 297 | params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype) 298 | if self.adaptive_scale: 299 | scale, shift = params.chunk(chunks=2, dim=1) 300 | x = silu(torch.addcmul(shift, self.norm1(x), scale + 1)) 301 | else: 302 | x = silu(self.norm1(x.add_(params))) 303 | 304 | x = self.conv1(torch.nn.functional.dropout(x, p=self.dropout, training=self.training)) 305 | x = x.add_(self.skip(orig) if self.skip is not None else orig) 306 | x = x * self.skip_scale 307 | 308 | if self.num_heads: 309 | q, k, v = self.qkv(self.norm2(x)).reshape(x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1).unbind(2) 310 | w = AttentionOp.apply(q, k) 311 | a = torch.einsum('nqk,nck->ncq', w, v) 312 | x = self.proj(a.reshape(*x.shape)).add_(x) 313 | x = x * self.skip_scale 314 | return x 315 | 316 | 317 | # ---------------------------------------------------------------------------- 318 | # Timestep embedding used in the DDPM++ and ADM architectures. 319 | 320 | 321 | @persistence.persistent_class 322 | class PositionalEmbedding(torch.nn.Module): 323 | def __init__(self, num_channels, max_positions=10000, endpoint=False): 324 | super().__init__() 325 | self.num_channels = num_channels 326 | self.max_positions = max_positions 327 | self.endpoint = endpoint 328 | 329 | def forward(self, x): 330 | freqs = torch.arange( 331 | start=0, end=self.num_channels // 2, dtype=torch.float64, device=x.device 332 | ) 333 | freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0)) 334 | freqs = (1 / self.max_positions) ** freqs 335 | x = x.ger(freqs ) 336 | x = torch.cat([x.cos(), x.sin()], dim=1).to(x.dtype) 337 | return x 338 | 339 | 340 | # ---------------------------------------------------------------------------- 341 | # Timestep embedding used in the NCSN++ architecture. 342 | 343 | 344 | @persistence.persistent_class 345 | class FourierEmbedding(torch.nn.Module): 346 | def __init__(self, num_channels, scale=0.02, learnable=False, **kwargs): 347 | super().__init__() 348 | print("FourierEmbedding scale:", scale) 349 | if learnable: 350 | self.freqs = torch.nn.Parameter(torch.randn(num_channels // 2) * scale) 351 | else: 352 | self.register_buffer("freqs", torch.randn(num_channels // 2) * scale) 353 | 354 | def forward(self, x): 355 | x = x.ger((2 * np.pi * self.freqs) ) 356 | x = torch.cat([x.cos(), x.sin()], dim=1).to(x.dtype) 357 | return x 358 | 359 | 360 | 361 | 362 | # ---------------------------------------------------------------------------- 363 | # Reimplementation of the DDPM++ and NCSN++ architectures from the paper 364 | # "Score-Based Generative Modeling through Stochastic Differential 365 | # Equations". Equivalent to the original implementation by Song et al., 366 | # available at https://github.com/yang-song/score_sde_pytorch 367 | 368 | 369 | @persistence.persistent_class 370 | class SongUNet(torch.nn.Module): 371 | def __init__(self, 372 | img_resolution, # Image resolution at input/output. 373 | in_channels, # Number of color channels at input. 374 | out_channels, # Number of color channels at output. 375 | label_dim = 0, # Number of class labels, 0 = unconditional. 376 | augment_dim = 0, # Augmentation label dimensionality, 0 = no augmentation. 377 | 378 | model_channels = 128, # Base multiplier for the number of channels. 379 | channel_mult = [1,2,2,2], # Per-resolution multipliers for the number of channels. 380 | channel_mult_emb = 4, # Multiplier for the dimensionality of the embedding vector. 381 | num_blocks = 4, # Number of residual blocks per resolution. 382 | attn_resolutions = [16], # List of resolutions with self-attention. 383 | dropout = 0.10, # Dropout probability of intermediate activations. 384 | label_dropout = 0, # Dropout probability of class labels for classifier-free guidance. 385 | 386 | embedding_type = 'positional', # Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++. 387 | channel_mult_noise = 1, # Timestep embedding size: 1 for DDPM++, 2 for NCSN++. 388 | encoder_type = 'standard', # Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++. 389 | decoder_type = 'standard', # Decoder architecture: 'standard' for both DDPM++ and NCSN++. 390 | resample_filter = [1,1], # Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++. 391 | s_embed=True, 392 | share_tsemb=True, 393 | embedding_kwargs = {}, 394 | **kwargs 395 | ): 396 | assert embedding_type in ['fourier', 'positional'] 397 | assert encoder_type in ['standard', 'skip', 'residual'] 398 | assert decoder_type in ['standard', 'skip'] 399 | 400 | super().__init__() 401 | self.label_dropout = label_dropout 402 | emb_channels = model_channels * channel_mult_emb 403 | noise_channels = model_channels * channel_mult_noise 404 | init = dict(init_mode='xavier_uniform') 405 | init_zero = dict(init_mode='xavier_uniform', init_weight=1e-5) 406 | init_attn = dict(init_mode='xavier_uniform', init_weight=np.sqrt(0.2)) 407 | block_kwargs = dict( 408 | emb_channels=emb_channels, num_heads=1, dropout=dropout, skip_scale=np.sqrt(0.5), eps=1e-6, 409 | resample_filter=resample_filter, resample_proj=True, adaptive_scale=False, 410 | init=init, init_zero=init_zero, init_attn=init_attn, 411 | ) 412 | 413 | # Mapping. 414 | self.map_noise = PositionalEmbedding(num_channels=noise_channels, endpoint=True) if embedding_type == 'positional' else FourierEmbedding(num_channels=noise_channels, **embedding_kwargs) 415 | self.map_label = Linear(in_features=label_dim, out_features=noise_channels, **init) if label_dim else None 416 | self.map_augment = Linear(in_features=augment_dim, out_features=noise_channels, bias=False, **init) if augment_dim else None 417 | self.map_layer0 = Linear(in_features=noise_channels, out_features=emb_channels, **init) 418 | self.map_layer1 = Linear(in_features=emb_channels, out_features=emb_channels, **init) 419 | 420 | self.s_embed = s_embed 421 | if s_embed: 422 | 423 | if embedding_type == "positional": 424 | self.map_noise_s = PositionalEmbedding( 425 | num_channels=noise_channels, endpoint=True 426 | ) 427 | elif embedding_type == "fourier": 428 | self.map_noise_s = FourierEmbedding( 429 | num_channels=noise_channels, **embedding_kwargs 430 | ) 431 | self.map_layer0_s = Linear( 432 | in_features=noise_channels, 433 | out_features=emb_channels, 434 | **init, 435 | ) 436 | self.map_layer1_s = Linear( 437 | in_features=emb_channels, out_features=emb_channels, **init 438 | ) 439 | # Encoder. 440 | self.enc = torch.nn.ModuleDict() 441 | cout = in_channels 442 | caux = in_channels 443 | for level, mult in enumerate(channel_mult): 444 | res = img_resolution >> level 445 | if level == 0: 446 | cin = cout 447 | cout = model_channels 448 | self.enc[f'{res}x{res}_conv'] = Conv2d(in_channels=cin, out_channels=cout, kernel=3, **init) 449 | else: 450 | self.enc[f'{res}x{res}_down'] = UNetBlock(in_channels=cout, out_channels=cout, down=True, **block_kwargs) 451 | if encoder_type == 'skip': 452 | self.enc[f'{res}x{res}_aux_down'] = Conv2d(in_channels=caux, out_channels=caux, kernel=0, down=True, resample_filter=resample_filter) 453 | self.enc[f'{res}x{res}_aux_skip'] = Conv2d(in_channels=caux, out_channels=cout, kernel=1, **init) 454 | if encoder_type == 'residual': 455 | self.enc[f'{res}x{res}_aux_residual'] = Conv2d(in_channels=caux, out_channels=cout, kernel=3, down=True, resample_filter=resample_filter, fused_resample=True, **init) 456 | caux = cout 457 | for idx in range(num_blocks): 458 | cin = cout 459 | cout = model_channels * mult 460 | attn = (res in attn_resolutions) 461 | self.enc[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=attn, **block_kwargs) 462 | skips = [block.out_channels for name, block in self.enc.items() if 'aux' not in name] 463 | 464 | # Decoder. 465 | self.dec = torch.nn.ModuleDict() 466 | for level, mult in reversed(list(enumerate(channel_mult))): 467 | res = img_resolution >> level 468 | if level == len(channel_mult) - 1: 469 | self.dec[f'{res}x{res}_in0'] = UNetBlock(in_channels=cout, out_channels=cout, attention=True, **block_kwargs) 470 | self.dec[f'{res}x{res}_in1'] = UNetBlock(in_channels=cout, out_channels=cout, **block_kwargs) 471 | else: 472 | self.dec[f'{res}x{res}_up'] = UNetBlock(in_channels=cout, out_channels=cout, up=True, **block_kwargs) 473 | for idx in range(num_blocks + 1): 474 | cin = cout + skips.pop() 475 | cout = model_channels * mult 476 | attn = (idx == num_blocks and res in attn_resolutions) 477 | self.dec[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=attn, **block_kwargs) 478 | if decoder_type == 'skip' or level == 0: 479 | if decoder_type == 'skip' and level < len(channel_mult) - 1: 480 | self.dec[f'{res}x{res}_aux_up'] = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=0, up=True, resample_filter=resample_filter) 481 | self.dec[f'{res}x{res}_aux_norm'] = GroupNorm(num_channels=cout, eps=1e-6) 482 | self.dec[f'{res}x{res}_aux_conv'] = Conv2d(in_channels=cout, out_channels=out_channels, kernel=3, **init_zero) 483 | 484 | def forward(self,x, 485 | noise_labels_t, 486 | noise_labels_s=None, 487 | class_labels=None, 488 | augment_labels=None, ): 489 | # Mapping. 490 | emb = self.map_noise(noise_labels_t) 491 | emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos 492 | if self.map_label is not None: 493 | tmp = class_labels 494 | if self.training and self.label_dropout: 495 | tmp = tmp * (torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout).to(tmp.dtype) 496 | emb = emb + self.map_label(tmp * np.sqrt(self.map_label.in_features)) 497 | if self.map_augment is not None and augment_labels is not None: 498 | emb = emb + self.map_augment(augment_labels) 499 | emb = silu(self.map_layer0(emb)) 500 | emb = self.map_layer1(emb) 501 | 502 | 503 | if noise_labels_s is not None and self.s_embed: 504 | 505 | emb_s = self.map_noise_s(noise_labels_s) 506 | emb_s = ( 507 | emb_s.reshape(emb_s.shape[0], 2, -1).flip(1).reshape(*emb_s.shape) 508 | ) # swap sin/cos 509 | emb_s = silu(self.map_layer0_s(emb_s)) 510 | emb_s = self.map_layer1_s(emb_s) 511 | emb = emb + emb_s 512 | 513 | emb = silu(emb) 514 | 515 | # Encoder. 516 | skips = [] 517 | aux = x 518 | for name, block in self.enc.items(): 519 | if 'aux_down' in name: 520 | aux = block(aux) 521 | elif 'aux_skip' in name: 522 | x = skips[-1] = x + block(aux) 523 | elif 'aux_residual' in name: 524 | x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2) 525 | else: 526 | x = block(x, emb) if isinstance(block, UNetBlock) else block(x) 527 | skips.append(x) 528 | 529 | # Decoder. 530 | aux = None 531 | tmp = None 532 | for name, block in self.dec.items(): 533 | if 'aux_up' in name: 534 | aux = block(aux) 535 | elif 'aux_norm' in name: 536 | tmp = block(x) 537 | elif 'aux_conv' in name: 538 | tmp = block(silu(tmp)) 539 | aux = tmp if aux is None else tmp + aux 540 | else: 541 | if x.shape[1] != block.in_channels: 542 | x = torch.cat([x, skips.pop()], dim=1) 543 | x = block(x, emb) 544 | 545 | return aux 546 | 547 | 548 | 549 | 550 | # ---------------------------------------------------------------------------- 551 | # Reimplementation of the ADM architecture from the paper 552 | # "Diffusion Models Beat GANS on Image Synthesis". Equivalent to the 553 | # original implementation by Dhariwal and Nichol, available at 554 | # https://github.com/openai/guided-diffusion 555 | 556 | 557 | @persistence.persistent_class 558 | class DhariwalUNet(torch.nn.Module): 559 | 560 | def __init__( 561 | self, 562 | img_resolution, # Image resolution at input/output. 563 | in_channels, # Number of color channels at input. 564 | out_channels, # Number of color channels at output. 565 | label_dim=0, # Number of class labels, 0 = unconditional. 566 | augment_dim=0, # Augmentation label dimensionality, 0 = no augmentation. 567 | model_channels=192, # Base multiplier for the number of channels. 568 | channel_mult=[ 569 | 1, 570 | 2, 571 | 3, 572 | 4, 573 | ], # Per-resolution multipliers for the number of channels. 574 | channel_mult_emb=4, # Multiplier for the dimensionality of the embedding vector. 575 | num_blocks=3, # Number of residual blocks per resolution. 576 | attn_resolutions=[32, 16, 8], # List of resolutions with self-attention. 577 | dropout=0.10, # List of resolutions with self-attention. 578 | label_dropout=0, # Dropout probability of class labels for classifier-free guidance. 579 | s_embed=True, 580 | **kwargs 581 | ): 582 | super().__init__() 583 | self.label_dropout = label_dropout 584 | emb_channels = model_channels * channel_mult_emb 585 | init = dict( 586 | init_mode="kaiming_uniform", 587 | init_weight=np.sqrt(1 / 3), 588 | init_bias=np.sqrt(1 / 3), 589 | ) 590 | init_zero = dict(init_mode="kaiming_uniform", init_weight=0, init_bias=0) 591 | block_kwargs = dict( 592 | emb_channels=emb_channels, 593 | channels_per_head=64, 594 | dropout=dropout, 595 | init=init, 596 | init_zero=init_zero, 597 | ) 598 | 599 | # Mapping. 600 | self.map_noise = PositionalEmbedding(num_channels=model_channels) 601 | self.s_embed = s_embed 602 | if s_embed: 603 | self.map_noise_s = self.map_noise 604 | 605 | self.map_layer0_s = Linear( 606 | in_features=model_channels, 607 | out_features=emb_channels, 608 | **init, 609 | ) 610 | self.map_layer1_s = Linear( 611 | in_features=emb_channels, out_features=emb_channels, **init 612 | ) 613 | self.map_augment = ( 614 | Linear( 615 | in_features=augment_dim, 616 | out_features=model_channels , 617 | bias=False, 618 | **init_zero, 619 | ) 620 | if augment_dim 621 | else None 622 | ) 623 | self.map_layer0 = Linear( 624 | in_features=model_channels , 625 | out_features=emb_channels , 626 | **init, 627 | ) 628 | self.map_layer1 = Linear( 629 | in_features=emb_channels , 630 | out_features=emb_channels, 631 | **init, 632 | ) 633 | self.map_label = ( 634 | Linear( 635 | in_features=label_dim, 636 | out_features=emb_channels, 637 | bias=False, 638 | init_mode="kaiming_normal", 639 | init_weight=np.sqrt(label_dim), 640 | ) 641 | if label_dim 642 | else None 643 | ) 644 | 645 | # Encoder. 646 | self.enc = torch.nn.ModuleDict() 647 | cout = in_channels 648 | for level, mult in enumerate(channel_mult): 649 | res = img_resolution >> level 650 | if level == 0: 651 | cin = cout 652 | cout = model_channels * mult 653 | self.enc[f"{res}x{res}_conv"] = Conv2d( 654 | in_channels=cin, out_channels=cout, kernel=3, **init 655 | ) 656 | else: 657 | self.enc[f"{res}x{res}_down"] = UNetBlock( 658 | in_channels=cout, 659 | out_channels=cout, 660 | down=True, 661 | **block_kwargs, 662 | ) 663 | for idx in range(num_blocks): 664 | cin = cout 665 | cout = model_channels * mult 666 | self.enc[f"{res}x{res}_block{idx}"] = UNetBlock( 667 | in_channels=cin, 668 | out_channels=cout, 669 | attention=(res in attn_resolutions), 670 | **block_kwargs, 671 | ) 672 | skips = [block.out_channels for block in self.enc.values()] 673 | 674 | # Decoder. 675 | self.dec = torch.nn.ModuleDict() 676 | for level, mult in reversed(list(enumerate(channel_mult))): 677 | res = img_resolution >> level 678 | if level == len(channel_mult) - 1: 679 | self.dec[f"{res}x{res}_in0"] = UNetBlock( 680 | in_channels=cout, 681 | out_channels=cout, 682 | attention=True, 683 | **block_kwargs, 684 | ) 685 | self.dec[f"{res}x{res}_in1"] = UNetBlock( 686 | in_channels=cout, 687 | out_channels=cout, 688 | **block_kwargs, 689 | ) 690 | else: 691 | self.dec[f"{res}x{res}_up"] = UNetBlock( 692 | in_channels=cout, 693 | out_channels=cout, 694 | up=True, 695 | **block_kwargs, 696 | ) 697 | for idx in range(num_blocks + 1): 698 | cin = cout + skips.pop() 699 | cout = model_channels * mult 700 | self.dec[f"{res}x{res}_block{idx}"] = UNetBlock( 701 | in_channels=cin, 702 | out_channels=cout, 703 | attention=(res in attn_resolutions), 704 | **block_kwargs, 705 | ) 706 | 707 | self.out_norm = GroupNorm(num_channels=cout) 708 | self.out_conv = Conv2d( 709 | in_channels=cout, out_channels=out_channels, kernel=3, **init_zero 710 | ) 711 | 712 | def forward( 713 | self, 714 | x, 715 | noise_labels_t, 716 | noise_labels_s=None, 717 | class_labels=None, 718 | augment_labels=None, 719 | ): 720 | 721 | # Mapping. 722 | emb = self.map_noise(noise_labels_t) 723 | if self.map_augment is not None and augment_labels is not None: 724 | emb = emb + self.map_augment(augment_labels) 725 | emb = silu(self.map_layer0(emb)) 726 | emb = self.map_layer1(emb) 727 | if self.map_label is not None: 728 | tmp = class_labels 729 | if self.training and self.label_dropout: 730 | tmp = tmp * ( 731 | torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout 732 | ).to(tmp.dtype) 733 | emb = emb + self.map_label(tmp) 734 | if noise_labels_s is not None and self.s_embed: 735 | 736 | emb_s = self.map_noise_s(noise_labels_s) 737 | emb_s = silu(self.map_layer0_s(emb_s)) 738 | emb_s = self.map_layer1_s(emb_s) 739 | emb = emb + emb_s 740 | 741 | emb = silu(emb) 742 | 743 | # Encoder. 744 | skips = [] 745 | for block in self.enc.values(): 746 | x = block(x, emb) if isinstance(block, UNetBlock) else block(x) 747 | skips.append(x) 748 | 749 | # Decoder. 750 | for block in self.dec.values(): 751 | if x.shape[1] != block.in_channels: 752 | x = torch.cat([x, skips.pop()], dim=1) 753 | x = block(x, emb) 754 | x = self.out_conv(silu(self.out_norm(x))) 755 | return x 756 | 757 | 758 | --------------------------------------------------------------------------------