├── .gitignore
├── LICENSE
├── PROMPT_GUIDE.md
├── README.md
├── assets
├── cond_and_image.jpg
├── examples
│ ├── canny
│ │ └── canny.webp
│ ├── depth
│ │ └── astronaut.webp
│ ├── id_customization
│ │ ├── chenhao
│ │ │ ├── image_0.png
│ │ │ ├── image_1.png
│ │ │ └── image_2.png
│ │ └── chika
│ │ │ ├── image_0.png
│ │ │ └── image_1.webp
│ ├── image_editing
│ │ └── astronaut.webp
│ ├── images
│ │ ├── cat.webp
│ │ └── cat_on_table.webp
│ ├── inpainting
│ │ └── giorno.webp
│ ├── semantic_map
│ │ ├── dragon.webp
│ │ ├── dragon_and_birds.png
│ │ └── dragon_birds_woman.webp
│ └── subject_driven
│ │ └── chill_guy.jpg
├── gradio_cached_examples
│ └── 42
│ │ ├── Generated Images
│ │ ├── 17b88522de30372d1f35
│ │ │ └── image.webp
│ │ ├── 317b1e50f07fbdcc2f45
│ │ │ └── image.webp
│ │ ├── 53575bb9e63d8bc1a71c
│ │ │ └── image.webp
│ │ ├── 581111895a66f93b1f10
│ │ │ └── image.webp
│ │ ├── 5d11a379682726a1dc31
│ │ │ └── image.webp
│ │ ├── 6beeb89d25287417e72d
│ │ │ └── image.webp
│ │ ├── 80d358ca48472ae5c466
│ │ │ └── image.webp
│ │ ├── 81bd6641e26a411b2da6
│ │ │ └── image.webp
│ │ ├── 892c2282d95cba917fa2
│ │ │ └── image.webp
│ │ ├── 8a9a8c9e07527586fc15
│ │ │ └── image.webp
│ │ ├── 8fd76ebb6657d26dd92c
│ │ │ └── image.webp
│ │ ├── 930ca5ca2692b479b752
│ │ │ └── image.webp
│ │ ├── 96928c780e841920bfcc
│ │ │ └── image.webp
│ │ ├── 9bdb84088da3c475e4bd
│ │ │ └── image.webp
│ │ ├── b6e34a7e72d74a80b52f
│ │ │ └── image.webp
│ │ ├── c8580411711e5880bb78
│ │ │ └── image.webp
│ │ └── fe7d4925dcd9c0186674
│ │ │ └── image.webp
│ │ ├── Image Preview
│ │ ├── 06eab237f5c35c1b745f
│ │ │ └── image.webp
│ │ ├── 08c2007af01acb165e9a
│ │ │ └── image.webp
│ │ ├── 133fc4374bccf2d6b2ca
│ │ │ └── image.webp
│ │ ├── 2b1daf3b917623fd8ff2
│ │ │ └── image.webp
│ │ ├── 2d82244403afa8c210e3
│ │ │ └── image.webp
│ │ ├── 303d58e0a1cdb5b953f9
│ │ │ └── image.webp
│ │ ├── 4284edec7b2efdd3fef1
│ │ │ └── image.webp
│ │ ├── 8348cbc4b20a443fc03b
│ │ │ └── image.webp
│ │ ├── 86bf806faee9fbb5d92e
│ │ │ └── image.webp
│ │ ├── 9065eadbcdacd92fe32f
│ │ │ └── image.webp
│ │ └── e47602f78caeedc4c6d3
│ │ │ └── image.webp
│ │ ├── Input Images
│ │ ├── 0ca5ab7959eed9fef525
│ │ │ └── astronaut.webp
│ │ ├── 10e196d148112fc7eb6a
│ │ │ └── dragon_birds_woman.webp
│ │ ├── 3a98c4847c8c4e64851e
│ │ │ └── giorno.webp
│ │ ├── 43c8d5a2278e70960d7b
│ │ │ └── image_1.png
│ │ ├── 491863b2f47c65f9ac67
│ │ │ └── image_0.png
│ │ ├── 7a5a351a562f9423582e
│ │ │ └── image_2.png
│ │ ├── 830bb33b0629cc76770d
│ │ │ └── cat_on_table.webp
│ │ ├── bbfa399fd2eb84e96645
│ │ │ └── chill_guy.jpg
│ │ ├── c2c85d82ceb9a9574063
│ │ │ └── image_0.png
│ │ ├── c875f74b3d73b2ea804f
│ │ │ └── astronaut.webp
│ │ └── f63287119763bed48822
│ │ │ └── cat.webp
│ │ └── log.csv
├── onediffusion_appendix_faceid.jpg
├── onediffusion_appendix_faceid_3.jpg
├── onediffusion_appendix_multiview.jpg
├── onediffusion_appendix_multiview_2.jpg
├── onediffusion_appendix_text2multiview.pdf
├── onediffusion_editing.jpg
├── onediffusion_zeroshot.jpg
├── promptguide_complex.jpg
├── promptguide_idtask.jpg
├── subject_driven.jpg
├── teaser.png
├── text2image.jpg
└── text2multiview.jpg
├── docker
└── Dockerfile
├── gradio_demo.py
├── inference.py
├── onediffusion
├── dataset
│ ├── multitask
│ │ └── multiview.py
│ ├── raydiff_utils.py
│ ├── transforms.py
│ └── utils.py
├── diffusion
│ └── pipelines
│ │ ├── image_processor.py
│ │ └── onediffusion.py
└── models
│ └── denoiser
│ ├── __init__.py
│ └── nextdit
│ ├── __init__.py
│ ├── layers.py
│ └── modeling_nextdit.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 | .vscode/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. Copyright and Similar Rights means copyright and/or similar rights
88 | closely related to copyright including, without limitation,
89 | performance, broadcast, sound recording, and Sui Generis Database
90 | Rights, without regard to how the rights are labeled or
91 | categorized. For purposes of this Public License, the rights
92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
93 | Rights.
94 | d. Effective Technological Measures means those measures that, in the
95 | absence of proper authority, may not be circumvented under laws
96 | fulfilling obligations under Article 11 of the WIPO Copyright
97 | Treaty adopted on December 20, 1996, and/or similar international
98 | agreements.
99 |
100 | e. Exceptions and Limitations means fair use, fair dealing, and/or
101 | any other exception or limitation to Copyright and Similar Rights
102 | that applies to Your use of the Licensed Material.
103 |
104 | f. Licensed Material means the artistic or literary work, database,
105 | or other material to which the Licensor applied this Public
106 | License.
107 |
108 | g. Licensed Rights means the rights granted to You subject to the
109 | terms and conditions of this Public License, which are limited to
110 | all Copyright and Similar Rights that apply to Your use of the
111 | Licensed Material and that the Licensor has authority to license.
112 |
113 | h. Licensor means the individual(s) or entity(ies) granting rights
114 | under this Public License.
115 |
116 | i. NonCommercial means not primarily intended for or directed towards
117 | commercial advantage or monetary compensation. For purposes of
118 | this Public License, the exchange of the Licensed Material for
119 | other material subject to Copyright and Similar Rights by digital
120 | file-sharing or similar means is NonCommercial provided there is
121 | no payment of monetary compensation in connection with the
122 | exchange.
123 |
124 | j. Share means to provide material to the public by any means or
125 | process that requires permission under the Licensed Rights, such
126 | as reproduction, public display, public performance, distribution,
127 | dissemination, communication, or importation, and to make material
128 | available to the public including in ways that members of the
129 | public may access the material from a place and at a time
130 | individually chosen by them.
131 |
132 | k. Sui Generis Database Rights means rights other than copyright
133 | resulting from Directive 96/9/EC of the European Parliament and of
134 | the Council of 11 March 1996 on the legal protection of databases,
135 | as amended and/or succeeded, as well as other essentially
136 | equivalent rights anywhere in the world.
137 |
138 | l. You means the individual or entity exercising the Licensed Rights
139 | under this Public License. Your has a corresponding meaning.
140 |
141 |
142 | Section 2 -- Scope.
143 |
144 | a. License grant.
145 |
146 | 1. Subject to the terms and conditions of this Public License,
147 | the Licensor hereby grants You a worldwide, royalty-free,
148 | non-sublicensable, non-exclusive, irrevocable license to
149 | exercise the Licensed Rights in the Licensed Material to:
150 |
151 | a. reproduce and Share the Licensed Material, in whole or
152 | in part, for NonCommercial purposes only; and
153 |
154 | b. produce, reproduce, and Share Adapted Material for
155 | NonCommercial purposes only.
156 |
157 | 2. Exceptions and Limitations. For the avoidance of doubt, where
158 | Exceptions and Limitations apply to Your use, this Public
159 | License does not apply, and You do not need to comply with
160 | its terms and conditions.
161 |
162 | 3. Term. The term of this Public License is specified in Section
163 | 6(a).
164 |
165 | 4. Media and formats; technical modifications allowed. The
166 | Licensor authorizes You to exercise the Licensed Rights in
167 | all media and formats whether now known or hereafter created,
168 | and to make technical modifications necessary to do so. The
169 | Licensor waives and/or agrees not to assert any right or
170 | authority to forbid You from making technical modifications
171 | necessary to exercise the Licensed Rights, including
172 | technical modifications necessary to circumvent Effective
173 | Technological Measures. For purposes of this Public License,
174 | simply making modifications authorized by this Section 2(a)
175 | (4) never produces Adapted Material.
176 |
177 | 5. Downstream recipients.
178 |
179 | a. Offer from the Licensor -- Licensed Material. Every
180 | recipient of the Licensed Material automatically
181 | receives an offer from the Licensor to exercise the
182 | Licensed Rights under the terms and conditions of this
183 | Public License.
184 |
185 | b. No downstream restrictions. You may not offer or impose
186 | any additional or different terms or conditions on, or
187 | apply any Effective Technological Measures to, the
188 | Licensed Material if doing so restricts exercise of the
189 | Licensed Rights by any recipient of the Licensed
190 | Material.
191 |
192 | 6. No endorsement. Nothing in this Public License constitutes or
193 | may be construed as permission to assert or imply that You
194 | are, or that Your use of the Licensed Material is, connected
195 | with, or sponsored, endorsed, or granted official status by,
196 | the Licensor or others designated to receive attribution as
197 | provided in Section 3(a)(1)(A)(i).
198 |
199 | b. Other rights.
200 |
201 | 1. Moral rights, such as the right of integrity, are not
202 | licensed under this Public License, nor are publicity,
203 | privacy, and/or other similar personality rights; however, to
204 | the extent possible, the Licensor waives and/or agrees not to
205 | assert any such rights held by the Licensor to the limited
206 | extent necessary to allow You to exercise the Licensed
207 | Rights, but not otherwise.
208 |
209 | 2. Patent and trademark rights are not licensed under this
210 | Public License.
211 |
212 | 3. To the extent possible, the Licensor waives any right to
213 | collect royalties from You for the exercise of the Licensed
214 | Rights, whether directly or through a collecting society
215 | under any voluntary or waivable statutory or compulsory
216 | licensing scheme. In all other cases the Licensor expressly
217 | reserves any right to collect such royalties, including when
218 | the Licensed Material is used other than for NonCommercial
219 | purposes.
220 |
221 |
222 | Section 3 -- License Conditions.
223 |
224 | Your exercise of the Licensed Rights is expressly made subject to the
225 | following conditions.
226 |
227 | a. Attribution.
228 |
229 | 1. If You Share the Licensed Material (including in modified
230 | form), You must:
231 |
232 | a. retain the following if it is supplied by the Licensor
233 | with the Licensed Material:
234 |
235 | i. identification of the creator(s) of the Licensed
236 | Material and any others designated to receive
237 | attribution, in any reasonable manner requested by
238 | the Licensor (including by pseudonym if
239 | designated);
240 |
241 | ii. a copyright notice;
242 |
243 | iii. a notice that refers to this Public License;
244 |
245 | iv. a notice that refers to the disclaimer of
246 | warranties;
247 |
248 | v. a URI or hyperlink to the Licensed Material to the
249 | extent reasonably practicable;
250 |
251 | b. indicate if You modified the Licensed Material and
252 | retain an indication of any previous modifications; and
253 |
254 | c. indicate the Licensed Material is licensed under this
255 | Public License, and include the text of, or the URI or
256 | hyperlink to, this Public License.
257 |
258 | 2. You may satisfy the conditions in Section 3(a)(1) in any
259 | reasonable manner based on the medium, means, and context in
260 | which You Share the Licensed Material. For example, it may be
261 | reasonable to satisfy the conditions by providing a URI or
262 | hyperlink to a resource that includes the required
263 | information.
264 |
265 | 3. If requested by the Licensor, You must remove any of the
266 | information required by Section 3(a)(1)(A) to the extent
267 | reasonably practicable.
268 |
269 | 4. If You Share Adapted Material You produce, the Adapter's
270 | License You apply must not prevent recipients of the Adapted
271 | Material from complying with this Public License.
272 |
273 |
274 | Section 4 -- Sui Generis Database Rights.
275 |
276 | Where the Licensed Rights include Sui Generis Database Rights that
277 | apply to Your use of the Licensed Material:
278 |
279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280 | to extract, reuse, reproduce, and Share all or a substantial
281 | portion of the contents of the database for NonCommercial purposes
282 | only;
283 |
284 | b. if You include all or a substantial portion of the database
285 | contents in a database in which You have Sui Generis Database
286 | Rights, then the database in which You have Sui Generis Database
287 | Rights (but not its individual contents) is Adapted Material; and
288 |
289 | c. You must comply with the conditions in Section 3(a) if You Share
290 | all or a substantial portion of the contents of the database.
291 |
292 | For the avoidance of doubt, this Section 4 supplements and does not
293 | replace Your obligations under this Public License where the Licensed
294 | Rights include other Copyright and Similar Rights.
295 |
296 |
297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298 |
299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309 |
310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319 |
320 | c. The disclaimer of warranties and limitation of liability provided
321 | above shall be interpreted in a manner that, to the extent
322 | possible, most closely approximates an absolute disclaimer and
323 | waiver of all liability.
324 |
325 |
326 | Section 6 -- Term and Termination.
327 |
328 | a. This Public License applies for the term of the Copyright and
329 | Similar Rights licensed here. However, if You fail to comply with
330 | this Public License, then Your rights under this Public License
331 | terminate automatically.
332 |
333 | b. Where Your right to use the Licensed Material has terminated under
334 | Section 6(a), it reinstates:
335 |
336 | 1. automatically as of the date the violation is cured, provided
337 | it is cured within 30 days of Your discovery of the
338 | violation; or
339 |
340 | 2. upon express reinstatement by the Licensor.
341 |
342 | For the avoidance of doubt, this Section 6(b) does not affect any
343 | right the Licensor may have to seek remedies for Your violations
344 | of this Public License.
345 |
346 | c. For the avoidance of doubt, the Licensor may also offer the
347 | Licensed Material under separate terms or conditions or stop
348 | distributing the Licensed Material at any time; however, doing so
349 | will not terminate this Public License.
350 |
351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352 | License.
353 |
354 |
355 | Section 7 -- Other Terms and Conditions.
356 |
357 | a. The Licensor shall not be bound by any additional or different
358 | terms or conditions communicated by You unless expressly agreed.
359 |
360 | b. Any arrangements, understandings, or agreements regarding the
361 | Licensed Material not stated herein are separate from and
362 | independent of the terms and conditions of this Public License.
363 |
364 |
365 | Section 8 -- Interpretation.
366 |
367 | a. For the avoidance of doubt, this Public License does not, and
368 | shall not be interpreted to, reduce, limit, restrict, or impose
369 | conditions on any use of the Licensed Material that could lawfully
370 | be made without permission under this Public License.
371 |
372 | b. To the extent possible, if any provision of this Public License is
373 | deemed unenforceable, it shall be automatically reformed to the
374 | minimum extent necessary to make it enforceable. If the provision
375 | cannot be reformed, it shall be severed from this Public License
376 | without affecting the enforceability of the remaining terms and
377 | conditions.
378 |
379 | c. No term or condition of this Public License will be waived and no
380 | failure to comply consented to unless expressly agreed to by the
381 | Licensor.
382 |
383 | d. Nothing in this Public License constitutes or may be interpreted
384 | as a limitation upon, or waiver of, any privileges and immunities
385 | that apply to the Licensor or You, including from the legal
386 | processes of any jurisdiction or authority.
387 |
388 | =======================================================================
389 |
390 | Creative Commons is not a party to its public
391 | licenses. Notwithstanding, Creative Commons may elect to apply one of
392 | its public licenses to material it publishes and in those instances
393 | will be considered the “Licensor.” The text of the Creative Commons
394 | public licenses is dedicated to the public domain under the CC0 Public
395 | Domain Dedication. Except for the limited purpose of indicating that
396 | material is shared under a Creative Commons public license or as
397 | otherwise permitted by the Creative Commons policies published at
398 | creativecommons.org/policies, Creative Commons does not authorize the
399 | use of the trademark "Creative Commons" or any other trademark or logo
400 | of Creative Commons without its prior written consent including,
401 | without limitation, in connection with any unauthorized modifications
402 | to any of its public licenses or any other arrangements,
403 | understandings, or agreements concerning use of licensed material. For
404 | the avoidance of doubt, this paragraph does not form part of the
405 | public licenses.
406 |
407 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/PROMPT_GUIDE.md:
--------------------------------------------------------------------------------
1 | # Prompt Guide
2 |
3 | All examples are generated with a CFG of $4.2$, $50$ steps, and are non-cherrypicked unless otherwise stated. Negative prompt is set to:
4 | ```
5 | monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation
6 | ```
7 |
8 | ## 1. Text-to-Image
9 |
10 | ### 1.1 Long and detailed prompts give (much) better results.
11 |
12 | Since our training comprised of long and detailed prompts, the model is more likely to generate better images with detailed prompts.
13 |
14 |
15 | The model shows good text adherence with long and complex prompts as in below images. We use the first $20$ prompts from [simoryu's examples](https://cloneofsimo.github.io/compare_aura_sd3/). For detailed prompts, results of other models, refer to the above link.
16 |
17 |
18 |
19 |
20 |
21 |
22 | ### 1.2 Resolution
23 |
24 | The model generally works well with height and width in range of $[768; 1280]$ (height/width must be divisible by 16) for text-to-image. For other tasks, it performs best with resolution around $512$.
25 |
26 | ## 2. ID Customization & Subject-driven generation
27 |
28 | - The expected length of source captions is $30$ to $75$ words. Empirically, we find that longer prompt can help preserve the ID better but it might hinder the text-adherence for target caption.
29 |
30 | - We find it better to add some descriptions (e.g., from source caption) to target to preserve the identity, especially for complex subjects with delicate details.
31 |
32 |
33 |
34 |
35 |
36 | ## 3. Multiview generation
37 |
38 | We recommend not use captions, which describe the facial features e.g., looking at the camera, etc, to mitigate multifaced/janus problems.
39 |
40 | ## 4. Image editing
41 |
42 | We find it's generally better to set the guidance scale to lower value e.g., $[3; 3.5]$ to avoid over-saturation results.
43 |
44 | ## 5. Special tokens and available colors
45 |
46 | ### 5.1 Task Tokens
47 |
48 | | Task | Token | Additional Tokens |
49 | |:---------------------|:---------------------------|:------------------|
50 | | Text to Image | `[[text2image]]` | |
51 | | Deblurring | `[[deblurring]]` | |
52 | | Inpainting | `[[image_inpainting]]` | |
53 | | Canny-edge and Image | `[[canny2image]]` | |
54 | | Depth and Image | `[[depth2image]]` | |
55 | | Hed and Image | `[[hed2img]]` | |
56 | | Pose and Image | `[[pose2image]]` | |
57 | | Image editing with Instruction | `[[image_editing]]` | |
58 | | Semantic map and Image| `[[semanticmap2image]]` | `<#00FFFF cyan mask: object/to/segment>` |
59 | | Boundingbox and Image | `[[boundingbox2image]]` | `<#00FFFF cyan boundingbox: object/to/detect>` |
60 | | ID customization | `[[faceid]]` | `[[img0]] target/caption [[img1]] caption/of/source/image_1 [[img2]] caption/of/source/image_2 [[img3]] caption/of/source/image_3` |
61 | | Multiview | `[[multiview]]` | |
62 | | Subject-Driven | `[[subject_driven]]` | ` [[img0]] target/caption/goes/here [[img1]] insert/source/caption` |
63 |
64 |
65 | Note that you can replace the cyan color above with any from below table and have multiple additional tokens to detect/segment multiple classes.
66 |
67 | ### 5.2 Available colors
68 |
69 |
70 | | Hex Code | Color Name |
71 | |:---------|:-----------|
72 | | #FF0000 | red |
73 | | #00FF00 | lime |
74 | | #0000FF | blue |
75 | | #FFFF00 | yellow |
76 | | #FF00FF | magenta |
77 | | #00FFFF | cyan |
78 | | #FFA500 | orange |
79 | | #800080 | purple |
80 | | #A52A2A | brown |
81 | | #008000 | green |
82 | | #FFC0CB | pink |
83 | | #008080 | teal |
84 | | #FF8C00 | darkorange |
85 | | #8A2BE2 | blueviolet |
86 | | #006400 | darkgreen |
87 | | #FF4500 | orangered |
88 | | #000080 | navy |
89 | | #FFD700 | gold |
90 | | #40E0D0 | turquoise |
91 | | #DA70D6 | orchid |
92 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # One Diffusion to Generate Them All
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | News |
21 | Quick start |
22 | Prompt guide & Supported tasks |
23 | Qualitative results |
24 | License |
25 | Citation
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 | This is official repo of OneDiffusion, a versatile, large-scale diffusion model that seamlessly supports bidirectional image synthesis and understanding across diverse tasks.
36 |
37 | **For more detail, read our paper [here](https://arxiv.org/abs/2411.16318).**
38 |
39 | ## News
40 | - 📦 2024/12/11: [Huggingface space](https://huggingface.co/spaces/lehduong/OneDiffusion) is online. Reduce the VRAM requirements for running demo with Molmo to 21GB.
41 | - 📦 2024/12/10: Released [weight](https://huggingface.co/lehduong/OneDiffusion) and inference code.
42 | - ✨ 2024/12/06: Added image editing from instruction.
43 | - ✨ 2024/12/02: Added subject-driven generation
44 |
45 | ## Installation
46 | ```
47 | conda create -n onediffusion_env python=3.8 &&
48 | conda activate onediffusion_env &&
49 | pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118 &&
50 | pip install "git+https://github.com/facebookresearch/pytorch3d.git" &&
51 | pip install -r requirements.txt
52 | ```
53 |
54 | ## Quick start
55 |
56 | Check `inference.py` for more detailed. For text-to-image, you can use below code snipe.
57 |
58 | ```
59 | import torch
60 | from onediffusion.diffusion.pipelines.onediffusion import OneDiffusionPipeline
61 |
62 | device = torch.device('cuda:0')
63 |
64 | pipeline = OneDiffusionPipeline.from_pretrained("lehduong/OneDiffusion").to(device=device, dtype=torch.bfloat16)
65 |
66 | NEGATIVE_PROMPT = "monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
67 |
68 | output = pipeline(
69 | prompt="[[text2image]] A bipedal black cat wearing a huge oversized witch hat, a wizards robe, casting a spell,in an enchanted forest. The scene is filled with fireflies and moss on surrounding rocks and trees",
70 | negative_prompt=NEGATIVE_PROMPT,
71 | num_inference_steps=50,
72 | guidance_scale=4,
73 | height=1024,
74 | width=1024,
75 | )
76 | output.images[0].save('text2image_output.jpg')
77 | ```
78 |
79 | You can run the gradio demo with:
80 | ```
81 | python gradio_demo.py --captioner molmo # [molmo, llava, disable]
82 | ```
83 | The demo provides guidance and helps format the prompt properly for each task.
84 |
85 | - By default, it loads the **quantized** Molmo for captioning source images. ~~which significantly increases memory usage. You generally need a GPU with at least $40$ GB of memory to run the demo.~~ You generally need a GPU with at least $21$ GB of memory to run the demo.
86 |
87 | - Opting to use LLaVA can reduce this requirement to $\approx 27$ GB, though the resulting captions may be less accurate in some cases.
88 |
89 | - You can also manually provide the caption for each input image and run with `disable` mode. In this mode, the returned caption is an empty string, but you should still press the `Generate Caption` button so that the code formats the input text properly. The memory requirement for this mode is $\approx 12$ GB.
90 |
91 | Note that the above required memory can change if you use higher resolution or more input images.
92 |
93 | ## Qualitative Results
94 |
95 | ### 1. Text-to-Image
96 |
97 |
98 |
99 |
100 |
101 | ### 2. ID customization
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 | ### 3. Multiview generation
112 |
113 | Single image to multiview:
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 | Text to multiview:
124 |
125 |
126 |
127 |
128 |
129 | ### 4. Condition-to-Image and vice versa
130 |
131 |
132 |
133 |
134 | ### 5. Subject-driven generation
135 |
136 | We finetuned the model on [Subject-200K](https://huggingface.co/datasets/Yuanshi/Subjects200K) dataset (along with all other tasks) for additional 40k steps. The model is now capable of subject-driven generation.
137 |
138 |
139 |
140 |
141 |
142 | ### 6. Text-guide image editing
143 |
144 | We finetuned the model on [OmniEdit](https://huggingface.co/datasets/TIGER-Lab/OmniEdit-Filtered-1.2M) dataset for additional 30K steps.
145 |
146 |
147 |
148 |
149 |
150 | ### 7. Zero-shot Task combinations
151 |
152 | We found that the model can handle multiple tasks in a zero-shot setting by combining condition images and task tokens without any fine-tuning, as shown in the examples below. However, its performance on these combined tasks might not be robust, and the model’s behavior may change if the order of task tokens or captions is altered. For example, when using both image inpainting and ID customization together, the target prompt and the caption of the masked image must be identical. If you plan to use such combinations, we recommend fine-tuning the model on these tasks to achieve better performance and simpler usage.
153 |
154 |
155 |
156 |
157 |
158 |
159 | ## License
160 |
161 | The model is trained on several non-commercially licensed datasets (e.g., DL3DV, Unsplash), thus, **model weights** are released under a CC BY-NC license as described in [LICENSE](https://github.com/lehduong/onediffusion/blob/main/LICENSE).
162 |
163 | ## Citation
164 |
165 | ```bibtex
166 | @misc{le2024diffusiongenerate,
167 | title={One Diffusion to Generate Them All},
168 | author={Duong H. Le and Tuan Pham and Sangho Lee and Christopher Clark and Aniruddha Kembhavi and Stephan Mandt and Ranjay Krishna and Jiasen Lu},
169 | year={2024},
170 | eprint={2411.16318},
171 | archivePrefix={arXiv},
172 | primaryClass={cs.CV},
173 | url={https://arxiv.org/abs/2411.16318},
174 | }
175 | ```
176 |
--------------------------------------------------------------------------------
/assets/cond_and_image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/cond_and_image.jpg
--------------------------------------------------------------------------------
/assets/examples/canny/canny.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/canny/canny.webp
--------------------------------------------------------------------------------
/assets/examples/depth/astronaut.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/depth/astronaut.webp
--------------------------------------------------------------------------------
/assets/examples/id_customization/chenhao/image_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/id_customization/chenhao/image_0.png
--------------------------------------------------------------------------------
/assets/examples/id_customization/chenhao/image_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/id_customization/chenhao/image_1.png
--------------------------------------------------------------------------------
/assets/examples/id_customization/chenhao/image_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/id_customization/chenhao/image_2.png
--------------------------------------------------------------------------------
/assets/examples/id_customization/chika/image_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/id_customization/chika/image_0.png
--------------------------------------------------------------------------------
/assets/examples/id_customization/chika/image_1.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/id_customization/chika/image_1.webp
--------------------------------------------------------------------------------
/assets/examples/image_editing/astronaut.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/image_editing/astronaut.webp
--------------------------------------------------------------------------------
/assets/examples/images/cat.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/images/cat.webp
--------------------------------------------------------------------------------
/assets/examples/images/cat_on_table.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/images/cat_on_table.webp
--------------------------------------------------------------------------------
/assets/examples/inpainting/giorno.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/inpainting/giorno.webp
--------------------------------------------------------------------------------
/assets/examples/semantic_map/dragon.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/semantic_map/dragon.webp
--------------------------------------------------------------------------------
/assets/examples/semantic_map/dragon_and_birds.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/semantic_map/dragon_and_birds.png
--------------------------------------------------------------------------------
/assets/examples/semantic_map/dragon_birds_woman.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/semantic_map/dragon_birds_woman.webp
--------------------------------------------------------------------------------
/assets/examples/subject_driven/chill_guy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/subject_driven/chill_guy.jpg
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Generated Images/17b88522de30372d1f35/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/17b88522de30372d1f35/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Generated Images/317b1e50f07fbdcc2f45/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/317b1e50f07fbdcc2f45/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Generated Images/53575bb9e63d8bc1a71c/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/53575bb9e63d8bc1a71c/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Generated Images/581111895a66f93b1f10/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/581111895a66f93b1f10/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Generated Images/5d11a379682726a1dc31/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/5d11a379682726a1dc31/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Generated Images/6beeb89d25287417e72d/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/6beeb89d25287417e72d/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Generated Images/80d358ca48472ae5c466/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/80d358ca48472ae5c466/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Generated Images/81bd6641e26a411b2da6/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/81bd6641e26a411b2da6/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Generated Images/892c2282d95cba917fa2/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/892c2282d95cba917fa2/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Generated Images/8a9a8c9e07527586fc15/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/8a9a8c9e07527586fc15/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Generated Images/8fd76ebb6657d26dd92c/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/8fd76ebb6657d26dd92c/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Generated Images/930ca5ca2692b479b752/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/930ca5ca2692b479b752/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Generated Images/96928c780e841920bfcc/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/96928c780e841920bfcc/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Generated Images/9bdb84088da3c475e4bd/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/9bdb84088da3c475e4bd/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Generated Images/b6e34a7e72d74a80b52f/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/b6e34a7e72d74a80b52f/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Generated Images/c8580411711e5880bb78/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/c8580411711e5880bb78/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Generated Images/fe7d4925dcd9c0186674/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/fe7d4925dcd9c0186674/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Image Preview/06eab237f5c35c1b745f/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/06eab237f5c35c1b745f/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Image Preview/08c2007af01acb165e9a/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/08c2007af01acb165e9a/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Image Preview/133fc4374bccf2d6b2ca/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/133fc4374bccf2d6b2ca/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Image Preview/2b1daf3b917623fd8ff2/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/2b1daf3b917623fd8ff2/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Image Preview/2d82244403afa8c210e3/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/2d82244403afa8c210e3/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Image Preview/303d58e0a1cdb5b953f9/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/303d58e0a1cdb5b953f9/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Image Preview/4284edec7b2efdd3fef1/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/4284edec7b2efdd3fef1/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Image Preview/8348cbc4b20a443fc03b/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/8348cbc4b20a443fc03b/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Image Preview/86bf806faee9fbb5d92e/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/86bf806faee9fbb5d92e/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Image Preview/9065eadbcdacd92fe32f/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/9065eadbcdacd92fe32f/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Image Preview/e47602f78caeedc4c6d3/image.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/e47602f78caeedc4c6d3/image.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Input Images/0ca5ab7959eed9fef525/astronaut.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/0ca5ab7959eed9fef525/astronaut.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Input Images/10e196d148112fc7eb6a/dragon_birds_woman.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/10e196d148112fc7eb6a/dragon_birds_woman.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Input Images/3a98c4847c8c4e64851e/giorno.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/3a98c4847c8c4e64851e/giorno.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Input Images/43c8d5a2278e70960d7b/image_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/43c8d5a2278e70960d7b/image_1.png
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Input Images/491863b2f47c65f9ac67/image_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/491863b2f47c65f9ac67/image_0.png
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Input Images/7a5a351a562f9423582e/image_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/7a5a351a562f9423582e/image_2.png
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Input Images/830bb33b0629cc76770d/cat_on_table.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/830bb33b0629cc76770d/cat_on_table.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Input Images/bbfa399fd2eb84e96645/chill_guy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/bbfa399fd2eb84e96645/chill_guy.jpg
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Input Images/c2c85d82ceb9a9574063/image_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/c2c85d82ceb9a9574063/image_0.png
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Input Images/c875f74b3d73b2ea804f/astronaut.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/c875f74b3d73b2ea804f/astronaut.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/Input Images/f63287119763bed48822/cat.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/f63287119763bed48822/cat.webp
--------------------------------------------------------------------------------
/assets/gradio_cached_examples/42/log.csv:
--------------------------------------------------------------------------------
1 | Generated Images,Generation Status,Input Images,component 3,Image Preview,flag,username,timestamp
2 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/b6e34a7e72d74a80b52f/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/d8cd3b6318cdd3c5d4bdd4c37c9a2f2df0e47a91647e22cde72befa22a9bb748/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,[],,[],,,2024-12-13 20:14:34.209093
3 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/17b88522de30372d1f35/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/cc0dc567fb017145e3fbf502ef9e10ba9cefa8c45e15c62d3d060eaea986f6b1/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/491863b2f47c65f9ac67/image_0.png"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image_0.png"", ""size"": 1191470, ""orig_name"": ""image_0.png"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/8348cbc4b20a443fc03b/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:14:49.555938
4 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/5d11a379682726a1dc31/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/26f13a9435ed998c8ef570925c40d60345f5ca26b27a7d08c3faa76b397455b5/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/c2c85d82ceb9a9574063/image_0.png"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image_0.png"", ""size"": 628293, ""orig_name"": ""image_0.png"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""assets/gradio_cached_examples/42/Input Images/43c8d5a2278e70960d7b/image_1.png"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image_1.png"", ""size"": 551085, ""orig_name"": ""image_1.png"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""assets/gradio_cached_examples/42/Input Images/7a5a351a562f9423582e/image_2.png"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image_2.png"", ""size"": 620161, ""orig_name"": ""image_2.png"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/9065eadbcdacd92fe32f/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}, {""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/2d82244403afa8c210e3/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}, {""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/86bf806faee9fbb5d92e/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:15:13.029766
5 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/fe7d4925dcd9c0186674/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/6be14ba2f6c3de6620f526c45f5043c6d252a4cfb1d2400a168e70f005ec4a61/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}, {""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/930ca5ca2692b479b752/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}, {""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/53575bb9e63d8bc1a71c/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/830bb33b0629cc76770d/cat_on_table.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/cat_on_table.webp"", ""size"": 41180, ""orig_name"": ""cat_on_table.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/133fc4374bccf2d6b2ca/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:16:09.060610
6 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/581111895a66f93b1f10/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/ce1da06372f563fdca891efc30288e2a85d568f2351b54c36a19b75a77ac8d8b/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/10e196d148112fc7eb6a/dragon_birds_woman.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/dragon_birds_woman.webp"", ""size"": 16180, ""orig_name"": ""dragon_birds_woman.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/2b1daf3b917623fd8ff2/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:16:19.208900
7 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/892c2282d95cba917fa2/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/441882b3d10fbb91786980f9c484e518b3a8056558a7fad1fbe4d7db4aec408e/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/bbfa399fd2eb84e96645/chill_guy.jpg"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/chill_guy.jpg"", ""size"": 61273, ""orig_name"": ""chill_guy.jpg"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/303d58e0a1cdb5b953f9/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:16:33.375436
8 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/317b1e50f07fbdcc2f45/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/12e1af3a972f41b652b3ec04cd7e256ef43815e0e8ded1a9d758ffc18585ada1/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/c875f74b3d73b2ea804f/astronaut.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/astronaut.webp"", ""size"": 5528, ""orig_name"": ""astronaut.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/06eab237f5c35c1b745f/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:16:43.378123
9 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/8a9a8c9e07527586fc15/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/1be0cd7d19f19d2ca9bef889b2d0a17c1ee6e2b1cab471db19103c116461a954/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}, {""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/6beeb89d25287417e72d/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/f63287119763bed48822/cat.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/cat.webp"", ""size"": 57702, ""orig_name"": ""cat.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/e47602f78caeedc4c6d3/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:16:53.637761
10 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/80d358ca48472ae5c466/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/fd6c6cd46e17b3233178109e2fbd1593a88280484c665301fe329d14ba7dd913/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/0ca5ab7959eed9fef525/astronaut.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/astronaut.webp"", ""size"": 62630, ""orig_name"": ""astronaut.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/4284edec7b2efdd3fef1/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:17:05.900479
11 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/96928c780e841920bfcc/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/6f7388d917738246baaadb98c1a8f045374e9c182a3b7697c18a05f8f1bff3e6/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}, {""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/81bd6641e26a411b2da6/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}, {""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/8fd76ebb6657d26dd92c/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}, {""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/c8580411711e5880bb78/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,[],,[],,,2024-12-13 20:18:01.760111
12 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/9bdb84088da3c475e4bd/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/00a1b6bb103d9ed89daaae43449e6975a39134b3f5ba2ce6e23b23f6c9f967f2/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/3a98c4847c8c4e64851e/giorno.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/giorno.webp"", ""size"": 17570, ""orig_name"": ""giorno.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/08c2007af01acb165e9a/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:18:11.526344
13 |
--------------------------------------------------------------------------------
/assets/onediffusion_appendix_faceid.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/onediffusion_appendix_faceid.jpg
--------------------------------------------------------------------------------
/assets/onediffusion_appendix_faceid_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/onediffusion_appendix_faceid_3.jpg
--------------------------------------------------------------------------------
/assets/onediffusion_appendix_multiview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/onediffusion_appendix_multiview.jpg
--------------------------------------------------------------------------------
/assets/onediffusion_appendix_multiview_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/onediffusion_appendix_multiview_2.jpg
--------------------------------------------------------------------------------
/assets/onediffusion_appendix_text2multiview.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/onediffusion_appendix_text2multiview.pdf
--------------------------------------------------------------------------------
/assets/onediffusion_editing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/onediffusion_editing.jpg
--------------------------------------------------------------------------------
/assets/onediffusion_zeroshot.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/onediffusion_zeroshot.jpg
--------------------------------------------------------------------------------
/assets/promptguide_complex.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/promptguide_complex.jpg
--------------------------------------------------------------------------------
/assets/promptguide_idtask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/promptguide_idtask.jpg
--------------------------------------------------------------------------------
/assets/subject_driven.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/subject_driven.jpg
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/teaser.png
--------------------------------------------------------------------------------
/assets/text2image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/text2image.jpg
--------------------------------------------------------------------------------
/assets/text2multiview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/text2multiview.jpg
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # Inspired by https://github.com/anibali/docker-pytorch/blob/master/dockerfiles/1.10.0-cuda11.3-ubuntu20.04/Dockerfile
2 | # ARG COMPAT=0
3 | ARG PERSONAL=0
4 | # FROM nvidia/cuda:11.3.1-devel-ubuntu20.04 as base-0
5 | FROM nvcr.io/nvidia/pytorch:22.12-py3 as base
6 |
7 | ENV HOST docker
8 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
9 | # https://serverfault.com/questions/683605/docker-container-time-timezone-will-not-reflect-changes
10 | ENV TZ America/Los_Angeles
11 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
12 |
13 | # git for installing dependencies
14 | # tzdata to set time zone
15 | # wget and unzip to download data
16 | # [2021-09-09] TD: zsh, stow, subversion, fasd are for setting up my personal environment.
17 | # [2021-12-07] TD: openmpi-bin for MPI (multi-node training)
18 | RUN apt-get update && apt-get install -y --no-install-recommends \
19 | build-essential \
20 | cmake \
21 | curl \
22 | ca-certificates \
23 | sudo \
24 | less \
25 | htop \
26 | git \
27 | tzdata \
28 | wget \
29 | tmux \
30 | zip \
31 | unzip \
32 | zsh stow subversion fasd \
33 | && rm -rf /var/lib/apt/lists/*
34 | # openmpi-bin \
35 |
36 | # Allow running runmpi as root
37 | # ENV OMPI_ALLOW_RUN_AS_ROOT=1 OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1
38 |
39 | # # Create a non-root user and switch to it
40 | # RUN adduser --disabled-password --gecos '' --shell /bin/bash user \
41 | # && echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user
42 | # USER user
43 |
44 | # All users can use /home/user as their home directory
45 | ENV HOME=/home/user
46 | RUN mkdir -p /home/user && chmod 777 /home/user
47 | WORKDIR /home/user
48 |
49 | # Set up personal environment
50 | # FROM base-${COMPAT} as env-0
51 | FROM base as env-0
52 | FROM env-0 as env-1
53 | # Use ONBUILD so that the dotfiles dir doesn't need to exist unless we're building a personal image
54 | # https://stackoverflow.com/questions/31528384/conditional-copy-add-in-dockerfile
55 | ONBUILD COPY dotfiles ./dotfiles
56 | ONBUILD RUN cd ~/dotfiles && stow bash zsh tmux && sudo chsh -s /usr/bin/zsh $(whoami)
57 | # nvcr pytorch image sets SHELL=/bin/bash
58 | ONBUILD ENV SHELL=/bin/zsh
59 |
60 | FROM env-${PERSONAL} as packages
61 |
62 | # Disable pip cache: https://stackoverflow.com/questions/45594707/what-is-pips-no-cache-dir-good-for
63 | ENV PIP_NO_CACHE_DIR=1
64 |
65 | # # apex and pytorch-fast-transformers take a while to compile so we install them first
66 | # TD [2022-04-28] apex is already installed. In case we need a newer commit:
67 | # RUN pip install --upgrade --force-reinstall --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" --global-option="--fmha" --global-option="--fast_layer_norm" --global-option="--xentropy" git+https://github.com/NVIDIA/apex.git#egg=apex
68 |
69 | # xgboost conflicts with deepspeed
70 | RUN pip uninstall -y xgboost && DS_BUILD_UTILS=1 DS_BUILD_FUSED_LAMB=1 pip install deepspeed==0.7.7
71 |
72 | # General packages that we don't care about the version
73 | # zstandard to extract the_pile dataset
74 | # psutil to get the number of cpu physical cores
75 | # twine to upload package to PyPI
76 | RUN pip install pytest matplotlib jupyter ipython ipdb gpustat scikit-learn spacy munch einops opt_einsum fvcore gsutil cmake pykeops zstandard psutil h5py twine gdown \
77 | && python -m spacy download en_core_web_sm
78 | # hydra
79 | RUN pip install hydra-core==1.3.1 hydra-colorlog==1.2.0 hydra-optuna-sweeper==1.2.0 pyrootutils rich
80 | # Core packages
81 | RUN pip install transformers==4.45.2 datasets==3.0.1 pytorch-lightning==2.2.1 triton==2.3.1 wandb==0.16.3 controlnet_aux==0.0.9 timm==0.6.7 torchmetrics==1.3.2
82 | # torchmetrics 0.11.0 broke hydra's instantiate
83 |
84 | # For MLPerf
85 | RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
86 |
87 | RUN pip install accelerate==0.34.2
88 |
89 | RUN pip install diffusers==0.30.3
90 |
91 | RUN pip install deepspeed==0.15.2
92 |
93 | RUN pip install sentencepiece==0.1.99
94 |
95 | RUN pip install pillow==10.2.0
96 |
97 | RUN pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118
98 |
99 | # Install FlashAttention
100 | RUN pip install flash-attn==2.6.3
101 |
102 | # Install CUDA extensions for fused dense
103 | RUN pip install git+https://github.com/Dao-AILab/flash-attention@v2.6.3#subdirectory=csrc/fused_dense_lib
104 |
105 | RUN pip install jaxtyping mediapipe gradio
106 |
107 | RUN pip install "git+https://github.com/facebookresearch/pytorch3d.git"
108 |
109 | RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
110 |
111 | RUN pip install opencv-python==4.5.5.64
112 |
113 | RUN pip install opencv-python-headless==4.5.5.64
114 |
115 | RUN pip install huggingface_hub==0.24
116 |
117 | RUN pip install numpy==1.24.4
118 |
119 |
120 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from onediffusion.diffusion.pipelines.onediffusion import OneDiffusionPipeline
4 | from PIL import Image
5 |
6 |
7 | device = torch.device('cuda:0')
8 | pipeline = OneDiffusionPipeline.from_pretrained("lehduong/OneDiffusion").to(device=device, dtype=torch.bfloat16)
9 | NEGATIVE_PROMPT = "monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
10 | output_dir = "outputs"
11 | os.makedirs(output_dir, exist_ok=True)
12 |
13 |
14 |
15 | ################################################################################################
16 | ## 1. Text-to-image
17 | ################################################################################################
18 | output = pipeline(
19 | prompt="[[text2image]] A bipedal black cat wearing a huge oversized witch hat, a wizards robe, casting a spell,in an enchanted forest. The scene is filled with fireflies and moss on surrounding rocks and trees",
20 | negative_prompt=NEGATIVE_PROMPT,
21 | num_inference_steps=50,
22 | guidance_scale=4,
23 | height=1024,
24 | width=1024,
25 | )
26 | output.images[0].save(f'{output_dir}/text2image_output.jpg')
27 |
28 |
29 |
30 | ################################################################################################
31 | ## 2. Semantic to Image (and any other condition2image tasks):
32 | ################################################################################################
33 | images = [
34 | Image.open("assets/examples/semantic_map/dragon_birds_woman.webp")
35 | ]
36 | prompt = "[[semanticmap2image]] <#00ffff Cyan mask: dragon> <#ff0000 yellow mask: bird> <#800080 purple mask: woman> A woman in a red dress with gold floral patterns stands in a traditional Japanese-style building. She has black hair and wears a gold choker and earrings. Behind her, a large orange and white dragon coils around the structure. Two white birds fly near her. The building features paper windows and a wooden roof with lanterns. The scene blends traditional Japanese architecture with fantastical elements, creating a mystical atmosphere."
37 | # set the denoise mask to [1, 0] to denoise image from conditions
38 | # by default, the height and width will be set so that input image is minimally cropped
39 | ret = pipeline.img2img(
40 | image=images,
41 | num_inference_steps=50,
42 | prompt=prompt,
43 | denoise_mask=[1, 0],
44 | guidance_scale=4,
45 | negative_prompt=NEGATIVE_PROMPT,
46 | # height=512,
47 | # width=512,
48 | )
49 | ret.images[0].save(f"{output_dir}/semanticmap2image_output.jpg")
50 |
51 |
52 |
53 | ################################################################################################
54 | ## 3. Depth Estimation (and any other image2condtition tasks):
55 | ################################################################################################
56 | images = [
57 | Image.open("assets/examples/images/cat_on_table.webp"),
58 | ]
59 | prompt = "[[depth2image]] cat sitting on a table" # you can omit caption i.e., setting prompt to "[[depth2image]]"
60 | # set the denoise mask to [0, 1] to denoise condition (depth, pose, hed, canny, semantic map etc)
61 | # by default, the height and width will be set so that input image is minimally cropped
62 | ret = pipeline.img2img(
63 | image=images,
64 | num_inference_steps=50,
65 | prompt=prompt,
66 | denoise_mask=[0, 1],
67 | guidance_scale=4,
68 | NEGATIVE_PROMPT=NEGATIVE_PROMPT,
69 | # height=512,
70 | # width=512,
71 | )
72 | ret.images[0].save(f"{output_dir}/image2depth_output.jpg")
73 |
74 |
75 |
76 | ################################################################################################
77 | ## 4. ID Customization
78 | ################################################################################################
79 | images = [
80 | Image.open("assets/examples/id_customization/chenhao/image_0.png"),
81 | Image.open("assets/examples/id_customization/chenhao/image_1.png"),
82 | Image.open("assets/examples/id_customization/chenhao/image_2.png")
83 | ]
84 | ##### we will set the denoise mask to [1, 0, 0, 0], which mean we the generated image will be the first view and next three views will be the condition (with same order as `images` list)
85 | #### prompt format: [[faceid]] [[img0]] target/caption [[img1]] caption/of/first/image [[img2]] caption/of/second/image [[img3]] caption/of/third/image
86 | prompt = "[[faceid]] \
87 | [[img0]] A woman dressed in traditional attire with intricate headpieces. She is looking at the camera and having neutral expression. \
88 | [[img1]] A woman with long dark hair, smiling warmly while wearing a floral dress. \
89 | [[img2]] A woman in traditional clothing holding a lace parasol, with her hair styled elegantly. \
90 | [[img3]] A woman in elaborate traditional attire and jewelry, with an ornate headdress, looking intently forward. \
91 | "
92 | # by default, all images will be cropped according to the FIRST input image.
93 | ret = pipeline.img2img(image=images, num_inference_steps=75, prompt=prompt, denoise_mask=[1, 0, 0, 0], guidance_scale=4, negative_prompt=NEGATIVE_PROMPT)
94 | ret.images[0].save(f"{output_dir}/idcustomization_output.jpg")
95 |
96 |
97 |
98 | ################################################################################################
99 | ## 5. Image to multiview
100 | ################################################################################################
101 | images = [
102 | Image.open("assets/examples/images/cat_on_table.webp"),
103 | ]
104 | prompt = "[[multiview]] A cat with orange and white fur sits on a round wooden table. The cat has striking green eyes and a pink nose. Its ears are perked up, and its tail is curled around its body. The background is blurred, showing a white wall, a wooden chair, and a wooden table with a white pot and green plant. A white curtain is visible on the right side. The cat's gaze is directed slightly to the right, and its paws are white. The overall scene creates a cozy, domestic atmosphere with the cat as the central focus."
105 | # denoise mask: [0, 0, 1, 0, 1, 0, 1, 0] is for [img_0, camera of img0, img_1, camera of img1, img_2, camera of img2, img_3, camera of img3]
106 | # since we provide 1 input image and all camera positions, the we don't need to denoise those value and set the mask to 0
107 | # set the mask of [img_1, img_2, img_3] to 1 to generate novel views
108 | # NOTE: only support SQUARE image
109 | ret = pipeline.img2img(
110 | image=images,
111 | num_inference_steps=60,
112 | prompt=prompt,
113 | negative_prompt=NEGATIVE_PROMPT,
114 | denoise_mask=[0, 0, 1, 0, 1, 0, 1, 0],
115 | guidance_scale=4,
116 | multiview_azimuths=[0,20,40,60], # relative azimuth to first views
117 | multiview_elevations=[0,0,0,0], # relative elevation to first views
118 | multiview_distances=[1.5,1.5,1.5,1.5],
119 | # multiview_c2ws=None, # you can provide the camera-to-world matrix of shape [N, 4,4] for ALL views, camera extrinsics matrix is relative to first view
120 | # multiview_intrinsics=None, # provide the intrinsics matrix if c2ws is used
121 | height=512,
122 | width=512,
123 | is_multiview=True,
124 | )
125 | for i in range(len(ret.images)):
126 | ret.images[i].save(f"{output_dir}/img2multiview_output_view{i+1}.jpg")
127 |
128 |
129 |
--------------------------------------------------------------------------------
/onediffusion/dataset/multitask/multiview.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | from PIL import Image
5 | import torch
6 | from typing import List, Tuple, Union
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms
9 | import torchvision.transforms as T
10 | from onediffusion.dataset.utils import *
11 | import glob
12 |
13 | from onediffusion.dataset.raydiff_utils import cameras_to_rays, first_camera_transform, normalize_cameras
14 | from onediffusion.dataset.transforms import CenterCropResizeImage
15 | from pytorch3d.renderer import PerspectiveCameras
16 |
17 | import numpy as np
18 |
19 | def _cameras_from_opencv_projection(
20 | R: torch.Tensor,
21 | tvec: torch.Tensor,
22 | camera_matrix: torch.Tensor,
23 | image_size: torch.Tensor,
24 | do_normalize_cameras,
25 | normalize_scale,
26 | ) -> PerspectiveCameras:
27 | focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
28 | principal_point = camera_matrix[:, :2, 2]
29 |
30 | # Retype the image_size correctly and flip to width, height.
31 | image_size_wh = image_size.to(R).flip(dims=(1,))
32 |
33 | # Screen to NDC conversion:
34 | # For non square images, we scale the points such that smallest side
35 | # has range [-1, 1] and the largest side has range [-u, u], with u > 1.
36 | # This convention is consistent with the PyTorch3D renderer, as well as
37 | # the transformation function `get_ndc_to_screen_transform`.
38 | scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0
39 | scale = scale.expand(-1, 2)
40 | c0 = image_size_wh / 2.0
41 |
42 | # Get the PyTorch3D focal length and principal point.
43 | focal_pytorch3d = focal_length / scale
44 | p0_pytorch3d = -(principal_point - c0) / scale
45 |
46 | # For R, T we flip x, y axes (opencv screen space has an opposite
47 | # orientation of screen axes).
48 | # We also transpose R (opencv multiplies points from the opposite=left side).
49 | R_pytorch3d = R.clone().permute(0, 2, 1)
50 | T_pytorch3d = tvec.clone()
51 | R_pytorch3d[:, :, :2] *= -1
52 | T_pytorch3d[:, :2] *= -1
53 |
54 | cams = PerspectiveCameras(
55 | R=R_pytorch3d,
56 | T=T_pytorch3d,
57 | focal_length=focal_pytorch3d,
58 | principal_point=p0_pytorch3d,
59 | image_size=image_size,
60 | device=R.device,
61 | )
62 |
63 | if do_normalize_cameras:
64 | cams, _ = normalize_cameras(cams, scale=normalize_scale)
65 |
66 | cams = first_camera_transform(cams, rotation_only=False)
67 | return cams
68 |
69 | def calculate_rays(Ks, sizes, Rs, Ts, target_size, use_plucker=True, do_normalize_cameras=False, normalize_scale=1.0):
70 | cameras = _cameras_from_opencv_projection(
71 | R=Rs,
72 | tvec=Ts,
73 | camera_matrix=Ks,
74 | image_size=sizes,
75 | do_normalize_cameras=do_normalize_cameras,
76 | normalize_scale=normalize_scale
77 | )
78 |
79 | rays_embedding = cameras_to_rays(
80 | cameras=cameras,
81 | num_patches_x=target_size,
82 | num_patches_y=target_size,
83 | crop_parameters=None,
84 | use_plucker=use_plucker
85 | )
86 |
87 | return rays_embedding.rays
88 |
89 | def convert_rgba_to_rgb_white_bg(image):
90 | """Convert RGBA image to RGB with white background"""
91 | if image.mode == 'RGBA':
92 | # Create a white background
93 | background = Image.new('RGBA', image.size, (255, 255, 255, 255))
94 | # Composite the image onto the white background
95 | return Image.alpha_composite(background, image).convert('RGB')
96 | return image.convert('RGB')
97 |
98 | class MultiviewDataset(Dataset):
99 | def __init__(
100 | self,
101 | scene_folders: str,
102 | samples_per_set: Union[int, Tuple[int, int]], # Changed from samples_per_set to samples_range
103 | transform=None,
104 | caption_keys: Union[str, List] = "caption",
105 | multiscale=False,
106 | aspect_ratio_type=ASPECT_RATIO_512,
107 | c2w_scaling=1.7,
108 | default_max_distance=1, # default max distance from all camera of a scene ,
109 | do_normalize=True, # whether normalize translation of c2w with max_distance
110 | swap_xz=False, # whether swap x and z axis of 3D scenes
111 | valid_paths: str = "",
112 | frame_sliding_windows: float = None # limit all sampled frames to be within this window, so that camera poses won't be too different
113 | ):
114 | if not isinstance(samples_per_set, tuple) and not isinstance(samples_per_set, list):
115 | samples_per_set = (samples_per_set, samples_per_set)
116 | self.samples_range = samples_per_set # Tuple of (min_samples, max_samples)
117 | self.transform = transform
118 | self.caption_keys = caption_keys if isinstance(caption_keys, list) else [caption_keys]
119 | self.aspect_ratio = aspect_ratio_type
120 | self.scene_folders = sorted(glob.glob(scene_folders))
121 | # filter out scene folders that do not have transforms.json
122 | self.scene_folders = list(filter(lambda x: os.path.exists(os.path.join(x, "transforms.json")), self.scene_folders))
123 |
124 | # if valid_paths.txt exists, only use paths in that file
125 | if os.path.exists(valid_paths):
126 | with open(valid_paths, 'r') as f:
127 | valid_scene_folders = f.read().splitlines()
128 | self.scene_folders = sorted(valid_scene_folders)
129 |
130 | self.c2w_scaling = c2w_scaling
131 | self.do_normalize = do_normalize
132 | self.default_max_distance = default_max_distance
133 | self.swap_xz = swap_xz
134 | self.frame_sliding_windows = frame_sliding_windows
135 |
136 | if multiscale:
137 | assert self.aspect_ratio in [ASPECT_RATIO_512, ASPECT_RATIO_1024, ASPECT_RATIO_2048, ASPECT_RATIO_2880]
138 | if self.aspect_ratio in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]:
139 | self.interpolate_model = T.InterpolationMode.LANCZOS
140 | self.ratio_index = {}
141 | self.ratio_nums = {}
142 | for k, v in self.aspect_ratio.items():
143 | self.ratio_index[float(k)] = [] # used for self.getitem
144 | self.ratio_nums[float(k)] = 0 # used for batch-sampler
145 |
146 | def __len__(self):
147 | return len(self.scene_folders)
148 |
149 | def __getitem__(self, idx):
150 | try:
151 | scene_path = self.scene_folders[idx]
152 |
153 | if os.path.exists(os.path.join(scene_path, "images")):
154 | image_folder = os.path.join(scene_path, "images")
155 | downscale_factor = 1
156 | elif os.path.exists(os.path.join(scene_path, "images_4")):
157 | image_folder = os.path.join(scene_path, "images_4")
158 | downscale_factor = 1 / 4
159 | elif os.path.exists(os.path.join(scene_path, "images_8")):
160 | image_folder = os.path.join(scene_path, "images_8")
161 | downscale_factor = 1 / 8
162 | else:
163 | raise NotImplementedError
164 |
165 | json_path = os.path.join(scene_path, "transforms.json")
166 | caption_path = os.path.join(scene_path, "caption.json")
167 | image_files = os.listdir(image_folder)
168 |
169 | with open(json_path, 'r') as f:
170 | json_data = json.load(f)
171 | height, width = json_data['h'], json_data['w']
172 |
173 | dh, dw = int(height * downscale_factor), int(width * downscale_factor)
174 | fl_x, fl_y = json_data['fl_x'] * downscale_factor, json_data['fl_y'] * downscale_factor
175 | cx = dw // 2
176 | cy = dh // 2
177 |
178 | frame_list = json_data['frames']
179 |
180 | # Randomly select number of samples
181 |
182 | samples_per_set = random.randint(self.samples_range[0], self.samples_range[1])
183 |
184 | # uniformly for all scenes
185 | if self.frame_sliding_windows is None:
186 | selected_indices = random.sample(range(len(frame_list)), min(samples_per_set, len(frame_list)))
187 | # limit the multiview to be in a sliding window (to avoid catastrophic difference in camera angles)
188 | else:
189 | # Determine the starting index of the sliding window
190 | if len(frame_list) <= self.frame_sliding_windows:
191 | # If the frame list is smaller than or equal to X, use the entire list
192 | window_start = 0
193 | window_end = len(frame_list)
194 | else:
195 | # Randomly select a starting point for the window
196 | window_start = random.randint(0, len(frame_list) - self.frame_sliding_windows)
197 | window_end = window_start + self.frame_sliding_windows
198 |
199 | # Get the indices within the sliding window
200 | window_indices = list(range(window_start, window_end))
201 |
202 | # Randomly sample indices from the window
203 | selected_indices = random.sample(window_indices, samples_per_set)
204 |
205 | image_files = [os.path.basename(frame_list[i]['file_path']) for i in selected_indices]
206 | image_paths = [os.path.join(image_folder, file) for file in image_files]
207 |
208 | # Load images and convert RGBA to RGB with white background
209 | images = [convert_rgba_to_rgb_white_bg(Image.open(image_path)) for image_path in image_paths]
210 |
211 | if self.transform:
212 | images = [self.transform(image) for image in images]
213 | else:
214 | closest_size, closest_ratio = self.aspect_ratio['1.0'], 1.0
215 | closest_size = tuple(map(int, closest_size))
216 | transform = T.Compose([
217 | T.ToTensor(),
218 | CenterCropResizeImage(closest_size),
219 | T.Normalize([.5], [.5]),
220 | ])
221 | images = [transform(image) for image in images]
222 | images = torch.stack(images)
223 |
224 | c2ws = [frame_list[i]['transform_matrix'] for i in selected_indices]
225 | c2ws = torch.tensor(c2ws).reshape(-1, 4, 4)
226 | # max_distance = json_data.get('max_distance', self.default_max_distance)
227 | # if 'max_distance' not in json_data.keys():
228 | # print(f"not found `max_distance` in json path: {json_path}")
229 |
230 | if self.swap_xz:
231 | swap_xz = torch.tensor([[[0, 0, 1., 0],
232 | [0, 1., 0, 0],
233 | [-1., 0, 0, 0],
234 | [0, 0, 0, 1.]]])
235 | c2ws = swap_xz @ c2ws
236 |
237 | # OPENGL to OPENCV
238 | c2ws[:, 0:3, 1:3] *= -1
239 | c2ws = c2ws[:, [1, 0, 2, 3], :]
240 | c2ws[:, 2, :] *= -1
241 |
242 | w2cs = torch.inverse(c2ws)
243 | K = torch.tensor([[[fl_x, 0, cx], [0, fl_y, cy], [0, 0, 1]]]).repeat(len(c2ws), 1, 1)
244 | Rs = w2cs[:, :3, :3]
245 | Ts = w2cs[:, :3, 3]
246 | sizes = torch.tensor([[dh, dw]]).repeat(len(c2ws), 1)
247 |
248 | # get ray embedding and padding last dimension to 16 (num channels of VAE)
249 | # rays_od = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, use_plucker=False, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling)
250 | rays = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling)
251 | rays = rays.reshape(samples_per_set, closest_size[0] // 8, closest_size[1] // 8, 6)
252 | # padding = (0, 10) # pad the last dimension to 16
253 | # rays = torch.nn.functional.pad(rays, padding, "constant", 0)
254 | rays = torch.cat([rays, rays, rays[..., :4]], dim=-1) * 1.658
255 |
256 | if os.path.exists(caption_path):
257 | with open(caption_path, 'r') as f:
258 | caption_key = random.choice(self.caption_keys)
259 | caption = json.load(f).get(caption_key, "")
260 | else:
261 | caption = ""
262 |
263 | caption = "[[multiview]] " + caption if caption else "[[multiview]]"
264 |
265 | return {
266 | 'pixel_values': images,
267 | 'rays': rays,
268 | 'aspect_ratio': closest_ratio,
269 | 'caption': caption,
270 | 'height': dh,
271 | 'width': dw,
272 | # 'origins': rays_od[..., :3],
273 | # 'dirs': rays_od[..., 3:6]
274 | }
275 | except Exception as e:
276 | return self.__getitem__(random.randint(0, len(self.scene_folders) - 1))
277 |
--------------------------------------------------------------------------------
/onediffusion/dataset/raydiff_utils.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | Adapted from code originally written by David Novotny.
4 | """
5 |
6 | import torch
7 | from pytorch3d.transforms import Rotate, Translate
8 |
9 | import cv2
10 | import numpy as np
11 | import torch
12 | from pytorch3d.renderer import PerspectiveCameras, RayBundle
13 |
14 | def intersect_skew_line_groups(p, r, mask):
15 | # p, r both of shape (B, N, n_intersected_lines, 3)
16 | # mask of shape (B, N, n_intersected_lines)
17 | p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
18 | if p_intersect is None:
19 | return None, None, None, None
20 | _, p_line_intersect = point_line_distance(
21 | p, r, p_intersect[..., None, :].expand_as(p)
22 | )
23 | intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(
24 | dim=-1
25 | )
26 | return p_intersect, p_line_intersect, intersect_dist_squared, r
27 |
28 |
29 | def intersect_skew_lines_high_dim(p, r, mask=None):
30 | # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
31 | dim = p.shape[-1]
32 | # make sure the heading vectors are l2-normed
33 | if mask is None:
34 | mask = torch.ones_like(p[..., 0])
35 | r = torch.nn.functional.normalize(r, dim=-1)
36 |
37 | eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
38 | I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
39 | sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
40 |
41 | # I_eps = torch.zeros_like(I_min_cov.sum(dim=-3)) + 1e-10
42 | # p_intersect = torch.pinverse(I_min_cov.sum(dim=-3) + I_eps).matmul(sum_proj)[..., 0]
43 | p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
44 |
45 | # I_min_cov.sum(dim=-3): torch.Size([1, 1, 3, 3])
46 | # sum_proj: torch.Size([1, 1, 3, 1])
47 |
48 | # p_intersect = np.linalg.lstsq(I_min_cov.sum(dim=-3).numpy(), sum_proj.numpy(), rcond=None)[0]
49 |
50 | if torch.any(torch.isnan(p_intersect)):
51 | print(p_intersect)
52 | return None, None
53 | ipdb.set_trace()
54 | assert False
55 | return p_intersect, r
56 |
57 |
58 | def point_line_distance(p1, r1, p2):
59 | df = p2 - p1
60 | proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
61 | line_pt_nearest = p2 - proj_vector
62 | d = (proj_vector).norm(dim=-1)
63 | return d, line_pt_nearest
64 |
65 |
66 | def compute_optical_axis_intersection(cameras):
67 | centers = cameras.get_camera_center()
68 | principal_points = cameras.principal_point
69 |
70 | one_vec = torch.ones((len(cameras), 1), device=centers.device)
71 | optical_axis = torch.cat((principal_points, one_vec), -1)
72 |
73 | # optical_axis = torch.cat(
74 | # (principal_points, cameras.focal_length[:, 0].unsqueeze(1)), -1
75 | # )
76 |
77 | pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True)
78 | pp2 = torch.diagonal(pp, dim1=0, dim2=1).T
79 |
80 | directions = pp2 - centers
81 | centers = centers.unsqueeze(0).unsqueeze(0)
82 | directions = directions.unsqueeze(0).unsqueeze(0)
83 |
84 | p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(
85 | p=centers, r=directions, mask=None
86 | )
87 |
88 | if p_intersect is None:
89 | dist = None
90 | else:
91 | p_intersect = p_intersect.squeeze().unsqueeze(0)
92 | dist = (p_intersect - centers).norm(dim=-1)
93 |
94 | return p_intersect, dist, p_line_intersect, pp2, r
95 |
96 |
97 | def normalize_cameras(cameras, scale=1.0):
98 | """
99 | Normalizes cameras such that the optical axes point to the origin, the rotation is
100 | identity, and the norm of the translation of the first camera is 1.
101 |
102 | Args:
103 | cameras (pytorch3d.renderer.cameras.CamerasBase).
104 | scale (float): Norm of the translation of the first camera.
105 |
106 | Returns:
107 | new_cameras (pytorch3d.renderer.cameras.CamerasBase): Normalized cameras.
108 | undo_transform (function): Function that undoes the normalization.
109 | """
110 |
111 | # Let distance from first camera to origin be unit
112 | new_cameras = cameras.clone()
113 | new_transform = (
114 | new_cameras.get_world_to_view_transform()
115 | ) # potential R is not valid matrix
116 | p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(
117 | cameras
118 | )
119 |
120 | if p_intersect is None:
121 | print("Warning: optical axes code has a nan. Returning identity cameras.")
122 | new_cameras.R[:] = torch.eye(3, device=cameras.R.device, dtype=cameras.R.dtype)
123 | new_cameras.T[:] = torch.tensor(
124 | [0, 0, 1], device=cameras.T.device, dtype=cameras.T.dtype
125 | )
126 | return new_cameras, lambda x: x
127 |
128 | d = dist.squeeze(dim=1).squeeze(dim=0)[0]
129 | # Degenerate case
130 | if d == 0:
131 | print(cameras.T)
132 | print(new_transform.get_matrix()[:, 3, :3])
133 | assert False
134 | assert d != 0
135 |
136 | # Can't figure out how to make scale part of the transform too without messing up R.
137 | # Ideally, we would just wrap it all in a single Pytorch3D transform so that it
138 | # would work with any structure (eg PointClouds, Meshes).
139 | tR = Rotate(new_cameras.R[0].unsqueeze(0)).inverse()
140 | tT = Translate(p_intersect)
141 | t = tR.compose(tT)
142 |
143 | new_transform = t.compose(new_transform)
144 | new_cameras.R = new_transform.get_matrix()[:, :3, :3]
145 | new_cameras.T = new_transform.get_matrix()[:, 3, :3] / d * scale
146 |
147 | def undo_transform(cameras):
148 | cameras_copy = cameras.clone()
149 | cameras_copy.T *= d / scale
150 | new_t = (
151 | t.inverse().compose(cameras_copy.get_world_to_view_transform()).get_matrix()
152 | )
153 | cameras_copy.R = new_t[:, :3, :3]
154 | cameras_copy.T = new_t[:, 3, :3]
155 | return cameras_copy
156 |
157 | return new_cameras, undo_transform
158 |
159 | def first_camera_transform(cameras, rotation_only=True):
160 | new_cameras = cameras.clone()
161 | new_transform = new_cameras.get_world_to_view_transform()
162 | tR = Rotate(new_cameras.R[0].unsqueeze(0))
163 | if rotation_only:
164 | t = tR.inverse()
165 | else:
166 | tT = Translate(new_cameras.T[0].unsqueeze(0))
167 | t = tR.compose(tT).inverse()
168 |
169 | new_transform = t.compose(new_transform)
170 | new_cameras.R = new_transform.get_matrix()[:, :3, :3]
171 | new_cameras.T = new_transform.get_matrix()[:, 3, :3]
172 |
173 | return new_cameras
174 |
175 |
176 | def get_identity_cameras_with_intrinsics(cameras):
177 | D = len(cameras)
178 | device = cameras.R.device
179 |
180 | new_cameras = cameras.clone()
181 | new_cameras.R = torch.eye(3, device=device).unsqueeze(0).repeat((D, 1, 1))
182 | new_cameras.T = torch.zeros((D, 3), device=device)
183 |
184 | return new_cameras
185 |
186 |
187 | def normalize_cameras_batch(cameras, scale=1.0, normalize_first_camera=False):
188 | new_cameras = []
189 | undo_transforms = []
190 | for cam in cameras:
191 | if normalize_first_camera:
192 | # Normalize cameras such that first camera is identity and origin is at
193 | # first camera center.
194 | normalized_cameras = first_camera_transform(cam, rotation_only=False)
195 | undo_transform = None
196 | else:
197 | normalized_cameras, undo_transform = normalize_cameras(cam, scale=scale)
198 | new_cameras.append(normalized_cameras)
199 | undo_transforms.append(undo_transform)
200 | return new_cameras, undo_transforms
201 |
202 |
203 | class Rays(object):
204 | def __init__(
205 | self,
206 | rays=None,
207 | origins=None,
208 | directions=None,
209 | moments=None,
210 | is_plucker=False,
211 | moments_rescale=1.0,
212 | ndc_coordinates=None,
213 | crop_parameters=None,
214 | num_patches_x=16,
215 | num_patches_y=16,
216 | ):
217 | """
218 | Ray class to keep track of current ray representation.
219 |
220 | Args:
221 | rays: (..., 6).
222 | origins: (..., 3).
223 | directions: (..., 3).
224 | moments: (..., 3).
225 | is_plucker: If True, rays are in plucker coordinates (Default: False).
226 | moments_rescale: Rescale the moment component of the rays by a scalar.
227 | ndc_coordinates: (..., 2): NDC coordinates of each ray.
228 | """
229 | if rays is not None:
230 | self.rays = rays
231 | self._is_plucker = is_plucker
232 | elif origins is not None and directions is not None:
233 | self.rays = torch.cat((origins, directions), dim=-1)
234 | self._is_plucker = False
235 | elif directions is not None and moments is not None:
236 | self.rays = torch.cat((directions, moments), dim=-1)
237 | self._is_plucker = True
238 | else:
239 | raise Exception("Invalid combination of arguments")
240 |
241 | if moments_rescale != 1.0:
242 | self.rescale_moments(moments_rescale)
243 |
244 | if ndc_coordinates is not None:
245 | self.ndc_coordinates = ndc_coordinates
246 | elif crop_parameters is not None:
247 | # (..., H, W, 2)
248 | xy_grid = compute_ndc_coordinates(
249 | crop_parameters,
250 | num_patches_x=num_patches_x,
251 | num_patches_y=num_patches_y,
252 | )[..., :2]
253 | xy_grid = xy_grid.reshape(*xy_grid.shape[:-3], -1, 2)
254 | self.ndc_coordinates = xy_grid
255 | else:
256 | self.ndc_coordinates = None
257 |
258 | def __getitem__(self, index):
259 | return Rays(
260 | rays=self.rays[index],
261 | is_plucker=self._is_plucker,
262 | ndc_coordinates=(
263 | self.ndc_coordinates[index]
264 | if self.ndc_coordinates is not None
265 | else None
266 | ),
267 | )
268 |
269 | def to_spatial(self, include_ndc_coordinates=False):
270 | """
271 | Converts rays to spatial representation: (..., H * W, 6) --> (..., 6, H, W)
272 |
273 | Returns:
274 | torch.Tensor: (..., 6, H, W)
275 | """
276 | rays = self.to_plucker().rays
277 | *batch_dims, P, D = rays.shape
278 | H = W = int(np.sqrt(P))
279 | assert H * W == P
280 | rays = torch.transpose(rays, -1, -2) # (..., 6, H * W)
281 | rays = rays.reshape(*batch_dims, D, H, W)
282 | if include_ndc_coordinates:
283 | ndc_coords = self.ndc_coordinates.transpose(-1, -2) # (..., 2, H * W)
284 | ndc_coords = ndc_coords.reshape(*batch_dims, 2, H, W)
285 | rays = torch.cat((rays, ndc_coords), dim=-3)
286 | return rays
287 |
288 | def rescale_moments(self, scale):
289 | """
290 | Rescale the moment component of the rays by a scalar. Might be desirable since
291 | moments may come from a very narrow distribution.
292 |
293 | Note that this modifies in place!
294 | """
295 | if self.is_plucker:
296 | self.rays[..., 3:] *= scale
297 | return self
298 | else:
299 | return self.to_plucker().rescale_moments(scale)
300 |
301 | @classmethod
302 | def from_spatial(cls, rays, moments_rescale=1.0, ndc_coordinates=None):
303 | """
304 | Converts rays from spatial representation: (..., 6, H, W) --> (..., H * W, 6)
305 |
306 | Args:
307 | rays: (..., 6, H, W)
308 |
309 | Returns:
310 | Rays: (..., H * W, 6)
311 | """
312 | *batch_dims, D, H, W = rays.shape
313 | rays = rays.reshape(*batch_dims, D, H * W)
314 | rays = torch.transpose(rays, -1, -2)
315 | return cls(
316 | rays=rays,
317 | is_plucker=True,
318 | moments_rescale=moments_rescale,
319 | ndc_coordinates=ndc_coordinates,
320 | )
321 |
322 | def to_point_direction(self, normalize_moment=True):
323 | """
324 | Convert to point direction representation .
325 |
326 | Returns:
327 | rays: (..., 6).
328 | """
329 | if self._is_plucker:
330 | direction = torch.nn.functional.normalize(self.rays[..., :3], dim=-1)
331 | moment = self.rays[..., 3:]
332 | if normalize_moment:
333 | c = torch.linalg.norm(direction, dim=-1, keepdim=True)
334 | moment = moment / c
335 | points = torch.cross(direction, moment, dim=-1)
336 | return Rays(
337 | rays=torch.cat((points, direction), dim=-1),
338 | is_plucker=False,
339 | ndc_coordinates=self.ndc_coordinates,
340 | )
341 | else:
342 | return self
343 |
344 | def to_plucker(self):
345 | """
346 | Convert to plucker representation .
347 | """
348 | if self.is_plucker:
349 | return self
350 | else:
351 | ray = self.rays.clone()
352 | ray_origins = ray[..., :3]
353 | ray_directions = ray[..., 3:]
354 | # Normalize ray directions to unit vectors
355 | ray_directions = ray_directions / ray_directions.norm(dim=-1, keepdim=True)
356 | plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1)
357 | new_ray = torch.cat([ray_directions, plucker_normal], dim=-1)
358 | return Rays(
359 | rays=new_ray, is_plucker=True, ndc_coordinates=self.ndc_coordinates
360 | )
361 |
362 | def get_directions(self, normalize=True):
363 | if self.is_plucker:
364 | directions = self.rays[..., :3]
365 | else:
366 | directions = self.rays[..., 3:]
367 | if normalize:
368 | directions = torch.nn.functional.normalize(directions, dim=-1)
369 | return directions
370 |
371 | def get_origins(self):
372 | if self.is_plucker:
373 | origins = self.to_point_direction().get_origins()
374 | else:
375 | origins = self.rays[..., :3]
376 | return origins
377 |
378 | def get_moments(self):
379 | if self.is_plucker:
380 | moments = self.rays[..., 3:]
381 | else:
382 | moments = self.to_plucker().get_moments()
383 | return moments
384 |
385 | def get_ndc_coordinates(self):
386 | return self.ndc_coordinates
387 |
388 | @property
389 | def is_plucker(self):
390 | return self._is_plucker
391 |
392 | @property
393 | def device(self):
394 | return self.rays.device
395 |
396 | def __repr__(self, *args, **kwargs):
397 | ray_str = self.rays.__repr__(*args, **kwargs)[6:] # remove "tensor"
398 | if self._is_plucker:
399 | return "PluRay" + ray_str
400 | else:
401 | return "DirRay" + ray_str
402 |
403 | def to(self, device):
404 | self.rays = self.rays.to(device)
405 |
406 | def clone(self):
407 | return Rays(rays=self.rays.clone(), is_plucker=self._is_plucker)
408 |
409 | @property
410 | def shape(self):
411 | return self.rays.shape
412 |
413 | def visualize(self):
414 | directions = torch.nn.functional.normalize(self.get_directions(), dim=-1).cpu()
415 | moments = torch.nn.functional.normalize(self.get_moments(), dim=-1).cpu()
416 | return (directions + 1) / 2, (moments + 1) / 2
417 |
418 | def to_ray_bundle(self, length=0.3, recenter=True):
419 | lengths = torch.ones_like(self.get_origins()[..., :2]) * length
420 | lengths[..., 0] = 0
421 | if recenter:
422 | centers, _ = intersect_skew_lines_high_dim(
423 | self.get_origins(), self.get_directions()
424 | )
425 | centers = centers.unsqueeze(1).repeat(1, lengths.shape[1], 1)
426 | else:
427 | centers = self.get_origins()
428 | return RayBundle(
429 | origins=centers,
430 | directions=self.get_directions(),
431 | lengths=lengths,
432 | xys=self.get_directions(),
433 | )
434 |
435 |
436 | def cameras_to_rays(
437 | cameras,
438 | crop_parameters,
439 | use_half_pix=True,
440 | use_plucker=True,
441 | num_patches_x=16,
442 | num_patches_y=16,
443 | ):
444 | """
445 | Unprojects rays from camera center to grid on image plane.
446 |
447 | Args:
448 | cameras: Pytorch3D cameras to unproject. Can be batched.
449 | crop_parameters: Crop parameters in NDC (cc_x, cc_y, crop_width, scale).
450 | Shape is (B, 4).
451 | use_half_pix: If True, use half pixel offset (Default: True).
452 | use_plucker: If True, return rays in plucker coordinates (Default: False).
453 | num_patches_x: Number of patches in x direction (Default: 16).
454 | num_patches_y: Number of patches in y direction (Default: 16).
455 | """
456 | unprojected = []
457 | crop_parameters_list = (
458 | crop_parameters if crop_parameters is not None else [None for _ in cameras]
459 | )
460 | for camera, crop_param in zip(cameras, crop_parameters_list):
461 | xyd_grid = compute_ndc_coordinates(
462 | crop_parameters=crop_param,
463 | use_half_pix=use_half_pix,
464 | num_patches_x=num_patches_x,
465 | num_patches_y=num_patches_y,
466 | )
467 |
468 | unprojected.append(
469 | camera.unproject_points(
470 | xyd_grid.reshape(-1, 3), world_coordinates=True, from_ndc=True
471 | )
472 | )
473 | unprojected = torch.stack(unprojected, dim=0) # (N, P, 3)
474 | origins = cameras.get_camera_center().unsqueeze(1) # (N, 1, 3)
475 | origins = origins.repeat(1, num_patches_x * num_patches_y, 1) # (N, P, 3)
476 | directions = unprojected - origins
477 |
478 | rays = Rays(
479 | origins=origins,
480 | directions=directions,
481 | crop_parameters=crop_parameters,
482 | num_patches_x=num_patches_x,
483 | num_patches_y=num_patches_y,
484 | )
485 | if use_plucker:
486 | return rays.to_plucker()
487 | return rays
488 |
489 |
490 | def rays_to_cameras(
491 | rays,
492 | crop_parameters,
493 | num_patches_x=16,
494 | num_patches_y=16,
495 | use_half_pix=True,
496 | sampled_ray_idx=None,
497 | cameras=None,
498 | focal_length=(3.453,),
499 | ):
500 | """
501 | If cameras are provided, will use those intrinsics. Otherwise will use the provided
502 | focal_length(s). Dataset default is 3.32.
503 |
504 | Args:
505 | rays (Rays): (N, P, 6)
506 | crop_parameters (torch.Tensor): (N, 4)
507 | """
508 | device = rays.device
509 | origins = rays.get_origins()
510 | directions = rays.get_directions()
511 | camera_centers, _ = intersect_skew_lines_high_dim(origins, directions)
512 |
513 | # Retrieve target rays
514 | if cameras is None:
515 | if len(focal_length) == 1:
516 | focal_length = focal_length * rays.shape[0]
517 | I_camera = PerspectiveCameras(focal_length=focal_length, device=device)
518 | else:
519 | # Use same intrinsics but reset to identity extrinsics.
520 | I_camera = cameras.clone()
521 | I_camera.R[:] = torch.eye(3, device=device)
522 | I_camera.T[:] = torch.zeros(3, device=device)
523 | I_patch_rays = cameras_to_rays(
524 | cameras=I_camera,
525 | num_patches_x=num_patches_x,
526 | num_patches_y=num_patches_y,
527 | use_half_pix=use_half_pix,
528 | crop_parameters=crop_parameters,
529 | ).get_directions()
530 |
531 | if sampled_ray_idx is not None:
532 | I_patch_rays = I_patch_rays[:, sampled_ray_idx]
533 |
534 | # Compute optimal rotation to align rays
535 | R = torch.zeros_like(I_camera.R)
536 | for i in range(len(I_camera)):
537 | R[i] = compute_optimal_rotation_alignment(
538 | I_patch_rays[i],
539 | directions[i],
540 | )
541 |
542 | # Construct and return rotated camera
543 | cam = I_camera.clone()
544 | cam.R = R
545 | cam.T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2)
546 | return cam
547 |
548 |
549 | # https://www.reddit.com/r/learnmath/comments/v1crd7/linear_algebra_qr_to_ql_decomposition/
550 | def ql_decomposition(A):
551 | P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device).float()
552 | A_tilde = torch.matmul(A, P)
553 | Q_tilde, R_tilde = torch.linalg.qr(A_tilde)
554 | Q = torch.matmul(Q_tilde, P)
555 | L = torch.matmul(torch.matmul(P, R_tilde), P)
556 | d = torch.diag(L)
557 | Q[:, 0] *= torch.sign(d[0])
558 | Q[:, 1] *= torch.sign(d[1])
559 | Q[:, 2] *= torch.sign(d[2])
560 | L[0] *= torch.sign(d[0])
561 | L[1] *= torch.sign(d[1])
562 | L[2] *= torch.sign(d[2])
563 | return Q, L
564 |
565 |
566 | def rays_to_cameras_homography(
567 | rays,
568 | crop_parameters,
569 | num_patches_x=16,
570 | num_patches_y=16,
571 | use_half_pix=True,
572 | sampled_ray_idx=None,
573 | reproj_threshold=0.2,
574 | ):
575 | """
576 | Args:
577 | rays (Rays): (N, P, 6)
578 | crop_parameters (torch.Tensor): (N, 4)
579 | """
580 | device = rays.device
581 | origins = rays.get_origins()
582 | directions = rays.get_directions()
583 | camera_centers, _ = intersect_skew_lines_high_dim(origins, directions)
584 |
585 | # Retrieve target rays
586 | I_camera = PerspectiveCameras(focal_length=[1] * rays.shape[0], device=device)
587 | I_patch_rays = cameras_to_rays(
588 | cameras=I_camera,
589 | num_patches_x=num_patches_x,
590 | num_patches_y=num_patches_y,
591 | use_half_pix=use_half_pix,
592 | crop_parameters=crop_parameters,
593 | ).get_directions()
594 |
595 | if sampled_ray_idx is not None:
596 | I_patch_rays = I_patch_rays[:, sampled_ray_idx]
597 |
598 | # Compute optimal rotation to align rays
599 | Rs = []
600 | focal_lengths = []
601 | principal_points = []
602 | for i in range(rays.shape[-3]):
603 | R, f, pp = compute_optimal_rotation_intrinsics(
604 | I_patch_rays[i],
605 | directions[i],
606 | reproj_threshold=reproj_threshold,
607 | )
608 | Rs.append(R)
609 | focal_lengths.append(f)
610 | principal_points.append(pp)
611 |
612 | R = torch.stack(Rs)
613 | focal_lengths = torch.stack(focal_lengths)
614 | principal_points = torch.stack(principal_points)
615 | T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2)
616 | return PerspectiveCameras(
617 | R=R,
618 | T=T,
619 | focal_length=focal_lengths,
620 | principal_point=principal_points,
621 | device=device,
622 | )
623 |
624 |
625 | def compute_optimal_rotation_alignment(A, B):
626 | """
627 | Compute optimal R that minimizes: || A - B @ R ||_F
628 |
629 | Args:
630 | A (torch.Tensor): (N, 3)
631 | B (torch.Tensor): (N, 3)
632 |
633 | Returns:
634 | R (torch.tensor): (3, 3)
635 | """
636 | # normally with R @ B, this would be A @ B.T
637 | H = B.T @ A
638 | U, _, Vh = torch.linalg.svd(H, full_matrices=True)
639 | s = torch.linalg.det(U @ Vh)
640 | S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device))
641 | return U @ S_prime @ Vh
642 |
643 |
644 | def compute_optimal_rotation_intrinsics(
645 | rays_origin, rays_target, z_threshold=1e-4, reproj_threshold=0.2
646 | ):
647 | """
648 | Note: for some reason, f seems to be 1/f.
649 |
650 | Args:
651 | rays_origin (torch.Tensor): (N, 3)
652 | rays_target (torch.Tensor): (N, 3)
653 | z_threshold (float): Threshold for z value to be considered valid.
654 |
655 | Returns:
656 | R (torch.tensor): (3, 3)
657 | focal_length (torch.tensor): (2,)
658 | principal_point (torch.tensor): (2,)
659 | """
660 | device = rays_origin.device
661 | z_mask = torch.logical_and(
662 | torch.abs(rays_target) > z_threshold, torch.abs(rays_origin) > z_threshold
663 | )[:, 2]
664 | rays_target = rays_target[z_mask]
665 | rays_origin = rays_origin[z_mask]
666 | rays_origin = rays_origin[:, :2] / rays_origin[:, -1:]
667 | rays_target = rays_target[:, :2] / rays_target[:, -1:]
668 |
669 | A, _ = cv2.findHomography(
670 | rays_origin.cpu().numpy(),
671 | rays_target.cpu().numpy(),
672 | cv2.RANSAC,
673 | reproj_threshold,
674 | )
675 | A = torch.from_numpy(A).float().to(device)
676 |
677 | if torch.linalg.det(A) < 0:
678 | A = -A
679 |
680 | R, L = ql_decomposition(A)
681 | L = L / L[2][2]
682 |
683 | f = torch.stack((L[0][0], L[1][1]))
684 | pp = torch.stack((L[2][0], L[2][1]))
685 | return R, f, pp
686 |
687 |
688 | def compute_ndc_coordinates(
689 | crop_parameters=None,
690 | use_half_pix=True,
691 | num_patches_x=16,
692 | num_patches_y=16,
693 | device=None,
694 | ):
695 | """
696 | Computes NDC Grid using crop_parameters. If crop_parameters is not provided,
697 | then it assumes that the crop is the entire image (corresponding to an NDC grid
698 | where top left corner is (1, 1) and bottom right corner is (-1, -1)).
699 | """
700 | if crop_parameters is None:
701 | cc_x, cc_y, width = 0, 0, 2
702 | else:
703 | if len(crop_parameters.shape) > 1:
704 | return torch.stack(
705 | [
706 | compute_ndc_coordinates(
707 | crop_parameters=crop_param,
708 | use_half_pix=use_half_pix,
709 | num_patches_x=num_patches_x,
710 | num_patches_y=num_patches_y,
711 | )
712 | for crop_param in crop_parameters
713 | ],
714 | dim=0,
715 | )
716 | device = crop_parameters.device
717 | cc_x, cc_y, width, _ = crop_parameters
718 |
719 | dx = 1 / num_patches_x
720 | dy = 1 / num_patches_y
721 | if use_half_pix:
722 | min_y = 1 - dy
723 | max_y = -min_y
724 | min_x = 1 - dx
725 | max_x = -min_x
726 | else:
727 | min_y = min_x = 1
728 | max_y = -1 + 2 * dy
729 | max_x = -1 + 2 * dx
730 |
731 | y, x = torch.meshgrid(
732 | torch.linspace(min_y, max_y, num_patches_y, dtype=torch.float32, device=device),
733 | torch.linspace(min_x, max_x, num_patches_x, dtype=torch.float32, device=device),
734 | indexing="ij",
735 | )
736 | x_prime = x * width / 2 - cc_x
737 | y_prime = y * width / 2 - cc_y
738 | xyd_grid = torch.stack([x_prime, y_prime, torch.ones_like(x)], dim=-1)
739 | return xyd_grid
740 |
--------------------------------------------------------------------------------
/onediffusion/dataset/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | def crop(image, i, j, h, w):
5 | """
6 | Args:
7 | image (torch.tensor): Image to be cropped. Size is (C, H, W)
8 | """
9 | if len(image.size()) != 3:
10 | raise ValueError("image should be a 3D tensor")
11 | return image[..., i : i + h, j : j + w]
12 |
13 | def resize(image, target_size, interpolation_mode):
14 | if len(target_size) != 2:
15 | raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
16 | return F.interpolate(image.unsqueeze(0), size=target_size, mode=interpolation_mode, align_corners=False).squeeze(0)
17 |
18 | def resize_scale(image, target_size, interpolation_mode):
19 | if len(target_size) != 2:
20 | raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
21 | H, W = image.size(-2), image.size(-1)
22 | scale_ = target_size[0] / min(H, W)
23 | return F.interpolate(image.unsqueeze(0), scale_factor=scale_, mode=interpolation_mode, align_corners=False).squeeze(0)
24 |
25 | def resized_crop(image, i, j, h, w, size, interpolation_mode="bilinear"):
26 | """
27 | Do spatial cropping and resizing to the image
28 | Args:
29 | image (torch.tensor): Image to be cropped. Size is (C, H, W)
30 | i (int): i in (i,j) i.e coordinates of the upper left corner.
31 | j (int): j in (i,j) i.e coordinates of the upper left corner.
32 | h (int): Height of the cropped region.
33 | w (int): Width of the cropped region.
34 | size (tuple(int, int)): height and width of resized image
35 | Returns:
36 | image (torch.tensor): Resized and cropped image. Size is (C, H, W)
37 | """
38 | if len(image.size()) != 3:
39 | raise ValueError("image should be a 3D torch.tensor")
40 | image = crop(image, i, j, h, w)
41 | image = resize(image, size, interpolation_mode)
42 | return image
43 |
44 | def center_crop(image, crop_size):
45 | if len(image.size()) != 3:
46 | raise ValueError("image should be a 3D torch.tensor")
47 | h, w = image.size(-2), image.size(-1)
48 | th, tw = crop_size
49 | if h < th or w < tw:
50 | raise ValueError("height and width must be no smaller than crop_size")
51 | i = int(round((h - th) / 2.0))
52 | j = int(round((w - tw) / 2.0))
53 | return crop(image, i, j, th, tw)
54 |
55 | def center_crop_using_short_edge(image):
56 | if len(image.size()) != 3:
57 | raise ValueError("image should be a 3D torch.tensor")
58 | h, w = image.size(-2), image.size(-1)
59 | if h < w:
60 | th, tw = h, h
61 | i = 0
62 | j = int(round((w - tw) / 2.0))
63 | else:
64 | th, tw = w, w
65 | i = int(round((h - th) / 2.0))
66 | j = 0
67 | return crop(image, i, j, th, tw)
68 |
69 | class CenterCropResizeImage:
70 | """
71 | Resize the image while maintaining aspect ratio, and then crop it to the desired size.
72 | The resizing is done such that the area of padding/cropping is minimized.
73 | """
74 | def __init__(self, size, interpolation_mode="bilinear"):
75 | if isinstance(size, tuple):
76 | if len(size) != 2:
77 | raise ValueError(f"Size should be a tuple (height, width), instead got {size}")
78 | self.size = size
79 | else:
80 | self.size = (size, size)
81 | self.interpolation_mode = interpolation_mode
82 |
83 | def __call__(self, image):
84 | """
85 | Args:
86 | image (torch.Tensor): Image to be resized and cropped. Size is (C, H, W)
87 |
88 | Returns:
89 | torch.Tensor: Resized and cropped image. Size is (C, target_height, target_width)
90 | """
91 | target_height, target_width = self.size
92 | target_aspect = target_width / target_height
93 |
94 | # Get current image shape and aspect ratio
95 | _, height, width = image.shape
96 | height, width = float(height), float(width)
97 | current_aspect = width / height
98 |
99 | # Calculate crop dimensions
100 | if current_aspect > target_aspect:
101 | # Image is wider than target, crop width
102 | crop_height = height
103 | crop_width = height * target_aspect
104 | else:
105 | # Image is taller than target, crop height
106 | crop_height = width / target_aspect
107 | crop_width = width
108 |
109 | # Calculate crop coordinates (center crop)
110 | y1 = (height - crop_height) / 2
111 | x1 = (width - crop_width) / 2
112 |
113 | # Perform the crop
114 | cropped_image = crop(image, int(y1), int(x1), int(crop_height), int(crop_width))
115 |
116 | # Resize the cropped image to the target size
117 | resized_image = resize(cropped_image, self.size, self.interpolation_mode)
118 |
119 | return resized_image
120 |
121 | # Example usage
122 | if __name__ == "__main__":
123 | # Create a sample image tensor
124 | sample_image = torch.rand(3, 480, 640) # (C, H, W)
125 |
126 | # Initialize the transform
127 | transform = CenterCropResizeImage(size=(224, 224), interpolation_mode="bilinear")
128 |
129 | # Apply the transform
130 | transformed_image = transform(sample_image)
131 |
132 | print(f"Original image shape: {sample_image.shape}")
133 | print(f"Transformed image shape: {transformed_image.shape}")
--------------------------------------------------------------------------------
/onediffusion/dataset/utils.py:
--------------------------------------------------------------------------------
1 |
2 | ASPECT_RATIO_2880 = {
3 | '0.25': [1408.0, 5760.0], '0.26': [1408.0, 5568.0], '0.27': [1408.0, 5376.0], '0.28': [1408.0, 5184.0],
4 | '0.32': [1600.0, 4992.0], '0.33': [1600.0, 4800.0], '0.34': [1600.0, 4672.0], '0.4': [1792.0, 4480.0],
5 | '0.42': [1792.0, 4288.0], '0.47': [1920.0, 4096.0], '0.49': [1920.0, 3904.0], '0.51': [1920.0, 3776.0],
6 | '0.55': [2112.0, 3840.0], '0.59': [2112.0, 3584.0], '0.68': [2304.0, 3392.0], '0.72': [2304.0, 3200.0],
7 | '0.78': [2496.0, 3200.0], '0.83': [2496.0, 3008.0], '0.89': [2688.0, 3008.0], '0.93': [2688.0, 2880.0],
8 | '1.0': [2880.0, 2880.0], '1.07': [2880.0, 2688.0], '1.12': [3008.0, 2688.0], '1.21': [3008.0, 2496.0],
9 | '1.28': [3200.0, 2496.0], '1.39': [3200.0, 2304.0], '1.47': [3392.0, 2304.0], '1.7': [3584.0, 2112.0],
10 | '1.82': [3840.0, 2112.0], '2.03': [3904.0, 1920.0], '2.13': [4096.0, 1920.0], '2.39': [4288.0, 1792.0],
11 | '2.5': [4480.0, 1792.0], '2.92': [4672.0, 1600.0], '3.0': [4800.0, 1600.0], '3.12': [4992.0, 1600.0],
12 | '3.68': [5184.0, 1408.0], '3.82': [5376.0, 1408.0], '3.95': [5568.0, 1408.0], '4.0': [5760.0, 1408.0]
13 | }
14 |
15 | ASPECT_RATIO_2048 = {
16 | '0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0], '0.27': [1024.0, 3840.0], '0.28': [1024.0, 3712.0],
17 | '0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0],
18 | '0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0],
19 | '0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0],
20 | '0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0],
21 | '1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0],
22 | '1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0],
23 | '1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0],
24 | '2.5': [3200.0, 1280.0], '2.89': [3328.0, 1152.0], '3.0': [3456.0, 1152.0], '3.11': [3584.0, 1152.0],
25 | '3.62': [3712.0, 1024.0], '3.75': [3840.0, 1024.0], '3.88': [3968.0, 1024.0], '4.0': [4096.0, 1024.0]
26 | }
27 |
28 | ASPECT_RATIO_1024 = {
29 | '0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.],
30 | '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
31 | '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
32 | '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
33 | '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
34 | '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
35 | '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
36 | '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
37 | '2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.],
38 | '3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.],
39 | }
40 |
41 | ASPECT_RATIO_512 = {
42 | '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
43 | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
44 | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
45 | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
46 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
47 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
48 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
49 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
50 | '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
51 | '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
52 | }
53 |
54 |
55 | ASPECT_RATIO_384 = {
56 | '0.25': [192.0, 768.0],
57 | '0.26': [192.0, 736.0],
58 | '0.27': [208.0, 768.0],
59 | '0.28': [208.0, 736.0],
60 | '0.33': [240.0, 720.0],
61 | '0.4': [256.0, 640.0],
62 | '0.42': [304.0, 720.0],
63 | '0.48': [368.0, 768.0],
64 | '0.5': [384.0, 768.0],
65 | '0.52': [384.0, 736.0],
66 | '0.57': [384.0, 672.0],
67 | '0.6': [384.0, 640.0],
68 | '0.73': [384.0, 528.0],
69 | '0.77': [384.0, 496.0],
70 | '0.83': [384.0, 464.0],
71 | '0.89': [384.0, 432.0],
72 | '0.92': [384.0, 416.0],
73 | '1.0': [384.0, 384.0],
74 | '1.09': [384.0, 352.0],
75 | '1.14': [384.0, 336.0],
76 | '1.2': [384.0, 320.0],
77 | '1.26': [384.0, 304.0],
78 | '1.33': [384.0, 288.0],
79 | '1.41': [384.0, 272.0],
80 | '1.6': [384.0, 240.0],
81 | '1.71': [384.0, 224.0],
82 | '2.0': [384.0, 192.0],
83 | '2.4': [384.0, 160.0],
84 | '2.88': [368.0, 128.0],
85 | '3.0': [384.0, 128.0],
86 | '3.43': [384.0, 112.0],
87 | '4.0': [384.0, 96.0]
88 | }
89 |
90 | ASPECT_RATIO_256 = {
91 | '0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0],
92 | '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
93 | '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
94 | '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
95 | '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
96 | '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
97 | '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
98 | '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
99 | '2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0],
100 | '3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0]
101 | }
102 |
103 | ASPECT_RATIO_256_TEST = {
104 | '0.25': [128.0, 512.0], '0.28': [128.0, 464.0],
105 | '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0],
106 | '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0],
107 | '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0],
108 | '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0],
109 | '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0],
110 | '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0],
111 | '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0],
112 | '2.5': [400.0, 160.0], '3.0': [432.0, 144.0],
113 | '4.0': [512.0, 128.0]
114 | }
115 |
116 | ASPECT_RATIO_512_TEST = {
117 | '0.25': [256.0, 1024.0], '0.28': [256.0, 928.0],
118 | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
119 | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
120 | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
121 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
122 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
123 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
124 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
125 | '2.5': [800.0, 320.0], '3.0': [864.0, 288.0],
126 | '4.0': [1024.0, 256.0]
127 | }
128 |
129 | ASPECT_RATIO_1024_TEST = {
130 | '0.25': [512., 2048.], '0.28': [512., 1856.],
131 | '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.],
132 | '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.],
133 | '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.],
134 | '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.],
135 | '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.],
136 | '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.],
137 | '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.],
138 | '2.5': [1600., 640.], '3.0': [1728., 576.],
139 | '4.0': [2048., 512.],
140 | }
141 |
142 | ASPECT_RATIO_2048_TEST = {
143 | '0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0],
144 | '0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0],
145 | '0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0],
146 | '0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0],
147 | '0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0],
148 | '1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0],
149 | '1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0],
150 | '1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0],
151 | '2.5': [3200.0, 1280.0], '3.0': [3456.0, 1152.0],
152 | '4.0': [4096.0, 1024.0]
153 | }
154 |
155 | ASPECT_RATIO_2880_TEST = {
156 | '0.25': [2048.0, 8192.0], '0.26': [2048.0, 7936.0],
157 | '0.32': [2304.0, 7168.0], '0.33': [2304.0, 6912.0], '0.35': [2304.0, 6656.0], '0.4': [2560.0, 6400.0],
158 | '0.42': [2560.0, 6144.0], '0.48': [2816.0, 5888.0], '0.5': [2816.0, 5632.0], '0.52': [2816.0, 5376.0],
159 | '0.57': [3072.0, 5376.0], '0.6': [3072.0, 5120.0], '0.68': [3328.0, 4864.0], '0.72': [3328.0, 4608.0],
160 | '0.78': [3584.0, 4608.0], '0.82': [3584.0, 4352.0], '0.88': [3840.0, 4352.0], '0.94': [3840.0, 4096.0],
161 | '1.0': [4096.0, 4096.0], '1.07': [4096.0, 3840.0], '1.13': [4352.0, 3840.0], '1.21': [4352.0, 3584.0],
162 | '1.29': [4608.0, 3584.0], '1.38': [4608.0, 3328.0], '1.46': [4864.0, 3328.0], '1.67': [5120.0, 3072.0],
163 | '1.75': [5376.0, 3072.0], '2.0': [5632.0, 2816.0], '2.09': [5888.0, 2816.0], '2.4': [6144.0, 2560.0],
164 | '2.5': [6400.0, 2560.0], '3.0': [6912.0, 2304.0],
165 | '4.0': [8192.0, 2048.0],
166 | }
167 |
168 | def get_chunks(lst, n):
169 | for i in range(0, len(lst), n):
170 | yield lst[i:i + n]
171 |
172 | def get_closest_ratio(height: float, width: float, ratios: dict):
173 | aspect_ratio = height / width
174 | closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
175 | return ratios[closest_ratio], float(closest_ratio)
176 |
--------------------------------------------------------------------------------
/onediffusion/diffusion/pipelines/image_processor.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | import warnings
17 | from typing import List, Optional, Tuple, Union
18 |
19 | import numpy as np
20 | import PIL.Image
21 | import torch
22 | import torch.nn.functional as F
23 | import torchvision.transforms as T
24 | from PIL import Image, ImageFilter, ImageOps
25 |
26 | from diffusers.configuration_utils import ConfigMixin, register_to_config
27 | from diffusers.utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
28 |
29 | from onediffusion.dataset.transforms import CenterCropResizeImage
30 |
31 | PipelineImageInput = Union[
32 | PIL.Image.Image,
33 | np.ndarray,
34 | torch.Tensor,
35 | List[PIL.Image.Image],
36 | List[np.ndarray],
37 | List[torch.Tensor],
38 | ]
39 |
40 | PipelineDepthInput = PipelineImageInput
41 |
42 |
43 | def is_valid_image(image):
44 | return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
45 |
46 |
47 | def is_valid_image_imagelist(images):
48 | # check if the image input is one of the supported formats for image and image list:
49 | # it can be either one of below 3
50 | # (1) a 4d pytorch tensor or numpy array,
51 | # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
52 | # (3) a list of valid image
53 | if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
54 | return True
55 | elif is_valid_image(images):
56 | return True
57 | elif isinstance(images, list):
58 | return all(is_valid_image(image) for image in images)
59 | return False
60 |
61 |
62 | class VaeImageProcessorOneDiffuser(ConfigMixin):
63 | """
64 | Image processor for VAE.
65 |
66 | Args:
67 | do_resize (`bool`, *optional*, defaults to `True`):
68 | Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
69 | `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
70 | vae_scale_factor (`int`, *optional*, defaults to `8`):
71 | VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
72 | resample (`str`, *optional*, defaults to `lanczos`):
73 | Resampling filter to use when resizing the image.
74 | do_normalize (`bool`, *optional*, defaults to `True`):
75 | Whether to normalize the image to [-1,1].
76 | do_binarize (`bool`, *optional*, defaults to `False`):
77 | Whether to binarize the image to 0/1.
78 | do_convert_rgb (`bool`, *optional*, defaults to be `False`):
79 | Whether to convert the images to RGB format.
80 | do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
81 | Whether to convert the images to grayscale format.
82 | """
83 |
84 | config_name = CONFIG_NAME
85 |
86 | @register_to_config
87 | def __init__(
88 | self,
89 | do_resize: bool = True,
90 | vae_scale_factor: int = 8,
91 | vae_latent_channels: int = 4,
92 | resample: str = "lanczos",
93 | do_normalize: bool = True,
94 | do_binarize: bool = False,
95 | do_convert_rgb: bool = False,
96 | do_convert_grayscale: bool = False,
97 | ):
98 | super().__init__()
99 | if do_convert_rgb and do_convert_grayscale:
100 | raise ValueError(
101 | "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
102 | " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
103 | " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
104 | )
105 |
106 | @staticmethod
107 | def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
108 | """
109 | Convert a numpy image or a batch of images to a PIL image.
110 | """
111 | if images.ndim == 3:
112 | images = images[None, ...]
113 | images = (images * 255).round().astype("uint8")
114 | if images.shape[-1] == 1:
115 | # special case for grayscale (single channel) images
116 | pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
117 | else:
118 | pil_images = [Image.fromarray(image) for image in images]
119 |
120 | return pil_images
121 |
122 | @staticmethod
123 | def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
124 | """
125 | Convert a PIL image or a list of PIL images to NumPy arrays.
126 | """
127 | if not isinstance(images, list):
128 | images = [images]
129 | images = [np.array(image).astype(np.float32) / 255.0 for image in images]
130 | images = np.stack(images, axis=0)
131 |
132 | return images
133 |
134 | @staticmethod
135 | def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
136 | """
137 | Convert a NumPy image to a PyTorch tensor.
138 | """
139 | if images.ndim == 3:
140 | images = images[..., None]
141 |
142 | images = torch.from_numpy(images.transpose(0, 3, 1, 2))
143 | return images
144 |
145 | @staticmethod
146 | def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
147 | """
148 | Convert a PyTorch tensor to a NumPy image.
149 | """
150 | images = images.cpu().permute(0, 2, 3, 1).float().numpy()
151 | return images
152 |
153 | @staticmethod
154 | def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
155 | """
156 | Normalize an image array to [-1,1].
157 | """
158 | return 2.0 * images - 1.0
159 |
160 | @staticmethod
161 | def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
162 | """
163 | Denormalize an image array to [0,1].
164 | """
165 | return (images / 2 + 0.5).clamp(0, 1)
166 |
167 | @staticmethod
168 | def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
169 | """
170 | Converts a PIL image to RGB format.
171 | """
172 | image = image.convert("RGB")
173 |
174 | return image
175 |
176 | @staticmethod
177 | def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
178 | """
179 | Converts a PIL image to grayscale format.
180 | """
181 | image = image.convert("L")
182 |
183 | return image
184 |
185 | @staticmethod
186 | def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
187 | """
188 | Applies Gaussian blur to an image.
189 | """
190 | image = image.filter(ImageFilter.GaussianBlur(blur_factor))
191 |
192 | return image
193 |
194 | @staticmethod
195 | def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
196 | """
197 | Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
198 | ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
199 | processing are 512x512, the region will be expanded to 128x128.
200 |
201 | Args:
202 | mask_image (PIL.Image.Image): Mask image.
203 | width (int): Width of the image to be processed.
204 | height (int): Height of the image to be processed.
205 | pad (int, optional): Padding to be added to the crop region. Defaults to 0.
206 |
207 | Returns:
208 | tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
209 | matches the original aspect ratio.
210 | """
211 |
212 | mask_image = mask_image.convert("L")
213 | mask = np.array(mask_image)
214 |
215 | # 1. find a rectangular region that contains all masked ares in an image
216 | h, w = mask.shape
217 | crop_left = 0
218 | for i in range(w):
219 | if not (mask[:, i] == 0).all():
220 | break
221 | crop_left += 1
222 |
223 | crop_right = 0
224 | for i in reversed(range(w)):
225 | if not (mask[:, i] == 0).all():
226 | break
227 | crop_right += 1
228 |
229 | crop_top = 0
230 | for i in range(h):
231 | if not (mask[i] == 0).all():
232 | break
233 | crop_top += 1
234 |
235 | crop_bottom = 0
236 | for i in reversed(range(h)):
237 | if not (mask[i] == 0).all():
238 | break
239 | crop_bottom += 1
240 |
241 | # 2. add padding to the crop region
242 | x1, y1, x2, y2 = (
243 | int(max(crop_left - pad, 0)),
244 | int(max(crop_top - pad, 0)),
245 | int(min(w - crop_right + pad, w)),
246 | int(min(h - crop_bottom + pad, h)),
247 | )
248 |
249 | # 3. expands crop region to match the aspect ratio of the image to be processed
250 | ratio_crop_region = (x2 - x1) / (y2 - y1)
251 | ratio_processing = width / height
252 |
253 | if ratio_crop_region > ratio_processing:
254 | desired_height = (x2 - x1) / ratio_processing
255 | desired_height_diff = int(desired_height - (y2 - y1))
256 | y1 -= desired_height_diff // 2
257 | y2 += desired_height_diff - desired_height_diff // 2
258 | if y2 >= mask_image.height:
259 | diff = y2 - mask_image.height
260 | y2 -= diff
261 | y1 -= diff
262 | if y1 < 0:
263 | y2 -= y1
264 | y1 -= y1
265 | if y2 >= mask_image.height:
266 | y2 = mask_image.height
267 | else:
268 | desired_width = (y2 - y1) * ratio_processing
269 | desired_width_diff = int(desired_width - (x2 - x1))
270 | x1 -= desired_width_diff // 2
271 | x2 += desired_width_diff - desired_width_diff // 2
272 | if x2 >= mask_image.width:
273 | diff = x2 - mask_image.width
274 | x2 -= diff
275 | x1 -= diff
276 | if x1 < 0:
277 | x2 -= x1
278 | x1 -= x1
279 | if x2 >= mask_image.width:
280 | x2 = mask_image.width
281 |
282 | return x1, y1, x2, y2
283 |
284 | def _resize_and_fill(
285 | self,
286 | image: PIL.Image.Image,
287 | width: int,
288 | height: int,
289 | ) -> PIL.Image.Image:
290 | """
291 | Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
292 | the image within the dimensions, filling empty with data from image.
293 |
294 | Args:
295 | image: The image to resize.
296 | width: The width to resize the image to.
297 | height: The height to resize the image to.
298 | """
299 |
300 | ratio = width / height
301 | src_ratio = image.width / image.height
302 |
303 | src_w = width if ratio < src_ratio else image.width * height // image.height
304 | src_h = height if ratio >= src_ratio else image.height * width // image.width
305 |
306 | resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
307 | res = Image.new("RGB", (width, height))
308 | res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
309 |
310 | if ratio < src_ratio:
311 | fill_height = height // 2 - src_h // 2
312 | if fill_height > 0:
313 | res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
314 | res.paste(
315 | resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
316 | box=(0, fill_height + src_h),
317 | )
318 | elif ratio > src_ratio:
319 | fill_width = width // 2 - src_w // 2
320 | if fill_width > 0:
321 | res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
322 | res.paste(
323 | resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
324 | box=(fill_width + src_w, 0),
325 | )
326 |
327 | return res
328 |
329 | def _resize_and_crop(
330 | self,
331 | image: PIL.Image.Image,
332 | width: int,
333 | height: int,
334 | ) -> PIL.Image.Image:
335 | """
336 | Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
337 | the image within the dimensions, cropping the excess.
338 |
339 | Args:
340 | image: The image to resize.
341 | width: The width to resize the image to.
342 | height: The height to resize the image to.
343 | """
344 | ratio = width / height
345 | src_ratio = image.width / image.height
346 |
347 | src_w = width if ratio > src_ratio else image.width * height // image.height
348 | src_h = height if ratio <= src_ratio else image.height * width // image.width
349 |
350 | resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
351 | res = Image.new("RGB", (width, height))
352 | res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
353 | return res
354 |
355 | def resize(
356 | self,
357 | image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
358 | height: int,
359 | width: int,
360 | resize_mode: str = "default", # "default", "fill", "crop"
361 | ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
362 | """
363 | Resize image.
364 |
365 | Args:
366 | image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
367 | The image input, can be a PIL image, numpy array or pytorch tensor.
368 | height (`int`):
369 | The height to resize to.
370 | width (`int`):
371 | The width to resize to.
372 | resize_mode (`str`, *optional*, defaults to `default`):
373 | The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
374 | within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
375 | will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
376 | then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
377 | the image to fit within the specified width and height, maintaining the aspect ratio, and then center
378 | the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
379 | supported for PIL image input.
380 |
381 | Returns:
382 | `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
383 | The resized image.
384 | """
385 | if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
386 | raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
387 | if isinstance(image, PIL.Image.Image):
388 | if resize_mode == "default":
389 | image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
390 | elif resize_mode == "fill":
391 | image = self._resize_and_fill(image, width, height)
392 | elif resize_mode == "crop":
393 | image = self._resize_and_crop(image, width, height)
394 | else:
395 | raise ValueError(f"resize_mode {resize_mode} is not supported")
396 |
397 | elif isinstance(image, torch.Tensor):
398 | image = torch.nn.functional.interpolate(
399 | image,
400 | size=(height, width),
401 | )
402 | elif isinstance(image, np.ndarray):
403 | image = self.numpy_to_pt(image)
404 | image = torch.nn.functional.interpolate(
405 | image,
406 | size=(height, width),
407 | )
408 | image = self.pt_to_numpy(image)
409 | return image
410 |
411 | def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
412 | """
413 | Create a mask.
414 |
415 | Args:
416 | image (`PIL.Image.Image`):
417 | The image input, should be a PIL image.
418 |
419 | Returns:
420 | `PIL.Image.Image`:
421 | The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
422 | """
423 | image[image < 0.5] = 0
424 | image[image >= 0.5] = 1
425 |
426 | return image
427 |
428 | def get_default_height_width(
429 | self,
430 | image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
431 | height: Optional[int] = None,
432 | width: Optional[int] = None,
433 | ) -> Tuple[int, int]:
434 | """
435 | This function return the height and width that are downscaled to the next integer multiple of
436 | `vae_scale_factor`.
437 |
438 | Args:
439 | image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
440 | The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
441 | shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
442 | have shape `[batch, channel, height, width]`.
443 | height (`int`, *optional*, defaults to `None`):
444 | The height in preprocessed image. If `None`, will use the height of `image` input.
445 | width (`int`, *optional*`, defaults to `None`):
446 | The width in preprocessed. If `None`, will use the width of the `image` input.
447 | """
448 |
449 | if height is None:
450 | if isinstance(image, PIL.Image.Image):
451 | height = image.height
452 | elif isinstance(image, torch.Tensor):
453 | height = image.shape[2]
454 | else:
455 | height = image.shape[1]
456 |
457 | if width is None:
458 | if isinstance(image, PIL.Image.Image):
459 | width = image.width
460 | elif isinstance(image, torch.Tensor):
461 | width = image.shape[3]
462 | else:
463 | width = image.shape[2]
464 |
465 | width, height = (
466 | x - x % self.config.vae_scale_factor for x in (width, height)
467 | ) # resize to integer multiple of vae_scale_factor
468 |
469 | return height, width
470 |
471 | def preprocess(
472 | self,
473 | image: PipelineImageInput,
474 | height: Optional[int] = None,
475 | width: Optional[int] = None,
476 | resize_mode: str = "default", # "default", "fill", "crop"
477 | crops_coords: Optional[Tuple[int, int, int, int]] = None,
478 | do_crop: bool = True,
479 | ) -> torch.Tensor:
480 | """
481 | Preprocess the image input.
482 |
483 | Args:
484 | image (`pipeline_image_input`):
485 | The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
486 | supported formats.
487 | height (`int`, *optional*, defaults to `None`):
488 | The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
489 | height.
490 | width (`int`, *optional*`, defaults to `None`):
491 | The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
492 | resize_mode (`str`, *optional*, defaults to `default`):
493 | The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
494 | the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
495 | resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
496 | center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
497 | image to fit within the specified width and height, maintaining the aspect ratio, and then center the
498 | image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
499 | supported for PIL image input.
500 | crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
501 | The crop coordinates for each image in the batch. If `None`, will not crop the image.
502 | """
503 | supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
504 |
505 | # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
506 | if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
507 | if isinstance(image, torch.Tensor):
508 | # if image is a pytorch tensor could have 2 possible shapes:
509 | # 1. batch x height x width: we should insert the channel dimension at position 1
510 | # 2. channel x height x width: we should insert batch dimension at position 0,
511 | # however, since both channel and batch dimension has same size 1, it is same to insert at position 1
512 | # for simplicity, we insert a dimension of size 1 at position 1 for both cases
513 | image = image.unsqueeze(1)
514 | else:
515 | # if it is a numpy array, it could have 2 possible shapes:
516 | # 1. batch x height x width: insert channel dimension on last position
517 | # 2. height x width x channel: insert batch dimension on first position
518 | if image.shape[-1] == 1:
519 | image = np.expand_dims(image, axis=0)
520 | else:
521 | image = np.expand_dims(image, axis=-1)
522 |
523 | if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
524 | warnings.warn(
525 | "Passing `image` as a list of 4d np.ndarray is deprecated."
526 | "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
527 | FutureWarning,
528 | )
529 | image = np.concatenate(image, axis=0)
530 | if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
531 | warnings.warn(
532 | "Passing `image` as a list of 4d torch.Tensor is deprecated."
533 | "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
534 | FutureWarning,
535 | )
536 | image = torch.cat(image, axis=0)
537 |
538 | if not is_valid_image_imagelist(image):
539 | raise ValueError(
540 | f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
541 | )
542 | if not isinstance(image, list):
543 | image = [image]
544 |
545 | if isinstance(image[0], PIL.Image.Image):
546 | pass
547 | elif isinstance(image[0], np.ndarray):
548 | image = self.numpy_to_pil(image)
549 | elif isinstance(image[0], torch.Tensor):
550 | image = self.pt_to_numpy(image)
551 | image = self.numpy_to_pil(image)
552 |
553 | if do_crop:
554 | transforms = T.Compose([
555 | T.Lambda(lambda image: image.convert('RGB')),
556 | T.ToTensor(),
557 | CenterCropResizeImage((height, width)),
558 | T.Normalize([.5], [.5]),
559 | ])
560 | else:
561 | transforms = T.Compose([
562 | T.Lambda(lambda image: image.convert('RGB')),
563 | T.ToTensor(),
564 | T.Resize((height, width)),
565 | T.Normalize([.5], [.5]),
566 | ])
567 | image = torch.stack([transforms(i) for i in image])
568 |
569 | # expected range [0,1], normalize to [-1,1]
570 | do_normalize = self.config.do_normalize
571 | if do_normalize and image.min() < 0:
572 | warnings.warn(
573 | "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
574 | f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
575 | FutureWarning,
576 | )
577 | do_normalize = False
578 | if do_normalize:
579 | image = self.normalize(image)
580 |
581 | if self.config.do_binarize:
582 | image = self.binarize(image)
583 |
584 | return image
585 |
586 | def postprocess(
587 | self,
588 | image: torch.Tensor,
589 | output_type: str = "pil",
590 | do_denormalize: Optional[List[bool]] = None,
591 | ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
592 | """
593 | Postprocess the image output from tensor to `output_type`.
594 |
595 | Args:
596 | image (`torch.Tensor`):
597 | The image input, should be a pytorch tensor with shape `B x C x H x W`.
598 | output_type (`str`, *optional*, defaults to `pil`):
599 | The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
600 | do_denormalize (`List[bool]`, *optional*, defaults to `None`):
601 | Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
602 | `VaeImageProcessor` config.
603 |
604 | Returns:
605 | `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
606 | The postprocessed image.
607 | """
608 | if not isinstance(image, torch.Tensor):
609 | raise ValueError(
610 | f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
611 | )
612 | if output_type not in ["latent", "pt", "np", "pil"]:
613 | deprecation_message = (
614 | f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
615 | "`pil`, `np`, `pt`, `latent`"
616 | )
617 | deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
618 | output_type = "np"
619 |
620 | if output_type == "latent":
621 | return image
622 |
623 | if do_denormalize is None:
624 | do_denormalize = [self.config.do_normalize] * image.shape[0]
625 |
626 | image = torch.stack(
627 | [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
628 | )
629 |
630 | if output_type == "pt":
631 | return image
632 |
633 | image = self.pt_to_numpy(image)
634 |
635 | if output_type == "np":
636 | return image
637 |
638 | if output_type == "pil":
639 | return self.numpy_to_pil(image)
640 |
641 | def apply_overlay(
642 | self,
643 | mask: PIL.Image.Image,
644 | init_image: PIL.Image.Image,
645 | image: PIL.Image.Image,
646 | crop_coords: Optional[Tuple[int, int, int, int]] = None,
647 | ) -> PIL.Image.Image:
648 | """
649 | overlay the inpaint output to the original image
650 | """
651 |
652 | width, height = image.width, image.height
653 |
654 | init_image = self.resize(init_image, width=width, height=height)
655 | mask = self.resize(mask, width=width, height=height)
656 |
657 | init_image_masked = PIL.Image.new("RGBa", (width, height))
658 | init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
659 | init_image_masked = init_image_masked.convert("RGBA")
660 |
661 | if crop_coords is not None:
662 | x, y, x2, y2 = crop_coords
663 | w = x2 - x
664 | h = y2 - y
665 | base_image = PIL.Image.new("RGBA", (width, height))
666 | image = self.resize(image, height=h, width=w, resize_mode="crop")
667 | base_image.paste(image, (x, y))
668 | image = base_image.convert("RGB")
669 |
670 | image = image.convert("RGBA")
671 | image.alpha_composite(init_image_masked)
672 | image = image.convert("RGB")
673 |
674 | return image
675 |
--------------------------------------------------------------------------------
/onediffusion/models/denoiser/__init__.py:
--------------------------------------------------------------------------------
1 | from . import (
2 | nextdit
3 | )
--------------------------------------------------------------------------------
/onediffusion/models/denoiser/nextdit/__init__.py:
--------------------------------------------------------------------------------
1 | from .modeling_nextdit import NextDiT
--------------------------------------------------------------------------------
/onediffusion/models/denoiser/nextdit/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from typing import Callable, Optional
6 |
7 | import warnings
8 |
9 | import torch
10 | import torch.nn as nn
11 |
12 | try:
13 | from apex.normalization import FusedRMSNorm as RMSNorm
14 | except ImportError:
15 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
16 |
17 |
18 | class RMSNorm(torch.nn.Module):
19 | def __init__(self, dim: int, eps: float = 1e-6):
20 | """
21 | Initialize the RMSNorm normalization layer.
22 | Args:
23 | dim (int): The dimension of the input tensor.
24 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
25 | Attributes:
26 | eps (float): A small value added to the denominator for numerical stability.
27 | weight (nn.Parameter): Learnable scaling parameter.
28 | """
29 | super().__init__()
30 | self.eps = eps
31 | self.weight = nn.Parameter(torch.ones(dim))
32 |
33 | def _norm(self, x):
34 | """
35 | Apply the RMSNorm normalization to the input tensor.
36 | Args:
37 | x (torch.Tensor): The input tensor.
38 | Returns:
39 | torch.Tensor: The normalized tensor.
40 | """
41 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
42 |
43 | def forward(self, x):
44 | """
45 | Forward pass through the RMSNorm layer.
46 | Args:
47 | x (torch.Tensor): The input tensor.
48 | Returns:
49 | torch.Tensor: The output tensor after applying RMSNorm.
50 | """
51 | output = self._norm(x.float()).type_as(x)
52 | return output * self.weight
53 |
54 |
55 | def modulate(x, scale):
56 | return x * (1 + scale.unsqueeze(1))
57 |
58 | class LLamaFeedForward(nn.Module):
59 | """
60 | Corresponds to the FeedForward layer in Next DiT.
61 | """
62 | def __init__(
63 | self,
64 | dim: int,
65 | hidden_dim: int,
66 | multiple_of: int,
67 | ffn_dim_multiplier: Optional[float] = None,
68 | zeros_initialize: bool = True,
69 | dtype: torch.dtype = torch.float32,
70 | ):
71 | super().__init__()
72 | self.dim = dim
73 | self.hidden_dim = hidden_dim
74 | self.multiple_of = multiple_of
75 | self.ffn_dim_multiplier = ffn_dim_multiplier
76 | self.zeros_initialize = zeros_initialize
77 | self.dtype = dtype
78 |
79 | # Compute hidden_dim based on the given formula
80 | hidden_dim_calculated = int(2 * self.hidden_dim / 3)
81 | if self.ffn_dim_multiplier is not None:
82 | hidden_dim_calculated = int(self.ffn_dim_multiplier * hidden_dim_calculated)
83 | hidden_dim_calculated = self.multiple_of * ((hidden_dim_calculated + self.multiple_of - 1) // self.multiple_of)
84 |
85 | # Define linear layers
86 | self.w1 = nn.Linear(self.dim, hidden_dim_calculated, bias=False)
87 | self.w2 = nn.Linear(hidden_dim_calculated, self.dim, bias=False)
88 | self.w3 = nn.Linear(self.dim, hidden_dim_calculated, bias=False)
89 |
90 | # Initialize weights
91 | if self.zeros_initialize:
92 | nn.init.zeros_(self.w2.weight)
93 | else:
94 | nn.init.xavier_uniform_(self.w2.weight)
95 | nn.init.xavier_uniform_(self.w1.weight)
96 | nn.init.xavier_uniform_(self.w3.weight)
97 |
98 | def _forward_silu_gating(self, x1, x3):
99 | return F.silu(x1) * x3
100 |
101 | def forward(self, x):
102 | return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
103 |
104 | class FinalLayer(nn.Module):
105 | """
106 | The final layer of Next-DiT.
107 | """
108 | def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
109 | super().__init__()
110 | self.hidden_size = hidden_size
111 | self.patch_size = patch_size
112 | self.out_channels = out_channels
113 |
114 | # LayerNorm without learnable parameters (elementwise_affine=False)
115 | self.norm_final = nn.LayerNorm(self.hidden_size, eps=1e-6, elementwise_affine=False)
116 | self.linear = nn.Linear(self.hidden_size, np.prod(self.patch_size) * self.out_channels, bias=True)
117 | nn.init.zeros_(self.linear.weight)
118 | nn.init.zeros_(self.linear.bias)
119 |
120 | self.adaLN_modulation = nn.Sequential(
121 | nn.SiLU(),
122 | nn.Linear(self.hidden_size, self.hidden_size),
123 | )
124 | # Initialize the last layer with zeros
125 | nn.init.zeros_(self.adaLN_modulation[1].weight)
126 | nn.init.zeros_(self.adaLN_modulation[1].bias)
127 |
128 | def forward(self, x, c):
129 | scale = self.adaLN_modulation(c)
130 | x = modulate(self.norm_final(x), scale)
131 | x = self.linear(x)
132 | return x
--------------------------------------------------------------------------------
/onediffusion/models/denoiser/nextdit/modeling_nextdit.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import einops
7 | from diffusers.configuration_utils import ConfigMixin, register_to_config
8 | from diffusers.models.modeling_utils import ModelMixin
9 | from typing import Any, Tuple, Optional
10 | from flash_attn import flash_attn_varlen_func
11 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
12 |
13 | from .layers import LLamaFeedForward, RMSNorm
14 |
15 | # import frasch
16 |
17 |
18 | def modulate(x, scale):
19 | return x * (1 + scale)
20 |
21 | class TimestepEmbedder(nn.Module):
22 | """
23 | Embeds scalar timesteps into vector representations.
24 | """
25 | def __init__(self, hidden_size, frequency_embedding_size=256):
26 | super().__init__()
27 | self.hidden_size = hidden_size
28 | self.frequency_embedding_size = frequency_embedding_size
29 | self.mlp = nn.Sequential(
30 | nn.Linear(self.frequency_embedding_size, self.hidden_size),
31 | nn.SiLU(),
32 | nn.Linear(self.hidden_size, self.hidden_size),
33 | )
34 |
35 | @staticmethod
36 | def timestep_embedding(t, dim, max_period=10000):
37 | """
38 | Create sinusoidal timestep embeddings.
39 | :param t: a 1-D Tensor of N indices, one per batch element.
40 | :param dim: the dimension of the output.
41 | :param max_period: controls the minimum frequency of the embeddings.
42 | :return: an (N, D) Tensor of positional embeddings.
43 | """
44 | half = dim // 2
45 | freqs = torch.exp(
46 | -np.log(max_period) * torch.arange(0, half, dtype=t.dtype) / half
47 | ).to(t.device)
48 | args = t[:, :, None] * freqs[None, :]
49 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
50 | if dim % 2:
51 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :, :1])], dim=-1)
52 | return embedding
53 |
54 | def forward(self, t):
55 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
56 | t_freq = t_freq.to(self.mlp[0].weight.dtype)
57 | return self.mlp(t_freq)
58 |
59 | class FinalLayer(nn.Module):
60 | def __init__(self, hidden_size, num_patches, out_channels):
61 | super().__init__()
62 | self.norm_final = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False)
63 | self.linear = nn.Linear(hidden_size, num_patches * out_channels)
64 | self.adaLN_modulation = nn.Sequential(
65 | nn.SiLU(),
66 | nn.Linear(min(hidden_size, 1024), hidden_size),
67 | )
68 |
69 | def forward(self, x, c):
70 | scale = self.adaLN_modulation(c)
71 | x = modulate(self.norm_final(x), scale)
72 | x = self.linear(x)
73 | return x
74 |
75 | class Attention(nn.Module):
76 | def __init__(
77 | self,
78 | dim,
79 | n_heads,
80 | n_kv_heads=None,
81 | qk_norm=False,
82 | y_dim=0,
83 | base_seqlen=None,
84 | proportional_attn=False,
85 | attention_dropout=0.0,
86 | max_position_embeddings=384,
87 | ):
88 | super().__init__()
89 | self.dim = dim
90 | self.n_heads = n_heads
91 | self.n_kv_heads = n_kv_heads or n_heads
92 | self.qk_norm = qk_norm
93 | self.y_dim = y_dim
94 | self.base_seqlen = base_seqlen
95 | self.proportional_attn = proportional_attn
96 | self.attention_dropout = attention_dropout
97 | self.max_position_embeddings = max_position_embeddings
98 |
99 | self.head_dim = dim // n_heads
100 |
101 | self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False)
102 | self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
103 | self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False)
104 |
105 | if y_dim > 0:
106 | self.wk_y = nn.Linear(y_dim, self.n_kv_heads * self.head_dim, bias=False)
107 | self.wv_y = nn.Linear(y_dim, self.n_kv_heads * self.head_dim, bias=False)
108 | self.gate = nn.Parameter(torch.zeros(n_heads))
109 |
110 | self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False)
111 |
112 | if qk_norm:
113 | self.q_norm = nn.LayerNorm(self.n_heads * self.head_dim)
114 | self.k_norm = nn.LayerNorm(self.n_kv_heads * self.head_dim)
115 | if y_dim > 0:
116 | self.ky_norm = nn.LayerNorm(self.n_kv_heads * self.head_dim, eps=1e-6)
117 | else:
118 | self.ky_norm = nn.Identity()
119 | else:
120 | self.q_norm = nn.Identity()
121 | self.k_norm = nn.Identity()
122 | self.ky_norm = nn.Identity()
123 |
124 |
125 | @staticmethod
126 | def apply_rotary_emb(xq, xk, freqs_cis):
127 | # xq, xk: [batch_size, seq_len, n_heads, head_dim]
128 | # freqs_cis: [1, seq_len, 1, head_dim]
129 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
130 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)
131 |
132 | xq_complex = torch.view_as_complex(xq_)
133 | xk_complex = torch.view_as_complex(xk_)
134 |
135 | freqs_cis = freqs_cis.unsqueeze(2)
136 |
137 | # Apply freqs_cis
138 | xq_out = xq_complex * freqs_cis
139 | xk_out = xk_complex * freqs_cis
140 |
141 | # Convert back to real numbers
142 | xq_out = torch.view_as_real(xq_out).flatten(-2)
143 | xk_out = torch.view_as_real(xk_out).flatten(-2)
144 |
145 | return xq_out.type_as(xq), xk_out.type_as(xk)
146 |
147 | # copied from huggingface modeling_llama.py
148 | def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
149 | def _get_unpad_data(attention_mask):
150 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
151 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
152 | max_seqlen_in_batch = seqlens_in_batch.max().item()
153 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
154 | return (
155 | indices,
156 | cu_seqlens,
157 | max_seqlen_in_batch,
158 | )
159 |
160 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
161 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
162 |
163 | key_layer = index_first_axis(
164 | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
165 | indices_k,
166 | )
167 | value_layer = index_first_axis(
168 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
169 | indices_k,
170 | )
171 | if query_length == kv_seq_len:
172 | query_layer = index_first_axis(
173 | query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim),
174 | indices_k,
175 | )
176 | cu_seqlens_q = cu_seqlens_k
177 | max_seqlen_in_batch_q = max_seqlen_in_batch_k
178 | indices_q = indices_k
179 | elif query_length == 1:
180 | max_seqlen_in_batch_q = 1
181 | cu_seqlens_q = torch.arange(
182 | batch_size + 1, dtype=torch.int32, device=query_layer.device
183 | ) # There is a memcpy here, that is very bad.
184 | indices_q = cu_seqlens_q[:-1]
185 | query_layer = query_layer.squeeze(1)
186 | else:
187 | # The -q_len: slice assumes left padding.
188 | attention_mask = attention_mask[:, -query_length:]
189 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
190 |
191 | return (
192 | query_layer,
193 | key_layer,
194 | value_layer,
195 | indices_q,
196 | (cu_seqlens_q, cu_seqlens_k),
197 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
198 | )
199 |
200 | def forward(
201 | self,
202 | x,
203 | x_mask,
204 | freqs_cis,
205 | y=None,
206 | y_mask=None,
207 | init_cache=False,
208 | ):
209 | bsz, seqlen, _ = x.size()
210 | xq = self.wq(x)
211 | xk = self.wk(x)
212 | xv = self.wv(x)
213 |
214 | if x_mask is None:
215 | x_mask = torch.ones(bsz, seqlen, dtype=torch.bool, device=x.device)
216 | inp_dtype = xq.dtype
217 |
218 | xq = self.q_norm(xq)
219 | xk = self.k_norm(xk)
220 |
221 | xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
222 | xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
223 | xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
224 |
225 | if self.n_kv_heads != self.n_heads:
226 | n_rep = self.n_heads // self.n_kv_heads
227 | xk = xk.repeat_interleave(n_rep, dim=2)
228 | xv = xv.repeat_interleave(n_rep, dim=2)
229 |
230 | freqs_cis = freqs_cis.to(xq.device)
231 | xq, xk = self.apply_rotary_emb(xq, xk, freqs_cis)
232 |
233 | if inp_dtype in [torch.float16, torch.bfloat16]:
234 | # begin var_len flash attn
235 | (
236 | query_states,
237 | key_states,
238 | value_states,
239 | indices_q,
240 | cu_seq_lens,
241 | max_seq_lens,
242 | ) = self._upad_input(xq, xk, xv, x_mask, seqlen)
243 |
244 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens
245 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
246 |
247 | attn_output_unpad = flash_attn_varlen_func(
248 | query_states.to(inp_dtype),
249 | key_states.to(inp_dtype),
250 | value_states.to(inp_dtype),
251 | cu_seqlens_q=cu_seqlens_q,
252 | cu_seqlens_k=cu_seqlens_k,
253 | max_seqlen_q=max_seqlen_in_batch_q,
254 | max_seqlen_k=max_seqlen_in_batch_k,
255 | dropout_p=0.0,
256 | causal=False,
257 | softmax_scale=None,
258 | softcap=30,
259 | )
260 | output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
261 | else:
262 | output = (
263 | F.scaled_dot_product_attention(
264 | xq.permute(0, 2, 1, 3),
265 | xk.permute(0, 2, 1, 3),
266 | xv.permute(0, 2, 1, 3),
267 | attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_heads, seqlen, -1),
268 | scale=None,
269 | )
270 | .permute(0, 2, 1, 3)
271 | .to(inp_dtype)
272 | ) #ok
273 |
274 |
275 | if hasattr(self, "wk_y"):
276 | yk = self.ky_norm(self.wk_y(y)).view(bsz, -1, self.n_kv_heads, self.head_dim)
277 | yv = self.wv_y(y).view(bsz, -1, self.n_kv_heads, self.head_dim)
278 | n_rep = self.n_heads // self.n_kv_heads
279 | # if n_rep >= 1:
280 | # yk = yk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
281 | # yv = yv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
282 | if n_rep >= 1:
283 | yk = einops.repeat(yk, "b l h d -> b l (repeat h) d", repeat=n_rep)
284 | yv = einops.repeat(yv, "b l h d -> b l (repeat h) d", repeat=n_rep)
285 | output_y = F.scaled_dot_product_attention(
286 | xq.permute(0, 2, 1, 3),
287 | yk.permute(0, 2, 1, 3),
288 | yv.permute(0, 2, 1, 3),
289 | y_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_heads, seqlen, -1).to(torch.bool),
290 | ).permute(0, 2, 1, 3)
291 | output_y = output_y * self.gate.tanh().view(1, 1, -1, 1)
292 | output = output + output_y
293 |
294 | output = output.flatten(-2)
295 | output = self.wo(output)
296 |
297 | return output.to(inp_dtype)
298 |
299 | class TransformerBlock(nn.Module):
300 | """
301 | Corresponds to the Transformer block in the JAX code.
302 | """
303 | def __init__(
304 | self,
305 | dim,
306 | n_heads,
307 | n_kv_heads,
308 | multiple_of,
309 | ffn_dim_multiplier,
310 | norm_eps,
311 | qk_norm,
312 | y_dim,
313 | max_position_embeddings,
314 | ):
315 | super().__init__()
316 | self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, y_dim=y_dim, max_position_embeddings=max_position_embeddings)
317 | self.feed_forward = LLamaFeedForward(
318 | dim=dim,
319 | hidden_dim=4 * dim,
320 | multiple_of=multiple_of,
321 | ffn_dim_multiplier=ffn_dim_multiplier,
322 | )
323 | self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
324 | self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
325 | self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
326 | self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
327 | self.adaLN_modulation = nn.Sequential(
328 | nn.SiLU(),
329 | nn.Linear(min(dim, 1024), 4 * dim),
330 | )
331 | self.attention_y_norm = RMSNorm(y_dim, eps=norm_eps)
332 |
333 | def forward(
334 | self,
335 | x,
336 | x_mask,
337 | freqs_cis,
338 | y,
339 | y_mask,
340 | adaln_input=None,
341 | ):
342 | if adaln_input is not None:
343 | scales_gates = self.adaLN_modulation(adaln_input)
344 | # TODO: Duong - check the dimension of chunking
345 | # scale_msa, gate_msa, scale_mlp, gate_mlp = scales_gates.chunk(4, dim=-1)
346 | scale_msa, gate_msa, scale_mlp, gate_mlp = scales_gates.chunk(4, dim=-1)
347 | x = x + torch.tanh(gate_msa) * self.attention_norm2(
348 | self.attention(
349 | modulate(self.attention_norm1(x), scale_msa), # ok
350 | x_mask,
351 | freqs_cis,
352 | self.attention_y_norm(y), # ok
353 | y_mask,
354 | )
355 | )
356 | x = x + torch.tanh(gate_mlp) * self.ffn_norm2(
357 | self.feed_forward(
358 | modulate(self.ffn_norm1(x), scale_mlp),
359 | )
360 | )
361 | else:
362 | x = x + self.attention_norm2(
363 | self.attention(
364 | self.attention_norm1(x),
365 | x_mask,
366 | freqs_cis,
367 | self.attention_y_norm(y),
368 | y_mask,
369 | )
370 | )
371 | x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
372 | return x
373 |
374 |
375 | class NextDiT(ModelMixin, ConfigMixin):
376 | """
377 | Diffusion model with a Transformer backbone for joint image-video training.
378 | """
379 | @register_to_config
380 | def __init__(
381 | self,
382 | input_size=(1, 32, 32),
383 | patch_size=(1, 2, 2),
384 | in_channels=16,
385 | hidden_size=4096,
386 | depth=32,
387 | num_heads=32,
388 | num_kv_heads=None,
389 | multiple_of=256,
390 | ffn_dim_multiplier=None,
391 | norm_eps=1e-5,
392 | pred_sigma=False,
393 | caption_channels=4096,
394 | qk_norm=False,
395 | norm_type="rms",
396 | model_max_length=120,
397 | rotary_max_length=384,
398 | rotary_max_length_t=None
399 | ):
400 | super().__init__()
401 | self.input_size = input_size
402 | self.patch_size = patch_size
403 | self.in_channels = in_channels
404 | self.hidden_size = hidden_size
405 | self.depth = depth
406 | self.num_heads = num_heads
407 | self.num_kv_heads = num_kv_heads or num_heads
408 | self.multiple_of = multiple_of
409 | self.ffn_dim_multiplier = ffn_dim_multiplier
410 | self.norm_eps = norm_eps
411 | self.pred_sigma = pred_sigma
412 | self.caption_channels = caption_channels
413 | self.qk_norm = qk_norm
414 | self.norm_type = norm_type
415 | self.model_max_length = model_max_length
416 | self.rotary_max_length = rotary_max_length
417 | self.rotary_max_length_t = rotary_max_length_t
418 | self.out_channels = in_channels * 2 if pred_sigma else in_channels
419 |
420 | self.x_embedder = nn.Linear(np.prod(self.patch_size) * in_channels, hidden_size)
421 |
422 | self.t_embedder = TimestepEmbedder(min(hidden_size, 1024))
423 | self.y_embedder = nn.Sequential(
424 | nn.LayerNorm(caption_channels, eps=1e-6),
425 | nn.Linear(caption_channels, min(hidden_size, 1024)),
426 | )
427 |
428 | self.layers = nn.ModuleList([
429 | TransformerBlock(
430 | dim=hidden_size,
431 | n_heads=num_heads,
432 | n_kv_heads=self.num_kv_heads,
433 | multiple_of=multiple_of,
434 | ffn_dim_multiplier=ffn_dim_multiplier,
435 | norm_eps=norm_eps,
436 | qk_norm=qk_norm,
437 | y_dim=caption_channels,
438 | max_position_embeddings=rotary_max_length,
439 | )
440 | for _ in range(depth)
441 | ])
442 |
443 | self.final_layer = FinalLayer(
444 | hidden_size=hidden_size,
445 | num_patches=np.prod(patch_size),
446 | out_channels=self.out_channels,
447 | )
448 |
449 | assert (hidden_size // num_heads) % 6 == 0, "3d rope needs head dim to be divisible by 6"
450 |
451 | self.freqs_cis = self.precompute_freqs_cis(
452 | hidden_size // num_heads,
453 | self.rotary_max_length,
454 | end_t=self.rotary_max_length_t
455 | )
456 |
457 | def to(self, *args, **kwargs):
458 | self = super().to(*args, **kwargs)
459 | # self.freqs_cis = self.freqs_cis.to(*args, **kwargs)
460 | return self
461 |
462 | @staticmethod
463 | def precompute_freqs_cis(
464 | dim: int,
465 | end: int,
466 | end_t: int = None,
467 | theta: float = 10000.0,
468 | scale_factor: float = 1.0,
469 | scale_watershed: float = 1.0,
470 | timestep: float = 1.0,
471 | ):
472 | if timestep < scale_watershed:
473 | linear_factor = scale_factor
474 | ntk_factor = 1.0
475 | else:
476 | linear_factor = 1.0
477 | ntk_factor = scale_factor
478 |
479 | theta = theta * ntk_factor
480 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 6)[: (dim // 6)] / dim)) / linear_factor
481 |
482 | timestep = torch.arange(end, dtype=torch.float32)
483 | freqs = torch.outer(timestep, freqs).float()
484 | freqs_cis = torch.exp(1j * freqs)
485 |
486 | if end_t is not None:
487 | freqs_t = 1.0 / (theta ** (torch.arange(0, dim, 6)[: (dim // 6)] / dim)) / linear_factor
488 | timestep_t = torch.arange(end_t, dtype=torch.float32)
489 | freqs_t = torch.outer(timestep_t, freqs_t).float()
490 | freqs_cis_t = torch.exp(1j * freqs_t)
491 | freqs_cis_t = freqs_cis_t.view(end_t, 1, 1, dim // 6).repeat(1, end, end, 1)
492 | else:
493 | end_t = end
494 | freqs_cis_t = freqs_cis.view(end_t, 1, 1, dim // 6).repeat(1, end, end, 1)
495 |
496 | freqs_cis_h = freqs_cis.view(1, end, 1, dim // 6).repeat(end_t, 1, end, 1)
497 | freqs_cis_w = freqs_cis.view(1, 1, end, dim // 6).repeat(end_t, end, 1, 1)
498 | freqs_cis = torch.cat([freqs_cis_t, freqs_cis_h, freqs_cis_w], dim=-1).view(end_t, end, end, -1)
499 | return freqs_cis
500 |
501 | def forward(
502 | self,
503 | samples,
504 | timesteps,
505 | encoder_hidden_states,
506 | encoder_attention_mask,
507 | scale_factor: float = 1.0, # scale_factor for rotary embedding
508 | scale_watershed: float = 1.0, # scale_watershed for rotary embedding
509 | ):
510 | if samples.ndim == 4: # B C H W
511 | samples = samples[:, None, ...] # B F C H W
512 |
513 | precomputed_freqs_cis = None
514 | if scale_factor != 1 or scale_watershed != 1:
515 | precomputed_freqs_cis = self.precompute_freqs_cis(
516 | self.hidden_size // self.num_heads,
517 | self.rotary_max_length,
518 | end_t=self.rotary_max_length_t,
519 | scale_factor=scale_factor,
520 | scale_watershed=scale_watershed,
521 | timestep=torch.max(timesteps.cpu()).item()
522 | )
523 |
524 | if len(timesteps.shape) == 5:
525 | t, *_ = self.patchify(timesteps, precomputed_freqs_cis)
526 | timesteps = t.mean(dim=-1)
527 | elif len(timesteps.shape) == 1:
528 | timesteps = timesteps[:, None, None, None, None].expand_as(samples)
529 | t, *_ = self.patchify(timesteps, precomputed_freqs_cis)
530 | timesteps = t.mean(dim=-1)
531 | samples, T, H, W, freqs_cis = self.patchify(samples, precomputed_freqs_cis)
532 | samples = self.x_embedder(samples)
533 | t = self.t_embedder(timesteps)
534 |
535 | encoder_attention_mask_float = encoder_attention_mask[..., None].float()
536 | encoder_hidden_states_pool = (encoder_hidden_states * encoder_attention_mask_float).sum(dim=1) / (encoder_attention_mask_float.sum(dim=1) + 1e-8)
537 | encoder_hidden_states_pool = encoder_hidden_states_pool.to(samples.dtype)
538 | y = self.y_embedder(encoder_hidden_states_pool)
539 | y = y.unsqueeze(1).expand(-1, samples.size(1), -1)
540 |
541 | adaln_input = t + y
542 |
543 | for block in self.layers:
544 | samples = block(samples, None, freqs_cis, encoder_hidden_states, encoder_attention_mask, adaln_input)
545 |
546 | samples = self.final_layer(samples, adaln_input)
547 | samples = self.unpatchify(samples, T, H, W)
548 |
549 | return samples
550 |
551 | def patchify(self, x, precompute_freqs_cis=None):
552 | # pytorch is C, H, W
553 | B, T, C, H, W = x.size()
554 | pT, pH, pW = self.patch_size
555 | x = x.view(B, T // pT, pT, C, H // pH, pH, W // pW, pW)
556 | x = x.permute(0, 1, 4, 6, 2, 5, 7, 3)
557 | x = x.reshape(B, -1, pT * pH * pW * C)
558 | if precompute_freqs_cis is None:
559 | freqs_cis = self.freqs_cis[: T // pT, :H // pH, :W // pW].reshape(-1, * self.freqs_cis.shape[3:])[None].to(x.device)
560 | else:
561 | freqs_cis = precompute_freqs_cis[: T // pT, :H // pH, :W // pW].reshape(-1, * precompute_freqs_cis.shape[3:])[None].to(x.device)
562 | return x, T // pT, H // pH, W // pW, freqs_cis
563 |
564 | def unpatchify(self, x, T, H, W):
565 | B = x.size(0)
566 | C = self.out_channels
567 | pT, pH, pW = self.patch_size
568 | x = x.view(B, T, H, W, pT, pH, pW, C)
569 | x = x.permute(0, 1, 4, 7, 2, 5, 3, 6)
570 | x = x.reshape(B, T * pT, C, H * pH, W * pW)
571 | return x
572 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytest
2 | matplotlib
3 | scikit-learn
4 | scipy
5 | numpy
6 | einops
7 | einsum
8 | fvcore
9 | h5py
10 | twine
11 | sentencepiece
12 | tensorflow-cpu==2.11.0
13 | protobuf==3.19.6
14 | transformers==4.45.2
15 | huggingface_hub==0.24
16 | accelerate==0.34.2
17 | diffusers==0.30.3
18 | pillow==10.2.0
19 | torch==2.3.1
20 | torchvision==0.18.1
21 | torchaudio==2.3.1
22 | flash-attn==2.6.3
23 | git+https://github.com/Dao-AILab/flash-attention@v2.6.3#subdirectory=csrc/fused_dense_lib
24 | jaxtyping
25 | mediapipe
26 | gradio
27 | git+https://github.com/facebookresearch/pytorch3d.git
28 | opencv-python==4.5.5.64
29 | opencv-python-headless==4.5.5.64
30 | bitsandbytes==0.45.0
--------------------------------------------------------------------------------