├── .gitignore
├── LICENSE
├── README.md
├── assets
└── teaser.png
├── configs
├── cifar10_generate_images.yaml
└── im256_generate_images.yaml
├── dnnlib
├── __init__.py
└── util.py
├── env.yml
├── generate_images.py
├── model_card.md
├── torch_utils
├── __init__.py
├── distributed.py
├── misc.py
├── persistence.py
└── training_stats.py
└── training
├── __init__.py
├── dit.py
├── encoders.py
├── preconds.py
└── unets.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 |
163 | # Project-related
164 | datasets
165 |
166 | *-runs
167 | outputs/
168 |
169 | slurm*
170 | debug.sh
171 | *.out
172 |
173 |
174 | wandb/
175 |
176 | multirun/
177 |
178 | *.json
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial-ShareAlike 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58 | Public License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License
63 | ("Public License"). To the extent this Public License may be
64 | interpreted as a contract, You are granted the Licensed Rights in
65 | consideration of Your acceptance of these terms and conditions, and the
66 | Licensor grants You such rights in consideration of benefits the
67 | Licensor receives from making the Licensed Material available under
68 | these terms and conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. BY-NC-SA Compatible License means a license listed at
88 | creativecommons.org/compatiblelicenses, approved by Creative
89 | Commons as essentially the equivalent of this Public License.
90 |
91 | d. Copyright and Similar Rights means copyright and/or similar rights
92 | closely related to copyright including, without limitation,
93 | performance, broadcast, sound recording, and Sui Generis Database
94 | Rights, without regard to how the rights are labeled or
95 | categorized. For purposes of this Public License, the rights
96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
97 | Rights.
98 |
99 | e. Effective Technological Measures means those measures that, in the
100 | absence of proper authority, may not be circumvented under laws
101 | fulfilling obligations under Article 11 of the WIPO Copyright
102 | Treaty adopted on December 20, 1996, and/or similar international
103 | agreements.
104 |
105 | f. Exceptions and Limitations means fair use, fair dealing, and/or
106 | any other exception or limitation to Copyright and Similar Rights
107 | that applies to Your use of the Licensed Material.
108 |
109 | g. License Elements means the license attributes listed in the name
110 | of a Creative Commons Public License. The License Elements of this
111 | Public License are Attribution, NonCommercial, and ShareAlike.
112 |
113 | h. Licensed Material means the artistic or literary work, database,
114 | or other material to which the Licensor applied this Public
115 | License.
116 |
117 | i. Licensed Rights means the rights granted to You subject to the
118 | terms and conditions of this Public License, which are limited to
119 | all Copyright and Similar Rights that apply to Your use of the
120 | Licensed Material and that the Licensor has authority to license.
121 |
122 | j. Licensor means the individual(s) or entity(ies) granting rights
123 | under this Public License.
124 |
125 | k. NonCommercial means not primarily intended for or directed towards
126 | commercial advantage or monetary compensation. For purposes of
127 | this Public License, the exchange of the Licensed Material for
128 | other material subject to Copyright and Similar Rights by digital
129 | file-sharing or similar means is NonCommercial provided there is
130 | no payment of monetary compensation in connection with the
131 | exchange.
132 |
133 | l. Share means to provide material to the public by any means or
134 | process that requires permission under the Licensed Rights, such
135 | as reproduction, public display, public performance, distribution,
136 | dissemination, communication, or importation, and to make material
137 | available to the public including in ways that members of the
138 | public may access the material from a place and at a time
139 | individually chosen by them.
140 |
141 | m. Sui Generis Database Rights means rights other than copyright
142 | resulting from Directive 96/9/EC of the European Parliament and of
143 | the Council of 11 March 1996 on the legal protection of databases,
144 | as amended and/or succeeded, as well as other essentially
145 | equivalent rights anywhere in the world.
146 |
147 | n. You means the individual or entity exercising the Licensed Rights
148 | under this Public License. Your has a corresponding meaning.
149 |
150 |
151 | Section 2 -- Scope.
152 |
153 | a. License grant.
154 |
155 | 1. Subject to the terms and conditions of this Public License,
156 | the Licensor hereby grants You a worldwide, royalty-free,
157 | non-sublicensable, non-exclusive, irrevocable license to
158 | exercise the Licensed Rights in the Licensed Material to:
159 |
160 | a. reproduce and Share the Licensed Material, in whole or
161 | in part, for NonCommercial purposes only; and
162 |
163 | b. produce, reproduce, and Share Adapted Material for
164 | NonCommercial purposes only.
165 |
166 | 2. Exceptions and Limitations. For the avoidance of doubt, where
167 | Exceptions and Limitations apply to Your use, this Public
168 | License does not apply, and You do not need to comply with
169 | its terms and conditions.
170 |
171 | 3. Term. The term of this Public License is specified in Section
172 | 6(a).
173 |
174 | 4. Media and formats; technical modifications allowed. The
175 | Licensor authorizes You to exercise the Licensed Rights in
176 | all media and formats whether now known or hereafter created,
177 | and to make technical modifications necessary to do so. The
178 | Licensor waives and/or agrees not to assert any right or
179 | authority to forbid You from making technical modifications
180 | necessary to exercise the Licensed Rights, including
181 | technical modifications necessary to circumvent Effective
182 | Technological Measures. For purposes of this Public License,
183 | simply making modifications authorized by this Section 2(a)
184 | (4) never produces Adapted Material.
185 |
186 | 5. Downstream recipients.
187 |
188 | a. Offer from the Licensor -- Licensed Material. Every
189 | recipient of the Licensed Material automatically
190 | receives an offer from the Licensor to exercise the
191 | Licensed Rights under the terms and conditions of this
192 | Public License.
193 |
194 | b. Additional offer from the Licensor -- Adapted Material.
195 | Every recipient of Adapted Material from You
196 | automatically receives an offer from the Licensor to
197 | exercise the Licensed Rights in the Adapted Material
198 | under the conditions of the Adapter's License You apply.
199 |
200 | c. No downstream restrictions. You may not offer or impose
201 | any additional or different terms or conditions on, or
202 | apply any Effective Technological Measures to, the
203 | Licensed Material if doing so restricts exercise of the
204 | Licensed Rights by any recipient of the Licensed
205 | Material.
206 |
207 | 6. No endorsement. Nothing in this Public License constitutes or
208 | may be construed as permission to assert or imply that You
209 | are, or that Your use of the Licensed Material is, connected
210 | with, or sponsored, endorsed, or granted official status by,
211 | the Licensor or others designated to receive attribution as
212 | provided in Section 3(a)(1)(A)(i).
213 |
214 | b. Other rights.
215 |
216 | 1. Moral rights, such as the right of integrity, are not
217 | licensed under this Public License, nor are publicity,
218 | privacy, and/or other similar personality rights; however, to
219 | the extent possible, the Licensor waives and/or agrees not to
220 | assert any such rights held by the Licensor to the limited
221 | extent necessary to allow You to exercise the Licensed
222 | Rights, but not otherwise.
223 |
224 | 2. Patent and trademark rights are not licensed under this
225 | Public License.
226 |
227 | 3. To the extent possible, the Licensor waives any right to
228 | collect royalties from You for the exercise of the Licensed
229 | Rights, whether directly or through a collecting society
230 | under any voluntary or waivable statutory or compulsory
231 | licensing scheme. In all other cases the Licensor expressly
232 | reserves any right to collect such royalties, including when
233 | the Licensed Material is used other than for NonCommercial
234 | purposes.
235 |
236 |
237 | Section 3 -- License Conditions.
238 |
239 | Your exercise of the Licensed Rights is expressly made subject to the
240 | following conditions.
241 |
242 | a. Attribution.
243 |
244 | 1. If You Share the Licensed Material (including in modified
245 | form), You must:
246 |
247 | a. retain the following if it is supplied by the Licensor
248 | with the Licensed Material:
249 |
250 | i. identification of the creator(s) of the Licensed
251 | Material and any others designated to receive
252 | attribution, in any reasonable manner requested by
253 | the Licensor (including by pseudonym if
254 | designated);
255 |
256 | ii. a copyright notice;
257 |
258 | iii. a notice that refers to this Public License;
259 |
260 | iv. a notice that refers to the disclaimer of
261 | warranties;
262 |
263 | v. a URI or hyperlink to the Licensed Material to the
264 | extent reasonably practicable;
265 |
266 | b. indicate if You modified the Licensed Material and
267 | retain an indication of any previous modifications; and
268 |
269 | c. indicate the Licensed Material is licensed under this
270 | Public License, and include the text of, or the URI or
271 | hyperlink to, this Public License.
272 |
273 | 2. You may satisfy the conditions in Section 3(a)(1) in any
274 | reasonable manner based on the medium, means, and context in
275 | which You Share the Licensed Material. For example, it may be
276 | reasonable to satisfy the conditions by providing a URI or
277 | hyperlink to a resource that includes the required
278 | information.
279 | 3. If requested by the Licensor, You must remove any of the
280 | information required by Section 3(a)(1)(A) to the extent
281 | reasonably practicable.
282 |
283 | b. ShareAlike.
284 |
285 | In addition to the conditions in Section 3(a), if You Share
286 | Adapted Material You produce, the following conditions also apply.
287 |
288 | 1. The Adapter's License You apply must be a Creative Commons
289 | license with the same License Elements, this version or
290 | later, or a BY-NC-SA Compatible License.
291 |
292 | 2. You must include the text of, or the URI or hyperlink to, the
293 | Adapter's License You apply. You may satisfy this condition
294 | in any reasonable manner based on the medium, means, and
295 | context in which You Share Adapted Material.
296 |
297 | 3. You may not offer or impose any additional or different terms
298 | or conditions on, or apply any Effective Technological
299 | Measures to, Adapted Material that restrict exercise of the
300 | rights granted under the Adapter's License You apply.
301 |
302 |
303 | Section 4 -- Sui Generis Database Rights.
304 |
305 | Where the Licensed Rights include Sui Generis Database Rights that
306 | apply to Your use of the Licensed Material:
307 |
308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309 | to extract, reuse, reproduce, and Share all or a substantial
310 | portion of the contents of the database for NonCommercial purposes
311 | only;
312 |
313 | b. if You include all or a substantial portion of the database
314 | contents in a database in which You have Sui Generis Database
315 | Rights, then the database in which You have Sui Generis Database
316 | Rights (but not its individual contents) is Adapted Material,
317 | including for purposes of Section 3(b); and
318 |
319 | c. You must comply with the conditions in Section 3(a) if You Share
320 | all or a substantial portion of the contents of the database.
321 |
322 | For the avoidance of doubt, this Section 4 supplements and does not
323 | replace Your obligations under this Public License where the Licensed
324 | Rights include other Copyright and Similar Rights.
325 |
326 |
327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328 |
329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339 |
340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349 |
350 | c. The disclaimer of warranties and limitation of liability provided
351 | above shall be interpreted in a manner that, to the extent
352 | possible, most closely approximates an absolute disclaimer and
353 | waiver of all liability.
354 |
355 |
356 | Section 6 -- Term and Termination.
357 |
358 | a. This Public License applies for the term of the Copyright and
359 | Similar Rights licensed here. However, if You fail to comply with
360 | this Public License, then Your rights under this Public License
361 | terminate automatically.
362 |
363 | b. Where Your right to use the Licensed Material has terminated under
364 | Section 6(a), it reinstates:
365 |
366 | 1. automatically as of the date the violation is cured, provided
367 | it is cured within 30 days of Your discovery of the
368 | violation; or
369 |
370 | 2. upon express reinstatement by the Licensor.
371 |
372 | For the avoidance of doubt, this Section 6(b) does not affect any
373 | right the Licensor may have to seek remedies for Your violations
374 | of this Public License.
375 |
376 | c. For the avoidance of doubt, the Licensor may also offer the
377 | Licensed Material under separate terms or conditions or stop
378 | distributing the Licensed Material at any time; however, doing so
379 | will not terminate this Public License.
380 |
381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382 | License.
383 |
384 |
385 | Section 7 -- Other Terms and Conditions.
386 |
387 | a. The Licensor shall not be bound by any additional or different
388 | terms or conditions communicated by You unless expressly agreed.
389 |
390 | b. Any arrangements, understandings, or agreements regarding the
391 | Licensed Material not stated herein are separate from and
392 | independent of the terms and conditions of this Public License.
393 |
394 |
395 | Section 8 -- Interpretation.
396 |
397 | a. For the avoidance of doubt, this Public License does not, and
398 | shall not be interpreted to, reduce, limit, restrict, or impose
399 | conditions on any use of the Licensed Material that could lawfully
400 | be made without permission under this Public License.
401 |
402 | b. To the extent possible, if any provision of this Public License is
403 | deemed unenforceable, it shall be automatically reformed to the
404 | minimum extent necessary to make it enforceable. If the provision
405 | cannot be reformed, it shall be severed from this Public License
406 | without affecting the enforceability of the remaining terms and
407 | conditions.
408 |
409 | c. No term or condition of this Public License will be waived and no
410 | failure to comply consented to unless expressly agreed to by the
411 | Licensor.
412 |
413 | d. Nothing in this Public License constitutes or may be interpreted
414 | as a limitation upon, or waiver of, any privileges and immunities
415 | that apply to the Licensor or You, including from the legal
416 | processes of any jurisdiction or authority.
417 |
418 | =======================================================================
419 |
420 | Creative Commons is not a party to its public
421 | licenses. Notwithstanding, Creative Commons may elect to apply one of
422 | its public licenses to material it publishes and in those instances
423 | will be considered the "Licensor." The text of the Creative Commons
424 | public licenses is dedicated to the public domain under the CC0 Public
425 | Domain Dedication. Except for the limited purpose of indicating that
426 | material is shared under a Creative Commons public license or as
427 | otherwise permitted by the Creative Commons policies published at
428 | creativecommons.org/policies, Creative Commons does not authorize the
429 | use of the trademark "Creative Commons" or any other trademark or logo
430 | of Creative Commons without its prior written consent including,
431 | without limitation, in connection with any unauthorized modifications
432 | to any of its public licenses or any other arrangements,
433 | understandings, or agreements concerning use of licensed material. For
434 | the avoidance of doubt, this paragraph does not form part of the
435 | public licenses.
436 |
437 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Inductive Moment Matching
2 |
3 |
4 | Official Implementation of [Inductive Moment Matching](https://arxiv.org/abs/2503.07565)
5 |
6 |
7 |
8 |
9 |
10 |
20 |
21 |
22 | 1Luma AI,
23 | 2Stanford University
24 |
25 |
29 |
30 |
31 | Also check out our accompanying [position paper](https://arxiv.org/abs/2503.07154) that explains the motivation and ways of designing new generative paradigms.
32 |
33 | # Dependencies
34 |
35 | To install all packages in this codebase along with their dependencies, run
36 | ```sh
37 | conda env create -f env.yml
38 | ```
39 |
40 | # Pre-trained models
41 |
42 | We provide pretrained checkpoints through our [repo](https://huggingface.co/lumaai/imm) on Hugging Face:
43 | * IMM on CIFAR-10: [cifar10.pkl](https://huggingface.co/lumaai/imm/resolve/main/cifar10.pt).
44 | * IMM on ImageNet-256x256:
45 | 1. `t-s` is passed as second time embedding, trained with `a=2`: [imagenet256_ts_a2.pkl](https://huggingface.co/lumaai/imm/resolve/main/imagenet256_ts_a2.pkl).
46 | 2. `s` is passed as second time embedding directly, trained with `a=1`: [imagenet256_s_a1.pkl](https://huggingface.co/lumaai/imm/resolve/main/imagenet256_s_a1.pkl).
47 |
48 | # Sampling
49 |
50 | The checkpoints can be tested via
51 | ```sh
52 | python generate_images.py --config-name=CONFIG_NAME eval.resume=CKPT_PATH REPLACEMENT_ARGS
53 | ```
54 | where `CONFIG_NAME` is `im256_generate_images.yaml` or `cifar10_generate_images.yaml` and `CKPT_PATH` is the path to your checkpoint. When loading `imagenet256_s_a1.pkl`, `REPLACEMENT_ARGS` needs to be `network.temb_type=identity`. Otherwise, `REPLACEMENT_ARGS` is empty.
55 |
56 | # Checklist
57 |
58 | - [x] Add model weights and model definitions.
59 | - [x] Add inference scripts.
60 | - [ ] Add evaluation scripts.
61 | - [ ] Add training scripts.
62 |
63 | # Acknowledgements
64 |
65 | Some of the utility functions are based on [EDM](https://github.com/NVlabs/edm), and thus parts of the code would apply under [this license](https://github.com/NVlabs/edm/blob/main/LICENSE.txt).
66 |
67 | # Citation
68 |
69 | ```
70 | @article{zhou2025inductive,
71 | title={Inductive Moment Matching},
72 | author={Zhou, Linqi and Ermon, Stefano and Song, Jiaming},
73 | journal={arXiv preprint arXiv:2503.07565},
74 | year={2025}
75 | }
76 | ```
77 |
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lumalabs/imm/c9473e9395910d3dd03ffffb6b0a8b4fb760c3a1/assets/teaser.png
--------------------------------------------------------------------------------
/configs/cifar10_generate_images.yaml:
--------------------------------------------------------------------------------
1 |
2 | label_dim: 0 # unconditional
3 | resolution: 32 #latent resolution
4 | channels: 3
5 |
6 | encoder:
7 | class_name: training.encoders.StandardRGBEncoder
8 |
9 |
10 | network:
11 | class_name: training.preconds.IMMPrecond
12 | #ddpmpp
13 | model_type: "SongUNet"
14 | embedding_type: "positional"
15 | encoder_type: "standard"
16 | decoder_type: "standard"
17 | channel_mult_noise: 1
18 | resample_filter: [1, 1]
19 | model_channels: 128
20 | channel_mult: [2, 2, 2]
21 | s_embed: true
22 | dropout: 0.2
23 |
24 | noise_schedule: fm
25 |
26 | f_type: simple_edm
27 | temb_type: identity
28 | time_scale: 1000
29 |
30 |
31 | eps: 0.006
32 | T: 0.994
33 |
34 |
35 | sampling:
36 | 1_step:
37 | name: pushforward_generator_fn
38 | mid_nt: null
39 |
40 | 2_steps:
41 | name: pushforward_generator_fn
42 | mid_nt: [1.4]
43 |
44 | eval:
45 | seed: 42
46 | batch_size: 256
47 | cudnn_benchmark: true
48 | resume: null
49 |
50 |
51 | hydra:
52 | output_subdir: null
53 | run:
54 | dir: .
--------------------------------------------------------------------------------
/configs/im256_generate_images.yaml:
--------------------------------------------------------------------------------
1 |
2 |
3 | label_dim: 1000
4 | resolution: 32 #latent resolution
5 | channels: 4
6 |
7 | dataloader:
8 | pin_memory: true
9 | num_workers: 1
10 | prefetch_factor: 2
11 |
12 | encoder:
13 | class_name: training.encoders.StabilityVAEEncoder
14 | vae_name: stabilityai/sd-vae-ft-ema
15 | final_std: 0.5
16 | raw_mean: [ 0.86488, -0.27787343, 0.21616915, 0.3738409 ]
17 | raw_std: [4.85503674, 5.31922414, 3.93725398 , 3.9870003 ]
18 | use_fp16: true
19 |
20 |
21 | network:
22 | class_name: training.preconds.IMMPrecond
23 |
24 | model_type: "DiT_XL_2"
25 | s_embed: true
26 |
27 | noise_schedule: fm
28 |
29 | #sample function
30 | f_type: euler_fm
31 | temb_type: stride
32 | time_scale: 1000
33 |
34 | sigma_data: 0.5
35 |
36 | eps: 0.
37 | T: 0.994
38 |
39 |
40 |
41 | sampling:
42 |
43 | 1_steps_cfg1.5_pushforward_uniform:
44 | name: pushforward_generator_fn
45 | discretization: uniform
46 | num_steps: 1
47 | cfg_scale: 1.5
48 |
49 |
50 | 2_steps_cfg1.5_pushforward_uniform:
51 | name: pushforward_generator_fn
52 | discretization: uniform
53 | num_steps: 2
54 | cfg_scale: 1.5
55 |
56 | 4_steps_cfg1.5_pushforward_uniform:
57 | name: pushforward_generator_fn
58 | discretization: uniform
59 | num_steps: 4
60 | cfg_scale: 1.5
61 |
62 | 8_steps_cfg1.5_pushforward_uniform:
63 | name: pushforward_generator_fn
64 | discretization: uniform
65 | num_steps: 8
66 | cfg_scale: 1.5
67 |
68 |
69 |
70 | eval:
71 | seed: 42
72 | batch_size: 256
73 | cudnn_benchmark: true
74 | resume: null
75 |
76 |
77 | hydra:
78 | output_subdir: null
79 | run:
80 | dir: .
--------------------------------------------------------------------------------
/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | from .util import EasyDict, make_cache_dir_path
9 |
--------------------------------------------------------------------------------
/dnnlib/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Miscellaneous utility classes and functions."""
9 |
10 | import ctypes
11 | import fnmatch
12 | import importlib
13 | import inspect
14 | import numpy as np
15 | import os
16 | import shutil
17 | import sys
18 | import types
19 | import io
20 | import pickle
21 | import re
22 | import requests
23 | import html
24 | import hashlib
25 | import glob
26 | import tempfile
27 | import urllib
28 | import urllib.request
29 | import uuid
30 |
31 | from distutils.util import strtobool
32 | from typing import Any, List, Tuple, Union, Optional
33 |
34 |
35 | # Util classes
36 | # ------------------------------------------------------------------------------------------
37 |
38 |
39 | class EasyDict(dict):
40 | """Convenience class that behaves like a dict but allows access with the attribute syntax."""
41 |
42 | def __getattr__(self, name: str) -> Any:
43 | try:
44 | return self[name]
45 | except KeyError:
46 | raise AttributeError(name)
47 |
48 | def __setattr__(self, name: str, value: Any) -> None:
49 | self[name] = value
50 |
51 | def __delattr__(self, name: str) -> None:
52 | del self[name]
53 |
54 |
55 | class Logger(object):
56 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
57 |
58 | def __init__(
59 | self,
60 | file_name: Optional[str] = None,
61 | file_mode: str = "w",
62 | should_flush: bool = True,
63 | ):
64 | self.file = None
65 |
66 | if file_name is not None:
67 | self.file = open(file_name, file_mode)
68 |
69 | self.should_flush = should_flush
70 | self.stdout = sys.stdout
71 | self.stderr = sys.stderr
72 |
73 | sys.stdout = self
74 | sys.stderr = self
75 |
76 | def __enter__(self) -> "Logger":
77 | return self
78 |
79 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
80 | self.close()
81 |
82 | def write(self, text: Union[str, bytes]) -> None:
83 | """Write text to stdout (and a file) and optionally flush."""
84 | if isinstance(text, bytes):
85 | text = text.decode()
86 | if (
87 | len(text) == 0
88 | ): # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
89 | return
90 |
91 | if self.file is not None:
92 | self.file.write(text)
93 |
94 | self.stdout.write(text)
95 |
96 | if self.should_flush:
97 | self.flush()
98 |
99 | def flush(self) -> None:
100 | """Flush written text to both stdout and a file, if open."""
101 | if self.file is not None:
102 | self.file.flush()
103 |
104 | self.stdout.flush()
105 |
106 | def close(self) -> None:
107 | """Flush, close possible files, and remove stdout/stderr mirroring."""
108 | self.flush()
109 |
110 | # if using multiple loggers, prevent closing in wrong order
111 | if sys.stdout is self:
112 | sys.stdout = self.stdout
113 | if sys.stderr is self:
114 | sys.stderr = self.stderr
115 |
116 | if self.file is not None:
117 | self.file.close()
118 | self.file = None
119 |
120 |
121 | # Cache directories
122 | # ------------------------------------------------------------------------------------------
123 |
124 | _dnnlib_cache_dir = None
125 |
126 |
127 | def set_cache_dir(path: str) -> None:
128 | global _dnnlib_cache_dir
129 | _dnnlib_cache_dir = path
130 |
131 |
132 | def make_cache_dir_path(*paths: str) -> str:
133 | if _dnnlib_cache_dir is not None:
134 | return os.path.join(_dnnlib_cache_dir, *paths)
135 | if "DNNLIB_CACHE_DIR" in os.environ:
136 | return os.path.join(os.environ["DNNLIB_CACHE_DIR"], *paths)
137 | if "HOME" in os.environ:
138 | return os.path.join(os.environ["HOME"], ".cache", "dnnlib", *paths)
139 | if "USERPROFILE" in os.environ:
140 | return os.path.join(os.environ["USERPROFILE"], ".cache", "dnnlib", *paths)
141 | return os.path.join(tempfile.gettempdir(), ".cache", "dnnlib", *paths)
142 |
143 |
144 | # Small util functions
145 | # ------------------------------------------------------------------------------------------
146 |
147 |
148 | def format_time(seconds: Union[int, float]) -> str:
149 | """Convert the seconds to human readable string with days, hours, minutes and seconds."""
150 | s = int(np.rint(seconds))
151 |
152 | if s < 60:
153 | return "{0}s".format(s)
154 | elif s < 60 * 60:
155 | return "{0}m {1:02}s".format(s // 60, s % 60)
156 | elif s < 24 * 60 * 60:
157 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
158 | else:
159 | return "{0}d {1:02}h {2:02}m".format(
160 | s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60
161 | )
162 |
163 |
164 | def format_time_brief(seconds: Union[int, float]) -> str:
165 | """Convert the seconds to human readable string with days, hours, minutes and seconds."""
166 | s = int(np.rint(seconds))
167 |
168 | if s < 60:
169 | return "{0}s".format(s)
170 | elif s < 60 * 60:
171 | return "{0}m {1:02}s".format(s // 60, s % 60)
172 | elif s < 24 * 60 * 60:
173 | return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
174 | else:
175 | return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
176 |
177 |
178 | def ask_yes_no(question: str) -> bool:
179 | """Ask the user the question until the user inputs a valid answer."""
180 | while True:
181 | try:
182 | print("{0} [y/n]".format(question))
183 | return strtobool(input().lower())
184 | except ValueError:
185 | pass
186 |
187 |
188 | def tuple_product(t: Tuple) -> Any:
189 | """Calculate the product of the tuple elements."""
190 | result = 1
191 |
192 | for v in t:
193 | result *= v
194 |
195 | return result
196 |
197 |
198 | _str_to_ctype = {
199 | "uint8": ctypes.c_ubyte,
200 | "uint16": ctypes.c_uint16,
201 | "uint32": ctypes.c_uint32,
202 | "uint64": ctypes.c_uint64,
203 | "int8": ctypes.c_byte,
204 | "int16": ctypes.c_int16,
205 | "int32": ctypes.c_int32,
206 | "int64": ctypes.c_int64,
207 | "float32": ctypes.c_float,
208 | "float64": ctypes.c_double,
209 | }
210 |
211 |
212 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
213 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
214 | type_str = None
215 |
216 | if isinstance(type_obj, str):
217 | type_str = type_obj
218 | elif hasattr(type_obj, "__name__"):
219 | type_str = type_obj.__name__
220 | elif hasattr(type_obj, "name"):
221 | type_str = type_obj.name
222 | else:
223 | raise RuntimeError("Cannot infer type name from input")
224 |
225 | assert type_str in _str_to_ctype.keys()
226 |
227 | my_dtype = np.dtype(type_str)
228 | my_ctype = _str_to_ctype[type_str]
229 |
230 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
231 |
232 | return my_dtype, my_ctype
233 |
234 |
235 | def is_pickleable(obj: Any) -> bool:
236 | try:
237 | with io.BytesIO() as stream:
238 | pickle.dump(obj, stream)
239 | return True
240 | except:
241 | return False
242 |
243 |
244 | # Functionality to import modules/objects by name, and call functions by name
245 | # ------------------------------------------------------------------------------------------
246 |
247 |
248 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
249 | """Searches for the underlying module behind the name to some python object.
250 | Returns the module and the object name (original name with module part removed)."""
251 |
252 | # allow convenience shorthands, substitute them by full names
253 | obj_name = re.sub("^np.", "numpy.", obj_name)
254 | obj_name = re.sub("^tf.", "tensorflow.", obj_name)
255 |
256 | # list alternatives for (module_name, local_obj_name)
257 | parts = obj_name.split(".")
258 | name_pairs = [
259 | (".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)
260 | ]
261 |
262 | # try each alternative in turn
263 | for module_name, local_obj_name in name_pairs:
264 | try:
265 | module = importlib.import_module(module_name) # may raise ImportError
266 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
267 | return module, local_obj_name
268 | except:
269 | pass
270 |
271 | # maybe some of the modules themselves contain errors?
272 | for module_name, _local_obj_name in name_pairs:
273 | try:
274 | importlib.import_module(module_name) # may raise ImportError
275 | except ImportError:
276 | if not str(sys.exc_info()[1]).startswith(
277 | "No module named '" + module_name + "'"
278 | ):
279 | raise
280 |
281 | # maybe the requested attribute is missing?
282 | for module_name, local_obj_name in name_pairs:
283 | try:
284 | module = importlib.import_module(module_name) # may raise ImportError
285 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
286 | except ImportError:
287 | pass
288 |
289 | # we are out of luck, but we have no idea why
290 | raise ImportError(obj_name)
291 |
292 |
293 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
294 | """Traverses the object name and returns the last (rightmost) python object."""
295 | if obj_name == "":
296 | return module
297 | obj = module
298 | for part in obj_name.split("."):
299 | obj = getattr(obj, part)
300 | return obj
301 |
302 |
303 | def get_obj_by_name(name: str) -> Any:
304 | """Finds the python object with the given name."""
305 | module, obj_name = get_module_from_obj_name(name)
306 | return get_obj_from_module(module, obj_name)
307 |
308 |
309 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
310 | """Finds the python object with the given name and calls it as a function."""
311 | assert func_name is not None
312 | func_obj = get_obj_by_name(func_name)
313 | assert callable(func_obj)
314 | return func_obj(*args, **kwargs)
315 |
316 |
317 | def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
318 | """Finds the python class with the given name and constructs it with the given arguments."""
319 | return call_func_by_name(*args, func_name=class_name, **kwargs)
320 |
321 |
322 | def get_module_dir_by_obj_name(obj_name: str) -> str:
323 | """Get the directory path of the module containing the given object name."""
324 | module, _ = get_module_from_obj_name(obj_name)
325 | return os.path.dirname(inspect.getfile(module))
326 |
327 |
328 | def is_top_level_function(obj: Any) -> bool:
329 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
330 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
331 |
332 |
333 | def get_top_level_function_name(obj: Any) -> str:
334 | """Return the fully-qualified name of a top-level function."""
335 | assert is_top_level_function(obj)
336 | module = obj.__module__
337 | if module == "__main__":
338 | module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
339 | return module + "." + obj.__name__
340 |
341 |
342 | # File system helpers
343 | # ------------------------------------------------------------------------------------------
344 |
345 |
346 | def list_dir_recursively_with_ignore(
347 | dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False
348 | ) -> List[Tuple[str, str]]:
349 | """List all files recursively in a given directory while ignoring given file and directory names.
350 | Returns list of tuples containing both absolute and relative paths."""
351 | assert os.path.isdir(dir_path)
352 | base_name = os.path.basename(os.path.normpath(dir_path))
353 |
354 | if ignores is None:
355 | ignores = []
356 |
357 | result = []
358 |
359 | for root, dirs, files in os.walk(dir_path, topdown=True):
360 | for ignore_ in ignores:
361 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
362 |
363 | # dirs need to be edited in-place
364 | for d in dirs_to_remove:
365 | dirs.remove(d)
366 |
367 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
368 |
369 | absolute_paths = [os.path.join(root, f) for f in files]
370 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
371 |
372 | if add_base_to_relative:
373 | relative_paths = [os.path.join(base_name, p) for p in relative_paths]
374 |
375 | assert len(absolute_paths) == len(relative_paths)
376 | result += zip(absolute_paths, relative_paths)
377 |
378 | return result
379 |
380 |
381 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
382 | """Takes in a list of tuples of (src, dst) paths and copies files.
383 | Will create all necessary directories."""
384 | for file in files:
385 | target_dir_name = os.path.dirname(file[1])
386 |
387 | # will create all intermediate-level directories
388 | if not os.path.exists(target_dir_name):
389 | os.makedirs(target_dir_name)
390 |
391 | shutil.copyfile(file[0], file[1])
392 |
393 |
394 | # URL helpers
395 | # ------------------------------------------------------------------------------------------
396 |
397 |
398 | def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
399 | """Determine whether the given object is a valid URL string."""
400 | if not isinstance(obj, str) or not "://" in obj:
401 | return False
402 | if allow_file_urls and obj.startswith("file://"):
403 | return True
404 | try:
405 | res = requests.compat.urlparse(obj)
406 | if not res.scheme or not res.netloc or not "." in res.netloc:
407 | return False
408 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
409 | if not res.scheme or not res.netloc or not "." in res.netloc:
410 | return False
411 | except:
412 | return False
413 | return True
414 |
415 |
416 | def open_url(
417 | url: str,
418 | cache_dir: str = None,
419 | num_attempts: int = 10,
420 | verbose: bool = True,
421 | return_filename: bool = False,
422 | cache: bool = True,
423 | ) -> Any:
424 | """Download the given URL and return a binary-mode file object to access the data."""
425 | assert num_attempts >= 1
426 | assert not (return_filename and (not cache))
427 |
428 | # Doesn't look like an URL scheme so interpret it as a local filename.
429 | if not re.match("^[a-z]+://", url):
430 | return url if return_filename else open(url, "rb")
431 |
432 | # Handle file URLs. This code handles unusual file:// patterns that
433 | # arise on Windows:
434 | #
435 | # file:///c:/foo.txt
436 | #
437 | # which would translate to a local '/c:/foo.txt' filename that's
438 | # invalid. Drop the forward slash for such pathnames.
439 | #
440 | # If you touch this code path, you should test it on both Linux and
441 | # Windows.
442 | #
443 | # Some internet resources suggest using urllib.request.url2pathname() but
444 | # but that converts forward slashes to backslashes and this causes
445 | # its own set of problems.
446 | if url.startswith("file://"):
447 | filename = urllib.parse.urlparse(url).path
448 | if re.match(r"^/[a-zA-Z]:", filename):
449 | filename = filename[1:]
450 | return filename if return_filename else open(filename, "rb")
451 |
452 | assert is_url(url)
453 |
454 | # Lookup from cache.
455 | if cache_dir is None:
456 | cache_dir = make_cache_dir_path("downloads")
457 |
458 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
459 | if cache:
460 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
461 | if len(cache_files) == 1:
462 | filename = cache_files[0]
463 | return filename if return_filename else open(filename, "rb")
464 |
465 | # Download.
466 | url_name = None
467 | url_data = None
468 | with requests.Session() as session:
469 | if verbose:
470 | print("Downloading %s ..." % url, end="", flush=True)
471 | for attempts_left in reversed(range(num_attempts)):
472 | try:
473 | with session.get(url) as res:
474 | res.raise_for_status()
475 | if len(res.content) == 0:
476 | raise IOError("No data received")
477 |
478 | if len(res.content) < 8192:
479 | content_str = res.content.decode("utf-8")
480 | if "download_warning" in res.headers.get("Set-Cookie", ""):
481 | links = [
482 | html.unescape(link)
483 | for link in content_str.split('"')
484 | if "export=download" in link
485 | ]
486 | if len(links) == 1:
487 | url = requests.compat.urljoin(url, links[0])
488 | raise IOError("Google Drive virus checker nag")
489 | if "Google Drive - Quota exceeded" in content_str:
490 | raise IOError(
491 | "Google Drive download quota exceeded -- please try again later"
492 | )
493 |
494 | match = re.search(
495 | r'filename="([^"]*)"',
496 | res.headers.get("Content-Disposition", ""),
497 | )
498 | url_name = match[1] if match else url
499 | url_data = res.content
500 | if verbose:
501 | print(" done")
502 | break
503 | except KeyboardInterrupt:
504 | raise
505 | except:
506 | if not attempts_left:
507 | if verbose:
508 | print(" failed")
509 | raise
510 | if verbose:
511 | print(".", end="", flush=True)
512 |
513 | # Save to cache.
514 | if cache:
515 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
516 | safe_name = safe_name[: min(len(safe_name), 128)]
517 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
518 | temp_file = os.path.join(
519 | cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name
520 | )
521 | os.makedirs(cache_dir, exist_ok=True)
522 | with open(temp_file, "wb") as f:
523 | f.write(url_data)
524 | os.replace(temp_file, cache_file) # atomic
525 | if return_filename:
526 | return cache_file
527 |
528 | # Return data as file object.
529 | assert not return_filename
530 | return io.BytesIO(url_data)
531 |
--------------------------------------------------------------------------------
/env.yml:
--------------------------------------------------------------------------------
1 | name: imm
2 | channels:
3 | - pytorch
4 | - nvidia
5 | dependencies:
6 | - python==3.9.18
7 | - pip
8 | - click
9 | - requests
10 | - pillow
11 | - numpy
12 | - scipy
13 | - psutil
14 | - tqdm
15 | - imageio
16 | - pytorch=2.5.1
17 | - pytorch-cuda=12.1
18 | - pip:
19 | - einops
20 | - matplotlib
21 | - seaborn
22 | - wandb
23 | - timm==1.0.8
24 | - imageio-ffmpeg
25 | - pyspng
26 | - omegaconf==2.3.0
27 | - hydra-core==1.3.2
28 | - diffusers==0.31.0
29 |
30 |
--------------------------------------------------------------------------------
/generate_images.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import json
4 |
5 | import pickle
6 | import functools
7 | import numpy as np
8 |
9 | import torch
10 | import dnnlib
11 |
12 | import torchvision.utils as vutils
13 | import warnings
14 |
15 | from omegaconf import OmegaConf
16 | from torch_utils import misc
17 | import hydra
18 |
19 | warnings.filterwarnings(
20 | "ignore", "Grad strides do not match bucket view strides"
21 | ) # False warning printed by PyTorch 1.12.
22 |
23 |
24 |
25 |
26 | # ----------------------------------------------------------------------------
27 |
28 | def generator_fn(*args, name='pushforward_generator_fn', **kwargs):
29 | return globals()[name](*args, **kwargs)
30 |
31 |
32 |
33 |
34 | @torch.no_grad()
35 | def pushforward_generator_fn(net, latents, class_labels=None, discretization=None, mid_nt=None, num_steps=None, cfg_scale=None, ):
36 | # Time step discretization.
37 | if discretization == 'uniform':
38 | t_steps = torch.linspace(net.T, net.eps, num_steps+1, dtype=torch.float64, device=latents.device)
39 | elif discretization == 'edm':
40 | nt_min = net.get_log_nt(torch.as_tensor(net.eps, dtype=torch.float64)).exp().item()
41 | nt_max = net.get_log_nt(torch.as_tensor(net.T, dtype=torch.float64)).exp().item()
42 | rho = 7
43 | step_indices = torch.arange(num_steps+1, dtype=torch.float64, device=latents.device)
44 | nt_steps = (nt_max ** (1 / rho) + step_indices / (num_steps) * (nt_min ** (1 / rho) - nt_max ** (1 / rho))) ** rho
45 | t_steps = net.nt_to_t(nt_steps)
46 | else:
47 | if mid_nt is None:
48 | mid_nt = []
49 | mid_t = [net.nt_to_t(torch.as_tensor(nt)).item() for nt in mid_nt]
50 | t_steps = torch.tensor(
51 | [net.T] + list(mid_t), dtype=torch.float64, device=latents.device
52 | )
53 | # t_0 = T, t_N = 0
54 | t_steps = torch.cat([t_steps, torch.ones_like(t_steps[:1]) * net.eps])
55 |
56 | # Sampling steps
57 | x = latents.to(torch.float64)
58 |
59 | for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
60 |
61 | x = net.cfg_forward(x, t_cur, t_next, class_labels=class_labels, cfg_scale=cfg_scale ).to(
62 | torch.float64
63 | )
64 |
65 |
66 | return x
67 |
68 | @torch.no_grad()
69 | def restart_generator_fn(net, latents, class_labels=None, discretization=None, mid_nt=None, num_steps=None, cfg_scale=None ):
70 | # Time step discretization.
71 | if discretization == 'uniform':
72 | t_steps = torch.linspace(net.T, net.eps, num_steps+1, dtype=torch.float64, device=latents.device)[:-1]
73 | elif discretization == 'edm':
74 | nt_min = net.get_log_nt(torch.as_tensor(net.eps, dtype=torch.float64)).exp().item()
75 | nt_max = net.get_log_nt(torch.as_tensor(net.T, dtype=torch.float64)).exp().item()
76 | rho = 7
77 | step_indices = torch.arange(num_steps+1, dtype=torch.float64, device=latents.device)
78 | nt_steps = (nt_max ** (1 / rho) + step_indices / (num_steps) * (nt_min ** (1 / rho) - nt_max ** (1 / rho))) ** rho
79 | t_steps = net.nt_to_t(nt_steps)[:-1]
80 | else:
81 | if mid_nt is None:
82 | mid_nt = []
83 | mid_t = [net.nt_to_t(torch.as_tensor(nt)).item() for nt in mid_nt]
84 | t_steps = torch.tensor(
85 | [net.T] + list(mid_t), dtype=torch.float64, device=latents.device
86 | )
87 | # Sampling steps
88 | x = latents.to(torch.float64)
89 |
90 | for i, t_cur in enumerate(t_steps):
91 |
92 |
93 | x = net.cfg_forward(x, t_cur, torch.ones_like(t_cur) * net.eps, class_labels=class_labels, cfg_scale=cfg_scale ).to(
94 | torch.float64
95 | )
96 |
97 | if i < len(t_steps) - 1:
98 | x, _ = net.add_noise(x, t_steps[i+1])
99 |
100 | return x
101 |
102 |
103 |
104 | # ----------------------------------------------------------------------------
105 |
106 | @hydra.main(version_base=None, config_path="configs")
107 | def main(cfg):
108 |
109 | device = torch.device("cuda")
110 | config = OmegaConf.create(OmegaConf.to_yaml(cfg, resolve=True))
111 |
112 | # Random seed.
113 | if config.eval.seed is None:
114 |
115 | seed = torch.randint(1 << 31, size=[], device=device)
116 | torch.distributed.broadcast(seed, src=0)
117 | config.eval.seed = int(seed)
118 |
119 | # Checkpoint to evaluate.
120 | resume_pkl = cfg.eval.resume
121 | cudnn_benchmark = config.eval.cudnn_benchmark
122 | seed = config.eval.seed
123 | encoder_kwargs = config.encoder
124 |
125 | batch_size = config.eval.batch_size
126 | sample_kwargs_dict = config.get('sampling', {})
127 | # Initialize.
128 | np.random.seed(seed % (1 << 31))
129 | torch.manual_seed(np.random.randint(1 << 31))
130 | torch.backends.cudnn.benchmark = cudnn_benchmark
131 | torch.backends.cudnn.allow_tf32 = True
132 | torch.backends.cuda.matmul.allow_tf32 = True
133 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
134 |
135 | print('Setting up encoder...')
136 | encoder = dnnlib.util.construct_class_by_name(**encoder_kwargs)
137 |
138 | # Construct network.
139 | print("Constructing network...")
140 |
141 | interface_kwargs = dict(
142 | img_resolution=config.resolution,
143 | img_channels=config.channels,
144 | label_dim=config.label_dim,
145 | )
146 | if config.get('network', None) is not None:
147 | network_kwargs = config.network
148 | net = dnnlib.util.construct_class_by_name(
149 | **network_kwargs, **interface_kwargs
150 | ) # subclass of torch.nn.Module
151 | net.eval().requires_grad_(False).to(device)
152 |
153 | # Resume training from previous snapshot.
154 | with dnnlib.util.open_url(resume_pkl, verbose=True) as f:
155 | data = pickle.load(f)
156 |
157 | if config.get('network', None) is not None:
158 | misc.copy_params_and_buffers(
159 | src_module=data['ema'], dst_module=net, require_all=True
160 | )
161 | else:
162 | net = data['ema'].eval().requires_grad_(False).to(device)
163 |
164 |
165 | grid_z = net.get_init_noise(
166 | [batch_size, net.img_channels, net.img_resolution, net.img_resolution],
167 | device,
168 | )
169 | if net.label_dim > 0:
170 | labels = torch.randint(0, net.label_dim, (batch_size,), device=device)
171 | grid_c = torch.nn.functional.one_hot(labels, num_classes=net.label_dim)
172 | else:
173 | grid_c = None
174 |
175 | # Few-step Evaluation.
176 | generator_fn_dict = {k: functools.partial(generator_fn, **sample_kwargs) for k, sample_kwargs in sample_kwargs_dict.items()}
177 | print("Sample images...")
178 | res = {}
179 | for key, gen_fn in generator_fn_dict.items():
180 | images = gen_fn(net, grid_z, grid_c)
181 | images = encoder.decode(images.to(device) ).detach().cpu()
182 |
183 | vutils.save_image(
184 | images / 255.,
185 | os.path.join(f"{key}_samples.png"),
186 | nrow=int(np.sqrt(images.shape[0])),
187 | normalize=False,
188 | )
189 |
190 | res[key] = images
191 |
192 | print('done.')
193 |
194 | # ----------------------------------------------------------------------------
195 |
196 | if __name__ == "__main__":
197 | main()
198 |
199 | # ----------------------------------------------------------------------------
200 |
--------------------------------------------------------------------------------
/model_card.md:
--------------------------------------------------------------------------------
1 | # Model Card
2 |
3 | These are Inductive Moment Matching (IMM) models described in the paper [Inductive Moment Matching](https://arxiv.org/abs/2503.07565). We include the following models in this release:
4 |
5 | We provide pretrained checkpoints through our [repo](https://huggingface.co/lumaai/imm) on Hugging Face:
6 | * IMM on CIFAR-10: [cifar10.pkl](https://huggingface.co/lumaai/imm/resolve/main/cifar10.pt).
7 | * IMM on ImageNet-256x256:
8 | 1. `t-s` is passed as second time embedding, trained with `a=2`: [imagenet256_ts_a2.pkl](https://huggingface.co/lumaai/imm/resolve/main/imagenet256_ts_a2.pkl).
9 | 2. `s` is passed as second time embedding directly, trained with `a=1`: [imagenet256_s_a1.pkl](https://huggingface.co/lumaai/imm/resolve/main/imagenet256_s_a1.pkl).
10 |
11 |
12 | ## Intended Use
13 |
14 | This model is provided exclusively for research purposes. Acceptable uses include:
15 |
16 | - Academic research on generative modeling techniques
17 | - Benchmarking against other generative models
18 | - Educational purposes to understand Inductive Moment Matching algorithms
19 | - Exploration of model capabilities in controlled research environments
20 |
21 | Prohibited Uses:
22 |
23 | - Any commercial applications or commercial product development
24 | - Integration into products or services offered to customers
25 | - Generation of content for commercial distribution
26 | - Any applications that could result in harm, including but not limited to:
27 | - Creating deceptive or misleading content
28 | - Generating harmful, offensive, or discriminatory outputs
29 | - Circumventing security systems
30 | - Creating deepfakes or other potentially harmful synthetic media
31 | - Any use case that could negatively impact individuals or society
32 |
33 | ## Limitations
34 |
35 | The IMM models have several limitations common to image generation models:
36 |
37 | - Limited Resolution: The models are trained on specific resolutions (CIFAR-10 and 256x256 for ImageNet), and generating images at significantly higher resolutions may result in quality degradation or artifacts.
38 | - Computational Resources: Training and inference require substantial computational resources, which may limit their practical applications in resource-constrained environments.
39 | - Training Data Limitations: The models are trained on specific datasets (CIFAR-10 and ImageNet), and may not generalize well to other domains or data distributions.
40 | - Generalization to Unseen Data: The models may not generalize well to unseen data or domains, which is a common limitation for generative models.
41 |
42 |
43 |
--------------------------------------------------------------------------------
/torch_utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/torch_utils/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from . import training_stats
4 | import torch.distributed as dist
5 | import datetime
6 |
7 | # ----------------------------------------------------------------------------
8 |
9 |
10 | def init():
11 | if "MASTER_ADDR" not in os.environ:
12 | os.environ["MASTER_ADDR"] = "localhost"
13 | if "MASTER_PORT" not in os.environ:
14 | os.environ["MASTER_PORT"] = "29500"
15 | if "RANK" not in os.environ:
16 | os.environ["RANK"] = "0"
17 | if "LOCAL_RANK" not in os.environ:
18 | os.environ["LOCAL_RANK"] = "0"
19 | if "WORLD_SIZE" not in os.environ:
20 | os.environ["WORLD_SIZE"] = "1"
21 |
22 | os.environ["NCCL_SOCKET_IFNAME"] = "enp"
23 | os.environ["FI_EFA_SET_CUDA_SYNC_MEMOPS"] = "0"
24 |
25 | os.environ["NCCL_BUFFSIZE"] = "8388608"
26 | os.environ["NCCL_P2P_NET_CHUNKSIZE"] = "524288"
27 |
28 |
29 | os.environ['TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC'] = '1200'
30 | os.environ['TORCH_NCCL_ENABLE_MONITORING'] = '0'
31 |
32 | backend = "gloo" if os.name == "nt" else "nccl"
33 | torch.distributed.init_process_group(backend=backend, init_method="env://", timeout=datetime.timedelta(minutes=120),)
34 | torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", "0")))
35 |
36 | sync_device = torch.device("cuda") if get_world_size() > 1 else None
37 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device)
38 |
39 |
40 | # ----------------------------------------------------------------------------
41 |
42 |
43 | def get_rank():
44 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
45 |
46 |
47 | # ----------------------------------------------------------------------------
48 |
49 |
50 | def get_world_size():
51 | return (
52 | torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
53 | )
54 |
55 |
56 | # ----------------------------------------------------------------------------
57 |
58 |
59 | def should_stop():
60 | return False
61 |
62 |
63 | # ----------------------------------------------------------------------------
64 |
65 |
66 | def update_progress(cur, total):
67 | _ = cur, total
68 |
69 |
70 | # ----------------------------------------------------------------------------
71 |
72 |
73 | def print0(*args, **kwargs):
74 | if get_rank() == 0:
75 | print(*args, **kwargs)
76 |
77 |
78 | # ----------------------------------------------------------------------------
79 |
80 |
81 |
82 | broadcast = dist.broadcast
83 | new_group = dist.new_group
84 | barrier = dist.barrier
85 | all_gather = dist.all_gather
86 | send = dist.send
87 | recv = dist.recv
--------------------------------------------------------------------------------
/torch_utils/misc.py:
--------------------------------------------------------------------------------
1 | import re
2 | import contextlib
3 | import numpy as np
4 | import torch
5 | import warnings
6 | import dnnlib
7 | import functools
8 | from . import persistence
9 |
10 | # ----------------------------------------------------------------------------
11 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
12 | # same constant is used multiple times.
13 |
14 | _constant_cache = dict()
15 |
16 |
17 | def constant(value, shape=None, dtype=None, device=None, memory_format=None):
18 | value = np.asarray(value)
19 | if shape is not None:
20 | shape = tuple(shape)
21 | if dtype is None:
22 | dtype = torch.get_default_dtype()
23 | if device is None:
24 | device = torch.device("cpu")
25 | if memory_format is None:
26 | memory_format = torch.contiguous_format
27 |
28 | key = (
29 | value.shape,
30 | value.dtype,
31 | value.tobytes(),
32 | shape,
33 | dtype,
34 | device,
35 | memory_format,
36 | )
37 | tensor = _constant_cache.get(key, None)
38 | if tensor is None:
39 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
40 | if shape is not None:
41 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
42 | tensor = tensor.contiguous(memory_format=memory_format)
43 | _constant_cache[key] = tensor
44 | return tensor
45 |
46 | #----------------------------------------------------------------------------
47 | # Variant of constant() that inherits dtype and device from the given
48 | # reference tensor by default.
49 |
50 | def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
51 | if dtype is None:
52 | dtype = ref.dtype
53 | if device is None:
54 | device = ref.device
55 | return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
56 |
57 | #----------------------------------------------------------------------------
58 | # Cached construction of temporary tensors in pinned CPU memory.
59 |
60 | @functools.lru_cache(None)
61 | def pinned_buf(shape, dtype):
62 | return torch.empty(shape, dtype=dtype).pin_memory()
63 |
64 | #----------------------------------------------------------------------------
65 | # Symbolic assert.
66 |
67 | try:
68 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
69 | except AttributeError:
70 | symbolic_assert = torch.Assert # 1.7.0
71 |
72 |
73 | # ----------------------------------------------------------------------------
74 | # Replace NaN/Inf with specified numerical values.
75 |
76 | try:
77 | nan_to_num = torch.nan_to_num # 1.8.0a0
78 | except AttributeError:
79 |
80 | def nan_to_num(
81 | input, nan=0.0, posinf=None, neginf=None, *, out=None
82 | ): # pylint: disable=redefined-builtin
83 | assert isinstance(input, torch.Tensor)
84 | if posinf is None:
85 | posinf = torch.finfo(input.dtype).max
86 | if neginf is None:
87 | neginf = torch.finfo(input.dtype).min
88 | assert nan == 0
89 | return torch.clamp(
90 | input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out
91 | )
92 |
93 |
94 | # ----------------------------------------------------------------------------
95 | # Symbolic assert.
96 |
97 | try:
98 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
99 | except AttributeError:
100 | symbolic_assert = torch.Assert # 1.7.0
101 |
102 | # ----------------------------------------------------------------------------
103 | # Context manager to temporarily suppress known warnings in torch.jit.trace().
104 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
105 |
106 |
107 | @contextlib.contextmanager
108 | def suppress_tracer_warnings():
109 | flt = ("ignore", None, torch.jit.TracerWarning, None, 0)
110 | warnings.filters.insert(0, flt)
111 | yield
112 | warnings.filters.remove(flt)
113 |
114 |
115 | # ----------------------------------------------------------------------------
116 | # Assert that the shape of a tensor matches the given list of integers.
117 | # None indicates that the size of a dimension is allowed to vary.
118 | # Performs symbolic assertion when used in torch.jit.trace().
119 |
120 |
121 | def assert_shape(tensor, ref_shape):
122 | if tensor.ndim != len(ref_shape):
123 | raise AssertionError(
124 | f"Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}"
125 | )
126 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
127 | if ref_size is None:
128 | pass
129 | elif isinstance(ref_size, torch.Tensor):
130 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
131 | symbolic_assert(
132 | torch.equal(torch.as_tensor(size), ref_size),
133 | f"Wrong size for dimension {idx}",
134 | )
135 | elif isinstance(size, torch.Tensor):
136 | with suppress_tracer_warnings(): # as_tensor results are registered as constants
137 | symbolic_assert(
138 | torch.equal(size, torch.as_tensor(ref_size)),
139 | f"Wrong size for dimension {idx}: expected {ref_size}",
140 | )
141 | elif size != ref_size:
142 | raise AssertionError(
143 | f"Wrong size for dimension {idx}: got {size}, expected {ref_size}"
144 | )
145 |
146 |
147 | # ----------------------------------------------------------------------------
148 | # Function decorator that calls torch.autograd.profiler.record_function().
149 |
150 |
151 | def profiled_function(fn):
152 | def decorator(*args, **kwargs):
153 | with torch.autograd.profiler.record_function(fn.__name__):
154 | return fn(*args, **kwargs)
155 |
156 | decorator.__name__ = fn.__name__
157 | return decorator
158 |
159 |
160 | # ----------------------------------------------------------------------------
161 | # Sampler for torch.utils.data.DataLoader that loops over the dataset
162 | # indefinitely, shuffling items as it goes.
163 |
164 |
165 | class InfiniteSampler(torch.utils.data.Sampler):
166 | def __init__(
167 | self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5
168 | ):
169 | assert len(dataset) > 0
170 | assert num_replicas > 0
171 | assert 0 <= rank < num_replicas
172 | assert 0 <= window_size <= 1
173 | super().__init__(dataset)
174 | self.dataset = dataset
175 | self.rank = rank
176 | self.num_replicas = num_replicas
177 | self.shuffle = shuffle
178 | self.seed = seed
179 | self.window_size = window_size
180 |
181 | def __iter__(self):
182 | order = np.arange(len(self.dataset))
183 | rnd = None
184 | window = 0
185 | if self.shuffle:
186 | rnd = np.random.RandomState(self.seed)
187 | rnd.shuffle(order)
188 | window = int(np.rint(order.size * self.window_size))
189 |
190 | idx = 0
191 | while True:
192 | i = idx % order.size
193 | if idx % self.num_replicas == self.rank:
194 | yield order[i]
195 | if window >= 2:
196 | j = (i - rnd.randint(window)) % order.size
197 | order[i], order[j] = order[j], order[i]
198 | idx += 1
199 |
200 |
201 | # ----------------------------------------------------------------------------
202 | # Utilities for operating with torch.nn.Module parameters and buffers.
203 |
204 |
205 | def params_and_buffers(module):
206 | assert isinstance(module, torch.nn.Module)
207 | return list(module.parameters()) + list(module.buffers())
208 |
209 |
210 | def named_params_and_buffers(module):
211 | assert isinstance(module, torch.nn.Module)
212 | return list(module.named_parameters()) + list(module.named_buffers())
213 |
214 |
215 | @torch.no_grad()
216 | def copy_params_and_buffers(src_module, dst_module, require_all=False):
217 | assert isinstance(src_module, torch.nn.Module)
218 | assert isinstance(dst_module, torch.nn.Module)
219 | src_tensors = dict(named_params_and_buffers(src_module))
220 | for name, tensor in named_params_and_buffers(dst_module):
221 | assert (name in src_tensors) or (not require_all)
222 | if name in src_tensors:
223 | tensor.copy_(src_tensors[name])
224 |
225 |
226 | # ----------------------------------------------------------------------------
227 | # Context manager for easily enabling/disabling DistributedDataParallel
228 | # synchronization.
229 |
230 |
231 | @contextlib.contextmanager
232 | def ddp_sync(module, sync):
233 | assert isinstance(module, torch.nn.Module)
234 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
235 | yield
236 | else:
237 | with module.no_sync():
238 | yield
239 |
240 |
241 | # ----------------------------------------------------------------------------
242 | # Check DistributedDataParallel consistency across processes.
243 |
244 |
245 | def check_ddp_consistency(module, ignore_regex=None):
246 | assert isinstance(module, torch.nn.Module)
247 | for name, tensor in named_params_and_buffers(module):
248 | fullname = type(module).__name__ + "." + name
249 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
250 | continue
251 | tensor = tensor.detach()
252 | if tensor.is_floating_point():
253 | tensor = nan_to_num(tensor)
254 | other = tensor.clone()
255 | torch.distributed.broadcast(tensor=other, src=0)
256 | assert (tensor == other).all(), fullname
257 |
258 |
259 | # ----------------------------------------------------------------------------
260 | # Print summary table of module hierarchy.
261 |
262 |
263 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
264 | assert isinstance(module, torch.nn.Module)
265 | assert not isinstance(module, torch.jit.ScriptModule)
266 | assert isinstance(inputs, (tuple, list))
267 |
268 | # Register hooks.
269 | entries = []
270 | nesting = [0]
271 |
272 | def pre_hook(_mod, _inputs):
273 | nesting[0] += 1
274 |
275 | def post_hook(mod, _inputs, outputs):
276 | nesting[0] -= 1
277 | if nesting[0] <= max_nesting:
278 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
279 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
280 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
281 |
282 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
283 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
284 |
285 | # Run module.
286 | outputs = module(*inputs)
287 | for hook in hooks:
288 | hook.remove()
289 |
290 | # Identify unique outputs, parameters, and buffers.
291 | tensors_seen = set()
292 | for e in entries:
293 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
294 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
295 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
296 | tensors_seen |= {
297 | id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs
298 | }
299 |
300 | # Filter out redundant entries.
301 | if skip_redundant:
302 | entries = [
303 | e
304 | for e in entries
305 | if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)
306 | ]
307 |
308 | # Construct table.
309 | rows = [
310 | [type(module).__name__, "Parameters", "Buffers", "Output shape", "Datatype"]
311 | ]
312 | rows += [["---"] * len(rows[0])]
313 | param_total = 0
314 | buffer_total = 0
315 | submodule_names = {mod: name for name, mod in module.named_modules()}
316 | for e in entries:
317 | name = "" if e.mod is module else submodule_names[e.mod]
318 | param_size = sum(t.numel() for t in e.unique_params)
319 | buffer_size = sum(t.numel() for t in e.unique_buffers)
320 | output_shapes = [str(list(t.shape)) for t in e.outputs]
321 | output_dtypes = [str(t.dtype).split(".")[-1] for t in e.outputs]
322 | rows += [
323 | [
324 | name + (":0" if len(e.outputs) >= 2 else ""),
325 | str(param_size) if param_size else "-",
326 | str(buffer_size) if buffer_size else "-",
327 | (output_shapes + ["-"])[0],
328 | (output_dtypes + ["-"])[0],
329 | ]
330 | ]
331 | for idx in range(1, len(e.outputs)):
332 | rows += [
333 | [name + f":{idx}", "-", "-", output_shapes[idx], output_dtypes[idx]]
334 | ]
335 | param_total += param_size
336 | buffer_total += buffer_size
337 | rows += [["---"] * len(rows[0])]
338 | rows += [["Total", str(param_total), str(buffer_total), "-", "-"]]
339 |
340 | # Print table.
341 | widths = [max(len(cell) for cell in column) for column in zip(*rows)]
342 | print()
343 | for row in rows:
344 | print(
345 | " ".join(
346 | cell + " " * (width - len(cell)) for cell, width in zip(row, widths)
347 | )
348 | )
349 | print()
350 | return outputs
351 |
352 |
353 | # ----------------------------------------------------------------------------
354 |
355 |
356 |
357 | import abc
358 | @persistence.persistent_class
359 | class ActivationHook(abc.ABC):
360 | def __init__(self, modules_to_watch):
361 | self.modules_to_watch = modules_to_watch
362 |
363 | self._hook_result = {}
364 |
365 | @property
366 | def hook_result(self):
367 | return self._hook_result
368 |
369 | @abc.abstractmethod
370 | def __call__(self, module, input, output):
371 | pass
372 |
373 | def watch(self, models_dict):
374 |
375 | acc = []
376 | #register name for easy access
377 | for k in models_dict:
378 | model = models_dict[k]
379 | for name, module in model.named_modules():
380 | if self.modules_to_watch == 'all':
381 | module._hook_name = name
382 | else:
383 | for mw in self.modules_to_watch:
384 | if mw in name and name not in acc:
385 | module._hook_name = k + '.' + name
386 | acc.append(name)
387 |
388 |
389 | def clear(self):
390 | self._hook_result = {}
391 |
392 |
393 | @persistence.persistent_class
394 | class ActivationMagnitudeHook(ActivationHook):
395 | def __init__(self, modules_to_watch='all'):
396 | super().__init__(modules_to_watch)
397 |
398 | def __call__(self, module, input, output):
399 | if hasattr(module, '_hook_name'):
400 | # only track registered modules
401 | if isinstance(output, torch.Tensor):
402 | output_ = output.detach()
403 |
404 | self._hook_result['activations/' + module._hook_name + '_magnitude_div_10000'] = (output_/10000).flatten(1).norm(1).mean().item() #prevent overflow
405 | else:
406 | self._hook_result['activations/' + module._hook_name + '_magnitude_div_10000'] = 0
407 |
408 |
--------------------------------------------------------------------------------
/torch_utils/persistence.py:
--------------------------------------------------------------------------------
1 | """Facilities for pickling Python code alongside other data.
2 |
3 | The pickled code is automatically imported into a separate Python module
4 | during unpickling. This way, any previously exported pickles will remain
5 | usable even if the original code is no longer available, or if the current
6 | version of the code is not consistent with what was originally pickled."""
7 |
8 | import sys
9 | import pickle
10 | import io
11 | import inspect
12 | import copy
13 | import uuid
14 | import types
15 | import dnnlib
16 |
17 | #----------------------------------------------------------------------------
18 |
19 | _version = 6 # internal version number
20 | _decorators = set() # {decorator_class, ...}
21 | _import_hooks = [] # [hook_function, ...]
22 | _module_to_src_dict = dict() # {module: src, ...}
23 | _src_to_module_dict = dict() # {src: module, ...}
24 |
25 | #----------------------------------------------------------------------------
26 |
27 | def persistent_class(orig_class):
28 | r"""Class decorator that extends a given class to save its source code
29 | when pickled.
30 |
31 | Example:
32 |
33 | from torch_utils import persistence
34 |
35 | @persistence.persistent_class
36 | class MyNetwork(torch.nn.Module):
37 | def __init__(self, num_inputs, num_outputs):
38 | super().__init__()
39 | self.fc = MyLayer(num_inputs, num_outputs)
40 | ...
41 |
42 | @persistence.persistent_class
43 | class MyLayer(torch.nn.Module):
44 | ...
45 |
46 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its
47 | source code alongside other internal state (e.g., parameters, buffers,
48 | and submodules). This way, any previously exported pickle will remain
49 | usable even if the class definitions have been modified or are no
50 | longer available.
51 |
52 | The decorator saves the source code of the entire Python module
53 | containing the decorated class. It does *not* save the source code of
54 | any imported modules. Thus, the imported modules must be available
55 | during unpickling, also including `torch_utils.persistence` itself.
56 |
57 | It is ok to call functions defined in the same module from the
58 | decorated class. However, if the decorated class depends on other
59 | classes defined in the same module, they must be decorated as well.
60 | This is illustrated in the above example in the case of `MyLayer`.
61 |
62 | It is also possible to employ the decorator just-in-time before
63 | calling the constructor. For example:
64 |
65 | cls = MyLayer
66 | if want_to_make_it_persistent:
67 | cls = persistence.persistent_class(cls)
68 | layer = cls(num_inputs, num_outputs)
69 |
70 | As an additional feature, the decorator also keeps track of the
71 | arguments that were used to construct each instance of the decorated
72 | class. The arguments can be queried via `obj.init_args` and
73 | `obj.init_kwargs`, and they are automatically pickled alongside other
74 | object state. This feature can be disabled on a per-instance basis
75 | by setting `self._record_init_args = False` in the constructor.
76 |
77 | A typical use case is to first unpickle a previous instance of a
78 | persistent class, and then upgrade it to use the latest version of
79 | the source code:
80 |
81 | with open('old_pickle.pkl', 'rb') as f:
82 | old_net = pickle.load(f)
83 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs)
84 | misc.copy_params_and_buffers(old_net, new_net, require_all=True)
85 | """
86 | assert isinstance(orig_class, type)
87 | if is_persistent(orig_class):
88 | return orig_class
89 |
90 | assert orig_class.__module__ in sys.modules
91 | orig_module = sys.modules[orig_class.__module__]
92 | orig_module_src = _module_to_src(orig_module)
93 |
94 | class Decorator(orig_class):
95 | _orig_module_src = orig_module_src
96 | _orig_class_name = orig_class.__name__
97 |
98 | def __init__(self, *args, **kwargs):
99 | super().__init__(*args, **kwargs)
100 | record_init_args = getattr(self, '_record_init_args', True)
101 | self._init_args = copy.deepcopy(args) if record_init_args else None
102 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None
103 | assert orig_class.__name__ in orig_module.__dict__
104 | _check_pickleable(self.__reduce__())
105 |
106 | @property
107 | def init_args(self):
108 | assert self._init_args is not None
109 | return copy.deepcopy(self._init_args)
110 |
111 | @property
112 | def init_kwargs(self):
113 | assert self._init_kwargs is not None
114 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs))
115 |
116 | def __reduce__(self):
117 | fields = list(super().__reduce__())
118 | fields += [None] * max(3 - len(fields), 0)
119 | if fields[0] is not _reconstruct_persistent_obj:
120 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2])
121 | fields[0] = _reconstruct_persistent_obj # reconstruct func
122 | fields[1] = (meta,) # reconstruct args
123 | fields[2] = None # state dict
124 | return tuple(fields)
125 |
126 | Decorator.__name__ = orig_class.__name__
127 | Decorator.__module__ = orig_class.__module__
128 | _decorators.add(Decorator)
129 | return Decorator
130 |
131 | #----------------------------------------------------------------------------
132 |
133 | def is_persistent(obj):
134 | r"""Test whether the given object or class is persistent, i.e.,
135 | whether it will save its source code when pickled.
136 | """
137 | try:
138 | if obj in _decorators:
139 | return True
140 | except TypeError:
141 | pass
142 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck
143 |
144 | #----------------------------------------------------------------------------
145 |
146 | def import_hook(hook):
147 | r"""Register an import hook that is called whenever a persistent object
148 | is being unpickled. A typical use case is to patch the pickled source
149 | code to avoid errors and inconsistencies when the API of some imported
150 | module has changed.
151 |
152 | The hook should have the following signature:
153 |
154 | hook(meta) -> modified meta
155 |
156 | `meta` is an instance of `dnnlib.EasyDict` with the following fields:
157 |
158 | type: Type of the persistent object, e.g. `'class'`.
159 | version: Internal version number of `torch_utils.persistence`.
160 | module_src Original source code of the Python module.
161 | class_name: Class name in the original Python module.
162 | state: Internal state of the object.
163 |
164 | Example:
165 |
166 | @persistence.import_hook
167 | def wreck_my_network(meta):
168 | if meta.class_name == 'MyNetwork':
169 | print('MyNetwork is being imported. I will wreck it!')
170 | meta.module_src = meta.module_src.replace("True", "False")
171 | return meta
172 | """
173 | assert callable(hook)
174 | _import_hooks.append(hook)
175 |
176 | #----------------------------------------------------------------------------
177 |
178 | def _reconstruct_persistent_obj(meta):
179 | r"""Hook that is called internally by the `pickle` module to unpickle
180 | a persistent object.
181 | """
182 | meta = dnnlib.EasyDict(meta)
183 |
184 | meta.state = dnnlib.EasyDict(meta.state) if meta.state is not None else dnnlib.EasyDict()
185 | for hook in _import_hooks:
186 | meta = hook(meta)
187 | assert meta is not None
188 |
189 | assert meta.version == _version
190 | module = _src_to_module(meta.module_src)
191 |
192 | assert meta.type == 'class'
193 | orig_class = module.__dict__[meta.class_name]
194 | decorator_class = persistent_class(orig_class)
195 | obj = decorator_class.__new__(decorator_class)
196 |
197 | setstate = getattr(obj, '__setstate__', None)
198 | if callable(setstate):
199 | setstate(meta.state) # pylint: disable=not-callable
200 | else:
201 | obj.__dict__.update(meta.state)
202 |
203 | return obj
204 |
205 | #----------------------------------------------------------------------------
206 |
207 | def _module_to_src(module):
208 | r"""Query the source code of a given Python module.
209 | """
210 | src = _module_to_src_dict.get(module, None)
211 | if src is None:
212 | src = inspect.getsource(module)
213 | _module_to_src_dict[module] = src
214 | _src_to_module_dict[src] = module
215 | return src
216 |
217 | def _src_to_module(src):
218 | r"""Get or create a Python module for the given source code.
219 | """
220 | module = _src_to_module_dict.get(src, None)
221 | if module is None:
222 | module_name = "_imported_module_" + uuid.uuid4().hex
223 | module = types.ModuleType(module_name)
224 | sys.modules[module_name] = module
225 | _module_to_src_dict[module] = src
226 | _src_to_module_dict[src] = module
227 | exec(src, module.__dict__) # pylint: disable=exec-used
228 | return module
229 |
230 | #----------------------------------------------------------------------------
231 |
232 | def _check_pickleable(obj):
233 | r"""Check that the given object is pickleable, raising an exception if
234 | it is not. This function is expected to be considerably more efficient
235 | than actually pickling the object.
236 | """
237 | def recurse(obj):
238 | if isinstance(obj, (list, tuple, set)):
239 | return [recurse(x) for x in obj]
240 | if isinstance(obj, dict):
241 | return [[recurse(x), recurse(y)] for x, y in obj.items()]
242 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)):
243 | return None # Python primitive types are pickleable.
244 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']:
245 | return None # NumPy arrays and PyTorch tensors are pickleable.
246 | if is_persistent(obj):
247 | return None # Persistent objects are pickleable, by virtue of the constructor check.
248 | return obj
249 | with io.BytesIO() as f:
250 | pickle.dump(recurse(obj), f)
251 |
252 | #----------------------------------------------------------------------------
253 |
--------------------------------------------------------------------------------
/torch_utils/training_stats.py:
--------------------------------------------------------------------------------
1 | """Facilities for reporting and collecting training statistics across
2 | multiple processes and devices. The interface is designed to minimize
3 | synchronization overhead as well as the amount of boilerplate in user
4 | code."""
5 |
6 | import re
7 | import numpy as np
8 | import torch
9 | import torch.distributed
10 | import dnnlib
11 |
12 | from collections import defaultdict
13 |
14 | from . import misc
15 |
16 | # ----------------------------------------------------------------------------
17 |
18 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
19 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction.
20 | _counter_dtype = torch.float64 # Data type to use for the internal counters.
21 | _rank = 0 # Rank of the current process.
22 | _sync_device = (
23 | None # Device to use for multiprocess communication. None = single-process.
24 | )
25 | _sync_called = False # Has _sync() been called yet?
26 | _counters = (
27 | dict()
28 | ) # Running counters on each device, updated by report(): name => device => torch.Tensor
29 | _cumulative = (
30 | dict()
31 | ) # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
32 |
33 | # ----------------------------------------------------------------------------
34 |
35 |
36 | def init_multiprocessing(rank, sync_device):
37 | r"""Initializes `torch_utils.training_stats` for collecting statistics
38 | across multiple processes.
39 |
40 | This function must be called after
41 | `torch.distributed.init_process_group()` and before `Collector.update()`.
42 | The call is not necessary if multi-process collection is not needed.
43 |
44 | Args:
45 | rank: Rank of the current process.
46 | sync_device: PyTorch device to use for inter-process
47 | communication, or None to disable multi-process
48 | collection. Typically `torch.device('cuda', rank)`.
49 | """
50 | global _rank, _sync_device
51 | assert not _sync_called
52 | _rank = rank
53 | _sync_device = sync_device
54 |
55 |
56 | # ----------------------------------------------------------------------------
57 |
58 |
59 | @misc.profiled_function
60 | def report(name, value, ts=None, max_t=1, num_bins=4):
61 | r"""Broadcasts the given set of scalars to all interested instances of
62 | `Collector`, across device and process boundaries.
63 |
64 | This function is expected to be extremely cheap and can be safely
65 | called from anywhere in the training loop, loss function, or inside a
66 | `torch.nn.Module`.
67 |
68 | Warning: The current implementation expects the set of unique names to
69 | be consistent across processes. Please make sure that `report()` is
70 | called at least once for each unique name by each process, and in the
71 | same order. If a given process has no scalars to broadcast, it can do
72 | `report(name, [])` (empty list).
73 |
74 | Args:
75 | name: Arbitrary string specifying the name of the statistic.
76 | Averages are accumulated separately for each unique name.
77 | value: Arbitrary set of scalars. Can be a list, tuple,
78 | NumPy array, PyTorch tensor, or Python scalar.
79 |
80 | Returns:
81 | The same `value` that was passed in.
82 | """
83 | value_in = value
84 | quantiles = {f"{name}_q{quartile}": [] for quartile in range(num_bins)}
85 | if ts is not None:
86 | for sub_t, sub_loss in zip(ts.cpu().numpy(), value):
87 | if isinstance(sub_loss, torch.Tensor):
88 | sub_loss = sub_loss.detach().cpu().numpy()
89 |
90 | quartile = int(num_bins * min(sub_t, max_t-1e-3) / max_t)
91 |
92 | quantiles[f"{name}_q{quartile}"].append(sub_loss.item())
93 |
94 | else:
95 | quantiles[name] = value
96 |
97 | for name, value in quantiles.items():
98 | if name not in _counters:
99 | _counters[name] = dict()
100 |
101 | elems = torch.as_tensor(value)
102 | if elems.numel() == 0:
103 | elems = torch.zeros([1], dtype=_reduce_dtype)
104 | moments = torch.stack(
105 | [
106 | torch.zeros_like(elems).sum(),
107 | elems.sum(),
108 | elems.square().sum(),
109 | ]
110 | ).to(_counter_dtype)
111 |
112 | continue
113 | else:
114 |
115 | elems = elems.detach().flatten().to(_reduce_dtype)
116 | moments = torch.stack(
117 | [
118 | torch.ones_like(elems).sum(),
119 | elems.sum(),
120 | elems.square().sum(),
121 | ]
122 | )
123 | assert moments.ndim == 1 and moments.shape[0] == _num_moments
124 | moments = moments.to(_counter_dtype)
125 |
126 | device = moments.device
127 | if device not in _counters[name]:
128 | _counters[name][device] = torch.zeros_like(moments)
129 | _counters[name][device].add_(moments)
130 | return value_in
131 |
132 |
133 | # ----------------------------------------------------------------------------
134 |
135 |
136 | def report0(name, value):
137 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
138 | but ignores any scalars provided by the other processes.
139 | See `report()` for further details.
140 | """
141 |
142 | report(name, value if _rank == 0 else [])
143 | return value
144 |
145 |
146 | # ----------------------------------------------------------------------------
147 |
148 |
149 | class Collector:
150 | r"""Collects the scalars broadcasted by `report()` and `report0()` and
151 | computes their long-term averages (mean and standard deviation) over
152 | user-defined periods of time.
153 |
154 | The averages are first collected into internal counters that are not
155 | directly visible to the user. They are then copied to the user-visible
156 | state as a result of calling `update()` and can then be queried using
157 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
158 | internal counters for the next round, so that the user-visible state
159 | effectively reflects averages collected between the last two calls to
160 | `update()`.
161 |
162 | Args:
163 | regex: Regular expression defining which statistics to
164 | collect. The default is to collect everything.
165 | keep_previous: Whether to retain the previous averages if no
166 | scalars were collected on a given round
167 | (default: True).
168 | """
169 |
170 | def __init__(self, regex=".*", keep_previous=True):
171 | self._regex = re.compile(regex)
172 | self._keep_previous = keep_previous
173 | self._cumulative = dict()
174 | self._moments = dict()
175 | self._moments.clear()
176 |
177 | def names(self):
178 | r"""Returns the names of all statistics broadcasted so far that
179 | match the regular expression specified at construction time.
180 | """
181 | return [name for name in _counters if self._regex.fullmatch(name)]
182 |
183 | def update(
184 | self,
185 | disable_sync=False
186 | ):
187 | r"""Copies current values of the internal counters to the
188 | user-visible state and resets them for the next round.
189 |
190 | If `keep_previous=True` was specified at construction time, the
191 | operation is skipped for statistics that have received no scalars
192 | since the last update, retaining their previous averages.
193 |
194 | This method performs a number of GPU-to-CPU transfers and one
195 | `torch.distributed.all_reduce()`. It is intended to be called
196 | periodically in the main training loop, typically once every
197 | N training steps.
198 | """
199 | if not self._keep_previous:
200 | self._moments.clear()
201 |
202 | for name, cumulative in _sync(self.names(), disable=disable_sync):
203 | if name not in self._cumulative:
204 | self._cumulative[name] = torch.zeros(
205 | [_num_moments], dtype=_counter_dtype
206 | )
207 | delta = cumulative - self._cumulative[name]
208 | self._cumulative[name].copy_(cumulative)
209 | if float(delta[0]) != 0:
210 | self._moments[name] = delta
211 |
212 | def _get_delta(self, name):
213 | r"""Returns the raw moments that were accumulated for the given
214 | statistic between the last two calls to `update()`, or zero if
215 | no scalars were collected.
216 | """
217 | assert self._regex.fullmatch(name)
218 | if name not in self._moments:
219 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
220 | return self._moments[name]
221 |
222 | def num(self, name):
223 | r"""Returns the number of scalars that were accumulated for the given
224 | statistic between the last two calls to `update()`, or zero if
225 | no scalars were collected.
226 | """
227 | delta = self._get_delta(name)
228 | return int(delta[0])
229 |
230 | def mean(self, name):
231 | r"""Returns the mean of the scalars that were accumulated for the
232 | given statistic between the last two calls to `update()`, or NaN if
233 | no scalars were collected.
234 | """
235 | delta = self._get_delta(name)
236 | if int(delta[0]) == 0:
237 | return float("nan")
238 | return float(delta[1] / delta[0])
239 |
240 | def std(self, name):
241 | r"""Returns the standard deviation of the scalars that were
242 | accumulated for the given statistic between the last two calls to
243 | `update()`, or NaN if no scalars were collected.
244 | """
245 | delta = self._get_delta(name)
246 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
247 | return float("nan")
248 | if int(delta[0]) == 1:
249 | return float(0)
250 | mean = float(delta[1] / delta[0])
251 | raw_var = float(delta[2] / delta[0])
252 | return np.sqrt(max(raw_var - np.square(mean), 0))
253 |
254 | def as_dict(self):
255 | r"""Returns the averages accumulated between the last two calls to
256 | `update()` as an `dnnlib.EasyDict`. The contents are as follows:
257 |
258 | dnnlib.EasyDict(
259 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
260 | ...
261 | )
262 | """
263 | stats = dnnlib.EasyDict()
264 | for name in self.names():
265 | stats[name] = dnnlib.EasyDict(
266 | mean=self.mean(name),
267 | )
268 | return stats
269 |
270 | def __getitem__(self, name):
271 | r"""Convenience getter.
272 | `collector[name]` is a synonym for `collector.mean(name)`.
273 | """
274 | return self.mean(name)
275 |
276 |
277 | # ----------------------------------------------------------------------------
278 |
279 |
280 | def _sync(names, disable=False):
281 | r"""Synchronize the global cumulative counters across devices and
282 | processes. Called internally by `Collector.update()`.
283 | """
284 | if len(names) == 0:
285 | return []
286 |
287 | global _sync_called
288 | _sync_called = True
289 |
290 | # Collect deltas within current rank.
291 | deltas = []
292 | device = _sync_device if _sync_device is not None else torch.device("cpu")
293 |
294 | for name in names:
295 |
296 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
297 | for counter in _counters[name].values():
298 | delta.add_(counter.to(device))
299 | counter.copy_(torch.zeros_like(counter))
300 | deltas.append(delta)
301 | deltas = torch.stack(deltas)
302 |
303 | # Sum deltas across ranks.
304 | if _sync_device is not None and not disable:
305 | torch.distributed.all_reduce(deltas)
306 | # Update cumulative values.
307 | deltas = deltas.cpu()
308 | for idx, name in enumerate(names):
309 | if name not in _cumulative:
310 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
311 | _cumulative[name].add_(deltas[idx])
312 |
313 | # Return name-value pairs.
314 | return [(name, _cumulative[name]) for name in names]
315 |
316 |
317 | # ----------------------------------------------------------------------------
318 | # Convenience.
319 |
320 | default_collector = Collector()
321 |
322 | # ----------------------------------------------------------------------------
323 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | # empty
9 |
--------------------------------------------------------------------------------
/training/dit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # GLIDE: https://github.com/openai/glide-text2im
9 | # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10 | # --------------------------------------------------------
11 |
12 | import torch
13 | import torch.nn as nn
14 | import numpy as np
15 | import math
16 | import functools
17 | from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
18 |
19 | from torch_utils import persistence
20 | from einops import repeat
21 |
22 | def modulate(x, shift, scale ):
23 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
24 |
25 | #################################################################################
26 | # Embedding Layers for Timesteps and Class Labels #
27 | #################################################################################
28 |
29 |
30 | @persistence.persistent_class
31 | class FourierEmbedding(torch.nn.Module):
32 | def __init__(self, num_channels, scale=16, **kwargs):
33 | super().__init__()
34 | print("FourierEmbedding scale:", scale)
35 | self.register_buffer("freqs", torch.randn(num_channels // 2) * scale)
36 |
37 | def forward(self, x):
38 | dtype = x.dtype
39 | x = x.to(torch.float64).ger((2 * np.pi * self.freqs.to(torch.float64)) )
40 | x = torch.cat([x.cos(), x.sin()], dim=1).to(dtype)
41 | return x
42 |
43 | @persistence.persistent_class
44 | class TimestepEmbedder(nn.Module):
45 | """
46 | Embeds scalar timesteps into vector representations.
47 | """
48 |
49 | def __init__(self, hidden_size, frequency_embedding_size=256, embedding_type='positional', use_mlp=True, scale=1):
50 | super().__init__()
51 | self.use_mlp = use_mlp
52 | self.hidden_size = hidden_size
53 | if use_mlp:
54 | self.mlp = nn.Sequential(
55 | nn.Linear(frequency_embedding_size, hidden_size , bias=True),
56 | nn.SiLU(),
57 | nn.Linear(hidden_size , hidden_size, bias=True),
58 | )
59 | self.frequency_embedding_size = frequency_embedding_size
60 |
61 | self.embedding_type = embedding_type
62 |
63 | if self.embedding_type == 'fourier':
64 | self.register_buffer("freqs", torch.randn(frequency_embedding_size // 2) * scale)
65 |
66 |
67 | @staticmethod
68 | def positional_timestep_embedding(t, dim, max_period=10000 ):
69 | """
70 | Create sinusoidal timestep embeddings.
71 | :param t: a 1-D Tensor of N indices, one per batch element.
72 | These may be fractional.
73 | :param dim: the dimension of the output.
74 | :param max_period: controls the minimum frequency of the embeddings.
75 | :return: an (N, D) Tensor of positional embeddings.
76 | """
77 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
78 | half = dim // 2
79 | freqs = torch.exp(
80 | -math.log(max_period)
81 | * torch.arange(start=0, end=half, dtype=torch.float64)
82 | / half
83 | ).to(device=t.device)
84 | args = t[:, None].to(torch.float64) * freqs[None]
85 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
86 | if dim % 2:
87 | embedding = torch.cat(
88 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
89 | )
90 | return embedding
91 |
92 | def fourier_timestep_embedding(self, t, ):
93 | x = t.to(torch.float64).ger((2 * np.pi * self.freqs.to(torch.float64)) )
94 | x = torch.cat([x.cos(), x.sin()], dim=1)
95 |
96 | return x
97 |
98 | def forward(self, t ):
99 | if self.embedding_type == 'positional':
100 | t_freq = self.positional_timestep_embedding(t, self.frequency_embedding_size)
101 | elif self.embedding_type == 'fourier':
102 | t_freq = self.fourier_timestep_embedding(t)
103 |
104 | if self.use_mlp:
105 | t_emb = self.mlp(t_freq.to(dtype=t.dtype) )
106 | else:
107 | t_emb = t_freq.to(dtype=t.dtype)
108 | return t_emb
109 |
110 |
111 | @persistence.persistent_class
112 | class LabelEmbedder(nn.Module):
113 | """
114 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
115 | """
116 |
117 | def __init__(self, num_classes, hidden_size, dropout_prob):
118 | super().__init__()
119 | use_cfg_embedding = dropout_prob > 0
120 | self.embedding_table = nn.Embedding(
121 | num_classes + use_cfg_embedding, hidden_size
122 | )
123 | self.num_classes = num_classes
124 | self.dropout_prob = dropout_prob
125 |
126 | def token_drop(self, labels, force_drop_ids=None):
127 | """
128 | Drops labels to enable classifier-free guidance.
129 | """
130 | if force_drop_ids is None:
131 | drop_ids = (
132 | torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
133 | )
134 | else:
135 | drop_ids = force_drop_ids == 1
136 | labels = torch.where(drop_ids, self.num_classes, labels)
137 | return labels
138 |
139 | def forward(self, labels, train, force_drop_ids=None):
140 | use_dropout = self.dropout_prob > 0
141 | if (train and use_dropout) or (force_drop_ids is not None):
142 | labels = self.token_drop(labels, force_drop_ids)
143 | embeddings = self.embedding_table(labels)
144 | return embeddings
145 |
146 |
147 | #################################################################################
148 | # Core DiT Model #
149 | #################################################################################
150 |
151 |
152 |
153 |
154 | @persistence.persistent_class
155 | class DiTBlock(nn.Module):
156 | """
157 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
158 | """
159 |
160 | def __init__(self, hidden_size, num_heads, temb_size, mlp_ratio=4.0, skip=False, dropout=0, **block_kwargs):
161 | super().__init__()
162 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6 )
163 | self.attn = Attention(
164 | hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
165 | )
166 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False , eps=1e-6)
167 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
168 | approx_gelu = lambda: nn.GELU(approximate="tanh")
169 | self.mlp = Mlp(
170 | in_features=hidden_size,
171 | hidden_features=mlp_hidden_dim,
172 | act_layer=approx_gelu,
173 | drop=dropout,
174 | )
175 | self.adaLN_modulation = nn.Sequential(
176 | nn.SiLU(), nn.Linear(temb_size, 6 * hidden_size, bias=True)
177 | )
178 | self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) if skip else None
179 |
180 | def forward(self, x, c, ):
181 |
182 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
183 | (self.adaLN_modulation(c )).chunk(6, dim=1)
184 | )
185 |
186 | x = x + gate_msa.unsqueeze(1) * self.attn(
187 | modulate(self.norm1(x), shift_msa, scale_msa )
188 | )
189 | x = x + gate_mlp.unsqueeze(1) * self.mlp(
190 | modulate(self.norm2(x), shift_mlp, scale_mlp )
191 | )
192 |
193 | return x
194 |
195 |
196 | @persistence.persistent_class
197 | class FinalLayer(nn.Module):
198 | """
199 | The final layer of DiT.
200 | """
201 |
202 | def __init__(self, hidden_size, patch_size, out_channels):
203 | super().__init__()
204 | self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
205 | self.linear = nn.Linear(
206 | hidden_size, patch_size * patch_size * out_channels, bias=True
207 | )
208 | self.adaLN_modulation = nn.Sequential(
209 | nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
210 | )
211 |
212 | def forward(self, x, c):
213 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
214 | x = modulate(self.norm_final(x), shift, scale)
215 | x = self.linear(x)
216 | return x
217 |
218 |
219 | @persistence.persistent_class
220 | class DiT(nn.Module):
221 | """
222 | Diffusion model with a Transformer backbone.
223 | """
224 |
225 | def __init__(
226 | self,
227 | img_resolution,
228 | patch_size=2,
229 | in_channels=4,
230 | hidden_size=1152,
231 | depth=28,
232 | num_heads=16,
233 | mlp_ratio=4.0,
234 | class_dropout_prob=0.,
235 | num_classes=1000,
236 | s_embed=True,
237 | qk_norm=False,
238 | skip=False,
239 | embedding_kwargs={},
240 | temb_mult=1,
241 | dropout=0,
242 | **kwargs
243 | ):
244 | super().__init__()
245 | self.in_channels = in_channels
246 | self.out_channels = in_channels
247 | self.patch_size = patch_size
248 | self.num_heads = num_heads
249 | self.skip = skip
250 | temb_size = hidden_size * temb_mult
251 |
252 | self.s_embed = s_embed
253 | if s_embed:
254 | self.s_embedder = TimestepEmbedder(temb_size, **embedding_kwargs )
255 |
256 | self.x_embedder = PatchEmbed(
257 | img_resolution, patch_size, in_channels, hidden_size, bias=True,
258 | )
259 | self.t_embedder = TimestepEmbedder(temb_size, **embedding_kwargs )
260 | self.y_embedder = LabelEmbedder(num_classes + 1, temb_size, class_dropout_prob)
261 | num_patches = self.x_embedder.num_patches
262 | # Will use fixed sin-cos embedding:
263 | self.pos_embed = nn.Parameter(
264 | torch.zeros(1, num_patches, hidden_size), requires_grad=False
265 | )
266 | self.blocks = nn.ModuleList(
267 | [
268 | DiTBlock(hidden_size, num_heads,temb_size, mlp_ratio=mlp_ratio, qk_norm=qk_norm, dropout=dropout, )
269 | for _ in range(depth)
270 | ]
271 | )
272 | self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
273 | self.initialize_weights()
274 |
275 | def initialize_weights(self):
276 | # Initialize transformer layers:
277 | def _basic_init(module):
278 | if isinstance(module, nn.Linear):
279 | torch.nn.init.xavier_uniform_(module.weight)
280 | if module.bias is not None:
281 | nn.init.constant_(module.bias, 0)
282 | elif isinstance(module, nn.LayerNorm):
283 | if module.bias is not None:
284 | nn.init.constant_(module.bias, 0)
285 | if module.weight is not None:
286 | nn.init.constant_(module.weight, 1.0)
287 |
288 |
289 | self.apply(_basic_init)
290 |
291 | # Initialize (and freeze) pos_embed by sin-cos embedding:
292 | pos_embed = get_2d_sincos_pos_embed(
293 | self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5),
294 | )
295 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
296 |
297 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
298 | w = self.x_embedder.proj.weight.data
299 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
300 | nn.init.constant_(self.x_embedder.proj.bias, 0)
301 |
302 | # Initialize label embedding table:
303 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
304 |
305 | # Initialize timestep embedding MLP:
306 | if self.t_embedder.use_mlp:
307 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
308 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
309 |
310 | if self.s_embed and self.s_embedder.use_mlp:
311 | # Initialize timestep embedding MLP:
312 | nn.init.normal_(self.s_embedder.mlp[0].weight, std=0.02)
313 | nn.init.normal_(self.s_embedder.mlp[2].weight, std=0.02)
314 |
315 | # Zero-out adaLN modulation layers in DiT blocks:
316 | for block in self.blocks:
317 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
318 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
319 | # Zero-out output layers:
320 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
321 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
322 | nn.init.constant_(self.final_layer.linear.weight, 0)
323 | nn.init.constant_(self.final_layer.linear.bias, 0)
324 |
325 |
326 | def unpatchify(self, x):
327 | """
328 | x: (N, T, patch_size**2 * C)
329 | imgs: (N, H, W, C)
330 | """
331 | c = self.out_channels
332 | p = self.x_embedder.patch_size[0]
333 | h = w = int(x.shape[1] ** 0.5)
334 | assert h * w == x.shape[1]
335 |
336 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
337 | x = torch.einsum("nhwpqc->nchpwq", x)
338 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
339 | return imgs
340 |
341 | def forward(self, x,
342 | noise_labels_t,
343 | noise_labels_s=None,
344 | class_labels=None,
345 | **kwargs):
346 | """
347 | Forward pass of DiT.
348 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
349 | t: (N,) tensor of diffusion timesteps
350 | y: (N, num_classes) tensor of one-hot vector labels, all zeros denote unconditional
351 | """
352 | is_uncond = 1 - class_labels.sum(dim=1, keepdims=True) # 1 if unconditional 0 otherwise
353 | y = torch.cat([ class_labels, is_uncond], dim=1)
354 | y = y.argmax(dim=1) # to (N,) tensor of class labels
355 |
356 | x = (
357 | self.x_embedder(x) + self.pos_embed
358 | ) # (N, T, D), where T = H * W / patch_size ** 2
359 | if noise_labels_t.shape[0] == 1:
360 | noise_labels_t = repeat(noise_labels_t, '1 ... -> B ...', B=x.shape[0])
361 |
362 | t = self.t_embedder(noise_labels_t ) # (N, D)
363 | if noise_labels_s is not None and self.s_embed:
364 |
365 | if noise_labels_s.shape[0] == 1:
366 | noise_labels_s = repeat(noise_labels_s, '1 ... -> B ...', B=x.shape[0])
367 |
368 | s = self.s_embedder(noise_labels_s )
369 |
370 | t = t + s
371 |
372 |
373 | y = self.y_embedder(y, self.training) # (N, D)
374 | c = t + y # (N, D)
375 |
376 | for block in self.blocks:
377 | x = block(x, c) # (N, T, D)
378 |
379 |
380 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
381 | x = self.unpatchify(x) # (N, out_channels, H, W)
382 |
383 | return x
384 |
385 |
386 | #################################################################################
387 | # Sine/Cosine Positional Embedding Functions #
388 | #################################################################################
389 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
390 |
391 |
392 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
393 | """
394 | grid_size: int of the grid height and width
395 | return:
396 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
397 | """
398 | grid_h = np.arange(grid_size, dtype=np.float64)
399 | grid_w = np.arange(grid_size, dtype=np.float64)
400 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
401 | grid = np.stack(grid, axis=0)
402 |
403 | grid = grid.reshape([2, 1, grid_size, grid_size])
404 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
405 | if cls_token and extra_tokens > 0:
406 | pos_embed = np.concatenate(
407 | [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
408 | )
409 | return pos_embed
410 |
411 |
412 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
413 | assert embed_dim % 2 == 0
414 |
415 | # use half of dimensions to encode grid_h
416 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
417 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
418 |
419 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
420 | return emb
421 |
422 |
423 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
424 | """
425 | embed_dim: output dimension for each position
426 | pos: a list of positions to be encoded: size (M,)
427 | out: (M, D)
428 | """
429 | assert embed_dim % 2 == 0
430 | omega = np.arange(embed_dim // 2, dtype=np.float64)
431 | omega /= embed_dim / 2.0
432 | omega = 1.0 / 10000**omega # (D/2,)
433 |
434 | pos = pos.reshape(-1) # (M,)
435 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
436 |
437 | emb_sin = np.sin(out) # (M, D/2)
438 | emb_cos = np.cos(out) # (M, D/2)
439 |
440 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
441 | return emb
442 |
443 |
444 | #################################################################################
445 | # DiT Configs #
446 | #################################################################################
447 |
448 |
449 | @persistence.persistent_class
450 | class DiT_XL_2(DiT):
451 | def __init__(self, **kwargs):
452 | super().__init__(patch_size=2, hidden_size=1152, depth=28, num_heads=16, **kwargs)
453 |
454 | @persistence.persistent_class
455 | class DiT_XL_4(DiT):
456 | def __init__(self, **kwargs):
457 | super().__init__(patch_size=4, hidden_size=1152, depth=28, num_heads=16, **kwargs)
458 |
459 | @persistence.persistent_class
460 | class DiT_XL_8(DiT):
461 | def __init__(self, **kwargs):
462 | super().__init__(patch_size=8, hidden_size=1152, depth=28, num_heads=16, **kwargs)
463 |
464 | @persistence.persistent_class
465 | class DiT_L_2(DiT):
466 | def __init__(self, **kwargs):
467 | super().__init__(patch_size=2, hidden_size=1024, depth=24, num_heads=16, **kwargs)
468 |
469 | @persistence.persistent_class
470 | class DiT_L_4(DiT):
471 | def __init__(self, **kwargs):
472 | super().__init__(patch_size=4, hidden_size=1024, depth=24, num_heads=16, **kwargs)
473 |
474 | @persistence.persistent_class
475 | class DiT_L_8(DiT):
476 | def __init__(self, **kwargs):
477 | super().__init__(patch_size=8, hidden_size=1024, depth=24, num_heads=16, **kwargs)
478 |
479 | @persistence.persistent_class
480 | class DiT_B_2(DiT):
481 | def __init__(self, **kwargs):
482 | super().__init__(patch_size=2, hidden_size=768, depth=12, num_heads=12, **kwargs)
483 |
484 | @persistence.persistent_class
485 | class DiT_B_4(DiT):
486 | def __init__(self, **kwargs):
487 | super().__init__(patch_size=4, hidden_size=768, depth=12, num_heads=12, **kwargs)
488 |
489 | @persistence.persistent_class
490 | class DiT_B_8(DiT):
491 | def __init__(self, **kwargs):
492 | super().__init__(patch_size=8, hidden_size=768, depth=12, num_heads=12, **kwargs)
493 |
494 | @persistence.persistent_class
495 | class DiT_S_2(DiT):
496 | def __init__(self, **kwargs):
497 | super().__init__(patch_size=2, hidden_size=384, depth=12, num_heads=6, **kwargs)
498 |
499 | @persistence.persistent_class
500 | class DiT_S_4(DiT):
501 | def __init__(self, **kwargs):
502 | super().__init__(patch_size=4, hidden_size=384, depth=12, num_heads=6, **kwargs)
503 |
504 | @persistence.persistent_class
505 | class DiT_S_8(DiT):
506 | def __init__(self, **kwargs):
507 | super().__init__(patch_size=8, hidden_size=384, depth=12, num_heads=6, **kwargs)
508 |
509 |
--------------------------------------------------------------------------------
/training/encoders.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Converting between pixel and latent representations of image data."""
9 |
10 | import os
11 | import warnings
12 | import numpy as np
13 | import torch
14 | from torch_utils import persistence
15 | from torch_utils import misc
16 |
17 | warnings.filterwarnings('ignore', 'torch.utils._pytree._register_pytree_node is deprecated.')
18 | warnings.filterwarnings('ignore', '`resume_download` is deprecated')
19 |
20 | #----------------------------------------------------------------------------
21 | # Abstract base class for encoders/decoders that convert back and forth
22 | # between pixel and latent representations of image data.
23 | #
24 | # Logically, "raw pixels" are first encoded into "raw latents" that are
25 | # then further encoded into "final latents". Decoding, on the other hand,
26 | # goes directly from the final latents to raw pixels. The final latents are
27 | # used as inputs and outputs of the model, whereas the raw latents are
28 | # stored in the dataset. This separation provides added flexibility in terms
29 | # of performing just-in-time adjustments, such as data whitening, without
30 | # having to construct a new dataset.
31 | #
32 | # All image data is represented as PyTorch tensors in NCHW order.
33 | # Raw pixels are represented as 3-channel uint8.
34 |
35 | @persistence.persistent_class
36 | class Encoder:
37 | def __init__(self):
38 | pass
39 |
40 | def init(self, device): # force lazy init to happen now
41 | pass
42 |
43 | def __getstate__(self):
44 | return self.__dict__
45 |
46 | def encode(self, x): # raw pixels => final latents
47 | return self.encode_latents(self.encode_pixels(x))
48 |
49 | def encode_pixels(self, x): # raw pixels => raw latents
50 | raise NotImplementedError # to be overridden by subclass
51 |
52 | def encode_latents(self, x): # raw latents => final latents
53 | raise NotImplementedError # to be overridden by subclass
54 |
55 | def decode(self, x): # final latents => raw pixels
56 | raise NotImplementedError # to be overridden by subclass
57 |
58 | #----------------------------------------------------------------------------
59 | # Standard RGB encoder that scales the pixel data into [-1, +1].
60 |
61 | @persistence.persistent_class
62 | class IdentityEncoder(Encoder):
63 | def __init__(self):
64 | super().__init__()
65 |
66 | def encode_pixels(self, x): # raw pixels => raw latents
67 | return x
68 |
69 | def encode_latents(self, x): # raw latents => final latents
70 | return x
71 | def encode(self, x):
72 | return x
73 | def decode(self, x): # final latents => raw pixels
74 | return x
75 | #----------------------------------------------------------------------------
76 | # Standard RGB encoder that scales the pixel data into [-1, +1].
77 |
78 | @persistence.persistent_class
79 | class StandardRGBEncoder(Encoder):
80 | def __init__(self):
81 | super().__init__()
82 |
83 | def encode_pixels(self, x): # raw pixels => raw latents
84 | return x
85 |
86 | def encode_latents(self, x): # raw latents => final latents
87 | return x.to(torch.float32) / 127.5 - 1
88 |
89 | def encode(self, x):
90 | return self.encode_latents(self.encode_pixels(x))
91 |
92 |
93 | def decode(self, x): # final latents => raw pixels
94 |
95 | return (x.to(torch.float32) * 127.5 + 128).clip(0, 255).to(torch.uint8)
96 |
97 | #----------------------------------------------------------------------------
98 | # Pre-trained VAE encoder from Stability AI.
99 |
100 | @persistence.persistent_class
101 | class StabilityVAEEncoder(Encoder):
102 | def __init__(self,
103 | vae_name = 'stabilityai/sd-vae-ft-mse', # Name of the VAE to use.
104 | raw_mean = [1.56, -0.695, 0.483, 0.729], # Assumed mean of the raw latents.
105 | raw_std = [5.27, 5.91, 4.21, 4.31], # Assumed standard deviation of the raw latents.
106 | final_mean = 0, # Desired mean of the final latents.
107 | final_std = 0.5, # Desired standard deviation of the final latents.
108 | batch_size = 8, # Batch size to use when running the VAE.
109 | use_fp16 = False, # Data type to use for the latents.
110 | ):
111 | super().__init__()
112 | self.vae_name = vae_name
113 | self.scale = np.float32(final_std) / np.float32(raw_std)
114 | self.bias = np.float32(final_mean) - np.float32(raw_mean) * self.scale
115 | self.batch_size = int(batch_size)
116 | self._vae = None
117 | self.dtype = torch.float16 if use_fp16 else torch.float32
118 |
119 | def init(self, device): # force lazy init to happen now
120 | super().init(device)
121 | if self._vae is None:
122 | self._vae = load_stability_vae(self.vae_name, device=device, dtype=self.dtype)
123 | else:
124 | self._vae.to(device)
125 |
126 | def __getstate__(self):
127 | return dict(super().__getstate__(), _vae=None) # do not pickle the vae
128 |
129 | def _run_vae_encoder(self, x):
130 | dtype = x.dtype
131 | d = self._vae.encode(x.to(self.dtype))['latent_dist']
132 | return torch.cat([d.mean, d.std], dim=1).to(dtype)
133 |
134 | def _run_vae_decoder(self, x):
135 | dtype = x.dtype
136 | return self._vae.decode(x.to(self.dtype))['sample'].to(dtype)
137 |
138 | def encode_pixels(self, x): # raw pixels => raw latents
139 | self.init(x.device)
140 | x = x.to(torch.float32) / 127.5 - 1
141 | x = torch.cat([self._run_vae_encoder(batch) for batch in x.split(self.batch_size)])
142 | return x
143 | def encode_latents(self, x): # raw latents => final latents
144 | mean, std = x.to(torch.float32).chunk(2, dim=1)
145 | x = mean + torch.randn_like(mean) * std
146 | x = x * misc.const_like(x, self.scale).reshape(1, -1, 1, 1)
147 | x = x + misc.const_like(x, self.bias).reshape(1, -1, 1, 1)
148 | return x
149 |
150 | def encode(self, x):
151 | if x.shape[1] == 2 * 4:
152 | return self.encode_latents(x)
153 | elif x.shape[1] == 3:
154 | return self.encode_latents(self.encode_pixels(x))
155 | else:
156 | raise ValueError(f'Invalid number of channels: {x.shape[1]}')
157 |
158 | def decode_latents_to_pixels(self, x):
159 | self.init(x.device)
160 | x = x.to(torch.float32)
161 | x = x - misc.const_like(x, self.bias).reshape(1, -1, 1, 1)
162 | x = x / misc.const_like(x, self.scale).reshape(1, -1, 1, 1)
163 | x = torch.cat([self._run_vae_decoder(batch) for batch in x.split(self.batch_size)])
164 | x = (x * 0.5 + 0.5).clamp(0,1).mul(255).to(torch.uint8)
165 | return x
166 |
167 | def decode(self, x): # final latents => raw pixels
168 | if x.shape[1] == 2 * 4:
169 | mean, std = x.to(torch.float32).chunk(2, dim=1)
170 | x = mean + torch.randn_like(mean) * std
171 | return self.decode_latents_to_pixels(x)
172 | elif x.shape[1] == 4:
173 | return self.decode_latents_to_pixels(x)
174 | elif x.shape[1] == 3:
175 | return x
176 | else:
177 | raise ValueError(f'Invalid number of channels: {x.shape[1]}')
178 | #----------------------------------------------------------------------------
179 |
180 | def load_stability_vae(vae_name='stabilityai/sd-vae-ft-mse', device=torch.device('cpu'), dtype=torch.float32):
181 | import dnnlib
182 | cache_dir = dnnlib.make_cache_dir_path('diffusers')
183 | os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'
184 | os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'
185 | os.environ['HF_HOME'] = cache_dir
186 |
187 | import diffusers # pip install diffusers # pyright: ignore [reportMissingImports]
188 | try:
189 | # First try with local_files_only to avoid consulting tfhub metadata if the model is already in cache.
190 | vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, cache_dir=cache_dir, local_files_only=True, torch_dtype=dtype)
191 | except:
192 | # Could not load the model from cache; try without local_files_only.
193 | vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, cache_dir=cache_dir, torch_dtype=dtype)
194 | return torch.compile(vae.eval().requires_grad_(False).to(device), mode="max-autotune", fullgraph=True)
195 |
196 | #----------------------------------------------------------------------------
--------------------------------------------------------------------------------
/training/preconds.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Model architectures and preconditioning schemes used in the paper
9 | "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10 |
11 | import numpy as np
12 | import torch
13 | from torch_utils import persistence
14 | from training.unets import *
15 | from training.dit import *
16 |
17 |
18 | @persistence.persistent_class
19 | class IMMPrecond(torch.nn.Module):
20 |
21 | def __init__(
22 | self,
23 | img_resolution, # Image resolution.
24 | img_channels, # Number of color channels.
25 | label_dim=0, # Number of class labels, 0 = unconditional.
26 | mixed_precision=None,
27 | noise_schedule="fm",
28 | model_type="SongUNet",
29 | sigma_data=0.5,
30 | f_type="euler_fm",
31 | T=0.994,
32 | eps=0.,
33 | temb_type='identity',
34 | time_scale=1000.,
35 | **model_kwargs, # Keyword arguments for the underlying model.
36 | ):
37 | super().__init__()
38 |
39 |
40 | self.img_resolution = img_resolution
41 | self.img_channels = img_channels
42 |
43 | self.label_dim = label_dim
44 | self.use_mixed_precision = mixed_precision is not None
45 | if mixed_precision == 'bf16':
46 | self.mixed_precision = torch.bfloat16
47 | elif mixed_precision == 'fp16':
48 | self.mixed_precision = torch.float16
49 | elif mixed_precision is None:
50 | self.mixed_precision = torch.float32
51 | else:
52 | raise ValueError(f"Unknown mixed_precision: {mixed_precision}")
53 |
54 |
55 | self.noise_schedule = noise_schedule
56 |
57 | self.T = T
58 | self.eps = eps
59 |
60 | self.sigma_data = sigma_data
61 |
62 | self.f_type = f_type
63 |
64 | self.nt_low = self.get_log_nt(torch.tensor(self.eps, dtype=torch.float64)).exp().numpy().item()
65 | self.nt_high = self.get_log_nt(torch.tensor(self.T, dtype=torch.float64)).exp().numpy().item()
66 |
67 | self.model = globals()[model_type](
68 | img_resolution=img_resolution,
69 | img_channels=img_channels,
70 | in_channels=img_channels,
71 | out_channels=img_channels,
72 | label_dim=label_dim,
73 | **model_kwargs,
74 | )
75 | print('# Mparams:', sum(p.numel() for p in self.model.parameters()) / 1000000)
76 |
77 | self.time_scale = time_scale
78 |
79 |
80 | self.temb_type = temb_type
81 |
82 | if self.f_type == 'euler_fm':
83 | assert self.noise_schedule == 'fm'
84 |
85 |
86 | def get_logsnr(self, t):
87 | dtype = t.dtype
88 | t = t.to(torch.float64)
89 | if self.noise_schedule == "vp_cosine":
90 | logsnr = -2 * torch.log(torch.tan(t * torch.pi * 0.5))
91 |
92 | elif self.noise_schedule == "fm":
93 | logsnr = 2 * ((1 - t).log() - t.log())
94 |
95 | logsnr = logsnr.to(dtype)
96 | return logsnr
97 |
98 | def get_log_nt(self, t):
99 | logsnr_t = self.get_logsnr(t)
100 |
101 | return -0.5 * logsnr_t
102 |
103 | def get_alpha_sigma(self, t):
104 | if self.noise_schedule == 'fm':
105 | alpha_t = (1 - t)
106 | sigma_t = t
107 | elif self.noise_schedule == 'vp_cosine':
108 | alpha_t = torch.cos(t * torch.pi * 0.5)
109 | sigma_t = torch.sin(t * torch.pi * 0.5)
110 |
111 | return alpha_t, sigma_t
112 |
113 | def add_noise(self, y, t, noise=None):
114 |
115 | if noise is None:
116 | noise = torch.randn_like(y) * self.sigma_data
117 |
118 | alpha_t, sigma_t = self.get_alpha_sigma(t)
119 |
120 | return alpha_t * y + sigma_t * noise, noise
121 |
122 | def ddim(self, yt, y, t, s, noise=None):
123 | alpha_t, sigma_t = self.get_alpha_sigma(t)
124 | alpha_s, sigma_s = self.get_alpha_sigma(s)
125 |
126 |
127 | if noise is None:
128 | ys = (alpha_s - alpha_t * sigma_s / sigma_t) * y + sigma_s / sigma_t * yt
129 | else:
130 | ys = alpha_s * y + sigma_s * noise
131 | return ys
132 |
133 |
134 |
135 | def simple_edm_sample_function(self, yt, y, t, s ):
136 | alpha_t, sigma_t = self.get_alpha_sigma(t)
137 | alpha_s, sigma_s = self.get_alpha_sigma(s)
138 |
139 | c_skip = (alpha_t * alpha_s + sigma_t * sigma_s) / (alpha_t**2 + sigma_t**2)
140 |
141 | c_out = - (alpha_s * sigma_t - alpha_t * sigma_s) * (alpha_t**2 + sigma_t**2).rsqrt() * self.sigma_data
142 |
143 | return c_skip * yt + c_out * y
144 |
145 | def euler_fm_sample_function(self, yt, y, t, s ):
146 | assert self.noise_schedule == 'fm'
147 |
148 |
149 | return yt - (t - s) * self.sigma_data * y
150 |
151 | def nt_to_t(self, nt):
152 | dtype = nt.dtype
153 | nt = nt.to(torch.float64)
154 | if self.noise_schedule == "vp_cosine":
155 | t = torch.arctan(nt) / (torch.pi * 0.5)
156 |
157 | elif self.noise_schedule == "fm":
158 | t = nt / (1 + nt)
159 |
160 | t = torch.nan_to_num(t, nan=1)
161 |
162 | t = t.to(dtype)
163 |
164 |
165 | if (
166 | self.noise_schedule.startswith("vp")
167 | and self.noise_schedule == "fm"
168 | and t.max() > 1
169 | ):
170 | raise ValueError(f"t out of range: {t.min().item()}, {t.max().item()}")
171 | return t
172 |
173 | def get_init_noise(self, shape, device):
174 |
175 | noise = torch.randn(shape, device=device) * self.sigma_data
176 | return noise
177 |
178 | def forward_model(
179 | self,
180 | model,
181 | x,
182 | t,
183 | s,
184 | class_labels=None,
185 | force_fp32=False,
186 | **model_kwargs,
187 | ):
188 |
189 |
190 |
191 | alpha_t, sigma_t = self.get_alpha_sigma(t)
192 |
193 | c_in = (alpha_t ** 2 + sigma_t**2 ).rsqrt() / self.sigma_data
194 | if self.temb_type == 'identity':
195 |
196 | c_noise_t = t * self.time_scale
197 | c_noise_s = s * self.time_scale
198 |
199 | elif self.temb_type == 'stride':
200 |
201 | c_noise_t = t * self.time_scale
202 | c_noise_s = (t - s) * self.time_scale
203 |
204 | with torch.amp.autocast('cuda', enabled=self.use_mixed_precision and not force_fp32, dtype= self.mixed_precision ):
205 | F_x = model(
206 | (c_in * x) ,
207 | c_noise_t.flatten() ,
208 | c_noise_s.flatten() ,
209 | class_labels=class_labels,
210 | **model_kwargs,
211 | )
212 | return F_x
213 |
214 |
215 | def forward(
216 | self,
217 | x,
218 | t,
219 | s=None,
220 | class_labels=None,
221 | force_fp32=False,
222 | **model_kwargs,
223 | ):
224 | dtype = t.dtype
225 | class_labels = (
226 | None
227 | if self.label_dim == 0
228 | else (
229 | torch.zeros([1, self.label_dim], device=x.device)
230 | if class_labels is None
231 | else class_labels.to(torch.float32).reshape(-1, self.label_dim)
232 | )
233 | )
234 |
235 | F_x = self.forward_model(
236 | self.model,
237 | x.to(torch.float32),
238 | t.to(torch.float32).reshape(-1, 1, 1, 1),
239 | s.to(torch.float32).reshape(-1, 1, 1, 1) if s is not None else None,
240 | class_labels,
241 | force_fp32,
242 | **model_kwargs,
243 | )
244 | F_x = F_x.to(dtype)
245 |
246 | if self.f_type == "identity":
247 | F_x = self.ddim(x, F_x , t, s)
248 | elif self.f_type == "simple_edm":
249 | F_x = self.simple_edm_sample_function(x, F_x , t, s)
250 | elif self.f_type == "euler_fm":
251 | F_x = self.euler_fm_sample_function(x, F_x, t, s)
252 | else:
253 | raise NotImplementedError
254 |
255 | return F_x
256 |
257 | def cfg_forward(
258 | self,
259 | x,
260 | t,
261 | s=None,
262 | class_labels=None,
263 | force_fp32=False,
264 | cfg_scale=None,
265 | **model_kwargs,
266 | ):
267 | dtype = t.dtype
268 | class_labels = (
269 | None
270 | if self.label_dim == 0
271 | else (
272 | torch.zeros([1, self.label_dim], device=x.device)
273 | if class_labels is None
274 | else class_labels.to(torch.float32).reshape(-1, self.label_dim)
275 | )
276 | )
277 | if cfg_scale is not None:
278 |
279 | x_cfg = torch.cat([x, x], dim=0)
280 | class_labels = torch.cat([torch.zeros_like(class_labels), class_labels], dim=0)
281 | else:
282 | x_cfg = x
283 | F_x = self.forward_model(
284 | self.model,
285 | x_cfg.to(torch.float32),
286 | t.to(torch.float32).reshape(-1, 1, 1, 1) ,
287 | s.to(torch.float32).reshape(-1, 1, 1, 1) if s is not None else None,
288 | class_labels=class_labels,
289 | force_fp32=force_fp32,
290 | **model_kwargs,
291 | )
292 | F_x = F_x.to(dtype)
293 |
294 | if cfg_scale is not None:
295 | uncond_F = F_x[:len(x) ]
296 | cond_F = F_x[len(x) :]
297 |
298 | F_x = uncond_F + cfg_scale * (cond_F - uncond_F)
299 |
300 | if self.f_type == "identity":
301 | F_x = self.ddim(x, F_x, t, s)
302 | elif self.f_type == "simple_edm":
303 | F_x = self.simple_edm_sample_function(x, F_x , t, s)
304 | elif self.f_type == "euler_fm":
305 | F_x = self.euler_fm_sample_function(x, F_x , t, s)
306 | else:
307 | raise NotImplementedError
308 |
309 | return F_x
--------------------------------------------------------------------------------
/training/unets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # This work is licensed under a Creative Commons
4 | # Attribution-NonCommercial-ShareAlike 4.0 International License.
5 | # You should have received a copy of the license along with this
6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7 |
8 | """Model architectures and preconditioning schemes used in the paper
9 | "Elucidating the Design Space of Diffusion-Based Generative Models"."""
10 |
11 | import numpy as np
12 | import torch
13 | from torch_utils import persistence
14 | from torch.nn.functional import silu
15 |
16 | # ----------------------------------------------------------------------------
17 | # Unified routine for initializing weights and biases.
18 |
19 |
20 | def weight_init(shape, mode, fan_in, fan_out):
21 | if mode == "xavier_uniform":
22 | return np.sqrt(6 / (fan_in + fan_out)) * (torch.rand(*shape) * 2 - 1)
23 | if mode == "xavier_normal":
24 | return np.sqrt(2 / (fan_in + fan_out)) * torch.randn(*shape)
25 | if mode == "kaiming_uniform":
26 | return np.sqrt(3 / fan_in) * (torch.rand(*shape) * 2 - 1)
27 | if mode == "kaiming_normal":
28 | return np.sqrt(1 / fan_in) * torch.randn(*shape)
29 | raise ValueError(f'Invalid init mode "{mode}"')
30 |
31 |
32 | # ----------------------------------------------------------------------------
33 | # Fully-connected layer.
34 |
35 |
36 | @persistence.persistent_class
37 | class Linear(torch.nn.Module):
38 | def __init__(
39 | self,
40 | in_features,
41 | out_features,
42 | bias=True,
43 | init_mode="kaiming_normal",
44 | init_weight=1,
45 | init_bias=0,
46 | ):
47 | super().__init__()
48 | self.in_features = in_features
49 | self.out_features = out_features
50 | init_kwargs = dict(mode=init_mode, fan_in=in_features, fan_out=out_features)
51 | self.weight = torch.nn.Parameter(
52 | weight_init([out_features, in_features], **init_kwargs) * init_weight
53 | )
54 | self.bias = (
55 | torch.nn.Parameter(weight_init([out_features], **init_kwargs) * init_bias)
56 | if bias
57 | else None
58 | )
59 |
60 | def forward(self, x):
61 | x = x @ self.weight.to(x.dtype).t()
62 | if self.bias is not None:
63 | x = x.add_(self.bias.to(x.dtype))
64 | return x
65 |
66 |
67 | # ----------------------------------------------------------------------------
68 | # Convolutional layer with optional up/downsampling.
69 |
70 |
71 | @persistence.persistent_class
72 | class Conv2d(torch.nn.Module):
73 | def __init__(
74 | self,
75 | in_channels,
76 | out_channels,
77 | kernel,
78 | bias=True,
79 | up=False,
80 | down=False,
81 | resample_filter=[1, 1],
82 | fused_resample=False,
83 | init_mode="kaiming_normal",
84 | init_weight=1,
85 | init_bias=0,
86 | ):
87 | assert not (up and down)
88 | super().__init__()
89 | self.in_channels = in_channels
90 | self.out_channels = out_channels
91 | self.up = up
92 | self.down = down
93 | self.fused_resample = fused_resample
94 | init_kwargs = dict(
95 | mode=init_mode,
96 | fan_in=in_channels * kernel * kernel,
97 | fan_out=out_channels * kernel * kernel,
98 | )
99 | self.weight = (
100 | torch.nn.Parameter(
101 | weight_init([out_channels, in_channels, kernel, kernel], **init_kwargs)
102 | * init_weight
103 | )
104 | if kernel
105 | else None
106 | )
107 | self.bias = (
108 | torch.nn.Parameter(weight_init([out_channels], **init_kwargs) * init_bias)
109 | if kernel and bias
110 | else None
111 | )
112 | f = torch.as_tensor(resample_filter, dtype=torch.float32)
113 | f = f.ger(f).unsqueeze(0).unsqueeze(1) / f.sum().square()
114 | self.register_buffer("resample_filter", f if up or down else None)
115 |
116 | def forward(self, x):
117 | w = self.weight.to(x.dtype) if self.weight is not None else None
118 | b = self.bias.to(x.dtype) if self.bias is not None else None
119 | f = (
120 | self.resample_filter.to(x.dtype)
121 | if self.resample_filter is not None
122 | else None
123 | )
124 | w_pad = w.shape[-1] // 2 if w is not None else 0
125 | f_pad = (f.shape[-1] - 1) // 2 if f is not None else 0
126 |
127 | if self.fused_resample and self.up and w is not None:
128 | x = torch.nn.functional.conv_transpose2d(
129 | x,
130 | f.mul(4).tile([self.in_channels, 1, 1, 1]),
131 | groups=self.in_channels,
132 | stride=2,
133 | padding=max(f_pad - w_pad, 0),
134 | )
135 | x = torch.nn.functional.conv2d(x, w, padding=max(w_pad - f_pad, 0))
136 | elif self.fused_resample and self.down and w is not None:
137 | x = torch.nn.functional.conv2d(x, w, padding=w_pad + f_pad)
138 | x = torch.nn.functional.conv2d(
139 | x,
140 | f.tile([self.out_channels, 1, 1, 1]),
141 | groups=self.out_channels,
142 | stride=2,
143 | )
144 | else:
145 | if self.up:
146 | x = torch.nn.functional.conv_transpose2d(
147 | x,
148 | f.mul(4).tile([self.in_channels, 1, 1, 1]),
149 | groups=self.in_channels,
150 | stride=2,
151 | padding=f_pad,
152 | )
153 | if self.down:
154 | x = torch.nn.functional.conv2d(
155 | x,
156 | f.tile([self.in_channels, 1, 1, 1]),
157 | groups=self.in_channels,
158 | stride=2,
159 | padding=f_pad,
160 | )
161 | if w is not None:
162 | x = torch.nn.functional.conv2d(x, w, padding=w_pad)
163 | if b is not None:
164 | x = x.add_(b.reshape(1, -1, 1, 1))
165 | return x
166 |
167 |
168 | # ----------------------------------------------------------------------------
169 | # Group normalization.
170 |
171 |
172 | @persistence.persistent_class
173 | class GroupNorm(torch.nn.Module):
174 | def __init__(self, num_channels, num_groups=32, min_channels_per_group=4, eps=1e-5):
175 | super().__init__()
176 | self.num_groups = min(num_groups, num_channels // min_channels_per_group)
177 | self.eps = eps
178 | self.weight = torch.nn.Parameter(torch.ones(num_channels))
179 | self.bias = torch.nn.Parameter(torch.zeros(num_channels))
180 |
181 | def forward(self, x, *args, **kwargs):
182 | x = torch.nn.functional.group_norm(
183 | x,
184 | num_groups=self.num_groups,
185 | weight=self.weight.to(x.dtype),
186 | bias=self.bias.to(x.dtype),
187 | eps=self.eps,
188 | )
189 | return x
190 |
191 |
192 |
193 | # ----------------------------------------------------------------------------
194 | # Attention weight computation, i.e., softmax(Q^T * K).
195 | # Performs all computation using FP32, but uses the original datatype for
196 | # inputs/outputs/gradients to conserve memory.
197 |
198 |
199 | class AttentionOp(torch.autograd.Function):
200 | @staticmethod
201 | def forward(q, k):
202 | w = (
203 | torch.einsum(
204 | "ncq,nck->nqk",
205 | q.to(torch.float32),
206 | (k / np.sqrt(k.shape[1])).to(torch.float32),
207 | )
208 | .softmax(dim=2)
209 | .to(q.dtype)
210 | )
211 | return w
212 |
213 | @staticmethod
214 | def setup_context(ctx, inputs, outputs):
215 | q,k = inputs
216 | w = outputs
217 | ctx.save_for_backward(q, k, w)
218 | # ctx.w = w
219 |
220 | @staticmethod
221 | def backward(ctx, dw):
222 | q, k, w = ctx.saved_tensors
223 | db = torch._softmax_backward_data(
224 | grad_output=dw.to(torch.float32),
225 | output=w.to(torch.float32),
226 | dim=2,
227 | input_dtype=torch.float32,
228 | )
229 | dq = torch.einsum("nck,nqk->ncq", k.to(torch.float32), db).to(
230 | q.dtype
231 | ) / np.sqrt(k.shape[1])
232 | dk = torch.einsum("ncq,nqk->nck", q.to(torch.float32), db).to(
233 | k.dtype
234 | ) / np.sqrt(k.shape[1])
235 | return dq, dk
236 |
237 |
238 | @persistence.persistent_class
239 | class Attention(torch.nn.Module):
240 | def forward(self, q, k):
241 | w = (
242 | torch.einsum(
243 | "ncq,nck->nqk",
244 | q.to(torch.float32),
245 | (k / np.sqrt(k.shape[1])).to(torch.float32),
246 | )
247 | .softmax(dim=2)
248 | .to(q.dtype)
249 | )
250 | return w
251 |
252 |
253 | # ----------------------------------------------------------------------------
254 | # Unified U-Net block with optional up/downsampling and self-attention.
255 | # Represents the union of all features employed by the DDPM++, NCSN++, and
256 | # ADM architectures.
257 |
258 |
259 |
260 | @persistence.persistent_class
261 | class UNetBlock(torch.nn.Module):
262 | def __init__(self,
263 | in_channels, out_channels, emb_channels, up=False, down=False, attention=False,
264 | num_heads=None, channels_per_head=64, dropout=0, skip_scale=1, eps=1e-5,
265 | resample_filter=[1,1], resample_proj=False, adaptive_scale=True,
266 | init=dict(), init_zero=dict(init_weight=0), init_attn=None,
267 | ):
268 | super().__init__()
269 | self.in_channels = in_channels
270 | self.out_channels = out_channels
271 | self.emb_channels = emb_channels
272 | self.num_heads = 0 if not attention else num_heads if num_heads is not None else out_channels // channels_per_head
273 | self.dropout = dropout
274 | self.skip_scale = skip_scale
275 | self.adaptive_scale = adaptive_scale
276 |
277 | self.norm0 = GroupNorm(num_channels=in_channels, eps=eps)
278 | self.conv0 = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=3, up=up, down=down, resample_filter=resample_filter, **init)
279 | self.affine = Linear(in_features=emb_channels, out_features=out_channels*(2 if adaptive_scale else 1), **init)
280 | self.norm1 = GroupNorm(num_channels=out_channels, eps=eps)
281 | self.conv1 = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=3, **init_zero)
282 |
283 | self.skip = None
284 | if out_channels != in_channels or up or down:
285 | kernel = 1 if resample_proj or out_channels!= in_channels else 0
286 | self.skip = Conv2d(in_channels=in_channels, out_channels=out_channels, kernel=kernel, up=up, down=down, resample_filter=resample_filter, **init)
287 |
288 | if self.num_heads:
289 | self.norm2 = GroupNorm(num_channels=out_channels, eps=eps)
290 | self.qkv = Conv2d(in_channels=out_channels, out_channels=out_channels*3, kernel=1, **(init_attn if init_attn is not None else init))
291 | self.proj = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=1, **init_zero)
292 |
293 | def forward(self, x, emb):
294 | orig = x
295 | x = self.conv0(silu(self.norm0(x)))
296 |
297 | params = self.affine(emb).unsqueeze(2).unsqueeze(3).to(x.dtype)
298 | if self.adaptive_scale:
299 | scale, shift = params.chunk(chunks=2, dim=1)
300 | x = silu(torch.addcmul(shift, self.norm1(x), scale + 1))
301 | else:
302 | x = silu(self.norm1(x.add_(params)))
303 |
304 | x = self.conv1(torch.nn.functional.dropout(x, p=self.dropout, training=self.training))
305 | x = x.add_(self.skip(orig) if self.skip is not None else orig)
306 | x = x * self.skip_scale
307 |
308 | if self.num_heads:
309 | q, k, v = self.qkv(self.norm2(x)).reshape(x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1).unbind(2)
310 | w = AttentionOp.apply(q, k)
311 | a = torch.einsum('nqk,nck->ncq', w, v)
312 | x = self.proj(a.reshape(*x.shape)).add_(x)
313 | x = x * self.skip_scale
314 | return x
315 |
316 |
317 | # ----------------------------------------------------------------------------
318 | # Timestep embedding used in the DDPM++ and ADM architectures.
319 |
320 |
321 | @persistence.persistent_class
322 | class PositionalEmbedding(torch.nn.Module):
323 | def __init__(self, num_channels, max_positions=10000, endpoint=False):
324 | super().__init__()
325 | self.num_channels = num_channels
326 | self.max_positions = max_positions
327 | self.endpoint = endpoint
328 |
329 | def forward(self, x):
330 | freqs = torch.arange(
331 | start=0, end=self.num_channels // 2, dtype=torch.float64, device=x.device
332 | )
333 | freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
334 | freqs = (1 / self.max_positions) ** freqs
335 | x = x.ger(freqs )
336 | x = torch.cat([x.cos(), x.sin()], dim=1).to(x.dtype)
337 | return x
338 |
339 |
340 | # ----------------------------------------------------------------------------
341 | # Timestep embedding used in the NCSN++ architecture.
342 |
343 |
344 | @persistence.persistent_class
345 | class FourierEmbedding(torch.nn.Module):
346 | def __init__(self, num_channels, scale=0.02, learnable=False, **kwargs):
347 | super().__init__()
348 | print("FourierEmbedding scale:", scale)
349 | if learnable:
350 | self.freqs = torch.nn.Parameter(torch.randn(num_channels // 2) * scale)
351 | else:
352 | self.register_buffer("freqs", torch.randn(num_channels // 2) * scale)
353 |
354 | def forward(self, x):
355 | x = x.ger((2 * np.pi * self.freqs) )
356 | x = torch.cat([x.cos(), x.sin()], dim=1).to(x.dtype)
357 | return x
358 |
359 |
360 |
361 |
362 | # ----------------------------------------------------------------------------
363 | # Reimplementation of the DDPM++ and NCSN++ architectures from the paper
364 | # "Score-Based Generative Modeling through Stochastic Differential
365 | # Equations". Equivalent to the original implementation by Song et al.,
366 | # available at https://github.com/yang-song/score_sde_pytorch
367 |
368 |
369 | @persistence.persistent_class
370 | class SongUNet(torch.nn.Module):
371 | def __init__(self,
372 | img_resolution, # Image resolution at input/output.
373 | in_channels, # Number of color channels at input.
374 | out_channels, # Number of color channels at output.
375 | label_dim = 0, # Number of class labels, 0 = unconditional.
376 | augment_dim = 0, # Augmentation label dimensionality, 0 = no augmentation.
377 |
378 | model_channels = 128, # Base multiplier for the number of channels.
379 | channel_mult = [1,2,2,2], # Per-resolution multipliers for the number of channels.
380 | channel_mult_emb = 4, # Multiplier for the dimensionality of the embedding vector.
381 | num_blocks = 4, # Number of residual blocks per resolution.
382 | attn_resolutions = [16], # List of resolutions with self-attention.
383 | dropout = 0.10, # Dropout probability of intermediate activations.
384 | label_dropout = 0, # Dropout probability of class labels for classifier-free guidance.
385 |
386 | embedding_type = 'positional', # Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++.
387 | channel_mult_noise = 1, # Timestep embedding size: 1 for DDPM++, 2 for NCSN++.
388 | encoder_type = 'standard', # Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++.
389 | decoder_type = 'standard', # Decoder architecture: 'standard' for both DDPM++ and NCSN++.
390 | resample_filter = [1,1], # Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++.
391 | s_embed=True,
392 | share_tsemb=True,
393 | embedding_kwargs = {},
394 | **kwargs
395 | ):
396 | assert embedding_type in ['fourier', 'positional']
397 | assert encoder_type in ['standard', 'skip', 'residual']
398 | assert decoder_type in ['standard', 'skip']
399 |
400 | super().__init__()
401 | self.label_dropout = label_dropout
402 | emb_channels = model_channels * channel_mult_emb
403 | noise_channels = model_channels * channel_mult_noise
404 | init = dict(init_mode='xavier_uniform')
405 | init_zero = dict(init_mode='xavier_uniform', init_weight=1e-5)
406 | init_attn = dict(init_mode='xavier_uniform', init_weight=np.sqrt(0.2))
407 | block_kwargs = dict(
408 | emb_channels=emb_channels, num_heads=1, dropout=dropout, skip_scale=np.sqrt(0.5), eps=1e-6,
409 | resample_filter=resample_filter, resample_proj=True, adaptive_scale=False,
410 | init=init, init_zero=init_zero, init_attn=init_attn,
411 | )
412 |
413 | # Mapping.
414 | self.map_noise = PositionalEmbedding(num_channels=noise_channels, endpoint=True) if embedding_type == 'positional' else FourierEmbedding(num_channels=noise_channels, **embedding_kwargs)
415 | self.map_label = Linear(in_features=label_dim, out_features=noise_channels, **init) if label_dim else None
416 | self.map_augment = Linear(in_features=augment_dim, out_features=noise_channels, bias=False, **init) if augment_dim else None
417 | self.map_layer0 = Linear(in_features=noise_channels, out_features=emb_channels, **init)
418 | self.map_layer1 = Linear(in_features=emb_channels, out_features=emb_channels, **init)
419 |
420 | self.s_embed = s_embed
421 | if s_embed:
422 |
423 | if embedding_type == "positional":
424 | self.map_noise_s = PositionalEmbedding(
425 | num_channels=noise_channels, endpoint=True
426 | )
427 | elif embedding_type == "fourier":
428 | self.map_noise_s = FourierEmbedding(
429 | num_channels=noise_channels, **embedding_kwargs
430 | )
431 | self.map_layer0_s = Linear(
432 | in_features=noise_channels,
433 | out_features=emb_channels,
434 | **init,
435 | )
436 | self.map_layer1_s = Linear(
437 | in_features=emb_channels, out_features=emb_channels, **init
438 | )
439 | # Encoder.
440 | self.enc = torch.nn.ModuleDict()
441 | cout = in_channels
442 | caux = in_channels
443 | for level, mult in enumerate(channel_mult):
444 | res = img_resolution >> level
445 | if level == 0:
446 | cin = cout
447 | cout = model_channels
448 | self.enc[f'{res}x{res}_conv'] = Conv2d(in_channels=cin, out_channels=cout, kernel=3, **init)
449 | else:
450 | self.enc[f'{res}x{res}_down'] = UNetBlock(in_channels=cout, out_channels=cout, down=True, **block_kwargs)
451 | if encoder_type == 'skip':
452 | self.enc[f'{res}x{res}_aux_down'] = Conv2d(in_channels=caux, out_channels=caux, kernel=0, down=True, resample_filter=resample_filter)
453 | self.enc[f'{res}x{res}_aux_skip'] = Conv2d(in_channels=caux, out_channels=cout, kernel=1, **init)
454 | if encoder_type == 'residual':
455 | self.enc[f'{res}x{res}_aux_residual'] = Conv2d(in_channels=caux, out_channels=cout, kernel=3, down=True, resample_filter=resample_filter, fused_resample=True, **init)
456 | caux = cout
457 | for idx in range(num_blocks):
458 | cin = cout
459 | cout = model_channels * mult
460 | attn = (res in attn_resolutions)
461 | self.enc[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=attn, **block_kwargs)
462 | skips = [block.out_channels for name, block in self.enc.items() if 'aux' not in name]
463 |
464 | # Decoder.
465 | self.dec = torch.nn.ModuleDict()
466 | for level, mult in reversed(list(enumerate(channel_mult))):
467 | res = img_resolution >> level
468 | if level == len(channel_mult) - 1:
469 | self.dec[f'{res}x{res}_in0'] = UNetBlock(in_channels=cout, out_channels=cout, attention=True, **block_kwargs)
470 | self.dec[f'{res}x{res}_in1'] = UNetBlock(in_channels=cout, out_channels=cout, **block_kwargs)
471 | else:
472 | self.dec[f'{res}x{res}_up'] = UNetBlock(in_channels=cout, out_channels=cout, up=True, **block_kwargs)
473 | for idx in range(num_blocks + 1):
474 | cin = cout + skips.pop()
475 | cout = model_channels * mult
476 | attn = (idx == num_blocks and res in attn_resolutions)
477 | self.dec[f'{res}x{res}_block{idx}'] = UNetBlock(in_channels=cin, out_channels=cout, attention=attn, **block_kwargs)
478 | if decoder_type == 'skip' or level == 0:
479 | if decoder_type == 'skip' and level < len(channel_mult) - 1:
480 | self.dec[f'{res}x{res}_aux_up'] = Conv2d(in_channels=out_channels, out_channels=out_channels, kernel=0, up=True, resample_filter=resample_filter)
481 | self.dec[f'{res}x{res}_aux_norm'] = GroupNorm(num_channels=cout, eps=1e-6)
482 | self.dec[f'{res}x{res}_aux_conv'] = Conv2d(in_channels=cout, out_channels=out_channels, kernel=3, **init_zero)
483 |
484 | def forward(self,x,
485 | noise_labels_t,
486 | noise_labels_s=None,
487 | class_labels=None,
488 | augment_labels=None, ):
489 | # Mapping.
490 | emb = self.map_noise(noise_labels_t)
491 | emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
492 | if self.map_label is not None:
493 | tmp = class_labels
494 | if self.training and self.label_dropout:
495 | tmp = tmp * (torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout).to(tmp.dtype)
496 | emb = emb + self.map_label(tmp * np.sqrt(self.map_label.in_features))
497 | if self.map_augment is not None and augment_labels is not None:
498 | emb = emb + self.map_augment(augment_labels)
499 | emb = silu(self.map_layer0(emb))
500 | emb = self.map_layer1(emb)
501 |
502 |
503 | if noise_labels_s is not None and self.s_embed:
504 |
505 | emb_s = self.map_noise_s(noise_labels_s)
506 | emb_s = (
507 | emb_s.reshape(emb_s.shape[0], 2, -1).flip(1).reshape(*emb_s.shape)
508 | ) # swap sin/cos
509 | emb_s = silu(self.map_layer0_s(emb_s))
510 | emb_s = self.map_layer1_s(emb_s)
511 | emb = emb + emb_s
512 |
513 | emb = silu(emb)
514 |
515 | # Encoder.
516 | skips = []
517 | aux = x
518 | for name, block in self.enc.items():
519 | if 'aux_down' in name:
520 | aux = block(aux)
521 | elif 'aux_skip' in name:
522 | x = skips[-1] = x + block(aux)
523 | elif 'aux_residual' in name:
524 | x = skips[-1] = aux = (x + block(aux)) / np.sqrt(2)
525 | else:
526 | x = block(x, emb) if isinstance(block, UNetBlock) else block(x)
527 | skips.append(x)
528 |
529 | # Decoder.
530 | aux = None
531 | tmp = None
532 | for name, block in self.dec.items():
533 | if 'aux_up' in name:
534 | aux = block(aux)
535 | elif 'aux_norm' in name:
536 | tmp = block(x)
537 | elif 'aux_conv' in name:
538 | tmp = block(silu(tmp))
539 | aux = tmp if aux is None else tmp + aux
540 | else:
541 | if x.shape[1] != block.in_channels:
542 | x = torch.cat([x, skips.pop()], dim=1)
543 | x = block(x, emb)
544 |
545 | return aux
546 |
547 |
548 |
549 |
550 | # ----------------------------------------------------------------------------
551 | # Reimplementation of the ADM architecture from the paper
552 | # "Diffusion Models Beat GANS on Image Synthesis". Equivalent to the
553 | # original implementation by Dhariwal and Nichol, available at
554 | # https://github.com/openai/guided-diffusion
555 |
556 |
557 | @persistence.persistent_class
558 | class DhariwalUNet(torch.nn.Module):
559 |
560 | def __init__(
561 | self,
562 | img_resolution, # Image resolution at input/output.
563 | in_channels, # Number of color channels at input.
564 | out_channels, # Number of color channels at output.
565 | label_dim=0, # Number of class labels, 0 = unconditional.
566 | augment_dim=0, # Augmentation label dimensionality, 0 = no augmentation.
567 | model_channels=192, # Base multiplier for the number of channels.
568 | channel_mult=[
569 | 1,
570 | 2,
571 | 3,
572 | 4,
573 | ], # Per-resolution multipliers for the number of channels.
574 | channel_mult_emb=4, # Multiplier for the dimensionality of the embedding vector.
575 | num_blocks=3, # Number of residual blocks per resolution.
576 | attn_resolutions=[32, 16, 8], # List of resolutions with self-attention.
577 | dropout=0.10, # List of resolutions with self-attention.
578 | label_dropout=0, # Dropout probability of class labels for classifier-free guidance.
579 | s_embed=True,
580 | **kwargs
581 | ):
582 | super().__init__()
583 | self.label_dropout = label_dropout
584 | emb_channels = model_channels * channel_mult_emb
585 | init = dict(
586 | init_mode="kaiming_uniform",
587 | init_weight=np.sqrt(1 / 3),
588 | init_bias=np.sqrt(1 / 3),
589 | )
590 | init_zero = dict(init_mode="kaiming_uniform", init_weight=0, init_bias=0)
591 | block_kwargs = dict(
592 | emb_channels=emb_channels,
593 | channels_per_head=64,
594 | dropout=dropout,
595 | init=init,
596 | init_zero=init_zero,
597 | )
598 |
599 | # Mapping.
600 | self.map_noise = PositionalEmbedding(num_channels=model_channels)
601 | self.s_embed = s_embed
602 | if s_embed:
603 | self.map_noise_s = self.map_noise
604 |
605 | self.map_layer0_s = Linear(
606 | in_features=model_channels,
607 | out_features=emb_channels,
608 | **init,
609 | )
610 | self.map_layer1_s = Linear(
611 | in_features=emb_channels, out_features=emb_channels, **init
612 | )
613 | self.map_augment = (
614 | Linear(
615 | in_features=augment_dim,
616 | out_features=model_channels ,
617 | bias=False,
618 | **init_zero,
619 | )
620 | if augment_dim
621 | else None
622 | )
623 | self.map_layer0 = Linear(
624 | in_features=model_channels ,
625 | out_features=emb_channels ,
626 | **init,
627 | )
628 | self.map_layer1 = Linear(
629 | in_features=emb_channels ,
630 | out_features=emb_channels,
631 | **init,
632 | )
633 | self.map_label = (
634 | Linear(
635 | in_features=label_dim,
636 | out_features=emb_channels,
637 | bias=False,
638 | init_mode="kaiming_normal",
639 | init_weight=np.sqrt(label_dim),
640 | )
641 | if label_dim
642 | else None
643 | )
644 |
645 | # Encoder.
646 | self.enc = torch.nn.ModuleDict()
647 | cout = in_channels
648 | for level, mult in enumerate(channel_mult):
649 | res = img_resolution >> level
650 | if level == 0:
651 | cin = cout
652 | cout = model_channels * mult
653 | self.enc[f"{res}x{res}_conv"] = Conv2d(
654 | in_channels=cin, out_channels=cout, kernel=3, **init
655 | )
656 | else:
657 | self.enc[f"{res}x{res}_down"] = UNetBlock(
658 | in_channels=cout,
659 | out_channels=cout,
660 | down=True,
661 | **block_kwargs,
662 | )
663 | for idx in range(num_blocks):
664 | cin = cout
665 | cout = model_channels * mult
666 | self.enc[f"{res}x{res}_block{idx}"] = UNetBlock(
667 | in_channels=cin,
668 | out_channels=cout,
669 | attention=(res in attn_resolutions),
670 | **block_kwargs,
671 | )
672 | skips = [block.out_channels for block in self.enc.values()]
673 |
674 | # Decoder.
675 | self.dec = torch.nn.ModuleDict()
676 | for level, mult in reversed(list(enumerate(channel_mult))):
677 | res = img_resolution >> level
678 | if level == len(channel_mult) - 1:
679 | self.dec[f"{res}x{res}_in0"] = UNetBlock(
680 | in_channels=cout,
681 | out_channels=cout,
682 | attention=True,
683 | **block_kwargs,
684 | )
685 | self.dec[f"{res}x{res}_in1"] = UNetBlock(
686 | in_channels=cout,
687 | out_channels=cout,
688 | **block_kwargs,
689 | )
690 | else:
691 | self.dec[f"{res}x{res}_up"] = UNetBlock(
692 | in_channels=cout,
693 | out_channels=cout,
694 | up=True,
695 | **block_kwargs,
696 | )
697 | for idx in range(num_blocks + 1):
698 | cin = cout + skips.pop()
699 | cout = model_channels * mult
700 | self.dec[f"{res}x{res}_block{idx}"] = UNetBlock(
701 | in_channels=cin,
702 | out_channels=cout,
703 | attention=(res in attn_resolutions),
704 | **block_kwargs,
705 | )
706 |
707 | self.out_norm = GroupNorm(num_channels=cout)
708 | self.out_conv = Conv2d(
709 | in_channels=cout, out_channels=out_channels, kernel=3, **init_zero
710 | )
711 |
712 | def forward(
713 | self,
714 | x,
715 | noise_labels_t,
716 | noise_labels_s=None,
717 | class_labels=None,
718 | augment_labels=None,
719 | ):
720 |
721 | # Mapping.
722 | emb = self.map_noise(noise_labels_t)
723 | if self.map_augment is not None and augment_labels is not None:
724 | emb = emb + self.map_augment(augment_labels)
725 | emb = silu(self.map_layer0(emb))
726 | emb = self.map_layer1(emb)
727 | if self.map_label is not None:
728 | tmp = class_labels
729 | if self.training and self.label_dropout:
730 | tmp = tmp * (
731 | torch.rand([x.shape[0], 1], device=x.device) >= self.label_dropout
732 | ).to(tmp.dtype)
733 | emb = emb + self.map_label(tmp)
734 | if noise_labels_s is not None and self.s_embed:
735 |
736 | emb_s = self.map_noise_s(noise_labels_s)
737 | emb_s = silu(self.map_layer0_s(emb_s))
738 | emb_s = self.map_layer1_s(emb_s)
739 | emb = emb + emb_s
740 |
741 | emb = silu(emb)
742 |
743 | # Encoder.
744 | skips = []
745 | for block in self.enc.values():
746 | x = block(x, emb) if isinstance(block, UNetBlock) else block(x)
747 | skips.append(x)
748 |
749 | # Decoder.
750 | for block in self.dec.values():
751 | if x.shape[1] != block.in_channels:
752 | x = torch.cat([x, skips.pop()], dim=1)
753 | x = block(x, emb)
754 | x = self.out_conv(silu(self.out_norm(x)))
755 | return x
756 |
757 |
758 |
--------------------------------------------------------------------------------