├── .gitignore ├── LICENSE ├── PROMPT_GUIDE.md ├── README.md ├── assets ├── cond_and_image.jpg ├── examples │ ├── canny │ │ └── canny.webp │ ├── depth │ │ └── astronaut.webp │ ├── id_customization │ │ ├── chenhao │ │ │ ├── image_0.png │ │ │ ├── image_1.png │ │ │ └── image_2.png │ │ └── chika │ │ │ ├── image_0.png │ │ │ └── image_1.webp │ ├── image_editing │ │ └── astronaut.webp │ ├── images │ │ ├── cat.webp │ │ └── cat_on_table.webp │ ├── inpainting │ │ └── giorno.webp │ ├── semantic_map │ │ ├── dragon.webp │ │ ├── dragon_and_birds.png │ │ └── dragon_birds_woman.webp │ └── subject_driven │ │ └── chill_guy.jpg ├── gradio_cached_examples │ └── 42 │ │ ├── Generated Images │ │ ├── 17b88522de30372d1f35 │ │ │ └── image.webp │ │ ├── 317b1e50f07fbdcc2f45 │ │ │ └── image.webp │ │ ├── 53575bb9e63d8bc1a71c │ │ │ └── image.webp │ │ ├── 581111895a66f93b1f10 │ │ │ └── image.webp │ │ ├── 5d11a379682726a1dc31 │ │ │ └── image.webp │ │ ├── 6beeb89d25287417e72d │ │ │ └── image.webp │ │ ├── 80d358ca48472ae5c466 │ │ │ └── image.webp │ │ ├── 81bd6641e26a411b2da6 │ │ │ └── image.webp │ │ ├── 892c2282d95cba917fa2 │ │ │ └── image.webp │ │ ├── 8a9a8c9e07527586fc15 │ │ │ └── image.webp │ │ ├── 8fd76ebb6657d26dd92c │ │ │ └── image.webp │ │ ├── 930ca5ca2692b479b752 │ │ │ └── image.webp │ │ ├── 96928c780e841920bfcc │ │ │ └── image.webp │ │ ├── 9bdb84088da3c475e4bd │ │ │ └── image.webp │ │ ├── b6e34a7e72d74a80b52f │ │ │ └── image.webp │ │ ├── c8580411711e5880bb78 │ │ │ └── image.webp │ │ └── fe7d4925dcd9c0186674 │ │ │ └── image.webp │ │ ├── Image Preview │ │ ├── 06eab237f5c35c1b745f │ │ │ └── image.webp │ │ ├── 08c2007af01acb165e9a │ │ │ └── image.webp │ │ ├── 133fc4374bccf2d6b2ca │ │ │ └── image.webp │ │ ├── 2b1daf3b917623fd8ff2 │ │ │ └── image.webp │ │ ├── 2d82244403afa8c210e3 │ │ │ └── image.webp │ │ ├── 303d58e0a1cdb5b953f9 │ │ │ └── image.webp │ │ ├── 4284edec7b2efdd3fef1 │ │ │ └── image.webp │ │ ├── 8348cbc4b20a443fc03b │ │ │ └── image.webp │ │ ├── 86bf806faee9fbb5d92e │ │ │ └── image.webp │ │ ├── 9065eadbcdacd92fe32f │ │ │ └── image.webp │ │ └── e47602f78caeedc4c6d3 │ │ │ └── image.webp │ │ ├── Input Images │ │ ├── 0ca5ab7959eed9fef525 │ │ │ └── astronaut.webp │ │ ├── 10e196d148112fc7eb6a │ │ │ └── dragon_birds_woman.webp │ │ ├── 3a98c4847c8c4e64851e │ │ │ └── giorno.webp │ │ ├── 43c8d5a2278e70960d7b │ │ │ └── image_1.png │ │ ├── 491863b2f47c65f9ac67 │ │ │ └── image_0.png │ │ ├── 7a5a351a562f9423582e │ │ │ └── image_2.png │ │ ├── 830bb33b0629cc76770d │ │ │ └── cat_on_table.webp │ │ ├── bbfa399fd2eb84e96645 │ │ │ └── chill_guy.jpg │ │ ├── c2c85d82ceb9a9574063 │ │ │ └── image_0.png │ │ ├── c875f74b3d73b2ea804f │ │ │ └── astronaut.webp │ │ └── f63287119763bed48822 │ │ │ └── cat.webp │ │ └── log.csv ├── onediffusion_appendix_faceid.jpg ├── onediffusion_appendix_faceid_3.jpg ├── onediffusion_appendix_multiview.jpg ├── onediffusion_appendix_multiview_2.jpg ├── onediffusion_appendix_text2multiview.pdf ├── onediffusion_editing.jpg ├── onediffusion_zeroshot.jpg ├── promptguide_complex.jpg ├── promptguide_idtask.jpg ├── subject_driven.jpg ├── teaser.png ├── text2image.jpg └── text2multiview.jpg ├── docker └── Dockerfile ├── gradio_demo.py ├── inference.py ├── onediffusion ├── dataset │ ├── multitask │ │ └── multiview.py │ ├── raydiff_utils.py │ ├── transforms.py │ └── utils.py ├── diffusion │ └── pipelines │ │ ├── image_processor.py │ │ └── onediffusion.py └── models │ └── denoiser │ ├── __init__.py │ └── nextdit │ ├── __init__.py │ ├── layers.py │ └── modeling_nextdit.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | .vscode/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /PROMPT_GUIDE.md: -------------------------------------------------------------------------------- 1 | # Prompt Guide 2 | 3 | All examples are generated with a CFG of $4.2$, $50$ steps, and are non-cherrypicked unless otherwise stated. Negative prompt is set to: 4 | ``` 5 | monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation 6 | ``` 7 | 8 | ## 1. Text-to-Image 9 | 10 | ### 1.1 Long and detailed prompts give (much) better results. 11 | 12 | Since our training comprised of long and detailed prompts, the model is more likely to generate better images with detailed prompts. 13 | 14 | 15 | The model shows good text adherence with long and complex prompts as in below images. We use the first $20$ prompts from [simoryu's examples](https://cloneofsimo.github.io/compare_aura_sd3/). For detailed prompts, results of other models, refer to the above link. 16 | 17 |

18 | Text-to-Image results 19 |

20 | 21 | 22 | ### 1.2 Resolution 23 | 24 | The model generally works well with height and width in range of $[768; 1280]$ (height/width must be divisible by 16) for text-to-image. For other tasks, it performs best with resolution around $512$. 25 | 26 | ## 2. ID Customization & Subject-driven generation 27 | 28 | - The expected length of source captions is $30$ to $75$ words. Empirically, we find that longer prompt can help preserve the ID better but it might hinder the text-adherence for target caption. 29 | 30 | - We find it better to add some descriptions (e.g., from source caption) to target to preserve the identity, especially for complex subjects with delicate details. 31 | 32 |

33 | ablation id task 34 |

35 | 36 | ## 3. Multiview generation 37 | 38 | We recommend not use captions, which describe the facial features e.g., looking at the camera, etc, to mitigate multifaced/janus problems. 39 | 40 | ## 4. Image editing 41 | 42 | We find it's generally better to set the guidance scale to lower value e.g., $[3; 3.5]$ to avoid over-saturation results. 43 | 44 | ## 5. Special tokens and available colors 45 | 46 | ### 5.1 Task Tokens 47 | 48 | | Task | Token | Additional Tokens | 49 | |:---------------------|:---------------------------|:------------------| 50 | | Text to Image | `[[text2image]]` | | 51 | | Deblurring | `[[deblurring]]` | | 52 | | Inpainting | `[[image_inpainting]]` | | 53 | | Canny-edge and Image | `[[canny2image]]` | | 54 | | Depth and Image | `[[depth2image]]` | | 55 | | Hed and Image | `[[hed2img]]` | | 56 | | Pose and Image | `[[pose2image]]` | | 57 | | Image editing with Instruction | `[[image_editing]]` | | 58 | | Semantic map and Image| `[[semanticmap2image]]` | `<#00FFFF cyan mask: object/to/segment>` | 59 | | Boundingbox and Image | `[[boundingbox2image]]` | `<#00FFFF cyan boundingbox: object/to/detect>` | 60 | | ID customization | `[[faceid]]` | `[[img0]] target/caption [[img1]] caption/of/source/image_1 [[img2]] caption/of/source/image_2 [[img3]] caption/of/source/image_3` | 61 | | Multiview | `[[multiview]]` | | 62 | | Subject-Driven | `[[subject_driven]]` | ` [[img0]] target/caption/goes/here [[img1]] insert/source/caption` | 63 | 64 | 65 | Note that you can replace the cyan color above with any from below table and have multiple additional tokens to detect/segment multiple classes. 66 | 67 | ### 5.2 Available colors 68 | 69 | 70 | | Hex Code | Color Name | 71 | |:---------|:-----------| 72 | | #FF0000 | red | 73 | | #00FF00 | lime | 74 | | #0000FF | blue | 75 | | #FFFF00 | yellow | 76 | | #FF00FF | magenta | 77 | | #00FFFF | cyan | 78 | | #FFA500 | orange | 79 | | #800080 | purple | 80 | | #A52A2A | brown | 81 | | #008000 | green | 82 | | #FFC0CB | pink | 83 | | #008080 | teal | 84 | | #FF8C00 | darkorange | 85 | | #8A2BE2 | blueviolet | 86 | | #006400 | darkgreen | 87 | | #FF4500 | orangered | 88 | | #000080 | navy | 89 | | #FFD700 | gold | 90 | | #40E0D0 | turquoise | 91 | | #DA70D6 | orchid | 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # One Diffusion to Generate Them All 2 | 3 |

4 | 5 | Build 6 | 7 | 8 | Build 9 | 10 | 11 | License 12 | 13 | 14 | Build 15 | 16 |

17 | 18 |

19 |

20 | News | 21 | Quick start | 22 | Prompt guide & Supported tasks | 23 | Qualitative results | 24 | License | 25 | Citation 26 |

27 |

28 | 29 | 30 |

31 | Teaser Image 32 |

33 | 34 | 35 | This is official repo of OneDiffusion, a versatile, large-scale diffusion model that seamlessly supports bidirectional image synthesis and understanding across diverse tasks. 36 | 37 | **For more detail, read our paper [here](https://arxiv.org/abs/2411.16318).** 38 | 39 | ## News 40 | - 📦 2024/12/11: [Huggingface space](https://huggingface.co/spaces/lehduong/OneDiffusion) is online. Reduce the VRAM requirements for running demo with Molmo to 21GB. 41 | - 📦 2024/12/10: Released [weight](https://huggingface.co/lehduong/OneDiffusion) and inference code. 42 | - ✨ 2024/12/06: Added image editing from instruction. 43 | - ✨ 2024/12/02: Added subject-driven generation 44 | 45 | ## Installation 46 | ``` 47 | conda create -n onediffusion_env python=3.8 && 48 | conda activate onediffusion_env && 49 | pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118 && 50 | pip install "git+https://github.com/facebookresearch/pytorch3d.git" && 51 | pip install -r requirements.txt 52 | ``` 53 | 54 | ## Quick start 55 | 56 | Check `inference.py` for more detailed. For text-to-image, you can use below code snipe. 57 | 58 | ``` 59 | import torch 60 | from onediffusion.diffusion.pipelines.onediffusion import OneDiffusionPipeline 61 | 62 | device = torch.device('cuda:0') 63 | 64 | pipeline = OneDiffusionPipeline.from_pretrained("lehduong/OneDiffusion").to(device=device, dtype=torch.bfloat16) 65 | 66 | NEGATIVE_PROMPT = "monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation" 67 | 68 | output = pipeline( 69 | prompt="[[text2image]] A bipedal black cat wearing a huge oversized witch hat, a wizards robe, casting a spell,in an enchanted forest. The scene is filled with fireflies and moss on surrounding rocks and trees", 70 | negative_prompt=NEGATIVE_PROMPT, 71 | num_inference_steps=50, 72 | guidance_scale=4, 73 | height=1024, 74 | width=1024, 75 | ) 76 | output.images[0].save('text2image_output.jpg') 77 | ``` 78 | 79 | You can run the gradio demo with: 80 | ``` 81 | python gradio_demo.py --captioner molmo # [molmo, llava, disable] 82 | ``` 83 | The demo provides guidance and helps format the prompt properly for each task. 84 | 85 | - By default, it loads the **quantized** Molmo for captioning source images. ~~which significantly increases memory usage. You generally need a GPU with at least $40$ GB of memory to run the demo.~~ You generally need a GPU with at least $21$ GB of memory to run the demo. 86 | 87 | - Opting to use LLaVA can reduce this requirement to $\approx 27$ GB, though the resulting captions may be less accurate in some cases. 88 | 89 | - You can also manually provide the caption for each input image and run with `disable` mode. In this mode, the returned caption is an empty string, but you should still press the `Generate Caption` button so that the code formats the input text properly. The memory requirement for this mode is $\approx 12$ GB. 90 | 91 | Note that the above required memory can change if you use higher resolution or more input images. 92 | 93 | ## Qualitative Results 94 | 95 | ### 1. Text-to-Image 96 |

97 | Text-to-Image results 98 |

99 | 100 | 101 | ### 2. ID customization 102 | 103 |

104 | ID customization 105 |

106 | 107 |

108 | ID customization non-human subject 109 |

110 | 111 | ### 3. Multiview generation 112 | 113 | Single image to multiview: 114 | 115 |

116 | Image to multiview 117 |

118 | 119 |

120 | image to multiview 121 |

122 | 123 | Text to multiview: 124 | 125 |

126 | Text to multiview image 127 |

128 | 129 | ### 4. Condition-to-Image and vice versa 130 |

131 | Condition and Image 132 |

133 | 134 | ### 5. Subject-driven generation 135 | 136 | We finetuned the model on [Subject-200K](https://huggingface.co/datasets/Yuanshi/Subjects200K) dataset (along with all other tasks) for additional 40k steps. The model is now capable of subject-driven generation. 137 | 138 |

139 | Subject driven generation 140 |

141 | 142 | ### 6. Text-guide image editing 143 | 144 | We finetuned the model on [OmniEdit](https://huggingface.co/datasets/TIGER-Lab/OmniEdit-Filtered-1.2M) dataset for additional 30K steps. 145 | 146 |

147 | Text-guide editing 148 |

149 | 150 | ### 7. Zero-shot Task combinations 151 | 152 | We found that the model can handle multiple tasks in a zero-shot setting by combining condition images and task tokens without any fine-tuning, as shown in the examples below. However, its performance on these combined tasks might not be robust, and the model’s behavior may change if the order of task tokens or captions is altered. For example, when using both image inpainting and ID customization together, the target prompt and the caption of the masked image must be identical. If you plan to use such combinations, we recommend fine-tuning the model on these tasks to achieve better performance and simpler usage. 153 | 154 | 155 |

156 | Subject driven generation 157 |

158 | 159 | ## License 160 | 161 | The model is trained on several non-commercially licensed datasets (e.g., DL3DV, Unsplash), thus, **model weights** are released under a CC BY-NC license as described in [LICENSE](https://github.com/lehduong/onediffusion/blob/main/LICENSE). 162 | 163 | ## Citation 164 | 165 | ```bibtex 166 | @misc{le2024diffusiongenerate, 167 | title={One Diffusion to Generate Them All}, 168 | author={Duong H. Le and Tuan Pham and Sangho Lee and Christopher Clark and Aniruddha Kembhavi and Stephan Mandt and Ranjay Krishna and Jiasen Lu}, 169 | year={2024}, 170 | eprint={2411.16318}, 171 | archivePrefix={arXiv}, 172 | primaryClass={cs.CV}, 173 | url={https://arxiv.org/abs/2411.16318}, 174 | } 175 | ``` 176 | -------------------------------------------------------------------------------- /assets/cond_and_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/cond_and_image.jpg -------------------------------------------------------------------------------- /assets/examples/canny/canny.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/canny/canny.webp -------------------------------------------------------------------------------- /assets/examples/depth/astronaut.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/depth/astronaut.webp -------------------------------------------------------------------------------- /assets/examples/id_customization/chenhao/image_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/id_customization/chenhao/image_0.png -------------------------------------------------------------------------------- /assets/examples/id_customization/chenhao/image_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/id_customization/chenhao/image_1.png -------------------------------------------------------------------------------- /assets/examples/id_customization/chenhao/image_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/id_customization/chenhao/image_2.png -------------------------------------------------------------------------------- /assets/examples/id_customization/chika/image_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/id_customization/chika/image_0.png -------------------------------------------------------------------------------- /assets/examples/id_customization/chika/image_1.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/id_customization/chika/image_1.webp -------------------------------------------------------------------------------- /assets/examples/image_editing/astronaut.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/image_editing/astronaut.webp -------------------------------------------------------------------------------- /assets/examples/images/cat.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/images/cat.webp -------------------------------------------------------------------------------- /assets/examples/images/cat_on_table.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/images/cat_on_table.webp -------------------------------------------------------------------------------- /assets/examples/inpainting/giorno.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/inpainting/giorno.webp -------------------------------------------------------------------------------- /assets/examples/semantic_map/dragon.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/semantic_map/dragon.webp -------------------------------------------------------------------------------- /assets/examples/semantic_map/dragon_and_birds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/semantic_map/dragon_and_birds.png -------------------------------------------------------------------------------- /assets/examples/semantic_map/dragon_birds_woman.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/semantic_map/dragon_birds_woman.webp -------------------------------------------------------------------------------- /assets/examples/subject_driven/chill_guy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/examples/subject_driven/chill_guy.jpg -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Generated Images/17b88522de30372d1f35/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/17b88522de30372d1f35/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Generated Images/317b1e50f07fbdcc2f45/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/317b1e50f07fbdcc2f45/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Generated Images/53575bb9e63d8bc1a71c/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/53575bb9e63d8bc1a71c/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Generated Images/581111895a66f93b1f10/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/581111895a66f93b1f10/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Generated Images/5d11a379682726a1dc31/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/5d11a379682726a1dc31/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Generated Images/6beeb89d25287417e72d/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/6beeb89d25287417e72d/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Generated Images/80d358ca48472ae5c466/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/80d358ca48472ae5c466/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Generated Images/81bd6641e26a411b2da6/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/81bd6641e26a411b2da6/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Generated Images/892c2282d95cba917fa2/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/892c2282d95cba917fa2/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Generated Images/8a9a8c9e07527586fc15/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/8a9a8c9e07527586fc15/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Generated Images/8fd76ebb6657d26dd92c/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/8fd76ebb6657d26dd92c/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Generated Images/930ca5ca2692b479b752/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/930ca5ca2692b479b752/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Generated Images/96928c780e841920bfcc/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/96928c780e841920bfcc/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Generated Images/9bdb84088da3c475e4bd/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/9bdb84088da3c475e4bd/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Generated Images/b6e34a7e72d74a80b52f/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/b6e34a7e72d74a80b52f/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Generated Images/c8580411711e5880bb78/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/c8580411711e5880bb78/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Generated Images/fe7d4925dcd9c0186674/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Generated Images/fe7d4925dcd9c0186674/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Image Preview/06eab237f5c35c1b745f/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/06eab237f5c35c1b745f/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Image Preview/08c2007af01acb165e9a/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/08c2007af01acb165e9a/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Image Preview/133fc4374bccf2d6b2ca/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/133fc4374bccf2d6b2ca/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Image Preview/2b1daf3b917623fd8ff2/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/2b1daf3b917623fd8ff2/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Image Preview/2d82244403afa8c210e3/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/2d82244403afa8c210e3/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Image Preview/303d58e0a1cdb5b953f9/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/303d58e0a1cdb5b953f9/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Image Preview/4284edec7b2efdd3fef1/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/4284edec7b2efdd3fef1/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Image Preview/8348cbc4b20a443fc03b/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/8348cbc4b20a443fc03b/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Image Preview/86bf806faee9fbb5d92e/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/86bf806faee9fbb5d92e/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Image Preview/9065eadbcdacd92fe32f/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/9065eadbcdacd92fe32f/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Image Preview/e47602f78caeedc4c6d3/image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Image Preview/e47602f78caeedc4c6d3/image.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Input Images/0ca5ab7959eed9fef525/astronaut.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/0ca5ab7959eed9fef525/astronaut.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Input Images/10e196d148112fc7eb6a/dragon_birds_woman.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/10e196d148112fc7eb6a/dragon_birds_woman.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Input Images/3a98c4847c8c4e64851e/giorno.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/3a98c4847c8c4e64851e/giorno.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Input Images/43c8d5a2278e70960d7b/image_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/43c8d5a2278e70960d7b/image_1.png -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Input Images/491863b2f47c65f9ac67/image_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/491863b2f47c65f9ac67/image_0.png -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Input Images/7a5a351a562f9423582e/image_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/7a5a351a562f9423582e/image_2.png -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Input Images/830bb33b0629cc76770d/cat_on_table.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/830bb33b0629cc76770d/cat_on_table.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Input Images/bbfa399fd2eb84e96645/chill_guy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/bbfa399fd2eb84e96645/chill_guy.jpg -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Input Images/c2c85d82ceb9a9574063/image_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/c2c85d82ceb9a9574063/image_0.png -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Input Images/c875f74b3d73b2ea804f/astronaut.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/c875f74b3d73b2ea804f/astronaut.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/Input Images/f63287119763bed48822/cat.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/gradio_cached_examples/42/Input Images/f63287119763bed48822/cat.webp -------------------------------------------------------------------------------- /assets/gradio_cached_examples/42/log.csv: -------------------------------------------------------------------------------- 1 | Generated Images,Generation Status,Input Images,component 3,Image Preview,flag,username,timestamp 2 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/b6e34a7e72d74a80b52f/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/d8cd3b6318cdd3c5d4bdd4c37c9a2f2df0e47a91647e22cde72befa22a9bb748/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,[],,[],,,2024-12-13 20:14:34.209093 3 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/17b88522de30372d1f35/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/cc0dc567fb017145e3fbf502ef9e10ba9cefa8c45e15c62d3d060eaea986f6b1/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/491863b2f47c65f9ac67/image_0.png"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image_0.png"", ""size"": 1191470, ""orig_name"": ""image_0.png"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/8348cbc4b20a443fc03b/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:14:49.555938 4 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/5d11a379682726a1dc31/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/26f13a9435ed998c8ef570925c40d60345f5ca26b27a7d08c3faa76b397455b5/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/c2c85d82ceb9a9574063/image_0.png"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image_0.png"", ""size"": 628293, ""orig_name"": ""image_0.png"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""assets/gradio_cached_examples/42/Input Images/43c8d5a2278e70960d7b/image_1.png"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image_1.png"", ""size"": 551085, ""orig_name"": ""image_1.png"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, {""path"": ""assets/gradio_cached_examples/42/Input Images/7a5a351a562f9423582e/image_2.png"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image_2.png"", ""size"": 620161, ""orig_name"": ""image_2.png"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/9065eadbcdacd92fe32f/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}, {""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/2d82244403afa8c210e3/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}, {""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/86bf806faee9fbb5d92e/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:15:13.029766 5 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/fe7d4925dcd9c0186674/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/6be14ba2f6c3de6620f526c45f5043c6d252a4cfb1d2400a168e70f005ec4a61/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}, {""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/930ca5ca2692b479b752/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}, {""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/53575bb9e63d8bc1a71c/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/830bb33b0629cc76770d/cat_on_table.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/cat_on_table.webp"", ""size"": 41180, ""orig_name"": ""cat_on_table.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/133fc4374bccf2d6b2ca/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:16:09.060610 6 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/581111895a66f93b1f10/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/ce1da06372f563fdca891efc30288e2a85d568f2351b54c36a19b75a77ac8d8b/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/10e196d148112fc7eb6a/dragon_birds_woman.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/dragon_birds_woman.webp"", ""size"": 16180, ""orig_name"": ""dragon_birds_woman.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/2b1daf3b917623fd8ff2/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:16:19.208900 7 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/892c2282d95cba917fa2/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/441882b3d10fbb91786980f9c484e518b3a8056558a7fad1fbe4d7db4aec408e/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/bbfa399fd2eb84e96645/chill_guy.jpg"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/chill_guy.jpg"", ""size"": 61273, ""orig_name"": ""chill_guy.jpg"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/303d58e0a1cdb5b953f9/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:16:33.375436 8 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/317b1e50f07fbdcc2f45/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/12e1af3a972f41b652b3ec04cd7e256ef43815e0e8ded1a9d758ffc18585ada1/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/c875f74b3d73b2ea804f/astronaut.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/astronaut.webp"", ""size"": 5528, ""orig_name"": ""astronaut.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/06eab237f5c35c1b745f/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:16:43.378123 9 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/8a9a8c9e07527586fc15/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/1be0cd7d19f19d2ca9bef889b2d0a17c1ee6e2b1cab471db19103c116461a954/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}, {""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/6beeb89d25287417e72d/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/f63287119763bed48822/cat.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/cat.webp"", ""size"": 57702, ""orig_name"": ""cat.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/e47602f78caeedc4c6d3/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:16:53.637761 10 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/80d358ca48472ae5c466/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/fd6c6cd46e17b3233178109e2fbd1593a88280484c665301fe329d14ba7dd913/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/0ca5ab7959eed9fef525/astronaut.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/astronaut.webp"", ""size"": 62630, ""orig_name"": ""astronaut.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/4284edec7b2efdd3fef1/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:17:05.900479 11 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/96928c780e841920bfcc/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/6f7388d917738246baaadb98c1a8f045374e9c182a3b7697c18a05f8f1bff3e6/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}, {""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/81bd6641e26a411b2da6/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}, {""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/8fd76ebb6657d26dd92c/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}, {""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/c8580411711e5880bb78/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,[],,[],,,2024-12-13 20:18:01.760111 12 | "[{""image"": {""path"": ""assets/gradio_cached_examples/42/Generated Images/9bdb84088da3c475e4bd/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/00a1b6bb103d9ed89daaae43449e6975a39134b3f5ba2ce6e23b23f6c9f967f2/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",Generation completed successfully!,"[{""path"": ""assets/gradio_cached_examples/42/Input Images/3a98c4847c8c4e64851e/giorno.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/giorno.webp"", ""size"": 17570, ""orig_name"": ""giorno.webp"", ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}]",,"[{""image"": {""path"": ""assets/gradio_cached_examples/42/Image Preview/08c2007af01acb165e9a/image.webp"", ""url"": ""/file=./assets/gradio_cached_examples/42/Generated Images/image.webp"", ""size"": null, ""orig_name"": null, ""mime_type"": null, ""is_stream"": false, ""meta"": {""_type"": ""gradio.FileData""}}, ""caption"": null}]",,,2024-12-13 20:18:11.526344 13 | -------------------------------------------------------------------------------- /assets/onediffusion_appendix_faceid.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/onediffusion_appendix_faceid.jpg -------------------------------------------------------------------------------- /assets/onediffusion_appendix_faceid_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/onediffusion_appendix_faceid_3.jpg -------------------------------------------------------------------------------- /assets/onediffusion_appendix_multiview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/onediffusion_appendix_multiview.jpg -------------------------------------------------------------------------------- /assets/onediffusion_appendix_multiview_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/onediffusion_appendix_multiview_2.jpg -------------------------------------------------------------------------------- /assets/onediffusion_appendix_text2multiview.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/onediffusion_appendix_text2multiview.pdf -------------------------------------------------------------------------------- /assets/onediffusion_editing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/onediffusion_editing.jpg -------------------------------------------------------------------------------- /assets/onediffusion_zeroshot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/onediffusion_zeroshot.jpg -------------------------------------------------------------------------------- /assets/promptguide_complex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/promptguide_complex.jpg -------------------------------------------------------------------------------- /assets/promptguide_idtask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/promptguide_idtask.jpg -------------------------------------------------------------------------------- /assets/subject_driven.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/subject_driven.jpg -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/teaser.png -------------------------------------------------------------------------------- /assets/text2image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/text2image.jpg -------------------------------------------------------------------------------- /assets/text2multiview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lehduong/OneDiffusion/6a64303695bea9209d87a2943b329705e4925207/assets/text2multiview.jpg -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Inspired by https://github.com/anibali/docker-pytorch/blob/master/dockerfiles/1.10.0-cuda11.3-ubuntu20.04/Dockerfile 2 | # ARG COMPAT=0 3 | ARG PERSONAL=0 4 | # FROM nvidia/cuda:11.3.1-devel-ubuntu20.04 as base-0 5 | FROM nvcr.io/nvidia/pytorch:22.12-py3 as base 6 | 7 | ENV HOST docker 8 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 9 | # https://serverfault.com/questions/683605/docker-container-time-timezone-will-not-reflect-changes 10 | ENV TZ America/Los_Angeles 11 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 12 | 13 | # git for installing dependencies 14 | # tzdata to set time zone 15 | # wget and unzip to download data 16 | # [2021-09-09] TD: zsh, stow, subversion, fasd are for setting up my personal environment. 17 | # [2021-12-07] TD: openmpi-bin for MPI (multi-node training) 18 | RUN apt-get update && apt-get install -y --no-install-recommends \ 19 | build-essential \ 20 | cmake \ 21 | curl \ 22 | ca-certificates \ 23 | sudo \ 24 | less \ 25 | htop \ 26 | git \ 27 | tzdata \ 28 | wget \ 29 | tmux \ 30 | zip \ 31 | unzip \ 32 | zsh stow subversion fasd \ 33 | && rm -rf /var/lib/apt/lists/* 34 | # openmpi-bin \ 35 | 36 | # Allow running runmpi as root 37 | # ENV OMPI_ALLOW_RUN_AS_ROOT=1 OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 38 | 39 | # # Create a non-root user and switch to it 40 | # RUN adduser --disabled-password --gecos '' --shell /bin/bash user \ 41 | # && echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user 42 | # USER user 43 | 44 | # All users can use /home/user as their home directory 45 | ENV HOME=/home/user 46 | RUN mkdir -p /home/user && chmod 777 /home/user 47 | WORKDIR /home/user 48 | 49 | # Set up personal environment 50 | # FROM base-${COMPAT} as env-0 51 | FROM base as env-0 52 | FROM env-0 as env-1 53 | # Use ONBUILD so that the dotfiles dir doesn't need to exist unless we're building a personal image 54 | # https://stackoverflow.com/questions/31528384/conditional-copy-add-in-dockerfile 55 | ONBUILD COPY dotfiles ./dotfiles 56 | ONBUILD RUN cd ~/dotfiles && stow bash zsh tmux && sudo chsh -s /usr/bin/zsh $(whoami) 57 | # nvcr pytorch image sets SHELL=/bin/bash 58 | ONBUILD ENV SHELL=/bin/zsh 59 | 60 | FROM env-${PERSONAL} as packages 61 | 62 | # Disable pip cache: https://stackoverflow.com/questions/45594707/what-is-pips-no-cache-dir-good-for 63 | ENV PIP_NO_CACHE_DIR=1 64 | 65 | # # apex and pytorch-fast-transformers take a while to compile so we install them first 66 | # TD [2022-04-28] apex is already installed. In case we need a newer commit: 67 | # RUN pip install --upgrade --force-reinstall --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" --global-option="--fmha" --global-option="--fast_layer_norm" --global-option="--xentropy" git+https://github.com/NVIDIA/apex.git#egg=apex 68 | 69 | # xgboost conflicts with deepspeed 70 | RUN pip uninstall -y xgboost && DS_BUILD_UTILS=1 DS_BUILD_FUSED_LAMB=1 pip install deepspeed==0.7.7 71 | 72 | # General packages that we don't care about the version 73 | # zstandard to extract the_pile dataset 74 | # psutil to get the number of cpu physical cores 75 | # twine to upload package to PyPI 76 | RUN pip install pytest matplotlib jupyter ipython ipdb gpustat scikit-learn spacy munch einops opt_einsum fvcore gsutil cmake pykeops zstandard psutil h5py twine gdown \ 77 | && python -m spacy download en_core_web_sm 78 | # hydra 79 | RUN pip install hydra-core==1.3.1 hydra-colorlog==1.2.0 hydra-optuna-sweeper==1.2.0 pyrootutils rich 80 | # Core packages 81 | RUN pip install transformers==4.45.2 datasets==3.0.1 pytorch-lightning==2.2.1 triton==2.3.1 wandb==0.16.3 controlnet_aux==0.0.9 timm==0.6.7 torchmetrics==1.3.2 82 | # torchmetrics 0.11.0 broke hydra's instantiate 83 | 84 | # For MLPerf 85 | RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 86 | 87 | RUN pip install accelerate==0.34.2 88 | 89 | RUN pip install diffusers==0.30.3 90 | 91 | RUN pip install deepspeed==0.15.2 92 | 93 | RUN pip install sentencepiece==0.1.99 94 | 95 | RUN pip install pillow==10.2.0 96 | 97 | RUN pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu118 98 | 99 | # Install FlashAttention 100 | RUN pip install flash-attn==2.6.3 101 | 102 | # Install CUDA extensions for fused dense 103 | RUN pip install git+https://github.com/Dao-AILab/flash-attention@v2.6.3#subdirectory=csrc/fused_dense_lib 104 | 105 | RUN pip install jaxtyping mediapipe gradio 106 | 107 | RUN pip install "git+https://github.com/facebookresearch/pytorch3d.git" 108 | 109 | RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y 110 | 111 | RUN pip install opencv-python==4.5.5.64 112 | 113 | RUN pip install opencv-python-headless==4.5.5.64 114 | 115 | RUN pip install huggingface_hub==0.24 116 | 117 | RUN pip install numpy==1.24.4 118 | 119 | 120 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from onediffusion.diffusion.pipelines.onediffusion import OneDiffusionPipeline 4 | from PIL import Image 5 | 6 | 7 | device = torch.device('cuda:0') 8 | pipeline = OneDiffusionPipeline.from_pretrained("lehduong/OneDiffusion").to(device=device, dtype=torch.bfloat16) 9 | NEGATIVE_PROMPT = "monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation" 10 | output_dir = "outputs" 11 | os.makedirs(output_dir, exist_ok=True) 12 | 13 | 14 | 15 | ################################################################################################ 16 | ## 1. Text-to-image 17 | ################################################################################################ 18 | output = pipeline( 19 | prompt="[[text2image]] A bipedal black cat wearing a huge oversized witch hat, a wizards robe, casting a spell,in an enchanted forest. The scene is filled with fireflies and moss on surrounding rocks and trees", 20 | negative_prompt=NEGATIVE_PROMPT, 21 | num_inference_steps=50, 22 | guidance_scale=4, 23 | height=1024, 24 | width=1024, 25 | ) 26 | output.images[0].save(f'{output_dir}/text2image_output.jpg') 27 | 28 | 29 | 30 | ################################################################################################ 31 | ## 2. Semantic to Image (and any other condition2image tasks): 32 | ################################################################################################ 33 | images = [ 34 | Image.open("assets/examples/semantic_map/dragon_birds_woman.webp") 35 | ] 36 | prompt = "[[semanticmap2image]] <#00ffff Cyan mask: dragon> <#ff0000 yellow mask: bird> <#800080 purple mask: woman> A woman in a red dress with gold floral patterns stands in a traditional Japanese-style building. She has black hair and wears a gold choker and earrings. Behind her, a large orange and white dragon coils around the structure. Two white birds fly near her. The building features paper windows and a wooden roof with lanterns. The scene blends traditional Japanese architecture with fantastical elements, creating a mystical atmosphere." 37 | # set the denoise mask to [1, 0] to denoise image from conditions 38 | # by default, the height and width will be set so that input image is minimally cropped 39 | ret = pipeline.img2img( 40 | image=images, 41 | num_inference_steps=50, 42 | prompt=prompt, 43 | denoise_mask=[1, 0], 44 | guidance_scale=4, 45 | negative_prompt=NEGATIVE_PROMPT, 46 | # height=512, 47 | # width=512, 48 | ) 49 | ret.images[0].save(f"{output_dir}/semanticmap2image_output.jpg") 50 | 51 | 52 | 53 | ################################################################################################ 54 | ## 3. Depth Estimation (and any other image2condtition tasks): 55 | ################################################################################################ 56 | images = [ 57 | Image.open("assets/examples/images/cat_on_table.webp"), 58 | ] 59 | prompt = "[[depth2image]] cat sitting on a table" # you can omit caption i.e., setting prompt to "[[depth2image]]" 60 | # set the denoise mask to [0, 1] to denoise condition (depth, pose, hed, canny, semantic map etc) 61 | # by default, the height and width will be set so that input image is minimally cropped 62 | ret = pipeline.img2img( 63 | image=images, 64 | num_inference_steps=50, 65 | prompt=prompt, 66 | denoise_mask=[0, 1], 67 | guidance_scale=4, 68 | NEGATIVE_PROMPT=NEGATIVE_PROMPT, 69 | # height=512, 70 | # width=512, 71 | ) 72 | ret.images[0].save(f"{output_dir}/image2depth_output.jpg") 73 | 74 | 75 | 76 | ################################################################################################ 77 | ## 4. ID Customization 78 | ################################################################################################ 79 | images = [ 80 | Image.open("assets/examples/id_customization/chenhao/image_0.png"), 81 | Image.open("assets/examples/id_customization/chenhao/image_1.png"), 82 | Image.open("assets/examples/id_customization/chenhao/image_2.png") 83 | ] 84 | ##### we will set the denoise mask to [1, 0, 0, 0], which mean we the generated image will be the first view and next three views will be the condition (with same order as `images` list) 85 | #### prompt format: [[faceid]] [[img0]] target/caption [[img1]] caption/of/first/image [[img2]] caption/of/second/image [[img3]] caption/of/third/image 86 | prompt = "[[faceid]] \ 87 | [[img0]] A woman dressed in traditional attire with intricate headpieces. She is looking at the camera and having neutral expression. \ 88 | [[img1]] A woman with long dark hair, smiling warmly while wearing a floral dress. \ 89 | [[img2]] A woman in traditional clothing holding a lace parasol, with her hair styled elegantly. \ 90 | [[img3]] A woman in elaborate traditional attire and jewelry, with an ornate headdress, looking intently forward. \ 91 | " 92 | # by default, all images will be cropped according to the FIRST input image. 93 | ret = pipeline.img2img(image=images, num_inference_steps=75, prompt=prompt, denoise_mask=[1, 0, 0, 0], guidance_scale=4, negative_prompt=NEGATIVE_PROMPT) 94 | ret.images[0].save(f"{output_dir}/idcustomization_output.jpg") 95 | 96 | 97 | 98 | ################################################################################################ 99 | ## 5. Image to multiview 100 | ################################################################################################ 101 | images = [ 102 | Image.open("assets/examples/images/cat_on_table.webp"), 103 | ] 104 | prompt = "[[multiview]] A cat with orange and white fur sits on a round wooden table. The cat has striking green eyes and a pink nose. Its ears are perked up, and its tail is curled around its body. The background is blurred, showing a white wall, a wooden chair, and a wooden table with a white pot and green plant. A white curtain is visible on the right side. The cat's gaze is directed slightly to the right, and its paws are white. The overall scene creates a cozy, domestic atmosphere with the cat as the central focus." 105 | # denoise mask: [0, 0, 1, 0, 1, 0, 1, 0] is for [img_0, camera of img0, img_1, camera of img1, img_2, camera of img2, img_3, camera of img3] 106 | # since we provide 1 input image and all camera positions, the we don't need to denoise those value and set the mask to 0 107 | # set the mask of [img_1, img_2, img_3] to 1 to generate novel views 108 | # NOTE: only support SQUARE image 109 | ret = pipeline.img2img( 110 | image=images, 111 | num_inference_steps=60, 112 | prompt=prompt, 113 | negative_prompt=NEGATIVE_PROMPT, 114 | denoise_mask=[0, 0, 1, 0, 1, 0, 1, 0], 115 | guidance_scale=4, 116 | multiview_azimuths=[0,20,40,60], # relative azimuth to first views 117 | multiview_elevations=[0,0,0,0], # relative elevation to first views 118 | multiview_distances=[1.5,1.5,1.5,1.5], 119 | # multiview_c2ws=None, # you can provide the camera-to-world matrix of shape [N, 4,4] for ALL views, camera extrinsics matrix is relative to first view 120 | # multiview_intrinsics=None, # provide the intrinsics matrix if c2ws is used 121 | height=512, 122 | width=512, 123 | is_multiview=True, 124 | ) 125 | for i in range(len(ret.images)): 126 | ret.images[i].save(f"{output_dir}/img2multiview_output_view{i+1}.jpg") 127 | 128 | 129 | -------------------------------------------------------------------------------- /onediffusion/dataset/multitask/multiview.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | from PIL import Image 5 | import torch 6 | from typing import List, Tuple, Union 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | import torchvision.transforms as T 10 | from onediffusion.dataset.utils import * 11 | import glob 12 | 13 | from onediffusion.dataset.raydiff_utils import cameras_to_rays, first_camera_transform, normalize_cameras 14 | from onediffusion.dataset.transforms import CenterCropResizeImage 15 | from pytorch3d.renderer import PerspectiveCameras 16 | 17 | import numpy as np 18 | 19 | def _cameras_from_opencv_projection( 20 | R: torch.Tensor, 21 | tvec: torch.Tensor, 22 | camera_matrix: torch.Tensor, 23 | image_size: torch.Tensor, 24 | do_normalize_cameras, 25 | normalize_scale, 26 | ) -> PerspectiveCameras: 27 | focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1) 28 | principal_point = camera_matrix[:, :2, 2] 29 | 30 | # Retype the image_size correctly and flip to width, height. 31 | image_size_wh = image_size.to(R).flip(dims=(1,)) 32 | 33 | # Screen to NDC conversion: 34 | # For non square images, we scale the points such that smallest side 35 | # has range [-1, 1] and the largest side has range [-u, u], with u > 1. 36 | # This convention is consistent with the PyTorch3D renderer, as well as 37 | # the transformation function `get_ndc_to_screen_transform`. 38 | scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0 39 | scale = scale.expand(-1, 2) 40 | c0 = image_size_wh / 2.0 41 | 42 | # Get the PyTorch3D focal length and principal point. 43 | focal_pytorch3d = focal_length / scale 44 | p0_pytorch3d = -(principal_point - c0) / scale 45 | 46 | # For R, T we flip x, y axes (opencv screen space has an opposite 47 | # orientation of screen axes). 48 | # We also transpose R (opencv multiplies points from the opposite=left side). 49 | R_pytorch3d = R.clone().permute(0, 2, 1) 50 | T_pytorch3d = tvec.clone() 51 | R_pytorch3d[:, :, :2] *= -1 52 | T_pytorch3d[:, :2] *= -1 53 | 54 | cams = PerspectiveCameras( 55 | R=R_pytorch3d, 56 | T=T_pytorch3d, 57 | focal_length=focal_pytorch3d, 58 | principal_point=p0_pytorch3d, 59 | image_size=image_size, 60 | device=R.device, 61 | ) 62 | 63 | if do_normalize_cameras: 64 | cams, _ = normalize_cameras(cams, scale=normalize_scale) 65 | 66 | cams = first_camera_transform(cams, rotation_only=False) 67 | return cams 68 | 69 | def calculate_rays(Ks, sizes, Rs, Ts, target_size, use_plucker=True, do_normalize_cameras=False, normalize_scale=1.0): 70 | cameras = _cameras_from_opencv_projection( 71 | R=Rs, 72 | tvec=Ts, 73 | camera_matrix=Ks, 74 | image_size=sizes, 75 | do_normalize_cameras=do_normalize_cameras, 76 | normalize_scale=normalize_scale 77 | ) 78 | 79 | rays_embedding = cameras_to_rays( 80 | cameras=cameras, 81 | num_patches_x=target_size, 82 | num_patches_y=target_size, 83 | crop_parameters=None, 84 | use_plucker=use_plucker 85 | ) 86 | 87 | return rays_embedding.rays 88 | 89 | def convert_rgba_to_rgb_white_bg(image): 90 | """Convert RGBA image to RGB with white background""" 91 | if image.mode == 'RGBA': 92 | # Create a white background 93 | background = Image.new('RGBA', image.size, (255, 255, 255, 255)) 94 | # Composite the image onto the white background 95 | return Image.alpha_composite(background, image).convert('RGB') 96 | return image.convert('RGB') 97 | 98 | class MultiviewDataset(Dataset): 99 | def __init__( 100 | self, 101 | scene_folders: str, 102 | samples_per_set: Union[int, Tuple[int, int]], # Changed from samples_per_set to samples_range 103 | transform=None, 104 | caption_keys: Union[str, List] = "caption", 105 | multiscale=False, 106 | aspect_ratio_type=ASPECT_RATIO_512, 107 | c2w_scaling=1.7, 108 | default_max_distance=1, # default max distance from all camera of a scene , 109 | do_normalize=True, # whether normalize translation of c2w with max_distance 110 | swap_xz=False, # whether swap x and z axis of 3D scenes 111 | valid_paths: str = "", 112 | frame_sliding_windows: float = None # limit all sampled frames to be within this window, so that camera poses won't be too different 113 | ): 114 | if not isinstance(samples_per_set, tuple) and not isinstance(samples_per_set, list): 115 | samples_per_set = (samples_per_set, samples_per_set) 116 | self.samples_range = samples_per_set # Tuple of (min_samples, max_samples) 117 | self.transform = transform 118 | self.caption_keys = caption_keys if isinstance(caption_keys, list) else [caption_keys] 119 | self.aspect_ratio = aspect_ratio_type 120 | self.scene_folders = sorted(glob.glob(scene_folders)) 121 | # filter out scene folders that do not have transforms.json 122 | self.scene_folders = list(filter(lambda x: os.path.exists(os.path.join(x, "transforms.json")), self.scene_folders)) 123 | 124 | # if valid_paths.txt exists, only use paths in that file 125 | if os.path.exists(valid_paths): 126 | with open(valid_paths, 'r') as f: 127 | valid_scene_folders = f.read().splitlines() 128 | self.scene_folders = sorted(valid_scene_folders) 129 | 130 | self.c2w_scaling = c2w_scaling 131 | self.do_normalize = do_normalize 132 | self.default_max_distance = default_max_distance 133 | self.swap_xz = swap_xz 134 | self.frame_sliding_windows = frame_sliding_windows 135 | 136 | if multiscale: 137 | assert self.aspect_ratio in [ASPECT_RATIO_512, ASPECT_RATIO_1024, ASPECT_RATIO_2048, ASPECT_RATIO_2880] 138 | if self.aspect_ratio in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]: 139 | self.interpolate_model = T.InterpolationMode.LANCZOS 140 | self.ratio_index = {} 141 | self.ratio_nums = {} 142 | for k, v in self.aspect_ratio.items(): 143 | self.ratio_index[float(k)] = [] # used for self.getitem 144 | self.ratio_nums[float(k)] = 0 # used for batch-sampler 145 | 146 | def __len__(self): 147 | return len(self.scene_folders) 148 | 149 | def __getitem__(self, idx): 150 | try: 151 | scene_path = self.scene_folders[idx] 152 | 153 | if os.path.exists(os.path.join(scene_path, "images")): 154 | image_folder = os.path.join(scene_path, "images") 155 | downscale_factor = 1 156 | elif os.path.exists(os.path.join(scene_path, "images_4")): 157 | image_folder = os.path.join(scene_path, "images_4") 158 | downscale_factor = 1 / 4 159 | elif os.path.exists(os.path.join(scene_path, "images_8")): 160 | image_folder = os.path.join(scene_path, "images_8") 161 | downscale_factor = 1 / 8 162 | else: 163 | raise NotImplementedError 164 | 165 | json_path = os.path.join(scene_path, "transforms.json") 166 | caption_path = os.path.join(scene_path, "caption.json") 167 | image_files = os.listdir(image_folder) 168 | 169 | with open(json_path, 'r') as f: 170 | json_data = json.load(f) 171 | height, width = json_data['h'], json_data['w'] 172 | 173 | dh, dw = int(height * downscale_factor), int(width * downscale_factor) 174 | fl_x, fl_y = json_data['fl_x'] * downscale_factor, json_data['fl_y'] * downscale_factor 175 | cx = dw // 2 176 | cy = dh // 2 177 | 178 | frame_list = json_data['frames'] 179 | 180 | # Randomly select number of samples 181 | 182 | samples_per_set = random.randint(self.samples_range[0], self.samples_range[1]) 183 | 184 | # uniformly for all scenes 185 | if self.frame_sliding_windows is None: 186 | selected_indices = random.sample(range(len(frame_list)), min(samples_per_set, len(frame_list))) 187 | # limit the multiview to be in a sliding window (to avoid catastrophic difference in camera angles) 188 | else: 189 | # Determine the starting index of the sliding window 190 | if len(frame_list) <= self.frame_sliding_windows: 191 | # If the frame list is smaller than or equal to X, use the entire list 192 | window_start = 0 193 | window_end = len(frame_list) 194 | else: 195 | # Randomly select a starting point for the window 196 | window_start = random.randint(0, len(frame_list) - self.frame_sliding_windows) 197 | window_end = window_start + self.frame_sliding_windows 198 | 199 | # Get the indices within the sliding window 200 | window_indices = list(range(window_start, window_end)) 201 | 202 | # Randomly sample indices from the window 203 | selected_indices = random.sample(window_indices, samples_per_set) 204 | 205 | image_files = [os.path.basename(frame_list[i]['file_path']) for i in selected_indices] 206 | image_paths = [os.path.join(image_folder, file) for file in image_files] 207 | 208 | # Load images and convert RGBA to RGB with white background 209 | images = [convert_rgba_to_rgb_white_bg(Image.open(image_path)) for image_path in image_paths] 210 | 211 | if self.transform: 212 | images = [self.transform(image) for image in images] 213 | else: 214 | closest_size, closest_ratio = self.aspect_ratio['1.0'], 1.0 215 | closest_size = tuple(map(int, closest_size)) 216 | transform = T.Compose([ 217 | T.ToTensor(), 218 | CenterCropResizeImage(closest_size), 219 | T.Normalize([.5], [.5]), 220 | ]) 221 | images = [transform(image) for image in images] 222 | images = torch.stack(images) 223 | 224 | c2ws = [frame_list[i]['transform_matrix'] for i in selected_indices] 225 | c2ws = torch.tensor(c2ws).reshape(-1, 4, 4) 226 | # max_distance = json_data.get('max_distance', self.default_max_distance) 227 | # if 'max_distance' not in json_data.keys(): 228 | # print(f"not found `max_distance` in json path: {json_path}") 229 | 230 | if self.swap_xz: 231 | swap_xz = torch.tensor([[[0, 0, 1., 0], 232 | [0, 1., 0, 0], 233 | [-1., 0, 0, 0], 234 | [0, 0, 0, 1.]]]) 235 | c2ws = swap_xz @ c2ws 236 | 237 | # OPENGL to OPENCV 238 | c2ws[:, 0:3, 1:3] *= -1 239 | c2ws = c2ws[:, [1, 0, 2, 3], :] 240 | c2ws[:, 2, :] *= -1 241 | 242 | w2cs = torch.inverse(c2ws) 243 | K = torch.tensor([[[fl_x, 0, cx], [0, fl_y, cy], [0, 0, 1]]]).repeat(len(c2ws), 1, 1) 244 | Rs = w2cs[:, :3, :3] 245 | Ts = w2cs[:, :3, 3] 246 | sizes = torch.tensor([[dh, dw]]).repeat(len(c2ws), 1) 247 | 248 | # get ray embedding and padding last dimension to 16 (num channels of VAE) 249 | # rays_od = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, use_plucker=False, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling) 250 | rays = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling) 251 | rays = rays.reshape(samples_per_set, closest_size[0] // 8, closest_size[1] // 8, 6) 252 | # padding = (0, 10) # pad the last dimension to 16 253 | # rays = torch.nn.functional.pad(rays, padding, "constant", 0) 254 | rays = torch.cat([rays, rays, rays[..., :4]], dim=-1) * 1.658 255 | 256 | if os.path.exists(caption_path): 257 | with open(caption_path, 'r') as f: 258 | caption_key = random.choice(self.caption_keys) 259 | caption = json.load(f).get(caption_key, "") 260 | else: 261 | caption = "" 262 | 263 | caption = "[[multiview]] " + caption if caption else "[[multiview]]" 264 | 265 | return { 266 | 'pixel_values': images, 267 | 'rays': rays, 268 | 'aspect_ratio': closest_ratio, 269 | 'caption': caption, 270 | 'height': dh, 271 | 'width': dw, 272 | # 'origins': rays_od[..., :3], 273 | # 'dirs': rays_od[..., 3:6] 274 | } 275 | except Exception as e: 276 | return self.__getitem__(random.randint(0, len(self.scene_folders) - 1)) 277 | -------------------------------------------------------------------------------- /onediffusion/dataset/raydiff_utils.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Adapted from code originally written by David Novotny. 4 | """ 5 | 6 | import torch 7 | from pytorch3d.transforms import Rotate, Translate 8 | 9 | import cv2 10 | import numpy as np 11 | import torch 12 | from pytorch3d.renderer import PerspectiveCameras, RayBundle 13 | 14 | def intersect_skew_line_groups(p, r, mask): 15 | # p, r both of shape (B, N, n_intersected_lines, 3) 16 | # mask of shape (B, N, n_intersected_lines) 17 | p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask) 18 | if p_intersect is None: 19 | return None, None, None, None 20 | _, p_line_intersect = point_line_distance( 21 | p, r, p_intersect[..., None, :].expand_as(p) 22 | ) 23 | intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum( 24 | dim=-1 25 | ) 26 | return p_intersect, p_line_intersect, intersect_dist_squared, r 27 | 28 | 29 | def intersect_skew_lines_high_dim(p, r, mask=None): 30 | # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions 31 | dim = p.shape[-1] 32 | # make sure the heading vectors are l2-normed 33 | if mask is None: 34 | mask = torch.ones_like(p[..., 0]) 35 | r = torch.nn.functional.normalize(r, dim=-1) 36 | 37 | eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None] 38 | I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None] 39 | sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3) 40 | 41 | # I_eps = torch.zeros_like(I_min_cov.sum(dim=-3)) + 1e-10 42 | # p_intersect = torch.pinverse(I_min_cov.sum(dim=-3) + I_eps).matmul(sum_proj)[..., 0] 43 | p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0] 44 | 45 | # I_min_cov.sum(dim=-3): torch.Size([1, 1, 3, 3]) 46 | # sum_proj: torch.Size([1, 1, 3, 1]) 47 | 48 | # p_intersect = np.linalg.lstsq(I_min_cov.sum(dim=-3).numpy(), sum_proj.numpy(), rcond=None)[0] 49 | 50 | if torch.any(torch.isnan(p_intersect)): 51 | print(p_intersect) 52 | return None, None 53 | ipdb.set_trace() 54 | assert False 55 | return p_intersect, r 56 | 57 | 58 | def point_line_distance(p1, r1, p2): 59 | df = p2 - p1 60 | proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1) 61 | line_pt_nearest = p2 - proj_vector 62 | d = (proj_vector).norm(dim=-1) 63 | return d, line_pt_nearest 64 | 65 | 66 | def compute_optical_axis_intersection(cameras): 67 | centers = cameras.get_camera_center() 68 | principal_points = cameras.principal_point 69 | 70 | one_vec = torch.ones((len(cameras), 1), device=centers.device) 71 | optical_axis = torch.cat((principal_points, one_vec), -1) 72 | 73 | # optical_axis = torch.cat( 74 | # (principal_points, cameras.focal_length[:, 0].unsqueeze(1)), -1 75 | # ) 76 | 77 | pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True) 78 | pp2 = torch.diagonal(pp, dim1=0, dim2=1).T 79 | 80 | directions = pp2 - centers 81 | centers = centers.unsqueeze(0).unsqueeze(0) 82 | directions = directions.unsqueeze(0).unsqueeze(0) 83 | 84 | p_intersect, p_line_intersect, _, r = intersect_skew_line_groups( 85 | p=centers, r=directions, mask=None 86 | ) 87 | 88 | if p_intersect is None: 89 | dist = None 90 | else: 91 | p_intersect = p_intersect.squeeze().unsqueeze(0) 92 | dist = (p_intersect - centers).norm(dim=-1) 93 | 94 | return p_intersect, dist, p_line_intersect, pp2, r 95 | 96 | 97 | def normalize_cameras(cameras, scale=1.0): 98 | """ 99 | Normalizes cameras such that the optical axes point to the origin, the rotation is 100 | identity, and the norm of the translation of the first camera is 1. 101 | 102 | Args: 103 | cameras (pytorch3d.renderer.cameras.CamerasBase). 104 | scale (float): Norm of the translation of the first camera. 105 | 106 | Returns: 107 | new_cameras (pytorch3d.renderer.cameras.CamerasBase): Normalized cameras. 108 | undo_transform (function): Function that undoes the normalization. 109 | """ 110 | 111 | # Let distance from first camera to origin be unit 112 | new_cameras = cameras.clone() 113 | new_transform = ( 114 | new_cameras.get_world_to_view_transform() 115 | ) # potential R is not valid matrix 116 | p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection( 117 | cameras 118 | ) 119 | 120 | if p_intersect is None: 121 | print("Warning: optical axes code has a nan. Returning identity cameras.") 122 | new_cameras.R[:] = torch.eye(3, device=cameras.R.device, dtype=cameras.R.dtype) 123 | new_cameras.T[:] = torch.tensor( 124 | [0, 0, 1], device=cameras.T.device, dtype=cameras.T.dtype 125 | ) 126 | return new_cameras, lambda x: x 127 | 128 | d = dist.squeeze(dim=1).squeeze(dim=0)[0] 129 | # Degenerate case 130 | if d == 0: 131 | print(cameras.T) 132 | print(new_transform.get_matrix()[:, 3, :3]) 133 | assert False 134 | assert d != 0 135 | 136 | # Can't figure out how to make scale part of the transform too without messing up R. 137 | # Ideally, we would just wrap it all in a single Pytorch3D transform so that it 138 | # would work with any structure (eg PointClouds, Meshes). 139 | tR = Rotate(new_cameras.R[0].unsqueeze(0)).inverse() 140 | tT = Translate(p_intersect) 141 | t = tR.compose(tT) 142 | 143 | new_transform = t.compose(new_transform) 144 | new_cameras.R = new_transform.get_matrix()[:, :3, :3] 145 | new_cameras.T = new_transform.get_matrix()[:, 3, :3] / d * scale 146 | 147 | def undo_transform(cameras): 148 | cameras_copy = cameras.clone() 149 | cameras_copy.T *= d / scale 150 | new_t = ( 151 | t.inverse().compose(cameras_copy.get_world_to_view_transform()).get_matrix() 152 | ) 153 | cameras_copy.R = new_t[:, :3, :3] 154 | cameras_copy.T = new_t[:, 3, :3] 155 | return cameras_copy 156 | 157 | return new_cameras, undo_transform 158 | 159 | def first_camera_transform(cameras, rotation_only=True): 160 | new_cameras = cameras.clone() 161 | new_transform = new_cameras.get_world_to_view_transform() 162 | tR = Rotate(new_cameras.R[0].unsqueeze(0)) 163 | if rotation_only: 164 | t = tR.inverse() 165 | else: 166 | tT = Translate(new_cameras.T[0].unsqueeze(0)) 167 | t = tR.compose(tT).inverse() 168 | 169 | new_transform = t.compose(new_transform) 170 | new_cameras.R = new_transform.get_matrix()[:, :3, :3] 171 | new_cameras.T = new_transform.get_matrix()[:, 3, :3] 172 | 173 | return new_cameras 174 | 175 | 176 | def get_identity_cameras_with_intrinsics(cameras): 177 | D = len(cameras) 178 | device = cameras.R.device 179 | 180 | new_cameras = cameras.clone() 181 | new_cameras.R = torch.eye(3, device=device).unsqueeze(0).repeat((D, 1, 1)) 182 | new_cameras.T = torch.zeros((D, 3), device=device) 183 | 184 | return new_cameras 185 | 186 | 187 | def normalize_cameras_batch(cameras, scale=1.0, normalize_first_camera=False): 188 | new_cameras = [] 189 | undo_transforms = [] 190 | for cam in cameras: 191 | if normalize_first_camera: 192 | # Normalize cameras such that first camera is identity and origin is at 193 | # first camera center. 194 | normalized_cameras = first_camera_transform(cam, rotation_only=False) 195 | undo_transform = None 196 | else: 197 | normalized_cameras, undo_transform = normalize_cameras(cam, scale=scale) 198 | new_cameras.append(normalized_cameras) 199 | undo_transforms.append(undo_transform) 200 | return new_cameras, undo_transforms 201 | 202 | 203 | class Rays(object): 204 | def __init__( 205 | self, 206 | rays=None, 207 | origins=None, 208 | directions=None, 209 | moments=None, 210 | is_plucker=False, 211 | moments_rescale=1.0, 212 | ndc_coordinates=None, 213 | crop_parameters=None, 214 | num_patches_x=16, 215 | num_patches_y=16, 216 | ): 217 | """ 218 | Ray class to keep track of current ray representation. 219 | 220 | Args: 221 | rays: (..., 6). 222 | origins: (..., 3). 223 | directions: (..., 3). 224 | moments: (..., 3). 225 | is_plucker: If True, rays are in plucker coordinates (Default: False). 226 | moments_rescale: Rescale the moment component of the rays by a scalar. 227 | ndc_coordinates: (..., 2): NDC coordinates of each ray. 228 | """ 229 | if rays is not None: 230 | self.rays = rays 231 | self._is_plucker = is_plucker 232 | elif origins is not None and directions is not None: 233 | self.rays = torch.cat((origins, directions), dim=-1) 234 | self._is_plucker = False 235 | elif directions is not None and moments is not None: 236 | self.rays = torch.cat((directions, moments), dim=-1) 237 | self._is_plucker = True 238 | else: 239 | raise Exception("Invalid combination of arguments") 240 | 241 | if moments_rescale != 1.0: 242 | self.rescale_moments(moments_rescale) 243 | 244 | if ndc_coordinates is not None: 245 | self.ndc_coordinates = ndc_coordinates 246 | elif crop_parameters is not None: 247 | # (..., H, W, 2) 248 | xy_grid = compute_ndc_coordinates( 249 | crop_parameters, 250 | num_patches_x=num_patches_x, 251 | num_patches_y=num_patches_y, 252 | )[..., :2] 253 | xy_grid = xy_grid.reshape(*xy_grid.shape[:-3], -1, 2) 254 | self.ndc_coordinates = xy_grid 255 | else: 256 | self.ndc_coordinates = None 257 | 258 | def __getitem__(self, index): 259 | return Rays( 260 | rays=self.rays[index], 261 | is_plucker=self._is_plucker, 262 | ndc_coordinates=( 263 | self.ndc_coordinates[index] 264 | if self.ndc_coordinates is not None 265 | else None 266 | ), 267 | ) 268 | 269 | def to_spatial(self, include_ndc_coordinates=False): 270 | """ 271 | Converts rays to spatial representation: (..., H * W, 6) --> (..., 6, H, W) 272 | 273 | Returns: 274 | torch.Tensor: (..., 6, H, W) 275 | """ 276 | rays = self.to_plucker().rays 277 | *batch_dims, P, D = rays.shape 278 | H = W = int(np.sqrt(P)) 279 | assert H * W == P 280 | rays = torch.transpose(rays, -1, -2) # (..., 6, H * W) 281 | rays = rays.reshape(*batch_dims, D, H, W) 282 | if include_ndc_coordinates: 283 | ndc_coords = self.ndc_coordinates.transpose(-1, -2) # (..., 2, H * W) 284 | ndc_coords = ndc_coords.reshape(*batch_dims, 2, H, W) 285 | rays = torch.cat((rays, ndc_coords), dim=-3) 286 | return rays 287 | 288 | def rescale_moments(self, scale): 289 | """ 290 | Rescale the moment component of the rays by a scalar. Might be desirable since 291 | moments may come from a very narrow distribution. 292 | 293 | Note that this modifies in place! 294 | """ 295 | if self.is_plucker: 296 | self.rays[..., 3:] *= scale 297 | return self 298 | else: 299 | return self.to_plucker().rescale_moments(scale) 300 | 301 | @classmethod 302 | def from_spatial(cls, rays, moments_rescale=1.0, ndc_coordinates=None): 303 | """ 304 | Converts rays from spatial representation: (..., 6, H, W) --> (..., H * W, 6) 305 | 306 | Args: 307 | rays: (..., 6, H, W) 308 | 309 | Returns: 310 | Rays: (..., H * W, 6) 311 | """ 312 | *batch_dims, D, H, W = rays.shape 313 | rays = rays.reshape(*batch_dims, D, H * W) 314 | rays = torch.transpose(rays, -1, -2) 315 | return cls( 316 | rays=rays, 317 | is_plucker=True, 318 | moments_rescale=moments_rescale, 319 | ndc_coordinates=ndc_coordinates, 320 | ) 321 | 322 | def to_point_direction(self, normalize_moment=True): 323 | """ 324 | Convert to point direction representation . 325 | 326 | Returns: 327 | rays: (..., 6). 328 | """ 329 | if self._is_plucker: 330 | direction = torch.nn.functional.normalize(self.rays[..., :3], dim=-1) 331 | moment = self.rays[..., 3:] 332 | if normalize_moment: 333 | c = torch.linalg.norm(direction, dim=-1, keepdim=True) 334 | moment = moment / c 335 | points = torch.cross(direction, moment, dim=-1) 336 | return Rays( 337 | rays=torch.cat((points, direction), dim=-1), 338 | is_plucker=False, 339 | ndc_coordinates=self.ndc_coordinates, 340 | ) 341 | else: 342 | return self 343 | 344 | def to_plucker(self): 345 | """ 346 | Convert to plucker representation . 347 | """ 348 | if self.is_plucker: 349 | return self 350 | else: 351 | ray = self.rays.clone() 352 | ray_origins = ray[..., :3] 353 | ray_directions = ray[..., 3:] 354 | # Normalize ray directions to unit vectors 355 | ray_directions = ray_directions / ray_directions.norm(dim=-1, keepdim=True) 356 | plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1) 357 | new_ray = torch.cat([ray_directions, plucker_normal], dim=-1) 358 | return Rays( 359 | rays=new_ray, is_plucker=True, ndc_coordinates=self.ndc_coordinates 360 | ) 361 | 362 | def get_directions(self, normalize=True): 363 | if self.is_plucker: 364 | directions = self.rays[..., :3] 365 | else: 366 | directions = self.rays[..., 3:] 367 | if normalize: 368 | directions = torch.nn.functional.normalize(directions, dim=-1) 369 | return directions 370 | 371 | def get_origins(self): 372 | if self.is_plucker: 373 | origins = self.to_point_direction().get_origins() 374 | else: 375 | origins = self.rays[..., :3] 376 | return origins 377 | 378 | def get_moments(self): 379 | if self.is_plucker: 380 | moments = self.rays[..., 3:] 381 | else: 382 | moments = self.to_plucker().get_moments() 383 | return moments 384 | 385 | def get_ndc_coordinates(self): 386 | return self.ndc_coordinates 387 | 388 | @property 389 | def is_plucker(self): 390 | return self._is_plucker 391 | 392 | @property 393 | def device(self): 394 | return self.rays.device 395 | 396 | def __repr__(self, *args, **kwargs): 397 | ray_str = self.rays.__repr__(*args, **kwargs)[6:] # remove "tensor" 398 | if self._is_plucker: 399 | return "PluRay" + ray_str 400 | else: 401 | return "DirRay" + ray_str 402 | 403 | def to(self, device): 404 | self.rays = self.rays.to(device) 405 | 406 | def clone(self): 407 | return Rays(rays=self.rays.clone(), is_plucker=self._is_plucker) 408 | 409 | @property 410 | def shape(self): 411 | return self.rays.shape 412 | 413 | def visualize(self): 414 | directions = torch.nn.functional.normalize(self.get_directions(), dim=-1).cpu() 415 | moments = torch.nn.functional.normalize(self.get_moments(), dim=-1).cpu() 416 | return (directions + 1) / 2, (moments + 1) / 2 417 | 418 | def to_ray_bundle(self, length=0.3, recenter=True): 419 | lengths = torch.ones_like(self.get_origins()[..., :2]) * length 420 | lengths[..., 0] = 0 421 | if recenter: 422 | centers, _ = intersect_skew_lines_high_dim( 423 | self.get_origins(), self.get_directions() 424 | ) 425 | centers = centers.unsqueeze(1).repeat(1, lengths.shape[1], 1) 426 | else: 427 | centers = self.get_origins() 428 | return RayBundle( 429 | origins=centers, 430 | directions=self.get_directions(), 431 | lengths=lengths, 432 | xys=self.get_directions(), 433 | ) 434 | 435 | 436 | def cameras_to_rays( 437 | cameras, 438 | crop_parameters, 439 | use_half_pix=True, 440 | use_plucker=True, 441 | num_patches_x=16, 442 | num_patches_y=16, 443 | ): 444 | """ 445 | Unprojects rays from camera center to grid on image plane. 446 | 447 | Args: 448 | cameras: Pytorch3D cameras to unproject. Can be batched. 449 | crop_parameters: Crop parameters in NDC (cc_x, cc_y, crop_width, scale). 450 | Shape is (B, 4). 451 | use_half_pix: If True, use half pixel offset (Default: True). 452 | use_plucker: If True, return rays in plucker coordinates (Default: False). 453 | num_patches_x: Number of patches in x direction (Default: 16). 454 | num_patches_y: Number of patches in y direction (Default: 16). 455 | """ 456 | unprojected = [] 457 | crop_parameters_list = ( 458 | crop_parameters if crop_parameters is not None else [None for _ in cameras] 459 | ) 460 | for camera, crop_param in zip(cameras, crop_parameters_list): 461 | xyd_grid = compute_ndc_coordinates( 462 | crop_parameters=crop_param, 463 | use_half_pix=use_half_pix, 464 | num_patches_x=num_patches_x, 465 | num_patches_y=num_patches_y, 466 | ) 467 | 468 | unprojected.append( 469 | camera.unproject_points( 470 | xyd_grid.reshape(-1, 3), world_coordinates=True, from_ndc=True 471 | ) 472 | ) 473 | unprojected = torch.stack(unprojected, dim=0) # (N, P, 3) 474 | origins = cameras.get_camera_center().unsqueeze(1) # (N, 1, 3) 475 | origins = origins.repeat(1, num_patches_x * num_patches_y, 1) # (N, P, 3) 476 | directions = unprojected - origins 477 | 478 | rays = Rays( 479 | origins=origins, 480 | directions=directions, 481 | crop_parameters=crop_parameters, 482 | num_patches_x=num_patches_x, 483 | num_patches_y=num_patches_y, 484 | ) 485 | if use_plucker: 486 | return rays.to_plucker() 487 | return rays 488 | 489 | 490 | def rays_to_cameras( 491 | rays, 492 | crop_parameters, 493 | num_patches_x=16, 494 | num_patches_y=16, 495 | use_half_pix=True, 496 | sampled_ray_idx=None, 497 | cameras=None, 498 | focal_length=(3.453,), 499 | ): 500 | """ 501 | If cameras are provided, will use those intrinsics. Otherwise will use the provided 502 | focal_length(s). Dataset default is 3.32. 503 | 504 | Args: 505 | rays (Rays): (N, P, 6) 506 | crop_parameters (torch.Tensor): (N, 4) 507 | """ 508 | device = rays.device 509 | origins = rays.get_origins() 510 | directions = rays.get_directions() 511 | camera_centers, _ = intersect_skew_lines_high_dim(origins, directions) 512 | 513 | # Retrieve target rays 514 | if cameras is None: 515 | if len(focal_length) == 1: 516 | focal_length = focal_length * rays.shape[0] 517 | I_camera = PerspectiveCameras(focal_length=focal_length, device=device) 518 | else: 519 | # Use same intrinsics but reset to identity extrinsics. 520 | I_camera = cameras.clone() 521 | I_camera.R[:] = torch.eye(3, device=device) 522 | I_camera.T[:] = torch.zeros(3, device=device) 523 | I_patch_rays = cameras_to_rays( 524 | cameras=I_camera, 525 | num_patches_x=num_patches_x, 526 | num_patches_y=num_patches_y, 527 | use_half_pix=use_half_pix, 528 | crop_parameters=crop_parameters, 529 | ).get_directions() 530 | 531 | if sampled_ray_idx is not None: 532 | I_patch_rays = I_patch_rays[:, sampled_ray_idx] 533 | 534 | # Compute optimal rotation to align rays 535 | R = torch.zeros_like(I_camera.R) 536 | for i in range(len(I_camera)): 537 | R[i] = compute_optimal_rotation_alignment( 538 | I_patch_rays[i], 539 | directions[i], 540 | ) 541 | 542 | # Construct and return rotated camera 543 | cam = I_camera.clone() 544 | cam.R = R 545 | cam.T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2) 546 | return cam 547 | 548 | 549 | # https://www.reddit.com/r/learnmath/comments/v1crd7/linear_algebra_qr_to_ql_decomposition/ 550 | def ql_decomposition(A): 551 | P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device).float() 552 | A_tilde = torch.matmul(A, P) 553 | Q_tilde, R_tilde = torch.linalg.qr(A_tilde) 554 | Q = torch.matmul(Q_tilde, P) 555 | L = torch.matmul(torch.matmul(P, R_tilde), P) 556 | d = torch.diag(L) 557 | Q[:, 0] *= torch.sign(d[0]) 558 | Q[:, 1] *= torch.sign(d[1]) 559 | Q[:, 2] *= torch.sign(d[2]) 560 | L[0] *= torch.sign(d[0]) 561 | L[1] *= torch.sign(d[1]) 562 | L[2] *= torch.sign(d[2]) 563 | return Q, L 564 | 565 | 566 | def rays_to_cameras_homography( 567 | rays, 568 | crop_parameters, 569 | num_patches_x=16, 570 | num_patches_y=16, 571 | use_half_pix=True, 572 | sampled_ray_idx=None, 573 | reproj_threshold=0.2, 574 | ): 575 | """ 576 | Args: 577 | rays (Rays): (N, P, 6) 578 | crop_parameters (torch.Tensor): (N, 4) 579 | """ 580 | device = rays.device 581 | origins = rays.get_origins() 582 | directions = rays.get_directions() 583 | camera_centers, _ = intersect_skew_lines_high_dim(origins, directions) 584 | 585 | # Retrieve target rays 586 | I_camera = PerspectiveCameras(focal_length=[1] * rays.shape[0], device=device) 587 | I_patch_rays = cameras_to_rays( 588 | cameras=I_camera, 589 | num_patches_x=num_patches_x, 590 | num_patches_y=num_patches_y, 591 | use_half_pix=use_half_pix, 592 | crop_parameters=crop_parameters, 593 | ).get_directions() 594 | 595 | if sampled_ray_idx is not None: 596 | I_patch_rays = I_patch_rays[:, sampled_ray_idx] 597 | 598 | # Compute optimal rotation to align rays 599 | Rs = [] 600 | focal_lengths = [] 601 | principal_points = [] 602 | for i in range(rays.shape[-3]): 603 | R, f, pp = compute_optimal_rotation_intrinsics( 604 | I_patch_rays[i], 605 | directions[i], 606 | reproj_threshold=reproj_threshold, 607 | ) 608 | Rs.append(R) 609 | focal_lengths.append(f) 610 | principal_points.append(pp) 611 | 612 | R = torch.stack(Rs) 613 | focal_lengths = torch.stack(focal_lengths) 614 | principal_points = torch.stack(principal_points) 615 | T = -torch.matmul(R.transpose(1, 2), camera_centers.unsqueeze(2)).squeeze(2) 616 | return PerspectiveCameras( 617 | R=R, 618 | T=T, 619 | focal_length=focal_lengths, 620 | principal_point=principal_points, 621 | device=device, 622 | ) 623 | 624 | 625 | def compute_optimal_rotation_alignment(A, B): 626 | """ 627 | Compute optimal R that minimizes: || A - B @ R ||_F 628 | 629 | Args: 630 | A (torch.Tensor): (N, 3) 631 | B (torch.Tensor): (N, 3) 632 | 633 | Returns: 634 | R (torch.tensor): (3, 3) 635 | """ 636 | # normally with R @ B, this would be A @ B.T 637 | H = B.T @ A 638 | U, _, Vh = torch.linalg.svd(H, full_matrices=True) 639 | s = torch.linalg.det(U @ Vh) 640 | S_prime = torch.diag(torch.tensor([1, 1, torch.sign(s)], device=A.device)) 641 | return U @ S_prime @ Vh 642 | 643 | 644 | def compute_optimal_rotation_intrinsics( 645 | rays_origin, rays_target, z_threshold=1e-4, reproj_threshold=0.2 646 | ): 647 | """ 648 | Note: for some reason, f seems to be 1/f. 649 | 650 | Args: 651 | rays_origin (torch.Tensor): (N, 3) 652 | rays_target (torch.Tensor): (N, 3) 653 | z_threshold (float): Threshold for z value to be considered valid. 654 | 655 | Returns: 656 | R (torch.tensor): (3, 3) 657 | focal_length (torch.tensor): (2,) 658 | principal_point (torch.tensor): (2,) 659 | """ 660 | device = rays_origin.device 661 | z_mask = torch.logical_and( 662 | torch.abs(rays_target) > z_threshold, torch.abs(rays_origin) > z_threshold 663 | )[:, 2] 664 | rays_target = rays_target[z_mask] 665 | rays_origin = rays_origin[z_mask] 666 | rays_origin = rays_origin[:, :2] / rays_origin[:, -1:] 667 | rays_target = rays_target[:, :2] / rays_target[:, -1:] 668 | 669 | A, _ = cv2.findHomography( 670 | rays_origin.cpu().numpy(), 671 | rays_target.cpu().numpy(), 672 | cv2.RANSAC, 673 | reproj_threshold, 674 | ) 675 | A = torch.from_numpy(A).float().to(device) 676 | 677 | if torch.linalg.det(A) < 0: 678 | A = -A 679 | 680 | R, L = ql_decomposition(A) 681 | L = L / L[2][2] 682 | 683 | f = torch.stack((L[0][0], L[1][1])) 684 | pp = torch.stack((L[2][0], L[2][1])) 685 | return R, f, pp 686 | 687 | 688 | def compute_ndc_coordinates( 689 | crop_parameters=None, 690 | use_half_pix=True, 691 | num_patches_x=16, 692 | num_patches_y=16, 693 | device=None, 694 | ): 695 | """ 696 | Computes NDC Grid using crop_parameters. If crop_parameters is not provided, 697 | then it assumes that the crop is the entire image (corresponding to an NDC grid 698 | where top left corner is (1, 1) and bottom right corner is (-1, -1)). 699 | """ 700 | if crop_parameters is None: 701 | cc_x, cc_y, width = 0, 0, 2 702 | else: 703 | if len(crop_parameters.shape) > 1: 704 | return torch.stack( 705 | [ 706 | compute_ndc_coordinates( 707 | crop_parameters=crop_param, 708 | use_half_pix=use_half_pix, 709 | num_patches_x=num_patches_x, 710 | num_patches_y=num_patches_y, 711 | ) 712 | for crop_param in crop_parameters 713 | ], 714 | dim=0, 715 | ) 716 | device = crop_parameters.device 717 | cc_x, cc_y, width, _ = crop_parameters 718 | 719 | dx = 1 / num_patches_x 720 | dy = 1 / num_patches_y 721 | if use_half_pix: 722 | min_y = 1 - dy 723 | max_y = -min_y 724 | min_x = 1 - dx 725 | max_x = -min_x 726 | else: 727 | min_y = min_x = 1 728 | max_y = -1 + 2 * dy 729 | max_x = -1 + 2 * dx 730 | 731 | y, x = torch.meshgrid( 732 | torch.linspace(min_y, max_y, num_patches_y, dtype=torch.float32, device=device), 733 | torch.linspace(min_x, max_x, num_patches_x, dtype=torch.float32, device=device), 734 | indexing="ij", 735 | ) 736 | x_prime = x * width / 2 - cc_x 737 | y_prime = y * width / 2 - cc_y 738 | xyd_grid = torch.stack([x_prime, y_prime, torch.ones_like(x)], dim=-1) 739 | return xyd_grid 740 | -------------------------------------------------------------------------------- /onediffusion/dataset/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def crop(image, i, j, h, w): 5 | """ 6 | Args: 7 | image (torch.tensor): Image to be cropped. Size is (C, H, W) 8 | """ 9 | if len(image.size()) != 3: 10 | raise ValueError("image should be a 3D tensor") 11 | return image[..., i : i + h, j : j + w] 12 | 13 | def resize(image, target_size, interpolation_mode): 14 | if len(target_size) != 2: 15 | raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") 16 | return F.interpolate(image.unsqueeze(0), size=target_size, mode=interpolation_mode, align_corners=False).squeeze(0) 17 | 18 | def resize_scale(image, target_size, interpolation_mode): 19 | if len(target_size) != 2: 20 | raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") 21 | H, W = image.size(-2), image.size(-1) 22 | scale_ = target_size[0] / min(H, W) 23 | return F.interpolate(image.unsqueeze(0), scale_factor=scale_, mode=interpolation_mode, align_corners=False).squeeze(0) 24 | 25 | def resized_crop(image, i, j, h, w, size, interpolation_mode="bilinear"): 26 | """ 27 | Do spatial cropping and resizing to the image 28 | Args: 29 | image (torch.tensor): Image to be cropped. Size is (C, H, W) 30 | i (int): i in (i,j) i.e coordinates of the upper left corner. 31 | j (int): j in (i,j) i.e coordinates of the upper left corner. 32 | h (int): Height of the cropped region. 33 | w (int): Width of the cropped region. 34 | size (tuple(int, int)): height and width of resized image 35 | Returns: 36 | image (torch.tensor): Resized and cropped image. Size is (C, H, W) 37 | """ 38 | if len(image.size()) != 3: 39 | raise ValueError("image should be a 3D torch.tensor") 40 | image = crop(image, i, j, h, w) 41 | image = resize(image, size, interpolation_mode) 42 | return image 43 | 44 | def center_crop(image, crop_size): 45 | if len(image.size()) != 3: 46 | raise ValueError("image should be a 3D torch.tensor") 47 | h, w = image.size(-2), image.size(-1) 48 | th, tw = crop_size 49 | if h < th or w < tw: 50 | raise ValueError("height and width must be no smaller than crop_size") 51 | i = int(round((h - th) / 2.0)) 52 | j = int(round((w - tw) / 2.0)) 53 | return crop(image, i, j, th, tw) 54 | 55 | def center_crop_using_short_edge(image): 56 | if len(image.size()) != 3: 57 | raise ValueError("image should be a 3D torch.tensor") 58 | h, w = image.size(-2), image.size(-1) 59 | if h < w: 60 | th, tw = h, h 61 | i = 0 62 | j = int(round((w - tw) / 2.0)) 63 | else: 64 | th, tw = w, w 65 | i = int(round((h - th) / 2.0)) 66 | j = 0 67 | return crop(image, i, j, th, tw) 68 | 69 | class CenterCropResizeImage: 70 | """ 71 | Resize the image while maintaining aspect ratio, and then crop it to the desired size. 72 | The resizing is done such that the area of padding/cropping is minimized. 73 | """ 74 | def __init__(self, size, interpolation_mode="bilinear"): 75 | if isinstance(size, tuple): 76 | if len(size) != 2: 77 | raise ValueError(f"Size should be a tuple (height, width), instead got {size}") 78 | self.size = size 79 | else: 80 | self.size = (size, size) 81 | self.interpolation_mode = interpolation_mode 82 | 83 | def __call__(self, image): 84 | """ 85 | Args: 86 | image (torch.Tensor): Image to be resized and cropped. Size is (C, H, W) 87 | 88 | Returns: 89 | torch.Tensor: Resized and cropped image. Size is (C, target_height, target_width) 90 | """ 91 | target_height, target_width = self.size 92 | target_aspect = target_width / target_height 93 | 94 | # Get current image shape and aspect ratio 95 | _, height, width = image.shape 96 | height, width = float(height), float(width) 97 | current_aspect = width / height 98 | 99 | # Calculate crop dimensions 100 | if current_aspect > target_aspect: 101 | # Image is wider than target, crop width 102 | crop_height = height 103 | crop_width = height * target_aspect 104 | else: 105 | # Image is taller than target, crop height 106 | crop_height = width / target_aspect 107 | crop_width = width 108 | 109 | # Calculate crop coordinates (center crop) 110 | y1 = (height - crop_height) / 2 111 | x1 = (width - crop_width) / 2 112 | 113 | # Perform the crop 114 | cropped_image = crop(image, int(y1), int(x1), int(crop_height), int(crop_width)) 115 | 116 | # Resize the cropped image to the target size 117 | resized_image = resize(cropped_image, self.size, self.interpolation_mode) 118 | 119 | return resized_image 120 | 121 | # Example usage 122 | if __name__ == "__main__": 123 | # Create a sample image tensor 124 | sample_image = torch.rand(3, 480, 640) # (C, H, W) 125 | 126 | # Initialize the transform 127 | transform = CenterCropResizeImage(size=(224, 224), interpolation_mode="bilinear") 128 | 129 | # Apply the transform 130 | transformed_image = transform(sample_image) 131 | 132 | print(f"Original image shape: {sample_image.shape}") 133 | print(f"Transformed image shape: {transformed_image.shape}") -------------------------------------------------------------------------------- /onediffusion/dataset/utils.py: -------------------------------------------------------------------------------- 1 | 2 | ASPECT_RATIO_2880 = { 3 | '0.25': [1408.0, 5760.0], '0.26': [1408.0, 5568.0], '0.27': [1408.0, 5376.0], '0.28': [1408.0, 5184.0], 4 | '0.32': [1600.0, 4992.0], '0.33': [1600.0, 4800.0], '0.34': [1600.0, 4672.0], '0.4': [1792.0, 4480.0], 5 | '0.42': [1792.0, 4288.0], '0.47': [1920.0, 4096.0], '0.49': [1920.0, 3904.0], '0.51': [1920.0, 3776.0], 6 | '0.55': [2112.0, 3840.0], '0.59': [2112.0, 3584.0], '0.68': [2304.0, 3392.0], '0.72': [2304.0, 3200.0], 7 | '0.78': [2496.0, 3200.0], '0.83': [2496.0, 3008.0], '0.89': [2688.0, 3008.0], '0.93': [2688.0, 2880.0], 8 | '1.0': [2880.0, 2880.0], '1.07': [2880.0, 2688.0], '1.12': [3008.0, 2688.0], '1.21': [3008.0, 2496.0], 9 | '1.28': [3200.0, 2496.0], '1.39': [3200.0, 2304.0], '1.47': [3392.0, 2304.0], '1.7': [3584.0, 2112.0], 10 | '1.82': [3840.0, 2112.0], '2.03': [3904.0, 1920.0], '2.13': [4096.0, 1920.0], '2.39': [4288.0, 1792.0], 11 | '2.5': [4480.0, 1792.0], '2.92': [4672.0, 1600.0], '3.0': [4800.0, 1600.0], '3.12': [4992.0, 1600.0], 12 | '3.68': [5184.0, 1408.0], '3.82': [5376.0, 1408.0], '3.95': [5568.0, 1408.0], '4.0': [5760.0, 1408.0] 13 | } 14 | 15 | ASPECT_RATIO_2048 = { 16 | '0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0], '0.27': [1024.0, 3840.0], '0.28': [1024.0, 3712.0], 17 | '0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0], 18 | '0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0], 19 | '0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0], 20 | '0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0], 21 | '1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0], 22 | '1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0], 23 | '1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0], 24 | '2.5': [3200.0, 1280.0], '2.89': [3328.0, 1152.0], '3.0': [3456.0, 1152.0], '3.11': [3584.0, 1152.0], 25 | '3.62': [3712.0, 1024.0], '3.75': [3840.0, 1024.0], '3.88': [3968.0, 1024.0], '4.0': [4096.0, 1024.0] 26 | } 27 | 28 | ASPECT_RATIO_1024 = { 29 | '0.25': [512., 2048.], '0.26': [512., 1984.], '0.27': [512., 1920.], '0.28': [512., 1856.], 30 | '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.], 31 | '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.], 32 | '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.], 33 | '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.], 34 | '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.], 35 | '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.], 36 | '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.], 37 | '2.5': [1600., 640.], '2.89': [1664., 576.], '3.0': [1728., 576.], '3.11': [1792., 576.], 38 | '3.62': [1856., 512.], '3.75': [1920., 512.], '3.88': [1984., 512.], '4.0': [2048., 512.], 39 | } 40 | 41 | ASPECT_RATIO_512 = { 42 | '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0], 43 | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], 44 | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], 45 | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], 46 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], 47 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], 48 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], 49 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], 50 | '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0], 51 | '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0] 52 | } 53 | 54 | 55 | ASPECT_RATIO_384 = { 56 | '0.25': [192.0, 768.0], 57 | '0.26': [192.0, 736.0], 58 | '0.27': [208.0, 768.0], 59 | '0.28': [208.0, 736.0], 60 | '0.33': [240.0, 720.0], 61 | '0.4': [256.0, 640.0], 62 | '0.42': [304.0, 720.0], 63 | '0.48': [368.0, 768.0], 64 | '0.5': [384.0, 768.0], 65 | '0.52': [384.0, 736.0], 66 | '0.57': [384.0, 672.0], 67 | '0.6': [384.0, 640.0], 68 | '0.73': [384.0, 528.0], 69 | '0.77': [384.0, 496.0], 70 | '0.83': [384.0, 464.0], 71 | '0.89': [384.0, 432.0], 72 | '0.92': [384.0, 416.0], 73 | '1.0': [384.0, 384.0], 74 | '1.09': [384.0, 352.0], 75 | '1.14': [384.0, 336.0], 76 | '1.2': [384.0, 320.0], 77 | '1.26': [384.0, 304.0], 78 | '1.33': [384.0, 288.0], 79 | '1.41': [384.0, 272.0], 80 | '1.6': [384.0, 240.0], 81 | '1.71': [384.0, 224.0], 82 | '2.0': [384.0, 192.0], 83 | '2.4': [384.0, 160.0], 84 | '2.88': [368.0, 128.0], 85 | '3.0': [384.0, 128.0], 86 | '3.43': [384.0, 112.0], 87 | '4.0': [384.0, 96.0] 88 | } 89 | 90 | ASPECT_RATIO_256 = { 91 | '0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0], 92 | '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0], 93 | '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0], 94 | '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0], 95 | '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0], 96 | '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0], 97 | '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0], 98 | '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0], 99 | '2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0], 100 | '3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0] 101 | } 102 | 103 | ASPECT_RATIO_256_TEST = { 104 | '0.25': [128.0, 512.0], '0.28': [128.0, 464.0], 105 | '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0], 106 | '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0], 107 | '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0], 108 | '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0], 109 | '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0], 110 | '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0], 111 | '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0], 112 | '2.5': [400.0, 160.0], '3.0': [432.0, 144.0], 113 | '4.0': [512.0, 128.0] 114 | } 115 | 116 | ASPECT_RATIO_512_TEST = { 117 | '0.25': [256.0, 1024.0], '0.28': [256.0, 928.0], 118 | '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], 119 | '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], 120 | '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], 121 | '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], 122 | '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], 123 | '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], 124 | '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], 125 | '2.5': [800.0, 320.0], '3.0': [864.0, 288.0], 126 | '4.0': [1024.0, 256.0] 127 | } 128 | 129 | ASPECT_RATIO_1024_TEST = { 130 | '0.25': [512., 2048.], '0.28': [512., 1856.], 131 | '0.32': [576., 1792.], '0.33': [576., 1728.], '0.35': [576., 1664.], '0.4': [640., 1600.], 132 | '0.42': [640., 1536.], '0.48': [704., 1472.], '0.5': [704., 1408.], '0.52': [704., 1344.], 133 | '0.57': [768., 1344.], '0.6': [768., 1280.], '0.68': [832., 1216.], '0.72': [832., 1152.], 134 | '0.78': [896., 1152.], '0.82': [896., 1088.], '0.88': [960., 1088.], '0.94': [960., 1024.], 135 | '1.0': [1024., 1024.], '1.07': [1024., 960.], '1.13': [1088., 960.], '1.21': [1088., 896.], 136 | '1.29': [1152., 896.], '1.38': [1152., 832.], '1.46': [1216., 832.], '1.67': [1280., 768.], 137 | '1.75': [1344., 768.], '2.0': [1408., 704.], '2.09': [1472., 704.], '2.4': [1536., 640.], 138 | '2.5': [1600., 640.], '3.0': [1728., 576.], 139 | '4.0': [2048., 512.], 140 | } 141 | 142 | ASPECT_RATIO_2048_TEST = { 143 | '0.25': [1024.0, 4096.0], '0.26': [1024.0, 3968.0], 144 | '0.32': [1152.0, 3584.0], '0.33': [1152.0, 3456.0], '0.35': [1152.0, 3328.0], '0.4': [1280.0, 3200.0], 145 | '0.42': [1280.0, 3072.0], '0.48': [1408.0, 2944.0], '0.5': [1408.0, 2816.0], '0.52': [1408.0, 2688.0], 146 | '0.57': [1536.0, 2688.0], '0.6': [1536.0, 2560.0], '0.68': [1664.0, 2432.0], '0.72': [1664.0, 2304.0], 147 | '0.78': [1792.0, 2304.0], '0.82': [1792.0, 2176.0], '0.88': [1920.0, 2176.0], '0.94': [1920.0, 2048.0], 148 | '1.0': [2048.0, 2048.0], '1.07': [2048.0, 1920.0], '1.13': [2176.0, 1920.0], '1.21': [2176.0, 1792.0], 149 | '1.29': [2304.0, 1792.0], '1.38': [2304.0, 1664.0], '1.46': [2432.0, 1664.0], '1.67': [2560.0, 1536.0], 150 | '1.75': [2688.0, 1536.0], '2.0': [2816.0, 1408.0], '2.09': [2944.0, 1408.0], '2.4': [3072.0, 1280.0], 151 | '2.5': [3200.0, 1280.0], '3.0': [3456.0, 1152.0], 152 | '4.0': [4096.0, 1024.0] 153 | } 154 | 155 | ASPECT_RATIO_2880_TEST = { 156 | '0.25': [2048.0, 8192.0], '0.26': [2048.0, 7936.0], 157 | '0.32': [2304.0, 7168.0], '0.33': [2304.0, 6912.0], '0.35': [2304.0, 6656.0], '0.4': [2560.0, 6400.0], 158 | '0.42': [2560.0, 6144.0], '0.48': [2816.0, 5888.0], '0.5': [2816.0, 5632.0], '0.52': [2816.0, 5376.0], 159 | '0.57': [3072.0, 5376.0], '0.6': [3072.0, 5120.0], '0.68': [3328.0, 4864.0], '0.72': [3328.0, 4608.0], 160 | '0.78': [3584.0, 4608.0], '0.82': [3584.0, 4352.0], '0.88': [3840.0, 4352.0], '0.94': [3840.0, 4096.0], 161 | '1.0': [4096.0, 4096.0], '1.07': [4096.0, 3840.0], '1.13': [4352.0, 3840.0], '1.21': [4352.0, 3584.0], 162 | '1.29': [4608.0, 3584.0], '1.38': [4608.0, 3328.0], '1.46': [4864.0, 3328.0], '1.67': [5120.0, 3072.0], 163 | '1.75': [5376.0, 3072.0], '2.0': [5632.0, 2816.0], '2.09': [5888.0, 2816.0], '2.4': [6144.0, 2560.0], 164 | '2.5': [6400.0, 2560.0], '3.0': [6912.0, 2304.0], 165 | '4.0': [8192.0, 2048.0], 166 | } 167 | 168 | def get_chunks(lst, n): 169 | for i in range(0, len(lst), n): 170 | yield lst[i:i + n] 171 | 172 | def get_closest_ratio(height: float, width: float, ratios: dict): 173 | aspect_ratio = height / width 174 | closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) 175 | return ratios[closest_ratio], float(closest_ratio) 176 | -------------------------------------------------------------------------------- /onediffusion/diffusion/pipelines/image_processor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import warnings 17 | from typing import List, Optional, Tuple, Union 18 | 19 | import numpy as np 20 | import PIL.Image 21 | import torch 22 | import torch.nn.functional as F 23 | import torchvision.transforms as T 24 | from PIL import Image, ImageFilter, ImageOps 25 | 26 | from diffusers.configuration_utils import ConfigMixin, register_to_config 27 | from diffusers.utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate 28 | 29 | from onediffusion.dataset.transforms import CenterCropResizeImage 30 | 31 | PipelineImageInput = Union[ 32 | PIL.Image.Image, 33 | np.ndarray, 34 | torch.Tensor, 35 | List[PIL.Image.Image], 36 | List[np.ndarray], 37 | List[torch.Tensor], 38 | ] 39 | 40 | PipelineDepthInput = PipelineImageInput 41 | 42 | 43 | def is_valid_image(image): 44 | return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) 45 | 46 | 47 | def is_valid_image_imagelist(images): 48 | # check if the image input is one of the supported formats for image and image list: 49 | # it can be either one of below 3 50 | # (1) a 4d pytorch tensor or numpy array, 51 | # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor 52 | # (3) a list of valid image 53 | if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: 54 | return True 55 | elif is_valid_image(images): 56 | return True 57 | elif isinstance(images, list): 58 | return all(is_valid_image(image) for image in images) 59 | return False 60 | 61 | 62 | class VaeImageProcessorOneDiffuser(ConfigMixin): 63 | """ 64 | Image processor for VAE. 65 | 66 | Args: 67 | do_resize (`bool`, *optional*, defaults to `True`): 68 | Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept 69 | `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method. 70 | vae_scale_factor (`int`, *optional*, defaults to `8`): 71 | VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor. 72 | resample (`str`, *optional*, defaults to `lanczos`): 73 | Resampling filter to use when resizing the image. 74 | do_normalize (`bool`, *optional*, defaults to `True`): 75 | Whether to normalize the image to [-1,1]. 76 | do_binarize (`bool`, *optional*, defaults to `False`): 77 | Whether to binarize the image to 0/1. 78 | do_convert_rgb (`bool`, *optional*, defaults to be `False`): 79 | Whether to convert the images to RGB format. 80 | do_convert_grayscale (`bool`, *optional*, defaults to be `False`): 81 | Whether to convert the images to grayscale format. 82 | """ 83 | 84 | config_name = CONFIG_NAME 85 | 86 | @register_to_config 87 | def __init__( 88 | self, 89 | do_resize: bool = True, 90 | vae_scale_factor: int = 8, 91 | vae_latent_channels: int = 4, 92 | resample: str = "lanczos", 93 | do_normalize: bool = True, 94 | do_binarize: bool = False, 95 | do_convert_rgb: bool = False, 96 | do_convert_grayscale: bool = False, 97 | ): 98 | super().__init__() 99 | if do_convert_rgb and do_convert_grayscale: 100 | raise ValueError( 101 | "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`," 102 | " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.", 103 | " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`", 104 | ) 105 | 106 | @staticmethod 107 | def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: 108 | """ 109 | Convert a numpy image or a batch of images to a PIL image. 110 | """ 111 | if images.ndim == 3: 112 | images = images[None, ...] 113 | images = (images * 255).round().astype("uint8") 114 | if images.shape[-1] == 1: 115 | # special case for grayscale (single channel) images 116 | pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] 117 | else: 118 | pil_images = [Image.fromarray(image) for image in images] 119 | 120 | return pil_images 121 | 122 | @staticmethod 123 | def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: 124 | """ 125 | Convert a PIL image or a list of PIL images to NumPy arrays. 126 | """ 127 | if not isinstance(images, list): 128 | images = [images] 129 | images = [np.array(image).astype(np.float32) / 255.0 for image in images] 130 | images = np.stack(images, axis=0) 131 | 132 | return images 133 | 134 | @staticmethod 135 | def numpy_to_pt(images: np.ndarray) -> torch.Tensor: 136 | """ 137 | Convert a NumPy image to a PyTorch tensor. 138 | """ 139 | if images.ndim == 3: 140 | images = images[..., None] 141 | 142 | images = torch.from_numpy(images.transpose(0, 3, 1, 2)) 143 | return images 144 | 145 | @staticmethod 146 | def pt_to_numpy(images: torch.Tensor) -> np.ndarray: 147 | """ 148 | Convert a PyTorch tensor to a NumPy image. 149 | """ 150 | images = images.cpu().permute(0, 2, 3, 1).float().numpy() 151 | return images 152 | 153 | @staticmethod 154 | def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: 155 | """ 156 | Normalize an image array to [-1,1]. 157 | """ 158 | return 2.0 * images - 1.0 159 | 160 | @staticmethod 161 | def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: 162 | """ 163 | Denormalize an image array to [0,1]. 164 | """ 165 | return (images / 2 + 0.5).clamp(0, 1) 166 | 167 | @staticmethod 168 | def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: 169 | """ 170 | Converts a PIL image to RGB format. 171 | """ 172 | image = image.convert("RGB") 173 | 174 | return image 175 | 176 | @staticmethod 177 | def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image: 178 | """ 179 | Converts a PIL image to grayscale format. 180 | """ 181 | image = image.convert("L") 182 | 183 | return image 184 | 185 | @staticmethod 186 | def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image: 187 | """ 188 | Applies Gaussian blur to an image. 189 | """ 190 | image = image.filter(ImageFilter.GaussianBlur(blur_factor)) 191 | 192 | return image 193 | 194 | @staticmethod 195 | def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0): 196 | """ 197 | Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect 198 | ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for 199 | processing are 512x512, the region will be expanded to 128x128. 200 | 201 | Args: 202 | mask_image (PIL.Image.Image): Mask image. 203 | width (int): Width of the image to be processed. 204 | height (int): Height of the image to be processed. 205 | pad (int, optional): Padding to be added to the crop region. Defaults to 0. 206 | 207 | Returns: 208 | tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and 209 | matches the original aspect ratio. 210 | """ 211 | 212 | mask_image = mask_image.convert("L") 213 | mask = np.array(mask_image) 214 | 215 | # 1. find a rectangular region that contains all masked ares in an image 216 | h, w = mask.shape 217 | crop_left = 0 218 | for i in range(w): 219 | if not (mask[:, i] == 0).all(): 220 | break 221 | crop_left += 1 222 | 223 | crop_right = 0 224 | for i in reversed(range(w)): 225 | if not (mask[:, i] == 0).all(): 226 | break 227 | crop_right += 1 228 | 229 | crop_top = 0 230 | for i in range(h): 231 | if not (mask[i] == 0).all(): 232 | break 233 | crop_top += 1 234 | 235 | crop_bottom = 0 236 | for i in reversed(range(h)): 237 | if not (mask[i] == 0).all(): 238 | break 239 | crop_bottom += 1 240 | 241 | # 2. add padding to the crop region 242 | x1, y1, x2, y2 = ( 243 | int(max(crop_left - pad, 0)), 244 | int(max(crop_top - pad, 0)), 245 | int(min(w - crop_right + pad, w)), 246 | int(min(h - crop_bottom + pad, h)), 247 | ) 248 | 249 | # 3. expands crop region to match the aspect ratio of the image to be processed 250 | ratio_crop_region = (x2 - x1) / (y2 - y1) 251 | ratio_processing = width / height 252 | 253 | if ratio_crop_region > ratio_processing: 254 | desired_height = (x2 - x1) / ratio_processing 255 | desired_height_diff = int(desired_height - (y2 - y1)) 256 | y1 -= desired_height_diff // 2 257 | y2 += desired_height_diff - desired_height_diff // 2 258 | if y2 >= mask_image.height: 259 | diff = y2 - mask_image.height 260 | y2 -= diff 261 | y1 -= diff 262 | if y1 < 0: 263 | y2 -= y1 264 | y1 -= y1 265 | if y2 >= mask_image.height: 266 | y2 = mask_image.height 267 | else: 268 | desired_width = (y2 - y1) * ratio_processing 269 | desired_width_diff = int(desired_width - (x2 - x1)) 270 | x1 -= desired_width_diff // 2 271 | x2 += desired_width_diff - desired_width_diff // 2 272 | if x2 >= mask_image.width: 273 | diff = x2 - mask_image.width 274 | x2 -= diff 275 | x1 -= diff 276 | if x1 < 0: 277 | x2 -= x1 278 | x1 -= x1 279 | if x2 >= mask_image.width: 280 | x2 = mask_image.width 281 | 282 | return x1, y1, x2, y2 283 | 284 | def _resize_and_fill( 285 | self, 286 | image: PIL.Image.Image, 287 | width: int, 288 | height: int, 289 | ) -> PIL.Image.Image: 290 | """ 291 | Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center 292 | the image within the dimensions, filling empty with data from image. 293 | 294 | Args: 295 | image: The image to resize. 296 | width: The width to resize the image to. 297 | height: The height to resize the image to. 298 | """ 299 | 300 | ratio = width / height 301 | src_ratio = image.width / image.height 302 | 303 | src_w = width if ratio < src_ratio else image.width * height // image.height 304 | src_h = height if ratio >= src_ratio else image.height * width // image.width 305 | 306 | resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"]) 307 | res = Image.new("RGB", (width, height)) 308 | res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) 309 | 310 | if ratio < src_ratio: 311 | fill_height = height // 2 - src_h // 2 312 | if fill_height > 0: 313 | res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) 314 | res.paste( 315 | resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), 316 | box=(0, fill_height + src_h), 317 | ) 318 | elif ratio > src_ratio: 319 | fill_width = width // 2 - src_w // 2 320 | if fill_width > 0: 321 | res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) 322 | res.paste( 323 | resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), 324 | box=(fill_width + src_w, 0), 325 | ) 326 | 327 | return res 328 | 329 | def _resize_and_crop( 330 | self, 331 | image: PIL.Image.Image, 332 | width: int, 333 | height: int, 334 | ) -> PIL.Image.Image: 335 | """ 336 | Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center 337 | the image within the dimensions, cropping the excess. 338 | 339 | Args: 340 | image: The image to resize. 341 | width: The width to resize the image to. 342 | height: The height to resize the image to. 343 | """ 344 | ratio = width / height 345 | src_ratio = image.width / image.height 346 | 347 | src_w = width if ratio > src_ratio else image.width * height // image.height 348 | src_h = height if ratio <= src_ratio else image.height * width // image.width 349 | 350 | resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"]) 351 | res = Image.new("RGB", (width, height)) 352 | res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) 353 | return res 354 | 355 | def resize( 356 | self, 357 | image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], 358 | height: int, 359 | width: int, 360 | resize_mode: str = "default", # "default", "fill", "crop" 361 | ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: 362 | """ 363 | Resize image. 364 | 365 | Args: 366 | image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): 367 | The image input, can be a PIL image, numpy array or pytorch tensor. 368 | height (`int`): 369 | The height to resize to. 370 | width (`int`): 371 | The width to resize to. 372 | resize_mode (`str`, *optional*, defaults to `default`): 373 | The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit 374 | within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, 375 | will resize the image to fit within the specified width and height, maintaining the aspect ratio, and 376 | then center the image within the dimensions, filling empty with data from image. If `crop`, will resize 377 | the image to fit within the specified width and height, maintaining the aspect ratio, and then center 378 | the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only 379 | supported for PIL image input. 380 | 381 | Returns: 382 | `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: 383 | The resized image. 384 | """ 385 | if resize_mode != "default" and not isinstance(image, PIL.Image.Image): 386 | raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}") 387 | if isinstance(image, PIL.Image.Image): 388 | if resize_mode == "default": 389 | image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) 390 | elif resize_mode == "fill": 391 | image = self._resize_and_fill(image, width, height) 392 | elif resize_mode == "crop": 393 | image = self._resize_and_crop(image, width, height) 394 | else: 395 | raise ValueError(f"resize_mode {resize_mode} is not supported") 396 | 397 | elif isinstance(image, torch.Tensor): 398 | image = torch.nn.functional.interpolate( 399 | image, 400 | size=(height, width), 401 | ) 402 | elif isinstance(image, np.ndarray): 403 | image = self.numpy_to_pt(image) 404 | image = torch.nn.functional.interpolate( 405 | image, 406 | size=(height, width), 407 | ) 408 | image = self.pt_to_numpy(image) 409 | return image 410 | 411 | def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image: 412 | """ 413 | Create a mask. 414 | 415 | Args: 416 | image (`PIL.Image.Image`): 417 | The image input, should be a PIL image. 418 | 419 | Returns: 420 | `PIL.Image.Image`: 421 | The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1. 422 | """ 423 | image[image < 0.5] = 0 424 | image[image >= 0.5] = 1 425 | 426 | return image 427 | 428 | def get_default_height_width( 429 | self, 430 | image: Union[PIL.Image.Image, np.ndarray, torch.Tensor], 431 | height: Optional[int] = None, 432 | width: Optional[int] = None, 433 | ) -> Tuple[int, int]: 434 | """ 435 | This function return the height and width that are downscaled to the next integer multiple of 436 | `vae_scale_factor`. 437 | 438 | Args: 439 | image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): 440 | The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have 441 | shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should 442 | have shape `[batch, channel, height, width]`. 443 | height (`int`, *optional*, defaults to `None`): 444 | The height in preprocessed image. If `None`, will use the height of `image` input. 445 | width (`int`, *optional*`, defaults to `None`): 446 | The width in preprocessed. If `None`, will use the width of the `image` input. 447 | """ 448 | 449 | if height is None: 450 | if isinstance(image, PIL.Image.Image): 451 | height = image.height 452 | elif isinstance(image, torch.Tensor): 453 | height = image.shape[2] 454 | else: 455 | height = image.shape[1] 456 | 457 | if width is None: 458 | if isinstance(image, PIL.Image.Image): 459 | width = image.width 460 | elif isinstance(image, torch.Tensor): 461 | width = image.shape[3] 462 | else: 463 | width = image.shape[2] 464 | 465 | width, height = ( 466 | x - x % self.config.vae_scale_factor for x in (width, height) 467 | ) # resize to integer multiple of vae_scale_factor 468 | 469 | return height, width 470 | 471 | def preprocess( 472 | self, 473 | image: PipelineImageInput, 474 | height: Optional[int] = None, 475 | width: Optional[int] = None, 476 | resize_mode: str = "default", # "default", "fill", "crop" 477 | crops_coords: Optional[Tuple[int, int, int, int]] = None, 478 | do_crop: bool = True, 479 | ) -> torch.Tensor: 480 | """ 481 | Preprocess the image input. 482 | 483 | Args: 484 | image (`pipeline_image_input`): 485 | The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of 486 | supported formats. 487 | height (`int`, *optional*, defaults to `None`): 488 | The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default 489 | height. 490 | width (`int`, *optional*`, defaults to `None`): 491 | The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width. 492 | resize_mode (`str`, *optional*, defaults to `default`): 493 | The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within 494 | the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will 495 | resize the image to fit within the specified width and height, maintaining the aspect ratio, and then 496 | center the image within the dimensions, filling empty with data from image. If `crop`, will resize the 497 | image to fit within the specified width and height, maintaining the aspect ratio, and then center the 498 | image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only 499 | supported for PIL image input. 500 | crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`): 501 | The crop coordinates for each image in the batch. If `None`, will not crop the image. 502 | """ 503 | supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) 504 | 505 | # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image 506 | if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3: 507 | if isinstance(image, torch.Tensor): 508 | # if image is a pytorch tensor could have 2 possible shapes: 509 | # 1. batch x height x width: we should insert the channel dimension at position 1 510 | # 2. channel x height x width: we should insert batch dimension at position 0, 511 | # however, since both channel and batch dimension has same size 1, it is same to insert at position 1 512 | # for simplicity, we insert a dimension of size 1 at position 1 for both cases 513 | image = image.unsqueeze(1) 514 | else: 515 | # if it is a numpy array, it could have 2 possible shapes: 516 | # 1. batch x height x width: insert channel dimension on last position 517 | # 2. height x width x channel: insert batch dimension on first position 518 | if image.shape[-1] == 1: 519 | image = np.expand_dims(image, axis=0) 520 | else: 521 | image = np.expand_dims(image, axis=-1) 522 | 523 | if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4: 524 | warnings.warn( 525 | "Passing `image` as a list of 4d np.ndarray is deprecated." 526 | "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray", 527 | FutureWarning, 528 | ) 529 | image = np.concatenate(image, axis=0) 530 | if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: 531 | warnings.warn( 532 | "Passing `image` as a list of 4d torch.Tensor is deprecated." 533 | "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor", 534 | FutureWarning, 535 | ) 536 | image = torch.cat(image, axis=0) 537 | 538 | if not is_valid_image_imagelist(image): 539 | raise ValueError( 540 | f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}" 541 | ) 542 | if not isinstance(image, list): 543 | image = [image] 544 | 545 | if isinstance(image[0], PIL.Image.Image): 546 | pass 547 | elif isinstance(image[0], np.ndarray): 548 | image = self.numpy_to_pil(image) 549 | elif isinstance(image[0], torch.Tensor): 550 | image = self.pt_to_numpy(image) 551 | image = self.numpy_to_pil(image) 552 | 553 | if do_crop: 554 | transforms = T.Compose([ 555 | T.Lambda(lambda image: image.convert('RGB')), 556 | T.ToTensor(), 557 | CenterCropResizeImage((height, width)), 558 | T.Normalize([.5], [.5]), 559 | ]) 560 | else: 561 | transforms = T.Compose([ 562 | T.Lambda(lambda image: image.convert('RGB')), 563 | T.ToTensor(), 564 | T.Resize((height, width)), 565 | T.Normalize([.5], [.5]), 566 | ]) 567 | image = torch.stack([transforms(i) for i in image]) 568 | 569 | # expected range [0,1], normalize to [-1,1] 570 | do_normalize = self.config.do_normalize 571 | if do_normalize and image.min() < 0: 572 | warnings.warn( 573 | "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " 574 | f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", 575 | FutureWarning, 576 | ) 577 | do_normalize = False 578 | if do_normalize: 579 | image = self.normalize(image) 580 | 581 | if self.config.do_binarize: 582 | image = self.binarize(image) 583 | 584 | return image 585 | 586 | def postprocess( 587 | self, 588 | image: torch.Tensor, 589 | output_type: str = "pil", 590 | do_denormalize: Optional[List[bool]] = None, 591 | ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: 592 | """ 593 | Postprocess the image output from tensor to `output_type`. 594 | 595 | Args: 596 | image (`torch.Tensor`): 597 | The image input, should be a pytorch tensor with shape `B x C x H x W`. 598 | output_type (`str`, *optional*, defaults to `pil`): 599 | The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. 600 | do_denormalize (`List[bool]`, *optional*, defaults to `None`): 601 | Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the 602 | `VaeImageProcessor` config. 603 | 604 | Returns: 605 | `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: 606 | The postprocessed image. 607 | """ 608 | if not isinstance(image, torch.Tensor): 609 | raise ValueError( 610 | f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" 611 | ) 612 | if output_type not in ["latent", "pt", "np", "pil"]: 613 | deprecation_message = ( 614 | f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " 615 | "`pil`, `np`, `pt`, `latent`" 616 | ) 617 | deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) 618 | output_type = "np" 619 | 620 | if output_type == "latent": 621 | return image 622 | 623 | if do_denormalize is None: 624 | do_denormalize = [self.config.do_normalize] * image.shape[0] 625 | 626 | image = torch.stack( 627 | [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])] 628 | ) 629 | 630 | if output_type == "pt": 631 | return image 632 | 633 | image = self.pt_to_numpy(image) 634 | 635 | if output_type == "np": 636 | return image 637 | 638 | if output_type == "pil": 639 | return self.numpy_to_pil(image) 640 | 641 | def apply_overlay( 642 | self, 643 | mask: PIL.Image.Image, 644 | init_image: PIL.Image.Image, 645 | image: PIL.Image.Image, 646 | crop_coords: Optional[Tuple[int, int, int, int]] = None, 647 | ) -> PIL.Image.Image: 648 | """ 649 | overlay the inpaint output to the original image 650 | """ 651 | 652 | width, height = image.width, image.height 653 | 654 | init_image = self.resize(init_image, width=width, height=height) 655 | mask = self.resize(mask, width=width, height=height) 656 | 657 | init_image_masked = PIL.Image.new("RGBa", (width, height)) 658 | init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L"))) 659 | init_image_masked = init_image_masked.convert("RGBA") 660 | 661 | if crop_coords is not None: 662 | x, y, x2, y2 = crop_coords 663 | w = x2 - x 664 | h = y2 - y 665 | base_image = PIL.Image.new("RGBA", (width, height)) 666 | image = self.resize(image, height=h, width=w, resize_mode="crop") 667 | base_image.paste(image, (x, y)) 668 | image = base_image.convert("RGB") 669 | 670 | image = image.convert("RGBA") 671 | image.alpha_composite(init_image_masked) 672 | image = image.convert("RGB") 673 | 674 | return image 675 | -------------------------------------------------------------------------------- /onediffusion/models/denoiser/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | nextdit 3 | ) -------------------------------------------------------------------------------- /onediffusion/models/denoiser/nextdit/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_nextdit import NextDiT -------------------------------------------------------------------------------- /onediffusion/models/denoiser/nextdit/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from typing import Callable, Optional 6 | 7 | import warnings 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | try: 13 | from apex.normalization import FusedRMSNorm as RMSNorm 14 | except ImportError: 15 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") 16 | 17 | 18 | class RMSNorm(torch.nn.Module): 19 | def __init__(self, dim: int, eps: float = 1e-6): 20 | """ 21 | Initialize the RMSNorm normalization layer. 22 | Args: 23 | dim (int): The dimension of the input tensor. 24 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 25 | Attributes: 26 | eps (float): A small value added to the denominator for numerical stability. 27 | weight (nn.Parameter): Learnable scaling parameter. 28 | """ 29 | super().__init__() 30 | self.eps = eps 31 | self.weight = nn.Parameter(torch.ones(dim)) 32 | 33 | def _norm(self, x): 34 | """ 35 | Apply the RMSNorm normalization to the input tensor. 36 | Args: 37 | x (torch.Tensor): The input tensor. 38 | Returns: 39 | torch.Tensor: The normalized tensor. 40 | """ 41 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 42 | 43 | def forward(self, x): 44 | """ 45 | Forward pass through the RMSNorm layer. 46 | Args: 47 | x (torch.Tensor): The input tensor. 48 | Returns: 49 | torch.Tensor: The output tensor after applying RMSNorm. 50 | """ 51 | output = self._norm(x.float()).type_as(x) 52 | return output * self.weight 53 | 54 | 55 | def modulate(x, scale): 56 | return x * (1 + scale.unsqueeze(1)) 57 | 58 | class LLamaFeedForward(nn.Module): 59 | """ 60 | Corresponds to the FeedForward layer in Next DiT. 61 | """ 62 | def __init__( 63 | self, 64 | dim: int, 65 | hidden_dim: int, 66 | multiple_of: int, 67 | ffn_dim_multiplier: Optional[float] = None, 68 | zeros_initialize: bool = True, 69 | dtype: torch.dtype = torch.float32, 70 | ): 71 | super().__init__() 72 | self.dim = dim 73 | self.hidden_dim = hidden_dim 74 | self.multiple_of = multiple_of 75 | self.ffn_dim_multiplier = ffn_dim_multiplier 76 | self.zeros_initialize = zeros_initialize 77 | self.dtype = dtype 78 | 79 | # Compute hidden_dim based on the given formula 80 | hidden_dim_calculated = int(2 * self.hidden_dim / 3) 81 | if self.ffn_dim_multiplier is not None: 82 | hidden_dim_calculated = int(self.ffn_dim_multiplier * hidden_dim_calculated) 83 | hidden_dim_calculated = self.multiple_of * ((hidden_dim_calculated + self.multiple_of - 1) // self.multiple_of) 84 | 85 | # Define linear layers 86 | self.w1 = nn.Linear(self.dim, hidden_dim_calculated, bias=False) 87 | self.w2 = nn.Linear(hidden_dim_calculated, self.dim, bias=False) 88 | self.w3 = nn.Linear(self.dim, hidden_dim_calculated, bias=False) 89 | 90 | # Initialize weights 91 | if self.zeros_initialize: 92 | nn.init.zeros_(self.w2.weight) 93 | else: 94 | nn.init.xavier_uniform_(self.w2.weight) 95 | nn.init.xavier_uniform_(self.w1.weight) 96 | nn.init.xavier_uniform_(self.w3.weight) 97 | 98 | def _forward_silu_gating(self, x1, x3): 99 | return F.silu(x1) * x3 100 | 101 | def forward(self, x): 102 | return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) 103 | 104 | class FinalLayer(nn.Module): 105 | """ 106 | The final layer of Next-DiT. 107 | """ 108 | def __init__(self, hidden_size: int, patch_size: int, out_channels: int): 109 | super().__init__() 110 | self.hidden_size = hidden_size 111 | self.patch_size = patch_size 112 | self.out_channels = out_channels 113 | 114 | # LayerNorm without learnable parameters (elementwise_affine=False) 115 | self.norm_final = nn.LayerNorm(self.hidden_size, eps=1e-6, elementwise_affine=False) 116 | self.linear = nn.Linear(self.hidden_size, np.prod(self.patch_size) * self.out_channels, bias=True) 117 | nn.init.zeros_(self.linear.weight) 118 | nn.init.zeros_(self.linear.bias) 119 | 120 | self.adaLN_modulation = nn.Sequential( 121 | nn.SiLU(), 122 | nn.Linear(self.hidden_size, self.hidden_size), 123 | ) 124 | # Initialize the last layer with zeros 125 | nn.init.zeros_(self.adaLN_modulation[1].weight) 126 | nn.init.zeros_(self.adaLN_modulation[1].bias) 127 | 128 | def forward(self, x, c): 129 | scale = self.adaLN_modulation(c) 130 | x = modulate(self.norm_final(x), scale) 131 | x = self.linear(x) 132 | return x -------------------------------------------------------------------------------- /onediffusion/models/denoiser/nextdit/modeling_nextdit.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import einops 7 | from diffusers.configuration_utils import ConfigMixin, register_to_config 8 | from diffusers.models.modeling_utils import ModelMixin 9 | from typing import Any, Tuple, Optional 10 | from flash_attn import flash_attn_varlen_func 11 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 12 | 13 | from .layers import LLamaFeedForward, RMSNorm 14 | 15 | # import frasch 16 | 17 | 18 | def modulate(x, scale): 19 | return x * (1 + scale) 20 | 21 | class TimestepEmbedder(nn.Module): 22 | """ 23 | Embeds scalar timesteps into vector representations. 24 | """ 25 | def __init__(self, hidden_size, frequency_embedding_size=256): 26 | super().__init__() 27 | self.hidden_size = hidden_size 28 | self.frequency_embedding_size = frequency_embedding_size 29 | self.mlp = nn.Sequential( 30 | nn.Linear(self.frequency_embedding_size, self.hidden_size), 31 | nn.SiLU(), 32 | nn.Linear(self.hidden_size, self.hidden_size), 33 | ) 34 | 35 | @staticmethod 36 | def timestep_embedding(t, dim, max_period=10000): 37 | """ 38 | Create sinusoidal timestep embeddings. 39 | :param t: a 1-D Tensor of N indices, one per batch element. 40 | :param dim: the dimension of the output. 41 | :param max_period: controls the minimum frequency of the embeddings. 42 | :return: an (N, D) Tensor of positional embeddings. 43 | """ 44 | half = dim // 2 45 | freqs = torch.exp( 46 | -np.log(max_period) * torch.arange(0, half, dtype=t.dtype) / half 47 | ).to(t.device) 48 | args = t[:, :, None] * freqs[None, :] 49 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 50 | if dim % 2: 51 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :, :1])], dim=-1) 52 | return embedding 53 | 54 | def forward(self, t): 55 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 56 | t_freq = t_freq.to(self.mlp[0].weight.dtype) 57 | return self.mlp(t_freq) 58 | 59 | class FinalLayer(nn.Module): 60 | def __init__(self, hidden_size, num_patches, out_channels): 61 | super().__init__() 62 | self.norm_final = nn.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False) 63 | self.linear = nn.Linear(hidden_size, num_patches * out_channels) 64 | self.adaLN_modulation = nn.Sequential( 65 | nn.SiLU(), 66 | nn.Linear(min(hidden_size, 1024), hidden_size), 67 | ) 68 | 69 | def forward(self, x, c): 70 | scale = self.adaLN_modulation(c) 71 | x = modulate(self.norm_final(x), scale) 72 | x = self.linear(x) 73 | return x 74 | 75 | class Attention(nn.Module): 76 | def __init__( 77 | self, 78 | dim, 79 | n_heads, 80 | n_kv_heads=None, 81 | qk_norm=False, 82 | y_dim=0, 83 | base_seqlen=None, 84 | proportional_attn=False, 85 | attention_dropout=0.0, 86 | max_position_embeddings=384, 87 | ): 88 | super().__init__() 89 | self.dim = dim 90 | self.n_heads = n_heads 91 | self.n_kv_heads = n_kv_heads or n_heads 92 | self.qk_norm = qk_norm 93 | self.y_dim = y_dim 94 | self.base_seqlen = base_seqlen 95 | self.proportional_attn = proportional_attn 96 | self.attention_dropout = attention_dropout 97 | self.max_position_embeddings = max_position_embeddings 98 | 99 | self.head_dim = dim // n_heads 100 | 101 | self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False) 102 | self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False) 103 | self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False) 104 | 105 | if y_dim > 0: 106 | self.wk_y = nn.Linear(y_dim, self.n_kv_heads * self.head_dim, bias=False) 107 | self.wv_y = nn.Linear(y_dim, self.n_kv_heads * self.head_dim, bias=False) 108 | self.gate = nn.Parameter(torch.zeros(n_heads)) 109 | 110 | self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False) 111 | 112 | if qk_norm: 113 | self.q_norm = nn.LayerNorm(self.n_heads * self.head_dim) 114 | self.k_norm = nn.LayerNorm(self.n_kv_heads * self.head_dim) 115 | if y_dim > 0: 116 | self.ky_norm = nn.LayerNorm(self.n_kv_heads * self.head_dim, eps=1e-6) 117 | else: 118 | self.ky_norm = nn.Identity() 119 | else: 120 | self.q_norm = nn.Identity() 121 | self.k_norm = nn.Identity() 122 | self.ky_norm = nn.Identity() 123 | 124 | 125 | @staticmethod 126 | def apply_rotary_emb(xq, xk, freqs_cis): 127 | # xq, xk: [batch_size, seq_len, n_heads, head_dim] 128 | # freqs_cis: [1, seq_len, 1, head_dim] 129 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2) 130 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2) 131 | 132 | xq_complex = torch.view_as_complex(xq_) 133 | xk_complex = torch.view_as_complex(xk_) 134 | 135 | freqs_cis = freqs_cis.unsqueeze(2) 136 | 137 | # Apply freqs_cis 138 | xq_out = xq_complex * freqs_cis 139 | xk_out = xk_complex * freqs_cis 140 | 141 | # Convert back to real numbers 142 | xq_out = torch.view_as_real(xq_out).flatten(-2) 143 | xk_out = torch.view_as_real(xk_out).flatten(-2) 144 | 145 | return xq_out.type_as(xq), xk_out.type_as(xk) 146 | 147 | # copied from huggingface modeling_llama.py 148 | def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): 149 | def _get_unpad_data(attention_mask): 150 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 151 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 152 | max_seqlen_in_batch = seqlens_in_batch.max().item() 153 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) 154 | return ( 155 | indices, 156 | cu_seqlens, 157 | max_seqlen_in_batch, 158 | ) 159 | 160 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) 161 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 162 | 163 | key_layer = index_first_axis( 164 | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), 165 | indices_k, 166 | ) 167 | value_layer = index_first_axis( 168 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), 169 | indices_k, 170 | ) 171 | if query_length == kv_seq_len: 172 | query_layer = index_first_axis( 173 | query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim), 174 | indices_k, 175 | ) 176 | cu_seqlens_q = cu_seqlens_k 177 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 178 | indices_q = indices_k 179 | elif query_length == 1: 180 | max_seqlen_in_batch_q = 1 181 | cu_seqlens_q = torch.arange( 182 | batch_size + 1, dtype=torch.int32, device=query_layer.device 183 | ) # There is a memcpy here, that is very bad. 184 | indices_q = cu_seqlens_q[:-1] 185 | query_layer = query_layer.squeeze(1) 186 | else: 187 | # The -q_len: slice assumes left padding. 188 | attention_mask = attention_mask[:, -query_length:] 189 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) 190 | 191 | return ( 192 | query_layer, 193 | key_layer, 194 | value_layer, 195 | indices_q, 196 | (cu_seqlens_q, cu_seqlens_k), 197 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 198 | ) 199 | 200 | def forward( 201 | self, 202 | x, 203 | x_mask, 204 | freqs_cis, 205 | y=None, 206 | y_mask=None, 207 | init_cache=False, 208 | ): 209 | bsz, seqlen, _ = x.size() 210 | xq = self.wq(x) 211 | xk = self.wk(x) 212 | xv = self.wv(x) 213 | 214 | if x_mask is None: 215 | x_mask = torch.ones(bsz, seqlen, dtype=torch.bool, device=x.device) 216 | inp_dtype = xq.dtype 217 | 218 | xq = self.q_norm(xq) 219 | xk = self.k_norm(xk) 220 | 221 | xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) 222 | xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) 223 | xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) 224 | 225 | if self.n_kv_heads != self.n_heads: 226 | n_rep = self.n_heads // self.n_kv_heads 227 | xk = xk.repeat_interleave(n_rep, dim=2) 228 | xv = xv.repeat_interleave(n_rep, dim=2) 229 | 230 | freqs_cis = freqs_cis.to(xq.device) 231 | xq, xk = self.apply_rotary_emb(xq, xk, freqs_cis) 232 | 233 | if inp_dtype in [torch.float16, torch.bfloat16]: 234 | # begin var_len flash attn 235 | ( 236 | query_states, 237 | key_states, 238 | value_states, 239 | indices_q, 240 | cu_seq_lens, 241 | max_seq_lens, 242 | ) = self._upad_input(xq, xk, xv, x_mask, seqlen) 243 | 244 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 245 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 246 | 247 | attn_output_unpad = flash_attn_varlen_func( 248 | query_states.to(inp_dtype), 249 | key_states.to(inp_dtype), 250 | value_states.to(inp_dtype), 251 | cu_seqlens_q=cu_seqlens_q, 252 | cu_seqlens_k=cu_seqlens_k, 253 | max_seqlen_q=max_seqlen_in_batch_q, 254 | max_seqlen_k=max_seqlen_in_batch_k, 255 | dropout_p=0.0, 256 | causal=False, 257 | softmax_scale=None, 258 | softcap=30, 259 | ) 260 | output = pad_input(attn_output_unpad, indices_q, bsz, seqlen) 261 | else: 262 | output = ( 263 | F.scaled_dot_product_attention( 264 | xq.permute(0, 2, 1, 3), 265 | xk.permute(0, 2, 1, 3), 266 | xv.permute(0, 2, 1, 3), 267 | attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_heads, seqlen, -1), 268 | scale=None, 269 | ) 270 | .permute(0, 2, 1, 3) 271 | .to(inp_dtype) 272 | ) #ok 273 | 274 | 275 | if hasattr(self, "wk_y"): 276 | yk = self.ky_norm(self.wk_y(y)).view(bsz, -1, self.n_kv_heads, self.head_dim) 277 | yv = self.wv_y(y).view(bsz, -1, self.n_kv_heads, self.head_dim) 278 | n_rep = self.n_heads // self.n_kv_heads 279 | # if n_rep >= 1: 280 | # yk = yk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) 281 | # yv = yv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) 282 | if n_rep >= 1: 283 | yk = einops.repeat(yk, "b l h d -> b l (repeat h) d", repeat=n_rep) 284 | yv = einops.repeat(yv, "b l h d -> b l (repeat h) d", repeat=n_rep) 285 | output_y = F.scaled_dot_product_attention( 286 | xq.permute(0, 2, 1, 3), 287 | yk.permute(0, 2, 1, 3), 288 | yv.permute(0, 2, 1, 3), 289 | y_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_heads, seqlen, -1).to(torch.bool), 290 | ).permute(0, 2, 1, 3) 291 | output_y = output_y * self.gate.tanh().view(1, 1, -1, 1) 292 | output = output + output_y 293 | 294 | output = output.flatten(-2) 295 | output = self.wo(output) 296 | 297 | return output.to(inp_dtype) 298 | 299 | class TransformerBlock(nn.Module): 300 | """ 301 | Corresponds to the Transformer block in the JAX code. 302 | """ 303 | def __init__( 304 | self, 305 | dim, 306 | n_heads, 307 | n_kv_heads, 308 | multiple_of, 309 | ffn_dim_multiplier, 310 | norm_eps, 311 | qk_norm, 312 | y_dim, 313 | max_position_embeddings, 314 | ): 315 | super().__init__() 316 | self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, y_dim=y_dim, max_position_embeddings=max_position_embeddings) 317 | self.feed_forward = LLamaFeedForward( 318 | dim=dim, 319 | hidden_dim=4 * dim, 320 | multiple_of=multiple_of, 321 | ffn_dim_multiplier=ffn_dim_multiplier, 322 | ) 323 | self.attention_norm1 = RMSNorm(dim, eps=norm_eps) 324 | self.attention_norm2 = RMSNorm(dim, eps=norm_eps) 325 | self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) 326 | self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) 327 | self.adaLN_modulation = nn.Sequential( 328 | nn.SiLU(), 329 | nn.Linear(min(dim, 1024), 4 * dim), 330 | ) 331 | self.attention_y_norm = RMSNorm(y_dim, eps=norm_eps) 332 | 333 | def forward( 334 | self, 335 | x, 336 | x_mask, 337 | freqs_cis, 338 | y, 339 | y_mask, 340 | adaln_input=None, 341 | ): 342 | if adaln_input is not None: 343 | scales_gates = self.adaLN_modulation(adaln_input) 344 | # TODO: Duong - check the dimension of chunking 345 | # scale_msa, gate_msa, scale_mlp, gate_mlp = scales_gates.chunk(4, dim=-1) 346 | scale_msa, gate_msa, scale_mlp, gate_mlp = scales_gates.chunk(4, dim=-1) 347 | x = x + torch.tanh(gate_msa) * self.attention_norm2( 348 | self.attention( 349 | modulate(self.attention_norm1(x), scale_msa), # ok 350 | x_mask, 351 | freqs_cis, 352 | self.attention_y_norm(y), # ok 353 | y_mask, 354 | ) 355 | ) 356 | x = x + torch.tanh(gate_mlp) * self.ffn_norm2( 357 | self.feed_forward( 358 | modulate(self.ffn_norm1(x), scale_mlp), 359 | ) 360 | ) 361 | else: 362 | x = x + self.attention_norm2( 363 | self.attention( 364 | self.attention_norm1(x), 365 | x_mask, 366 | freqs_cis, 367 | self.attention_y_norm(y), 368 | y_mask, 369 | ) 370 | ) 371 | x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) 372 | return x 373 | 374 | 375 | class NextDiT(ModelMixin, ConfigMixin): 376 | """ 377 | Diffusion model with a Transformer backbone for joint image-video training. 378 | """ 379 | @register_to_config 380 | def __init__( 381 | self, 382 | input_size=(1, 32, 32), 383 | patch_size=(1, 2, 2), 384 | in_channels=16, 385 | hidden_size=4096, 386 | depth=32, 387 | num_heads=32, 388 | num_kv_heads=None, 389 | multiple_of=256, 390 | ffn_dim_multiplier=None, 391 | norm_eps=1e-5, 392 | pred_sigma=False, 393 | caption_channels=4096, 394 | qk_norm=False, 395 | norm_type="rms", 396 | model_max_length=120, 397 | rotary_max_length=384, 398 | rotary_max_length_t=None 399 | ): 400 | super().__init__() 401 | self.input_size = input_size 402 | self.patch_size = patch_size 403 | self.in_channels = in_channels 404 | self.hidden_size = hidden_size 405 | self.depth = depth 406 | self.num_heads = num_heads 407 | self.num_kv_heads = num_kv_heads or num_heads 408 | self.multiple_of = multiple_of 409 | self.ffn_dim_multiplier = ffn_dim_multiplier 410 | self.norm_eps = norm_eps 411 | self.pred_sigma = pred_sigma 412 | self.caption_channels = caption_channels 413 | self.qk_norm = qk_norm 414 | self.norm_type = norm_type 415 | self.model_max_length = model_max_length 416 | self.rotary_max_length = rotary_max_length 417 | self.rotary_max_length_t = rotary_max_length_t 418 | self.out_channels = in_channels * 2 if pred_sigma else in_channels 419 | 420 | self.x_embedder = nn.Linear(np.prod(self.patch_size) * in_channels, hidden_size) 421 | 422 | self.t_embedder = TimestepEmbedder(min(hidden_size, 1024)) 423 | self.y_embedder = nn.Sequential( 424 | nn.LayerNorm(caption_channels, eps=1e-6), 425 | nn.Linear(caption_channels, min(hidden_size, 1024)), 426 | ) 427 | 428 | self.layers = nn.ModuleList([ 429 | TransformerBlock( 430 | dim=hidden_size, 431 | n_heads=num_heads, 432 | n_kv_heads=self.num_kv_heads, 433 | multiple_of=multiple_of, 434 | ffn_dim_multiplier=ffn_dim_multiplier, 435 | norm_eps=norm_eps, 436 | qk_norm=qk_norm, 437 | y_dim=caption_channels, 438 | max_position_embeddings=rotary_max_length, 439 | ) 440 | for _ in range(depth) 441 | ]) 442 | 443 | self.final_layer = FinalLayer( 444 | hidden_size=hidden_size, 445 | num_patches=np.prod(patch_size), 446 | out_channels=self.out_channels, 447 | ) 448 | 449 | assert (hidden_size // num_heads) % 6 == 0, "3d rope needs head dim to be divisible by 6" 450 | 451 | self.freqs_cis = self.precompute_freqs_cis( 452 | hidden_size // num_heads, 453 | self.rotary_max_length, 454 | end_t=self.rotary_max_length_t 455 | ) 456 | 457 | def to(self, *args, **kwargs): 458 | self = super().to(*args, **kwargs) 459 | # self.freqs_cis = self.freqs_cis.to(*args, **kwargs) 460 | return self 461 | 462 | @staticmethod 463 | def precompute_freqs_cis( 464 | dim: int, 465 | end: int, 466 | end_t: int = None, 467 | theta: float = 10000.0, 468 | scale_factor: float = 1.0, 469 | scale_watershed: float = 1.0, 470 | timestep: float = 1.0, 471 | ): 472 | if timestep < scale_watershed: 473 | linear_factor = scale_factor 474 | ntk_factor = 1.0 475 | else: 476 | linear_factor = 1.0 477 | ntk_factor = scale_factor 478 | 479 | theta = theta * ntk_factor 480 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 6)[: (dim // 6)] / dim)) / linear_factor 481 | 482 | timestep = torch.arange(end, dtype=torch.float32) 483 | freqs = torch.outer(timestep, freqs).float() 484 | freqs_cis = torch.exp(1j * freqs) 485 | 486 | if end_t is not None: 487 | freqs_t = 1.0 / (theta ** (torch.arange(0, dim, 6)[: (dim // 6)] / dim)) / linear_factor 488 | timestep_t = torch.arange(end_t, dtype=torch.float32) 489 | freqs_t = torch.outer(timestep_t, freqs_t).float() 490 | freqs_cis_t = torch.exp(1j * freqs_t) 491 | freqs_cis_t = freqs_cis_t.view(end_t, 1, 1, dim // 6).repeat(1, end, end, 1) 492 | else: 493 | end_t = end 494 | freqs_cis_t = freqs_cis.view(end_t, 1, 1, dim // 6).repeat(1, end, end, 1) 495 | 496 | freqs_cis_h = freqs_cis.view(1, end, 1, dim // 6).repeat(end_t, 1, end, 1) 497 | freqs_cis_w = freqs_cis.view(1, 1, end, dim // 6).repeat(end_t, end, 1, 1) 498 | freqs_cis = torch.cat([freqs_cis_t, freqs_cis_h, freqs_cis_w], dim=-1).view(end_t, end, end, -1) 499 | return freqs_cis 500 | 501 | def forward( 502 | self, 503 | samples, 504 | timesteps, 505 | encoder_hidden_states, 506 | encoder_attention_mask, 507 | scale_factor: float = 1.0, # scale_factor for rotary embedding 508 | scale_watershed: float = 1.0, # scale_watershed for rotary embedding 509 | ): 510 | if samples.ndim == 4: # B C H W 511 | samples = samples[:, None, ...] # B F C H W 512 | 513 | precomputed_freqs_cis = None 514 | if scale_factor != 1 or scale_watershed != 1: 515 | precomputed_freqs_cis = self.precompute_freqs_cis( 516 | self.hidden_size // self.num_heads, 517 | self.rotary_max_length, 518 | end_t=self.rotary_max_length_t, 519 | scale_factor=scale_factor, 520 | scale_watershed=scale_watershed, 521 | timestep=torch.max(timesteps.cpu()).item() 522 | ) 523 | 524 | if len(timesteps.shape) == 5: 525 | t, *_ = self.patchify(timesteps, precomputed_freqs_cis) 526 | timesteps = t.mean(dim=-1) 527 | elif len(timesteps.shape) == 1: 528 | timesteps = timesteps[:, None, None, None, None].expand_as(samples) 529 | t, *_ = self.patchify(timesteps, precomputed_freqs_cis) 530 | timesteps = t.mean(dim=-1) 531 | samples, T, H, W, freqs_cis = self.patchify(samples, precomputed_freqs_cis) 532 | samples = self.x_embedder(samples) 533 | t = self.t_embedder(timesteps) 534 | 535 | encoder_attention_mask_float = encoder_attention_mask[..., None].float() 536 | encoder_hidden_states_pool = (encoder_hidden_states * encoder_attention_mask_float).sum(dim=1) / (encoder_attention_mask_float.sum(dim=1) + 1e-8) 537 | encoder_hidden_states_pool = encoder_hidden_states_pool.to(samples.dtype) 538 | y = self.y_embedder(encoder_hidden_states_pool) 539 | y = y.unsqueeze(1).expand(-1, samples.size(1), -1) 540 | 541 | adaln_input = t + y 542 | 543 | for block in self.layers: 544 | samples = block(samples, None, freqs_cis, encoder_hidden_states, encoder_attention_mask, adaln_input) 545 | 546 | samples = self.final_layer(samples, adaln_input) 547 | samples = self.unpatchify(samples, T, H, W) 548 | 549 | return samples 550 | 551 | def patchify(self, x, precompute_freqs_cis=None): 552 | # pytorch is C, H, W 553 | B, T, C, H, W = x.size() 554 | pT, pH, pW = self.patch_size 555 | x = x.view(B, T // pT, pT, C, H // pH, pH, W // pW, pW) 556 | x = x.permute(0, 1, 4, 6, 2, 5, 7, 3) 557 | x = x.reshape(B, -1, pT * pH * pW * C) 558 | if precompute_freqs_cis is None: 559 | freqs_cis = self.freqs_cis[: T // pT, :H // pH, :W // pW].reshape(-1, * self.freqs_cis.shape[3:])[None].to(x.device) 560 | else: 561 | freqs_cis = precompute_freqs_cis[: T // pT, :H // pH, :W // pW].reshape(-1, * precompute_freqs_cis.shape[3:])[None].to(x.device) 562 | return x, T // pT, H // pH, W // pW, freqs_cis 563 | 564 | def unpatchify(self, x, T, H, W): 565 | B = x.size(0) 566 | C = self.out_channels 567 | pT, pH, pW = self.patch_size 568 | x = x.view(B, T, H, W, pT, pH, pW, C) 569 | x = x.permute(0, 1, 4, 7, 2, 5, 3, 6) 570 | x = x.reshape(B, T * pT, C, H * pH, W * pW) 571 | return x 572 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | matplotlib 3 | scikit-learn 4 | scipy 5 | numpy 6 | einops 7 | einsum 8 | fvcore 9 | h5py 10 | twine 11 | sentencepiece 12 | tensorflow-cpu==2.11.0 13 | protobuf==3.19.6 14 | transformers==4.45.2 15 | huggingface_hub==0.24 16 | accelerate==0.34.2 17 | diffusers==0.30.3 18 | pillow==10.2.0 19 | torch==2.3.1 20 | torchvision==0.18.1 21 | torchaudio==2.3.1 22 | flash-attn==2.6.3 23 | git+https://github.com/Dao-AILab/flash-attention@v2.6.3#subdirectory=csrc/fused_dense_lib 24 | jaxtyping 25 | mediapipe 26 | gradio 27 | git+https://github.com/facebookresearch/pytorch3d.git 28 | opencv-python==4.5.5.64 29 | opencv-python-headless==4.5.5.64 30 | bitsandbytes==0.45.0 --------------------------------------------------------------------------------