├── .gitignore ├── LICENSE ├── README.md ├── ckpt └── put_your_model_here.txt ├── configs ├── inference │ ├── config_16z.yaml │ ├── config_16z_cap.yaml │ ├── config_4z.yaml │ └── config_4z_cap.yaml └── train │ ├── config_16z.yaml │ ├── config_16z_cap.yaml │ ├── config_16z_joint.yaml │ ├── config_4z.yaml │ ├── config_4z_cap.yaml │ └── config_4z_joint.yaml ├── data ├── dataset.py └── lightning_data.py ├── docs ├── case1 │ ├── fkanimal2.gif │ └── gtanimal2.gif ├── case2 │ ├── fkcloseshot1.gif │ └── gtcloseshot1.gif ├── case3 │ ├── fkface.gif │ └── gtface.gif ├── case4 │ ├── fkmotion4.gif │ └── gtmotion4.gif ├── case5 │ ├── fkview7.gif │ └── gtview7.gif └── sota-table.png ├── evaluation ├── compute_metrics.py └── compute_metrics_img.py ├── examples ├── images │ ├── gt │ │ ├── 00000091.jpg │ │ ├── 00000103.jpg │ │ ├── 00000110.jpg │ │ ├── 00000212.jpg │ │ ├── 00000268.jpg │ │ ├── 00000592.jpg │ │ ├── 00006871.jpg │ │ ├── 00007252.jpg │ │ ├── 00007826.jpg │ │ └── 00008868.jpg │ └── recon │ │ ├── 00000091.jpeg │ │ ├── 00000110.jpeg │ │ ├── 00000212.jpeg │ │ ├── 00000268.jpeg │ │ ├── 00000592.jpeg │ │ ├── 00006871.jpeg │ │ ├── 00007252.jpeg │ │ ├── 00007826.jpeg │ │ └── 00008868.jpeg └── videos │ ├── gt │ ├── 40.mp4 │ ├── 40.txt │ ├── 8.mp4 │ ├── 8.txt │ ├── animal.mp4 │ ├── animal.txt │ ├── closeshot.mp4 │ ├── closeshot.txt │ ├── face.mp4 │ ├── face.txt │ ├── view.mp4 │ └── view.txt │ └── recon │ ├── 40_reconstructed.mp4 │ ├── 8_reconstructed.mp4 │ ├── animal_reconstructed.mp4 │ ├── closeshot_reconstructed.mp4 │ ├── face_reconstructed.mp4 │ └── view_reconstructed.mp4 ├── inference_image.py ├── inference_video.py ├── requirements.txt ├── scripts ├── evaluation_image.sh ├── evaluation_video.sh ├── run_inference_image.sh ├── run_inference_video.sh └── run_train.sh ├── src ├── distributions.py ├── models │ ├── autoencoder.py │ ├── autoencoder2plus1d_1dcnn.py │ └── autoencoder_temporal.py └── modules │ ├── ae_modules.py │ ├── attention_temporal_videoae.py │ ├── losses │ ├── __init__.py │ └── contperceptual.py │ ├── t5.py │ └── utils.py ├── train.py └── utils ├── callbacks.py ├── common_utils.py ├── save_video.py └── train_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | 4 | .vscode 5 | .DS_Store 6 | .idea 7 | .git 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial-NoDerivatives 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 58 | International Public License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial-NoDerivatives 4.0 International Public 63 | License ("Public License"). To the extent this Public License may be 64 | interpreted as a contract, You are granted the Licensed Rights in 65 | consideration of Your acceptance of these terms and conditions, and the 66 | Licensor grants You such rights in consideration of benefits the 67 | Licensor receives from making the Licensed Material available under 68 | these terms and conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Copyright and Similar Rights means copyright and/or similar rights 84 | closely related to copyright including, without limitation, 85 | performance, broadcast, sound recording, and Sui Generis Database 86 | Rights, without regard to how the rights are labeled or 87 | categorized. For purposes of this Public License, the rights 88 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 89 | Rights. 90 | 91 | c. Effective Technological Measures means those measures that, in the 92 | absence of proper authority, may not be circumvented under laws 93 | fulfilling obligations under Article 11 of the WIPO Copyright 94 | Treaty adopted on December 20, 1996, and/or similar international 95 | agreements. 96 | 97 | d. Exceptions and Limitations means fair use, fair dealing, and/or 98 | any other exception or limitation to Copyright and Similar Rights 99 | that applies to Your use of the Licensed Material. 100 | 101 | e. Licensed Material means the artistic or literary work, database, 102 | or other material to which the Licensor applied this Public 103 | License. 104 | 105 | f. Licensed Rights means the rights granted to You subject to the 106 | terms and conditions of this Public License, which are limited to 107 | all Copyright and Similar Rights that apply to Your use of the 108 | Licensed Material and that the Licensor has authority to license. 109 | 110 | g. Licensor means the individual(s) or entity(ies) granting rights 111 | under this Public License. 112 | 113 | h. NonCommercial means not primarily intended for or directed towards 114 | commercial advantage or monetary compensation. For purposes of 115 | this Public License, the exchange of the Licensed Material for 116 | other material subject to Copyright and Similar Rights by digital 117 | file-sharing or similar means is NonCommercial provided there is 118 | no payment of monetary compensation in connection with the 119 | exchange. 120 | 121 | i. Share means to provide material to the public by any means or 122 | process that requires permission under the Licensed Rights, such 123 | as reproduction, public display, public performance, distribution, 124 | dissemination, communication, or importation, and to make material 125 | available to the public including in ways that members of the 126 | public may access the material from a place and at a time 127 | individually chosen by them. 128 | 129 | j. Sui Generis Database Rights means rights other than copyright 130 | resulting from Directive 96/9/EC of the European Parliament and of 131 | the Council of 11 March 1996 on the legal protection of databases, 132 | as amended and/or succeeded, as well as other essentially 133 | equivalent rights anywhere in the world. 134 | 135 | k. You means the individual or entity exercising the Licensed Rights 136 | under this Public License. Your has a corresponding meaning. 137 | 138 | 139 | Section 2 -- Scope. 140 | 141 | a. License grant. 142 | 143 | 1. Subject to the terms and conditions of this Public License, 144 | the Licensor hereby grants You a worldwide, royalty-free, 145 | non-sublicensable, non-exclusive, irrevocable license to 146 | exercise the Licensed Rights in the Licensed Material to: 147 | 148 | a. reproduce and Share the Licensed Material, in whole or 149 | in part, for NonCommercial purposes only; and 150 | 151 | b. produce and reproduce, but not Share, Adapted Material 152 | for NonCommercial purposes only. 153 | 154 | 2. Exceptions and Limitations. For the avoidance of doubt, where 155 | Exceptions and Limitations apply to Your use, this Public 156 | License does not apply, and You do not need to comply with 157 | its terms and conditions. 158 | 159 | 3. Term. The term of this Public License is specified in Section 160 | 6(a). 161 | 162 | 4. Media and formats; technical modifications allowed. The 163 | Licensor authorizes You to exercise the Licensed Rights in 164 | all media and formats whether now known or hereafter created, 165 | and to make technical modifications necessary to do so. The 166 | Licensor waives and/or agrees not to assert any right or 167 | authority to forbid You from making technical modifications 168 | necessary to exercise the Licensed Rights, including 169 | technical modifications necessary to circumvent Effective 170 | Technological Measures. For purposes of this Public License, 171 | simply making modifications authorized by this Section 2(a) 172 | (4) never produces Adapted Material. 173 | 174 | 5. Downstream recipients. 175 | 176 | a. Offer from the Licensor -- Licensed Material. Every 177 | recipient of the Licensed Material automatically 178 | receives an offer from the Licensor to exercise the 179 | Licensed Rights under the terms and conditions of this 180 | Public License. 181 | 182 | b. No downstream restrictions. You may not offer or impose 183 | any additional or different terms or conditions on, or 184 | apply any Effective Technological Measures to, the 185 | Licensed Material if doing so restricts exercise of the 186 | Licensed Rights by any recipient of the Licensed 187 | Material. 188 | 189 | 6. No endorsement. Nothing in this Public License constitutes or 190 | may be construed as permission to assert or imply that You 191 | are, or that Your use of the Licensed Material is, connected 192 | with, or sponsored, endorsed, or granted official status by, 193 | the Licensor or others designated to receive attribution as 194 | provided in Section 3(a)(1)(A)(i). 195 | 196 | b. Other rights. 197 | 198 | 1. Moral rights, such as the right of integrity, are not 199 | licensed under this Public License, nor are publicity, 200 | privacy, and/or other similar personality rights; however, to 201 | the extent possible, the Licensor waives and/or agrees not to 202 | assert any such rights held by the Licensor to the limited 203 | extent necessary to allow You to exercise the Licensed 204 | Rights, but not otherwise. 205 | 206 | 2. Patent and trademark rights are not licensed under this 207 | Public License. 208 | 209 | 3. To the extent possible, the Licensor waives any right to 210 | collect royalties from You for the exercise of the Licensed 211 | Rights, whether directly or through a collecting society 212 | under any voluntary or waivable statutory or compulsory 213 | licensing scheme. In all other cases the Licensor expressly 214 | reserves any right to collect such royalties, including when 215 | the Licensed Material is used other than for NonCommercial 216 | purposes. 217 | 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material, You must: 227 | 228 | a. retain the following if it is supplied by the Licensor 229 | with the Licensed Material: 230 | 231 | i. identification of the creator(s) of the Licensed 232 | Material and any others designated to receive 233 | attribution, in any reasonable manner requested by 234 | the Licensor (including by pseudonym if 235 | designated); 236 | 237 | ii. a copyright notice; 238 | 239 | iii. a notice that refers to this Public License; 240 | 241 | iv. a notice that refers to the disclaimer of 242 | warranties; 243 | 244 | v. a URI or hyperlink to the Licensed Material to the 245 | extent reasonably practicable; 246 | 247 | b. indicate if You modified the Licensed Material and 248 | retain an indication of any previous modifications; and 249 | 250 | c. indicate the Licensed Material is licensed under this 251 | Public License, and include the text of, or the URI or 252 | hyperlink to, this Public License. 253 | 254 | For the avoidance of doubt, You do not have permission under 255 | this Public License to Share Adapted Material. 256 | 257 | 2. You may satisfy the conditions in Section 3(a)(1) in any 258 | reasonable manner based on the medium, means, and context in 259 | which You Share the Licensed Material. For example, it may be 260 | reasonable to satisfy the conditions by providing a URI or 261 | hyperlink to a resource that includes the required 262 | information. 263 | 264 | 3. If requested by the Licensor, You must remove any of the 265 | information required by Section 3(a)(1)(A) to the extent 266 | reasonably practicable. 267 | 268 | 269 | Section 4 -- Sui Generis Database Rights. 270 | 271 | Where the Licensed Rights include Sui Generis Database Rights that 272 | apply to Your use of the Licensed Material: 273 | 274 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 275 | to extract, reuse, reproduce, and Share all or a substantial 276 | portion of the contents of the database for NonCommercial purposes 277 | only and provided You do not Share Adapted Material; 278 | 279 | b. if You include all or a substantial portion of the database 280 | contents in a database in which You have Sui Generis Database 281 | Rights, then the database in which You have Sui Generis Database 282 | Rights (but not its individual contents) is Adapted Material; and 283 | 284 | c. You must comply with the conditions in Section 3(a) if You Share 285 | all or a substantial portion of the contents of the database. 286 | 287 | For the avoidance of doubt, this Section 4 supplements and does not 288 | replace Your obligations under this Public License where the Licensed 289 | Rights include other Copyright and Similar Rights. 290 | 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | 350 | Section 7 -- Other Terms and Conditions. 351 | 352 | a. The Licensor shall not be bound by any additional or different 353 | terms or conditions communicated by You unless expressly agreed. 354 | 355 | b. Any arrangements, understandings, or agreements regarding the 356 | Licensed Material not stated herein are separate from and 357 | independent of the terms and conditions of this Public License. 358 | 359 | 360 | Section 8 -- Interpretation. 361 | 362 | a. For the avoidance of doubt, this Public License does not, and 363 | shall not be interpreted to, reduce, limit, restrict, or impose 364 | conditions on any use of the Licensed Material that could lawfully 365 | be made without permission under this Public License. 366 | 367 | b. To the extent possible, if any provision of this Public License is 368 | deemed unenforceable, it shall be automatically reformed to the 369 | minimum extent necessary to make it enforceable. If the provision 370 | cannot be reformed, it shall be severed from this Public License 371 | without affecting the enforceability of the remaining terms and 372 | conditions. 373 | 374 | c. No term or condition of this Public License will be waived and no 375 | failure to comply consented to unless expressly agreed to by the 376 | Licensor. 377 | 378 | d. Nothing in this Public License constitutes or may be interpreted 379 | as a limitation upon, or waiver of, any privileges and immunities 380 | that apply to the Licensor or You, including from the legal 381 | processes of any jurisdiction or authority. 382 | 383 | ======================================================================= 384 | 385 | Creative Commons is not a party to its public 386 | licenses. Notwithstanding, Creative Commons may elect to apply one of 387 | its public licenses to material it publishes and in those instances 388 | will be considered the “Licensor.” The text of the Creative Commons 389 | public licenses is dedicated to the public domain under the CC0 Public 390 | Domain Dedication. Except for the limited purpose of indicating that 391 | material is shared under a Creative Commons public license or as 392 | otherwise permitted by the Creative Commons policies published at 393 | creativecommons.org/policies, Creative Commons does not authorize the 394 | use of the trademark "Creative Commons" or any other trademark or logo 395 | of Creative Commons without its prior written consent including, 396 | without limitation, in connection with any unauthorized modifications 397 | to any of its public licenses or any other arrangements, 398 | understandings, or agreements concerning use of licensed material. For 399 | the avoidance of doubt, this paragraph does not form part of the 400 | public licenses. 401 | 402 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VideoVAE+: Large Motion Video Autoencoding with Cross-modal Video VAE 2 | 3 | | Ground Truth (GT) | Reconstructed | 4 | |-------------------|---------------| 5 | | GT Video 1 | Reconstructed Video 1 | 6 | | GT Video 2 | Reconstructed Video 2 | 7 | | GT Video 3 | Reconstructed Video 3 | 8 | | GT Video 4 | Reconstructed Video 4 | 9 | | GT Video 5 | Reconstructed Video 5 | 10 | 11 | [Yazhou Xing](https://yzxing87.github.io)\*, [Yang Fei](https://sunfly04.github.io)\*, [Yingqing He](https://yingqinghe.github.io)\*†, [Jingye Chen](https://jingyechen.github.io), [Jiaxin Xie](https://jiaxinxie97.github.io/Jiaxin-Xie), [Xiaowei Chi](https://scholar.google.com/citations?user=Vl1X_-sAAAAJ&hl=zh-CN), [Qifeng Chen](https://cqf.io/)† (*equal contribution, †corresponding author) 12 | *The Hong Kong University of Science and Technology* 13 | 14 | #### [Project Page](https://yzxing87.github.io/vae/) | [Paper](https://arxiv.org/abs/2412.17805) | [High-Res Demo](https://www.youtube.com/embed/Kb4rn9z9xAA) 15 | 16 | A state-of-the-art **Video Variational Autoencoder (VAE)** designed for high-fidelity video reconstruction. This project leverages cross-modal and joint video-image training to enhance reconstruction quality. 17 | 18 | --- 19 | 20 | ## ✨ Features 21 | 22 | - **High-Fidelity Reconstruction**: Achieve superior image and video reconstruction quality. 23 | - **Cross-Modal Reconstruction**: Utilize captions to guide the reconstruction process. 24 | - **State-of-the-Art Performance**: Set new benchmarks in video reconstruction tasks. 25 | 26 | ![SOTA Table](docs/sota-table.png) 27 | --- 28 | 29 | ## 📰 News 30 | - [Jan 2025] 🏋️ Released training code & better pretrained 4z-text weight 31 | - [Dec 2024] 🚀 Released inference code and pretrained models 32 | - [Dec 2024] 📝 Released paper on [arXiv](https://arxiv.org/abs/2412.17805) 33 | - [Dec 2024] 💡 Project page is live at [VideoVAE+](https://yzxing87.github.io/vae/) 34 | 35 | --- 36 | 37 | ## ⏰ Todo 38 | 39 | - [x] **Release Pretrained Model Weights** 40 | - [x] **Release Inference Code** 41 | - [x] **Release Training Code** 42 | 43 | --- 44 | 45 | ## 🚀 Get Started 46 | 47 | Follow these steps to set up your environment and run the code: 48 | 49 | ### 1. Clone the Repository 50 | 51 | ```bash 52 | git clone https://github.com/VideoVerses/VideoVAEPlus.git 53 | cd VideoVAEPlus 54 | ``` 55 | 56 | ### 2. Set Up the Environment 57 | 58 | Create a Conda environment and install dependencies: 59 | 60 | ```bash 61 | conda create --name vae python=3.10 -y 62 | conda activate vae 63 | pip install -r requirements.txt 64 | ``` 65 | 66 | --- 67 | 68 | ## 📦 Pretrained Models 69 | 70 | | Model Name | Latent Channels | Download Link | 71 | |-----------------|-----------------|------------------| 72 | | sota-4z | 4 | [Download](https://drive.google.com/file/d/1WEKBdRFjEUxwcBgX_thckXklD8s6dDTj/view?usp=drive_link) | 73 | | sota-4z-text | 4 | [Download](https://drive.google.com/file/d/1QfqrKIWu5zG10U-xRgeF8Dhp__njC8OH/view?usp=sharing) | 74 | | sota-16z | 16 | [Download](https://drive.google.com/file/d/13v2Pq6dG1jo7RNImxNOXr9-WizgMiJ7M/view?usp=sharing) | 75 | | sota-16z-text | 16 | [Download](https://drive.google.com/file/d/1iYCAtmdaOX0V41p0vbt_6g8kRS1EK56p/view?usp=sharing) | 76 | 77 | - **Note**: '4z' and '16z' indicate the number of latent channels in the VAE model. Models with 'text' support text guidance. 78 | 79 | --- 80 | 81 | ## 📁 Data Preparation 82 | 83 | To reconstruct videos and images using our VAE model, organize your data in the following structure: 84 | 85 | ### Videos 86 | 87 | Place your videos and optional captions in the `examples/videos/gt` directory. 88 | 89 | #### Directory Structure: 90 | 91 | ``` 92 | examples/videos/ 93 | ├── gt/ 94 | │ ├── video1.mp4 95 | │ ├── video1.txt # Optional caption 96 | │ ├── video2.mp4 97 | │ ├── video2.txt 98 | │ └── ... 99 | ├── recon/ 100 | └── (reconstructed videos will be saved here) 101 | ``` 102 | 103 | - **Captions**: For cross-modal reconstruction, include a `.txt` file with the same name as the video containing its caption. 104 | 105 | ### Images 106 | 107 | Place your images in the `examples/images/gt` directory. 108 | 109 | #### Directory Structure: 110 | 111 | ``` 112 | examples/images/ 113 | ├── gt/ 114 | │ ├── image1.jpg 115 | │ ├── image2.png 116 | │ └── ... 117 | ├── recon/ 118 | └── (reconstructed images will be saved here) 119 | ``` 120 | 121 | - **Note**: The images dataset does not require captions. 122 | 123 | --- 124 | 125 | ## 🔧 Inference 126 | 127 | Our video VAE supports both image and video reconstruction. 128 | 129 | Please ensure that the `ckpt_path` in all your configuration files is set to the actual path of your checkpoint. 130 | 131 | ### Video Reconstruction 132 | 133 | Run video reconstruction using: 134 | 135 | ```bash 136 | bash scripts/run_inference_video.sh 137 | ``` 138 | 139 | This is equivalent to: 140 | 141 | ```bash 142 | python inference_video.py \ 143 | --data_root 'examples/videos/gt' \ 144 | --out_root 'examples/videos/recon' \ 145 | --config_path 'configs/inference/config_16z.yaml' \ 146 | --chunk_size 8 \ 147 | --resolution 720 1280 148 | ``` 149 | 150 | - If the chunk size is too large, you may encounter memory issues. In this case, reduce the `chunk_size` parameter. Ensure the `chunk_size` is divisible by 4. 151 | 152 | - To enable cross-modal reconstruction using captions, modify `config_path` to `'configs/config_16z_cap.yaml'` for the 16-channel model with caption guidance. 153 | 154 | ### Image Reconstruction 155 | 156 | Run image reconstruction using: 157 | 158 | ```bash 159 | bash scripts/run_inference_image.sh 160 | ``` 161 | 162 | This is equivalent to: 163 | 164 | ```bash 165 | python inference_image.py \ 166 | --data_root 'examples/images/gt' \ 167 | --out_root 'examples/images/recon' \ 168 | --config_path 'configs/inference/config_16z.yaml' \ 169 | --batch_size 1 170 | ``` 171 | 172 | - **Note**: that the batch size is set to 1 because the images in the example folder have varying resolutions. If you have a batch of images with the same resolution, you can increase the batch size to accelerate inference. 173 | 174 | --- 175 | 176 | ## 🏋️ Training 177 | 178 | ### Quick Start 179 | 180 | To start training, use the following command: 181 | 182 | ```bash 183 | bash scripts/run_training.sh config_16z 184 | ``` 185 | 186 | This default command trains the 16-channel model with video reconstruction on a single GPU. 187 | 188 | ### Configuration Options 189 | 190 | You can modify the training configuration by changing the config parameter: 191 | 192 | - `config_4z`: 4-channel model 193 | - `config_4z_joint`: 4-channel model trained jointly on both image and video data 194 | - `config_4z_cap`': 4-channel model with text guidance 195 | - `config_16z`: Default 16-channel model 196 | - `config_16z_joint`: 16-channel model trained jointly on both image and video data 197 | - `config_16z_cap`: 16-channel model with text guidance 198 | 199 | Note: Do not include the `.yaml` extension when specifying the config. 200 | 201 | ### Data Preparation 202 | 203 | #### Dataset Structure 204 | The training data should be organized in a CSV file with the following format: 205 | 206 | ```csv 207 | path,text 208 | /absolute/path/to/video1.mp4,A person walking on the beach 209 | /absolute/path/to/video2.mp4,A car driving down the road 210 | ``` 211 | 212 | #### Requirements: 213 | - Use absolute paths for video files 214 | - Include two columns: path and text 215 | - For training without text guidance, leave the caption column empty but maintain the CSV structure 216 | 217 | #### Example CSV: 218 | ```csv 219 | # With captions 220 | /data/videos/clip1.mp4,A dog playing in the park 221 | /data/videos/clip2.mp4,Sunset over the ocean 222 | 223 | # Without captions 224 | /data/videos/clip1.mp4, 225 | /data/videos/clip2.mp4, 226 | ``` 227 | 228 | --- 229 | 230 | ## 📊 Evaluation 231 | 232 | Use the provided scripts to evaluate reconstruction quality using **PSNR**, **SSIM**, and **LPIPS** metrics. 233 | 234 | ### Evaluate Image Reconstruction 235 | 236 | ```bash 237 | bash scripts/evaluation_image.sh 238 | ``` 239 | 240 | ### Evaluate Video Reconstruction 241 | 242 | ```bash 243 | bash scripts/evaluation_video.sh 244 | ``` 245 | 246 | --- 247 | 248 | ## 📝 License 249 | 250 | Please follow [CC-BY-NC-ND](./LICENSE). 251 | 252 | ## Star History 253 | 254 | [![Star History Chart](https://api.star-history.com/svg?repos=VideoVerses/VideoVAEPlus&type=Date)](https://star-history.com/#VideoVerses/VideoVAEPlus&Date) 255 | -------------------------------------------------------------------------------- /ckpt/put_your_model_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/ckpt/put_your_model_here.txt -------------------------------------------------------------------------------- /configs/inference/config_16z.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | scale_lr: False 4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn 5 | params: 6 | monitor: "val/rec_loss" 7 | video_key: video 8 | image_key: video 9 | ckpt_path: ckpt/sota-4-16z.ckpt 10 | input_dim: 5 11 | ignore_keys_3d: ['loss'] 12 | caption_guide: False 13 | use_quant_conv: False 14 | img_video_joint_train: False 15 | 16 | lossconfig: 17 | target: src.modules.losses.LPIPSWithDiscriminator3D 18 | params: 19 | disc_start: 50001 20 | kl_weight: 0 21 | disc_weight: 0.5 22 | 23 | ddconfig: 24 | double_z: True 25 | z_channels: 16 26 | resolution: 216 27 | in_channels: 3 28 | out_ch: 3 29 | ch: 128 30 | ch_mult: [ 1,2,4,4 ] 31 | temporal_down_factor: 1 32 | num_res_blocks: 2 33 | attn_resolutions: [] 34 | dropout: 0.0 35 | 36 | ppconfig: 37 | temporal_scale_factor: 4 38 | z_channels: 16 39 | out_ch: 16 40 | ch: 16 41 | attn_temporal_factor: [] -------------------------------------------------------------------------------- /configs/inference/config_16z_cap.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 #5.80e-04 3 | scale_lr: False 4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn 5 | params: 6 | monitor: "val/rec_loss" 7 | video_key: video 8 | image_key: video 9 | img_video_joint_train: False 10 | caption_guide: True 11 | use_quant_conv: False 12 | t5_model_max_length: 100 13 | 14 | ckpt_path: ckpt/sota-4-16z-text.ckpt 15 | input_dim: 5 16 | ignore_keys_3d: ['loss'] 17 | 18 | lossconfig: 19 | target: src.modules.losses.LPIPSWithDiscriminator3D 20 | params: 21 | disc_start: 50001 22 | kl_weight: 0 23 | disc_weight: 0.5 24 | 25 | ddconfig: 26 | double_z: True 27 | z_channels: 16 28 | resolution: 216 29 | in_channels: 3 30 | out_ch: 3 31 | ch: 128 32 | ch_mult: [ 1,2,4,4 ] 33 | temporal_down_factor: 1 34 | num_res_blocks: 2 35 | attn_resolutions: [27, 54, 108, 216] 36 | dropout: 0.0 37 | 38 | ppconfig: 39 | temporal_scale_factor: 4 40 | z_channels: 16 41 | out_ch: 16 42 | ch: 16 43 | attn_temporal_factor: [2,4] -------------------------------------------------------------------------------- /configs/inference/config_4z.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 #5.80e-04 3 | scale_lr: False 4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn 5 | params: 6 | monitor: "val/rec_loss" 7 | embed_dim: 4 8 | video_key: video 9 | image_key: video #jpg 10 | ckpt_path: ckpt/sota-4-4z.ckpt 11 | input_dim: 5 12 | ignore_keys_3d: ['loss'] 13 | 14 | img_video_joint_train: False 15 | caption_guide: False 16 | use_quant_conv: True 17 | 18 | lossconfig: 19 | target: src.modules.losses.LPIPSWithDiscriminator3D 20 | params: 21 | disc_start: 50001 22 | kl_weight: 0 23 | disc_weight: 0.5 24 | 25 | ddconfig: 26 | double_z: True 27 | z_channels: 4 28 | resolution: 216 29 | in_channels: 3 30 | out_ch: 3 31 | ch: 128 32 | ch_mult: [ 1,2,4,4 ] 33 | temporal_down_factor: 1 34 | num_res_blocks: 2 35 | attn_resolutions: [ ] 36 | dropout: 0.0 37 | 38 | ppconfig: 39 | temporal_scale_factor: 4 40 | z_channels: 4 41 | out_ch: 4 42 | ch: 4 43 | attn_temporal_factor: [] -------------------------------------------------------------------------------- /configs/inference/config_4z_cap.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | scale_lr: False 4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn 5 | params: 6 | monitor: "val/rec_loss" 7 | video_key: video 8 | image_key: video 9 | img_video_joint_train: False 10 | caption_guide: True 11 | use_quant_conv: True 12 | t5_model_max_length: 100 13 | 14 | ckpt_path: ckpt/sota-4-4z-text.ckpt 15 | input_dim: 5 16 | 17 | lossconfig: 18 | target: src.modules.losses.LPIPSWithDiscriminator3D 19 | params: 20 | disc_start: 50001 21 | kl_weight: 0 22 | disc_weight: 0.5 23 | 24 | ddconfig: 25 | double_z: True 26 | z_channels: 4 27 | resolution: 216 28 | in_channels: 3 29 | out_ch: 3 30 | ch: 128 31 | ch_mult: [ 1,2,4,4 ] 32 | temporal_down_factor: 1 33 | num_res_blocks: 2 34 | attn_resolutions: [27, 54, 108, 216] 35 | dropout: 0.0 36 | 37 | ppconfig: 38 | temporal_scale_factor: 4 39 | z_channels: 4 40 | out_ch: 4 41 | ch: 4 42 | attn_temporal_factor: [2,4] -------------------------------------------------------------------------------- /configs/train/config_16z.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 #5.80e-04 3 | scale_lr: False 4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn 5 | params: 6 | monitor: "val/rec_loss" 7 | video_key: video 8 | image_key: video 9 | img_video_joint_train: False 10 | caption_guide: False 11 | use_quant_conv: False 12 | 13 | ignore_keys_3d: ['loss'] 14 | ckpt_path: ckpt/sota-4-16z.ckpt 15 | input_dim: 5 16 | 17 | lossconfig: 18 | target: src.modules.losses.LPIPSWithDiscriminator3D 19 | params: 20 | disc_start: 50001 21 | kl_weight: 0.000001 22 | disc_weight: 0.5 23 | 24 | ddconfig: 25 | double_z: True 26 | z_channels: 16 27 | resolution: 216 28 | in_channels: 3 29 | out_ch: 3 30 | ch: 128 31 | ch_mult: [ 1,2,4,4 ] 32 | temporal_down_factor: 1 33 | num_res_blocks: 2 34 | attn_resolutions: [ ] 35 | dropout: 0.0 36 | 37 | ppconfig: 38 | temporal_scale_factor: 4 39 | z_channels: 16 40 | out_ch: 16 41 | ch: 16 42 | attn_temporal_factor: [] 43 | 44 | data: 45 | target: data.lightning_data.DataModuleFromConfig 46 | params: 47 | img_video_joint_train: False 48 | batch_size: 1 49 | num_workers: 32 50 | wrap: false 51 | train: 52 | target: data.dataset.DatasetVideoLoader 53 | params: 54 | csv_file: path/to/your.csv 55 | resolution: [216, 216] 56 | video_length: 16 57 | subset_split: train 58 | validation: 59 | target: data.dataset.DatasetVideoLoader 60 | params: 61 | csv_file: path/to/your.csv 62 | resolution: [216, 216] 63 | video_length: 16 64 | subset_split: val 65 | 66 | lightning: 67 | find_unused_parameters: True 68 | callbacks: 69 | image_logger: 70 | target: utils.callbacks.ImageLogger 71 | params: 72 | batch_frequency: 1009 73 | max_images: 8 74 | metrics_over_trainsteps_checkpoint: 75 | target: pytorch_lightning.callbacks.ModelCheckpoint 76 | params: 77 | filename: '{epoch:06}-{step:09}' 78 | save_weights_only: False 79 | every_n_train_steps: 5000 80 | trainer: 81 | benchmark: True 82 | accumulate_grad_batches: 2 83 | batch_size: 1 84 | num_workers: 32 85 | max_epochs: 3000 86 | modelcheckpoint: 87 | target: pytorch_lightning.callbacks.ModelCheckpoint 88 | params: 89 | every_n_train_steps: 3000 90 | filename: "{epoch:04}-{step:06}" 91 | -------------------------------------------------------------------------------- /configs/train/config_16z_cap.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | scale_lr: False 4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn 5 | params: 6 | monitor: "val/rec_loss" 7 | video_key: video 8 | image_key: video 9 | img_video_joint_train: False 10 | caption_guide: True 11 | use_quant_conv: False 12 | t5_model_max_length: 100 13 | 14 | ckpt_path: ckpt/sota-4-16z-text.ckpt 15 | 16 | input_dim: 5 17 | ignore_keys_3d: ['loss'] 18 | 19 | lossconfig: 20 | target: src.modules.losses.LPIPSWithDiscriminator3D 21 | params: 22 | disc_start: 50001 23 | kl_weight: 0.000001 24 | disc_weight: 0.5 25 | 26 | ddconfig: 27 | double_z: True 28 | z_channels: 16 29 | resolution: 216 30 | in_channels: 3 31 | out_ch: 3 32 | ch: 128 33 | ch_mult: [ 1,2,4,4 ] 34 | temporal_down_factor: 1 35 | num_res_blocks: 2 36 | attn_resolutions: [27, 54, 108, 216] 37 | dropout: 0.0 38 | 39 | ppconfig: 40 | temporal_scale_factor: 4 41 | z_channels: 16 42 | out_ch: 16 43 | ch: 16 44 | attn_temporal_factor: [2, 4] 45 | 46 | data: 47 | target: data.lightning_data.DataModuleFromConfig 48 | params: 49 | batch_size: 1 50 | num_workers: 32 51 | wrap: false 52 | train: 53 | target: data.dataset.DatasetVideoLoader 54 | params: 55 | csv_file: path/to/your.csv 56 | resolution: [216, 216] 57 | video_length: 16 58 | subset_split: train 59 | validation: 60 | target: data.dataset.DatasetVideoLoader 61 | params: 62 | csv_file: path/to/your.csv 63 | resolution: [216, 216] 64 | video_length: 16 65 | subset_split: val 66 | 67 | lightning: 68 | find_unused_parameters: True 69 | callbacks: 70 | image_logger: 71 | target: utils.callbacks.ImageLogger 72 | params: 73 | batch_frequency: 509 74 | max_images: 8 75 | metrics_over_trainsteps_checkpoint: 76 | target: pytorch_lightning.callbacks.ModelCheckpoint 77 | params: 78 | filename: '{epoch:06}-{step:09}' 79 | save_weights_only: True 80 | every_n_train_steps: 5000 81 | trainer: 82 | benchmark: True 83 | accumulate_grad_batches: 2 84 | batch_size: 1 85 | num_workers: 32 86 | max_epochs: 3000 87 | modelcheckpoint: 88 | target: pytorch_lightning.callbacks.ModelCheckpoint 89 | params: 90 | every_n_train_steps: 3000 91 | filename: "{epoch:04}-{step:06}" -------------------------------------------------------------------------------- /configs/train/config_16z_joint.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 #5.80e-04 3 | scale_lr: False 4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn 5 | params: 6 | monitor: "val/rec_loss" 7 | video_key: video 8 | image_key: video 9 | img_video_joint_train: True 10 | caption_guide: False 11 | use_quant_conv: False 12 | 13 | ignore_keys_3d: ['loss'] 14 | ckpt_path: ckpt/sota-4-16z.ckpt 15 | input_dim: 5 16 | 17 | lossconfig: 18 | target: src.modules.losses.LPIPSWithDiscriminator 19 | params: 20 | disc_start: 50001 21 | kl_weight: 0.000001 22 | disc_weight: 0.5 23 | 24 | ddconfig: 25 | double_z: True 26 | z_channels: 16 27 | resolution: 216 28 | in_channels: 3 29 | out_ch: 3 30 | ch: 128 31 | ch_mult: [ 1,2,4,4 ] 32 | temporal_down_factor: 1 33 | num_res_blocks: 2 34 | attn_resolutions: [ ] 35 | dropout: 0.0 36 | 37 | ppconfig: 38 | temporal_scale_factor: 4 39 | z_channels: 16 40 | out_ch: 16 41 | ch: 16 42 | attn_temporal_factor: [] 43 | 44 | data: 45 | target: data.lightning_data.DataModuleFromConfig 46 | params: 47 | img_video_joint_train: True 48 | batch_size: 1 49 | num_workers: 20 50 | wrap: false 51 | train: 52 | target: data.dataset.DatasetVideoLoader 53 | params: 54 | csv_file: path/to/your.csv 55 | resolution: [216, 216] 56 | video_length: 16 57 | subset_split: train 58 | validation: 59 | target: data.dataset.DatasetVideoLoader 60 | params: 61 | csv_file: path/to/your.csv 62 | resolution: [216, 216] 63 | video_length: 16 64 | subset_split: val 65 | 66 | lightning: 67 | find_unused_parameters: True 68 | callbacks: 69 | image_logger: 70 | target: utils.callbacks.ImageLogger 71 | params: 72 | batch_frequency: 1009 73 | max_images: 8 74 | metrics_over_trainsteps_checkpoint: 75 | target: pytorch_lightning.callbacks.ModelCheckpoint 76 | params: 77 | filename: '{epoch:06}-{step:09}' 78 | save_weights_only: False 79 | every_n_train_steps: 5000 80 | trainer: 81 | benchmark: True 82 | accumulate_grad_batches: 2 83 | batch_size: 1 84 | num_workers: 20 85 | max_epochs: 3000 86 | modelcheckpoint: 87 | target: pytorch_lightning.callbacks.ModelCheckpoint 88 | params: 89 | every_n_train_steps: 3000 90 | filename: "{epoch:04}-{step:06}" 91 | -------------------------------------------------------------------------------- /configs/train/config_4z.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 #5.80e-04 3 | scale_lr: False 4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn 5 | params: 6 | monitor: "val/rec_loss" 7 | embed_dim: 4 8 | video_key: video 9 | image_key: video #jpg 10 | ckpt_path: ckpt/sota-4-4z.ckpt 11 | input_dim: 5 12 | ignore_keys_3d: ['loss'] 13 | 14 | img_video_joint_train: False 15 | caption_guide: False 16 | use_quant_conv: True 17 | 18 | lossconfig: 19 | target: src.modules.losses.LPIPSWithDiscriminator3D 20 | params: 21 | disc_start: 50001 22 | kl_weight: 0.000001 23 | disc_weight: 0.5 24 | 25 | ddconfig: 26 | double_z: True 27 | z_channels: 4 28 | resolution: 216 29 | in_channels: 3 30 | out_ch: 3 31 | ch: 128 32 | ch_mult: [ 1,2,4,4 ] 33 | temporal_down_factor: 1 34 | num_res_blocks: 2 35 | attn_resolutions: [ ] 36 | dropout: 0.0 37 | 38 | ppconfig: 39 | temporal_scale_factor: 4 40 | z_channels: 4 41 | out_ch: 4 42 | ch: 4 # 16*4 43 | attn_temporal_factor: [] 44 | 45 | data: 46 | target: data.lightning_data.DataModuleFromConfig 47 | params: 48 | img_video_joint_train: False 49 | batch_size: 1 50 | num_workers: 32 51 | wrap: false 52 | train: 53 | target: data.dataset.DatasetVideoLoader 54 | params: 55 | csv_file: path/to/your.csv 56 | resolution: [216, 216] 57 | video_length: 16 58 | subset_split: train 59 | validation: 60 | target: data.dataset.DatasetVideoLoader 61 | params: 62 | csv_file: path/to/your.csv 63 | resolution: [216, 216] 64 | video_length: 16 65 | subset_split: val 66 | 67 | lightning: 68 | find_unused_parameters: True 69 | callbacks: 70 | image_logger: 71 | target: utils.callbacks.ImageLogger 72 | params: 73 | batch_frequency: 1009 74 | max_images: 8 75 | metrics_over_trainsteps_checkpoint: 76 | target: pytorch_lightning.callbacks.ModelCheckpoint 77 | params: 78 | filename: '{epoch:06}-{step:09}' 79 | save_weights_only: False 80 | every_n_train_steps: 5000 81 | trainer: 82 | benchmark: True 83 | accumulate_grad_batches: 2 84 | batch_size: 1 85 | num_workers: 32 86 | max_epochs: 3000 87 | modelcheckpoint: 88 | target: pytorch_lightning.callbacks.ModelCheckpoint 89 | params: 90 | every_n_train_steps: 3000 91 | filename: "{epoch:04}-{step:06}" 92 | -------------------------------------------------------------------------------- /configs/train/config_4z_cap.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | scale_lr: False 4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn 5 | params: 6 | embed_dim: 4 7 | monitor: "val/rec_loss" 8 | video_key: video 9 | image_key: video 10 | img_video_joint_train: False 11 | caption_guide: True 12 | use_quant_conv: True 13 | t5_model_max_length: 100 14 | 15 | ckpt_path: ckpt/sota-4-4z-text.ckpt 16 | input_dim: 5 17 | 18 | ignore_keys_3d: ['loss'] 19 | 20 | lossconfig: 21 | target: src.modules.losses.LPIPSWithDiscriminator3D 22 | params: 23 | disc_start: 50001 24 | kl_weight: 0.000001 25 | disc_weight: 0.5 26 | 27 | ddconfig: 28 | double_z: True 29 | z_channels: 4 30 | resolution: 216 31 | in_channels: 3 32 | out_ch: 3 33 | ch: 128 34 | ch_mult: [ 1,2,4,4 ] 35 | temporal_down_factor: 1 36 | num_res_blocks: 2 37 | attn_resolutions: [27, 54, 108, 216] 38 | dropout: 0.0 39 | 40 | ppconfig: 41 | temporal_scale_factor: 4 42 | z_channels: 4 43 | out_ch: 4 44 | ch: 4 45 | attn_temporal_factor: [2,4] 46 | 47 | data: 48 | target: data.lightning_data.DataModuleFromConfig 49 | params: 50 | batch_size: 1 51 | num_workers: 32 52 | wrap: false 53 | train: 54 | target: data.dataset.DatasetVideoLoader 55 | params: 56 | csv_file: path/to/your.csv 57 | resolution: [216, 216] 58 | video_length: 16 59 | subset_split: train 60 | validation: 61 | target: data.dataset.DatasetVideoLoader 62 | params: 63 | csv_file: path/to/your.csv 64 | resolution: [216, 216] 65 | video_length: 16 66 | subset_split: val 67 | 68 | lightning: 69 | find_unused_parameters: True 70 | callbacks: 71 | image_logger: 72 | target: utils.callbacks.ImageLogger 73 | params: 74 | batch_frequency: 509 75 | max_images: 8 76 | metrics_over_trainsteps_checkpoint: 77 | target: pytorch_lightning.callbacks.ModelCheckpoint 78 | params: 79 | filename: '{epoch:06}-{step:09}' 80 | save_weights_only: True 81 | every_n_train_steps: 5000 82 | trainer: 83 | benchmark: True 84 | accumulate_grad_batches: 2 85 | batch_size: 1 86 | num_workers: 32 87 | max_epochs: 3000 88 | modelcheckpoint: 89 | target: pytorch_lightning.callbacks.ModelCheckpoint 90 | params: 91 | every_n_train_steps: 3000 92 | filename: "{epoch:04}-{step:06}" -------------------------------------------------------------------------------- /configs/train/config_4z_joint.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | scale_lr: False 4 | target: src.models.autoencoder2plus1d_1dcnn.AutoencoderKL2plus1D_1dcnn 5 | params: 6 | monitor: "val/rec_loss" 7 | embed_dim: 4 8 | video_key: video 9 | image_key: video 10 | ckpt_path: ckpt/sota-4-4z.ckpt 11 | input_dim: 5 12 | 13 | ignore_keys_3d: ['loss'] 14 | 15 | img_video_joint_train: True 16 | caption_guide: False 17 | use_quant_conv: True 18 | 19 | lossconfig: 20 | target: src.modules.losses.LPIPSWithDiscriminator 21 | params: 22 | disc_start: 50001 23 | kl_weight: 0.000001 24 | disc_weight: 0.5 25 | 26 | ddconfig: 27 | double_z: True 28 | z_channels: 4 29 | resolution: 216 30 | in_channels: 3 31 | out_ch: 3 32 | ch: 128 33 | ch_mult: [ 1,2,4,4 ] 34 | temporal_down_factor: 1 35 | num_res_blocks: 2 36 | attn_resolutions: [ ] 37 | dropout: 0.0 38 | 39 | ppconfig: 40 | temporal_scale_factor: 4 41 | z_channels: 4 42 | out_ch: 4 43 | ch: 4 44 | attn_temporal_factor: [] 45 | 46 | data: 47 | target: data.lightning_data.DataModuleFromConfig 48 | params: 49 | img_video_joint_train: True 50 | batch_size: 1 51 | num_workers: 20 52 | wrap: false 53 | train: 54 | target: data.dataset.DatasetVideoLoader 55 | params: 56 | csv_file: path/to/your.csv 57 | resolution: [216, 216] 58 | video_length: 16 59 | subset_split: train 60 | validation: 61 | target: data.dataset.DatasetVideoLoader 62 | params: 63 | csv_file: path/to/your.csv 64 | resolution: [216, 216] 65 | video_length: 16 66 | subset_split: val 67 | 68 | lightning: 69 | find_unused_parameters: True 70 | callbacks: 71 | image_logger: 72 | target: utils.callbacks.ImageLogger 73 | params: 74 | batch_frequency: 1009 75 | max_images: 8 76 | metrics_over_trainsteps_checkpoint: 77 | target: pytorch_lightning.callbacks.ModelCheckpoint 78 | params: 79 | filename: '{epoch:06}-{step:09}' 80 | save_weights_only: False 81 | every_n_train_steps: 5000 82 | trainer: 83 | benchmark: True 84 | accumulate_grad_batches: 2 85 | batch_size: 1 86 | num_workers: 20 87 | max_epochs: 3000 88 | modelcheckpoint: 89 | target: pytorch_lightning.callbacks.ModelCheckpoint 90 | params: 91 | every_n_train_steps: 3000 92 | filename: "{epoch:04}-{step:06}" 93 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | from torch.utils.data import Dataset 5 | from decord import VideoReader, cpu 6 | import pandas as pd 7 | 8 | 9 | class DatasetVideoLoader(Dataset): 10 | """ 11 | Dataset for loading videos and captions from a CSV file. 12 | CSV file contains two columns: 'path' and 'text', where: 13 | - 'path' is the path to the video file 14 | - 'text' is the caption for the video. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | csv_file, 20 | resolution, 21 | video_length, 22 | frame_stride=4, 23 | subset_split="all", 24 | clip_length=1.0, 25 | random_stride=False, 26 | mode="video", 27 | ): 28 | self.csv_file = csv_file 29 | self.resolution = resolution 30 | self.video_length = video_length 31 | self.subset_split = subset_split 32 | self.frame_stride = frame_stride 33 | self.clip_length = clip_length 34 | self.random_stride = random_stride 35 | self.mode = mode 36 | 37 | assert self.subset_split in ["train", "test", "val", "all"] 38 | self.exts = ["avi", "mp4", "webm"] 39 | 40 | if isinstance(self.resolution, int): 41 | self.resolution = [self.resolution, self.resolution] 42 | 43 | # Load dataset from CSV file 44 | self._make_dataset() 45 | 46 | def _make_dataset(self): 47 | """ 48 | Load video paths and captions from the CSV file. 49 | """ 50 | self.videos = pd.read_csv(self.csv_file) 51 | print(f"Loaded {len(self.videos)} videos from {self.csv_file}") 52 | 53 | if self.subset_split == "val": 54 | self.videos = self.videos[-300:] 55 | elif self.subset_split == "train": 56 | self.videos = self.videos[:-300] 57 | elif self.subset_split == "test": 58 | self.videos = self.videos[-30:] 59 | 60 | print(f"Number of videos = {len(self.videos)}") 61 | 62 | # Create video indices for image mode 63 | self.video_indices = list(range(len(self.videos))) 64 | 65 | def set_mode(self, mode): 66 | self.mode = mode 67 | 68 | def _get_video_path(self, index): 69 | return self.videos.iloc[index]["path"] 70 | 71 | def __getitem__(self, index): 72 | if self.mode == "image": 73 | return self.__getitem__images(index) 74 | else: 75 | return self.__getitem__video(index) 76 | 77 | def __getitem__video(self, index): 78 | while True: 79 | video_path = self.videos.iloc[index]["path"] 80 | caption = self.videos.iloc[index]["text"] 81 | 82 | try: 83 | video_reader = VideoReader( 84 | video_path, 85 | ctx=cpu(0), 86 | width=self.resolution[1], 87 | height=self.resolution[0], 88 | ) 89 | if len(video_reader) < self.video_length: 90 | index = (index + 1) % len(self.videos) 91 | continue 92 | else: 93 | break 94 | except Exception as e: 95 | print(f"Load video failed! path = {video_path}, error: {str(e)}") 96 | index = (index + 1) % len(self.videos) 97 | continue 98 | 99 | if self.random_stride: 100 | self.frame_stride = random.choice([4, 8, 12, 16]) 101 | 102 | all_frames = list(range(0, len(video_reader), self.frame_stride)) 103 | if len(all_frames) < self.video_length: 104 | all_frames = list(range(0, len(video_reader), 1)) 105 | 106 | # Select random clip 107 | rand_idx = random.randint(0, len(all_frames) - self.video_length) 108 | frame_indices = all_frames[rand_idx : rand_idx + self.video_length] 109 | frames = video_reader.get_batch(frame_indices) 110 | assert ( 111 | frames.shape[0] == self.video_length 112 | ), f"{len(frames)}, self.video_length={self.video_length}" 113 | 114 | frames = ( 115 | torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() 116 | ) # [t,h,w,c] -> [c,t,h,w] 117 | assert ( 118 | frames.shape[2] == self.resolution[0] 119 | and frames.shape[3] == self.resolution[1] 120 | ), f"frames={frames.shape}, self.resolution={self.resolution}" 121 | frames = (frames / 255 - 0.5) * 2 122 | 123 | return {"video": frames, "caption": caption, "is_video": True} 124 | 125 | def __getitem__images(self, index): 126 | frames_list = [] 127 | for i in range(self.video_length): 128 | # Get a unique video for each frame 129 | video_index = (index + i) % len(self.video_indices) 130 | video_path = self._get_video_path(video_index) 131 | 132 | try: 133 | video_reader = VideoReader( 134 | video_path, 135 | ctx=cpu(0), 136 | width=self.resolution[1], 137 | height=self.resolution[0], 138 | ) 139 | except Exception as e: 140 | print(f"Load video failed! path = {video_path}, error = {e}") 141 | # Skip this video and try the next one 142 | return self.__getitem__images((index + 1) % len(self.video_indices)) 143 | 144 | # Randomly select a frame from the video 145 | rand_idx = random.randint(0, len(video_reader) - 1) 146 | frame = video_reader[rand_idx] 147 | frame_tensor = ( 148 | torch.tensor(frame.asnumpy()).permute(2, 0, 1).float().unsqueeze(0) 149 | ) # [h,w,c] -> [c,h,w] -> [1, c, h, w] 150 | 151 | frames_list.append(frame_tensor) 152 | 153 | frames = torch.cat(frames_list, dim=0) 154 | frames = (frames / 255 - 0.5) * 2 155 | frames = frames.permute(1, 0, 2, 3) 156 | assert ( 157 | frames.shape[2] == self.resolution[0] 158 | and frames.shape[3] == self.resolution[1] 159 | ), f"frame={frames.shape}, self.resolution={self.resolution}" 160 | 161 | data = {"video": frames, "is_video": False} 162 | return data 163 | 164 | def __len__(self): 165 | return len(self.videos) 166 | -------------------------------------------------------------------------------- /data/lightning_data.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | 4 | import torch 5 | import pytorch_lightning as pl 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | import argparse, os, sys, glob 9 | 10 | os.chdir(sys.path[0]) 11 | sys.path.append("..") 12 | 13 | from utils.common_utils import instantiate_from_config 14 | 15 | 16 | def worker_init_fn(_): 17 | worker_info = torch.utils.data.get_worker_info() 18 | 19 | dataset = worker_info.dataset 20 | worker_id = worker_info.id 21 | 22 | mode = "image" if worker_id < worker_info.num_workers * 0.2 else "video" 23 | print(f"Mode is {mode}") 24 | dataset.set_mode(mode) 25 | 26 | return np.random.seed(np.random.get_state()[1][0] + worker_id) 27 | 28 | 29 | class WrappedDataset(Dataset): 30 | """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" 31 | 32 | def __init__(self, dataset): 33 | self.data = dataset 34 | 35 | def __len__(self): 36 | return len(self.data) 37 | 38 | def __getitem__(self, idx): 39 | return self.data[idx] 40 | 41 | 42 | class DataModuleFromConfig(pl.LightningDataModule): 43 | def __init__( 44 | self, 45 | batch_size, 46 | train=None, 47 | validation=None, 48 | test=None, 49 | predict=None, 50 | wrap=False, 51 | num_workers=None, 52 | shuffle_test_loader=False, 53 | img_video_joint_train=False, 54 | shuffle_val_dataloader=False, 55 | train_img=None, 56 | test_max_n_samples=None, 57 | ): 58 | super().__init__() 59 | self.batch_size = batch_size 60 | self.dataset_configs = dict() 61 | self.num_workers = num_workers if num_workers is not None else batch_size * 2 62 | self.use_worker_init_fn = img_video_joint_train 63 | if train is not None: 64 | self.dataset_configs["train"] = train 65 | self.train_dataloader = self._train_dataloader 66 | if validation is not None: 67 | self.dataset_configs["validation"] = validation 68 | self.val_dataloader = partial( 69 | self._val_dataloader, shuffle=shuffle_val_dataloader 70 | ) 71 | if test is not None: 72 | self.dataset_configs["test"] = test 73 | self.test_dataloader = partial( 74 | self._test_dataloader, shuffle=shuffle_test_loader 75 | ) 76 | if predict is not None: 77 | self.dataset_configs["predict"] = predict 78 | self.predict_dataloader = self._predict_dataloader 79 | # train 2 dataset 80 | # if img_loader is not None: 81 | # img_data = instantiate_from_config(img_loader) 82 | # img_data.setup() 83 | if train_img is not None: 84 | if train_img["params"]["batch_size"] == -1: 85 | train_img["params"]["batch_size"] = ( 86 | batch_size * train["params"]["video_length"] 87 | ) 88 | print( 89 | "Set train_img batch_size to {}".format( 90 | train_img["params"]["batch_size"] 91 | ) 92 | ) 93 | img_data = instantiate_from_config(train_img) 94 | self.img_loader = img_data.train_dataloader() 95 | else: 96 | self.img_loader = None 97 | self.wrap = wrap 98 | self.test_max_n_samples = test_max_n_samples 99 | self.collate_fn = None 100 | 101 | def prepare_data(self): 102 | # for data_cfg in self.dataset_configs.values(): 103 | # instantiate_from_config(data_cfg) 104 | pass 105 | 106 | def setup(self, stage=None): 107 | self.datasets = dict( 108 | (k, instantiate_from_config(self.dataset_configs[k])) 109 | for k in self.dataset_configs 110 | ) 111 | if self.wrap: 112 | for k in self.datasets: 113 | self.datasets[k] = WrappedDataset(self.datasets[k]) 114 | 115 | def _train_dataloader(self): 116 | if self.use_worker_init_fn: 117 | init_fn = worker_init_fn 118 | else: 119 | init_fn = None 120 | loader = DataLoader( 121 | self.datasets["train"], 122 | batch_size=self.batch_size, 123 | num_workers=self.num_workers, 124 | shuffle=True, 125 | worker_init_fn=init_fn, 126 | collate_fn=self.collate_fn, 127 | ) 128 | if self.img_loader is not None: 129 | return {"loader_video": loader, "loader_img": self.img_loader} 130 | else: 131 | return loader 132 | 133 | def _val_dataloader(self, shuffle=False): 134 | if self.use_worker_init_fn: 135 | init_fn = worker_init_fn 136 | else: 137 | init_fn = None 138 | return DataLoader( 139 | self.datasets["validation"], 140 | batch_size=self.batch_size, 141 | num_workers=self.num_workers, 142 | worker_init_fn=init_fn, 143 | shuffle=shuffle, 144 | collate_fn=self.collate_fn, 145 | ) 146 | 147 | def _test_dataloader(self, shuffle=False): 148 | if self.use_worker_init_fn: 149 | init_fn = worker_init_fn 150 | else: 151 | init_fn = None 152 | 153 | if self.test_max_n_samples is not None: 154 | dataset = torch.utils.data.Subset( 155 | self.datasets["test"], list(range(self.test_max_n_samples)) 156 | ) 157 | else: 158 | dataset = self.datasets["test"] 159 | return DataLoader( 160 | dataset, 161 | batch_size=self.batch_size, 162 | num_workers=self.num_workers, 163 | worker_init_fn=init_fn, 164 | shuffle=shuffle, 165 | collate_fn=self.collate_fn, 166 | ) 167 | 168 | def _predict_dataloader(self, shuffle=False): 169 | if self.use_worker_init_fn: 170 | init_fn = worker_init_fn 171 | else: 172 | init_fn = None 173 | return DataLoader( 174 | self.datasets["predict"], 175 | batch_size=self.batch_size, 176 | num_workers=self.num_workers, 177 | worker_init_fn=init_fn, 178 | collate_fn=self.collate_fn, 179 | ) 180 | -------------------------------------------------------------------------------- /docs/case1/fkanimal2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case1/fkanimal2.gif -------------------------------------------------------------------------------- /docs/case1/gtanimal2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case1/gtanimal2.gif -------------------------------------------------------------------------------- /docs/case2/fkcloseshot1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case2/fkcloseshot1.gif -------------------------------------------------------------------------------- /docs/case2/gtcloseshot1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case2/gtcloseshot1.gif -------------------------------------------------------------------------------- /docs/case3/fkface.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case3/fkface.gif -------------------------------------------------------------------------------- /docs/case3/gtface.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case3/gtface.gif -------------------------------------------------------------------------------- /docs/case4/fkmotion4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case4/fkmotion4.gif -------------------------------------------------------------------------------- /docs/case4/gtmotion4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case4/gtmotion4.gif -------------------------------------------------------------------------------- /docs/case5/fkview7.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case5/fkview7.gif -------------------------------------------------------------------------------- /docs/case5/gtview7.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/case5/gtview7.gif -------------------------------------------------------------------------------- /docs/sota-table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/docs/sota-table.png -------------------------------------------------------------------------------- /evaluation/compute_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import math 5 | from glob import glob 6 | from skimage.metrics import structural_similarity as compare_ssim 7 | import imageio 8 | import lpips 9 | import torch 10 | from tqdm import tqdm 11 | import logging 12 | 13 | # Configure logging 14 | logging.basicConfig( 15 | level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" 16 | ) 17 | 18 | # Argument parser 19 | parser = argparse.ArgumentParser( 20 | description="Compute PSNR, SSIM, and LPIPS for videos." 21 | ) 22 | parser.add_argument( 23 | "--root1", 24 | "-r1", 25 | type=str, 26 | required=True, 27 | help="Directory for the first set of videos.", 28 | ) 29 | parser.add_argument( 30 | "--root2", 31 | "-r2", 32 | type=str, 33 | required=True, 34 | help="Directory for the second set of videos.", 35 | ) 36 | parser.add_argument("--ssim", action="store_true", default=False, help="Compute SSIM.") 37 | parser.add_argument("--psnr", action="store_true", default=False, help="Compute PSNR.") 38 | parser.add_argument( 39 | "--lpips", action="store_true", default=False, help="Compute LPIPS." 40 | ) 41 | 42 | args = parser.parse_args() 43 | 44 | # Define metric functions 45 | 46 | 47 | def compute_psnr(img1, img2): 48 | mse = np.mean((img1 / 255.0 - img2 / 255.0) ** 2) 49 | if mse < 1.0e-10: 50 | return 100 51 | PIXEL_MAX = 1 52 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 53 | 54 | 55 | def compute_ssim(img1, img2): 56 | if np.all(img1 == img1[0, 0, 0]) or np.all(img2 == img2[0, 0, 0]): 57 | return 1.0 58 | return compare_ssim(img1, img2, data_range=img1.max() - img1.min(), channel_axis=-1) 59 | 60 | 61 | def compute_lpips(img1, img2, loss_fn): 62 | img1_tensor = ( 63 | torch.from_numpy(img1 / 255.0) 64 | .float() 65 | .permute(2, 0, 1) 66 | .unsqueeze(0) 67 | .to("cuda:0") 68 | ) 69 | img2_tensor = ( 70 | torch.from_numpy(img2 / 255.0) 71 | .float() 72 | .permute(2, 0, 1) 73 | .unsqueeze(0) 74 | .to("cuda:0") 75 | ) 76 | 77 | img1_tensor = img1_tensor * 2 - 1 # Normalize to [-1, 1] 78 | img2_tensor = img2_tensor * 2 - 1 79 | 80 | return loss_fn(img1_tensor, img2_tensor).item() 81 | 82 | 83 | def read_video(file_path): 84 | try: 85 | video = imageio.get_reader(file_path) 86 | frames = [frame for frame in video] 87 | video.close() 88 | return frames 89 | except Exception as e: 90 | logging.error(f"Error reading video {file_path}: {e}") 91 | return [] 92 | 93 | 94 | def save_results(results, root1, root2, output_file="metrics.txt"): 95 | with open(output_file, "a") as f: 96 | f.write("\n") 97 | f.write(f"Root1: {root1}\n") 98 | f.write(f"Root2: {root2}\n") 99 | for metric, value in results.items(): 100 | f.write(f"{metric}: {value}\n") 101 | f.write("\n") 102 | logging.info(f"Results saved to {output_file}") 103 | 104 | 105 | def main(): 106 | # Load video paths 107 | all_videos1 = sorted(glob(os.path.join(args.root1, "*mp4"))) 108 | all_videos2 = sorted(glob(os.path.join(args.root2, "*mp4"))) 109 | 110 | assert len(all_videos1) == len( 111 | all_videos2 112 | ), f"Number of files mismatch: {len(all_videos1)} in {args.root1}, {len(all_videos2)} in {args.root2}" 113 | 114 | # Metrics storage 115 | metric_psnr = [] 116 | metric_ssim = [] 117 | metric_lpips = [] 118 | 119 | # Initialize LPIPS model if needed 120 | lpips_model = None 121 | if args.lpips: 122 | lpips_model = lpips.LPIPS(net="alex").to("cuda:0") 123 | logging.info("Initialized LPIPS model (AlexNet).") 124 | 125 | for vid1_path, vid2_path in tqdm( 126 | zip(all_videos1, all_videos2), total=len(all_videos1), desc="Processing videos" 127 | ): 128 | vid1_frames = read_video(vid1_path) 129 | vid2_frames = read_video(vid2_path) 130 | 131 | if not vid1_frames or not vid2_frames: 132 | logging.error( 133 | f"Skipping video pair due to read failure: {vid1_path}, {vid2_path}" 134 | ) 135 | continue 136 | 137 | assert len(vid1_frames) == len( 138 | vid2_frames 139 | ), f"Frame count mismatch: {len(vid1_frames)} in {vid1_path}, {len(vid2_frames)} in {vid2_path}" 140 | 141 | # Process each pair of frames 142 | for f1, f2 in zip(vid1_frames, vid2_frames): 143 | if args.psnr: 144 | try: 145 | psnr_value = compute_psnr(f1, f2) 146 | metric_psnr.append(psnr_value) 147 | except Exception as e: 148 | logging.error(f"Error computing PSNR for frames: {e}") 149 | 150 | if args.ssim: 151 | try: 152 | ssim_value = compute_ssim(f1, f2) 153 | metric_ssim.append(ssim_value) 154 | except Exception as e: 155 | logging.error(f"Error computing SSIM for frames: {e}") 156 | 157 | if args.lpips: 158 | try: 159 | lpips_value = compute_lpips(f1, f2, lpips_model) 160 | metric_lpips.append(lpips_value) 161 | except Exception as e: 162 | logging.error(f"Error computing LPIPS for frames: {e}") 163 | 164 | # Compute average metrics 165 | results = {} 166 | if args.psnr and metric_psnr: 167 | results["PSNR"] = sum(metric_psnr) / len(metric_psnr) 168 | if args.ssim and metric_ssim: 169 | results["SSIM"] = sum(metric_ssim) / len(metric_ssim) 170 | if args.lpips and metric_lpips: 171 | results["LPIPS"] = sum(metric_lpips) / len(metric_lpips) 172 | 173 | # Print and save results 174 | logging.info(f"Results: {results}") 175 | save_results(results, args.root1, args.root2) 176 | 177 | 178 | if __name__ == "__main__": 179 | main() 180 | -------------------------------------------------------------------------------- /evaluation/compute_metrics_img.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | import math 5 | from glob import glob 6 | from skimage.metrics import structural_similarity as compare_ssim 7 | import imageio 8 | import lpips 9 | import torch 10 | from tqdm import tqdm 11 | import logging 12 | 13 | # Configure logging 14 | logging.basicConfig( 15 | level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" 16 | ) 17 | 18 | # Argument parser 19 | parser = argparse.ArgumentParser( 20 | description="Calculate PSNR, SSIM, and LPIPS between two sets of images." 21 | ) 22 | parser.add_argument( 23 | "--root1", 24 | "-r1", 25 | type=str, 26 | required=True, 27 | help="Directory for the first set of images.", 28 | ) 29 | parser.add_argument( 30 | "--root2", 31 | "-r2", 32 | type=str, 33 | required=True, 34 | help="Directory for the second set of images.", 35 | ) 36 | parser.add_argument("--ssim", action="store_true", default=False, help="Compute SSIM.") 37 | parser.add_argument("--psnr", action="store_true", default=False, help="Compute PSNR.") 38 | parser.add_argument( 39 | "--lpips", action="store_true", default=False, help="Compute LPIPS." 40 | ) 41 | 42 | args = parser.parse_args() 43 | 44 | # Define metric functions 45 | 46 | 47 | def compute_psnr(img1, img2): 48 | mse = np.mean((img1 / 255.0 - img2 / 255.0) ** 2) 49 | if mse < 1.0e-10: 50 | return 100 51 | PIXEL_MAX = 1 52 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 53 | 54 | 55 | def compute_ssim(img1, img2): 56 | return compare_ssim(img1, img2, data_range=img1.max() - img1.min(), channel_axis=-1) 57 | 58 | 59 | def compute_lpips(img1, img2, loss_fn): 60 | img1_tensor = ( 61 | torch.from_numpy(img1 / 255.0) 62 | .float() 63 | .permute(2, 0, 1) 64 | .unsqueeze(0) 65 | .to("cuda:0") 66 | ) 67 | img2_tensor = ( 68 | torch.from_numpy(img2 / 255.0) 69 | .float() 70 | .permute(2, 0, 1) 71 | .unsqueeze(0) 72 | .to("cuda:0") 73 | ) 74 | 75 | img1_tensor = img1_tensor * 2 - 1 # Normalize to [-1, 1] 76 | img2_tensor = img2_tensor * 2 - 1 77 | 78 | return loss_fn(img1_tensor, img2_tensor).item() 79 | 80 | 81 | def read_image(file_path): 82 | try: 83 | return imageio.imread(file_path) 84 | except Exception as e: 85 | logging.error(f"Error reading image {file_path}: {e}") 86 | return None 87 | 88 | 89 | def save_results(results, root1, root2, output_file="metrics.txt"): 90 | with open(output_file, "a") as f: 91 | f.write("\n") 92 | f.write(f"Root1: {root1}\n") 93 | f.write(f"Root2: {root2}\n") 94 | for metric, value in results.items(): 95 | f.write(f"{metric}: {value}\n") 96 | f.write("\n") 97 | logging.info(f"Results saved to {output_file}") 98 | 99 | 100 | def main(): 101 | # Load image paths 102 | all_images1 = sorted(glob(os.path.join(args.root1, "*jpeg"))) 103 | all_images2 = sorted(glob(os.path.join(args.root2, "*jpeg"))) 104 | 105 | assert len(all_images1) == len( 106 | all_images2 107 | ), f"Number of files mismatch: {len(all_images1)} in {args.root1}, {len(all_images2)} in {args.root2}" 108 | 109 | # Metrics storage 110 | metric_psnr = [] 111 | metric_ssim = [] 112 | metric_lpips = [] 113 | 114 | lpips_model = None 115 | if args.lpips: 116 | lpips_model = lpips.LPIPS(net="alex").to("cuda:0") 117 | logging.info("Initialized LPIPS model (AlexNet).") 118 | 119 | # Compute metrics for each pair of images 120 | for i, (img1_path, img2_path) in enumerate( 121 | tqdm( 122 | zip(all_images1, all_images2), 123 | total=len(all_images1), 124 | desc="Processing images", 125 | ) 126 | ): 127 | img1 = read_image(img1_path) 128 | img2 = read_image(img2_path) 129 | if img1 is None or img2 is None: 130 | logging.warning(f"Skipping pair: {img1_path}, {img2_path}") 131 | continue 132 | 133 | if args.psnr: 134 | try: 135 | psnr_value = compute_psnr(img1, img2) 136 | metric_psnr.append(psnr_value) 137 | except Exception as e: 138 | logging.error(f"Error computing PSNR for {img1_path}, {img2_path}: {e}") 139 | 140 | if args.ssim: 141 | try: 142 | ssim_value = compute_ssim(img1, img2) 143 | metric_ssim.append(ssim_value) 144 | except Exception as e: 145 | logging.error(f"Error computing SSIM for {img1_path}, {img2_path}: {e}") 146 | 147 | if args.lpips: 148 | try: 149 | lpips_value = compute_lpips(img1, img2, lpips_model) 150 | metric_lpips.append(lpips_value) 151 | except Exception as e: 152 | logging.error( 153 | f"Error computing LPIPS for {img1_path}, {img2_path}: {e}" 154 | ) 155 | 156 | results = {} 157 | if args.psnr and metric_psnr: 158 | results["PSNR"] = sum(metric_psnr) / len(metric_psnr) 159 | if args.ssim and metric_ssim: 160 | results["SSIM"] = sum(metric_ssim) / len(metric_ssim) 161 | if args.lpips and metric_lpips: 162 | results["LPIPS"] = sum(metric_lpips) / len(metric_lpips) 163 | 164 | # Print and save results 165 | logging.info(f"Results: {results}") 166 | save_results(results, args.root1, args.root2) 167 | 168 | 169 | if __name__ == "__main__": 170 | main() 171 | -------------------------------------------------------------------------------- /examples/images/gt/00000091.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00000091.jpg -------------------------------------------------------------------------------- /examples/images/gt/00000103.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00000103.jpg -------------------------------------------------------------------------------- /examples/images/gt/00000110.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00000110.jpg -------------------------------------------------------------------------------- /examples/images/gt/00000212.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00000212.jpg -------------------------------------------------------------------------------- /examples/images/gt/00000268.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00000268.jpg -------------------------------------------------------------------------------- /examples/images/gt/00000592.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00000592.jpg -------------------------------------------------------------------------------- /examples/images/gt/00006871.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00006871.jpg -------------------------------------------------------------------------------- /examples/images/gt/00007252.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00007252.jpg -------------------------------------------------------------------------------- /examples/images/gt/00007826.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00007826.jpg -------------------------------------------------------------------------------- /examples/images/gt/00008868.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/gt/00008868.jpg -------------------------------------------------------------------------------- /examples/images/recon/00000091.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00000091.jpeg -------------------------------------------------------------------------------- /examples/images/recon/00000110.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00000110.jpeg -------------------------------------------------------------------------------- /examples/images/recon/00000212.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00000212.jpeg -------------------------------------------------------------------------------- /examples/images/recon/00000268.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00000268.jpeg -------------------------------------------------------------------------------- /examples/images/recon/00000592.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00000592.jpeg -------------------------------------------------------------------------------- /examples/images/recon/00006871.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00006871.jpeg -------------------------------------------------------------------------------- /examples/images/recon/00007252.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00007252.jpeg -------------------------------------------------------------------------------- /examples/images/recon/00007826.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00007826.jpeg -------------------------------------------------------------------------------- /examples/images/recon/00008868.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/images/recon/00008868.jpeg -------------------------------------------------------------------------------- /examples/videos/gt/40.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/gt/40.mp4 -------------------------------------------------------------------------------- /examples/videos/gt/40.txt: -------------------------------------------------------------------------------- 1 | The video features a gray tabby cat lying on a wooden deck, grooming itself in the sunlight. The cat is seen licking its paw and then using it to clean its face and fur. The background includes some greenery, indicating an outdoor setting. -------------------------------------------------------------------------------- /examples/videos/gt/8.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/gt/8.mp4 -------------------------------------------------------------------------------- /examples/videos/gt/8.txt: -------------------------------------------------------------------------------- 1 | The video showcases a glassblowing process, featuring a skilled artisan working with molten glass in a furnace. The scene is set in a dark environment with a bright orange glow from the furnace illuminating the glass and the artisan's tools. The glassblowing process involves shaping and manipulating the molten glass to create various forms, highlighting the intricate and artistic nature of this traditional craft. -------------------------------------------------------------------------------- /examples/videos/gt/animal.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/gt/animal.mp4 -------------------------------------------------------------------------------- /examples/videos/gt/animal.txt: -------------------------------------------------------------------------------- 1 | The video features a juvenile Black-crowned Night Heron perched on a branch. The heron is primarily gray and brown with white speckling on its feathers. It has a long, pointed beak and is actively preening its feathers, using its beak to smooth and clean them. The background is a soft, blurred green, suggesting foliage and possibly other trees. The lighting is natural and warm, indicating that the shot was likely taken during the day. The overall tone of the video is peaceful and focused on the natural behavior of the bird. A small logo "8 EARTH" is visible in the bottom left corner. -------------------------------------------------------------------------------- /examples/videos/gt/closeshot.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/gt/closeshot.mp4 -------------------------------------------------------------------------------- /examples/videos/gt/closeshot.txt: -------------------------------------------------------------------------------- 1 | The video features a detailed macro view of a 50mm prime lens, likely for a DSLR or mirrorless camera. The image highlights the lens's build quality, showcasing the textured grip and the aperture markings ranging from f/1.8 to f/16. The distance scale is also visible, indicating focus points. A red accent ring around the lens barrel provides a subtle visual highlight. The lens's glass reflects light, creating a soft, out-of-focus bokeh effect. This shot emphasizes the precision and design of professional camera equipment. -------------------------------------------------------------------------------- /examples/videos/gt/face.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/gt/face.mp4 -------------------------------------------------------------------------------- /examples/videos/gt/face.txt: -------------------------------------------------------------------------------- 1 | The video captures a man with a beard and glasses is talking to the camera. He is gesturing with his hands as he explains something. Behind him, there are Star Trek action figures in their boxes and other items on a wooden cabinet. The man is likely discussing or explaining a process or topic. -------------------------------------------------------------------------------- /examples/videos/gt/view.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/gt/view.mp4 -------------------------------------------------------------------------------- /examples/videos/gt/view.txt: -------------------------------------------------------------------------------- 1 | The video captures the vibrant and unique atmosphere of Fremont Street in Las Vegas. The combination of the large mechanical praying mantis, the fire effects, and the text overlay suggests a lively and engaging walking tour experience. The video is likely the title card or intro to a video about Fremont Street. -------------------------------------------------------------------------------- /examples/videos/recon/40_reconstructed.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/recon/40_reconstructed.mp4 -------------------------------------------------------------------------------- /examples/videos/recon/8_reconstructed.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/recon/8_reconstructed.mp4 -------------------------------------------------------------------------------- /examples/videos/recon/animal_reconstructed.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/recon/animal_reconstructed.mp4 -------------------------------------------------------------------------------- /examples/videos/recon/closeshot_reconstructed.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/recon/closeshot_reconstructed.mp4 -------------------------------------------------------------------------------- /examples/videos/recon/face_reconstructed.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/recon/face_reconstructed.mp4 -------------------------------------------------------------------------------- /examples/videos/recon/view_reconstructed.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VideoVerses/VideoVAEPlus/c829d91254e63dbfe0cbcb97f62c40768a099698/examples/videos/recon/view_reconstructed.mp4 -------------------------------------------------------------------------------- /inference_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | from glob import glob 5 | import argparse 6 | from omegaconf import OmegaConf 7 | from utils.common_utils import instantiate_from_config 8 | import torchvision.transforms as transforms 9 | import numpy as np 10 | from PIL import Image 11 | 12 | logging.basicConfig( 13 | level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" 14 | ) 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser(description="Image Inference Script") 19 | parser.add_argument( 20 | "--data_root", 21 | type=str, 22 | required=True, 23 | help="Path to the folder containing input images.", 24 | ) 25 | parser.add_argument( 26 | "--out_root", type=str, required=True, help="Path to save reconstructed images." 27 | ) 28 | parser.add_argument( 29 | "--config_path", 30 | type=str, 31 | required=True, 32 | help="Path to the model configuration file.", 33 | ) 34 | parser.add_argument( 35 | "--batch_size", type=int, default=16, help="Batch size for image processing." 36 | ) 37 | parser.add_argument( 38 | "--device", 39 | type=str, 40 | default="cuda:0", 41 | help="Device to run inference on (e.g., 'cpu', 'cuda:0').", 42 | ) 43 | return parser.parse_args() 44 | 45 | 46 | def data_processing(img_path): 47 | try: 48 | img = Image.open(img_path).convert("RGB") 49 | transform = transforms.Compose( 50 | [ 51 | transforms.ToTensor(), 52 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 53 | ] 54 | ) 55 | return transform(img) 56 | except Exception as e: 57 | logging.error(f"Error processing image {img_path}: {e}") 58 | return None 59 | 60 | 61 | def save_img(tensor, save_path): 62 | try: 63 | tensor = (tensor + 1) / 2 # Denormalize 64 | tensor = tensor.clamp(0, 1).detach().cpu() 65 | to_pil = transforms.ToPILImage() 66 | img = to_pil(tensor) 67 | img.save(save_path, format="JPEG") 68 | logging.info(f"Image saved to {save_path}") 69 | except Exception as e: 70 | logging.error(f"Error saving image to {save_path}: {e}") 71 | 72 | 73 | def process_batch(image_list, img_name_list, model, device, out_root): 74 | try: 75 | frames = torch.stack(image_list) # [batch_size, c, h, w] 76 | frames = frames.unsqueeze(1) # [batch_size, 1, c, h, w] 77 | frames = frames.permute(0, 2, 1, 3, 4) # [batch_size, c, 1, h, w] 78 | 79 | with torch.no_grad(): 80 | frames = frames.to(device) 81 | dec, _ = model.forward(frames, sample_posterior=False, mask_temporal=True) 82 | dec = dec.squeeze(2) # [batch_size, c, h, w] 83 | 84 | for i in range(len(image_list)): 85 | output_img = dec[i] 86 | save_img(output_img, os.path.join(out_root, img_name_list[i] + ".jpeg")) 87 | except Exception as e: 88 | logging.error(f"Error processing batch: {e}") 89 | 90 | 91 | def main(): 92 | args = parse_args() 93 | 94 | os.makedirs(args.out_root, exist_ok=True) 95 | 96 | config = OmegaConf.load(args.config_path) 97 | model = instantiate_from_config(config.model) 98 | model = model.to(args.device) 99 | model.eval() 100 | 101 | # Load all image paths 102 | all_images = sorted(glob(os.path.join(args.data_root, "*jpeg"))) 103 | if not all_images: 104 | logging.error(f"No images found in {args.data_root}") 105 | return 106 | 107 | batch_size = args.batch_size 108 | image_list = [] 109 | img_name_list = [] 110 | 111 | logging.info(f"Starting inference on {len(all_images)} images...") 112 | 113 | for img_path in all_images: 114 | img = data_processing(img_path) # [c, h, w] 115 | if img is None: 116 | logging.warning(f"Skipping invalid image {img_path}") 117 | continue 118 | 119 | img_name = os.path.basename(img_path).split(".")[0] 120 | image_list.append(img) 121 | img_name_list.append(img_name) 122 | 123 | # Process a batch when full 124 | if len(image_list) == batch_size: 125 | logging.info(f"Processing batch of {batch_size} images...") 126 | process_batch(image_list, img_name_list, model, args.device, args.out_root) 127 | 128 | # Clear lists for next batch 129 | image_list = [] 130 | img_name_list = [] 131 | 132 | # Process any remaining images 133 | if len(image_list) > 0: 134 | logging.info(f"Processing remaining {len(image_list)} images...") 135 | process_batch(image_list, img_name_list, model, args.device, args.out_root) 136 | 137 | logging.info("Inference completed successfully!") 138 | 139 | 140 | if __name__ == "__main__": 141 | main() 142 | -------------------------------------------------------------------------------- /inference_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import logging 5 | from decord import VideoReader, cpu 6 | from glob import glob 7 | from omegaconf import OmegaConf 8 | import numpy as np 9 | import imageio 10 | from tqdm import tqdm 11 | from utils.common_utils import instantiate_from_config 12 | from src.modules.t5 import T5Embedder 13 | import torchvision 14 | 15 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 16 | logging.basicConfig( 17 | level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" 18 | ) 19 | 20 | 21 | def parse_args(): 22 | """Parse command-line arguments.""" 23 | parser = argparse.ArgumentParser(description="Video VAE Inference Script") 24 | parser.add_argument( 25 | "--data_root", 26 | type=str, 27 | required=True, 28 | help="Path to the folder containing input videos.", 29 | ) 30 | parser.add_argument( 31 | "--out_root", type=str, required=True, help="Path to save reconstructed videos." 32 | ) 33 | parser.add_argument( 34 | "--config_path", 35 | type=str, 36 | required=True, 37 | help="Path to the model configuration file.", 38 | ) 39 | parser.add_argument( 40 | "--device", 41 | type=str, 42 | default="cuda:0", 43 | help="Device to run inference on (e.g., 'cpu', 'cuda:0').", 44 | ) 45 | parser.add_argument( 46 | "--chunk_size", 47 | type=int, 48 | default=16, 49 | help="Number of frames per chunk for processing.", 50 | ) 51 | parser.add_argument( 52 | "--resolution", 53 | type=int, 54 | nargs=2, 55 | default=[720, 1280], 56 | help="Resolution to process videos (height, width).", 57 | ) 58 | return parser.parse_args() 59 | 60 | 61 | def data_processing(video_path, resolution): 62 | """Load and preprocess video data.""" 63 | try: 64 | video_reader = VideoReader(video_path, ctx=cpu(0)) 65 | video_resolution = video_reader[0].shape 66 | 67 | # Rescale resolution to match specified limits 68 | resolution = [ 69 | min(video_resolution[0], resolution[0]), 70 | min(video_resolution[1], resolution[1]), 71 | ] 72 | video_reader = VideoReader( 73 | video_path, ctx=cpu(0), width=resolution[1], height=resolution[0] 74 | ) 75 | 76 | video_length = len(video_reader) 77 | vid_fps = video_reader.get_avg_fps() 78 | frame_indices = list(range(0, video_length)) 79 | frames = video_reader.get_batch(frame_indices) 80 | assert ( 81 | frames.shape[0] == video_length 82 | ), f"Frame mismatch: {len(frames)} != {video_length}" 83 | 84 | frames = ( 85 | torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() 86 | ) # [t, h, w, c] -> [c, t, h, w] 87 | frames = (frames / 255 - 0.5) * 2 # Normalize to [-1, 1] 88 | return frames, vid_fps 89 | except Exception as e: 90 | logging.error(f"Error processing video {video_path}: {e}") 91 | return None, None 92 | 93 | 94 | def save_video(tensor, save_path, fps: float): 95 | """Save video tensor to a file.""" 96 | try: 97 | tensor = torch.clamp((tensor + 1) / 2, 0, 1) * 255 98 | arr = tensor.detach().cpu().squeeze().to(torch.uint8) 99 | c, t, h, w = arr.shape 100 | 101 | torchvision.io.write_video(save_path, arr.permute(1, 2, 3, 0), fps=fps, options={'codec': 'libx264', 'crf': '15'}) 102 | logging.info(f"Video saved to {save_path}") 103 | except Exception as e: 104 | logging.error(f"Error saving video {save_path}: {e}") 105 | 106 | 107 | def process_in_chunks( 108 | video_data, 109 | model, 110 | chunk_size, 111 | text_embeddings=None, 112 | text_attn_mask=None, 113 | device="cuda:0", 114 | ): 115 | try: 116 | assert chunk_size % 4 == 0, "Chunk size must be a multiple of 4." 117 | num_frames = video_data.size(2) 118 | padding_frames = 0 119 | output_chunks = [] 120 | 121 | # Pad video to make the frame count divisible by 4 122 | if num_frames % 4 != 0: 123 | padding_frames = 4 - (num_frames % 4) 124 | padding = video_data[:, :, -1:, :, :].repeat(1, 1, padding_frames, 1, 1) 125 | video_data = torch.cat((video_data, padding), dim=2) 126 | num_frames = video_data.size(2) 127 | 128 | start = 0 129 | 130 | while start < num_frames: 131 | end = min(start + chunk_size, num_frames) 132 | chunk = video_data[:, :, start:end, :, :] 133 | 134 | with torch.no_grad(): 135 | chunk = chunk.to(device) 136 | if text_embeddings is not None and text_attn_mask is not None: 137 | recon_chunk, _ = model.forward( 138 | chunk, 139 | text_embeddings=text_embeddings, 140 | text_attn_mask=text_attn_mask, 141 | sample_posterior=False, 142 | ) 143 | else: 144 | recon_chunk, _ = model.forward(chunk, sample_posterior=False) 145 | recon_chunk = recon_chunk.cpu().float() 146 | output_chunks.append(recon_chunk) 147 | start += chunk_size 148 | 149 | ret = torch.cat(output_chunks, dim=2) 150 | if padding_frames > 0: 151 | ret = ret[:, :, :-padding_frames, :, :] 152 | return ret 153 | except Exception as e: 154 | logging.error(f"Error processing chunks: {e}") 155 | return None 156 | 157 | 158 | def main(): 159 | """Main function for video VAE inference.""" 160 | args = parse_args() 161 | 162 | os.makedirs(args.out_root, exist_ok=True) 163 | config = OmegaConf.load(args.config_path) 164 | 165 | # Initialize model 166 | model = instantiate_from_config(config.model) 167 | is_t5 = getattr(model, "caption_guide", False) 168 | model = model.to(args.device) 169 | model.eval() 170 | 171 | # Initialize text embedder if T5 is used 172 | text_embedder = None 173 | if is_t5: 174 | text_embedder = T5Embedder( 175 | device=args.device, model_max_length=model.t5_model_max_length 176 | ) 177 | 178 | # Get all videos 179 | all_videos = sorted(glob(os.path.join(args.data_root, "*.mp4"))) 180 | if not all_videos: 181 | logging.error(f"No videos found in {args.data_root}") 182 | return 183 | 184 | # Process each video 185 | for video_path in tqdm(all_videos, desc="Processing videos", unit="video"): 186 | logging.info(f"Processing video: {video_path}") 187 | frames, vid_fps = data_processing(video_path, args.resolution) 188 | if frames is None: 189 | continue 190 | 191 | video_name = os.path.basename(video_path).split(".")[0] 192 | frames = torch.unsqueeze(frames, dim=0) # Add batch dimension 193 | 194 | with torch.no_grad(): 195 | if is_t5: 196 | # Load caption if available 197 | text_path = os.path.join(args.data_root, f"{video_name}.txt") 198 | try: 199 | with open(text_path, "r") as f: 200 | caption = [f.read()] 201 | except Exception as e: 202 | logging.warning(f"Caption file not found for {video_name}: {e}") 203 | caption = [""] 204 | 205 | text_embedding, text_attn_mask = text_embedder.get_text_embeddings( 206 | caption 207 | ) 208 | text_embedding = text_embedding.to(args.device, dtype=model.dtype) 209 | text_attn_mask = text_attn_mask.to(args.device, dtype=model.dtype) 210 | 211 | video_recon = process_in_chunks( 212 | frames, 213 | model, 214 | args.chunk_size, 215 | text_embedding, 216 | text_attn_mask, 217 | device=args.device, 218 | ) 219 | else: 220 | video_recon = process_in_chunks( 221 | frames, model, args.chunk_size, device=args.device 222 | ) 223 | 224 | if video_recon is not None: 225 | save_path = os.path.join( 226 | args.out_root, f"{video_name}_reconstructed.mp4" 227 | ) 228 | save_video(video_recon, save_path, vid_fps) 229 | 230 | 231 | if __name__ == "__main__": 232 | main() 233 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | av==12.0.0 2 | accelerate==0.34.2 3 | academictorrents==2.3.3 4 | albumentations==1.4.16 5 | apex==0.9.10dev 6 | beautifulsoup4==4.12.3 7 | decord==0.6.0 8 | diffusers==0.30.3 9 | einops==0.8.0 10 | fairscale==0.4.13 11 | ftfy==6.2.3 12 | huggingface-hub==0.23.2 13 | imageio==2.33.1 14 | kornia==0.7.3 15 | moviepy==1.0.3 16 | more_itertools==10.5.0 17 | numpy==1.26.3 18 | nvidia-cublas-cu12==12.1.3.1 19 | nvidia-cuda-cupti-cu12==12.1.105 20 | nvidia-cuda-nvrtc-cu12==12.1.105 21 | nvidia-cuda-runtime-cu12==12.1.105 22 | nvidia-cudnn-cu12==8.9.2.26 23 | nvidia-cufft-cu12==11.0.2.54 24 | nvidia-curand-cu12==10.3.2.106 25 | nvidia-cusolver-cu12==11.4.5.107 26 | nvidia-cusparse-cu12==12.1.0.106 27 | nvidia-nccl-cu12==2.19.3 28 | nvidia-nvjitlink-cu12==12.1.105 29 | nvidia-nvtx-cu12==12.1.105 30 | omegaconf==2.3.0 31 | opencv-python==4.9.0.80 32 | packaging==24.0 33 | pandas==2.2.1 34 | psutil==6.0.0 35 | Pillow==10.4.0 36 | pytorch_lightning==1.9.4 37 | PyYAML==6.0.1 38 | protobuf==3.20.* 39 | Requests==2.32.3 40 | safetensors==0.4.5 41 | scipy==1.14.1 42 | sentencepiece==0.2.0 43 | tensorboard==2.18.0 44 | taming-transformers==0.0.1 45 | tensorboardX==2.6.2.2 46 | timm==1.0.9 47 | torch==2.2.0 48 | torchaudio==2.2.0 49 | torchmetrics==1.3.1 50 | torchvision==0.17.0 51 | tokenizers==0.13.3 52 | tqdm==4.66.2 53 | transformers==4.25.1 54 | typing_extensions==4.12.2 55 | xformers==0.0.24 56 | lpips -------------------------------------------------------------------------------- /scripts/evaluation_image.sh: -------------------------------------------------------------------------------- 1 | python evaluation/compute_metrics_img.py \ 2 | --root1 "examples/images/gt" \ 3 | --root2 "examples/images/recon" \ 4 | --ssim \ 5 | --psnr \ 6 | --lpips -------------------------------------------------------------------------------- /scripts/evaluation_video.sh: -------------------------------------------------------------------------------- 1 | python evaluation/compute_metrics.py \ 2 | --root1 "examples/videos/gt" \ 3 | --root2 "examples/videos/recon" \ 4 | --ssim \ 5 | --psnr \ 6 | --lpips -------------------------------------------------------------------------------- /scripts/run_inference_image.sh: -------------------------------------------------------------------------------- 1 | python inference_image.py \ 2 | --data_root 'examples/images/gt' \ 3 | --out_root 'examples/images/recon' \ 4 | --config_path 'configs/inference/config_16z.yaml' \ 5 | --batch_size 1 -------------------------------------------------------------------------------- /scripts/run_inference_video.sh: -------------------------------------------------------------------------------- 1 | python inference_video.py \ 2 | --data_root 'examples/videos/gt' \ 3 | --out_root 'examples/videos/recon' \ 4 | --config_path 'configs/inference/config_16z.yaml' \ 5 | --chunk_size 8 --resolution 720 1280 6 | -------------------------------------------------------------------------------- /scripts/run_train.sh: -------------------------------------------------------------------------------- 1 | yaml="configs/train/$1.yaml" 2 | exp_name="VideoVAEPlus_$1" 3 | 4 | n_HOST=1 5 | elastic=1 6 | GPUName="A" 7 | current_time=$(date +%Y%m%d%H%M%S) 8 | 9 | out_dir_name="${exp_name}_${n_HOST}nodes_e${elastic}_${GPUName}_$current_time" 10 | res_root="./debug" 11 | 12 | mkdir -p $res_root/$out_dir_name 13 | 14 | torchrun \ 15 | --nproc_per_node=1 --nnodes=1 --master_port=16666 \ 16 | train.py \ 17 | --base $yaml \ 18 | -t --devices 0, \ 19 | lightning.trainer.num_nodes=1 \ 20 | --name ${out_dir_name} \ 21 | --logdir $res_root \ 22 | --auto_resume True \ -------------------------------------------------------------------------------- /src/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self, noise=None): 38 | if noise is None: 39 | noise = torch.randn(self.mean.shape) 40 | 41 | x = self.mean + self.std * noise.to(device=self.parameters.device) 42 | return x 43 | 44 | def kl(self, other=None): 45 | if self.deterministic: 46 | return torch.Tensor([0.0]) 47 | else: 48 | if other is None: 49 | return 0.5 * torch.sum( 50 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 51 | dim=[1, 2, 3], 52 | ) 53 | else: 54 | return 0.5 * torch.sum( 55 | torch.pow(self.mean - other.mean, 2) / other.var 56 | + self.var / other.var 57 | - 1.0 58 | - self.logvar 59 | + other.logvar, 60 | dim=[1, 2, 3], 61 | ) 62 | 63 | def nll(self, sample, dims=[1, 2, 3]): 64 | if self.deterministic: 65 | return torch.Tensor([0.0]) 66 | logtwopi = np.log(2.0 * np.pi) 67 | return 0.5 * torch.sum( 68 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 69 | dim=dims, 70 | ) 71 | 72 | def mode(self): 73 | return self.mean 74 | 75 | 76 | def normal_kl(mean1, logvar1, mean2, logvar2): 77 | """ 78 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 79 | Compute the KL divergence between two gaussians. 80 | Shapes are automatically broadcasted, so batches can be compared to 81 | scalars, among other use cases. 82 | """ 83 | tensor = None 84 | for obj in (mean1, logvar1, mean2, logvar2): 85 | if isinstance(obj, torch.Tensor): 86 | tensor = obj 87 | break 88 | assert tensor is not None, "at least one argument must be a Tensor" 89 | 90 | # Force variances to be Tensors. Broadcasting helps convert scalars to 91 | # Tensors, but it does not work for torch.exp(). 92 | logvar1, logvar2 = [ 93 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 94 | for x in (logvar1, logvar2) 95 | ] 96 | 97 | return 0.5 * ( 98 | -1.0 99 | + logvar2 100 | - logvar1 101 | + torch.exp(logvar1 - logvar2) 102 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 103 | ) 104 | -------------------------------------------------------------------------------- /src/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | 4 | import torch.nn.functional as F 5 | import pytorch_lightning as pl 6 | 7 | from src.modules.ae_modules import Encoder, Decoder 8 | from src.distributions import DiagonalGaussianDistribution 9 | from utils.common_utils import instantiate_from_config 10 | 11 | 12 | class AutoencoderKL(pl.LightningModule): 13 | def __init__( 14 | self, 15 | ddconfig, 16 | lossconfig, 17 | embed_dim, 18 | use_quant_conv=True, 19 | ckpt_path=None, 20 | ignore_keys=[], 21 | image_key="image", 22 | colorize_nlabels=None, 23 | monitor=None, 24 | test=False, 25 | logdir=None, 26 | input_dim=4, 27 | test_args=None, 28 | ): 29 | super().__init__() 30 | self.image_key = image_key 31 | self.encoder = Encoder(**ddconfig) 32 | self.decoder = Decoder(**ddconfig) 33 | self.loss = instantiate_from_config(lossconfig) 34 | assert ddconfig["double_z"] 35 | 36 | if use_quant_conv: 37 | self.quant_conv = torch.nn.Conv2d( 38 | 2 * ddconfig["z_channels"], 2 * embed_dim, 1 39 | ) 40 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 41 | self.embed_dim = embed_dim 42 | 43 | self.use_quant_conv = use_quant_conv 44 | 45 | self.input_dim = input_dim 46 | self.test = test 47 | self.test_args = test_args 48 | self.logdir = logdir 49 | if colorize_nlabels is not None: 50 | assert type(colorize_nlabels) == int 51 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 52 | if monitor is not None: 53 | self.monitor = monitor 54 | if ckpt_path is not None: 55 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 56 | 57 | def init_from_ckpt(self, path, ignore_keys=list()): 58 | sd = torch.load(path, map_location="cpu") 59 | try: 60 | self._cur_epoch = sd["epoch"] 61 | sd = sd["state_dict"] 62 | except: 63 | self._cur_epoch = "null" 64 | keys = list(sd.keys()) 65 | for k in keys: 66 | for ik in ignore_keys: 67 | if k.startswith(ik): 68 | # print("Deleting key {} from state_dict.".format(k)) 69 | del sd[k] 70 | self.load_state_dict(sd, strict=False) 71 | # self.load_state_dict(sd, strict=True) 72 | print(f"Restored from {path}") 73 | 74 | def encode(self, x, **kwargs): 75 | 76 | h = self.encoder(x) 77 | moments = h 78 | if self.use_quant_conv: 79 | moments = self.quant_conv(h) 80 | posterior = DiagonalGaussianDistribution(moments) 81 | return posterior 82 | 83 | def decode(self, z, **kwargs): 84 | if self.use_quant_conv: 85 | z = self.post_quant_conv(z) 86 | dec = self.decoder(z) 87 | return dec 88 | 89 | def forward(self, input, sample_posterior=True): 90 | posterior = self.encode(input) 91 | if sample_posterior: 92 | z = posterior.sample() 93 | else: 94 | z = posterior.mode() 95 | dec = self.decode(z) 96 | return dec, posterior 97 | 98 | def get_input(self, batch, k): 99 | x = batch[k] 100 | if x.dim() == 5 and self.input_dim == 4: 101 | b, c, t, h, w = x.shape 102 | self.b = b 103 | self.t = t 104 | x = rearrange(x, "b c t h w -> (b t) c h w") 105 | 106 | return x 107 | 108 | def training_step(self, batch, batch_idx, optimizer_idx): 109 | inputs = self.get_input(batch, self.image_key) 110 | reconstructions, posterior = self(inputs) 111 | 112 | if optimizer_idx == 0: 113 | # train encoder+decoder+logvar 114 | aeloss, log_dict_ae = self.loss( 115 | inputs, 116 | reconstructions, 117 | posterior, 118 | optimizer_idx, 119 | self.global_step, 120 | last_layer=self.get_last_layer(), 121 | split="train", 122 | ) 123 | self.log( 124 | "aeloss", 125 | aeloss, 126 | prog_bar=True, 127 | logger=True, 128 | on_step=True, 129 | on_epoch=True, 130 | ) 131 | self.log_dict( 132 | log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False 133 | ) 134 | return aeloss 135 | 136 | if optimizer_idx == 1: 137 | # train the discriminator 138 | discloss, log_dict_disc = self.loss( 139 | inputs, 140 | reconstructions, 141 | posterior, 142 | optimizer_idx, 143 | self.global_step, 144 | last_layer=self.get_last_layer(), 145 | split="train", 146 | ) 147 | 148 | self.log( 149 | "discloss", 150 | discloss, 151 | prog_bar=True, 152 | logger=True, 153 | on_step=True, 154 | on_epoch=True, 155 | ) 156 | self.log_dict( 157 | log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False 158 | ) 159 | return discloss 160 | 161 | def validation_step(self, batch, batch_idx): 162 | inputs = self.get_input(batch, self.image_key) 163 | reconstructions, posterior = self(inputs) 164 | aeloss, log_dict_ae = self.loss( 165 | inputs, 166 | reconstructions, 167 | posterior, 168 | 0, 169 | self.global_step, 170 | last_layer=self.get_last_layer(), 171 | split="val", 172 | ) 173 | 174 | discloss, log_dict_disc = self.loss( 175 | inputs, 176 | reconstructions, 177 | posterior, 178 | 1, 179 | self.global_step, 180 | last_layer=self.get_last_layer(), 181 | split="val", 182 | ) 183 | 184 | recontructions = reconstructions.cpu().detach() 185 | 186 | self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) 187 | self.log_dict(log_dict_ae) 188 | self.log_dict(log_dict_disc) 189 | return self.log_dict 190 | 191 | def configure_optimizers(self): 192 | lr = self.learning_rate 193 | opt_ae = torch.optim.Adam( 194 | list(self.encoder.parameters()) 195 | + list(self.decoder.parameters()) 196 | + list(self.quant_conv.parameters()) 197 | + list(self.post_quant_conv.parameters()), 198 | lr=lr, 199 | betas=(0.5, 0.9), 200 | ) 201 | opt_disc = torch.optim.Adam( 202 | self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) 203 | ) 204 | return [opt_ae, opt_disc], [] 205 | 206 | def get_last_layer(self): 207 | return self.decoder.conv_out.weight 208 | 209 | @torch.no_grad() 210 | def log_images(self, batch, only_inputs=False, **kwargs): 211 | log = dict() 212 | x = self.get_input(batch, self.image_key) 213 | x = x.to(self.device) 214 | if not only_inputs: 215 | xrec, posterior = self(x) 216 | if x.shape[1] > 3: 217 | # colorize with random projection 218 | assert xrec.shape[1] > 3 219 | x = self.to_rgb(x) 220 | xrec = self.to_rgb(xrec) 221 | 222 | log["samples"] = self.decode(torch.randn_like(posterior.sample())) 223 | xrec = xrec.cpu().detach() 224 | log["reconstructions"] = xrec 225 | 226 | x = x.cpu().detach() 227 | log["inputs"] = x 228 | return log 229 | 230 | def to_rgb(self, x): 231 | assert self.image_key == "segmentation" 232 | if not hasattr(self, "colorize"): 233 | self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) 234 | x = F.conv2d(x, weight=self.colorize) 235 | x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 236 | return x 237 | 238 | 239 | class IdentityFirstStage(torch.nn.Module): 240 | def __init__(self, *args, vq_interface=False, **kwargs): 241 | # TODO: Should be true by default but check to not break older stuff 242 | self.vq_interface = vq_interface 243 | super().__init__() 244 | 245 | def encode(self, x, *args, **kwargs): 246 | return x 247 | 248 | def decode(self, x, *args, **kwargs): 249 | return x 250 | 251 | def quantize(self, x, *args, **kwargs): 252 | if self.vq_interface: 253 | return x, None, [None, None, None] 254 | return x 255 | 256 | def forward(self, x, *args, **kwargs): 257 | return x 258 | -------------------------------------------------------------------------------- /src/models/autoencoder_temporal.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from src.modules.attention_temporal_videoae import * 6 | from einops import rearrange, reduce, repeat 7 | 8 | try: 9 | import xformers 10 | import xformers.ops as xops 11 | 12 | XFORMERS_IS_AVAILBLE = True 13 | except: 14 | XFORMERS_IS_AVAILBLE = False 15 | 16 | 17 | def silu(x): 18 | # swish 19 | return x * torch.sigmoid(x) 20 | 21 | 22 | class SiLU(nn.Module): 23 | def __init__(self): 24 | super(SiLU, self).__init__() 25 | 26 | def forward(self, x): 27 | return silu(x) 28 | 29 | 30 | def Normalize(in_channels, norm_type="group"): 31 | assert norm_type in ["group", "batch"] 32 | if norm_type == "group": 33 | return torch.nn.GroupNorm( 34 | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True 35 | ) 36 | elif norm_type == "batch": 37 | return torch.nn.SyncBatchNorm(in_channels) 38 | 39 | 40 | # Does not support dilation 41 | 42 | 43 | class SamePadConv3d(nn.Module): 44 | def __init__( 45 | self, 46 | in_channels, 47 | out_channels, 48 | kernel_size, 49 | stride=1, 50 | bias=True, 51 | padding_type="replicate", 52 | ): 53 | super().__init__() 54 | if isinstance(kernel_size, int): 55 | kernel_size = (kernel_size,) * 3 56 | if isinstance(stride, int): 57 | stride = (stride,) * 3 58 | 59 | # assumes that the input shape is divisible by stride 60 | total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) 61 | pad_input = [] 62 | for p in total_pad[::-1]: # reverse since F.pad starts from last dim 63 | pad_input.append((p // 2 + p % 2, p // 2)) 64 | pad_input = sum(pad_input, tuple()) 65 | self.pad_input = pad_input 66 | self.padding_type = padding_type 67 | 68 | self.conv = nn.Conv3d( 69 | in_channels, out_channels, kernel_size, stride=stride, padding=0, bias=bias 70 | ) 71 | 72 | def forward(self, x): 73 | # print(x.dtype) 74 | return self.conv(F.pad(x, self.pad_input, mode=self.padding_type)) 75 | 76 | 77 | class SamePadConvTranspose3d(nn.Module): 78 | def __init__( 79 | self, 80 | in_channels, 81 | out_channels, 82 | kernel_size, 83 | stride=1, 84 | bias=True, 85 | padding_type="replicate", 86 | ): 87 | super().__init__() 88 | if isinstance(kernel_size, int): 89 | kernel_size = (kernel_size,) * 3 90 | if isinstance(stride, int): 91 | stride = (stride,) * 3 92 | 93 | total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) 94 | pad_input = [] 95 | for p in total_pad[::-1]: # reverse since F.pad starts from last dim 96 | pad_input.append((p // 2 + p % 2, p // 2)) 97 | pad_input = sum(pad_input, tuple()) 98 | self.pad_input = pad_input 99 | self.padding_type = padding_type 100 | 101 | self.convt = nn.ConvTranspose3d( 102 | in_channels, 103 | out_channels, 104 | kernel_size, 105 | stride=stride, 106 | bias=bias, 107 | padding=tuple([k - 1 for k in kernel_size]), 108 | ) 109 | 110 | def forward(self, x): 111 | return self.convt(F.pad(x, self.pad_input, mode=self.padding_type)) 112 | 113 | 114 | class ResBlock(nn.Module): 115 | def __init__( 116 | self, 117 | in_channels, 118 | out_channels=None, 119 | conv_shortcut=False, 120 | dropout=0.0, 121 | norm_type="group", 122 | padding_type="replicate", 123 | ): 124 | super().__init__() 125 | self.in_channels = in_channels 126 | out_channels = in_channels if out_channels is None else out_channels 127 | self.out_channels = out_channels 128 | self.use_conv_shortcut = conv_shortcut 129 | 130 | self.norm1 = Normalize(in_channels, norm_type) 131 | self.conv1 = SamePadConv3d( 132 | in_channels, out_channels, kernel_size=3, padding_type=padding_type 133 | ) 134 | self.dropout = torch.nn.Dropout(dropout) 135 | self.norm2 = Normalize(in_channels, norm_type) 136 | self.conv2 = SamePadConv3d( 137 | out_channels, out_channels, kernel_size=3, padding_type=padding_type 138 | ) 139 | if self.in_channels != self.out_channels: 140 | self.conv_shortcut = SamePadConv3d( 141 | in_channels, out_channels, kernel_size=3, padding_type=padding_type 142 | ) 143 | 144 | def forward(self, x): 145 | h = x 146 | h = self.norm1(h) 147 | h = silu(h) 148 | h = self.conv1(h) 149 | h = self.norm2(h) 150 | h = silu(h) 151 | h = self.conv2(h) 152 | 153 | if self.in_channels != self.out_channels: 154 | x = self.conv_shortcut(x) 155 | 156 | return x + h 157 | 158 | 159 | class SpatialCrossAttention(nn.Module): 160 | def __init__( 161 | self, 162 | query_dim, 163 | patch_size=1, 164 | context_dim=None, 165 | heads=8, 166 | dim_head=64, 167 | dropout=0.0, 168 | ): 169 | super().__init__() 170 | inner_dim = dim_head * heads 171 | context_dim = default(context_dim, query_dim) 172 | 173 | self.scale = dim_head**-0.5 174 | self.heads = heads 175 | self.dim_head = dim_head 176 | 177 | # print(f"query dimension is {query_dim}") 178 | 179 | self.patch_size = patch_size 180 | patch_dim = query_dim * patch_size * patch_size 181 | self.norm = nn.LayerNorm(patch_dim) 182 | 183 | self.to_q = nn.Linear(patch_dim, inner_dim, bias=False) 184 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 185 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 186 | 187 | self.to_out = nn.Sequential( 188 | nn.Linear(inner_dim, patch_dim), nn.Dropout(dropout) 189 | ) 190 | self.attention_op: Optional[Any] = None 191 | 192 | def forward(self, x, context=None, mask=None): 193 | b, c, t, height, width = x.shape 194 | 195 | # patch: [patch_size, patch_size] 196 | divide_factor_height = height // self.patch_size 197 | divide_factor_width = width // self.patch_size 198 | x = rearrange( 199 | x, 200 | "b c t (df1 ph) (df2 pw) -> (b t) (df1 df2) (ph pw c)", 201 | df1=divide_factor_height, 202 | df2=divide_factor_width, 203 | ph=self.patch_size, 204 | pw=self.patch_size, 205 | ) 206 | x = self.norm(x) 207 | 208 | context = default(context, x) 209 | context = repeat(context, "b n d -> (b t) n d", b=b, t=t) 210 | 211 | q = self.to_q(x) 212 | k = self.to_k(context) 213 | v = self.to_v(context) 214 | 215 | q, k, v = map( 216 | lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=self.heads), (q, k, v) 217 | ) 218 | 219 | if exists(mask): 220 | mask = rearrange(mask, "b ... -> b (...)") 221 | mask = repeat(mask, "b j -> (b t h) () j", t=t, h=self.heads) 222 | 223 | if XFORMERS_IS_AVAILBLE: 224 | if exists(mask): 225 | mask = mask.to(q.dtype) 226 | max_neg_value = -torch.finfo(q.dtype).max 227 | 228 | attn_bias = torch.zeros_like(mask) 229 | attn_bias.masked_fill_(mask <= 0.5, max_neg_value) 230 | 231 | mask = mask.detach().cpu() 232 | attn_bias = attn_bias.expand(-1, q.shape[1], -1) 233 | 234 | attn_bias_expansion_q = (attn_bias.shape[1] + 7) // 8 * 8 235 | attn_bias_expansion_k = (attn_bias.shape[2] + 7) // 8 * 8 236 | 237 | attn_bias_expansion = torch.zeros( 238 | (attn_bias.shape[0], attn_bias_expansion_q, attn_bias_expansion_k), 239 | dtype=attn_bias.dtype, 240 | device=attn_bias.device, 241 | ) 242 | attn_bias_expansion[:, : attn_bias.shape[1], : attn_bias.shape[2]] = ( 243 | attn_bias 244 | ) 245 | 246 | attn_bias = attn_bias.detach().cpu() 247 | 248 | out = xops.memory_efficient_attention( 249 | q, 250 | k, 251 | v, 252 | attn_bias=attn_bias_expansion[ 253 | :, : attn_bias.shape[1], : attn_bias.shape[2] 254 | ], 255 | scale=self.scale, 256 | ) 257 | else: 258 | out = xops.memory_efficient_attention(q, k, v, scale=self.scale) 259 | else: 260 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 261 | if exists(mask): 262 | max_neg_value = -torch.finfo(sim.dtype).max 263 | sim.masked_fill_(~(mask > 0.5), max_neg_value) 264 | attn = sim.softmax(dim=-1) 265 | out = einsum("b i j, b j d -> b i d", attn, v) 266 | 267 | out = rearrange(out, "(b h) n d -> b n (h d)", h=self.heads) 268 | 269 | ret = self.to_out(out) 270 | ret = rearrange( 271 | ret, 272 | "(b t) (df1 df2) (ph pw c) -> b c t (df1 ph) (df2 pw)", 273 | b=b, 274 | t=t, 275 | df1=divide_factor_height, 276 | df2=divide_factor_width, 277 | ph=self.patch_size, 278 | pw=self.patch_size, 279 | ) 280 | return ret 281 | 282 | 283 | # ---------------------------------------------------------------------------------------------------= 284 | 285 | 286 | class EncoderTemporal1DCNN(nn.Module): 287 | def __init__( 288 | self, 289 | *, 290 | ch, 291 | out_ch, 292 | attn_temporal_factor=[], 293 | temporal_scale_factor=4, 294 | hidden_channel=128, 295 | **ignore_kwargs 296 | ): 297 | super().__init__() 298 | 299 | self.ch = ch 300 | self.temb_ch = 0 301 | self.temporal_scale_factor = temporal_scale_factor 302 | 303 | # conv_in + resblock + down_block + resblock + down_block + final_block 304 | self.conv_in = SamePadConv3d( 305 | ch, hidden_channel, kernel_size=3, padding_type="replicate" 306 | ) 307 | 308 | self.mid_blocks = nn.ModuleList() 309 | 310 | num_ds = int(math.log2(temporal_scale_factor)) 311 | norm_type = "group" 312 | 313 | curr_temporal_factor = 1 314 | for i in range(num_ds): 315 | block = nn.Module() 316 | # compute in_ch, out_ch, stride 317 | in_channels = hidden_channel * 2**i 318 | out_channels = hidden_channel * 2 ** (i + 1) 319 | temporal_stride = 2 320 | curr_temporal_factor = curr_temporal_factor * 2 321 | 322 | block.down = SamePadConv3d( 323 | in_channels, 324 | out_channels, 325 | kernel_size=3, 326 | stride=(temporal_stride, 1, 1), 327 | padding_type="replicate", 328 | ) 329 | block.res = ResBlock(out_channels, out_channels, norm_type=norm_type) 330 | 331 | block.attn = nn.ModuleList() 332 | if curr_temporal_factor in attn_temporal_factor: 333 | block.attn.append( 334 | SpatialCrossAttention(query_dim=out_channels, context_dim=1024) 335 | ) 336 | 337 | self.mid_blocks.append(block) 338 | # n_times_downsample -= 1 339 | 340 | self.final_block = nn.Sequential( 341 | Normalize(out_channels, norm_type), 342 | SiLU(), 343 | SamePadConv3d( 344 | out_channels, out_ch * 2, kernel_size=3, padding_type="replicate" 345 | ), 346 | ) 347 | 348 | self.initialize_weights() 349 | 350 | def initialize_weights(self): 351 | # Initialize transformer layers: 352 | def _basic_init(module): 353 | if isinstance(module, nn.Linear): 354 | if module.weight.requires_grad_: 355 | torch.nn.init.xavier_uniform_(module.weight) 356 | if module.bias is not None: 357 | nn.init.constant_(module.bias, 0) 358 | if isinstance(module, nn.Conv3d): 359 | torch.nn.init.xavier_uniform_(module.weight) 360 | if module.bias is not None: 361 | nn.init.constant_(module.bias, 0) 362 | 363 | self.apply(_basic_init) 364 | 365 | def forward(self, x, text_embeddings=None, text_attn_mask=None): 366 | # x: [b c t h w] 367 | # x: [1, 4, 16, 32, 32] 368 | # timestep embedding 369 | h = self.conv_in(x) 370 | for block in self.mid_blocks: 371 | h = block.down(h) 372 | h = block.res(h) 373 | if len(block.attn) > 0: 374 | for attn in block.attn: 375 | h = attn(h, context=text_embeddings, mask=text_attn_mask) + h 376 | 377 | h = self.final_block(h) 378 | 379 | return h 380 | 381 | 382 | class TemporalUpsample(nn.Module): 383 | def __init__( 384 | self, size=None, scale_factor=None, mode="nearest", align_corners=None 385 | ): 386 | super(TemporalUpsample, self).__init__() 387 | self.size = size 388 | self.scale_factor = scale_factor 389 | self.mode = mode 390 | self.align_corners = align_corners 391 | 392 | def forward(self, x): 393 | return F.interpolate( 394 | x, 395 | size=self.size, 396 | scale_factor=self.scale_factor, 397 | mode=self.mode, 398 | align_corners=self.align_corners, 399 | ) 400 | 401 | 402 | class DecoderTemporal1DCNN(nn.Module): 403 | def __init__( 404 | self, 405 | *, 406 | ch, 407 | out_ch, 408 | attn_temporal_factor=[], 409 | temporal_scale_factor=4, 410 | hidden_channel=128, 411 | **ignore_kwargs 412 | ): 413 | super().__init__() 414 | 415 | self.ch = ch 416 | self.temb_ch = 0 417 | self.temporal_scale_factor = temporal_scale_factor 418 | 419 | num_us = int(math.log2(temporal_scale_factor)) 420 | norm_type = "group" 421 | 422 | # conv_in, mid_blocks, final_block 423 | # out channel of encoder, before the last conv layer 424 | enc_out_channels = hidden_channel * 2**num_us 425 | self.conv_in = SamePadConv3d( 426 | ch, enc_out_channels, kernel_size=3, padding_type="replicate" 427 | ) 428 | 429 | self.mid_blocks = nn.ModuleList() 430 | curr_temporal_factor = self.temporal_scale_factor 431 | 432 | for i in range(num_us): 433 | block = nn.Module() 434 | in_channels = ( 435 | enc_out_channels if i == 0 else hidden_channel * 2 ** (num_us - i + 1) 436 | ) # max_us: 3 437 | out_channels = hidden_channel * 2 ** (num_us - i) 438 | temporal_stride = 2 439 | # block.up = SamePadConvTranspose3d(in_channels, out_channels, kernel_size=3, stride=(temporal_stride, 1, 1)) 440 | block.up = torch.nn.ConvTranspose3d( 441 | in_channels, 442 | out_channels, 443 | kernel_size=(3, 3, 3), 444 | stride=(2, 1, 1), 445 | padding=(1, 1, 1), 446 | output_padding=(1, 0, 0), 447 | ) 448 | block.res1 = ResBlock(out_channels, out_channels, norm_type=norm_type) 449 | block.attn1 = nn.ModuleList() 450 | 451 | if curr_temporal_factor in attn_temporal_factor: 452 | block.attn1.append( 453 | SpatialCrossAttention(query_dim=out_channels, context_dim=1024) 454 | ) 455 | 456 | block.res2 = ResBlock(out_channels, out_channels, norm_type=norm_type) 457 | 458 | block.attn2 = nn.ModuleList() 459 | if curr_temporal_factor in attn_temporal_factor: 460 | block.attn2.append( 461 | SpatialCrossAttention(query_dim=out_channels, context_dim=1024) 462 | ) 463 | 464 | curr_temporal_factor = curr_temporal_factor / 2 465 | self.mid_blocks.append(block) 466 | 467 | self.conv_last = SamePadConv3d(out_channels, out_ch, kernel_size=3) 468 | 469 | self.initialize_weights() 470 | 471 | def initialize_weights(self): 472 | # Initialize transformer layers: 473 | def _basic_init(module): 474 | if isinstance(module, nn.Linear): 475 | if module.weight.requires_grad_: 476 | torch.nn.init.xavier_uniform_(module.weight) 477 | if module.bias is not None: 478 | nn.init.constant_(module.bias, 0) 479 | if isinstance(module, nn.Conv3d): 480 | torch.nn.init.xavier_uniform_(module.weight) 481 | if module.bias is not None: 482 | nn.init.constant_(module.bias, 0) 483 | if isinstance(module, nn.ConvTranspose3d): 484 | torch.nn.init.xavier_uniform_(module.weight) 485 | if module.bias is not None: 486 | nn.init.constant_(module.bias, 0) 487 | 488 | self.apply(_basic_init) 489 | 490 | def forward(self, x, text_embeddings=None, text_attn_mask=None): 491 | # x: [b c t h w] 492 | h = self.conv_in(x) 493 | for i, block in enumerate(self.mid_blocks): 494 | h = block.up(h) 495 | h = block.res1(h) 496 | if len(block.attn1) > 0: 497 | for attn in block.attn1: 498 | h = attn(h, context=text_embeddings, mask=text_attn_mask) + h 499 | 500 | h = block.res2(h) 501 | if len(block.attn2) > 0: 502 | for attn in block.attn2: 503 | h = attn(h, context=text_embeddings, mask=text_attn_mask) + h 504 | 505 | h = self.conv_last(h) 506 | 507 | return h 508 | -------------------------------------------------------------------------------- /src/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from src.modules.losses.contperceptual import ( 2 | LPIPSWithDiscriminator, 3 | MSEWithDiscriminator, 4 | LPIPSWithDiscriminator3D, 5 | ) 6 | -------------------------------------------------------------------------------- /src/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | import functools 6 | 7 | 8 | class LPIPSWithDiscriminator(nn.Module): 9 | def __init__( 10 | self, 11 | disc_start, 12 | logvar_init=0.0, 13 | kl_weight=1.0, 14 | pixelloss_weight=1.0, 15 | disc_num_layers=3, 16 | disc_in_channels=3, 17 | disc_factor=1.0, 18 | disc_weight=1.0, 19 | perceptual_weight=1.0, 20 | use_actnorm=False, 21 | disc_conditional=False, 22 | disc_loss="hinge", 23 | max_bs=None, 24 | ): 25 | 26 | super().__init__() 27 | assert disc_loss in ["hinge", "vanilla"] 28 | self.kl_weight = kl_weight 29 | self.pixel_weight = pixelloss_weight 30 | self.perceptual_loss = LPIPS().eval() 31 | self.perceptual_weight = perceptual_weight 32 | # output log variance 33 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 34 | 35 | self.discriminator = NLayerDiscriminator( 36 | input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm 37 | ).apply(weights_init) 38 | self.discriminator_iter_start = disc_start 39 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 40 | self.disc_factor = disc_factor 41 | self.discriminator_weight = disc_weight 42 | self.disc_conditional = disc_conditional 43 | self.max_bs = max_bs 44 | 45 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 46 | if last_layer is not None: 47 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 48 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 49 | else: 50 | nll_grads = torch.autograd.grad( 51 | nll_loss, self.last_layer[0], retain_graph=True 52 | )[0] 53 | g_grads = torch.autograd.grad( 54 | g_loss, self.last_layer[0], retain_graph=True 55 | )[0] 56 | 57 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 58 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 59 | d_weight = d_weight * self.discriminator_weight 60 | return d_weight 61 | 62 | def forward( 63 | self, 64 | inputs, 65 | reconstructions, 66 | posteriors, 67 | optimizer_idx, 68 | global_step, 69 | last_layer=None, 70 | cond=None, 71 | split="train", 72 | weights=None, 73 | ): 74 | if inputs.dim() == 5: 75 | inputs = rearrange(inputs, "b c t h w -> (b t) c h w") 76 | if reconstructions.dim() == 5: 77 | reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w") 78 | 79 | # print('loss shape: ', inputs.shape, reconstructions.shape) 80 | # exit() 81 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 82 | if self.perceptual_weight > 0: 83 | if self.max_bs is not None and self.max_bs < inputs.shape[0]: 84 | input_list = torch.split(inputs, self.max_bs, dim=0) 85 | reconstruction_list = torch.split(reconstructions, self.max_bs, dim=0) 86 | p_losses = [ 87 | self.perceptual_loss( 88 | inputs.contiguous(), reconstructions.contiguous() 89 | ) 90 | for inputs, reconstructions in zip(input_list, reconstruction_list) 91 | ] 92 | p_loss = torch.cat(p_losses, dim=0) 93 | else: 94 | p_loss = self.perceptual_loss( 95 | inputs.contiguous(), reconstructions.contiguous() 96 | ) 97 | rec_loss = rec_loss + self.perceptual_weight * p_loss 98 | 99 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 100 | weighted_nll_loss = nll_loss 101 | if weights is not None: 102 | weighted_nll_loss = weights * nll_loss 103 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 104 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 105 | 106 | kl_loss = posteriors.kl() 107 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 108 | 109 | if global_step < self.discriminator_iter_start: 110 | loss = weighted_nll_loss + self.kl_weight * kl_loss 111 | log = { 112 | "{}/total_loss".format(split): loss.clone().detach().mean(), 113 | "{}/logvar".format(split): self.logvar.detach(), 114 | "{}/kl_loss".format(split): kl_loss.detach().mean(), 115 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 116 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 117 | } 118 | 119 | return loss, log 120 | 121 | # now the GAN part 122 | if optimizer_idx == 0: 123 | # generator update 124 | if cond is None: 125 | assert not self.disc_conditional 126 | logits_fake = self.discriminator(reconstructions.contiguous()) 127 | else: 128 | assert self.disc_conditional 129 | logits_fake = self.discriminator( 130 | torch.cat((reconstructions.contiguous(), cond), dim=1) 131 | ) 132 | g_loss = -torch.mean(logits_fake) 133 | 134 | if self.disc_factor > 0.0: 135 | try: 136 | d_weight = self.calculate_adaptive_weight( 137 | nll_loss, g_loss, last_layer=last_layer 138 | ) 139 | except RuntimeError: 140 | assert not self.training 141 | d_weight = torch.tensor(0.0) 142 | else: 143 | d_weight = torch.tensor(0.0) 144 | 145 | disc_factor = adopt_weight( 146 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 147 | ) 148 | loss = ( 149 | weighted_nll_loss 150 | + self.kl_weight * kl_loss 151 | + d_weight * disc_factor * g_loss 152 | ) 153 | 154 | log = { 155 | "{}/total_loss".format(split): loss.clone().detach().mean(), 156 | "{}/logvar".format(split): self.logvar.detach(), 157 | "{}/kl_loss".format(split): kl_loss.detach().mean(), 158 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 159 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 160 | "{}/d_weight".format(split): d_weight.detach(), 161 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 162 | "{}/g_loss".format(split): g_loss.detach().mean(), 163 | } 164 | return loss, log 165 | 166 | if optimizer_idx == 1: 167 | # second pass for discriminator update 168 | if cond is None: 169 | logits_real = self.discriminator(inputs.contiguous().detach()) 170 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 171 | else: 172 | logits_real = self.discriminator( 173 | torch.cat((inputs.contiguous().detach(), cond), dim=1) 174 | ) 175 | logits_fake = self.discriminator( 176 | torch.cat((reconstructions.contiguous().detach(), cond), dim=1) 177 | ) 178 | 179 | disc_factor = adopt_weight( 180 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 181 | ) 182 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 183 | 184 | log = { 185 | "{}/disc_loss".format(split): d_loss.clone().detach().mean(), 186 | "{}/logits_real".format(split): logits_real.detach().mean(), 187 | "{}/logits_fake".format(split): logits_fake.detach().mean(), 188 | } 189 | return d_loss, log 190 | 191 | 192 | ### Modified for 1dcnn lpips -> mse 193 | 194 | 195 | class MSEWithDiscriminator(nn.Module): 196 | def __init__( 197 | self, 198 | disc_start, 199 | logvar_init=0.0, 200 | kl_weight=1.0, 201 | pixelloss_weight=1.0, 202 | disc_num_layers=3, 203 | disc_in_channels=4, 204 | disc_factor=1.0, 205 | disc_weight=1.0, 206 | perceptual_weight=1.0, 207 | use_actnorm=False, 208 | disc_conditional=False, 209 | disc_loss="hinge", 210 | max_bs=None, 211 | ): 212 | 213 | super().__init__() 214 | assert disc_loss in ["hinge", "vanilla"] 215 | self.kl_weight = kl_weight 216 | self.pixel_weight = pixelloss_weight 217 | self.perceptual_loss = nn.MSELoss() 218 | self.perceptual_weight = perceptual_weight 219 | # output log variance 220 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 221 | 222 | self.discriminator = NLayerDiscriminator( 223 | input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm 224 | ).apply(weights_init) 225 | self.discriminator_iter_start = disc_start 226 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 227 | self.disc_factor = disc_factor 228 | self.discriminator_weight = disc_weight 229 | self.disc_conditional = disc_conditional 230 | self.max_bs = max_bs 231 | 232 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 233 | if last_layer is not None: 234 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 235 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 236 | else: 237 | nll_grads = torch.autograd.grad( 238 | nll_loss, self.last_layer[0], retain_graph=True 239 | )[0] 240 | g_grads = torch.autograd.grad( 241 | g_loss, self.last_layer[0], retain_graph=True 242 | )[0] 243 | 244 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 245 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 246 | d_weight = d_weight * self.discriminator_weight 247 | return d_weight 248 | 249 | def forward( 250 | self, 251 | inputs, 252 | reconstructions, 253 | posteriors, 254 | optimizer_idx, 255 | global_step, 256 | last_layer=None, 257 | cond=None, 258 | split="train", 259 | weights=None, 260 | ): 261 | if inputs.dim() == 5: 262 | inputs = rearrange(inputs, "b c t h w -> (b t) c h w") 263 | if reconstructions.dim() == 5: 264 | reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w") 265 | 266 | # print('loss shape: ', inputs.shape, reconstructions.shape) 267 | # exit() 268 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 269 | if self.perceptual_weight > 0: 270 | p_loss = self.perceptual_loss( 271 | inputs.contiguous(), reconstructions.contiguous() 272 | ) 273 | rec_loss = rec_loss + self.perceptual_weight * p_loss 274 | 275 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 276 | weighted_nll_loss = nll_loss 277 | if weights is not None: 278 | weighted_nll_loss = weights * nll_loss 279 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 280 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 281 | 282 | kl_loss = posteriors.kl() 283 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 284 | 285 | if global_step < self.discriminator_iter_start: 286 | loss = weighted_nll_loss + self.kl_weight * kl_loss 287 | log = { 288 | "{}/total_loss".format(split): loss.clone().detach().mean(), 289 | "{}/logvar".format(split): self.logvar.detach(), 290 | "{}/kl_loss".format(split): kl_loss.detach().mean(), 291 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 292 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 293 | } 294 | 295 | return loss, log 296 | 297 | # now the GAN part 298 | if optimizer_idx == 0: 299 | # generator update 300 | if cond is None: 301 | assert not self.disc_conditional 302 | logits_fake = self.discriminator(reconstructions.contiguous()) 303 | else: 304 | assert self.disc_conditional 305 | logits_fake = self.discriminator( 306 | torch.cat((reconstructions.contiguous(), cond), dim=1) 307 | ) 308 | g_loss = -torch.mean(logits_fake) 309 | 310 | if self.disc_factor > 0.0: 311 | try: 312 | d_weight = self.calculate_adaptive_weight( 313 | nll_loss, g_loss, last_layer=last_layer 314 | ) 315 | except RuntimeError: 316 | assert not self.training 317 | d_weight = torch.tensor(0.0) 318 | else: 319 | d_weight = torch.tensor(0.0) 320 | 321 | disc_factor = adopt_weight( 322 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 323 | ) 324 | loss = ( 325 | weighted_nll_loss 326 | + self.kl_weight * kl_loss 327 | + d_weight * disc_factor * g_loss 328 | ) 329 | 330 | log = { 331 | "{}/total_loss".format(split): loss.clone().detach().mean(), 332 | "{}/logvar".format(split): self.logvar.detach(), 333 | "{}/kl_loss".format(split): kl_loss.detach().mean(), 334 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 335 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 336 | "{}/d_weight".format(split): d_weight.detach(), 337 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 338 | "{}/g_loss".format(split): g_loss.detach().mean(), 339 | } 340 | return loss, log 341 | 342 | if optimizer_idx == 1: 343 | # second pass for discriminator update 344 | if cond is None: 345 | logits_real = self.discriminator(inputs.contiguous().detach()) 346 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 347 | else: 348 | logits_real = self.discriminator( 349 | torch.cat((inputs.contiguous().detach(), cond), dim=1) 350 | ) 351 | logits_fake = self.discriminator( 352 | torch.cat((reconstructions.contiguous().detach(), cond), dim=1) 353 | ) 354 | 355 | disc_factor = adopt_weight( 356 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 357 | ) 358 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 359 | 360 | log = { 361 | "{}/disc_loss".format(split): d_loss.clone().detach().mean(), 362 | "{}/logits_real".format(split): logits_real.detach().mean(), 363 | "{}/logits_fake".format(split): logits_fake.detach().mean(), 364 | } 365 | return d_loss, log 366 | 367 | 368 | class NLayerDiscriminator3D(nn.Module): 369 | """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs.""" 370 | 371 | def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False): 372 | """ 373 | Construct a 3D PatchGAN discriminator 374 | 375 | Parameters: 376 | input_nc (int) -- the number of channels in input volumes 377 | ndf (int) -- the number of filters in the last conv layer 378 | n_layers (int) -- the number of conv layers in the discriminator 379 | use_actnorm (bool) -- flag to use actnorm instead of batchnorm 380 | """ 381 | super(NLayerDiscriminator3D, self).__init__() 382 | if not use_actnorm: 383 | norm_layer = nn.BatchNorm3d 384 | else: 385 | raise NotImplementedError("Not implemented.") 386 | if type(norm_layer) == functools.partial: 387 | use_bias = norm_layer.func != nn.BatchNorm3d 388 | else: 389 | use_bias = norm_layer != nn.BatchNorm3d 390 | 391 | kw = 3 392 | padw = 1 393 | sequence = [ 394 | nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 395 | nn.LeakyReLU(0.2, True), 396 | ] 397 | nf_mult = 1 398 | nf_mult_prev = 1 399 | for n in range(1, n_layers): # gradually increase the number of filters 400 | nf_mult_prev = nf_mult 401 | nf_mult = min(2**n, 8) 402 | sequence += [ 403 | nn.Conv3d( 404 | ndf * nf_mult_prev, 405 | ndf * nf_mult, 406 | kernel_size=(kw, kw, kw), 407 | stride=(2 if n == 1 else 1, 2, 2), 408 | padding=padw, 409 | bias=use_bias, 410 | ), 411 | norm_layer(ndf * nf_mult), 412 | nn.LeakyReLU(0.2, True), 413 | ] 414 | 415 | nf_mult_prev = nf_mult 416 | nf_mult = min(2**n_layers, 8) 417 | sequence += [ 418 | nn.Conv3d( 419 | ndf * nf_mult_prev, 420 | ndf * nf_mult, 421 | kernel_size=(kw, kw, kw), 422 | stride=1, 423 | padding=padw, 424 | bias=use_bias, 425 | ), 426 | norm_layer(ndf * nf_mult), 427 | nn.LeakyReLU(0.2, True), 428 | ] 429 | 430 | sequence += [ 431 | nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 432 | ] # output 1 channel prediction map 433 | self.main = nn.Sequential(*sequence) 434 | 435 | def forward(self, input): 436 | """Standard forward.""" 437 | return self.main(input) 438 | 439 | 440 | class LPIPSWithDiscriminator3D(nn.Module): 441 | def __init__( 442 | self, 443 | disc_start, 444 | logvar_init=0.0, 445 | kl_weight=1.0, 446 | pixelloss_weight=1.0, 447 | perceptual_weight=1.0, 448 | # --- Discriminator Loss --- 449 | disc_num_layers=3, 450 | disc_in_channels=3, 451 | disc_factor=1.0, 452 | disc_weight=1.0, 453 | use_actnorm=False, 454 | disc_conditional=False, 455 | disc_loss="hinge", 456 | ): 457 | 458 | super().__init__() 459 | assert disc_loss in ["hinge", "vanilla"] 460 | self.kl_weight = kl_weight 461 | self.pixel_weight = pixelloss_weight 462 | self.perceptual_loss = LPIPS().eval() 463 | self.perceptual_weight = perceptual_weight 464 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 465 | 466 | self.discriminator = NLayerDiscriminator3D( 467 | input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm 468 | ).apply(weights_init) 469 | self.discriminator_iter_start = disc_start 470 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 471 | self.disc_factor = disc_factor 472 | self.discriminator_weight = disc_weight 473 | self.disc_conditional = disc_conditional 474 | 475 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 476 | if last_layer is not None: 477 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 478 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 479 | else: 480 | nll_grads = torch.autograd.grad( 481 | nll_loss, self.last_layer[0], retain_graph=True 482 | )[0] 483 | g_grads = torch.autograd.grad( 484 | g_loss, self.last_layer[0], retain_graph=True 485 | )[0] 486 | 487 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 488 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 489 | d_weight = d_weight * self.discriminator_weight 490 | return d_weight 491 | 492 | def forward( 493 | self, 494 | inputs, 495 | reconstructions, 496 | posteriors, 497 | optimizer_idx, 498 | global_step, 499 | split="train", 500 | weights=None, 501 | last_layer=None, 502 | cond=None, 503 | ): 504 | t = inputs.shape[2] 505 | inputs = rearrange(inputs, "b c t h w -> (b t) c h w") 506 | reconstructions = rearrange(reconstructions, "b c t h w -> (b t) c h w") 507 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 508 | if self.perceptual_weight > 0: 509 | p_loss = self.perceptual_loss( 510 | inputs.contiguous(), reconstructions.contiguous() 511 | ) 512 | rec_loss = rec_loss + self.perceptual_weight * p_loss 513 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 514 | weighted_nll_loss = nll_loss 515 | if weights is not None: 516 | weighted_nll_loss = weights * nll_loss 517 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 518 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 519 | kl_loss = posteriors.kl() 520 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 521 | 522 | if global_step < self.discriminator_iter_start: 523 | loss = weighted_nll_loss + self.kl_weight * kl_loss 524 | log = { 525 | "{}/total_loss".format(split): loss.clone().detach().mean(), 526 | "{}/logvar".format(split): self.logvar.detach(), 527 | "{}/kl_loss".format(split): kl_loss.detach().mean(), 528 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 529 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 530 | } 531 | 532 | return loss, log 533 | 534 | inputs = rearrange(inputs, "(b t) c h w -> b c t h w", t=t) 535 | reconstructions = rearrange(reconstructions, "(b t) c h w -> b c t h w", t=t) 536 | # GAN Part 537 | if optimizer_idx == 0: 538 | # generator update 539 | if cond is None: 540 | assert not self.disc_conditional 541 | logits_fake = self.discriminator(reconstructions.contiguous()) 542 | else: 543 | assert self.disc_conditional 544 | logits_fake = self.discriminator( 545 | torch.cat((reconstructions.contiguous(), cond), dim=1) 546 | ) 547 | g_loss = -torch.mean(logits_fake) 548 | 549 | if self.disc_factor > 0.0: 550 | try: 551 | d_weight = self.calculate_adaptive_weight( 552 | nll_loss, g_loss, last_layer=last_layer 553 | ) 554 | except RuntimeError as e: 555 | assert not self.training, print(e) 556 | d_weight = torch.tensor(0.0) 557 | else: 558 | d_weight = torch.tensor(0.0) 559 | 560 | disc_factor = adopt_weight( 561 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 562 | ) 563 | loss = ( 564 | weighted_nll_loss 565 | + self.kl_weight * kl_loss 566 | + d_weight * disc_factor * g_loss 567 | ) 568 | log = { 569 | "{}/total_loss".format(split): loss.clone().detach().mean(), 570 | "{}/logvar".format(split): self.logvar.detach(), 571 | "{}/kl_loss".format(split): kl_loss.detach().mean(), 572 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 573 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 574 | "{}/d_weight".format(split): d_weight.detach(), 575 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 576 | "{}/g_loss".format(split): g_loss.detach().mean(), 577 | } 578 | return loss, log 579 | 580 | if optimizer_idx == 1: 581 | if cond is None: 582 | logits_real = self.discriminator(inputs.contiguous().detach()) 583 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 584 | else: 585 | logits_real = self.discriminator( 586 | torch.cat((inputs.contiguous().detach(), cond), dim=1) 587 | ) 588 | logits_fake = self.discriminator( 589 | torch.cat((reconstructions.contiguous().detach(), cond), dim=1) 590 | ) 591 | 592 | disc_factor = adopt_weight( 593 | self.disc_factor, global_step, threshold=self.discriminator_iter_start 594 | ) 595 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 596 | 597 | log = { 598 | "{}/disc_loss".format(split): d_loss.clone().detach().mean(), 599 | "{}/logits_real".format(split): logits_real.detach().mean(), 600 | "{}/logits_fake".format(split): logits_fake.detach().mean(), 601 | } 602 | return d_loss, log 603 | -------------------------------------------------------------------------------- /src/modules/t5.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import re 4 | import html 5 | import urllib.parse as ul 6 | 7 | import ftfy 8 | import torch 9 | from bs4 import BeautifulSoup 10 | from transformers import T5EncoderModel, AutoTokenizer 11 | from huggingface_hub import hf_hub_download 12 | 13 | 14 | class T5Embedder: 15 | available_models = ["flan-t5-large"] 16 | bad_punct_regex = re.compile( 17 | r"[" 18 | + "#®•©™&@·º½¾¿¡§~" 19 | + "\)" 20 | + "\(" 21 | + "\]" 22 | + "\[" 23 | + "\}" 24 | + "\{" 25 | + "\|" 26 | + "\\" 27 | + "\/" 28 | + "\*" 29 | + r"]{1,}" 30 | ) # noqa 31 | 32 | def __init__( 33 | self, 34 | device, 35 | dir_or_name="flan-t5-large", 36 | *, 37 | local_cache=False, 38 | cache_dir=None, 39 | hf_token=None, 40 | use_text_preprocessing=True, 41 | t5_model_kwargs=None, 42 | torch_dtype=None, 43 | use_offload_folder=None, 44 | model_max_length=180, 45 | ): 46 | self.device = torch.device(device) 47 | print(f"T5 embedder is on {self.device}") 48 | self.torch_dtype = torch_dtype or torch.bfloat16 49 | if t5_model_kwargs is None: 50 | t5_model_kwargs = { 51 | "low_cpu_mem_usage": True, 52 | "torch_dtype": self.torch_dtype, 53 | } 54 | if use_offload_folder is not None: 55 | t5_model_kwargs["offload_folder"] = use_offload_folder 56 | t5_model_kwargs["device_map"] = { 57 | "shared": self.device, 58 | "encoder.embed_tokens": self.device, 59 | "encoder.block.0": self.device, 60 | "encoder.block.1": self.device, 61 | "encoder.block.2": self.device, 62 | "encoder.block.3": self.device, 63 | "encoder.block.4": self.device, 64 | "encoder.block.5": self.device, 65 | "encoder.block.6": self.device, 66 | "encoder.block.7": self.device, 67 | "encoder.block.8": self.device, 68 | "encoder.block.9": self.device, 69 | "encoder.block.10": self.device, 70 | "encoder.block.11": self.device, 71 | "encoder.block.12": "disk", 72 | "encoder.block.13": "disk", 73 | "encoder.block.14": "disk", 74 | "encoder.block.15": "disk", 75 | "encoder.block.16": "disk", 76 | "encoder.block.17": "disk", 77 | "encoder.block.18": "disk", 78 | "encoder.block.19": "disk", 79 | "encoder.block.20": "disk", 80 | "encoder.block.21": "disk", 81 | "encoder.block.22": "disk", 82 | "encoder.block.23": "disk", 83 | "encoder.final_layer_norm": "disk", 84 | "encoder.dropout": "disk", 85 | } 86 | else: 87 | t5_model_kwargs["device_map"] = { 88 | "shared": self.device, 89 | "encoder": self.device, 90 | } 91 | 92 | self.use_text_preprocessing = use_text_preprocessing 93 | self.hf_token = hf_token 94 | self.cache_dir = cache_dir or os.path.expanduser("~/.cache/IF_") 95 | self.dir_or_name = dir_or_name 96 | tokenizer_path, path = dir_or_name, dir_or_name 97 | if local_cache: 98 | cache_dir = os.path.join(self.cache_dir, dir_or_name) 99 | tokenizer_path, path = cache_dir, cache_dir 100 | elif dir_or_name in self.available_models: 101 | cache_dir = os.path.join(self.cache_dir, dir_or_name) 102 | for filename in [ 103 | "config.json", 104 | "special_tokens_map.json", 105 | "spiece.model", 106 | "tokenizer_config.json", 107 | "pytorch_model.bin", 108 | ]: 109 | 110 | hf_hub_download( 111 | repo_id=f"google/{dir_or_name}", 112 | filename=filename, 113 | cache_dir=cache_dir, 114 | force_filename=filename, 115 | token=self.hf_token, 116 | ) 117 | tokenizer_path, path = cache_dir, cache_dir 118 | else: 119 | cache_dir = os.path.join(self.cache_dir, "flan-t5-large") 120 | for filename in [ 121 | "config.json", 122 | "special_tokens_map.json", 123 | "spiece.model", 124 | "tokenizer_config.json", 125 | ]: 126 | hf_hub_download( 127 | repo_id="google/flan-t5-large", 128 | filename=filename, 129 | cache_dir=cache_dir, 130 | force_filename=filename, 131 | token=self.hf_token, 132 | ) 133 | tokenizer_path = cache_dir 134 | 135 | print(tokenizer_path) 136 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) 137 | self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval() 138 | self.model_max_length = model_max_length 139 | 140 | def get_text_embeddings(self, texts): 141 | texts = [self.text_preprocessing(text) for text in texts] 142 | 143 | # print(self.model_max_length) 144 | 145 | text_tokens_and_mask = self.tokenizer( 146 | texts, 147 | max_length=self.model_max_length, 148 | padding="max_length", 149 | truncation=True, 150 | return_attention_mask=True, 151 | add_special_tokens=True, 152 | return_tensors="pt", 153 | ) 154 | 155 | text_tokens_and_mask["input_ids"] = text_tokens_and_mask["input_ids"] 156 | text_tokens_and_mask["attention_mask"] = text_tokens_and_mask["attention_mask"] 157 | 158 | with torch.no_grad(): 159 | text_encoder_embs = self.model( 160 | input_ids=text_tokens_and_mask["input_ids"].to(self.device), 161 | attention_mask=text_tokens_and_mask["attention_mask"].to(self.device), 162 | )["last_hidden_state"].detach() 163 | return text_encoder_embs, text_tokens_and_mask["attention_mask"].to(self.device) 164 | 165 | def text_preprocessing(self, text): 166 | if self.use_text_preprocessing: 167 | # The exact text cleaning as was in the training stage: 168 | text = self.clean_caption(text) 169 | text = self.clean_caption(text) 170 | return text 171 | else: 172 | return text.lower().strip() 173 | 174 | @staticmethod 175 | def basic_clean(text): 176 | text = ftfy.fix_text(text) 177 | text = html.unescape(html.unescape(text)) 178 | return text.strip() 179 | 180 | def clean_caption(self, caption): 181 | caption = str(caption) 182 | caption = ul.unquote_plus(caption) 183 | caption = caption.strip().lower() 184 | caption = re.sub("", "person", caption) 185 | # urls: 186 | caption = re.sub( 187 | r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa 188 | "", 189 | caption, 190 | ) # regex for urls 191 | caption = re.sub( 192 | r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa 193 | "", 194 | caption, 195 | ) # regex for urls 196 | # html: 197 | caption = BeautifulSoup(caption, features="html.parser").text 198 | 199 | # @ 200 | caption = re.sub(r"@[\w\d]+\b", "", caption) 201 | 202 | # 31C0—31EF CJK Strokes 203 | # 31F0—31FF Katakana Phonetic Extensions 204 | # 3200—32FF Enclosed CJK Letters and Months 205 | # 3300—33FF CJK Compatibility 206 | # 3400—4DBF CJK Unified Ideographs Extension A 207 | # 4DC0—4DFF Yijing Hexagram Symbols 208 | # 4E00—9FFF CJK Unified Ideographs 209 | caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) 210 | caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) 211 | caption = re.sub(r"[\u3200-\u32ff]+", "", caption) 212 | caption = re.sub(r"[\u3300-\u33ff]+", "", caption) 213 | caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) 214 | caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) 215 | caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) 216 | ####################################################### 217 | 218 | # все виды тире / all types of dash --> "-" 219 | caption = re.sub( 220 | r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa 221 | "-", 222 | caption, 223 | ) 224 | 225 | # кавычки к одному стандарту 226 | caption = re.sub(r"[`´«»“”¨]", '"', caption) 227 | caption = re.sub(r"[‘’]", "'", caption) 228 | 229 | # " 230 | caption = re.sub(r""?", "", caption) 231 | # & 232 | caption = re.sub(r"&", "", caption) 233 | 234 | # ip adresses: 235 | caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) 236 | 237 | # article ids: 238 | caption = re.sub(r"\d:\d\d\s+$", "", caption) 239 | 240 | # \n 241 | caption = re.sub(r"\\n", " ", caption) 242 | 243 | # "#123" 244 | caption = re.sub(r"#\d{1,3}\b", "", caption) 245 | # "#12345.." 246 | caption = re.sub(r"#\d{5,}\b", "", caption) 247 | # "123456.." 248 | caption = re.sub(r"\b\d{6,}\b", "", caption) 249 | # filenames: 250 | caption = re.sub( 251 | r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption 252 | ) 253 | 254 | # 255 | caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" 256 | caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" 257 | 258 | caption = re.sub( 259 | self.bad_punct_regex, r" ", caption 260 | ) # ***AUSVERKAUFT***, #AUSVERKAUFT 261 | caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " 262 | 263 | # this-is-my-cute-cat / this_is_my_cute_cat 264 | regex2 = re.compile(r"(?:\-|\_)") 265 | if len(re.findall(regex2, caption)) > 3: 266 | caption = re.sub(regex2, " ", caption) 267 | 268 | caption = self.basic_clean(caption) 269 | 270 | caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 271 | caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc 272 | caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 273 | 274 | caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) 275 | caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) 276 | caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) 277 | caption = re.sub( 278 | r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption 279 | ) 280 | caption = re.sub(r"\bpage\s+\d+\b", "", caption) 281 | 282 | caption = re.sub( 283 | r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption 284 | ) # j2d1a2a... 285 | 286 | caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) 287 | 288 | caption = re.sub(r"\b\s+\:\s+", r": ", caption) 289 | caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) 290 | caption = re.sub(r"\s+", " ", caption) 291 | 292 | caption.strip() 293 | 294 | caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) 295 | caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) 296 | caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) 297 | caption = re.sub(r"^\.\S+$", "", caption) 298 | 299 | return caption.strip() 300 | 301 | def find_phrase_indices(self, sentence, phrase): 302 | sentence_tokens = self.tokenizer.tokenize(sentence) 303 | phrase_tokens = self.tokenizer.tokenize(phrase) 304 | 305 | phrase_len = len(phrase_tokens) 306 | for i in range(len(sentence_tokens) - phrase_len + 1): 307 | if sentence_tokens[i : i + phrase_len] == phrase_tokens: 308 | return i + 1, i + phrase_len + 1 309 | return None 310 | -------------------------------------------------------------------------------- /src/modules/utils.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | import torch.nn as nn 11 | from utils.common_utils import instantiate_from_config 12 | 13 | import math 14 | from inspect import isfunction 15 | import torch 16 | from torch import nn 17 | import torch.distributed as dist 18 | 19 | 20 | def gather_data(data, return_np=True): 21 | """gather data from multiple processes to one list""" 22 | data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())] 23 | dist.all_gather(data_list, data) # gather not supported with NCCL 24 | if return_np: 25 | data_list = [data.cpu().numpy() for data in data_list] 26 | return data_list 27 | 28 | 29 | def autocast(f): 30 | def do_autocast(*args, **kwargs): 31 | with torch.cuda.amp.autocast( 32 | enabled=True, 33 | dtype=torch.get_autocast_gpu_dtype(), 34 | cache_enabled=torch.is_autocast_cache_enabled(), 35 | ): 36 | return f(*args, **kwargs) 37 | 38 | return do_autocast 39 | 40 | 41 | def extract_into_tensor(a, t, x_shape): 42 | b, *_ = t.shape 43 | out = a.gather(-1, t) 44 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 45 | 46 | 47 | def noise_like(shape, device, repeat=False): 48 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( 49 | shape[0], *((1,) * (len(shape) - 1)) 50 | ) 51 | noise = lambda: torch.randn(shape, device=device) 52 | return repeat_noise() if repeat else noise() 53 | 54 | 55 | def default(val, d): 56 | if exists(val): 57 | return val 58 | return d() if isfunction(d) else d 59 | 60 | 61 | def exists(val): 62 | return val is not None 63 | 64 | 65 | def identity(*args, **kwargs): 66 | return nn.Identity() 67 | 68 | 69 | def uniq(arr): 70 | return {el: True for el in arr}.keys() 71 | 72 | 73 | def mean_flat(tensor): 74 | """ 75 | Take the mean over all non-batch dimensions. 76 | """ 77 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 78 | 79 | 80 | def ismap(x): 81 | if not isinstance(x, torch.Tensor): 82 | return False 83 | return (len(x.shape) == 4) and (x.shape[1] > 3) 84 | 85 | 86 | def isimage(x): 87 | if not isinstance(x, torch.Tensor): 88 | return False 89 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 90 | 91 | 92 | def max_neg_value(t): 93 | return -torch.finfo(t.dtype).max 94 | 95 | 96 | def shape_to_str(x): 97 | shape_str = "x".join([str(x) for x in x.shape]) 98 | return shape_str 99 | 100 | 101 | def init_(tensor): 102 | dim = tensor.shape[-1] 103 | std = 1 / math.sqrt(dim) 104 | tensor.uniform_(-std, std) 105 | return tensor 106 | 107 | 108 | ckpt = torch.utils.checkpoint.checkpoint 109 | 110 | 111 | def checkpoint(func, inputs, params, flag): 112 | """ 113 | Evaluate a function without caching intermediate activations, allowing for 114 | reduced memory at the expense of extra compute in the backward pass. 115 | :param func: the function to evaluate. 116 | :param inputs: the argument sequence to pass to `func`. 117 | :param params: a sequence of parameters `func` depends on but does not 118 | explicitly take as arguments. 119 | :param flag: if False, disable gradient checkpointing. 120 | """ 121 | if flag: 122 | return ckpt(func, *inputs) 123 | else: 124 | return func(*inputs) 125 | 126 | 127 | def disabled_train(self, mode=True): 128 | """Overwrite model.train with this function to make sure train/eval mode 129 | does not change anymore.""" 130 | return self 131 | 132 | 133 | def zero_module(module): 134 | """ 135 | Zero out the parameters of a module and return it. 136 | """ 137 | for p in module.parameters(): 138 | p.detach().zero_() 139 | return module 140 | 141 | 142 | def scale_module(module, scale): 143 | """ 144 | Scale the parameters of a module and return it. 145 | """ 146 | for p in module.parameters(): 147 | p.detach().mul_(scale) 148 | return module 149 | 150 | 151 | def conv_nd(dims, *args, **kwargs): 152 | """ 153 | Create a 1D, 2D, or 3D convolution module. 154 | """ 155 | if dims == 1: 156 | return nn.Conv1d(*args, **kwargs) 157 | elif dims == 2: 158 | return nn.Conv2d(*args, **kwargs) 159 | elif dims == 3: 160 | return nn.Conv3d(*args, **kwargs) 161 | raise ValueError(f"unsupported dimensions: {dims}") 162 | 163 | 164 | def linear(*args, **kwargs): 165 | """ 166 | Create a linear module. 167 | """ 168 | return nn.Linear(*args, **kwargs) 169 | 170 | 171 | def avg_pool_nd(dims, *args, **kwargs): 172 | """ 173 | Create a 1D, 2D, or 3D average pooling module. 174 | """ 175 | if dims == 1: 176 | return nn.AvgPool1d(*args, **kwargs) 177 | elif dims == 2: 178 | return nn.AvgPool2d(*args, **kwargs) 179 | elif dims == 3: 180 | return nn.AvgPool3d(*args, **kwargs) 181 | raise ValueError(f"unsupported dimensions: {dims}") 182 | 183 | 184 | def nonlinearity(type="silu"): 185 | if type == "silu": 186 | return nn.SiLU() 187 | elif type == "leaky_relu": 188 | return nn.LeakyReLU() 189 | 190 | 191 | class GroupNormSpecific(nn.GroupNorm): 192 | def forward(self, x): 193 | if x.dtype == torch.float16 or x.dtype == torch.bfloat16: 194 | return super().forward(x).type(x.dtype) 195 | else: 196 | return super().forward(x.float()).type(x.dtype) 197 | 198 | 199 | def normalization(channels, num_groups=32): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNormSpecific(num_groups, channels) 206 | 207 | 208 | class HybridConditioner(nn.Module): 209 | 210 | def __init__(self, c_concat_config, c_crossattn_config): 211 | super().__init__() 212 | self.concat_conditioner = instantiate_from_config(c_concat_config) 213 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 214 | 215 | def forward(self, c_concat, c_crossattn): 216 | c_concat = self.concat_conditioner(c_concat) 217 | c_crossattn = self.crossattn_conditioner(c_crossattn) 218 | return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} 219 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, datetime 2 | from omegaconf import OmegaConf 3 | from transformers import logging as transf_logging 4 | 5 | import torch 6 | import pytorch_lightning as pl 7 | from pytorch_lightning import seed_everything 8 | from pytorch_lightning.trainer import Trainer 9 | 10 | sys.path.insert(0, os.getcwd()) 11 | from utils.common_utils import instantiate_from_config 12 | from utils.train_utils import ( 13 | get_trainer_callbacks, 14 | get_trainer_logger, 15 | get_trainer_strategy, 16 | ) 17 | from utils.train_utils import ( 18 | set_logger, 19 | init_workspace, 20 | load_checkpoints, 21 | get_autoresume_path, 22 | ) 23 | 24 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 25 | 26 | 27 | def get_parser(**parser_kwargs): 28 | parser = argparse.ArgumentParser(**parser_kwargs) 29 | parser.add_argument( 30 | "--seed", "-s", type=int, default=20230211, help="seed for seed_everything" 31 | ) 32 | parser.add_argument( 33 | "--name", "-n", type=str, default="", help="experiment name, as saving folder" 34 | ) 35 | 36 | parser.add_argument( 37 | "--base", 38 | "-b", 39 | nargs="*", 40 | metavar="base_config.yaml", 41 | help="paths to base configs. Loaded from left-to-right. " 42 | "Parameters can be overwritten or added with command-line options of the form `--key value`.", 43 | default=list(), 44 | ) 45 | 46 | parser.add_argument( 47 | "--train", "-t", action="store_true", default=False, help="train" 48 | ) 49 | parser.add_argument("--val", "-v", action="store_true", default=False, help="val") 50 | parser.add_argument("--test", action="store_true", default=False, help="test") 51 | 52 | parser.add_argument( 53 | "--logdir", 54 | "-l", 55 | type=str, 56 | default="logs", 57 | help="directory for logging dat shit", 58 | ) 59 | parser.add_argument( 60 | "--auto_resume", 61 | action="store_true", 62 | default=False, 63 | help="resume from full-info checkpoint", 64 | ) 65 | parser.add_argument( 66 | "--debug", 67 | "-d", 68 | action="store_true", 69 | default=False, 70 | help="enable post-mortem debugging", 71 | ) 72 | 73 | return parser 74 | 75 | 76 | def get_nondefault_trainer_args(args): 77 | parser = argparse.ArgumentParser() 78 | parser = Trainer.add_argparse_args(parser) 79 | default_trainer_args = parser.parse_args([]) 80 | return sorted( 81 | k 82 | for k in vars(default_trainer_args) 83 | if getattr(args, k) != getattr(default_trainer_args, k) 84 | ) 85 | 86 | 87 | if __name__ == "__main__": 88 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 89 | try: 90 | local_rank = int(os.environ.get("LOCAL_RANK")) 91 | global_rank = int(os.environ.get("RANK")) 92 | num_rank = int(os.environ.get("WORLD_SIZE")) 93 | except: 94 | local_rank, global_rank, num_rank = 0, 0, 1 95 | # print(f'local_rank: {local_rank} | global_rank:{global_rank} | num_rank:{num_rank}') 96 | 97 | parser = get_parser() 98 | ## Extends existing argparse by default Trainer attributes 99 | parser = Trainer.add_argparse_args(parser) 100 | args, unknown = parser.parse_known_args() 101 | ## disable transformer warning 102 | transf_logging.set_verbosity_error() 103 | seed_everything(args.seed) 104 | 105 | ## yaml configs: "model" | "data" | "lightning" 106 | configs = [OmegaConf.load(cfg) for cfg in args.base] 107 | cli = OmegaConf.from_dotlist(unknown) 108 | config = OmegaConf.merge(*configs, cli) 109 | lightning_config = config.pop("lightning", OmegaConf.create()) 110 | trainer_config = lightning_config.get("trainer", OmegaConf.create()) 111 | 112 | ## setup workspace directories 113 | workdir, ckptdir, cfgdir, loginfo = init_workspace( 114 | args.name, args.logdir, config, lightning_config, global_rank 115 | ) 116 | logger = set_logger( 117 | logfile=os.path.join(loginfo, "log_%d:%s.txt" % (global_rank, now)) 118 | ) 119 | logger.info("@lightning version: %s [>=1.8 required]" % (pl.__version__)) 120 | 121 | ## MODEL CONFIG >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 122 | logger.info("***** Configing Model *****") 123 | config.model.params.logdir = workdir 124 | model = instantiate_from_config(config.model) 125 | 126 | if args.auto_resume: 127 | ## the saved checkpoint must be: full-info checkpoint 128 | resume_ckpt_path = get_autoresume_path(workdir) 129 | if resume_ckpt_path is not None: 130 | args.resume_from_checkpoint = resume_ckpt_path 131 | logger.info("Resuming from checkpoint: %s" % args.resume_from_checkpoint) 132 | ## just in case train empy parameters only 133 | else: 134 | model = load_checkpoints(model, config.model) 135 | logger.warning("Auto-resuming skipped as No checkpoit found!") 136 | else: 137 | model = load_checkpoints(model, config.model) 138 | 139 | ## update trainer config 140 | for k in get_nondefault_trainer_args(args): 141 | trainer_config[k] = getattr(args, k) 142 | 143 | print(trainer_config) 144 | num_nodes = trainer_config.num_nodes 145 | ngpu_per_node = trainer_config.devices 146 | logger.info(f"Running on {num_rank}={num_nodes}x{ngpu_per_node} GPUs") 147 | 148 | ## setup learning rate 149 | base_lr = config.model.base_learning_rate 150 | bs = config.data.params.batch_size 151 | if getattr(config.model, "scale_lr", True): 152 | model.learning_rate = num_rank * bs * base_lr 153 | else: 154 | model.learning_rate = base_lr 155 | 156 | ## DATA CONFIG >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 157 | logger.info("***** Configing Data *****") 158 | data = instantiate_from_config(config.data) 159 | data.setup() 160 | for k in data.datasets: 161 | logger.info( 162 | f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}" 163 | ) 164 | 165 | ## TRAINER CONFIG >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 166 | logger.info("***** Configing Trainer *****") 167 | if "accelerator" not in trainer_config: 168 | trainer_config["accelerator"] = "gpu" 169 | 170 | torch.set_float32_matmul_precision("medium") 171 | 172 | ## setup trainer args: pl-logger and callbacks 173 | trainer_kwargs = dict() 174 | trainer_kwargs["num_sanity_val_steps"] = 0 175 | logger_cfg = get_trainer_logger(lightning_config, workdir, args.debug) 176 | trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) 177 | 178 | ## setup callbacks 179 | callbacks_cfg = get_trainer_callbacks( 180 | lightning_config, config, workdir, ckptdir, logger 181 | ) 182 | trainer_kwargs["callbacks"] = [ 183 | instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg 184 | ] 185 | strategy_cfg = get_trainer_strategy(lightning_config) 186 | trainer_kwargs["strategy"] = ( 187 | strategy_cfg 188 | if type(strategy_cfg) == str 189 | else instantiate_from_config(strategy_cfg) 190 | ) 191 | trainer_kwargs["precision"] = lightning_config.get("precision", "bf16") 192 | trainer_kwargs["sync_batchnorm"] = False 193 | 194 | ## trainer config: others 195 | if ( 196 | "train" in config.data.params 197 | and config.data.params.train.target == "lvdm.data.hdvila.HDVila" 198 | or ( 199 | "validation" in config.data.params 200 | and config.data.params.validation.target == "lvdm.data.hdvila.HDVila" 201 | ) 202 | ): 203 | trainer_kwargs["replace_sampler_ddp"] = False 204 | 205 | ## for debug 206 | # trainer_kwargs["fast_dev_run"] = 10 207 | # trainer_kwargs["limit_train_batches"] = 1./32 208 | # trainer_kwargs["limit_val_batches"] = 0.01 209 | # trainer_kwargs["val_check_interval"] = 20 #float: epoch ratio | integer: batch num 210 | 211 | trainer_args = argparse.Namespace(**trainer_config) 212 | trainer = Trainer.from_argparse_args(trainer_args, **trainer_kwargs) 213 | 214 | ## allow checkpointing via USR1 215 | def melk(*args, **kwargs): 216 | ## run all checkpoint hooks 217 | if trainer.global_rank == 0: 218 | print("Summoning checkpoint.") 219 | ckpt_path = os.path.join(ckptdir, "last_summoning.ckpt") 220 | trainer.save_checkpoint(ckpt_path) 221 | 222 | def divein(*args, **kwargs): 223 | if trainer.global_rank == 0: 224 | import pudb 225 | 226 | pudb.set_trace() 227 | 228 | import signal 229 | 230 | signal.signal(signal.SIGUSR1, melk) 231 | signal.signal(signal.SIGUSR2, divein) 232 | 233 | ## Running LOOP >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 234 | logger.info("***** Running the Loop *****") 235 | if args.train: 236 | try: 237 | if "strategy" in lightning_config: 238 | logger.info("") 239 | ## deepspeed 240 | with torch.cuda.amp.autocast(): 241 | trainer.fit(model, data) 242 | else: 243 | logger.info("") 244 | ## ddpshare 245 | trainer.fit(model, data) 246 | except Exception: 247 | # melk() 248 | raise 249 | if args.val: 250 | trainer.validate(model, data) 251 | if args.test or not trainer.interrupted: 252 | trainer.test(model, data) 253 | -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | 5 | mainlogger = logging.getLogger("mainlogger") 6 | 7 | import torch 8 | import torchvision 9 | import pytorch_lightning as pl 10 | from pytorch_lightning.callbacks import Callback 11 | from pytorch_lightning.utilities import rank_zero_only 12 | from pytorch_lightning.utilities import rank_zero_info 13 | from utils.save_video import log_local, prepare_to_log 14 | 15 | 16 | class ImageLogger(Callback): 17 | def __init__( 18 | self, 19 | batch_frequency, 20 | max_images=8, 21 | clamp=True, 22 | rescale=True, 23 | save_dir=None, 24 | to_local=False, 25 | log_images_kwargs=None, 26 | ): 27 | super().__init__() 28 | self.rescale = rescale 29 | self.batch_freq = batch_frequency 30 | self.max_images = max_images 31 | self.to_local = to_local 32 | self.clamp = clamp 33 | self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {} 34 | if self.to_local: 35 | ## default save dir 36 | self.save_dir = os.path.join(save_dir, "images") 37 | os.makedirs(os.path.join(self.save_dir, "train"), exist_ok=True) 38 | os.makedirs(os.path.join(self.save_dir, "val"), exist_ok=True) 39 | 40 | def log_to_tensorboard(self, pl_module, batch_logs, filename, split, save_fps=10): 41 | """log images and videos to tensorboard""" 42 | global_step = pl_module.global_step 43 | for key in batch_logs: 44 | value = batch_logs[key] 45 | tag = "gs%d-%s/%s-%s" % (global_step, split, filename, key) 46 | if isinstance(value, list) and isinstance(value[0], str): 47 | captions = " |------| ".join(value) 48 | pl_module.logger.experiment.add_text( 49 | tag, captions, global_step=global_step 50 | ) 51 | elif isinstance(value, torch.Tensor) and value.dim() == 5: 52 | video = value 53 | n = video.shape[0] 54 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 55 | frame_grids = [ 56 | torchvision.utils.make_grid(framesheet, nrow=int(n)) 57 | for framesheet in video 58 | ] # [3, n*h, 1*w] 59 | grid = torch.stack( 60 | frame_grids, dim=0 61 | ) # stack in temporal dim [t, 3, n*h, w] 62 | grid = (grid + 1.0) / 2.0 63 | grid = grid.unsqueeze(dim=0) 64 | pl_module.logger.experiment.add_video( 65 | tag, grid, fps=save_fps, global_step=global_step 66 | ) 67 | elif isinstance(value, torch.Tensor) and value.dim() == 4: 68 | img = value 69 | grid = torchvision.utils.make_grid(img, nrow=int(n)) 70 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 71 | pl_module.logger.experiment.add_image( 72 | tag, grid, global_step=global_step 73 | ) 74 | else: 75 | pass 76 | 77 | @rank_zero_only 78 | def log_batch_imgs(self, pl_module, batch, batch_idx, split="train"): 79 | """generate images, then save and log to tensorboard""" 80 | skip_freq = self.batch_freq if split == "train" else 5 81 | if (batch_idx + 1) % skip_freq == 0: 82 | is_train = pl_module.training 83 | if is_train: 84 | pl_module.eval() 85 | 86 | with torch.no_grad(): 87 | log_func = pl_module.log_images 88 | batch_logs = log_func(batch, split=split, **self.log_images_kwargs) 89 | 90 | ## process: move to CPU and clamp 91 | batch_logs = prepare_to_log(batch_logs, self.max_images, self.clamp) 92 | torch.cuda.empty_cache() 93 | 94 | filename = "ep{}_idx{}_rank{}".format( 95 | pl_module.current_epoch, batch_idx, pl_module.global_rank 96 | ) 97 | if self.to_local: 98 | mainlogger.info("Log [%s] batch <%s> to local ..." % (split, filename)) 99 | filename = "gs{}_".format(pl_module.global_step) + filename 100 | log_local( 101 | batch_logs, 102 | os.path.join(self.save_dir, split), 103 | filename, 104 | save_fps=10, 105 | ) 106 | else: 107 | mainlogger.info( 108 | "Log [%s] batch <%s> to tensorboard ..." % (split, filename) 109 | ) 110 | self.log_to_tensorboard( 111 | pl_module, batch_logs, filename, split, save_fps=10 112 | ) 113 | mainlogger.info("Finish!") 114 | 115 | if is_train: 116 | pl_module.train() 117 | 118 | def on_train_batch_end( 119 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None 120 | ): 121 | if self.batch_freq != -1 and pl_module.logdir: 122 | self.log_batch_imgs(pl_module, batch, batch_idx, split="train") 123 | 124 | def on_validation_batch_end( 125 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None 126 | ): 127 | ## different with validation_step() that saving the whole validation set and only keep the latest, 128 | ## it records the performance of every validation (without overwritten) by only keep a subset 129 | if self.batch_freq != -1 and pl_module.logdir: 130 | self.log_batch_imgs(pl_module, batch, batch_idx, split="val") 131 | if hasattr(pl_module, "calibrate_grad_norm"): 132 | if ( 133 | pl_module.calibrate_grad_norm and batch_idx % 25 == 0 134 | ) and batch_idx > 0: 135 | self.log_gradients(trainer, pl_module, batch_idx=batch_idx) 136 | 137 | 138 | """ 139 | class DataModeSwitcher(Callback): 140 | def on_epoch_start(self, trainer, pl_module): 141 | mode = 'image' if random.random() <= 0.3 else 'video' 142 | trainer.datamodule.dataset.set_mode(mode) 143 | if trainer.global_rank == 0: 144 | torch.distributed.barrier() 145 | """ 146 | 147 | 148 | class CUDACallback(Callback): 149 | # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py 150 | def on_train_epoch_start(self, trainer, pl_module): 151 | # Reset the memory use counter 152 | # lightning update 153 | if int((pl.__version__).split(".")[1]) >= 7: 154 | gpu_index = trainer.strategy.root_device.index 155 | else: 156 | gpu_index = trainer.root_gpu 157 | torch.cuda.reset_peak_memory_stats(gpu_index) 158 | torch.cuda.synchronize(gpu_index) 159 | self.start_time = time.time() 160 | 161 | def on_train_epoch_end(self, trainer, pl_module): 162 | if int((pl.__version__).split(".")[1]) >= 7: 163 | gpu_index = trainer.strategy.root_device.index 164 | else: 165 | gpu_index = trainer.root_gpu 166 | torch.cuda.synchronize(gpu_index) 167 | max_memory = torch.cuda.max_memory_allocated(gpu_index) / 2**20 168 | epoch_time = time.time() - self.start_time 169 | 170 | try: 171 | max_memory = trainer.training_type_plugin.reduce(max_memory) 172 | epoch_time = trainer.training_type_plugin.reduce(epoch_time) 173 | 174 | rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") 175 | rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") 176 | except AttributeError: 177 | pass 178 | -------------------------------------------------------------------------------- /utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import cv2, os 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | def count_params(model, verbose=False): 9 | total_params = sum(p.numel() for p in model.parameters()) 10 | if verbose: 11 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 12 | return total_params 13 | 14 | 15 | def check_istarget(name, para_list): 16 | """ 17 | name: full name of source para 18 | para_list: partial name of target para 19 | """ 20 | istarget = False 21 | for para in para_list: 22 | if para in name: 23 | return True 24 | return istarget 25 | 26 | 27 | def instantiate_from_config(config): 28 | if not "target" in config: 29 | if config == "__is_first_stage__": 30 | return None 31 | elif config == "__is_unconditional__": 32 | return None 33 | raise KeyError("Expected key `target` to instantiate.") 34 | 35 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 36 | 37 | 38 | def get_obj_from_str(string, reload=False): 39 | module, cls = string.rsplit(".", 1) 40 | if reload: 41 | module_imp = importlib.import_module(module) 42 | importlib.reload(module_imp) 43 | return getattr(importlib.import_module(module, package=None), cls) 44 | 45 | 46 | def load_npz_from_dir(data_dir): 47 | data = [ 48 | np.load(os.path.join(data_dir, data_name))["arr_0"] 49 | for data_name in os.listdir(data_dir) 50 | ] 51 | data = np.concatenate(data, axis=0) 52 | return data 53 | 54 | 55 | def load_npz_from_paths(data_paths): 56 | data = [np.load(data_path)["arr_0"] for data_path in data_paths] 57 | data = np.concatenate(data, axis=0) 58 | return data 59 | 60 | 61 | def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None): 62 | h, w = image.shape[:2] 63 | if resize_short_edge is not None: 64 | k = resize_short_edge / min(h, w) 65 | else: 66 | k = max_resolution / (h * w) 67 | k = k**0.5 68 | h = int(np.round(h * k / 64)) * 64 69 | w = int(np.round(w * k / 64)) * 64 70 | image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4) 71 | return image 72 | 73 | 74 | def setup_dist(args): 75 | if dist.is_initialized(): 76 | return 77 | torch.cuda.set_device(args.local_rank) 78 | torch.distributed.init_process_group("nccl", init_method="env://") 79 | -------------------------------------------------------------------------------- /utils/save_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from PIL import Image 5 | from einops import rearrange 6 | 7 | import torch 8 | import torchvision 9 | from torch import Tensor 10 | from torchvision.utils import make_grid 11 | from torchvision.transforms.functional import to_tensor 12 | from PIL import Image, ImageDraw, ImageFont 13 | 14 | 15 | def save_video_tensor_to_mp4(video, path, fps): 16 | # b,c,t,h,w 17 | video = video.detach().cpu() 18 | video = torch.clamp(video.float(), -1.0, 1.0) 19 | n = video.shape[0] 20 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 21 | frame_grids = [ 22 | torchvision.utils.make_grid(framesheet, nrow=int(n)) for framesheet in video 23 | ] # [3, 1*h, n*w] 24 | grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] 25 | grid = (grid + 1.0) / 2.0 26 | grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) 27 | torchvision.io.write_video( 28 | path, grid, fps=fps, video_codec="h264", options={"crf": "10"} 29 | ) 30 | 31 | 32 | def save_video_tensor_to_frames(video, dir): 33 | os.makedirs(dir, exist_ok=True) 34 | # b,c,t,h,w 35 | video = video.detach().cpu() 36 | video = torch.clamp(video.float(), -1.0, 1.0) 37 | n = video.shape[0] 38 | assert n == 1 39 | video = video[0] # cthw 40 | video = video.permute(1, 2, 3, 0) # thwc 41 | # video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 42 | video = (video + 1.0) / 2.0 * 255 43 | video = video.to(torch.uint8).numpy() 44 | for i in range(video.shape[0]): 45 | img = video[i] # hwc 46 | image = Image.fromarray(img) 47 | image.save(os.path.join(dir, f"frame{i:03d}.jpg"), q=95) 48 | 49 | 50 | def frames_to_mp4(frame_dir, output_path, fps): 51 | def read_first_n_frames(d: os.PathLike, num_frames: int): 52 | if num_frames: 53 | images = [ 54 | Image.open(os.path.join(d, f)) 55 | for f in sorted(os.listdir(d))[:num_frames] 56 | ] 57 | else: 58 | images = [Image.open(os.path.join(d, f)) for f in sorted(os.listdir(d))] 59 | images = [to_tensor(x) for x in images] 60 | return torch.stack(images) 61 | 62 | videos = read_first_n_frames(frame_dir, num_frames=None) 63 | videos = videos.mul(255).to(torch.uint8).permute(0, 2, 3, 1) 64 | torchvision.io.write_video( 65 | output_path, videos, fps=fps, video_codec="h264", options={"crf": "10"} 66 | ) 67 | 68 | 69 | def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None): 70 | """ 71 | video: torch.Tensor, b,c,t,h,w, 0-1 72 | if -1~1, enable rescale=True 73 | """ 74 | n = video.shape[0] 75 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 76 | nrow = int(np.sqrt(n)) if nrow is None else nrow 77 | frame_grids = [ 78 | torchvision.utils.make_grid(framesheet, nrow=nrow) for framesheet in video 79 | ] # [3, grid_h, grid_w] 80 | grid = torch.stack( 81 | frame_grids, dim=0 82 | ) # stack in temporal dim [T, 3, grid_h, grid_w] 83 | grid = torch.clamp(grid.float(), -1.0, 1.0) 84 | if rescale: 85 | grid = (grid + 1.0) / 2.0 86 | grid = ( 87 | (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) 88 | ) # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3] 89 | # print(f'Save video to {savepath}') 90 | torchvision.io.write_video( 91 | savepath, grid, fps=fps, video_codec="h264", options={"crf": "10"} 92 | ) 93 | 94 | 95 | def tensor2videogrids(video, root, filename, fps, rescale=True, clamp=True): 96 | 97 | assert video.dim() == 5 # b,c,t,h,w 98 | assert isinstance(video, torch.Tensor) 99 | 100 | video = video.detach().cpu() 101 | if clamp: 102 | video = torch.clamp(video, -1.0, 1.0) 103 | n = video.shape[0] 104 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 105 | frame_grids = [ 106 | torchvision.utils.make_grid(framesheet, nrow=int(np.sqrt(n))) 107 | for framesheet in video 108 | ] # [3, grid_h, grid_w] 109 | grid = torch.stack( 110 | frame_grids, dim=0 111 | ) # stack in temporal dim [T, 3, grid_h, grid_w] 112 | if rescale: 113 | grid = (grid + 1.0) / 2.0 114 | grid = ( 115 | (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) 116 | ) # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3] 117 | path = os.path.join(root, filename) 118 | # print('Save video ...') 119 | torchvision.io.write_video( 120 | path, grid, fps=fps, video_codec="h264", options={"crf": "10"} 121 | ) 122 | # print('Finish!') 123 | 124 | 125 | def log_txt_as_img(wh, xc, size=10): 126 | # wh a tuple of (width, height) 127 | # xc a list of captions to plot 128 | b = len(xc) 129 | txts = list() 130 | for bi in range(b): 131 | txt = Image.new("RGB", wh, color="white") 132 | draw = ImageDraw.Draw(txt) 133 | font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) 134 | nc = int(40 * (wh[0] / 256)) 135 | lines = "\n".join( 136 | xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) 137 | ) 138 | 139 | try: 140 | draw.text((0, 0), lines, fill="black", font=font) 141 | except UnicodeEncodeError: 142 | print("Cant encode string for logging. Skipping.") 143 | 144 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 145 | txts.append(txt) 146 | txts = np.stack(txts) 147 | txts = torch.tensor(txts) 148 | return txts 149 | 150 | 151 | def log_local(batch_logs, save_dir, filename, save_fps=10, rescale=True): 152 | if batch_logs is None: 153 | return None 154 | """ save images and videos from images dict """ 155 | 156 | def save_img_grid(grid, path, rescale): 157 | if rescale: 158 | grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w 159 | grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) 160 | grid = grid.numpy() 161 | grid = (grid * 255).astype(np.uint8) 162 | os.makedirs(os.path.split(path)[0], exist_ok=True) 163 | Image.fromarray(grid).save(path) 164 | 165 | for key in batch_logs: 166 | value = batch_logs[key] 167 | if isinstance(value, list) and isinstance(value[0], str): 168 | ## a batch of captions 169 | path = os.path.join(save_dir, "%s-%s.txt" % (key, filename)) 170 | with open(path, "w") as f: 171 | for i, txt in enumerate(value): 172 | f.write(f"idx={i}, txt={txt}\n") 173 | f.close() 174 | elif isinstance(value, torch.Tensor) and value.dim() == 5: 175 | ## save video grids 176 | video = value # b,c,t,h,w 177 | ## only save grayscale or rgb mode 178 | if video.shape[1] != 1 and video.shape[1] != 3: 179 | continue 180 | n = video.shape[0] 181 | video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w 182 | frame_grids = [ 183 | torchvision.utils.make_grid(framesheet, nrow=int(1)) 184 | for framesheet in video 185 | ] # [3, n*h, 1*w] 186 | grid = torch.stack( 187 | frame_grids, dim=0 188 | ) # stack in temporal dim [t, 3, n*h, w] 189 | if rescale: 190 | grid = (grid + 1.0) / 2.0 191 | grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) 192 | path = os.path.join(save_dir, "%s-%s.mp4" % (key, filename)) 193 | torchvision.io.write_video( 194 | path, grid, fps=save_fps, video_codec="h264", options={"crf": "10"} 195 | ) 196 | 197 | ## save frame sheet 198 | img = value 199 | video_frames = rearrange(img, "b c t h w -> (b t) c h w") 200 | t = img.shape[2] 201 | grid = torchvision.utils.make_grid(video_frames, nrow=t) 202 | path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename)) 203 | # save_img_grid(grid, path, rescale) 204 | elif isinstance(value, torch.Tensor) and value.dim() == 4: 205 | ## save image grids 206 | img = value 207 | ## only save grayscale or rgb mode 208 | if img.shape[1] != 1 and img.shape[1] != 3: 209 | continue 210 | n = img.shape[0] 211 | grid = torchvision.utils.make_grid(img, nrow=1) 212 | path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename)) 213 | save_img_grid(grid, path, rescale) 214 | else: 215 | pass 216 | 217 | 218 | def prepare_to_log(batch_logs, max_images=100000, clamp=True): 219 | if batch_logs is None: 220 | return None 221 | # process 222 | for key in batch_logs: 223 | if batch_logs[key] is not None: 224 | N = ( 225 | batch_logs[key].shape[0] 226 | if hasattr(batch_logs[key], "shape") 227 | else len(batch_logs[key]) 228 | ) 229 | N = min(N, max_images) 230 | batch_logs[key] = batch_logs[key][:N] 231 | ## in batch_logs: images & caption 232 | if isinstance(batch_logs[key], torch.Tensor): 233 | batch_logs[key] = batch_logs[key].detach().cpu() 234 | if clamp: 235 | try: 236 | batch_logs[key] = torch.clamp( 237 | batch_logs[key].float(), -1.0, 1.0 238 | ) 239 | except RuntimeError: 240 | print("clamp_scalar_cpu not implemented for Half") 241 | return batch_logs 242 | 243 | 244 | # ---------------------------------------------------------------------------------------------- 245 | 246 | 247 | def fill_with_black_squares(video, desired_len: int) -> Tensor: 248 | if len(video) >= desired_len: 249 | return video 250 | 251 | return torch.cat( 252 | [ 253 | video, 254 | torch.zeros_like(video[0]) 255 | .unsqueeze(0) 256 | .repeat(desired_len - len(video), 1, 1, 1), 257 | ], 258 | dim=0, 259 | ) 260 | 261 | 262 | # ---------------------------------------------------------------------------------------------- 263 | def load_num_videos(data_path, num_videos): 264 | # first argument can be either data_path of np array 265 | if isinstance(data_path, str): 266 | videos = np.load(data_path)["arr_0"] # NTHWC 267 | elif isinstance(data_path, np.ndarray): 268 | videos = data_path 269 | else: 270 | raise Exception 271 | 272 | if num_videos is not None: 273 | videos = videos[:num_videos, :, :, :, :] 274 | return videos 275 | 276 | 277 | def npz_to_video_grid( 278 | data_path, out_path, num_frames, fps, num_videos=None, nrow=None, verbose=True 279 | ): 280 | # videos = torch.tensor(np.load(data_path)['arr_0']).permute(0,1,4,2,3).div_(255).mul_(2) - 1.0 # NTHWC->NTCHW, np int -> torch tensor 0-1 281 | if isinstance(data_path, str): 282 | videos = load_num_videos(data_path, num_videos) 283 | elif isinstance(data_path, np.ndarray): 284 | videos = data_path 285 | else: 286 | raise Exception 287 | n, t, h, w, c = videos.shape 288 | videos_th = [] 289 | for i in range(n): 290 | video = videos[i, :, :, :, :] 291 | images = [video[j, :, :, :] for j in range(t)] 292 | images = [to_tensor(img) for img in images] 293 | video = torch.stack(images) 294 | videos_th.append(video) 295 | if verbose: 296 | videos = [ 297 | fill_with_black_squares(v, num_frames) 298 | for v in tqdm(videos_th, desc="Adding empty frames") 299 | ] # NTCHW 300 | else: 301 | videos = [fill_with_black_squares(v, num_frames) for v in videos_th] # NTCHW 302 | 303 | frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4) # [T, N, C, H, W] 304 | if nrow is None: 305 | nrow = int(np.ceil(np.sqrt(n))) 306 | if verbose: 307 | frame_grids = [ 308 | make_grid(fs, nrow=nrow) for fs in tqdm(frame_grids, desc="Making grids") 309 | ] 310 | else: 311 | frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids] 312 | 313 | if os.path.dirname(out_path) != "": 314 | os.makedirs(os.path.dirname(out_path), exist_ok=True) 315 | frame_grids = ( 316 | (torch.stack(frame_grids) * 255).to(torch.uint8).permute(0, 2, 3, 1) 317 | ) # [T, H, W, C] 318 | torchvision.io.write_video( 319 | out_path, frame_grids, fps=fps, video_codec="h264", options={"crf": "10"} 320 | ) 321 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from omegaconf import OmegaConf 3 | from collections import OrderedDict 4 | import logging 5 | 6 | mainlogger = logging.getLogger("mainlogger") 7 | 8 | import torch 9 | from collections import OrderedDict 10 | 11 | 12 | def init_workspace(name, logdir, model_config, lightning_config, rank=0): 13 | workdir = os.path.join(logdir, name) 14 | ckptdir = os.path.join(workdir, "checkpoints") 15 | cfgdir = os.path.join(workdir, "configs") 16 | loginfo = os.path.join(workdir, "loginfo") 17 | 18 | # Create logdirs and save configs (all ranks will do to avoid missing directory error if rank:0 is slower) 19 | os.makedirs(workdir, exist_ok=True) 20 | os.makedirs(ckptdir, exist_ok=True) 21 | os.makedirs(cfgdir, exist_ok=True) 22 | os.makedirs(loginfo, exist_ok=True) 23 | 24 | if rank == 0: 25 | if ( 26 | "callbacks" in lightning_config 27 | and "metrics_over_trainsteps_checkpoint" in lightning_config.callbacks 28 | ): 29 | os.makedirs(os.path.join(ckptdir, "trainstep_checkpoints"), exist_ok=True) 30 | OmegaConf.save(model_config, os.path.join(cfgdir, "model.yaml")) 31 | OmegaConf.save( 32 | OmegaConf.create({"lightning": lightning_config}), 33 | os.path.join(cfgdir, "lightning.yaml"), 34 | ) 35 | return workdir, ckptdir, cfgdir, loginfo 36 | 37 | 38 | def check_config_attribute(config, name): 39 | if name in config: 40 | value = getattr(config, name) 41 | return value 42 | else: 43 | return None 44 | 45 | 46 | def get_trainer_callbacks(lightning_config, config, logdir, ckptdir, logger): 47 | default_callbacks_cfg = { 48 | "model_checkpoint": { 49 | "target": "pytorch_lightning.callbacks.ModelCheckpoint", 50 | "params": { 51 | "dirpath": ckptdir, 52 | "filename": "{epoch}", 53 | "verbose": True, 54 | "save_last": True, 55 | }, 56 | }, 57 | "batch_logger": { 58 | "target": "utils.callbacks.ImageLogger", 59 | "params": { 60 | "save_dir": logdir, 61 | "batch_frequency": 1000, 62 | "max_images": 4, 63 | "clamp": True, 64 | }, 65 | }, 66 | "learning_rate_logger": { 67 | "target": "pytorch_lightning.callbacks.LearningRateMonitor", 68 | "params": {"logging_interval": "step", "log_momentum": False}, 69 | }, 70 | "cuda_callback": {"target": "utils.callbacks.CUDACallback"}, 71 | } 72 | 73 | ## optional setting for saving checkpoints 74 | monitor_metric = check_config_attribute(config.model.params, "monitor") 75 | if monitor_metric is not None: 76 | mainlogger.info(f"Monitoring {monitor_metric} as checkpoint metric.") 77 | default_callbacks_cfg["model_checkpoint"]["params"]["monitor"] = monitor_metric 78 | default_callbacks_cfg["model_checkpoint"]["params"]["save_top_k"] = 3 79 | default_callbacks_cfg["model_checkpoint"]["params"]["mode"] = "min" 80 | 81 | if "metrics_over_trainsteps_checkpoint" in lightning_config.callbacks: 82 | mainlogger.info( 83 | "Caution: Saving checkpoints every n train steps without deleting. This might require some free space." 84 | ) 85 | default_metrics_over_trainsteps_ckpt_dict = { 86 | "metrics_over_trainsteps_checkpoint": { 87 | "target": "pytorch_lightning.callbacks.ModelCheckpoint", 88 | "params": { 89 | "dirpath": os.path.join(ckptdir, "trainstep_checkpoints"), 90 | "filename": "{epoch}-{step}", 91 | "verbose": True, 92 | "save_top_k": -1, 93 | "every_n_train_steps": 10000, 94 | "save_weights_only": True, 95 | }, 96 | } 97 | } 98 | default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) 99 | 100 | if "callbacks" in lightning_config: 101 | callbacks_cfg = lightning_config.callbacks 102 | else: 103 | callbacks_cfg = OmegaConf.create() 104 | callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) 105 | 106 | return callbacks_cfg 107 | 108 | 109 | def get_trainer_logger(lightning_config, logdir, on_debug): 110 | default_logger_cfgs = { 111 | "tensorboard": { 112 | "target": "pytorch_lightning.loggers.TensorBoardLogger", 113 | "params": { 114 | "save_dir": logdir, 115 | "name": "tensorboard", 116 | }, 117 | }, 118 | "testtube": { 119 | "target": "pytorch_lightning.loggers.CSVLogger", 120 | "params": { 121 | "name": "testtube", 122 | "save_dir": logdir, 123 | }, 124 | }, 125 | } 126 | os.makedirs(os.path.join(logdir, "tensorboard"), exist_ok=True) 127 | default_logger_cfg = default_logger_cfgs["tensorboard"] 128 | if "logger" in lightning_config: 129 | logger_cfg = lightning_config.logger 130 | else: 131 | logger_cfg = OmegaConf.create() 132 | logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) 133 | return logger_cfg 134 | 135 | 136 | def get_trainer_strategy(lightning_config): 137 | default_strategy_dict = { 138 | "target": "pytorch_lightning.strategies.DDPShardedStrategy" 139 | } 140 | if "strategy" in lightning_config: 141 | strategy_cfg = lightning_config.strategy 142 | return strategy_cfg 143 | else: 144 | strategy_cfg = OmegaConf.create() 145 | 146 | strategy_cfg = OmegaConf.merge(default_strategy_dict, strategy_cfg) 147 | return strategy_cfg 148 | 149 | 150 | def load_checkpoints(model, model_cfg): 151 | ## special load setting for adapter training 152 | if check_config_attribute(model_cfg, "adapter_only"): 153 | pretrained_ckpt = model_cfg.pretrained_checkpoint 154 | assert os.path.exists(pretrained_ckpt), ( 155 | "Error: Pre-trained checkpoint NOT found at:%s" % pretrained_ckpt 156 | ) 157 | mainlogger.info( 158 | ">>> Load weights from pretrained checkpoint (training adapter only)" 159 | ) 160 | print(f"Loading model from {pretrained_ckpt}") 161 | ## only load weight for the backbone model (e.g. latent diffusion model) 162 | state_dict = torch.load(pretrained_ckpt, map_location=f"cpu") 163 | if "state_dict" in list(state_dict.keys()): 164 | state_dict = state_dict["state_dict"] 165 | else: 166 | # deepspeed 167 | dp_state_dict = OrderedDict() 168 | for key in state_dict["module"].keys(): 169 | dp_state_dict[key[16:]] = state_dict["module"][key] 170 | state_dict = dp_state_dict 171 | model.load_state_dict(state_dict, strict=False) 172 | model.empty_paras = None 173 | return model 174 | empty_paras = None 175 | 176 | if check_config_attribute(model_cfg, "pretrained_checkpoint"): 177 | pretrained_ckpt = model_cfg.pretrained_checkpoint 178 | assert os.path.exists(pretrained_ckpt), ( 179 | "Error: Pre-trained checkpoint NOT found at:%s" % pretrained_ckpt 180 | ) 181 | mainlogger.info(">>> Load weights from pretrained checkpoint") 182 | # mainlogger.info(pretrained_ckpt) 183 | print("Loading model from {pretrained_ckpt}") 184 | pl_sd = torch.load(pretrained_ckpt, map_location="cpu") 185 | try: 186 | if "state_dict" in pl_sd.keys(): 187 | model.load_state_dict(pl_sd["state_dict"]) 188 | else: 189 | # deepspeed 190 | new_pl_sd = OrderedDict() 191 | for key in pl_sd["module"].keys(): 192 | new_pl_sd[key[16:]] = pl_sd["module"][key] 193 | model.load_state_dict(new_pl_sd) 194 | except: 195 | model.load_state_dict(pl_sd) 196 | else: 197 | empty_paras = None 198 | 199 | ## record empty params 200 | model.empty_paras = empty_paras 201 | return model 202 | 203 | 204 | def get_autoresume_path(logdir): 205 | ckpt = os.path.join(logdir, "checkpoints", "last.ckpt") 206 | if os.path.exists(ckpt): 207 | try: 208 | tmp = torch.load(ckpt, map_location="cpu") 209 | e = tmp["epoch"] 210 | gs = tmp["global_step"] 211 | mainlogger.info(f"[INFO] Resume from epoch {e}, global step {gs}!") 212 | del tmp 213 | except: 214 | try: 215 | mainlogger.info("Load last.ckpt failed!") 216 | ckpts = sorted( 217 | [ 218 | f 219 | for f in os.listdir(os.path.join(logdir, "checkpoints")) 220 | if not os.path.isdir(f) 221 | ] 222 | ) 223 | mainlogger.info(f"all avaible checkpoints: {ckpts}") 224 | ckpts.remove("last.ckpt") 225 | if "trainstep_checkpoints" in ckpts: 226 | ckpts.remove("trainstep_checkpoints") 227 | ckpt_path = ckpts[-1] 228 | ckpt = os.path.join(logdir, "checkpoints", ckpt_path) 229 | mainlogger.info(f"Select resuming ckpt: {ckpt}") 230 | except ValueError: 231 | mainlogger.info("Load last.ckpt failed! and there is no other ckpts") 232 | 233 | resume_checkpt_path = ckpt 234 | mainlogger.info(f"[INFO] resume from: {ckpt}") 235 | else: 236 | resume_checkpt_path = None 237 | mainlogger.info( 238 | f"[INFO] no checkpoint found in current workspace: {os.path.join(logdir, 'checkpoints')}" 239 | ) 240 | 241 | return resume_checkpt_path 242 | 243 | 244 | def set_logger(logfile, name="mainlogger"): 245 | logger = logging.getLogger(name) 246 | logger.setLevel(logging.INFO) 247 | fh = logging.FileHandler(logfile, mode="w") 248 | fh.setLevel(logging.INFO) 249 | ch = logging.StreamHandler() 250 | ch.setLevel(logging.DEBUG) 251 | fh.setFormatter(logging.Formatter("%(asctime)s-%(levelname)s: %(message)s")) 252 | ch.setFormatter(logging.Formatter("%(message)s")) 253 | logger.addHandler(fh) 254 | logger.addHandler(ch) 255 | return logger 256 | --------------------------------------------------------------------------------