├── LICENSE ├── README-ZH.md ├── README.md ├── assets ├── airport_people_crowd_busy.gif ├── beach_ocean_waves_water_sand.gif ├── bee_honey_insect_beehive_nature.gif ├── coffee_beans_caffeine_coffee_shop.gif ├── fish_underwater_aquarium_swim.gif ├── forest_woods_mystical_morning.gif ├── hair_wind_girl_woman_people.gif ├── ocean_beach_sunset_sea_atmosphere.gif ├── reeds_grass_wind_golden_sunshine.gif ├── sea_ocean_seagulls_birds_sunset.gif ├── woman_flowers_plants_field_garden.gif └── wood_anemones_wildflower_flower.gif ├── configs ├── stable_diffusion │ └── tokenizer │ │ ├── merges.txt │ │ ├── special_tokens_map.json │ │ ├── tokenizer_config.json │ │ └── vocab.json └── stable_diffusion_xl │ └── tokenizer_2 │ ├── merges.txt │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.json ├── data └── pixabay100 │ ├── metadata.json │ └── videos │ └── Put videos here.txt ├── environment.yml ├── litesora ├── data │ ├── __init__.py │ ├── text_video_dataset.py │ └── utils.py ├── models │ ├── __init__.py │ ├── attention.py │ ├── patchify.py │ ├── sdxl_text_encoder_2.py │ ├── svd_vae.py │ ├── utils.py │ └── video_dit.py ├── pipelines │ ├── __init__.py │ └── pixel_video_dit.py ├── schedulers │ ├── __init__.py │ └── ddim.py └── trainers │ └── v1.py └── models ├── denoising_model └── Put denoising model checkpoints here.txt ├── text_encoder └── Put text encoder checkpoints here.txt └── vae └── Put VAE checkpoints here.txt /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README-ZH.md: -------------------------------------------------------------------------------- 1 | # Lite-Sora 2 | 3 | ## 简介 4 | 5 | lite-sora项目是一个Sora技术复现的项目,它由华东师范大学和ModelScope社区共同发起,探索Sora背后的视频生成技术的最小复现和简洁实现,我们希望可以提供简洁易读的代码方便大家一起改进实验,不断探索提升开源视频生成技术的上限。 6 | 7 | ## 技术路线 8 | 9 | * [x] 搭建基础架构 10 | * [ ] 模型 11 | * [x] Text Encoder(基于 Stable Diffusion XL 中的 [Text Encoder](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder_2/model.safetensors)) 12 | * [x] VideoDiT(基于 [Facebook DiT](https://github.com/facebookresearch/DiT)) 13 | * [ ] VideoVAE 14 | * [x] Scheduler(基于 [DDIM](https://arxiv.org/abs/2010.02502)) 15 | * [x] Trainer(基于 [PyTorch-lightning](https://lightning.ai/docs/pytorch/stable/)) 16 | * [x] 小规模数据集验证 17 | * [x] [Pixabay100](https://github.com/ECNU-CILAB/Pixabay100) 18 | * [ ] 在大规模数据集上训练 Video Encoder & Decoder 19 | * [ ] 在大规模数据集上训练 VideoDiT 20 | 21 | ## 使用 22 | 23 | ### Python 环境搭建 24 | 25 | ``` 26 | conda env create -f environment.yml 27 | conda activate litesora 28 | ``` 29 | 30 | ### 下载模型 31 | 32 | * `models/text_encoder/model.safetensors`: 来自 Stable Diffusion XL 的 Text Encoder,[下载链接](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/text_encoder_2/model.safetensors) 33 | * `models/denoising_model/model.safetensors`:我们在 [Pixabay100](https://github.com/ECNU-CILAB/Pixabay100) 数据集上训练的模型,该模型可以证明我们的训练代码能够正常拟合训练数据,分辨率为 64*64。**这个模型显然是过拟合的,由于训练数据少,尚不具备泛化能力,仅用于验证训练算法的正确性。** [下载链接](https://huggingface.co/ECNU-CILab/lite-sora-v1-pixabay100/resolve/main/denoising_model/model.safetensors) 34 | * `models/vae/model.safetensors`: Stable Video Diffusion 的 VAE. [下载链接](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/resolve/main/vae/diffusion_pytorch_model.fp16.safetensors) 35 | 36 | ### 训练 37 | 38 | ```python 39 | from litesora.data import TextVideoDataset 40 | from litesora.models import SDXLTextEncoder2 41 | from litesora.trainers.v1 import LightningVideoDiT 42 | import lightning as pl 43 | import torch 44 | 45 | 46 | if __name__ == '__main__': 47 | # dataset and data loader 48 | dataset = TextVideoDataset("data/pixabay100", "data/pixabay100/metadata.json", 49 | num_frames=64, height=64, width=64) 50 | train_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=8) 51 | 52 | # model 53 | model = LightningVideoDiT(learning_rate=1e-5) 54 | model.text_encoder.load_state_dict_from_diffusers("models/text_encoder/model.safetensors") 55 | 56 | # train 57 | trainer = pl.Trainer(max_epochs=100000, accelerator="gpu", devices="auto", callbacks=[ 58 | pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1) 59 | ]) 60 | trainer.fit(model=model, train_dataloaders=train_loader) 61 | ``` 62 | 63 | 训练程序启动后,可开启 `tensorboard` 监视训练进度 64 | 65 | ``` 66 | tensorboard --logdir . 67 | ``` 68 | 69 | ### 推理 70 | 71 | * 在像素空间生成一个视频。 72 | 73 | ```python 74 | from litesora.models import SDXLTextEncoder2, VideoDiT 75 | from litesora.pipelines import PixelVideoDiTPipeline 76 | from litesora.data import save_video 77 | import torch 78 | 79 | 80 | # models 81 | text_encoder = SDXLTextEncoder2.from_diffusers("models/text_encoder/model.safetensors") 82 | denoising_model = VideoDiT.from_pretrained("models/denoising_model/model.safetensors") 83 | 84 | # pipeline 85 | pipe = PixelVideoDiTPipeline(torch_dtype=torch.float16, device="cuda") 86 | pipe.fetch_models(text_encoder, denoising_model) 87 | 88 | # generate a video 89 | prompt = "woman, flowers, plants, field, garden" 90 | video = pipe(prompt=prompt, num_inference_steps=100) 91 | 92 | # save the video (the resolution is 64*64, we enlarge it to 512*512 here) 93 | save_video(video, "output.mp4", upscale=8) 94 | ``` 95 | 96 | * 把一个视频从像素空间编码到隐空间,然后解码它 97 | 98 | ```python 99 | from litesora.models import SDVAEEncoder, SVDVAEDecoder 100 | from litesora.data import load_video, tensor2video, concat_video, save_video 101 | import torch 102 | from tqdm import tqdm 103 | 104 | 105 | frames = load_video("data/pixabay100/videos/168572 (Original).mp4", 106 | num_frames=1024, height=1024, width=1024, random_crop=False) 107 | frames = frames.to(dtype=torch.float16, device="cpu") 108 | 109 | encoder = SDVAEEncoder.from_diffusers("models/vae/model.safetensors").to(dtype=torch.float16, device="cuda") 110 | decoder = SVDVAEDecoder.from_diffusers("models/vae/model.safetensors").to(dtype=torch.float16, device="cuda") 111 | 112 | with torch.no_grad(): 113 | print(frames.shape) 114 | latents = encoder.encode_video(frames, progress_bar=tqdm) 115 | print(latents.shape) 116 | decoded_frames = decoder.decode_video(latents, progress_bar=tqdm) 117 | 118 | video = tensor2video(concat_video([frames, decoded_frames])) 119 | save_video(video, "video.mp4", fps=24) 120 | ``` 121 | 122 | ### 现阶段效果展示 123 | 124 | 我们在 [Pixabay100](https://github.com/ECNU-CILAB/Pixabay100) 数据集上训练的模型,该模型可以证明我们的训练代码能够正常拟合训练数据,分辨率为 64*64。**这个模型显然是过拟合的,由于训练数据少,尚不具备泛化能力,仅用于验证训练算法的正确性。** [下载链接](https://huggingface.co/ECNU-CILab/lite-sora-v1-pixabay100/resolve/main/denoising_model/model.safetensors) 125 | 126 | |airport, people, crowd, busy|beach, ocean, waves, water, sand|bee, honey, insect, beehive, nature|coffee, beans, caffeine, coffee, shop| 127 | |-|-|-|-| 128 | |![](assets/airport_people_crowd_busy.gif)|![](assets/beach_ocean_waves_water_sand.gif)|![](assets/bee_honey_insect_beehive_nature.gif)|![](assets/coffee_beans_caffeine_coffee_shop.gif)| 129 | |fish, underwater, aquarium, swim|forest, woods, mystical, morning|ocean, beach, sunset, sea, atmosphere|hair, wind, girl, woman, people| 130 | |![](assets/fish_underwater_aquarium_swim.gif)|![](assets/forest_woods_mystical_morning.gif)|![](assets/ocean_beach_sunset_sea_atmosphere.gif)|![](assets/hair_wind_girl_woman_people.gif)| 131 | |reeds, grass, wind, golden, sunshine|sea, ocean, seagulls, birds, sunset|woman, flowers, plants, field, garden|wood, anemones, wildflower, flower| 132 | |![](assets/reeds_grass_wind_golden_sunshine.gif)|![](assets/sea_ocean_seagulls_birds_sunset.gif)|![](assets/woman_flowers_plants_field_garden.gif)|![](assets/wood_anemones_wildflower_flower.gif)| 133 | 134 | 我们采用 [Stable-Video-Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt) 中的 VAE 模型来做视频的编码和解码。我们的代码支持非常长的高分辨率视频! 135 | 136 | https://github.com/modelscope/lite-sora/assets/35051019/dc205719-d0bc-4bca-b117-ff5aa19ebd86 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lite-Sora 2 | 3 | ## Introduction 4 | 5 | The lite-sora project is an initiative to replicate Sora, co-launched by East China Normal University and the ModelScope community. It aims to explore the minimal reproduction and streamlined implementation of the video generation algorithms behind Sora. We hope to provide concise and readable code to facilitate collective experimentation and improvement, continuously pushing the boundaries of open-source video generation technology. 6 | 7 | ## Roadmap 8 | 9 | * [x] Implement the base architecture 10 | * [ ] Models 11 | * [x] Text Encoder(based on Stable Diffusion XL's [Text Encoder](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/text_encoder_2/model.safetensors)) 12 | * [x] VideoDiT(based on [Facebook DiT](https://github.com/facebookresearch/DiT)) 13 | * [ ] VideoVAE 14 | * [x] Scheduler(based on [DDIM](https://arxiv.org/abs/2010.02502)) 15 | * [x] Trainer(based on [PyTorch-lightning](https://lightning.ai/docs/pytorch/stable/)) 16 | * [x] Validate on small datasets 17 | * [x] [Pixabay100](https://github.com/ECNU-CILAB/Pixabay100) 18 | * [ ] Train Video Encoder & Decoder on large datasets 19 | * [ ] Train VideoDiT on large datasets 20 | 21 | ## Usage 22 | 23 | ### Python Environment 24 | 25 | ``` 26 | conda env create -f environment.yml 27 | conda activate litesora 28 | ``` 29 | 30 | ### Download Models 31 | 32 | * `models/text_encoder/model.safetensors`: Stable Diffusion XL's Text Encoder. [download](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/text_encoder_2/model.safetensors) 33 | * `models/denoising_model/model.safetensors`:We trained a denoising model using a small dataset [Pixabay100](https://github.com/ECNU-CILAB/Pixabay100). This model serves to demonstrate that our training code is capable of fitting the training data properly, with a resolution of 64*64. **Obviously this model is overfitting due to the limited amount of training data, and thus it lacks generalization capability at this stage. Its purpose is solely for verifying the correctness of the training algorithm.** [download](https://huggingface.co/ECNU-CILab/lite-sora-v1-pixabay100/resolve/main/denoising_model/model.safetensors) 34 | * `models/vae/model.safetensors`: Stable Video Diffusion's VAE. [download](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/resolve/main/vae/diffusion_pytorch_model.fp16.safetensors) 35 | 36 | ### Training 37 | 38 | ```python 39 | from litesora.data import TextVideoDataset 40 | from litesora.models import SDXLTextEncoder2 41 | from litesora.trainers.v1 import LightningVideoDiT 42 | import lightning as pl 43 | import torch 44 | 45 | 46 | if __name__ == '__main__': 47 | # dataset and data loader 48 | dataset = TextVideoDataset("data/pixabay100", "data/pixabay100/metadata.json", 49 | num_frames=64, height=64, width=64) 50 | train_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=8) 51 | 52 | # model 53 | model = LightningVideoDiT(learning_rate=1e-5) 54 | model.text_encoder.load_state_dict_from_diffusers("models/text_encoder/model.safetensors") 55 | 56 | # train 57 | trainer = pl.Trainer(max_epochs=100000, accelerator="gpu", devices="auto", callbacks=[ 58 | pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1) 59 | ]) 60 | trainer.fit(model=model, train_dataloaders=train_loader) 61 | ``` 62 | 63 | While the training program is running, you can launch `tensorboard` to see the training loss. 64 | 65 | ``` 66 | tensorboard --logdir . 67 | ``` 68 | 69 | ### Inference 70 | 71 | * Synthesize a video in the pixel space. 72 | 73 | ```python 74 | from litesora.models import SDXLTextEncoder2, VideoDiT 75 | from litesora.pipelines import PixelVideoDiTPipeline 76 | from litesora.data import save_video 77 | import torch 78 | 79 | 80 | # models 81 | text_encoder = SDXLTextEncoder2.from_diffusers("models/text_encoder/model.safetensors") 82 | denoising_model = VideoDiT.from_pretrained("models/denoising_model/model.safetensors") 83 | 84 | # pipeline 85 | pipe = PixelVideoDiTPipeline(torch_dtype=torch.float16, device="cuda") 86 | pipe.fetch_models(text_encoder, denoising_model) 87 | 88 | # generate a video 89 | prompt = "woman, flowers, plants, field, garden" 90 | video = pipe(prompt=prompt, num_inference_steps=100) 91 | 92 | # save the video (the resolution is 64*64, we enlarge it to 512*512 here) 93 | save_video(video, "output.mp4", upscale=8) 94 | ``` 95 | 96 | * Encode a video into the latent space, and then decode it. 97 | 98 | 99 | ```python 100 | from litesora.models import SDVAEEncoder, SVDVAEDecoder 101 | from litesora.data import load_video, tensor2video, concat_video, save_video 102 | import torch 103 | from tqdm import tqdm 104 | 105 | 106 | frames = load_video("data/pixabay100/videos/168572 (Original).mp4", 107 | num_frames=1024, height=1024, width=1024, random_crop=False) 108 | frames = frames.to(dtype=torch.float16, device="cpu") 109 | 110 | encoder = SDVAEEncoder.from_diffusers("models/vae/model.safetensors").to(dtype=torch.float16, device="cuda") 111 | decoder = SVDVAEDecoder.from_diffusers("models/vae/model.safetensors").to(dtype=torch.float16, device="cuda") 112 | 113 | with torch.no_grad(): 114 | print(frames.shape) 115 | latents = encoder.encode_video(frames, progress_bar=tqdm) 116 | print(latents.shape) 117 | decoded_frames = decoder.decode_video(latents, progress_bar=tqdm) 118 | 119 | video = tensor2video(concat_video([frames, decoded_frames])) 120 | save_video(video, "video.mp4", fps=24) 121 | ``` 122 | 123 | 124 | ### Results (Experimental) 125 | 126 | We trained a denoising model using a small dataset [Pixabay100](https://github.com/ECNU-CILAB/Pixabay100). This model serves to demonstrate that our training code is capable of fitting the training data properly, with a resolution of 64*64. **Obviously this model is overfitting due to the limited amount of training data, and thus it lacks generalization capability at this stage. Its purpose is solely for verifying the correctness of the training algorithm.** [download](https://huggingface.co/ECNU-CILab/lite-sora-v1-pixabay100/resolve/main/denoising_model/model.safetensors) 127 | 128 | |airport, people, crowd, busy|beach, ocean, waves, water, sand|bee, honey, insect, beehive, nature|coffee, beans, caffeine, coffee, shop| 129 | |-|-|-|-| 130 | |![](assets/airport_people_crowd_busy.gif)|![](assets/beach_ocean_waves_water_sand.gif)|![](assets/bee_honey_insect_beehive_nature.gif)|![](assets/coffee_beans_caffeine_coffee_shop.gif)| 131 | |fish, underwater, aquarium, swim|forest, woods, mystical, morning|ocean, beach, sunset, sea, atmosphere|hair, wind, girl, woman, people| 132 | |![](assets/fish_underwater_aquarium_swim.gif)|![](assets/forest_woods_mystical_morning.gif)|![](assets/ocean_beach_sunset_sea_atmosphere.gif)|![](assets/hair_wind_girl_woman_people.gif)| 133 | |reeds, grass, wind, golden, sunshine|sea, ocean, seagulls, birds, sunset|woman, flowers, plants, field, garden|wood, anemones, wildflower, flower| 134 | |![](assets/reeds_grass_wind_golden_sunshine.gif)|![](assets/sea_ocean_seagulls_birds_sunset.gif)|![](assets/woman_flowers_plants_field_garden.gif)|![](assets/wood_anemones_wildflower_flower.gif)| 135 | 136 | We leverage the VAE model from [Stable-Video-Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt) to encode videos to the latent space. Our code supports extremely long high-resolution videos! 137 | 138 | https://github.com/modelscope/lite-sora/assets/35051019/dc205719-d0bc-4bca-b117-ff5aa19ebd86 139 | -------------------------------------------------------------------------------- /assets/airport_people_crowd_busy.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelscope/lite-sora/75caae14a22f4a127cc36c13d8171376450d7887/assets/airport_people_crowd_busy.gif -------------------------------------------------------------------------------- /assets/beach_ocean_waves_water_sand.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelscope/lite-sora/75caae14a22f4a127cc36c13d8171376450d7887/assets/beach_ocean_waves_water_sand.gif -------------------------------------------------------------------------------- /assets/bee_honey_insect_beehive_nature.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelscope/lite-sora/75caae14a22f4a127cc36c13d8171376450d7887/assets/bee_honey_insect_beehive_nature.gif -------------------------------------------------------------------------------- /assets/coffee_beans_caffeine_coffee_shop.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelscope/lite-sora/75caae14a22f4a127cc36c13d8171376450d7887/assets/coffee_beans_caffeine_coffee_shop.gif -------------------------------------------------------------------------------- /assets/fish_underwater_aquarium_swim.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelscope/lite-sora/75caae14a22f4a127cc36c13d8171376450d7887/assets/fish_underwater_aquarium_swim.gif -------------------------------------------------------------------------------- /assets/forest_woods_mystical_morning.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelscope/lite-sora/75caae14a22f4a127cc36c13d8171376450d7887/assets/forest_woods_mystical_morning.gif -------------------------------------------------------------------------------- /assets/hair_wind_girl_woman_people.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelscope/lite-sora/75caae14a22f4a127cc36c13d8171376450d7887/assets/hair_wind_girl_woman_people.gif -------------------------------------------------------------------------------- /assets/ocean_beach_sunset_sea_atmosphere.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelscope/lite-sora/75caae14a22f4a127cc36c13d8171376450d7887/assets/ocean_beach_sunset_sea_atmosphere.gif -------------------------------------------------------------------------------- /assets/reeds_grass_wind_golden_sunshine.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelscope/lite-sora/75caae14a22f4a127cc36c13d8171376450d7887/assets/reeds_grass_wind_golden_sunshine.gif -------------------------------------------------------------------------------- /assets/sea_ocean_seagulls_birds_sunset.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelscope/lite-sora/75caae14a22f4a127cc36c13d8171376450d7887/assets/sea_ocean_seagulls_birds_sunset.gif -------------------------------------------------------------------------------- /assets/woman_flowers_plants_field_garden.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelscope/lite-sora/75caae14a22f4a127cc36c13d8171376450d7887/assets/woman_flowers_plants_field_garden.gif -------------------------------------------------------------------------------- /assets/wood_anemones_wildflower_flower.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelscope/lite-sora/75caae14a22f4a127cc36c13d8171376450d7887/assets/wood_anemones_wildflower_flower.gif -------------------------------------------------------------------------------- /configs/stable_diffusion/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "<|startoftext|>", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "<|endoftext|>", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": "<|endoftext|>", 17 | "unk_token": { 18 | "content": "<|endoftext|>", 19 | "lstrip": false, 20 | "normalized": true, 21 | "rstrip": false, 22 | "single_word": false 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /configs/stable_diffusion/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "bos_token": { 4 | "__type": "AddedToken", 5 | "content": "<|startoftext|>", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false 10 | }, 11 | "do_lower_case": true, 12 | "eos_token": { 13 | "__type": "AddedToken", 14 | "content": "<|endoftext|>", 15 | "lstrip": false, 16 | "normalized": true, 17 | "rstrip": false, 18 | "single_word": false 19 | }, 20 | "errors": "replace", 21 | "model_max_length": 77, 22 | "name_or_path": "openai/clip-vit-large-patch14", 23 | "pad_token": "<|endoftext|>", 24 | "special_tokens_map_file": "./special_tokens_map.json", 25 | "tokenizer_class": "CLIPTokenizer", 26 | "unk_token": { 27 | "__type": "AddedToken", 28 | "content": "<|endoftext|>", 29 | "lstrip": false, 30 | "normalized": true, 31 | "rstrip": false, 32 | "single_word": false 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "<|startoftext|>", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "<|endoftext|>", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": "!", 17 | "unk_token": { 18 | "content": "<|endoftext|>", 19 | "lstrip": false, 20 | "normalized": true, 21 | "rstrip": false, 22 | "single_word": false 23 | } 24 | } -------------------------------------------------------------------------------- /configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "added_tokens_decoder": { 4 | "0": { 5 | "content": "!", 6 | "lstrip": false, 7 | "normalized": false, 8 | "rstrip": false, 9 | "single_word": false, 10 | "special": true 11 | }, 12 | "49406": { 13 | "content": "<|startoftext|>", 14 | "lstrip": false, 15 | "normalized": true, 16 | "rstrip": false, 17 | "single_word": false, 18 | "special": true 19 | }, 20 | "49407": { 21 | "content": "<|endoftext|>", 22 | "lstrip": false, 23 | "normalized": true, 24 | "rstrip": false, 25 | "single_word": false, 26 | "special": true 27 | } 28 | }, 29 | "bos_token": "<|startoftext|>", 30 | "clean_up_tokenization_spaces": true, 31 | "do_lower_case": true, 32 | "eos_token": "<|endoftext|>", 33 | "errors": "replace", 34 | "model_max_length": 77, 35 | "pad_token": "!", 36 | "tokenizer_class": "CLIPTokenizer", 37 | "unk_token": "<|endoftext|>" 38 | } -------------------------------------------------------------------------------- /data/pixabay100/metadata.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "path": "videos/dubrovnik_-_12866 (1080p).mp4", 4 | "text": "dubrovnik, sunset, sea, city" 5 | }, 6 | { 7 | "path": "videos/seoul_-_21985 (Original).mp4", 8 | "text": "seoul, city, streets, road, night, view" 9 | }, 10 | { 11 | "path": "videos/daffodil_-_34826 (1080p).mp4", 12 | "text": "easter, bell, daffodil, easter, spring" 13 | }, 14 | { 15 | "path": "videos/airport_-_36510 (Original).mp4", 16 | "text": "airport, people, crowd, busy" 17 | }, 18 | { 19 | "path": "videos/surf_-_36609 (1080p).mp4", 20 | "text": "surf, sea, ocean, wave, beach, water" 21 | }, 22 | { 23 | "path": "videos/ijen_-_37079 (1080p).mp4", 24 | "text": "ijen, crater, indonesia, mine, volcano" 25 | }, 26 | { 27 | "path": "videos/waterfall_-_37088 (1080p).mp4", 28 | "text": "waterfall, water, river, nature" 29 | }, 30 | { 31 | "path": "videos/bee_-_37144 (Original).mp4", 32 | "text": "bee, oilseed, rape, blossom, bloom" 33 | }, 34 | { 35 | "path": "videos/underwater_-_37712 (Original).mp4", 36 | "text": "underwater, sea, ocean, water" 37 | }, 38 | { 39 | "path": "videos/woman_-_38084 (Original).mp4", 40 | "text": "woman, phone, smartphone, technology" 41 | }, 42 | { 43 | "path": "videos/bee_-_39116 (Original).mp4", 44 | "text": "bee, honey, insect, beehive, nature" 45 | }, 46 | { 47 | "path": "videos/cycling_-_39183 (Original).mp4", 48 | "text": "cycling, man, bike, helmet, road" 49 | }, 50 | { 51 | "path": "videos/nature_-_42420 (Original).mp4", 52 | "text": "nature, rain, plant, water, garden" 53 | }, 54 | { 55 | "path": "videos/china_-_43238 (Original).mp4", 56 | "text": "china, the, poet, tea, culture, read" 57 | }, 58 | { 59 | "path": "videos/girl_-_43459 (1080p).mp4", 60 | "text": "girl, writing, student, female, book" 61 | }, 62 | { 63 | "path": "videos/hair_-_43633 (Original).mp4", 64 | "text": "hair, wind, girl, woman, people" 65 | }, 66 | { 67 | "path": "videos/river_-_44509 (1080p).mp4", 68 | "text": "river, urban, city, architecture" 69 | }, 70 | { 71 | "path": "videos/fishing_-_44534 (Original).mp4", 72 | "text": "fishing, fish, fisherman, water" 73 | }, 74 | { 75 | "path": "videos/trail_-_44702 (1080p).mp4", 76 | "text": "path, people, helicopter, mountains" 77 | }, 78 | { 79 | "path": "videos/creux_de_van_-_45150 (Original).mp4", 80 | "text": "clouds, mountains, panorama, outdoor" 81 | }, 82 | { 83 | "path": "videos/coffee_-_45358 (Original).mp4", 84 | "text": "coffee, beans, caffeine, coffee, shop" 85 | }, 86 | { 87 | "path": "videos/glacier_express_-_45569 (1080p).mp4", 88 | "text": "glacier, express, bernina, express, rhb" 89 | }, 90 | { 91 | "path": "videos/bird_-_46026 (Original).mp4", 92 | "text": "bird, parrot, nature, animal, colorful" 93 | }, 94 | { 95 | "path": "videos/alone_-_46637 (Original).mp4", 96 | "text": "alone, person, man, people, guitar" 97 | }, 98 | { 99 | "path": "videos/coffee_-_46989 (1080p).mp4", 100 | "text": "coffee, hot, drink, coffee, pot" 101 | }, 102 | { 103 | "path": "videos/windmill_-_47905 (1080p).mp4", 104 | "text": "windmill, landscape, energy, sky" 105 | }, 106 | { 107 | "path": "videos/golden_-_48569 (1080p).mp4", 108 | "text": "golden, particles, overlay, decoration" 109 | }, 110 | { 111 | "path": "videos/forest_-_49981 (Original).mp4", 112 | "text": "forest, woods, mystical, morning" 113 | }, 114 | { 115 | "path": "videos/sea_-_53127 (Original).mp4", 116 | "text": "sea, beach, sunset, tree" 117 | }, 118 | { 119 | "path": "videos/fog_-_53358 (1080p).mp4", 120 | "text": "fog, mountains, road, foggy, dramatic" 121 | }, 122 | { 123 | "path": "videos/trafic_-_53902 (Original).mp4", 124 | "text": "traffic, city, traffic, drone" 125 | }, 126 | { 127 | "path": "videos/sheep_-_57647 (1080p).mp4", 128 | "text": "sheep, morning, sunray, nature" 129 | }, 130 | { 131 | "path": "videos/road_-_57993 (1080p).mp4", 132 | "text": "road, highway, drive, car, alps" 133 | }, 134 | { 135 | "path": "videos/woman_-_58142 (Original).mp4", 136 | "text": "woman, flowers, plants, field, garden" 137 | }, 138 | { 139 | "path": "videos/mountains_-_59291 (1080p).mp4", 140 | "text": "mountains, fog, clouds, sunset" 141 | }, 142 | { 143 | "path": "videos/sunflowers_-_59483 (1080p).mp4", 144 | "text": "sunflowers, bees, field, bloom, summer" 145 | }, 146 | { 147 | "path": "videos/mountains_-_61818 (Original).mp4", 148 | "text": "mountains, trees, woods, forest" 149 | }, 150 | { 151 | "path": "videos/ocean_-_62249 (Original).mp4", 152 | "text": "ocean, sunset, sea, beach, coast" 153 | }, 154 | { 155 | "path": "videos/woman_-_63241 (Original).mp4", 156 | "text": "woman, mask, healthcare, young" 157 | }, 158 | { 159 | "path": "videos/clouds_-_64759 (1080p).mp4", 160 | "text": "clouds, cumulus, sky, fluffy, outdoors" 161 | }, 162 | { 163 | "path": "videos/clouds_-_64767 (1080p).mp4", 164 | "text": "clouds, cumulus, sky, fluffy, outdoors" 165 | }, 166 | { 167 | "path": "videos/record_-_65390 (1080p).mp4", 168 | "text": "record, record, player, vinyl" 169 | }, 170 | { 171 | "path": "videos/cat_-_65438 (1080p).mp4", 172 | "text": "cat, feline, animal, mammal, kitten" 173 | }, 174 | { 175 | "path": "videos/ocean_-_65560 (Original).mp4", 176 | "text": "ocean, waves, water, shore, florida" 177 | }, 178 | { 179 | "path": "videos/highland_cows_-_65903 (1080p).mp4", 180 | "text": "highland, cows, cows, cattle, scotland" 181 | }, 182 | { 183 | "path": "videos/grass_-_66810 (1080p).mp4", 184 | "text": "grass, dew, field, fog, morning" 185 | }, 186 | { 187 | "path": "videos/flowers_-_66823 (1080p).mp4", 188 | "text": "flowers, cherry, flower, petals, bloom" 189 | }, 190 | { 191 | "path": "videos/lake_-_67201 (1080p).mp4", 192 | "text": "lake, houses, hill, mountain, boat" 193 | }, 194 | { 195 | "path": "videos/ink_-_67358 (1080p).mp4", 196 | "text": "ink, water, underwater, foam, smoke" 197 | }, 198 | { 199 | "path": "videos/waves_-_70796 (1080p).mp4", 200 | "text": "waves, ocean, sea, beach, byron, bay" 201 | }, 202 | { 203 | "path": "videos/waves_-_71122 (1080p).mp4", 204 | "text": "waves, sea, ocean, storm, water, tide" 205 | }, 206 | { 207 | "path": "videos/blackthorn_-_71457 (1080p).mp4", 208 | "text": "blackthorn, bud, blossom, bloom" 209 | }, 210 | { 211 | "path": "videos/bike_-_72566 (1080p).mp4", 212 | "text": "bike, beach, biker, mountain, bike" 213 | }, 214 | { 215 | "path": "videos/flowers_-_72763 (1080p).mp4", 216 | "text": "flower, dandelion, wild, flowers" 217 | }, 218 | { 219 | "path": "videos/sand_-_73847 (1080p).mp4", 220 | "text": "sand, hand, beach, dessert, sea" 221 | }, 222 | { 223 | "path": "videos/ocean_-_74888 (1080p).mp4", 224 | "text": "ocean, birds, waves, beach, sunset" 225 | }, 226 | { 227 | "path": "videos/water_lilies_-_75008 (1080p).mp4", 228 | "text": "water, lilies, lotus, pond" 229 | }, 230 | { 231 | "path": "videos/skyscrapers_-_80724 (1080p).mp4", 232 | "text": "skyscrapers, sunset, skyline, clouds" 233 | }, 234 | { 235 | "path": "videos/mountains_-_81945 (1080p).mp4", 236 | "text": "mountains, parachute, paragliding" 237 | }, 238 | { 239 | "path": "videos/sunrise_-_83880 (Original).mp4", 240 | "text": "sunrise, sunbeams, trees, woods" 241 | }, 242 | { 243 | "path": "videos/people_-_84973 (1080p).mp4", 244 | "text": "people, sunset, stroll, moscow, glare" 245 | }, 246 | { 247 | "path": "videos/aurora_borealis_-_90877 (1080p).mp4", 248 | "text": "aurora, borealis, northern, lights" 249 | }, 250 | { 251 | "path": "videos/mountains_-_91545 (1080p).mp4", 252 | "text": "mountains, alps, fog, snow, clouds" 253 | }, 254 | { 255 | "path": "videos/skyscrapers_-_91744 (1080p).mp4", 256 | "text": "skyscrapers, buildings, city, urban" 257 | }, 258 | { 259 | "path": "videos/skate_-_110734 (1080p).mp4", 260 | "text": "skate, sport, water, action, exercise" 261 | }, 262 | { 263 | "path": "videos/sakura_-_110790 (Original).mp4", 264 | "text": "sakura, flowers, spring, nature, japan" 265 | }, 266 | { 267 | "path": "videos/jellyfish_-_110877 (Original).mp4", 268 | "text": "jellyfish, underwater, ocean, water" 269 | }, 270 | { 271 | "path": "videos/fish_-_110879 (Original).mp4", 272 | "text": "fish, underwater, aquarium, swim" 273 | }, 274 | { 275 | "path": "videos/forest_-_111101 (Original).mp4", 276 | "text": "forest, nature, fog, mist, morning" 277 | }, 278 | { 279 | "path": "videos/sunset_-_111204 (1080p).mp4", 280 | "text": "sunset, nature, sky, dubai" 281 | }, 282 | { 283 | "path": "videos/beach_-_111263 (1080p).mp4", 284 | "text": "beach, ocean, waves, water, sand" 285 | }, 286 | { 287 | "path": "videos/nature_-_111508 (Original).mp4", 288 | "text": "nature, mountains, grass, flowers" 289 | }, 290 | { 291 | "path": "videos/daffodils_-_112389 (1080p).mp4", 292 | "text": "daffodils, easter, bells, spring" 293 | }, 294 | { 295 | "path": "videos/wood_anemones_-_112429 (Original).mp4", 296 | "text": "wood, anemones, wildflower, flower" 297 | }, 298 | { 299 | "path": "videos/woman_of_the_sea_-_112722 (Original).mp4", 300 | "text": "live, wallpaper, sea, water, waves" 301 | }, 302 | { 303 | "path": "videos/windmill_-_112957 (Original).mp4", 304 | "text": "windmill, turbine, energy" 305 | }, 306 | { 307 | "path": "videos/blossoms_-_113004 (Original).mp4", 308 | "text": "blooms, tree, blossoms, spring, branch" 309 | }, 310 | { 311 | "path": "videos/frog_-_113403 (Original).mp4", 312 | "text": "frog, reeds, green" 313 | }, 314 | { 315 | "path": "videos/135658 (1080p).mp4", 316 | "text": "ocean, beach, sunset, sea, atmosphere" 317 | }, 318 | { 319 | "path": "videos/138588 (1080p).mp4", 320 | "text": "sea, ocean, beach, waves, sun, dusk" 321 | }, 322 | { 323 | "path": "videos/140111 (1080p).mp4", 324 | "text": "sea, ocean, seagulls, birds, sunset" 325 | }, 326 | { 327 | "path": "videos/141964 (Original).mp4", 328 | "text": "christmas, christmas, decorations" 329 | }, 330 | { 331 | "path": "videos/142579 (Original).mp4", 332 | "text": "pine, forest, snow, winter, snowfall" 333 | }, 334 | { 335 | "path": "videos/146169 (1080p).mp4", 336 | "text": "clouds, sunlight, view, drone, nature" 337 | }, 338 | { 339 | "path": "videos/152740 (Original).mp4", 340 | "text": "reeds, grass, wind, golden, sunshine" 341 | }, 342 | { 343 | "path": "videos/153167-804706404 (1080p).mp4", 344 | "text": "beach, sea, sand, island, elafonisos" 345 | }, 346 | { 347 | "path": "videos/153976-806571973 (1080p).mp4", 348 | "text": "sunset, sea, sun, evening, atmosphere" 349 | }, 350 | { 351 | "path": "videos/158349 (1080p).mp4", 352 | "text": "beach, waves, natural, scenery, morning" 353 | }, 354 | { 355 | "path": "videos/158384 (1080p).mp4", 356 | "text": "clouds, cloudscape, wind, nature" 357 | }, 358 | { 359 | "path": "videos/158980 (1080p).mp4", 360 | "text": "lion, animal, wildlife, safari, lions" 361 | }, 362 | { 363 | "path": "videos/159627 (1080p).mp4", 364 | "text": "flow, rocks, water, fluent, stones" 365 | }, 366 | { 367 | "path": "videos/160767 (1080p).mp4", 368 | "text": "ocean, beach, waves, breaking, waves" 369 | }, 370 | { 371 | "path": "videos/161071 (1080p).mp4", 372 | "text": "river, road, mountain, forest, nature" 373 | }, 374 | { 375 | "path": "videos/161178 (Original).mp4", 376 | "text": "flower, gemswurz, yellow" 377 | }, 378 | { 379 | "path": "videos/163333 (1080p).mp4", 380 | "text": "dandelion, blossom, bloom, faded" 381 | }, 382 | { 383 | "path": "videos/163869 (1080p).mp4", 384 | "text": "flower, water, lilies, lake" 385 | }, 386 | { 387 | "path": "videos/164215 (1080p).mp4", 388 | "text": "flower, poppy, wind, wildflower, plant" 389 | }, 390 | { 391 | "path": "videos/164360 (1080p).mp4", 392 | "text": "fogging, landscape, forest, grasslands" 393 | }, 394 | { 395 | "path": "videos/166808 (Original).mp4", 396 | "text": "carpenter, tool, handmade, industry" 397 | }, 398 | { 399 | "path": "videos/168572 (Original).mp4", 400 | "text": "snowball, flower, white, spherical" 401 | } 402 | ] -------------------------------------------------------------------------------- /data/pixabay100/videos/Put videos here.txt: -------------------------------------------------------------------------------- 1 | Video urls can be found here: https://github.com/ECNU-CILAB/Pixabay100. 2 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: litesora 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - python=3.9.16 8 | - pip=23.0.1 9 | - cudatoolkit 10 | - pytorch==2.1.0 11 | - lightning 12 | - transformers 13 | - pip: 14 | - einops 15 | - imageio 16 | - imageio[ffmpeg] 17 | - tensorboard 18 | -------------------------------------------------------------------------------- /litesora/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .text_video_dataset import TextVideoDataset 3 | -------------------------------------------------------------------------------- /litesora/data/text_video_dataset.py: -------------------------------------------------------------------------------- 1 | import torch, json, os, imageio 2 | from einops import rearrange 3 | from PIL import Image 4 | import numpy as np 5 | 6 | 7 | class TextVideoDataset(torch.utils.data.Dataset): 8 | def __init__(self, base_path, metadata_path, num_frames=64, height=64, width=64): 9 | with open(metadata_path, "r") as f: 10 | metadata = json.load(f) 11 | self.path = [os.path.join(base_path, i["path"]) for i in metadata] 12 | self.text = [i["text"] for i in metadata] 13 | self.num_frames = num_frames 14 | self.height = height 15 | self.width = width 16 | 17 | def crop_and_resize(self, image, height, width): 18 | image = np.array(image) 19 | image_height, image_width, _ = image.shape 20 | if image_height / image_width < height / width: 21 | croped_width = int(image_height / height * width) 22 | left = (image_width - croped_width) // 2 23 | image = image[:, left: left+croped_width] 24 | image = Image.fromarray(image).resize((width, height)) 25 | else: 26 | croped_height = int(image_width / width * height) 27 | left = (image_height - croped_height) // 2 28 | image = image[left: left+croped_height, :] 29 | image = Image.fromarray(image).resize((width, height)) 30 | return image 31 | 32 | def load_video(self, file_path, num_frames, height, width): 33 | frames = [] 34 | reader = imageio.get_reader(file_path) 35 | for frame in reader: 36 | frame = self.crop_and_resize(frame, height, width) 37 | frames.append(frame) 38 | if len(frames)>=num_frames: 39 | break 40 | frames = torch.tensor(np.stack(frames)) 41 | reader.close() 42 | return frames 43 | 44 | def process_video_frames(self, frames): 45 | frames = frames / 127.5 - 1 46 | frames = rearrange(frames, "T H W C -> C T H W") 47 | return frames 48 | 49 | def __getitem__(self, index): 50 | video_file = self.path[index % len(self.path)] 51 | text = self.text[index % len(self.path)] 52 | frames = self.load_video(video_file, self.num_frames, self.height, self.width) 53 | frames = self.process_video_frames(frames) 54 | return {"frames": frames, "text": text} 55 | 56 | def __len__(self): 57 | return len(self.path) 58 | -------------------------------------------------------------------------------- /litesora/data/utils.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import numpy as np 3 | from PIL import Image 4 | import torch 5 | from einops import rearrange 6 | 7 | 8 | def save_video(frames, save_path, fps=30, quality=5, upscale=1): 9 | height, width, _ = frames[0].shape 10 | writer = imageio.get_writer(save_path, fps=fps, quality=quality) 11 | for frame in frames: 12 | frame = np.array(Image.fromarray(frame).resize((width*upscale, height*upscale), Image.NEAREST)) 13 | writer.append_data(frame) 14 | writer.close() 15 | 16 | 17 | def crop_and_resize(image, height, width, start_height, start_width): 18 | image = np.array(image) 19 | image_height, image_width, _ = image.shape 20 | if image_height / image_width < height / width: 21 | croped_width = int(image_height / height * width) 22 | left = start_width 23 | image = image[:, left: left+croped_width] 24 | image = Image.fromarray(image).convert("RGB").resize((width, height)) 25 | else: 26 | croped_height = int(image_width / width * height) 27 | left = start_height 28 | image = image[left: left+croped_height, :] 29 | image = Image.fromarray(image).convert("RGB").resize((width, height)) 30 | return image 31 | 32 | 33 | def load_video(file_path, num_frames, height, width, random_crop=True): 34 | frames = [] 35 | reader = imageio.get_reader(file_path) 36 | if reader.count_frames() < num_frames: 37 | return None 38 | if random_crop: 39 | start_frame = torch.randint(0, reader.count_frames() - num_frames + 1, (1,))[0] 40 | else: 41 | start_frame = 0 42 | w, h = reader.get_meta_data()["size"] 43 | if width / height < w / h: 44 | position = torch.rand(1)[0] if random_crop else 0.5 45 | start_width = int(position * (w - h / height * width)) 46 | start_height = 0 47 | else: 48 | start_width = 0 49 | position = torch.rand(1)[0] if random_crop else 0.5 50 | start_height = int(position * (h - w / width * height)) 51 | for frame_id in range(start_frame, start_frame + num_frames): 52 | frame = reader.get_data(frame_id) 53 | frame = crop_and_resize(frame, height, width, start_height, start_width) 54 | frames.append(frame) 55 | frames = torch.tensor(np.stack(frames)) 56 | frames = frames / 127.5 - 1 57 | frames = rearrange(frames, "T H W C -> C T H W") 58 | reader.close() 59 | return frames 60 | 61 | 62 | def tensor2video(frames): 63 | frames = rearrange(frames, "C T H W -> T H W C") 64 | frames = ((frames + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8) 65 | return frames 66 | 67 | 68 | def concat_video(videos): 69 | video = torch.concat(videos, dim=-1) 70 | return video 71 | -------------------------------------------------------------------------------- /litesora/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .sdxl_text_encoder_2 import SDXLTextEncoder2 2 | from .video_dit import VideoDiT 3 | from .svd_vae import SDVAEEncoder, SVDVAEDecoder 4 | -------------------------------------------------------------------------------- /litesora/models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Attention(torch.nn.Module): 5 | 6 | def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): 7 | super().__init__() 8 | dim_inner = head_dim * num_heads 9 | kv_dim = kv_dim if kv_dim is not None else q_dim 10 | self.num_heads = num_heads 11 | self.head_dim = head_dim 12 | 13 | self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) 14 | self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) 15 | self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) 16 | self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) 17 | 18 | def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): 19 | if encoder_hidden_states is None: 20 | encoder_hidden_states = hidden_states 21 | 22 | batch_size = encoder_hidden_states.shape[0] 23 | 24 | q = self.to_q(hidden_states) 25 | k = self.to_k(encoder_hidden_states) 26 | v = self.to_v(encoder_hidden_states) 27 | 28 | q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 29 | k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 30 | v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) 31 | 32 | hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) 33 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) 34 | hidden_states = hidden_states.to(q.dtype) 35 | 36 | hidden_states = self.to_out(hidden_states) 37 | 38 | return hidden_states 39 | 40 | def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): 41 | return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask) 42 | -------------------------------------------------------------------------------- /litesora/models/patchify.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import repeat, rearrange 3 | 4 | 5 | def get_pos_embedding_1d(embed_dim, pos): 6 | omega = torch.arange(embed_dim // 2).to(torch.float64) * (2.0 / embed_dim) 7 | omega = 1.0 / 10000**omega 8 | 9 | pos = pos.reshape(-1) 10 | out = torch.einsum("m,d->md", pos, omega) 11 | emb_sin = torch.sin(out) 12 | emb_cos = torch.cos(out) 13 | 14 | emb = torch.concatenate([emb_sin, emb_cos], axis=1) 15 | return emb 16 | 17 | 18 | def get_pos_embedding_2d(embed_dim, grid_size, base_size=16): 19 | grid_h = torch.arange(grid_size[0]) / (grid_size[0] / base_size) 20 | grid_w = torch.arange(grid_size[1]) / (grid_size[1] / base_size) 21 | # In the original implementation of DiT, the h and w seem to be reversed. 22 | grid = torch.stack([ 23 | repeat(grid_w, "W -> H W", H=grid_size[0], W=grid_size[1]), 24 | repeat(grid_h, "H -> H W", H=grid_size[0], W=grid_size[1]), 25 | ]) 26 | pos_embed = get_pos_embedding_1d(embed_dim // 2, grid) 27 | pos_embed = rearrange(pos_embed, "(C N) D -> N (C D)", C=2) 28 | return pos_embed 29 | 30 | 31 | class PatchEmbed(torch.nn.Module): 32 | def __init__( 33 | self, 34 | base_size=16, 35 | patch_size=16, 36 | in_channels=3, 37 | embed_dim=768 38 | ): 39 | super().__init__() 40 | self.base_size = base_size 41 | self.patch_size = patch_size 42 | self.embed_dim = embed_dim 43 | self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True) 44 | 45 | def forward(self, latent): 46 | pos_embed = get_pos_embedding_2d( 47 | embed_dim=self.embed_dim, 48 | grid_size=(latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size), 49 | base_size=self.base_size 50 | ) 51 | pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) 52 | 53 | latent = self.proj(latent) 54 | latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC 55 | 56 | return (latent + pos_embed).to(latent.dtype) 57 | -------------------------------------------------------------------------------- /litesora/models/sdxl_text_encoder_2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .attention import Attention 3 | from .utils import load_state_dict 4 | 5 | 6 | class CLIPEncoderLayer(torch.nn.Module): 7 | def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True): 8 | super().__init__() 9 | self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True) 10 | self.layer_norm1 = torch.nn.LayerNorm(embed_dim) 11 | self.layer_norm2 = torch.nn.LayerNorm(embed_dim) 12 | self.fc1 = torch.nn.Linear(embed_dim, intermediate_size) 13 | self.fc2 = torch.nn.Linear(intermediate_size, embed_dim) 14 | 15 | self.use_quick_gelu = use_quick_gelu 16 | 17 | def quickGELU(self, x): 18 | return x * torch.sigmoid(1.702 * x) 19 | 20 | def forward(self, hidden_states, attn_mask=None): 21 | residual = hidden_states 22 | 23 | hidden_states = self.layer_norm1(hidden_states) 24 | hidden_states = self.attn(hidden_states, attn_mask=attn_mask) 25 | hidden_states = residual + hidden_states 26 | 27 | residual = hidden_states 28 | hidden_states = self.layer_norm2(hidden_states) 29 | hidden_states = self.fc1(hidden_states) 30 | if self.use_quick_gelu: 31 | hidden_states = self.quickGELU(hidden_states) 32 | else: 33 | hidden_states = torch.nn.functional.gelu(hidden_states) 34 | hidden_states = self.fc2(hidden_states) 35 | hidden_states = residual + hidden_states 36 | 37 | return hidden_states 38 | 39 | 40 | class SDXLTextEncoder2(torch.nn.Module): 41 | def __init__(self, embed_dim=1280, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=32, encoder_intermediate_size=5120): 42 | super().__init__() 43 | 44 | # token_embedding 45 | self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) 46 | 47 | # position_embeds (This is a fixed tensor) 48 | self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) 49 | 50 | # encoders 51 | self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=20, head_dim=64, use_quick_gelu=False) for _ in range(num_encoder_layers)]) 52 | 53 | # attn_mask 54 | self.attn_mask = self.attention_mask(max_position_embeddings) 55 | 56 | # final_layer_norm 57 | self.final_layer_norm = torch.nn.LayerNorm(embed_dim) 58 | 59 | # text_projection 60 | self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False) 61 | 62 | def attention_mask(self, length): 63 | mask = torch.empty(length, length) 64 | mask.fill_(float("-inf")) 65 | mask.triu_(1) 66 | return mask 67 | 68 | def forward(self, input_ids): 69 | embeds = self.token_embedding(input_ids) + self.position_embeds 70 | attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) 71 | for encoder_id, encoder in enumerate(self.encoders): 72 | embeds = encoder(embeds, attn_mask=attn_mask) 73 | embeds = self.final_layer_norm(embeds) 74 | pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)] 75 | pooled_embeds = self.text_projection(pooled_embeds) 76 | return pooled_embeds 77 | 78 | def load_state_dict_from_diffusers(self, file_path=None, state_dict=None): 79 | if state_dict is None: 80 | state_dict = load_state_dict(file_path) 81 | rename_dict = { 82 | "text_model.embeddings.token_embedding.weight": "token_embedding.weight", 83 | "text_model.embeddings.position_embedding.weight": "position_embeds", 84 | "text_model.final_layer_norm.weight": "final_layer_norm.weight", 85 | "text_model.final_layer_norm.bias": "final_layer_norm.bias", 86 | "text_projection.weight": "text_projection.weight" 87 | } 88 | attn_rename_dict = { 89 | "self_attn.q_proj": "attn.to_q", 90 | "self_attn.k_proj": "attn.to_k", 91 | "self_attn.v_proj": "attn.to_v", 92 | "self_attn.out_proj": "attn.to_out", 93 | "layer_norm1": "layer_norm1", 94 | "layer_norm2": "layer_norm2", 95 | "mlp.fc1": "fc1", 96 | "mlp.fc2": "fc2", 97 | } 98 | state_dict_ = {} 99 | for name in state_dict: 100 | if name in rename_dict: 101 | param = state_dict[name] 102 | if name == "text_model.embeddings.position_embedding.weight": 103 | param = param.reshape((1, param.shape[0], param.shape[1])) 104 | state_dict_[rename_dict[name]] = param 105 | elif name.startswith("text_model.encoder.layers."): 106 | param = state_dict[name] 107 | names = name.split(".") 108 | layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] 109 | name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) 110 | state_dict_[name_] = param 111 | self.load_state_dict(state_dict_) 112 | 113 | @staticmethod 114 | def from_diffusers(file_path=None, state_dict=None): 115 | model = SDXLTextEncoder2() 116 | model.eval() 117 | model.load_state_dict_from_diffusers(file_path, state_dict) 118 | return model 119 | 120 | -------------------------------------------------------------------------------- /litesora/models/svd_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .attention import Attention 3 | from .utils import load_state_dict 4 | from einops import rearrange, repeat 5 | 6 | 7 | class VAEAttentionBlock(torch.nn.Module): 8 | 9 | def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5): 10 | super().__init__() 11 | inner_dim = num_attention_heads * attention_head_dim 12 | 13 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True) 14 | 15 | self.transformer_blocks = torch.nn.ModuleList([ 16 | Attention( 17 | inner_dim, 18 | num_attention_heads, 19 | attention_head_dim, 20 | bias_q=True, 21 | bias_kv=True, 22 | bias_out=True 23 | ) 24 | for d in range(num_layers) 25 | ]) 26 | 27 | def forward(self, hidden_states): 28 | batch, _, height, width = hidden_states.shape 29 | residual = hidden_states 30 | 31 | hidden_states = self.norm(hidden_states) 32 | inner_dim = hidden_states.shape[1] 33 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 34 | 35 | for block in self.transformer_blocks: 36 | hidden_states = block(hidden_states) 37 | 38 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 39 | hidden_states = hidden_states + residual 40 | 41 | return hidden_states 42 | 43 | 44 | 45 | class ResnetBlock(torch.nn.Module): 46 | def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5): 47 | super().__init__() 48 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 49 | self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 50 | if temb_channels is not None: 51 | self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) 52 | self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) 53 | self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 54 | self.nonlinearity = torch.nn.SiLU() 55 | self.conv_shortcut = None 56 | if in_channels != out_channels: 57 | self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True) 58 | 59 | def forward(self, hidden_states): 60 | x = hidden_states 61 | x = self.norm1(x) 62 | x = self.nonlinearity(x) 63 | x = self.conv1(x) 64 | x = self.norm2(x) 65 | x = self.nonlinearity(x) 66 | x = self.conv2(x) 67 | if self.conv_shortcut is not None: 68 | hidden_states = self.conv_shortcut(hidden_states) 69 | hidden_states = hidden_states + x 70 | return hidden_states 71 | 72 | 73 | 74 | class DownSampler(torch.nn.Module): 75 | def __init__(self, channels, padding=1, extra_padding=False): 76 | super().__init__() 77 | self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding) 78 | self.extra_padding = extra_padding 79 | 80 | def forward(self, hidden_states): 81 | if self.extra_padding: 82 | hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0) 83 | hidden_states = self.conv(hidden_states) 84 | return hidden_states 85 | 86 | 87 | 88 | class UpSampler(torch.nn.Module): 89 | def __init__(self, channels): 90 | super().__init__() 91 | self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1) 92 | 93 | def forward(self, hidden_states): 94 | hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest") 95 | hidden_states = self.conv(hidden_states) 96 | return hidden_states 97 | 98 | 99 | 100 | class TemporalResnetBlock(torch.nn.Module): 101 | 102 | def __init__(self, in_channels, out_channels, groups=32, eps=1e-5): 103 | super().__init__() 104 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 105 | self.conv1 = torch.nn.Conv3d(in_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0)) 106 | self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) 107 | self.conv2 = torch.nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0)) 108 | self.nonlinearity = torch.nn.SiLU() 109 | self.mix_factor = torch.nn.Parameter(torch.Tensor([0.5])) 110 | 111 | def forward(self, hidden_states): 112 | x_spatial = hidden_states 113 | x = rearrange(hidden_states, "T C H W -> 1 C T H W") 114 | x = self.norm1(x) 115 | x = self.nonlinearity(x) 116 | x = self.conv1(x) 117 | x = self.norm2(x) 118 | x = self.nonlinearity(x) 119 | x = self.conv2(x) 120 | x_temporal = hidden_states + x[0].permute(1, 0, 2, 3) 121 | alpha = torch.sigmoid(self.mix_factor) 122 | hidden_states = alpha * x_temporal + (1 - alpha) * x_spatial 123 | return hidden_states 124 | 125 | 126 | 127 | class SDVAEEncoder(torch.nn.Module): 128 | def __init__(self): 129 | super().__init__() 130 | self.scaling_factor = 0.18215 131 | self.quant_conv = torch.nn.Conv2d(8, 8, kernel_size=1) 132 | self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1) 133 | 134 | self.blocks = torch.nn.ModuleList([ 135 | # DownEncoderBlock2D 136 | ResnetBlock(128, 128, eps=1e-6), 137 | ResnetBlock(128, 128, eps=1e-6), 138 | DownSampler(128, padding=0, extra_padding=True), 139 | # DownEncoderBlock2D 140 | ResnetBlock(128, 256, eps=1e-6), 141 | ResnetBlock(256, 256, eps=1e-6), 142 | DownSampler(256, padding=0, extra_padding=True), 143 | # DownEncoderBlock2D 144 | ResnetBlock(256, 512, eps=1e-6), 145 | ResnetBlock(512, 512, eps=1e-6), 146 | DownSampler(512, padding=0, extra_padding=True), 147 | # DownEncoderBlock2D 148 | ResnetBlock(512, 512, eps=1e-6), 149 | ResnetBlock(512, 512, eps=1e-6), 150 | # UNetMidBlock2D 151 | ResnetBlock(512, 512, eps=1e-6), 152 | VAEAttentionBlock(1, 512, 512, 1, eps=1e-6), 153 | ResnetBlock(512, 512, eps=1e-6), 154 | ]) 155 | 156 | self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6) 157 | self.conv_act = torch.nn.SiLU() 158 | self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1) 159 | 160 | 161 | def forward(self, sample): 162 | # 1. pre-process 163 | hidden_states = self.conv_in(sample) 164 | 165 | # 2. blocks 166 | for i, block in enumerate(self.blocks): 167 | hidden_states = block(hidden_states) 168 | 169 | # 3. output 170 | hidden_states = self.conv_norm_out(hidden_states) 171 | hidden_states = self.conv_act(hidden_states) 172 | hidden_states = self.conv_out(hidden_states) 173 | hidden_states = self.quant_conv(hidden_states) 174 | hidden_states = hidden_states[:, :4] 175 | hidden_states *= self.scaling_factor 176 | 177 | return hidden_states 178 | 179 | 180 | def encode_video(self, sample, batch_size=8, progress_bar=lambda x:x): 181 | data_device = sample.device 182 | computation_device = self.conv_in.weight.device 183 | hidden_states = [] 184 | sample = rearrange(sample, "C T H W -> T C H W") 185 | 186 | for i in progress_bar(range(0, sample.shape[0], batch_size)): 187 | hidden_states_batch = self.forward(sample[i: i+batch_size].to(computation_device)) 188 | hidden_states.append(hidden_states_batch.to(data_device)) 189 | 190 | hidden_states = torch.concat(hidden_states, dim=0) 191 | hidden_states = rearrange(hidden_states, "T C H W -> C T H W") 192 | return hidden_states 193 | 194 | 195 | def load_state_dict_from_diffusers(self, file_path=None, state_dict=None): 196 | if state_dict is None: 197 | state_dict = load_state_dict(file_path) 198 | 199 | # architecture 200 | block_types = [ 201 | 'ResnetBlock', 'ResnetBlock', 'DownSampler', 202 | 'ResnetBlock', 'ResnetBlock', 'DownSampler', 203 | 'ResnetBlock', 'ResnetBlock', 'DownSampler', 204 | 'ResnetBlock', 'ResnetBlock', 205 | 'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock' 206 | ] 207 | 208 | # Rename each parameter 209 | local_rename_dict = { 210 | "quant_conv": "quant_conv", 211 | "encoder.conv_in": "conv_in", 212 | "encoder.mid_block.attentions.0.group_norm": "blocks.12.norm", 213 | "encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q", 214 | "encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k", 215 | "encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v", 216 | "encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out", 217 | "encoder.mid_block.resnets.0.norm1": "blocks.11.norm1", 218 | "encoder.mid_block.resnets.0.conv1": "blocks.11.conv1", 219 | "encoder.mid_block.resnets.0.norm2": "blocks.11.norm2", 220 | "encoder.mid_block.resnets.0.conv2": "blocks.11.conv2", 221 | "encoder.mid_block.resnets.1.norm1": "blocks.13.norm1", 222 | "encoder.mid_block.resnets.1.conv1": "blocks.13.conv1", 223 | "encoder.mid_block.resnets.1.norm2": "blocks.13.norm2", 224 | "encoder.mid_block.resnets.1.conv2": "blocks.13.conv2", 225 | "encoder.conv_norm_out": "conv_norm_out", 226 | "encoder.conv_out": "conv_out", 227 | } 228 | name_list = sorted([name for name in state_dict]) 229 | rename_dict = {} 230 | block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1} 231 | last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""} 232 | for name in name_list: 233 | names = name.split(".") 234 | name_prefix = ".".join(names[:-1]) 235 | if name_prefix in local_rename_dict: 236 | rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1] 237 | elif name.startswith("encoder.down_blocks"): 238 | block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]] 239 | block_type_with_id = ".".join(names[:5]) 240 | if block_type_with_id != last_block_type_with_id[block_type]: 241 | block_id[block_type] += 1 242 | last_block_type_with_id[block_type] = block_type_with_id 243 | while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type: 244 | block_id[block_type] += 1 245 | block_type_with_id = ".".join(names[:5]) 246 | names = ["blocks", str(block_id[block_type])] + names[5:] 247 | rename_dict[name] = ".".join(names) 248 | 249 | # Convert state_dict 250 | state_dict_ = {} 251 | for name, param in state_dict.items(): 252 | if name in rename_dict: 253 | state_dict_[rename_dict[name]] = param 254 | self.load_state_dict(state_dict_) 255 | 256 | 257 | @staticmethod 258 | def from_diffusers(file_path=None, state_dict=None): 259 | model = SDVAEEncoder() 260 | model.eval() 261 | model.load_state_dict_from_diffusers(file_path, state_dict) 262 | return model 263 | 264 | 265 | 266 | class SVDVAEDecoder(torch.nn.Module): 267 | def __init__(self): 268 | super().__init__() 269 | self.scaling_factor = 0.18215 270 | self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1) 271 | 272 | self.blocks = torch.nn.ModuleList([ 273 | # UNetMidBlock 274 | ResnetBlock(512, 512, eps=1e-6), 275 | TemporalResnetBlock(512, 512, eps=1e-6), 276 | VAEAttentionBlock(1, 512, 512, 1, eps=1e-6), 277 | ResnetBlock(512, 512, eps=1e-6), 278 | TemporalResnetBlock(512, 512, eps=1e-6), 279 | # UpDecoderBlock 280 | ResnetBlock(512, 512, eps=1e-6), 281 | TemporalResnetBlock(512, 512, eps=1e-6), 282 | ResnetBlock(512, 512, eps=1e-6), 283 | TemporalResnetBlock(512, 512, eps=1e-6), 284 | ResnetBlock(512, 512, eps=1e-6), 285 | TemporalResnetBlock(512, 512, eps=1e-6), 286 | UpSampler(512), 287 | # UpDecoderBlock 288 | ResnetBlock(512, 512, eps=1e-6), 289 | TemporalResnetBlock(512, 512, eps=1e-6), 290 | ResnetBlock(512, 512, eps=1e-6), 291 | TemporalResnetBlock(512, 512, eps=1e-6), 292 | ResnetBlock(512, 512, eps=1e-6), 293 | TemporalResnetBlock(512, 512, eps=1e-6), 294 | UpSampler(512), 295 | # UpDecoderBlock 296 | ResnetBlock(512, 256, eps=1e-6), 297 | TemporalResnetBlock(256, 256, eps=1e-6), 298 | ResnetBlock(256, 256, eps=1e-6), 299 | TemporalResnetBlock(256, 256, eps=1e-6), 300 | ResnetBlock(256, 256, eps=1e-6), 301 | TemporalResnetBlock(256, 256, eps=1e-6), 302 | UpSampler(256), 303 | # UpDecoderBlock 304 | ResnetBlock(256, 128, eps=1e-6), 305 | TemporalResnetBlock(128, 128, eps=1e-6), 306 | ResnetBlock(128, 128, eps=1e-6), 307 | TemporalResnetBlock(128, 128, eps=1e-6), 308 | ResnetBlock(128, 128, eps=1e-6), 309 | TemporalResnetBlock(128, 128, eps=1e-6), 310 | ]) 311 | 312 | self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5) 313 | self.conv_act = torch.nn.SiLU() 314 | self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1) 315 | self.time_conv_out = torch.nn.Conv3d(3, 3, kernel_size=(3, 1, 1), padding=(1, 0, 0)) 316 | 317 | 318 | def forward(self, sample): 319 | # 1. pre-process 320 | hidden_states = rearrange(sample, "C T H W -> T C H W") 321 | hidden_states = hidden_states / self.scaling_factor 322 | hidden_states = self.conv_in(hidden_states) 323 | 324 | # 2. blocks 325 | for i, block in enumerate(self.blocks): 326 | hidden_states = block(hidden_states) 327 | 328 | # 3. output 329 | hidden_states = self.conv_norm_out(hidden_states) 330 | hidden_states = self.conv_act(hidden_states) 331 | hidden_states = self.conv_out(hidden_states) 332 | hidden_states = rearrange(hidden_states, "T C H W -> C T H W") 333 | hidden_states = self.time_conv_out(hidden_states) 334 | 335 | return hidden_states 336 | 337 | 338 | def build_mask(self, data, is_bound): 339 | _, T, H, W = data.shape 340 | t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W) 341 | h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W) 342 | w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W) 343 | border_width = (T + H + W) // 6 344 | pad = torch.ones_like(t) * border_width 345 | mask = torch.stack([ 346 | pad if is_bound[0] else t + 1, 347 | pad if is_bound[1] else T - t, 348 | pad if is_bound[2] else h + 1, 349 | pad if is_bound[3] else H - h, 350 | pad if is_bound[4] else w + 1, 351 | pad if is_bound[5] else W - w 352 | ]).min(dim=0).values 353 | mask = mask.clip(1, border_width) 354 | mask = (mask / border_width).to(dtype=data.dtype, device=data.device) 355 | mask = rearrange(mask, "T H W -> 1 T H W") 356 | return mask 357 | 358 | 359 | def decode_video( 360 | self, sample, 361 | batch_time=32, batch_height=64, batch_width=64, 362 | stride_time=16, stride_height=32, stride_width=32, 363 | progress_bar=lambda x:x 364 | ): 365 | data_device = sample.device 366 | computation_device = self.conv_in.weight.device 367 | torch_dtype = sample.dtype 368 | _, T, H, W = sample.shape 369 | 370 | weight = torch.zeros((1, T, H*8, W*8), dtype=torch_dtype, device=data_device) 371 | values = torch.zeros((3, T, H*8, W*8), dtype=torch_dtype, device=data_device) 372 | 373 | # Split tasks 374 | tasks = [] 375 | for t in range(0, T, stride_time): 376 | for h in range(0, H, stride_height): 377 | for w in range(0, W, stride_width): 378 | if (t-stride_time >= 0 and t-stride_time+batch_time >= T)\ 379 | or (h-stride_height >= 0 and h-stride_height+batch_height >= H)\ 380 | or (w-stride_width >= 0 and w-stride_width+batch_width >= W): 381 | continue 382 | tasks.append((t, t+batch_time, h, h+batch_height, w, w+batch_width)) 383 | 384 | # Run 385 | for tl, tr, hl, hr, wl, wr in progress_bar(tasks): 386 | sample_batch = sample[:, tl:tr, hl:hr, wl:wr].to(computation_device) 387 | sample_batch = self.forward(sample_batch).to(data_device) 388 | mask = self.build_mask(sample_batch, is_bound=(tl==0, tr>=T, hl==0, hr>=H, wl==0, wr>=W)) 389 | values[:, tl:tr, hl*8:hr*8, wl*8:wr*8] += sample_batch * mask 390 | weight[:, tl:tr, hl*8:hr*8, wl*8:wr*8] += mask 391 | values /= weight 392 | return values 393 | 394 | 395 | def load_state_dict_from_diffusers(self, file_path=None, state_dict=None): 396 | if state_dict is None: 397 | state_dict = load_state_dict(file_path) 398 | 399 | static_rename_dict = { 400 | "decoder.conv_in": "conv_in", 401 | "decoder.mid_block.attentions.0.group_norm": "blocks.2.norm", 402 | "decoder.mid_block.attentions.0.to_q": "blocks.2.transformer_blocks.0.to_q", 403 | "decoder.mid_block.attentions.0.to_k": "blocks.2.transformer_blocks.0.to_k", 404 | "decoder.mid_block.attentions.0.to_v": "blocks.2.transformer_blocks.0.to_v", 405 | "decoder.mid_block.attentions.0.to_out.0": "blocks.2.transformer_blocks.0.to_out", 406 | "decoder.up_blocks.0.upsamplers.0.conv": "blocks.11.conv", 407 | "decoder.up_blocks.1.upsamplers.0.conv": "blocks.18.conv", 408 | "decoder.up_blocks.2.upsamplers.0.conv": "blocks.25.conv", 409 | "decoder.conv_norm_out": "conv_norm_out", 410 | "decoder.conv_out": "conv_out", 411 | "decoder.time_conv_out": "time_conv_out" 412 | } 413 | prefix_rename_dict = { 414 | "decoder.mid_block.resnets.0.spatial_res_block": "blocks.0", 415 | "decoder.mid_block.resnets.0.temporal_res_block": "blocks.1", 416 | "decoder.mid_block.resnets.0.time_mixer": "blocks.1", 417 | "decoder.mid_block.resnets.1.spatial_res_block": "blocks.3", 418 | "decoder.mid_block.resnets.1.temporal_res_block": "blocks.4", 419 | "decoder.mid_block.resnets.1.time_mixer": "blocks.4", 420 | 421 | "decoder.up_blocks.0.resnets.0.spatial_res_block": "blocks.5", 422 | "decoder.up_blocks.0.resnets.0.temporal_res_block": "blocks.6", 423 | "decoder.up_blocks.0.resnets.0.time_mixer": "blocks.6", 424 | "decoder.up_blocks.0.resnets.1.spatial_res_block": "blocks.7", 425 | "decoder.up_blocks.0.resnets.1.temporal_res_block": "blocks.8", 426 | "decoder.up_blocks.0.resnets.1.time_mixer": "blocks.8", 427 | "decoder.up_blocks.0.resnets.2.spatial_res_block": "blocks.9", 428 | "decoder.up_blocks.0.resnets.2.temporal_res_block": "blocks.10", 429 | "decoder.up_blocks.0.resnets.2.time_mixer": "blocks.10", 430 | 431 | "decoder.up_blocks.1.resnets.0.spatial_res_block": "blocks.12", 432 | "decoder.up_blocks.1.resnets.0.temporal_res_block": "blocks.13", 433 | "decoder.up_blocks.1.resnets.0.time_mixer": "blocks.13", 434 | "decoder.up_blocks.1.resnets.1.spatial_res_block": "blocks.14", 435 | "decoder.up_blocks.1.resnets.1.temporal_res_block": "blocks.15", 436 | "decoder.up_blocks.1.resnets.1.time_mixer": "blocks.15", 437 | "decoder.up_blocks.1.resnets.2.spatial_res_block": "blocks.16", 438 | "decoder.up_blocks.1.resnets.2.temporal_res_block": "blocks.17", 439 | "decoder.up_blocks.1.resnets.2.time_mixer": "blocks.17", 440 | 441 | "decoder.up_blocks.2.resnets.0.spatial_res_block": "blocks.19", 442 | "decoder.up_blocks.2.resnets.0.temporal_res_block": "blocks.20", 443 | "decoder.up_blocks.2.resnets.0.time_mixer": "blocks.20", 444 | "decoder.up_blocks.2.resnets.1.spatial_res_block": "blocks.21", 445 | "decoder.up_blocks.2.resnets.1.temporal_res_block": "blocks.22", 446 | "decoder.up_blocks.2.resnets.1.time_mixer": "blocks.22", 447 | "decoder.up_blocks.2.resnets.2.spatial_res_block": "blocks.23", 448 | "decoder.up_blocks.2.resnets.2.temporal_res_block": "blocks.24", 449 | "decoder.up_blocks.2.resnets.2.time_mixer": "blocks.24", 450 | 451 | "decoder.up_blocks.3.resnets.0.spatial_res_block": "blocks.26", 452 | "decoder.up_blocks.3.resnets.0.temporal_res_block": "blocks.27", 453 | "decoder.up_blocks.3.resnets.0.time_mixer": "blocks.27", 454 | "decoder.up_blocks.3.resnets.1.spatial_res_block": "blocks.28", 455 | "decoder.up_blocks.3.resnets.1.temporal_res_block": "blocks.29", 456 | "decoder.up_blocks.3.resnets.1.time_mixer": "blocks.29", 457 | "decoder.up_blocks.3.resnets.2.spatial_res_block": "blocks.30", 458 | "decoder.up_blocks.3.resnets.2.temporal_res_block": "blocks.31", 459 | "decoder.up_blocks.3.resnets.2.time_mixer": "blocks.31", 460 | } 461 | suffix_rename_dict = { 462 | "norm1.weight": "norm1.weight", 463 | "conv1.weight": "conv1.weight", 464 | "norm2.weight": "norm2.weight", 465 | "conv2.weight": "conv2.weight", 466 | "conv_shortcut.weight": "conv_shortcut.weight", 467 | "norm1.bias": "norm1.bias", 468 | "conv1.bias": "conv1.bias", 469 | "norm2.bias": "norm2.bias", 470 | "conv2.bias": "conv2.bias", 471 | "conv_shortcut.bias": "conv_shortcut.bias", 472 | "mix_factor": "mix_factor", 473 | } 474 | 475 | state_dict_ = {} 476 | for name in static_rename_dict: 477 | state_dict_[static_rename_dict[name] + ".weight"] = state_dict[name + ".weight"] 478 | state_dict_[static_rename_dict[name] + ".bias"] = state_dict[name + ".bias"] 479 | for prefix_name in prefix_rename_dict: 480 | for suffix_name in suffix_rename_dict: 481 | name = prefix_name + "." + suffix_name 482 | name_ = prefix_rename_dict[prefix_name] + "." + suffix_rename_dict[suffix_name] 483 | if name in state_dict: 484 | state_dict_[name_] = state_dict[name] 485 | self.load_state_dict(state_dict_) 486 | 487 | @staticmethod 488 | def from_diffusers(file_path=None, state_dict=None): 489 | model = SVDVAEDecoder() 490 | model.eval() 491 | model.load_state_dict_from_diffusers(file_path, state_dict) 492 | return model 493 | -------------------------------------------------------------------------------- /litesora/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors import safe_open 3 | 4 | 5 | def load_state_dict(file_path, torch_dtype=None): 6 | if file_path.endswith(".safetensors"): 7 | state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) 8 | else: 9 | state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) 10 | return state_dict 11 | 12 | 13 | def load_state_dict_from_safetensors(file_path, torch_dtype=None): 14 | state_dict = {} 15 | with safe_open(file_path, framework="pt", device="cpu") as f: 16 | for k in f.keys(): 17 | state_dict[k] = f.get_tensor(k) 18 | if torch_dtype is not None: 19 | state_dict[k] = state_dict[k].to(torch_dtype) 20 | return state_dict 21 | 22 | 23 | def load_state_dict_from_bin(file_path, torch_dtype=None): 24 | state_dict = torch.load(file_path, map_location="cpu") 25 | if torch_dtype is not None: 26 | state_dict = {i: state_dict[i].to(torch_dtype) for i in state_dict} 27 | return state_dict 28 | -------------------------------------------------------------------------------- /litesora/models/video_dit.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | from .attention import Attention 3 | from .utils import load_state_dict 4 | from einops import rearrange, repeat 5 | from functools import reduce 6 | 7 | 8 | class Timesteps(torch.nn.Module): 9 | def __init__(self, num_channels): 10 | super().__init__() 11 | self.num_channels = num_channels 12 | 13 | def forward(self, timesteps): 14 | half_dim = self.num_channels // 2 15 | exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) / half_dim 16 | timesteps = timesteps.unsqueeze(-1) 17 | emb = timesteps.float() * torch.exp(exponent) 18 | emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1) 19 | return emb 20 | 21 | 22 | class ConditioningFusion(torch.nn.Module): 23 | def __init__(self, dims_in, dim_out): 24 | super().__init__() 25 | self.projs = torch.nn.ModuleList([torch.nn.Linear(dim_in, dim_out) for dim_in in dims_in]) 26 | 27 | def forward(self, conditionings): 28 | conditionings = [proj(conditioning) for conditioning, proj in zip(conditionings, self.projs)] 29 | conditionings = torch.stack(conditionings).sum(axis=0) 30 | return conditionings 31 | 32 | 33 | class AdaLayerNormZero(torch.nn.Module): 34 | def __init__(self, dim_time, dim_text, dim_out): 35 | super().__init__() 36 | self.fusion = ConditioningFusion([dim_time, dim_text], dim_out) 37 | self.linear = torch.nn.Linear(dim_out, 6 * dim_out) 38 | 39 | def forward(self, time_emb, text_emb): 40 | conditionings = self.fusion([time_emb, text_emb]) 41 | conditionings = self.linear(torch.nn.functional.silu(conditionings)).unsqueeze(1) 42 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = conditionings.chunk(6, dim=-1) 43 | return shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp 44 | 45 | 46 | class DiTFeedForward(torch.nn.Module): 47 | def __init__(self, dim): 48 | super().__init__() 49 | self.proj_in = torch.nn.Linear(dim, dim * 4) 50 | self.proj_out = torch.nn.Linear(dim * 4, dim) 51 | 52 | def forward(self, hidden_states): 53 | dtype = hidden_states.dtype 54 | hidden_states = self.proj_in(hidden_states) 55 | hidden_states = torch.nn.functional.gelu(hidden_states.to(dtype=torch.float32), approximate="tanh").to(dtype=dtype) 56 | hidden_states = self.proj_out(hidden_states) 57 | return hidden_states 58 | 59 | 60 | class DiTBlock(torch.nn.Module): 61 | def __init__(self, dim_out, dim_time, dim_text, dim_head): 62 | super().__init__() 63 | self.adaln = AdaLayerNormZero(dim_time, dim_text, dim_out) 64 | self.norm1 = torch.nn.LayerNorm(dim_out, eps=1e-5, elementwise_affine=False) 65 | self.attn1 = Attention(dim_out, dim_out // dim_head, dim_head, bias_q=True, bias_kv=True, bias_out=True) 66 | self.norm2 = torch.nn.LayerNorm(dim_out, 1e-5, elementwise_affine=False) 67 | self.ff = DiTFeedForward(dim_out) 68 | 69 | 70 | def forward(self, hidden_states, time_emb, text_emb): 71 | # 0. AdaLayerNormZero (Conditioning Fusion) 72 | beta_1, gamma_1, alpha_1, beta_2, gamma_2, alpha_2 = self.adaln(time_emb, text_emb) 73 | 74 | # 1. Layer Norm 75 | norm_hidden_states = self.norm1(hidden_states) 76 | 77 | # 2. Scale, Shift 78 | norm_hidden_states = norm_hidden_states * (1 + gamma_1) + beta_1 79 | 80 | # 3. Multi-Head Self-Attention 81 | attn_output = self.attn1(norm_hidden_states) 82 | 83 | # 4. Scale & Add 84 | hidden_states = alpha_1 * attn_output + hidden_states 85 | 86 | # 5. Layer Norm 87 | norm_hidden_states = self.norm2(hidden_states) 88 | 89 | # 6. Scale & Shift 90 | norm_hidden_states = norm_hidden_states * (1 + gamma_2) + beta_2 91 | 92 | # 7. Pointwise Feedforward 93 | ff_output = self.ff(norm_hidden_states) 94 | 95 | # 8. Scale & Add 96 | hidden_states = alpha_2 * ff_output + hidden_states 97 | 98 | return hidden_states 99 | 100 | 101 | class VideoPatchEmbed(torch.nn.Module): 102 | def __init__(self, base_size=(16, 16, 16), patch_size=(16, 16, 16), in_channels=3, embed_dim=512): 103 | super().__init__() 104 | self.base_size = base_size 105 | self.patch_size = patch_size 106 | self.embed_dim = embed_dim 107 | self.proj_pos = torch.nn.Linear(embed_dim*3, embed_dim) 108 | self.proj_latent = torch.nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) 109 | 110 | def get_pos_embedding_1d(self, embed_dim, pos): 111 | omega = torch.arange(embed_dim // 2).to(torch.float64) * (2.0 / embed_dim) 112 | omega = 1.0 / 10000**omega 113 | 114 | pos = pos.reshape(-1) 115 | out = torch.einsum("m,d->md", pos, omega) 116 | emb_sin = torch.sin(out) 117 | emb_cos = torch.cos(out) 118 | 119 | emb = torch.concatenate([emb_sin, emb_cos], axis=1) 120 | return emb 121 | 122 | def get_pos_embedding_3d(self, embed_dim, grid_size, base_size): 123 | grid_t = torch.arange(grid_size[0]) / (grid_size[0] / base_size[0]) 124 | grid_h = torch.arange(grid_size[1]) / (grid_size[1] / base_size[1]) 125 | grid_w = torch.arange(grid_size[2]) / (grid_size[2] / base_size[2]) 126 | grid = torch.stack([ 127 | repeat(grid_t, "T -> T H W", T=grid_size[0], H=grid_size[1], W=grid_size[2]), 128 | repeat(grid_h, "H -> T H W", T=grid_size[0], H=grid_size[1], W=grid_size[2]), 129 | repeat(grid_w, "W -> T H W", T=grid_size[0], H=grid_size[1], W=grid_size[2]), 130 | ]) 131 | pos_embed = self.get_pos_embedding_1d(embed_dim, grid) 132 | pos_embed = rearrange(pos_embed, "(C N) D -> N (C D)", C=3) 133 | return pos_embed 134 | 135 | def forward(self, latent): 136 | pos_embed = self.get_pos_embedding_3d( 137 | self.embed_dim, 138 | (latent.shape[-3] // self.patch_size[0], latent.shape[-2] // self.patch_size[1], latent.shape[-1] // self.patch_size[1]), 139 | self.base_size 140 | ) 141 | pos_embed = pos_embed.unsqueeze(0).to(dtype=latent.dtype, device=latent.device) 142 | pos_embed = self.proj_pos(pos_embed) 143 | 144 | latent = self.proj_latent(latent) 145 | latent = rearrange(latent, "B C T H W -> B (T H W) C") 146 | 147 | return (latent + pos_embed).to(latent.dtype) 148 | 149 | 150 | class TimeEmbed(torch.nn.Module): 151 | def __init__(self, dim_time): 152 | super().__init__() 153 | self.time_proj = Timesteps(dim_time) 154 | self.time_embedding = torch.nn.Sequential( 155 | torch.nn.Linear(dim_time, dim_time), 156 | torch.nn.SiLU(), 157 | torch.nn.Linear(dim_time, dim_time) 158 | ) 159 | 160 | def forward(self, timesteps, dtype=torch.float32): 161 | time_emb = self.time_proj(timesteps).to(dtype=dtype) 162 | time_emb = self.time_embedding(time_emb) 163 | return time_emb 164 | 165 | 166 | class VideoDiT(torch.nn.Module): 167 | def __init__(self, dim_hidden=1024, dim_time=1024, dim_text=1280, dim_head=64, num_blocks=16, patch_size=(4, 4, 4), in_channels=3): 168 | super().__init__() 169 | self.time_emb = TimeEmbed(dim_time) 170 | self.patchify = VideoPatchEmbed((16, 16, 16), patch_size, in_channels, dim_hidden) 171 | self.blocks = torch.nn.ModuleList([DiTBlock(dim_hidden, dim_time, dim_text, dim_head) for _ in range(num_blocks)]) 172 | self.norm_out = torch.nn.LayerNorm(dim_hidden, eps=1e-5, elementwise_affine=False) 173 | self.proj_out = torch.nn.Linear(dim_hidden, reduce(lambda x,y: x*y, patch_size) * in_channels, bias=True) 174 | 175 | def forward(self, hidden_states, timesteps, text_emb): 176 | # Shape 177 | B, C, T, H, W = hidden_states.shape 178 | 179 | # Time Embedding 180 | time_emb = self.time_emb(timesteps, dtype=hidden_states.dtype) 181 | 182 | # Patchify 183 | hidden_states = self.patchify(hidden_states) 184 | 185 | # DiT Blocks 186 | for block in self.blocks: 187 | hidden_states = block(hidden_states, time_emb, text_emb) 188 | 189 | # The following computation is different from the original version of DiT 190 | # We make it simple. 191 | hidden_states = self.norm_out(hidden_states) 192 | hidden_states = self.proj_out(hidden_states) 193 | hidden_states = rearrange( 194 | hidden_states, 195 | "B (T H W) (PT PH PW C) -> B C (T PT) (H PH) (W PW)", 196 | T=T//self.patchify.patch_size[0], H=H//self.patchify.patch_size[1], W=W//self.patchify.patch_size[2], 197 | PT=self.patchify.patch_size[0], PH=self.patchify.patch_size[1], PW=self.patchify.patch_size[2] 198 | ) 199 | 200 | return hidden_states 201 | 202 | @staticmethod 203 | def from_pretrained(file_path): 204 | state_dict = load_state_dict(file_path) 205 | if "state_dict" in state_dict: 206 | state_dict = state_dict["state_dict"] 207 | state_dict = {i[len("denoising_model."):]: state_dict[i] for i in state_dict if i.startswith("denoising_model.")} 208 | 209 | model = VideoDiT() 210 | model.eval() 211 | model.load_state_dict(state_dict) 212 | return model 213 | -------------------------------------------------------------------------------- /litesora/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .pixel_video_dit import PixelVideoDiTPipeline 2 | -------------------------------------------------------------------------------- /litesora/pipelines/pixel_video_dit.py: -------------------------------------------------------------------------------- 1 | from ..models.video_dit import VideoDiT 2 | from ..models.sdxl_text_encoder_2 import SDXLTextEncoder2 3 | from ..schedulers.ddim import DDIMScheduler 4 | from transformers import CLIPTokenizer 5 | import torch 6 | from tqdm import tqdm 7 | from PIL import Image 8 | import numpy as np 9 | from einops import rearrange 10 | 11 | 12 | class PixelVideoDiTPipeline(torch.nn.Module): 13 | 14 | def __init__(self, device="cuda", torch_dtype=torch.float16): 15 | super().__init__() 16 | self.tokenizer = CLIPTokenizer.from_pretrained("configs/stable_diffusion_xl/tokenizer_2") 17 | self.scheduler = DDIMScheduler() 18 | self.device = device 19 | self.torch_dtype = torch_dtype 20 | self.text_encoder: SDXLTextEncoder2 = None 21 | self.denoising_model: VideoDiT = None 22 | 23 | 24 | def fetch_models(self, text_encoder, denoising_model): 25 | self.text_encoder = text_encoder.to(dtype=self.torch_dtype, device=self.device) 26 | self.denoising_model = denoising_model.to(dtype=self.torch_dtype, device=self.device) 27 | 28 | 29 | def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): 30 | image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] 31 | image = image.cpu().permute(1, 2, 0).numpy() 32 | image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) 33 | return image 34 | 35 | 36 | def decode_video(self, frames): 37 | frames = rearrange(frames[0], "C T H W -> T H W C") 38 | frames = ((frames + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8) 39 | return frames 40 | 41 | 42 | def tokenize(self, prompt): 43 | input_ids = self.tokenizer( 44 | prompt, 45 | return_tensors="pt", 46 | padding="max_length", 47 | max_length=77, 48 | truncation=True 49 | ).input_ids 50 | return input_ids 51 | 52 | 53 | def encode_prompt(self, prompt): 54 | input_ids = self.tokenize(prompt).to(self.device) 55 | text_emb = self.text_encoder(input_ids) 56 | return text_emb 57 | 58 | 59 | @torch.no_grad() 60 | def __call__( 61 | self, 62 | prompt="", 63 | negative_prompt="", 64 | cfg_scale=1.0, 65 | use_cfg=True, 66 | denoising_strength=1.0, 67 | num_frames=64, 68 | height=64, 69 | width=64, 70 | num_inference_steps=20, 71 | progress_bar_cmd=tqdm 72 | ): 73 | # Prepare scheduler 74 | self.scheduler.set_timesteps(num_inference_steps, denoising_strength) 75 | 76 | # Prepare latent tensors 77 | latents = torch.randn((1, 3, num_frames, height, width), device=self.device, dtype=self.torch_dtype) 78 | 79 | # TODO: Encode prompts 80 | prompt_emb_posi = self.encode_prompt(prompt) 81 | prompt_emb_nega = self.encode_prompt(negative_prompt) 82 | 83 | # Denoise 84 | for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): 85 | timestep = torch.IntTensor((timestep,)).to(self.device) 86 | 87 | # Classifier-free guidance 88 | if use_cfg: 89 | noise_pred_posi = self.denoising_model(latents, timestep, prompt_emb_posi) 90 | noise_pred_nega = self.denoising_model(latents, timestep, prompt_emb_nega) 91 | noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) 92 | else: 93 | noise_pred = self.denoising_model(latents, timestep, prompt_emb_posi) 94 | 95 | # Call scheduler 96 | latents = self.scheduler.step(noise_pred, timestep, latents) 97 | 98 | # Decode video 99 | video = self.decode_video(latents) 100 | 101 | return video 102 | -------------------------------------------------------------------------------- /litesora/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .ddim import DDIMScheduler 2 | -------------------------------------------------------------------------------- /litesora/schedulers/ddim.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | 3 | 4 | class DDIMScheduler(): 5 | 6 | def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"): 7 | self.num_train_timesteps = num_train_timesteps 8 | if beta_schedule == "scaled_linear": 9 | betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32)) 10 | elif beta_schedule == "linear": 11 | betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 12 | else: 13 | raise NotImplementedError(f"{beta_schedule} is not implemented") 14 | self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist() 15 | self.set_timesteps(10) 16 | 17 | 18 | def set_timesteps(self, num_inference_steps, denoising_strength=1.0): 19 | # The timesteps are aligned to 999...0, which is different from other implementations, 20 | # but I think this implementation is more reasonable in theory. 21 | max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0) 22 | num_inference_steps = min(num_inference_steps, max_timestep + 1) 23 | if num_inference_steps == 1: 24 | self.timesteps = [max_timestep] 25 | else: 26 | step_length = max_timestep / (num_inference_steps - 1) 27 | self.timesteps = [round(max_timestep - i*step_length) for i in range(num_inference_steps)] 28 | 29 | 30 | def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev): 31 | weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t) 32 | weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t) 33 | 34 | prev_sample = sample * weight_x + model_output * weight_e 35 | 36 | return prev_sample 37 | 38 | 39 | def step(self, model_output, timestep, sample, to_final=False): 40 | alpha_prod_t = self.alphas_cumprod[timestep] 41 | timestep_id = self.timesteps.index(timestep) 42 | if to_final or timestep_id + 1 >= len(self.timesteps): 43 | alpha_prod_t_prev = 1.0 44 | else: 45 | timestep_prev = self.timesteps[timestep_id + 1] 46 | alpha_prod_t_prev = self.alphas_cumprod[timestep_prev] 47 | 48 | return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev) 49 | 50 | 51 | def return_to_timestep(self, timestep, sample, sample_stablized): 52 | alpha_prod_t = self.alphas_cumprod[timestep] 53 | noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t) 54 | return noise_pred 55 | 56 | 57 | def add_noise(self, original_samples, noise, timestep): 58 | sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep]) 59 | sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep]) 60 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 61 | return noisy_samples 62 | 63 | -------------------------------------------------------------------------------- /litesora/trainers/v1.py: -------------------------------------------------------------------------------- 1 | import lightning as pl 2 | import torch 3 | from transformers import CLIPTokenizer 4 | from ..models import SDXLTextEncoder2, VideoDiT 5 | from ..schedulers import DDIMScheduler 6 | 7 | 8 | class LightningVideoDiT(pl.LightningModule): 9 | def __init__(self, learning_rate=1e-5): 10 | super().__init__() 11 | self.tokenizer = CLIPTokenizer.from_pretrained("configs/stable_diffusion_xl/tokenizer_2") 12 | self.text_encoder = SDXLTextEncoder2() 13 | self.denoising_model = VideoDiT() 14 | self.noise_scheduler = DDIMScheduler() 15 | self.learning_rate = learning_rate 16 | self.text_encoder.requires_grad_(False) 17 | 18 | def tokenize(self, prompt): 19 | input_ids = self.tokenizer( 20 | prompt, 21 | return_tensors="pt", 22 | padding="max_length", 23 | max_length=77, 24 | truncation=True 25 | ).input_ids 26 | return input_ids 27 | 28 | def training_step(self, batch, batch_idx): 29 | hidden_states, text = batch["frames"], batch["text"] 30 | 31 | with torch.no_grad(): 32 | input_ids = self.tokenize(text[0]).to(self.device) 33 | text_emb = self.text_encoder(input_ids) 34 | 35 | timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, (1,), device="cuda") 36 | noise = torch.randn_like(hidden_states) 37 | noisy_latents = self.noise_scheduler.add_noise(hidden_states, noise, timesteps) 38 | 39 | model_pred = self.denoising_model(noisy_latents, timesteps, text_emb) 40 | loss = torch.nn.functional.mse_loss(model_pred, noise, reduction="mean") 41 | 42 | self.log("train_loss", loss) 43 | return loss 44 | 45 | def configure_optimizers(self): 46 | optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) 47 | return optimizer -------------------------------------------------------------------------------- /models/denoising_model/Put denoising model checkpoints here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelscope/lite-sora/75caae14a22f4a127cc36c13d8171376450d7887/models/denoising_model/Put denoising model checkpoints here.txt -------------------------------------------------------------------------------- /models/text_encoder/Put text encoder checkpoints here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelscope/lite-sora/75caae14a22f4a127cc36c13d8171376450d7887/models/text_encoder/Put text encoder checkpoints here.txt -------------------------------------------------------------------------------- /models/vae/Put VAE checkpoints here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/modelscope/lite-sora/75caae14a22f4a127cc36c13d8171376450d7887/models/vae/Put VAE checkpoints here.txt --------------------------------------------------------------------------------